| """ |
| StructuredReasoningDecoder: Cross-modal decoder that generates abstract reasoning outputs. |
| Cross-attends from graph node embeddings — never from raw text. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
|
|
| class GraphPositionalEncoding(nn.Module): |
| def __init__(self, hidden_dim: int, max_nodes: int = 256): |
| super().__init__() |
| self.max_nodes = max_nodes |
| self.node_pos_embed = nn.Embedding(max_nodes, hidden_dim) |
| self.depth_embed = nn.Embedding(16, hidden_dim) |
| nn.init.normal_(self.node_pos_embed.weight, std=0.02) |
| nn.init.normal_(self.depth_embed.weight, std=0.02) |
|
|
| def forward(self, graph_positions, tree_depths=None): |
| pos_emb = self.node_pos_embed(graph_positions.clamp(0, self.max_nodes - 1)) |
| if tree_depths is not None: |
| pos_emb = pos_emb + self.depth_embed(tree_depths.clamp(0, 15)) |
| return pos_emb |
|
|
|
|
| class CrossModalDecoderBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_dim = config.hidden_dim |
| self.num_heads = config.num_heads |
| self.head_dim = config.hidden_dim // config.num_heads |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
| self.self_attn_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3) |
| self.self_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.cross_attn_q = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.cross_attn_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2) |
| self.cross_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim) |
| self.ffn = nn.Sequential( |
| nn.Linear(config.hidden_dim, config.ffn_dim), |
| nn.GELU(), nn.Dropout(config.dropout), |
| nn.Linear(config.ffn_dim, config.hidden_dim), |
| ) |
| self.ln1 = nn.LayerNorm(config.hidden_dim) |
| self.ln2 = nn.LayerNorm(config.hidden_dim) |
| self.ln3 = nn.LayerNorm(config.hidden_dim) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def _self_attention(self, x, mask=None): |
| B, L, D = x.shape |
| qkv = self.self_attn_qkv(x).view(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| if mask is not None: |
| attn = attn.masked_fill(mask == 0, float("-inf")) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
| out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) |
| return self.self_attn_out(out) |
|
|
| def _cross_attention(self, q_state, kv_state, mask=None): |
| B, L, D = q_state.shape |
| _, N, _ = kv_state.shape |
| q = self.cross_attn_q(q_state).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) |
| kv = self.cross_attn_kv(kv_state).view(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| k, v = kv[0], kv[1] |
| attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| if mask is not None: |
| attn = attn.masked_fill(mask.unsqueeze(1) == 0, float("-inf")) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.dropout(attn) |
| out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D) |
| return self.cross_attn_out(out) |
|
|
| def forward(self, x, encoder_hidden, self_mask=None, cross_mask=None): |
| x = x + self.dropout(self._self_attention(self.ln1(x), self_mask)) |
| x = x + self.dropout(self._cross_attention(self.ln2(x), encoder_hidden, cross_mask)) |
| x = x + self.dropout(self.ffn(self.ln3(x))) |
| return x |
|
|
|
|
| class StructuredReasoningDecoder(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_dim = config.hidden_dim |
| self.vocab_size = config.vocab_size |
| self.token_embed = nn.Embedding(config.vocab_size, config.hidden_dim) |
| self.graph_pos_enc = GraphPositionalEncoding(config.hidden_dim, config.max_graph_nodes) |
| self.layers = nn.ModuleList([ |
| CrossModalDecoderBlock(config) for _ in range(config.num_decoder_layers) |
| ]) |
| self.final_ln = nn.LayerNorm(config.hidden_dim) |
| self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
| nn.init.normal_(self.token_embed.weight, std=0.02) |
| nn.init.normal_(self.lm_head.weight, std=0.02) |
|
|
| def forward(self, input_ids, encoder_hidden, graph_positions=None, |
| tree_depths=None, attention_mask=None, encoder_mask=None): |
| B, L = input_ids.shape |
| x = self.token_embed(input_ids) |
| if graph_positions is not None: |
| x = x + self.graph_pos_enc(graph_positions, tree_depths) |
| causal_mask = torch.tril(torch.ones(L, L, device=input_ids.device)).view(1, 1, L, L) |
| for layer in self.layers: |
| x = layer(x, encoder_hidden, self_mask=causal_mask, cross_mask=encoder_mask) |
| x = self.final_ln(x) |
| return self.lm_head(x) |
|
|
| @torch.no_grad() |
| def generate(self, encoder_hidden, start_token_id, graph_positions=None, |
| encoder_mask=None, max_length=128, temperature=0.8, |
| top_k=50, top_p=0.9): |
| B = encoder_hidden.shape[0] |
| device = encoder_hidden.device |
| generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device) |
| for _ in range(max_length - 1): |
| logits = self.forward(generated, encoder_hidden, None, None, None, encoder_mask) |
| next_logits = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(next_logits, top_k) |
| next_logits[next_logits < v[:, -1:]] = float("-inf") |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = False |
| for b in range(B): |
| indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]] |
| next_logits[b, indices_to_remove] = float("-inf") |
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated = torch.cat([generated, next_token], dim=1) |
| return generated |
|
|