|
|
|
|
|
""" |
|
|
Quick Test Script - Verify Installation and Setup |
|
|
Tests basic functionality without full training |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
|
|
|
def test_imports(): |
|
|
"""Test if all required packages can be imported""" |
|
|
print("π Testing imports...") |
|
|
|
|
|
errors = [] |
|
|
|
|
|
|
|
|
packages = [ |
|
|
('torch', 'PyTorch'), |
|
|
('transformers', 'Transformers'), |
|
|
('sklearn', 'scikit-learn'), |
|
|
('pandas', 'Pandas'), |
|
|
('numpy', 'NumPy') |
|
|
] |
|
|
|
|
|
for package, name in packages: |
|
|
try: |
|
|
__import__(package) |
|
|
print(f" β
{name}") |
|
|
except ImportError: |
|
|
print(f" β {name} - NOT INSTALLED") |
|
|
errors.append(name) |
|
|
|
|
|
|
|
|
optional_packages = [ |
|
|
('matplotlib', 'Matplotlib'), |
|
|
('seaborn', 'Seaborn') |
|
|
] |
|
|
|
|
|
print("\nπ Optional packages:") |
|
|
for package, name in optional_packages: |
|
|
try: |
|
|
__import__(package) |
|
|
print(f" β
{name}") |
|
|
except ImportError: |
|
|
print(f" β οΈ {name} - Not installed (visualizations will be skipped)") |
|
|
|
|
|
return len(errors) == 0, errors |
|
|
|
|
|
def test_module_imports(): |
|
|
"""Test if project modules can be imported""" |
|
|
print("\nπ Testing project modules...") |
|
|
|
|
|
modules = [ |
|
|
'config', |
|
|
'data_loader', |
|
|
'risk_discovery', |
|
|
'model', |
|
|
'trainer', |
|
|
'evaluator', |
|
|
'utils' |
|
|
] |
|
|
|
|
|
errors = [] |
|
|
for module in modules: |
|
|
try: |
|
|
__import__(module) |
|
|
print(f" β
{module}.py") |
|
|
except Exception as e: |
|
|
print(f" β {module}.py - {str(e)[:50]}") |
|
|
errors.append(module) |
|
|
|
|
|
return len(errors) == 0, errors |
|
|
|
|
|
def test_configuration(): |
|
|
"""Test configuration loading""" |
|
|
print("\nπ Testing configuration...") |
|
|
|
|
|
try: |
|
|
from config import LegalBertConfig |
|
|
config = LegalBertConfig() |
|
|
|
|
|
print(f" β
Configuration loaded") |
|
|
print(f" - BERT Model: {config.bert_model_name}") |
|
|
print(f" - Batch Size: {config.batch_size}") |
|
|
print(f" - Epochs: {config.num_epochs}") |
|
|
print(f" - Device: {config.device}") |
|
|
print(f" - Data Path: {config.data_path}") |
|
|
|
|
|
|
|
|
if os.path.exists(config.data_path): |
|
|
print(f" β
CUAD dataset found at {config.data_path}") |
|
|
else: |
|
|
print(f" β οΈ CUAD dataset NOT found at {config.data_path}") |
|
|
print(f" Please download and place it at this location") |
|
|
|
|
|
return True, None |
|
|
except Exception as e: |
|
|
print(f" β Configuration failed: {e}") |
|
|
return False, str(e) |
|
|
|
|
|
def test_model_initialization(): |
|
|
"""Test model can be initialized""" |
|
|
print("\nπ Testing model initialization...") |
|
|
|
|
|
try: |
|
|
from config import LegalBertConfig |
|
|
from model import FullyLearningBasedLegalBERT, HierarchicalLegalBERT, LegalBertTokenizer |
|
|
|
|
|
config = LegalBertConfig() |
|
|
|
|
|
|
|
|
tokenizer = LegalBertTokenizer(config.bert_model_name) |
|
|
test_text = "This agreement shall be governed by the laws of Delaware." |
|
|
encoded = tokenizer.tokenize_clauses([test_text], max_length=128) |
|
|
print(f" β
Tokenizer works") |
|
|
print(f" Input shape: {encoded['input_ids'].shape}") |
|
|
|
|
|
|
|
|
print(f" βΉοΈ Model initialization skipped (use train.py to fully initialize)") |
|
|
|
|
|
return True, None |
|
|
except Exception as e: |
|
|
print(f" β Model initialization failed: {e}") |
|
|
return False, str(e) |
|
|
|
|
|
def test_data_loader(): |
|
|
"""Test data loader (if data exists)""" |
|
|
print("\nπ Testing data loader...") |
|
|
|
|
|
try: |
|
|
from config import LegalBertConfig |
|
|
from data_loader import CUADDataLoader |
|
|
|
|
|
config = LegalBertConfig() |
|
|
|
|
|
if not os.path.exists(config.data_path): |
|
|
print(f" β οΈ Data file not found, skipping data loader test") |
|
|
return True, None |
|
|
|
|
|
loader = CUADDataLoader(config.data_path) |
|
|
print(f" β
Data loader initialized") |
|
|
|
|
|
|
|
|
print(f" βΉοΈ Full data loading skipped (use train.py for full load)") |
|
|
|
|
|
return True, None |
|
|
except Exception as e: |
|
|
print(f" β Data loader failed: {e}") |
|
|
return False, str(e) |
|
|
|
|
|
def main(): |
|
|
"""Run all tests""" |
|
|
print("=" * 80) |
|
|
print("π§ͺ LEGAL-BERT PROJECT - QUICK TEST") |
|
|
print("=" * 80) |
|
|
|
|
|
all_passed = True |
|
|
|
|
|
|
|
|
passed, errors = test_imports() |
|
|
if not passed: |
|
|
print(f"\nβ Missing required packages: {', '.join(errors)}") |
|
|
print(f" Install with: pip install -r requirements.txt") |
|
|
all_passed = False |
|
|
|
|
|
|
|
|
passed, errors = test_module_imports() |
|
|
if not passed: |
|
|
print(f"\nβ Module import errors: {', '.join(errors)}") |
|
|
all_passed = False |
|
|
|
|
|
|
|
|
passed, error = test_configuration() |
|
|
if not passed: |
|
|
all_passed = False |
|
|
|
|
|
|
|
|
passed, error = test_model_initialization() |
|
|
if not passed: |
|
|
all_passed = False |
|
|
|
|
|
|
|
|
passed, error = test_data_loader() |
|
|
if not passed: |
|
|
all_passed = False |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
if all_passed: |
|
|
print("β
ALL TESTS PASSED!") |
|
|
print("=" * 80) |
|
|
print("\nπ Ready to train! Run: python train.py") |
|
|
else: |
|
|
print("β SOME TESTS FAILED") |
|
|
print("=" * 80) |
|
|
print("\nβ οΈ Please fix the issues above before training") |
|
|
|
|
|
return 0 if all_passed else 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|