| """GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture. |
| |
| Per the SOTA audit (May 2026): |
| - Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval) |
| - Path A adopted: small from-scratch model + KG adapters + MEDS interop |
| |
| Architecture: |
| - 12M backbone params (Chinchilla-respecting for ~20M token corpus) |
| - d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512 |
| - SwiGLU MLP (ffn:d_model = 2.67) |
| - Tied embeddings (saves ~12M at vocab=32k) |
| - Dropout 0.1 everywhere (small-data critical) |
| - Block-causal attention (Diffusion Forcing) |
| - Per-token sigma noise (independent) |
| - GATED KG cross-attention (tanh(α)·xattn, α init=0) |
| - Layers 4, 6, 7 (3 of 8) |
| - Lets model learn to use KG progressively, doesn't disrupt early loss |
| - DF objective + LM-aux loss (joint training, paper-grade) |
| |
| Sources audited: |
| - CoMET (Aug 2025): tokens-per-param ratio |
| - CLMBR (Stanford): adapter pattern for cross-site transfer |
| - MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs |
| - Genie (DeepMind 2024): gated cross-attention pattern |
| - SD3 (Esser 2024): AdaLN-Zero zero-init gates |
| """ |
| from __future__ import annotations |
| import math |
| from dataclasses import dataclass, field |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class CDFv13Config: |
| |
| vocab_size: int = 32768 |
| mask_token: int = 32767 |
| max_seq_len: int = 512 |
| block_size: int = 16 |
| |
| d_model: int = 384 |
| n_heads: int = 6 |
| n_layers: int = 8 |
| ffn: int = 1024 |
| dropout: float = 0.1 |
| emb_dropout: float = 0.1 |
| use_swiglu: bool = True |
| use_rmsnorm: bool = True |
| tie_embeddings: bool = True |
| |
| use_qk_norm: bool = False |
| use_adaln: bool = False |
| bidirectional: bool = False |
| |
| cond_dropout: float = 0.10 |
| |
| use_kg: bool = True |
| kg_dim: int = 3072 |
| kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7]) |
| |
| use_latent_action: bool = False |
| n_latent_actions: int = 512 |
| |
| n_conditions: int = 64 |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root-mean-square LayerNorm (LLaMA/Mistral style).""" |
| def __init__(self, d: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(d)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) |
| return (norm * self.weight.float()).to(x.dtype) |
|
|
|
|
| class SwiGLU(nn.Module): |
| """SwiGLU MLP (used in LLaMA/Gemma/Mistral).""" |
| def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1): |
| super().__init__() |
| self.w_gate = nn.Linear(d_in, d_hidden, bias=False) |
| self.w_up = nn.Linear(d_in, d_hidden, bias=False) |
| self.w_down = nn.Linear(d_hidden, d_in, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """RoPE (Su et al. 2021).""" |
| def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(max_seq).float() |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer("cos", emb.cos(), persistent=False) |
| self.register_buffer("sin", emb.sin(), persistent=False) |
|
|
| def forward(self, q, k, seq_len): |
| cos = self.cos[:seq_len].to(q.dtype).to(q.device) |
| sin = self.sin[:seq_len].to(q.dtype).to(q.device) |
| def rot_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
| return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin) |
|
|
|
|
| class PerTokenSigmaEmbed(nn.Module): |
| """Sinusoidal embedding of per-position diffusion noise sigma in [0,1].""" |
| def __init__(self, d: int): |
| super().__init__() |
| self.d = d |
| self.proj = nn.Sequential( |
| nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d), |
| ) |
|
|
| def forward(self, sigma: torch.Tensor) -> torch.Tensor: |
| half = self.d // 2 |
| freqs = torch.exp( |
| -math.log(10000.0) * torch.arange(half, device=sigma.device) / half |
| ) |
| ang = sigma.float().unsqueeze(-1) * freqs |
| emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) |
| return self.proj(emb) |
|
|
|
|
| class GatedKGCrossAttention(nn.Module): |
| """Cross-attention to KG ego-subgraph, with GATED output. |
| |
| `tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar |
| initialized to 0. This means at init the cross-attention contributes |
| NOTHING to the residual stream, so the model trains identically to |
| no-KG until it discovers KG is useful. Prevents catastrophic loss |
| spikes on small data. |
| |
| Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022). |
| """ |
| def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| |
| self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False) |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.norm_q = RMSNorm(d_model) |
| self.norm_kv = RMSNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| |
| self.alpha = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor: |
| """ |
| x_seq: (B, T, d_model) |
| kg_raw: (B, N_kg, kg_dim) -- raw KG embeddings (e.g. 3072) |
| """ |
| B, T, D = x_seq.shape |
| kg_proj = self.kg_in_proj(kg_raw) |
| N_kg = kg_proj.size(1) |
| q = self.q_proj(self.norm_q(x_seq)) |
| kv = self.kv_proj(self.norm_kv(kg_proj)) |
| k, v = kv.chunk(2, dim=-1) |
| q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2) |
| out = F.scaled_dot_product_attention( |
| q, k, v, dropout_p=self.dropout.p if self.training else 0.0) |
| out = out.transpose(1, 2).reshape(B, T, D) |
| gate = torch.tanh(self.alpha) |
| return x_seq + gate * self.dropout(self.out_proj(out)) |
|
|
|
|
| class CDFv13Block(nn.Module): |
| """Pre-norm transformer block + optional gated KG cross-attn.""" |
| def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding, |
| layer_idx: int): |
| super().__init__() |
| self.cfg = cfg |
| self.rope = rope |
| self.layer_idx = layer_idx |
| norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm |
| self.norm1 = norm_cls(cfg.d_model) |
| self.norm2 = norm_cls(cfg.d_model) |
| self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False) |
| self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
| self.head_dim = cfg.d_model // cfg.n_heads |
| |
| if cfg.use_qk_norm: |
| self.q_norm = RMSNorm(self.head_dim) |
| self.k_norm = RMSNorm(self.head_dim) |
| |
| if cfg.use_adaln: |
| self.adaln = nn.Sequential(nn.SiLU(), nn.Linear(cfg.d_model, 6 * cfg.d_model, bias=True)) |
| if cfg.use_swiglu: |
| self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout) |
| else: |
| self.mlp = nn.Sequential( |
| nn.Linear(cfg.d_model, cfg.ffn, bias=False), |
| nn.GELU(), |
| nn.Linear(cfg.ffn, cfg.d_model, bias=False), |
| nn.Dropout(cfg.dropout), |
| ) |
| self.dropout = nn.Dropout(cfg.dropout) |
| self.head_dim = cfg.d_model // cfg.n_heads |
| |
| self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers |
| if self.use_kg_in_layer: |
| self.kg_xattn = GatedKGCrossAttention( |
| cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout) |
|
|
| def forward(self, x, attn_mask, kg_raw=None, cond_vec=None): |
| B, T, D = x.shape |
| |
| if self.cfg.use_adaln and cond_vec is not None: |
| sh_msa, sc_msa, g_msa, sh_mlp, sc_mlp, g_mlp = self.adaln(cond_vec).chunk(6, dim=-1) |
| else: |
| sh_msa = sc_msa = g_msa = sh_mlp = sc_mlp = g_mlp = None |
| |
| h = self.norm1(x) |
| if sc_msa is not None: |
| h = h * (1 + sc_msa) + sh_msa |
| qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim) |
| q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) |
| if self.cfg.use_qk_norm: |
| q = self.q_norm(q); k = self.k_norm(k) |
| q, k = self.rope(q, k, T) |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None], |
| dropout_p=self.cfg.dropout if self.training else 0.0, |
| ) |
| out = out.transpose(1, 2).reshape(B, T, D) |
| attn_out = self.dropout(self.proj(out)) |
| x = x + (g_msa * attn_out if g_msa is not None else attn_out) |
| |
| if self.use_kg_in_layer and kg_raw is not None: |
| x = self.kg_xattn(x, kg_raw) |
| |
| h2 = self.norm2(x) |
| if sc_mlp is not None: |
| h2 = h2 * (1 + sc_mlp) + sh_mlp |
| mlp_out = self.mlp(h2) |
| x = x + (g_mlp * mlp_out if g_mlp is not None else mlp_out) |
| return x |
|
|
|
|
| class CDFv13Transformer(nn.Module): |
| """Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective.""" |
|
|
| def __init__(self, cfg: CDFv13Config | None = None): |
| super().__init__() |
| self.cfg = cfg or CDFv13Config() |
| c = self.cfg |
| norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm |
|
|
| self.tok_emb = nn.Embedding(c.vocab_size, c.d_model) |
| self.emb_dropout = nn.Dropout(c.emb_dropout) |
|
|
| |
| self.sigma_emb = PerTokenSigmaEmbed(c.d_model) |
| |
| self.cond_emb = nn.Embedding(c.n_conditions, c.d_model) |
|
|
| |
| self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2) |
|
|
| |
| self.blocks = nn.ModuleList([ |
| CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers) |
| ]) |
| self.final_norm = norm_cls(c.d_model) |
| self.head = nn.Linear(c.d_model, c.vocab_size, bias=False) |
| if c.tie_embeddings: |
| self.head.weight = self.tok_emb.weight |
|
|
| |
| T = c.max_seq_len |
| block_id = torch.arange(T) // c.block_size |
| |
| |
| |
| |
| |
| if getattr(c, "bidirectional", False): |
| mask = torch.zeros(T, T, dtype=torch.bool) |
| else: |
| mask = block_id.unsqueeze(0) > block_id.unsqueeze(1) |
| self.register_buffer("block_mask", mask, persistent=False) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| |
| if self.cfg.use_adaln: |
| for blk in self.blocks: |
| nn.init.zeros_(blk.adaln[-1].weight) |
| nn.init.zeros_(blk.adaln[-1].bias) |
|
|
| def forward(self, x, sigma, cond, kg_raw=None): |
| B, T = x.shape |
| cond_vec = None |
| if self.cfg.use_adaln: |
| |
| cond_vec = self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1) |
| h = self.tok_emb(x) |
| else: |
| h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1) |
| h = self.emb_dropout(h) |
| mask = self.block_mask[:T, :T] |
| for blk in self.blocks: |
| h = blk(h, mask, kg_raw=kg_raw, cond_vec=cond_vec) |
| h = self.final_norm(h) |
| return self.head(h) |
|
|
| def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None, |
| mode: str = "uniform") -> torch.Tensor: |
| """Standard absorbing-state DF loss with per-token sigma. |
| |
| mode: 'uniform' (default — safer for discrete than logit-normal per audit) |
| 'logit_normal' (SD3-style — keep as ablation only) |
| """ |
| B, T = x_clean.shape |
| device = x_clean.device |
| |
| drop = torch.rand(B, device=device) < self.cfg.cond_dropout |
| cond = torch.where(drop, torch.zeros_like(cond), cond) |
| if kg_raw is not None: |
| drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float() |
| kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1) |
| |
| if mode == "logit_normal": |
| sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99) |
| else: |
| sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99) |
| |
| corrupt = torch.rand(B, T, device=device) < sigma |
| x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean) |
| logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw) |
| ce = F.cross_entropy( |
| logits.reshape(-1, self.cfg.vocab_size), |
| x_clean.reshape(-1), |
| reduction="none", |
| ).reshape(B, T) |
| n = corrupt.float().sum().clamp(min=1.0) |
| return (ce * corrupt.float()).sum() / n |
|
|
| @staticmethod |
| def recurrence_weights(x_clean, struct_ids, lam: float = 0.25, w_min: float = 0.02): |
| """RAVEN recurrence-aware weights (Rajamohan et al., arXiv 2603.24562). |
| |
| w[i,t] = max(lam ** count, w_min), where `count` is the number of prior |
| occurrences of token x[i,t] earlier in patient i's sequence. First |
| occurrences get full weight; repeats decay geometrically toward w_min. |
| Structural tokens get weight 0. Vectorized (no Python Counter loop). |
| Returns a (B, T) float tensor on x_clean.device. |
| """ |
| B, T = x_clean.shape |
| device = x_clean.device |
| |
| eq = (x_clean.unsqueeze(2) == x_clean.unsqueeze(1)) |
| earlier = torch.tril(torch.ones(T, T, device=device), diagonal=-1).bool() |
| count = (eq & earlier.unsqueeze(0)).sum(dim=2).float() |
| w = torch.clamp(lam ** count, min=w_min) |
| if struct_ids: |
| sid = torch.tensor(sorted(struct_ids), device=device) |
| is_struct = (x_clean.unsqueeze(-1) == sid).any(-1) |
| w = w.masked_fill(is_struct, 0.0) |
| return w |
|
|
| def recurrence_aware_loss(self, x_clean, cond, struct_ids, kg_raw=None, |
| lam: float = 0.25, w_min: float = 0.02, |
| mode: str = "uniform") -> torch.Tensor: |
| """Diffusion-forcing loss reweighted by RAVEN recurrence decay — the |
| objective that makes GEMEO predict NOVEL events, not repeats. This is the |
| loss used to train the released `gemeo-sus` flagship.""" |
| B, T = x_clean.shape |
| device = x_clean.device |
| drop = torch.rand(B, device=device) < self.cfg.cond_dropout |
| cond = torch.where(drop, torch.zeros_like(cond), cond) |
| if kg_raw is not None: |
| drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float() |
| kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1) |
| if mode == "logit_normal": |
| sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99) |
| else: |
| sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99) |
| corrupt = torch.rand(B, T, device=device) < sigma |
| x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean) |
| logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw) |
| ce = F.cross_entropy( |
| logits.reshape(-1, self.cfg.vocab_size), x_clean.reshape(-1), |
| reduction="none").reshape(B, T) |
| w = self.recurrence_weights(x_clean, struct_ids, lam, w_min) * corrupt.float() |
| return (ce * w).sum() / w.sum().clamp(min=1.0) |
|
|