|
|
""" |
|
|
Utilities and helper functions for Legal-BERT project |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
from typing import Dict, List, Any, Tuple |
|
|
import logging |
|
|
|
|
|
def setup_logging(log_level: str = "INFO") -> logging.Logger: |
|
|
"""Set up logging configuration""" |
|
|
logging.basicConfig( |
|
|
level=getattr(logging, log_level.upper()), |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('legal_bert.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
return logging.getLogger(__name__) |
|
|
|
|
|
def ensure_directory_exists(path: str): |
|
|
"""Create directory if it doesn't exist""" |
|
|
if not os.path.exists(path): |
|
|
os.makedirs(path) |
|
|
print(f"π Created directory: {path}") |
|
|
|
|
|
def save_json(data: Dict[str, Any], filepath: str): |
|
|
"""Save data to JSON file""" |
|
|
ensure_directory_exists(os.path.dirname(filepath)) |
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
print(f"πΎ Saved JSON: {filepath}") |
|
|
|
|
|
def load_json(filepath: str) -> Dict[str, Any]: |
|
|
"""Load data from JSON file""" |
|
|
if not os.path.exists(filepath): |
|
|
raise FileNotFoundError(f"JSON file not found: {filepath}") |
|
|
|
|
|
with open(filepath, 'r') as f: |
|
|
data = json.load(f) |
|
|
print(f"π Loaded JSON: {filepath}") |
|
|
return data |
|
|
|
|
|
def clean_text(text: str) -> str: |
|
|
"""Clean and normalize text""" |
|
|
if not isinstance(text, str): |
|
|
return "" |
|
|
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
|
|
|
text = re.sub(r'[^\w\s.,;:()"-]', ' ', text) |
|
|
|
|
|
|
|
|
text = text.strip() |
|
|
|
|
|
return text |
|
|
|
|
|
def extract_contract_metadata(filename: str) -> Dict[str, str]: |
|
|
"""Extract metadata from contract filename""" |
|
|
|
|
|
parts = filename.replace('.txt', '').split('_') |
|
|
|
|
|
metadata = { |
|
|
'company': parts[0] if len(parts) > 0 else 'Unknown', |
|
|
'date': parts[1] if len(parts) > 1 else 'Unknown', |
|
|
'filing_type': parts[2] if len(parts) > 2 else 'Unknown', |
|
|
'exhibit': parts[3] if len(parts) > 3 else 'Unknown', |
|
|
'agreement_type': '_'.join(parts[4:]) if len(parts) > 4 else 'Unknown' |
|
|
} |
|
|
|
|
|
return metadata |
|
|
|
|
|
def format_risk_score(score: float) -> str: |
|
|
"""Format risk score for display""" |
|
|
if score < 2: |
|
|
return f"LOW ({score:.2f})" |
|
|
elif score < 5: |
|
|
return f"MEDIUM ({score:.2f})" |
|
|
elif score < 8: |
|
|
return f"HIGH ({score:.2f})" |
|
|
else: |
|
|
return f"CRITICAL ({score:.2f})" |
|
|
|
|
|
def calculate_statistics(values: List[float]) -> Dict[str, float]: |
|
|
"""Calculate basic statistics for a list of values""" |
|
|
if not values: |
|
|
return {'mean': 0, 'std': 0, 'min': 0, 'max': 0, 'median': 0} |
|
|
|
|
|
import statistics |
|
|
|
|
|
return { |
|
|
'mean': statistics.mean(values), |
|
|
'std': statistics.stdev(values) if len(values) > 1 else 0, |
|
|
'min': min(values), |
|
|
'max': max(values), |
|
|
'median': statistics.median(values) |
|
|
} |
|
|
|
|
|
def set_seed(seed: int = 42): |
|
|
"""Set random seed for reproducibility""" |
|
|
import random |
|
|
import numpy as np |
|
|
|
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
try: |
|
|
import torch |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
print(f"π² Random seed set to {seed}") |
|
|
except ImportError: |
|
|
print(f"π² Random seed set to {seed} (torch not available)") |
|
|
|
|
|
def plot_training_history(history: Dict[str, List[float]], save_path: str = None): |
|
|
"""Plot training history curves""" |
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
axes[0].plot(history['train_loss'], label='Train Loss', marker='o') |
|
|
axes[0].plot(history['val_loss'], label='Val Loss', marker='s') |
|
|
axes[0].set_xlabel('Epoch') |
|
|
axes[0].set_ylabel('Loss') |
|
|
axes[0].set_title('Training and Validation Loss') |
|
|
axes[0].legend() |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o') |
|
|
axes[1].plot(history['val_acc'], label='Val Accuracy', marker='s') |
|
|
axes[1].set_xlabel('Epoch') |
|
|
axes[1].set_ylabel('Accuracy') |
|
|
axes[1].set_title('Training and Validation Accuracy') |
|
|
axes[1].legend() |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
print(f"πΎ Training history plot saved to: {save_path}") |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
plt.close() |
|
|
|
|
|
except ImportError: |
|
|
print("β οΈ matplotlib not available. Skipping training history plot.") |
|
|
|
|
|
def format_time(seconds: float) -> str: |
|
|
"""Format time in seconds to human readable string""" |
|
|
if seconds < 60: |
|
|
return f"{seconds:.1f}s" |
|
|
elif seconds < 3600: |
|
|
minutes = int(seconds // 60) |
|
|
secs = int(seconds % 60) |
|
|
return f"{minutes}m {secs}s" |
|
|
else: |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
return f"{hours}h {minutes}m" |
|
|
|
|
|
def print_progress_bar(iteration: int, total: int, prefix: str = 'Progress', |
|
|
suffix: str = 'Complete', length: int = 50): |
|
|
"""Print a progress bar""" |
|
|
percent = (100 * (iteration / float(total))) |
|
|
filled_length = int(length * iteration // total) |
|
|
bar = 'β' * filled_length + '-' * (length - filled_length) |
|
|
print(f'\r{prefix} |{bar}| {percent:.1f}% {suffix}', end='') |
|
|
if iteration == total: |
|
|
print() |
|
|
|
|
|
def validate_config(config) -> List[str]: |
|
|
"""Validate configuration settings""" |
|
|
errors = [] |
|
|
|
|
|
|
|
|
required_fields = ['bert_model_name', 'data_path', 'batch_size', 'num_epochs'] |
|
|
for field in required_fields: |
|
|
if not hasattr(config, field): |
|
|
errors.append(f"Missing required config field: {field}") |
|
|
|
|
|
|
|
|
if hasattr(config, 'data_path') and not os.path.exists(config.data_path): |
|
|
errors.append(f"Data path does not exist: {config.data_path}") |
|
|
|
|
|
|
|
|
if hasattr(config, 'batch_size') and config.batch_size <= 0: |
|
|
errors.append("Batch size must be positive") |
|
|
|
|
|
if hasattr(config, 'num_epochs') and config.num_epochs <= 0: |
|
|
errors.append("Number of epochs must be positive") |
|
|
|
|
|
|
|
|
if hasattr(config, 'learning_rate') and (config.learning_rate <= 0 or config.learning_rate > 1): |
|
|
errors.append("Learning rate must be between 0 and 1") |
|
|
|
|
|
return errors |
|
|
|
|
|
def create_model_summary(model, config) -> str: |
|
|
"""Create a summary of the model architecture""" |
|
|
try: |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
except: |
|
|
total_params = "Unknown" |
|
|
trainable_params = "Unknown" |
|
|
|
|
|
summary = [ |
|
|
"π MODEL SUMMARY", |
|
|
"=" * 50, |
|
|
f"Architecture: Legal-BERT (Fully Learning-Based)", |
|
|
f"Base Model: {config.bert_model_name}", |
|
|
f"Risk Categories: {config.num_risk_categories} (discovered)", |
|
|
f"Max Sequence Length: {config.max_sequence_length}", |
|
|
f"Dropout Rate: {config.dropout_rate}", |
|
|
f"Total Parameters: {total_params}", |
|
|
f"Trainable Parameters: {trainable_params}", |
|
|
f"Device: {config.device}", |
|
|
"=" * 50 |
|
|
] |
|
|
|
|
|
return "\n".join(summary) |
|
|
|
|
|
def check_dependencies() -> Dict[str, bool]: |
|
|
"""Check if required dependencies are available""" |
|
|
dependencies = { |
|
|
'torch': False, |
|
|
'transformers': False, |
|
|
'sklearn': False, |
|
|
'numpy': False, |
|
|
'pandas': False |
|
|
} |
|
|
|
|
|
for dep in dependencies: |
|
|
try: |
|
|
__import__(dep) |
|
|
dependencies[dep] = True |
|
|
except ImportError: |
|
|
dependencies[dep] = False |
|
|
|
|
|
return dependencies |
|
|
|
|
|
def print_dependency_status(): |
|
|
"""Print status of dependencies""" |
|
|
deps = check_dependencies() |
|
|
|
|
|
print("π¦ DEPENDENCY STATUS") |
|
|
print("-" * 30) |
|
|
|
|
|
for dep, available in deps.items(): |
|
|
status = "β
Available" if available else "β Missing" |
|
|
print(f"{dep:12} : {status}") |
|
|
|
|
|
missing = [dep for dep, available in deps.items() if not available] |
|
|
|
|
|
if missing: |
|
|
print(f"\nβ οΈ Missing dependencies: {', '.join(missing)}") |
|
|
print("Install with: pip install torch transformers scikit-learn numpy pandas") |
|
|
print("For demo mode, dependencies are not required.") |
|
|
else: |
|
|
print("\nπ All dependencies available!") |
|
|
|
|
|
def get_sample_contract_text() -> str: |
|
|
"""Get sample contract text for testing""" |
|
|
return """ |
|
|
SERVICES AGREEMENT |
|
|
|
|
|
This Services Agreement ("Agreement") is entered into as of the Effective Date |
|
|
by and between Company A ("Provider") and Company B ("Client"). |
|
|
|
|
|
1. SERVICES |
|
|
Provider shall provide the services described in Exhibit A ("Services") to Client |
|
|
in accordance with the terms and conditions set forth herein. |
|
|
|
|
|
2. PAYMENT TERMS |
|
|
Client shall pay Provider the fees specified in Exhibit B within thirty (30) days |
|
|
of receipt of each invoice. Late payments shall incur a penalty of 1.5% per month. |
|
|
|
|
|
3. INDEMNIFICATION |
|
|
Each party shall indemnify and hold harmless the other party from and against any |
|
|
third-party claims arising out of such party's breach of this Agreement. |
|
|
|
|
|
4. LIMITATION OF LIABILITY |
|
|
In no event shall either party's liability exceed the total amount paid under this |
|
|
Agreement in the twelve (12) months preceding the claim. |
|
|
|
|
|
5. TERMINATION |
|
|
Either party may terminate this Agreement upon thirty (30) days written notice |
|
|
to the other party. Upon termination, all confidential information shall be returned. |
|
|
|
|
|
6. GOVERNING LAW |
|
|
This Agreement shall be governed by and construed in accordance with the laws |
|
|
of the State of Delaware. |
|
|
""" |
|
|
|
|
|
|
|
|
def split_into_clauses(text: str, method: str = 'sentence') -> List[str]: |
|
|
""" |
|
|
Split a contract paragraph/document into individual clauses. |
|
|
|
|
|
This is CRITICAL for real-world usage because: |
|
|
- Contracts have 50-500+ clauses |
|
|
- Model processes ONE clause at a time |
|
|
- Need to segment before analysis |
|
|
|
|
|
Args: |
|
|
text: Full contract text or paragraph |
|
|
method: 'sentence' (basic) or 'legal' (advanced legal-aware splitting) |
|
|
|
|
|
Returns: |
|
|
List of individual clauses |
|
|
|
|
|
Example: |
|
|
>>> text = "The Company shall not be liable. Either party may terminate." |
|
|
>>> clauses = split_into_clauses(text) |
|
|
>>> # Returns: ["The Company shall not be liable.", "Either party may terminate."] |
|
|
""" |
|
|
if not text or not isinstance(text, str): |
|
|
return [] |
|
|
|
|
|
if method == 'sentence': |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
clauses = re.split(r'(?<=[.;])\s+(?=[A-Z])|(?<=\n)\s*(?=[A-Z])', text) |
|
|
|
|
|
|
|
|
clauses = [c.strip() for c in clauses if c.strip()] |
|
|
|
|
|
|
|
|
clauses = [c for c in clauses if len(c) >= 10] |
|
|
|
|
|
return clauses |
|
|
|
|
|
elif method == 'legal': |
|
|
|
|
|
import re |
|
|
|
|
|
clauses = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sections = re.split(r'\n\s*(\d+\.?\s+[A-Z][A-Z\s]+)\n', text) |
|
|
|
|
|
for section in sections: |
|
|
if not section.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
sentences = re.split(r'(?<=[.;])\s+(?=[A-Z(])', section) |
|
|
|
|
|
for sent in sentences: |
|
|
sent = sent.strip() |
|
|
if len(sent) >= 10: |
|
|
clauses.append(sent) |
|
|
|
|
|
return clauses |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown method: {method}. Use 'sentence' or 'legal'") |
|
|
|
|
|
|
|
|
def analyze_full_document( |
|
|
text: str, |
|
|
model, |
|
|
return_details: bool = True, |
|
|
use_context: bool = True, |
|
|
context_window: int = 1 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Analyze a full contract document (multiple clauses). |
|
|
|
|
|
CONTEXT-AWARE ANALYSIS: |
|
|
- By default, includes surrounding clauses as context (use_context=True) |
|
|
- This solves the problem of references like "Such Services", "Section 5", etc. |
|
|
- Each clause gets analyzed with its neighboring clauses for better understanding |
|
|
|
|
|
This is the HIGH-LEVEL function you'd use in production: |
|
|
- Takes full contract text |
|
|
- Splits into clauses automatically |
|
|
- Analyzes each clause (with context!) |
|
|
- Returns aggregated results |
|
|
|
|
|
Args: |
|
|
text: Full contract text (can be 10+ pages) |
|
|
model: Trained LegalBERT model |
|
|
return_details: If True, include per-clause predictions |
|
|
use_context: If True, include surrounding clauses as context (RECOMMENDED) |
|
|
context_window: Number of clauses before/after to include (1 = prev + curr + next) |
|
|
|
|
|
Returns: |
|
|
Dictionary with document-level and clause-level analysis |
|
|
|
|
|
Example: |
|
|
>>> contract = "The Company shall provide services... [1000 more words]" |
|
|
>>> results = analyze_full_document(contract, model, use_context=True) |
|
|
>>> print(f"Document risk: {results['overall_severity']}") |
|
|
>>> print(f"High-risk clauses: {len(results['high_risk_clauses'])}") |
|
|
""" |
|
|
|
|
|
clauses = split_into_clauses(text, method='legal') |
|
|
|
|
|
if not clauses: |
|
|
return { |
|
|
'error': 'No clauses found in document', |
|
|
'n_clauses': 0 |
|
|
} |
|
|
|
|
|
|
|
|
clause_predictions = [] |
|
|
|
|
|
if use_context: |
|
|
print(f"π Analyzing document with {len(clauses)} clauses (context-aware)...") |
|
|
print(f" Context window: Β±{context_window} clauses") |
|
|
else: |
|
|
print(f"π Analyzing document with {len(clauses)} clauses...") |
|
|
|
|
|
for i, clause in enumerate(clauses): |
|
|
try: |
|
|
|
|
|
if use_context: |
|
|
|
|
|
start_idx = max(0, i - context_window) |
|
|
|
|
|
end_idx = min(len(clauses), i + context_window + 1) |
|
|
|
|
|
|
|
|
context_clauses = clauses[start_idx:end_idx] |
|
|
|
|
|
|
|
|
|
|
|
clause_with_context = " ".join(context_clauses) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_text = clause_with_context |
|
|
else: |
|
|
|
|
|
input_text = clause |
|
|
|
|
|
|
|
|
pred = model.predict(input_text) |
|
|
|
|
|
clause_predictions.append({ |
|
|
'clause_id': i, |
|
|
'clause_text': clause, |
|
|
'analyzed_with_context': use_context, |
|
|
'risk_type': pred.get('risk_type'), |
|
|
'risk_name': pred.get('risk_name'), |
|
|
'confidence': pred.get('confidence'), |
|
|
'severity': pred.get('severity'), |
|
|
'importance': pred.get('importance') |
|
|
}) |
|
|
|
|
|
if (i + 1) % 10 == 0: |
|
|
print(f" Processed {i + 1}/{len(clauses)} clauses...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Error analyzing clause {i}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if not clause_predictions: |
|
|
return { |
|
|
'error': 'Failed to analyze any clauses', |
|
|
'n_clauses': len(clauses) |
|
|
} |
|
|
|
|
|
|
|
|
severities = [p['severity'] for p in clause_predictions if p.get('severity')] |
|
|
importances = [p['importance'] for p in clause_predictions if p.get('importance')] |
|
|
|
|
|
|
|
|
high_risk_clauses = [ |
|
|
p for p in clause_predictions |
|
|
if p.get('severity', 0) > 7.0 |
|
|
] |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
risk_counts = Counter([p['risk_name'] for p in clause_predictions if p.get('risk_name')]) |
|
|
total = len(clause_predictions) |
|
|
risk_distribution = { |
|
|
risk: count / total |
|
|
for risk, count in risk_counts.items() |
|
|
} |
|
|
|
|
|
|
|
|
dominant_risk = risk_counts.most_common(1)[0] if risk_counts else ('UNKNOWN', 0) |
|
|
|
|
|
|
|
|
result = { |
|
|
'document_summary': { |
|
|
'total_clauses': len(clauses), |
|
|
'analyzed_clauses': len(clause_predictions), |
|
|
'overall_severity': sum(severities) / len(severities) if severities else 0, |
|
|
'max_severity': max(severities) if severities else 0, |
|
|
'overall_importance': sum(importances) / len(importances) if importances else 0, |
|
|
'high_risk_clause_count': len(high_risk_clauses), |
|
|
'dominant_risk_type': dominant_risk[0], |
|
|
'dominant_risk_percentage': (dominant_risk[1] / total * 100) if total > 0 else 0 |
|
|
}, |
|
|
'risk_distribution': risk_distribution, |
|
|
'high_risk_clauses': high_risk_clauses[:10] if high_risk_clauses else [] |
|
|
} |
|
|
|
|
|
|
|
|
if return_details: |
|
|
result['all_clauses'] = clause_predictions |
|
|
|
|
|
print(f"β
Analysis complete!") |
|
|
print(f" Overall Severity: {result['document_summary']['overall_severity']:.2f}") |
|
|
print(f" High-Risk Clauses: {len(high_risk_clauses)}") |
|
|
print(f" Dominant Risk: {dominant_risk[0]} ({dominant_risk[1]} clauses)") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def analyze_with_section_context(text: str, model, return_details: bool = True) -> Dict[str, Any]: |
|
|
""" |
|
|
Advanced context-aware analysis using document structure. |
|
|
|
|
|
SECTION-AWARE APPROACH: |
|
|
- Identifies document sections (e.g., "1. SERVICES", "2. PAYMENT") |
|
|
- Analyzes clauses within section context |
|
|
- Preserves hierarchical relationships |
|
|
|
|
|
This is better than sliding window because: |
|
|
- Respects document structure |
|
|
- Section headers provide semantic context |
|
|
- References like "this Section" are understood |
|
|
|
|
|
Args: |
|
|
text: Full contract text |
|
|
model: Trained model |
|
|
return_details: Include all clause predictions |
|
|
|
|
|
Returns: |
|
|
Analysis with section-level grouping |
|
|
|
|
|
Example: |
|
|
>>> results = analyze_with_section_context(contract, model) |
|
|
>>> for section in results['sections']: |
|
|
... print(f"{section['title']}: {section['avg_severity']}") |
|
|
""" |
|
|
import re |
|
|
|
|
|
print("π Analyzing document with section-aware context...") |
|
|
|
|
|
|
|
|
|
|
|
section_pattern = r'\n\s*(\d+\.?\d*\s+[A-Z][A-Z\s]+)\n' |
|
|
|
|
|
|
|
|
parts = re.split(section_pattern, text) |
|
|
|
|
|
sections = [] |
|
|
current_section = {'title': 'Preamble', 'text': parts[0], 'clauses': []} |
|
|
|
|
|
|
|
|
for i in range(1, len(parts), 2): |
|
|
if i + 1 < len(parts): |
|
|
|
|
|
if current_section['text'].strip(): |
|
|
section_clauses = split_into_clauses(current_section['text'], method='sentence') |
|
|
current_section['clauses'] = section_clauses |
|
|
sections.append(current_section) |
|
|
|
|
|
|
|
|
current_section = { |
|
|
'title': parts[i].strip(), |
|
|
'text': parts[i + 1], |
|
|
'clauses': [] |
|
|
} |
|
|
|
|
|
|
|
|
if current_section['text'].strip(): |
|
|
section_clauses = split_into_clauses(current_section['text'], method='sentence') |
|
|
current_section['clauses'] = section_clauses |
|
|
sections.append(current_section) |
|
|
|
|
|
print(f" Identified {len(sections)} sections") |
|
|
|
|
|
|
|
|
all_predictions = [] |
|
|
section_summaries = [] |
|
|
|
|
|
for sect_idx, section in enumerate(sections): |
|
|
section_title = section['title'] |
|
|
section_text = section['text'] |
|
|
clauses = section['clauses'] |
|
|
|
|
|
print(f" Analyzing section: {section_title} ({len(clauses)} clauses)") |
|
|
|
|
|
section_predictions = [] |
|
|
|
|
|
for clause_idx, clause in enumerate(clauses): |
|
|
try: |
|
|
|
|
|
|
|
|
context_input = f"{section_title}. {section_text}" |
|
|
|
|
|
|
|
|
if len(context_input) > 1000: |
|
|
|
|
|
window_start = max(0, clause_idx - 2) |
|
|
window_end = min(len(clauses), clause_idx + 3) |
|
|
nearby = " ".join(clauses[window_start:window_end]) |
|
|
context_input = f"{section_title}. {nearby}" |
|
|
|
|
|
|
|
|
pred = model.predict(context_input) |
|
|
|
|
|
prediction = { |
|
|
'clause_id': len(all_predictions), |
|
|
'section': section_title, |
|
|
'clause_text': clause, |
|
|
'risk_type': pred.get('risk_type'), |
|
|
'risk_name': pred.get('risk_name'), |
|
|
'confidence': pred.get('confidence'), |
|
|
'severity': pred.get('severity'), |
|
|
'importance': pred.get('importance'), |
|
|
'analyzed_with_section_context': True |
|
|
} |
|
|
|
|
|
section_predictions.append(prediction) |
|
|
all_predictions.append(prediction) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Error in {section_title}, clause {clause_idx}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
if section_predictions: |
|
|
severities = [p['severity'] for p in section_predictions if p.get('severity')] |
|
|
avg_severity = sum(severities) / len(severities) if severities else 0 |
|
|
|
|
|
section_summaries.append({ |
|
|
'title': section_title, |
|
|
'clause_count': len(clauses), |
|
|
'avg_severity': avg_severity, |
|
|
'max_severity': max(severities) if severities else 0, |
|
|
'high_risk_count': sum(1 for s in severities if s > 7) |
|
|
}) |
|
|
|
|
|
|
|
|
if not all_predictions: |
|
|
return {'error': 'No predictions generated'} |
|
|
|
|
|
from collections import Counter |
|
|
|
|
|
severities = [p['severity'] for p in all_predictions if p.get('severity')] |
|
|
risk_counts = Counter([p['risk_name'] for p in all_predictions if p.get('risk_name')]) |
|
|
total = len(all_predictions) |
|
|
|
|
|
result = { |
|
|
'document_summary': { |
|
|
'total_sections': len(sections), |
|
|
'total_clauses': len(all_predictions), |
|
|
'overall_severity': sum(severities) / len(severities) if severities else 0, |
|
|
'max_severity': max(severities) if severities else 0, |
|
|
'high_risk_clause_count': sum(1 for s in severities if s > 7) |
|
|
}, |
|
|
'sections': section_summaries, |
|
|
'risk_distribution': {risk: count/total for risk, count in risk_counts.items()}, |
|
|
'all_clauses': all_predictions if return_details else [] |
|
|
} |
|
|
|
|
|
print(f"β
Analysis complete!") |
|
|
print(f" {len(sections)} sections analyzed") |
|
|
print(f" Overall severity: {result['document_summary']['overall_severity']:.2f}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def print_document_analysis(results: Dict[str, Any]): |
|
|
""" |
|
|
Pretty-print document analysis results. |
|
|
|
|
|
Args: |
|
|
results: Output from analyze_full_document() |
|
|
""" |
|
|
print("\n" + "=" * 80) |
|
|
print("π DOCUMENT RISK ANALYSIS REPORT") |
|
|
print("=" * 80) |
|
|
|
|
|
summary = results.get('document_summary', {}) |
|
|
|
|
|
print(f"\nπ Document Overview:") |
|
|
print(f" Total Clauses: {summary.get('total_clauses', 0)}") |
|
|
print(f" Analyzed: {summary.get('analyzed_clauses', 0)}") |
|
|
|
|
|
print(f"\nβ οΈ Risk Assessment:") |
|
|
severity = summary.get('overall_severity', 0) |
|
|
print(f" Overall Severity: {severity:.2f}/10 - {format_risk_score(severity)}") |
|
|
print(f" Maximum Severity: {summary.get('max_severity', 0):.2f}/10") |
|
|
print(f" Overall Importance: {summary.get('overall_importance', 0):.2f}/10") |
|
|
|
|
|
print(f"\nπ΄ High-Risk Clauses:") |
|
|
print(f" Count: {summary.get('high_risk_clause_count', 0)}") |
|
|
|
|
|
print(f"\nπ Risk Distribution:") |
|
|
for risk_type, percentage in results.get('risk_distribution', {}).items(): |
|
|
print(f" {risk_type}: {percentage*100:.1f}%") |
|
|
|
|
|
print(f"\nπ― Dominant Risk:") |
|
|
print(f" {summary.get('dominant_risk_type', 'N/A')} " |
|
|
f"({summary.get('dominant_risk_percentage', 0):.1f}% of clauses)") |
|
|
|
|
|
|
|
|
high_risk = results.get('high_risk_clauses', []) |
|
|
if high_risk: |
|
|
print(f"\nπ Top High-Risk Clauses:") |
|
|
for i, clause in enumerate(high_risk[:5], 1): |
|
|
print(f"\n {i}. {clause['risk_name']} (Severity: {clause['severity']:.1f})") |
|
|
text = clause['clause_text'][:100] + "..." if len(clause['clause_text']) > 100 else clause['clause_text'] |
|
|
print(f" \"{text}\"") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
|
|
|
|
|
|
def parse_document_hierarchically(text: str) -> List[List[str]]: |
|
|
""" |
|
|
Parse document into hierarchical structure: sections β clauses |
|
|
|
|
|
Args: |
|
|
text: Full document text |
|
|
|
|
|
Returns: |
|
|
List of sections, each containing list of clauses |
|
|
Example: [ |
|
|
['clause1', 'clause2'], # Section 1 |
|
|
['clause3', 'clause4'], # Section 2 |
|
|
] |
|
|
""" |
|
|
|
|
|
section_pattern = r'\n\s*(\d+\.?\d*\s+[A-Z][A-Z\s]+)\n' |
|
|
sections = re.split(section_pattern, text) |
|
|
|
|
|
document_structure = [] |
|
|
|
|
|
|
|
|
for i in range(1, len(sections), 2): |
|
|
if i + 1 < len(sections): |
|
|
section_title = sections[i].strip() |
|
|
section_text = sections[i + 1].strip() |
|
|
|
|
|
|
|
|
clauses = split_into_clauses(section_text, method='sentence') |
|
|
|
|
|
if clauses: |
|
|
document_structure.append(clauses) |
|
|
|
|
|
|
|
|
if not document_structure: |
|
|
clauses = split_into_clauses(text, method='sentence') |
|
|
if clauses: |
|
|
document_structure.append(clauses) |
|
|
|
|
|
return document_structure |
|
|
|
|
|
|
|
|
def prepare_hierarchical_input(clauses: List[str], tokenizer) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Prepare clauses for hierarchical model input |
|
|
|
|
|
Args: |
|
|
clauses: List of clause texts |
|
|
tokenizer: LegalBertTokenizer instance |
|
|
|
|
|
Returns: |
|
|
List of tokenized inputs for each clause |
|
|
""" |
|
|
clause_inputs = [] |
|
|
|
|
|
for clause in clauses: |
|
|
encoded = tokenizer.tokenize_clauses([clause], max_length=128) |
|
|
clause_inputs.append({ |
|
|
'input_ids': encoded['input_ids'].squeeze(0), |
|
|
'attention_mask': encoded['attention_mask'].squeeze(0) |
|
|
}) |
|
|
|
|
|
return clause_inputs |