smithblack-0 commited on
Commit
f192edb
·
verified ·
1 Parent(s): 7a9910f

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. config.json +2 -2
  3. configuration.py +22 -22
  4. huggingface.py +301 -376
README.md CHANGED
@@ -82,9 +82,10 @@ contains no weights. All values are overridable via kwargs.
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
- | `load_balance_loss_type` | ce |
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
 
88
  | `mlp_width` | 1366 |
89
  | `mosrah_overallocation_factor` | 2.0 |
90
  | `mosrah_rope_theta` | 10000.0 |
@@ -95,7 +96,6 @@ contains no weights. All values are overridable via kwargs.
95
  | `output_hidden_states` | False |
96
  | `rms_norm_eps` | 1e-05 |
97
  | `rope_mode` | main_sequence |
98
- | `routing_mode` | integral |
99
  | `tie_word_embeddings` | False |
100
  | `training_sequence_length` | 1024 |
101
  | `use_cache` | True |
 
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
+ | `load_balance_loss_type` | temporal_overcapacity |
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
88
+ | `maximum_expert_overclaim` | 20 |
89
  | `mlp_width` | 1366 |
90
  | `mosrah_overallocation_factor` | 2.0 |
91
  | `mosrah_rope_theta` | 10000.0 |
 
96
  | `output_hidden_states` | False |
97
  | `rms_norm_eps` | 1e-05 |
98
  | `rope_mode` | main_sequence |
 
99
  | `tie_word_embeddings` | False |
100
  | `training_sequence_length` | 1024 |
101
  | `use_cache` | True |
config.json CHANGED
@@ -9,9 +9,10 @@
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
- "load_balance_loss_type": "ce",
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
 
15
  "mlp_width": 1366,
16
  "model_type": "shram",
17
  "mosrah_overallocation_factor": 2.0,
@@ -22,7 +23,6 @@
22
  "num_sliding_window_heads": 16,
23
  "rms_norm_eps": 1e-05,
24
  "rope_mode": "main_sequence",
25
- "routing_mode": "integral",
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
  "transformers_version": "5.10.2",
 
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
+ "load_balance_loss_type": "temporal_overcapacity",
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
15
+ "maximum_expert_overclaim": 20,
16
  "mlp_width": 1366,
17
  "model_type": "shram",
18
  "mosrah_overallocation_factor": 2.0,
 
23
  "num_sliding_window_heads": 16,
24
  "rms_norm_eps": 1e-05,
25
  "rope_mode": "main_sequence",
 
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
  "transformers_version": "5.10.2",
configuration.py CHANGED
@@ -91,17 +91,18 @@ class ShramConfig(PretrainedConfig):
91
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
92
  Default 10.
93
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
94
- One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
95
- is the default; its log-probability signal scales with violation severity
96
- and makes correction magnitude proportional to routing imbalance.
97
- Default ``"ce"``.
98
- routing_mode: Routing computation mode. ``"integral"`` (default) enables the
99
- integral routing extension: the exclusive cumsum of routing logits along
100
- the sequence dimension is mapped through two additional (L, L) parameter
101
- matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
102
- B') and added as corrections to both logit pathways. This gives each
103
- token a read on the cumulative routing history so far in the sequence.
104
- ``"default"`` disables the extension; A' and B' are not created.
 
105
  """
106
 
107
  model_type = "shram"
@@ -136,8 +137,8 @@ class ShramConfig(PretrainedConfig):
136
  tie_word_embeddings: bool = False,
137
  mosrah_overallocation_factor: float = 2.0,
138
  max_bid_rounds: int = 10,
139
- load_balance_loss_type: str = "ce",
140
- routing_mode: str = "integral",
141
  **kwargs
142
  ):
143
  if head_dim % 2 != 0:
@@ -178,7 +179,13 @@ class ShramConfig(PretrainedConfig):
178
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
179
  )
180
 
181
- _supported_loss_types = {"gshard", "ce", "bce"}
 
 
 
 
 
 
182
  if load_balance_loss_type not in _supported_loss_types:
183
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
184
  raise ValueError(
@@ -186,13 +193,6 @@ class ShramConfig(PretrainedConfig):
186
  f"got {load_balance_loss_type!r}."
187
  )
188
 
189
- _supported_routing_modes = {"default", "integral"}
190
- if routing_mode not in _supported_routing_modes:
191
- supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
192
- raise ValueError(
193
- f"routing_mode must be one of {supported}, got {routing_mode!r}."
194
- )
195
-
196
  self.vocab_size = vocab_size
197
  self.embedding_width = embedding_width
198
  self.mlp_width = mlp_width
@@ -213,7 +213,7 @@ class ShramConfig(PretrainedConfig):
213
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
214
  self.max_bid_rounds = max_bid_rounds
215
  self.load_balance_loss_type = load_balance_loss_type
216
- self.routing_mode = routing_mode
217
  self.attention_dropout = attention_dropout
218
  self.use_cache = use_cache
219
 
 
91
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
92
  Default 10.
93
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
94
+ One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
95
+ ``"temporal_overcapacity"`` is the default; it fires only when an expert
96
+ exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
97
+ and shuts off automatically once routing is balanced, allowing it to be
98
+ used with a strong weight without interfering with task training during
99
+ balanced routing. Default ``"temporal_overcapacity"``.
100
+ maximum_expert_overclaim: Maximum number of tokens an expert may receive above
101
+ its ideal allocation trajectory before the temporal overcapacity loss
102
+ fires. A value of 0 means violations trigger immediately at any imbalance.
103
+ Larger values permit short-lived semantic specialization before correction.
104
+ Only used when ``load_balance_loss_type="temporal_overcapacity"``.
105
+ Must be non-negative. Default 20.
106
  """
107
 
108
  model_type = "shram"
 
137
  tie_word_embeddings: bool = False,
138
  mosrah_overallocation_factor: float = 2.0,
139
  max_bid_rounds: int = 10,
140
+ load_balance_loss_type: str = "temporal_overcapacity",
141
+ maximum_expert_overclaim: int = 20,
142
  **kwargs
143
  ):
144
  if head_dim % 2 != 0:
 
179
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
180
  )
181
 
182
+ if maximum_expert_overclaim < 0:
183
+ raise ValueError(
184
+ f"maximum_expert_overclaim must be non-negative, "
185
+ f"got {maximum_expert_overclaim}."
186
+ )
187
+
188
+ _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
189
  if load_balance_loss_type not in _supported_loss_types:
190
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
191
  raise ValueError(
 
193
  f"got {load_balance_loss_type!r}."
194
  )
195
 
 
 
 
 
 
 
 
196
  self.vocab_size = vocab_size
197
  self.embedding_width = embedding_width
198
  self.mlp_width = mlp_width
 
213
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
214
  self.max_bid_rounds = max_bid_rounds
215
  self.load_balance_loss_type = load_balance_loss_type
216
+ self.maximum_expert_overclaim = maximum_expert_overclaim
217
  self.attention_dropout = attention_dropout
218
  self.use_cache = use_cache
219
 
huggingface.py CHANGED
@@ -178,17 +178,18 @@ class ShramConfig(PretrainedConfig):
178
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
179
  Default 10.
180
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
181
- One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
182
- is the default; its log-probability signal scales with violation severity
183
- and makes correction magnitude proportional to routing imbalance.
184
- Default ``"ce"``.
185
- routing_mode: Routing computation mode. ``"integral"`` (default) enables the
186
- integral routing extension: the exclusive cumsum of routing logits along
187
- the sequence dimension is mapped through two additional (L, L) parameter
188
- matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
189
- B') and added as corrections to both logit pathways. This gives each
190
- token a read on the cumulative routing history so far in the sequence.
191
- ``"default"`` disables the extension; A' and B' are not created.
 
192
  """
193
 
194
  model_type = "shram"
@@ -223,8 +224,8 @@ class ShramConfig(PretrainedConfig):
223
  tie_word_embeddings: bool = False,
224
  mosrah_overallocation_factor: float = 2.0,
225
  max_bid_rounds: int = 10,
226
- load_balance_loss_type: str = "ce",
227
- routing_mode: str = "integral",
228
  **kwargs
229
  ):
230
  if head_dim % 2 != 0:
@@ -265,7 +266,13 @@ class ShramConfig(PretrainedConfig):
265
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
266
  )
267
 
268
- _supported_loss_types = {"gshard", "ce", "bce"}
 
 
 
 
 
 
269
  if load_balance_loss_type not in _supported_loss_types:
270
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
271
  raise ValueError(
@@ -273,13 +280,6 @@ class ShramConfig(PretrainedConfig):
273
  f"got {load_balance_loss_type!r}."
274
  )
275
 
276
- _supported_routing_modes = {"default", "integral"}
277
- if routing_mode not in _supported_routing_modes:
278
- supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
279
- raise ValueError(
280
- f"routing_mode must be one of {supported}, got {routing_mode!r}."
281
- )
282
-
283
  self.vocab_size = vocab_size
284
  self.embedding_width = embedding_width
285
  self.mlp_width = mlp_width
@@ -300,7 +300,7 @@ class ShramConfig(PretrainedConfig):
300
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
301
  self.max_bid_rounds = max_bid_rounds
302
  self.load_balance_loss_type = load_balance_loss_type
303
- self.routing_mode = routing_mode
304
  self.attention_dropout = attention_dropout
305
  self.use_cache = use_cache
306
 
@@ -1478,10 +1478,7 @@ Returns a plain dict with keys:
1478
  - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
1479
  - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
1480
  - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
1481
- - "bias_std": detached scalar mean per-layer std of the expert bias vector
1482
- - "raw_logit_std": detached scalar mean per-layer per-token routing logit spread
1483
- - "logit_std": detached scalar mean per-layer per-token combined (logit + bias) spread
1484
- - "bias_alignment": detached scalar mean per-layer cosine similarity of bias vs logits
1485
  """
1486
 
1487
 
@@ -2725,71 +2722,38 @@ This module implements the routing mechanism described in Appendix A.Routing of
2725
  paper. Given an input hidden state x, the router produces two outputs used downstream:
2726
 
2727
  - selected_heads (I): which K of the L available expert heads each token routes to,
2728
- determined by TopK over capacity-balanced semantic routing scores.
2729
  - routing_probs (P): the weights used for the weighted output reduction, gathered from
2730
- the semantic routing scores at the selected indices and renormalized to sum to 1
2731
- per token.
2732
-
2733
- Base routing uses two learnable projection matrices and two gradient-isolated pathways:
2734
-
2735
- - routing_weight (A): shape (L, embedding_width). Maps input to per-head routing
2736
- scores. Receives gradients from task loss; balance_weight is isolated.
2737
- - balance_weight (B): shape (L, embedding_width). Maps input to per-head load-balance
2738
- correction scores. Receives gradients from load_balance_loss; routing_weight is
2739
- isolated.
2740
-
2741
- The two gradient-isolated base pathways over numerically identical values:
2742
-
2743
- - semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
2744
- balance_weight is isolated from task loss.
2745
- - load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance gradients
2746
- reach balance_weight; routing_weight and x are isolated from load balance loss.
2747
-
2748
- Integral routing extension (routing_mode == "integral"):
2749
-
2750
- Standard routing is parallel — each token routes based on its own hidden state alone,
2751
- with no direct read on what earlier tokens in the sequence have already selected.
2752
- Integral routing adds a cumulative-sum signal that gives each token a view of the
2753
- prior routing history within the sequence.
2754
-
2755
- Two additional (L, L) parameter matrices are introduced:
2756
-
2757
- - routing_integral_weight (A'): shape (L, L). Maps the cumulative logit history to
2758
- per-head semantic corrections. Receives gradients from task loss.
2759
- - balance_integral_weight (B'): shape (L, L). Maps the cumulative logit history to
2760
- per-head load-balance corrections. Receives gradients from load_balance_loss.
2761
 
2762
- The cumulative history signal u is the exclusive cumsum of the base logits along the
2763
- sequence dimension: u[n] = sum(logits[0..n-1]), shape (B, N, L). Position 0 receives
2764
- zeros (no prior history). The same gradient isolation pattern as A/B applies:
2765
 
2766
- - semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
2767
- - lb_logits += (A'·u_load).detach() + B'·u_load
 
2768
 
2769
- Detaching the full B'·u_semantic result (rather than just B') mirrors the
2770
- (B·x).detach() pattern in the base pathway and prevents double-counting the
2771
- cumsum gradient path back to routing_weight.
 
 
2772
 
2773
- Both base matrices and both integral matrices are nn.Parameter so that HuggingFace
2774
- _init_weights does not override their kaiming initialization at construction.
2775
 
2776
- Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
2777
- Post-capacity softmax would invert the load balance gradient for over-capacity experts
2778
- (near-zero probability after masking signals "increase corrections" for an already-
2779
- overloaded expert).
2780
 
2781
- The router also computes and returns the load balance loss via a log-probability auxiliary
2782
- loss (see load_balance_loss.py). The loss formulation is selected by config; the default
2783
- is cross-entropy.
2784
-
2785
- The router additionally computes and returns MaxVio, a detached scalar summarising
2786
- routing imbalance for the current forward pass:
2787
-
2788
- MaxVio = mean_b( L · max_l(f_bl 1/L) )
2789
-
2790
- where f_bl is the per-batch-item realised routing frequency of head l and 1/L is the
2791
- perfectly balanced target. MaxVio is averaged over batch items and is a monitoring
2792
- quantity only; it never contributes gradients.
2793
 
2794
  Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2795
  """
@@ -2804,7 +2768,7 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2804
  # -----------
2805
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2806
 
2807
- This module provides three load-balance loss formulations, two token-reduction
2808
  helpers, and a factory that selects among the formulations. All formulations
2809
  share the same external contract:
2810
 
@@ -2814,9 +2778,8 @@ share the same external contract:
2814
  active_mask: Tensor[B, N],
2815
  ) -> scalar Tensor
2816
 
2817
- logits: Load-balancing logits, shape (B, N, L). These are the raw
2818
- pre-softmax scores from logits.detach() + expert_bias.
2819
- Gradient flows to expert_bias through this tensor.
2820
  assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
2821
  is 1.0 if token (b, n) was assigned to head l. Dead tokens
2822
  should carry zero entries.
@@ -2826,17 +2789,19 @@ share the same external contract:
2826
  Token reduction is split into two helpers with distinct roles:
2827
 
2828
  reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L).
2829
- Called by all three formulations. Output is detached; f_bl carries no gradient.
2830
 
2831
  reduce_probability_tokens — produces per-batch-item mean assignment probabilities
2832
- p_bl (B, L). Called only by gshard and bce. Gradient flows to expert_bias
2833
- through the internal softmax over logits.
2834
 
2835
  CE delegates probability computation to F.cross_entropy, which handles its own
2836
  log_softmax and operates directly on the raw (B, N, L) logits.
2837
 
2838
- The factory is the intended entry point. MoSRAHRouter constructs the loss callable
2839
- once at init and invokes it each forward pass.
 
 
2840
  """
2841
 
2842
 
@@ -3010,30 +2975,181 @@ def bce_loss(
3010
  """
3011
  f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
3012
  p_bl = reduce_probability_tokens(logits, active_mask)
3013
- # Clamp p_bl for numerical safety: F.binary_cross_entropy requires input in
3014
- # (0, 1) and will produce inf for exactly 0 or 1. Softmax outputs are
3015
- # strictly positive in normal operation; the clamp guards the all-dead-tokens
3016
- # edge case where the mean defaults to zero.
3017
- return F.binary_cross_entropy(
3018
- p_bl.clamp(min=1e-7, max=1.0 - 1e-7),
3019
- 1.0 - f_bl,
3020
- reduction='mean',
3021
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3022
 
3023
 
3024
  # ---------------------------------------------------------------------------
3025
  # Factory
3026
  # ---------------------------------------------------------------------------
3027
 
3028
- _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = {
3029
- "gshard": gshard_loss,
3030
- "ce": ce_loss,
3031
- "bce": bce_loss,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3032
  }
3033
 
3034
 
3035
  def make_load_balance_loss(
3036
  loss_type: str,
 
3037
  ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3038
  """Return a load-balance loss callable for the requested formulation.
3039
 
@@ -3045,11 +3161,14 @@ def make_load_balance_loss(
3045
  active_mask: Tensor[B, N],
3046
  ) -> scalar Tensor
3047
 
3048
- The caller is responsible for computing logits as logits.detach() + expert_bias
3049
- to ensure gradient isolation to expert_bias.
 
3050
 
3051
  Args:
3052
- loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
 
 
3053
 
3054
  Returns:
3055
  Loss callable matching the shared contract.
@@ -3062,34 +3181,29 @@ def make_load_balance_loss(
3062
  raise ValueError(
3063
  f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
3064
  )
3065
- return _LOSS_REGISTRY[loss_type]
3066
 
3067
 
3068
  class MoSRAHRouter(nn.Module):
3069
  """Token-choice router for MoSRAH sparse attention.
3070
 
3071
- Each input token independently selects K of the L available expert heads. Both
3072
- selection and routing_probs incorporate balance_weight via two gradient-isolated
3073
- pathways over numerically identical values. See module docstring for the
3074
- two-pathway architecture and the integral routing extension.
3075
 
3076
- All four learnable matrices are nn.Parameter rather than nn.Linear so that
3077
- HuggingFace _init_weights does not override their kaiming initialization at
3078
- construction.
3079
 
3080
  Attributes:
3081
- routing_weight: A, shape (L, embedding_width). Task-loss pathway.
3082
- balance_weight: B, shape (L, embedding_width). Load-balance pathway.
3083
- routing_integral_weight: A', shape (L, L). Integral task-loss pathway.
3084
- Present only when ``routing_mode == "integral"``.
3085
- balance_integral_weight: B', shape (L, L). Integral load-balance pathway.
3086
- Present only when ``routing_mode == "integral"``.
3087
- routing_mode: ``"integral"`` or ``"default"``, from config.
3088
 
3089
  Args:
3090
  config: Model configuration. Must expose ``embedding_width``,
3091
- ``num_mosrah_heads`` (L), ``num_selected_heads`` (K), and
3092
- ``routing_mode``.
 
3093
  """
3094
 
3095
  def __init__(self, config: ShramConfig) -> None:
@@ -3102,40 +3216,19 @@ class MoSRAHRouter(nn.Module):
3102
  self.capacity = config.mosrah_packed_length
3103
 
3104
  self.max_bid_rounds = config.max_bid_rounds
3105
- self.routing_mode = config.routing_mode
3106
- self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
3107
-
3108
- # W_r (A): semantic routing matrix. Maps input (B, N, d) to per-head routing
3109
- # scores (B, N, L) for selection and routing_probs. nn.Parameter ensures
3110
- # HuggingFace _init_weights does not override kaiming initialization.
3111
- self.routing_weight = nn.Parameter(
3112
- torch.empty(config.num_mosrah_heads, config.embedding_width)
3113
  )
3114
- nn.init.kaiming_uniform_(self.routing_weight)
3115
 
3116
- # W_b (B): load-balancing projection matrix. Maps input (B, N, d) to per-head
3117
- # correction scores (B, N, L). Receives gradients only from load_balance_loss.
3118
  # nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
3119
- self.balance_weight = nn.Parameter(
3120
  torch.empty(config.num_mosrah_heads, config.embedding_width)
3121
  )
3122
- nn.init.kaiming_uniform_(self.balance_weight)
3123
-
3124
- if self.routing_mode == "integral":
3125
- L = config.num_mosrah_heads
3126
- # A': integral semantic matrix. Maps cumulative logit history (B, N, L) to
3127
- # per-head semantic corrections (B, N, L). Shape (L, L). Receives gradients
3128
- # from task loss; balance_integral_weight is isolated from task loss.
3129
- # Zero-initialized so that corrections start at zero and grow from gradient
3130
- # updates — kaiming init produces corrections that immediately overwhelm the
3131
- # base routing signal via the cumsum feedback path.
3132
- self.routing_integral_weight = nn.Parameter(torch.zeros(L, L))
3133
-
3134
- # B': integral load-balance matrix. Maps cumulative logit history (B, N, L)
3135
- # to per-head load-balance corrections (B, N, L). Shape (L, L). Receives
3136
- # gradients from load_balance_loss; routing_integral_weight is isolated.
3137
- # Zero-initialized for the same reason as routing_integral_weight.
3138
- self.balance_integral_weight = nn.Parameter(torch.zeros(L, L))
3139
 
3140
  @staticmethod
3141
  def get_best_proposals(
@@ -3380,226 +3473,87 @@ class MoSRAHRouter(nn.Module):
3380
  Returns:
3381
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
3382
  Each token's K selected head indices, determined by TopK on
3383
- capacity-balanced semantic scores.
3384
  routing_probs: Routing probabilities P of shape (batch, seq_len,
3385
- num_selected_heads). Gathered from pre-capacity semantic softmax at
3386
  selected_heads indices and renormalized to sum to 1 per token.
3387
  router_diagnostics: Dict of routing feedback scalars. Keys:
3388
  - ``load_balance_loss``: scalar load-balance loss with gradient.
3389
  - ``max_vio``: detached scalar routing-imbalance summary.
3390
- - ``raw_logit_std``: mean per-token std of routing_logits; natural
3391
- routing preference scale and baseline for interpreting bias_std.
3392
- - ``bias_std``: mean per-token std of balance_logits; near-zero
3393
- means balance corrections have not built up relative to routing scale.
3394
- - ``logit_std``: mean per-token std of semantic_logits; lower than
3395
- raw_logit_std means balance is flattening preferences (healthy correction).
3396
- - ``bias_alignment``: mean cosine similarity of routing_logits vs
3397
- balance_logits per token. Negative means balance opposes routing direction
3398
- (healthy correction); positive means runaway reinforcement.
3399
  """
3400
  B, N, _ = x.shape
3401
  L = self.num_mosrah_heads
3402
  K = self.num_selected_heads
3403
 
3404
- logits = self._compute_routing_logits(x, active_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3405
 
3406
- # Diagnostic scalars characterising the two routing pathways. Must be computed
3407
- # before balance_capacity injects -1e8 sentinels that would corrupt std and
3408
- # cosine similarity. Extracted to _compute_bias_diagnostics to keep the forward
3409
- # body free of non-(B,N,L) reduction logic.
3410
- bias_diagnostics = self._compute_bias_diagnostics(
3411
- logits["routing_logits"], logits["balance_logits"], logits["semantic_logits"]
3412
  )
3413
 
3414
- # Pre-capacity semantic softmax for gathering routing_probs. Computed before
3415
- # balance_capacity so that gathered probabilities reflect genuine preference
3416
- # magnitudes rather than hard-masked sentinel values.
3417
- routing_scores = F.softmax(logits["semantic_logits"], dim=-1) # (B, N, L)
3418
-
3419
- # Capacity-balanced semantic logits for selection. Injects -1e8 into positions
3420
- # that would exceed per-expert token budget, enforcing the packing constraint.
3421
- balanced_semantic_logits = self.balance_capacity(
3422
- logits["semantic_logits"],
 
 
3423
  used_capacity,
3424
  self.capacity,
3425
  self.num_selected_heads,
3426
  self.max_bid_rounds,
3427
  )
3428
- selection_scores = F.softmax(balanced_semantic_logits, dim=-1) # (B, N, L)
3429
 
3430
- # selected_heads I = TopK over capacity-balanced semantic scores.
3431
- selected_heads = selection_scores.topk(K, dim=-1).indices # (B, N, K)
 
3432
 
3433
- # Routing probabilities P: gathered from pre-capacity semantic softmax at
3434
- # selected_heads positions, renormalized so they sum to 1 per token.
3435
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
3436
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
3437
 
3438
- # assignment_mask: (B, N, L) float — 1.0 at each token's K selected heads, 0 elsewhere.
3439
- # The discrete routing decision; no gradient flows through it. Passed alongside
3440
- # load_balancing_logits and active_mask to the loss and max_vio methods, which
3441
- # own all frequency aggregation and reduction internally.
3442
- assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
3443
- assignment_mask.scatter_(-1, selected_heads, 1.0)
3444
-
3445
- load_balance_loss = self._load_balance_loss(
3446
- logits["load_balancing_logits"], assignment_mask, active_mask
3447
- )
3448
-
3449
- # MaxVio: detached monitoring scalar averaged over batch items. Computed from
3450
- # the same (B, N, L) assignment_mask so frequencies are consistent with the loss.
3451
- max_vio = self._compute_max_vio(assignment_mask, active_mask, L)
3452
-
3453
  router_diagnostics = {
3454
  "load_balance_loss": load_balance_loss,
3455
- "max_vio": max_vio,
3456
- **bias_diagnostics,
3457
  }
3458
  return selected_heads, routing_probs, router_diagnostics
3459
 
3460
- @staticmethod
3461
- def exclusive_cumsum(logits: torch.Tensor) -> torch.Tensor:
3462
- """Compute the exclusive cumulative sum along the sequence dimension.
3463
-
3464
- u[n] = sum(logits[0..n-1]): position n receives the accumulated sum of all
3465
- prior positions, giving it a read on the routing preferences expressed by
3466
- earlier tokens in the sequence. Position 0 always receives zeros — no prior
3467
- history exists at the first position.
3468
-
3469
- Args:
3470
- logits: Shape (B, N, L). Any per-head score tensor along a sequence.
3471
-
3472
- Returns:
3473
- Exclusive cumsum, shape (B, N, L). Same dtype and device as input.
3474
- """
3475
- shifted = torch.cat(
3476
- [torch.zeros_like(logits[:, :1, :]), logits[:, :-1, :]], dim=1
3477
- )
3478
- return shifted.cumsum(dim=1)
3479
-
3480
- def _compute_routing_logits(
3481
- self, x: torch.Tensor, active_mask: torch.Tensor
3482
- ) -> dict[str, torch.Tensor]:
3483
- """Compute the gradient-isolated logit pathways from input hidden states.
3484
-
3485
- Base pathways (both modes):
3486
-
3487
- Two gradient-isolated pathways over numerically identical values:
3488
- - semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
3489
- balance_weight is isolated from task loss.
3490
- - load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance
3491
- gradients reach balance_weight; routing_weight and x are isolated.
3492
-
3493
- Integral extension (routing_mode == "integral"):
3494
-
3495
- Dead tokens are zeroed out of the logits before computing the cumsum, so
3496
- inactive positions do not contribute to the routing history of downstream
3497
- live tokens. u_semantic and u_load therefore represent history from live
3498
- tokens only.
3499
-
3500
- u_semantic = exclusive_cumsum(semantic_logits * active_mask) — (B, N, L)
3501
- u_load = exclusive_cumsum(load_balancing_logits * active_mask) — (B, N, L)
3502
-
3503
- semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
3504
- load_balancing_logits += (A'·u_load).detach() + B'·u_load
3505
-
3506
- Detaching the full (B'·u_semantic) result mirrors the (B·x).detach() base
3507
- pattern: it isolates balance_integral_weight from task loss AND prevents
3508
- double-counting the cumsum gradient path back to routing_weight.
3509
- The same reasoning applies to (A'·u_load).detach() in the load-balance
3510
- pathway — u_load already has no path to routing_weight (routing_logits is
3511
- detached in load_balancing_logits), and the detach additionally blocks
3512
- routing_integral_weight.
3513
 
3514
  Args:
3515
  x: Input hidden states, shape (batch, seq_len, embedding_width).
3516
- active_mask: Boolean active-token mask, shape (batch, seq_len). Dead tokens
3517
- are excluded from the cumsum history in integral mode.
3518
 
3519
  Returns:
3520
- Dict with keys:
3521
- - ``routing_logits``: A·x, shape (B, N, L).
3522
- - ``balance_logits``: B·x, shape (B, N, L).
3523
- - ``semantic_logits``: combined task-loss pathway, shape (B, N, L).
3524
- - ``load_balancing_logits``: combined load-balance pathway, shape (B, N, L).
3525
  """
3526
- routing_logits = F.linear(x, self.routing_weight) # (B, N, L)
3527
- balance_logits = F.linear(x, self.balance_weight) # (B, N, L)
3528
- semantic_logits = routing_logits + balance_logits.detach()
3529
- load_balancing_logits = routing_logits.detach() + F.linear(x.detach(), self.balance_weight)
3530
-
3531
- if self.routing_mode == "integral":
3532
- # Zero out dead token positions before cumsum so inactive tokens do not
3533
- # contaminate the routing history of subsequent live tokens.
3534
- live = active_mask.unsqueeze(-1) # (B, N, 1)
3535
- u_semantic = self.exclusive_cumsum(semantic_logits * live) # (B, N, L)
3536
- u_load = self.exclusive_cumsum(load_balancing_logits * live) # (B, N, L)
3537
-
3538
- # Semantic pathway: A' trains on task loss; B' term is fully detached to
3539
- # isolate balance_integral_weight from task loss and prevent double-counting
3540
- # the cumsum gradient path back to routing_weight.
3541
- semantic_logits = (
3542
- semantic_logits
3543
- + F.linear(u_semantic, self.routing_integral_weight)
3544
- + F.linear(u_semantic, self.balance_integral_weight).detach()
3545
- )
3546
-
3547
- # Load-balance pathway: B' trains on load_balance_loss; A' term is fully
3548
- # detached to isolate routing_integral_weight from load_balance_loss.
3549
- load_balancing_logits = (
3550
- load_balancing_logits
3551
- + F.linear(u_load, self.routing_integral_weight).detach()
3552
- + F.linear(u_load, self.balance_integral_weight)
3553
- )
3554
-
3555
- return {
3556
- "routing_logits": routing_logits,
3557
- "balance_logits": balance_logits,
3558
- "semantic_logits": semantic_logits,
3559
- "load_balancing_logits": load_balancing_logits,
3560
- }
3561
-
3562
- @staticmethod
3563
- def _compute_bias_diagnostics(
3564
- routing_logits: torch.Tensor,
3565
- balance_logits: torch.Tensor,
3566
- semantic_logits: torch.Tensor,
3567
- ) -> dict[str, torch.Tensor]:
3568
- """Compute detached diagnostic scalars characterising the two routing pathways.
3569
-
3570
- All scalars must be computed from pre-capacity logits; balance_capacity
3571
- applies -1e8 sentinels that would corrupt std and cosine similarity.
3572
- Extracted from forward to keep the main body free of reduction logic.
3573
-
3574
- Args:
3575
- routing_logits: A·x, routing pathway output, shape (B, N, L).
3576
- balance_logits: B·x, balance pathway output, shape (B, N, L).
3577
- semantic_logits: A·x + (B·x).detach(), combined signal, shape (B, N, L).
3578
-
3579
- Returns:
3580
- Dict with keys:
3581
- - ``raw_logit_std``: Mean per-token std of routing_logits. Natural
3582
- routing preference scale; reference baseline for
3583
- interpreting bias_std.
3584
- - ``bias_std``: Mean per-token std of balance_logits. Near-zero
3585
- means balance corrections have not built up
3586
- relative to the routing scale.
3587
- - ``logit_std``: Mean per-token std of semantic_logits. Lower than
3588
- raw_logit_std indicates balance is flattening
3589
- preferences (healthy correction signal).
3590
- - ``bias_alignment``: Mean cosine similarity of routing_logits vs
3591
- balance_logits per token. Range [-1, 1]. Negative
3592
- means balance opposes routing direction (healthy
3593
- correction); positive means runaway reinforcement.
3594
- """
3595
- return {
3596
- "raw_logit_std": routing_logits.std(dim=-1).mean().detach(),
3597
- "bias_std": balance_logits.std(dim=-1).mean().detach(),
3598
- "logit_std": semantic_logits.std(dim=-1).mean().detach(),
3599
- "bias_alignment": F.cosine_similarity(
3600
- routing_logits, balance_logits, dim=-1
3601
- ).mean().detach(),
3602
- }
3603
 
3604
  @staticmethod
3605
  def _compute_max_vio(
@@ -4137,30 +4091,15 @@ class ShramModel(nn.Module):
4137
  - ``"max_vio"``: detached scalar maximum routing-imbalance across
4138
  all decoder layers. Zero means perfectly balanced routing across
4139
  every layer; higher values identify the worst-case head imbalance.
4140
- - ``"bias_std"``: detached scalar — mean across layers of the std
4141
- of each layer's expert bias vector. Near-zero means corrections
4142
- have not built up; large relative to ``raw_logit_std`` means the
4143
- bias dominates routing.
4144
- - ``"raw_logit_std"``: detached scalar — mean across layers of the
4145
- per-token routing logit spread before bias addition. Baseline
4146
- natural routing preference scale.
4147
  - ``"logit_std"``: detached scalar — mean across layers of the
4148
- per-token combined (logit + bias) spread. Lower than
4149
- ``raw_logit_std`` indicates healthy flattening; higher indicates
4150
- amplification.
4151
- - ``"bias_alignment"``: detached scalar — mean across layers of the
4152
- per-token cosine similarity between the expert bias vector and the
4153
- routing logits. Negative is healthy correction; positive is
4154
- runaway feedback.
4155
  """
4156
  hidden_states = inputs_embeds
4157
  all_hidden_states = (hidden_states,) if output_hidden_states else None
4158
  total_load_balance_loss = inputs_embeds.new_zeros(())
4159
  max_vio = inputs_embeds.new_zeros(())
4160
- total_bias_std = inputs_embeds.new_zeros(())
4161
- total_raw_logit_std = inputs_embeds.new_zeros(())
4162
  total_logit_std = inputs_embeds.new_zeros(())
4163
- total_bias_alignment = inputs_embeds.new_zeros(())
4164
 
4165
  for layer_idx, layer in enumerate(self.layers):
4166
  layer_cache = None if cache is None else cache.layers[layer_idx]
@@ -4172,10 +4111,7 @@ class ShramModel(nn.Module):
4172
  )
4173
  total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
4174
  max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
4175
- total_bias_std = total_bias_std + layer_diagnostics["bias_std"]
4176
- total_raw_logit_std = total_raw_logit_std + layer_diagnostics["raw_logit_std"]
4177
  total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
4178
- total_bias_alignment = total_bias_alignment + layer_diagnostics["bias_alignment"]
4179
 
4180
  if output_hidden_states:
4181
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -4189,10 +4125,7 @@ class ShramModel(nn.Module):
4189
  "hidden_states": all_hidden_states,
4190
  "load_balance_loss": total_load_balance_loss,
4191
  "max_vio": max_vio,
4192
- "bias_std": total_bias_std / num_layers,
4193
- "raw_logit_std": total_raw_logit_std / num_layers,
4194
  "logit_std": total_logit_std / num_layers,
4195
- "bias_alignment": total_bias_alignment / num_layers,
4196
  }
4197
 
4198
 
@@ -4209,17 +4142,14 @@ class ShramCausalLMOutput(CausalLMOutputWithPast):
4209
  ## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
4210
  ## fields to None, which forces every subclass field to also carry a default.
4211
  ## The = None below is a language constraint, not a semantic statement. In
4212
- ## practice, load_balance_loss, max_vio, bias_std, raw_logit_std, logit_std,
4213
- ## and bias_alignment are always populated by ShramForCausalLM.forward().
4214
- ## ce_loss is genuinely optional — present only when labels are supplied.
4215
 
4216
  ce_loss: torch.FloatTensor | None = None
4217
  load_balance_loss: torch.FloatTensor | None = None
4218
  max_vio: torch.FloatTensor | None = None
4219
- bias_std: torch.Tensor | None = None
4220
- raw_logit_std: torch.Tensor | None = None
4221
  logit_std: torch.Tensor | None = None
4222
- bias_alignment: torch.Tensor | None = None
4223
 
4224
  class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4225
  """HuggingFace-facing causal language model wrapper for SHRAM.
@@ -4668,9 +4598,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4668
  - ``hidden_states`` when requested,
4669
  - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
4670
  - ``max_vio`` — detached worst-case routing imbalance across layers,
4671
- - ``bias_std``, ``raw_logit_std``, ``logit_std``, ``bias_alignment``
4672
- detached load-balance health scalars averaged across decoder layers;
4673
- see ``ShramModel`` for interpretation.
4674
  """
4675
  use_cache = use_cache if use_cache is not None else self.config.use_cache
4676
  output_hidden_states = (
@@ -4777,8 +4705,5 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4777
  hidden_states=backbone_outputs["hidden_states"],
4778
  load_balance_loss=backbone_outputs["load_balance_loss"],
4779
  max_vio=backbone_outputs["max_vio"],
4780
- bias_std=backbone_outputs["bias_std"],
4781
- raw_logit_std=backbone_outputs["raw_logit_std"],
4782
  logit_std=backbone_outputs["logit_std"],
4783
- bias_alignment=backbone_outputs["bias_alignment"],
4784
  )
 
178
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
179
  Default 10.
180
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
181
+ One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
182
+ ``"temporal_overcapacity"`` is the default; it fires only when an expert
183
+ exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
184
+ and shuts off automatically once routing is balanced, allowing it to be
185
+ used with a strong weight without interfering with task training during
186
+ balanced routing. Default ``"temporal_overcapacity"``.
187
+ maximum_expert_overclaim: Maximum number of tokens an expert may receive above
188
+ its ideal allocation trajectory before the temporal overcapacity loss
189
+ fires. A value of 0 means violations trigger immediately at any imbalance.
190
+ Larger values permit short-lived semantic specialization before correction.
191
+ Only used when ``load_balance_loss_type="temporal_overcapacity"``.
192
+ Must be non-negative. Default 20.
193
  """
194
 
195
  model_type = "shram"
 
224
  tie_word_embeddings: bool = False,
225
  mosrah_overallocation_factor: float = 2.0,
226
  max_bid_rounds: int = 10,
227
+ load_balance_loss_type: str = "temporal_overcapacity",
228
+ maximum_expert_overclaim: int = 20,
229
  **kwargs
230
  ):
231
  if head_dim % 2 != 0:
 
266
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
267
  )
268
 
269
+ if maximum_expert_overclaim < 0:
270
+ raise ValueError(
271
+ f"maximum_expert_overclaim must be non-negative, "
272
+ f"got {maximum_expert_overclaim}."
273
+ )
274
+
275
+ _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
276
  if load_balance_loss_type not in _supported_loss_types:
277
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
278
  raise ValueError(
 
280
  f"got {load_balance_loss_type!r}."
281
  )
282
 
 
 
 
 
 
 
 
283
  self.vocab_size = vocab_size
284
  self.embedding_width = embedding_width
285
  self.mlp_width = mlp_width
 
300
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
301
  self.max_bid_rounds = max_bid_rounds
302
  self.load_balance_loss_type = load_balance_loss_type
303
+ self.maximum_expert_overclaim = maximum_expert_overclaim
304
  self.attention_dropout = attention_dropout
305
  self.use_cache = use_cache
306
 
 
1478
  - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
1479
  - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
1480
  - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
1481
+ - "logit_std": detached scalar mean per-layer per-token routing logit spread
 
 
 
1482
  """
1483
 
1484
 
 
2722
  paper. Given an input hidden state x, the router produces two outputs used downstream:
2723
 
2724
  - selected_heads (I): which K of the L available expert heads each token routes to,
2725
+ determined by TopK over capacity-balanced routing scores.
2726
  - routing_probs (P): the weights used for the weighted output reduction, gathered from
2727
+ the routing scores at the selected indices and renormalized to sum to 1 per token.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2728
 
2729
+ Routing uses a single learnable projection:
 
 
2730
 
2731
+ - routing_weight: shape (L, embedding_width). Maps input to per-head routing scores.
2732
+ Both task loss and load_balance_loss train this parameter directly — there is no
2733
+ gradient isolation between the two signals.
2734
 
2735
+ This coupled design is intentional. SHRAM has an unusually strong task-level incentive
2736
+ to concentrate tokens into the same expert bucket (sparse attention only occurs among
2737
+ tokens routed to the same expert), so any indirect balancing pathway will be outlearned.
2738
+ Coupling the gradients allows the load balance loss to act with full strength directly
2739
+ on the parameter that determines routing.
2740
 
2741
+ routing_weight is nn.Parameter so that HuggingFace _init_weights does not override
2742
+ its kaiming initialization at construction.
2743
 
2744
+ routing_probs are computed before balance_capacity applies -1e8 sentinels. Post-capacity
2745
+ softmax would corrupt routing_probs for over-capacity experts (near-zero probability
2746
+ after masking does not reflect genuine routing preference).
 
2747
 
2748
+ The router computes and returns:
2749
+ - load_balance_loss: scalar auxiliary loss (see load_balance_loss.py); gradient flows
2750
+ to routing_weight.
2751
+ - max_vio: detached scalar summarising routing imbalance:
2752
+ MaxVio = mean_b( L · max_l(f_bl 1/L) )
2753
+ where f_bl is the per-batch-item realised routing frequency of head l. Zero means
2754
+ perfect balance; 1.0 means the most loaded head received double its fair share.
2755
+ - logit_std: detached scalar; mean per-token standard deviation of routing logits.
2756
+ Monitoring metric for routing sharpness.
 
 
 
2757
 
2758
  Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2759
  """
 
2768
  # -----------
2769
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2770
 
2771
+ This module provides four load-balance loss formulations, two token-reduction
2772
  helpers, and a factory that selects among the formulations. All formulations
2773
  share the same external contract:
2774
 
 
2778
  active_mask: Tensor[B, N],
2779
  ) -> scalar Tensor
2780
 
2781
+ logits: Pre-softmax routing scores, shape (B, N, L). Gradient flows
2782
+ through this tensor.
 
2783
  assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
2784
  is 1.0 if token (b, n) was assigned to head l. Dead tokens
2785
  should carry zero entries.
 
2789
  Token reduction is split into two helpers with distinct roles:
2790
 
2791
  reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L).
2792
+ Called by gshard, ce, and bce. Output is detached; f_bl carries no gradient.
2793
 
2794
  reduce_probability_tokens — produces per-batch-item mean assignment probabilities
2795
+ p_bl (B, L). Called only by gshard and bce. Gradient flows through the
2796
+ internal softmax over logits.
2797
 
2798
  CE delegates probability computation to F.cross_entropy, which handles its own
2799
  log_softmax and operates directly on the raw (B, N, L) logits.
2800
 
2801
+ ``make_load_balance_loss`` is the sole public entry point. The individual loss
2802
+ functions are internal implementation details; their signatures may change between
2803
+ units. Callers and tests must construct loss callables through the factory, not by
2804
+ importing or invoking the loss functions directly.
2805
  """
2806
 
2807
 
 
2975
  """
2976
  f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
2977
  p_bl = reduce_probability_tokens(logits, active_mask)
2978
+ # Clamp for numerical safety: softmax outputs are strictly positive in
2979
+ # normal operation; the clamp guards the all-dead-tokens edge case where
2980
+ # the mean defaults to zero. log1p(-p) avoids cancellation near p=1.
2981
+ p = p_bl.clamp(min=1e-7, max=1.0 - 1e-7)
2982
+ target = 1.0 - f_bl
2983
+ return -(target * torch.log(p) + (1.0 - target) * torch.log1p(-p)).mean()
2984
+
2985
+
2986
+ def _temporal_overcapacity_loss(
2987
+ logits: torch.Tensor,
2988
+ assignment_mask: torch.Tensor,
2989
+ active_mask: torch.Tensor,
2990
+ expected_tokens_rate: float,
2991
+ maximum_expert_overclaim: int,
2992
+ ) -> torch.Tensor:
2993
+ """Temporal overcapacity loss for MoSRAH load balancing.
2994
+
2995
+ Penalises routing decisions that select a head already overloaded relative to
2996
+ its ideal allocation trajectory. A head is considered overloaded when the number
2997
+ of active tokens before position n assigned to that head exceeds
2998
+ cumulative_active_tokens * M + C, where M is the expected_tokens_rate (K/L) and
2999
+ C is the maximum_expert_overclaim slack.
3000
+
3001
+ Loss is exactly zero when no head exceeds its trajectory, making it safe to
3002
+ weight strongly — it stays out of the way when routing is balanced.
3003
+
3004
+ Args:
3005
+ logits: Pre-softmax routing scores, shape (B, N, L).
3006
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
3007
+ 1.0 if token (b, n) is assigned to head l.
3008
+ active_mask: Boolean active-token mask, shape (B, N).
3009
+ expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed
3010
+ by the factory so the division is not repeated each
3011
+ forward pass.
3012
+ maximum_expert_overclaim (C): Slack above the ideal trajectory before
3013
+ imbalance fires. Larger C tolerates more deviation.
3014
+
3015
+ Returns:
3016
+ Scalar loss tensor. Exactly 0.0 when no head exceeds its allowed trajectory.
3017
+ """
3018
+ # ── Algorithm overview ──────────────────────────────────────────────────────
3019
+ #
3020
+ # Problem: token routing is stateless — each token's TopK selection is blind to
3021
+ # how many times each expert has already been chosen earlier in the sequence. A
3022
+ # router that develops a strong preference for certain experts will overload them
3023
+ # far beyond their K/L fair share with no correction signal at the moment of
3024
+ # selection.
3025
+ #
3026
+ # Approach: track per-head assignment history as exclusive cumulative counts
3027
+ # (assignments by all active tokens strictly before position n) and compare
3028
+ # against an ideal trajectory S·M, where S is the inclusive cumulative active
3029
+ # token count and M is the amount of tokens expected given ideal balancing
3030
+ # A head is overloaded when its prior count exceeds that trajectory
3031
+ # by more than C. When a token selects an already-overloaded head, the loss
3032
+ # moment — mean(violating logits) minus mean(non-overloaded logits) — penalises
3033
+ # the gap and pushes future routing toward underloaded alternatives.
3034
+
3035
+ # ── Routing history and imbalance threshold ──────────────────────────────────
3036
+ #
3037
+ # prior_assignment_counts is the exclusive routing history at each position:
3038
+ # active assignments to each head by all tokens strictly before position n.
3039
+ # Exclusive because it reflects only what was known when token n was being routed.
3040
+ # cumulative_active_tokens grows by 1 per active token; the ideal per-head
3041
+ # allocation at n is S·M. Exceeding that by more than C triggers imbalance.
3042
+
3043
+ active_float = active_mask.float() # (B, N)
3044
+ active_assignments = assignment_mask * active_float.unsqueeze(-1) # (B, N, L)
3045
+
3046
+ # exclusive cumsums: subtract self to exclude position n
3047
+ prior_assignment_counts = active_assignments.cumsum(dim=1) - active_assignments # (B, N, L)
3048
+ cumulative_active_tokens = active_float.cumsum(dim=1) - active_float # (B, N)
3049
+
3050
+ maximum_supportable_assignments = (
3051
+ cumulative_active_tokens.unsqueeze(-1) * expected_tokens_rate
3052
+ + maximum_expert_overclaim
3053
+ ) # (B, N, 1) → broadcasts to (B, N, L)
3054
+
3055
+ # ── Mask construction ────────────────────────────────────────────────────────
3056
+ #
3057
+ # Three derived masks:
3058
+ # imbalance_mask: any head exceeding its trajectory.
3059
+ # violating_selection_mask: selected AND imbalanced — the penalty target.
3060
+ # non_overloaded_head_mask: NOT imbalanced, regardless of selection.
3061
+ #
3062
+ # Masking is deliberately assymetric. We have a problem when something is over
3063
+ # capacity AND gets chosen by topk. We can transfer it elsewhere only if we
3064
+ # are not overcapacity.
3065
+
3066
+ imbalance_mask = prior_assignment_counts > maximum_supportable_assignments # (B, N, L)
3067
+ violating_selection_mask = assignment_mask.bool() & imbalance_mask # (B, N, L)
3068
+ non_overloaded_head_mask = ~imbalance_mask # (B, N, L)
3069
+ has_violation_mask = violating_selection_mask.any(dim=-1) # (B, N)
3070
+
3071
+ # ── Loss moment ────────────────────────────────────────────────────────
3072
+ #
3073
+ # Epsilons on the count denominators guard against NaN when violation_count or
3074
+ # non_overloaded_count is zero. has_violation_mask zeros positions with no
3075
+ # violations at the gating step, so the epsilon-inflated denominator never
3076
+ # contributes to the loss.
3077
+ #
3078
+ # One notable property of this moment is it keeps the amount of transferred
3079
+ # logit mass constant. That is the gradient reduces violating logits and increases
3080
+ # non-overloaded logits by equal magnitude. Routing is redirected, not suppressed.
3081
+
3082
+ violation_count = violating_selection_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
3083
+ non_overloaded_count = non_overloaded_head_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
3084
+ mean_violating_logit = (violating_selection_mask.float() * logits).sum(dim=-1) / violation_count # (B, N)
3085
+ mean_non_overloaded_logit = (non_overloaded_head_mask.float() * logits).sum(dim=-1) / non_overloaded_count # (B, N)
3086
+ raw_loss = mean_violating_logit - mean_non_overloaded_logit # (B, N)
3087
+
3088
+ # ── Loss reduction ───────────────────────────────────────────────────────────
3089
+ #
3090
+ # Reduction is over active positions only; dead tokens are excluded from both
3091
+ # numerator (gated by active_float) and denominator (active_count_per_seq).
3092
+ # clamp(min=1.0) handles the all-dead-tokens edge case: gated_loss is zero
3093
+ # there since active_float gates it, so the result is 0/1 = 0.
3094
+ #
3095
+ # Exact-zero guarantee: when no head exceeds its trajectory, has_violation_mask
3096
+ # is all-False, gated_loss is zeroed everywhere, and the scalar return is
3097
+ # exactly 0.0. The loss is inert when routing is balanced.
3098
+
3099
+ gated_loss = active_float * has_violation_mask.float() * raw_loss # (B, N)
3100
+ active_count_per_seq = active_float.sum(dim=1).clamp(min=1.0) # (B,)
3101
+ sequence_loss = gated_loss.sum(dim=1) / active_count_per_seq # (B,)
3102
+ final_loss = sequence_loss.mean()
3103
+ return final_loss
3104
 
3105
 
3106
  # ---------------------------------------------------------------------------
3107
  # Factory
3108
  # ---------------------------------------------------------------------------
3109
 
3110
+ def _gshard_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3111
+ return gshard_loss
3112
+
3113
+
3114
+ def _ce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3115
+ return ce_loss
3116
+
3117
+
3118
+ def _bce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3119
+ return bce_loss
3120
+
3121
+
3122
+ def _temporal_overcapacity_factory(
3123
+ num_selected_heads: int,
3124
+ num_total_heads: int,
3125
+ maximum_expert_overclaim: int,
3126
+ **kwargs: object,
3127
+ ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3128
+ expected_tokens_rate = num_selected_heads / num_total_heads
3129
+ def _runtime(
3130
+ logits: torch.Tensor,
3131
+ assignment_mask: torch.Tensor,
3132
+ active_mask: torch.Tensor,
3133
+ ) -> torch.Tensor:
3134
+ return _temporal_overcapacity_loss(
3135
+ logits, assignment_mask, active_mask,
3136
+ expected_tokens_rate=expected_tokens_rate,
3137
+ maximum_expert_overclaim=maximum_expert_overclaim,
3138
+ )
3139
+ return _runtime
3140
+
3141
+
3142
+ _LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
3143
+ "gshard": _gshard_factory,
3144
+ "ce": _ce_factory,
3145
+ "bce": _bce_factory,
3146
+ "temporal_overcapacity": _temporal_overcapacity_factory,
3147
  }
3148
 
3149
 
3150
  def make_load_balance_loss(
3151
  loss_type: str,
3152
+ **loss_parameters: object,
3153
  ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3154
  """Return a load-balance loss callable for the requested formulation.
3155
 
 
3161
  active_mask: Tensor[B, N],
3162
  ) -> scalar Tensor
3163
 
3164
+ Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
3165
+ factories silently ignore all kwargs; this allows callers to pass loss-type-specific
3166
+ parameters (e.g. for temporal_overcapacity) without branching on loss_type.
3167
 
3168
  Args:
3169
+ loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``, or
3170
+ ``"temporal_overcapacity"``.
3171
+ **loss_parameters: Construction-time parameters forwarded to the factory.
3172
 
3173
  Returns:
3174
  Loss callable matching the shared contract.
 
3181
  raise ValueError(
3182
  f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
3183
  )
3184
+ return _LOSS_REGISTRY[loss_type](**loss_parameters)
3185
 
3186
 
3187
  class MoSRAHRouter(nn.Module):
3188
  """Token-choice router for MoSRAH sparse attention.
3189
 
3190
+ Each input token independently selects K of the L available expert heads.
3191
+ A single routing projection maps input hidden states to per-head scores; both
3192
+ task loss and load_balance_loss train this projection directly.
 
3193
 
3194
+ routing_weight is nn.Parameter rather than nn.Linear so that HuggingFace
3195
+ _init_weights does not override its kaiming initialization at construction.
 
3196
 
3197
  Attributes:
3198
+ routing_weight: Shape (L, embedding_width). Maps input hidden states to
3199
+ per-head routing scores. Receives gradients from both task loss and
3200
+ load_balance_loss.
 
 
 
 
3201
 
3202
  Args:
3203
  config: Model configuration. Must expose ``embedding_width``,
3204
+ ``num_mosrah_heads`` (L), ``num_selected_heads`` (K),
3205
+ ``load_balance_loss_type``, ``maximum_expert_overclaim``, ``max_bid_rounds``,
3206
+ ``use_cache``, ``mosrah_cache_length``, and ``mosrah_packed_length``.
3207
  """
3208
 
3209
  def __init__(self, config: ShramConfig) -> None:
 
3216
  self.capacity = config.mosrah_packed_length
3217
 
3218
  self.max_bid_rounds = config.max_bid_rounds
3219
+ self._load_balance_loss = make_load_balance_loss(
3220
+ config.load_balance_loss_type,
3221
+ num_selected_heads=config.num_selected_heads,
3222
+ num_total_heads=config.num_mosrah_heads,
3223
+ maximum_expert_overclaim=config.maximum_expert_overclaim,
 
 
 
3224
  )
 
3225
 
3226
+ # Routing projection: maps input (B, N, d) to per-head routing scores (B, N, L).
 
3227
  # nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
3228
+ self.routing_weight = nn.Parameter(
3229
  torch.empty(config.num_mosrah_heads, config.embedding_width)
3230
  )
3231
+ nn.init.kaiming_normal_(self.routing_weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3232
 
3233
  @staticmethod
3234
  def get_best_proposals(
 
3473
  Returns:
3474
  selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
3475
  Each token's K selected head indices, determined by TopK on
3476
+ capacity-balanced routing scores.
3477
  routing_probs: Routing probabilities P of shape (batch, seq_len,
3478
+ num_selected_heads). Gathered from pre-capacity routing softmax at
3479
  selected_heads indices and renormalized to sum to 1 per token.
3480
  router_diagnostics: Dict of routing feedback scalars. Keys:
3481
  - ``load_balance_loss``: scalar load-balance loss with gradient.
3482
  - ``max_vio``: detached scalar routing-imbalance summary.
3483
+ - ``logit_std``: detached mean per-token std of routing logits;
3484
+ monitoring metric for routing sharpness.
 
 
 
 
 
 
 
3485
  """
3486
  B, N, _ = x.shape
3487
  L = self.num_mosrah_heads
3488
  K = self.num_selected_heads
3489
 
3490
+ # ── Phase: pre-capacity scoring ───────────────────────────────────────
3491
+ #
3492
+ # Establishes the clean pre-sentinel distribution that all downstream
3493
+ # consumers draw from. logit_std must be captured here — balance_capacity
3494
+ # injects -1e8 sentinels that would corrupt the standard deviation.
3495
+ # routing_scores is the pre-capacity probability distribution; both the
3496
+ # load balance signal and the final routing_probs gather from it.
3497
+ routing_logits = self._compute_routing_logits(x) # (B, N, L)
3498
+ logit_std = routing_logits.std(dim=-1).mean().detach()
3499
+ routing_scores = F.softmax(routing_logits, dim=-1) # (B, N, L)
3500
+
3501
+ # ── Phase: load balance signal ────────────────────────────────────────
3502
+ #
3503
+ # The loss must observe the unconstrained routing decision — the genuine
3504
+ # routing pressure before capacity enforcement masks any imbalance.
3505
+ # pre_cap_heads and assignment_mask exist solely to give the loss this
3506
+ # honest view; nothing downstream uses them.
3507
+ pre_cap_heads = routing_scores.topk(K, dim=-1).indices # (B, N, K)
3508
+ assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
3509
+ assignment_mask.scatter_(-1, pre_cap_heads, 1.0)
3510
 
3511
+ load_balance_loss = self._load_balance_loss(
3512
+ routing_logits, assignment_mask, active_mask
 
 
 
 
3513
  )
3514
 
3515
+ # ── Phase: capacity enforcement and final selection ───────────────────
3516
+ #
3517
+ # Produces the capacity-enforced routing that all downstream consumers
3518
+ # depend on. max_vio is computed here because it measures realized routing
3519
+ # imbalance — the actual post-capacity assignment, not the unconstrained
3520
+ # preference. routing_probs are gathered from the pre-capacity routing_scores
3521
+ # (not the balanced distribution) to avoid sentinel corruption overloaded
3522
+ # experts would otherwise receive near-zero probability regardless of genuine
3523
+ # routing preference.
3524
+ balanced_logits = self.balance_capacity(
3525
+ routing_logits,
3526
  used_capacity,
3527
  self.capacity,
3528
  self.num_selected_heads,
3529
  self.max_bid_rounds,
3530
  )
3531
+ selected_heads = F.softmax(balanced_logits, dim=-1).topk(K, dim=-1).indices # (B, N, K)
3532
 
3533
+ realized_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
3534
+ realized_mask.scatter_(-1, selected_heads, 1.0)
3535
+ max_vio = self._compute_max_vio(realized_mask, active_mask, L)
3536
 
 
 
3537
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
3538
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
3539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3540
  router_diagnostics = {
3541
  "load_balance_loss": load_balance_loss,
3542
+ "max_vio": max_vio,
3543
+ "logit_std": logit_std,
3544
  }
3545
  return selected_heads, routing_probs, router_diagnostics
3546
 
3547
+ def _compute_routing_logits(self, x: torch.Tensor) -> torch.Tensor:
3548
+ """Compute per-head routing logits from input hidden states.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3549
 
3550
  Args:
3551
  x: Input hidden states, shape (batch, seq_len, embedding_width).
 
 
3552
 
3553
  Returns:
3554
+ Routing logits, shape (batch, seq_len, num_mosrah_heads).
 
 
 
 
3555
  """
3556
+ return F.linear(x, self.routing_weight) # (B, N, L)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3557
 
3558
  @staticmethod
3559
  def _compute_max_vio(
 
4091
  - ``"max_vio"``: detached scalar maximum routing-imbalance across
4092
  all decoder layers. Zero means perfectly balanced routing across
4093
  every layer; higher values identify the worst-case head imbalance.
 
 
 
 
 
 
 
4094
  - ``"logit_std"``: detached scalar — mean across layers of the
4095
+ per-token routing logit spread. Monitoring metric for routing
4096
+ sharpness.
 
 
 
 
 
4097
  """
4098
  hidden_states = inputs_embeds
4099
  all_hidden_states = (hidden_states,) if output_hidden_states else None
4100
  total_load_balance_loss = inputs_embeds.new_zeros(())
4101
  max_vio = inputs_embeds.new_zeros(())
 
 
4102
  total_logit_std = inputs_embeds.new_zeros(())
 
4103
 
4104
  for layer_idx, layer in enumerate(self.layers):
4105
  layer_cache = None if cache is None else cache.layers[layer_idx]
 
4111
  )
4112
  total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"]
4113
  max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"])
 
 
4114
  total_logit_std = total_logit_std + layer_diagnostics["logit_std"]
 
4115
 
4116
  if output_hidden_states:
4117
  all_hidden_states = all_hidden_states + (hidden_states,)
 
4125
  "hidden_states": all_hidden_states,
4126
  "load_balance_loss": total_load_balance_loss,
4127
  "max_vio": max_vio,
 
 
4128
  "logit_std": total_logit_std / num_layers,
 
4129
  }
4130
 
4131
 
 
4142
  ## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all
4143
  ## fields to None, which forces every subclass field to also carry a default.
4144
  ## The = None below is a language constraint, not a semantic statement. In
4145
+ ## practice, load_balance_loss, max_vio, and logit_std are always populated
4146
+ ## by ShramForCausalLM.forward(). ce_loss is genuinely optional — present
4147
+ ## only when labels are supplied.
4148
 
4149
  ce_loss: torch.FloatTensor | None = None
4150
  load_balance_loss: torch.FloatTensor | None = None
4151
  max_vio: torch.FloatTensor | None = None
 
 
4152
  logit_std: torch.Tensor | None = None
 
4153
 
4154
  class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4155
  """HuggingFace-facing causal language model wrapper for SHRAM.
 
4598
  - ``hidden_states`` when requested,
4599
  - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
4600
  - ``max_vio`` — detached worst-case routing imbalance across layers,
4601
+ - ``logit_std`` detached mean per-token routing logit spread across layers.
 
 
4602
  """
4603
  use_cache = use_cache if use_cache is not None else self.config.use_cache
4604
  output_hidden_states = (
 
4705
  hidden_states=backbone_outputs["hidden_states"],
4706
  load_balance_loss=backbone_outputs["load_balance_loss"],
4707
  max_vio=backbone_outputs["max_vio"],
 
 
4708
  logit_std=backbone_outputs["logit_std"],
 
4709
  )