File size: 5,924 Bytes
9b1c753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#!/usr/bin/env python3
"""
Quick Test Script - Verify Installation and Setup
Tests basic functionality without full training
"""
import sys
import os
def test_imports():
"""Test if all required packages can be imported"""
print("π Testing imports...")
errors = []
# Core dependencies
packages = [
('torch', 'PyTorch'),
('transformers', 'Transformers'),
('sklearn', 'scikit-learn'),
('pandas', 'Pandas'),
('numpy', 'NumPy')
]
for package, name in packages:
try:
__import__(package)
print(f" β
{name}")
except ImportError:
print(f" β {name} - NOT INSTALLED")
errors.append(name)
# Optional dependencies
optional_packages = [
('matplotlib', 'Matplotlib'),
('seaborn', 'Seaborn')
]
print("\nπ Optional packages:")
for package, name in optional_packages:
try:
__import__(package)
print(f" β
{name}")
except ImportError:
print(f" β οΈ {name} - Not installed (visualizations will be skipped)")
return len(errors) == 0, errors
def test_module_imports():
"""Test if project modules can be imported"""
print("\nπ Testing project modules...")
modules = [
'config',
'data_loader',
'risk_discovery',
'model',
'trainer',
'evaluator',
'utils'
]
errors = []
for module in modules:
try:
__import__(module)
print(f" β
{module}.py")
except Exception as e:
print(f" β {module}.py - {str(e)[:50]}")
errors.append(module)
return len(errors) == 0, errors
def test_configuration():
"""Test configuration loading"""
print("\nπ Testing configuration...")
try:
from config import LegalBertConfig
config = LegalBertConfig()
print(f" β
Configuration loaded")
print(f" - BERT Model: {config.bert_model_name}")
print(f" - Batch Size: {config.batch_size}")
print(f" - Epochs: {config.num_epochs}")
print(f" - Device: {config.device}")
print(f" - Data Path: {config.data_path}")
# Check if data path exists
if os.path.exists(config.data_path):
print(f" β
CUAD dataset found at {config.data_path}")
else:
print(f" β οΈ CUAD dataset NOT found at {config.data_path}")
print(f" Please download and place it at this location")
return True, None
except Exception as e:
print(f" β Configuration failed: {e}")
return False, str(e)
def test_model_initialization():
"""Test model can be initialized"""
print("\nπ Testing model initialization...")
try:
from config import LegalBertConfig
from model import FullyLearningBasedLegalBERT, HierarchicalLegalBERT, LegalBertTokenizer
config = LegalBertConfig()
# Test tokenizer
tokenizer = LegalBertTokenizer(config.bert_model_name)
test_text = "This agreement shall be governed by the laws of Delaware."
encoded = tokenizer.tokenize_clauses([test_text], max_length=128)
print(f" β
Tokenizer works")
print(f" Input shape: {encoded['input_ids'].shape}")
# Test model (without downloading weights for quick test)
print(f" βΉοΈ Model initialization skipped (use train.py to fully initialize)")
return True, None
except Exception as e:
print(f" β Model initialization failed: {e}")
return False, str(e)
def test_data_loader():
"""Test data loader (if data exists)"""
print("\nπ Testing data loader...")
try:
from config import LegalBertConfig
from data_loader import CUADDataLoader
config = LegalBertConfig()
if not os.path.exists(config.data_path):
print(f" β οΈ Data file not found, skipping data loader test")
return True, None
loader = CUADDataLoader(config.data_path)
print(f" β
Data loader initialized")
# Try to load a small sample
print(f" βΉοΈ Full data loading skipped (use train.py for full load)")
return True, None
except Exception as e:
print(f" β Data loader failed: {e}")
return False, str(e)
def main():
"""Run all tests"""
print("=" * 80)
print("π§ͺ LEGAL-BERT PROJECT - QUICK TEST")
print("=" * 80)
all_passed = True
# Test imports
passed, errors = test_imports()
if not passed:
print(f"\nβ Missing required packages: {', '.join(errors)}")
print(f" Install with: pip install -r requirements.txt")
all_passed = False
# Test modules
passed, errors = test_module_imports()
if not passed:
print(f"\nβ Module import errors: {', '.join(errors)}")
all_passed = False
# Test configuration
passed, error = test_configuration()
if not passed:
all_passed = False
# Test model initialization
passed, error = test_model_initialization()
if not passed:
all_passed = False
# Test data loader
passed, error = test_data_loader()
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 80)
if all_passed:
print("β
ALL TESTS PASSED!")
print("=" * 80)
print("\nπ Ready to train! Run: python train.py")
else:
print("β SOME TESTS FAILED")
print("=" * 80)
print("\nβ οΈ Please fix the issues above before training")
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())
|