|
|
""" |
|
|
Evaluation Script for Legal-BERT |
|
|
Executes Week 8: Comprehensive Evaluation & Analysis |
|
|
""" |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
|
|
|
from config import LegalBertConfig |
|
|
from trainer import LegalBertTrainer, collate_batch |
|
|
from evaluator import LegalBertEvaluator |
|
|
from data_loader import CUADDataLoader |
|
|
from risk_discovery import UnsupervisedRiskDiscovery |
|
|
|
|
|
def main(): |
|
|
"""Execute Legal-BERT evaluation pipeline""" |
|
|
|
|
|
print("=" * 80) |
|
|
print("π LEGAL-BERT EVALUATION PIPELINE") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
config = LegalBertConfig() |
|
|
|
|
|
|
|
|
print("\nπ Loading trained model...") |
|
|
model_path = os.path.join(config.model_save_path, 'final_model.pt') |
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
print(f"β Error: Model not found at {model_path}") |
|
|
print("Please train the model first using: python train.py") |
|
|
return |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=config.device, weights_only=False) |
|
|
|
|
|
|
|
|
trainer = LegalBertTrainer(config) |
|
|
|
|
|
|
|
|
if 'risk_discovery_model' in checkpoint: |
|
|
trainer.risk_discovery = checkpoint['risk_discovery_model'] |
|
|
else: |
|
|
|
|
|
trainer.risk_discovery.discovered_patterns = checkpoint['discovered_patterns'] |
|
|
trainer.risk_discovery.n_clusters = len(checkpoint['discovered_patterns']) |
|
|
|
|
|
|
|
|
from model import HierarchicalLegalBERT |
|
|
|
|
|
|
|
|
if 'config' in checkpoint: |
|
|
saved_config = checkpoint['config'] |
|
|
hidden_dim = saved_config.hierarchical_hidden_dim |
|
|
num_lstm_layers = saved_config.hierarchical_num_lstm_layers |
|
|
print(f" Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}") |
|
|
else: |
|
|
|
|
|
hidden_dim = config.hierarchical_hidden_dim |
|
|
num_lstm_layers = config.hierarchical_num_lstm_layers |
|
|
print(f" β οΈ Warning: No config in checkpoint, using current config") |
|
|
|
|
|
print("π Loading Hierarchical BERT model") |
|
|
trainer.model = HierarchicalLegalBERT( |
|
|
config=config, |
|
|
num_discovered_risks=trainer.risk_discovery.n_clusters, |
|
|
hidden_dim=hidden_dim, |
|
|
num_lstm_layers=num_lstm_layers |
|
|
).to(config.device) |
|
|
|
|
|
trainer.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
|
|
|
|
|
|
print("\nπ Loading test data...") |
|
|
data_loader = CUADDataLoader(config.data_path) |
|
|
df_clauses, contracts = data_loader.load_data() |
|
|
splits = data_loader.create_splits() |
|
|
|
|
|
|
|
|
test_clauses = splits['test']['clause_text'].tolist() |
|
|
risk_labels = trainer.risk_discovery.get_risk_labels(test_clauses) |
|
|
severity_scores = trainer._generate_synthetic_scores(test_clauses, 'severity') |
|
|
importance_scores = trainer._generate_synthetic_scores(test_clauses, 'importance') |
|
|
|
|
|
from trainer import LegalClauseDataset |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
test_dataset = LegalClauseDataset( |
|
|
clauses=test_clauses, |
|
|
risk_labels=risk_labels, |
|
|
severity_scores=severity_scores, |
|
|
importance_scores=importance_scores, |
|
|
tokenizer=trainer.tokenizer, |
|
|
max_length=config.max_sequence_length |
|
|
) |
|
|
|
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
collate_fn=collate_batch |
|
|
) |
|
|
|
|
|
print(f"β
Test data prepared: {len(test_dataset)} samples") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("π PHASE 1: MODEL EVALUATION") |
|
|
print("=" * 80) |
|
|
|
|
|
evaluator = LegalBertEvaluator( |
|
|
model=trainer.model, |
|
|
tokenizer=trainer.tokenizer, |
|
|
risk_discovery=trainer.risk_discovery |
|
|
) |
|
|
|
|
|
|
|
|
results = evaluator.evaluate_model(test_loader, save_results=True) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("π EVALUATION REPORT") |
|
|
print("=" * 80) |
|
|
|
|
|
report = evaluator.generate_report() |
|
|
print(report) |
|
|
|
|
|
|
|
|
results_path = os.path.join(config.checkpoint_dir, 'evaluation_results.json') |
|
|
|
|
|
|
|
|
def convert_to_serializable(obj): |
|
|
if hasattr(obj, 'tolist'): |
|
|
return obj.tolist() |
|
|
elif isinstance(obj, dict): |
|
|
return {k: convert_to_serializable(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [convert_to_serializable(item) for item in obj] |
|
|
else: |
|
|
return obj |
|
|
|
|
|
results_serializable = convert_to_serializable(results) |
|
|
|
|
|
with open(results_path, 'w') as f: |
|
|
json.dump(results_serializable, f, indent=2) |
|
|
|
|
|
print(f"\nπΎ Detailed results saved to: {results_path}") |
|
|
|
|
|
|
|
|
print("\nπ Generating visualizations...") |
|
|
evaluator.plot_confusion_matrix(save_path=os.path.join(config.checkpoint_dir, 'confusion_matrix.png')) |
|
|
evaluator.plot_risk_distribution(save_path=os.path.join(config.checkpoint_dir, 'risk_distribution.png')) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("β
EVALUATION COMPLETE!") |
|
|
print("=" * 80) |
|
|
|
|
|
clf_metrics = results['classification_metrics'] |
|
|
print(f"\nπ― Key Metrics:") |
|
|
print(f" Accuracy: {clf_metrics['accuracy']:.4f}") |
|
|
print(f" F1-Score: {clf_metrics['f1_score']:.4f}") |
|
|
print(f" Precision: {clf_metrics['precision']:.4f}") |
|
|
print(f" Recall: {clf_metrics['recall']:.4f}") |
|
|
|
|
|
reg_metrics = results['regression_metrics'] |
|
|
print(f"\nπ Regression Performance:") |
|
|
print(f" Severity RΒ²: {reg_metrics['severity']['r2_score']:.4f}") |
|
|
print(f" Importance RΒ²: {reg_metrics['importance']['r2_score']:.4f}") |
|
|
|
|
|
print(f"\nπ― Next Steps:") |
|
|
print(f" 1. Apply calibration methods: python calibrate.py") |
|
|
print(f" 2. Analyze error cases") |
|
|
print(f" 3. Compare with baseline methods") |
|
|
|
|
|
return evaluator, results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
evaluator, results = main() |
|
|
|