File size: 11,506 Bytes
9b1c753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
"""
Advanced Analysis Script for Legal-BERT
Demonstrates attention analysis, hierarchical risk modeling, and risk dependencies
This script showcases the newly implemented features:
1. Attention mechanism analysis for clause importance
2. Hierarchical risk aggregation (clause β contract level)
3. Risk dependency and interaction analysis
"""
import torch
import json
from typing import Dict, List, Any
import numpy as np
from config import LegalBertConfig
from model import HierarchicalLegalBERT, LegalBertTokenizer
from evaluator import LegalBertEvaluator
from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer
from risk_discovery import UnsupervisedRiskDiscovery
def load_trained_model(model_path: str, config: LegalBertConfig):
"""Load a trained Hierarchical Legal-BERT model"""
print(f"π Loading model from {model_path}...")
try:
checkpoint = torch.load(model_path, map_location=config.device)
num_discovered_risks = len(checkpoint.get('discovered_patterns', {}))
print("π Loading Hierarchical BERT model")
model = HierarchicalLegalBERT(
config,
num_discovered_risks=num_discovered_risks,
hidden_dim=config.hierarchical_hidden_dim,
num_lstm_layers=config.hierarchical_num_lstm_layers
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(config.device)
model.eval()
print("β
Model loaded successfully")
return model
except FileNotFoundError:
print("β οΈ Model file not found. Please train the model first.")
return None
def demo_attention_analysis(model, tokenizer, sample_clauses: List[str]):
"""Demonstrate attention mechanism analysis"""
print("\n" + "="*80)
print("π ATTENTION MECHANISM ANALYSIS")
print("="*80)
for idx, clause in enumerate(sample_clauses[:3]):
print(f"\nπ Analyzing Clause {idx + 1}:")
print(f"Text: {clause[:100]}..." if len(clause) > 100 else f"Text: {clause}")
# Tokenize
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
# Get attention analysis
analysis = model.analyze_attention(input_ids, attention_mask, tokenizer)
# Get prediction
prediction = model.predict_risk_pattern(input_ids, attention_mask)
print(f"\n Predicted Risk ID: {prediction['predicted_risk_id'][0]}")
print(f" Severity: {prediction['severity_score'][0]:.2f}/10")
print(f" Importance: {prediction['importance_score'][0]:.2f}/10")
print(f" Confidence: {prediction['confidence'][0]:.2%}")
if 'top_tokens' in analysis:
print(f"\n π― Most Important Tokens:")
for token, score in zip(analysis['top_tokens'][:5],
analysis['top_token_scores'][0][:5]):
print(f" {token}: {score:.4f}")
print("\nβ
Attention analysis complete")
def demo_hierarchical_risk(model, tokenizer, contract_clauses: Dict[str, List[str]]):
"""Demonstrate hierarchical risk aggregation"""
print("\n" + "="*80)
print("π HIERARCHICAL RISK AGGREGATION (Clause β Contract)")
print("="*80)
aggregator = HierarchicalRiskAggregator()
for contract_name, clauses in contract_clauses.items():
print(f"\nπ Analyzing Contract: {contract_name}")
print(f" Number of clauses: {len(clauses)}")
# Get predictions for all clauses
clause_predictions = []
model.eval()
with torch.no_grad():
for clause in clauses:
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
pred = model.predict_risk_pattern(input_ids, attention_mask)
clause_predictions.append({
'predicted_risk_id': int(pred['predicted_risk_id'][0]),
'confidence': float(pred['confidence'][0]),
'severity_score': float(pred['severity_score'][0]),
'importance_score': float(pred['importance_score'][0])
})
# Aggregate to contract level
contract_risk = aggregator.aggregate_contract_risk(
clause_predictions,
method='weighted_mean'
)
# Display results
print(f"\n Contract-Level Assessment:")
print(f" ββ Risk Category: {contract_risk['contract_risk_id']}")
print(f" ββ Overall Severity: {contract_risk['contract_severity']:.2f}/10")
print(f" ββ Overall Importance: {contract_risk['contract_importance']:.2f}/10")
print(f" ββ Confidence: {contract_risk['contract_confidence']:.2%}")
print(f" ββ High-Risk Clauses: {len(contract_risk['high_risk_clauses'])}")
# Generate report
report = aggregator.generate_contract_report(clause_predictions, contract_name)
print(report)
print("\nβ
Hierarchical risk analysis complete")
def demo_risk_dependencies(model, tokenizer, contract_clauses: Dict[str, List[str]]):
"""Demonstrate risk dependency analysis"""
print("\n" + "="*80)
print("π RISK DEPENDENCY & INTERACTION ANALYSIS")
print("="*80)
dependency_analyzer = RiskDependencyAnalyzer()
# Collect predictions for all contracts
all_contract_predictions = []
model.eval()
with torch.no_grad():
for contract_name, clauses in contract_clauses.items():
clause_predictions = []
for clause in clauses:
tokens = tokenizer.tokenize_clauses([clause])
input_ids = tokens['input_ids'].to(model.config.device)
attention_mask = tokens['attention_mask'].to(model.config.device)
pred = model.predict_risk_pattern(input_ids, attention_mask)
clause_predictions.append({
'predicted_risk_id': int(pred['predicted_risk_id'][0]),
'confidence': float(pred['confidence'][0]),
'severity_score': float(pred['severity_score'][0]),
'importance_score': float(pred['importance_score'][0])
})
all_contract_predictions.append(clause_predictions)
# Compute risk correlation
print("\nπ Computing risk correlation matrix...")
correlation = dependency_analyzer.compute_risk_correlation(
all_contract_predictions,
num_risk_types=7
)
print("\n Risk Type Correlation Matrix (7x7):")
print(" " + "-"*50)
for i, row in enumerate(correlation):
print(f" Risk {i}: " + " ".join([f"{val:6.3f}" for val in row]))
# Analyze risk amplification
print("\nβ‘ Analyzing risk amplification effects...")
all_clauses = [pred for contract in all_contract_predictions for pred in contract]
amplification = dependency_analyzer.analyze_risk_amplification(all_clauses)
print("\n Risk Amplification Analysis:")
for risk_id, stats in sorted(amplification.items(),
key=lambda x: x[1]['avg_severity'],
reverse=True):
print(f" Risk {risk_id}:")
print(f" ββ Avg Severity: {stats['avg_severity']:.2f}")
print(f" ββ Max Severity: {stats['max_severity']:.2f}")
print(f" ββ Clause Count: {stats['clause_count']}")
print(f" ββ Severity Variance: {stats['severity_variance']:.2f}")
# Find risk chains
print("\nπ Identifying common risk chains...")
all_chains = []
for clause_preds in all_contract_predictions:
chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3)
all_chains.extend(chains)
from collections import Counter
chain_counts = Counter([tuple(chain) for chain in all_chains])
most_common = chain_counts.most_common(5)
print(f"\n Top 5 Most Common Risk Chains:")
for chain, count in most_common:
print(f" {list(chain)} β appeared {count} times")
print("\nβ
Risk dependency analysis complete")
def main():
"""Main demonstration script"""
print("="*80)
print("ποΈ LEGAL-BERT ADVANCED ANALYSIS DEMONSTRATION")
print("="*80)
# Initialize configuration
config = LegalBertConfig()
# Load model
model_path = f"{config.model_save_path}/best_model.pt"
model = load_trained_model(model_path, config)
if model is None:
print("\nβ οΈ Cannot proceed without trained model.")
print(" Please run 'python train.py' first to train the model.")
return
# Initialize tokenizer
tokenizer = LegalBertTokenizer(config.bert_model_name)
# Sample clauses for demonstration
sample_clauses = [
"The Company shall indemnify and hold harmless the Customer from any claims, damages, or liabilities arising from breach of this Agreement.",
"Either party may terminate this Agreement upon thirty (30) days written notice to the other party.",
"All intellectual property rights in the deliverables shall remain the exclusive property of the Company.",
"The Customer agrees to pay the Company a monthly fee of $10,000 for the services provided under this Agreement."
]
# Sample contracts (multiple clauses per contract)
contract_clauses = {
"Service_Agreement_001": [
"The Service Provider agrees to provide software development services as specified in Exhibit A.",
"Payment shall be made within 30 days of invoice receipt.",
"The Service Provider shall indemnify Client against all third-party claims arising from the services.",
"This Agreement may be terminated by either party with 60 days notice."
],
"License_Agreement_002": [
"Licensor grants Licensee a non-exclusive, worldwide license to use the Software.",
"Licensee shall pay annual license fees of $50,000.",
"All intellectual property rights remain with Licensor.",
"Confidential information must be kept confidential for 5 years."
]
}
# Run demonstrations
try:
# 1. Attention Analysis
demo_attention_analysis(model, tokenizer, sample_clauses)
# 2. Hierarchical Risk Modeling
demo_hierarchical_risk(model, tokenizer, contract_clauses)
# 3. Risk Dependencies
demo_risk_dependencies(model, tokenizer, contract_clauses)
except Exception as e:
print(f"\nβ Error during analysis: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
print("π ADVANCED ANALYSIS DEMONSTRATION COMPLETE")
print("="*80)
print("\nThese features are now integrated into the evaluation pipeline.")
print("Use them during training evaluation or post-training analysis.")
if __name__ == "__main__":
main()
|