File size: 11,506 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
Advanced Analysis Script for Legal-BERT
Demonstrates attention analysis, hierarchical risk modeling, and risk dependencies

This script showcases the newly implemented features:
1. Attention mechanism analysis for clause importance
2. Hierarchical risk aggregation (clause β†’ contract level)
3. Risk dependency and interaction analysis
"""
import torch
import json
from typing import Dict, List, Any
import numpy as np

from config import LegalBertConfig
from model import HierarchicalLegalBERT, LegalBertTokenizer
from evaluator import LegalBertEvaluator
from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer
from risk_discovery import UnsupervisedRiskDiscovery


def load_trained_model(model_path: str, config: LegalBertConfig):
    """Load a trained Hierarchical Legal-BERT model"""
    print(f"πŸ“‚ Loading model from {model_path}...")
    
    try:
        checkpoint = torch.load(model_path, map_location=config.device)
        
        num_discovered_risks = len(checkpoint.get('discovered_patterns', {}))
        
        print("πŸ“Š Loading Hierarchical BERT model")
        model = HierarchicalLegalBERT(
            config, 
            num_discovered_risks=num_discovered_risks,
            hidden_dim=config.hierarchical_hidden_dim,
            num_lstm_layers=config.hierarchical_num_lstm_layers
        )
        
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(config.device)
        model.eval()
        print("βœ… Model loaded successfully")
        return model
        
    except FileNotFoundError:
        print("⚠️ Model file not found. Please train the model first.")
        return None


def demo_attention_analysis(model, tokenizer, sample_clauses: List[str]):
    """Demonstrate attention mechanism analysis"""
    print("\n" + "="*80)
    print("πŸ” ATTENTION MECHANISM ANALYSIS")
    print("="*80)
    
    for idx, clause in enumerate(sample_clauses[:3]):
        print(f"\nπŸ“„ Analyzing Clause {idx + 1}:")
        print(f"Text: {clause[:100]}..." if len(clause) > 100 else f"Text: {clause}")
        
        # Tokenize
        tokens = tokenizer.tokenize_clauses([clause])
        input_ids = tokens['input_ids'].to(model.config.device)
        attention_mask = tokens['attention_mask'].to(model.config.device)
        
        # Get attention analysis
        analysis = model.analyze_attention(input_ids, attention_mask, tokenizer)
        
        # Get prediction
        prediction = model.predict_risk_pattern(input_ids, attention_mask)
        
        print(f"\n  Predicted Risk ID: {prediction['predicted_risk_id'][0]}")
        print(f"  Severity: {prediction['severity_score'][0]:.2f}/10")
        print(f"  Importance: {prediction['importance_score'][0]:.2f}/10")
        print(f"  Confidence: {prediction['confidence'][0]:.2%}")
        
        if 'top_tokens' in analysis:
            print(f"\n  🎯 Most Important Tokens:")
            for token, score in zip(analysis['top_tokens'][:5], 
                                   analysis['top_token_scores'][0][:5]):
                print(f"    {token}: {score:.4f}")
    
    print("\nβœ… Attention analysis complete")


def demo_hierarchical_risk(model, tokenizer, contract_clauses: Dict[str, List[str]]):
    """Demonstrate hierarchical risk aggregation"""
    print("\n" + "="*80)
    print("πŸ“Š HIERARCHICAL RISK AGGREGATION (Clause β†’ Contract)")
    print("="*80)
    
    aggregator = HierarchicalRiskAggregator()
    
    for contract_name, clauses in contract_clauses.items():
        print(f"\nπŸ“‹ Analyzing Contract: {contract_name}")
        print(f"   Number of clauses: {len(clauses)}")
        
        # Get predictions for all clauses
        clause_predictions = []
        
        model.eval()
        with torch.no_grad():
            for clause in clauses:
                tokens = tokenizer.tokenize_clauses([clause])
                input_ids = tokens['input_ids'].to(model.config.device)
                attention_mask = tokens['attention_mask'].to(model.config.device)
                
                pred = model.predict_risk_pattern(input_ids, attention_mask)
                
                clause_predictions.append({
                    'predicted_risk_id': int(pred['predicted_risk_id'][0]),
                    'confidence': float(pred['confidence'][0]),
                    'severity_score': float(pred['severity_score'][0]),
                    'importance_score': float(pred['importance_score'][0])
                })
        
        # Aggregate to contract level
        contract_risk = aggregator.aggregate_contract_risk(
            clause_predictions, 
            method='weighted_mean'
        )
        
        # Display results
        print(f"\n   Contract-Level Assessment:")
        print(f"   β”œβ”€ Risk Category: {contract_risk['contract_risk_id']}")
        print(f"   β”œβ”€ Overall Severity: {contract_risk['contract_severity']:.2f}/10")
        print(f"   β”œβ”€ Overall Importance: {contract_risk['contract_importance']:.2f}/10")
        print(f"   β”œβ”€ Confidence: {contract_risk['contract_confidence']:.2%}")
        print(f"   └─ High-Risk Clauses: {len(contract_risk['high_risk_clauses'])}")
        
        # Generate report
        report = aggregator.generate_contract_report(clause_predictions, contract_name)
        print(report)
    
    print("\nβœ… Hierarchical risk analysis complete")


def demo_risk_dependencies(model, tokenizer, contract_clauses: Dict[str, List[str]]):
    """Demonstrate risk dependency analysis"""
    print("\n" + "="*80)
    print("πŸ”— RISK DEPENDENCY & INTERACTION ANALYSIS")
    print("="*80)
    
    dependency_analyzer = RiskDependencyAnalyzer()
    
    # Collect predictions for all contracts
    all_contract_predictions = []
    
    model.eval()
    with torch.no_grad():
        for contract_name, clauses in contract_clauses.items():
            clause_predictions = []
            
            for clause in clauses:
                tokens = tokenizer.tokenize_clauses([clause])
                input_ids = tokens['input_ids'].to(model.config.device)
                attention_mask = tokens['attention_mask'].to(model.config.device)
                
                pred = model.predict_risk_pattern(input_ids, attention_mask)
                
                clause_predictions.append({
                    'predicted_risk_id': int(pred['predicted_risk_id'][0]),
                    'confidence': float(pred['confidence'][0]),
                    'severity_score': float(pred['severity_score'][0]),
                    'importance_score': float(pred['importance_score'][0])
                })
            
            all_contract_predictions.append(clause_predictions)
    
    # Compute risk correlation
    print("\nπŸ“ˆ Computing risk correlation matrix...")
    correlation = dependency_analyzer.compute_risk_correlation(
        all_contract_predictions, 
        num_risk_types=7
    )
    
    print("\n  Risk Type Correlation Matrix (7x7):")
    print("  " + "-"*50)
    for i, row in enumerate(correlation):
        print(f"  Risk {i}: " + " ".join([f"{val:6.3f}" for val in row]))
    
    # Analyze risk amplification
    print("\n⚑ Analyzing risk amplification effects...")
    all_clauses = [pred for contract in all_contract_predictions for pred in contract]
    amplification = dependency_analyzer.analyze_risk_amplification(all_clauses)
    
    print("\n  Risk Amplification Analysis:")
    for risk_id, stats in sorted(amplification.items(), 
                                 key=lambda x: x[1]['avg_severity'], 
                                 reverse=True):
        print(f"  Risk {risk_id}:")
        print(f"    β”œβ”€ Avg Severity: {stats['avg_severity']:.2f}")
        print(f"    β”œβ”€ Max Severity: {stats['max_severity']:.2f}")
        print(f"    β”œβ”€ Clause Count: {stats['clause_count']}")
        print(f"    └─ Severity Variance: {stats['severity_variance']:.2f}")
    
    # Find risk chains
    print("\nπŸ”— Identifying common risk chains...")
    all_chains = []
    for clause_preds in all_contract_predictions:
        chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3)
        all_chains.extend(chains)
    
    from collections import Counter
    chain_counts = Counter([tuple(chain) for chain in all_chains])
    most_common = chain_counts.most_common(5)
    
    print(f"\n  Top 5 Most Common Risk Chains:")
    for chain, count in most_common:
        print(f"    {list(chain)} β†’ appeared {count} times")
    
    print("\nβœ… Risk dependency analysis complete")


def main():
    """Main demonstration script"""
    print("="*80)
    print("πŸ›οΈ  LEGAL-BERT ADVANCED ANALYSIS DEMONSTRATION")
    print("="*80)
    
    # Initialize configuration
    config = LegalBertConfig()
    
    # Load model
    model_path = f"{config.model_save_path}/best_model.pt"
    model = load_trained_model(model_path, config)
    
    if model is None:
        print("\n⚠️ Cannot proceed without trained model.")
        print("   Please run 'python train.py' first to train the model.")
        return
    
    # Initialize tokenizer
    tokenizer = LegalBertTokenizer(config.bert_model_name)
    
    # Sample clauses for demonstration
    sample_clauses = [
        "The Company shall indemnify and hold harmless the Customer from any claims, damages, or liabilities arising from breach of this Agreement.",
        "Either party may terminate this Agreement upon thirty (30) days written notice to the other party.",
        "All intellectual property rights in the deliverables shall remain the exclusive property of the Company.",
        "The Customer agrees to pay the Company a monthly fee of $10,000 for the services provided under this Agreement."
    ]
    
    # Sample contracts (multiple clauses per contract)
    contract_clauses = {
        "Service_Agreement_001": [
            "The Service Provider agrees to provide software development services as specified in Exhibit A.",
            "Payment shall be made within 30 days of invoice receipt.",
            "The Service Provider shall indemnify Client against all third-party claims arising from the services.",
            "This Agreement may be terminated by either party with 60 days notice."
        ],
        "License_Agreement_002": [
            "Licensor grants Licensee a non-exclusive, worldwide license to use the Software.",
            "Licensee shall pay annual license fees of $50,000.",
            "All intellectual property rights remain with Licensor.",
            "Confidential information must be kept confidential for 5 years."
        ]
    }
    
    # Run demonstrations
    try:
        # 1. Attention Analysis
        demo_attention_analysis(model, tokenizer, sample_clauses)
        
        # 2. Hierarchical Risk Modeling
        demo_hierarchical_risk(model, tokenizer, contract_clauses)
        
        # 3. Risk Dependencies
        demo_risk_dependencies(model, tokenizer, contract_clauses)
        
    except Exception as e:
        print(f"\n❌ Error during analysis: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n" + "="*80)
    print("πŸŽ‰ ADVANCED ANALYSIS DEMONSTRATION COMPLETE")
    print("="*80)
    print("\nThese features are now integrated into the evaluation pipeline.")
    print("Use them during training evaluation or post-training analysis.")


if __name__ == "__main__":
    main()