""" StructuredReasoningDecoder: Cross-modal decoder that generates abstract reasoning outputs. Cross-attends from graph node embeddings — never from raw text. """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class GraphPositionalEncoding(nn.Module): def __init__(self, hidden_dim: int, max_nodes: int = 256): super().__init__() self.max_nodes = max_nodes self.node_pos_embed = nn.Embedding(max_nodes, hidden_dim) self.depth_embed = nn.Embedding(16, hidden_dim) nn.init.normal_(self.node_pos_embed.weight, std=0.02) nn.init.normal_(self.depth_embed.weight, std=0.02) def forward(self, graph_positions, tree_depths=None): pos_emb = self.node_pos_embed(graph_positions.clamp(0, self.max_nodes - 1)) if tree_depths is not None: pos_emb = pos_emb + self.depth_embed(tree_depths.clamp(0, 15)) return pos_emb class CrossModalDecoderBlock(nn.Module): def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_dim self.num_heads = config.num_heads self.head_dim = config.hidden_dim // config.num_heads self.scale = 1.0 / math.sqrt(self.head_dim) self.self_attn_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3) self.self_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim) self.cross_attn_q = nn.Linear(config.hidden_dim, config.hidden_dim) self.cross_attn_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2) self.cross_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim) self.ffn = nn.Sequential( nn.Linear(config.hidden_dim, config.ffn_dim), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.ffn_dim, config.hidden_dim), ) self.ln1 = nn.LayerNorm(config.hidden_dim) self.ln2 = nn.LayerNorm(config.hidden_dim) self.ln3 = nn.LayerNorm(config.hidden_dim) self.dropout = nn.Dropout(config.dropout) def _self_attention(self, x, mask=None): B, L, D = x.shape qkv = self.self_attn_qkv(x).view(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale if mask is not None: attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) return self.self_attn_out(out) def _cross_attention(self, q_state, kv_state, mask=None): B, L, D = q_state.shape _, N, _ = kv_state.shape q = self.cross_attn_q(q_state).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) kv = self.cross_attn_kv(kv_state).view(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale if mask is not None: attn = attn.masked_fill(mask.unsqueeze(1) == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) return self.cross_attn_out(out) def forward(self, x, encoder_hidden, self_mask=None, cross_mask=None): x = x + self.dropout(self._self_attention(self.ln1(x), self_mask)) x = x + self.dropout(self._cross_attention(self.ln2(x), encoder_hidden, cross_mask)) x = x + self.dropout(self.ffn(self.ln3(x))) return x class StructuredReasoningDecoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_dim = config.hidden_dim self.vocab_size = config.vocab_size self.token_embed = nn.Embedding(config.vocab_size, config.hidden_dim) self.graph_pos_enc = GraphPositionalEncoding(config.hidden_dim, config.max_graph_nodes) self.layers = nn.ModuleList([ CrossModalDecoderBlock(config) for _ in range(config.num_decoder_layers) ]) self.final_ln = nn.LayerNorm(config.hidden_dim) self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) self.dropout = nn.Dropout(config.dropout) nn.init.normal_(self.token_embed.weight, std=0.02) nn.init.normal_(self.lm_head.weight, std=0.02) def forward(self, input_ids, encoder_hidden, graph_positions=None, tree_depths=None, attention_mask=None, encoder_mask=None): B, L = input_ids.shape x = self.token_embed(input_ids) if graph_positions is not None: x = x + self.graph_pos_enc(graph_positions, tree_depths) causal_mask = torch.tril(torch.ones(L, L, device=input_ids.device)).view(1, 1, L, L) for layer in self.layers: x = layer(x, encoder_hidden, self_mask=causal_mask, cross_mask=encoder_mask) x = self.final_ln(x) return self.lm_head(x) @torch.no_grad() def generate(self, encoder_hidden, start_token_id, graph_positions=None, encoder_mask=None, max_length=128, temperature=0.8, top_k=50, top_p=0.9): B = encoder_hidden.shape[0] device = encoder_hidden.device generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device) for _ in range(max_length - 1): logits = self.forward(generated, encoder_hidden, None, None, None, encoder_mask) next_logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(next_logits, top_k) next_logits[next_logits < v[:, -1:]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False for b in range(B): indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]] next_logits[b, indices_to_remove] = float("-inf") probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated = torch.cat([generated, next_token], dim=1) return generated