Backup-bdg commited on
Commit
f9dcb77
·
verified ·
1 Parent(s): 688711a

Update model weights after training (epoch 5, loss 6.9589)

Browse files
audio_decoder.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0644bb8cb74a2a1d0e055138e41ec52d65d83dca9bc9466cbdd8f388f1aa96b2
3
  size 1458410612
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c225077ec0e29909d0f390011f666158ae658fa3385cf8032280f5203da09cae
3
  size 1458410612
config.json CHANGED
@@ -49,10 +49,10 @@
49
  "image_size_step": 32,
50
  "video_min_size": 128,
51
  "video_max_size": 320,
52
- "video_base_size": 192,
53
  "video_size_step": 32,
54
  "video_min_frames": 8,
55
- "video_max_frames": 24,
56
  "video_base_frames": 16,
57
  "video_frame_step": 4,
58
  "multi_scale_strategy": "adaptive",
 
49
  "image_size_step": 32,
50
  "video_min_size": 128,
51
  "video_max_size": 320,
52
+ "video_base_size": 320,
53
  "video_size_step": 32,
54
  "video_min_frames": 8,
55
+ "video_max_frames": 8,
56
  "video_base_frames": 16,
57
  "video_frame_step": 4,
58
  "multi_scale_strategy": "adaptive",
configuration_xoron.py CHANGED
@@ -213,11 +213,17 @@ class XoronConfig(PreTrainedConfig):
213
  # Output path (used during training)
214
  output_dir: str = "./xoron-model",
215
 
 
 
 
216
  **kwargs,
217
  ):
218
  # Call parent init
219
  super().__init__(**kwargs)
220
 
 
 
 
221
  # Model identification
222
  self.model_name = model_name
223
 
 
213
  # Output path (used during training)
214
  output_dir: str = "./xoron-model",
215
 
216
+ # Training Configuration
217
+ modality_dropout_prob: float = 0.0,
218
+
219
  **kwargs,
220
  ):
221
  # Call parent init
222
  super().__init__(**kwargs)
223
 
224
+ # Training Configuration
225
+ self.modality_dropout_prob = modality_dropout_prob
226
+
227
  # Model identification
228
  self.model_name = model_name
229
 
cross_attention.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:10a70bf7bf4edce737146b199b106166957aa843440edfc45831f1d6033b7e11
3
  size 174191400
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5dc29d69984df0e49cf508c56c03b7a18a7a49baf89a414fa3128513d753e7e
3
  size 174191400
llm.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b78daf2a6be38a3c0753175dd705363f8a348dc24b7d7a6fb9539715c530f22e
3
- size 1506831304
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5de86313a868d4108f814a3debd9d1ed31dc72281458ef9c7824b9a4398ce28f
3
+ size 1506832040
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 7309365038,
4
  "format": "components"
5
  },
6
  "weight_map": {
@@ -36,6 +36,7 @@
36
  "llm.model.layers.1.self_attn.o_proj.linear.weight": "llm.safetensors",
37
  "llm.model.layers.1.input_layernorm.weight": "llm.safetensors",
38
  "llm.model.layers.1.post_attention_layernorm.weight": "llm.safetensors",
 
39
  "llm.model.layers.1.mlp.router.input_norm.weight": "llm.safetensors",
40
  "llm.model.layers.1.mlp.router.gate.weight": "llm.safetensors",
41
  "llm.model.layers.1.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
@@ -150,6 +151,7 @@
150
  "llm.model.layers.3.self_attn.o_proj.linear.weight": "llm.safetensors",
151
  "llm.model.layers.3.input_layernorm.weight": "llm.safetensors",
152
  "llm.model.layers.3.post_attention_layernorm.weight": "llm.safetensors",
 
153
  "llm.model.layers.3.mlp.router.input_norm.weight": "llm.safetensors",
154
  "llm.model.layers.3.mlp.router.gate.weight": "llm.safetensors",
155
  "llm.model.layers.3.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
@@ -264,6 +266,7 @@
264
  "llm.model.layers.5.self_attn.o_proj.linear.weight": "llm.safetensors",
265
  "llm.model.layers.5.input_layernorm.weight": "llm.safetensors",
266
  "llm.model.layers.5.post_attention_layernorm.weight": "llm.safetensors",
 
267
  "llm.model.layers.5.mlp.router.input_norm.weight": "llm.safetensors",
268
  "llm.model.layers.5.mlp.router.gate.weight": "llm.safetensors",
269
  "llm.model.layers.5.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
@@ -378,6 +381,7 @@
378
  "llm.model.layers.7.self_attn.o_proj.linear.weight": "llm.safetensors",
379
  "llm.model.layers.7.input_layernorm.weight": "llm.safetensors",
380
  "llm.model.layers.7.post_attention_layernorm.weight": "llm.safetensors",
 
381
  "llm.model.layers.7.mlp.router.input_norm.weight": "llm.safetensors",
382
  "llm.model.layers.7.mlp.router.gate.weight": "llm.safetensors",
383
  "llm.model.layers.7.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
@@ -492,6 +496,7 @@
492
  "llm.model.layers.9.self_attn.o_proj.linear.weight": "llm.safetensors",
493
  "llm.model.layers.9.input_layernorm.weight": "llm.safetensors",
494
  "llm.model.layers.9.post_attention_layernorm.weight": "llm.safetensors",
 
495
  "llm.model.layers.9.mlp.router.input_norm.weight": "llm.safetensors",
496
  "llm.model.layers.9.mlp.router.gate.weight": "llm.safetensors",
497
  "llm.model.layers.9.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
@@ -606,6 +611,7 @@
606
  "llm.model.layers.11.self_attn.o_proj.linear.weight": "llm.safetensors",
607
  "llm.model.layers.11.input_layernorm.weight": "llm.safetensors",
608
  "llm.model.layers.11.post_attention_layernorm.weight": "llm.safetensors",
 
609
  "llm.model.layers.11.mlp.router.input_norm.weight": "llm.safetensors",
610
  "llm.model.layers.11.mlp.router.gate.weight": "llm.safetensors",
611
  "llm.model.layers.11.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 7309365134,
4
  "format": "components"
5
  },
6
  "weight_map": {
 
36
  "llm.model.layers.1.self_attn.o_proj.linear.weight": "llm.safetensors",
37
  "llm.model.layers.1.input_layernorm.weight": "llm.safetensors",
38
  "llm.model.layers.1.post_attention_layernorm.weight": "llm.safetensors",
39
+ "llm.model.layers.1.mlp.router.expert_bias": "llm.safetensors",
40
  "llm.model.layers.1.mlp.router.input_norm.weight": "llm.safetensors",
41
  "llm.model.layers.1.mlp.router.gate.weight": "llm.safetensors",
42
  "llm.model.layers.1.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
151
  "llm.model.layers.3.self_attn.o_proj.linear.weight": "llm.safetensors",
152
  "llm.model.layers.3.input_layernorm.weight": "llm.safetensors",
153
  "llm.model.layers.3.post_attention_layernorm.weight": "llm.safetensors",
154
+ "llm.model.layers.3.mlp.router.expert_bias": "llm.safetensors",
155
  "llm.model.layers.3.mlp.router.input_norm.weight": "llm.safetensors",
156
  "llm.model.layers.3.mlp.router.gate.weight": "llm.safetensors",
157
  "llm.model.layers.3.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
266
  "llm.model.layers.5.self_attn.o_proj.linear.weight": "llm.safetensors",
267
  "llm.model.layers.5.input_layernorm.weight": "llm.safetensors",
268
  "llm.model.layers.5.post_attention_layernorm.weight": "llm.safetensors",
269
+ "llm.model.layers.5.mlp.router.expert_bias": "llm.safetensors",
270
  "llm.model.layers.5.mlp.router.input_norm.weight": "llm.safetensors",
271
  "llm.model.layers.5.mlp.router.gate.weight": "llm.safetensors",
272
  "llm.model.layers.5.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
381
  "llm.model.layers.7.self_attn.o_proj.linear.weight": "llm.safetensors",
382
  "llm.model.layers.7.input_layernorm.weight": "llm.safetensors",
383
  "llm.model.layers.7.post_attention_layernorm.weight": "llm.safetensors",
384
+ "llm.model.layers.7.mlp.router.expert_bias": "llm.safetensors",
385
  "llm.model.layers.7.mlp.router.input_norm.weight": "llm.safetensors",
386
  "llm.model.layers.7.mlp.router.gate.weight": "llm.safetensors",
387
  "llm.model.layers.7.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
496
  "llm.model.layers.9.self_attn.o_proj.linear.weight": "llm.safetensors",
497
  "llm.model.layers.9.input_layernorm.weight": "llm.safetensors",
498
  "llm.model.layers.9.post_attention_layernorm.weight": "llm.safetensors",
499
+ "llm.model.layers.9.mlp.router.expert_bias": "llm.safetensors",
500
  "llm.model.layers.9.mlp.router.input_norm.weight": "llm.safetensors",
501
  "llm.model.layers.9.mlp.router.gate.weight": "llm.safetensors",
502
  "llm.model.layers.9.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
 
611
  "llm.model.layers.11.self_attn.o_proj.linear.weight": "llm.safetensors",
612
  "llm.model.layers.11.input_layernorm.weight": "llm.safetensors",
613
  "llm.model.layers.11.post_attention_layernorm.weight": "llm.safetensors",
614
+ "llm.model.layers.11.mlp.router.expert_bias": "llm.safetensors",
615
  "llm.model.layers.11.mlp.router.input_norm.weight": "llm.safetensors",
616
  "llm.model.layers.11.mlp.router.gate.weight": "llm.safetensors",
617
  "llm.model.layers.11.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
modeling_xoron.py CHANGED
@@ -436,18 +436,25 @@ def compute_qk_scale(head_dim: int) -> float:
436
  return head_dim ** -0.25
437
 
438
 
439
- @dataclass
440
  class AttentionKVCache:
441
- """
442
- KV Cache for efficient autoregressive attention.
443
 
444
- Features:
445
- - Memory-efficient storage
446
- - Support for cross-attention caching
447
  """
448
- key_cache: torch.Tensor = None
449
- value_cache: torch.Tensor = None
450
- seen_tokens: int = 0
 
 
 
 
 
 
 
 
 
 
451
 
452
  def update(
453
  self,
@@ -455,36 +462,45 @@ class AttentionKVCache:
455
  value_states: torch.Tensor,
456
  ) -> Tuple[torch.Tensor, torch.Tensor]:
457
  """
458
- Update cache with new key/value states.
459
 
460
  Args:
461
  key_states: New key states [batch, num_heads, seq_len, head_dim]
462
  value_states: New value states [batch, num_heads, seq_len, head_dim]
463
 
464
  Returns:
465
- Updated key and value states including cache
466
  """
467
- if self.key_cache is None:
468
- self.key_cache = key_states
469
- self.value_cache = value_states
470
- else:
471
- self.key_cache = torch.cat([self.key_cache, key_states], dim=2)
472
- self.value_cache = torch.cat([self.value_cache, value_states], dim=2)
473
-
474
- self.seen_tokens += key_states.shape[2]
475
 
476
- return self.key_cache, self.value_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
  def get_seq_length(self) -> int:
479
  """Get current sequence length in cache."""
480
- if self.key_cache is None:
481
- return 0
482
- return self.key_cache.shape[2]
483
 
484
  def reset(self):
485
- """Reset the cache."""
486
- self.key_cache = None
487
- self.value_cache = None
488
  self.seen_tokens = 0
489
 
490
 
@@ -1447,12 +1463,12 @@ class PerceiverAttention(nn.Module):
1447
  sin = sin.unsqueeze(0).unsqueeze(0)
1448
  k = apply_rope(k, cos, sin)
1449
 
1450
- # Attention
1451
- attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
1452
- attn = torch.clamp(attn, min=-11.0, max=11.0)
1453
- attn = attn.softmax(dim=-1)
1454
-
1455
- out = torch.matmul(attn, v)
1456
  out = out.transpose(1, 2).reshape(b, n, self.inner_dim)
1457
 
1458
  return self.to_out(out)
@@ -1848,6 +1864,171 @@ class MultimodalProjector(nn.Module):
1848
  EPS = 1e-5
1849
 
1850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1851
  class MoERouter(nn.Module):
1852
  """
1853
  SOTA Router for Mixture of Experts v2.0 - FP16 native.
@@ -2036,6 +2217,10 @@ class MoELayer(nn.Module):
2036
 
2037
  top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
2038
 
 
 
 
 
2039
  final_output = torch.zeros_like(hidden_flat)
2040
 
2041
  for expert_idx in range(self.num_experts):
@@ -4726,21 +4911,23 @@ class RotaryMultiHeadLatentAttention(nn.Module):
4726
 
4727
  present_key_value = (key, value) if use_cache else None
4728
 
4729
- # Expand KV for grouped query attention
4730
- if self.num_key_value_groups > 1:
4731
- key = key.repeat_interleave(self.num_key_value_groups, dim=1)
4732
- value = value.repeat_interleave(self.num_key_value_groups, dim=1)
4733
-
4734
- # Attention computation
4735
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * self.scale
4736
-
4737
- if attention_mask is not None:
4738
- attn_weights = attn_weights + attention_mask
4739
-
4740
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
4741
- attn_weights = self.dropout(attn_weights)
4742
-
4743
- output = torch.matmul(attn_weights, value)
 
 
4744
  output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
4745
  output = self.o_proj(output)
4746
 
@@ -6719,6 +6906,7 @@ class DualStreamSelfAttention(nn.Module):
6719
  """
6720
  Symmetric Dual-Stream Self-Attention (SD3/Flux-style).
6721
  Two parallel streams with cross-stream information exchange.
 
6722
  """
6723
 
6724
  def __init__(self, hidden_size: int, num_heads: int = 8, max_height: int = 64, max_width: int = 64):
@@ -6727,6 +6915,8 @@ class DualStreamSelfAttention(nn.Module):
6727
  self.num_heads = num_heads
6728
  self.head_dim = hidden_size // num_heads
6729
  self.scale = self.head_dim ** -0.5
 
 
6730
 
6731
  self.to_qkv_a = nn.Linear(hidden_size, hidden_size * 3, bias=False)
6732
  self.to_qkv_b = nn.Linear(hidden_size, hidden_size * 3, bias=False)
@@ -6752,8 +6942,6 @@ class DualStreamSelfAttention(nn.Module):
6752
  q_b, k_b, v_b = qkv_b.unbind(dim=2)
6753
 
6754
  cos, sin = self.rope_2d(x_a, height, width)
6755
- # cos/sin shape: [seq_len, head_dim] -> [1, 1, seq_len, head_dim]
6756
- # to broadcast with q/k shape: [B, num_heads, seq_len, head_dim]
6757
  cos = cos.unsqueeze(0).unsqueeze(1)
6758
  sin = sin.unsqueeze(0).unsqueeze(1)
6759
 
@@ -6772,13 +6960,15 @@ class DualStreamSelfAttention(nn.Module):
6772
  k_combined = torch.cat([k_a, k_b], dim=2)
6773
  v_combined = torch.cat([v_a, v_b], dim=2)
6774
 
6775
- attn_a = torch.matmul(q_a, k_combined.transpose(-1, -2)) * self.scale
6776
- attn_a = F.softmax(attn_a, dim=-1, dtype=torch.float32).to(x_a.dtype)
6777
- out_a = torch.matmul(attn_a, v_combined)
6778
-
6779
- attn_b = torch.matmul(q_b, k_combined.transpose(-1, -2)) * self.scale
6780
- attn_b = F.softmax(attn_b, dim=-1, dtype=torch.float32).to(x_b.dtype)
6781
- out_b = torch.matmul(attn_b, v_combined)
 
 
6782
 
6783
  out_a = out_a.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
6784
  out_b = out_b.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
@@ -6815,10 +7005,12 @@ class CrossAttention(nn.Module):
6815
  k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
6816
  v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
6817
 
6818
- attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
6819
- attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype)
6820
-
6821
- out = torch.matmul(attn, v)
 
 
6822
  out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
6823
  out = self.to_out(out)
6824
 
@@ -6984,12 +7176,18 @@ class MoEDiT(nn.Module):
6984
  self.final_norm = nn.LayerNorm(hidden_size)
6985
  self.unpatch_embed = UnpatchEmbed(patch_size, out_channels, hidden_size)
6986
 
 
 
6987
  self._init_weights()
6988
 
6989
  def _init_weights(self):
6990
  nn.init.zeros_(self.unpatch_embed.proj.weight)
6991
  nn.init.zeros_(self.unpatch_embed.proj.bias)
6992
 
 
 
 
 
6993
  def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
6994
  batch_size, channels, height, width = x.shape
6995
  patch_height = height // self.patch_size
@@ -7010,7 +7208,13 @@ class MoEDiT(nn.Module):
7010
  x_b = x_patches.clone()
7011
 
7012
  for block in self.blocks:
7013
- x_a, x_b = block(x_a, x_b, context_proj, t_emb, patch_height, patch_width)
 
 
 
 
 
 
7014
 
7015
  x_combined = (x_a + x_b) / 2
7016
  x_combined = self.final_norm(x_combined)
@@ -7518,11 +7722,12 @@ class SpatialAttention(nn.Module):
7518
  q = apply_rope(q, cos, sin)
7519
  k = apply_rope(k, cos, sin)
7520
 
7521
- # Attention only within each frame: [B*T, heads, H*W, H*W]
7522
- attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
7523
- attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype)
7524
-
7525
- out = torch.matmul(attn, v)
 
7526
  out = out.transpose(1, 2).reshape(batch_size * frames, spatial_len, self.hidden_size)
7527
  out = self.to_out(out)
7528
 
@@ -7573,16 +7778,12 @@ class TemporalAttention(nn.Module):
7573
  q = apply_rope(q, cos, sin)
7574
  k = apply_rope(k, cos, sin)
7575
 
7576
- # Attention across time for each position: [B*H*W, heads, T, T]
7577
- attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
7578
-
7579
- if causal:
7580
- causal_mask = torch.triu(torch.ones(frames, frames, device=x.device, dtype=torch.bool), diagonal=1)
7581
- attn = attn.masked_fill(causal_mask, float('-inf'))
7582
-
7583
- attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype)
7584
-
7585
- out = torch.matmul(attn, v)
7586
  out = out.transpose(1, 2).reshape(batch_size * spatial_len, frames, self.hidden_size)
7587
 
7588
  # Reshape back to [B, T*H*W, hidden]
@@ -7647,10 +7848,12 @@ class CrossAttention3D(nn.Module):
7647
  k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7648
  v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7649
 
7650
- attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
7651
- attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype)
7652
-
7653
- out = torch.matmul(attn, v)
 
 
7654
  out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
7655
  out = self.to_out(out)
7656
 
@@ -7768,6 +7971,12 @@ class VideoUNet3D(nn.Module):
7768
 
7769
  nn.init.zeros_(self.output_proj[-1].weight)
7770
  nn.init.zeros_(self.output_proj[-1].bias)
 
 
 
 
 
 
7771
 
7772
  def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, first_frame_latent: Optional[torch.Tensor] = None) -> torch.Tensor:
7773
  batch_size, channels, frames, height, width = x.shape
@@ -7786,7 +7995,13 @@ class VideoUNet3D(nn.Module):
7786
  temporal_context = t_emb.unsqueeze(1).expand(-1, frames * height * width, -1)
7787
 
7788
  for block in self.transformer_blocks:
7789
- h = block(h, context, height, width, frames, temporal_context)
 
 
 
 
 
 
7790
 
7791
  h = h.reshape(batch_size, frames, height, width, self.hidden_size).permute(0, 4, 1, 2, 3)
7792
 
@@ -8260,12 +8475,31 @@ def apply_rotary_pos_emb(
8260
  return q_embed, k_embed
8261
 
8262
 
8263
- @dataclass
8264
  class KVCache:
8265
- """KV Cache for efficient autoregressive generation."""
8266
- key_cache: torch.Tensor
8267
- value_cache: torch.Tensor
8268
- seen_tokens: int = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8269
 
8270
  def update(
8271
  self,
@@ -8273,20 +8507,43 @@ class KVCache:
8273
  value_states: torch.Tensor,
8274
  chunk_size: Optional[int] = None,
8275
  ) -> Tuple[torch.Tensor, torch.Tensor]:
8276
- if self.key_cache is None:
8277
- self.key_cache = key_states
8278
- self.value_cache = value_states
8279
- else:
8280
- self.key_cache = torch.cat([self.key_cache, key_states], dim=2)
8281
- self.value_cache = torch.cat([self.value_cache, value_states], dim=2)
8282
 
8283
- self.seen_tokens = self.key_cache.shape[2]
8284
-
8285
- if chunk_size is not None and self.key_cache.shape[2] > chunk_size * 2:
8286
- self.key_cache = self.key_cache[:, :, -chunk_size * 2:]
8287
- self.value_cache = self.value_cache[:, :, -chunk_size * 2:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8288
 
8289
- return self.key_cache, self.value_cache
 
 
8290
 
8291
 
8292
  def ring_attention(
@@ -8298,7 +8555,7 @@ def ring_attention(
8298
  ) -> torch.Tensor:
8299
  """
8300
  Ring Attention for distributed long-context processing.
8301
- Processes sequence in chunks with proper attention accumulation.
8302
 
8303
  Args:
8304
  query: [batch, heads, seq_len, head_dim]
@@ -8312,24 +8569,45 @@ def ring_attention(
8312
  """
8313
  batch_size, num_heads, seq_len, head_dim = query.shape
8314
  kv_len = key.shape[2]
8315
- scale = head_dim ** -0.5
8316
 
 
8317
  if seq_len <= chunk_size and kv_len <= chunk_size:
8318
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scale
8319
-
8320
- if causal:
8321
- causal_mask = torch.triu(torch.ones(seq_len, kv_len, device=query.device, dtype=torch.bool), diagonal=1)
8322
- if kv_len > seq_len:
8323
- causal_mask = causal_mask[:, -seq_len:]
8324
- attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
8325
 
8326
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
8327
- return torch.matmul(attn_weights, value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8328
 
 
 
8329
  output = torch.zeros_like(query)
8330
  max_logits = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=query.device, dtype=query.dtype)
8331
  sum_exp = torch.zeros((batch_size, num_heads, seq_len, 1), device=query.device, dtype=query.dtype)
8332
 
 
 
 
 
 
 
8333
  num_kv_chunks = (kv_len + chunk_size - 1) // chunk_size
8334
 
8335
  for kv_idx in range(num_kv_chunks):
@@ -8342,13 +8620,11 @@ def ring_attention(
8342
  attn_chunk = torch.matmul(query, key_chunk.transpose(-1, -2)) * scale
8343
 
8344
  if causal:
8345
- chunk_len = kv_end - kv_start
8346
- for q_idx in range(seq_len):
8347
- q_pos = q_idx + (kv_len - seq_len) if kv_len > seq_len else q_idx
8348
- for k_idx in range(chunk_len):
8349
- k_pos = kv_start + k_idx
8350
- if k_pos > q_pos:
8351
- attn_chunk[:, :, q_idx, k_idx] = float('-inf')
8352
 
8353
  chunk_max = attn_chunk.max(dim=-1, keepdim=True)[0]
8354
  new_max = torch.maximum(max_logits, chunk_max)
@@ -8465,31 +8741,35 @@ class MultiHeadLatentAttention(nn.Module):
8465
  self.ring_chunk_size if self.use_ring_attention else None
8466
  )
8467
 
8468
- if self.num_key_value_groups > 1:
8469
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
8470
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
8471
-
8472
  if self.use_ring_attention:
 
 
 
 
 
 
 
8473
  attn_output = ring_attention(
8474
- query_states, key_states, value_states,
8475
  chunk_size=self.ring_chunk_size,
8476
  causal=True,
8477
  )
8478
  else:
8479
- attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scale
8480
-
 
8481
  kv_len = key_states.shape[2]
8482
- causal_mask = torch.triu(
8483
- torch.ones(seq_len, kv_len, device=hidden_states.device, dtype=torch.bool),
8484
- diagonal=kv_len - seq_len + 1
 
 
 
 
 
 
 
8485
  )
8486
- attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
8487
-
8488
- if attention_mask is not None:
8489
- attn_weights = attn_weights + attention_mask
8490
-
8491
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
8492
- attn_output = torch.matmul(attn_weights, value_states)
8493
 
8494
  attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
8495
  attn_output = self.o_proj(attn_output)
@@ -8518,6 +8798,11 @@ class AuxLosslessMoERouter(nn.Module):
8518
  self.input_norm = LlamaRMSNorm(hidden_size)
8519
  self.gate = nn.Linear(hidden_size, num_experts, bias=False)
8520
  nn.init.normal_(self.gate.weight, mean=0.0, std=0.01)
 
 
 
 
 
8521
 
8522
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8523
  batch_size, seq_len, hidden_dim = hidden_states.shape
@@ -8526,7 +8811,10 @@ class AuxLosslessMoERouter(nn.Module):
8526
  hidden_norm = self.input_norm(hidden_flat)
8527
  router_logits = self.gate(hidden_norm)
8528
 
8529
- router_probs = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
 
 
 
8530
 
8531
  top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
8532
 
@@ -8614,8 +8902,13 @@ class AuxLosslessMoELayer(nn.Module):
8614
  batch_size, seq_len, hidden_size = hidden_states.shape
8615
  original_dtype = hidden_states.dtype
8616
  hidden_flat = hidden_states.view(-1, hidden_size)
 
8617
 
8618
- top_k_probs, top_k_indices, _ = self.router(hidden_states)
 
 
 
 
8619
 
8620
  final_output = torch.zeros_like(hidden_flat)
8621
 
@@ -8635,10 +8928,37 @@ class AuxLosslessMoELayer(nn.Module):
8635
 
8636
  final_output = final_output.view(batch_size, seq_len, hidden_size)
8637
 
8638
- aux_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
 
 
8639
 
8640
  return final_output, aux_loss
8641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8642
 
8643
  MoELayer = AuxLosslessMoELayer
8644
 
 
436
  return head_dim ** -0.25
437
 
438
 
 
439
  class AttentionKVCache:
440
+ """Pre-allocated KV Cache — static buffer with index-based filling.
 
441
 
442
+ Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
443
+ Buffer is allocated once at first use and reused via slice assignment.
 
444
  """
445
+
446
+ __slots__ = ('key_cache', 'value_cache', 'seen_tokens', '_max_len')
447
+
448
+ def __init__(self, max_seq_len: int = 131072):
449
+ self.key_cache: torch.Tensor = None
450
+ self.value_cache: torch.Tensor = None
451
+ self.seen_tokens: int = 0
452
+ self._max_len = max_seq_len
453
+
454
+ def _allocate(self, batch: int, heads: int, head_dim: int, device: torch.device, dtype: torch.dtype):
455
+ """Allocate static buffer on first use."""
456
+ self.key_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
457
+ self.value_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
458
 
459
  def update(
460
  self,
 
462
  value_states: torch.Tensor,
463
  ) -> Tuple[torch.Tensor, torch.Tensor]:
464
  """
465
+ Update cache with new key/value states using index-based filling.
466
 
467
  Args:
468
  key_states: New key states [batch, num_heads, seq_len, head_dim]
469
  value_states: New value states [batch, num_heads, seq_len, head_dim]
470
 
471
  Returns:
472
+ Updated key and value states including cache (views, no copy)
473
  """
474
+ batch, heads, new_len, head_dim = key_states.shape
 
 
 
 
 
 
 
475
 
476
+ if self.key_cache is None:
477
+ self._allocate(batch, heads, head_dim, key_states.device, key_states.dtype)
478
+ self.seen_tokens = 0
479
+
480
+ # Grow buffer if needed (rare fallback)
481
+ if self.seen_tokens + new_len > self.key_cache.shape[2]:
482
+ new_max = max(self.key_cache.shape[2] * 2, self.seen_tokens + new_len)
483
+ new_key = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
484
+ new_val = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
485
+ new_key[:, :, :self.seen_tokens] = self.key_cache[:, :, :self.seen_tokens]
486
+ new_val[:, :, :self.seen_tokens] = self.value_cache[:, :, :self.seen_tokens]
487
+ self.key_cache = new_key
488
+ self.value_cache = new_val
489
+
490
+ # Index-based fill — no allocation, no fragmentation
491
+ self.key_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = key_states
492
+ self.value_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = value_states
493
+ self.seen_tokens += new_len
494
+
495
+ # Return valid slice (view, no copy)
496
+ return self.key_cache[:, :, :self.seen_tokens], self.value_cache[:, :, :self.seen_tokens]
497
 
498
  def get_seq_length(self) -> int:
499
  """Get current sequence length in cache."""
500
+ return self.seen_tokens
 
 
501
 
502
  def reset(self):
503
+ """Reset cache position without deallocating the buffer."""
 
 
504
  self.seen_tokens = 0
505
 
506
 
 
1463
  sin = sin.unsqueeze(0).unsqueeze(0)
1464
  k = apply_rope(k, cos, sin)
1465
 
1466
+ # Flash Attention 2.0 via SDPA — FP16-safe with Q/K pre-scaling
1467
+ qk_scale = d ** -0.25
1468
+ out = F.scaled_dot_product_attention(
1469
+ q * qk_scale, k * qk_scale, v,
1470
+ is_causal=False, scale=1.0,
1471
+ )
1472
  out = out.transpose(1, 2).reshape(b, n, self.inner_dim)
1473
 
1474
  return self.to_out(out)
 
1864
  EPS = 1e-5
1865
 
1866
 
1867
+ class ExpertUtilizationTracker:
1868
+ """
1869
+ Tracks expert utilization across MoE layers.
1870
+
1871
+ Attach to any MoE layer to log per-expert usage histograms.
1872
+ Every `report_interval` steps, prints a report showing:
1873
+ - Frequency of use per expert
1874
+ - Cold experts (used < 1% of tokens)
1875
+ - Count of experts offloaded to CPU (if ExpertOffloadManager is available)
1876
+
1877
+ Usage:
1878
+ tracker = ExpertUtilizationTracker(num_experts=8, layer_name="layer.3.moe")
1879
+ # In forward: tracker.record(top_k_indices)
1880
+ # Every N steps: tracker.step() (auto-prints when interval hit)
1881
+ """
1882
+
1883
+ def __init__(
1884
+ self,
1885
+ num_experts: int,
1886
+ layer_name: str = "moe",
1887
+ report_interval: int = 100,
1888
+ cold_threshold_pct: float = 1.0,
1889
+ ):
1890
+ self.num_experts = num_experts
1891
+ self.layer_name = layer_name
1892
+ self.report_interval = report_interval
1893
+ self.cold_threshold_pct = cold_threshold_pct
1894
+
1895
+ self._counts = torch.zeros(num_experts, dtype=torch.long)
1896
+ self._total_tokens = 0
1897
+ self._step = 0
1898
+ self._offload_manager = None # Link to ExpertOffloadManager if available
1899
+
1900
+ def link_offload_manager(self, manager):
1901
+ """Link an ExpertOffloadManager for cold-expert reporting."""
1902
+ self._offload_manager = manager
1903
+
1904
+ def record(self, expert_indices: torch.Tensor):
1905
+ """
1906
+ Record expert selections from a forward pass.
1907
+
1908
+ Args:
1909
+ expert_indices: [num_tokens, top_k] tensor of selected expert indices
1910
+ """
1911
+ indices_flat = expert_indices.detach().cpu().reshape(-1)
1912
+ for idx in range(self.num_experts):
1913
+ self._counts[idx] += (indices_flat == idx).sum().item()
1914
+ self._total_tokens += expert_indices.shape[0]
1915
+
1916
+ def step(self):
1917
+ """Advance step counter. Prints report and resets when interval is hit."""
1918
+ self._step += 1
1919
+ if self._step % self.report_interval == 0:
1920
+ self._print_report()
1921
+ self._reset()
1922
+
1923
+ def _reset(self):
1924
+ """Reset accumulators for next interval."""
1925
+ self._counts.zero_()
1926
+ self._total_tokens = 0
1927
+
1928
+ def _print_report(self):
1929
+ """Print expert utilization histogram."""
1930
+ if self._total_tokens == 0:
1931
+ return
1932
+
1933
+ freqs = self._counts.float()
1934
+ total_assignments = freqs.sum().item()
1935
+ if total_assignments == 0:
1936
+ return
1937
+
1938
+ pcts = (freqs / total_assignments * 100).tolist()
1939
+
1940
+ # Identify cold experts
1941
+ cold_experts = [i for i, p in enumerate(pcts) if p < self.cold_threshold_pct]
1942
+
1943
+ # Build histogram
1944
+ max_pct = max(pcts) if pcts else 0
1945
+ bar_max = 30 # max bar width
1946
+
1947
+ lines = [f"\n{'='*60}"]
1948
+ lines.append(f" Expert Utilization — {self.layer_name} (step {self._step})")
1949
+ lines.append(f" {self._total_tokens:,} tokens, {int(total_assignments):,} assignments")
1950
+ lines.append(f"{'─'*60}")
1951
+
1952
+ for i, pct in enumerate(pcts):
1953
+ bar_len = int(pct / max_pct * bar_max) if max_pct > 0 else 0
1954
+ bar = "█" * bar_len
1955
+ cold_tag = " ❄️" if pct < self.cold_threshold_pct else ""
1956
+ lines.append(f" Expert {i:2d} │{bar:<{bar_max}}│ {pct:5.1f}% ({int(self._counts[i]):>6d}){cold_tag}")
1957
+
1958
+ lines.append(f"{'─'*60}")
1959
+
1960
+ if cold_experts:
1961
+ lines.append(f" ❄️ Cold experts (<{self.cold_threshold_pct}%): {cold_experts}")
1962
+ else:
1963
+ lines.append(f" ✅ All experts active (no cold experts)")
1964
+
1965
+ # Report offloaded experts if manager linked
1966
+ if self._offload_manager is not None:
1967
+ status = self._offload_manager.get_status()
1968
+ lines.append(f" 💾 Offloaded to CPU: {status['cpu']}/{status['total']}")
1969
+
1970
+ # Compute load balance score (1.0 = perfectly balanced)
1971
+ ideal_pct = 100.0 / self.num_experts
1972
+ balance = 1.0 - (sum(abs(p - ideal_pct) for p in pcts) / (2 * 100))
1973
+ lines.append(f" ⚖️ Load balance score: {balance:.3f} (1.0 = perfect)")
1974
+
1975
+ lines.append(f"{'='*60}")
1976
+ print("\n".join(lines))
1977
+
1978
+ def get_stats(self) -> dict:
1979
+ """Return current stats as a dict (for programmatic access)."""
1980
+ total = self._counts.sum().item()
1981
+ if total == 0:
1982
+ pcts = [0.0] * self.num_experts
1983
+ else:
1984
+ pcts = (self._counts.float() / total * 100).tolist()
1985
+
1986
+ cold = [i for i, p in enumerate(pcts) if p < self.cold_threshold_pct]
1987
+ ideal_pct = 100.0 / self.num_experts
1988
+ balance = 1.0 - (sum(abs(p - ideal_pct) for p in pcts) / (2 * 100)) if total > 0 else 0.0
1989
+
1990
+ return {
1991
+ "step": self._step,
1992
+ "layer_name": self.layer_name,
1993
+ "total_tokens": self._total_tokens,
1994
+ "expert_counts": self._counts.tolist(),
1995
+ "expert_pcts": pcts,
1996
+ "cold_experts": cold,
1997
+ "balance_score": balance,
1998
+ }
1999
+
2000
+
2001
+ def attach_utilization_trackers(
2002
+ model: torch.nn.Module,
2003
+ report_interval: int = 100,
2004
+ ) -> list:
2005
+ """
2006
+ Find all MoE layers in a model and attach ExpertUtilizationTrackers.
2007
+
2008
+ Returns list of trackers for manual step() calls in the training loop.
2009
+ """
2010
+ trackers = []
2011
+ for name, module in model.named_modules():
2012
+ if hasattr(module, 'experts') and hasattr(module, 'router'):
2013
+ num_experts = len(module.experts)
2014
+ tracker = ExpertUtilizationTracker(
2015
+ num_experts=num_experts,
2016
+ layer_name=name,
2017
+ report_interval=report_interval,
2018
+ )
2019
+ # Link offload manager if present
2020
+ if hasattr(module, '_expert_offload_manager'):
2021
+ tracker.link_offload_manager(module._expert_offload_manager)
2022
+
2023
+ module._utilization_tracker = tracker
2024
+ trackers.append(tracker)
2025
+
2026
+ if trackers:
2027
+ print(f" 📊 Attached {len(trackers)} expert utilization trackers (report every {report_interval} steps)")
2028
+
2029
+ return trackers
2030
+
2031
+
2032
  class MoERouter(nn.Module):
2033
  """
2034
  SOTA Router for Mixture of Experts v2.0 - FP16 native.
 
2217
 
2218
  top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
2219
 
2220
+ # Record expert utilization if tracker is attached
2221
+ if hasattr(self, '_utilization_tracker'):
2222
+ self._utilization_tracker.record(top_k_indices)
2223
+
2224
  final_output = torch.zeros_like(hidden_flat)
2225
 
2226
  for expert_idx in range(self.num_experts):
 
4911
 
4912
  present_key_value = (key, value) if use_cache else None
4913
 
4914
+ # True GQA via SDPA no repeat_interleave, O(N) memory, FP16-safe
4915
+ qk_scale = self.head_dim ** -0.25
4916
+ kv_len = key.shape[2]
4917
+ use_causal = (attention_mask is None and seq_len > 1 and seq_len == kv_len)
4918
+
4919
+ dropout_p = self.dropout.p if self.training else 0.0
4920
+
4921
+ output = F.scaled_dot_product_attention(
4922
+ query * qk_scale,
4923
+ key * qk_scale,
4924
+ value,
4925
+ attn_mask=attention_mask,
4926
+ is_causal=use_causal,
4927
+ dropout_p=dropout_p,
4928
+ scale=1.0,
4929
+ enable_gqa=(self.num_key_value_groups > 1),
4930
+ )
4931
  output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
4932
  output = self.o_proj(output)
4933
 
 
6906
  """
6907
  Symmetric Dual-Stream Self-Attention (SD3/Flux-style).
6908
  Two parallel streams with cross-stream information exchange.
6909
+ Uses Flash Attention 2.0 via SDPA for O(N) memory.
6910
  """
6911
 
6912
  def __init__(self, hidden_size: int, num_heads: int = 8, max_height: int = 64, max_width: int = 64):
 
6915
  self.num_heads = num_heads
6916
  self.head_dim = hidden_size // num_heads
6917
  self.scale = self.head_dim ** -0.5
6918
+ # Pre-compute Q/K scaling for FP16 stability
6919
+ self._qk_scale = self.head_dim ** -0.25
6920
 
6921
  self.to_qkv_a = nn.Linear(hidden_size, hidden_size * 3, bias=False)
6922
  self.to_qkv_b = nn.Linear(hidden_size, hidden_size * 3, bias=False)
 
6942
  q_b, k_b, v_b = qkv_b.unbind(dim=2)
6943
 
6944
  cos, sin = self.rope_2d(x_a, height, width)
 
 
6945
  cos = cos.unsqueeze(0).unsqueeze(1)
6946
  sin = sin.unsqueeze(0).unsqueeze(1)
6947
 
 
6960
  k_combined = torch.cat([k_a, k_b], dim=2)
6961
  v_combined = torch.cat([v_a, v_b], dim=2)
6962
 
6963
+ # Flash Attention 2.0 via SDPA — O(N) memory, FP16-safe with pre-scaling
6964
+ out_a = F.scaled_dot_product_attention(
6965
+ q_a * self._qk_scale, k_combined * self._qk_scale, v_combined,
6966
+ is_causal=False, scale=1.0,
6967
+ )
6968
+ out_b = F.scaled_dot_product_attention(
6969
+ q_b * self._qk_scale, k_combined * self._qk_scale, v_combined,
6970
+ is_causal=False, scale=1.0,
6971
+ )
6972
 
6973
  out_a = out_a.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
6974
  out_b = out_b.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
 
7005
  k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7006
  v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7007
 
7008
+ # Flash Attention 2.0 via SDPA — O(N) memory, non-causal cross-attention
7009
+ qk_scale = self.head_dim ** -0.25
7010
+ out = F.scaled_dot_product_attention(
7011
+ q * qk_scale, k * qk_scale, v,
7012
+ is_causal=False, scale=1.0,
7013
+ )
7014
  out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
7015
  out = self.to_out(out)
7016
 
 
7176
  self.final_norm = nn.LayerNorm(hidden_size)
7177
  self.unpatch_embed = UnpatchEmbed(patch_size, out_channels, hidden_size)
7178
 
7179
+ self.gradient_checkpointing = False
7180
+
7181
  self._init_weights()
7182
 
7183
  def _init_weights(self):
7184
  nn.init.zeros_(self.unpatch_embed.proj.weight)
7185
  nn.init.zeros_(self.unpatch_embed.proj.bias)
7186
 
7187
+ def enable_gradient_checkpointing(self):
7188
+ """Enable gradient checkpointing for memory efficiency."""
7189
+ self.gradient_checkpointing = True
7190
+
7191
  def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
7192
  batch_size, channels, height, width = x.shape
7193
  patch_height = height // self.patch_size
 
7208
  x_b = x_patches.clone()
7209
 
7210
  for block in self.blocks:
7211
+ if self.gradient_checkpointing and self.training:
7212
+ x_a, x_b = torch.utils.checkpoint.checkpoint(
7213
+ block, x_a, x_b, context_proj, t_emb, patch_height, patch_width,
7214
+ use_reentrant=False
7215
+ )
7216
+ else:
7217
+ x_a, x_b = block(x_a, x_b, context_proj, t_emb, patch_height, patch_width)
7218
 
7219
  x_combined = (x_a + x_b) / 2
7220
  x_combined = self.final_norm(x_combined)
 
7722
  q = apply_rope(q, cos, sin)
7723
  k = apply_rope(k, cos, sin)
7724
 
7725
+ # Flash Attention 2.0 via SDPA O(N) memory, non-causal spatial attention
7726
+ qk_scale = self.head_dim ** -0.25
7727
+ out = F.scaled_dot_product_attention(
7728
+ q * qk_scale, k * qk_scale, v,
7729
+ is_causal=False, scale=1.0,
7730
+ )
7731
  out = out.transpose(1, 2).reshape(batch_size * frames, spatial_len, self.hidden_size)
7732
  out = self.to_out(out)
7733
 
 
7778
  q = apply_rope(q, cos, sin)
7779
  k = apply_rope(k, cos, sin)
7780
 
7781
+ # Flash Attention 2.0 via SDPA causal temporal attention
7782
+ qk_scale = self.head_dim ** -0.25
7783
+ out = F.scaled_dot_product_attention(
7784
+ q * qk_scale, k * qk_scale, v,
7785
+ is_causal=causal, scale=1.0,
7786
+ )
 
 
 
 
7787
  out = out.transpose(1, 2).reshape(batch_size * spatial_len, frames, self.hidden_size)
7788
 
7789
  # Reshape back to [B, T*H*W, hidden]
 
7848
  k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7849
  v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
7850
 
7851
+ # Flash Attention 2.0 via SDPA — non-causal cross-attention
7852
+ qk_scale = self.head_dim ** -0.25
7853
+ out = F.scaled_dot_product_attention(
7854
+ q * qk_scale, k * qk_scale, v,
7855
+ is_causal=False, scale=1.0,
7856
+ )
7857
  out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
7858
  out = self.to_out(out)
7859
 
 
7971
 
7972
  nn.init.zeros_(self.output_proj[-1].weight)
7973
  nn.init.zeros_(self.output_proj[-1].bias)
7974
+
7975
+ self.gradient_checkpointing = False
7976
+
7977
+ def enable_gradient_checkpointing(self):
7978
+ """Enable gradient checkpointing for memory efficiency."""
7979
+ self.gradient_checkpointing = True
7980
 
7981
  def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, first_frame_latent: Optional[torch.Tensor] = None) -> torch.Tensor:
7982
  batch_size, channels, frames, height, width = x.shape
 
7995
  temporal_context = t_emb.unsqueeze(1).expand(-1, frames * height * width, -1)
7996
 
7997
  for block in self.transformer_blocks:
7998
+ if self.gradient_checkpointing and self.training:
7999
+ h = torch.utils.checkpoint.checkpoint(
8000
+ block, h, context, height, width, frames, temporal_context,
8001
+ use_reentrant=False
8002
+ )
8003
+ else:
8004
+ h = block(h, context, height, width, frames, temporal_context)
8005
 
8006
  h = h.reshape(batch_size, frames, height, width, self.hidden_size).permute(0, 4, 1, 2, 3)
8007
 
 
8475
  return q_embed, k_embed
8476
 
8477
 
 
8478
  class KVCache:
8479
+ """Pre-allocated KV Cache static buffer with index-based filling.
8480
+
8481
+ Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
8482
+ Buffer is allocated once at first use and reused via slice assignment.
8483
+ """
8484
+
8485
+ __slots__ = ('key_cache', 'value_cache', 'seen_tokens', '_max_len')
8486
+
8487
+ def __init__(
8488
+ self,
8489
+ key_cache: torch.Tensor = None,
8490
+ value_cache: torch.Tensor = None,
8491
+ seen_tokens: int = 0,
8492
+ max_seq_len: int = 131072,
8493
+ ):
8494
+ self.key_cache = key_cache
8495
+ self.value_cache = value_cache
8496
+ self.seen_tokens = seen_tokens
8497
+ self._max_len = max_seq_len
8498
+
8499
+ def _allocate(self, batch: int, heads: int, head_dim: int, device: torch.device, dtype: torch.dtype):
8500
+ """Allocate static buffer on first use."""
8501
+ self.key_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
8502
+ self.value_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
8503
 
8504
  def update(
8505
  self,
 
8507
  value_states: torch.Tensor,
8508
  chunk_size: Optional[int] = None,
8509
  ) -> Tuple[torch.Tensor, torch.Tensor]:
8510
+ batch, heads, new_len, head_dim = key_states.shape
 
 
 
 
 
8511
 
8512
+ if self.key_cache is None:
8513
+ # Allocate static buffer on first call
8514
+ self._allocate(batch, heads, head_dim, key_states.device, key_states.dtype)
8515
+ self.seen_tokens = 0
8516
+
8517
+ # Check if we need to apply chunk windowing
8518
+ if chunk_size is not None and self.seen_tokens + new_len > chunk_size * 2:
8519
+ # Shift: keep only the last chunk_size tokens, then append
8520
+ keep = chunk_size
8521
+ if self.seen_tokens > keep:
8522
+ self.key_cache[:, :, :keep] = self.key_cache[:, :, self.seen_tokens - keep:self.seen_tokens].clone()
8523
+ self.value_cache[:, :, :keep] = self.value_cache[:, :, self.seen_tokens - keep:self.seen_tokens].clone()
8524
+ self.seen_tokens = keep
8525
+
8526
+ # Grow buffer if needed (rare fallback — avoids crash on very long sequences)
8527
+ if self.seen_tokens + new_len > self.key_cache.shape[2]:
8528
+ new_max = max(self.key_cache.shape[2] * 2, self.seen_tokens + new_len)
8529
+ new_key = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
8530
+ new_val = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
8531
+ new_key[:, :, :self.seen_tokens] = self.key_cache[:, :, :self.seen_tokens]
8532
+ new_val[:, :, :self.seen_tokens] = self.value_cache[:, :, :self.seen_tokens]
8533
+ self.key_cache = new_key
8534
+ self.value_cache = new_val
8535
+
8536
+ # Index-based fill — no allocation, no fragmentation
8537
+ self.key_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = key_states
8538
+ self.value_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = value_states
8539
+ self.seen_tokens += new_len
8540
+
8541
+ # Return only the valid slice (view, no copy)
8542
+ return self.key_cache[:, :, :self.seen_tokens], self.value_cache[:, :, :self.seen_tokens]
8543
 
8544
+ def reset(self):
8545
+ """Reset cache position without deallocating the buffer."""
8546
+ self.seen_tokens = 0
8547
 
8548
 
8549
  def ring_attention(
 
8555
  ) -> torch.Tensor:
8556
  """
8557
  Ring Attention for distributed long-context processing.
8558
+ Processes sequence in chunks with online softmax accumulation.
8559
 
8560
  Args:
8561
  query: [batch, heads, seq_len, head_dim]
 
8569
  """
8570
  batch_size, num_heads, seq_len, head_dim = query.shape
8571
  kv_len = key.shape[2]
 
8572
 
8573
+ # Short path: use SDPA directly for small sequences
8574
  if seq_len <= chunk_size and kv_len <= chunk_size:
8575
+ qk_scale = head_dim ** -0.25
8576
+ use_causal = causal and seq_len == kv_len and seq_len > 1
 
 
 
 
 
8577
 
8578
+ if use_causal:
8579
+ return F.scaled_dot_product_attention(
8580
+ query * qk_scale, key * qk_scale, value,
8581
+ is_causal=True, scale=1.0,
8582
+ )
8583
+ elif causal and kv_len > seq_len:
8584
+ # KV cache case: build explicit causal mask
8585
+ causal_mask = torch.zeros(seq_len, kv_len, device=query.device, dtype=query.dtype)
8586
+ q_pos = torch.arange(seq_len, device=query.device) + (kv_len - seq_len)
8587
+ k_pos = torch.arange(kv_len, device=query.device)
8588
+ causal_mask = torch.where(k_pos.unsqueeze(0) > q_pos.unsqueeze(1), float('-inf'), 0.0)
8589
+ return F.scaled_dot_product_attention(
8590
+ query * qk_scale, key * qk_scale, value,
8591
+ attn_mask=causal_mask, scale=1.0,
8592
+ )
8593
+ else:
8594
+ return F.scaled_dot_product_attention(
8595
+ query * qk_scale, key * qk_scale, value,
8596
+ is_causal=False, scale=1.0,
8597
+ )
8598
 
8599
+ # Long path: chunked attention with online softmax (FlashAttention-style)
8600
+ scale = head_dim ** -0.5
8601
  output = torch.zeros_like(query)
8602
  max_logits = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=query.device, dtype=query.dtype)
8603
  sum_exp = torch.zeros((batch_size, num_heads, seq_len, 1), device=query.device, dtype=query.dtype)
8604
 
8605
+ # Pre-compute query positions for vectorized causal masking
8606
+ if causal:
8607
+ q_positions = torch.arange(seq_len, device=query.device)
8608
+ if kv_len > seq_len:
8609
+ q_positions = q_positions + (kv_len - seq_len)
8610
+
8611
  num_kv_chunks = (kv_len + chunk_size - 1) // chunk_size
8612
 
8613
  for kv_idx in range(num_kv_chunks):
 
8620
  attn_chunk = torch.matmul(query, key_chunk.transpose(-1, -2)) * scale
8621
 
8622
  if causal:
8623
+ # Vectorized causal mask — replaces O(n²) nested Python loop
8624
+ k_positions = torch.arange(kv_start, kv_end, device=query.device)
8625
+ # mask[i, j] = True where k_pos[j] > q_pos[i] (future tokens)
8626
+ causal_mask = k_positions.unsqueeze(0) > q_positions.unsqueeze(1) # [seq_len, chunk_len]
8627
+ attn_chunk = attn_chunk.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
 
 
8628
 
8629
  chunk_max = attn_chunk.max(dim=-1, keepdim=True)[0]
8630
  new_max = torch.maximum(max_logits, chunk_max)
 
8741
  self.ring_chunk_size if self.use_ring_attention else None
8742
  )
8743
 
 
 
 
 
8744
  if self.use_ring_attention:
8745
+ # Ring attention needs matched head counts — expand KV heads
8746
+ if self.num_key_value_groups > 1:
8747
+ key_expanded = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
8748
+ value_expanded = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
8749
+ else:
8750
+ key_expanded = key_states
8751
+ value_expanded = value_states
8752
  attn_output = ring_attention(
8753
+ query_states, key_expanded, value_expanded,
8754
  chunk_size=self.ring_chunk_size,
8755
  causal=True,
8756
  )
8757
  else:
8758
+ # True GQA via SDPA — no repeat_interleave, O(N) memory
8759
+ # SDPA natively handles N query heads with M KV heads via enable_gqa
8760
+ qk_scale = self.head_dim ** -0.25
8761
  kv_len = key_states.shape[2]
8762
+ use_causal = (attention_mask is None and seq_len > 1 and seq_len == kv_len)
8763
+
8764
+ attn_output = F.scaled_dot_product_attention(
8765
+ query_states * qk_scale,
8766
+ key_states * qk_scale,
8767
+ value_states,
8768
+ attn_mask=attention_mask,
8769
+ is_causal=use_causal,
8770
+ scale=1.0, # Already scaled Q and K
8771
+ enable_gqa=(self.num_key_value_groups > 1),
8772
  )
 
 
 
 
 
 
 
8773
 
8774
  attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
8775
  attn_output = self.o_proj(attn_output)
 
8798
  self.input_norm = LlamaRMSNorm(hidden_size)
8799
  self.gate = nn.Linear(hidden_size, num_experts, bias=False)
8800
  nn.init.normal_(self.gate.weight, mean=0.0, std=0.01)
8801
+
8802
+ # DeepSeek-style expert bias for aux-lossless load balancing
8803
+ # This learnable bias steers token routing to underutilized experts
8804
+ # without requiring an auxiliary loss term
8805
+ self.expert_bias = nn.Parameter(torch.zeros(num_experts))
8806
 
8807
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8808
  batch_size, seq_len, hidden_dim = hidden_states.shape
 
8811
  hidden_norm = self.input_norm(hidden_flat)
8812
  router_logits = self.gate(hidden_norm)
8813
 
8814
+ # Add expert bias for load balancing (aux-lossless mechanism)
8815
+ biased_logits = router_logits + self.expert_bias
8816
+
8817
+ router_probs = F.softmax(biased_logits, dim=-1, dtype=hidden_states.dtype)
8818
 
8819
  top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
8820
 
 
8902
  batch_size, seq_len, hidden_size = hidden_states.shape
8903
  original_dtype = hidden_states.dtype
8904
  hidden_flat = hidden_states.view(-1, hidden_size)
8905
+ num_tokens = hidden_flat.shape[0]
8906
 
8907
+ top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
8908
+
8909
+ # Record expert utilization if tracker is attached
8910
+ if hasattr(self, '_utilization_tracker'):
8911
+ self._utilization_tracker.record(top_k_indices)
8912
 
8913
  final_output = torch.zeros_like(hidden_flat)
8914
 
 
8928
 
8929
  final_output = final_output.view(batch_size, seq_len, hidden_size)
8930
 
8931
+ # Aux-lossless: z-loss only for router logit stability
8932
+ # The expert_bias in the router handles load balancing architecturally
8933
+ aux_loss = self._compute_aux_loss(router_logits, top_k_indices, num_tokens)
8934
 
8935
  return final_output, aux_loss
8936
 
8937
+ def _compute_aux_loss(
8938
+ self,
8939
+ router_logits: torch.Tensor,
8940
+ top_k_indices: torch.Tensor,
8941
+ num_tokens: int,
8942
+ ) -> torch.Tensor:
8943
+ """
8944
+ Aux-lossless auxiliary loss.
8945
+
8946
+ Uses z-loss to keep router logits from growing unboundedly (FP16 stability),
8947
+ plus a soft utilization penalty that activates only when experts go completely
8948
+ cold. The expert_bias parameter handles routine load balancing.
8949
+ """
8950
+ # z-loss: prevents router logit explosion in FP16
8951
+ z_loss = torch.logsumexp(router_logits, dim=-1).square().mean() * 0.0001
8952
+
8953
+ # Soft utilization penalty: only penalizes fully-dead experts
8954
+ # This does NOT hurt convergence because it only activates at extremes
8955
+ expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
8956
+ tokens_per_expert = expert_mask.sum(dim=(0, 1)) # [num_experts]
8957
+ fraction_used = (tokens_per_expert > 0).float().mean()
8958
+ utilization_loss = (1.0 - fraction_used) * 0.01 # Very soft penalty
8959
+
8960
+ return z_loss + utilization_loss
8961
+
8962
 
8963
  MoELayer = AuxLosslessMoELayer
8964
 
streaming_state.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
- "epoch": 26,
3
- "unique_samples": 586,
4
- "total_yields": 1172,
5
  "dataset_positions": {
6
  "WebSight": 386,
7
  "ScienceQA": 364,
@@ -18,16 +18,16 @@
18
  "CodeParrot-Clean": 200,
19
  "ShareGPT-Clean": 200,
20
  "Synth-Issues": 200,
21
- "Dolly-15k": 200,
22
- "Conversation-Summarization": 200,
23
  "Synth-ShellTimeout": 200,
24
  "Synth-Docker": 200,
25
- "Synth-Documents": 200,
26
  "HumanEval-JavaScript": 164,
27
- "OpenOrca": 200,
28
  "Synth-MultiStepExecution": 200,
29
  "Synth-Citation": 200,
30
- "NoRobots": 200,
31
  "Synth-LanguageSetup": 200,
32
  "Function-Calling-ChatML": 200,
33
  "Synth-CoT": 200,
@@ -75,7 +75,7 @@
75
  "Synth-Debugging": 200,
76
  "Tool-Calls-SingleTurn": 200,
77
  "Tool-Calls-Multiturn": 200,
78
- "OpenAssistant": 200,
79
  "T2V-Sora-Preferences-2": 200,
80
  "T2V-Human-Preferences": 200,
81
  "Sora-Alignment-Likert": 198,
@@ -84,7 +84,9 @@
84
  "WebVid-10M": 200,
85
  "Sora-Physics-Likert": 198,
86
  "TIP-I2V": 200,
87
- "Pexels-I2V-350k": 200
 
 
88
  },
89
  "modality_positions": {
90
  "text": {
@@ -92,11 +94,11 @@
92
  "Midjourney-Prompts": 200,
93
  "CodeParrot-Clean": 200,
94
  "ShareGPT-Clean": 200,
95
- "Dolly-15k": 200,
96
- "Conversation-Summarization": 200,
97
  "HumanEval-JavaScript": 164,
98
- "OpenOrca": 200,
99
- "NoRobots": 200,
100
  "Function-Calling-ChatML": 200,
101
  "Python-Code-18k": 200,
102
  "Code-Feedback": 200,
@@ -119,7 +121,9 @@
119
  "HumanEval-Rust": 164,
120
  "Tool-Calls-SingleTurn": 200,
121
  "Tool-Calls-Multiturn": 200,
122
- "OpenAssistant": 200
 
 
123
  },
124
  "image": {
125
  "WebSight": 386,
@@ -144,9 +148,9 @@
144
  "audio": {}
145
  },
146
  "modality_counts": {
147
- "text": 0,
148
  "image": 0,
149
- "video": 586,
150
  "audio": 0
151
  },
152
  "last_modality": null
 
1
  {
2
+ "epoch": 35,
3
+ "unique_samples": 400,
4
+ "total_yields": 800,
5
  "dataset_positions": {
6
  "WebSight": 386,
7
  "ScienceQA": 364,
 
18
  "CodeParrot-Clean": 200,
19
  "ShareGPT-Clean": 200,
20
  "Synth-Issues": 200,
21
+ "Dolly-15k": 450,
22
+ "Conversation-Summarization": 450,
23
  "Synth-ShellTimeout": 200,
24
  "Synth-Docker": 200,
25
+ "Synth-Documents": 450,
26
  "HumanEval-JavaScript": 164,
27
+ "OpenOrca": 450,
28
  "Synth-MultiStepExecution": 200,
29
  "Synth-Citation": 200,
30
+ "NoRobots": 450,
31
  "Synth-LanguageSetup": 200,
32
  "Function-Calling-ChatML": 200,
33
  "Synth-CoT": 200,
 
75
  "Synth-Debugging": 200,
76
  "Tool-Calls-SingleTurn": 200,
77
  "Tool-Calls-Multiturn": 200,
78
+ "OpenAssistant": 450,
79
  "T2V-Sora-Preferences-2": 200,
80
  "T2V-Human-Preferences": 200,
81
  "Sora-Alignment-Likert": 198,
 
84
  "WebVid-10M": 200,
85
  "Sora-Physics-Likert": 198,
86
  "TIP-I2V": 200,
87
+ "Pexels-I2V-350k": 200,
88
+ "SmolTalk-OpenHermes": 250,
89
+ "SmolTalk-All": 250
90
  },
91
  "modality_positions": {
92
  "text": {
 
94
  "Midjourney-Prompts": 200,
95
  "CodeParrot-Clean": 200,
96
  "ShareGPT-Clean": 200,
97
+ "Dolly-15k": 450,
98
+ "Conversation-Summarization": 450,
99
  "HumanEval-JavaScript": 164,
100
+ "OpenOrca": 450,
101
+ "NoRobots": 450,
102
  "Function-Calling-ChatML": 200,
103
  "Python-Code-18k": 200,
104
  "Code-Feedback": 200,
 
121
  "HumanEval-Rust": 164,
122
  "Tool-Calls-SingleTurn": 200,
123
  "Tool-Calls-Multiturn": 200,
124
+ "OpenAssistant": 450,
125
+ "SmolTalk-OpenHermes": 250,
126
+ "SmolTalk-All": 250
127
  },
128
  "image": {
129
  "WebSight": 386,
 
148
  "audio": {}
149
  },
150
  "modality_counts": {
151
+ "text": 400,
152
  "image": 0,
153
+ "video": 0,
154
  "audio": 0
155
  },
156
  "last_modality": null
trainer_state.json CHANGED
@@ -1,32 +1,32 @@
1
  {
2
  "best_model_checkpoint": "/kaggle/working/xoron-final",
3
- "best_metric": 3.869171884744816,
4
- "epoch": 4,
5
- "epochs_completed": 4,
6
- "global_step": 298,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [],
10
  "logging_steps": 50,
11
- "max_steps": 298,
12
- "num_train_epochs": 4,
13
  "total_flos": 0,
14
  "train_batch_size": 1,
15
  "effective_batch_size": 16,
16
  "learning_rate": 0.0001,
17
  "max_grad_norm": 1.0,
18
  "trainable_components": [
19
- "vision",
20
- "video",
21
  "llm",
22
  "cross_attention",
23
- "video_generation",
24
  "modality_markers"
25
  ],
26
  "frozen_components": [
 
 
27
  "audio",
28
  "speech",
29
- "image_generation"
 
30
  ],
31
  "trial_name": null,
32
  "trial_params": null
 
1
  {
2
  "best_model_checkpoint": "/kaggle/working/xoron-final",
3
+ "best_metric": 6.958861378133297,
4
+ "epoch": 5,
5
+ "epochs_completed": 5,
6
+ "global_step": 250,
7
  "is_local_process_zero": true,
8
  "is_world_process_zero": true,
9
  "log_history": [],
10
  "logging_steps": 50,
11
+ "max_steps": 250,
12
+ "num_train_epochs": 5,
13
  "total_flos": 0,
14
  "train_batch_size": 1,
15
  "effective_batch_size": 16,
16
  "learning_rate": 0.0001,
17
  "max_grad_norm": 1.0,
18
  "trainable_components": [
 
 
19
  "llm",
20
  "cross_attention",
 
21
  "modality_markers"
22
  ],
23
  "frozen_components": [
24
+ "vision",
25
+ "video",
26
  "audio",
27
  "speech",
28
+ "image_generation",
29
+ "video_generation"
30
  ],
31
  "trial_name": null,
32
  "trial_params": null
training_state.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9b37a03cba59de5ddbc9ab88c301e76b8a0fa5bc81d6d471cbefe513d0699cf
3
- size 724684421
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a751ecf22021470154d58846b700d04286522c14cda7393ece31f907eff5a2c7
3
+ size 1514911851