""" Advanced Analysis Script for Legal-BERT Demonstrates attention analysis, hierarchical risk modeling, and risk dependencies This script showcases the newly implemented features: 1. Attention mechanism analysis for clause importance 2. Hierarchical risk aggregation (clause → contract level) 3. Risk dependency and interaction analysis """ import torch import json from typing import Dict, List, Any import numpy as np from config import LegalBertConfig from model import HierarchicalLegalBERT, LegalBertTokenizer from evaluator import LegalBertEvaluator from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer from risk_discovery import UnsupervisedRiskDiscovery def load_trained_model(model_path: str, config: LegalBertConfig): """Load a trained Hierarchical Legal-BERT model""" print(f"šŸ“‚ Loading model from {model_path}...") try: checkpoint = torch.load(model_path, map_location=config.device) num_discovered_risks = len(checkpoint.get('discovered_patterns', {})) print("šŸ“Š Loading Hierarchical BERT model") model = HierarchicalLegalBERT( config, num_discovered_risks=num_discovered_risks, hidden_dim=config.hierarchical_hidden_dim, num_lstm_layers=config.hierarchical_num_lstm_layers ) model.load_state_dict(checkpoint['model_state_dict']) model.to(config.device) model.eval() print("āœ… Model loaded successfully") return model except FileNotFoundError: print("āš ļø Model file not found. Please train the model first.") return None def demo_attention_analysis(model, tokenizer, sample_clauses: List[str]): """Demonstrate attention mechanism analysis""" print("\n" + "="*80) print("šŸ” ATTENTION MECHANISM ANALYSIS") print("="*80) for idx, clause in enumerate(sample_clauses[:3]): print(f"\nšŸ“„ Analyzing Clause {idx + 1}:") print(f"Text: {clause[:100]}..." if len(clause) > 100 else f"Text: {clause}") # Tokenize tokens = tokenizer.tokenize_clauses([clause]) input_ids = tokens['input_ids'].to(model.config.device) attention_mask = tokens['attention_mask'].to(model.config.device) # Get attention analysis analysis = model.analyze_attention(input_ids, attention_mask, tokenizer) # Get prediction prediction = model.predict_risk_pattern(input_ids, attention_mask) print(f"\n Predicted Risk ID: {prediction['predicted_risk_id'][0]}") print(f" Severity: {prediction['severity_score'][0]:.2f}/10") print(f" Importance: {prediction['importance_score'][0]:.2f}/10") print(f" Confidence: {prediction['confidence'][0]:.2%}") if 'top_tokens' in analysis: print(f"\n šŸŽÆ Most Important Tokens:") for token, score in zip(analysis['top_tokens'][:5], analysis['top_token_scores'][0][:5]): print(f" {token}: {score:.4f}") print("\nāœ… Attention analysis complete") def demo_hierarchical_risk(model, tokenizer, contract_clauses: Dict[str, List[str]]): """Demonstrate hierarchical risk aggregation""" print("\n" + "="*80) print("šŸ“Š HIERARCHICAL RISK AGGREGATION (Clause → Contract)") print("="*80) aggregator = HierarchicalRiskAggregator() for contract_name, clauses in contract_clauses.items(): print(f"\nšŸ“‹ Analyzing Contract: {contract_name}") print(f" Number of clauses: {len(clauses)}") # Get predictions for all clauses clause_predictions = [] model.eval() with torch.no_grad(): for clause in clauses: tokens = tokenizer.tokenize_clauses([clause]) input_ids = tokens['input_ids'].to(model.config.device) attention_mask = tokens['attention_mask'].to(model.config.device) pred = model.predict_risk_pattern(input_ids, attention_mask) clause_predictions.append({ 'predicted_risk_id': int(pred['predicted_risk_id'][0]), 'confidence': float(pred['confidence'][0]), 'severity_score': float(pred['severity_score'][0]), 'importance_score': float(pred['importance_score'][0]) }) # Aggregate to contract level contract_risk = aggregator.aggregate_contract_risk( clause_predictions, method='weighted_mean' ) # Display results print(f"\n Contract-Level Assessment:") print(f" ā”œā”€ Risk Category: {contract_risk['contract_risk_id']}") print(f" ā”œā”€ Overall Severity: {contract_risk['contract_severity']:.2f}/10") print(f" ā”œā”€ Overall Importance: {contract_risk['contract_importance']:.2f}/10") print(f" ā”œā”€ Confidence: {contract_risk['contract_confidence']:.2%}") print(f" └─ High-Risk Clauses: {len(contract_risk['high_risk_clauses'])}") # Generate report report = aggregator.generate_contract_report(clause_predictions, contract_name) print(report) print("\nāœ… Hierarchical risk analysis complete") def demo_risk_dependencies(model, tokenizer, contract_clauses: Dict[str, List[str]]): """Demonstrate risk dependency analysis""" print("\n" + "="*80) print("šŸ”— RISK DEPENDENCY & INTERACTION ANALYSIS") print("="*80) dependency_analyzer = RiskDependencyAnalyzer() # Collect predictions for all contracts all_contract_predictions = [] model.eval() with torch.no_grad(): for contract_name, clauses in contract_clauses.items(): clause_predictions = [] for clause in clauses: tokens = tokenizer.tokenize_clauses([clause]) input_ids = tokens['input_ids'].to(model.config.device) attention_mask = tokens['attention_mask'].to(model.config.device) pred = model.predict_risk_pattern(input_ids, attention_mask) clause_predictions.append({ 'predicted_risk_id': int(pred['predicted_risk_id'][0]), 'confidence': float(pred['confidence'][0]), 'severity_score': float(pred['severity_score'][0]), 'importance_score': float(pred['importance_score'][0]) }) all_contract_predictions.append(clause_predictions) # Compute risk correlation print("\nšŸ“ˆ Computing risk correlation matrix...") correlation = dependency_analyzer.compute_risk_correlation( all_contract_predictions, num_risk_types=7 ) print("\n Risk Type Correlation Matrix (7x7):") print(" " + "-"*50) for i, row in enumerate(correlation): print(f" Risk {i}: " + " ".join([f"{val:6.3f}" for val in row])) # Analyze risk amplification print("\n⚔ Analyzing risk amplification effects...") all_clauses = [pred for contract in all_contract_predictions for pred in contract] amplification = dependency_analyzer.analyze_risk_amplification(all_clauses) print("\n Risk Amplification Analysis:") for risk_id, stats in sorted(amplification.items(), key=lambda x: x[1]['avg_severity'], reverse=True): print(f" Risk {risk_id}:") print(f" ā”œā”€ Avg Severity: {stats['avg_severity']:.2f}") print(f" ā”œā”€ Max Severity: {stats['max_severity']:.2f}") print(f" ā”œā”€ Clause Count: {stats['clause_count']}") print(f" └─ Severity Variance: {stats['severity_variance']:.2f}") # Find risk chains print("\nšŸ”— Identifying common risk chains...") all_chains = [] for clause_preds in all_contract_predictions: chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3) all_chains.extend(chains) from collections import Counter chain_counts = Counter([tuple(chain) for chain in all_chains]) most_common = chain_counts.most_common(5) print(f"\n Top 5 Most Common Risk Chains:") for chain, count in most_common: print(f" {list(chain)} → appeared {count} times") print("\nāœ… Risk dependency analysis complete") def main(): """Main demonstration script""" print("="*80) print("šŸ›ļø LEGAL-BERT ADVANCED ANALYSIS DEMONSTRATION") print("="*80) # Initialize configuration config = LegalBertConfig() # Load model model_path = f"{config.model_save_path}/best_model.pt" model = load_trained_model(model_path, config) if model is None: print("\nāš ļø Cannot proceed without trained model.") print(" Please run 'python train.py' first to train the model.") return # Initialize tokenizer tokenizer = LegalBertTokenizer(config.bert_model_name) # Sample clauses for demonstration sample_clauses = [ "The Company shall indemnify and hold harmless the Customer from any claims, damages, or liabilities arising from breach of this Agreement.", "Either party may terminate this Agreement upon thirty (30) days written notice to the other party.", "All intellectual property rights in the deliverables shall remain the exclusive property of the Company.", "The Customer agrees to pay the Company a monthly fee of $10,000 for the services provided under this Agreement." ] # Sample contracts (multiple clauses per contract) contract_clauses = { "Service_Agreement_001": [ "The Service Provider agrees to provide software development services as specified in Exhibit A.", "Payment shall be made within 30 days of invoice receipt.", "The Service Provider shall indemnify Client against all third-party claims arising from the services.", "This Agreement may be terminated by either party with 60 days notice." ], "License_Agreement_002": [ "Licensor grants Licensee a non-exclusive, worldwide license to use the Software.", "Licensee shall pay annual license fees of $50,000.", "All intellectual property rights remain with Licensor.", "Confidential information must be kept confidential for 5 years." ] } # Run demonstrations try: # 1. Attention Analysis demo_attention_analysis(model, tokenizer, sample_clauses) # 2. Hierarchical Risk Modeling demo_hierarchical_risk(model, tokenizer, contract_clauses) # 3. Risk Dependencies demo_risk_dependencies(model, tokenizer, contract_clauses) except Exception as e: print(f"\nāŒ Error during analysis: {e}") import traceback traceback.print_exc() print("\n" + "="*80) print("šŸŽ‰ ADVANCED ANALYSIS DEMONSTRATION COMPLETE") print("="*80) print("\nThese features are now integrated into the evaluation pipeline.") print("Use them during training evaluation or post-training analysis.") if __name__ == "__main__": main()