Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +1 -2
- config.json +1 -2
- configuration.py +13 -23
- huggingface.py +474 -208
README.md
CHANGED
|
@@ -83,7 +83,6 @@ contains no weights. All values are overridable via kwargs.
|
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
| `load_balance_loss_type` | ce |
|
| 86 |
-
| `load_balance_p` | 1.0 |
|
| 87 |
| `local_rope_theta` | 10000.0 |
|
| 88 |
| `max_bid_rounds` | 10 |
|
| 89 |
| `mlp_width` | 1366 |
|
|
@@ -96,7 +95,7 @@ contains no weights. All values are overridable via kwargs.
|
|
| 96 |
| `output_hidden_states` | False |
|
| 97 |
| `rms_norm_eps` | 1e-05 |
|
| 98 |
| `rope_mode` | main_sequence |
|
| 99 |
-
| `
|
| 100 |
| `tie_word_embeddings` | False |
|
| 101 |
| `training_sequence_length` | 1024 |
|
| 102 |
| `use_cache` | True |
|
|
|
|
| 83 |
| `head_dim` | 16 |
|
| 84 |
| `inference_sequence_length` | 1024 |
|
| 85 |
| `load_balance_loss_type` | ce |
|
|
|
|
| 86 |
| `local_rope_theta` | 10000.0 |
|
| 87 |
| `max_bid_rounds` | 10 |
|
| 88 |
| `mlp_width` | 1366 |
|
|
|
|
| 95 |
| `output_hidden_states` | False |
|
| 96 |
| `rms_norm_eps` | 1e-05 |
|
| 97 |
| `rope_mode` | main_sequence |
|
| 98 |
+
| `routing_mode` | integral |
|
| 99 |
| `tie_word_embeddings` | False |
|
| 100 |
| `training_sequence_length` | 1024 |
|
| 101 |
| `use_cache` | True |
|
config.json
CHANGED
|
@@ -10,7 +10,6 @@
|
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
"load_balance_loss_type": "ce",
|
| 13 |
-
"load_balance_p": 1.0,
|
| 14 |
"local_rope_theta": 10000.0,
|
| 15 |
"max_bid_rounds": 10,
|
| 16 |
"mlp_width": 1366,
|
|
@@ -23,7 +22,7 @@
|
|
| 23 |
"num_sliding_window_heads": 16,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"rope_mode": "main_sequence",
|
| 26 |
-
"
|
| 27 |
"tie_word_embeddings": false,
|
| 28 |
"training_sequence_length": 1024,
|
| 29 |
"transformers_version": "5.10.2",
|
|
|
|
| 10 |
"head_dim": 16,
|
| 11 |
"inference_sequence_length": 1024,
|
| 12 |
"load_balance_loss_type": "ce",
|
|
|
|
| 13 |
"local_rope_theta": 10000.0,
|
| 14 |
"max_bid_rounds": 10,
|
| 15 |
"mlp_width": 1366,
|
|
|
|
| 22 |
"num_sliding_window_heads": 16,
|
| 23 |
"rms_norm_eps": 1e-05,
|
| 24 |
"rope_mode": "main_sequence",
|
| 25 |
+
"routing_mode": "integral",
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
"transformers_version": "5.10.2",
|
configuration.py
CHANGED
|
@@ -84,10 +84,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 84 |
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 85 |
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 86 |
baseline. Default 2.0.
|
| 87 |
-
load_balance_p: Exponent p for the p-mean aggregation of per-item routing
|
| 88 |
-
frequencies into the load balance signal. Higher p weights aggregation
|
| 89 |
-
toward the worst-case batch item, making the correction signal more
|
| 90 |
-
sensitive to per-item allocation spikes. Must be positive. Default 2.0.
|
| 91 |
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 92 |
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 93 |
the 98th percentile of routing densities; the top 2% of extreme-density
|
|
@@ -99,11 +95,13 @@ class ShramConfig(PretrainedConfig):
|
|
| 99 |
is the default; its log-probability signal scales with violation severity
|
| 100 |
and makes correction magnitude proportional to routing imbalance.
|
| 101 |
Default ``"ce"``.
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
"""
|
| 108 |
|
| 109 |
model_type = "shram"
|
|
@@ -137,10 +135,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 137 |
output_hidden_states: bool = False,
|
| 138 |
tie_word_embeddings: bool = False,
|
| 139 |
mosrah_overallocation_factor: float = 2.0,
|
| 140 |
-
load_balance_p: float = 1.0,
|
| 141 |
max_bid_rounds: int = 10,
|
| 142 |
load_balance_loss_type: str = "ce",
|
| 143 |
-
|
| 144 |
**kwargs
|
| 145 |
):
|
| 146 |
if head_dim % 2 != 0:
|
|
@@ -176,11 +173,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 176 |
f"Got {mosrah_overallocation_factor}."
|
| 177 |
)
|
| 178 |
|
| 179 |
-
if load_balance_p <= 0.0:
|
| 180 |
-
raise ValueError(
|
| 181 |
-
f"load_balance_p must be positive, got {load_balance_p}."
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
if max_bid_rounds < 1:
|
| 185 |
raise ValueError(
|
| 186 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
|
@@ -193,13 +185,12 @@ class ShramConfig(PretrainedConfig):
|
|
| 193 |
f"load_balance_loss_type must be one of {supported}, "
|
| 194 |
f"got {load_balance_loss_type!r}."
|
| 195 |
)
|
| 196 |
-
if load_balance_loss_type == "ce" and load_balance_p != 1.0:
|
| 197 |
-
raise ValueError("In cross entropy mode, aggregation of "
|
| 198 |
-
"frequencies must be with mean 1.0")
|
| 199 |
|
| 200 |
-
|
|
|
|
|
|
|
| 201 |
raise ValueError(
|
| 202 |
-
f"
|
| 203 |
)
|
| 204 |
|
| 205 |
self.vocab_size = vocab_size
|
|
@@ -220,10 +211,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 220 |
self.alpha = alpha
|
| 221 |
self.beta = beta
|
| 222 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 223 |
-
self.load_balance_p = load_balance_p
|
| 224 |
self.max_bid_rounds = max_bid_rounds
|
| 225 |
self.load_balance_loss_type = load_balance_loss_type
|
| 226 |
-
self.
|
| 227 |
self.attention_dropout = attention_dropout
|
| 228 |
self.use_cache = use_cache
|
| 229 |
|
|
|
|
| 84 |
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 85 |
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 86 |
baseline. Default 2.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 88 |
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 89 |
the 98th percentile of routing densities; the top 2% of extreme-density
|
|
|
|
| 95 |
is the default; its log-probability signal scales with violation severity
|
| 96 |
and makes correction magnitude proportional to routing imbalance.
|
| 97 |
Default ``"ce"``.
|
| 98 |
+
routing_mode: Routing computation mode. ``"integral"`` (default) enables the
|
| 99 |
+
integral routing extension: the exclusive cumsum of routing logits along
|
| 100 |
+
the sequence dimension is mapped through two additional (L, L) parameter
|
| 101 |
+
matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
|
| 102 |
+
B') and added as corrections to both logit pathways. This gives each
|
| 103 |
+
token a read on the cumulative routing history so far in the sequence.
|
| 104 |
+
``"default"`` disables the extension; A' and B' are not created.
|
| 105 |
"""
|
| 106 |
|
| 107 |
model_type = "shram"
|
|
|
|
| 135 |
output_hidden_states: bool = False,
|
| 136 |
tie_word_embeddings: bool = False,
|
| 137 |
mosrah_overallocation_factor: float = 2.0,
|
|
|
|
| 138 |
max_bid_rounds: int = 10,
|
| 139 |
load_balance_loss_type: str = "ce",
|
| 140 |
+
routing_mode: str = "integral",
|
| 141 |
**kwargs
|
| 142 |
):
|
| 143 |
if head_dim % 2 != 0:
|
|
|
|
| 173 |
f"Got {mosrah_overallocation_factor}."
|
| 174 |
)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
if max_bid_rounds < 1:
|
| 177 |
raise ValueError(
|
| 178 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
|
|
|
| 185 |
f"load_balance_loss_type must be one of {supported}, "
|
| 186 |
f"got {load_balance_loss_type!r}."
|
| 187 |
)
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
_supported_routing_modes = {"default", "integral"}
|
| 190 |
+
if routing_mode not in _supported_routing_modes:
|
| 191 |
+
supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
|
| 192 |
raise ValueError(
|
| 193 |
+
f"routing_mode must be one of {supported}, got {routing_mode!r}."
|
| 194 |
)
|
| 195 |
|
| 196 |
self.vocab_size = vocab_size
|
|
|
|
| 211 |
self.alpha = alpha
|
| 212 |
self.beta = beta
|
| 213 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
|
|
|
| 214 |
self.max_bid_rounds = max_bid_rounds
|
| 215 |
self.load_balance_loss_type = load_balance_loss_type
|
| 216 |
+
self.routing_mode = routing_mode
|
| 217 |
self.attention_dropout = attention_dropout
|
| 218 |
self.use_cache = use_cache
|
| 219 |
|
huggingface.py
CHANGED
|
@@ -45,7 +45,6 @@ from torch.nn.attention.flex_attention import create_block_mask
|
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
| 47 |
from typing import Callable
|
| 48 |
-
from typing import Optional
|
| 49 |
|
| 50 |
|
| 51 |
|
|
@@ -172,10 +171,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 172 |
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 173 |
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 174 |
baseline. Default 2.0.
|
| 175 |
-
load_balance_p: Exponent p for the p-mean aggregation of per-item routing
|
| 176 |
-
frequencies into the load balance signal. Higher p weights aggregation
|
| 177 |
-
toward the worst-case batch item, making the correction signal more
|
| 178 |
-
sensitive to per-item allocation spikes. Must be positive. Default 2.0.
|
| 179 |
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 180 |
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 181 |
the 98th percentile of routing densities; the top 2% of extreme-density
|
|
@@ -187,11 +182,13 @@ class ShramConfig(PretrainedConfig):
|
|
| 187 |
is the default; its log-probability signal scales with violation severity
|
| 188 |
and makes correction magnitude proportional to routing imbalance.
|
| 189 |
Default ``"ce"``.
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
|
|
|
|
|
|
| 195 |
"""
|
| 196 |
|
| 197 |
model_type = "shram"
|
|
@@ -225,10 +222,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 225 |
output_hidden_states: bool = False,
|
| 226 |
tie_word_embeddings: bool = False,
|
| 227 |
mosrah_overallocation_factor: float = 2.0,
|
| 228 |
-
load_balance_p: float = 1.0,
|
| 229 |
max_bid_rounds: int = 10,
|
| 230 |
load_balance_loss_type: str = "ce",
|
| 231 |
-
|
| 232 |
**kwargs
|
| 233 |
):
|
| 234 |
if head_dim % 2 != 0:
|
|
@@ -264,11 +260,6 @@ class ShramConfig(PretrainedConfig):
|
|
| 264 |
f"Got {mosrah_overallocation_factor}."
|
| 265 |
)
|
| 266 |
|
| 267 |
-
if load_balance_p <= 0.0:
|
| 268 |
-
raise ValueError(
|
| 269 |
-
f"load_balance_p must be positive, got {load_balance_p}."
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
if max_bid_rounds < 1:
|
| 273 |
raise ValueError(
|
| 274 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
|
@@ -281,13 +272,12 @@ class ShramConfig(PretrainedConfig):
|
|
| 281 |
f"load_balance_loss_type must be one of {supported}, "
|
| 282 |
f"got {load_balance_loss_type!r}."
|
| 283 |
)
|
| 284 |
-
if load_balance_loss_type == "ce" and load_balance_p != 1.0:
|
| 285 |
-
raise ValueError("In cross entropy mode, aggregation of "
|
| 286 |
-
"frequencies must be with mean 1.0")
|
| 287 |
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
raise ValueError(
|
| 290 |
-
f"
|
| 291 |
)
|
| 292 |
|
| 293 |
self.vocab_size = vocab_size
|
|
@@ -308,10 +298,9 @@ class ShramConfig(PretrainedConfig):
|
|
| 308 |
self.alpha = alpha
|
| 309 |
self.beta = beta
|
| 310 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
| 311 |
-
self.load_balance_p = load_balance_p
|
| 312 |
self.max_bid_rounds = max_bid_rounds
|
| 313 |
self.load_balance_loss_type = load_balance_loss_type
|
| 314 |
-
self.
|
| 315 |
self.attention_dropout = attention_dropout
|
| 316 |
self.use_cache = use_cache
|
| 317 |
|
|
@@ -2741,24 +2730,53 @@ paper. Given an input hidden state x, the router produces two outputs used downs
|
|
| 2741 |
the semantic routing scores at the selected indices and renormalized to sum to 1
|
| 2742 |
per token.
|
| 2743 |
|
| 2744 |
-
|
| 2745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2746 |
|
| 2747 |
-
|
| 2748 |
-
|
| 2749 |
-
|
| 2750 |
-
Load balance gradients reach expert_bias; routing_projection.weight is isolated from
|
| 2751 |
-
load balance loss.
|
| 2752 |
|
| 2753 |
-
|
| 2754 |
-
|
| 2755 |
-
a bias-redirected expert could be selected but contribute negligibly to the output because
|
| 2756 |
-
its unbiased preference — and thus its routing_prob — remained near zero.
|
| 2757 |
|
| 2758 |
Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
|
| 2759 |
Post-capacity softmax would invert the load balance gradient for over-capacity experts
|
| 2760 |
-
(near-zero probability after masking signals "increase
|
| 2761 |
-
expert).
|
| 2762 |
|
| 2763 |
The router also computes and returns the load balance loss via a log-probability auxiliary
|
| 2764 |
loss (see load_balance_loss.py). The loss formulation is selected by config; the default
|
|
@@ -2767,10 +2785,11 @@ is cross-entropy.
|
|
| 2767 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2768 |
routing imbalance for the current forward pass:
|
| 2769 |
|
| 2770 |
-
MaxVio = L · max_l(
|
| 2771 |
|
| 2772 |
-
where
|
| 2773 |
-
target. MaxVio is
|
|
|
|
| 2774 |
|
| 2775 |
Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
| 2776 |
"""
|
|
@@ -2785,130 +2804,228 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
|
| 2785 |
# -----------
|
| 2786 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2787 |
|
| 2788 |
-
This module provides three load-balance loss formulations
|
| 2789 |
-
|
| 2790 |
-
|
| 2791 |
-
expert_bias, so only expert_bias receives gradients from the loss signal. The routing
|
| 2792 |
-
projection weights are not reachable from any returned loss.
|
| 2793 |
|
| 2794 |
-
|
| 2795 |
-
|
|
|
|
|
|
|
|
|
|
| 2796 |
|
| 2797 |
-
|
| 2798 |
-
|
| 2799 |
-
|
| 2800 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2801 |
|
| 2802 |
-
|
| 2803 |
|
| 2804 |
-
|
|
|
|
| 2805 |
|
| 2806 |
-
|
| 2807 |
-
|
| 2808 |
-
|
| 2809 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2810 |
"""
|
| 2811 |
|
| 2812 |
|
| 2813 |
|
| 2814 |
|
| 2815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2816 |
# ---------------------------------------------------------------------------
|
| 2817 |
# Loss functions
|
| 2818 |
# ---------------------------------------------------------------------------
|
| 2819 |
|
| 2820 |
def gshard_loss(
|
| 2821 |
-
|
| 2822 |
-
|
|
|
|
| 2823 |
) -> torch.Tensor:
|
| 2824 |
"""GShard-style linear load-balance loss.
|
| 2825 |
|
| 2826 |
-
Computes (1/L) * Σ
|
| 2827 |
-
|
| 2828 |
-
assignment probability for head i.
|
| 2829 |
|
| 2830 |
-
The
|
| 2831 |
-
|
| 2832 |
-
The linear signal is the weakest of the three formulations — gradient magnitude
|
| 2833 |
-
does not grow with deviation from the target. Provided for comparison.
|
| 2834 |
|
| 2835 |
Args:
|
| 2836 |
-
|
| 2837 |
-
|
| 2838 |
-
|
| 2839 |
|
| 2840 |
Returns:
|
| 2841 |
Scalar loss tensor.
|
| 2842 |
"""
|
| 2843 |
-
L =
|
| 2844 |
-
|
|
|
|
|
|
|
| 2845 |
|
| 2846 |
|
| 2847 |
def ce_loss(
|
| 2848 |
-
|
| 2849 |
-
|
|
|
|
| 2850 |
) -> torch.Tensor:
|
| 2851 |
"""Cross-entropy load-balance loss.
|
| 2852 |
|
| 2853 |
-
|
| 2854 |
-
|
| 2855 |
-
|
| 2856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2857 |
|
| 2858 |
-
The
|
| 2859 |
-
|
| 2860 |
-
deviates from the target, providing correction that scales with violation severity.
|
| 2861 |
|
| 2862 |
Args:
|
| 2863 |
-
|
| 2864 |
-
|
| 2865 |
-
|
| 2866 |
|
| 2867 |
Returns:
|
| 2868 |
Scalar loss tensor.
|
| 2869 |
"""
|
| 2870 |
-
L =
|
| 2871 |
-
|
| 2872 |
-
|
| 2873 |
-
|
| 2874 |
-
#
|
| 2875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2876 |
|
| 2877 |
|
| 2878 |
def bce_loss(
|
| 2879 |
-
|
| 2880 |
-
|
|
|
|
| 2881 |
) -> torch.Tensor:
|
| 2882 |
"""Binary cross-entropy load-balance loss.
|
| 2883 |
|
| 2884 |
-
|
| 2885 |
-
|
| 2886 |
-
|
| 2887 |
-
|
|
|
|
|
|
|
|
|
|
| 2888 |
|
| 2889 |
-
|
|
|
|
|
|
|
| 2890 |
|
| 2891 |
Args:
|
| 2892 |
-
|
| 2893 |
-
|
| 2894 |
-
|
| 2895 |
|
| 2896 |
Returns:
|
| 2897 |
Scalar loss tensor.
|
| 2898 |
"""
|
| 2899 |
-
|
| 2900 |
-
|
| 2901 |
-
#
|
| 2902 |
-
#
|
| 2903 |
-
|
| 2904 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2905 |
|
| 2906 |
|
| 2907 |
# ---------------------------------------------------------------------------
|
| 2908 |
# Factory
|
| 2909 |
# ---------------------------------------------------------------------------
|
| 2910 |
|
| 2911 |
-
_LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {
|
| 2912 |
"gshard": gshard_loss,
|
| 2913 |
"ce": ce_loss,
|
| 2914 |
"bce": bce_loss,
|
|
@@ -2917,15 +3034,19 @@ _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
|
| 2917 |
|
| 2918 |
def make_load_balance_loss(
|
| 2919 |
loss_type: str,
|
| 2920 |
-
) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 2921 |
"""Return a load-balance loss callable for the requested formulation.
|
| 2922 |
|
| 2923 |
-
All returned callables share the
|
| 2924 |
|
| 2925 |
-
loss_fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2926 |
|
| 2927 |
-
The caller is responsible for computing
|
| 2928 |
-
|
| 2929 |
|
| 2930 |
Args:
|
| 2931 |
loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
|
|
@@ -2944,55 +3065,77 @@ def make_load_balance_loss(
|
|
| 2944 |
return _LOSS_REGISTRY[loss_type]
|
| 2945 |
|
| 2946 |
|
| 2947 |
-
|
| 2948 |
-
|
| 2949 |
class MoSRAHRouter(nn.Module):
|
| 2950 |
"""Token-choice router for MoSRAH sparse attention.
|
| 2951 |
|
| 2952 |
Each input token independently selects K of the L available expert heads. Both
|
| 2953 |
-
selection and routing_probs incorporate
|
| 2954 |
-
pathways over numerically identical
|
| 2955 |
-
two-pathway architecture.
|
| 2956 |
-
|
| 2957 |
-
|
| 2958 |
-
|
| 2959 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2960 |
|
| 2961 |
Args:
|
| 2962 |
-
config: Model configuration. Must expose ``
|
| 2963 |
-
(L),
|
|
|
|
| 2964 |
"""
|
| 2965 |
|
| 2966 |
def __init__(self, config: ShramConfig) -> None:
|
| 2967 |
super().__init__()
|
| 2968 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 2969 |
self.num_selected_heads = config.num_selected_heads
|
| 2970 |
-
self.load_balance_p = config.load_balance_p
|
| 2971 |
if config.use_cache:
|
| 2972 |
self.capacity = config.mosrah_cache_length
|
| 2973 |
else:
|
| 2974 |
self.capacity = config.mosrah_packed_length
|
| 2975 |
|
| 2976 |
self.max_bid_rounds = config.max_bid_rounds
|
|
|
|
| 2977 |
self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
|
| 2978 |
|
| 2979 |
-
# W_r: routing
|
| 2980 |
-
|
| 2981 |
-
|
|
|
|
|
|
|
| 2982 |
)
|
|
|
|
| 2983 |
|
| 2984 |
-
#
|
| 2985 |
-
#
|
| 2986 |
-
#
|
| 2987 |
-
|
| 2988 |
-
|
| 2989 |
-
torch.full((1,), config.router_init_scale)
|
| 2990 |
)
|
| 2991 |
-
|
| 2992 |
-
|
| 2993 |
-
|
| 2994 |
-
|
| 2995 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2996 |
|
| 2997 |
@staticmethod
|
| 2998 |
def get_best_proposals(
|
|
@@ -3228,7 +3371,7 @@ class MoSRAHRouter(nn.Module):
|
|
| 3228 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 3229 |
|
| 3230 |
Args:
|
| 3231 |
-
x: Input hidden states of shape (batch, seq_len,
|
| 3232 |
active_mask: Current-chunk active mask of shape (batch, seq_len), where
|
| 3233 |
True means the token is semantically live. Dead tokens do not
|
| 3234 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
|
@@ -3244,56 +3387,39 @@ class MoSRAHRouter(nn.Module):
|
|
| 3244 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3245 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3246 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3247 |
-
- ``
|
| 3248 |
-
|
|
|
|
|
|
|
| 3249 |
- ``logit_std``: mean per-token std of semantic_logits; lower than
|
| 3250 |
-
raw_logit_std means
|
| 3251 |
-
- ``bias_alignment``: mean cosine similarity of
|
| 3252 |
-
|
| 3253 |
-
positive means runaway reinforcement.
|
| 3254 |
"""
|
| 3255 |
B, N, _ = x.shape
|
| 3256 |
L = self.num_mosrah_heads
|
| 3257 |
K = self.num_selected_heads
|
| 3258 |
|
| 3259 |
-
|
| 3260 |
-
|
| 3261 |
-
#
|
| 3262 |
-
|
| 3263 |
-
|
| 3264 |
-
#
|
| 3265 |
-
|
| 3266 |
-
|
| 3267 |
-
|
| 3268 |
-
load_balancing_logits = logits.detach() + self.expert_bias # (B, N, L)
|
| 3269 |
-
|
| 3270 |
-
# Diagnostic scalars characterising the load-balance mechanism. Must be
|
| 3271 |
-
# computed here — before balance_capacity injects -1e8 sentinels that
|
| 3272 |
-
# would corrupt std and cosine similarity.
|
| 3273 |
-
bias_std = self.expert_bias.std().detach()
|
| 3274 |
-
raw_logit_std = logits.std(dim=-1).mean().detach()
|
| 3275 |
-
logit_std = semantic_logits.std(dim=-1).mean().detach()
|
| 3276 |
-
bias_alignment = F.cosine_similarity(
|
| 3277 |
-
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3278 |
-
).mean().detach()
|
| 3279 |
-
|
| 3280 |
-
# Assignment probabilities for load balance loss. Computed from load_balancing_logits
|
| 3281 |
-
# before balance_capacity so that -1e8 sentinels do not invert the load balance
|
| 3282 |
-
# gradient for over-capacity experts. active_float is reused below for routing freqs.
|
| 3283 |
-
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 3284 |
-
lb_softmax = F.softmax(load_balancing_logits, dim=-1) # (B, N, L)
|
| 3285 |
-
assignment_probs = (lb_softmax * active_float).sum(dim=(0, 1)) # (L,) unnorm
|
| 3286 |
-
assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
|
| 3287 |
|
| 3288 |
# Pre-capacity semantic softmax for gathering routing_probs. Computed before
|
| 3289 |
# balance_capacity so that gathered probabilities reflect genuine preference
|
| 3290 |
# magnitudes rather than hard-masked sentinel values.
|
| 3291 |
-
routing_scores = F.softmax(semantic_logits, dim=-1) # (B, N, L)
|
| 3292 |
|
| 3293 |
# Capacity-balanced semantic logits for selection. Injects -1e8 into positions
|
| 3294 |
# that would exceed per-expert token budget, enforcing the packing constraint.
|
| 3295 |
balanced_semantic_logits = self.balance_capacity(
|
| 3296 |
-
semantic_logits,
|
| 3297 |
used_capacity,
|
| 3298 |
self.capacity,
|
| 3299 |
self.num_selected_heads,
|
|
@@ -3309,61 +3435,201 @@ class MoSRAHRouter(nn.Module):
|
|
| 3309 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
|
| 3310 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 3311 |
|
| 3312 |
-
#
|
| 3313 |
-
#
|
| 3314 |
-
#
|
| 3315 |
-
#
|
| 3316 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3317 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 3318 |
-
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 3319 |
-
per_item_counts = active_assignments.sum(dim=1) # (B, L)
|
| 3320 |
-
per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
|
| 3321 |
-
per_item_freqs = per_item_counts / per_item_total # (B, L)
|
| 3322 |
-
|
| 3323 |
-
# p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
|
| 3324 |
-
# p-mean weights aggregation toward the worst-case batch item relative to
|
| 3325 |
-
# arithmetic mean, making the load balance signal sensitive to per-item spikes
|
| 3326 |
-
# that cause packing overflow.
|
| 3327 |
-
p = self.load_balance_p
|
| 3328 |
-
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 3329 |
|
| 3330 |
-
load_balance_loss = self._load_balance_loss(
|
|
|
|
|
|
|
| 3331 |
|
| 3332 |
-
# MaxVio
|
| 3333 |
-
#
|
| 3334 |
-
max_vio = self._compute_max_vio(
|
| 3335 |
|
| 3336 |
router_diagnostics = {
|
| 3337 |
"load_balance_loss": load_balance_loss,
|
| 3338 |
"max_vio": max_vio,
|
| 3339 |
-
|
| 3340 |
-
"raw_logit_std": raw_logit_std,
|
| 3341 |
-
"logit_std": logit_std,
|
| 3342 |
-
"bias_alignment": bias_alignment,
|
| 3343 |
}
|
| 3344 |
return selected_heads, routing_probs, router_diagnostics
|
| 3345 |
|
| 3346 |
@staticmethod
|
| 3347 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3348 |
"""Compute the MaxVio routing-imbalance scalar.
|
| 3349 |
|
| 3350 |
-
MaxVio = L · max_l(
|
| 3351 |
-
|
| 3352 |
-
|
| 3353 |
-
balance
|
| 3354 |
-
|
|
|
|
| 3355 |
|
| 3356 |
-
The result is detached
|
| 3357 |
-
|
| 3358 |
|
| 3359 |
Args:
|
| 3360 |
-
|
| 3361 |
-
|
|
|
|
| 3362 |
|
| 3363 |
Returns:
|
| 3364 |
Detached scalar MaxVio tensor.
|
| 3365 |
"""
|
| 3366 |
-
|
|
|
|
|
|
|
| 3367 |
|
| 3368 |
# -----------
|
| 3369 |
# Inlined from: positions_converter.py
|
|
|
|
| 45 |
from torch.nn.attention.flex_attention import flex_attention
|
| 46 |
import torch.nn.functional as F
|
| 47 |
from typing import Callable
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
|
|
|
|
| 171 |
num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
|
| 172 |
Must be > 1.0 to guarantee a buffer larger than the balanced-routing
|
| 173 |
baseline. Default 2.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
|
| 175 |
solver in ``balance_capacity``. 10 covers convergence at approximately
|
| 176 |
the 98th percentile of routing densities; the top 2% of extreme-density
|
|
|
|
| 182 |
is the default; its log-probability signal scales with violation severity
|
| 183 |
and makes correction magnitude proportional to routing imbalance.
|
| 184 |
Default ``"ce"``.
|
| 185 |
+
routing_mode: Routing computation mode. ``"integral"`` (default) enables the
|
| 186 |
+
integral routing extension: the exclusive cumsum of routing logits along
|
| 187 |
+
the sequence dimension is mapped through two additional (L, L) parameter
|
| 188 |
+
matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
|
| 189 |
+
B') and added as corrections to both logit pathways. This gives each
|
| 190 |
+
token a read on the cumulative routing history so far in the sequence.
|
| 191 |
+
``"default"`` disables the extension; A' and B' are not created.
|
| 192 |
"""
|
| 193 |
|
| 194 |
model_type = "shram"
|
|
|
|
| 222 |
output_hidden_states: bool = False,
|
| 223 |
tie_word_embeddings: bool = False,
|
| 224 |
mosrah_overallocation_factor: float = 2.0,
|
|
|
|
| 225 |
max_bid_rounds: int = 10,
|
| 226 |
load_balance_loss_type: str = "ce",
|
| 227 |
+
routing_mode: str = "integral",
|
| 228 |
**kwargs
|
| 229 |
):
|
| 230 |
if head_dim % 2 != 0:
|
|
|
|
| 260 |
f"Got {mosrah_overallocation_factor}."
|
| 261 |
)
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
if max_bid_rounds < 1:
|
| 264 |
raise ValueError(
|
| 265 |
f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
|
|
|
|
| 272 |
f"load_balance_loss_type must be one of {supported}, "
|
| 273 |
f"got {load_balance_loss_type!r}."
|
| 274 |
)
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
_supported_routing_modes = {"default", "integral"}
|
| 277 |
+
if routing_mode not in _supported_routing_modes:
|
| 278 |
+
supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
|
| 279 |
raise ValueError(
|
| 280 |
+
f"routing_mode must be one of {supported}, got {routing_mode!r}."
|
| 281 |
)
|
| 282 |
|
| 283 |
self.vocab_size = vocab_size
|
|
|
|
| 298 |
self.alpha = alpha
|
| 299 |
self.beta = beta
|
| 300 |
self.mosrah_overallocation_factor = mosrah_overallocation_factor
|
|
|
|
| 301 |
self.max_bid_rounds = max_bid_rounds
|
| 302 |
self.load_balance_loss_type = load_balance_loss_type
|
| 303 |
+
self.routing_mode = routing_mode
|
| 304 |
self.attention_dropout = attention_dropout
|
| 305 |
self.use_cache = use_cache
|
| 306 |
|
|
|
|
| 2730 |
the semantic routing scores at the selected indices and renormalized to sum to 1
|
| 2731 |
per token.
|
| 2732 |
|
| 2733 |
+
Base routing uses two learnable projection matrices and two gradient-isolated pathways:
|
| 2734 |
+
|
| 2735 |
+
- routing_weight (A): shape (L, embedding_width). Maps input to per-head routing
|
| 2736 |
+
scores. Receives gradients from task loss; balance_weight is isolated.
|
| 2737 |
+
- balance_weight (B): shape (L, embedding_width). Maps input to per-head load-balance
|
| 2738 |
+
correction scores. Receives gradients from load_balance_loss; routing_weight is
|
| 2739 |
+
isolated.
|
| 2740 |
+
|
| 2741 |
+
The two gradient-isolated base pathways over numerically identical values:
|
| 2742 |
+
|
| 2743 |
+
- semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
|
| 2744 |
+
balance_weight is isolated from task loss.
|
| 2745 |
+
- load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance gradients
|
| 2746 |
+
reach balance_weight; routing_weight and x are isolated from load balance loss.
|
| 2747 |
+
|
| 2748 |
+
Integral routing extension (routing_mode == "integral"):
|
| 2749 |
+
|
| 2750 |
+
Standard routing is parallel — each token routes based on its own hidden state alone,
|
| 2751 |
+
with no direct read on what earlier tokens in the sequence have already selected.
|
| 2752 |
+
Integral routing adds a cumulative-sum signal that gives each token a view of the
|
| 2753 |
+
prior routing history within the sequence.
|
| 2754 |
+
|
| 2755 |
+
Two additional (L, L) parameter matrices are introduced:
|
| 2756 |
+
|
| 2757 |
+
- routing_integral_weight (A'): shape (L, L). Maps the cumulative logit history to
|
| 2758 |
+
per-head semantic corrections. Receives gradients from task loss.
|
| 2759 |
+
- balance_integral_weight (B'): shape (L, L). Maps the cumulative logit history to
|
| 2760 |
+
per-head load-balance corrections. Receives gradients from load_balance_loss.
|
| 2761 |
+
|
| 2762 |
+
The cumulative history signal u is the exclusive cumsum of the base logits along the
|
| 2763 |
+
sequence dimension: u[n] = sum(logits[0..n-1]), shape (B, N, L). Position 0 receives
|
| 2764 |
+
zeros (no prior history). The same gradient isolation pattern as A/B applies:
|
| 2765 |
+
|
| 2766 |
+
- semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
|
| 2767 |
+
- lb_logits += (A'·u_load).detach() + B'·u_load
|
| 2768 |
|
| 2769 |
+
Detaching the full B'·u_semantic result (rather than just B') mirrors the
|
| 2770 |
+
(B·x).detach() pattern in the base pathway and prevents double-counting the
|
| 2771 |
+
cumsum gradient path back to routing_weight.
|
|
|
|
|
|
|
| 2772 |
|
| 2773 |
+
Both base matrices and both integral matrices are nn.Parameter so that HuggingFace
|
| 2774 |
+
_init_weights does not override their kaiming initialization at construction.
|
|
|
|
|
|
|
| 2775 |
|
| 2776 |
Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
|
| 2777 |
Post-capacity softmax would invert the load balance gradient for over-capacity experts
|
| 2778 |
+
(near-zero probability after masking signals "increase corrections" for an already-
|
| 2779 |
+
overloaded expert).
|
| 2780 |
|
| 2781 |
The router also computes and returns the load balance loss via a log-probability auxiliary
|
| 2782 |
loss (see load_balance_loss.py). The loss formulation is selected by config; the default
|
|
|
|
| 2785 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2786 |
routing imbalance for the current forward pass:
|
| 2787 |
|
| 2788 |
+
MaxVio = mean_b( L · max_l(f_bl − 1/L) )
|
| 2789 |
|
| 2790 |
+
where f_bl is the per-batch-item realised routing frequency of head l and 1/L is the
|
| 2791 |
+
perfectly balanced target. MaxVio is averaged over batch items and is a monitoring
|
| 2792 |
+
quantity only; it never contributes gradients.
|
| 2793 |
|
| 2794 |
Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
| 2795 |
"""
|
|
|
|
| 2804 |
# -----------
|
| 2805 |
"""Log-probability auxiliary loss functions for MoSRAH load balancing.
|
| 2806 |
|
| 2807 |
+
This module provides three load-balance loss formulations, two token-reduction
|
| 2808 |
+
helpers, and a factory that selects among the formulations. All formulations
|
| 2809 |
+
share the same external contract:
|
|
|
|
|
|
|
| 2810 |
|
| 2811 |
+
loss_fn(
|
| 2812 |
+
logits: Tensor[B, N, L],
|
| 2813 |
+
assignment_mask: Tensor[B, N, L],
|
| 2814 |
+
active_mask: Tensor[B, N],
|
| 2815 |
+
) -> scalar Tensor
|
| 2816 |
|
| 2817 |
+
logits: Load-balancing logits, shape (B, N, L). These are the raw
|
| 2818 |
+
pre-softmax scores from logits.detach() + expert_bias.
|
| 2819 |
+
Gradient flows to expert_bias through this tensor.
|
| 2820 |
+
assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
|
| 2821 |
+
is 1.0 if token (b, n) was assigned to head l. Dead tokens
|
| 2822 |
+
should carry zero entries.
|
| 2823 |
+
active_mask: Boolean mask, shape (B, N). True means the token is
|
| 2824 |
+
semantically live.
|
| 2825 |
|
| 2826 |
+
Token reduction is split into two helpers with distinct roles:
|
| 2827 |
|
| 2828 |
+
reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L).
|
| 2829 |
+
Called by all three formulations. Output is detached; f_bl carries no gradient.
|
| 2830 |
|
| 2831 |
+
reduce_probability_tokens — produces per-batch-item mean assignment probabilities
|
| 2832 |
+
p_bl (B, L). Called only by gshard and bce. Gradient flows to expert_bias
|
| 2833 |
+
through the internal softmax over logits.
|
| 2834 |
+
|
| 2835 |
+
CE delegates probability computation to F.cross_entropy, which handles its own
|
| 2836 |
+
log_softmax and operates directly on the raw (B, N, L) logits.
|
| 2837 |
+
|
| 2838 |
+
The factory is the intended entry point. MoSRAHRouter constructs the loss callable
|
| 2839 |
+
once at init and invokes it each forward pass.
|
| 2840 |
"""
|
| 2841 |
|
| 2842 |
|
| 2843 |
|
| 2844 |
|
| 2845 |
|
| 2846 |
+
|
| 2847 |
+
# ---------------------------------------------------------------------------
|
| 2848 |
+
# Token-reduction helpers
|
| 2849 |
+
# ---------------------------------------------------------------------------
|
| 2850 |
+
|
| 2851 |
+
def reduce_frequency_tokens(
|
| 2852 |
+
assignment_mask: torch.Tensor,
|
| 2853 |
+
active_mask: torch.Tensor,
|
| 2854 |
+
) -> torch.Tensor:
|
| 2855 |
+
"""Reduce per-token head assignments to per-batch-item routing frequencies.
|
| 2856 |
+
|
| 2857 |
+
f_bl[b, l] is the fraction of active-token assignments in batch item b going
|
| 2858 |
+
to head l. Values sum to 1 per batch item when routing is valid.
|
| 2859 |
+
|
| 2860 |
+
The output is detached from the autograd graph: routing frequencies are
|
| 2861 |
+
derived from discrete TopK selections and must not carry gradients.
|
| 2862 |
+
|
| 2863 |
+
Denominators are clamped to 1 to handle the all-dead-tokens edge case.
|
| 2864 |
+
|
| 2865 |
+
Args:
|
| 2866 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 2867 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 2868 |
+
|
| 2869 |
+
Returns:
|
| 2870 |
+
f_bl: Per-batch-item routing frequencies, shape (B, L). Detached.
|
| 2871 |
+
"""
|
| 2872 |
+
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 2873 |
+
active_assignments = assignment_mask * active_float # (B, N, L)
|
| 2874 |
+
assignment_totals = (
|
| 2875 |
+
active_assignments.sum(dim=(1, 2)).clamp(min=1.0).unsqueeze(-1) # (B, 1)
|
| 2876 |
+
)
|
| 2877 |
+
return (active_assignments.sum(dim=1) / assignment_totals).detach() # (B, L)
|
| 2878 |
+
|
| 2879 |
+
|
| 2880 |
+
def reduce_probability_tokens(
|
| 2881 |
+
logits: torch.Tensor,
|
| 2882 |
+
active_mask: torch.Tensor,
|
| 2883 |
+
) -> torch.Tensor:
|
| 2884 |
+
"""Reduce per-token load-balancing logits to per-batch-item assignment probabilities.
|
| 2885 |
+
|
| 2886 |
+
p_bl[b, l] is the mean softmax probability for head l over active tokens in
|
| 2887 |
+
batch item b. Values sum to 1 per batch item. Gradient flows to expert_bias
|
| 2888 |
+
through the internal softmax.
|
| 2889 |
+
|
| 2890 |
+
Denominators are clamped to 1 to handle the all-dead-tokens edge case.
|
| 2891 |
+
|
| 2892 |
+
Args:
|
| 2893 |
+
logits: Load-balancing logits, shape (B, N, L). Gradient flows through.
|
| 2894 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 2895 |
+
|
| 2896 |
+
Returns:
|
| 2897 |
+
p_bl: Per-batch-item mean assignment probabilities, shape (B, L).
|
| 2898 |
+
"""
|
| 2899 |
+
per_token_probs = F.softmax(logits, dim=-1) # (B, N, L)
|
| 2900 |
+
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 2901 |
+
active_count = active_mask.float().sum(dim=1, keepdim=True).clamp(min=1.0) # (B, 1)
|
| 2902 |
+
return (per_token_probs * active_float).sum(dim=1) / active_count # (B, L)
|
| 2903 |
+
|
| 2904 |
+
|
| 2905 |
# ---------------------------------------------------------------------------
|
| 2906 |
# Loss functions
|
| 2907 |
# ---------------------------------------------------------------------------
|
| 2908 |
|
| 2909 |
def gshard_loss(
|
| 2910 |
+
logits: torch.Tensor,
|
| 2911 |
+
assignment_mask: torch.Tensor,
|
| 2912 |
+
active_mask: torch.Tensor,
|
| 2913 |
) -> torch.Tensor:
|
| 2914 |
"""GShard-style linear load-balance loss.
|
| 2915 |
|
| 2916 |
+
Computes (1/L) * Σ_l f_bl * p_bl per batch item, averaged over B, where
|
| 2917 |
+
f_bl comes from reduce_frequency_tokens and p_bl from reduce_probability_tokens.
|
|
|
|
| 2918 |
|
| 2919 |
+
The linear signal is the weakest of the three formulations; gradient magnitude
|
| 2920 |
+
does not grow with violation severity. Provided for comparison.
|
|
|
|
|
|
|
| 2921 |
|
| 2922 |
Args:
|
| 2923 |
+
logits: Load-balancing logits, shape (B, N, L).
|
| 2924 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 2925 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 2926 |
|
| 2927 |
Returns:
|
| 2928 |
Scalar loss tensor.
|
| 2929 |
"""
|
| 2930 |
+
L = logits.shape[-1]
|
| 2931 |
+
f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
|
| 2932 |
+
p_bl = reduce_probability_tokens(logits, active_mask)
|
| 2933 |
+
return (f_bl * p_bl).sum(dim=-1).mean() / L
|
| 2934 |
|
| 2935 |
|
| 2936 |
def ce_loss(
|
| 2937 |
+
logits: torch.Tensor,
|
| 2938 |
+
assignment_mask: torch.Tensor,
|
| 2939 |
+
active_mask: torch.Tensor,
|
| 2940 |
) -> torch.Tensor:
|
| 2941 |
"""Cross-entropy load-balance loss.
|
| 2942 |
|
| 2943 |
+
Constructs per-batch-item soft target distributions from routing frequencies
|
| 2944 |
+
and delegates to F.cross_entropy operating directly on (B, N, L) logits.
|
| 2945 |
+
Inactive tokens receive all-zero targets, producing zero loss and zero gradient.
|
| 2946 |
+
|
| 2947 |
+
The soft target for head l in batch item b is (1 - f_bl) / (L - 1). This
|
| 2948 |
+
distribution sums to 1 per batch item (since Σ_l (1 - f_bl) = L - 1) and
|
| 2949 |
+
weights underloaded heads (low f_bl → high target) more strongly than
|
| 2950 |
+
overloaded ones.
|
| 2951 |
|
| 2952 |
+
The total CE over active tokens is normalised by the active token count rather
|
| 2953 |
+
than B*N to avoid dilution from inactive positions.
|
|
|
|
| 2954 |
|
| 2955 |
Args:
|
| 2956 |
+
logits: Load-balancing logits, shape (B, N, L).
|
| 2957 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 2958 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 2959 |
|
| 2960 |
Returns:
|
| 2961 |
Scalar loss tensor.
|
| 2962 |
"""
|
| 2963 |
+
B, N, L = logits.shape
|
| 2964 |
+
f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L)
|
| 2965 |
+
active_count = active_mask.float().sum().clamp(min=1.0)
|
| 2966 |
+
|
| 2967 |
+
# Soft target: (1 - f_bl) / (L - 1) for active tokens, zeros for inactive.
|
| 2968 |
+
# Zeros give zero CE loss and zero gradient at inactive positions.
|
| 2969 |
+
target = (1.0 - f_bl) / (L - 1) # (B, L)
|
| 2970 |
+
target_per_token = (
|
| 2971 |
+
target.unsqueeze(1).expand(-1, N, -1) # (B, N, L)
|
| 2972 |
+
* active_mask.float().unsqueeze(-1) # zero inactive
|
| 2973 |
+
)
|
| 2974 |
+
|
| 2975 |
+
# F.cross_entropy requires the class dimension to be dim 1.
|
| 2976 |
+
# Permute (B, N, L) → (B, L, N) to satisfy the (N, C, d) contract.
|
| 2977 |
+
return F.cross_entropy(
|
| 2978 |
+
logits.permute(0, 2, 1), # (B, L, N)
|
| 2979 |
+
target_per_token.permute(0, 2, 1), # (B, L, N)
|
| 2980 |
+
reduction='sum',
|
| 2981 |
+
) / active_count
|
| 2982 |
|
| 2983 |
|
| 2984 |
def bce_loss(
|
| 2985 |
+
logits: torch.Tensor,
|
| 2986 |
+
assignment_mask: torch.Tensor,
|
| 2987 |
+
active_mask: torch.Tensor,
|
| 2988 |
) -> torch.Tensor:
|
| 2989 |
"""Binary cross-entropy load-balance loss.
|
| 2990 |
|
| 2991 |
+
Treats each head as an independent binary target with label (1 - f_bl).
|
| 2992 |
+
Uses reduce_probability_tokens to produce per-batch-item probabilities,
|
| 2993 |
+
then delegates to F.binary_cross_entropy over (B, L) tensors.
|
| 2994 |
+
|
| 2995 |
+
Unlike CE, BCE maintains a repulsion signal from saturated experts: when
|
| 2996 |
+
f_bl → 1 the target → 0, driving p_bl away from 1 and preventing runaway
|
| 2997 |
+
concentration.
|
| 2998 |
|
| 2999 |
+
Active masking is handled inside reduce_frequency_tokens and
|
| 3000 |
+
reduce_probability_tokens, so the (B, L) output tensors already exclude
|
| 3001 |
+
inactive tokens from both frequencies and probabilities.
|
| 3002 |
|
| 3003 |
Args:
|
| 3004 |
+
logits: Load-balancing logits, shape (B, N, L).
|
| 3005 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 3006 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 3007 |
|
| 3008 |
Returns:
|
| 3009 |
Scalar loss tensor.
|
| 3010 |
"""
|
| 3011 |
+
f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
|
| 3012 |
+
p_bl = reduce_probability_tokens(logits, active_mask)
|
| 3013 |
+
# Clamp p_bl for numerical safety: F.binary_cross_entropy requires input in
|
| 3014 |
+
# (0, 1) and will produce inf for exactly 0 or 1. Softmax outputs are
|
| 3015 |
+
# strictly positive in normal operation; the clamp guards the all-dead-tokens
|
| 3016 |
+
# edge case where the mean defaults to zero.
|
| 3017 |
+
return F.binary_cross_entropy(
|
| 3018 |
+
p_bl.clamp(min=1e-7, max=1.0 - 1e-7),
|
| 3019 |
+
1.0 - f_bl,
|
| 3020 |
+
reduction='mean',
|
| 3021 |
+
)
|
| 3022 |
|
| 3023 |
|
| 3024 |
# ---------------------------------------------------------------------------
|
| 3025 |
# Factory
|
| 3026 |
# ---------------------------------------------------------------------------
|
| 3027 |
|
| 3028 |
+
_LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = {
|
| 3029 |
"gshard": gshard_loss,
|
| 3030 |
"ce": ce_loss,
|
| 3031 |
"bce": bce_loss,
|
|
|
|
| 3034 |
|
| 3035 |
def make_load_balance_loss(
|
| 3036 |
loss_type: str,
|
| 3037 |
+
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
|
| 3038 |
"""Return a load-balance loss callable for the requested formulation.
|
| 3039 |
|
| 3040 |
+
All returned callables share the external contract:
|
| 3041 |
|
| 3042 |
+
loss_fn(
|
| 3043 |
+
logits: Tensor[B, N, L],
|
| 3044 |
+
assignment_mask: Tensor[B, N, L],
|
| 3045 |
+
active_mask: Tensor[B, N],
|
| 3046 |
+
) -> scalar Tensor
|
| 3047 |
|
| 3048 |
+
The caller is responsible for computing logits as logits.detach() + expert_bias
|
| 3049 |
+
to ensure gradient isolation to expert_bias.
|
| 3050 |
|
| 3051 |
Args:
|
| 3052 |
loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
|
|
|
|
| 3065 |
return _LOSS_REGISTRY[loss_type]
|
| 3066 |
|
| 3067 |
|
|
|
|
|
|
|
| 3068 |
class MoSRAHRouter(nn.Module):
|
| 3069 |
"""Token-choice router for MoSRAH sparse attention.
|
| 3070 |
|
| 3071 |
Each input token independently selects K of the L available expert heads. Both
|
| 3072 |
+
selection and routing_probs incorporate balance_weight via two gradient-isolated
|
| 3073 |
+
pathways over numerically identical values. See module docstring for the
|
| 3074 |
+
two-pathway architecture and the integral routing extension.
|
| 3075 |
+
|
| 3076 |
+
All four learnable matrices are nn.Parameter rather than nn.Linear so that
|
| 3077 |
+
HuggingFace _init_weights does not override their kaiming initialization at
|
| 3078 |
+
construction.
|
| 3079 |
+
|
| 3080 |
+
Attributes:
|
| 3081 |
+
routing_weight: A, shape (L, embedding_width). Task-loss pathway.
|
| 3082 |
+
balance_weight: B, shape (L, embedding_width). Load-balance pathway.
|
| 3083 |
+
routing_integral_weight: A', shape (L, L). Integral task-loss pathway.
|
| 3084 |
+
Present only when ``routing_mode == "integral"``.
|
| 3085 |
+
balance_integral_weight: B', shape (L, L). Integral load-balance pathway.
|
| 3086 |
+
Present only when ``routing_mode == "integral"``.
|
| 3087 |
+
routing_mode: ``"integral"`` or ``"default"``, from config.
|
| 3088 |
|
| 3089 |
Args:
|
| 3090 |
+
config: Model configuration. Must expose ``embedding_width``,
|
| 3091 |
+
``num_mosrah_heads`` (L), ``num_selected_heads`` (K), and
|
| 3092 |
+
``routing_mode``.
|
| 3093 |
"""
|
| 3094 |
|
| 3095 |
def __init__(self, config: ShramConfig) -> None:
|
| 3096 |
super().__init__()
|
| 3097 |
self.num_mosrah_heads = config.num_mosrah_heads
|
| 3098 |
self.num_selected_heads = config.num_selected_heads
|
|
|
|
| 3099 |
if config.use_cache:
|
| 3100 |
self.capacity = config.mosrah_cache_length
|
| 3101 |
else:
|
| 3102 |
self.capacity = config.mosrah_packed_length
|
| 3103 |
|
| 3104 |
self.max_bid_rounds = config.max_bid_rounds
|
| 3105 |
+
self.routing_mode = config.routing_mode
|
| 3106 |
self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
|
| 3107 |
|
| 3108 |
+
# W_r (A): semantic routing matrix. Maps input (B, N, d) to per-head routing
|
| 3109 |
+
# scores (B, N, L) for selection and routing_probs. nn.Parameter ensures
|
| 3110 |
+
# HuggingFace _init_weights does not override kaiming initialization.
|
| 3111 |
+
self.routing_weight = nn.Parameter(
|
| 3112 |
+
torch.empty(config.num_mosrah_heads, config.embedding_width)
|
| 3113 |
)
|
| 3114 |
+
nn.init.kaiming_uniform_(self.routing_weight)
|
| 3115 |
|
| 3116 |
+
# W_b (B): load-balancing projection matrix. Maps input (B, N, d) to per-head
|
| 3117 |
+
# correction scores (B, N, L). Receives gradients only from load_balance_loss.
|
| 3118 |
+
# nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
|
| 3119 |
+
self.balance_weight = nn.Parameter(
|
| 3120 |
+
torch.empty(config.num_mosrah_heads, config.embedding_width)
|
|
|
|
| 3121 |
)
|
| 3122 |
+
nn.init.kaiming_uniform_(self.balance_weight)
|
| 3123 |
+
|
| 3124 |
+
if self.routing_mode == "integral":
|
| 3125 |
+
L = config.num_mosrah_heads
|
| 3126 |
+
# A': integral semantic matrix. Maps cumulative logit history (B, N, L) to
|
| 3127 |
+
# per-head semantic corrections (B, N, L). Shape (L, L). Receives gradients
|
| 3128 |
+
# from task loss; balance_integral_weight is isolated from task loss.
|
| 3129 |
+
# Zero-initialized so that corrections start at zero and grow from gradient
|
| 3130 |
+
# updates — kaiming init produces corrections that immediately overwhelm the
|
| 3131 |
+
# base routing signal via the cumsum feedback path.
|
| 3132 |
+
self.routing_integral_weight = nn.Parameter(torch.zeros(L, L))
|
| 3133 |
+
|
| 3134 |
+
# B': integral load-balance matrix. Maps cumulative logit history (B, N, L)
|
| 3135 |
+
# to per-head load-balance corrections (B, N, L). Shape (L, L). Receives
|
| 3136 |
+
# gradients from load_balance_loss; routing_integral_weight is isolated.
|
| 3137 |
+
# Zero-initialized for the same reason as routing_integral_weight.
|
| 3138 |
+
self.balance_integral_weight = nn.Parameter(torch.zeros(L, L))
|
| 3139 |
|
| 3140 |
@staticmethod
|
| 3141 |
def get_best_proposals(
|
|
|
|
| 3371 |
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 3372 |
|
| 3373 |
Args:
|
| 3374 |
+
x: Input hidden states of shape (batch, seq_len, embedding_width).
|
| 3375 |
active_mask: Current-chunk active mask of shape (batch, seq_len), where
|
| 3376 |
True means the token is semantically live. Dead tokens do not
|
| 3377 |
contribute to routing frequencies, load_balance_loss, or max_vio.
|
|
|
|
| 3387 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3388 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3389 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3390 |
+
- ``raw_logit_std``: mean per-token std of routing_logits; natural
|
| 3391 |
+
routing preference scale and baseline for interpreting bias_std.
|
| 3392 |
+
- ``bias_std``: mean per-token std of balance_logits; near-zero
|
| 3393 |
+
means balance corrections have not built up relative to routing scale.
|
| 3394 |
- ``logit_std``: mean per-token std of semantic_logits; lower than
|
| 3395 |
+
raw_logit_std means balance is flattening preferences (healthy correction).
|
| 3396 |
+
- ``bias_alignment``: mean cosine similarity of routing_logits vs
|
| 3397 |
+
balance_logits per token. Negative means balance opposes routing direction
|
| 3398 |
+
(healthy correction); positive means runaway reinforcement.
|
| 3399 |
"""
|
| 3400 |
B, N, _ = x.shape
|
| 3401 |
L = self.num_mosrah_heads
|
| 3402 |
K = self.num_selected_heads
|
| 3403 |
|
| 3404 |
+
logits = self._compute_routing_logits(x, active_mask)
|
| 3405 |
+
|
| 3406 |
+
# Diagnostic scalars characterising the two routing pathways. Must be computed
|
| 3407 |
+
# before balance_capacity injects -1e8 sentinels that would corrupt std and
|
| 3408 |
+
# cosine similarity. Extracted to _compute_bias_diagnostics to keep the forward
|
| 3409 |
+
# body free of non-(B,N,L) reduction logic.
|
| 3410 |
+
bias_diagnostics = self._compute_bias_diagnostics(
|
| 3411 |
+
logits["routing_logits"], logits["balance_logits"], logits["semantic_logits"]
|
| 3412 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3413 |
|
| 3414 |
# Pre-capacity semantic softmax for gathering routing_probs. Computed before
|
| 3415 |
# balance_capacity so that gathered probabilities reflect genuine preference
|
| 3416 |
# magnitudes rather than hard-masked sentinel values.
|
| 3417 |
+
routing_scores = F.softmax(logits["semantic_logits"], dim=-1) # (B, N, L)
|
| 3418 |
|
| 3419 |
# Capacity-balanced semantic logits for selection. Injects -1e8 into positions
|
| 3420 |
# that would exceed per-expert token budget, enforcing the packing constraint.
|
| 3421 |
balanced_semantic_logits = self.balance_capacity(
|
| 3422 |
+
logits["semantic_logits"],
|
| 3423 |
used_capacity,
|
| 3424 |
self.capacity,
|
| 3425 |
self.num_selected_heads,
|
|
|
|
| 3435 |
gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
|
| 3436 |
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 3437 |
|
| 3438 |
+
# assignment_mask: (B, N, L) float — 1.0 at each token's K selected heads, 0 elsewhere.
|
| 3439 |
+
# The discrete routing decision; no gradient flows through it. Passed alongside
|
| 3440 |
+
# load_balancing_logits and active_mask to the loss and max_vio methods, which
|
| 3441 |
+
# own all frequency aggregation and reduction internally.
|
| 3442 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3443 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3444 |
|
| 3445 |
+
load_balance_loss = self._load_balance_loss(
|
| 3446 |
+
logits["load_balancing_logits"], assignment_mask, active_mask
|
| 3447 |
+
)
|
| 3448 |
|
| 3449 |
+
# MaxVio: detached monitoring scalar averaged over batch items. Computed from
|
| 3450 |
+
# the same (B, N, L) assignment_mask so frequencies are consistent with the loss.
|
| 3451 |
+
max_vio = self._compute_max_vio(assignment_mask, active_mask, L)
|
| 3452 |
|
| 3453 |
router_diagnostics = {
|
| 3454 |
"load_balance_loss": load_balance_loss,
|
| 3455 |
"max_vio": max_vio,
|
| 3456 |
+
**bias_diagnostics,
|
|
|
|
|
|
|
|
|
|
| 3457 |
}
|
| 3458 |
return selected_heads, routing_probs, router_diagnostics
|
| 3459 |
|
| 3460 |
@staticmethod
|
| 3461 |
+
def exclusive_cumsum(logits: torch.Tensor) -> torch.Tensor:
|
| 3462 |
+
"""Compute the exclusive cumulative sum along the sequence dimension.
|
| 3463 |
+
|
| 3464 |
+
u[n] = sum(logits[0..n-1]): position n receives the accumulated sum of all
|
| 3465 |
+
prior positions, giving it a read on the routing preferences expressed by
|
| 3466 |
+
earlier tokens in the sequence. Position 0 always receives zeros — no prior
|
| 3467 |
+
history exists at the first position.
|
| 3468 |
+
|
| 3469 |
+
Args:
|
| 3470 |
+
logits: Shape (B, N, L). Any per-head score tensor along a sequence.
|
| 3471 |
+
|
| 3472 |
+
Returns:
|
| 3473 |
+
Exclusive cumsum, shape (B, N, L). Same dtype and device as input.
|
| 3474 |
+
"""
|
| 3475 |
+
shifted = torch.cat(
|
| 3476 |
+
[torch.zeros_like(logits[:, :1, :]), logits[:, :-1, :]], dim=1
|
| 3477 |
+
)
|
| 3478 |
+
return shifted.cumsum(dim=1)
|
| 3479 |
+
|
| 3480 |
+
def _compute_routing_logits(
|
| 3481 |
+
self, x: torch.Tensor, active_mask: torch.Tensor
|
| 3482 |
+
) -> dict[str, torch.Tensor]:
|
| 3483 |
+
"""Compute the gradient-isolated logit pathways from input hidden states.
|
| 3484 |
+
|
| 3485 |
+
Base pathways (both modes):
|
| 3486 |
+
|
| 3487 |
+
Two gradient-isolated pathways over numerically identical values:
|
| 3488 |
+
- semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
|
| 3489 |
+
balance_weight is isolated from task loss.
|
| 3490 |
+
- load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance
|
| 3491 |
+
gradients reach balance_weight; routing_weight and x are isolated.
|
| 3492 |
+
|
| 3493 |
+
Integral extension (routing_mode == "integral"):
|
| 3494 |
+
|
| 3495 |
+
Dead tokens are zeroed out of the logits before computing the cumsum, so
|
| 3496 |
+
inactive positions do not contribute to the routing history of downstream
|
| 3497 |
+
live tokens. u_semantic and u_load therefore represent history from live
|
| 3498 |
+
tokens only.
|
| 3499 |
+
|
| 3500 |
+
u_semantic = exclusive_cumsum(semantic_logits * active_mask) — (B, N, L)
|
| 3501 |
+
u_load = exclusive_cumsum(load_balancing_logits * active_mask) — (B, N, L)
|
| 3502 |
+
|
| 3503 |
+
semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
|
| 3504 |
+
load_balancing_logits += (A'·u_load).detach() + B'·u_load
|
| 3505 |
+
|
| 3506 |
+
Detaching the full (B'·u_semantic) result mirrors the (B·x).detach() base
|
| 3507 |
+
pattern: it isolates balance_integral_weight from task loss AND prevents
|
| 3508 |
+
double-counting the cumsum gradient path back to routing_weight.
|
| 3509 |
+
The same reasoning applies to (A'·u_load).detach() in the load-balance
|
| 3510 |
+
pathway — u_load already has no path to routing_weight (routing_logits is
|
| 3511 |
+
detached in load_balancing_logits), and the detach additionally blocks
|
| 3512 |
+
routing_integral_weight.
|
| 3513 |
+
|
| 3514 |
+
Args:
|
| 3515 |
+
x: Input hidden states, shape (batch, seq_len, embedding_width).
|
| 3516 |
+
active_mask: Boolean active-token mask, shape (batch, seq_len). Dead tokens
|
| 3517 |
+
are excluded from the cumsum history in integral mode.
|
| 3518 |
+
|
| 3519 |
+
Returns:
|
| 3520 |
+
Dict with keys:
|
| 3521 |
+
- ``routing_logits``: A·x, shape (B, N, L).
|
| 3522 |
+
- ``balance_logits``: B·x, shape (B, N, L).
|
| 3523 |
+
- ``semantic_logits``: combined task-loss pathway, shape (B, N, L).
|
| 3524 |
+
- ``load_balancing_logits``: combined load-balance pathway, shape (B, N, L).
|
| 3525 |
+
"""
|
| 3526 |
+
routing_logits = F.linear(x, self.routing_weight) # (B, N, L)
|
| 3527 |
+
balance_logits = F.linear(x, self.balance_weight) # (B, N, L)
|
| 3528 |
+
semantic_logits = routing_logits + balance_logits.detach()
|
| 3529 |
+
load_balancing_logits = routing_logits.detach() + F.linear(x.detach(), self.balance_weight)
|
| 3530 |
+
|
| 3531 |
+
if self.routing_mode == "integral":
|
| 3532 |
+
# Zero out dead token positions before cumsum so inactive tokens do not
|
| 3533 |
+
# contaminate the routing history of subsequent live tokens.
|
| 3534 |
+
live = active_mask.unsqueeze(-1) # (B, N, 1)
|
| 3535 |
+
u_semantic = self.exclusive_cumsum(semantic_logits * live) # (B, N, L)
|
| 3536 |
+
u_load = self.exclusive_cumsum(load_balancing_logits * live) # (B, N, L)
|
| 3537 |
+
|
| 3538 |
+
# Semantic pathway: A' trains on task loss; B' term is fully detached to
|
| 3539 |
+
# isolate balance_integral_weight from task loss and prevent double-counting
|
| 3540 |
+
# the cumsum gradient path back to routing_weight.
|
| 3541 |
+
semantic_logits = (
|
| 3542 |
+
semantic_logits
|
| 3543 |
+
+ F.linear(u_semantic, self.routing_integral_weight)
|
| 3544 |
+
+ F.linear(u_semantic, self.balance_integral_weight).detach()
|
| 3545 |
+
)
|
| 3546 |
+
|
| 3547 |
+
# Load-balance pathway: B' trains on load_balance_loss; A' term is fully
|
| 3548 |
+
# detached to isolate routing_integral_weight from load_balance_loss.
|
| 3549 |
+
load_balancing_logits = (
|
| 3550 |
+
load_balancing_logits
|
| 3551 |
+
+ F.linear(u_load, self.routing_integral_weight).detach()
|
| 3552 |
+
+ F.linear(u_load, self.balance_integral_weight)
|
| 3553 |
+
)
|
| 3554 |
+
|
| 3555 |
+
return {
|
| 3556 |
+
"routing_logits": routing_logits,
|
| 3557 |
+
"balance_logits": balance_logits,
|
| 3558 |
+
"semantic_logits": semantic_logits,
|
| 3559 |
+
"load_balancing_logits": load_balancing_logits,
|
| 3560 |
+
}
|
| 3561 |
+
|
| 3562 |
+
@staticmethod
|
| 3563 |
+
def _compute_bias_diagnostics(
|
| 3564 |
+
routing_logits: torch.Tensor,
|
| 3565 |
+
balance_logits: torch.Tensor,
|
| 3566 |
+
semantic_logits: torch.Tensor,
|
| 3567 |
+
) -> dict[str, torch.Tensor]:
|
| 3568 |
+
"""Compute detached diagnostic scalars characterising the two routing pathways.
|
| 3569 |
+
|
| 3570 |
+
All scalars must be computed from pre-capacity logits; balance_capacity
|
| 3571 |
+
applies -1e8 sentinels that would corrupt std and cosine similarity.
|
| 3572 |
+
Extracted from forward to keep the main body free of reduction logic.
|
| 3573 |
+
|
| 3574 |
+
Args:
|
| 3575 |
+
routing_logits: A·x, routing pathway output, shape (B, N, L).
|
| 3576 |
+
balance_logits: B·x, balance pathway output, shape (B, N, L).
|
| 3577 |
+
semantic_logits: A·x + (B·x).detach(), combined signal, shape (B, N, L).
|
| 3578 |
+
|
| 3579 |
+
Returns:
|
| 3580 |
+
Dict with keys:
|
| 3581 |
+
- ``raw_logit_std``: Mean per-token std of routing_logits. Natural
|
| 3582 |
+
routing preference scale; reference baseline for
|
| 3583 |
+
interpreting bias_std.
|
| 3584 |
+
- ``bias_std``: Mean per-token std of balance_logits. Near-zero
|
| 3585 |
+
means balance corrections have not built up
|
| 3586 |
+
relative to the routing scale.
|
| 3587 |
+
- ``logit_std``: Mean per-token std of semantic_logits. Lower than
|
| 3588 |
+
raw_logit_std indicates balance is flattening
|
| 3589 |
+
preferences (healthy correction signal).
|
| 3590 |
+
- ``bias_alignment``: Mean cosine similarity of routing_logits vs
|
| 3591 |
+
balance_logits per token. Range [-1, 1]. Negative
|
| 3592 |
+
means balance opposes routing direction (healthy
|
| 3593 |
+
correction); positive means runaway reinforcement.
|
| 3594 |
+
"""
|
| 3595 |
+
return {
|
| 3596 |
+
"raw_logit_std": routing_logits.std(dim=-1).mean().detach(),
|
| 3597 |
+
"bias_std": balance_logits.std(dim=-1).mean().detach(),
|
| 3598 |
+
"logit_std": semantic_logits.std(dim=-1).mean().detach(),
|
| 3599 |
+
"bias_alignment": F.cosine_similarity(
|
| 3600 |
+
routing_logits, balance_logits, dim=-1
|
| 3601 |
+
).mean().detach(),
|
| 3602 |
+
}
|
| 3603 |
+
|
| 3604 |
+
@staticmethod
|
| 3605 |
+
def _compute_max_vio(
|
| 3606 |
+
assignment_mask: torch.Tensor,
|
| 3607 |
+
active_mask: torch.Tensor,
|
| 3608 |
+
num_heads: int,
|
| 3609 |
+
) -> torch.Tensor:
|
| 3610 |
"""Compute the MaxVio routing-imbalance scalar.
|
| 3611 |
|
| 3612 |
+
MaxVio = mean_b( L · max_l(f_bl − 1/L) ), where f_bl is the per-batch-item
|
| 3613 |
+
realised routing frequency of head l. Uses reduce_frequency_tokens for consistent
|
| 3614 |
+
per-batch-item frequency computation with dead tokens excluded, matching how the
|
| 3615 |
+
load balance loss computes frequencies. A value of zero indicates perfect balance;
|
| 3616 |
+
a value of 0.5 means the most overloaded head in the average batch item received
|
| 3617 |
+
50% more routed tokens than ideal.
|
| 3618 |
|
| 3619 |
+
The result is detached — MaxVio is a monitoring scalar and must not contribute
|
| 3620 |
+
gradients to any parameter.
|
| 3621 |
|
| 3622 |
Args:
|
| 3623 |
+
assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
|
| 3624 |
+
active_mask: Boolean active-token mask, shape (B, N).
|
| 3625 |
+
num_heads: Total number of MoSRAH heads L.
|
| 3626 |
|
| 3627 |
Returns:
|
| 3628 |
Detached scalar MaxVio tensor.
|
| 3629 |
"""
|
| 3630 |
+
f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L)
|
| 3631 |
+
per_item_max_vio = num_heads * (f_bl - 1.0 / num_heads).max(dim=-1).values # (B,)
|
| 3632 |
+
return per_item_max_vio.mean().detach()
|
| 3633 |
|
| 3634 |
# -----------
|
| 3635 |
# Inlined from: positions_converter.py
|