rrn-qa / code /modules.py
will4381's picture
Upload folder using huggingface_hub
3451ca0 verified
# modules.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class CrossAttentionDelta(nn.Module):
"""
Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
Improvements:
1. Pre-norm architecture (layer norm before attention)
2. More sophisticated attention patterns
3. Ability to incorporate reasoning trace
"""
def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# Pre-norm layer normalization (applied before attention)
self.pre_norm = nn.LayerNorm(hidden_dim)
# Cross-attention mechanism
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Post-attention layer normalization
self.post_norm = nn.LayerNorm(hidden_dim)
# Trace integration module (to incorporate reasoning trace T)
self.trace_integration = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
# Enhanced MLP for delta computation
self.delta_mlp = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim * 4), # Larger intermediate expansion
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim)
)
# Final layer normalization
self.final_norm = nn.LayerNorm(hidden_dim)
def forward(self, h0, reasoning_trace=None):
"""
Args:
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
Each tensor has shape (batch_size, seq_len, hidden_dim).
Returns:
delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
"""
batch_size, seq_len, _ = h0.shape
# --- Pre-norm Architecture ---
# Apply layer normalization before attention (pre-norm)
h0_norm = self.pre_norm(h0)
# --- Enhanced Cross-Attention ---
# Get attention weights to visualize attention patterns
attn_output, attn_weights = self.cross_attn(
query=h0_norm,
key=h0_norm,
value=h0_norm,
need_weights=True
)
# Residual connection and post-norm
c = self.post_norm(h0 + attn_output)
# --- Reasoning Trace Integration (if provided) ---
if reasoning_trace is not None and len(reasoning_trace) > 0:
# Use the last layer from the reasoning trace (most semantic)
last_layer = reasoning_trace[-1]
# Integrate the reasoning trace with the current context
trace_info = self.trace_integration(
torch.cat([c, last_layer], dim=-1)
)
# Add the trace information to the context
c = c + trace_info
# --- Enhanced MLP for Delta ---
# Concatenate original h0 with context c
mlp_input = torch.cat((h0, c), dim=-1)
# Compute delta through enhanced MLP
delta = self.delta_mlp(mlp_input)
# Apply final normalization
delta = self.final_norm(delta)
return delta, attn_weights
class GatingMechanism(nn.Module):
"""
Gating mechanism to selectively apply updates.
Learns when to apply the delta update based on the hidden state and delta.
"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
self.gate_network = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # Output between 0 and 1
)
def forward(self, h0, delta):
"""
Args:
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
Returns:
gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
"""
# Concatenate h0 and delta
gate_input = torch.cat([h0, delta], dim=-1)
# Compute gate values
gate = self.gate_network(gate_input)
return gate
class EnhancedQAHead(nn.Module):
"""
Enhanced Question Answering head with deeper architecture and bilinear scoring.
"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
# Deeper representation before prediction
self.start_transform = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
self.end_transform = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
# Bilinear layer for start position scoring
self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
# Bilinear layer for end position scoring
self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
# Global representation for bilinear scoring
self.global_rep = nn.Parameter(torch.randn(hidden_dim))
def forward(self, hidden_states):
"""
Args:
hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
Returns:
dict: Dictionary with start_logits and end_logits.
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
# Transform hidden states
start_rep = self.start_transform(hidden_states)
end_rep = self.end_transform(hidden_states)
# Expand global representation for batch processing
global_rep = self.global_rep.expand(batch_size, seq_len, -1)
# Compute start and end logits using bilinear scoring
start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
return {"start_logits": start_logits, "end_logits": end_logits}