|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import config
|
|
|
|
|
|
class CrossAttentionDelta(nn.Module):
|
|
|
"""
|
|
|
Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
|
|
|
Improvements:
|
|
|
1. Pre-norm architecture (layer norm before attention)
|
|
|
2. More sophisticated attention patterns
|
|
|
3. Ability to incorporate reasoning trace
|
|
|
"""
|
|
|
def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
|
self.pre_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.cross_attn = nn.MultiheadAttention(
|
|
|
embed_dim=hidden_dim,
|
|
|
num_heads=num_heads,
|
|
|
dropout=dropout,
|
|
|
batch_first=True
|
|
|
)
|
|
|
|
|
|
|
|
|
self.post_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.trace_integration = nn.Sequential(
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.delta_mlp = nn.Sequential(
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim * 4),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim * 4, hidden_dim * 2),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.final_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
def forward(self, h0, reasoning_trace=None):
|
|
|
"""
|
|
|
Args:
|
|
|
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
|
|
reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
|
|
|
Each tensor has shape (batch_size, seq_len, hidden_dim).
|
|
|
|
|
|
Returns:
|
|
|
delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
|
|
|
"""
|
|
|
batch_size, seq_len, _ = h0.shape
|
|
|
|
|
|
|
|
|
|
|
|
h0_norm = self.pre_norm(h0)
|
|
|
|
|
|
|
|
|
|
|
|
attn_output, attn_weights = self.cross_attn(
|
|
|
query=h0_norm,
|
|
|
key=h0_norm,
|
|
|
value=h0_norm,
|
|
|
need_weights=True
|
|
|
)
|
|
|
|
|
|
|
|
|
c = self.post_norm(h0 + attn_output)
|
|
|
|
|
|
|
|
|
if reasoning_trace is not None and len(reasoning_trace) > 0:
|
|
|
|
|
|
last_layer = reasoning_trace[-1]
|
|
|
|
|
|
|
|
|
trace_info = self.trace_integration(
|
|
|
torch.cat([c, last_layer], dim=-1)
|
|
|
)
|
|
|
|
|
|
|
|
|
c = c + trace_info
|
|
|
|
|
|
|
|
|
|
|
|
mlp_input = torch.cat((h0, c), dim=-1)
|
|
|
|
|
|
|
|
|
delta = self.delta_mlp(mlp_input)
|
|
|
|
|
|
|
|
|
delta = self.final_norm(delta)
|
|
|
|
|
|
return delta, attn_weights
|
|
|
|
|
|
class GatingMechanism(nn.Module):
|
|
|
"""
|
|
|
Gating mechanism to selectively apply updates.
|
|
|
Learns when to apply the delta update based on the hidden state and delta.
|
|
|
"""
|
|
|
def __init__(self, hidden_dim, dropout=0.1):
|
|
|
super().__init__()
|
|
|
self.gate_network = nn.Sequential(
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
def forward(self, h0, delta):
|
|
|
"""
|
|
|
Args:
|
|
|
h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
|
|
delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
|
|
|
|
|
|
Returns:
|
|
|
gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
|
|
|
"""
|
|
|
|
|
|
gate_input = torch.cat([h0, delta], dim=-1)
|
|
|
|
|
|
|
|
|
gate = self.gate_network(gate_input)
|
|
|
|
|
|
return gate
|
|
|
|
|
|
class EnhancedQAHead(nn.Module):
|
|
|
"""
|
|
|
Enhanced Question Answering head with deeper architecture and bilinear scoring.
|
|
|
"""
|
|
|
def __init__(self, hidden_dim, dropout=0.1):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.start_transform = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, hidden_dim)
|
|
|
)
|
|
|
|
|
|
self.end_transform = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
|
|
|
|
|
|
|
|
self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
|
|
|
|
|
|
|
|
self.global_rep = nn.Parameter(torch.randn(hidden_dim))
|
|
|
|
|
|
def forward(self, hidden_states):
|
|
|
"""
|
|
|
Args:
|
|
|
hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
|
|
|
|
|
|
Returns:
|
|
|
dict: Dictionary with start_logits and end_logits.
|
|
|
"""
|
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
|
|
|
|
|
start_rep = self.start_transform(hidden_states)
|
|
|
end_rep = self.end_transform(hidden_states)
|
|
|
|
|
|
|
|
|
global_rep = self.global_rep.expand(batch_size, seq_len, -1)
|
|
|
|
|
|
|
|
|
start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
|
|
|
end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
|
|
|
|
|
|
return {"start_logits": start_logits, "end_logits": end_logits}
|
|
|
|