| """AAM Diffusion LLM β Anchored Diffusion Decoder |
| |
| Replaces the standard softmax β token ID pipeline with: |
| 1. Model predicts continuous vector (NO softmax) |
| 2. 2-3 step anchored diffusion refinement |
| 3. Disambiguation + coherence + Evoformer feedback |
| 4. Final projection to vocabulary |
| |
| Key Insight (from Losion): |
| Standard diffusion LLM: starts from NOISE β needs 50-1000 steps |
| Anchored diffusion: starts from PREDICTED VECTOR (already meaningful) β 2-3 steps only |
| |
| The predicted vector serves as an "anchor" β it's already in the right |
| neighborhood of the output space. The decoder just needs to refine it. |
| |
| AAM-specific: The anchor comes from graph-conditioned denoising, so it's |
| already shaped by evidence/anomaly/reasoning from the RSVS Knowledge Graph. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class AnchoredDecoderConfig: |
| """Configuration for Anchored Diffusion Decoder.""" |
| d_model: int = 768 |
| d_vocab: int = 32000 |
| n_refine_steps: int = 3 |
| d_refine: int = 512 |
| use_evoformer_feedback: bool = True |
| n_feedback_iterations: int = 2 |
| disambiguation_heads: int = 8 |
|
|
|
|
| class DisambiguationBlock(nn.Module): |
| """Resolve between similar tokens based on graph context. |
| |
| The predicted continuous vector may fall between two tokens with similar |
| meanings (e.g., "bukti" vs "dugaan"). This block uses local context |
| and graph-conditioned attention to disambiguate. |
| """ |
|
|
| def __init__(self, d_model: int, n_heads: int = 8) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_kv = d_model // n_heads |
|
|
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.gate = nn.Sequential( |
| nn.Linear(d_model, 1, bias=False), |
| nn.Sigmoid(), |
| ) |
|
|
| self.norm = nn.RMSNorm(d_model) |
| self.scale = math.sqrt(self.d_kv) |
|
|
| def forward(self, x: torch.Tensor, graph_context: Optional[torch.Tensor] = None) -> torch.Tensor: |
| batch, seq_len, _ = x.shape |
|
|
| q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) |
|
|
| |
| if graph_context is not None: |
| k = self.k_proj(graph_context) |
| v = self.v_proj(graph_context) |
| if k.dim() == 3: |
| k = k.unsqueeze(1).expand(-1, self.n_heads, -1, -1).reshape(batch, -1, self.d_kv) |
| v = v.unsqueeze(1).expand(-1, self.n_heads, -1, -1).reshape(batch, -1, self.d_kv) |
| k = k.unsqueeze(1).transpose(1, 2) if k.dim() == 3 else k |
| |
| k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) |
| v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) |
| else: |
| k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) |
| v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2) |
|
|
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale |
| mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1) |
| scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float("-inf")) |
|
|
| attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(x.dtype) |
| context = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, self.d_model) |
| context = self.out_proj(context) |
|
|
| gate = self.gate(x) |
| refined = x + gate * context |
| refined = self.norm(refined) |
|
|
| return refined |
|
|
|
|
| class CoherenceBlock(nn.Module): |
| """Ensure parallel tokens are consistent with each other and the graph. |
| |
| When predicting multiple tokens in parallel (from the continuous vector |
| pipeline), each token's vector is predicted independently. This block |
| ensures they are coherent as a sequence. |
| """ |
|
|
| def __init__(self, d_model: int, d_refine: int = 512) -> None: |
| super().__init__() |
| self.d_model = d_model |
|
|
| self.coherence_mlp = nn.Sequential( |
| nn.Linear(d_model, d_refine, bias=False), |
| nn.SiLU(), |
| nn.Linear(d_refine, d_model, bias=False), |
| ) |
|
|
| self.gate = nn.Sequential( |
| nn.Linear(d_model, 1, bias=False), |
| nn.Sigmoid(), |
| ) |
|
|
| self.norm = nn.RMSNorm(d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| mlp_out = self.coherence_mlp(x) |
| gate = self.gate(x) |
| refined = x + gate * mlp_out |
| refined = self.norm(refined) |
| return refined |
|
|
|
|
| class AnchoredDiffusionDecoder(nn.Module): |
| """Anchored Diffusion Decoder β the core output pipeline for AAM v2.0. |
| |
| Replaces the standard softmax β token ID pipeline with: |
| 1. Model predicts continuous vector (NO softmax) |
| 2. 2-3 step anchored diffusion refinement |
| 3. Disambiguation + coherence + Evoformer feedback |
| 4. Final projection to vocabulary |
| |
| The key innovation: the predicted vector is ALREADY meaningful (it's |
| the model's best prediction after graph-conditioned denoising). The |
| decoder doesn't need to find the output from scratch β it just refines. |
| """ |
|
|
| def __init__(self, config: Optional[AnchoredDecoderConfig] = None) -> None: |
| super().__init__() |
| self.config = config or AnchoredDecoderConfig() |
| self.d_model = self.config.d_model |
| self.d_vocab = self.config.d_vocab |
| self.n_refine_steps = self.config.n_refine_steps |
|
|
| self.disambiguation = DisambiguationBlock( |
| d_model=self.d_model, |
| n_heads=self.config.disambiguation_heads, |
| ) |
|
|
| self.coherence_blocks = nn.ModuleList([ |
| CoherenceBlock(d_model=self.d_model, d_refine=self.config.d_refine) |
| for _ in range(self.n_refine_steps) |
| ]) |
|
|
| if self.config.use_evoformer_feedback: |
| self.feedback_proj = nn.Sequential( |
| nn.Linear(self.d_model, self.d_model, bias=False), |
| nn.SiLU(), |
| nn.Linear(self.d_model, self.d_model, bias=False), |
| ) |
| self.feedback_gate = nn.Sequential( |
| nn.Linear(self.d_model, 1, bias=False), |
| nn.Sigmoid(), |
| ) |
| self.feedback_norm = nn.RMSNorm(self.d_model) |
|
|
| self.vocab_proj = nn.Linear(self.d_model, self.d_vocab, bias=False) |
| self.pre_proj_norm = nn.RMSNorm(self.d_model) |
|
|
| def forward( |
| self, |
| predicted_vectors: torch.Tensor, |
| context: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, object]]: |
| x = predicted_vectors |
| info = {"n_refine_steps": self.n_refine_steps} |
|
|
| if self.config.use_evoformer_feedback: |
| for fb_iter in range(self.config.n_feedback_iterations): |
| disambiguated = self.disambiguation(x, context) |
| refined = disambiguated |
| for step in range(self.n_refine_steps): |
| refined = self.coherence_blocks[step](refined) |
|
|
| feedback = self.feedback_proj(refined - x) |
| gate = self.feedback_gate(x) |
| x = self.feedback_norm(x + gate * feedback) |
|
|
| info["feedback_iterations"] = self.config.n_feedback_iterations |
| else: |
| x = self.disambiguation(x, context) |
| for step in range(self.n_refine_steps): |
| x = self.coherence_blocks[step](x) |
|
|
| x = self.pre_proj_norm(x) |
| logits = self.vocab_proj(x) |
|
|
| delta = (x - predicted_vectors).norm(dim=-1).mean().item() |
| info["refinement_delta"] = delta |
|
|
| return logits, info |
|
|
| def predict_continuous(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Produce continuous prediction vectors (NO softmax).""" |
| return hidden_states |
|
|
|
|
| class ContinuousOutputHead(nn.Module): |
| """Continuous output head that produces prediction vectors without softmax. |
| |
| Replaces the standard nn.Linear β softmax pipeline with: |
| nn.Linear β continuous vector β AnchoredDiffusionDecoder β logits |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| d_vocab: int = 32000, |
| decoder_config: Optional[AnchoredDecoderConfig] = None, |
| ) -> None: |
| super().__init__() |
| self.d_model = d_model |
| self.d_vocab = d_vocab |
|
|
| self.predict_proj = nn.Sequential( |
| nn.Linear(d_model, d_model, bias=False), |
| nn.SiLU(), |
| nn.Linear(d_model, d_model, bias=False), |
| ) |
|
|
| if decoder_config is None: |
| decoder_config = AnchoredDecoderConfig(d_model=d_model, d_vocab=d_vocab) |
| else: |
| decoder_config.d_model = d_model |
| decoder_config.d_vocab = d_vocab |
|
|
| self.decoder = AnchoredDiffusionDecoder(decoder_config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| use_diffusion: bool = True, |
| context: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Dict[str, object]]: |
| pred_vectors = self.predict_proj(hidden_states) |
|
|
| if use_diffusion: |
| return self.decoder(pred_vectors, context=context) |
| else: |
| logits = self.decoder.vocab_proj(self.decoder.pre_proj_norm(pred_vectors)) |
| return logits, {"mode": "standard"} |
|
|
| def get_continuous_vectors(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.predict_proj(hidden_states) |
|
|