Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
Update architecture and tokenizer
Browse files- README.md +1 -0
- config.json +1 -0
- configuration.py +12 -0
- huggingface.py +94 -56
README.md
CHANGED
|
@@ -96,6 +96,7 @@ contains no weights. All values are overridable via kwargs.
|
|
| 96 |
| `output_hidden_states` | False |
|
| 97 |
| `rms_norm_eps` | 1e-05 |
|
| 98 |
| `rope_mode` | main_sequence |
|
|
|
|
| 99 |
| `tie_word_embeddings` | False |
|
| 100 |
| `training_sequence_length` | 1024 |
|
| 101 |
| `use_cache` | True |
|
|
|
|
| 96 |
| `output_hidden_states` | False |
|
| 97 |
| `rms_norm_eps` | 1e-05 |
|
| 98 |
| `rope_mode` | main_sequence |
|
| 99 |
+
| `router_init_scale` | 0.0001 |
|
| 100 |
| `tie_word_embeddings` | False |
|
| 101 |
| `training_sequence_length` | 1024 |
|
| 102 |
| `use_cache` | True |
|
config.json
CHANGED
|
@@ -23,6 +23,7 @@
|
|
| 23 |
"num_sliding_window_heads": 16,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"rope_mode": "main_sequence",
|
|
|
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"training_sequence_length": 1024,
|
| 28 |
"transformers_version": "5.10.2",
|
|
|
|
| 23 |
"num_sliding_window_heads": 16,
|
| 24 |
"rms_norm_eps": 1e-05,
|
| 25 |
"rope_mode": "main_sequence",
|
| 26 |
+
"router_init_scale": 0.0001,
|
| 27 |
"tie_word_embeddings": false,
|
| 28 |
"training_sequence_length": 1024,
|
| 29 |
"transformers_version": "5.10.2",
|
configuration.py
CHANGED
|
@@ -99,6 +99,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 99 |
is the default; its log-probability signal scales with violation severity
|
| 100 |
and makes correction magnitude proportional to routing imbalance.
|
| 101 |
Default ``"ce"``.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""
|
| 103 |
|
| 104 |
model_type = "shram"
|
|
@@ -135,6 +140,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 135 |
load_balance_p: float = 1.0,
|
| 136 |
max_bid_rounds: int = 10,
|
| 137 |
load_balance_loss_type: str = "ce",
|
|
|
|
| 138 |
**kwargs
|
| 139 |
):
|
| 140 |
if head_dim % 2 != 0:
|
|
@@ -191,6 +197,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 191 |
raise ValueError("In cross entropy mode, aggregation of "
|
| 192 |
"frequencies must be with mean 1.0")
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
self.vocab_size = vocab_size
|
| 195 |
self.embedding_width = embedding_width
|
| 196 |
self.mlp_width = mlp_width
|
|
@@ -212,6 +223,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 212 |
self.load_balance_p = load_balance_p
|
| 213 |
self.max_bid_rounds = max_bid_rounds
|
| 214 |
self.load_balance_loss_type = load_balance_loss_type
|
|
|
|
| 215 |
self.attention_dropout = attention_dropout
|
| 216 |
self.use_cache = use_cache
|
| 217 |
|
|
|
|
| 99 |
is the default; its log-probability signal scales with violation severity
|
| 100 |
and makes correction magnitude proportional to routing imbalance.
|
| 101 |
Default ``"ce"``.
|
| 102 |
+
router_init_scale: Initial standard deviation for the ``routing_scale``
|
| 103 |
+
scalar gate on routing logits. Brings routing logit magnitude to
|
| 104 |
+
``expert_bias`` scale at initialisation so load balancing is operative
|
| 105 |
+
from step one. Must be positive. Default ``1e-4``. Note lower values
|
| 106 |
+
may require more bidding rounds to converge and more overcapacity to support.
|
| 107 |
"""
|
| 108 |
|
| 109 |
model_type = "shram"
|
|
|
|
| 140 |
load_balance_p: float = 1.0,
|
| 141 |
max_bid_rounds: int = 10,
|
| 142 |
load_balance_loss_type: str = "ce",
|
| 143 |
+
router_init_scale: float = 1e-4,
|
| 144 |
**kwargs
|
| 145 |
):
|
| 146 |
if head_dim % 2 != 0:
|
|
|
|
| 197 |
raise ValueError("In cross entropy mode, aggregation of "
|
| 198 |
"frequencies must be with mean 1.0")
|
| 199 |
|
| 200 |
+
if router_init_scale <= 0.0:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"router_init_scale must be positive, got {router_init_scale}."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
self.vocab_size = vocab_size
|
| 206 |
self.embedding_width = embedding_width
|
| 207 |
self.mlp_width = mlp_width
|
|
|
|
| 223 |
self.load_balance_p = load_balance_p
|
| 224 |
self.max_bid_rounds = max_bid_rounds
|
| 225 |
self.load_balance_loss_type = load_balance_loss_type
|
| 226 |
+
self.router_init_scale = router_init_scale
|
| 227 |
self.attention_dropout = attention_dropout
|
| 228 |
self.use_cache = use_cache
|
| 229 |
|
huggingface.py
CHANGED
|
@@ -187,6 +187,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 187 |
is the default; its log-probability signal scales with violation severity
|
| 188 |
and makes correction magnitude proportional to routing imbalance.
|
| 189 |
Default ``"ce"``.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
"""
|
| 191 |
|
| 192 |
model_type = "shram"
|
|
@@ -223,6 +228,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 223 |
load_balance_p: float = 1.0,
|
| 224 |
max_bid_rounds: int = 10,
|
| 225 |
load_balance_loss_type: str = "ce",
|
|
|
|
| 226 |
**kwargs
|
| 227 |
):
|
| 228 |
if head_dim % 2 != 0:
|
|
@@ -279,6 +285,11 @@ class ShramConfig(PretrainedConfig):
|
|
| 279 |
raise ValueError("In cross entropy mode, aggregation of "
|
| 280 |
"frequencies must be with mean 1.0")
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
self.vocab_size = vocab_size
|
| 283 |
self.embedding_width = embedding_width
|
| 284 |
self.mlp_width = mlp_width
|
|
@@ -300,6 +311,7 @@ class ShramConfig(PretrainedConfig):
|
|
| 300 |
self.load_balance_p = load_balance_p
|
| 301 |
self.max_bid_rounds = max_bid_rounds
|
| 302 |
self.load_balance_loss_type = load_balance_loss_type
|
|
|
|
| 303 |
self.attention_dropout = attention_dropout
|
| 304 |
self.use_cache = use_cache
|
| 305 |
|
|
@@ -2724,19 +2736,33 @@ This module implements the routing mechanism described in Appendix A.Routing of
|
|
| 2724 |
paper. Given an input hidden state x, the router produces two outputs used downstream:
|
| 2725 |
|
| 2726 |
- selected_heads (I): which K of the L available expert heads each token routes to,
|
| 2727 |
-
determined by TopK over
|
| 2728 |
- routing_probs (P): the weights used for the weighted output reduction, gathered from
|
| 2729 |
-
|
| 2730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2731 |
|
| 2732 |
-
|
| 2733 |
-
|
| 2734 |
-
|
|
|
|
| 2735 |
|
| 2736 |
The router also computes and returns the load balance loss via a log-probability auxiliary
|
| 2737 |
loss (see load_balance_loss.py). The loss formulation is selected by config; the default
|
| 2738 |
-
is cross-entropy.
|
| 2739 |
-
isolated by detaching logits before computing assignment probabilities.
|
| 2740 |
|
| 2741 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2742 |
routing imbalance for the current forward pass:
|
|
@@ -2923,14 +2949,14 @@ def make_load_balance_loss(
|
|
| 2923 |
class MoSRAHRouter(nn.Module):
|
| 2924 |
"""Token-choice router for MoSRAH sparse attention.
|
| 2925 |
|
| 2926 |
-
Each input token independently selects K of the L available expert heads.
|
| 2927 |
-
|
| 2928 |
-
|
| 2929 |
-
|
| 2930 |
|
| 2931 |
The routing projection W_r has no bias term — the paper specifies xW_r with no
|
| 2932 |
additional projection bias. The only bias-like parameter is expert_bias (b), which
|
| 2933 |
-
has an entirely separate role and
|
| 2934 |
|
| 2935 |
Args:
|
| 2936 |
config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
|
|
@@ -2955,9 +2981,17 @@ class MoSRAHRouter(nn.Module):
|
|
| 2955 |
config.embedding_width, config.num_mosrah_heads, bias=False
|
| 2956 |
)
|
| 2957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2958 |
# b: learned per-head bias for load balancing. Initialized to zero so that all
|
| 2959 |
# heads start with equal selection probability. Updated by the main optimizer
|
| 2960 |
-
# via the
|
| 2961 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2962 |
|
| 2963 |
@staticmethod
|
|
@@ -3202,16 +3236,17 @@ class MoSRAHRouter(nn.Module):
|
|
| 3202 |
|
| 3203 |
Returns:
|
| 3204 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3205 |
-
Each token's K selected head indices, determined by TopK on
|
|
|
|
| 3206 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3207 |
-
num_selected_heads). Gathered from
|
| 3208 |
-
indices and renormalized to sum to 1 per token.
|
| 3209 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3210 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3211 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3212 |
- ``bias_std``: std of expert_bias; near-zero means corrections have not built up.
|
| 3213 |
-
- ``raw_logit_std``: mean per-token std of
|
| 3214 |
-
- ``logit_std``: mean per-token std of
|
| 3215 |
raw_logit_std means bias is flattening preferences (healthy correction).
|
| 3216 |
- ``bias_alignment``: mean cosine similarity of expert_bias against per-token
|
| 3217 |
logits. Negative means bias opposes routing direction (healthy correction);
|
|
@@ -3221,48 +3256,58 @@ class MoSRAHRouter(nn.Module):
|
|
| 3221 |
L = self.num_mosrah_heads
|
| 3222 |
K = self.num_selected_heads
|
| 3223 |
|
| 3224 |
-
#
|
| 3225 |
-
#
|
| 3226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3227 |
|
| 3228 |
# Diagnostic scalars characterising the load-balance mechanism. Must be
|
| 3229 |
# computed here — before balance_capacity injects -1e8 sentinels that
|
| 3230 |
# would corrupt std and cosine similarity.
|
| 3231 |
-
bias_std
|
| 3232 |
-
raw_logit_std
|
| 3233 |
-
logit_std
|
| 3234 |
bias_alignment = F.cosine_similarity(
|
| 3235 |
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3236 |
).mean().detach()
|
| 3237 |
|
| 3238 |
-
#
|
| 3239 |
-
|
| 3240 |
-
|
| 3241 |
-
|
| 3242 |
-
|
| 3243 |
-
|
| 3244 |
-
|
| 3245 |
-
|
| 3246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3247 |
used_capacity,
|
| 3248 |
self.capacity,
|
| 3249 |
self.num_selected_heads,
|
| 3250 |
self.max_bid_rounds,
|
| 3251 |
)
|
| 3252 |
-
|
| 3253 |
-
biased_logits, dim=-1
|
| 3254 |
-
)
|
| 3255 |
|
| 3256 |
-
# selected_heads I = TopK
|
| 3257 |
-
|
| 3258 |
-
selected_heads = biased_routing_scores.topk(K, dim=-1).indices
|
| 3259 |
-
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 3260 |
|
| 3261 |
-
# Routing probabilities P: gathered from
|
| 3262 |
-
#
|
| 3263 |
-
|
| 3264 |
-
|
| 3265 |
-
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 3266 |
|
| 3267 |
# Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
|
| 3268 |
# fraction of that item's active K assignments over all tokens go to head l.
|
|
@@ -3271,9 +3316,9 @@ class MoSRAHRouter(nn.Module):
|
|
| 3271 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3272 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 3273 |
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 3274 |
-
per_item_counts = active_assignments.sum(dim=1)
|
| 3275 |
-
per_item_total
|
| 3276 |
-
per_item_freqs
|
| 3277 |
|
| 3278 |
# p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
|
| 3279 |
# p-mean weights aggregation toward the worst-case batch item relative to
|
|
@@ -3282,13 +3327,6 @@ class MoSRAHRouter(nn.Module):
|
|
| 3282 |
p = self.load_balance_p
|
| 3283 |
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 3284 |
|
| 3285 |
-
# Active-token mean softmax probabilities. Detaching logits before softmax
|
| 3286 |
-
# ensures the only differentiable path into p is through expert_bias — the
|
| 3287 |
-
# load balance loss cannot reach routing_projection.weight.
|
| 3288 |
-
biased_probs = biased_routing_scores # (B, N, L)
|
| 3289 |
-
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 3290 |
-
assignment_probs = (biased_probs * active_float).sum(dim=(0, 1)) # (L,) unnorm
|
| 3291 |
-
assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
|
| 3292 |
load_balance_loss = self._load_balance_loss(routing_freqs, assignment_probs)
|
| 3293 |
|
| 3294 |
# MaxVio is a detached monitoring scalar following the paper's formula
|
|
|
|
| 187 |
is the default; its log-probability signal scales with violation severity
|
| 188 |
and makes correction magnitude proportional to routing imbalance.
|
| 189 |
Default ``"ce"``.
|
| 190 |
+
router_init_scale: Initial standard deviation for the ``routing_scale``
|
| 191 |
+
scalar gate on routing logits. Brings routing logit magnitude to
|
| 192 |
+
``expert_bias`` scale at initialisation so load balancing is operative
|
| 193 |
+
from step one. Must be positive. Default ``1e-4``. Note lower values
|
| 194 |
+
may require more bidding rounds to converge and more overcapacity to support.
|
| 195 |
"""
|
| 196 |
|
| 197 |
model_type = "shram"
|
|
|
|
| 228 |
load_balance_p: float = 1.0,
|
| 229 |
max_bid_rounds: int = 10,
|
| 230 |
load_balance_loss_type: str = "ce",
|
| 231 |
+
router_init_scale: float = 1e-4,
|
| 232 |
**kwargs
|
| 233 |
):
|
| 234 |
if head_dim % 2 != 0:
|
|
|
|
| 285 |
raise ValueError("In cross entropy mode, aggregation of "
|
| 286 |
"frequencies must be with mean 1.0")
|
| 287 |
|
| 288 |
+
if router_init_scale <= 0.0:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"router_init_scale must be positive, got {router_init_scale}."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
self.vocab_size = vocab_size
|
| 294 |
self.embedding_width = embedding_width
|
| 295 |
self.mlp_width = mlp_width
|
|
|
|
| 311 |
self.load_balance_p = load_balance_p
|
| 312 |
self.max_bid_rounds = max_bid_rounds
|
| 313 |
self.load_balance_loss_type = load_balance_loss_type
|
| 314 |
+
self.router_init_scale = router_init_scale
|
| 315 |
self.attention_dropout = attention_dropout
|
| 316 |
self.use_cache = use_cache
|
| 317 |
|
|
|
|
| 2736 |
paper. Given an input hidden state x, the router produces two outputs used downstream:
|
| 2737 |
|
| 2738 |
- selected_heads (I): which K of the L available expert heads each token routes to,
|
| 2739 |
+
determined by TopK over capacity-balanced semantic routing scores.
|
| 2740 |
- routing_probs (P): the weights used for the weighted output reduction, gathered from
|
| 2741 |
+
the semantic routing scores at the selected indices and renormalized to sum to 1
|
| 2742 |
+
per token.
|
| 2743 |
+
|
| 2744 |
+
Routing computation uses two gradient-isolated pathways over numerically identical
|
| 2745 |
+
biased values:
|
| 2746 |
+
|
| 2747 |
+
- semantic_logits = logits + expert_bias.detach(): drives selection and routing_probs.
|
| 2748 |
+
Task gradients reach routing_projection.weight; expert_bias is isolated from task loss.
|
| 2749 |
+
- load_balancing_logits = logits.detach() + expert_bias: drives assignment_probs.
|
| 2750 |
+
Load balance gradients reach expert_bias; routing_projection.weight is isolated from
|
| 2751 |
+
load balance loss.
|
| 2752 |
+
|
| 2753 |
+
No unbiased routing computation exists. All routing uses biased values. The separation
|
| 2754 |
+
of gradient paths replaces the previous biased/unbiased split, closing the loophole where
|
| 2755 |
+
a bias-redirected expert could be selected but contribute negligibly to the output because
|
| 2756 |
+
its unbiased preference — and thus its routing_prob — remained near zero.
|
| 2757 |
|
| 2758 |
+
Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
|
| 2759 |
+
Post-capacity softmax would invert the load balance gradient for over-capacity experts
|
| 2760 |
+
(near-zero probability after masking signals "increase bias" for an already-overloaded
|
| 2761 |
+
expert).
|
| 2762 |
|
| 2763 |
The router also computes and returns the load balance loss via a log-probability auxiliary
|
| 2764 |
loss (see load_balance_loss.py). The loss formulation is selected by config; the default
|
| 2765 |
+
is cross-entropy.
|
|
|
|
| 2766 |
|
| 2767 |
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 2768 |
routing imbalance for the current forward pass:
|
|
|
|
| 2949 |
class MoSRAHRouter(nn.Module):
|
| 2950 |
"""Token-choice router for MoSRAH sparse attention.
|
| 2951 |
|
| 2952 |
+
Each input token independently selects K of the L available expert heads. Both
|
| 2953 |
+
selection and routing_probs incorporate expert_bias via two gradient-isolated
|
| 2954 |
+
pathways over numerically identical biased values. See module docstring for the
|
| 2955 |
+
two-pathway architecture.
|
| 2956 |
|
| 2957 |
The routing projection W_r has no bias term — the paper specifies xW_r with no
|
| 2958 |
additional projection bias. The only bias-like parameter is expert_bias (b), which
|
| 2959 |
+
has an entirely separate role and gradient path.
|
| 2960 |
|
| 2961 |
Args:
|
| 2962 |
config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
|
|
|
|
| 2981 |
config.embedding_width, config.num_mosrah_heads, bias=False
|
| 2982 |
)
|
| 2983 |
|
| 2984 |
+
# Scalar gate on routing logits. As an nn.Parameter it is exempt from
|
| 2985 |
+
# HuggingFace _init_weights, so its near-zero initial value is preserved
|
| 2986 |
+
# after from_config construction. Near-zero initialization ensures routing
|
| 2987 |
+
# starts near-uniform and expert_bias has leverage over logits from step one.
|
| 2988 |
+
self.routing_scale = nn.Parameter(
|
| 2989 |
+
torch.full((1,), config.router_init_scale)
|
| 2990 |
+
)
|
| 2991 |
+
|
| 2992 |
# b: learned per-head bias for load balancing. Initialized to zero so that all
|
| 2993 |
# heads start with equal selection probability. Updated by the main optimizer
|
| 2994 |
+
# via gradients from the load balance loss through load_balancing_logits.
|
| 2995 |
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 2996 |
|
| 2997 |
@staticmethod
|
|
|
|
| 3236 |
|
| 3237 |
Returns:
|
| 3238 |
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 3239 |
+
Each token's K selected head indices, determined by TopK on
|
| 3240 |
+
capacity-balanced semantic scores.
|
| 3241 |
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 3242 |
+
num_selected_heads). Gathered from pre-capacity semantic softmax at
|
| 3243 |
+
selected_heads indices and renormalized to sum to 1 per token.
|
| 3244 |
router_diagnostics: Dict of routing feedback scalars. Keys:
|
| 3245 |
- ``load_balance_loss``: scalar load-balance loss with gradient.
|
| 3246 |
- ``max_vio``: detached scalar routing-imbalance summary.
|
| 3247 |
- ``bias_std``: std of expert_bias; near-zero means corrections have not built up.
|
| 3248 |
+
- ``raw_logit_std``: mean per-token std of scaled logits; the natural routing scale.
|
| 3249 |
+
- ``logit_std``: mean per-token std of semantic_logits; lower than
|
| 3250 |
raw_logit_std means bias is flattening preferences (healthy correction).
|
| 3251 |
- ``bias_alignment``: mean cosine similarity of expert_bias against per-token
|
| 3252 |
logits. Negative means bias opposes routing direction (healthy correction);
|
|
|
|
| 3256 |
L = self.num_mosrah_heads
|
| 3257 |
K = self.num_selected_heads
|
| 3258 |
|
| 3259 |
+
# Scaled logits. routing_scale is a near-zero nn.Parameter exempt from
|
| 3260 |
+
# HuggingFace _init_weights, so routing starts near-uniform and expert_bias
|
| 3261 |
+
# has leverage from step one.
|
| 3262 |
+
logits = self.routing_projection(x) * self.routing_scale # (B, N, L)
|
| 3263 |
+
|
| 3264 |
+
# Two gradient-isolated pathways over numerically identical biased values.
|
| 3265 |
+
# semantic_logits: task gradients reach routing_projection; expert_bias isolated.
|
| 3266 |
+
# load_balancing_logits: load balance gradients reach expert_bias; routing_projection isolated.
|
| 3267 |
+
semantic_logits = logits + self.expert_bias.detach() # (B, N, L)
|
| 3268 |
+
load_balancing_logits = logits.detach() + self.expert_bias # (B, N, L)
|
| 3269 |
|
| 3270 |
# Diagnostic scalars characterising the load-balance mechanism. Must be
|
| 3271 |
# computed here — before balance_capacity injects -1e8 sentinels that
|
| 3272 |
# would corrupt std and cosine similarity.
|
| 3273 |
+
bias_std = self.expert_bias.std().detach()
|
| 3274 |
+
raw_logit_std = logits.std(dim=-1).mean().detach()
|
| 3275 |
+
logit_std = semantic_logits.std(dim=-1).mean().detach()
|
| 3276 |
bias_alignment = F.cosine_similarity(
|
| 3277 |
logits, self.expert_bias.expand_as(logits), dim=-1
|
| 3278 |
).mean().detach()
|
| 3279 |
|
| 3280 |
+
# Assignment probabilities for load balance loss. Computed from load_balancing_logits
|
| 3281 |
+
# before balance_capacity so that -1e8 sentinels do not invert the load balance
|
| 3282 |
+
# gradient for over-capacity experts. active_float is reused below for routing freqs.
|
| 3283 |
+
active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
|
| 3284 |
+
lb_softmax = F.softmax(load_balancing_logits, dim=-1) # (B, N, L)
|
| 3285 |
+
assignment_probs = (lb_softmax * active_float).sum(dim=(0, 1)) # (L,) unnorm
|
| 3286 |
+
assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
|
| 3287 |
+
|
| 3288 |
+
# Pre-capacity semantic softmax for gathering routing_probs. Computed before
|
| 3289 |
+
# balance_capacity so that gathered probabilities reflect genuine preference
|
| 3290 |
+
# magnitudes rather than hard-masked sentinel values.
|
| 3291 |
+
routing_scores = F.softmax(semantic_logits, dim=-1) # (B, N, L)
|
| 3292 |
+
|
| 3293 |
+
# Capacity-balanced semantic logits for selection. Injects -1e8 into positions
|
| 3294 |
+
# that would exceed per-expert token budget, enforcing the packing constraint.
|
| 3295 |
+
balanced_semantic_logits = self.balance_capacity(
|
| 3296 |
+
semantic_logits,
|
| 3297 |
used_capacity,
|
| 3298 |
self.capacity,
|
| 3299 |
self.num_selected_heads,
|
| 3300 |
self.max_bid_rounds,
|
| 3301 |
)
|
| 3302 |
+
selection_scores = F.softmax(balanced_semantic_logits, dim=-1) # (B, N, L)
|
|
|
|
|
|
|
| 3303 |
|
| 3304 |
+
# selected_heads I = TopK over capacity-balanced semantic scores.
|
| 3305 |
+
selected_heads = selection_scores.topk(K, dim=-1).indices # (B, N, K)
|
|
|
|
|
|
|
| 3306 |
|
| 3307 |
+
# Routing probabilities P: gathered from pre-capacity semantic softmax at
|
| 3308 |
+
# selected_heads positions, renormalized so they sum to 1 per token.
|
| 3309 |
+
gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
|
| 3310 |
+
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
|
|
|
| 3311 |
|
| 3312 |
# Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
|
| 3313 |
# fraction of that item's active K assignments over all tokens go to head l.
|
|
|
|
| 3316 |
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 3317 |
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 3318 |
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 3319 |
+
per_item_counts = active_assignments.sum(dim=1) # (B, L)
|
| 3320 |
+
per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
|
| 3321 |
+
per_item_freqs = per_item_counts / per_item_total # (B, L)
|
| 3322 |
|
| 3323 |
# p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
|
| 3324 |
# p-mean weights aggregation toward the worst-case batch item relative to
|
|
|
|
| 3327 |
p = self.load_balance_p
|
| 3328 |
routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
|
| 3329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3330 |
load_balance_loss = self._load_balance_loss(routing_freqs, assignment_probs)
|
| 3331 |
|
| 3332 |
# MaxVio is a detached monitoring scalar following the paper's formula
|