"""AAM Diffusion LLM — Anchored Diffusion Decoder Replaces the standard softmax → token ID pipeline with: 1. Model predicts continuous vector (NO softmax) 2. 2-3 step anchored diffusion refinement 3. Disambiguation + coherence + Evoformer feedback 4. Final projection to vocabulary Key Insight (from Losion): Standard diffusion LLM: starts from NOISE → needs 50-1000 steps Anchored diffusion: starts from PREDICTED VECTOR (already meaningful) → 2-3 steps only The predicted vector serves as an "anchor" — it's already in the right neighborhood of the output space. The decoder just needs to refine it. AAM-specific: The anchor comes from graph-conditioned denoising, so it's already shaped by evidence/anomaly/reasoning from the RSVS Knowledge Graph. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F @dataclass class AnchoredDecoderConfig: """Configuration for Anchored Diffusion Decoder.""" d_model: int = 768 d_vocab: int = 32000 n_refine_steps: int = 3 d_refine: int = 512 use_evoformer_feedback: bool = True n_feedback_iterations: int = 2 disambiguation_heads: int = 8 class DisambiguationBlock(nn.Module): """Resolve between similar tokens based on graph context. The predicted continuous vector may fall between two tokens with similar meanings (e.g., "bukti" vs "dugaan"). This block uses local context and graph-conditioned attention to disambiguate. """ def __init__(self, d_model: int, n_heads: int = 8) -> None: super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_kv = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.gate = nn.Sequential( nn.Linear(d_model, 1, bias=False), nn.Sigmoid(), ) self.norm = nn.RMSNorm(d_model) self.scale = math.sqrt(self.d_kv) def forward(self, x: torch.Tensor, graph_context: Optional[torch.Tensor] = None) -> torch.Tensor: batch, seq_len, _ = x.shape q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) # Use graph context as key/value if available, otherwise self-attention if graph_context is not None: k = self.k_proj(graph_context) v = self.v_proj(graph_context) if k.dim() == 3: k = k.unsqueeze(1).expand(-1, self.n_heads, -1, -1).reshape(batch, -1, self.d_kv) v = v.unsqueeze(1).expand(-1, self.n_heads, -1, -1).reshape(batch, -1, self.d_kv) k = k.unsqueeze(1).transpose(1, 2) if k.dim() == 3 else k # Simplified: use x for k,v if graph_context shape is tricky k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) else: k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) # Causal mask scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf")) attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(x.dtype) context = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, self.d_model) context = self.out_proj(context) gate = self.gate(x) refined = x + gate * context refined = self.norm(refined) return refined class CoherenceBlock(nn.Module): """Ensure parallel tokens are consistent with each other and the graph. When predicting multiple tokens in parallel (from the continuous vector pipeline), each token's vector is predicted independently. This block ensures they are coherent as a sequence. """ def __init__(self, d_model: int, d_refine: int = 512) -> None: super().__init__() self.d_model = d_model self.coherence_mlp = nn.Sequential( nn.Linear(d_model, d_refine, bias=False), nn.SiLU(), nn.Linear(d_refine, d_model, bias=False), ) self.gate = nn.Sequential( nn.Linear(d_model, 1, bias=False), nn.Sigmoid(), ) self.norm = nn.RMSNorm(d_model) def forward(self, x: torch.Tensor) -> torch.Tensor: mlp_out = self.coherence_mlp(x) gate = self.gate(x) refined = x + gate * mlp_out refined = self.norm(refined) return refined class AnchoredDiffusionDecoder(nn.Module): """Anchored Diffusion Decoder — the core output pipeline for AAM v2.0. Replaces the standard softmax → token ID pipeline with: 1. Model predicts continuous vector (NO softmax) 2. 2-3 step anchored diffusion refinement 3. Disambiguation + coherence + Evoformer feedback 4. Final projection to vocabulary The key innovation: the predicted vector is ALREADY meaningful (it's the model's best prediction after graph-conditioned denoising). The decoder doesn't need to find the output from scratch — it just refines. """ def __init__(self, config: Optional[AnchoredDecoderConfig] = None) -> None: super().__init__() self.config = config or AnchoredDecoderConfig() self.d_model = self.config.d_model self.d_vocab = self.config.d_vocab self.n_refine_steps = self.config.n_refine_steps self.disambiguation = DisambiguationBlock( d_model=self.d_model, n_heads=self.config.disambiguation_heads, ) self.coherence_blocks = nn.ModuleList([ CoherenceBlock(d_model=self.d_model, d_refine=self.config.d_refine) for _ in range(self.n_refine_steps) ]) if self.config.use_evoformer_feedback: self.feedback_proj = nn.Sequential( nn.Linear(self.d_model, self.d_model, bias=False), nn.SiLU(), nn.Linear(self.d_model, self.d_model, bias=False), ) self.feedback_gate = nn.Sequential( nn.Linear(self.d_model, 1, bias=False), nn.Sigmoid(), ) self.feedback_norm = nn.RMSNorm(self.d_model) self.vocab_proj = nn.Linear(self.d_model, self.d_vocab, bias=False) self.pre_proj_norm = nn.RMSNorm(self.d_model) def forward( self, predicted_vectors: torch.Tensor, context: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, object]]: x = predicted_vectors info = {"n_refine_steps": self.n_refine_steps} if self.config.use_evoformer_feedback: for fb_iter in range(self.config.n_feedback_iterations): disambiguated = self.disambiguation(x, context) refined = disambiguated for step in range(self.n_refine_steps): refined = self.coherence_blocks[step](refined) feedback = self.feedback_proj(refined - x) gate = self.feedback_gate(x) x = self.feedback_norm(x + gate * feedback) info["feedback_iterations"] = self.config.n_feedback_iterations else: x = self.disambiguation(x, context) for step in range(self.n_refine_steps): x = self.coherence_blocks[step](x) x = self.pre_proj_norm(x) logits = self.vocab_proj(x) delta = (x - predicted_vectors).norm(dim=-1).mean().item() info["refinement_delta"] = delta return logits, info def predict_continuous(self, hidden_states: torch.Tensor) -> torch.Tensor: """Produce continuous prediction vectors (NO softmax).""" return hidden_states class ContinuousOutputHead(nn.Module): """Continuous output head that produces prediction vectors without softmax. Replaces the standard nn.Linear → softmax pipeline with: nn.Linear → continuous vector → AnchoredDiffusionDecoder → logits """ def __init__( self, d_model: int, d_vocab: int = 32000, decoder_config: Optional[AnchoredDecoderConfig] = None, ) -> None: super().__init__() self.d_model = d_model self.d_vocab = d_vocab self.predict_proj = nn.Sequential( nn.Linear(d_model, d_model, bias=False), nn.SiLU(), nn.Linear(d_model, d_model, bias=False), ) if decoder_config is None: decoder_config = AnchoredDecoderConfig(d_model=d_model, d_vocab=d_vocab) else: decoder_config.d_model = d_model decoder_config.d_vocab = d_vocab self.decoder = AnchoredDiffusionDecoder(decoder_config) def forward( self, hidden_states: torch.Tensor, use_diffusion: bool = True, context: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, object]]: pred_vectors = self.predict_proj(hidden_states) if use_diffusion: return self.decoder(pred_vectors, context=context) else: logits = self.decoder.vocab_proj(self.decoder.pre_proj_norm(pred_vectors)) return logits, {"mode": "standard"} def get_continuous_vectors(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.predict_proj(hidden_states)