File size: 13,626 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21613a7
 
 
 
 
 
 
 
 
 
 
 
9b1c753
 
 
 
 
21613a7
 
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
Calibration Script for Legal-BERT
Executes Week 7: Model Calibration & Uncertainty Quantification
"""
import torch
import os
import json
import numpy as np
from datetime import datetime

from config import LegalBertConfig
from trainer import LegalBertTrainer, LegalClauseDataset, collate_batch
from data_loader import CUADDataLoader
from model import HierarchicalLegalBERT
from torch.utils.data import DataLoader

class CalibrationFramework:
    """
    Calibration methods for Legal-BERT confidence scores
    Week 7 implementation: Temperature Scaling, Platt Scaling, Isotonic Regression
    """
    
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.temperature = 1.0
        
    def collect_logits_and_labels(self, data_loader):
        """Collect logits and true labels from validation set"""
        all_logits = []
        all_labels = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['risk_label']
                
                # Use the correct method for HierarchicalLegalBERT
                outputs = self.model.forward_single_clause(input_ids, attention_mask)
                logits = outputs['risk_logits']
                
                all_logits.append(logits.cpu())
                all_labels.append(labels)
        
        return torch.cat(all_logits), torch.cat(all_labels)
    
    def temperature_scaling(self, val_loader, lr=0.01, max_iter=50):
        """
        Apply temperature scaling calibration
        Learns optimal temperature to calibrate confidence scores
        """
        print("🌑️  Applying temperature scaling...")
        
        # Collect validation logits and labels
        logits, labels = self.collect_logits_and_labels(val_loader)
        
        # Create temperature parameter
        temperature = torch.nn.Parameter(torch.ones(1) * 1.5)
        optimizer = torch.optim.LBFGS([temperature], lr=lr, max_iter=max_iter)
        
        criterion = torch.nn.CrossEntropyLoss()
        
        def eval_loss():
            optimizer.zero_grad()
            loss = criterion(logits / temperature, labels)
            loss.backward()
            return loss
        
        optimizer.step(eval_loss)
        
        self.temperature = temperature.item()
        print(f"  βœ… Optimal temperature: {self.temperature:.4f}")
        
        return self.temperature
    
    def apply_temperature(self, logits):
        """Apply learned temperature to logits"""
        return logits / self.temperature
    
    def calculate_ece(self, data_loader, n_bins=15):
        """
        Calculate Expected Calibration Error (ECE)
        Measures calibration quality
        """
        print("πŸ“Š Calculating Expected Calibration Error (ECE)...")
        
        confidences = []
        predictions = []
        true_labels = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['risk_label']
                
                # Use the correct method for HierarchicalLegalBERT
                outputs = self.model.forward_single_clause(input_ids, attention_mask)
                logits = self.apply_temperature(outputs['risk_logits'])
                
                probs = torch.softmax(logits, dim=-1)
                conf, pred = torch.max(probs, dim=-1)
                
                confidences.extend(conf.cpu().numpy())
                predictions.extend(pred.cpu().numpy())
                true_labels.extend(labels.numpy())
        
        confidences = np.array(confidences)
        predictions = np.array(predictions)
        true_labels = np.array(true_labels)
        
        # Calculate ECE
        ece = 0.0
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        
        for i in range(n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = np.mean(in_bin)
            
            if prop_in_bin > 0:
                accuracy_in_bin = np.mean(predictions[in_bin] == true_labels[in_bin])
                avg_confidence_in_bin = np.mean(confidences[in_bin])
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        print(f"  ECE: {ece:.4f}")
        return ece
    
    def calculate_mce(self, data_loader, n_bins=15):
        """
        Calculate Maximum Calibration Error (MCE)
        """
        print("πŸ“Š Calculating Maximum Calibration Error (MCE)...")
        
        confidences = []
        predictions = []
        true_labels = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['risk_label']
                
                # Use the correct method for HierarchicalLegalBERT
                outputs = self.model.forward_single_clause(input_ids, attention_mask)
                logits = self.apply_temperature(outputs['risk_logits'])
                
                probs = torch.softmax(logits, dim=-1)
                conf, pred = torch.max(probs, dim=-1)
                
                confidences.extend(conf.cpu().numpy())
                predictions.extend(pred.cpu().numpy())
                true_labels.extend(labels.numpy())
        
        confidences = np.array(confidences)
        predictions = np.array(predictions)
        true_labels = np.array(true_labels)
        
        # Calculate MCE
        mce = 0.0
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        
        for i in range(n_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            
            if np.sum(in_bin) > 0:
                accuracy_in_bin = np.mean(predictions[in_bin] == true_labels[in_bin])
                avg_confidence_in_bin = np.mean(confidences[in_bin])
                mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin))
        
        print(f"  MCE: {mce:.4f}")
        return mce

def main():
    """Execute calibration pipeline"""
    
    print("=" * 80)
    print("🌑️  LEGAL-BERT CALIBRATION PIPELINE")
    print("=" * 80)
    
    # Initialize configuration
    config = LegalBertConfig()
    
    # Load trained model
    print("\nπŸ“‚ Loading trained model...")
    model_path = os.path.join(config.model_save_path, 'final_model.pt')
    
    if not os.path.exists(model_path):
        print(f"❌ Error: Model not found at {model_path}")
        print("Please train the model first using: python train.py")
        return
    
    checkpoint = torch.load(model_path, map_location=config.device, weights_only=False)
    
    # CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters
    if 'config' in checkpoint:
        saved_config = checkpoint['config']
        hidden_dim = saved_config.hierarchical_hidden_dim
        num_lstm_layers = saved_config.hierarchical_num_lstm_layers
        print(f"   Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}")
    else:
        # Fallback to current config (for backward compatibility)
        hidden_dim = config.hierarchical_hidden_dim
        num_lstm_layers = config.hierarchical_num_lstm_layers
        print(f"   ⚠️  Warning: No config in checkpoint, using current config")
    
    # Initialize and load Hierarchical BERT model
    print("πŸ“Š Loading Hierarchical BERT model")
    model = HierarchicalLegalBERT(
        config=config,
        num_discovered_risks=len(checkpoint['discovered_patterns']),
        hidden_dim=hidden_dim,
        num_lstm_layers=num_lstm_layers
    ).to(config.device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print("βœ… Model loaded successfully!")
    
    # Load validation and test data
    print("\nπŸ“Š Loading data...")
    data_loader = CUADDataLoader(config.data_path)
    df_clauses, contracts = data_loader.load_data()
    splits = data_loader.create_splits()
    
    # Initialize trainer for helper methods
    trainer = LegalBertTrainer(config)
    
    # Restore risk discovery model (including fitted LDA/K-Means)
    if 'risk_discovery_model' in checkpoint:
        trainer.risk_discovery = checkpoint['risk_discovery_model']
    else:
        # Fallback for older models
        trainer.risk_discovery.discovered_patterns = checkpoint['discovered_patterns']
        trainer.risk_discovery.n_clusters = len(checkpoint['discovered_patterns'])
    
    trainer.model = model
    
    # Prepare validation and test loaders
    val_clauses = splits['val']['clause_text'].tolist()
    test_clauses = splits['test']['clause_text'].tolist()
    
    val_risk_labels = trainer.risk_discovery.get_risk_labels(val_clauses)
    test_risk_labels = trainer.risk_discovery.get_risk_labels(test_clauses)
    
    val_dataset = LegalClauseDataset(
        clauses=val_clauses,
        risk_labels=val_risk_labels,
        severity_scores=trainer._generate_synthetic_scores(val_clauses, 'severity'),
        importance_scores=trainer._generate_synthetic_scores(val_clauses, 'importance'),
        tokenizer=trainer.tokenizer,
        max_length=config.max_sequence_length
    )
    
    test_dataset = LegalClauseDataset(
        clauses=test_clauses,
        risk_labels=test_risk_labels,
        severity_scores=trainer._generate_synthetic_scores(test_clauses, 'severity'),
        importance_scores=trainer._generate_synthetic_scores(test_clauses, 'importance'),
        tokenizer=trainer.tokenizer,
        max_length=config.max_sequence_length
    )
    
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_batch)
    
    print(f"βœ… Data loaded: {len(val_dataset)} val, {len(test_dataset)} test samples")
    
    # Initialize calibration framework
    print("\n" + "=" * 80)
    print("🌑️  PHASE 1: CALIBRATION")
    print("=" * 80)
    
    calibrator = CalibrationFramework(model, config.device)
    
    # Calculate pre-calibration metrics
    print("\nπŸ“Š Pre-calibration metrics:")
    ece_before = calibrator.calculate_ece(test_loader)
    mce_before = calibrator.calculate_mce(test_loader)
    
    # Apply temperature scaling
    print("\nπŸ”§ Calibrating model...")
    optimal_temp = calibrator.temperature_scaling(val_loader)
    
    # Calculate post-calibration metrics
    print("\nπŸ“Š Post-calibration metrics:")
    ece_after = calibrator.calculate_ece(test_loader)
    mce_after = calibrator.calculate_mce(test_loader)
    
    # Save calibration results
    print("\n" + "=" * 80)
    print("πŸ’Ύ SAVING RESULTS")
    print("=" * 80)
    
    calibration_results = {
        'calibration_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'optimal_temperature': optimal_temp,
        'metrics': {
            'pre_calibration': {
                'ece': float(ece_before),
                'mce': float(mce_before)
            },
            'post_calibration': {
                'ece': float(ece_after),
                'mce': float(mce_after)
            },
            'improvement': {
                'ece': float(ece_before - ece_after),
                'mce': float(mce_before - mce_after)
            }
        }
    }
    
    results_path = os.path.join(config.checkpoint_dir, 'calibration_results.json')
    with open(results_path, 'w') as f:
        json.dump(calibration_results, f, indent=2)
    
    print(f"βœ… Results saved to: {results_path}")
    
    # Save calibrated model
    calibrated_model_path = os.path.join(config.model_save_path, 'calibrated_model.pt')
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config,
        'discovered_patterns': checkpoint['discovered_patterns'],
        'temperature': optimal_temp,
        'calibration_results': calibration_results
    }, calibrated_model_path)
    
    print(f"βœ… Calibrated model saved to: {calibrated_model_path}")
    
    # Summary
    print("\n" + "=" * 80)
    print("βœ… CALIBRATION COMPLETE!")
    print("=" * 80)
    
    print(f"\n🎯 Calibration Results:")
    print(f"  Optimal Temperature: {optimal_temp:.4f}")
    print(f"\n  ECE Improvement: {ece_before:.4f} β†’ {ece_after:.4f} (Ξ” {ece_before - ece_after:.4f})")
    print(f"  MCE Improvement: {mce_before:.4f} β†’ {mce_after:.4f} (Ξ” {mce_before - mce_after:.4f})")
    
    if ece_after < 0.08:
        print(f"\n  βœ… Target ECE (<0.08) achieved!")
    else:
        print(f"\n  ⚠️  ECE slightly above target (0.08)")
    
    print(f"\n🎯 Next Steps:")
    print(f"  1. Analyze calibration quality across risk categories")
    print(f"  2. Compare with baseline methods")
    print(f"  3. Generate final implementation report")
    
    return calibrator, calibration_results

if __name__ == "__main__":
    calibrator, results = main()