Upload src/diffusion_forcing_v13.py with huggingface_hub
Browse files- src/diffusion_forcing_v13.py +100 -6
src/diffusion_forcing_v13.py
CHANGED
|
@@ -50,6 +50,10 @@ class CDFv13Config:
|
|
| 50 |
use_swiglu: bool = True
|
| 51 |
use_rmsnorm: bool = True
|
| 52 |
tie_embeddings: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# Diffusion forcing
|
| 54 |
cond_dropout: float = 0.10
|
| 55 |
# KG conditioning (GATED adapters)
|
|
@@ -187,6 +191,14 @@ class CDFv13Block(nn.Module):
|
|
| 187 |
self.norm2 = norm_cls(cfg.d_model)
|
| 188 |
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 189 |
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if cfg.use_swiglu:
|
| 191 |
self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
|
| 192 |
else:
|
|
@@ -204,12 +216,21 @@ class CDFv13Block(nn.Module):
|
|
| 204 |
self.kg_xattn = GatedKGCrossAttention(
|
| 205 |
cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
|
| 206 |
|
| 207 |
-
def forward(self, x, attn_mask, kg_raw=None):
|
| 208 |
B, T, D = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
# MSA
|
| 210 |
h = self.norm1(x)
|
|
|
|
|
|
|
| 211 |
qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
|
| 212 |
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
|
|
|
|
|
|
| 213 |
q, k = self.rope(q, k, T)
|
| 214 |
out = F.scaled_dot_product_attention(
|
| 215 |
q, k, v,
|
|
@@ -217,12 +238,17 @@ class CDFv13Block(nn.Module):
|
|
| 217 |
dropout_p=self.cfg.dropout if self.training else 0.0,
|
| 218 |
)
|
| 219 |
out = out.transpose(1, 2).reshape(B, T, D)
|
| 220 |
-
|
|
|
|
| 221 |
# Gated KG cross-attn (if enabled at this layer)
|
| 222 |
if self.use_kg_in_layer and kg_raw is not None:
|
| 223 |
x = self.kg_xattn(x, kg_raw)
|
| 224 |
# MLP
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
return x
|
| 227 |
|
| 228 |
|
|
@@ -258,7 +284,15 @@ class CDFv13Transformer(nn.Module):
|
|
| 258 |
# Block-causal mask buffer
|
| 259 |
T = c.max_seq_len
|
| 260 |
block_id = torch.arange(T) // c.block_size
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
self.register_buffer("block_mask", mask, persistent=False)
|
| 263 |
|
| 264 |
# Init
|
|
@@ -270,14 +304,25 @@ class CDFv13Transformer(nn.Module):
|
|
| 270 |
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 271 |
elif isinstance(m, nn.Embedding):
|
| 272 |
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
def forward(self, x, sigma, cond, kg_raw=None):
|
| 275 |
B, T = x.shape
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
h = self.emb_dropout(h)
|
| 278 |
mask = self.block_mask[:T, :T]
|
| 279 |
for blk in self.blocks:
|
| 280 |
-
h = blk(h, mask, kg_raw=kg_raw)
|
| 281 |
h = self.final_norm(h)
|
| 282 |
return self.head(h)
|
| 283 |
|
|
@@ -312,3 +357,52 @@ class CDFv13Transformer(nn.Module):
|
|
| 312 |
).reshape(B, T)
|
| 313 |
n = corrupt.float().sum().clamp(min=1.0)
|
| 314 |
return (ce * corrupt.float()).sum() / n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
use_swiglu: bool = True
|
| 51 |
use_rmsnorm: bool = True
|
| 52 |
tie_embeddings: bool = True
|
| 53 |
+
# SOTA upgrades (opt-in; default off keeps backward-compat with v13 checkpoints)
|
| 54 |
+
use_qk_norm: bool = False # RMSNorm on Q,K per head before RoPE (Gemma2/3-style)
|
| 55 |
+
use_adaln: bool = False # AdaLN-Zero (DiT/SD3) per-token sigma+cond conditioning
|
| 56 |
+
bidirectional: bool = False # full attention (pure masked diffusion); else block-causal
|
| 57 |
# Diffusion forcing
|
| 58 |
cond_dropout: float = 0.10
|
| 59 |
# KG conditioning (GATED adapters)
|
|
|
|
| 191 |
self.norm2 = norm_cls(cfg.d_model)
|
| 192 |
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
|
| 193 |
self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
|
| 194 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 195 |
+
# QK-norm: per-head RMSNorm on Q,K before RoPE (stabilises attn logits)
|
| 196 |
+
if cfg.use_qk_norm:
|
| 197 |
+
self.q_norm = RMSNorm(self.head_dim)
|
| 198 |
+
self.k_norm = RMSNorm(self.head_dim)
|
| 199 |
+
# AdaLN-Zero: per-token modulation (shift/scale/gate) for MSA + MLP
|
| 200 |
+
if cfg.use_adaln:
|
| 201 |
+
self.adaln = nn.Sequential(nn.SiLU(), nn.Linear(cfg.d_model, 6 * cfg.d_model, bias=True))
|
| 202 |
if cfg.use_swiglu:
|
| 203 |
self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
|
| 204 |
else:
|
|
|
|
| 216 |
self.kg_xattn = GatedKGCrossAttention(
|
| 217 |
cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
|
| 218 |
|
| 219 |
+
def forward(self, x, attn_mask, kg_raw=None, cond_vec=None):
|
| 220 |
B, T, D = x.shape
|
| 221 |
+
# AdaLN-Zero modulation (per-token shift/scale/gate) from sigma+cond
|
| 222 |
+
if self.cfg.use_adaln and cond_vec is not None:
|
| 223 |
+
sh_msa, sc_msa, g_msa, sh_mlp, sc_mlp, g_mlp = self.adaln(cond_vec).chunk(6, dim=-1)
|
| 224 |
+
else:
|
| 225 |
+
sh_msa = sc_msa = g_msa = sh_mlp = sc_mlp = g_mlp = None
|
| 226 |
# MSA
|
| 227 |
h = self.norm1(x)
|
| 228 |
+
if sc_msa is not None:
|
| 229 |
+
h = h * (1 + sc_msa) + sh_msa
|
| 230 |
qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
|
| 231 |
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
| 232 |
+
if self.cfg.use_qk_norm:
|
| 233 |
+
q = self.q_norm(q); k = self.k_norm(k)
|
| 234 |
q, k = self.rope(q, k, T)
|
| 235 |
out = F.scaled_dot_product_attention(
|
| 236 |
q, k, v,
|
|
|
|
| 238 |
dropout_p=self.cfg.dropout if self.training else 0.0,
|
| 239 |
)
|
| 240 |
out = out.transpose(1, 2).reshape(B, T, D)
|
| 241 |
+
attn_out = self.dropout(self.proj(out))
|
| 242 |
+
x = x + (g_msa * attn_out if g_msa is not None else attn_out)
|
| 243 |
# Gated KG cross-attn (if enabled at this layer)
|
| 244 |
if self.use_kg_in_layer and kg_raw is not None:
|
| 245 |
x = self.kg_xattn(x, kg_raw)
|
| 246 |
# MLP
|
| 247 |
+
h2 = self.norm2(x)
|
| 248 |
+
if sc_mlp is not None:
|
| 249 |
+
h2 = h2 * (1 + sc_mlp) + sh_mlp
|
| 250 |
+
mlp_out = self.mlp(h2)
|
| 251 |
+
x = x + (g_mlp * mlp_out if g_mlp is not None else mlp_out)
|
| 252 |
return x
|
| 253 |
|
| 254 |
|
|
|
|
| 284 |
# Block-causal mask buffer
|
| 285 |
T = c.max_seq_len
|
| 286 |
block_id = torch.arange(T) // c.block_size
|
| 287 |
+
# Block-causal (Diffusion Forcing): a query may attend to its own block and
|
| 288 |
+
# all EARLIER blocks; future blocks are masked. mask[i,j]=True => BLOCKED.
|
| 289 |
+
# (Fixes a prior inverted mask that blocked the past instead of the future.)
|
| 290 |
+
# Set cfg.bidirectional=True for full bidirectional attention (pure masked
|
| 291 |
+
# diffusion / gap-fill), which disables the causal mask entirely.
|
| 292 |
+
if getattr(c, "bidirectional", False):
|
| 293 |
+
mask = torch.zeros(T, T, dtype=torch.bool)
|
| 294 |
+
else:
|
| 295 |
+
mask = block_id.unsqueeze(0) > block_id.unsqueeze(1)
|
| 296 |
self.register_buffer("block_mask", mask, persistent=False)
|
| 297 |
|
| 298 |
# Init
|
|
|
|
| 304 |
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 305 |
elif isinstance(m, nn.Embedding):
|
| 306 |
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 307 |
+
# AdaLN-Zero: zero the modulation output so each block starts as identity
|
| 308 |
+
if self.cfg.use_adaln:
|
| 309 |
+
for blk in self.blocks:
|
| 310 |
+
nn.init.zeros_(blk.adaln[-1].weight)
|
| 311 |
+
nn.init.zeros_(blk.adaln[-1].bias)
|
| 312 |
|
| 313 |
def forward(self, x, sigma, cond, kg_raw=None):
|
| 314 |
B, T = x.shape
|
| 315 |
+
cond_vec = None
|
| 316 |
+
if self.cfg.use_adaln:
|
| 317 |
+
# AdaLN path: conditioning enters via per-token modulation, not additive
|
| 318 |
+
cond_vec = self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
|
| 319 |
+
h = self.tok_emb(x)
|
| 320 |
+
else:
|
| 321 |
+
h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
|
| 322 |
h = self.emb_dropout(h)
|
| 323 |
mask = self.block_mask[:T, :T]
|
| 324 |
for blk in self.blocks:
|
| 325 |
+
h = blk(h, mask, kg_raw=kg_raw, cond_vec=cond_vec)
|
| 326 |
h = self.final_norm(h)
|
| 327 |
return self.head(h)
|
| 328 |
|
|
|
|
| 357 |
).reshape(B, T)
|
| 358 |
n = corrupt.float().sum().clamp(min=1.0)
|
| 359 |
return (ce * corrupt.float()).sum() / n
|
| 360 |
+
|
| 361 |
+
@staticmethod
|
| 362 |
+
def recurrence_weights(x_clean, struct_ids, lam: float = 0.25, w_min: float = 0.02):
|
| 363 |
+
"""RAVEN recurrence-aware weights (Rajamohan et al., arXiv 2603.24562).
|
| 364 |
+
|
| 365 |
+
w[i,t] = max(lam ** count, w_min), where `count` is the number of prior
|
| 366 |
+
occurrences of token x[i,t] earlier in patient i's sequence. First
|
| 367 |
+
occurrences get full weight; repeats decay geometrically toward w_min.
|
| 368 |
+
Structural tokens get weight 0. Vectorized (no Python Counter loop).
|
| 369 |
+
Returns a (B, T) float tensor on x_clean.device.
|
| 370 |
+
"""
|
| 371 |
+
B, T = x_clean.shape
|
| 372 |
+
device = x_clean.device
|
| 373 |
+
# prior-occurrence count per position via equality-with-earlier-positions
|
| 374 |
+
eq = (x_clean.unsqueeze(2) == x_clean.unsqueeze(1)) # (B,T,T): eq[b,t,s] = x[b,t]==x[b,s]
|
| 375 |
+
earlier = torch.tril(torch.ones(T, T, device=device), diagonal=-1).bool() # [t,s]=True if s<t
|
| 376 |
+
count = (eq & earlier.unsqueeze(0)).sum(dim=2).float() # (B,T): #earlier positions s<t with same token
|
| 377 |
+
w = torch.clamp(lam ** count, min=w_min)
|
| 378 |
+
if struct_ids:
|
| 379 |
+
sid = torch.tensor(sorted(struct_ids), device=device)
|
| 380 |
+
is_struct = (x_clean.unsqueeze(-1) == sid).any(-1)
|
| 381 |
+
w = w.masked_fill(is_struct, 0.0)
|
| 382 |
+
return w
|
| 383 |
+
|
| 384 |
+
def recurrence_aware_loss(self, x_clean, cond, struct_ids, kg_raw=None,
|
| 385 |
+
lam: float = 0.25, w_min: float = 0.02,
|
| 386 |
+
mode: str = "uniform") -> torch.Tensor:
|
| 387 |
+
"""Diffusion-forcing loss reweighted by RAVEN recurrence decay — the
|
| 388 |
+
objective that makes GEMEO predict NOVEL events, not repeats. This is the
|
| 389 |
+
loss used to train the released `gemeo-sus` flagship."""
|
| 390 |
+
B, T = x_clean.shape
|
| 391 |
+
device = x_clean.device
|
| 392 |
+
drop = torch.rand(B, device=device) < self.cfg.cond_dropout
|
| 393 |
+
cond = torch.where(drop, torch.zeros_like(cond), cond)
|
| 394 |
+
if kg_raw is not None:
|
| 395 |
+
drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
|
| 396 |
+
kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
|
| 397 |
+
if mode == "logit_normal":
|
| 398 |
+
sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
|
| 399 |
+
else:
|
| 400 |
+
sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
|
| 401 |
+
corrupt = torch.rand(B, T, device=device) < sigma
|
| 402 |
+
x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
|
| 403 |
+
logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
|
| 404 |
+
ce = F.cross_entropy(
|
| 405 |
+
logits.reshape(-1, self.cfg.vocab_size), x_clean.reshape(-1),
|
| 406 |
+
reduction="none").reshape(B, T)
|
| 407 |
+
w = self.recurrence_weights(x_clean, struct_ids, lam, w_min) * corrupt.float()
|
| 408 |
+
return (ce * w).sum() / w.sum().clamp(min=1.0)
|