Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +2 -2
- config.json +2 -2
- configuration.py +22 -22
- huggingface.py +301 -376
README.md
CHANGED
|
@@ -82,9 +82,10 @@ contains no weights. All values are overridable via kwargs.
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
-
| `load_balance_loss_type` |
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
|
|
|
| 88 |
| `mlp_width` | 1366 |
|
| 89 |
| `mosrah_overallocation_factor` | 2.0 |
|
| 90 |
| `mosrah_rope_theta` | 10000.0 |
|
|
@@ -95,7 +96,6 @@ contains no weights. All values are overridable via kwargs.
|
|
| 95 |
| `output_hidden_states` | False |
|
| 96 |
| `rms_norm_eps` | 1e-05 |
|
| 97 |
| `rope_mode` | main_sequence |
|
| 98 |
-
| `routing_mode` | integral |
|
| 99 |
| `tie_word_embeddings` | False |
|
| 100 |
| `training_sequence_length` | 1024 |
|
| 101 |
| `use_cache` | True |
|
|
|
|
| 82 |
| `embedding_width` | 512 |
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
+
| `load_balance_loss_type` | temporal_overcapacity |
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
| 88 |
+
| `maximum_expert_overclaim` | 20 |
|
| 89 |
| `mlp_width` | 1366 |
|
| 90 |
| `mosrah_overallocation_factor` | 2.0 |
|
| 91 |
| `mosrah_rope_theta` | 10000.0 |
|
|
|
|
| 96 |
| `output_hidden_states` | False |
|
| 97 |
| `rms_norm_eps` | 1e-05 |
|
| 98 |
| `rope_mode` | main_sequence |
|
|
|
|
| 99 |
| `tie_word_embeddings` | False |
|
| 100 |
| `training_sequence_length` | 1024 |
|
| 101 |
| `use_cache` | True |
|
config.json
CHANGED
|
@@ -9,9 +9,10 @@
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
-
"load_balance_loss_type": "
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
|
|
|
| 15 |
"mlp_width": 1366,
|
| 16 |
"model_type": "shram",
|
| 17 |
"mosrah_overallocation_factor": 2.0,
|
|
@@ -22,7 +23,6 @@
|
|
| 22 |
"num_sliding_window_heads": 16,
|
| 23 |
"rms_norm_eps": 1e-05,
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
-
"routing_mode": "integral",
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
"transformers_version": "5.10.2",
|
|
|
|
| 9 |
"embedding_width": 512,
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
+
"load_balance_loss_type": "temporal_overcapacity",
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
| 15 |
+
"maximum_expert_overclaim": 20,
|
| 16 |
"mlp_width": 1366,
|
| 17 |
"model_type": "shram",
|
| 18 |
"mosrah_overallocation_factor": 2.0,
|
|
|
|
| 23 |
"num_sliding_window_heads": 16,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"rope_mode": "main_sequence",
|
|
|
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
"transformers_version": "5.10.2",
|
configuration.py
CHANGED
|
@@ -91,17 +91,18 @@ class ShramConfig(PretrainedConfig):
|
|
| 91 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 92 |
Default 10.
|
| 93 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 94 |
-
One of ``"gshard"``, ``"ce"``,
|
| 95 |
-
is the default;
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
``"
|
|
|
|
| 105 |
"""
|
| 106 |
|
| 107 |
model_type = "shram"
|
|
@@ -136,8 +137,8 @@ class ShramConfig(PretrainedConfig):
|
|
| 136 |
tie_word_embeddings: bool = False,
|
| 137 |
mosrah_overallocation_factor: float = 2.0,
|
| 138 |
max_bid_rounds: int = 10,
|
| 139 |
-
load_balance_loss_type: str = "
|
| 140 |
-
|
| 141 |
**kwargs
|
| 142 |
):
|
| 143 |
if head_dim % 2 != 0:
|
|
@@ -178,7 +179,13 @@ class ShramConfig(PretrainedConfig):
|
|
| 178 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 179 |
)
|
| 180 |
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if load_balance_loss_type not in _supported_loss_types:
|
| 183 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 184 |
raise ValueError(
|
|
@@ -186,13 +193,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 186 |
f"got {load_balance_loss_type!r}."
|
| 187 |
)
|
| 188 |
|
| 189 |
-
_supported_routing_modes = {"default", "integral"}
|
| 190 |
-
if routing_mode not in _supported_routing_modes:
|
| 191 |
-
supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
|
| 192 |
-
raise ValueError(
|
| 193 |
-
f"routing_mode must be one of {supported}, got {routing_mode!r}."
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
self.vocab_size = vocab_size
|
| 197 |
self.embedding_width = embedding_width
|
| 198 |
self.mlp_width = mlp_width
|
|
@@ -213,7 +213,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 213 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 214 |
self.max_bid_rounds = max_bid_rounds
|
| 215 |
self.load_balance_loss_type = load_balance_loss_type
|
| 216 |
-
self.
|
| 217 |
self.attention_dropout = attention_dropout
|
| 218 |
self.use_cache = use_cache
|
| 219 |
|
|
|
|
| 91 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 92 |
Default 10.
|
| 93 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 94 |
+
One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
|
| 95 |
+
``"temporal_overcapacity"`` is the default; it fires only when an expert
|
| 96 |
+
exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
|
| 97 |
+
and shuts off automatically once routing is balanced, allowing it to be
|
| 98 |
+
used with a strong weight without interfering with task training during
|
| 99 |
+
balanced routing. Default ``"temporal_overcapacity"``.
|
| 100 |
+
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 101 |
+
its ideal allocation trajectory before the temporal overcapacity loss
|
| 102 |
+
fires. A value of 0 means violations trigger immediately at any imbalance.
|
| 103 |
+
Larger values permit short-lived semantic specialization before correction.
|
| 104 |
+
Only used when ``load_balance_loss_type="temporal_overcapacity"``.
|
| 105 |
+
Must be non-negative. Default 20.
|
| 106 |
"""
|
| 107 |
|
| 108 |
model_type = "shram"
|
|
|
|
| 137 |
tie_word_embeddings: bool = False,
|
| 138 |
mosrah_overallocation_factor: float = 2.0,
|
| 139 |
max_bid_rounds: int = 10,
|
| 140 |
+
load_balance_loss_type: str = "temporal_overcapacity",
|
| 141 |
+
maximum_expert_overclaim: int = 20,
|
| 142 |
**kwargs
|
| 143 |
):
|
| 144 |
if head_dim % 2 != 0:
|
|
|
|
| 179 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 180 |
)
|
| 181 |
|
| 182 |
+
if maximum_expert_overclaim < 0:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"maximum_expert_overclaim must be non-negative, "
|
| 185 |
+
f"got {maximum_expert_overclaim}."
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
|
| 189 |
if load_balance_loss_type not in _supported_loss_types:
|
| 190 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 191 |
raise ValueError(
|
|
|
|
| 193 |
f"got {load_balance_loss_type!r}."
|
| 194 |
)
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
self.vocab_size = vocab_size
|
| 197 |
self.embedding_width = embedding_width
|
| 198 |
self.mlp_width = mlp_width
|
|
|
|
| 213 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 214 |
self.max_bid_rounds = max_bid_rounds
|
| 215 |
self.load_balance_loss_type = load_balance_loss_type
|
| 216 |
+
self.maximum_expert_overclaim = maximum_expert_overclaim
|
| 217 |
self.attention_dropout = attention_dropout
|
| 218 |
self.use_cache = use_cache
|
| 219 |
|
huggingface.py
CHANGED
|
@@ -178,17 +178,18 @@ class ShramConfig(PretrainedConfig):
|
|
| 178 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 179 |
Default 10.
|
| 180 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 181 |
-
One of ``"gshard"``, ``"ce"``,
|
| 182 |
-
is the default;
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
``"
|
|
|
|
| 192 |
"""
|
| 193 |
|
| 194 |
model_type = "shram"
|
|
@@ -223,8 +224,8 @@ class ShramConfig(PretrainedConfig):
|
|
| 223 |
tie_word_embeddings: bool = False,
|
| 224 |
mosrah_overallocation_factor: float = 2.0,
|
| 225 |
max_bid_rounds: int = 10,
|
| 226 |
-
load_balance_loss_type: str = "
|
| 227 |
-
|
| 228 |
**kwargs
|
| 229 |
):
|
| 230 |
if head_dim % 2 != 0:
|
|
@@ -265,7 +266,13 @@ class ShramConfig(PretrainedConfig):
|
|
| 265 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 266 |
)
|
| 267 |
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
if load_balance_loss_type not in _supported_loss_types:
|
| 270 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 271 |
raise ValueError(
|
|
@@ -273,13 +280,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 273 |
f"got {load_balance_loss_type!r}."
|
| 274 |
)
|
| 275 |
|
| 276 |
-
_supported_routing_modes = {"default", "integral"}
|
| 277 |
-
if routing_mode not in _supported_routing_modes:
|
| 278 |
-
supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
|
| 279 |
-
raise ValueError(
|
| 280 |
-
f"routing_mode must be one of {supported}, got {routing_mode!r}."
|
| 281 |
-
)
|
| 282 |
-
|
| 283 |
self.vocab_size = vocab_size
|
| 284 |
self.embedding_width = embedding_width
|
| 285 |
self.mlp_width = mlp_width
|
|
@@ -300,7 +300,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 300 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 301 |
self.max_bid_rounds = max_bid_rounds
|
| 302 |
self.load_balance_loss_type = load_balance_loss_type
|
| 303 |
-
self.
|
| 304 |
self.attention_dropout = attention_dropout
|
| 305 |
self.use_cache = use_cache
|
| 306 |
|
|
@@ -1478,10 +1478,7 @@ Returns a plain dict with keys:
|
|
| 1478 |
- "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
|
| 1479 |
- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
|
| 1480 |
- "max_vio": detached scalar maximum routing-imbalance across all decoder layers
|
| 1481 |
-
- "
|
| 1482 |
-
- "raw_logit_std": detached scalar mean per-layer per-token routing logit spread
|
| 1483 |
-
- "logit_std": detached scalar mean per-layer per-token combined (logit + bias) spread
|
| 1484 |
-
- "bias_alignment": detached scalar mean per-layer cosine similarity of bias vs logits
|
| 1485 |
"""
|
| 1486 |
|
| 1487 |
|
|
@@ -2725,71 +2722,38 @@ This module implements the routing mechanism described in Appendix A.Routing of
|
|
| 2725 |
paper. Given an input hidden state x, the router produces two outputs used downstream:
|
| 2726 |
|
| 2727 |
- selected_heads (I): which K of the L available expert heads each token routes to,
|
| 2728 |
-
determined by TopK over capacity-balanced
|
| 2729 |
- routing_probs (P): the weights used for the weighted output reduction, gathered from
|
| 2730 |
-
the
|
| 2731 |
-
per token.
|
| 2732 |
-
|
| 2733 |
-
Base routing uses two learnable projection matrices and two gradient-isolated pathways:
|
| 2734 |
-
|
| 2735 |
-
- routing_weight (A): shape (L, embedding_width). Maps input to per-head routing
|
| 2736 |
-
scores. Receives gradients from task loss; balance_weight is isolated.
|
| 2737 |
-
- balance_weight (B): shape (L, embedding_width). Maps input to per-head load-balance
|
| 2738 |
-
correction scores. Receives gradients from load_balance_loss; routing_weight is
|
| 2739 |
-
isolated.
|
| 2740 |
-
|
| 2741 |
-
The two gradient-isolated base pathways over numerically identical values:
|
| 2742 |
-
|
| 2743 |
-
- semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
|
| 2744 |
-
balance_weight is isolated from task loss.
|
| 2745 |
-
- load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance gradients
|
| 2746 |
-
reach balance_weight; routing_weight and x are isolated from load balance loss.
|
| 2747 |
-
|
| 2748 |
-
Integral routing extension (routing_mode == "integral"):
|
| 2749 |
-
|
| 2750 |
-
Standard routing is parallel — each token routes based on its own hidden state alone,
|
| 2751 |
-
with no direct read on what earlier tokens in the sequence have already selected.
|
| 2752 |
-
Integral routing adds a cumulative-sum signal that gives each token a view of the
|
| 2753 |
-
prior routing history within the sequence.
|
| 2754 |
-
|
| 2755 |
-
Two additional (L, L) parameter matrices are introduced:
|
| 2756 |
-
|
| 2757 |
-
- routing_integral_weight (A'): shape (L, L). Maps the cumulative logit history to
|
| 2758 |
-
per-head semantic corrections. Receives gradients from task loss.
|
| 2759 |
-
- balance_integral_weight (B'): shape (L, L). Maps the cumulative logit history to
|
| 2760 |
-
per-head load-balance corrections. Receives gradients from load_balance_loss.
|
| 2761 |
|
| 2762 |
-
|
| 2763 |
-
sequence dimension: u[n] = sum(logits[0..n-1]), shape (B, N, L). Position 0 receives
|
| 2764 |
-
zeros (no prior history). The same gradient isolation pattern as A/B applies:
|
| 2765 |
|
| 2766 |
-
-
|
| 2767 |
-
|
|
|
|
| 2768 |
|
| 2769 |
-
|
| 2770 |
-
|
| 2771 |
-
|
|
|
|
|
|
|
| 2772 |
|
| 2773 |
-
|
| 2774 |
-
|
| 2775 |
|
| 2776 |
-
|
| 2777 |
-
|
| 2778 |
-
|
| 2779 |
-
overloaded expert).
|
| 2780 |
|
| 2781 |
-
The router
|
| 2782 |
-
loss (see load_balance_loss.py)
|
| 2783 |
-
|
| 2784 |
-
|
| 2785 |
-
|
| 2786 |
-
|
| 2787 |
-
|
| 2788 |
-
|
| 2789 |
-
|
| 2790 |
-
where f_bl is the per-batch-item realised routing frequency of head l and 1/L is the
|
| 2791 |
-
perfectly balanced target. MaxVio is averaged over batch items and is a monitoring
|
| 2792 |
-
quantity only; it never contributes gradients.
|
| 2793 |
|
| 2794 |
Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
| 2795 |
"""
|
|
@@ -2804,7 +2768,7 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
|
| 2804 |
# -----------
|
| 2805 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2806 |
|
| 2807 |
-
This module provides
|
| 2808 |
helpers, and a factory that selects among the formulations. All formulations
|
| 2809 |
share the same external contract:
|
| 2810 |
|
|
@@ -2814,9 +2778,8 @@ share the same external contract:
|
|
| 2814 |
active_mask: Tensor[B, N],
|
| 2815 |
) -> scalar Tensor
|
| 2816 |
|
| 2817 |
-
logits:
|
| 2818 |
-
|
| 2819 |
-
Gradient flows to expert_bias through this tensor.
|
| 2820 |
assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
|
| 2821 |
is 1.0 if token (b, n) was assigned to head l. Dead tokens
|
| 2822 |
should carry zero entries.
|
|
@@ -2826,17 +2789,19 @@ share the same external contract:
|
|
| 2826 |
Token reduction is split into two helpers with distinct roles:
|
| 2827 |
|
| 2828 |
reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L).
|
| 2829 |
-
Called by
|
| 2830 |
|
| 2831 |
reduce_probability_tokens — produces per-batch-item mean assignment probabilities
|
| 2832 |
-
p_bl (B, L). Called only by gshard and bce. Gradient flows
|
| 2833 |
-
|
| 2834 |
|
| 2835 |
CE delegates probability computation to F.cross_entropy, which handles its own
|
| 2836 |
log_softmax and operates directly on the raw (B, N, L) logits.
|
| 2837 |
|
| 2838 |
-
|
| 2839 |
-
|
|
|
|
|
|
|
| 2840 |
"""
|
| 2841 |
|
| 2842 |
|
|
@@ -3010,30 +2975,181 @@ def bce_loss(
|
|
| 3010 |
"""
|
| 3011 |
f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
|
| 3012 |
p_bl = reduce_probability_tokens(logits, active_mask)
|
| 3013 |
-
# Clamp
|
| 3014 |
-
#
|
| 3015 |
-
#
|
| 3016 |
-
|
| 3017 |
-
|
| 3018 |
-
|
| 3019 |
-
|
| 3020 |
-
|
| 3021 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3022 |
|
| 3023 |
|
| 3024 |
# ---------------------------------------------------------------------------
|
| 3025 |
# Factory
|
| 3026 |
# ---------------------------------------------------------------------------
|
| 3027 |
|
| 3028 |
-
|
| 3029 |
-
|
| 3030 |
-
|
| 3031 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3032 |
}
|
| 3033 |
|
| 3034 |
|
| 3035 |
def make_load_balance_loss(
|
| 3036 |
loss_type: str,
|
|
|
|
| 3037 |
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3038 |
"""Return a load-balance loss callable for the requested formulation.
|
| 3039 |
|
|
@@ -3045,11 +3161,14 @@ def make_load_balance_loss(
|
|
| 3045 |
active_mask: Tensor[B, N],
|
| 3046 |
) -> scalar Tensor
|
| 3047 |
|
| 3048 |
-
|
| 3049 |
-
|
|
|
|
| 3050 |
|
| 3051 |
Args:
|
| 3052 |
-
loss_type:
|
|
|
|
|
|
|
| 3053 |
|
| 3054 |
Returns:
|
| 3055 |
Loss callable matching the shared contract.
|
|
@@ -3062,34 +3181,29 @@ def make_load_balance_loss(
|
|
| 3062 |
raise ValueError(
|
| 3063 |
f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
|
| 3064 |
)
|
| 3065 |
-
return _LOSS_REGISTRY[loss_type]
|
| 3066 |
|
| 3067 |
|
| 3068 |
class MoSRAHRouter(nn.Module):
|
| 3069 |
"""Token-choice router for MoSRAH sparse attention.
|
| 3070 |
|
| 3071 |
-
Each input token independently selects K of the L available expert heads.
|
| 3072 |
-
|
| 3073 |
-
|
| 3074 |
-
two-pathway architecture and the integral routing extension.
|
| 3075 |
|
| 3076 |
-
|
| 3077 |
-
|
| 3078 |
-
construction.
|
| 3079 |
|
| 3080 |
Attributes:
|
| 3081 |
-
routing_weight:
|
| 3082 |
-
|
| 3083 |
-
|
| 3084 |
-
Present only when ``routing_mode == "integral"``.
|
| 3085 |
-
balance_integral_weight: B', shape (L, L). Integral load-balance pathway.
|
| 3086 |
-
Present only when ``routing_mode == "integral"``.
|
| 3087 |
-
routing_mode: ``"integral"`` or ``"default"``, from config.
|
| 3088 |
|
| 3089 |
Args:
|
| 3090 |
config: Model configuration. Must expose ``embedding_width``,
|
| 3091 |
-
``num_mosrah_heads`` (L), ``num_selected_heads`` (K),
|
| 3092 |
-
``
|
|
|
|
| 3093 |
"""
|
| 3094 |
|
| 3095 |
def __init__(self, config: ShramConfig) -> None:
|
|
@@ -3102,40 +3216,19 @@ class MoSRAHRouter(nn.Module):
|
|
| 3102 |
self.capacity = config.mosrah_packed_length
|
| 3103 |
|
| 3104 |
self.max_bid_rounds = config.max_bid_rounds
|
| 3105 |
-
self.
|
| 3106 |
-
|
| 3107 |
-
|
| 3108 |
-
|
| 3109 |
-
|
| 3110 |
-
# HuggingFace _init_weights does not override kaiming initialization.
|
| 3111 |
-
self.routing_weight = nn.Parameter(
|
| 3112 |
-
torch.empty(config.num_mosrah_heads, config.embedding_width)
|
| 3113 |
)
|
| 3114 |
-
nn.init.kaiming_uniform_(self.routing_weight)
|
| 3115 |
|
| 3116 |
-
#
|
| 3117 |
-
# correction scores (B, N, L). Receives gradients only from load_balance_loss.
|
| 3118 |
# nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
|
| 3119 |
-
self.
|
| 3120 |
torch.empty(config.num_mosrah_heads, config.embedding_width)
|
| 3121 |
)
|
| 3122 |
-
nn.init.
|
| 3123 |
-
|
| 3124 |
-
if self.routing_mode == "integral":
|
| 3125 |
-
L = config.num_mosrah_heads
|
| 3126 |
-
# A': integral semantic matrix. Maps cumulative logit history (B, N, L) to
|
| 3127 |
-
# per-head semantic corrections (B, N, L). Shape (L, L). Receives gradients
|
| 3128 |
-
# from task loss; balance_integral_weight is isolated from task loss.
|
| 3129 |
-
# Zero-initialized so that corrections start at zero and grow from gradient
|
| 3130 |
-
# updates — kaiming init produces corrections that immediately overwhelm the
|
| 3131 |
-
# base routing signal via the cumsum feedback path.
|
| 3132 |
-
self.routing_integral_weight = nn.Parameter(torch.zeros(L, L))
|
| 3133 |
-
|
| 3134 |
-
# B': integral load-balance matrix. Maps cumulative logit history (B, N, L)
|
| 3135 |
-
# to per-head load-balance corrections (B, N, L). Shape (L, L). Receives
|
| 3136 |
-
# gradients from load_balance_loss; routing_integral_weight is isolated.
|
| 3137 |
-
# Zero-initialized for the same reason as routing_integral_weight.
|
| 3138 |
-
self.balance_integral_weight = nn.Parameter(torch.zeros(L, L))
|
| 3139 |
|
| 3140 |
@staticmethod
|
| 3141 |
def get_best_proposals(
|
|
@@ -3380,226 +3473,87 @@ class MoSRAHRouter(nn.Module):
|
|
| 3380 |
Returns:
|
| 3381 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3382 |
Each token's K selected head indices, determined by TopK on
|
| 3383 |
-
capacity-balanced
|
| 3384 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3385 |
-
num_selected_heads). Gathered from pre-capacity
|
| 3386 |
selected_heads indices and renormalized to sum to 1 per token.
|
| 3387 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3388 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3389 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3390 |
-
- ``
|
| 3391 |
-
|
| 3392 |
-
- ``bias_std``: mean per-token std of balance_logits; near-zero
|
| 3393 |
-
means balance corrections have not built up relative to routing scale.
|
| 3394 |
-
- ``logit_std``: mean per-token std of semantic_logits; lower than
|
| 3395 |
-
raw_logit_std means balance is flattening preferences (healthy correction).
|
| 3396 |
-
- ``bias_alignment``: mean cosine similarity of routing_logits vs
|
| 3397 |
-
balance_logits per token. Negative means balance opposes routing direction
|
| 3398 |
-
(healthy correction); positive means runaway reinforcement.
|
| 3399 |
"""
|
| 3400 |
B, N, _ = x.shape
|
| 3401 |
L = self.num_mosrah_heads
|
| 3402 |
K = self.num_selected_heads
|
| 3403 |
|
| 3404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3405 |
|
| 3406 |
-
|
| 3407 |
-
|
| 3408 |
-
# cosine similarity. Extracted to _compute_bias_diagnostics to keep the forward
|
| 3409 |
-
# body free of non-(B,N,L) reduction logic.
|
| 3410 |
-
bias_diagnostics = self._compute_bias_diagnostics(
|
| 3411 |
-
logits["routing_logits"], logits["balance_logits"], logits["semantic_logits"]
|
| 3412 |
)
|
| 3413 |
|
| 3414 |
-
#
|
| 3415 |
-
#
|
| 3416 |
-
#
|
| 3417 |
-
|
| 3418 |
-
|
| 3419 |
-
#
|
| 3420 |
-
#
|
| 3421 |
-
|
| 3422 |
-
|
|
|
|
|
|
|
| 3423 |
used_capacity,
|
| 3424 |
self.capacity,
|
| 3425 |
self.num_selected_heads,
|
| 3426 |
self.max_bid_rounds,
|
| 3427 |
)
|
| 3428 |
-
|
| 3429 |
|
| 3430 |
-
|
| 3431 |
-
|
|
|
|
| 3432 |
|
| 3433 |
-
# Routing probabilities P: gathered from pre-capacity semantic softmax at
|
| 3434 |
-
# selected_heads positions, renormalized so they sum to 1 per token.
|
| 3435 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
|
| 3436 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 3437 |
|
| 3438 |
-
# assignment_mask: (B, N, L) float — 1.0 at each token's K selected heads, 0 elsewhere.
|
| 3439 |
-
# The discrete routing decision; no gradient flows through it. Passed alongside
|
| 3440 |
-
# load_balancing_logits and active_mask to the loss and max_vio methods, which
|
| 3441 |
-
# own all frequency aggregation and reduction internally.
|
| 3442 |
-
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3443 |
-
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 3444 |
-
|
| 3445 |
-
load_balance_loss = self._load_balance_loss(
|
| 3446 |
-
logits["load_balancing_logits"], assignment_mask, active_mask
|
| 3447 |
-
)
|
| 3448 |
-
|
| 3449 |
-
# MaxVio: detached monitoring scalar averaged over batch items. Computed from
|
| 3450 |
-
# the same (B, N, L) assignment_mask so frequencies are consistent with the loss.
|
| 3451 |
-
max_vio = self._compute_max_vio(assignment_mask, active_mask, L)
|
| 3452 |
-
|
| 3453 |
router_diagnostics = {
|
| 3454 |
"load_balance_loss": load_balance_loss,
|
| 3455 |
-
"max_vio":
|
| 3456 |
-
|
| 3457 |
}
|
| 3458 |
return selected_heads, routing_probs, router_diagnostics
|
| 3459 |
|
| 3460 |
-
|
| 3461 |
-
|
| 3462 |
-
"""Compute the exclusive cumulative sum along the sequence dimension.
|
| 3463 |
-
|
| 3464 |
-
u[n] = sum(logits[0..n-1]): position n receives the accumulated sum of all
|
| 3465 |
-
prior positions, giving it a read on the routing preferences expressed by
|
| 3466 |
-
earlier tokens in the sequence. Position 0 always receives zeros — no prior
|
| 3467 |
-
history exists at the first position.
|
| 3468 |
-
|
| 3469 |
-
Args:
|
| 3470 |
-
logits: Shape (B, N, L). Any per-head score tensor along a sequence.
|
| 3471 |
-
|
| 3472 |
-
Returns:
|
| 3473 |
-
Exclusive cumsum, shape (B, N, L). Same dtype and device as input.
|
| 3474 |
-
"""
|
| 3475 |
-
shifted = torch.cat(
|
| 3476 |
-
[torch.zeros_like(logits[:, :1, :]), logits[:, :-1, :]], dim=1
|
| 3477 |
-
)
|
| 3478 |
-
return shifted.cumsum(dim=1)
|
| 3479 |
-
|
| 3480 |
-
def _compute_routing_logits(
|
| 3481 |
-
self, x: torch.Tensor, active_mask: torch.Tensor
|
| 3482 |
-
) -> dict[str, torch.Tensor]:
|
| 3483 |
-
"""Compute the gradient-isolated logit pathways from input hidden states.
|
| 3484 |
-
|
| 3485 |
-
Base pathways (both modes):
|
| 3486 |
-
|
| 3487 |
-
Two gradient-isolated pathways over numerically identical values:
|
| 3488 |
-
- semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
|
| 3489 |
-
balance_weight is isolated from task loss.
|
| 3490 |
-
- load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance
|
| 3491 |
-
gradients reach balance_weight; routing_weight and x are isolated.
|
| 3492 |
-
|
| 3493 |
-
Integral extension (routing_mode == "integral"):
|
| 3494 |
-
|
| 3495 |
-
Dead tokens are zeroed out of the logits before computing the cumsum, so
|
| 3496 |
-
inactive positions do not contribute to the routing history of downstream
|
| 3497 |
-
live tokens. u_semantic and u_load therefore represent history from live
|
| 3498 |
-
tokens only.
|
| 3499 |
-
|
| 3500 |
-
u_semantic = exclusive_cumsum(semantic_logits * active_mask) — (B, N, L)
|
| 3501 |
-
u_load = exclusive_cumsum(load_balancing_logits * active_mask) — (B, N, L)
|
| 3502 |
-
|
| 3503 |
-
semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
|
| 3504 |
-
load_balancing_logits += (A'·u_load).detach() + B'·u_load
|
| 3505 |
-
|
| 3506 |
-
Detaching the full (B'·u_semantic) result mirrors the (B·x).detach() base
|
| 3507 |
-
pattern: it isolates balance_integral_weight from task loss AND prevents
|
| 3508 |
-
double-counting the cumsum gradient path back to routing_weight.
|
| 3509 |
-
The same reasoning applies to (A'·u_load).detach() in the load-balance
|
| 3510 |
-
pathway — u_load already has no path to routing_weight (routing_logits is
|
| 3511 |
-
detached in load_balancing_logits), and the detach additionally blocks
|
| 3512 |
-
routing_integral_weight.
|
| 3513 |
|
| 3514 |
Args:
|
| 3515 |
x: Input hidden states, shape (batch, seq_len, embedding_width).
|
| 3516 |
-
active_mask: Boolean active-token mask, shape (batch, seq_len). Dead tokens
|
| 3517 |
-
are excluded from the cumsum history in integral mode.
|
| 3518 |
|
| 3519 |
Returns:
|
| 3520 |
-
|
| 3521 |
-
- ``routing_logits``: A·x, shape (B, N, L).
|
| 3522 |
-
- ``balance_logits``: B·x, shape (B, N, L).
|
| 3523 |
-
- ``semantic_logits``: combined task-loss pathway, shape (B, N, L).
|
| 3524 |
-
- ``load_balancing_logits``: combined load-balance pathway, shape (B, N, L).
|
| 3525 |
"""
|
| 3526 |
-
|
| 3527 |
-
balance_logits = F.linear(x, self.balance_weight) # (B, N, L)
|
| 3528 |
-
semantic_logits = routing_logits + balance_logits.detach()
|
| 3529 |
-
load_balancing_logits = routing_logits.detach() + F.linear(x.detach(), self.balance_weight)
|
| 3530 |
-
|
| 3531 |
-
if self.routing_mode == "integral":
|
| 3532 |
-
# Zero out dead token positions before cumsum so inactive tokens do not
|
| 3533 |
-
# contaminate the routing history of subsequent live tokens.
|
| 3534 |
-
live = active_mask.unsqueeze(-1) # (B, N, 1)
|
| 3535 |
-
u_semantic = self.exclusive_cumsum(semantic_logits * live) # (B, N, L)
|
| 3536 |
-
u_load = self.exclusive_cumsum(load_balancing_logits * live) # (B, N, L)
|
| 3537 |
-
|
| 3538 |
-
# Semantic pathway: A' trains on task loss; B' term is fully detached to
|
| 3539 |
-
# isolate balance_integral_weight from task loss and prevent double-counting
|
| 3540 |
-
# the cumsum gradient path back to routing_weight.
|
| 3541 |
-
semantic_logits = (
|
| 3542 |
-
semantic_logits
|
| 3543 |
-
+ F.linear(u_semantic, self.routing_integral_weight)
|
| 3544 |
-
+ F.linear(u_semantic, self.balance_integral_weight).detach()
|
| 3545 |
-
)
|
| 3546 |
-
|
| 3547 |
-
# Load-balance pathway: B' trains on load_balance_loss; A' term is fully
|
| 3548 |
-
# detached to isolate routing_integral_weight from load_balance_loss.
|
| 3549 |
-
load_balancing_logits = (
|
| 3550 |
-
load_balancing_logits
|
| 3551 |
-
+ F.linear(u_load, self.routing_integral_weight).detach()
|
| 3552 |
-
+ F.linear(u_load, self.balance_integral_weight)
|
| 3553 |
-
)
|
| 3554 |
-
|
| 3555 |
-
return {
|
| 3556 |
-
"routing_logits": routing_logits,
|
| 3557 |
-
"balance_logits": balance_logits,
|
| 3558 |
-
"semantic_logits": semantic_logits,
|
| 3559 |
-
"load_balancing_logits": load_balancing_logits,
|
| 3560 |
-
}
|
| 3561 |
-
|
| 3562 |
-
@staticmethod
|
| 3563 |
-
def _compute_bias_diagnostics(
|
| 3564 |
-
routing_logits: torch.Tensor,
|
| 3565 |
-
balance_logits: torch.Tensor,
|
| 3566 |
-
semantic_logits: torch.Tensor,
|
| 3567 |
-
) -> dict[str, torch.Tensor]:
|
| 3568 |
-
"""Compute detached diagnostic scalars characterising the two routing pathways.
|
| 3569 |
-
|
| 3570 |
-
All scalars must be computed from pre-capacity logits; balance_capacity
|
| 3571 |
-
applies -1e8 sentinels that would corrupt std and cosine similarity.
|
| 3572 |
-
Extracted from forward to keep the main body free of reduction logic.
|
| 3573 |
-
|
| 3574 |
-
Args:
|
| 3575 |
-
routing_logits: A·x, routing pathway output, shape (B, N, L).
|
| 3576 |
-
balance_logits: B·x, balance pathway output, shape (B, N, L).
|
| 3577 |
-
semantic_logits: A·x + (B·x).detach(), combined signal, shape (B, N, L).
|
| 3578 |
-
|
| 3579 |
-
Returns:
|
| 3580 |
-
Dict with keys:
|
| 3581 |
-
- ``raw_logit_std``: Mean per-token std of routing_logits. Natural
|
| 3582 |
-
routing preference scale; reference baseline for
|
| 3583 |
-
interpreting bias_std.
|
| 3584 |
-
- ``bias_std``: Mean per-token std of balance_logits. Near-zero
|
| 3585 |
-
means balance corrections have not built up
|
| 3586 |
-
relative to the routing scale.
|
| 3587 |
-
- ``logit_std``: Mean per-token std of semantic_logits. Lower than
|
| 3588 |
-
raw_logit_std indicates balance is flattening
|
| 3589 |
-
preferences (healthy correction signal).
|
| 3590 |
-
- ``bias_alignment``: Mean cosine similarity of routing_logits vs
|
| 3591 |
-
balance_logits per token. Range [-1, 1]. Negative
|
| 3592 |
-
means balance opposes routing direction (healthy
|
| 3593 |
-
correction); positive means runaway reinforcement.
|
| 3594 |
-
"""
|
| 3595 |
-
return {
|
| 3596 |
-
"raw_logit_std": routing_logits.std(dim=-1).mean().detach(),
|
| 3597 |
-
"bias_std": balance_logits.std(dim=-1).mean().detach(),
|
| 3598 |
-
"logit_std": semantic_logits.std(dim=-1).mean().detach(),
|
| 3599 |
-
"bias_alignment": F.cosine_similarity(
|
| 3600 |
-
routing_logits, balance_logits, dim=-1
|
| 3601 |
-
).mean().detach(),
|
| 3602 |
-
}
|
| 3603 |
|
| 3604 |
@staticmethod
|
| 3605 |
def _compute_max_vio(
|
|
@@ -4137,30 +4091,15 @@ class ShramModel(nn.Module):
|
|
| 4137 |
- ``"max_vio"``: detached scalar maximum routing-imbalance across
|
| 4138 |
all decoder layers. Zero means perfectly balanced routing across
|
| 4139 |
every layer; higher values identify the worst-case head imbalance.
|
| 4140 |
-
- ``"bias_std"``: detached scalar — mean across layers of the std
|
| 4141 |
-
of each layer's expert bias vector. Near-zero means corrections
|
| 4142 |
-
have not built up; large relative to ``raw_logit_std`` means the
|
| 4143 |
-
bias dominates routing.
|
| 4144 |
-
- ``"raw_logit_std"``: detached scalar — mean across layers of the
|
| 4145 |
-
per-token routing logit spread before bias addition. Baseline
|
| 4146 |
-
natural routing preference scale.
|
| 4147 |
- ``"logit_std"``: detached scalar — mean across layers of the
|
| 4148 |
-
per-token
|
| 4149 |
-
|
| 4150 |
-
amplification.
|
| 4151 |
-
- ``"bias_alignment"``: detached scalar — mean across layers of the
|
| 4152 |
-
per-token cosine similarity between the expert bias vector and the
|
| 4153 |
-
routing logits. Negative is healthy correction; positive is
|
| 4154 |
-
runaway feedback.
|
| 4155 |
"""
|
| 4156 |
hidden_states = inputs_embeds
|
| 4157 |
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 4158 |
total_load_balance_loss = inputs_embeds.new_zeros(())
|
| 4159 |
max_vio = inputs_embeds.new_zeros(())
|
| 4160 |
-
total_bias_std = inputs_embeds.new_zeros(())
|
| 4161 |
-
total_raw_logit_std = inputs_embeds.new_zeros(())
|
| 4162 |
total_logit_std = inputs_embeds.new_zeros(())
|
| 4163 |
-
total_bias_alignment = inputs_embeds.new_zeros(())
|
| 4164 |
|
| 4165 |
for layer_idx, layer in enumerate(self.layers):
|
| 4166 |
layer_cache = None if cache is None else cache.layers[layer_idx]
|
|
@@ -4172,10 +4111,7 @@ class ShramModel(nn.Module):
|
|
| 4172 |
)
|
| 4173 |
total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
|
| 4174 |
max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
|
| 4175 |
-
total_bias_std = total_bias_std + layer_diagnostics["bias_std"]
|
| 4176 |
-
total_raw_logit_std = total_raw_logit_std + layer_diagnostics["raw_logit_std"]
|
| 4177 |
total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
|
| 4178 |
-
total_bias_alignment = total_bias_alignment + layer_diagnostics["bias_alignment"]
|
| 4179 |
|
| 4180 |
if output_hidden_states:
|
| 4181 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
@@ -4189,10 +4125,7 @@ class ShramModel(nn.Module):
|
|
| 4189 |
"hidden_states": all_hidden_states,
|
| 4190 |
"load_balance_loss": total_load_balance_loss,
|
| 4191 |
"max_vio": max_vio,
|
| 4192 |
-
"bias_std": total_bias_std / num_layers,
|
| 4193 |
-
"raw_logit_std": total_raw_logit_std / num_layers,
|
| 4194 |
"logit_std": total_logit_std / num_layers,
|
| 4195 |
-
"bias_alignment": total_bias_alignment / num_layers,
|
| 4196 |
}
|
| 4197 |
|
| 4198 |
|
|
@@ -4209,17 +4142,14 @@ class ShramCausalLMOutput(CausalLMOutputWithPast):
|
|
| 4209 |
## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
|
| 4210 |
## fields to None, which forces every subclass field to also carry a default.
|
| 4211 |
## The = None below is a language constraint, not a semantic statement. In
|
| 4212 |
-
## practice, load_balance_loss, max_vio,
|
| 4213 |
-
##
|
| 4214 |
-
##
|
| 4215 |
|
| 4216 |
ce_loss: torch.FloatTensor | None = None
|
| 4217 |
load_balance_loss: torch.FloatTensor | None = None
|
| 4218 |
max_vio: torch.FloatTensor | None = None
|
| 4219 |
-
bias_std: torch.Tensor | None = None
|
| 4220 |
-
raw_logit_std: torch.Tensor | None = None
|
| 4221 |
logit_std: torch.Tensor | None = None
|
| 4222 |
-
bias_alignment: torch.Tensor | None = None
|
| 4223 |
|
| 4224 |
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
| 4225 |
"""HuggingFace-facing causal language model wrapper for SHRAM.
|
|
@@ -4668,9 +4598,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4668 |
- ``hidden_states`` when requested,
|
| 4669 |
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
|
| 4670 |
- ``max_vio`` — detached worst-case routing imbalance across layers,
|
| 4671 |
-
- ``
|
| 4672 |
-
detached load-balance health scalars averaged across decoder layers;
|
| 4673 |
-
see ``ShramModel`` for interpretation.
|
| 4674 |
"""
|
| 4675 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 4676 |
output_hidden_states = (
|
|
@@ -4777,8 +4705,5 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 4777 |
hidden_states=backbone_outputs["hidden_states"],
|
| 4778 |
load_balance_loss=backbone_outputs["load_balance_loss"],
|
| 4779 |
max_vio=backbone_outputs["max_vio"],
|
| 4780 |
-
bias_std=backbone_outputs["bias_std"],
|
| 4781 |
-
raw_logit_std=backbone_outputs["raw_logit_std"],
|
| 4782 |
logit_std=backbone_outputs["logit_std"],
|
| 4783 |
-
bias_alignment=backbone_outputs["bias_alignment"],
|
| 4784 |
)
|
|
|
|
| 178 |
correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
|
| 179 |
Default 10.
|
| 180 |
load_balance_loss_type: Formula used for the load-balance auxiliary loss.
|
| 181 |
+
One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
|
| 182 |
+
``"temporal_overcapacity"`` is the default; it fires only when an expert
|
| 183 |
+
exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
|
| 184 |
+
and shuts off automatically once routing is balanced, allowing it to be
|
| 185 |
+
used with a strong weight without interfering with task training during
|
| 186 |
+
balanced routing. Default ``"temporal_overcapacity"``.
|
| 187 |
+
maximum_expert_overclaim: Maximum number of tokens an expert may receive above
|
| 188 |
+
its ideal allocation trajectory before the temporal overcapacity loss
|
| 189 |
+
fires. A value of 0 means violations trigger immediately at any imbalance.
|
| 190 |
+
Larger values permit short-lived semantic specialization before correction.
|
| 191 |
+
Only used when ``load_balance_loss_type="temporal_overcapacity"``.
|
| 192 |
+
Must be non-negative. Default 20.
|
| 193 |
"""
|
| 194 |
|
| 195 |
model_type = "shram"
|
|
|
|
| 224 |
tie_word_embeddings: bool = False,
|
| 225 |
mosrah_overallocation_factor: float = 2.0,
|
| 226 |
max_bid_rounds: int = 10,
|
| 227 |
+
load_balance_loss_type: str = "temporal_overcapacity",
|
| 228 |
+
maximum_expert_overclaim: int = 20,
|
| 229 |
**kwargs
|
| 230 |
):
|
| 231 |
if head_dim % 2 != 0:
|
|
|
|
| 266 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
| 267 |
)
|
| 268 |
|
| 269 |
+
if maximum_expert_overclaim < 0:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"maximum_expert_overclaim must be non-negative, "
|
| 272 |
+
f"got {maximum_expert_overclaim}."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
_supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
|
| 276 |
if load_balance_loss_type not in _supported_loss_types:
|
| 277 |
supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
|
| 278 |
raise ValueError(
|
|
|
|
| 280 |
f"got {load_balance_loss_type!r}."
|
| 281 |
)
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
self.vocab_size = vocab_size
|
| 284 |
self.embedding_width = embedding_width
|
| 285 |
self.mlp_width = mlp_width
|
|
|
|
| 300 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 301 |
self.max_bid_rounds = max_bid_rounds
|
| 302 |
self.load_balance_loss_type = load_balance_loss_type
|
| 303 |
+
self.maximum_expert_overclaim = maximum_expert_overclaim
|
| 304 |
self.attention_dropout = attention_dropout
|
| 305 |
self.use_cache = use_cache
|
| 306 |
|
|
|
|
| 1478 |
- "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
|
| 1479 |
- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
|
| 1480 |
- "max_vio": detached scalar maximum routing-imbalance across all decoder layers
|
| 1481 |
+
- "logit_std": detached scalar mean per-layer per-token routing logit spread
|
|
|
|
|
|
|
|
|
|
| 1482 |
"""
|
| 1483 |
|
| 1484 |
|
|
|
|
| 2722 |
paper. Given an input hidden state x, the router produces two outputs used downstream:
|
| 2723 |
|
| 2724 |
- selected_heads (I): which K of the L available expert heads each token routes to,
|
| 2725 |
+
determined by TopK over capacity-balanced routing scores.
|
| 2726 |
- routing_probs (P): the weights used for the weighted output reduction, gathered from
|
| 2727 |
+
the routing scores at the selected indices and renormalized to sum to 1 per token.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2728 |
|
| 2729 |
+
Routing uses a single learnable projection:
|
|
|
|
|
|
|
| 2730 |
|
| 2731 |
+
- routing_weight: shape (L, embedding_width). Maps input to per-head routing scores.
|
| 2732 |
+
Both task loss and load_balance_loss train this parameter directly — there is no
|
| 2733 |
+
gradient isolation between the two signals.
|
| 2734 |
|
| 2735 |
+
This coupled design is intentional. SHRAM has an unusually strong task-level incentive
|
| 2736 |
+
to concentrate tokens into the same expert bucket (sparse attention only occurs among
|
| 2737 |
+
tokens routed to the same expert), so any indirect balancing pathway will be outlearned.
|
| 2738 |
+
Coupling the gradients allows the load balance loss to act with full strength directly
|
| 2739 |
+
on the parameter that determines routing.
|
| 2740 |
|
| 2741 |
+
routing_weight is nn.Parameter so that HuggingFace _init_weights does not override
|
| 2742 |
+
its kaiming initialization at construction.
|
| 2743 |
|
| 2744 |
+
routing_probs are computed before balance_capacity applies -1e8 sentinels. Post-capacity
|
| 2745 |
+
softmax would corrupt routing_probs for over-capacity experts (near-zero probability
|
| 2746 |
+
after masking does not reflect genuine routing preference).
|
|
|
|
| 2747 |
|
| 2748 |
+
The router computes and returns:
|
| 2749 |
+
- load_balance_loss: scalar auxiliary loss (see load_balance_loss.py); gradient flows
|
| 2750 |
+
to routing_weight.
|
| 2751 |
+
- max_vio: detached scalar summarising routing imbalance:
|
| 2752 |
+
MaxVio = mean_b( L · max_l(f_bl − 1/L) )
|
| 2753 |
+
where f_bl is the per-batch-item realised routing frequency of head l. Zero means
|
| 2754 |
+
perfect balance; 1.0 means the most loaded head received double its fair share.
|
| 2755 |
+
- logit_std: detached scalar; mean per-token standard deviation of routing logits.
|
| 2756 |
+
Monitoring metric for routing sharpness.
|
|
|
|
|
|
|
|
|
|
| 2757 |
|
| 2758 |
Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
| 2759 |
"""
|
|
|
|
| 2768 |
# -----------
|
| 2769 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2770 |
|
| 2771 |
+
This module provides four load-balance loss formulations, two token-reduction
|
| 2772 |
helpers, and a factory that selects among the formulations. All formulations
|
| 2773 |
share the same external contract:
|
| 2774 |
|
|
|
|
| 2778 |
active_mask: Tensor[B, N],
|
| 2779 |
) -> scalar Tensor
|
| 2780 |
|
| 2781 |
+
logits: Pre-softmax routing scores, shape (B, N, L). Gradient flows
|
| 2782 |
+
through this tensor.
|
|
|
|
| 2783 |
assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
|
| 2784 |
is 1.0 if token (b, n) was assigned to head l. Dead tokens
|
| 2785 |
should carry zero entries.
|
|
|
|
| 2789 |
Token reduction is split into two helpers with distinct roles:
|
| 2790 |
|
| 2791 |
reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L).
|
| 2792 |
+
Called by gshard, ce, and bce. Output is detached; f_bl carries no gradient.
|
| 2793 |
|
| 2794 |
reduce_probability_tokens — produces per-batch-item mean assignment probabilities
|
| 2795 |
+
p_bl (B, L). Called only by gshard and bce. Gradient flows through the
|
| 2796 |
+
internal softmax over logits.
|
| 2797 |
|
| 2798 |
CE delegates probability computation to F.cross_entropy, which handles its own
|
| 2799 |
log_softmax and operates directly on the raw (B, N, L) logits.
|
| 2800 |
|
| 2801 |
+
``make_load_balance_loss`` is the sole public entry point. The individual loss
|
| 2802 |
+
functions are internal implementation details; their signatures may change between
|
| 2803 |
+
units. Callers and tests must construct loss callables through the factory, not by
|
| 2804 |
+
importing or invoking the loss functions directly.
|
| 2805 |
"""
|
| 2806 |
|
| 2807 |
|
|
|
|
| 2975 |
"""
|
| 2976 |
f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
|
| 2977 |
p_bl = reduce_probability_tokens(logits, active_mask)
|
| 2978 |
+
# Clamp for numerical safety: softmax outputs are strictly positive in
|
| 2979 |
+
# normal operation; the clamp guards the all-dead-tokens edge case where
|
| 2980 |
+
# the mean defaults to zero. log1p(-p) avoids cancellation near p=1.
|
| 2981 |
+
p = p_bl.clamp(min=1e-7, max=1.0 - 1e-7)
|
| 2982 |
+
target = 1.0 - f_bl
|
| 2983 |
+
return -(target * torch.log(p) + (1.0 - target) * torch.log1p(-p)).mean()
|
| 2984 |
+
|
| 2985 |
+
|
| 2986 |
+
def _temporal_overcapacity_loss(
|
| 2987 |
+
logits: torch.Tensor,
|
| 2988 |
+
assignment_mask: torch.Tensor,
|
| 2989 |
+
active_mask: torch.Tensor,
|
| 2990 |
+
expected_tokens_rate: float,
|
| 2991 |
+
maximum_expert_overclaim: int,
|
| 2992 |
+
) -> torch.Tensor:
|
| 2993 |
+
"""Temporal overcapacity loss for MoSRAH load balancing.
|
| 2994 |
+
|
| 2995 |
+
Penalises routing decisions that select a head already overloaded relative to
|
| 2996 |
+
its ideal allocation trajectory. A head is considered overloaded when the number
|
| 2997 |
+
of active tokens before position n assigned to that head exceeds
|
| 2998 |
+
cumulative_active_tokens * M + C, where M is the expected_tokens_rate (K/L) and
|
| 2999 |
+
C is the maximum_expert_overclaim slack.
|
| 3000 |
+
|
| 3001 |
+
Loss is exactly zero when no head exceeds its trajectory, making it safe to
|
| 3002 |
+
weight strongly — it stays out of the way when routing is balanced.
|
| 3003 |
+
|
| 3004 |
+
Args:
|
| 3005 |
+
logits: Pre-softmax routing scores, shape (B, N, L).
|
| 3006 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 3007 |
+
1.0 if token (b, n) is assigned to head l.
|
| 3008 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 3009 |
+
expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed
|
| 3010 |
+
by the factory so the division is not repeated each
|
| 3011 |
+
forward pass.
|
| 3012 |
+
maximum_expert_overclaim (C): Slack above the ideal trajectory before
|
| 3013 |
+
imbalance fires. Larger C tolerates more deviation.
|
| 3014 |
+
|
| 3015 |
+
Returns:
|
| 3016 |
+
Scalar loss tensor. Exactly 0.0 when no head exceeds its allowed trajectory.
|
| 3017 |
+
"""
|
| 3018 |
+
# ── Algorithm overview ──────────────────────────────────────────────────────
|
| 3019 |
+
#
|
| 3020 |
+
# Problem: token routing is stateless — each token's TopK selection is blind to
|
| 3021 |
+
# how many times each expert has already been chosen earlier in the sequence. A
|
| 3022 |
+
# router that develops a strong preference for certain experts will overload them
|
| 3023 |
+
# far beyond their K/L fair share with no correction signal at the moment of
|
| 3024 |
+
# selection.
|
| 3025 |
+
#
|
| 3026 |
+
# Approach: track per-head assignment history as exclusive cumulative counts
|
| 3027 |
+
# (assignments by all active tokens strictly before position n) and compare
|
| 3028 |
+
# against an ideal trajectory S·M, where S is the inclusive cumulative active
|
| 3029 |
+
# token count and M is the amount of tokens expected given ideal balancing
|
| 3030 |
+
# A head is overloaded when its prior count exceeds that trajectory
|
| 3031 |
+
# by more than C. When a token selects an already-overloaded head, the loss
|
| 3032 |
+
# moment — mean(violating logits) minus mean(non-overloaded logits) — penalises
|
| 3033 |
+
# the gap and pushes future routing toward underloaded alternatives.
|
| 3034 |
+
|
| 3035 |
+
# ── Routing history and imbalance threshold ──────────────────────────────────
|
| 3036 |
+
#
|
| 3037 |
+
# prior_assignment_counts is the exclusive routing history at each position:
|
| 3038 |
+
# active assignments to each head by all tokens strictly before position n.
|
| 3039 |
+
# Exclusive because it reflects only what was known when token n was being routed.
|
| 3040 |
+
# cumulative_active_tokens grows by 1 per active token; the ideal per-head
|
| 3041 |
+
# allocation at n is S·M. Exceeding that by more than C triggers imbalance.
|
| 3042 |
+
|
| 3043 |
+
active_float = active_mask.float() # (B, N)
|
| 3044 |
+
active_assignments = assignment_mask * active_float.unsqueeze(-1) # (B, N, L)
|
| 3045 |
+
|
| 3046 |
+
# exclusive cumsums: subtract self to exclude position n
|
| 3047 |
+
prior_assignment_counts = active_assignments.cumsum(dim=1) - active_assignments # (B, N, L)
|
| 3048 |
+
cumulative_active_tokens = active_float.cumsum(dim=1) - active_float # (B, N)
|
| 3049 |
+
|
| 3050 |
+
maximum_supportable_assignments = (
|
| 3051 |
+
cumulative_active_tokens.unsqueeze(-1) * expected_tokens_rate
|
| 3052 |
+
+ maximum_expert_overclaim
|
| 3053 |
+
) # (B, N, 1) → broadcasts to (B, N, L)
|
| 3054 |
+
|
| 3055 |
+
# ── Mask construction ────────────────────────────────────────────────────────
|
| 3056 |
+
#
|
| 3057 |
+
# Three derived masks:
|
| 3058 |
+
# imbalance_mask: any head exceeding its trajectory.
|
| 3059 |
+
# violating_selection_mask: selected AND imbalanced — the penalty target.
|
| 3060 |
+
# non_overloaded_head_mask: NOT imbalanced, regardless of selection.
|
| 3061 |
+
#
|
| 3062 |
+
# Masking is deliberately assymetric. We have a problem when something is over
|
| 3063 |
+
# capacity AND gets chosen by topk. We can transfer it elsewhere only if we
|
| 3064 |
+
# are not overcapacity.
|
| 3065 |
+
|
| 3066 |
+
imbalance_mask = prior_assignment_counts > maximum_supportable_assignments # (B, N, L)
|
| 3067 |
+
violating_selection_mask = assignment_mask.bool() & imbalance_mask # (B, N, L)
|
| 3068 |
+
non_overloaded_head_mask = ~imbalance_mask # (B, N, L)
|
| 3069 |
+
has_violation_mask = violating_selection_mask.any(dim=-1) # (B, N)
|
| 3070 |
+
|
| 3071 |
+
# ── Loss moment ────────────────────────────────────────────────────────
|
| 3072 |
+
#
|
| 3073 |
+
# Epsilons on the count denominators guard against NaN when violation_count or
|
| 3074 |
+
# non_overloaded_count is zero. has_violation_mask zeros positions with no
|
| 3075 |
+
# violations at the gating step, so the epsilon-inflated denominator never
|
| 3076 |
+
# contributes to the loss.
|
| 3077 |
+
#
|
| 3078 |
+
# One notable property of this moment is it keeps the amount of transferred
|
| 3079 |
+
# logit mass constant. That is the gradient reduces violating logits and increases
|
| 3080 |
+
# non-overloaded logits by equal magnitude. Routing is redirected, not suppressed.
|
| 3081 |
+
|
| 3082 |
+
violation_count = violating_selection_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
|
| 3083 |
+
non_overloaded_count = non_overloaded_head_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
|
| 3084 |
+
mean_violating_logit = (violating_selection_mask.float() * logits).sum(dim=-1) / violation_count # (B, N)
|
| 3085 |
+
mean_non_overloaded_logit = (non_overloaded_head_mask.float() * logits).sum(dim=-1) / non_overloaded_count # (B, N)
|
| 3086 |
+
raw_loss = mean_violating_logit - mean_non_overloaded_logit # (B, N)
|
| 3087 |
+
|
| 3088 |
+
# ── Loss reduction ───────────────────────────────────────────────────────────
|
| 3089 |
+
#
|
| 3090 |
+
# Reduction is over active positions only; dead tokens are excluded from both
|
| 3091 |
+
# numerator (gated by active_float) and denominator (active_count_per_seq).
|
| 3092 |
+
# clamp(min=1.0) handles the all-dead-tokens edge case: gated_loss is zero
|
| 3093 |
+
# there since active_float gates it, so the result is 0/1 = 0.
|
| 3094 |
+
#
|
| 3095 |
+
# Exact-zero guarantee: when no head exceeds its trajectory, has_violation_mask
|
| 3096 |
+
# is all-False, gated_loss is zeroed everywhere, and the scalar return is
|
| 3097 |
+
# exactly 0.0. The loss is inert when routing is balanced.
|
| 3098 |
+
|
| 3099 |
+
gated_loss = active_float * has_violation_mask.float() * raw_loss # (B, N)
|
| 3100 |
+
active_count_per_seq = active_float.sum(dim=1).clamp(min=1.0) # (B,)
|
| 3101 |
+
sequence_loss = gated_loss.sum(dim=1) / active_count_per_seq # (B,)
|
| 3102 |
+
final_loss = sequence_loss.mean()
|
| 3103 |
+
return final_loss
|
| 3104 |
|
| 3105 |
|
| 3106 |
# ---------------------------------------------------------------------------
|
| 3107 |
# Factory
|
| 3108 |
# ---------------------------------------------------------------------------
|
| 3109 |
|
| 3110 |
+
def _gshard_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3111 |
+
return gshard_loss
|
| 3112 |
+
|
| 3113 |
+
|
| 3114 |
+
def _ce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3115 |
+
return ce_loss
|
| 3116 |
+
|
| 3117 |
+
|
| 3118 |
+
def _bce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3119 |
+
return bce_loss
|
| 3120 |
+
|
| 3121 |
+
|
| 3122 |
+
def _temporal_overcapacity_factory(
|
| 3123 |
+
num_selected_heads: int,
|
| 3124 |
+
num_total_heads: int,
|
| 3125 |
+
maximum_expert_overclaim: int,
|
| 3126 |
+
**kwargs: object,
|
| 3127 |
+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3128 |
+
expected_tokens_rate = num_selected_heads / num_total_heads
|
| 3129 |
+
def _runtime(
|
| 3130 |
+
logits: torch.Tensor,
|
| 3131 |
+
assignment_mask: torch.Tensor,
|
| 3132 |
+
active_mask: torch.Tensor,
|
| 3133 |
+
) -> torch.Tensor:
|
| 3134 |
+
return _temporal_overcapacity_loss(
|
| 3135 |
+
logits, assignment_mask, active_mask,
|
| 3136 |
+
expected_tokens_rate=expected_tokens_rate,
|
| 3137 |
+
maximum_expert_overclaim=maximum_expert_overclaim,
|
| 3138 |
+
)
|
| 3139 |
+
return _runtime
|
| 3140 |
+
|
| 3141 |
+
|
| 3142 |
+
_LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
|
| 3143 |
+
"gshard": _gshard_factory,
|
| 3144 |
+
"ce": _ce_factory,
|
| 3145 |
+
"bce": _bce_factory,
|
| 3146 |
+
"temporal_overcapacity": _temporal_overcapacity_factory,
|
| 3147 |
}
|
| 3148 |
|
| 3149 |
|
| 3150 |
def make_load_balance_loss(
|
| 3151 |
loss_type: str,
|
| 3152 |
+
**loss_parameters: object,
|
| 3153 |
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3154 |
"""Return a load-balance loss callable for the requested formulation.
|
| 3155 |
|
|
|
|
| 3161 |
active_mask: Tensor[B, N],
|
| 3162 |
) -> scalar Tensor
|
| 3163 |
|
| 3164 |
+
Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
|
| 3165 |
+
factories silently ignore all kwargs; this allows callers to pass loss-type-specific
|
| 3166 |
+
parameters (e.g. for temporal_overcapacity) without branching on loss_type.
|
| 3167 |
|
| 3168 |
Args:
|
| 3169 |
+
loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``, or
|
| 3170 |
+
``"temporal_overcapacity"``.
|
| 3171 |
+
**loss_parameters: Construction-time parameters forwarded to the factory.
|
| 3172 |
|
| 3173 |
Returns:
|
| 3174 |
Loss callable matching the shared contract.
|
|
|
|
| 3181 |
raise ValueError(
|
| 3182 |
f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
|
| 3183 |
)
|
| 3184 |
+
return _LOSS_REGISTRY[loss_type](**loss_parameters)
|
| 3185 |
|
| 3186 |
|
| 3187 |
class MoSRAHRouter(nn.Module):
|
| 3188 |
"""Token-choice router for MoSRAH sparse attention.
|
| 3189 |
|
| 3190 |
+
Each input token independently selects K of the L available expert heads.
|
| 3191 |
+
A single routing projection maps input hidden states to per-head scores; both
|
| 3192 |
+
task loss and load_balance_loss train this projection directly.
|
|
|
|
| 3193 |
|
| 3194 |
+
routing_weight is nn.Parameter rather than nn.Linear so that HuggingFace
|
| 3195 |
+
_init_weights does not override its kaiming initialization at construction.
|
|
|
|
| 3196 |
|
| 3197 |
Attributes:
|
| 3198 |
+
routing_weight: Shape (L, embedding_width). Maps input hidden states to
|
| 3199 |
+
per-head routing scores. Receives gradients from both task loss and
|
| 3200 |
+
load_balance_loss.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3201 |
|
| 3202 |
Args:
|
| 3203 |
config: Model configuration. Must expose ``embedding_width``,
|
| 3204 |
+
``num_mosrah_heads`` (L), ``num_selected_heads`` (K),
|
| 3205 |
+
``load_balance_loss_type``, ``maximum_expert_overclaim``, ``max_bid_rounds``,
|
| 3206 |
+
``use_cache``, ``mosrah_cache_length``, and ``mosrah_packed_length``.
|
| 3207 |
"""
|
| 3208 |
|
| 3209 |
def __init__(self, config: ShramConfig) -> None:
|
|
|
|
| 3216 |
self.capacity = config.mosrah_packed_length
|
| 3217 |
|
| 3218 |
self.max_bid_rounds = config.max_bid_rounds
|
| 3219 |
+
self._load_balance_loss = make_load_balance_loss(
|
| 3220 |
+
config.load_balance_loss_type,
|
| 3221 |
+
num_selected_heads=config.num_selected_heads,
|
| 3222 |
+
num_total_heads=config.num_mosrah_heads,
|
| 3223 |
+
maximum_expert_overclaim=config.maximum_expert_overclaim,
|
|
|
|
|
|
|
|
|
|
| 3224 |
)
|
|
|
|
| 3225 |
|
| 3226 |
+
# Routing projection: maps input (B, N, d) to per-head routing scores (B, N, L).
|
|
|
|
| 3227 |
# nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
|
| 3228 |
+
self.routing_weight = nn.Parameter(
|
| 3229 |
torch.empty(config.num_mosrah_heads, config.embedding_width)
|
| 3230 |
)
|
| 3231 |
+
nn.init.kaiming_normal_(self.routing_weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3232 |
|
| 3233 |
@staticmethod
|
| 3234 |
def get_best_proposals(
|
|
|
|
| 3473 |
Returns:
|
| 3474 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3475 |
Each token's K selected head indices, determined by TopK on
|
| 3476 |
+
capacity-balanced routing scores.
|
| 3477 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3478 |
+
num_selected_heads). Gathered from pre-capacity routing softmax at
|
| 3479 |
selected_heads indices and renormalized to sum to 1 per token.
|
| 3480 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3481 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3482 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3483 |
+
- ``logit_std``: detached mean per-token std of routing logits;
|
| 3484 |
+
monitoring metric for routing sharpness.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3485 |
"""
|
| 3486 |
B, N, _ = x.shape
|
| 3487 |
L = self.num_mosrah_heads
|
| 3488 |
K = self.num_selected_heads
|
| 3489 |
|
| 3490 |
+
# ── Phase: pre-capacity scoring ───────────────────────────────────────
|
| 3491 |
+
#
|
| 3492 |
+
# Establishes the clean pre-sentinel distribution that all downstream
|
| 3493 |
+
# consumers draw from. logit_std must be captured here — balance_capacity
|
| 3494 |
+
# injects -1e8 sentinels that would corrupt the standard deviation.
|
| 3495 |
+
# routing_scores is the pre-capacity probability distribution; both the
|
| 3496 |
+
# load balance signal and the final routing_probs gather from it.
|
| 3497 |
+
routing_logits = self._compute_routing_logits(x) # (B, N, L)
|
| 3498 |
+
logit_std = routing_logits.std(dim=-1).mean().detach()
|
| 3499 |
+
routing_scores = F.softmax(routing_logits, dim=-1) # (B, N, L)
|
| 3500 |
+
|
| 3501 |
+
# ── Phase: load balance signal ────────────────────────────────────────
|
| 3502 |
+
#
|
| 3503 |
+
# The loss must observe the unconstrained routing decision — the genuine
|
| 3504 |
+
# routing pressure before capacity enforcement masks any imbalance.
|
| 3505 |
+
# pre_cap_heads and assignment_mask exist solely to give the loss this
|
| 3506 |
+
# honest view; nothing downstream uses them.
|
| 3507 |
+
pre_cap_heads = routing_scores.topk(K, dim=-1).indices # (B, N, K)
|
| 3508 |
+
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3509 |
+
assignment_mask.scatter_(-1, pre_cap_heads, 1.0)
|
| 3510 |
|
| 3511 |
+
load_balance_loss = self._load_balance_loss(
|
| 3512 |
+
routing_logits, assignment_mask, active_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3513 |
)
|
| 3514 |
|
| 3515 |
+
# ── Phase: capacity enforcement and final selection ───────────────────
|
| 3516 |
+
#
|
| 3517 |
+
# Produces the capacity-enforced routing that all downstream consumers
|
| 3518 |
+
# depend on. max_vio is computed here because it measures realized routing
|
| 3519 |
+
# imbalance — the actual post-capacity assignment, not the unconstrained
|
| 3520 |
+
# preference. routing_probs are gathered from the pre-capacity routing_scores
|
| 3521 |
+
# (not the balanced distribution) to avoid sentinel corruption — overloaded
|
| 3522 |
+
# experts would otherwise receive near-zero probability regardless of genuine
|
| 3523 |
+
# routing preference.
|
| 3524 |
+
balanced_logits = self.balance_capacity(
|
| 3525 |
+
routing_logits,
|
| 3526 |
used_capacity,
|
| 3527 |
self.capacity,
|
| 3528 |
self.num_selected_heads,
|
| 3529 |
self.max_bid_rounds,
|
| 3530 |
)
|
| 3531 |
+
selected_heads = F.softmax(balanced_logits, dim=-1).topk(K, dim=-1).indices # (B, N, K)
|
| 3532 |
|
| 3533 |
+
realized_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3534 |
+
realized_mask.scatter_(-1, selected_heads, 1.0)
|
| 3535 |
+
max_vio = self._compute_max_vio(realized_mask, active_mask, L)
|
| 3536 |
|
|
|
|
|
|
|
| 3537 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
|
| 3538 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 3539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3540 |
router_diagnostics = {
|
| 3541 |
"load_balance_loss": load_balance_loss,
|
| 3542 |
+
"max_vio": max_vio,
|
| 3543 |
+
"logit_std": logit_std,
|
| 3544 |
}
|
| 3545 |
return selected_heads, routing_probs, router_diagnostics
|
| 3546 |
|
| 3547 |
+
def _compute_routing_logits(self, x: torch.Tensor) -> torch.Tensor:
|
| 3548 |
+
"""Compute per-head routing logits from input hidden states.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3549 |
|
| 3550 |
Args:
|
| 3551 |
x: Input hidden states, shape (batch, seq_len, embedding_width).
|
|
|
|
|
|
|
| 3552 |
|
| 3553 |
Returns:
|
| 3554 |
+
Routing logits, shape (batch, seq_len, num_mosrah_heads).
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3555 |
"""
|
| 3556 |
+
return F.linear(x, self.routing_weight) # (B, N, L)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3557 |
|
| 3558 |
@staticmethod
|
| 3559 |
def _compute_max_vio(
|
|
|
|
| 4091 |
- ``"max_vio"``: detached scalar maximum routing-imbalance across
|
| 4092 |
all decoder layers. Zero means perfectly balanced routing across
|
| 4093 |
every layer; higher values identify the worst-case head imbalance.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4094 |
- ``"logit_std"``: detached scalar — mean across layers of the
|
| 4095 |
+
per-token routing logit spread. Monitoring metric for routing
|
| 4096 |
+
sharpness.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4097 |
"""
|
| 4098 |
hidden_states = inputs_embeds
|
| 4099 |
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 4100 |
total_load_balance_loss = inputs_embeds.new_zeros(())
|
| 4101 |
max_vio = inputs_embeds.new_zeros(())
|
|
|
|
|
|
|
| 4102 |
total_logit_std = inputs_embeds.new_zeros(())
|
|
|
|
| 4103 |
|
| 4104 |
for layer_idx, layer in enumerate(self.layers):
|
| 4105 |
layer_cache = None if cache is None else cache.layers[layer_idx]
|
|
|
|
| 4111 |
)
|
| 4112 |
total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
|
| 4113 |
max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
|
|
|
|
|
|
|
| 4114 |
total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
|
|
|
|
| 4115 |
|
| 4116 |
if output_hidden_states:
|
| 4117 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
| 4125 |
"hidden_states": all_hidden_states,
|
| 4126 |
"load_balance_loss": total_load_balance_loss,
|
| 4127 |
"max_vio": max_vio,
|
|
|
|
|
|
|
| 4128 |
"logit_std": total_logit_std / num_layers,
|
|
|
|
| 4129 |
}
|
| 4130 |
|
| 4131 |
|
|
|
|
| 4142 |
## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
|
| 4143 |
## fields to None, which forces every subclass field to also carry a default.
|
| 4144 |
## The = None below is a language constraint, not a semantic statement. In
|
| 4145 |
+
## practice, load_balance_loss, max_vio, and logit_std are always populated
|
| 4146 |
+
## by ShramForCausalLM.forward(). ce_loss is genuinely optional — present
|
| 4147 |
+
## only when labels are supplied.
|
| 4148 |
|
| 4149 |
ce_loss: torch.FloatTensor | None = None
|
| 4150 |
load_balance_loss: torch.FloatTensor | None = None
|
| 4151 |
max_vio: torch.FloatTensor | None = None
|
|
|
|
|
|
|
| 4152 |
logit_std: torch.Tensor | None = None
|
|
|
|
| 4153 |
|
| 4154 |
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
| 4155 |
"""HuggingFace-facing causal language model wrapper for SHRAM.
|
|
|
|
| 4598 |
- ``hidden_states`` when requested,
|
| 4599 |
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
|
| 4600 |
- ``max_vio`` — detached worst-case routing imbalance across layers,
|
| 4601 |
+
- ``logit_std`` — detached mean per-token routing logit spread across layers.
|
|
|
|
|
|
|
| 4602 |
"""
|
| 4603 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 4604 |
output_hidden_states = (
|
|
|
|
| 4705 |
hidden_states=backbone_outputs["hidden_states"],
|
| 4706 |
load_balance_loss=backbone_outputs["load_balance_loss"],
|
| 4707 |
max_vio=backbone_outputs["max_vio"],
|
|
|
|
|
|
|
| 4708 |
logit_std=backbone_outputs["logit_std"],
|
|
|
|
| 4709 |
)
|