""" Evaluation and Analysis Tools for Legal-BERT """ import torch import numpy as np import json from typing import Dict, List, Any, Tuple from collections import defaultdict # Try to import visualization libraries try: import matplotlib.pyplot as plt import seaborn as sns VISUALIZATION_AVAILABLE = True except ImportError: VISUALIZATION_AVAILABLE = False print("⚠️ Warning: matplotlib/seaborn not available. Visualizations will be skipped.") # Import hierarchical risk analysis try: from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer HIERARCHICAL_AVAILABLE = True except ImportError: HIERARCHICAL_AVAILABLE = False print("⚠️ Warning: hierarchical_risk module not available.") class LegalBertEvaluator: """ Comprehensive evaluation for Legal-BERT with discovered risk patterns """ def __init__(self, model, tokenizer, risk_discovery): self.model = model self.tokenizer = tokenizer self.risk_discovery = risk_discovery self.evaluation_results = {} def evaluate_model(self, test_loader, save_results: bool = True) -> Dict[str, Any]: """Comprehensive model evaluation""" print("🔍 Starting comprehensive evaluation...") # Collect predictions all_predictions = [] all_true_labels = [] all_severity_preds = [] all_severity_true = [] all_importance_preds = [] all_importance_true = [] all_confidences = [] self.model.eval() with torch.no_grad(): for batch in test_loader: device = next(self.model.parameters()).device input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) # Get predictions using the correct method outputs = self.model.forward_single_clause(input_ids, attention_mask) # Calculate predictions and confidences from logits risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1) predicted_risk_ids = torch.argmax(risk_probs, dim=-1) confidences = torch.max(risk_probs, dim=-1)[0] # Store results all_predictions.extend(predicted_risk_ids.cpu().numpy()) all_true_labels.extend(batch['risk_label'].numpy()) all_severity_preds.extend(outputs['severity_score'].cpu().numpy()) all_severity_true.extend(batch['severity_score'].numpy()) all_importance_preds.extend(outputs['importance_score'].cpu().numpy()) all_importance_true.extend(batch['importance_score'].numpy()) all_confidences.extend(confidences.cpu().numpy()) # Calculate metrics results = { 'classification_metrics': self._calculate_classification_metrics( all_true_labels, all_predictions, all_confidences ), 'regression_metrics': self._calculate_regression_metrics( all_severity_true, all_severity_preds, all_importance_true, all_importance_preds ), 'risk_pattern_analysis': self._analyze_risk_patterns( all_true_labels, all_predictions ) } self.evaluation_results = results if save_results: self.save_evaluation_results(results) print("✅ Evaluation complete!") return results def _calculate_classification_metrics(self, true_labels: List[int], predictions: List[int], confidences: List[float]) -> Dict[str, Any]: """Calculate classification metrics""" from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix accuracy = accuracy_score(true_labels, predictions) precision, recall, f1, support = precision_recall_fscore_support( true_labels, predictions, average='weighted' ) # Per-class metrics precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support( true_labels, predictions, average=None ) # Confusion matrix cm = confusion_matrix(true_labels, predictions) # Confidence analysis avg_confidence = np.mean(confidences) confidence_std = np.std(confidences) return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1, 'precision_per_class': precision_per_class.tolist(), 'recall_per_class': recall_per_class.tolist(), 'f1_per_class': f1_per_class.tolist(), 'confusion_matrix': cm.tolist(), 'avg_confidence': avg_confidence, 'confidence_std': confidence_std } def _calculate_regression_metrics(self, severity_true: List[float], severity_pred: List[float], importance_true: List[float], importance_pred: List[float]) -> Dict[str, Any]: """Calculate regression metrics""" from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score # Severity metrics severity_mse = mean_squared_error(severity_true, severity_pred) severity_mae = mean_absolute_error(severity_true, severity_pred) severity_r2 = r2_score(severity_true, severity_pred) # Importance metrics importance_mse = mean_squared_error(importance_true, importance_pred) importance_mae = mean_absolute_error(importance_true, importance_pred) importance_r2 = r2_score(importance_true, importance_pred) return { 'severity': { 'mse': severity_mse, 'mae': severity_mae, 'r2_score': severity_r2 }, 'importance': { 'mse': importance_mse, 'mae': importance_mae, 'r2_score': importance_r2 } } def _analyze_risk_patterns(self, true_labels: List[int], predictions: List[int]) -> Dict[str, Any]: """Analyze discovered risk patterns""" discovered_patterns = self.risk_discovery.discovered_patterns pattern_names = list(discovered_patterns.keys()) # Pattern distribution true_distribution = defaultdict(int) pred_distribution = defaultdict(int) for label in true_labels: true_distribution[pattern_names[label]] += 1 for pred in predictions: pred_distribution[pattern_names[pred]] += 1 # Pattern-specific performance pattern_performance = {} for i, pattern_name in enumerate(pattern_names): pattern_true = [1 if label == i else 0 for label in true_labels] pattern_pred = [1 if pred == i else 0 for pred in predictions] if sum(pattern_true) > 0: # Avoid division by zero precision = sum([1 for t, p in zip(pattern_true, pattern_pred) if t == 1 and p == 1]) / max(sum(pattern_pred), 1) recall = sum([1 for t, p in zip(pattern_true, pattern_pred) if t == 1 and p == 1]) / sum(pattern_true) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 pattern_performance[pattern_name] = { 'precision': precision, 'recall': recall, 'f1_score': f1, 'support': sum(pattern_true) } return { 'true_distribution': dict(true_distribution), 'predicted_distribution': dict(pred_distribution), 'pattern_performance': pattern_performance, 'discovered_patterns_info': discovered_patterns } def generate_report(self) -> str: """Generate comprehensive evaluation report""" if not self.evaluation_results: raise ValueError("Must run evaluation first") results = self.evaluation_results report = [] report.append("=" * 80) report.append("🏛️ LEGAL-BERT EVALUATION REPORT") report.append("=" * 80) # Classification Performance report.append("\n📊 RISK CLASSIFICATION PERFORMANCE") report.append("-" * 50) clf_metrics = results['classification_metrics'] report.append(f"Accuracy: {clf_metrics['accuracy']:.4f}") report.append(f"Precision: {clf_metrics['precision']:.4f}") report.append(f"Recall: {clf_metrics['recall']:.4f}") report.append(f"F1-Score: {clf_metrics['f1_score']:.4f}") report.append(f"Average Confidence: {clf_metrics['avg_confidence']:.4f}") # Regression Performance report.append("\n📈 REGRESSION PERFORMANCE") report.append("-" * 50) reg_metrics = results['regression_metrics'] report.append("Severity Prediction:") report.append(f" MSE: {reg_metrics['severity']['mse']:.4f}") report.append(f" MAE: {reg_metrics['severity']['mae']:.4f}") report.append(f" R²: {reg_metrics['severity']['r2_score']:.4f}") report.append("Importance Prediction:") report.append(f" MSE: {reg_metrics['importance']['mse']:.4f}") report.append(f" MAE: {reg_metrics['importance']['mae']:.4f}") report.append(f" R²: {reg_metrics['importance']['r2_score']:.4f}") # Risk Pattern Analysis report.append("\n🔍 DISCOVERED RISK PATTERNS") report.append("-" * 50) pattern_analysis = results['risk_pattern_analysis'] report.append("Pattern Distribution (True vs Predicted):") for pattern, count in pattern_analysis['true_distribution'].items(): pred_count = pattern_analysis['predicted_distribution'].get(pattern, 0) report.append(f" {pattern}: {count} → {pred_count}") report.append("\nPattern-Specific Performance:") for pattern, metrics in pattern_analysis['pattern_performance'].items(): report.append(f" {pattern}:") report.append(f" Precision: {metrics['precision']:.4f}") report.append(f" Recall: {metrics['recall']:.4f}") report.append(f" F1-Score: {metrics['f1_score']:.4f}") report.append(f" Support: {metrics['support']}") # Discovered Patterns Info report.append("\n🎯 DISCOVERED PATTERN DETAILS") report.append("-" * 50) for pattern_name, details in pattern_analysis['discovered_patterns_info'].items(): report.append(f"\n{pattern_name}:") # Handle different pattern structures (LDA vs K-Means) if 'clause_count' in details: report.append(f" Clauses: {details['clause_count']}") if 'avg_risk_intensity' in details: report.append(f" Risk Intensity: {details['avg_risk_intensity']:.3f}") if 'avg_legal_complexity' in details: report.append(f" Legal Complexity: {details['avg_legal_complexity']:.3f}") # Handle both 'key_terms' and 'top_words' (LDA uses top_words) if 'key_terms' in details: report.append(f" Key Terms: {', '.join(details['key_terms'][:5])}") elif 'top_words' in details: report.append(f" Top Words: {', '.join(details['top_words'][:5])}") # Show topic distribution if available (LDA-specific) if 'topic_distribution' in details: report.append(f" Topic Distribution: {details['topic_distribution']:.3f}") report.append("\n" + "=" * 80) return "\n".join(report) def plot_confusion_matrix(self, save_path: str = None): """Plot confusion matrix""" if not VISUALIZATION_AVAILABLE: print("⚠️ Visualization libraries not available. Skipping plot.") return if not self.evaluation_results: raise ValueError("Must run evaluation first") cm = np.array(self.evaluation_results['classification_metrics']['confusion_matrix']) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.title('Confusion Matrix - Risk Classification') plt.ylabel('True Label') plt.xlabel('Predicted Label') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"💾 Confusion matrix saved to: {save_path}") else: plt.show() plt.close() def plot_risk_distribution(self, save_path: str = None): """Plot risk pattern distribution""" if not VISUALIZATION_AVAILABLE: print("⚠️ Visualization libraries not available. Skipping plot.") return if not self.evaluation_results: raise ValueError("Must run evaluation first") pattern_analysis = self.evaluation_results['risk_pattern_analysis'] patterns = list(pattern_analysis['true_distribution'].keys()) true_counts = [pattern_analysis['true_distribution'][p] for p in patterns] pred_counts = [pattern_analysis['predicted_distribution'].get(p, 0) for p in patterns] x = np.arange(len(patterns)) width = 0.35 fig, ax = plt.subplots(figsize=(12, 6)) ax.bar(x - width/2, true_counts, width, label='True', alpha=0.8) ax.bar(x + width/2, pred_counts, width, label='Predicted', alpha=0.8) ax.set_xlabel('Risk Patterns') ax.set_ylabel('Count') ax.set_title('Risk Pattern Distribution - True vs Predicted') ax.set_xticks(x) ax.set_xticklabels(patterns, rotation=45, ha='right') ax.legend() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') print(f"💾 Risk distribution plot saved to: {save_path}") else: plt.show() plt.close() def save_evaluation_results(self, results: Dict[str, Any]): """Save evaluation results to file""" # Convert numpy arrays to lists for JSON serialization json_results = self._convert_for_json(results) with open('evaluation_results.json', 'w') as f: json.dump(json_results, f, indent=2) # Save report report = self.generate_report() with open('evaluation_report.txt', 'w') as f: f.write(report) print("💾 Evaluation results saved:") print(" - evaluation_results.json") print(" - evaluation_report.txt") def _convert_for_json(self, obj): """Convert numpy arrays to lists for JSON serialization""" if isinstance(obj, dict): return {key: self._convert_for_json(value) for key, value in obj.items()} elif isinstance(obj, list): return [self._convert_for_json(item) for item in obj] elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) else: return obj def analyze_attention_patterns(self, test_clauses: List[str], max_samples: int = 10) -> Dict[str, Any]: """ Analyze attention patterns for clause importance interpretation. Args: test_clauses: List of clause texts to analyze max_samples: Maximum number of samples to analyze Returns: Dictionary containing attention analysis results """ print(f"🔍 Analyzing attention patterns for {min(len(test_clauses), max_samples)} samples...") self.model.eval() attention_results = [] with torch.no_grad(): for idx, clause in enumerate(test_clauses[:max_samples]): # Tokenize tokens = self.tokenizer.tokenize_clauses([clause]) input_ids = tokens['input_ids'].to(self.model.config.device) attention_mask = tokens['attention_mask'].to(self.model.config.device) # Get attention analysis analysis = self.model.analyze_attention(input_ids, attention_mask, self.tokenizer) # Get prediction prediction = self.model.predict_risk_pattern(input_ids, attention_mask) result = { 'clause_index': idx, 'clause_preview': clause[:100] + '...' if len(clause) > 100 else clause, 'predicted_risk': int(prediction['predicted_risk_id'][0]), 'severity': float(prediction['severity_score'][0]), 'importance': float(prediction['importance_score'][0]), 'top_tokens': analysis.get('top_tokens', []), 'top_token_scores': analysis.get('top_token_scores', np.array([])).tolist() } attention_results.append(result) print(f"✅ Attention analysis complete for {len(attention_results)} clauses") return { 'num_analyzed': len(attention_results), 'clause_analyses': attention_results } def evaluate_hierarchical_risk(self, test_loader, contract_ids: List[int]) -> Dict[str, Any]: """ Evaluate hierarchical risk aggregation (clause → contract level). Args: test_loader: DataLoader with test clauses contract_ids: List of contract IDs for each clause in test set Returns: Contract-level risk assessment results """ if not HIERARCHICAL_AVAILABLE: print("⚠️ Hierarchical risk analysis not available") return {'error': 'hierarchical_risk module not found'} print("📊 Performing hierarchical risk evaluation (clause → contract level)...") # Collect clause-level predictions grouped by contract contract_predictions = defaultdict(list) self.model.eval() clause_idx = 0 with torch.no_grad(): for batch in test_loader: input_ids = batch['input_ids'].to(self.model.config.device) attention_mask = batch['attention_mask'].to(self.model.config.device) # Get predictions predictions = self.model.predict_risk_pattern(input_ids, attention_mask) # Group by contract batch_size = input_ids.size(0) for i in range(batch_size): contract_id = contract_ids[clause_idx] clause_pred = { 'predicted_risk_id': int(predictions['predicted_risk_id'][i]), 'confidence': float(predictions['confidence'][i]), 'severity_score': float(predictions['severity_score'][i]), 'importance_score': float(predictions['importance_score'][i]) } contract_predictions[contract_id].append(clause_pred) clause_idx += 1 # Aggregate to contract level aggregator = HierarchicalRiskAggregator() contract_results = {} for contract_id, clause_preds in contract_predictions.items(): contract_risk = aggregator.aggregate_contract_risk( clause_preds, method='weighted_mean' ) contract_results[contract_id] = contract_risk print(f"✅ Analyzed {len(contract_results)} contracts") # Summary statistics contract_severities = [r['contract_severity'] for r in contract_results.values()] contract_importances = [r['contract_importance'] for r in contract_results.values()] summary = { 'num_contracts': len(contract_results), 'contract_results': contract_results, 'summary_statistics': { 'avg_contract_severity': float(np.mean(contract_severities)), 'std_contract_severity': float(np.std(contract_severities)), 'max_contract_severity': float(np.max(contract_severities)), 'min_contract_severity': float(np.min(contract_severities)), 'avg_contract_importance': float(np.mean(contract_importances)), 'high_risk_contracts': sum(1 for s in contract_severities if s >= 7.0) } } return summary def analyze_risk_dependencies(self, test_loader, contract_ids: List[int], num_risk_types: int = 7) -> Dict[str, Any]: """ Analyze dependencies and interactions between risk types. Args: test_loader: DataLoader with test clauses contract_ids: List of contract IDs for each clause num_risk_types: Number of risk categories Returns: Risk dependency analysis including co-occurrence and correlations """ if not HIERARCHICAL_AVAILABLE: print("⚠️ Risk dependency analysis not available") return {'error': 'hierarchical_risk module not found'} print("🔗 Analyzing risk dependencies and interactions...") # Collect predictions grouped by contract contract_predictions = defaultdict(list) self.model.eval() clause_idx = 0 with torch.no_grad(): for batch in test_loader: input_ids = batch['input_ids'].to(self.model.config.device) attention_mask = batch['attention_mask'].to(self.model.config.device) predictions = self.model.predict_risk_pattern(input_ids, attention_mask) batch_size = input_ids.size(0) for i in range(batch_size): contract_id = contract_ids[clause_idx] clause_pred = { 'predicted_risk_id': int(predictions['predicted_risk_id'][i]), 'confidence': float(predictions['confidence'][i]), 'severity_score': float(predictions['severity_score'][i]), 'importance_score': float(predictions['importance_score'][i]) } contract_predictions[contract_id].append(clause_pred) clause_idx += 1 # Analyze dependencies dependency_analyzer = RiskDependencyAnalyzer() # Compute correlation across contracts contract_pred_lists = list(contract_predictions.values()) correlation_matrix = dependency_analyzer.compute_risk_correlation( contract_pred_lists, num_risk_types ) # Analyze amplification effects all_clause_preds = [pred for preds in contract_pred_lists for pred in preds] amplification = dependency_analyzer.analyze_risk_amplification(all_clause_preds) # Find common risk chains all_chains = [] for clause_preds in contract_pred_lists: chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3) all_chains.extend(chains) # Count most common chains from collections import Counter chain_counts = Counter([tuple(chain) for chain in all_chains]) most_common_chains = chain_counts.most_common(10) print(f"✅ Risk dependency analysis complete") return { 'correlation_matrix': correlation_matrix.tolist(), 'risk_amplification': amplification, 'common_risk_chains': [ {'chain': list(chain), 'count': count} for chain, count in most_common_chains ], 'total_chains_found': len(all_chains) } # Mock imports for environments without sklearn/matplotlib try: import torch import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score except ImportError: print("⚠️ Warning: Some evaluation dependencies not available. Using mock implementations.") # Mock torch class MockTensor: def __init__(self, data): self.data = data def numpy(self): return self.data def to(self, device): return self class MockModule: def eval(self): pass def __getattr__(self, name): return lambda *args, **kwargs: None torch = type('torch', (), { 'no_grad': lambda: type('context', (), {'__enter__': lambda self: None, '__exit__': lambda *args: None})() })() # Mock sklearn functions def accuracy_score(y_true, y_pred): return sum([1 for t, p in zip(y_true, y_pred) if t == p]) / len(y_true) def precision_recall_fscore_support(y_true, y_pred, average=None): return 0.5, 0.5, 0.5, None def confusion_matrix(y_true, y_pred): return [[1, 0], [0, 1]] def mean_squared_error(y_true, y_pred): return sum([(t - p) ** 2 for t, p in zip(y_true, y_pred)]) / len(y_true) def mean_absolute_error(y_true, y_pred): return sum([abs(t - p) for t, p in zip(y_true, y_pred)]) / len(y_true) def r2_score(y_true, y_pred): return 0.5