File size: 7,134 Bytes
3451ca0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# modules.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class CrossAttentionDelta(nn.Module):
"""
Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
Improvements:
1. Pre-norm architecture (layer norm before attention)
2. More sophisticated attention patterns
3. Ability to incorporate reasoning trace
"""
def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# Pre-norm layer normalization (applied before attention)
self.pre_norm = nn.LayerNorm(hidden_dim)
# Cross-attention mechanism
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Post-attention layer normalization
self.post_norm = nn.LayerNorm(hidden_dim)
# Trace integration module (to incorporate reasoning trace T)
self.trace_integration = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
# Enhanced MLP for delta computation
self.delta_mlp = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim * 4), # Larger intermediate expansion
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim)
)
# Final layer normalization
self.final_norm = nn.LayerNorm(hidden_dim)
def forward(self, h0, reasoning_trace=None):
"""
Args:
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
Each tensor has shape (batch_size, seq_len, hidden_dim).
Returns:
delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
"""
batch_size, seq_len, _ = h0.shape
# --- Pre-norm Architecture ---
# Apply layer normalization before attention (pre-norm)
h0_norm = self.pre_norm(h0)
# --- Enhanced Cross-Attention ---
# Get attention weights to visualize attention patterns
attn_output, attn_weights = self.cross_attn(
query=h0_norm,
key=h0_norm,
value=h0_norm,
need_weights=True
)
# Residual connection and post-norm
c = self.post_norm(h0 + attn_output)
# --- Reasoning Trace Integration (if provided) ---
if reasoning_trace is not None and len(reasoning_trace) > 0:
# Use the last layer from the reasoning trace (most semantic)
last_layer = reasoning_trace[-1]
# Integrate the reasoning trace with the current context
trace_info = self.trace_integration(
torch.cat([c, last_layer], dim=-1)
)
# Add the trace information to the context
c = c + trace_info
# --- Enhanced MLP for Delta ---
# Concatenate original h0 with context c
mlp_input = torch.cat((h0, c), dim=-1)
# Compute delta through enhanced MLP
delta = self.delta_mlp(mlp_input)
# Apply final normalization
delta = self.final_norm(delta)
return delta, attn_weights
class GatingMechanism(nn.Module):
"""
Gating mechanism to selectively apply updates.
Learns when to apply the delta update based on the hidden state and delta.
"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
self.gate_network = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # Output between 0 and 1
)
def forward(self, h0, delta):
"""
Args:
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
Returns:
gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
"""
# Concatenate h0 and delta
gate_input = torch.cat([h0, delta], dim=-1)
# Compute gate values
gate = self.gate_network(gate_input)
return gate
class EnhancedQAHead(nn.Module):
"""
Enhanced Question Answering head with deeper architecture and bilinear scoring.
"""
def __init__(self, hidden_dim, dropout=0.1):
super().__init__()
# Deeper representation before prediction
self.start_transform = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
self.end_transform = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim)
)
# Bilinear layer for start position scoring
self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
# Bilinear layer for end position scoring
self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
# Global representation for bilinear scoring
self.global_rep = nn.Parameter(torch.randn(hidden_dim))
def forward(self, hidden_states):
"""
Args:
hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
Returns:
dict: Dictionary with start_logits and end_logits.
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
# Transform hidden states
start_rep = self.start_transform(hidden_states)
end_rep = self.end_transform(hidden_states)
# Expand global representation for batch processing
global_rep = self.global_rep.expand(batch_size, seq_len, -1)
# Compute start and end logits using bilinear scoring
start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
return {"start_logits": start_logits, "end_logits": end_logits}
|