Any-to-Any
Transformers
Safetensors
English
xoron
multimodal
Mixture of Experts
text-to-image
image editing
image to video
text-to-video
video editing
text-to-speech
speech-to-text
speech-to-speech
image-to-text
video-to-text
agentic
tool-use
flow-matching
3d-rope
titok
vidtok
dual-stream-attention
zero-shot-voice-cloning
bigvgan
snake-activation
multi-receptive-field-fusion
custom_code
Update model weights after training (epoch 5, loss 6.9589)
Browse files- audio_decoder.safetensors +1 -1
- config.json +2 -2
- configuration_xoron.py +6 -0
- cross_attention.safetensors +1 -1
- llm.safetensors +2 -2
- model.safetensors.index.json +7 -1
- modeling_xoron.py +456 -136
- streaming_state.json +21 -17
- trainer_state.json +10 -10
- training_state.pt +2 -2
audio_decoder.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1458410612
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c225077ec0e29909d0f390011f666158ae658fa3385cf8032280f5203da09cae
|
| 3 |
size 1458410612
|
config.json
CHANGED
|
@@ -49,10 +49,10 @@
|
|
| 49 |
"image_size_step": 32,
|
| 50 |
"video_min_size": 128,
|
| 51 |
"video_max_size": 320,
|
| 52 |
-
"video_base_size":
|
| 53 |
"video_size_step": 32,
|
| 54 |
"video_min_frames": 8,
|
| 55 |
-
"video_max_frames":
|
| 56 |
"video_base_frames": 16,
|
| 57 |
"video_frame_step": 4,
|
| 58 |
"multi_scale_strategy": "adaptive",
|
|
|
|
| 49 |
"image_size_step": 32,
|
| 50 |
"video_min_size": 128,
|
| 51 |
"video_max_size": 320,
|
| 52 |
+
"video_base_size": 320,
|
| 53 |
"video_size_step": 32,
|
| 54 |
"video_min_frames": 8,
|
| 55 |
+
"video_max_frames": 8,
|
| 56 |
"video_base_frames": 16,
|
| 57 |
"video_frame_step": 4,
|
| 58 |
"multi_scale_strategy": "adaptive",
|
configuration_xoron.py
CHANGED
|
@@ -213,11 +213,17 @@ class XoronConfig(PreTrainedConfig):
|
|
| 213 |
# Output path (used during training)
|
| 214 |
output_dir: str = "./xoron-model",
|
| 215 |
|
|
|
|
|
|
|
|
|
|
| 216 |
**kwargs,
|
| 217 |
):
|
| 218 |
# Call parent init
|
| 219 |
super().__init__(**kwargs)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
# Model identification
|
| 222 |
self.model_name = model_name
|
| 223 |
|
|
|
|
| 213 |
# Output path (used during training)
|
| 214 |
output_dir: str = "./xoron-model",
|
| 215 |
|
| 216 |
+
# Training Configuration
|
| 217 |
+
modality_dropout_prob: float = 0.0,
|
| 218 |
+
|
| 219 |
**kwargs,
|
| 220 |
):
|
| 221 |
# Call parent init
|
| 222 |
super().__init__(**kwargs)
|
| 223 |
|
| 224 |
+
# Training Configuration
|
| 225 |
+
self.modality_dropout_prob = modality_dropout_prob
|
| 226 |
+
|
| 227 |
# Model identification
|
| 228 |
self.model_name = model_name
|
| 229 |
|
cross_attention.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 174191400
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5dc29d69984df0e49cf508c56c03b7a18a7a49baf89a414fa3128513d753e7e
|
| 3 |
size 174191400
|
llm.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5de86313a868d4108f814a3debd9d1ed31dc72281458ef9c7824b9a4398ce28f
|
| 3 |
+
size 1506832040
|
model.safetensors.index.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
-
"total_size":
|
| 4 |
"format": "components"
|
| 5 |
},
|
| 6 |
"weight_map": {
|
|
@@ -36,6 +36,7 @@
|
|
| 36 |
"llm.model.layers.1.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 37 |
"llm.model.layers.1.input_layernorm.weight": "llm.safetensors",
|
| 38 |
"llm.model.layers.1.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 39 |
"llm.model.layers.1.mlp.router.input_norm.weight": "llm.safetensors",
|
| 40 |
"llm.model.layers.1.mlp.router.gate.weight": "llm.safetensors",
|
| 41 |
"llm.model.layers.1.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
@@ -150,6 +151,7 @@
|
|
| 150 |
"llm.model.layers.3.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 151 |
"llm.model.layers.3.input_layernorm.weight": "llm.safetensors",
|
| 152 |
"llm.model.layers.3.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 153 |
"llm.model.layers.3.mlp.router.input_norm.weight": "llm.safetensors",
|
| 154 |
"llm.model.layers.3.mlp.router.gate.weight": "llm.safetensors",
|
| 155 |
"llm.model.layers.3.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
@@ -264,6 +266,7 @@
|
|
| 264 |
"llm.model.layers.5.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 265 |
"llm.model.layers.5.input_layernorm.weight": "llm.safetensors",
|
| 266 |
"llm.model.layers.5.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 267 |
"llm.model.layers.5.mlp.router.input_norm.weight": "llm.safetensors",
|
| 268 |
"llm.model.layers.5.mlp.router.gate.weight": "llm.safetensors",
|
| 269 |
"llm.model.layers.5.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
@@ -378,6 +381,7 @@
|
|
| 378 |
"llm.model.layers.7.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 379 |
"llm.model.layers.7.input_layernorm.weight": "llm.safetensors",
|
| 380 |
"llm.model.layers.7.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 381 |
"llm.model.layers.7.mlp.router.input_norm.weight": "llm.safetensors",
|
| 382 |
"llm.model.layers.7.mlp.router.gate.weight": "llm.safetensors",
|
| 383 |
"llm.model.layers.7.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
@@ -492,6 +496,7 @@
|
|
| 492 |
"llm.model.layers.9.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 493 |
"llm.model.layers.9.input_layernorm.weight": "llm.safetensors",
|
| 494 |
"llm.model.layers.9.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 495 |
"llm.model.layers.9.mlp.router.input_norm.weight": "llm.safetensors",
|
| 496 |
"llm.model.layers.9.mlp.router.gate.weight": "llm.safetensors",
|
| 497 |
"llm.model.layers.9.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
@@ -606,6 +611,7 @@
|
|
| 606 |
"llm.model.layers.11.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 607 |
"llm.model.layers.11.input_layernorm.weight": "llm.safetensors",
|
| 608 |
"llm.model.layers.11.post_attention_layernorm.weight": "llm.safetensors",
|
|
|
|
| 609 |
"llm.model.layers.11.mlp.router.input_norm.weight": "llm.safetensors",
|
| 610 |
"llm.model.layers.11.mlp.router.gate.weight": "llm.safetensors",
|
| 611 |
"llm.model.layers.11.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 1 |
{
|
| 2 |
"metadata": {
|
| 3 |
+
"total_size": 7309365134,
|
| 4 |
"format": "components"
|
| 5 |
},
|
| 6 |
"weight_map": {
|
|
|
|
| 36 |
"llm.model.layers.1.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 37 |
"llm.model.layers.1.input_layernorm.weight": "llm.safetensors",
|
| 38 |
"llm.model.layers.1.post_attention_layernorm.weight": "llm.safetensors",
|
| 39 |
+
"llm.model.layers.1.mlp.router.expert_bias": "llm.safetensors",
|
| 40 |
"llm.model.layers.1.mlp.router.input_norm.weight": "llm.safetensors",
|
| 41 |
"llm.model.layers.1.mlp.router.gate.weight": "llm.safetensors",
|
| 42 |
"llm.model.layers.1.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 151 |
"llm.model.layers.3.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 152 |
"llm.model.layers.3.input_layernorm.weight": "llm.safetensors",
|
| 153 |
"llm.model.layers.3.post_attention_layernorm.weight": "llm.safetensors",
|
| 154 |
+
"llm.model.layers.3.mlp.router.expert_bias": "llm.safetensors",
|
| 155 |
"llm.model.layers.3.mlp.router.input_norm.weight": "llm.safetensors",
|
| 156 |
"llm.model.layers.3.mlp.router.gate.weight": "llm.safetensors",
|
| 157 |
"llm.model.layers.3.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 266 |
"llm.model.layers.5.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 267 |
"llm.model.layers.5.input_layernorm.weight": "llm.safetensors",
|
| 268 |
"llm.model.layers.5.post_attention_layernorm.weight": "llm.safetensors",
|
| 269 |
+
"llm.model.layers.5.mlp.router.expert_bias": "llm.safetensors",
|
| 270 |
"llm.model.layers.5.mlp.router.input_norm.weight": "llm.safetensors",
|
| 271 |
"llm.model.layers.5.mlp.router.gate.weight": "llm.safetensors",
|
| 272 |
"llm.model.layers.5.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 381 |
"llm.model.layers.7.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 382 |
"llm.model.layers.7.input_layernorm.weight": "llm.safetensors",
|
| 383 |
"llm.model.layers.7.post_attention_layernorm.weight": "llm.safetensors",
|
| 384 |
+
"llm.model.layers.7.mlp.router.expert_bias": "llm.safetensors",
|
| 385 |
"llm.model.layers.7.mlp.router.input_norm.weight": "llm.safetensors",
|
| 386 |
"llm.model.layers.7.mlp.router.gate.weight": "llm.safetensors",
|
| 387 |
"llm.model.layers.7.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 496 |
"llm.model.layers.9.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 497 |
"llm.model.layers.9.input_layernorm.weight": "llm.safetensors",
|
| 498 |
"llm.model.layers.9.post_attention_layernorm.weight": "llm.safetensors",
|
| 499 |
+
"llm.model.layers.9.mlp.router.expert_bias": "llm.safetensors",
|
| 500 |
"llm.model.layers.9.mlp.router.input_norm.weight": "llm.safetensors",
|
| 501 |
"llm.model.layers.9.mlp.router.gate.weight": "llm.safetensors",
|
| 502 |
"llm.model.layers.9.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
|
|
|
| 611 |
"llm.model.layers.11.self_attn.o_proj.linear.weight": "llm.safetensors",
|
| 612 |
"llm.model.layers.11.input_layernorm.weight": "llm.safetensors",
|
| 613 |
"llm.model.layers.11.post_attention_layernorm.weight": "llm.safetensors",
|
| 614 |
+
"llm.model.layers.11.mlp.router.expert_bias": "llm.safetensors",
|
| 615 |
"llm.model.layers.11.mlp.router.input_norm.weight": "llm.safetensors",
|
| 616 |
"llm.model.layers.11.mlp.router.gate.weight": "llm.safetensors",
|
| 617 |
"llm.model.layers.11.mlp.experts.0.gate_proj.lora_A": "llm.safetensors",
|
modeling_xoron.py
CHANGED
|
@@ -436,18 +436,25 @@ def compute_qk_scale(head_dim: int) -> float:
|
|
| 436 |
return head_dim ** -0.25
|
| 437 |
|
| 438 |
|
| 439 |
-
@dataclass
|
| 440 |
class AttentionKVCache:
|
| 441 |
-
"""
|
| 442 |
-
KV Cache for efficient autoregressive attention.
|
| 443 |
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
- Support for cross-attention caching
|
| 447 |
"""
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
def update(
|
| 453 |
self,
|
|
@@ -455,36 +462,45 @@ class AttentionKVCache:
|
|
| 455 |
value_states: torch.Tensor,
|
| 456 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 457 |
"""
|
| 458 |
-
Update cache with new key/value states.
|
| 459 |
|
| 460 |
Args:
|
| 461 |
key_states: New key states [batch, num_heads, seq_len, head_dim]
|
| 462 |
value_states: New value states [batch, num_heads, seq_len, head_dim]
|
| 463 |
|
| 464 |
Returns:
|
| 465 |
-
Updated key and value states including cache
|
| 466 |
"""
|
| 467 |
-
|
| 468 |
-
self.key_cache = key_states
|
| 469 |
-
self.value_cache = value_states
|
| 470 |
-
else:
|
| 471 |
-
self.key_cache = torch.cat([self.key_cache, key_states], dim=2)
|
| 472 |
-
self.value_cache = torch.cat([self.value_cache, value_states], dim=2)
|
| 473 |
-
|
| 474 |
-
self.seen_tokens += key_states.shape[2]
|
| 475 |
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
def get_seq_length(self) -> int:
|
| 479 |
"""Get current sequence length in cache."""
|
| 480 |
-
|
| 481 |
-
return 0
|
| 482 |
-
return self.key_cache.shape[2]
|
| 483 |
|
| 484 |
def reset(self):
|
| 485 |
-
"""Reset the
|
| 486 |
-
self.key_cache = None
|
| 487 |
-
self.value_cache = None
|
| 488 |
self.seen_tokens = 0
|
| 489 |
|
| 490 |
|
|
@@ -1447,12 +1463,12 @@ class PerceiverAttention(nn.Module):
|
|
| 1447 |
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 1448 |
k = apply_rope(k, cos, sin)
|
| 1449 |
|
| 1450 |
-
# Attention
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
out = out.transpose(1, 2).reshape(b, n, self.inner_dim)
|
| 1457 |
|
| 1458 |
return self.to_out(out)
|
|
@@ -1848,6 +1864,171 @@ class MultimodalProjector(nn.Module):
|
|
| 1848 |
EPS = 1e-5
|
| 1849 |
|
| 1850 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1851 |
class MoERouter(nn.Module):
|
| 1852 |
"""
|
| 1853 |
SOTA Router for Mixture of Experts v2.0 - FP16 native.
|
|
@@ -2036,6 +2217,10 @@ class MoELayer(nn.Module):
|
|
| 2036 |
|
| 2037 |
top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
|
| 2038 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2039 |
final_output = torch.zeros_like(hidden_flat)
|
| 2040 |
|
| 2041 |
for expert_idx in range(self.num_experts):
|
|
@@ -4726,21 +4911,23 @@ class RotaryMultiHeadLatentAttention(nn.Module):
|
|
| 4726 |
|
| 4727 |
present_key_value = (key, value) if use_cache else None
|
| 4728 |
|
| 4729 |
-
#
|
| 4730 |
-
|
| 4731 |
-
|
| 4732 |
-
|
| 4733 |
-
|
| 4734 |
-
|
| 4735 |
-
|
| 4736 |
-
|
| 4737 |
-
|
| 4738 |
-
|
| 4739 |
-
|
| 4740 |
-
|
| 4741 |
-
|
| 4742 |
-
|
| 4743 |
-
|
|
|
|
|
|
|
| 4744 |
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 4745 |
output = self.o_proj(output)
|
| 4746 |
|
|
@@ -6719,6 +6906,7 @@ class DualStreamSelfAttention(nn.Module):
|
|
| 6719 |
"""
|
| 6720 |
Symmetric Dual-Stream Self-Attention (SD3/Flux-style).
|
| 6721 |
Two parallel streams with cross-stream information exchange.
|
|
|
|
| 6722 |
"""
|
| 6723 |
|
| 6724 |
def __init__(self, hidden_size: int, num_heads: int = 8, max_height: int = 64, max_width: int = 64):
|
|
@@ -6727,6 +6915,8 @@ class DualStreamSelfAttention(nn.Module):
|
|
| 6727 |
self.num_heads = num_heads
|
| 6728 |
self.head_dim = hidden_size // num_heads
|
| 6729 |
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
|
| 6730 |
|
| 6731 |
self.to_qkv_a = nn.Linear(hidden_size, hidden_size * 3, bias=False)
|
| 6732 |
self.to_qkv_b = nn.Linear(hidden_size, hidden_size * 3, bias=False)
|
|
@@ -6752,8 +6942,6 @@ class DualStreamSelfAttention(nn.Module):
|
|
| 6752 |
q_b, k_b, v_b = qkv_b.unbind(dim=2)
|
| 6753 |
|
| 6754 |
cos, sin = self.rope_2d(x_a, height, width)
|
| 6755 |
-
# cos/sin shape: [seq_len, head_dim] -> [1, 1, seq_len, head_dim]
|
| 6756 |
-
# to broadcast with q/k shape: [B, num_heads, seq_len, head_dim]
|
| 6757 |
cos = cos.unsqueeze(0).unsqueeze(1)
|
| 6758 |
sin = sin.unsqueeze(0).unsqueeze(1)
|
| 6759 |
|
|
@@ -6772,13 +6960,15 @@ class DualStreamSelfAttention(nn.Module):
|
|
| 6772 |
k_combined = torch.cat([k_a, k_b], dim=2)
|
| 6773 |
v_combined = torch.cat([v_a, v_b], dim=2)
|
| 6774 |
|
| 6775 |
-
|
| 6776 |
-
|
| 6777 |
-
|
| 6778 |
-
|
| 6779 |
-
|
| 6780 |
-
|
| 6781 |
-
|
|
|
|
|
|
|
| 6782 |
|
| 6783 |
out_a = out_a.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
|
| 6784 |
out_b = out_b.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
|
|
@@ -6815,10 +7005,12 @@ class CrossAttention(nn.Module):
|
|
| 6815 |
k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 6816 |
v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 6817 |
|
| 6818 |
-
|
| 6819 |
-
|
| 6820 |
-
|
| 6821 |
-
|
|
|
|
|
|
|
| 6822 |
out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 6823 |
out = self.to_out(out)
|
| 6824 |
|
|
@@ -6984,12 +7176,18 @@ class MoEDiT(nn.Module):
|
|
| 6984 |
self.final_norm = nn.LayerNorm(hidden_size)
|
| 6985 |
self.unpatch_embed = UnpatchEmbed(patch_size, out_channels, hidden_size)
|
| 6986 |
|
|
|
|
|
|
|
| 6987 |
self._init_weights()
|
| 6988 |
|
| 6989 |
def _init_weights(self):
|
| 6990 |
nn.init.zeros_(self.unpatch_embed.proj.weight)
|
| 6991 |
nn.init.zeros_(self.unpatch_embed.proj.bias)
|
| 6992 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6993 |
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 6994 |
batch_size, channels, height, width = x.shape
|
| 6995 |
patch_height = height // self.patch_size
|
|
@@ -7010,7 +7208,13 @@ class MoEDiT(nn.Module):
|
|
| 7010 |
x_b = x_patches.clone()
|
| 7011 |
|
| 7012 |
for block in self.blocks:
|
| 7013 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7014 |
|
| 7015 |
x_combined = (x_a + x_b) / 2
|
| 7016 |
x_combined = self.final_norm(x_combined)
|
|
@@ -7518,11 +7722,12 @@ class SpatialAttention(nn.Module):
|
|
| 7518 |
q = apply_rope(q, cos, sin)
|
| 7519 |
k = apply_rope(k, cos, sin)
|
| 7520 |
|
| 7521 |
-
# Attention
|
| 7522 |
-
|
| 7523 |
-
|
| 7524 |
-
|
| 7525 |
-
|
|
|
|
| 7526 |
out = out.transpose(1, 2).reshape(batch_size * frames, spatial_len, self.hidden_size)
|
| 7527 |
out = self.to_out(out)
|
| 7528 |
|
|
@@ -7573,16 +7778,12 @@ class TemporalAttention(nn.Module):
|
|
| 7573 |
q = apply_rope(q, cos, sin)
|
| 7574 |
k = apply_rope(k, cos, sin)
|
| 7575 |
|
| 7576 |
-
# Attention
|
| 7577 |
-
|
| 7578 |
-
|
| 7579 |
-
|
| 7580 |
-
|
| 7581 |
-
|
| 7582 |
-
|
| 7583 |
-
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype)
|
| 7584 |
-
|
| 7585 |
-
out = torch.matmul(attn, v)
|
| 7586 |
out = out.transpose(1, 2).reshape(batch_size * spatial_len, frames, self.hidden_size)
|
| 7587 |
|
| 7588 |
# Reshape back to [B, T*H*W, hidden]
|
|
@@ -7647,10 +7848,12 @@ class CrossAttention3D(nn.Module):
|
|
| 7647 |
k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7648 |
v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7649 |
|
| 7650 |
-
|
| 7651 |
-
|
| 7652 |
-
|
| 7653 |
-
|
|
|
|
|
|
|
| 7654 |
out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 7655 |
out = self.to_out(out)
|
| 7656 |
|
|
@@ -7768,6 +7971,12 @@ class VideoUNet3D(nn.Module):
|
|
| 7768 |
|
| 7769 |
nn.init.zeros_(self.output_proj[-1].weight)
|
| 7770 |
nn.init.zeros_(self.output_proj[-1].bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7771 |
|
| 7772 |
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, first_frame_latent: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 7773 |
batch_size, channels, frames, height, width = x.shape
|
|
@@ -7786,7 +7995,13 @@ class VideoUNet3D(nn.Module):
|
|
| 7786 |
temporal_context = t_emb.unsqueeze(1).expand(-1, frames * height * width, -1)
|
| 7787 |
|
| 7788 |
for block in self.transformer_blocks:
|
| 7789 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7790 |
|
| 7791 |
h = h.reshape(batch_size, frames, height, width, self.hidden_size).permute(0, 4, 1, 2, 3)
|
| 7792 |
|
|
@@ -8260,12 +8475,31 @@ def apply_rotary_pos_emb(
|
|
| 8260 |
return q_embed, k_embed
|
| 8261 |
|
| 8262 |
|
| 8263 |
-
@dataclass
|
| 8264 |
class KVCache:
|
| 8265 |
-
"""KV Cache
|
| 8266 |
-
|
| 8267 |
-
|
| 8268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8269 |
|
| 8270 |
def update(
|
| 8271 |
self,
|
|
@@ -8273,20 +8507,43 @@ class KVCache:
|
|
| 8273 |
value_states: torch.Tensor,
|
| 8274 |
chunk_size: Optional[int] = None,
|
| 8275 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 8276 |
-
|
| 8277 |
-
self.key_cache = key_states
|
| 8278 |
-
self.value_cache = value_states
|
| 8279 |
-
else:
|
| 8280 |
-
self.key_cache = torch.cat([self.key_cache, key_states], dim=2)
|
| 8281 |
-
self.value_cache = torch.cat([self.value_cache, value_states], dim=2)
|
| 8282 |
|
| 8283 |
-
|
| 8284 |
-
|
| 8285 |
-
|
| 8286 |
-
self.
|
| 8287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8288 |
|
| 8289 |
-
|
|
|
|
|
|
|
| 8290 |
|
| 8291 |
|
| 8292 |
def ring_attention(
|
|
@@ -8298,7 +8555,7 @@ def ring_attention(
|
|
| 8298 |
) -> torch.Tensor:
|
| 8299 |
"""
|
| 8300 |
Ring Attention for distributed long-context processing.
|
| 8301 |
-
Processes sequence in chunks with
|
| 8302 |
|
| 8303 |
Args:
|
| 8304 |
query: [batch, heads, seq_len, head_dim]
|
|
@@ -8312,24 +8569,45 @@ def ring_attention(
|
|
| 8312 |
"""
|
| 8313 |
batch_size, num_heads, seq_len, head_dim = query.shape
|
| 8314 |
kv_len = key.shape[2]
|
| 8315 |
-
scale = head_dim ** -0.5
|
| 8316 |
|
|
|
|
| 8317 |
if seq_len <= chunk_size and kv_len <= chunk_size:
|
| 8318 |
-
|
| 8319 |
-
|
| 8320 |
-
if causal:
|
| 8321 |
-
causal_mask = torch.triu(torch.ones(seq_len, kv_len, device=query.device, dtype=torch.bool), diagonal=1)
|
| 8322 |
-
if kv_len > seq_len:
|
| 8323 |
-
causal_mask = causal_mask[:, -seq_len:]
|
| 8324 |
-
attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
|
| 8325 |
|
| 8326 |
-
|
| 8327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8328 |
|
|
|
|
|
|
|
| 8329 |
output = torch.zeros_like(query)
|
| 8330 |
max_logits = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=query.device, dtype=query.dtype)
|
| 8331 |
sum_exp = torch.zeros((batch_size, num_heads, seq_len, 1), device=query.device, dtype=query.dtype)
|
| 8332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8333 |
num_kv_chunks = (kv_len + chunk_size - 1) // chunk_size
|
| 8334 |
|
| 8335 |
for kv_idx in range(num_kv_chunks):
|
|
@@ -8342,13 +8620,11 @@ def ring_attention(
|
|
| 8342 |
attn_chunk = torch.matmul(query, key_chunk.transpose(-1, -2)) * scale
|
| 8343 |
|
| 8344 |
if causal:
|
| 8345 |
-
|
| 8346 |
-
|
| 8347 |
-
|
| 8348 |
-
|
| 8349 |
-
|
| 8350 |
-
if k_pos > q_pos:
|
| 8351 |
-
attn_chunk[:, :, q_idx, k_idx] = float('-inf')
|
| 8352 |
|
| 8353 |
chunk_max = attn_chunk.max(dim=-1, keepdim=True)[0]
|
| 8354 |
new_max = torch.maximum(max_logits, chunk_max)
|
|
@@ -8465,31 +8741,35 @@ class MultiHeadLatentAttention(nn.Module):
|
|
| 8465 |
self.ring_chunk_size if self.use_ring_attention else None
|
| 8466 |
)
|
| 8467 |
|
| 8468 |
-
if self.num_key_value_groups > 1:
|
| 8469 |
-
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 8470 |
-
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 8471 |
-
|
| 8472 |
if self.use_ring_attention:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8473 |
attn_output = ring_attention(
|
| 8474 |
-
query_states,
|
| 8475 |
chunk_size=self.ring_chunk_size,
|
| 8476 |
causal=True,
|
| 8477 |
)
|
| 8478 |
else:
|
| 8479 |
-
|
| 8480 |
-
|
|
|
|
| 8481 |
kv_len = key_states.shape[2]
|
| 8482 |
-
|
| 8483 |
-
|
| 8484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8485 |
)
|
| 8486 |
-
attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
|
| 8487 |
-
|
| 8488 |
-
if attention_mask is not None:
|
| 8489 |
-
attn_weights = attn_weights + attention_mask
|
| 8490 |
-
|
| 8491 |
-
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
|
| 8492 |
-
attn_output = torch.matmul(attn_weights, value_states)
|
| 8493 |
|
| 8494 |
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 8495 |
attn_output = self.o_proj(attn_output)
|
|
@@ -8518,6 +8798,11 @@ class AuxLosslessMoERouter(nn.Module):
|
|
| 8518 |
self.input_norm = LlamaRMSNorm(hidden_size)
|
| 8519 |
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
|
| 8520 |
nn.init.normal_(self.gate.weight, mean=0.0, std=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8521 |
|
| 8522 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 8523 |
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
@@ -8526,7 +8811,10 @@ class AuxLosslessMoERouter(nn.Module):
|
|
| 8526 |
hidden_norm = self.input_norm(hidden_flat)
|
| 8527 |
router_logits = self.gate(hidden_norm)
|
| 8528 |
|
| 8529 |
-
|
|
|
|
|
|
|
|
|
|
| 8530 |
|
| 8531 |
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 8532 |
|
|
@@ -8614,8 +8902,13 @@ class AuxLosslessMoELayer(nn.Module):
|
|
| 8614 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 8615 |
original_dtype = hidden_states.dtype
|
| 8616 |
hidden_flat = hidden_states.view(-1, hidden_size)
|
|
|
|
| 8617 |
|
| 8618 |
-
top_k_probs, top_k_indices,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8619 |
|
| 8620 |
final_output = torch.zeros_like(hidden_flat)
|
| 8621 |
|
|
@@ -8635,10 +8928,37 @@ class AuxLosslessMoELayer(nn.Module):
|
|
| 8635 |
|
| 8636 |
final_output = final_output.view(batch_size, seq_len, hidden_size)
|
| 8637 |
|
| 8638 |
-
|
|
|
|
|
|
|
| 8639 |
|
| 8640 |
return final_output, aux_loss
|
| 8641 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8642 |
|
| 8643 |
MoELayer = AuxLosslessMoELayer
|
| 8644 |
|
|
|
|
| 436 |
return head_dim ** -0.25
|
| 437 |
|
| 438 |
|
|
|
|
| 439 |
class AttentionKVCache:
|
| 440 |
+
"""Pre-allocated KV Cache — static buffer with index-based filling.
|
|
|
|
| 441 |
|
| 442 |
+
Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
|
| 443 |
+
Buffer is allocated once at first use and reused via slice assignment.
|
|
|
|
| 444 |
"""
|
| 445 |
+
|
| 446 |
+
__slots__ = ('key_cache', 'value_cache', 'seen_tokens', '_max_len')
|
| 447 |
+
|
| 448 |
+
def __init__(self, max_seq_len: int = 131072):
|
| 449 |
+
self.key_cache: torch.Tensor = None
|
| 450 |
+
self.value_cache: torch.Tensor = None
|
| 451 |
+
self.seen_tokens: int = 0
|
| 452 |
+
self._max_len = max_seq_len
|
| 453 |
+
|
| 454 |
+
def _allocate(self, batch: int, heads: int, head_dim: int, device: torch.device, dtype: torch.dtype):
|
| 455 |
+
"""Allocate static buffer on first use."""
|
| 456 |
+
self.key_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
|
| 457 |
+
self.value_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
|
| 458 |
|
| 459 |
def update(
|
| 460 |
self,
|
|
|
|
| 462 |
value_states: torch.Tensor,
|
| 463 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 464 |
"""
|
| 465 |
+
Update cache with new key/value states using index-based filling.
|
| 466 |
|
| 467 |
Args:
|
| 468 |
key_states: New key states [batch, num_heads, seq_len, head_dim]
|
| 469 |
value_states: New value states [batch, num_heads, seq_len, head_dim]
|
| 470 |
|
| 471 |
Returns:
|
| 472 |
+
Updated key and value states including cache (views, no copy)
|
| 473 |
"""
|
| 474 |
+
batch, heads, new_len, head_dim = key_states.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
+
if self.key_cache is None:
|
| 477 |
+
self._allocate(batch, heads, head_dim, key_states.device, key_states.dtype)
|
| 478 |
+
self.seen_tokens = 0
|
| 479 |
+
|
| 480 |
+
# Grow buffer if needed (rare fallback)
|
| 481 |
+
if self.seen_tokens + new_len > self.key_cache.shape[2]:
|
| 482 |
+
new_max = max(self.key_cache.shape[2] * 2, self.seen_tokens + new_len)
|
| 483 |
+
new_key = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
|
| 484 |
+
new_val = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
|
| 485 |
+
new_key[:, :, :self.seen_tokens] = self.key_cache[:, :, :self.seen_tokens]
|
| 486 |
+
new_val[:, :, :self.seen_tokens] = self.value_cache[:, :, :self.seen_tokens]
|
| 487 |
+
self.key_cache = new_key
|
| 488 |
+
self.value_cache = new_val
|
| 489 |
+
|
| 490 |
+
# Index-based fill — no allocation, no fragmentation
|
| 491 |
+
self.key_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = key_states
|
| 492 |
+
self.value_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = value_states
|
| 493 |
+
self.seen_tokens += new_len
|
| 494 |
+
|
| 495 |
+
# Return valid slice (view, no copy)
|
| 496 |
+
return self.key_cache[:, :, :self.seen_tokens], self.value_cache[:, :, :self.seen_tokens]
|
| 497 |
|
| 498 |
def get_seq_length(self) -> int:
|
| 499 |
"""Get current sequence length in cache."""
|
| 500 |
+
return self.seen_tokens
|
|
|
|
|
|
|
| 501 |
|
| 502 |
def reset(self):
|
| 503 |
+
"""Reset cache position without deallocating the buffer."""
|
|
|
|
|
|
|
| 504 |
self.seen_tokens = 0
|
| 505 |
|
| 506 |
|
|
|
|
| 1463 |
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 1464 |
k = apply_rope(k, cos, sin)
|
| 1465 |
|
| 1466 |
+
# Flash Attention 2.0 via SDPA — FP16-safe with Q/K pre-scaling
|
| 1467 |
+
qk_scale = d ** -0.25
|
| 1468 |
+
out = F.scaled_dot_product_attention(
|
| 1469 |
+
q * qk_scale, k * qk_scale, v,
|
| 1470 |
+
is_causal=False, scale=1.0,
|
| 1471 |
+
)
|
| 1472 |
out = out.transpose(1, 2).reshape(b, n, self.inner_dim)
|
| 1473 |
|
| 1474 |
return self.to_out(out)
|
|
|
|
| 1864 |
EPS = 1e-5
|
| 1865 |
|
| 1866 |
|
| 1867 |
+
class ExpertUtilizationTracker:
|
| 1868 |
+
"""
|
| 1869 |
+
Tracks expert utilization across MoE layers.
|
| 1870 |
+
|
| 1871 |
+
Attach to any MoE layer to log per-expert usage histograms.
|
| 1872 |
+
Every `report_interval` steps, prints a report showing:
|
| 1873 |
+
- Frequency of use per expert
|
| 1874 |
+
- Cold experts (used < 1% of tokens)
|
| 1875 |
+
- Count of experts offloaded to CPU (if ExpertOffloadManager is available)
|
| 1876 |
+
|
| 1877 |
+
Usage:
|
| 1878 |
+
tracker = ExpertUtilizationTracker(num_experts=8, layer_name="layer.3.moe")
|
| 1879 |
+
# In forward: tracker.record(top_k_indices)
|
| 1880 |
+
# Every N steps: tracker.step() (auto-prints when interval hit)
|
| 1881 |
+
"""
|
| 1882 |
+
|
| 1883 |
+
def __init__(
|
| 1884 |
+
self,
|
| 1885 |
+
num_experts: int,
|
| 1886 |
+
layer_name: str = "moe",
|
| 1887 |
+
report_interval: int = 100,
|
| 1888 |
+
cold_threshold_pct: float = 1.0,
|
| 1889 |
+
):
|
| 1890 |
+
self.num_experts = num_experts
|
| 1891 |
+
self.layer_name = layer_name
|
| 1892 |
+
self.report_interval = report_interval
|
| 1893 |
+
self.cold_threshold_pct = cold_threshold_pct
|
| 1894 |
+
|
| 1895 |
+
self._counts = torch.zeros(num_experts, dtype=torch.long)
|
| 1896 |
+
self._total_tokens = 0
|
| 1897 |
+
self._step = 0
|
| 1898 |
+
self._offload_manager = None # Link to ExpertOffloadManager if available
|
| 1899 |
+
|
| 1900 |
+
def link_offload_manager(self, manager):
|
| 1901 |
+
"""Link an ExpertOffloadManager for cold-expert reporting."""
|
| 1902 |
+
self._offload_manager = manager
|
| 1903 |
+
|
| 1904 |
+
def record(self, expert_indices: torch.Tensor):
|
| 1905 |
+
"""
|
| 1906 |
+
Record expert selections from a forward pass.
|
| 1907 |
+
|
| 1908 |
+
Args:
|
| 1909 |
+
expert_indices: [num_tokens, top_k] tensor of selected expert indices
|
| 1910 |
+
"""
|
| 1911 |
+
indices_flat = expert_indices.detach().cpu().reshape(-1)
|
| 1912 |
+
for idx in range(self.num_experts):
|
| 1913 |
+
self._counts[idx] += (indices_flat == idx).sum().item()
|
| 1914 |
+
self._total_tokens += expert_indices.shape[0]
|
| 1915 |
+
|
| 1916 |
+
def step(self):
|
| 1917 |
+
"""Advance step counter. Prints report and resets when interval is hit."""
|
| 1918 |
+
self._step += 1
|
| 1919 |
+
if self._step % self.report_interval == 0:
|
| 1920 |
+
self._print_report()
|
| 1921 |
+
self._reset()
|
| 1922 |
+
|
| 1923 |
+
def _reset(self):
|
| 1924 |
+
"""Reset accumulators for next interval."""
|
| 1925 |
+
self._counts.zero_()
|
| 1926 |
+
self._total_tokens = 0
|
| 1927 |
+
|
| 1928 |
+
def _print_report(self):
|
| 1929 |
+
"""Print expert utilization histogram."""
|
| 1930 |
+
if self._total_tokens == 0:
|
| 1931 |
+
return
|
| 1932 |
+
|
| 1933 |
+
freqs = self._counts.float()
|
| 1934 |
+
total_assignments = freqs.sum().item()
|
| 1935 |
+
if total_assignments == 0:
|
| 1936 |
+
return
|
| 1937 |
+
|
| 1938 |
+
pcts = (freqs / total_assignments * 100).tolist()
|
| 1939 |
+
|
| 1940 |
+
# Identify cold experts
|
| 1941 |
+
cold_experts = [i for i, p in enumerate(pcts) if p < self.cold_threshold_pct]
|
| 1942 |
+
|
| 1943 |
+
# Build histogram
|
| 1944 |
+
max_pct = max(pcts) if pcts else 0
|
| 1945 |
+
bar_max = 30 # max bar width
|
| 1946 |
+
|
| 1947 |
+
lines = [f"\n{'='*60}"]
|
| 1948 |
+
lines.append(f" Expert Utilization — {self.layer_name} (step {self._step})")
|
| 1949 |
+
lines.append(f" {self._total_tokens:,} tokens, {int(total_assignments):,} assignments")
|
| 1950 |
+
lines.append(f"{'─'*60}")
|
| 1951 |
+
|
| 1952 |
+
for i, pct in enumerate(pcts):
|
| 1953 |
+
bar_len = int(pct / max_pct * bar_max) if max_pct > 0 else 0
|
| 1954 |
+
bar = "█" * bar_len
|
| 1955 |
+
cold_tag = " ❄️" if pct < self.cold_threshold_pct else ""
|
| 1956 |
+
lines.append(f" Expert {i:2d} │{bar:<{bar_max}}│ {pct:5.1f}% ({int(self._counts[i]):>6d}){cold_tag}")
|
| 1957 |
+
|
| 1958 |
+
lines.append(f"{'─'*60}")
|
| 1959 |
+
|
| 1960 |
+
if cold_experts:
|
| 1961 |
+
lines.append(f" ❄️ Cold experts (<{self.cold_threshold_pct}%): {cold_experts}")
|
| 1962 |
+
else:
|
| 1963 |
+
lines.append(f" ✅ All experts active (no cold experts)")
|
| 1964 |
+
|
| 1965 |
+
# Report offloaded experts if manager linked
|
| 1966 |
+
if self._offload_manager is not None:
|
| 1967 |
+
status = self._offload_manager.get_status()
|
| 1968 |
+
lines.append(f" 💾 Offloaded to CPU: {status['cpu']}/{status['total']}")
|
| 1969 |
+
|
| 1970 |
+
# Compute load balance score (1.0 = perfectly balanced)
|
| 1971 |
+
ideal_pct = 100.0 / self.num_experts
|
| 1972 |
+
balance = 1.0 - (sum(abs(p - ideal_pct) for p in pcts) / (2 * 100))
|
| 1973 |
+
lines.append(f" ⚖️ Load balance score: {balance:.3f} (1.0 = perfect)")
|
| 1974 |
+
|
| 1975 |
+
lines.append(f"{'='*60}")
|
| 1976 |
+
print("\n".join(lines))
|
| 1977 |
+
|
| 1978 |
+
def get_stats(self) -> dict:
|
| 1979 |
+
"""Return current stats as a dict (for programmatic access)."""
|
| 1980 |
+
total = self._counts.sum().item()
|
| 1981 |
+
if total == 0:
|
| 1982 |
+
pcts = [0.0] * self.num_experts
|
| 1983 |
+
else:
|
| 1984 |
+
pcts = (self._counts.float() / total * 100).tolist()
|
| 1985 |
+
|
| 1986 |
+
cold = [i for i, p in enumerate(pcts) if p < self.cold_threshold_pct]
|
| 1987 |
+
ideal_pct = 100.0 / self.num_experts
|
| 1988 |
+
balance = 1.0 - (sum(abs(p - ideal_pct) for p in pcts) / (2 * 100)) if total > 0 else 0.0
|
| 1989 |
+
|
| 1990 |
+
return {
|
| 1991 |
+
"step": self._step,
|
| 1992 |
+
"layer_name": self.layer_name,
|
| 1993 |
+
"total_tokens": self._total_tokens,
|
| 1994 |
+
"expert_counts": self._counts.tolist(),
|
| 1995 |
+
"expert_pcts": pcts,
|
| 1996 |
+
"cold_experts": cold,
|
| 1997 |
+
"balance_score": balance,
|
| 1998 |
+
}
|
| 1999 |
+
|
| 2000 |
+
|
| 2001 |
+
def attach_utilization_trackers(
|
| 2002 |
+
model: torch.nn.Module,
|
| 2003 |
+
report_interval: int = 100,
|
| 2004 |
+
) -> list:
|
| 2005 |
+
"""
|
| 2006 |
+
Find all MoE layers in a model and attach ExpertUtilizationTrackers.
|
| 2007 |
+
|
| 2008 |
+
Returns list of trackers for manual step() calls in the training loop.
|
| 2009 |
+
"""
|
| 2010 |
+
trackers = []
|
| 2011 |
+
for name, module in model.named_modules():
|
| 2012 |
+
if hasattr(module, 'experts') and hasattr(module, 'router'):
|
| 2013 |
+
num_experts = len(module.experts)
|
| 2014 |
+
tracker = ExpertUtilizationTracker(
|
| 2015 |
+
num_experts=num_experts,
|
| 2016 |
+
layer_name=name,
|
| 2017 |
+
report_interval=report_interval,
|
| 2018 |
+
)
|
| 2019 |
+
# Link offload manager if present
|
| 2020 |
+
if hasattr(module, '_expert_offload_manager'):
|
| 2021 |
+
tracker.link_offload_manager(module._expert_offload_manager)
|
| 2022 |
+
|
| 2023 |
+
module._utilization_tracker = tracker
|
| 2024 |
+
trackers.append(tracker)
|
| 2025 |
+
|
| 2026 |
+
if trackers:
|
| 2027 |
+
print(f" 📊 Attached {len(trackers)} expert utilization trackers (report every {report_interval} steps)")
|
| 2028 |
+
|
| 2029 |
+
return trackers
|
| 2030 |
+
|
| 2031 |
+
|
| 2032 |
class MoERouter(nn.Module):
|
| 2033 |
"""
|
| 2034 |
SOTA Router for Mixture of Experts v2.0 - FP16 native.
|
|
|
|
| 2217 |
|
| 2218 |
top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
|
| 2219 |
|
| 2220 |
+
# Record expert utilization if tracker is attached
|
| 2221 |
+
if hasattr(self, '_utilization_tracker'):
|
| 2222 |
+
self._utilization_tracker.record(top_k_indices)
|
| 2223 |
+
|
| 2224 |
final_output = torch.zeros_like(hidden_flat)
|
| 2225 |
|
| 2226 |
for expert_idx in range(self.num_experts):
|
|
|
|
| 4911 |
|
| 4912 |
present_key_value = (key, value) if use_cache else None
|
| 4913 |
|
| 4914 |
+
# True GQA via SDPA — no repeat_interleave, O(N) memory, FP16-safe
|
| 4915 |
+
qk_scale = self.head_dim ** -0.25
|
| 4916 |
+
kv_len = key.shape[2]
|
| 4917 |
+
use_causal = (attention_mask is None and seq_len > 1 and seq_len == kv_len)
|
| 4918 |
+
|
| 4919 |
+
dropout_p = self.dropout.p if self.training else 0.0
|
| 4920 |
+
|
| 4921 |
+
output = F.scaled_dot_product_attention(
|
| 4922 |
+
query * qk_scale,
|
| 4923 |
+
key * qk_scale,
|
| 4924 |
+
value,
|
| 4925 |
+
attn_mask=attention_mask,
|
| 4926 |
+
is_causal=use_causal,
|
| 4927 |
+
dropout_p=dropout_p,
|
| 4928 |
+
scale=1.0,
|
| 4929 |
+
enable_gqa=(self.num_key_value_groups > 1),
|
| 4930 |
+
)
|
| 4931 |
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 4932 |
output = self.o_proj(output)
|
| 4933 |
|
|
|
|
| 6906 |
"""
|
| 6907 |
Symmetric Dual-Stream Self-Attention (SD3/Flux-style).
|
| 6908 |
Two parallel streams with cross-stream information exchange.
|
| 6909 |
+
Uses Flash Attention 2.0 via SDPA for O(N) memory.
|
| 6910 |
"""
|
| 6911 |
|
| 6912 |
def __init__(self, hidden_size: int, num_heads: int = 8, max_height: int = 64, max_width: int = 64):
|
|
|
|
| 6915 |
self.num_heads = num_heads
|
| 6916 |
self.head_dim = hidden_size // num_heads
|
| 6917 |
self.scale = self.head_dim ** -0.5
|
| 6918 |
+
# Pre-compute Q/K scaling for FP16 stability
|
| 6919 |
+
self._qk_scale = self.head_dim ** -0.25
|
| 6920 |
|
| 6921 |
self.to_qkv_a = nn.Linear(hidden_size, hidden_size * 3, bias=False)
|
| 6922 |
self.to_qkv_b = nn.Linear(hidden_size, hidden_size * 3, bias=False)
|
|
|
|
| 6942 |
q_b, k_b, v_b = qkv_b.unbind(dim=2)
|
| 6943 |
|
| 6944 |
cos, sin = self.rope_2d(x_a, height, width)
|
|
|
|
|
|
|
| 6945 |
cos = cos.unsqueeze(0).unsqueeze(1)
|
| 6946 |
sin = sin.unsqueeze(0).unsqueeze(1)
|
| 6947 |
|
|
|
|
| 6960 |
k_combined = torch.cat([k_a, k_b], dim=2)
|
| 6961 |
v_combined = torch.cat([v_a, v_b], dim=2)
|
| 6962 |
|
| 6963 |
+
# Flash Attention 2.0 via SDPA — O(N) memory, FP16-safe with pre-scaling
|
| 6964 |
+
out_a = F.scaled_dot_product_attention(
|
| 6965 |
+
q_a * self._qk_scale, k_combined * self._qk_scale, v_combined,
|
| 6966 |
+
is_causal=False, scale=1.0,
|
| 6967 |
+
)
|
| 6968 |
+
out_b = F.scaled_dot_product_attention(
|
| 6969 |
+
q_b * self._qk_scale, k_combined * self._qk_scale, v_combined,
|
| 6970 |
+
is_causal=False, scale=1.0,
|
| 6971 |
+
)
|
| 6972 |
|
| 6973 |
out_a = out_a.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
|
| 6974 |
out_b = out_b.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
|
|
|
|
| 7005 |
k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7006 |
v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7007 |
|
| 7008 |
+
# Flash Attention 2.0 via SDPA — O(N) memory, non-causal cross-attention
|
| 7009 |
+
qk_scale = self.head_dim ** -0.25
|
| 7010 |
+
out = F.scaled_dot_product_attention(
|
| 7011 |
+
q * qk_scale, k * qk_scale, v,
|
| 7012 |
+
is_causal=False, scale=1.0,
|
| 7013 |
+
)
|
| 7014 |
out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 7015 |
out = self.to_out(out)
|
| 7016 |
|
|
|
|
| 7176 |
self.final_norm = nn.LayerNorm(hidden_size)
|
| 7177 |
self.unpatch_embed = UnpatchEmbed(patch_size, out_channels, hidden_size)
|
| 7178 |
|
| 7179 |
+
self.gradient_checkpointing = False
|
| 7180 |
+
|
| 7181 |
self._init_weights()
|
| 7182 |
|
| 7183 |
def _init_weights(self):
|
| 7184 |
nn.init.zeros_(self.unpatch_embed.proj.weight)
|
| 7185 |
nn.init.zeros_(self.unpatch_embed.proj.bias)
|
| 7186 |
|
| 7187 |
+
def enable_gradient_checkpointing(self):
|
| 7188 |
+
"""Enable gradient checkpointing for memory efficiency."""
|
| 7189 |
+
self.gradient_checkpointing = True
|
| 7190 |
+
|
| 7191 |
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 7192 |
batch_size, channels, height, width = x.shape
|
| 7193 |
patch_height = height // self.patch_size
|
|
|
|
| 7208 |
x_b = x_patches.clone()
|
| 7209 |
|
| 7210 |
for block in self.blocks:
|
| 7211 |
+
if self.gradient_checkpointing and self.training:
|
| 7212 |
+
x_a, x_b = torch.utils.checkpoint.checkpoint(
|
| 7213 |
+
block, x_a, x_b, context_proj, t_emb, patch_height, patch_width,
|
| 7214 |
+
use_reentrant=False
|
| 7215 |
+
)
|
| 7216 |
+
else:
|
| 7217 |
+
x_a, x_b = block(x_a, x_b, context_proj, t_emb, patch_height, patch_width)
|
| 7218 |
|
| 7219 |
x_combined = (x_a + x_b) / 2
|
| 7220 |
x_combined = self.final_norm(x_combined)
|
|
|
|
| 7722 |
q = apply_rope(q, cos, sin)
|
| 7723 |
k = apply_rope(k, cos, sin)
|
| 7724 |
|
| 7725 |
+
# Flash Attention 2.0 via SDPA — O(N) memory, non-causal spatial attention
|
| 7726 |
+
qk_scale = self.head_dim ** -0.25
|
| 7727 |
+
out = F.scaled_dot_product_attention(
|
| 7728 |
+
q * qk_scale, k * qk_scale, v,
|
| 7729 |
+
is_causal=False, scale=1.0,
|
| 7730 |
+
)
|
| 7731 |
out = out.transpose(1, 2).reshape(batch_size * frames, spatial_len, self.hidden_size)
|
| 7732 |
out = self.to_out(out)
|
| 7733 |
|
|
|
|
| 7778 |
q = apply_rope(q, cos, sin)
|
| 7779 |
k = apply_rope(k, cos, sin)
|
| 7780 |
|
| 7781 |
+
# Flash Attention 2.0 via SDPA — causal temporal attention
|
| 7782 |
+
qk_scale = self.head_dim ** -0.25
|
| 7783 |
+
out = F.scaled_dot_product_attention(
|
| 7784 |
+
q * qk_scale, k * qk_scale, v,
|
| 7785 |
+
is_causal=causal, scale=1.0,
|
| 7786 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7787 |
out = out.transpose(1, 2).reshape(batch_size * spatial_len, frames, self.hidden_size)
|
| 7788 |
|
| 7789 |
# Reshape back to [B, T*H*W, hidden]
|
|
|
|
| 7848 |
k = self.to_k(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7849 |
v = self.to_v(context).reshape(batch_size, ctx_len, self.heads, self.head_dim).transpose(1, 2)
|
| 7850 |
|
| 7851 |
+
# Flash Attention 2.0 via SDPA — non-causal cross-attention
|
| 7852 |
+
qk_scale = self.head_dim ** -0.25
|
| 7853 |
+
out = F.scaled_dot_product_attention(
|
| 7854 |
+
q * qk_scale, k * qk_scale, v,
|
| 7855 |
+
is_causal=False, scale=1.0,
|
| 7856 |
+
)
|
| 7857 |
out = out.transpose(1, 2).reshape(batch_size, seq_len, -1)
|
| 7858 |
out = self.to_out(out)
|
| 7859 |
|
|
|
|
| 7971 |
|
| 7972 |
nn.init.zeros_(self.output_proj[-1].weight)
|
| 7973 |
nn.init.zeros_(self.output_proj[-1].bias)
|
| 7974 |
+
|
| 7975 |
+
self.gradient_checkpointing = False
|
| 7976 |
+
|
| 7977 |
+
def enable_gradient_checkpointing(self):
|
| 7978 |
+
"""Enable gradient checkpointing for memory efficiency."""
|
| 7979 |
+
self.gradient_checkpointing = True
|
| 7980 |
|
| 7981 |
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor, first_frame_latent: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 7982 |
batch_size, channels, frames, height, width = x.shape
|
|
|
|
| 7995 |
temporal_context = t_emb.unsqueeze(1).expand(-1, frames * height * width, -1)
|
| 7996 |
|
| 7997 |
for block in self.transformer_blocks:
|
| 7998 |
+
if self.gradient_checkpointing and self.training:
|
| 7999 |
+
h = torch.utils.checkpoint.checkpoint(
|
| 8000 |
+
block, h, context, height, width, frames, temporal_context,
|
| 8001 |
+
use_reentrant=False
|
| 8002 |
+
)
|
| 8003 |
+
else:
|
| 8004 |
+
h = block(h, context, height, width, frames, temporal_context)
|
| 8005 |
|
| 8006 |
h = h.reshape(batch_size, frames, height, width, self.hidden_size).permute(0, 4, 1, 2, 3)
|
| 8007 |
|
|
|
|
| 8475 |
return q_embed, k_embed
|
| 8476 |
|
| 8477 |
|
|
|
|
| 8478 |
class KVCache:
|
| 8479 |
+
"""Pre-allocated KV Cache — static buffer with index-based filling.
|
| 8480 |
+
|
| 8481 |
+
Eliminates VRAM fragmentation from torch.cat during autoregressive generation.
|
| 8482 |
+
Buffer is allocated once at first use and reused via slice assignment.
|
| 8483 |
+
"""
|
| 8484 |
+
|
| 8485 |
+
__slots__ = ('key_cache', 'value_cache', 'seen_tokens', '_max_len')
|
| 8486 |
+
|
| 8487 |
+
def __init__(
|
| 8488 |
+
self,
|
| 8489 |
+
key_cache: torch.Tensor = None,
|
| 8490 |
+
value_cache: torch.Tensor = None,
|
| 8491 |
+
seen_tokens: int = 0,
|
| 8492 |
+
max_seq_len: int = 131072,
|
| 8493 |
+
):
|
| 8494 |
+
self.key_cache = key_cache
|
| 8495 |
+
self.value_cache = value_cache
|
| 8496 |
+
self.seen_tokens = seen_tokens
|
| 8497 |
+
self._max_len = max_seq_len
|
| 8498 |
+
|
| 8499 |
+
def _allocate(self, batch: int, heads: int, head_dim: int, device: torch.device, dtype: torch.dtype):
|
| 8500 |
+
"""Allocate static buffer on first use."""
|
| 8501 |
+
self.key_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
|
| 8502 |
+
self.value_cache = torch.zeros(batch, heads, self._max_len, head_dim, device=device, dtype=dtype)
|
| 8503 |
|
| 8504 |
def update(
|
| 8505 |
self,
|
|
|
|
| 8507 |
value_states: torch.Tensor,
|
| 8508 |
chunk_size: Optional[int] = None,
|
| 8509 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 8510 |
+
batch, heads, new_len, head_dim = key_states.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8511 |
|
| 8512 |
+
if self.key_cache is None:
|
| 8513 |
+
# Allocate static buffer on first call
|
| 8514 |
+
self._allocate(batch, heads, head_dim, key_states.device, key_states.dtype)
|
| 8515 |
+
self.seen_tokens = 0
|
| 8516 |
+
|
| 8517 |
+
# Check if we need to apply chunk windowing
|
| 8518 |
+
if chunk_size is not None and self.seen_tokens + new_len > chunk_size * 2:
|
| 8519 |
+
# Shift: keep only the last chunk_size tokens, then append
|
| 8520 |
+
keep = chunk_size
|
| 8521 |
+
if self.seen_tokens > keep:
|
| 8522 |
+
self.key_cache[:, :, :keep] = self.key_cache[:, :, self.seen_tokens - keep:self.seen_tokens].clone()
|
| 8523 |
+
self.value_cache[:, :, :keep] = self.value_cache[:, :, self.seen_tokens - keep:self.seen_tokens].clone()
|
| 8524 |
+
self.seen_tokens = keep
|
| 8525 |
+
|
| 8526 |
+
# Grow buffer if needed (rare fallback — avoids crash on very long sequences)
|
| 8527 |
+
if self.seen_tokens + new_len > self.key_cache.shape[2]:
|
| 8528 |
+
new_max = max(self.key_cache.shape[2] * 2, self.seen_tokens + new_len)
|
| 8529 |
+
new_key = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
|
| 8530 |
+
new_val = torch.zeros(batch, heads, new_max, head_dim, device=key_states.device, dtype=key_states.dtype)
|
| 8531 |
+
new_key[:, :, :self.seen_tokens] = self.key_cache[:, :, :self.seen_tokens]
|
| 8532 |
+
new_val[:, :, :self.seen_tokens] = self.value_cache[:, :, :self.seen_tokens]
|
| 8533 |
+
self.key_cache = new_key
|
| 8534 |
+
self.value_cache = new_val
|
| 8535 |
+
|
| 8536 |
+
# Index-based fill — no allocation, no fragmentation
|
| 8537 |
+
self.key_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = key_states
|
| 8538 |
+
self.value_cache[:, :, self.seen_tokens:self.seen_tokens + new_len] = value_states
|
| 8539 |
+
self.seen_tokens += new_len
|
| 8540 |
+
|
| 8541 |
+
# Return only the valid slice (view, no copy)
|
| 8542 |
+
return self.key_cache[:, :, :self.seen_tokens], self.value_cache[:, :, :self.seen_tokens]
|
| 8543 |
|
| 8544 |
+
def reset(self):
|
| 8545 |
+
"""Reset cache position without deallocating the buffer."""
|
| 8546 |
+
self.seen_tokens = 0
|
| 8547 |
|
| 8548 |
|
| 8549 |
def ring_attention(
|
|
|
|
| 8555 |
) -> torch.Tensor:
|
| 8556 |
"""
|
| 8557 |
Ring Attention for distributed long-context processing.
|
| 8558 |
+
Processes sequence in chunks with online softmax accumulation.
|
| 8559 |
|
| 8560 |
Args:
|
| 8561 |
query: [batch, heads, seq_len, head_dim]
|
|
|
|
| 8569 |
"""
|
| 8570 |
batch_size, num_heads, seq_len, head_dim = query.shape
|
| 8571 |
kv_len = key.shape[2]
|
|
|
|
| 8572 |
|
| 8573 |
+
# Short path: use SDPA directly for small sequences
|
| 8574 |
if seq_len <= chunk_size and kv_len <= chunk_size:
|
| 8575 |
+
qk_scale = head_dim ** -0.25
|
| 8576 |
+
use_causal = causal and seq_len == kv_len and seq_len > 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8577 |
|
| 8578 |
+
if use_causal:
|
| 8579 |
+
return F.scaled_dot_product_attention(
|
| 8580 |
+
query * qk_scale, key * qk_scale, value,
|
| 8581 |
+
is_causal=True, scale=1.0,
|
| 8582 |
+
)
|
| 8583 |
+
elif causal and kv_len > seq_len:
|
| 8584 |
+
# KV cache case: build explicit causal mask
|
| 8585 |
+
causal_mask = torch.zeros(seq_len, kv_len, device=query.device, dtype=query.dtype)
|
| 8586 |
+
q_pos = torch.arange(seq_len, device=query.device) + (kv_len - seq_len)
|
| 8587 |
+
k_pos = torch.arange(kv_len, device=query.device)
|
| 8588 |
+
causal_mask = torch.where(k_pos.unsqueeze(0) > q_pos.unsqueeze(1), float('-inf'), 0.0)
|
| 8589 |
+
return F.scaled_dot_product_attention(
|
| 8590 |
+
query * qk_scale, key * qk_scale, value,
|
| 8591 |
+
attn_mask=causal_mask, scale=1.0,
|
| 8592 |
+
)
|
| 8593 |
+
else:
|
| 8594 |
+
return F.scaled_dot_product_attention(
|
| 8595 |
+
query * qk_scale, key * qk_scale, value,
|
| 8596 |
+
is_causal=False, scale=1.0,
|
| 8597 |
+
)
|
| 8598 |
|
| 8599 |
+
# Long path: chunked attention with online softmax (FlashAttention-style)
|
| 8600 |
+
scale = head_dim ** -0.5
|
| 8601 |
output = torch.zeros_like(query)
|
| 8602 |
max_logits = torch.full((batch_size, num_heads, seq_len, 1), float('-inf'), device=query.device, dtype=query.dtype)
|
| 8603 |
sum_exp = torch.zeros((batch_size, num_heads, seq_len, 1), device=query.device, dtype=query.dtype)
|
| 8604 |
|
| 8605 |
+
# Pre-compute query positions for vectorized causal masking
|
| 8606 |
+
if causal:
|
| 8607 |
+
q_positions = torch.arange(seq_len, device=query.device)
|
| 8608 |
+
if kv_len > seq_len:
|
| 8609 |
+
q_positions = q_positions + (kv_len - seq_len)
|
| 8610 |
+
|
| 8611 |
num_kv_chunks = (kv_len + chunk_size - 1) // chunk_size
|
| 8612 |
|
| 8613 |
for kv_idx in range(num_kv_chunks):
|
|
|
|
| 8620 |
attn_chunk = torch.matmul(query, key_chunk.transpose(-1, -2)) * scale
|
| 8621 |
|
| 8622 |
if causal:
|
| 8623 |
+
# Vectorized causal mask — replaces O(n²) nested Python loop
|
| 8624 |
+
k_positions = torch.arange(kv_start, kv_end, device=query.device)
|
| 8625 |
+
# mask[i, j] = True where k_pos[j] > q_pos[i] (future tokens)
|
| 8626 |
+
causal_mask = k_positions.unsqueeze(0) > q_positions.unsqueeze(1) # [seq_len, chunk_len]
|
| 8627 |
+
attn_chunk = attn_chunk.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
|
|
|
|
|
| 8628 |
|
| 8629 |
chunk_max = attn_chunk.max(dim=-1, keepdim=True)[0]
|
| 8630 |
new_max = torch.maximum(max_logits, chunk_max)
|
|
|
|
| 8741 |
self.ring_chunk_size if self.use_ring_attention else None
|
| 8742 |
)
|
| 8743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8744 |
if self.use_ring_attention:
|
| 8745 |
+
# Ring attention needs matched head counts — expand KV heads
|
| 8746 |
+
if self.num_key_value_groups > 1:
|
| 8747 |
+
key_expanded = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 8748 |
+
value_expanded = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
| 8749 |
+
else:
|
| 8750 |
+
key_expanded = key_states
|
| 8751 |
+
value_expanded = value_states
|
| 8752 |
attn_output = ring_attention(
|
| 8753 |
+
query_states, key_expanded, value_expanded,
|
| 8754 |
chunk_size=self.ring_chunk_size,
|
| 8755 |
causal=True,
|
| 8756 |
)
|
| 8757 |
else:
|
| 8758 |
+
# True GQA via SDPA — no repeat_interleave, O(N) memory
|
| 8759 |
+
# SDPA natively handles N query heads with M KV heads via enable_gqa
|
| 8760 |
+
qk_scale = self.head_dim ** -0.25
|
| 8761 |
kv_len = key_states.shape[2]
|
| 8762 |
+
use_causal = (attention_mask is None and seq_len > 1 and seq_len == kv_len)
|
| 8763 |
+
|
| 8764 |
+
attn_output = F.scaled_dot_product_attention(
|
| 8765 |
+
query_states * qk_scale,
|
| 8766 |
+
key_states * qk_scale,
|
| 8767 |
+
value_states,
|
| 8768 |
+
attn_mask=attention_mask,
|
| 8769 |
+
is_causal=use_causal,
|
| 8770 |
+
scale=1.0, # Already scaled Q and K
|
| 8771 |
+
enable_gqa=(self.num_key_value_groups > 1),
|
| 8772 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8773 |
|
| 8774 |
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 8775 |
attn_output = self.o_proj(attn_output)
|
|
|
|
| 8798 |
self.input_norm = LlamaRMSNorm(hidden_size)
|
| 8799 |
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
|
| 8800 |
nn.init.normal_(self.gate.weight, mean=0.0, std=0.01)
|
| 8801 |
+
|
| 8802 |
+
# DeepSeek-style expert bias for aux-lossless load balancing
|
| 8803 |
+
# This learnable bias steers token routing to underutilized experts
|
| 8804 |
+
# without requiring an auxiliary loss term
|
| 8805 |
+
self.expert_bias = nn.Parameter(torch.zeros(num_experts))
|
| 8806 |
|
| 8807 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 8808 |
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
| 8811 |
hidden_norm = self.input_norm(hidden_flat)
|
| 8812 |
router_logits = self.gate(hidden_norm)
|
| 8813 |
|
| 8814 |
+
# Add expert bias for load balancing (aux-lossless mechanism)
|
| 8815 |
+
biased_logits = router_logits + self.expert_bias
|
| 8816 |
+
|
| 8817 |
+
router_probs = F.softmax(biased_logits, dim=-1, dtype=hidden_states.dtype)
|
| 8818 |
|
| 8819 |
top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 8820 |
|
|
|
|
| 8902 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 8903 |
original_dtype = hidden_states.dtype
|
| 8904 |
hidden_flat = hidden_states.view(-1, hidden_size)
|
| 8905 |
+
num_tokens = hidden_flat.shape[0]
|
| 8906 |
|
| 8907 |
+
top_k_probs, top_k_indices, router_logits = self.router(hidden_states)
|
| 8908 |
+
|
| 8909 |
+
# Record expert utilization if tracker is attached
|
| 8910 |
+
if hasattr(self, '_utilization_tracker'):
|
| 8911 |
+
self._utilization_tracker.record(top_k_indices)
|
| 8912 |
|
| 8913 |
final_output = torch.zeros_like(hidden_flat)
|
| 8914 |
|
|
|
|
| 8928 |
|
| 8929 |
final_output = final_output.view(batch_size, seq_len, hidden_size)
|
| 8930 |
|
| 8931 |
+
# Aux-lossless: z-loss only for router logit stability
|
| 8932 |
+
# The expert_bias in the router handles load balancing architecturally
|
| 8933 |
+
aux_loss = self._compute_aux_loss(router_logits, top_k_indices, num_tokens)
|
| 8934 |
|
| 8935 |
return final_output, aux_loss
|
| 8936 |
|
| 8937 |
+
def _compute_aux_loss(
|
| 8938 |
+
self,
|
| 8939 |
+
router_logits: torch.Tensor,
|
| 8940 |
+
top_k_indices: torch.Tensor,
|
| 8941 |
+
num_tokens: int,
|
| 8942 |
+
) -> torch.Tensor:
|
| 8943 |
+
"""
|
| 8944 |
+
Aux-lossless auxiliary loss.
|
| 8945 |
+
|
| 8946 |
+
Uses z-loss to keep router logits from growing unboundedly (FP16 stability),
|
| 8947 |
+
plus a soft utilization penalty that activates only when experts go completely
|
| 8948 |
+
cold. The expert_bias parameter handles routine load balancing.
|
| 8949 |
+
"""
|
| 8950 |
+
# z-loss: prevents router logit explosion in FP16
|
| 8951 |
+
z_loss = torch.logsumexp(router_logits, dim=-1).square().mean() * 0.0001
|
| 8952 |
+
|
| 8953 |
+
# Soft utilization penalty: only penalizes fully-dead experts
|
| 8954 |
+
# This does NOT hurt convergence because it only activates at extremes
|
| 8955 |
+
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
|
| 8956 |
+
tokens_per_expert = expert_mask.sum(dim=(0, 1)) # [num_experts]
|
| 8957 |
+
fraction_used = (tokens_per_expert > 0).float().mean()
|
| 8958 |
+
utilization_loss = (1.0 - fraction_used) * 0.01 # Very soft penalty
|
| 8959 |
+
|
| 8960 |
+
return z_loss + utilization_loss
|
| 8961 |
+
|
| 8962 |
|
| 8963 |
MoELayer = AuxLosslessMoELayer
|
| 8964 |
|
streaming_state.json
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
{
|
| 2 |
-
"epoch":
|
| 3 |
-
"unique_samples":
|
| 4 |
-
"total_yields":
|
| 5 |
"dataset_positions": {
|
| 6 |
"WebSight": 386,
|
| 7 |
"ScienceQA": 364,
|
|
@@ -18,16 +18,16 @@
|
|
| 18 |
"CodeParrot-Clean": 200,
|
| 19 |
"ShareGPT-Clean": 200,
|
| 20 |
"Synth-Issues": 200,
|
| 21 |
-
"Dolly-15k":
|
| 22 |
-
"Conversation-Summarization":
|
| 23 |
"Synth-ShellTimeout": 200,
|
| 24 |
"Synth-Docker": 200,
|
| 25 |
-
"Synth-Documents":
|
| 26 |
"HumanEval-JavaScript": 164,
|
| 27 |
-
"OpenOrca":
|
| 28 |
"Synth-MultiStepExecution": 200,
|
| 29 |
"Synth-Citation": 200,
|
| 30 |
-
"NoRobots":
|
| 31 |
"Synth-LanguageSetup": 200,
|
| 32 |
"Function-Calling-ChatML": 200,
|
| 33 |
"Synth-CoT": 200,
|
|
@@ -75,7 +75,7 @@
|
|
| 75 |
"Synth-Debugging": 200,
|
| 76 |
"Tool-Calls-SingleTurn": 200,
|
| 77 |
"Tool-Calls-Multiturn": 200,
|
| 78 |
-
"OpenAssistant":
|
| 79 |
"T2V-Sora-Preferences-2": 200,
|
| 80 |
"T2V-Human-Preferences": 200,
|
| 81 |
"Sora-Alignment-Likert": 198,
|
|
@@ -84,7 +84,9 @@
|
|
| 84 |
"WebVid-10M": 200,
|
| 85 |
"Sora-Physics-Likert": 198,
|
| 86 |
"TIP-I2V": 200,
|
| 87 |
-
"Pexels-I2V-350k": 200
|
|
|
|
|
|
|
| 88 |
},
|
| 89 |
"modality_positions": {
|
| 90 |
"text": {
|
|
@@ -92,11 +94,11 @@
|
|
| 92 |
"Midjourney-Prompts": 200,
|
| 93 |
"CodeParrot-Clean": 200,
|
| 94 |
"ShareGPT-Clean": 200,
|
| 95 |
-
"Dolly-15k":
|
| 96 |
-
"Conversation-Summarization":
|
| 97 |
"HumanEval-JavaScript": 164,
|
| 98 |
-
"OpenOrca":
|
| 99 |
-
"NoRobots":
|
| 100 |
"Function-Calling-ChatML": 200,
|
| 101 |
"Python-Code-18k": 200,
|
| 102 |
"Code-Feedback": 200,
|
|
@@ -119,7 +121,9 @@
|
|
| 119 |
"HumanEval-Rust": 164,
|
| 120 |
"Tool-Calls-SingleTurn": 200,
|
| 121 |
"Tool-Calls-Multiturn": 200,
|
| 122 |
-
"OpenAssistant":
|
|
|
|
|
|
|
| 123 |
},
|
| 124 |
"image": {
|
| 125 |
"WebSight": 386,
|
|
@@ -144,9 +148,9 @@
|
|
| 144 |
"audio": {}
|
| 145 |
},
|
| 146 |
"modality_counts": {
|
| 147 |
-
"text":
|
| 148 |
"image": 0,
|
| 149 |
-
"video":
|
| 150 |
"audio": 0
|
| 151 |
},
|
| 152 |
"last_modality": null
|
|
|
|
| 1 |
{
|
| 2 |
+
"epoch": 35,
|
| 3 |
+
"unique_samples": 400,
|
| 4 |
+
"total_yields": 800,
|
| 5 |
"dataset_positions": {
|
| 6 |
"WebSight": 386,
|
| 7 |
"ScienceQA": 364,
|
|
|
|
| 18 |
"CodeParrot-Clean": 200,
|
| 19 |
"ShareGPT-Clean": 200,
|
| 20 |
"Synth-Issues": 200,
|
| 21 |
+
"Dolly-15k": 450,
|
| 22 |
+
"Conversation-Summarization": 450,
|
| 23 |
"Synth-ShellTimeout": 200,
|
| 24 |
"Synth-Docker": 200,
|
| 25 |
+
"Synth-Documents": 450,
|
| 26 |
"HumanEval-JavaScript": 164,
|
| 27 |
+
"OpenOrca": 450,
|
| 28 |
"Synth-MultiStepExecution": 200,
|
| 29 |
"Synth-Citation": 200,
|
| 30 |
+
"NoRobots": 450,
|
| 31 |
"Synth-LanguageSetup": 200,
|
| 32 |
"Function-Calling-ChatML": 200,
|
| 33 |
"Synth-CoT": 200,
|
|
|
|
| 75 |
"Synth-Debugging": 200,
|
| 76 |
"Tool-Calls-SingleTurn": 200,
|
| 77 |
"Tool-Calls-Multiturn": 200,
|
| 78 |
+
"OpenAssistant": 450,
|
| 79 |
"T2V-Sora-Preferences-2": 200,
|
| 80 |
"T2V-Human-Preferences": 200,
|
| 81 |
"Sora-Alignment-Likert": 198,
|
|
|
|
| 84 |
"WebVid-10M": 200,
|
| 85 |
"Sora-Physics-Likert": 198,
|
| 86 |
"TIP-I2V": 200,
|
| 87 |
+
"Pexels-I2V-350k": 200,
|
| 88 |
+
"SmolTalk-OpenHermes": 250,
|
| 89 |
+
"SmolTalk-All": 250
|
| 90 |
},
|
| 91 |
"modality_positions": {
|
| 92 |
"text": {
|
|
|
|
| 94 |
"Midjourney-Prompts": 200,
|
| 95 |
"CodeParrot-Clean": 200,
|
| 96 |
"ShareGPT-Clean": 200,
|
| 97 |
+
"Dolly-15k": 450,
|
| 98 |
+
"Conversation-Summarization": 450,
|
| 99 |
"HumanEval-JavaScript": 164,
|
| 100 |
+
"OpenOrca": 450,
|
| 101 |
+
"NoRobots": 450,
|
| 102 |
"Function-Calling-ChatML": 200,
|
| 103 |
"Python-Code-18k": 200,
|
| 104 |
"Code-Feedback": 200,
|
|
|
|
| 121 |
"HumanEval-Rust": 164,
|
| 122 |
"Tool-Calls-SingleTurn": 200,
|
| 123 |
"Tool-Calls-Multiturn": 200,
|
| 124 |
+
"OpenAssistant": 450,
|
| 125 |
+
"SmolTalk-OpenHermes": 250,
|
| 126 |
+
"SmolTalk-All": 250
|
| 127 |
},
|
| 128 |
"image": {
|
| 129 |
"WebSight": 386,
|
|
|
|
| 148 |
"audio": {}
|
| 149 |
},
|
| 150 |
"modality_counts": {
|
| 151 |
+
"text": 400,
|
| 152 |
"image": 0,
|
| 153 |
+
"video": 0,
|
| 154 |
"audio": 0
|
| 155 |
},
|
| 156 |
"last_modality": null
|
trainer_state.json
CHANGED
|
@@ -1,32 +1,32 @@
|
|
| 1 |
{
|
| 2 |
"best_model_checkpoint": "/kaggle/working/xoron-final",
|
| 3 |
-
"best_metric":
|
| 4 |
-
"epoch":
|
| 5 |
-
"epochs_completed":
|
| 6 |
-
"global_step":
|
| 7 |
"is_local_process_zero": true,
|
| 8 |
"is_world_process_zero": true,
|
| 9 |
"log_history": [],
|
| 10 |
"logging_steps": 50,
|
| 11 |
-
"max_steps":
|
| 12 |
-
"num_train_epochs":
|
| 13 |
"total_flos": 0,
|
| 14 |
"train_batch_size": 1,
|
| 15 |
"effective_batch_size": 16,
|
| 16 |
"learning_rate": 0.0001,
|
| 17 |
"max_grad_norm": 1.0,
|
| 18 |
"trainable_components": [
|
| 19 |
-
"vision",
|
| 20 |
-
"video",
|
| 21 |
"llm",
|
| 22 |
"cross_attention",
|
| 23 |
-
"video_generation",
|
| 24 |
"modality_markers"
|
| 25 |
],
|
| 26 |
"frozen_components": [
|
|
|
|
|
|
|
| 27 |
"audio",
|
| 28 |
"speech",
|
| 29 |
-
"image_generation"
|
|
|
|
| 30 |
],
|
| 31 |
"trial_name": null,
|
| 32 |
"trial_params": null
|
|
|
|
| 1 |
{
|
| 2 |
"best_model_checkpoint": "/kaggle/working/xoron-final",
|
| 3 |
+
"best_metric": 6.958861378133297,
|
| 4 |
+
"epoch": 5,
|
| 5 |
+
"epochs_completed": 5,
|
| 6 |
+
"global_step": 250,
|
| 7 |
"is_local_process_zero": true,
|
| 8 |
"is_world_process_zero": true,
|
| 9 |
"log_history": [],
|
| 10 |
"logging_steps": 50,
|
| 11 |
+
"max_steps": 250,
|
| 12 |
+
"num_train_epochs": 5,
|
| 13 |
"total_flos": 0,
|
| 14 |
"train_batch_size": 1,
|
| 15 |
"effective_batch_size": 16,
|
| 16 |
"learning_rate": 0.0001,
|
| 17 |
"max_grad_norm": 1.0,
|
| 18 |
"trainable_components": [
|
|
|
|
|
|
|
| 19 |
"llm",
|
| 20 |
"cross_attention",
|
|
|
|
| 21 |
"modality_markers"
|
| 22 |
],
|
| 23 |
"frozen_components": [
|
| 24 |
+
"vision",
|
| 25 |
+
"video",
|
| 26 |
"audio",
|
| 27 |
"speech",
|
| 28 |
+
"image_generation",
|
| 29 |
+
"video_generation"
|
| 30 |
],
|
| 31 |
"trial_name": null,
|
| 32 |
"trial_params": null
|
training_state.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a751ecf22021470154d58846b700d04286522c14cda7393ece31f907eff5a2c7
|
| 3 |
+
size 1514911851
|