|
|
""" |
|
|
Quick test to verify Hierarchical BERT is the only model |
|
|
Tests import and basic functionality |
|
|
""" |
|
|
|
|
|
print("=" * 70) |
|
|
print("HIERARCHICAL BERT TEST (Hierarchical Only)") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
print("\n1. Testing config...") |
|
|
try: |
|
|
from config import LegalBertConfig |
|
|
config = LegalBertConfig() |
|
|
assert hasattr(config, 'hierarchical_hidden_dim'), "Missing hierarchical_hidden_dim" |
|
|
assert hasattr(config, 'hierarchical_num_lstm_layers'), "Missing hierarchical_num_lstm_layers" |
|
|
print(" β
Config OK (hierarchical parameters present)") |
|
|
except Exception as e: |
|
|
print(f" β Config failed: {e}") |
|
|
|
|
|
|
|
|
print("\n2. Testing model imports...") |
|
|
try: |
|
|
from model import HierarchicalLegalBERT, LegalBertTokenizer |
|
|
print(" β
Model imports OK (HierarchicalLegalBERT only)") |
|
|
except Exception as e: |
|
|
print(f" β Model imports failed: {e}") |
|
|
|
|
|
|
|
|
print("\n3. Testing trainer imports...") |
|
|
try: |
|
|
from trainer import LegalBertTrainer |
|
|
print(" β
Trainer imports OK") |
|
|
except Exception as e: |
|
|
print(f" β Trainer imports failed: {e}") |
|
|
|
|
|
|
|
|
print("\n4. Testing hierarchical model initialization...") |
|
|
try: |
|
|
config = LegalBertConfig() |
|
|
|
|
|
|
|
|
print(" Testing hierarchical model...") |
|
|
model_hierarchical = HierarchicalLegalBERT( |
|
|
config, |
|
|
num_discovered_risks=7, |
|
|
hidden_dim=256, |
|
|
num_lstm_layers=2 |
|
|
) |
|
|
print(" β
Hierarchical model OK") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" β Model initialization failed: {e}") |
|
|
|
|
|
|
|
|
print("\n5. Testing forward method compatibility...") |
|
|
try: |
|
|
|
|
|
assert hasattr(model_hierarchical, 'forward_single_clause'), "Missing forward_single_clause" |
|
|
assert hasattr(model_hierarchical, 'forward_document'), "Missing forward_document" |
|
|
assert hasattr(model_hierarchical, 'predict_document'), "Missing predict_document" |
|
|
print(" β
Forward methods OK") |
|
|
except Exception as e: |
|
|
print(f" β Forward methods failed: {e}") |
|
|
|
|
|
|
|
|
print("\n6. Testing trainer compatibility...") |
|
|
try: |
|
|
config = LegalBertConfig() |
|
|
trainer = LegalBertTrainer(config) |
|
|
print(" β
Trainer compatible") |
|
|
except Exception as e: |
|
|
print(f" β Trainer compatibility failed: {e}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("β
ALL TESTS PASSED!") |
|
|
print("=" * 70) |
|
|
print("\nπ― Ready to train:") |
|
|
print(" β’ python train.py # Train Hierarchical BERT") |
|
|
print(" β’ python train.py --epochs 10 # With custom epochs") |
|
|
print("\nπ Hierarchical BERT is now the ONLY model") |
|
|
print("οΏ½ See HIERARCHICAL_ONLY_SUMMARY.md for details") |
|
|
|