gemeo-sus / src /diffusion_forcing_v13.py
timmers's picture
Upload src/diffusion_forcing_v13.py with huggingface_hub
bd692df verified
"""GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.
Per the SOTA audit (May 2026):
- Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
- Path A adopted: small from-scratch model + KG adapters + MEDS interop
Architecture:
- 12M backbone params (Chinchilla-respecting for ~20M token corpus)
- d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
- SwiGLU MLP (ffn:d_model = 2.67)
- Tied embeddings (saves ~12M at vocab=32k)
- Dropout 0.1 everywhere (small-data critical)
- Block-causal attention (Diffusion Forcing)
- Per-token sigma noise (independent)
- GATED KG cross-attention (tanh(α)·xattn, α init=0)
- Layers 4, 6, 7 (3 of 8)
- Lets model learn to use KG progressively, doesn't disrupt early loss
- DF objective + LM-aux loss (joint training, paper-grade)
Sources audited:
- CoMET (Aug 2025): tokens-per-param ratio
- CLMBR (Stanford): adapter pattern for cross-site transfer
- MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
- Genie (DeepMind 2024): gated cross-attention pattern
- SD3 (Esser 2024): AdaLN-Zero zero-init gates
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class CDFv13Config:
# Vocab + sequence
vocab_size: int = 32768 # MEDS-derived (will be much smaller in practice)
mask_token: int = 32767
max_seq_len: int = 512
block_size: int = 16
# Architecture (Chinchilla-correct for ~20M tokens)
d_model: int = 384
n_heads: int = 6
n_layers: int = 8
ffn: int = 1024 # SwiGLU effective; flag below uses 2 projections
dropout: float = 0.1
emb_dropout: float = 0.1
use_swiglu: bool = True
use_rmsnorm: bool = True
tie_embeddings: bool = True
# SOTA upgrades (opt-in; default off keeps backward-compat with v13 checkpoints)
use_qk_norm: bool = False # RMSNorm on Q,K per head before RoPE (Gemma2/3-style)
use_adaln: bool = False # AdaLN-Zero (DiT/SD3) per-token sigma+cond conditioning
bidirectional: bool = False # full attention (pure masked diffusion); else block-causal
# Diffusion forcing
cond_dropout: float = 0.10
# KG conditioning (GATED adapters)
use_kg: bool = True
kg_dim: int = 3072
kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
# Latent action
use_latent_action: bool = False # Dropped per audit (concept shaky)
n_latent_actions: int = 512
# Conditioning
n_conditions: int = 64
class RMSNorm(nn.Module):
"""Root-mean-square LayerNorm (LLaMA/Mistral style)."""
def __init__(self, d: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (norm * self.weight.float()).to(x.dtype)
class SwiGLU(nn.Module):
"""SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
super().__init__()
self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
self.w_up = nn.Linear(d_in, d_hidden, bias=False)
self.w_down = nn.Linear(d_hidden, d_in, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))
class RotaryEmbedding(nn.Module):
"""RoPE (Su et al. 2021)."""
def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq).float()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos", emb.cos(), persistent=False)
self.register_buffer("sin", emb.sin(), persistent=False)
def forward(self, q, k, seq_len):
cos = self.cos[:seq_len].to(q.dtype).to(q.device)
sin = self.sin[:seq_len].to(q.dtype).to(q.device)
def rot_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
class PerTokenSigmaEmbed(nn.Module):
"""Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
def __init__(self, d: int):
super().__init__()
self.d = d
self.proj = nn.Sequential(
nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
)
def forward(self, sigma: torch.Tensor) -> torch.Tensor:
half = self.d // 2
freqs = torch.exp(
-math.log(10000.0) * torch.arange(half, device=sigma.device) / half
)
ang = sigma.float().unsqueeze(-1) * freqs
emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
return self.proj(emb)
class GatedKGCrossAttention(nn.Module):
"""Cross-attention to KG ego-subgraph, with GATED output.
`tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
initialized to 0. This means at init the cross-attention contributes
NOTHING to the residual stream, so the model trains identically to
no-KG until it discovers KG is useful. Prevents catastrophic loss
spikes on small data.
Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
"""
def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# Project KG to d_model (run inline so we don't need separate KGProjector module)
self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.norm_q = RMSNorm(d_model)
self.norm_kv = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
# Gate (scalar per block, init=0)
self.alpha = nn.Parameter(torch.zeros(1))
def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
"""
x_seq: (B, T, d_model)
kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072)
"""
B, T, D = x_seq.shape
kg_proj = self.kg_in_proj(kg_raw) # (B, N_kg, D)
N_kg = kg_proj.size(1)
q = self.q_proj(self.norm_q(x_seq))
kv = self.kv_proj(self.norm_kv(kg_proj))
k, v = kv.chunk(2, dim=-1)
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
out = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
out = out.transpose(1, 2).reshape(B, T, D)
gate = torch.tanh(self.alpha)
return x_seq + gate * self.dropout(self.out_proj(out))
class CDFv13Block(nn.Module):
"""Pre-norm transformer block + optional gated KG cross-attn."""
def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
layer_idx: int):
super().__init__()
self.cfg = cfg
self.rope = rope
self.layer_idx = layer_idx
norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
self.norm1 = norm_cls(cfg.d_model)
self.norm2 = norm_cls(cfg.d_model)
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.head_dim = cfg.d_model // cfg.n_heads
# QK-norm: per-head RMSNorm on Q,K before RoPE (stabilises attn logits)
if cfg.use_qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
# AdaLN-Zero: per-token modulation (shift/scale/gate) for MSA + MLP
if cfg.use_adaln:
self.adaln = nn.Sequential(nn.SiLU(), nn.Linear(cfg.d_model, 6 * cfg.d_model, bias=True))
if cfg.use_swiglu:
self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
else:
self.mlp = nn.Sequential(
nn.Linear(cfg.d_model, cfg.ffn, bias=False),
nn.GELU(),
nn.Linear(cfg.ffn, cfg.d_model, bias=False),
nn.Dropout(cfg.dropout),
)
self.dropout = nn.Dropout(cfg.dropout)
self.head_dim = cfg.d_model // cfg.n_heads
# Gated KG cross-attention (only in specified layers)
self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
if self.use_kg_in_layer:
self.kg_xattn = GatedKGCrossAttention(
cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
def forward(self, x, attn_mask, kg_raw=None, cond_vec=None):
B, T, D = x.shape
# AdaLN-Zero modulation (per-token shift/scale/gate) from sigma+cond
if self.cfg.use_adaln and cond_vec is not None:
sh_msa, sc_msa, g_msa, sh_mlp, sc_mlp, g_mlp = self.adaln(cond_vec).chunk(6, dim=-1)
else:
sh_msa = sc_msa = g_msa = sh_mlp = sc_mlp = g_mlp = None
# MSA
h = self.norm1(x)
if sc_msa is not None:
h = h * (1 + sc_msa) + sh_msa
qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
if self.cfg.use_qk_norm:
q = self.q_norm(q); k = self.k_norm(k)
q, k = self.rope(q, k, T)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
dropout_p=self.cfg.dropout if self.training else 0.0,
)
out = out.transpose(1, 2).reshape(B, T, D)
attn_out = self.dropout(self.proj(out))
x = x + (g_msa * attn_out if g_msa is not None else attn_out)
# Gated KG cross-attn (if enabled at this layer)
if self.use_kg_in_layer and kg_raw is not None:
x = self.kg_xattn(x, kg_raw)
# MLP
h2 = self.norm2(x)
if sc_mlp is not None:
h2 = h2 * (1 + sc_mlp) + sh_mlp
mlp_out = self.mlp(h2)
x = x + (g_mlp * mlp_out if g_mlp is not None else mlp_out)
return x
class CDFv13Transformer(nn.Module):
"""Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""
def __init__(self, cfg: CDFv13Config | None = None):
super().__init__()
self.cfg = cfg or CDFv13Config()
c = self.cfg
norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm
self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
self.emb_dropout = nn.Dropout(c.emb_dropout)
# Per-token sigma embedding (additive)
self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
# Global condition embedding (additive, broadcast)
self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)
# RoPE
self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)
# Blocks
self.blocks = nn.ModuleList([
CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
])
self.final_norm = norm_cls(c.d_model)
self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
if c.tie_embeddings:
self.head.weight = self.tok_emb.weight
# Block-causal mask buffer
T = c.max_seq_len
block_id = torch.arange(T) // c.block_size
# Block-causal (Diffusion Forcing): a query may attend to its own block and
# all EARLIER blocks; future blocks are masked. mask[i,j]=True => BLOCKED.
# (Fixes a prior inverted mask that blocked the past instead of the future.)
# Set cfg.bidirectional=True for full bidirectional attention (pure masked
# diffusion / gap-fill), which disables the causal mask entirely.
if getattr(c, "bidirectional", False):
mask = torch.zeros(T, T, dtype=torch.bool)
else:
mask = block_id.unsqueeze(0) > block_id.unsqueeze(1)
self.register_buffer("block_mask", mask, persistent=False)
# Init
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
# AdaLN-Zero: zero the modulation output so each block starts as identity
if self.cfg.use_adaln:
for blk in self.blocks:
nn.init.zeros_(blk.adaln[-1].weight)
nn.init.zeros_(blk.adaln[-1].bias)
def forward(self, x, sigma, cond, kg_raw=None):
B, T = x.shape
cond_vec = None
if self.cfg.use_adaln:
# AdaLN path: conditioning enters via per-token modulation, not additive
cond_vec = self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
h = self.tok_emb(x)
else:
h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
h = self.emb_dropout(h)
mask = self.block_mask[:T, :T]
for blk in self.blocks:
h = blk(h, mask, kg_raw=kg_raw, cond_vec=cond_vec)
h = self.final_norm(h)
return self.head(h)
def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
mode: str = "uniform") -> torch.Tensor:
"""Standard absorbing-state DF loss with per-token sigma.
mode: 'uniform' (default — safer for discrete than logit-normal per audit)
'logit_normal' (SD3-style — keep as ablation only)
"""
B, T = x_clean.shape
device = x_clean.device
# CFG cond dropout
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
cond = torch.where(drop, torch.zeros_like(cond), cond)
if kg_raw is not None:
drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
# Sample per-token sigma
if mode == "logit_normal":
sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
else:
sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
# Absorbing-state corruption
corrupt = torch.rand(B, T, device=device) < sigma
x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
ce = F.cross_entropy(
logits.reshape(-1, self.cfg.vocab_size),
x_clean.reshape(-1),
reduction="none",
).reshape(B, T)
n = corrupt.float().sum().clamp(min=1.0)
return (ce * corrupt.float()).sum() / n
@staticmethod
def recurrence_weights(x_clean, struct_ids, lam: float = 0.25, w_min: float = 0.02):
"""RAVEN recurrence-aware weights (Rajamohan et al., arXiv 2603.24562).
w[i,t] = max(lam ** count, w_min), where `count` is the number of prior
occurrences of token x[i,t] earlier in patient i's sequence. First
occurrences get full weight; repeats decay geometrically toward w_min.
Structural tokens get weight 0. Vectorized (no Python Counter loop).
Returns a (B, T) float tensor on x_clean.device.
"""
B, T = x_clean.shape
device = x_clean.device
# prior-occurrence count per position via equality-with-earlier-positions
eq = (x_clean.unsqueeze(2) == x_clean.unsqueeze(1)) # (B,T,T): eq[b,t,s] = x[b,t]==x[b,s]
earlier = torch.tril(torch.ones(T, T, device=device), diagonal=-1).bool() # [t,s]=True if s<t
count = (eq & earlier.unsqueeze(0)).sum(dim=2).float() # (B,T): #earlier positions s<t with same token
w = torch.clamp(lam ** count, min=w_min)
if struct_ids:
sid = torch.tensor(sorted(struct_ids), device=device)
is_struct = (x_clean.unsqueeze(-1) == sid).any(-1)
w = w.masked_fill(is_struct, 0.0)
return w
def recurrence_aware_loss(self, x_clean, cond, struct_ids, kg_raw=None,
lam: float = 0.25, w_min: float = 0.02,
mode: str = "uniform") -> torch.Tensor:
"""Diffusion-forcing loss reweighted by RAVEN recurrence decay — the
objective that makes GEMEO predict NOVEL events, not repeats. This is the
loss used to train the released `gemeo-sus` flagship."""
B, T = x_clean.shape
device = x_clean.device
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
cond = torch.where(drop, torch.zeros_like(cond), cond)
if kg_raw is not None:
drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
if mode == "logit_normal":
sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
else:
sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
corrupt = torch.rand(B, T, device=device) < sigma
x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
ce = F.cross_entropy(
logits.reshape(-1, self.cfg.vocab_size), x_clean.reshape(-1),
reduction="none").reshape(B, T)
w = self.recurrence_weights(x_clean, struct_ids, lam, w_min) * corrupt.float()
return (ce * w).sum() / w.sum().clamp(min=1.0)