File size: 12,683 Bytes
a0fa886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""GEMEO-CDF v13 — audit-driven Chinchilla-correct architecture.

Per the SOTA audit (May 2026):
  - Path B (CLMBR fine-tune) BLOCKED: CLMBR-T-base is HF-gated (manual approval)
  - Path A adopted: small from-scratch model + KG adapters + MEDS interop

Architecture:
  - 12M backbone params (Chinchilla-respecting for ~20M token corpus)
  - d_model=384, n_layers=8, n_heads=6, ffn=1024, ctx=512
  - SwiGLU MLP (ffn:d_model = 2.67)
  - Tied embeddings (saves ~12M at vocab=32k)
  - Dropout 0.1 everywhere (small-data critical)
  - Block-causal attention (Diffusion Forcing)
  - Per-token sigma noise (independent)
  - GATED KG cross-attention (tanh(α)·xattn, α init=0)
    - Layers 4, 6, 7 (3 of 8)
    - Lets model learn to use KG progressively, doesn't disrupt early loss
  - DF objective + LM-aux loss (joint training, paper-grade)

Sources audited:
  - CoMET (Aug 2025): tokens-per-param ratio
  - CLMBR (Stanford): adapter pattern for cross-site transfer
  - MDLM (Sahoo 2024): masked diffusion, matches AR at equal FLOPs
  - Genie (DeepMind 2024): gated cross-attention pattern
  - SD3 (Esser 2024): AdaLN-Zero zero-init gates
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class CDFv13Config:
    # Vocab + sequence
    vocab_size: int = 32768       # MEDS-derived (will be much smaller in practice)
    mask_token: int = 32767
    max_seq_len: int = 512
    block_size: int = 16
    # Architecture (Chinchilla-correct for ~20M tokens)
    d_model: int = 384
    n_heads: int = 6
    n_layers: int = 8
    ffn: int = 1024                # SwiGLU effective; flag below uses 2 projections
    dropout: float = 0.1
    emb_dropout: float = 0.1
    use_swiglu: bool = True
    use_rmsnorm: bool = True
    tie_embeddings: bool = True
    # Diffusion forcing
    cond_dropout: float = 0.10
    # KG conditioning (GATED adapters)
    use_kg: bool = True
    kg_dim: int = 3072
    kg_attn_layers: list = field(default_factory=lambda: [4, 6, 7])
    # Latent action
    use_latent_action: bool = False  # Dropped per audit (concept shaky)
    n_latent_actions: int = 512
    # Conditioning
    n_conditions: int = 64


class RMSNorm(nn.Module):
    """Root-mean-square LayerNorm (LLaMA/Mistral style)."""
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(d))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
        return (norm * self.weight.float()).to(x.dtype)


class SwiGLU(nn.Module):
    """SwiGLU MLP (used in LLaMA/Gemma/Mistral)."""
    def __init__(self, d_in: int, d_hidden: int, dropout: float = 0.1):
        super().__init__()
        self.w_gate = nn.Linear(d_in, d_hidden, bias=False)
        self.w_up = nn.Linear(d_in, d_hidden, bias=False)
        self.w_down = nn.Linear(d_hidden, d_in, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)))


class RotaryEmbedding(nn.Module):
    """RoPE (Su et al. 2021)."""
    def __init__(self, dim: int, max_seq: int = 8192, base: float = 10000.0):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq).float()
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer("cos", emb.cos(), persistent=False)
        self.register_buffer("sin", emb.sin(), persistent=False)

    def forward(self, q, k, seq_len):
        cos = self.cos[:seq_len].to(q.dtype).to(q.device)
        sin = self.sin[:seq_len].to(q.dtype).to(q.device)
        def rot_half(x):
            x1, x2 = x.chunk(2, dim=-1)
            return torch.cat([-x2, x1], dim=-1)
        return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)


class PerTokenSigmaEmbed(nn.Module):
    """Sinusoidal embedding of per-position diffusion noise sigma in [0,1]."""
    def __init__(self, d: int):
        super().__init__()
        self.d = d
        self.proj = nn.Sequential(
            nn.Linear(d, d), nn.SiLU(), nn.Linear(d, d),
        )

    def forward(self, sigma: torch.Tensor) -> torch.Tensor:
        half = self.d // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=sigma.device) / half
        )
        ang = sigma.float().unsqueeze(-1) * freqs
        emb = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1)
        return self.proj(emb)


class GatedKGCrossAttention(nn.Module):
    """Cross-attention to KG ego-subgraph, with GATED output.

    `tanh(alpha) * cross_attn(x_seq, x_kg)` where alpha is a learnable scalar
    initialized to 0. This means at init the cross-attention contributes
    NOTHING to the residual stream, so the model trains identically to
    no-KG until it discovers KG is useful. Prevents catastrophic loss
    spikes on small data.

    Pattern from: Genie (DeepMind 2024), Flamingo (DeepMind 2022).
    """
    def __init__(self, d_model: int, kg_dim: int, n_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        # Project KG to d_model (run inline so we don't need separate KGProjector module)
        self.kg_in_proj = nn.Linear(kg_dim, d_model, bias=False)
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.kv_proj = nn.Linear(d_model, 2 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.norm_q = RMSNorm(d_model)
        self.norm_kv = RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        # Gate (scalar per block, init=0)
        self.alpha = nn.Parameter(torch.zeros(1))

    def forward(self, x_seq: torch.Tensor, kg_raw: torch.Tensor) -> torch.Tensor:
        """
        x_seq: (B, T, d_model)
        kg_raw: (B, N_kg, kg_dim)   -- raw KG embeddings (e.g. 3072)
        """
        B, T, D = x_seq.shape
        kg_proj = self.kg_in_proj(kg_raw)  # (B, N_kg, D)
        N_kg = kg_proj.size(1)
        q = self.q_proj(self.norm_q(x_seq))
        kv = self.kv_proj(self.norm_kv(kg_proj))
        k, v = kv.chunk(2, dim=-1)
        q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, N_kg, self.n_heads, self.head_dim).transpose(1, 2)
        out = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout.p if self.training else 0.0)
        out = out.transpose(1, 2).reshape(B, T, D)
        gate = torch.tanh(self.alpha)
        return x_seq + gate * self.dropout(self.out_proj(out))


class CDFv13Block(nn.Module):
    """Pre-norm transformer block + optional gated KG cross-attn."""
    def __init__(self, cfg: CDFv13Config, rope: RotaryEmbedding,
                 layer_idx: int):
        super().__init__()
        self.cfg = cfg
        self.rope = rope
        self.layer_idx = layer_idx
        norm_cls = RMSNorm if cfg.use_rmsnorm else nn.LayerNorm
        self.norm1 = norm_cls(cfg.d_model)
        self.norm2 = norm_cls(cfg.d_model)
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        if cfg.use_swiglu:
            self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(cfg.d_model, cfg.ffn, bias=False),
                nn.GELU(),
                nn.Linear(cfg.ffn, cfg.d_model, bias=False),
                nn.Dropout(cfg.dropout),
            )
        self.dropout = nn.Dropout(cfg.dropout)
        self.head_dim = cfg.d_model // cfg.n_heads
        # Gated KG cross-attention (only in specified layers)
        self.use_kg_in_layer = cfg.use_kg and layer_idx in cfg.kg_attn_layers
        if self.use_kg_in_layer:
            self.kg_xattn = GatedKGCrossAttention(
                cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)

    def forward(self, x, attn_mask, kg_raw=None):
        B, T, D = x.shape
        # MSA
        h = self.norm1(x)
        qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
        q, k = self.rope(q, k, T)
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=(~attn_mask).float().masked_fill(attn_mask, float("-inf"))[None, None],
            dropout_p=self.cfg.dropout if self.training else 0.0,
        )
        out = out.transpose(1, 2).reshape(B, T, D)
        x = x + self.dropout(self.proj(out))
        # Gated KG cross-attn (if enabled at this layer)
        if self.use_kg_in_layer and kg_raw is not None:
            x = self.kg_xattn(x, kg_raw)
        # MLP
        x = x + self.mlp(self.norm2(x))
        return x


class CDFv13Transformer(nn.Module):
    """Audit-compliant CDF v13: 12M backbone + KG adapters + DF objective."""

    def __init__(self, cfg: CDFv13Config | None = None):
        super().__init__()
        self.cfg = cfg or CDFv13Config()
        c = self.cfg
        norm_cls = RMSNorm if c.use_rmsnorm else nn.LayerNorm

        self.tok_emb = nn.Embedding(c.vocab_size, c.d_model)
        self.emb_dropout = nn.Dropout(c.emb_dropout)

        # Per-token sigma embedding (additive)
        self.sigma_emb = PerTokenSigmaEmbed(c.d_model)
        # Global condition embedding (additive, broadcast)
        self.cond_emb = nn.Embedding(c.n_conditions, c.d_model)

        # RoPE
        self.rope = RotaryEmbedding(c.d_model // c.n_heads, max_seq=c.max_seq_len * 2)

        # Blocks
        self.blocks = nn.ModuleList([
            CDFv13Block(c, self.rope, layer_idx=i) for i in range(c.n_layers)
        ])
        self.final_norm = norm_cls(c.d_model)
        self.head = nn.Linear(c.d_model, c.vocab_size, bias=False)
        if c.tie_embeddings:
            self.head.weight = self.tok_emb.weight

        # Block-causal mask buffer
        T = c.max_seq_len
        block_id = torch.arange(T) // c.block_size
        mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
        self.register_buffer("block_mask", mask, persistent=False)

        # Init
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, x, sigma, cond, kg_raw=None):
        B, T = x.shape
        h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
        h = self.emb_dropout(h)
        mask = self.block_mask[:T, :T]
        for blk in self.blocks:
            h = blk(h, mask, kg_raw=kg_raw)
        h = self.final_norm(h)
        return self.head(h)

    def diffusion_forcing_loss(self, x_clean, cond, kg_raw=None,
                                mode: str = "uniform") -> torch.Tensor:
        """Standard absorbing-state DF loss with per-token sigma.

        mode: 'uniform' (default — safer for discrete than logit-normal per audit)
              'logit_normal' (SD3-style — keep as ablation only)
        """
        B, T = x_clean.shape
        device = x_clean.device
        # CFG cond dropout
        drop = torch.rand(B, device=device) < self.cfg.cond_dropout
        cond = torch.where(drop, torch.zeros_like(cond), cond)
        if kg_raw is not None:
            drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
            kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
        # Sample per-token sigma
        if mode == "logit_normal":
            sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
        else:
            sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
        # Absorbing-state corruption
        corrupt = torch.rand(B, T, device=device) < sigma
        x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
        logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
        ce = F.cross_entropy(
            logits.reshape(-1, self.cfg.vocab_size),
            x_clean.reshape(-1),
            reduction="none",
        ).reshape(B, T)
        n = corrupt.float().sum().clamp(min=1.0)
        return (ce * corrupt.float()).sum() / n