File size: 6,390 Bytes
5c0f558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Evaluation Script for Legal-BERT
Executes Week 8: Comprehensive Evaluation & Analysis
"""
import torch
import os
import json
from datetime import datetime

from config import LegalBertConfig
from trainer import LegalBertTrainer, collate_batch
from evaluator import LegalBertEvaluator
from data_loader import CUADDataLoader
from risk_discovery import UnsupervisedRiskDiscovery

def main():
    """Execute Legal-BERT evaluation pipeline"""
    
    print("=" * 80)
    print("πŸ” LEGAL-BERT EVALUATION PIPELINE")
    print("=" * 80)
    
    # Initialize configuration
    config = LegalBertConfig()
    
    # Load trained model
    print("\nπŸ“‚ Loading trained model...")
    model_path = os.path.join(config.model_save_path, 'final_model.pt')
    
    if not os.path.exists(model_path):
        print(f"❌ Error: Model not found at {model_path}")
        print("Please train the model first using: python train.py")
        return
    
    checkpoint = torch.load(model_path, map_location=config.device, weights_only=False)
    
    # Initialize trainer and load model
    trainer = LegalBertTrainer(config)
    
    # Restore risk discovery patterns
    if 'risk_discovery_model' in checkpoint:
        trainer.risk_discovery = checkpoint['risk_discovery_model']
    else:
        # Fallback for older models
        trainer.risk_discovery.discovered_patterns = checkpoint['discovered_patterns']
        trainer.risk_discovery.n_clusters = len(checkpoint['discovered_patterns'])
    
    # Load Hierarchical BERT model
    from model import HierarchicalLegalBERT
    
    # CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters
    if 'config' in checkpoint:
        saved_config = checkpoint['config']
        hidden_dim = saved_config.hierarchical_hidden_dim
        num_lstm_layers = saved_config.hierarchical_num_lstm_layers
        print(f"   Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}")
    else:
        # Fallback to current config (for backward compatibility)
        hidden_dim = config.hierarchical_hidden_dim
        num_lstm_layers = config.hierarchical_num_lstm_layers
        print(f"   ⚠️  Warning: No config in checkpoint, using current config")
    
    print("πŸ“Š Loading Hierarchical BERT model")
    trainer.model = HierarchicalLegalBERT(
        config=config,
        num_discovered_risks=trainer.risk_discovery.n_clusters,
        hidden_dim=hidden_dim,
        num_lstm_layers=num_lstm_layers
    ).to(config.device)
    
    trainer.model.load_state_dict(checkpoint['model_state_dict'])
    
    print("βœ… Model loaded successfully!")
    
    # Load test data
    print("\nπŸ“Š Loading test data...")
    data_loader = CUADDataLoader(config.data_path)
    df_clauses, contracts = data_loader.load_data()
    splits = data_loader.create_splits()
    
    # Prepare test loader
    test_clauses = splits['test']['clause_text'].tolist()
    risk_labels = trainer.risk_discovery.get_risk_labels(test_clauses)
    severity_scores = trainer._generate_synthetic_scores(test_clauses, 'severity')
    importance_scores = trainer._generate_synthetic_scores(test_clauses, 'importance')
    
    from trainer import LegalClauseDataset
    from torch.utils.data import DataLoader
    
    test_dataset = LegalClauseDataset(
        clauses=test_clauses,
        risk_labels=risk_labels,
        severity_scores=severity_scores,
        importance_scores=importance_scores,
        tokenizer=trainer.tokenizer,
        max_length=config.max_sequence_length
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_batch
    )
    
    print(f"βœ… Test data prepared: {len(test_dataset)} samples")
    
    # Initialize evaluator
    print("\n" + "=" * 80)
    print("πŸ“ˆ PHASE 1: MODEL EVALUATION")
    print("=" * 80)
    
    evaluator = LegalBertEvaluator(
        model=trainer.model,
        tokenizer=trainer.tokenizer,
        risk_discovery=trainer.risk_discovery
    )
    
    # Run evaluation
    results = evaluator.evaluate_model(test_loader, save_results=True)
    
    # Generate and display report
    print("\n" + "=" * 80)
    print("πŸ“„ EVALUATION REPORT")
    print("=" * 80)
    
    report = evaluator.generate_report()
    print(report)
    
    # Save detailed results
    results_path = os.path.join(config.checkpoint_dir, 'evaluation_results.json')
    
    # Convert numpy arrays to lists for JSON serialization
    def convert_to_serializable(obj):
        if hasattr(obj, 'tolist'):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {k: convert_to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [convert_to_serializable(item) for item in obj]
        else:
            return obj
    
    results_serializable = convert_to_serializable(results)
    
    with open(results_path, 'w') as f:
        json.dump(results_serializable, f, indent=2)
    
    print(f"\nπŸ’Ύ Detailed results saved to: {results_path}")
    
    # Generate visualizations
    print("\nπŸ“Š Generating visualizations...")
    evaluator.plot_confusion_matrix(save_path=os.path.join(config.checkpoint_dir, 'confusion_matrix.png'))
    evaluator.plot_risk_distribution(save_path=os.path.join(config.checkpoint_dir, 'risk_distribution.png'))
    
    # Summary
    print("\n" + "=" * 80)
    print("βœ… EVALUATION COMPLETE!")
    print("=" * 80)
    
    clf_metrics = results['classification_metrics']
    print(f"\n🎯 Key Metrics:")
    print(f"  Accuracy: {clf_metrics['accuracy']:.4f}")
    print(f"  F1-Score: {clf_metrics['f1_score']:.4f}")
    print(f"  Precision: {clf_metrics['precision']:.4f}")
    print(f"  Recall: {clf_metrics['recall']:.4f}")
    
    reg_metrics = results['regression_metrics']
    print(f"\nπŸ“ˆ Regression Performance:")
    print(f"  Severity RΒ²: {reg_metrics['severity']['r2_score']:.4f}")
    print(f"  Importance RΒ²: {reg_metrics['importance']['r2_score']:.4f}")
    
    print(f"\n🎯 Next Steps:")
    print(f"  1. Apply calibration methods: python calibrate.py")
    print(f"  2. Analyze error cases")
    print(f"  3. Compare with baseline methods")
    
    return evaluator, results

if __name__ == "__main__":
    evaluator, results = main()