code2-repo-deBERTa / evaluate.py
Deepu1965's picture
Upload folder using huggingface_hub
5c0f558 verified
"""
Evaluation Script for Legal-BERT
Executes Week 8: Comprehensive Evaluation & Analysis
"""
import torch
import os
import json
from datetime import datetime
from config import LegalBertConfig
from trainer import LegalBertTrainer, collate_batch
from evaluator import LegalBertEvaluator
from data_loader import CUADDataLoader
from risk_discovery import UnsupervisedRiskDiscovery
def main():
"""Execute Legal-BERT evaluation pipeline"""
print("=" * 80)
print("πŸ” LEGAL-BERT EVALUATION PIPELINE")
print("=" * 80)
# Initialize configuration
config = LegalBertConfig()
# Load trained model
print("\nπŸ“‚ Loading trained model...")
model_path = os.path.join(config.model_save_path, 'final_model.pt')
if not os.path.exists(model_path):
print(f"❌ Error: Model not found at {model_path}")
print("Please train the model first using: python train.py")
return
checkpoint = torch.load(model_path, map_location=config.device, weights_only=False)
# Initialize trainer and load model
trainer = LegalBertTrainer(config)
# Restore risk discovery patterns
if 'risk_discovery_model' in checkpoint:
trainer.risk_discovery = checkpoint['risk_discovery_model']
else:
# Fallback for older models
trainer.risk_discovery.discovered_patterns = checkpoint['discovered_patterns']
trainer.risk_discovery.n_clusters = len(checkpoint['discovered_patterns'])
# Load Hierarchical BERT model
from model import HierarchicalLegalBERT
# CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters
if 'config' in checkpoint:
saved_config = checkpoint['config']
hidden_dim = saved_config.hierarchical_hidden_dim
num_lstm_layers = saved_config.hierarchical_num_lstm_layers
print(f" Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}")
else:
# Fallback to current config (for backward compatibility)
hidden_dim = config.hierarchical_hidden_dim
num_lstm_layers = config.hierarchical_num_lstm_layers
print(f" ⚠️ Warning: No config in checkpoint, using current config")
print("πŸ“Š Loading Hierarchical BERT model")
trainer.model = HierarchicalLegalBERT(
config=config,
num_discovered_risks=trainer.risk_discovery.n_clusters,
hidden_dim=hidden_dim,
num_lstm_layers=num_lstm_layers
).to(config.device)
trainer.model.load_state_dict(checkpoint['model_state_dict'])
print("βœ… Model loaded successfully!")
# Load test data
print("\nπŸ“Š Loading test data...")
data_loader = CUADDataLoader(config.data_path)
df_clauses, contracts = data_loader.load_data()
splits = data_loader.create_splits()
# Prepare test loader
test_clauses = splits['test']['clause_text'].tolist()
risk_labels = trainer.risk_discovery.get_risk_labels(test_clauses)
severity_scores = trainer._generate_synthetic_scores(test_clauses, 'severity')
importance_scores = trainer._generate_synthetic_scores(test_clauses, 'importance')
from trainer import LegalClauseDataset
from torch.utils.data import DataLoader
test_dataset = LegalClauseDataset(
clauses=test_clauses,
risk_labels=risk_labels,
severity_scores=severity_scores,
importance_scores=importance_scores,
tokenizer=trainer.tokenizer,
max_length=config.max_sequence_length
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=0,
collate_fn=collate_batch
)
print(f"βœ… Test data prepared: {len(test_dataset)} samples")
# Initialize evaluator
print("\n" + "=" * 80)
print("πŸ“ˆ PHASE 1: MODEL EVALUATION")
print("=" * 80)
evaluator = LegalBertEvaluator(
model=trainer.model,
tokenizer=trainer.tokenizer,
risk_discovery=trainer.risk_discovery
)
# Run evaluation
results = evaluator.evaluate_model(test_loader, save_results=True)
# Generate and display report
print("\n" + "=" * 80)
print("πŸ“„ EVALUATION REPORT")
print("=" * 80)
report = evaluator.generate_report()
print(report)
# Save detailed results
results_path = os.path.join(config.checkpoint_dir, 'evaluation_results.json')
# Convert numpy arrays to lists for JSON serialization
def convert_to_serializable(obj):
if hasattr(obj, 'tolist'):
return obj.tolist()
elif isinstance(obj, dict):
return {k: convert_to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_to_serializable(item) for item in obj]
else:
return obj
results_serializable = convert_to_serializable(results)
with open(results_path, 'w') as f:
json.dump(results_serializable, f, indent=2)
print(f"\nπŸ’Ύ Detailed results saved to: {results_path}")
# Generate visualizations
print("\nπŸ“Š Generating visualizations...")
evaluator.plot_confusion_matrix(save_path=os.path.join(config.checkpoint_dir, 'confusion_matrix.png'))
evaluator.plot_risk_distribution(save_path=os.path.join(config.checkpoint_dir, 'risk_distribution.png'))
# Summary
print("\n" + "=" * 80)
print("βœ… EVALUATION COMPLETE!")
print("=" * 80)
clf_metrics = results['classification_metrics']
print(f"\n🎯 Key Metrics:")
print(f" Accuracy: {clf_metrics['accuracy']:.4f}")
print(f" F1-Score: {clf_metrics['f1_score']:.4f}")
print(f" Precision: {clf_metrics['precision']:.4f}")
print(f" Recall: {clf_metrics['recall']:.4f}")
reg_metrics = results['regression_metrics']
print(f"\nπŸ“ˆ Regression Performance:")
print(f" Severity RΒ²: {reg_metrics['severity']['r2_score']:.4f}")
print(f" Importance RΒ²: {reg_metrics['importance']['r2_score']:.4f}")
print(f"\n🎯 Next Steps:")
print(f" 1. Apply calibration methods: python calibrate.py")
print(f" 2. Analyze error cases")
print(f" 3. Compare with baseline methods")
return evaluator, results
if __name__ == "__main__":
evaluator, results = main()