code2-repo / test_setup.py
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
#!/usr/bin/env python3
"""
Quick Test Script - Verify Installation and Setup
Tests basic functionality without full training
"""
import sys
import os
def test_imports():
"""Test if all required packages can be imported"""
print("πŸ” Testing imports...")
errors = []
# Core dependencies
packages = [
('torch', 'PyTorch'),
('transformers', 'Transformers'),
('sklearn', 'scikit-learn'),
('pandas', 'Pandas'),
('numpy', 'NumPy')
]
for package, name in packages:
try:
__import__(package)
print(f" βœ… {name}")
except ImportError:
print(f" ❌ {name} - NOT INSTALLED")
errors.append(name)
# Optional dependencies
optional_packages = [
('matplotlib', 'Matplotlib'),
('seaborn', 'Seaborn')
]
print("\nπŸ“Š Optional packages:")
for package, name in optional_packages:
try:
__import__(package)
print(f" βœ… {name}")
except ImportError:
print(f" ⚠️ {name} - Not installed (visualizations will be skipped)")
return len(errors) == 0, errors
def test_module_imports():
"""Test if project modules can be imported"""
print("\nπŸ” Testing project modules...")
modules = [
'config',
'data_loader',
'risk_discovery',
'model',
'trainer',
'evaluator',
'utils'
]
errors = []
for module in modules:
try:
__import__(module)
print(f" βœ… {module}.py")
except Exception as e:
print(f" ❌ {module}.py - {str(e)[:50]}")
errors.append(module)
return len(errors) == 0, errors
def test_configuration():
"""Test configuration loading"""
print("\nπŸ” Testing configuration...")
try:
from config import LegalBertConfig
config = LegalBertConfig()
print(f" βœ… Configuration loaded")
print(f" - BERT Model: {config.bert_model_name}")
print(f" - Batch Size: {config.batch_size}")
print(f" - Epochs: {config.num_epochs}")
print(f" - Device: {config.device}")
print(f" - Data Path: {config.data_path}")
# Check if data path exists
if os.path.exists(config.data_path):
print(f" βœ… CUAD dataset found at {config.data_path}")
else:
print(f" ⚠️ CUAD dataset NOT found at {config.data_path}")
print(f" Please download and place it at this location")
return True, None
except Exception as e:
print(f" ❌ Configuration failed: {e}")
return False, str(e)
def test_model_initialization():
"""Test model can be initialized"""
print("\nπŸ” Testing model initialization...")
try:
from config import LegalBertConfig
from model import FullyLearningBasedLegalBERT, HierarchicalLegalBERT, LegalBertTokenizer
config = LegalBertConfig()
# Test tokenizer
tokenizer = LegalBertTokenizer(config.bert_model_name)
test_text = "This agreement shall be governed by the laws of Delaware."
encoded = tokenizer.tokenize_clauses([test_text], max_length=128)
print(f" βœ… Tokenizer works")
print(f" Input shape: {encoded['input_ids'].shape}")
# Test model (without downloading weights for quick test)
print(f" ℹ️ Model initialization skipped (use train.py to fully initialize)")
return True, None
except Exception as e:
print(f" ❌ Model initialization failed: {e}")
return False, str(e)
def test_data_loader():
"""Test data loader (if data exists)"""
print("\nπŸ” Testing data loader...")
try:
from config import LegalBertConfig
from data_loader import CUADDataLoader
config = LegalBertConfig()
if not os.path.exists(config.data_path):
print(f" ⚠️ Data file not found, skipping data loader test")
return True, None
loader = CUADDataLoader(config.data_path)
print(f" βœ… Data loader initialized")
# Try to load a small sample
print(f" ℹ️ Full data loading skipped (use train.py for full load)")
return True, None
except Exception as e:
print(f" ❌ Data loader failed: {e}")
return False, str(e)
def main():
"""Run all tests"""
print("=" * 80)
print("πŸ§ͺ LEGAL-BERT PROJECT - QUICK TEST")
print("=" * 80)
all_passed = True
# Test imports
passed, errors = test_imports()
if not passed:
print(f"\n❌ Missing required packages: {', '.join(errors)}")
print(f" Install with: pip install -r requirements.txt")
all_passed = False
# Test modules
passed, errors = test_module_imports()
if not passed:
print(f"\n❌ Module import errors: {', '.join(errors)}")
all_passed = False
# Test configuration
passed, error = test_configuration()
if not passed:
all_passed = False
# Test model initialization
passed, error = test_model_initialization()
if not passed:
all_passed = False
# Test data loader
passed, error = test_data_loader()
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 80)
if all_passed:
print("βœ… ALL TESTS PASSED!")
print("=" * 80)
print("\nπŸš€ Ready to train! Run: python train.py")
else:
print("❌ SOME TESTS FAILED")
print("=" * 80)
print("\n⚠️ Please fix the issues above before training")
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())