File size: 18,092 Bytes
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd692df
 
 
 
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd692df
 
 
 
 
 
 
 
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd692df
908ea05
bd692df
 
 
 
 
908ea05
 
bd692df
 
908ea05
 
bd692df
 
908ea05
 
 
 
 
 
 
bd692df
 
908ea05
 
 
 
bd692df
 
 
 
 
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd692df
 
 
 
 
 
 
 
 
908ea05
 
 
 
 
 
 
 
 
 
 
bd692df
 
 
 
 
908ea05
 
 
bd692df
 
 
 
 
 
 
908ea05
 
 
bd692df
908ea05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd692df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
"""GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.

Per the SOTA audit (May 2026):
  - Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
  - Path A adopted: small from-scratch model + KG adapters + MEDS interop

Architecture:
  - 12M backbone params (Chinchilla-respecting for ~20M token corpus)
  - d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
  - SwiGLU MLP (ffn:d_model = 2.67)
  - Tied embeddings (saves ~12M at vocab=32k)
  - Dropout 0.1 everywhere (small-data critical)
  - Block-causal attention (Diffusion Forcing)
  - Per-token sigma noise (independent)
  - GATED KG cross-attention (tanh(α)·xattn, α init=0)
    - Layers 4, 6, 7 (3 of 8)
    - Lets model learn to use KG progressively, doesn't disrupt early loss
  - DF objective + LM-aux loss (joint training, paper-grade)

Sources audited:
  - CoMET (Aug 2025): tokens-per-param ratio
  - CLMBR (Stanford): adapter pattern for cross-site transfer
  - MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
  - Genie (DeepMind 2024): gated cross-attention pattern
  - SD3 (Esser 2024): AdaLN-Zero zero-init gates
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class CDFv13Config:
    # Vocab + sequence
    vocab_size: int = 32768       # MEDS-derived (will be much smaller in practice)
    mask_token: int = 32767
    max_seq_len: int = 512
    block_size: int = 16
    # Architecture (Chinchilla-correct for ~20M tokens)
    d_model: int = 384
    n_heads: int = 6
    n_layers: int = 8
    ffn: int = 1024                # SwiGLU effective; flag below uses 2 projections
    dropout: float = 0.1
    emb_dropout: float = 0.1
    use_swiglu: bool = True
    use_rmsnorm: bool = True
    tie_embeddings: bool = True
    # SOTA upgrades (opt-in; default off keeps backward-compat with v13 checkpoints)
    use_qk_norm: bool = False      # RMSNorm on Q,K per head before RoPE (Gemma2/3-style)
    use_adaln: bool = False        # AdaLN-Zero (DiT/SD3) per-token sigma+cond conditioning
    bidirectional: bool = False    # full attention (pure masked diffusion); else block-causal
    # Diffusion forcing
    cond_dropout: float = 0.10
    # KG conditioning (GATED adapters)
    use_kg: bool = True
    kg_dim: int = 3072
    kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
    # Latent action
    use_latent_action: bool = False  # Dropped per audit (concept shaky)
    n_latent_actions: int = 512
    # Conditioning
    n_conditions: int = 64


class RMSNorm(nn.Module):
    """Root-mean-square LayerNorm (LLaMA/Mistral style)."""
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
        return (norm * self.weight.float()).to(x.dtype)


class SwiGLU(nn.Module):
    """SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
    def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
        super().__init__()
        self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
        self.w_up = nn.Linear(d_in, d_hidden, bias=False)
        self.w_down = nn.Linear(d_hidden, d_in, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))


class RotaryEmbedding(nn.Module):
    """RoPE (Su et al. 2021)."""
    def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq).float()
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos", emb.cos(), persistent=False)
        self.register_buffer("sin", emb.sin(), persistent=False)

    def forward(self, q, k, seq_len):
        cos = self.cos[:seq_len].to(q.dtype).to(q.device)
        sin = self.sin[:seq_len].to(q.dtype).to(q.device)
        def rot_half(x):
            x1, x2 = x.chunk(2, dim=-1)
            return torch.cat([-x2, x1], dim=-1)
        return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)


class PerTokenSigmaEmbed(nn.Module):
    """Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
    def __init__(self, d: int):
        super().__init__()
        self.d = d
        self.proj = nn.Sequential(
            nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
        )

    def forward(self, sigma: torch.Tensor) -> torch.Tensor:
        half = self.d // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
        )
        ang = sigma.float().unsqueeze(-1) * freqs
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
        return self.proj(emb)


class GatedKGCrossAttention(nn.Module):
    """Cross-attention to KG ego-subgraph, with GATED output.

    `tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
    initialized to 0. This means at init the cross-attention contributes
    NOTHING to the residual stream, so the model trains identically to
    no-KG until it discovers KG is useful. Prevents catastrophic loss
    spikes on small data.

    Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
    """
    def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        # Project KG to d_model (run inline so we don't need separate KGProjector module)
        self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.norm_q = RMSNorm(d_model)
        self.norm_kv = RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        # Gate (scalar per block, init=0)
        self.alpha = nn.Parameter(torch.zeros(1))

    def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
        """
        x_seq: (B, T, d_model)
        kg_raw: (B, N_kg, kg_dim)   -- raw KG embeddings (e.g. 3072)
        """
        B, T, D = x_seq.shape
        kg_proj = self.kg_in_proj(kg_raw)  # (B, N_kg, D)
        N_kg = kg_proj.size(1)
        q = self.q_proj(self.norm_q(x_seq))
        kv = self.kv_proj(self.norm_kv(kg_proj))
        k, v = kv.chunk(2, dim=-1)
        q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
        out = out.transpose(1, 2).reshape(B, T, D)
        gate = torch.tanh(self.alpha)
        return x_seq + gate * self.dropout(self.out_proj(out))


class CDFv13Block(nn.Module):
    """Pre-norm transformer block + optional gated KG cross-attn."""
    def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
                 layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.rope = rope
        self.layer_idx = layer_idx
        norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
        self.norm1 = norm_cls(cfg.d_model)
        self.norm2 = norm_cls(cfg.d_model)
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.head_dim = cfg.d_model // cfg.n_heads
        # QK-norm: per-head RMSNorm on Q,K before RoPE (stabilises attn logits)
        if cfg.use_qk_norm:
            self.q_norm = RMSNorm(self.head_dim)
            self.k_norm = RMSNorm(self.head_dim)
        # AdaLN-Zero: per-token modulation (shift/scale/gate) for MSA + MLP
        if cfg.use_adaln:
            self.adaln = nn.Sequential(nn.SiLU(), nn.Linear(cfg.d_model, 6 * cfg.d_model, bias=True))
        if cfg.use_swiglu:
            self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(cfg.d_model, cfg.ffn, bias=False),
                nn.GELU(),
                nn.Linear(cfg.ffn, cfg.d_model, bias=False),
                nn.Dropout(cfg.dropout),
            )
        self.dropout = nn.Dropout(cfg.dropout)
        self.head_dim = cfg.d_model // cfg.n_heads
        # Gated KG cross-attention (only in specified layers)
        self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
        if self.use_kg_in_layer:
            self.kg_xattn = GatedKGCrossAttention(
                cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)

    def forward(self, x, attn_mask, kg_raw=None, cond_vec=None):
        B, T, D = x.shape
        # AdaLN-Zero modulation (per-token shift/scale/gate) from sigma+cond
        if self.cfg.use_adaln and cond_vec is not None:
            sh_msa, sc_msa, g_msa, sh_mlp, sc_mlp, g_mlp = self.adaln(cond_vec).chunk(6, dim=-1)
        else:
            sh_msa = sc_msa = g_msa = sh_mlp = sc_mlp = g_mlp = None
        # MSA
        h = self.norm1(x)
        if sc_msa is not None:
            h = h * (1 + sc_msa) + sh_msa
        qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
        if self.cfg.use_qk_norm:
            q = self.q_norm(q); k = self.k_norm(k)
        q, k = self.rope(q, k, T)
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
            dropout_p=self.cfg.dropout if self.training else 0.0,
        )
        out = out.transpose(1, 2).reshape(B, T, D)
        attn_out = self.dropout(self.proj(out))
        x = x + (g_msa * attn_out if g_msa is not None else attn_out)
        # Gated KG cross-attn (if enabled at this layer)
        if self.use_kg_in_layer and kg_raw is not None:
            x = self.kg_xattn(x, kg_raw)
        # MLP
        h2 = self.norm2(x)
        if sc_mlp is not None:
            h2 = h2 * (1 + sc_mlp) + sh_mlp
        mlp_out = self.mlp(h2)
        x = x + (g_mlp * mlp_out if g_mlp is not None else mlp_out)
        return x


class CDFv13Transformer(nn.Module):
    """Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""

    def __init__(self, cfg: CDFv13Config | None = None):
        super().__init__()
        self.cfg = cfg or CDFv13Config()
        c = self.cfg
        norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm

        self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
        self.emb_dropout = nn.Dropout(c.emb_dropout)

        # Per-token sigma embedding (additive)
        self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
        # Global condition embedding (additive, broadcast)
        self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)

        # RoPE
        self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)

        # Blocks
        self.blocks = nn.ModuleList([
            CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
        ])
        self.final_norm = norm_cls(c.d_model)
        self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
        if c.tie_embeddings:
            self.head.weight = self.tok_emb.weight

        # Block-causal mask buffer
        T = c.max_seq_len
        block_id = torch.arange(T) // c.block_size
        # Block-causal (Diffusion Forcing): a query may attend to its own block and
        # all EARLIER blocks; future blocks are masked. mask[i,j]=True => BLOCKED.
        # (Fixes a prior inverted mask that blocked the past instead of the future.)
        # Set cfg.bidirectional=True for full bidirectional attention (pure masked
        # diffusion / gap-fill), which disables the causal mask entirely.
        if getattr(c, "bidirectional", False):
            mask = torch.zeros(T, T, dtype=torch.bool)
        else:
            mask = block_id.unsqueeze(0) > block_id.unsqueeze(1)
        self.register_buffer("block_mask", mask, persistent=False)

        # Init
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        # AdaLN-Zero: zero the modulation output so each block starts as identity
        if self.cfg.use_adaln:
            for blk in self.blocks:
                nn.init.zeros_(blk.adaln[-1].weight)
                nn.init.zeros_(blk.adaln[-1].bias)

    def forward(self, x, sigma, cond, kg_raw=None):
        B, T = x.shape
        cond_vec = None
        if self.cfg.use_adaln:
            # AdaLN path: conditioning enters via per-token modulation, not additive
            cond_vec = self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
            h = self.tok_emb(x)
        else:
            h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
        h = self.emb_dropout(h)
        mask = self.block_mask[:T, :T]
        for blk in self.blocks:
            h = blk(h, mask, kg_raw=kg_raw, cond_vec=cond_vec)
        h = self.final_norm(h)
        return self.head(h)

    def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
                                mode: str = "uniform") -> torch.Tensor:
        """Standard absorbing-state DF loss with per-token sigma.

        mode: 'uniform' (default — safer for discrete than logit-normal per audit)
              'logit_normal' (SD3-style — keep as ablation only)
        """
        B, T = x_clean.shape
        device = x_clean.device
        # CFG cond dropout
        drop = torch.rand(B, device=device) < self.cfg.cond_dropout
        cond = torch.where(drop, torch.zeros_like(cond), cond)
        if kg_raw is not None:
            drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
            kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
        # Sample per-token sigma
        if mode == "logit_normal":
            sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
        else:
            sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
        # Absorbing-state corruption
        corrupt = torch.rand(B, T, device=device) < sigma
        x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
        logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
        ce = F.cross_entropy(
            logits.reshape(-1, self.cfg.vocab_size),
            x_clean.reshape(-1),
            reduction="none",
        ).reshape(B, T)
        n = corrupt.float().sum().clamp(min=1.0)
        return (ce * corrupt.float()).sum() / n

    @staticmethod
    def recurrence_weights(x_clean, struct_ids, lam: float = 0.25, w_min: float = 0.02):
        """RAVEN recurrence-aware weights (Rajamohan et al., arXiv 2603.24562).

        w[i,t] = max(lam ** count, w_min), where `count` is the number of prior
        occurrences of token x[i,t] earlier in patient i's sequence. First
        occurrences get full weight; repeats decay geometrically toward w_min.
        Structural tokens get weight 0. Vectorized (no Python Counter loop).
        Returns a (B, T) float tensor on x_clean.device.
        """
        B, T = x_clean.shape
        device = x_clean.device
        # prior-occurrence count per position via equality-with-earlier-positions
        eq = (x_clean.unsqueeze(2) == x_clean.unsqueeze(1))      # (B,T,T): eq[b,t,s] = x[b,t]==x[b,s]
        earlier = torch.tril(torch.ones(T, T, device=device), diagonal=-1).bool()  # [t,s]=True if s<t
        count = (eq & earlier.unsqueeze(0)).sum(dim=2).float()   # (B,T): #earlier positions s<t with same token
        w = torch.clamp(lam ** count, min=w_min)
        if struct_ids:
            sid = torch.tensor(sorted(struct_ids), device=device)
            is_struct = (x_clean.unsqueeze(-1) == sid).any(-1)
            w = w.masked_fill(is_struct, 0.0)
        return w

    def recurrence_aware_loss(self, x_clean, cond, struct_ids, kg_raw=None,
                              lam: float = 0.25, w_min: float = 0.02,
                              mode: str = "uniform") -> torch.Tensor:
        """Diffusion-forcing loss reweighted by RAVEN recurrence decay — the
        objective that makes GEMEO predict NOVEL events, not repeats. This is the
        loss used to train the released `gemeo-sus` flagship."""
        B, T = x_clean.shape
        device = x_clean.device
        drop = torch.rand(B, device=device) < self.cfg.cond_dropout
        cond = torch.where(drop, torch.zeros_like(cond), cond)
        if kg_raw is not None:
            drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
            kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
        if mode == "logit_normal":
            sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
        else:
            sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
        corrupt = torch.rand(B, T, device=device) < sigma
        x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
        logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
        ce = F.cross_entropy(
            logits.reshape(-1, self.cfg.vocab_size), x_clean.reshape(-1),
            reduction="none").reshape(B, T)
        w = self.recurrence_weights(x_clean, struct_ids, lam, w_min) * corrupt.float()
        return (ce * w).sum() / w.sum().clamp(min=1.0)