# modules.py import torch import torch.nn as nn import torch.nn.functional as F import config class CrossAttentionDelta(nn.Module): """ Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention. Improvements: 1. Pre-norm architecture (layer norm before attention) 2. More sophisticated attention patterns 3. Ability to incorporate reasoning trace """ def __init__(self, hidden_dim, num_heads=8, dropout=0.1): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads # Pre-norm layer normalization (applied before attention) self.pre_norm = nn.LayerNorm(hidden_dim) # Cross-attention mechanism self.cross_attn = nn.MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, batch_first=True ) # Post-attention layer normalization self.post_norm = nn.LayerNorm(hidden_dim) # Trace integration module (to incorporate reasoning trace T) self.trace_integration = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim) ) # Enhanced MLP for delta computation self.delta_mlp = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim * 4), # Larger intermediate expansion nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, hidden_dim) ) # Final layer normalization self.final_norm = nn.LayerNorm(hidden_dim) def forward(self, h0, reasoning_trace=None): """ Args: h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim). reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model. Each tensor has shape (batch_size, seq_len, hidden_dim). Returns: delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim). """ batch_size, seq_len, _ = h0.shape # --- Pre-norm Architecture --- # Apply layer normalization before attention (pre-norm) h0_norm = self.pre_norm(h0) # --- Enhanced Cross-Attention --- # Get attention weights to visualize attention patterns attn_output, attn_weights = self.cross_attn( query=h0_norm, key=h0_norm, value=h0_norm, need_weights=True ) # Residual connection and post-norm c = self.post_norm(h0 + attn_output) # --- Reasoning Trace Integration (if provided) --- if reasoning_trace is not None and len(reasoning_trace) > 0: # Use the last layer from the reasoning trace (most semantic) last_layer = reasoning_trace[-1] # Integrate the reasoning trace with the current context trace_info = self.trace_integration( torch.cat([c, last_layer], dim=-1) ) # Add the trace information to the context c = c + trace_info # --- Enhanced MLP for Delta --- # Concatenate original h0 with context c mlp_input = torch.cat((h0, c), dim=-1) # Compute delta through enhanced MLP delta = self.delta_mlp(mlp_input) # Apply final normalization delta = self.final_norm(delta) return delta, attn_weights class GatingMechanism(nn.Module): """ Gating mechanism to selectively apply updates. Learns when to apply the delta update based on the hidden state and delta. """ def __init__(self, hidden_dim, dropout=0.1): super().__init__() self.gate_network = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 1), nn.Sigmoid() # Output between 0 and 1 ) def forward(self, h0, delta): """ Args: h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim). delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim). Returns: gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1). """ # Concatenate h0 and delta gate_input = torch.cat([h0, delta], dim=-1) # Compute gate values gate = self.gate_network(gate_input) return gate class EnhancedQAHead(nn.Module): """ Enhanced Question Answering head with deeper architecture and bilinear scoring. """ def __init__(self, hidden_dim, dropout=0.1): super().__init__() # Deeper representation before prediction self.start_transform = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim) ) self.end_transform = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim) ) # Bilinear layer for start position scoring self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1) # Bilinear layer for end position scoring self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1) # Global representation for bilinear scoring self.global_rep = nn.Parameter(torch.randn(hidden_dim)) def forward(self, hidden_states): """ Args: hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim). Returns: dict: Dictionary with start_logits and end_logits. """ batch_size, seq_len, hidden_dim = hidden_states.shape # Transform hidden states start_rep = self.start_transform(hidden_states) end_rep = self.end_transform(hidden_states) # Expand global representation for batch processing global_rep = self.global_rep.expand(batch_size, seq_len, -1) # Compute start and end logits using bilinear scoring start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1) end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1) return {"start_logits": start_logits, "end_logits": end_logits}