StruCTA / structa /decoder.py
YOUSSEF88's picture
Upload structa/decoder.py
94ebc81 verified
"""
StructuredReasoningDecoder: Cross-modal decoder that generates abstract reasoning outputs.
Cross-attends from graph node embeddings — never from raw text.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class GraphPositionalEncoding(nn.Module):
def __init__(self, hidden_dim: int, max_nodes: int = 256):
super().__init__()
self.max_nodes = max_nodes
self.node_pos_embed = nn.Embedding(max_nodes, hidden_dim)
self.depth_embed = nn.Embedding(16, hidden_dim)
nn.init.normal_(self.node_pos_embed.weight, std=0.02)
nn.init.normal_(self.depth_embed.weight, std=0.02)
def forward(self, graph_positions, tree_depths=None):
pos_emb = self.node_pos_embed(graph_positions.clamp(0, self.max_nodes - 1))
if tree_depths is not None:
pos_emb = pos_emb + self.depth_embed(tree_depths.clamp(0, 15))
return pos_emb
class CrossModalDecoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.num_heads = config.num_heads
self.head_dim = config.hidden_dim // config.num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.self_attn_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3)
self.self_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim)
self.cross_attn_q = nn.Linear(config.hidden_dim, config.hidden_dim)
self.cross_attn_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2)
self.cross_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(config.hidden_dim, config.ffn_dim),
nn.GELU(), nn.Dropout(config.dropout),
nn.Linear(config.ffn_dim, config.hidden_dim),
)
self.ln1 = nn.LayerNorm(config.hidden_dim)
self.ln2 = nn.LayerNorm(config.hidden_dim)
self.ln3 = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(config.dropout)
def _self_attention(self, x, mask=None):
B, L, D = x.shape
qkv = self.self_attn_qkv(x).view(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
return self.self_attn_out(out)
def _cross_attention(self, q_state, kv_state, mask=None):
B, L, D = q_state.shape
_, N, _ = kv_state.shape
q = self.cross_attn_q(q_state).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
kv = self.cross_attn_kv(kv_state).view(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(1) == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
return self.cross_attn_out(out)
def forward(self, x, encoder_hidden, self_mask=None, cross_mask=None):
x = x + self.dropout(self._self_attention(self.ln1(x), self_mask))
x = x + self.dropout(self._cross_attention(self.ln2(x), encoder_hidden, cross_mask))
x = x + self.dropout(self.ffn(self.ln3(x)))
return x
class StructuredReasoningDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_dim
self.vocab_size = config.vocab_size
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_dim)
self.graph_pos_enc = GraphPositionalEncoding(config.hidden_dim, config.max_graph_nodes)
self.layers = nn.ModuleList([
CrossModalDecoderBlock(config) for _ in range(config.num_decoder_layers)
])
self.final_ln = nn.LayerNorm(config.hidden_dim)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
self.dropout = nn.Dropout(config.dropout)
nn.init.normal_(self.token_embed.weight, std=0.02)
nn.init.normal_(self.lm_head.weight, std=0.02)
def forward(self, input_ids, encoder_hidden, graph_positions=None,
tree_depths=None, attention_mask=None, encoder_mask=None):
B, L = input_ids.shape
x = self.token_embed(input_ids)
if graph_positions is not None:
x = x + self.graph_pos_enc(graph_positions, tree_depths)
causal_mask = torch.tril(torch.ones(L, L, device=input_ids.device)).view(1, 1, L, L)
for layer in self.layers:
x = layer(x, encoder_hidden, self_mask=causal_mask, cross_mask=encoder_mask)
x = self.final_ln(x)
return self.lm_head(x)
@torch.no_grad()
def generate(self, encoder_hidden, start_token_id, graph_positions=None,
encoder_mask=None, max_length=128, temperature=0.8,
top_k=50, top_p=0.9):
B = encoder_hidden.shape[0]
device = encoder_hidden.device
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
for _ in range(max_length - 1):
logits = self.forward(generated, encoder_hidden, None, None, None, encoder_mask)
next_logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(next_logits, top_k)
next_logits[next_logits < v[:, -1:]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
for b in range(B):
indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
next_logits[b, indices_to_remove] = float("-inf")
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated