File size: 11,210 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21613a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""
Inference Script for Legal-BERT Risk Analysis
Run trained model on new legal clauses
"""

import torch
import json
from typing import List, Dict, Any
import argparse

from model import HierarchicalLegalBERT, LegalBertTokenizer
from config import LegalBertConfig


def load_trained_model(checkpoint_path: str, config: LegalBertConfig) -> HierarchicalLegalBERT:
    """Load trained model from checkpoint"""
    print(f"πŸ“₯ Loading model from: {checkpoint_path}")
    
    # PyTorch 2.6+ requires weights_only=False for custom classes
    # This is safe since we control the checkpoint creation
    checkpoint = torch.load(checkpoint_path, map_location=config.device, weights_only=False)
    
    # Get number of risk patterns
    num_risks = len(checkpoint.get('discovered_patterns', {}))
    print(f"   Model has {num_risks} discovered risk patterns")
    
    # CRITICAL FIX: Use the config from checkpoint to get correct architecture parameters
    # This ensures the model architecture matches the trained model
    if 'config' in checkpoint:
        saved_config = checkpoint['config']
        hidden_dim = saved_config.hierarchical_hidden_dim
        num_lstm_layers = saved_config.hierarchical_num_lstm_layers
        print(f"   Using saved architecture: hidden_dim={hidden_dim}, lstm_layers={num_lstm_layers}")
    else:
        # Fallback to current config (for backward compatibility)
        hidden_dim = config.hierarchical_hidden_dim
        num_lstm_layers = config.hierarchical_num_lstm_layers
        print(f"   ⚠️  Warning: No config in checkpoint, using current config")
    
    # Initialize model with correct architecture parameters
    model = HierarchicalLegalBERT(
        config=config,
        num_discovered_risks=num_risks,
        hidden_dim=hidden_dim,
        num_lstm_layers=num_lstm_layers
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(config.device)
    model.eval()
    
    print(f"   βœ… Model loaded successfully")
    
    return model, checkpoint.get('discovered_patterns', {})


def predict_single_clause(
    model: HierarchicalLegalBERT,
    tokenizer: LegalBertTokenizer,
    clause: str,
    config: LegalBertConfig
) -> Dict[str, Any]:
    """Predict risk for a single clause"""
    
    # Tokenize
    encoded = tokenizer.tokenize_clauses([clause], config.max_sequence_length)
    input_ids = encoded['input_ids'].to(config.device)
    attention_mask = encoded['attention_mask'].to(config.device)
    
    # Predict
    with torch.no_grad():
        outputs = model.forward_single_clause(input_ids, attention_mask)
        
        # Get probabilities
        risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1)
        predicted_risk = torch.argmax(risk_probs, dim=-1)
        confidence = torch.max(risk_probs, dim=-1)[0]
        
        return {
            'clause': clause,
            'predicted_risk_id': predicted_risk.cpu().item(),
            'confidence': confidence.cpu().item(),
            'risk_probabilities': risk_probs.cpu().numpy().tolist(),
            'severity_score': outputs['severity_score'].cpu().item(),
            'importance_score': outputs['importance_score'].cpu().item()
        }


def predict_document(
    model: HierarchicalLegalBERT,
    tokenizer: LegalBertTokenizer,
    document: List[List[str]],
    config: LegalBertConfig
) -> Dict[str, Any]:
    """
    Predict risks for a full document with context
    
    Args:
        document: List of sections, each containing list of clauses
            Example: [
                ['clause1', 'clause2'],  # Section 1
                ['clause3', 'clause4'],  # Section 2
            ]
    """
    
    print(f"πŸ“„ Analyzing document with {len(document)} sections...")
    
    # Tokenize document structure
    doc_structure = []
    clause_texts = []
    
    for section_idx, section in enumerate(document):
        section_tokens = []
        for clause_idx, clause in enumerate(section):
            encoded = tokenizer.tokenize_clauses([clause], config.max_sequence_length)
            section_tokens.append({
                'input_ids': encoded['input_ids'][0],
                'attention_mask': encoded['attention_mask'][0]
            })
            clause_texts.append({
                'section': section_idx,
                'clause': clause_idx,
                'text': clause
            })
        doc_structure.append(section_tokens)
    
    # Predict with context
    results = model.predict_document(doc_structure)
    
    # Merge predictions with clause texts
    for i, pred in enumerate(results['clauses']):
        pred['text'] = clause_texts[i]['text']
    
    return results


def format_prediction_output(
    prediction: Dict[str, Any],
    risk_patterns: Dict[str, Any]
) -> str:
    """Format prediction for display"""
    
    risk_id = prediction['predicted_risk_id']
    pattern_names = list(risk_patterns.keys())
    
    # Handle both string and integer pattern names
    if risk_id < len(pattern_names):
        risk_name = str(pattern_names[risk_id])
        risk_info = risk_patterns[pattern_names[risk_id]]
        
        # Extract keywords from pattern info
        if isinstance(risk_info, dict):
            keywords = ', '.join(risk_info.get('keywords', risk_info.get('top_words', []))[:5])
        else:
            keywords = "N/A"
    else:
        risk_name = f"Risk Pattern {risk_id}"
        keywords = "N/A"
    
    output = f"""
{'='*70}
πŸ“‹ CLAUSE ANALYSIS
{'='*70}

πŸ“ Clause:
   {prediction.get('text', prediction.get('clause', 'N/A'))}

🎯 Risk Classification:
   Pattern: {risk_name}
   Confidence: {prediction['confidence']:.1%}
   Keywords: {keywords}

πŸ“Š Risk Scores:
   Severity:   {prediction['severity_score']:.2f}/10
   Importance: {prediction['importance_score']:.2f}/10

πŸ” Probability Distribution:
"""
    
    # Show top 3 risk probabilities
    probs = prediction['risk_probabilities']
    
    # Handle nested list structure (e.g., [[prob1, prob2, ...]])
    if isinstance(probs, list) and len(probs) > 0 and isinstance(probs[0], list):
        probs = probs[0]
    
    top_3_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)[:3]
    
    for idx in top_3_indices:
        if idx < len(pattern_names):
            # Convert pattern name to string and truncate if needed
            pattern_str = str(pattern_names[idx])
            if len(pattern_str) > 40:
                pattern_str = pattern_str[:37] + "..."
            output += f"   {pattern_str:40s} {probs[idx]:.1%}\n"
        else:
            output += f"   Risk Pattern {idx:2d}                          {probs[idx]:.1%}\n"
    
    return output


def main():
    """Main inference function"""
    
    parser = argparse.ArgumentParser(description='Legal-BERT Risk Analysis Inference')
    parser.add_argument('--checkpoint', type=str, default='models/legal_bert/final_model.pt',
                       help='Path to model checkpoint')
    parser.add_argument('--clause', type=str, help='Single clause to analyze')
    parser.add_argument('--document', type=str, help='Path to JSON file with document structure')
    parser.add_argument('--output', type=str, help='Path to save results (JSON)')
    args = parser.parse_args()
    
    print("=" * 70)
    print("πŸ›οΈ  LEGAL-BERT RISK ANALYSIS INFERENCE")
    print("=" * 70)
    
    # Initialize config
    config = LegalBertConfig()
    print(f"\nπŸ“‹ Configuration:")
    print(f"   Device: {config.device}")
    print(f"   Max sequence length: {config.max_sequence_length}")
    
    # Load model
    model, risk_patterns = load_trained_model(args.checkpoint, config)
    tokenizer = LegalBertTokenizer(config.bert_model_name)
    
    print(f"\nπŸ” Discovered Risk Patterns ({len(risk_patterns)}):")
    pattern_names = list(risk_patterns.keys())
    for name in pattern_names[:5]:
        # Convert to string for display
        display_name = str(name)
        print(f"   β€’ {display_name}")
    if len(risk_patterns) > 5:
        print(f"   ... and {len(risk_patterns) - 5} more")
    
    results = []
    
    # Single clause mode
    if args.clause:
        print(f"\n" + "="*70)
        print("MODE: Single Clause Analysis")
        print("="*70)
        
        prediction = predict_single_clause(model, tokenizer, args.clause, config)
        print(format_prediction_output(prediction, risk_patterns))
        results.append(prediction)
    
    # Document mode
    elif args.document:
        print(f"\n" + "="*70)
        print("MODE: Full Document Analysis (with context)")
        print("="*70)
        
        # Load document
        with open(args.document, 'r') as f:
            doc_data = json.load(f)
        
        # Expected format: {"sections": [["clause1", "clause2"], ["clause3"]]}
        document = doc_data.get('sections', [])
        
        prediction = predict_document(model, tokenizer, document, config)
        
        print(f"\nπŸ“Š Document Summary:")
        print(f"   Sections: {prediction['summary']['num_sections']}")
        print(f"   Clauses: {prediction['summary']['num_clauses']}")
        print(f"   Average Severity: {prediction['summary']['avg_severity']:.2f}/10")
        print(f"   High Risk Clauses: {prediction['summary']['high_risk_count']}")
        
        print(f"\nπŸ“‹ Clause-by-Clause Analysis:")
        for clause_pred in prediction['clauses']:
            print(format_prediction_output(clause_pred, risk_patterns))
        
        results = prediction
    
    # Demo mode (no arguments)
    else:
        print(f"\n" + "="*70)
        print("MODE: Demo Analysis")
        print("="*70)
        print("\nπŸ’‘ Running demo with sample clauses...")
        
        demo_clauses = [
            "The party shall indemnify and hold harmless all damages and losses.",
            "This agreement shall be governed by the laws of the state of California.",
            "Payment must be made within thirty days of invoice date.",
            "The licensee must not disclose confidential information to third parties.",
            "Company shall comply with all applicable laws and regulations."
        ]
        
        for clause in demo_clauses:
            prediction = predict_single_clause(model, tokenizer, clause, config)
            print(format_prediction_output(prediction, risk_patterns))
            results.append(prediction)
    
    # Save results if output path provided
    if args.output:
        with open(args.output, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\nπŸ’Ύ Results saved to: {args.output}")
    
    print("\n" + "="*70)
    print("βœ… INFERENCE COMPLETE")
    print("="*70)
    
    # Usage tips
    if not args.clause and not args.document:
        print(f"\nπŸ’‘ Usage Examples:")
        print(f'\n   Single clause:')
        print(f'   python3 inference.py --clause "The party shall indemnify..."')
        print(f'\n   Full document:')
        print(f'   python3 inference.py --document contract.json')
        print(f'\n   Save results:')
        print(f'   python3 inference.py --clause "..." --output results.json')


if __name__ == "__main__":
    main()