|
|
""" |
|
|
Advanced Analysis Script for Legal-BERT |
|
|
Demonstrates attention analysis, hierarchical risk modeling, and risk dependencies |
|
|
|
|
|
This script showcases the newly implemented features: |
|
|
1. Attention mechanism analysis for clause importance |
|
|
2. Hierarchical risk aggregation (clause β contract level) |
|
|
3. Risk dependency and interaction analysis |
|
|
""" |
|
|
import torch |
|
|
import json |
|
|
from typing import Dict, List, Any |
|
|
import numpy as np |
|
|
|
|
|
from config import LegalBertConfig |
|
|
from model import HierarchicalLegalBERT, LegalBertTokenizer |
|
|
from evaluator import LegalBertEvaluator |
|
|
from hierarchical_risk import HierarchicalRiskAggregator, RiskDependencyAnalyzer |
|
|
from risk_discovery import UnsupervisedRiskDiscovery |
|
|
|
|
|
|
|
|
def load_trained_model(model_path: str, config: LegalBertConfig): |
|
|
"""Load a trained Hierarchical Legal-BERT model""" |
|
|
print(f"π Loading model from {model_path}...") |
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(model_path, map_location=config.device) |
|
|
|
|
|
num_discovered_risks = len(checkpoint.get('discovered_patterns', {})) |
|
|
|
|
|
print("π Loading Hierarchical BERT model") |
|
|
model = HierarchicalLegalBERT( |
|
|
config, |
|
|
num_discovered_risks=num_discovered_risks, |
|
|
hidden_dim=config.hierarchical_hidden_dim, |
|
|
num_lstm_layers=config.hierarchical_num_lstm_layers |
|
|
) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(config.device) |
|
|
model.eval() |
|
|
print("β
Model loaded successfully") |
|
|
return model |
|
|
|
|
|
except FileNotFoundError: |
|
|
print("β οΈ Model file not found. Please train the model first.") |
|
|
return None |
|
|
|
|
|
|
|
|
def demo_attention_analysis(model, tokenizer, sample_clauses: List[str]): |
|
|
"""Demonstrate attention mechanism analysis""" |
|
|
print("\n" + "="*80) |
|
|
print("π ATTENTION MECHANISM ANALYSIS") |
|
|
print("="*80) |
|
|
|
|
|
for idx, clause in enumerate(sample_clauses[:3]): |
|
|
print(f"\nπ Analyzing Clause {idx + 1}:") |
|
|
print(f"Text: {clause[:100]}..." if len(clause) > 100 else f"Text: {clause}") |
|
|
|
|
|
|
|
|
tokens = tokenizer.tokenize_clauses([clause]) |
|
|
input_ids = tokens['input_ids'].to(model.config.device) |
|
|
attention_mask = tokens['attention_mask'].to(model.config.device) |
|
|
|
|
|
|
|
|
analysis = model.analyze_attention(input_ids, attention_mask, tokenizer) |
|
|
|
|
|
|
|
|
prediction = model.predict_risk_pattern(input_ids, attention_mask) |
|
|
|
|
|
print(f"\n Predicted Risk ID: {prediction['predicted_risk_id'][0]}") |
|
|
print(f" Severity: {prediction['severity_score'][0]:.2f}/10") |
|
|
print(f" Importance: {prediction['importance_score'][0]:.2f}/10") |
|
|
print(f" Confidence: {prediction['confidence'][0]:.2%}") |
|
|
|
|
|
if 'top_tokens' in analysis: |
|
|
print(f"\n π― Most Important Tokens:") |
|
|
for token, score in zip(analysis['top_tokens'][:5], |
|
|
analysis['top_token_scores'][0][:5]): |
|
|
print(f" {token}: {score:.4f}") |
|
|
|
|
|
print("\nβ
Attention analysis complete") |
|
|
|
|
|
|
|
|
def demo_hierarchical_risk(model, tokenizer, contract_clauses: Dict[str, List[str]]): |
|
|
"""Demonstrate hierarchical risk aggregation""" |
|
|
print("\n" + "="*80) |
|
|
print("π HIERARCHICAL RISK AGGREGATION (Clause β Contract)") |
|
|
print("="*80) |
|
|
|
|
|
aggregator = HierarchicalRiskAggregator() |
|
|
|
|
|
for contract_name, clauses in contract_clauses.items(): |
|
|
print(f"\nπ Analyzing Contract: {contract_name}") |
|
|
print(f" Number of clauses: {len(clauses)}") |
|
|
|
|
|
|
|
|
clause_predictions = [] |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for clause in clauses: |
|
|
tokens = tokenizer.tokenize_clauses([clause]) |
|
|
input_ids = tokens['input_ids'].to(model.config.device) |
|
|
attention_mask = tokens['attention_mask'].to(model.config.device) |
|
|
|
|
|
pred = model.predict_risk_pattern(input_ids, attention_mask) |
|
|
|
|
|
clause_predictions.append({ |
|
|
'predicted_risk_id': int(pred['predicted_risk_id'][0]), |
|
|
'confidence': float(pred['confidence'][0]), |
|
|
'severity_score': float(pred['severity_score'][0]), |
|
|
'importance_score': float(pred['importance_score'][0]) |
|
|
}) |
|
|
|
|
|
|
|
|
contract_risk = aggregator.aggregate_contract_risk( |
|
|
clause_predictions, |
|
|
method='weighted_mean' |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n Contract-Level Assessment:") |
|
|
print(f" ββ Risk Category: {contract_risk['contract_risk_id']}") |
|
|
print(f" ββ Overall Severity: {contract_risk['contract_severity']:.2f}/10") |
|
|
print(f" ββ Overall Importance: {contract_risk['contract_importance']:.2f}/10") |
|
|
print(f" ββ Confidence: {contract_risk['contract_confidence']:.2%}") |
|
|
print(f" ββ High-Risk Clauses: {len(contract_risk['high_risk_clauses'])}") |
|
|
|
|
|
|
|
|
report = aggregator.generate_contract_report(clause_predictions, contract_name) |
|
|
print(report) |
|
|
|
|
|
print("\nβ
Hierarchical risk analysis complete") |
|
|
|
|
|
|
|
|
def demo_risk_dependencies(model, tokenizer, contract_clauses: Dict[str, List[str]]): |
|
|
"""Demonstrate risk dependency analysis""" |
|
|
print("\n" + "="*80) |
|
|
print("π RISK DEPENDENCY & INTERACTION ANALYSIS") |
|
|
print("="*80) |
|
|
|
|
|
dependency_analyzer = RiskDependencyAnalyzer() |
|
|
|
|
|
|
|
|
all_contract_predictions = [] |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for contract_name, clauses in contract_clauses.items(): |
|
|
clause_predictions = [] |
|
|
|
|
|
for clause in clauses: |
|
|
tokens = tokenizer.tokenize_clauses([clause]) |
|
|
input_ids = tokens['input_ids'].to(model.config.device) |
|
|
attention_mask = tokens['attention_mask'].to(model.config.device) |
|
|
|
|
|
pred = model.predict_risk_pattern(input_ids, attention_mask) |
|
|
|
|
|
clause_predictions.append({ |
|
|
'predicted_risk_id': int(pred['predicted_risk_id'][0]), |
|
|
'confidence': float(pred['confidence'][0]), |
|
|
'severity_score': float(pred['severity_score'][0]), |
|
|
'importance_score': float(pred['importance_score'][0]) |
|
|
}) |
|
|
|
|
|
all_contract_predictions.append(clause_predictions) |
|
|
|
|
|
|
|
|
print("\nπ Computing risk correlation matrix...") |
|
|
correlation = dependency_analyzer.compute_risk_correlation( |
|
|
all_contract_predictions, |
|
|
num_risk_types=7 |
|
|
) |
|
|
|
|
|
print("\n Risk Type Correlation Matrix (7x7):") |
|
|
print(" " + "-"*50) |
|
|
for i, row in enumerate(correlation): |
|
|
print(f" Risk {i}: " + " ".join([f"{val:6.3f}" for val in row])) |
|
|
|
|
|
|
|
|
print("\nβ‘ Analyzing risk amplification effects...") |
|
|
all_clauses = [pred for contract in all_contract_predictions for pred in contract] |
|
|
amplification = dependency_analyzer.analyze_risk_amplification(all_clauses) |
|
|
|
|
|
print("\n Risk Amplification Analysis:") |
|
|
for risk_id, stats in sorted(amplification.items(), |
|
|
key=lambda x: x[1]['avg_severity'], |
|
|
reverse=True): |
|
|
print(f" Risk {risk_id}:") |
|
|
print(f" ββ Avg Severity: {stats['avg_severity']:.2f}") |
|
|
print(f" ββ Max Severity: {stats['max_severity']:.2f}") |
|
|
print(f" ββ Clause Count: {stats['clause_count']}") |
|
|
print(f" ββ Severity Variance: {stats['severity_variance']:.2f}") |
|
|
|
|
|
|
|
|
print("\nπ Identifying common risk chains...") |
|
|
all_chains = [] |
|
|
for clause_preds in all_contract_predictions: |
|
|
chains = dependency_analyzer.find_risk_chains(clause_preds, window_size=3) |
|
|
all_chains.extend(chains) |
|
|
|
|
|
from collections import Counter |
|
|
chain_counts = Counter([tuple(chain) for chain in all_chains]) |
|
|
most_common = chain_counts.most_common(5) |
|
|
|
|
|
print(f"\n Top 5 Most Common Risk Chains:") |
|
|
for chain, count in most_common: |
|
|
print(f" {list(chain)} β appeared {count} times") |
|
|
|
|
|
print("\nβ
Risk dependency analysis complete") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main demonstration script""" |
|
|
print("="*80) |
|
|
print("ποΈ LEGAL-BERT ADVANCED ANALYSIS DEMONSTRATION") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
config = LegalBertConfig() |
|
|
|
|
|
|
|
|
model_path = f"{config.model_save_path}/best_model.pt" |
|
|
model = load_trained_model(model_path, config) |
|
|
|
|
|
if model is None: |
|
|
print("\nβ οΈ Cannot proceed without trained model.") |
|
|
print(" Please run 'python train.py' first to train the model.") |
|
|
return |
|
|
|
|
|
|
|
|
tokenizer = LegalBertTokenizer(config.bert_model_name) |
|
|
|
|
|
|
|
|
sample_clauses = [ |
|
|
"The Company shall indemnify and hold harmless the Customer from any claims, damages, or liabilities arising from breach of this Agreement.", |
|
|
"Either party may terminate this Agreement upon thirty (30) days written notice to the other party.", |
|
|
"All intellectual property rights in the deliverables shall remain the exclusive property of the Company.", |
|
|
"The Customer agrees to pay the Company a monthly fee of $10,000 for the services provided under this Agreement." |
|
|
] |
|
|
|
|
|
|
|
|
contract_clauses = { |
|
|
"Service_Agreement_001": [ |
|
|
"The Service Provider agrees to provide software development services as specified in Exhibit A.", |
|
|
"Payment shall be made within 30 days of invoice receipt.", |
|
|
"The Service Provider shall indemnify Client against all third-party claims arising from the services.", |
|
|
"This Agreement may be terminated by either party with 60 days notice." |
|
|
], |
|
|
"License_Agreement_002": [ |
|
|
"Licensor grants Licensee a non-exclusive, worldwide license to use the Software.", |
|
|
"Licensee shall pay annual license fees of $50,000.", |
|
|
"All intellectual property rights remain with Licensor.", |
|
|
"Confidential information must be kept confidential for 5 years." |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
demo_attention_analysis(model, tokenizer, sample_clauses) |
|
|
|
|
|
|
|
|
demo_hierarchical_risk(model, tokenizer, contract_clauses) |
|
|
|
|
|
|
|
|
demo_risk_dependencies(model, tokenizer, contract_clauses) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Error during analysis: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("π ADVANCED ANALYSIS DEMONSTRATION COMPLETE") |
|
|
print("="*80) |
|
|
print("\nThese features are now integrated into the evaluation pipeline.") |
|
|
print("Use them during training evaluation or post-training analysis.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|