code2-repo / advanced_analysis.py
Deepu1965's picture
Upload folder using huggingface_hub
9b1c753 verified
"""
Advanced Analysis Script for Legal-BERT
Demonstrates attention analysis, hierarchical risk modeling, and risk dependencies
This script showcases the newly implemented features:
1. Attention mechanism analysis for clause importance
2. Hierarchical risk aggregation (clause β†’ contract level)
3. Risk dependency and interaction analysis
"""
import torch
import json
from typing import Dict, List, Any
import numpy as np
from config import LegalBertConfig
from model import HierarchicalLegalBERT, LegalBertTokenizer
from evaluator import LegalBertEvaluator
from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer
from risk_discovery import UnsupervisedRiskDiscovery
def load_trained_model(model_path: str, config: LegalBertConfig):
"""Load a trained Hierarchical Legal-BERT model"""
print(f"πŸ“‚ Loading model from {model_path}...")
try:
checkpoint = torch.load(model_path, map_location=config.device)
num_discovered_risks = len(checkpoint.get('discovered_patterns', {}))
print("πŸ“Š Loading Hierarchical BERT model")
model = HierarchicalLegalBERT(
config,
num_discovered_risks=num_discovered_risks,
hidden_dim=config.hierarchical_hidden_dim,
num_lstm_layers=config.hierarchical_num_lstm_layers
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(config.device)
model.eval()
print("βœ… Model loaded successfully")
return model
except FileNotFoundError:
print("⚠️ Model file not found. Please train the model first.")
return None
def demo_attention_analysis(model, tokenizer, sample_clauses: List[str]):
"""Demonstrate attention mechanism analysis"""
print("\n" + "="*80)
print("πŸ” ATTENTION MECHANISM ANALYSIS")
print("="*80)
for idx, clause in enumerate(sample_clauses[:3]):
print(f"\nπŸ“„ Analyzing Clause {idx + 1}:")
print(f"Text: {clause[:100]}..." if len(clause) > 100 else f"Text: {clause}")
# Tokenize
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
# Get attention analysis
analysis = model.analyze_attention(input_ids, attention_mask, tokenizer)
# Get prediction
prediction = model.predict_risk_pattern(input_ids, attention_mask)
print(f"\n Predicted Risk ID: {prediction['predicted_risk_id'][0]}")
print(f" Severity: {prediction['severity_score'][0]:.2f}/10")
print(f" Importance: {prediction['importance_score'][0]:.2f}/10")
print(f" Confidence: {prediction['confidence'][0]:.2%}")
if 'top_tokens' in analysis:
print(f"\n 🎯 Most Important Tokens:")
for token, score in zip(analysis['top_tokens'][:5],
analysis['top_token_scores'][0][:5]):
print(f" {token}: {score:.4f}")
print("\nβœ… Attention analysis complete")
def demo_hierarchical_risk(model, tokenizer, contract_clauses: Dict[str, List[str]]):
"""Demonstrate hierarchical risk aggregation"""
print("\n" + "="*80)
print("πŸ“Š HIERARCHICAL RISK AGGREGATION (Clause β†’ Contract)")
print("="*80)
aggregator = HierarchicalRiskAggregator()
for contract_name, clauses in contract_clauses.items():
print(f"\nπŸ“‹ Analyzing Contract: {contract_name}")
print(f" Number of clauses: {len(clauses)}")
# Get predictions for all clauses
clause_predictions = []
model.eval()
with torch.no_grad():
for clause in clauses:
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
pred = model.predict_risk_pattern(input_ids, attention_mask)
clause_predictions.append({
'predicted_risk_id': int(pred['predicted_risk_id'][0]),
'confidence': float(pred['confidence'][0]),
'severity_score': float(pred['severity_score'][0]),
'importance_score': float(pred['importance_score'][0])
})
# Aggregate to contract level
contract_risk = aggregator.aggregate_contract_risk(
clause_predictions,
method='weighted_mean'
)
# Display results
print(f"\n Contract-Level Assessment:")
print(f" β”œβ”€ Risk Category: {contract_risk['contract_risk_id']}")
print(f" β”œβ”€ Overall Severity: {contract_risk['contract_severity']:.2f}/10")
print(f" β”œβ”€ Overall Importance: {contract_risk['contract_importance']:.2f}/10")
print(f" β”œβ”€ Confidence: {contract_risk['contract_confidence']:.2%}")
print(f" └─ High-Risk Clauses: {len(contract_risk['high_risk_clauses'])}")
# Generate report
report = aggregator.generate_contract_report(clause_predictions, contract_name)
print(report)
print("\nβœ… Hierarchical risk analysis complete")
def demo_risk_dependencies(model, tokenizer, contract_clauses: Dict[str, List[str]]):
"""Demonstrate risk dependency analysis"""
print("\n" + "="*80)
print("πŸ”— RISK DEPENDENCY & INTERACTION ANALYSIS")
print("="*80)
dependency_analyzer = RiskDependencyAnalyzer()
# Collect predictions for all contracts
all_contract_predictions = []
model.eval()
with torch.no_grad():
for contract_name, clauses in contract_clauses.items():
clause_predictions = []
for clause in clauses:
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
pred = model.predict_risk_pattern(input_ids, attention_mask)
clause_predictions.append({
'predicted_risk_id': int(pred['predicted_risk_id'][0]),
'confidence': float(pred['confidence'][0]),
'severity_score': float(pred['severity_score'][0]),
'importance_score': float(pred['importance_score'][0])
})
all_contract_predictions.append(clause_predictions)
# Compute risk correlation
print("\nπŸ“ˆ Computing risk correlation matrix...")
correlation = dependency_analyzer.compute_risk_correlation(
all_contract_predictions,
num_risk_types=7
)
print("\n Risk Type Correlation Matrix (7x7):")
print(" " + "-"*50)
for i, row in enumerate(correlation):
print(f" Risk {i}: " + " ".join([f"{val:6.3f}" for val in row]))
# Analyze risk amplification
print("\n⚑ Analyzing risk amplification effects...")
all_clauses = [pred for contract in all_contract_predictions for pred in contract]
amplification = dependency_analyzer.analyze_risk_amplification(all_clauses)
print("\n Risk Amplification Analysis:")
for risk_id, stats in sorted(amplification.items(),
key=lambda x: x[1]['avg_severity'],
reverse=True):
print(f" Risk {risk_id}:")
print(f" β”œβ”€ Avg Severity: {stats['avg_severity']:.2f}")
print(f" β”œβ”€ Max Severity: {stats['max_severity']:.2f}")
print(f" β”œβ”€ Clause Count: {stats['clause_count']}")
print(f" └─ Severity Variance: {stats['severity_variance']:.2f}")
# Find risk chains
print("\nπŸ”— Identifying common risk chains...")
all_chains = []
for clause_preds in all_contract_predictions:
chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3)
all_chains.extend(chains)
from collections import Counter
chain_counts = Counter([tuple(chain) for chain in all_chains])
most_common = chain_counts.most_common(5)
print(f"\n Top 5 Most Common Risk Chains:")
for chain, count in most_common:
print(f" {list(chain)} β†’ appeared {count} times")
print("\nβœ… Risk dependency analysis complete")
def main():
"""Main demonstration script"""
print("="*80)
print("πŸ›οΈ LEGAL-BERT ADVANCED ANALYSIS DEMONSTRATION")
print("="*80)
# Initialize configuration
config = LegalBertConfig()
# Load model
model_path = f"{config.model_save_path}/best_model.pt"
model = load_trained_model(model_path, config)
if model is None:
print("\n⚠️ Cannot proceed without trained model.")
print(" Please run 'python train.py' first to train the model.")
return
# Initialize tokenizer
tokenizer = LegalBertTokenizer(config.bert_model_name)
# Sample clauses for demonstration
sample_clauses = [
"The Company shall indemnify and hold harmless the Customer from any claims, damages, or liabilities arising from breach of this Agreement.",
"Either party may terminate this Agreement upon thirty (30) days written notice to the other party.",
"All intellectual property rights in the deliverables shall remain the exclusive property of the Company.",
"The Customer agrees to pay the Company a monthly fee of $10,000 for the services provided under this Agreement."
]
# Sample contracts (multiple clauses per contract)
contract_clauses = {
"Service_Agreement_001": [
"The Service Provider agrees to provide software development services as specified in Exhibit A.",
"Payment shall be made within 30 days of invoice receipt.",
"The Service Provider shall indemnify Client against all third-party claims arising from the services.",
"This Agreement may be terminated by either party with 60 days notice."
],
"License_Agreement_002": [
"Licensor grants Licensee a non-exclusive, worldwide license to use the Software.",
"Licensee shall pay annual license fees of $50,000.",
"All intellectual property rights remain with Licensor.",
"Confidential information must be kept confidential for 5 years."
]
}
# Run demonstrations
try:
# 1. Attention Analysis
demo_attention_analysis(model, tokenizer, sample_clauses)
# 2. Hierarchical Risk Modeling
demo_hierarchical_risk(model, tokenizer, contract_clauses)
# 3. Risk Dependencies
demo_risk_dependencies(model, tokenizer, contract_clauses)
except Exception as e:
print(f"\n❌ Error during analysis: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
print("πŸŽ‰ ADVANCED ANALYSIS DEMONSTRATION COMPLETE")
print("="*80)
print("\nThese features are now integrated into the evaluation pipeline.")
print("Use them during training evaluation or post-training analysis.")
if __name__ == "__main__":
main()