File size: 6,390 Bytes
5c0f558 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""
Evaluation Script for Legal-BERT
Executes Week 8: Comprehensive Evaluation & Analysis
"""
import torch
import os
import json
from datetime import datetime
from config import LegalBertConfig
from trainer import LegalBertTrainer, collate_batch
from evaluator import LegalBertEvaluator
from data_loader import CUADDataLoader
from risk_discovery import UnsupervisedRiskDiscovery
def main():
"""Execute Legal-BERT evaluation pipeline"""
print("=" * 80)
print("π LEGAL-BERT EVALUATION PIPELINE")
print("=" * 80)
# Initialize configuration
config = LegalBertConfig()
# Load trained model
print("\nπ Loading trained model...")
model_path = os.path.join(config.model_save_path, 'final_model.pt')
if not os.path.exists(model_path):
print(f"β Error: Model not found at {model_path}")
print("Please train the model first using: python train.py")
return
checkpoint = torch.load(model_path, map_location=config.device, weights_only=False)
# Initialize trainer and load model
trainer = LegalBertTrainer(config)
# Restore risk discovery patterns
if 'risk_discovery_model' in checkpoint:
trainer.risk_discovery = checkpoint['risk_discovery_model']
else:
# Fallback for older models
trainer.risk_discovery.discovered_patterns = checkpoint['discovered_patterns']
trainer.risk_discovery.n_clusters = len(checkpoint['discovered_patterns'])
# Load Hierarchical BERT model
from model import HierarchicalLegalBERT
# CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters
if 'config' in checkpoint:
saved_config = checkpoint['config']
hidden_dim = saved_config.hierarchical_hidden_dim
num_lstm_layers = saved_config.hierarchical_num_lstm_layers
print(f" Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}")
else:
# Fallback to current config (for backward compatibility)
hidden_dim = config.hierarchical_hidden_dim
num_lstm_layers = config.hierarchical_num_lstm_layers
print(f" β οΈ Warning: No config in checkpoint, using current config")
print("π Loading Hierarchical BERT model")
trainer.model = HierarchicalLegalBERT(
config=config,
num_discovered_risks=trainer.risk_discovery.n_clusters,
hidden_dim=hidden_dim,
num_lstm_layers=num_lstm_layers
).to(config.device)
trainer.model.load_state_dict(checkpoint['model_state_dict'])
print("β
Model loaded successfully!")
# Load test data
print("\nπ Loading test data...")
data_loader = CUADDataLoader(config.data_path)
df_clauses, contracts = data_loader.load_data()
splits = data_loader.create_splits()
# Prepare test loader
test_clauses = splits['test']['clause_text'].tolist()
risk_labels = trainer.risk_discovery.get_risk_labels(test_clauses)
severity_scores = trainer._generate_synthetic_scores(test_clauses, 'severity')
importance_scores = trainer._generate_synthetic_scores(test_clauses, 'importance')
from trainer import LegalClauseDataset
from torch.utils.data import DataLoader
test_dataset = LegalClauseDataset(
clauses=test_clauses,
risk_labels=risk_labels,
severity_scores=severity_scores,
importance_scores=importance_scores,
tokenizer=trainer.tokenizer,
max_length=config.max_sequence_length
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=0,
collate_fn=collate_batch
)
print(f"β
Test data prepared: {len(test_dataset)} samples")
# Initialize evaluator
print("\n" + "=" * 80)
print("π PHASE 1: MODEL EVALUATION")
print("=" * 80)
evaluator = LegalBertEvaluator(
model=trainer.model,
tokenizer=trainer.tokenizer,
risk_discovery=trainer.risk_discovery
)
# Run evaluation
results = evaluator.evaluate_model(test_loader, save_results=True)
# Generate and display report
print("\n" + "=" * 80)
print("π EVALUATION REPORT")
print("=" * 80)
report = evaluator.generate_report()
print(report)
# Save detailed results
results_path = os.path.join(config.checkpoint_dir, 'evaluation_results.json')
# Convert numpy arrays to lists for JSON serialization
def convert_to_serializable(obj):
if hasattr(obj, 'tolist'):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_to_serializable(item) for item in obj]
else:
return obj
results_serializable = convert_to_serializable(results)
with open(results_path, 'w') as f:
json.dump(results_serializable, f, indent=2)
print(f"\nπΎ Detailed results saved to: {results_path}")
# Generate visualizations
print("\nπ Generating visualizations...")
evaluator.plot_confusion_matrix(save_path=os.path.join(config.checkpoint_dir, 'confusion_matrix.png'))
evaluator.plot_risk_distribution(save_path=os.path.join(config.checkpoint_dir, 'risk_distribution.png'))
# Summary
print("\n" + "=" * 80)
print("β
EVALUATION COMPLETE!")
print("=" * 80)
clf_metrics = results['classification_metrics']
print(f"\nπ― Key Metrics:")
print(f" Accuracy: {clf_metrics['accuracy']:.4f}")
print(f" F1-Score: {clf_metrics['f1_score']:.4f}")
print(f" Precision: {clf_metrics['precision']:.4f}")
print(f" Recall: {clf_metrics['recall']:.4f}")
reg_metrics = results['regression_metrics']
print(f"\nπ Regression Performance:")
print(f" Severity RΒ²: {reg_metrics['severity']['r2_score']:.4f}")
print(f" Importance RΒ²: {reg_metrics['importance']['r2_score']:.4f}")
print(f"\nπ― Next Steps:")
print(f" 1. Apply calibration methods: python calibrate.py")
print(f" 2. Analyze error cases")
print(f" 3. Compare with baseline methods")
return evaluator, results
if __name__ == "__main__":
evaluator, results = main()
|