timmers commited on
Commit
bd692df
·
verified ·
1 Parent(s): 166badd

Upload src/diffusion_forcing_v13.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/diffusion_forcing_v13.py +100 -6
src/diffusion_forcing_v13.py CHANGED
@@ -50,6 +50,10 @@ class CDFv13Config:
50
  use_swiglu: bool = True
51
  use_rmsnorm: bool = True
52
  tie_embeddings: bool = True
 
 
 
 
53
  # Diffusion forcing
54
  cond_dropout: float = 0.10
55
  # KG conditioning (GATED adapters)
@@ -187,6 +191,14 @@ class CDFv13Block(nn.Module):
187
  self.norm2 = norm_cls(cfg.d_model)
188
  self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
189
  self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
 
 
 
 
 
 
 
 
190
  if cfg.use_swiglu:
191
  self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
192
  else:
@@ -204,12 +216,21 @@ class CDFv13Block(nn.Module):
204
  self.kg_xattn = GatedKGCrossAttention(
205
  cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
206
 
207
- def forward(self, x, attn_mask, kg_raw=None):
208
  B, T, D = x.shape
 
 
 
 
 
209
  # MSA
210
  h = self.norm1(x)
 
 
211
  qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
212
  q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
 
 
213
  q, k = self.rope(q, k, T)
214
  out = F.scaled_dot_product_attention(
215
  q, k, v,
@@ -217,12 +238,17 @@ class CDFv13Block(nn.Module):
217
  dropout_p=self.cfg.dropout if self.training else 0.0,
218
  )
219
  out = out.transpose(1, 2).reshape(B, T, D)
220
- x = x + self.dropout(self.proj(out))
 
221
  # Gated KG cross-attn (if enabled at this layer)
222
  if self.use_kg_in_layer and kg_raw is not None:
223
  x = self.kg_xattn(x, kg_raw)
224
  # MLP
225
- x = x + self.mlp(self.norm2(x))
 
 
 
 
226
  return x
227
 
228
 
@@ -258,7 +284,15 @@ class CDFv13Transformer(nn.Module):
258
  # Block-causal mask buffer
259
  T = c.max_seq_len
260
  block_id = torch.arange(T) // c.block_size
261
- mask = block_id.unsqueeze(0) < block_id.unsqueeze(1)
 
 
 
 
 
 
 
 
262
  self.register_buffer("block_mask", mask, persistent=False)
263
 
264
  # Init
@@ -270,14 +304,25 @@ class CDFv13Transformer(nn.Module):
270
  if m.bias is not None: nn.init.zeros_(m.bias)
271
  elif isinstance(m, nn.Embedding):
272
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
 
 
 
 
 
273
 
274
  def forward(self, x, sigma, cond, kg_raw=None):
275
  B, T = x.shape
276
- h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
 
 
 
 
 
 
277
  h = self.emb_dropout(h)
278
  mask = self.block_mask[:T, :T]
279
  for blk in self.blocks:
280
- h = blk(h, mask, kg_raw=kg_raw)
281
  h = self.final_norm(h)
282
  return self.head(h)
283
 
@@ -312,3 +357,52 @@ class CDFv13Transformer(nn.Module):
312
  ).reshape(B, T)
313
  n = corrupt.float().sum().clamp(min=1.0)
314
  return (ce * corrupt.float()).sum() / n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  use_swiglu: bool = True
51
  use_rmsnorm: bool = True
52
  tie_embeddings: bool = True
53
+ # SOTA upgrades (opt-in; default off keeps backward-compat with v13 checkpoints)
54
+ use_qk_norm: bool = False # RMSNorm on Q,K per head before RoPE (Gemma2/3-style)
55
+ use_adaln: bool = False # AdaLN-Zero (DiT/SD3) per-token sigma+cond conditioning
56
+ bidirectional: bool = False # full attention (pure masked diffusion); else block-causal
57
  # Diffusion forcing
58
  cond_dropout: float = 0.10
59
  # KG conditioning (GATED adapters)
 
191
  self.norm2 = norm_cls(cfg.d_model)
192
  self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
193
  self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
194
+ self.head_dim = cfg.d_model // cfg.n_heads
195
+ # QK-norm: per-head RMSNorm on Q,K before RoPE (stabilises attn logits)
196
+ if cfg.use_qk_norm:
197
+ self.q_norm = RMSNorm(self.head_dim)
198
+ self.k_norm = RMSNorm(self.head_dim)
199
+ # AdaLN-Zero: per-token modulation (shift/scale/gate) for MSA + MLP
200
+ if cfg.use_adaln:
201
+ self.adaln = nn.Sequential(nn.SiLU(), nn.Linear(cfg.d_model, 6 * cfg.d_model, bias=True))
202
  if cfg.use_swiglu:
203
  self.mlp = SwiGLU(cfg.d_model, cfg.ffn, cfg.dropout)
204
  else:
 
216
  self.kg_xattn = GatedKGCrossAttention(
217
  cfg.d_model, cfg.kg_dim, cfg.n_heads, cfg.dropout)
218
 
219
+ def forward(self, x, attn_mask, kg_raw=None, cond_vec=None):
220
  B, T, D = x.shape
221
+ # AdaLN-Zero modulation (per-token shift/scale/gate) from sigma+cond
222
+ if self.cfg.use_adaln and cond_vec is not None:
223
+ sh_msa, sc_msa, g_msa, sh_mlp, sc_mlp, g_mlp = self.adaln(cond_vec).chunk(6, dim=-1)
224
+ else:
225
+ sh_msa = sc_msa = g_msa = sh_mlp = sc_mlp = g_mlp = None
226
  # MSA
227
  h = self.norm1(x)
228
+ if sc_msa is not None:
229
+ h = h * (1 + sc_msa) + sh_msa
230
  qkv = self.qkv(h).reshape(B, T, 3, self.cfg.n_heads, self.head_dim)
231
  q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
232
+ if self.cfg.use_qk_norm:
233
+ q = self.q_norm(q); k = self.k_norm(k)
234
  q, k = self.rope(q, k, T)
235
  out = F.scaled_dot_product_attention(
236
  q, k, v,
 
238
  dropout_p=self.cfg.dropout if self.training else 0.0,
239
  )
240
  out = out.transpose(1, 2).reshape(B, T, D)
241
+ attn_out = self.dropout(self.proj(out))
242
+ x = x + (g_msa * attn_out if g_msa is not None else attn_out)
243
  # Gated KG cross-attn (if enabled at this layer)
244
  if self.use_kg_in_layer and kg_raw is not None:
245
  x = self.kg_xattn(x, kg_raw)
246
  # MLP
247
+ h2 = self.norm2(x)
248
+ if sc_mlp is not None:
249
+ h2 = h2 * (1 + sc_mlp) + sh_mlp
250
+ mlp_out = self.mlp(h2)
251
+ x = x + (g_mlp * mlp_out if g_mlp is not None else mlp_out)
252
  return x
253
 
254
 
 
284
  # Block-causal mask buffer
285
  T = c.max_seq_len
286
  block_id = torch.arange(T) // c.block_size
287
+ # Block-causal (Diffusion Forcing): a query may attend to its own block and
288
+ # all EARLIER blocks; future blocks are masked. mask[i,j]=True => BLOCKED.
289
+ # (Fixes a prior inverted mask that blocked the past instead of the future.)
290
+ # Set cfg.bidirectional=True for full bidirectional attention (pure masked
291
+ # diffusion / gap-fill), which disables the causal mask entirely.
292
+ if getattr(c, "bidirectional", False):
293
+ mask = torch.zeros(T, T, dtype=torch.bool)
294
+ else:
295
+ mask = block_id.unsqueeze(0) > block_id.unsqueeze(1)
296
  self.register_buffer("block_mask", mask, persistent=False)
297
 
298
  # Init
 
304
  if m.bias is not None: nn.init.zeros_(m.bias)
305
  elif isinstance(m, nn.Embedding):
306
  nn.init.normal_(m.weight, mean=0.0, std=0.02)
307
+ # AdaLN-Zero: zero the modulation output so each block starts as identity
308
+ if self.cfg.use_adaln:
309
+ for blk in self.blocks:
310
+ nn.init.zeros_(blk.adaln[-1].weight)
311
+ nn.init.zeros_(blk.adaln[-1].bias)
312
 
313
  def forward(self, x, sigma, cond, kg_raw=None):
314
  B, T = x.shape
315
+ cond_vec = None
316
+ if self.cfg.use_adaln:
317
+ # AdaLN path: conditioning enters via per-token modulation, not additive
318
+ cond_vec = self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
319
+ h = self.tok_emb(x)
320
+ else:
321
+ h = self.tok_emb(x) + self.sigma_emb(sigma) + self.cond_emb(cond).unsqueeze(1)
322
  h = self.emb_dropout(h)
323
  mask = self.block_mask[:T, :T]
324
  for blk in self.blocks:
325
+ h = blk(h, mask, kg_raw=kg_raw, cond_vec=cond_vec)
326
  h = self.final_norm(h)
327
  return self.head(h)
328
 
 
357
  ).reshape(B, T)
358
  n = corrupt.float().sum().clamp(min=1.0)
359
  return (ce * corrupt.float()).sum() / n
360
+
361
+ @staticmethod
362
+ def recurrence_weights(x_clean, struct_ids, lam: float = 0.25, w_min: float = 0.02):
363
+ """RAVEN recurrence-aware weights (Rajamohan et al., arXiv 2603.24562).
364
+
365
+ w[i,t] = max(lam ** count, w_min), where `count` is the number of prior
366
+ occurrences of token x[i,t] earlier in patient i's sequence. First
367
+ occurrences get full weight; repeats decay geometrically toward w_min.
368
+ Structural tokens get weight 0. Vectorized (no Python Counter loop).
369
+ Returns a (B, T) float tensor on x_clean.device.
370
+ """
371
+ B, T = x_clean.shape
372
+ device = x_clean.device
373
+ # prior-occurrence count per position via equality-with-earlier-positions
374
+ eq = (x_clean.unsqueeze(2) == x_clean.unsqueeze(1)) # (B,T,T): eq[b,t,s] = x[b,t]==x[b,s]
375
+ earlier = torch.tril(torch.ones(T, T, device=device), diagonal=-1).bool() # [t,s]=True if s<t
376
+ count = (eq & earlier.unsqueeze(0)).sum(dim=2).float() # (B,T): #earlier positions s<t with same token
377
+ w = torch.clamp(lam ** count, min=w_min)
378
+ if struct_ids:
379
+ sid = torch.tensor(sorted(struct_ids), device=device)
380
+ is_struct = (x_clean.unsqueeze(-1) == sid).any(-1)
381
+ w = w.masked_fill(is_struct, 0.0)
382
+ return w
383
+
384
+ def recurrence_aware_loss(self, x_clean, cond, struct_ids, kg_raw=None,
385
+ lam: float = 0.25, w_min: float = 0.02,
386
+ mode: str = "uniform") -> torch.Tensor:
387
+ """Diffusion-forcing loss reweighted by RAVEN recurrence decay — the
388
+ objective that makes GEMEO predict NOVEL events, not repeats. This is the
389
+ loss used to train the released `gemeo-sus` flagship."""
390
+ B, T = x_clean.shape
391
+ device = x_clean.device
392
+ drop = torch.rand(B, device=device) < self.cfg.cond_dropout
393
+ cond = torch.where(drop, torch.zeros_like(cond), cond)
394
+ if kg_raw is not None:
395
+ drop_kg = (torch.rand(B, device=device) < self.cfg.cond_dropout).float()
396
+ kg_raw = kg_raw * (1 - drop_kg).reshape(B, 1, 1)
397
+ if mode == "logit_normal":
398
+ sigma = torch.sigmoid(torch.randn(B, T, device=device)).clamp(0.01, 0.99)
399
+ else:
400
+ sigma = torch.rand(B, T, device=device).clamp(0.01, 0.99)
401
+ corrupt = torch.rand(B, T, device=device) < sigma
402
+ x_noisy = torch.where(corrupt, self.cfg.mask_token, x_clean)
403
+ logits = self.forward(x_noisy, sigma, cond, kg_raw=kg_raw)
404
+ ce = F.cross_entropy(
405
+ logits.reshape(-1, self.cfg.vocab_size), x_clean.reshape(-1),
406
+ reduction="none").reshape(B, T)
407
+ w = self.recurrence_weights(x_clean, struct_ids, lam, w_min) * corrupt.float()
408
+ return (ce * w).sum() / w.sum().clamp(min=1.0)