File size: 20,229 Bytes
9b1c753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""
Hierarchical Risk Modeling: Clause-to-Contract Level Aggregation

This module implements hierarchical risk assessment that aggregates clause-level
predictions to contract-level risk scores and insights.
"""
import numpy as np
import torch
from typing import Dict, List, Any, Tuple
from collections import defaultdict
import warnings


class HierarchicalRiskAggregator:
    """
    Aggregates clause-level risk predictions to contract-level risk assessment.
    
    Supports multiple aggregation strategies:
    - Maximum risk (worst-case scenario)
    - Average risk (overall risk profile)
    - Weighted average (importance-weighted)
    - Risk distribution analysis
    """
    
    def __init__(self):
        """Initialize the hierarchical risk aggregator"""
        self.aggregation_methods = {
            'max': self._aggregate_max,
            'mean': self._aggregate_mean,
            'weighted_mean': self._aggregate_weighted,
            'risk_distribution': self._aggregate_distribution,
            'severity_weighted': self._aggregate_severity_weighted
        }
    
    def aggregate_contract_risk(self, 
                               clause_predictions: List[Dict[str, Any]],
                               method: str = 'weighted_mean') -> Dict[str, Any]:
        """
        Aggregate clause-level predictions to contract-level risk assessment.
        
        Args:
            clause_predictions: List of dictionaries containing clause-level predictions
                Each dict should have: risk_id, confidence, severity, importance
            method: Aggregation method ('max', 'mean', 'weighted_mean', 'risk_distribution')
        
        Returns:
            Dictionary containing contract-level risk assessment
        """
        if not clause_predictions:
            return {'error': 'No clause predictions provided'}
        
        if method not in self.aggregation_methods:
            warnings.warn(f"Unknown method {method}, using 'weighted_mean'")
            method = 'weighted_mean'
        
        # Extract predictions
        risk_ids = np.array([p['predicted_risk_id'] for p in clause_predictions])
        confidences = np.array([p['confidence'] for p in clause_predictions])
        severities = np.array([p['severity_score'] for p in clause_predictions])
        importances = np.array([p['importance_score'] for p in clause_predictions])
        
        # Apply aggregation method
        contract_risk = self.aggregation_methods[method](
            risk_ids, confidences, severities, importances
        )
        
        # Add common statistics
        contract_risk.update({
            'num_clauses': len(clause_predictions),
            'clause_statistics': self._compute_clause_statistics(
                risk_ids, confidences, severities, importances
            ),
            'risk_distribution': self._compute_risk_distribution(risk_ids, severities),
            'high_risk_clauses': self._identify_high_risk_clauses(
                clause_predictions, threshold=7.0
            )
        })
        
        return contract_risk
    
    def _aggregate_max(self, risk_ids, confidences, severities, importances) -> Dict[str, Any]:
        """Maximum risk aggregation (worst-case scenario)"""
        max_severity_idx = np.argmax(severities)
        
        return {
            'contract_risk_id': int(risk_ids[max_severity_idx]),
            'contract_severity': float(severities[max_severity_idx]),
            'contract_importance': float(importances[max_severity_idx]),
            'contract_confidence': float(confidences[max_severity_idx]),
            'aggregation_method': 'max',
            'rationale': 'Based on highest severity clause'
        }
    
    def _aggregate_mean(self, risk_ids, confidences, severities, importances) -> Dict[str, Any]:
        """Simple mean aggregation"""
        # Most common risk type
        unique_risks, counts = np.unique(risk_ids, return_counts=True)
        dominant_risk = unique_risks[np.argmax(counts)]
        
        return {
            'contract_risk_id': int(dominant_risk),
            'contract_severity': float(np.mean(severities)),
            'contract_importance': float(np.mean(importances)),
            'contract_confidence': float(np.mean(confidences)),
            'aggregation_method': 'mean',
            'rationale': 'Based on average across all clauses'
        }
    
    def _aggregate_weighted(self, risk_ids, confidences, severities, importances) -> Dict[str, Any]:
        """Importance-weighted aggregation"""
        # Normalize importance scores to use as weights
        weights = importances / np.sum(importances) if np.sum(importances) > 0 else np.ones_like(importances) / len(importances)
        
        # Weighted average of severity
        weighted_severity = float(np.sum(severities * weights))
        weighted_importance = float(np.sum(importances * weights))
        weighted_confidence = float(np.sum(confidences * weights))
        
        # Weight risk types by their importance
        risk_weights = defaultdict(float)
        for risk_id, weight in zip(risk_ids, weights):
            risk_weights[risk_id] += weight
        
        dominant_risk = max(risk_weights.items(), key=lambda x: x[1])[0]
        
        return {
            'contract_risk_id': int(dominant_risk),
            'contract_severity': weighted_severity,
            'contract_importance': weighted_importance,
            'contract_confidence': weighted_confidence,
            'aggregation_method': 'weighted_mean',
            'rationale': 'Weighted by clause importance scores'
        }
    
    def _aggregate_severity_weighted(self, risk_ids, confidences, severities, importances) -> Dict[str, Any]:
        """Severity-weighted aggregation (emphasizes high-risk clauses)"""
        # Use severity as weights
        weights = severities / np.sum(severities) if np.sum(severities) > 0 else np.ones_like(severities) / len(severities)
        
        # Weighted statistics
        weighted_severity = float(np.sum(severities * weights))
        weighted_importance = float(np.sum(importances * weights))
        weighted_confidence = float(np.sum(confidences * weights))
        
        # Weight risk types by their severity
        risk_weights = defaultdict(float)
        for risk_id, weight in zip(risk_ids, weights):
            risk_weights[risk_id] += weight
        
        dominant_risk = max(risk_weights.items(), key=lambda x: x[1])[0]
        
        return {
            'contract_risk_id': int(dominant_risk),
            'contract_severity': weighted_severity,
            'contract_importance': weighted_importance,
            'contract_confidence': weighted_confidence,
            'aggregation_method': 'severity_weighted',
            'rationale': 'Weighted by clause severity (emphasizes high-risk clauses)'
        }
    
    def _aggregate_distribution(self, risk_ids, confidences, severities, importances) -> Dict[str, Any]:
        """Risk distribution-based aggregation"""
        # Analyze risk distribution
        unique_risks, counts = np.unique(risk_ids, return_counts=True)
        risk_proportions = counts / len(risk_ids)
        
        # Calculate diversity (entropy)
        entropy = -np.sum(risk_proportions * np.log(risk_proportions + 1e-10))
        
        # Dominant risk with highest severity
        risk_severities = {}
        for risk_id in unique_risks:
            mask = risk_ids == risk_id
            risk_severities[risk_id] = np.mean(severities[mask])
        
        dominant_risk = max(risk_severities.items(), key=lambda x: x[1])[0]
        
        return {
            'contract_risk_id': int(dominant_risk),
            'contract_severity': float(np.mean(severities)),
            'contract_importance': float(np.mean(importances)),
            'contract_confidence': float(np.mean(confidences)),
            'risk_diversity': float(entropy),
            'aggregation_method': 'risk_distribution',
            'rationale': 'Based on risk distribution analysis'
        }
    
    def _compute_clause_statistics(self, risk_ids, confidences, severities, importances) -> Dict[str, float]:
        """Compute statistical summary of clause-level predictions"""
        return {
            'mean_severity': float(np.mean(severities)),
            'std_severity': float(np.std(severities)),
            'max_severity': float(np.max(severities)),
            'min_severity': float(np.min(severities)),
            'mean_importance': float(np.mean(importances)),
            'std_importance': float(np.std(importances)),
            'mean_confidence': float(np.mean(confidences)),
            'std_confidence': float(np.std(confidences))
        }
    
    def _compute_risk_distribution(self, risk_ids, severities) -> Dict[int, Dict[str, float]]:
        """Compute distribution of risk types and their statistics"""
        risk_dist = {}
        unique_risks = np.unique(risk_ids)
        
        for risk_id in unique_risks:
            mask = risk_ids == risk_id
            risk_dist[int(risk_id)] = {
                'count': int(np.sum(mask)),
                'proportion': float(np.mean(mask)),
                'avg_severity': float(np.mean(severities[mask])),
                'max_severity': float(np.max(severities[mask]))
            }
        
        return risk_dist
    
    def _identify_high_risk_clauses(self, clause_predictions: List[Dict[str, Any]], 
                                   threshold: float = 7.0) -> List[Dict[str, Any]]:
        """Identify clauses with high risk (severity above threshold)"""
        high_risk = []
        
        for idx, pred in enumerate(clause_predictions):
            if pred['severity_score'] >= threshold:
                high_risk.append({
                    'clause_index': idx,
                    'risk_id': pred['predicted_risk_id'],
                    'severity': pred['severity_score'],
                    'importance': pred['importance_score'],
                    'confidence': pred['confidence']
                })
        
        # Sort by severity (descending)
        high_risk.sort(key=lambda x: x['severity'], reverse=True)
        
        return high_risk
    
    def compare_contracts(self, contract_a_predictions: List[Dict[str, Any]],
                         contract_b_predictions: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Compare risk profiles of two contracts.
        
        Args:
            contract_a_predictions: Clause predictions for contract A
            contract_b_predictions: Clause predictions for contract B
        
        Returns:
            Comparison analysis including relative risk levels and differences
        """
        # Aggregate both contracts
        contract_a = self.aggregate_contract_risk(contract_a_predictions, method='weighted_mean')
        contract_b = self.aggregate_contract_risk(contract_b_predictions, method='weighted_mean')
        
        # Compute differences
        severity_diff = contract_a['contract_severity'] - contract_b['contract_severity']
        importance_diff = contract_a['contract_importance'] - contract_b['contract_importance']
        
        # Determine which is riskier
        riskier_contract = 'Contract A' if severity_diff > 0 else 'Contract B'
        risk_difference = abs(severity_diff)
        
        return {
            'contract_a': {
                'severity': contract_a['contract_severity'],
                'importance': contract_a['contract_importance'],
                'num_clauses': contract_a['num_clauses'],
                'high_risk_clauses': len(contract_a['high_risk_clauses'])
            },
            'contract_b': {
                'severity': contract_b['contract_severity'],
                'importance': contract_b['contract_importance'],
                'num_clauses': contract_b['num_clauses'],
                'high_risk_clauses': len(contract_b['high_risk_clauses'])
            },
            'comparison': {
                'riskier_contract': riskier_contract,
                'severity_difference': float(severity_diff),
                'importance_difference': float(importance_diff),
                'risk_magnitude': 'high' if risk_difference > 2.0 else 'moderate' if risk_difference > 1.0 else 'low'
            }
        }
    
    def generate_contract_report(self, clause_predictions: List[Dict[str, Any]],
                                contract_name: str = "Contract") -> str:
        """
        Generate a human-readable report of contract risk assessment.
        
        Args:
            clause_predictions: Clause-level predictions
            contract_name: Name/identifier for the contract
        
        Returns:
            Formatted text report
        """
        # Aggregate risk
        contract_risk = self.aggregate_contract_risk(clause_predictions, method='weighted_mean')
        
        # Build report
        report = f"\n{'='*70}\n"
        report += f"CONTRACT RISK ASSESSMENT REPORT: {contract_name}\n"
        report += f"{'='*70}\n\n"
        
        report += f"๐Ÿ“Š OVERALL ASSESSMENT\n"
        report += f"{'-'*70}\n"
        report += f"Risk Category ID: {contract_risk['contract_risk_id']}\n"
        report += f"Overall Severity: {contract_risk['contract_severity']:.2f}/10.0\n"
        report += f"Overall Importance: {contract_risk['contract_importance']:.2f}/10.0\n"
        report += f"Confidence Level: {contract_risk['contract_confidence']:.2%}\n"
        report += f"Number of Clauses Analyzed: {contract_risk['num_clauses']}\n\n"
        
        # Severity interpretation
        severity = contract_risk['contract_severity']
        if severity >= 8.0:
            risk_level = "๐Ÿ”ด CRITICAL RISK"
        elif severity >= 6.0:
            risk_level = "๐ŸŸ  HIGH RISK"
        elif severity >= 4.0:
            risk_level = "๐ŸŸก MODERATE RISK"
        else:
            risk_level = "๐ŸŸข LOW RISK"
        
        report += f"Risk Level: {risk_level}\n\n"
        
        # High-risk clauses
        high_risk = contract_risk['high_risk_clauses']
        if high_risk:
            report += f"โš ๏ธ  HIGH-RISK CLAUSES (Severity โ‰ฅ 7.0)\n"
            report += f"{'-'*70}\n"
            for clause in high_risk[:5]:  # Show top 5
                report += f"Clause {clause['clause_index']}: "
                report += f"Severity={clause['severity']:.2f}, "
                report += f"Importance={clause['importance']:.2f}, "
                report += f"Confidence={clause['confidence']:.2%}\n"
            if len(high_risk) > 5:
                report += f"... and {len(high_risk) - 5} more high-risk clauses\n"
            report += "\n"
        
        # Risk distribution
        report += f"๐Ÿ“ˆ RISK DISTRIBUTION\n"
        report += f"{'-'*70}\n"
        risk_dist = contract_risk['risk_distribution']
        for risk_id, stats in sorted(risk_dist.items(), key=lambda x: x[1]['avg_severity'], reverse=True):
            report += f"Risk Type {risk_id}: "
            report += f"{stats['count']} clauses ({stats['proportion']:.1%}), "
            report += f"Avg Severity={stats['avg_severity']:.2f}\n"
        
        report += f"\n{'='*70}\n"
        
        return report


class RiskDependencyAnalyzer:
    """
    Analyzes dependencies and interactions between different risk types in a contract.
    
    This helps identify:
    - Co-occurrence patterns (which risks tend to appear together)
    - Risk amplification (how one risk type affects others)
    - Risk chains (sequences of related risks)
    """
    
    def __init__(self):
        """Initialize the risk dependency analyzer"""
        self.cooccurrence_matrix = None
    
    def analyze_risk_cooccurrence(self, clause_predictions: List[Dict[str, Any]],
                                 num_risk_types: int = 7) -> np.ndarray:
        """
        Analyze co-occurrence of risk types within a contract.
        
        Args:
            clause_predictions: Clause-level predictions
            num_risk_types: Total number of risk types
        
        Returns:
            Co-occurrence matrix (num_risk_types x num_risk_types)
        """
        # Extract risk IDs
        risk_ids = [p['predicted_risk_id'] for p in clause_predictions]
        
        # Initialize co-occurrence matrix
        cooccur = np.zeros((num_risk_types, num_risk_types))
        
        # Count co-occurrences (risks appearing in same contract)
        for i in range(num_risk_types):
            for j in range(num_risk_types):
                # Count how many times risk i and j appear together
                has_i = i in risk_ids
                has_j = j in risk_ids
                if has_i and has_j:
                    cooccur[i, j] += 1
        
        self.cooccurrence_matrix = cooccur
        return cooccur
    
    def find_risk_chains(self, clause_predictions: List[Dict[str, Any]],
                        window_size: int = 3) -> List[List[int]]:
        """
        Identify sequences of related risks (risk chains) in contract clauses.
        
        Args:
            clause_predictions: Clause-level predictions (ordered by clause position)
            window_size: Size of sliding window to find risk chains
        
        Returns:
            List of risk chains (sequences of risk IDs)
        """
        if len(clause_predictions) < window_size:
            return []
        
        risk_ids = [p['predicted_risk_id'] for p in clause_predictions]
        
        chains = []
        for i in range(len(risk_ids) - window_size + 1):
            chain = risk_ids[i:i+window_size]
            # Only keep chains with at least 2 different risk types
            if len(set(chain)) >= 2:
                chains.append(chain)
        
        return chains
    
    def compute_risk_correlation(self, contract_predictions: List[List[Dict[str, Any]]],
                                num_risk_types: int = 7) -> np.ndarray:
        """
        Compute correlation between risk types across multiple contracts.
        
        Args:
            contract_predictions: List of contract predictions (each is list of clause predictions)
            num_risk_types: Total number of risk types
        
        Returns:
            Correlation matrix showing how risk types co-occur across contracts
        """
        # Create binary matrix: contracts x risk_types
        num_contracts = len(contract_predictions)
        risk_matrix = np.zeros((num_contracts, num_risk_types))
        
        for contract_idx, clause_preds in enumerate(contract_predictions):
            risk_ids = [p['predicted_risk_id'] for p in clause_preds]
            for risk_id in set(risk_ids):
                risk_matrix[contract_idx, risk_id] = 1
        
        # Compute correlation
        correlation = np.corrcoef(risk_matrix.T)
        
        return correlation
    
    def analyze_risk_amplification(self, clause_predictions: List[Dict[str, Any]]) -> Dict[int, Dict[str, float]]:
        """
        Analyze how the presence of one risk type affects severity of others.
        
        Args:
            clause_predictions: Clause-level predictions
        
        Returns:
            Dictionary mapping risk_id to its amplification effects on other risks
        """
        # Group clauses by risk type
        risk_groups = defaultdict(list)
        for pred in clause_predictions:
            risk_groups[pred['predicted_risk_id']].append(pred)
        
        amplification = {}
        
        # For each risk type, compute its average severity
        for risk_id, clauses in risk_groups.items():
            severities = [c['severity_score'] for c in clauses]
            importances = [c['importance_score'] for c in clauses]
            
            amplification[risk_id] = {
                'avg_severity': float(np.mean(severities)),
                'max_severity': float(np.max(severities)),
                'avg_importance': float(np.mean(importances)),
                'clause_count': len(clauses),
                'severity_variance': float(np.var(severities))
            }
        
        return amplification