""" Legal-BERT Model Architecture - Fully Learning-Based Includes Hierarchical BERT for document-level understanding """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer from typing import Dict, List, Any, Optional, Tuple class FullyLearningBasedLegalBERT(nn.Module): """ Legal-BERT model that learns from discovered risk patterns. NO hardcoded risk categories! """ def __init__(self, config, num_discovered_risks: int = 7): super().__init__() self.config = config self.num_discovered_risks = num_discovered_risks # Load BERT model try: self.bert = AutoModel.from_pretrained(config.bert_model_name) # Configure BERT dropout self.bert.config.hidden_dropout_prob = config.dropout_rate self.bert.config.attention_probs_dropout_prob = config.dropout_rate except: # Fallback for testing without transformers print("⚠️ Warning: Using mock BERT model (transformers not available)") self.bert = None # Multi-task heads hidden_size = 768 # BERT-base hidden size # Risk classification head (for discovered risk patterns) self.risk_classifier = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_size, hidden_size // 2), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_size // 2, num_discovered_risks) ) # Severity regression head (0-10 scale) self.severity_regressor = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_size, hidden_size // 4), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_size // 4, 1), nn.Sigmoid() # Output between 0-1, will be scaled to 0-10 ) # Importance regression head (0-10 scale) self.importance_regressor = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_size, hidden_size // 4), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_size // 4, 1), nn.Sigmoid() # Output between 0-1, will be scaled to 0-10 ) # Temperature scaling for calibration self.temperature = nn.Parameter(torch.ones(1)) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, output_attentions: bool = False) -> Dict[str, torch.Tensor]: """Forward pass through the model Args: input_ids: Token IDs from tokenizer attention_mask: Attention mask for valid tokens output_attentions: If True, return attention weights for analysis """ if self.bert is not None: # Real BERT forward pass outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions ) pooled_output = outputs.pooler_output attentions = outputs.attentions if output_attentions else None else: # Mock output for testing batch_size = input_ids.size(0) pooled_output = torch.randn(batch_size, 768) if input_ids.is_cuda: pooled_output = pooled_output.cuda() attentions = None # Multi-task predictions risk_logits = self.risk_classifier(pooled_output) severity_score = self.severity_regressor(pooled_output).squeeze(-1) * 10 # Scale to 0-10 importance_score = self.importance_regressor(pooled_output).squeeze(-1) * 10 # Scale to 0-10 # Apply temperature scaling to classification logits calibrated_logits = risk_logits / self.temperature result = { 'risk_logits': risk_logits, 'calibrated_logits': calibrated_logits, 'severity_score': severity_score, 'importance_score': importance_score, 'pooled_output': pooled_output } if output_attentions and attentions is not None: result['attentions'] = attentions return result def predict_risk_pattern(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, return_attentions: bool = False) -> Dict[str, Any]: """Make predictions and return interpretable results Args: input_ids: Token IDs from tokenizer attention_mask: Attention mask for valid tokens return_attentions: If True, include attention weights for analysis """ self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask, output_attentions=return_attentions) # Get predictions risk_probs = torch.softmax(outputs['calibrated_logits'], dim=-1) predicted_risk = torch.argmax(risk_probs, dim=-1) confidence = torch.max(risk_probs, dim=-1)[0] result = { 'predicted_risk_id': predicted_risk.cpu().numpy(), 'risk_probabilities': risk_probs.cpu().numpy(), 'confidence': confidence.cpu().numpy(), 'severity_score': outputs['severity_score'].cpu().numpy(), 'importance_score': outputs['importance_score'].cpu().numpy() } if return_attentions and 'attentions' in outputs: result['attentions'] = outputs['attentions'] return result def analyze_attention(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, tokenizer: Optional['LegalBertTokenizer'] = None) -> Dict[str, Any]: """Analyze attention patterns to identify important tokens for risk assessment This method extracts and analyzes BERT attention weights to determine which tokens/words contribute most to the risk prediction. Useful for interpretability. Args: input_ids: Token IDs from tokenizer attention_mask: Attention mask for valid tokens tokenizer: Tokenizer to decode tokens (optional) Returns: Dictionary containing: - token_importance: Per-token importance scores - top_tokens: Most important tokens for prediction - attention_weights: Raw attention weights from last layer - layer_analysis: Attention analysis per layer """ self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask, output_attentions=True) if 'attentions' not in outputs or outputs['attentions'] is None: return {'error': 'Attention weights not available'} attentions = outputs['attentions'] # Tuple of (batch, num_heads, seq_len, seq_len) batch_size, seq_len = input_ids.shape # Average attention across all heads and layers for each token # Shape: (num_layers, batch, num_heads, seq_len, seq_len) all_attentions = torch.stack(attentions) # Stack all layers # Get attention to [CLS] token (index 0) which is used for classification # Average across layers and heads cls_attention = all_attentions[:, :, :, 0, :].mean(dim=[0, 2]) # (batch, seq_len) # Also get average attention from all tokens (global importance) global_attention = all_attentions.mean(dim=[0, 2, 3]) # (batch, seq_len) # Combine CLS attention and global attention for final importance score token_importance = (cls_attention + global_attention) / 2 # Mask out padding tokens token_importance = token_importance * attention_mask # Get top-k most important tokens per sample k = min(10, seq_len) top_values, top_indices = torch.topk(token_importance, k, dim=1) result = { 'token_importance': token_importance.cpu().numpy(), 'top_token_indices': top_indices.cpu().numpy(), 'top_token_scores': top_values.cpu().numpy(), 'attention_weights': { 'cls_attention': cls_attention.cpu().numpy(), 'global_attention': global_attention.cpu().numpy() } } # Add layer-wise analysis layer_attentions = [] for layer_idx, layer_attn in enumerate(attentions): # Average across heads and get attention to CLS token layer_cls_attn = layer_attn[:, :, 0, :].mean(dim=1) # (batch, seq_len) layer_attentions.append({ 'layer': layer_idx, 'cls_attention': layer_cls_attn.cpu().numpy() }) result['layer_analysis'] = layer_attentions # Decode tokens if tokenizer provided if tokenizer is not None and tokenizer.tokenizer is not None: tokens = tokenizer.tokenizer.convert_ids_to_tokens(input_ids[0]) top_tokens = [tokens[idx] for idx in top_indices[0].cpu().numpy()] result['tokens'] = tokens result['top_tokens'] = top_tokens return result class LegalBertTokenizer: """Tokenizer wrapper for Legal-BERT""" def __init__(self, model_name: str = "bert-base-uncased"): try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) except: print("⚠️ Warning: Using mock tokenizer (transformers not available)") self.tokenizer = None def tokenize_clauses(self, clauses: List[str], max_length: int = 512) -> Dict[str, torch.Tensor]: """Tokenize legal clauses for model input""" if self.tokenizer is None: # Mock tokenization for testing batch_size = len(clauses) return { 'input_ids': torch.randint(0, 1000, (batch_size, max_length)), 'attention_mask': torch.ones(batch_size, max_length) } # Real tokenization encoded = self.tokenizer( clauses, padding=True, truncation=True, max_length=max_length, return_tensors='pt' ) return { 'input_ids': encoded['input_ids'], 'attention_mask': encoded['attention_mask'] } def decode_tokens(self, token_ids: torch.Tensor) -> List[str]: """Decode token IDs back to text""" if self.tokenizer is None: return ["Mock decoded text"] * token_ids.size(0) return self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) # ============================================================================ # HIERARCHICAL BERT FOR DOCUMENT-LEVEL UNDERSTANDING # ============================================================================ class HierarchicalLegalBERT(nn.Module): """ Hierarchical BERT for document-level contract understanding **Key Innovation**: Processes documents hierarchically to maintain context Architecture: Clause Encoding (BERT) → Section Aggregation (LSTM+Attention) → Document Solves the context problem: - Your current model: Each clause processed independently ❌ - This model: Clauses processed WITH section context ✅ Usage: # Training: Same as current model (clause-level labels) # Inference: Processes full documents with context document = [ ['clause1', 'clause2'], # Section 1 ['clause3', 'clause4'], # Section 2 ] results = model.predict_document(document) """ def __init__( self, config, num_discovered_risks: int = 7, hidden_dim: int = 256, num_lstm_layers: int = 2 ): super().__init__() self.config = config self.num_discovered_risks = num_discovered_risks self.hidden_dim = hidden_dim # Load BERT for clause encoding try: self.bert = AutoModel.from_pretrained(config.bert_model_name) self.bert.config.hidden_dropout_prob = config.dropout_rate self.bert.config.attention_probs_dropout_prob = config.dropout_rate self.bert_hidden_size = self.bert.config.hidden_size # 768 except: print("⚠️ Warning: Using mock BERT model") self.bert = None self.bert_hidden_size = 768 # Hierarchical LSTM layers # Level 1: Clause-to-Section (captures context within a section) self.clause_to_section = nn.LSTM( input_size=self.bert_hidden_size, hidden_size=hidden_dim, num_layers=num_lstm_layers, bidirectional=True, dropout=config.dropout_rate if num_lstm_layers > 1 else 0, batch_first=True ) # Level 2: Section-to-Document (captures context across sections) self.section_to_document = nn.LSTM( input_size=hidden_dim * 2, # Bidirectional hidden_size=hidden_dim, num_layers=num_lstm_layers, bidirectional=True, dropout=config.dropout_rate if num_lstm_layers > 1 else 0, batch_first=True ) # Attention mechanisms for interpretability self.clause_attention = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.Tanh(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim, 1) ) self.section_attention = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.Tanh(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim, 1) ) # Task-specific prediction heads (same as your current model) # These operate on context-aware clause representations self.risk_classifier = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim, num_discovered_risks) ) self.severity_regressor = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim * 2, hidden_dim // 2), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid() ) self.importance_regressor = nn.Sequential( nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim * 2, hidden_dim // 2), nn.ReLU(), nn.Dropout(config.dropout_rate), nn.Linear(hidden_dim // 2, 1), nn.Sigmoid() ) self.temperature = nn.Parameter(torch.ones(1)) def encode_clause(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """Encode a single clause with BERT""" if self.bert is not None: outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) return outputs.pooler_output # [batch, 768] else: batch_size = input_ids.size(0) return torch.randn(batch_size, self.bert_hidden_size).to(input_ids.device) def forward_single_clause( self, input_ids: torch.Tensor, attention_mask: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Forward pass for SINGLE clause (for training compatibility) This maintains compatibility with your current training pipeline where clauses are processed one at a time during training. """ # Encode clause with BERT clause_embedding = self.encode_clause(input_ids, attention_mask) # Since we don't have section context during single-clause training, # pass through LSTM with single timestep to maintain architecture lstm_out, _ = self.clause_to_section(clause_embedding.unsqueeze(1)) context_aware_repr = lstm_out.squeeze(1) # [batch, hidden_dim*2] # Make predictions risk_logits = self.risk_classifier(context_aware_repr) severity_score = self.severity_regressor(context_aware_repr).squeeze(-1) * 10 importance_score = self.importance_regressor(context_aware_repr).squeeze(-1) * 10 calibrated_logits = risk_logits / self.temperature return { 'risk_logits': risk_logits, 'calibrated_logits': calibrated_logits, 'severity_score': severity_score, 'importance_score': importance_score, 'pooled_output': context_aware_repr } def forward_document( self, document_structure: List[List[Dict[str, torch.Tensor]]] ) -> Dict[str, Any]: """ Forward pass for FULL DOCUMENT (for inference with context) Args: document_structure: List of sections, each containing list of clause inputs Example: [ [ # Section 1 {'input_ids': tensor, 'attention_mask': tensor}, {'input_ids': tensor, 'attention_mask': tensor} ], [ # Section 2 {'input_ids': tensor, 'attention_mask': tensor} ] ] Returns: Document-level predictions with full context """ device = next(self.parameters()).device section_vectors = [] all_clause_predictions = [] attention_weights = {'clause': [], 'section': None} # Process each section for section_idx, section_clauses in enumerate(document_structure): if not section_clauses: continue # Encode all clauses in this section clause_embeddings = [] for clause_input in section_clauses: input_ids = clause_input['input_ids'].unsqueeze(0).to(device) attention_mask = clause_input['attention_mask'].unsqueeze(0).to(device) clause_emb = self.encode_clause(input_ids, attention_mask) clause_embeddings.append(clause_emb) # Stack: [num_clauses, 768] clause_hidden = torch.cat(clause_embeddings, dim=0) # LSTM over clauses → context-aware representations clause_lstm_out, _ = self.clause_to_section(clause_hidden.unsqueeze(0)) # clause_lstm_out: [1, num_clauses, hidden_dim*2] # Attention over clauses → section representation attention_logits = self.clause_attention(clause_lstm_out) clause_attn = F.softmax(attention_logits, dim=1) section_vec = torch.sum(clause_lstm_out * clause_attn, dim=1) section_vectors.append(section_vec) attention_weights['clause'].append(clause_attn.squeeze(0)) # Predict for each clause using context-aware representation for i in range(len(section_clauses)): clause_repr = clause_lstm_out[0, i, :] # Context-aware! risk_logits = self.risk_classifier(clause_repr) severity = self.severity_regressor(clause_repr).squeeze() * 10 importance = self.importance_regressor(clause_repr).squeeze() * 10 calibrated_logits = risk_logits / self.temperature all_clause_predictions.append({ 'risk_logits': risk_logits, 'calibrated_logits': calibrated_logits, 'severity_score': severity, 'importance_score': importance, 'section_idx': section_idx, 'clause_idx': i }) # Aggregate sections → document if section_vectors: section_hidden = torch.cat(section_vectors, dim=0) section_lstm_out, _ = self.section_to_document(section_hidden.unsqueeze(0)) attention_logits = self.section_attention(section_lstm_out) section_attn = F.softmax(attention_logits, dim=1) document_vec = torch.sum(section_lstm_out * section_attn, dim=1) attention_weights['section'] = section_attn.squeeze(0) else: document_vec = torch.zeros(1, self.hidden_dim * 2).to(device) return { 'document_embedding': document_vec, 'clause_predictions': all_clause_predictions, 'attention_weights': attention_weights } def predict_document( self, document_structure: List[List[Dict[str, torch.Tensor]]] ) -> Dict[str, Any]: """Inference mode with formatted output""" self.eval() with torch.no_grad(): outputs = self.forward_document(document_structure) # Format predictions predictions = [] for pred in outputs['clause_predictions']: risk_probs = F.softmax(pred['calibrated_logits'], dim=0).cpu().numpy() predicted_risk = int(risk_probs.argmax()) predictions.append({ 'section_idx': pred['section_idx'], 'clause_idx': pred['clause_idx'], 'predicted_risk_id': predicted_risk, 'risk_probabilities': risk_probs.tolist(), 'confidence': float(risk_probs[predicted_risk]), 'severity_score': pred['severity_score'].item(), 'importance_score': pred['importance_score'].item() }) return { 'clauses': predictions, 'attention_weights': { 'clause': [attn.cpu().numpy().tolist() for attn in outputs['attention_weights']['clause']], 'section': outputs['attention_weights']['section'].cpu().numpy().tolist() if outputs['attention_weights']['section'] is not None else None }, 'summary': { 'num_sections': len(document_structure), 'num_clauses': len(predictions), 'avg_severity': sum(p['severity_score'] for p in predictions) / len(predictions) if predictions else 0, 'high_risk_count': sum(1 for p in predictions if p['severity_score'] > 7) } }