File size: 6,644 Bytes
94ebc81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
StructuredReasoningDecoder: Cross-modal decoder that generates abstract reasoning outputs.
Cross-attends from graph node embeddings — never from raw text.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class GraphPositionalEncoding(nn.Module):
def __init__(self, hidden_dim: int, max_nodes: int = 256):
super().__init__()
self.max_nodes = max_nodes
self.node_pos_embed = nn.Embedding(max_nodes, hidden_dim)
self.depth_embed = nn.Embedding(16, hidden_dim)
nn.init.normal_(self.node_pos_embed.weight, std=0.02)
nn.init.normal_(self.depth_embed.weight, std=0.02)
def forward(self, graph_positions, tree_depths=None):
pos_emb = self.node_pos_embed(graph_positions.clamp(0, self.max_nodes - 1))
if tree_depths is not None:
pos_emb = pos_emb + self.depth_embed(tree_depths.clamp(0, 15))
return pos_emb
class CrossModalDecoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.num_heads = config.num_heads
self.head_dim = config.hidden_dim // config.num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.self_attn_qkv = nn.Linear(config.hidden_dim, config.hidden_dim * 3)
self.self_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim)
self.cross_attn_q = nn.Linear(config.hidden_dim, config.hidden_dim)
self.cross_attn_kv = nn.Linear(config.hidden_dim, config.hidden_dim * 2)
self.cross_attn_out = nn.Linear(config.hidden_dim, config.hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(config.hidden_dim, config.ffn_dim),
nn.GELU(), nn.Dropout(config.dropout),
nn.Linear(config.ffn_dim, config.hidden_dim),
)
self.ln1 = nn.LayerNorm(config.hidden_dim)
self.ln2 = nn.LayerNorm(config.hidden_dim)
self.ln3 = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(config.dropout)
def _self_attention(self, x, mask=None):
B, L, D = x.shape
qkv = self.self_attn_qkv(x).view(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
return self.self_attn_out(out)
def _cross_attention(self, q_state, kv_state, mask=None):
B, L, D = q_state.shape
_, N, _ = kv_state.shape
q = self.cross_attn_q(q_state).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
kv = self.cross_attn_kv(kv_state).view(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(1) == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, D)
return self.cross_attn_out(out)
def forward(self, x, encoder_hidden, self_mask=None, cross_mask=None):
x = x + self.dropout(self._self_attention(self.ln1(x), self_mask))
x = x + self.dropout(self._cross_attention(self.ln2(x), encoder_hidden, cross_mask))
x = x + self.dropout(self.ffn(self.ln3(x)))
return x
class StructuredReasoningDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_dim = config.hidden_dim
self.vocab_size = config.vocab_size
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_dim)
self.graph_pos_enc = GraphPositionalEncoding(config.hidden_dim, config.max_graph_nodes)
self.layers = nn.ModuleList([
CrossModalDecoderBlock(config) for _ in range(config.num_decoder_layers)
])
self.final_ln = nn.LayerNorm(config.hidden_dim)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
self.dropout = nn.Dropout(config.dropout)
nn.init.normal_(self.token_embed.weight, std=0.02)
nn.init.normal_(self.lm_head.weight, std=0.02)
def forward(self, input_ids, encoder_hidden, graph_positions=None,
tree_depths=None, attention_mask=None, encoder_mask=None):
B, L = input_ids.shape
x = self.token_embed(input_ids)
if graph_positions is not None:
x = x + self.graph_pos_enc(graph_positions, tree_depths)
causal_mask = torch.tril(torch.ones(L, L, device=input_ids.device)).view(1, 1, L, L)
for layer in self.layers:
x = layer(x, encoder_hidden, self_mask=causal_mask, cross_mask=encoder_mask)
x = self.final_ln(x)
return self.lm_head(x)
@torch.no_grad()
def generate(self, encoder_hidden, start_token_id, graph_positions=None,
encoder_mask=None, max_length=128, temperature=0.8,
top_k=50, top_p=0.9):
B = encoder_hidden.shape[0]
device = encoder_hidden.device
generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
for _ in range(max_length - 1):
logits = self.forward(generated, encoder_hidden, None, None, None, encoder_mask)
next_logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(next_logits, top_k)
next_logits[next_logits < v[:, -1:]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
for b in range(B):
indices_to_remove = sorted_indices[b][sorted_indices_to_remove[b]]
next_logits[b, indices_to_remove] = float("-inf")
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated
|