smithblack-0 commited on
Commit
7a9910f
·
verified ·
1 Parent(s): 40cae06

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +1 -2
  2. config.json +1 -2
  3. configuration.py +13 -23
  4. huggingface.py +474 -208
README.md CHANGED
@@ -83,7 +83,6 @@ contains no weights. All values are overridable via kwargs.
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
  | `load_balance_loss_type` | ce |
86
- | `load_balance_p` | 1.0 |
87
  | `local_rope_theta` | 10000.0 |
88
  | `max_bid_rounds` | 10 |
89
  | `mlp_width` | 1366 |
@@ -96,7 +95,7 @@ contains no weights. All values are overridable via kwargs.
96
  | `output_hidden_states` | False |
97
  | `rms_norm_eps` | 1e-05 |
98
  | `rope_mode` | main_sequence |
99
- | `router_init_scale` | 0.0001 |
100
  | `tie_word_embeddings` | False |
101
  | `training_sequence_length` | 1024 |
102
  | `use_cache` | True |
 
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
  | `load_balance_loss_type` | ce |
 
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
88
  | `mlp_width` | 1366 |
 
95
  | `output_hidden_states` | False |
96
  | `rms_norm_eps` | 1e-05 |
97
  | `rope_mode` | main_sequence |
98
+ | `routing_mode` | integral |
99
  | `tie_word_embeddings` | False |
100
  | `training_sequence_length` | 1024 |
101
  | `use_cache` | True |
config.json CHANGED
@@ -10,7 +10,6 @@
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
  "load_balance_loss_type": "ce",
13
- "load_balance_p": 1.0,
14
  "local_rope_theta": 10000.0,
15
  "max_bid_rounds": 10,
16
  "mlp_width": 1366,
@@ -23,7 +22,7 @@
23
  "num_sliding_window_heads": 16,
24
  "rms_norm_eps": 1e-05,
25
  "rope_mode": "main_sequence",
26
- "router_init_scale": 0.0001,
27
  "tie_word_embeddings": false,
28
  "training_sequence_length": 1024,
29
  "transformers_version": "5.10.2",
 
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
  "load_balance_loss_type": "ce",
 
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
15
  "mlp_width": 1366,
 
22
  "num_sliding_window_heads": 16,
23
  "rms_norm_eps": 1e-05,
24
  "rope_mode": "main_sequence",
25
+ "routing_mode": "integral",
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
  "transformers_version": "5.10.2",
configuration.py CHANGED
@@ -84,10 +84,6 @@ class ShramConfig(PretrainedConfig):
84
  num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
  Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
  baseline. Default 2.0.
87
- load_balance_p: Exponent p for the p-mean aggregation of per-item routing
88
- frequencies into the load balance signal. Higher p weights aggregation
89
- toward the worst-case batch item, making the correction signal more
90
- sensitive to per-item allocation spikes. Must be positive. Default 2.0.
91
  max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
92
  solver in ``balance_capacity``. 10 covers convergence at approximately
93
  the 98th percentile of routing densities; the top 2% of extreme-density
@@ -99,11 +95,13 @@ class ShramConfig(PretrainedConfig):
99
  is the default; its log-probability signal scales with violation severity
100
  and makes correction magnitude proportional to routing imbalance.
101
  Default ``"ce"``.
102
- router_init_scale: Initial standard deviation for the ``routing_scale``
103
- scalar gate on routing logits. Brings routing logit magnitude to
104
- ``expert_bias`` scale at initialisation so load balancing is operative
105
- from step one. Must be positive. Default ``1e-4``. Note lower values
106
- may require more bidding rounds to converge and more overcapacity to support.
 
 
107
  """
108
 
109
  model_type = "shram"
@@ -137,10 +135,9 @@ class ShramConfig(PretrainedConfig):
137
  output_hidden_states: bool = False,
138
  tie_word_embeddings: bool = False,
139
  mosrah_overallocation_factor: float = 2.0,
140
- load_balance_p: float = 1.0,
141
  max_bid_rounds: int = 10,
142
  load_balance_loss_type: str = "ce",
143
- router_init_scale: float = 1e-4,
144
  **kwargs
145
  ):
146
  if head_dim % 2 != 0:
@@ -176,11 +173,6 @@ class ShramConfig(PretrainedConfig):
176
  f"Got {mosrah_overallocation_factor}."
177
  )
178
 
179
- if load_balance_p <= 0.0:
180
- raise ValueError(
181
- f"load_balance_p must be positive, got {load_balance_p}."
182
- )
183
-
184
  if max_bid_rounds < 1:
185
  raise ValueError(
186
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
@@ -193,13 +185,12 @@ class ShramConfig(PretrainedConfig):
193
  f"load_balance_loss_type must be one of {supported}, "
194
  f"got {load_balance_loss_type!r}."
195
  )
196
- if load_balance_loss_type == "ce" and load_balance_p != 1.0:
197
- raise ValueError("In cross entropy mode, aggregation of "
198
- "frequencies must be with mean 1.0")
199
 
200
- if router_init_scale <= 0.0:
 
 
201
  raise ValueError(
202
- f"router_init_scale must be positive, got {router_init_scale}."
203
  )
204
 
205
  self.vocab_size = vocab_size
@@ -220,10 +211,9 @@ class ShramConfig(PretrainedConfig):
220
  self.alpha = alpha
221
  self.beta = beta
222
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
223
- self.load_balance_p = load_balance_p
224
  self.max_bid_rounds = max_bid_rounds
225
  self.load_balance_loss_type = load_balance_loss_type
226
- self.router_init_scale = router_init_scale
227
  self.attention_dropout = attention_dropout
228
  self.use_cache = use_cache
229
 
 
84
  num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
  Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
  baseline. Default 2.0.
 
 
 
 
87
  max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
88
  solver in ``balance_capacity``. 10 covers convergence at approximately
89
  the 98th percentile of routing densities; the top 2% of extreme-density
 
95
  is the default; its log-probability signal scales with violation severity
96
  and makes correction magnitude proportional to routing imbalance.
97
  Default ``"ce"``.
98
+ routing_mode: Routing computation mode. ``"integral"`` (default) enables the
99
+ integral routing extension: the exclusive cumsum of routing logits along
100
+ the sequence dimension is mapped through two additional (L, L) parameter
101
+ matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
102
+ B') and added as corrections to both logit pathways. This gives each
103
+ token a read on the cumulative routing history so far in the sequence.
104
+ ``"default"`` disables the extension; A' and B' are not created.
105
  """
106
 
107
  model_type = "shram"
 
135
  output_hidden_states: bool = False,
136
  tie_word_embeddings: bool = False,
137
  mosrah_overallocation_factor: float = 2.0,
 
138
  max_bid_rounds: int = 10,
139
  load_balance_loss_type: str = "ce",
140
+ routing_mode: str = "integral",
141
  **kwargs
142
  ):
143
  if head_dim % 2 != 0:
 
173
  f"Got {mosrah_overallocation_factor}."
174
  )
175
 
 
 
 
 
 
176
  if max_bid_rounds < 1:
177
  raise ValueError(
178
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
 
185
  f"load_balance_loss_type must be one of {supported}, "
186
  f"got {load_balance_loss_type!r}."
187
  )
 
 
 
188
 
189
+ _supported_routing_modes = {"default", "integral"}
190
+ if routing_mode not in _supported_routing_modes:
191
+ supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
192
  raise ValueError(
193
+ f"routing_mode must be one of {supported}, got {routing_mode!r}."
194
  )
195
 
196
  self.vocab_size = vocab_size
 
211
  self.alpha = alpha
212
  self.beta = beta
213
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
 
214
  self.max_bid_rounds = max_bid_rounds
215
  self.load_balance_loss_type = load_balance_loss_type
216
+ self.routing_mode = routing_mode
217
  self.attention_dropout = attention_dropout
218
  self.use_cache = use_cache
219
 
huggingface.py CHANGED
@@ -45,7 +45,6 @@ from torch.nn.attention.flex_attention import create_block_mask
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
47
  from typing import Callable
48
- from typing import Optional
49
 
50
 
51
 
@@ -172,10 +171,6 @@ class ShramConfig(PretrainedConfig):
172
  num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
173
  Must be > 1.0 to guarantee a buffer larger than the balanced-routing
174
  baseline. Default 2.0.
175
- load_balance_p: Exponent p for the p-mean aggregation of per-item routing
176
- frequencies into the load balance signal. Higher p weights aggregation
177
- toward the worst-case batch item, making the correction signal more
178
- sensitive to per-item allocation spikes. Must be positive. Default 2.0.
179
  max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
180
  solver in ``balance_capacity``. 10 covers convergence at approximately
181
  the 98th percentile of routing densities; the top 2% of extreme-density
@@ -187,11 +182,13 @@ class ShramConfig(PretrainedConfig):
187
  is the default; its log-probability signal scales with violation severity
188
  and makes correction magnitude proportional to routing imbalance.
189
  Default ``"ce"``.
190
- router_init_scale: Initial standard deviation for the ``routing_scale``
191
- scalar gate on routing logits. Brings routing logit magnitude to
192
- ``expert_bias`` scale at initialisation so load balancing is operative
193
- from step one. Must be positive. Default ``1e-4``. Note lower values
194
- may require more bidding rounds to converge and more overcapacity to support.
 
 
195
  """
196
 
197
  model_type = "shram"
@@ -225,10 +222,9 @@ class ShramConfig(PretrainedConfig):
225
  output_hidden_states: bool = False,
226
  tie_word_embeddings: bool = False,
227
  mosrah_overallocation_factor: float = 2.0,
228
- load_balance_p: float = 1.0,
229
  max_bid_rounds: int = 10,
230
  load_balance_loss_type: str = "ce",
231
- router_init_scale: float = 1e-4,
232
  **kwargs
233
  ):
234
  if head_dim % 2 != 0:
@@ -264,11 +260,6 @@ class ShramConfig(PretrainedConfig):
264
  f"Got {mosrah_overallocation_factor}."
265
  )
266
 
267
- if load_balance_p <= 0.0:
268
- raise ValueError(
269
- f"load_balance_p must be positive, got {load_balance_p}."
270
- )
271
-
272
  if max_bid_rounds < 1:
273
  raise ValueError(
274
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
@@ -281,13 +272,12 @@ class ShramConfig(PretrainedConfig):
281
  f"load_balance_loss_type must be one of {supported}, "
282
  f"got {load_balance_loss_type!r}."
283
  )
284
- if load_balance_loss_type == "ce" and load_balance_p != 1.0:
285
- raise ValueError("In cross entropy mode, aggregation of "
286
- "frequencies must be with mean 1.0")
287
 
288
- if router_init_scale <= 0.0:
 
 
289
  raise ValueError(
290
- f"router_init_scale must be positive, got {router_init_scale}."
291
  )
292
 
293
  self.vocab_size = vocab_size
@@ -308,10 +298,9 @@ class ShramConfig(PretrainedConfig):
308
  self.alpha = alpha
309
  self.beta = beta
310
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
311
- self.load_balance_p = load_balance_p
312
  self.max_bid_rounds = max_bid_rounds
313
  self.load_balance_loss_type = load_balance_loss_type
314
- self.router_init_scale = router_init_scale
315
  self.attention_dropout = attention_dropout
316
  self.use_cache = use_cache
317
 
@@ -2741,24 +2730,53 @@ paper. Given an input hidden state x, the router produces two outputs used downs
2741
  the semantic routing scores at the selected indices and renormalized to sum to 1
2742
  per token.
2743
 
2744
- Routing computation uses two gradient-isolated pathways over numerically identical
2745
- biased values:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2746
 
2747
- - semantic_logits = logits + expert_bias.detach(): drives selection and routing_probs.
2748
- Task gradients reach routing_projection.weight; expert_bias is isolated from task loss.
2749
- - load_balancing_logits = logits.detach() + expert_bias: drives assignment_probs.
2750
- Load balance gradients reach expert_bias; routing_projection.weight is isolated from
2751
- load balance loss.
2752
 
2753
- No unbiased routing computation exists. All routing uses biased values. The separation
2754
- of gradient paths replaces the previous biased/unbiased split, closing the loophole where
2755
- a bias-redirected expert could be selected but contribute negligibly to the output because
2756
- its unbiased preference — and thus its routing_prob — remained near zero.
2757
 
2758
  Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
2759
  Post-capacity softmax would invert the load balance gradient for over-capacity experts
2760
- (near-zero probability after masking signals "increase bias" for an already-overloaded
2761
- expert).
2762
 
2763
  The router also computes and returns the load balance loss via a log-probability auxiliary
2764
  loss (see load_balance_loss.py). The loss formulation is selected by config; the default
@@ -2767,10 +2785,11 @@ is cross-entropy.
2767
  The router additionally computes and returns MaxVio, a detached scalar summarising
2768
  routing imbalance for the current forward pass:
2769
 
2770
- MaxVio = L · max_l(f_l − 1/L)
2771
 
2772
- where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
2773
- target. MaxVio is a monitoring quantity only; it never contributes gradients.
 
2774
 
2775
  Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2776
  """
@@ -2785,130 +2804,228 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2785
  # -----------
2786
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2787
 
2788
- This module provides three load-balance loss formulations and a factory that selects
2789
- among them. All formulations share the same external contract and the same gradient
2790
- isolation property: assignment probabilities are computed from detached logits plus
2791
- expert_bias, so only expert_bias receives gradients from the loss signal. The routing
2792
- projection weights are not reachable from any returned loss.
2793
 
2794
- The factory is the intended entry point. The caller (MoSRAHRouter) constructs the
2795
- loss callable once at init and invokes it each forward pass.
 
 
 
2796
 
2797
- Log-probability formulations (ce, bce) are preferred over linear ones (gshard) because
2798
- their gradient magnitude scales with how far the distribution deviates from the target.
2799
- A linear signal can be outrun by routing concentrations that diverge nonlinearly; a
2800
- log-probability signal cannot.
 
 
 
 
2801
 
2802
- The external contract for all returned callables is:
2803
 
2804
- loss_fn(routing_freqs, assignment_probs) -> scalar Tensor
 
2805
 
2806
- routing_freqs: (L,) realized routing frequencies f_i, detached.
2807
- assignment_probs: (L,) soft assignment probabilities p_i with gradient through
2808
- expert_bias. Caller must compute these via
2809
- softmax(logits.detach() + expert_bias) to preserve isolation.
 
 
 
 
 
2810
  """
2811
 
2812
 
2813
 
2814
 
2815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2816
  # ---------------------------------------------------------------------------
2817
  # Loss functions
2818
  # ---------------------------------------------------------------------------
2819
 
2820
  def gshard_loss(
2821
- routing_freqs: torch.Tensor,
2822
- assignment_probs: torch.Tensor,
 
2823
  ) -> torch.Tensor:
2824
  """GShard-style linear load-balance loss.
2825
 
2826
- Computes (1/L) * Σ_i f_i * p_i, where L is the number of expert heads,
2827
- f_i is the realized routing frequency for head i, and p_i is the soft
2828
- assignment probability for head i.
2829
 
2830
- The fixed point of this loss under gradient descent is uniform routing:
2831
- when p_i = 1/L for all i, the loss is minimized at 1/L (independent of f_i).
2832
- The linear signal is the weakest of the three formulations — gradient magnitude
2833
- does not grow with deviation from the target. Provided for comparison.
2834
 
2835
  Args:
2836
- routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2837
- assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2838
- flows to expert_bias through this tensor.
2839
 
2840
  Returns:
2841
  Scalar loss tensor.
2842
  """
2843
- L = routing_freqs.shape[0]
2844
- return (routing_freqs * assignment_probs).sum() / L
 
 
2845
 
2846
 
2847
  def ce_loss(
2848
- routing_freqs: torch.Tensor,
2849
- assignment_probs: torch.Tensor,
 
2850
  ) -> torch.Tensor:
2851
  """Cross-entropy load-balance loss.
2852
 
2853
- Computes -(1/(L-1)) * Σ_i (1 - f_i) * log(p_i), where the weight (1 - f_i)
2854
- suppresses the signal for overloaded heads (high f_i → weight near zero) and
2855
- amplifies it for underloaded heads (low f_i weight near 1). This makes the
2856
- loss push probability mass toward under-utilized experts.
 
 
 
 
2857
 
2858
- The (1/(L-1)) normalization makes the coefficient interpretable as a controller
2859
- strength independent of expert count. The log-probability signal grows as p_i
2860
- deviates from the target, providing correction that scales with violation severity.
2861
 
2862
  Args:
2863
- routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2864
- assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2865
- flows to expert_bias through this tensor.
2866
 
2867
  Returns:
2868
  Scalar loss tensor.
2869
  """
2870
- L = routing_freqs.shape[0]
2871
- # Numerical stability: torch.log is safe here because softmax outputs are
2872
- # strictly positive. The (1 - f_i) weight goes to zero exactly when f_i = 1,
2873
- # which can only occur with a single head, so the 0 * (-inf) degenerate case
2874
- # does not arise in practice.
2875
- return -(((1.0 - routing_freqs) * torch.log(assignment_probs)).sum()) / (L - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
2876
 
2877
 
2878
  def bce_loss(
2879
- routing_freqs: torch.Tensor,
2880
- assignment_probs: torch.Tensor,
 
2881
  ) -> torch.Tensor:
2882
  """Binary cross-entropy load-balance loss.
2883
 
2884
- Computes -(1/L) * Σ_i [(1 - f_i) * log(p_i) + f_i * log(1 - p_i)], where
2885
- each head is treated as an independent binary target. Unlike CE, BCE maintains
2886
- a repulsion signal from saturated experts: when f_i → 1, the weight on
2887
- log(1 - p_i) drives p_i away from 1, preventing runaway concentration.
 
 
 
2888
 
2889
- log(1 - p_i) is computed as log1p(-p_i) for numerical safety near p_i = 1.
 
 
2890
 
2891
  Args:
2892
- routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2893
- assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2894
- flows to expert_bias through this tensor.
2895
 
2896
  Returns:
2897
  Scalar loss tensor.
2898
  """
2899
- L = routing_freqs.shape[0]
2900
- positive_term = (1.0 - routing_freqs) * torch.log(assignment_probs)
2901
- # log1p(-p) instead of log(1-p): avoids catastrophic cancellation when p is
2902
- # close to 1, where (1 - p) loses precision and log produces large errors.
2903
- negative_term = routing_freqs * torch.log1p(-assignment_probs)
2904
- return -(positive_term + negative_term).sum() / L
 
 
 
 
 
2905
 
2906
 
2907
  # ---------------------------------------------------------------------------
2908
  # Factory
2909
  # ---------------------------------------------------------------------------
2910
 
2911
- _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {
2912
  "gshard": gshard_loss,
2913
  "ce": ce_loss,
2914
  "bce": bce_loss,
@@ -2917,15 +3034,19 @@ _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
2917
 
2918
  def make_load_balance_loss(
2919
  loss_type: str,
2920
- ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
2921
  """Return a load-balance loss callable for the requested formulation.
2922
 
2923
- All returned callables share the same external contract:
2924
 
2925
- loss_fn(routing_freqs: Tensor, assignment_probs: Tensor) -> scalar Tensor
 
 
 
 
2926
 
2927
- The caller is responsible for computing assignment_probs via
2928
- softmax(logits.detach() + expert_bias) to ensure gradient isolation.
2929
 
2930
  Args:
2931
  loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
@@ -2944,55 +3065,77 @@ def make_load_balance_loss(
2944
  return _LOSS_REGISTRY[loss_type]
2945
 
2946
 
2947
-
2948
-
2949
  class MoSRAHRouter(nn.Module):
2950
  """Token-choice router for MoSRAH sparse attention.
2951
 
2952
  Each input token independently selects K of the L available expert heads. Both
2953
- selection and routing_probs incorporate expert_bias via two gradient-isolated
2954
- pathways over numerically identical biased values. See module docstring for the
2955
- two-pathway architecture.
2956
-
2957
- The routing projection W_r has no bias term the paper specifies xW_r with no
2958
- additional projection bias. The only bias-like parameter is expert_bias (b), which
2959
- has an entirely separate role and gradient path.
 
 
 
 
 
 
 
 
 
2960
 
2961
  Args:
2962
- config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
2963
- (L), and ``num_selected_heads`` (K).
 
2964
  """
2965
 
2966
  def __init__(self, config: ShramConfig) -> None:
2967
  super().__init__()
2968
  self.num_mosrah_heads = config.num_mosrah_heads
2969
  self.num_selected_heads = config.num_selected_heads
2970
- self.load_balance_p = config.load_balance_p
2971
  if config.use_cache:
2972
  self.capacity = config.mosrah_cache_length
2973
  else:
2974
  self.capacity = config.mosrah_packed_length
2975
 
2976
  self.max_bid_rounds = config.max_bid_rounds
 
2977
  self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
2978
 
2979
- # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
2980
- self.routing_projection = nn.Linear(
2981
- config.embedding_width, config.num_mosrah_heads, bias=False
 
 
2982
  )
 
2983
 
2984
- # Scalar gate on routing logits. As an nn.Parameter it is exempt from
2985
- # HuggingFace _init_weights, so its near-zero initial value is preserved
2986
- # after from_config construction. Near-zero initialization ensures routing
2987
- # starts near-uniform and expert_bias has leverage over logits from step one.
2988
- self.routing_scale = nn.Parameter(
2989
- torch.full((1,), config.router_init_scale)
2990
  )
2991
-
2992
- # b: learned per-head bias for load balancing. Initialized to zero so that all
2993
- # heads start with equal selection probability. Updated by the main optimizer
2994
- # via gradients from the load balance loss through load_balancing_logits.
2995
- self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
 
 
 
 
 
 
 
 
 
 
 
 
2996
 
2997
  @staticmethod
2998
  def get_best_proposals(
@@ -3228,7 +3371,7 @@ class MoSRAHRouter(nn.Module):
3228
  """Route input tokens to K expert heads each and compute routing probabilities.
3229
 
3230
  Args:
3231
- x: Input hidden states of shape (batch, seq_len, hidden_size).
3232
  active_mask: Current-chunk active mask of shape (batch, seq_len), where
3233
  True means the token is semantically live. Dead tokens do not
3234
  contribute to routing frequencies, load_balance_loss, or max_vio.
@@ -3244,56 +3387,39 @@ class MoSRAHRouter(nn.Module):
3244
  router_diagnostics: Dict of routing feedback scalars. Keys:
3245
  - ``load_balance_loss``: scalar load-balance loss with gradient.
3246
  - ``max_vio``: detached scalar routing-imbalance summary.
3247
- - ``bias_std``: std of expert_bias; near-zero means corrections have not built up.
3248
- - ``raw_logit_std``: mean per-token std of scaled logits; the natural routing scale.
 
 
3249
  - ``logit_std``: mean per-token std of semantic_logits; lower than
3250
- raw_logit_std means bias is flattening preferences (healthy correction).
3251
- - ``bias_alignment``: mean cosine similarity of expert_bias against per-token
3252
- logits. Negative means bias opposes routing direction (healthy correction);
3253
- positive means runaway reinforcement.
3254
  """
3255
  B, N, _ = x.shape
3256
  L = self.num_mosrah_heads
3257
  K = self.num_selected_heads
3258
 
3259
- # Scaled logits. routing_scale is a near-zero nn.Parameter exempt from
3260
- # HuggingFace _init_weights, so routing starts near-uniform and expert_bias
3261
- # has leverage from step one.
3262
- logits = self.routing_projection(x) * self.routing_scale # (B, N, L)
3263
-
3264
- # Two gradient-isolated pathways over numerically identical biased values.
3265
- # semantic_logits: task gradients reach routing_projection; expert_bias isolated.
3266
- # load_balancing_logits: load balance gradients reach expert_bias; routing_projection isolated.
3267
- semantic_logits = logits + self.expert_bias.detach() # (B, N, L)
3268
- load_balancing_logits = logits.detach() + self.expert_bias # (B, N, L)
3269
-
3270
- # Diagnostic scalars characterising the load-balance mechanism. Must be
3271
- # computed here — before balance_capacity injects -1e8 sentinels that
3272
- # would corrupt std and cosine similarity.
3273
- bias_std = self.expert_bias.std().detach()
3274
- raw_logit_std = logits.std(dim=-1).mean().detach()
3275
- logit_std = semantic_logits.std(dim=-1).mean().detach()
3276
- bias_alignment = F.cosine_similarity(
3277
- logits, self.expert_bias.expand_as(logits), dim=-1
3278
- ).mean().detach()
3279
-
3280
- # Assignment probabilities for load balance loss. Computed from load_balancing_logits
3281
- # before balance_capacity so that -1e8 sentinels do not invert the load balance
3282
- # gradient for over-capacity experts. active_float is reused below for routing freqs.
3283
- active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
3284
- lb_softmax = F.softmax(load_balancing_logits, dim=-1) # (B, N, L)
3285
- assignment_probs = (lb_softmax * active_float).sum(dim=(0, 1)) # (L,) unnorm
3286
- assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
3287
 
3288
  # Pre-capacity semantic softmax for gathering routing_probs. Computed before
3289
  # balance_capacity so that gathered probabilities reflect genuine preference
3290
  # magnitudes rather than hard-masked sentinel values.
3291
- routing_scores = F.softmax(semantic_logits, dim=-1) # (B, N, L)
3292
 
3293
  # Capacity-balanced semantic logits for selection. Injects -1e8 into positions
3294
  # that would exceed per-expert token budget, enforcing the packing constraint.
3295
  balanced_semantic_logits = self.balance_capacity(
3296
- semantic_logits,
3297
  used_capacity,
3298
  self.capacity,
3299
  self.num_selected_heads,
@@ -3309,61 +3435,201 @@ class MoSRAHRouter(nn.Module):
3309
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
3310
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
3311
 
3312
- # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
3313
- # fraction of that item's active K assignments over all tokens go to head l.
3314
- # Dead tokens are excluded before reduction. Normalization is per batch item so
3315
- # each item's frequencies sum to 1 independently of other items in the batch.
3316
  assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
3317
  assignment_mask.scatter_(-1, selected_heads, 1.0)
3318
- active_assignments = assignment_mask * active_mask.unsqueeze(-1)
3319
- per_item_counts = active_assignments.sum(dim=1) # (B, L)
3320
- per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
3321
- per_item_freqs = per_item_counts / per_item_total # (B, L)
3322
-
3323
- # p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
3324
- # p-mean weights aggregation toward the worst-case batch item relative to
3325
- # arithmetic mean, making the load balance signal sensitive to per-item spikes
3326
- # that cause packing overflow.
3327
- p = self.load_balance_p
3328
- routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
3329
 
3330
- load_balance_loss = self._load_balance_loss(routing_freqs, assignment_probs)
 
 
3331
 
3332
- # MaxVio is a detached monitoring scalar following the paper's formula
3333
- # L · max_l(f_l 1/L) applied to routing_freqs. Must not contribute gradients.
3334
- max_vio = self._compute_max_vio(routing_freqs, L)
3335
 
3336
  router_diagnostics = {
3337
  "load_balance_loss": load_balance_loss,
3338
  "max_vio": max_vio,
3339
- "bias_std": bias_std,
3340
- "raw_logit_std": raw_logit_std,
3341
- "logit_std": logit_std,
3342
- "bias_alignment": bias_alignment,
3343
  }
3344
  return selected_heads, routing_probs, router_diagnostics
3345
 
3346
  @staticmethod
3347
- def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3348
  """Compute the MaxVio routing-imbalance scalar.
3349
 
3350
- MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
3351
- head l and 1/L is the perfectly balanced target. Follows the paper's definition
3352
- (Wang et al.) applied to routing_freqs. A value of zero indicates perfect
3353
- balance; a value of 0.5 means the most overloaded head received 50% more routed
3354
- tokens than ideal.
 
3355
 
3356
- The result is detached from the autograd graph — MaxVio is a monitoring scalar
3357
- and must never contribute gradients to any parameter.
3358
 
3359
  Args:
3360
- routing_freqs: Per-head routing frequencies of shape (L,).
3361
- num_heads: Total number of MoSRAH heads L.
 
3362
 
3363
  Returns:
3364
  Detached scalar MaxVio tensor.
3365
  """
3366
- return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
 
 
3367
 
3368
  # -----------
3369
  # Inlined from: positions_converter.py
 
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
47
  from typing import Callable
 
48
 
49
 
50
 
 
171
  num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
172
  Must be > 1.0 to guarantee a buffer larger than the balanced-routing
173
  baseline. Default 2.0.
 
 
 
 
174
  max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity
175
  solver in ``balance_capacity``. 10 covers convergence at approximately
176
  the 98th percentile of routing densities; the top 2% of extreme-density
 
182
  is the default; its log-probability signal scales with violation severity
183
  and makes correction magnitude proportional to routing imbalance.
184
  Default ``"ce"``.
185
+ routing_mode: Routing computation mode. ``"integral"`` (default) enables the
186
+ integral routing extension: the exclusive cumsum of routing logits along
187
+ the sequence dimension is mapped through two additional (L, L) parameter
188
+ matrices (``routing_integral_weight`` A' and ``balance_integral_weight``
189
+ B') and added as corrections to both logit pathways. This gives each
190
+ token a read on the cumulative routing history so far in the sequence.
191
+ ``"default"`` disables the extension; A' and B' are not created.
192
  """
193
 
194
  model_type = "shram"
 
222
  output_hidden_states: bool = False,
223
  tie_word_embeddings: bool = False,
224
  mosrah_overallocation_factor: float = 2.0,
 
225
  max_bid_rounds: int = 10,
226
  load_balance_loss_type: str = "ce",
227
+ routing_mode: str = "integral",
228
  **kwargs
229
  ):
230
  if head_dim % 2 != 0:
 
260
  f"Got {mosrah_overallocation_factor}."
261
  )
262
 
 
 
 
 
 
263
  if max_bid_rounds < 1:
264
  raise ValueError(
265
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
 
272
  f"load_balance_loss_type must be one of {supported}, "
273
  f"got {load_balance_loss_type!r}."
274
  )
 
 
 
275
 
276
+ _supported_routing_modes = {"default", "integral"}
277
+ if routing_mode not in _supported_routing_modes:
278
+ supported = ", ".join(f'"{m}"' for m in sorted(_supported_routing_modes))
279
  raise ValueError(
280
+ f"routing_mode must be one of {supported}, got {routing_mode!r}."
281
  )
282
 
283
  self.vocab_size = vocab_size
 
298
  self.alpha = alpha
299
  self.beta = beta
300
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
 
301
  self.max_bid_rounds = max_bid_rounds
302
  self.load_balance_loss_type = load_balance_loss_type
303
+ self.routing_mode = routing_mode
304
  self.attention_dropout = attention_dropout
305
  self.use_cache = use_cache
306
 
 
2730
  the semantic routing scores at the selected indices and renormalized to sum to 1
2731
  per token.
2732
 
2733
+ Base routing uses two learnable projection matrices and two gradient-isolated pathways:
2734
+
2735
+ - routing_weight (A): shape (L, embedding_width). Maps input to per-head routing
2736
+ scores. Receives gradients from task loss; balance_weight is isolated.
2737
+ - balance_weight (B): shape (L, embedding_width). Maps input to per-head load-balance
2738
+ correction scores. Receives gradients from load_balance_loss; routing_weight is
2739
+ isolated.
2740
+
2741
+ The two gradient-isolated base pathways over numerically identical values:
2742
+
2743
+ - semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
2744
+ balance_weight is isolated from task loss.
2745
+ - load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance gradients
2746
+ reach balance_weight; routing_weight and x are isolated from load balance loss.
2747
+
2748
+ Integral routing extension (routing_mode == "integral"):
2749
+
2750
+ Standard routing is parallel — each token routes based on its own hidden state alone,
2751
+ with no direct read on what earlier tokens in the sequence have already selected.
2752
+ Integral routing adds a cumulative-sum signal that gives each token a view of the
2753
+ prior routing history within the sequence.
2754
+
2755
+ Two additional (L, L) parameter matrices are introduced:
2756
+
2757
+ - routing_integral_weight (A'): shape (L, L). Maps the cumulative logit history to
2758
+ per-head semantic corrections. Receives gradients from task loss.
2759
+ - balance_integral_weight (B'): shape (L, L). Maps the cumulative logit history to
2760
+ per-head load-balance corrections. Receives gradients from load_balance_loss.
2761
+
2762
+ The cumulative history signal u is the exclusive cumsum of the base logits along the
2763
+ sequence dimension: u[n] = sum(logits[0..n-1]), shape (B, N, L). Position 0 receives
2764
+ zeros (no prior history). The same gradient isolation pattern as A/B applies:
2765
+
2766
+ - semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
2767
+ - lb_logits += (A'·u_load).detach() + B'·u_load
2768
 
2769
+ Detaching the full B'·u_semantic result (rather than just B') mirrors the
2770
+ (B·x).detach() pattern in the base pathway and prevents double-counting the
2771
+ cumsum gradient path back to routing_weight.
 
 
2772
 
2773
+ Both base matrices and both integral matrices are nn.Parameter so that HuggingFace
2774
+ _init_weights does not override their kaiming initialization at construction.
 
 
2775
 
2776
  Assignment probabilities are computed before balance_capacity applies -1e8 sentinels.
2777
  Post-capacity softmax would invert the load balance gradient for over-capacity experts
2778
+ (near-zero probability after masking signals "increase corrections" for an already-
2779
+ overloaded expert).
2780
 
2781
  The router also computes and returns the load balance loss via a log-probability auxiliary
2782
  loss (see load_balance_loss.py). The loss formulation is selected by config; the default
 
2785
  The router additionally computes and returns MaxVio, a detached scalar summarising
2786
  routing imbalance for the current forward pass:
2787
 
2788
+ MaxVio = mean_b( L · max_l(f_bl − 1/L) )
2789
 
2790
+ where f_bl is the per-batch-item realised routing frequency of head l and 1/L is the
2791
+ perfectly balanced target. MaxVio is averaged over batch items and is a monitoring
2792
+ quantity only; it never contributes gradients.
2793
 
2794
  Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2795
  """
 
2804
  # -----------
2805
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2806
 
2807
+ This module provides three load-balance loss formulations, two token-reduction
2808
+ helpers, and a factory that selects among the formulations. All formulations
2809
+ share the same external contract:
 
 
2810
 
2811
+ loss_fn(
2812
+ logits: Tensor[B, N, L],
2813
+ assignment_mask: Tensor[B, N, L],
2814
+ active_mask: Tensor[B, N],
2815
+ ) -> scalar Tensor
2816
 
2817
+ logits: Load-balancing logits, shape (B, N, L). These are the raw
2818
+ pre-softmax scores from logits.detach() + expert_bias.
2819
+ Gradient flows to expert_bias through this tensor.
2820
+ assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l]
2821
+ is 1.0 if token (b, n) was assigned to head l. Dead tokens
2822
+ should carry zero entries.
2823
+ active_mask: Boolean mask, shape (B, N). True means the token is
2824
+ semantically live.
2825
 
2826
+ Token reduction is split into two helpers with distinct roles:
2827
 
2828
+ reduce_frequency_tokens produces per-batch-item routing frequencies f_bl (B, L).
2829
+ Called by all three formulations. Output is detached; f_bl carries no gradient.
2830
 
2831
+ reduce_probability_tokens produces per-batch-item mean assignment probabilities
2832
+ p_bl (B, L). Called only by gshard and bce. Gradient flows to expert_bias
2833
+ through the internal softmax over logits.
2834
+
2835
+ CE delegates probability computation to F.cross_entropy, which handles its own
2836
+ log_softmax and operates directly on the raw (B, N, L) logits.
2837
+
2838
+ The factory is the intended entry point. MoSRAHRouter constructs the loss callable
2839
+ once at init and invokes it each forward pass.
2840
  """
2841
 
2842
 
2843
 
2844
 
2845
 
2846
+
2847
+ # ---------------------------------------------------------------------------
2848
+ # Token-reduction helpers
2849
+ # ---------------------------------------------------------------------------
2850
+
2851
+ def reduce_frequency_tokens(
2852
+ assignment_mask: torch.Tensor,
2853
+ active_mask: torch.Tensor,
2854
+ ) -> torch.Tensor:
2855
+ """Reduce per-token head assignments to per-batch-item routing frequencies.
2856
+
2857
+ f_bl[b, l] is the fraction of active-token assignments in batch item b going
2858
+ to head l. Values sum to 1 per batch item when routing is valid.
2859
+
2860
+ The output is detached from the autograd graph: routing frequencies are
2861
+ derived from discrete TopK selections and must not carry gradients.
2862
+
2863
+ Denominators are clamped to 1 to handle the all-dead-tokens edge case.
2864
+
2865
+ Args:
2866
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
2867
+ active_mask: Boolean active-token mask, shape (B, N).
2868
+
2869
+ Returns:
2870
+ f_bl: Per-batch-item routing frequencies, shape (B, L). Detached.
2871
+ """
2872
+ active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
2873
+ active_assignments = assignment_mask * active_float # (B, N, L)
2874
+ assignment_totals = (
2875
+ active_assignments.sum(dim=(1, 2)).clamp(min=1.0).unsqueeze(-1) # (B, 1)
2876
+ )
2877
+ return (active_assignments.sum(dim=1) / assignment_totals).detach() # (B, L)
2878
+
2879
+
2880
+ def reduce_probability_tokens(
2881
+ logits: torch.Tensor,
2882
+ active_mask: torch.Tensor,
2883
+ ) -> torch.Tensor:
2884
+ """Reduce per-token load-balancing logits to per-batch-item assignment probabilities.
2885
+
2886
+ p_bl[b, l] is the mean softmax probability for head l over active tokens in
2887
+ batch item b. Values sum to 1 per batch item. Gradient flows to expert_bias
2888
+ through the internal softmax.
2889
+
2890
+ Denominators are clamped to 1 to handle the all-dead-tokens edge case.
2891
+
2892
+ Args:
2893
+ logits: Load-balancing logits, shape (B, N, L). Gradient flows through.
2894
+ active_mask: Boolean active-token mask, shape (B, N).
2895
+
2896
+ Returns:
2897
+ p_bl: Per-batch-item mean assignment probabilities, shape (B, L).
2898
+ """
2899
+ per_token_probs = F.softmax(logits, dim=-1) # (B, N, L)
2900
+ active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
2901
+ active_count = active_mask.float().sum(dim=1, keepdim=True).clamp(min=1.0) # (B, 1)
2902
+ return (per_token_probs * active_float).sum(dim=1) / active_count # (B, L)
2903
+
2904
+
2905
  # ---------------------------------------------------------------------------
2906
  # Loss functions
2907
  # ---------------------------------------------------------------------------
2908
 
2909
  def gshard_loss(
2910
+ logits: torch.Tensor,
2911
+ assignment_mask: torch.Tensor,
2912
+ active_mask: torch.Tensor,
2913
  ) -> torch.Tensor:
2914
  """GShard-style linear load-balance loss.
2915
 
2916
+ Computes (1/L) * Σ_l f_bl * p_bl per batch item, averaged over B, where
2917
+ f_bl comes from reduce_frequency_tokens and p_bl from reduce_probability_tokens.
 
2918
 
2919
+ The linear signal is the weakest of the three formulations; gradient magnitude
2920
+ does not grow with violation severity. Provided for comparison.
 
 
2921
 
2922
  Args:
2923
+ logits: Load-balancing logits, shape (B, N, L).
2924
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
2925
+ active_mask: Boolean active-token mask, shape (B, N).
2926
 
2927
  Returns:
2928
  Scalar loss tensor.
2929
  """
2930
+ L = logits.shape[-1]
2931
+ f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
2932
+ p_bl = reduce_probability_tokens(logits, active_mask)
2933
+ return (f_bl * p_bl).sum(dim=-1).mean() / L
2934
 
2935
 
2936
  def ce_loss(
2937
+ logits: torch.Tensor,
2938
+ assignment_mask: torch.Tensor,
2939
+ active_mask: torch.Tensor,
2940
  ) -> torch.Tensor:
2941
  """Cross-entropy load-balance loss.
2942
 
2943
+ Constructs per-batch-item soft target distributions from routing frequencies
2944
+ and delegates to F.cross_entropy operating directly on (B, N, L) logits.
2945
+ Inactive tokens receive all-zero targets, producing zero loss and zero gradient.
2946
+
2947
+ The soft target for head l in batch item b is (1 - f_bl) / (L - 1). This
2948
+ distribution sums to 1 per batch item (since Σ_l (1 - f_bl) = L - 1) and
2949
+ weights underloaded heads (low f_bl → high target) more strongly than
2950
+ overloaded ones.
2951
 
2952
+ The total CE over active tokens is normalised by the active token count rather
2953
+ than B*N to avoid dilution from inactive positions.
 
2954
 
2955
  Args:
2956
+ logits: Load-balancing logits, shape (B, N, L).
2957
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
2958
+ active_mask: Boolean active-token mask, shape (B, N).
2959
 
2960
  Returns:
2961
  Scalar loss tensor.
2962
  """
2963
+ B, N, L = logits.shape
2964
+ f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L)
2965
+ active_count = active_mask.float().sum().clamp(min=1.0)
2966
+
2967
+ # Soft target: (1 - f_bl) / (L - 1) for active tokens, zeros for inactive.
2968
+ # Zeros give zero CE loss and zero gradient at inactive positions.
2969
+ target = (1.0 - f_bl) / (L - 1) # (B, L)
2970
+ target_per_token = (
2971
+ target.unsqueeze(1).expand(-1, N, -1) # (B, N, L)
2972
+ * active_mask.float().unsqueeze(-1) # zero inactive
2973
+ )
2974
+
2975
+ # F.cross_entropy requires the class dimension to be dim 1.
2976
+ # Permute (B, N, L) → (B, L, N) to satisfy the (N, C, d) contract.
2977
+ return F.cross_entropy(
2978
+ logits.permute(0, 2, 1), # (B, L, N)
2979
+ target_per_token.permute(0, 2, 1), # (B, L, N)
2980
+ reduction='sum',
2981
+ ) / active_count
2982
 
2983
 
2984
  def bce_loss(
2985
+ logits: torch.Tensor,
2986
+ assignment_mask: torch.Tensor,
2987
+ active_mask: torch.Tensor,
2988
  ) -> torch.Tensor:
2989
  """Binary cross-entropy load-balance loss.
2990
 
2991
+ Treats each head as an independent binary target with label (1 - f_bl).
2992
+ Uses reduce_probability_tokens to produce per-batch-item probabilities,
2993
+ then delegates to F.binary_cross_entropy over (B, L) tensors.
2994
+
2995
+ Unlike CE, BCE maintains a repulsion signal from saturated experts: when
2996
+ f_bl → 1 the target → 0, driving p_bl away from 1 and preventing runaway
2997
+ concentration.
2998
 
2999
+ Active masking is handled inside reduce_frequency_tokens and
3000
+ reduce_probability_tokens, so the (B, L) output tensors already exclude
3001
+ inactive tokens from both frequencies and probabilities.
3002
 
3003
  Args:
3004
+ logits: Load-balancing logits, shape (B, N, L).
3005
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
3006
+ active_mask: Boolean active-token mask, shape (B, N).
3007
 
3008
  Returns:
3009
  Scalar loss tensor.
3010
  """
3011
+ f_bl = reduce_frequency_tokens(assignment_mask, active_mask)
3012
+ p_bl = reduce_probability_tokens(logits, active_mask)
3013
+ # Clamp p_bl for numerical safety: F.binary_cross_entropy requires input in
3014
+ # (0, 1) and will produce inf for exactly 0 or 1. Softmax outputs are
3015
+ # strictly positive in normal operation; the clamp guards the all-dead-tokens
3016
+ # edge case where the mean defaults to zero.
3017
+ return F.binary_cross_entropy(
3018
+ p_bl.clamp(min=1e-7, max=1.0 - 1e-7),
3019
+ 1.0 - f_bl,
3020
+ reduction='mean',
3021
+ )
3022
 
3023
 
3024
  # ---------------------------------------------------------------------------
3025
  # Factory
3026
  # ---------------------------------------------------------------------------
3027
 
3028
+ _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]] = {
3029
  "gshard": gshard_loss,
3030
  "ce": ce_loss,
3031
  "bce": bce_loss,
 
3034
 
3035
  def make_load_balance_loss(
3036
  loss_type: str,
3037
+ ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3038
  """Return a load-balance loss callable for the requested formulation.
3039
 
3040
+ All returned callables share the external contract:
3041
 
3042
+ loss_fn(
3043
+ logits: Tensor[B, N, L],
3044
+ assignment_mask: Tensor[B, N, L],
3045
+ active_mask: Tensor[B, N],
3046
+ ) -> scalar Tensor
3047
 
3048
+ The caller is responsible for computing logits as logits.detach() + expert_bias
3049
+ to ensure gradient isolation to expert_bias.
3050
 
3051
  Args:
3052
  loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
 
3065
  return _LOSS_REGISTRY[loss_type]
3066
 
3067
 
 
 
3068
  class MoSRAHRouter(nn.Module):
3069
  """Token-choice router for MoSRAH sparse attention.
3070
 
3071
  Each input token independently selects K of the L available expert heads. Both
3072
+ selection and routing_probs incorporate balance_weight via two gradient-isolated
3073
+ pathways over numerically identical values. See module docstring for the
3074
+ two-pathway architecture and the integral routing extension.
3075
+
3076
+ All four learnable matrices are nn.Parameter rather than nn.Linear so that
3077
+ HuggingFace _init_weights does not override their kaiming initialization at
3078
+ construction.
3079
+
3080
+ Attributes:
3081
+ routing_weight: A, shape (L, embedding_width). Task-loss pathway.
3082
+ balance_weight: B, shape (L, embedding_width). Load-balance pathway.
3083
+ routing_integral_weight: A', shape (L, L). Integral task-loss pathway.
3084
+ Present only when ``routing_mode == "integral"``.
3085
+ balance_integral_weight: B', shape (L, L). Integral load-balance pathway.
3086
+ Present only when ``routing_mode == "integral"``.
3087
+ routing_mode: ``"integral"`` or ``"default"``, from config.
3088
 
3089
  Args:
3090
+ config: Model configuration. Must expose ``embedding_width``,
3091
+ ``num_mosrah_heads`` (L), ``num_selected_heads`` (K), and
3092
+ ``routing_mode``.
3093
  """
3094
 
3095
  def __init__(self, config: ShramConfig) -> None:
3096
  super().__init__()
3097
  self.num_mosrah_heads = config.num_mosrah_heads
3098
  self.num_selected_heads = config.num_selected_heads
 
3099
  if config.use_cache:
3100
  self.capacity = config.mosrah_cache_length
3101
  else:
3102
  self.capacity = config.mosrah_packed_length
3103
 
3104
  self.max_bid_rounds = config.max_bid_rounds
3105
+ self.routing_mode = config.routing_mode
3106
  self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
3107
 
3108
+ # W_r (A): semantic routing matrix. Maps input (B, N, d) to per-head routing
3109
+ # scores (B, N, L) for selection and routing_probs. nn.Parameter ensures
3110
+ # HuggingFace _init_weights does not override kaiming initialization.
3111
+ self.routing_weight = nn.Parameter(
3112
+ torch.empty(config.num_mosrah_heads, config.embedding_width)
3113
  )
3114
+ nn.init.kaiming_uniform_(self.routing_weight)
3115
 
3116
+ # W_b (B): load-balancing projection matrix. Maps input (B, N, d) to per-head
3117
+ # correction scores (B, N, L). Receives gradients only from load_balance_loss.
3118
+ # nn.Parameter ensures HuggingFace _init_weights does not override kaiming init.
3119
+ self.balance_weight = nn.Parameter(
3120
+ torch.empty(config.num_mosrah_heads, config.embedding_width)
 
3121
  )
3122
+ nn.init.kaiming_uniform_(self.balance_weight)
3123
+
3124
+ if self.routing_mode == "integral":
3125
+ L = config.num_mosrah_heads
3126
+ # A': integral semantic matrix. Maps cumulative logit history (B, N, L) to
3127
+ # per-head semantic corrections (B, N, L). Shape (L, L). Receives gradients
3128
+ # from task loss; balance_integral_weight is isolated from task loss.
3129
+ # Zero-initialized so that corrections start at zero and grow from gradient
3130
+ # updates — kaiming init produces corrections that immediately overwhelm the
3131
+ # base routing signal via the cumsum feedback path.
3132
+ self.routing_integral_weight = nn.Parameter(torch.zeros(L, L))
3133
+
3134
+ # B': integral load-balance matrix. Maps cumulative logit history (B, N, L)
3135
+ # to per-head load-balance corrections (B, N, L). Shape (L, L). Receives
3136
+ # gradients from load_balance_loss; routing_integral_weight is isolated.
3137
+ # Zero-initialized for the same reason as routing_integral_weight.
3138
+ self.balance_integral_weight = nn.Parameter(torch.zeros(L, L))
3139
 
3140
  @staticmethod
3141
  def get_best_proposals(
 
3371
  """Route input tokens to K expert heads each and compute routing probabilities.
3372
 
3373
  Args:
3374
+ x: Input hidden states of shape (batch, seq_len, embedding_width).
3375
  active_mask: Current-chunk active mask of shape (batch, seq_len), where
3376
  True means the token is semantically live. Dead tokens do not
3377
  contribute to routing frequencies, load_balance_loss, or max_vio.
 
3387
  router_diagnostics: Dict of routing feedback scalars. Keys:
3388
  - ``load_balance_loss``: scalar load-balance loss with gradient.
3389
  - ``max_vio``: detached scalar routing-imbalance summary.
3390
+ - ``raw_logit_std``: mean per-token std of routing_logits; natural
3391
+ routing preference scale and baseline for interpreting bias_std.
3392
+ - ``bias_std``: mean per-token std of balance_logits; near-zero
3393
+ means balance corrections have not built up relative to routing scale.
3394
  - ``logit_std``: mean per-token std of semantic_logits; lower than
3395
+ raw_logit_std means balance is flattening preferences (healthy correction).
3396
+ - ``bias_alignment``: mean cosine similarity of routing_logits vs
3397
+ balance_logits per token. Negative means balance opposes routing direction
3398
+ (healthy correction); positive means runaway reinforcement.
3399
  """
3400
  B, N, _ = x.shape
3401
  L = self.num_mosrah_heads
3402
  K = self.num_selected_heads
3403
 
3404
+ logits = self._compute_routing_logits(x, active_mask)
3405
+
3406
+ # Diagnostic scalars characterising the two routing pathways. Must be computed
3407
+ # before balance_capacity injects -1e8 sentinels that would corrupt std and
3408
+ # cosine similarity. Extracted to _compute_bias_diagnostics to keep the forward
3409
+ # body free of non-(B,N,L) reduction logic.
3410
+ bias_diagnostics = self._compute_bias_diagnostics(
3411
+ logits["routing_logits"], logits["balance_logits"], logits["semantic_logits"]
3412
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3413
 
3414
  # Pre-capacity semantic softmax for gathering routing_probs. Computed before
3415
  # balance_capacity so that gathered probabilities reflect genuine preference
3416
  # magnitudes rather than hard-masked sentinel values.
3417
+ routing_scores = F.softmax(logits["semantic_logits"], dim=-1) # (B, N, L)
3418
 
3419
  # Capacity-balanced semantic logits for selection. Injects -1e8 into positions
3420
  # that would exceed per-expert token budget, enforcing the packing constraint.
3421
  balanced_semantic_logits = self.balance_capacity(
3422
+ logits["semantic_logits"],
3423
  used_capacity,
3424
  self.capacity,
3425
  self.num_selected_heads,
 
3435
  gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K)
3436
  routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
3437
 
3438
+ # assignment_mask: (B, N, L) float 1.0 at each token's K selected heads, 0 elsewhere.
3439
+ # The discrete routing decision; no gradient flows through it. Passed alongside
3440
+ # load_balancing_logits and active_mask to the loss and max_vio methods, which
3441
+ # own all frequency aggregation and reduction internally.
3442
  assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
3443
  assignment_mask.scatter_(-1, selected_heads, 1.0)
 
 
 
 
 
 
 
 
 
 
 
3444
 
3445
+ load_balance_loss = self._load_balance_loss(
3446
+ logits["load_balancing_logits"], assignment_mask, active_mask
3447
+ )
3448
 
3449
+ # MaxVio: detached monitoring scalar averaged over batch items. Computed from
3450
+ # the same (B, N, L) assignment_mask so frequencies are consistent with the loss.
3451
+ max_vio = self._compute_max_vio(assignment_mask, active_mask, L)
3452
 
3453
  router_diagnostics = {
3454
  "load_balance_loss": load_balance_loss,
3455
  "max_vio": max_vio,
3456
+ **bias_diagnostics,
 
 
 
3457
  }
3458
  return selected_heads, routing_probs, router_diagnostics
3459
 
3460
  @staticmethod
3461
+ def exclusive_cumsum(logits: torch.Tensor) -> torch.Tensor:
3462
+ """Compute the exclusive cumulative sum along the sequence dimension.
3463
+
3464
+ u[n] = sum(logits[0..n-1]): position n receives the accumulated sum of all
3465
+ prior positions, giving it a read on the routing preferences expressed by
3466
+ earlier tokens in the sequence. Position 0 always receives zeros — no prior
3467
+ history exists at the first position.
3468
+
3469
+ Args:
3470
+ logits: Shape (B, N, L). Any per-head score tensor along a sequence.
3471
+
3472
+ Returns:
3473
+ Exclusive cumsum, shape (B, N, L). Same dtype and device as input.
3474
+ """
3475
+ shifted = torch.cat(
3476
+ [torch.zeros_like(logits[:, :1, :]), logits[:, :-1, :]], dim=1
3477
+ )
3478
+ return shifted.cumsum(dim=1)
3479
+
3480
+ def _compute_routing_logits(
3481
+ self, x: torch.Tensor, active_mask: torch.Tensor
3482
+ ) -> dict[str, torch.Tensor]:
3483
+ """Compute the gradient-isolated logit pathways from input hidden states.
3484
+
3485
+ Base pathways (both modes):
3486
+
3487
+ Two gradient-isolated pathways over numerically identical values:
3488
+ - semantic_logits = A·x + (B·x).detach(): task gradients reach routing_weight;
3489
+ balance_weight is isolated from task loss.
3490
+ - load_balancing_logits = (A·x).detach() + B·(x.detach()): load balance
3491
+ gradients reach balance_weight; routing_weight and x are isolated.
3492
+
3493
+ Integral extension (routing_mode == "integral"):
3494
+
3495
+ Dead tokens are zeroed out of the logits before computing the cumsum, so
3496
+ inactive positions do not contribute to the routing history of downstream
3497
+ live tokens. u_semantic and u_load therefore represent history from live
3498
+ tokens only.
3499
+
3500
+ u_semantic = exclusive_cumsum(semantic_logits * active_mask) — (B, N, L)
3501
+ u_load = exclusive_cumsum(load_balancing_logits * active_mask) — (B, N, L)
3502
+
3503
+ semantic_logits += A'·u_semantic + (B'·u_semantic).detach()
3504
+ load_balancing_logits += (A'·u_load).detach() + B'·u_load
3505
+
3506
+ Detaching the full (B'·u_semantic) result mirrors the (B·x).detach() base
3507
+ pattern: it isolates balance_integral_weight from task loss AND prevents
3508
+ double-counting the cumsum gradient path back to routing_weight.
3509
+ The same reasoning applies to (A'·u_load).detach() in the load-balance
3510
+ pathway — u_load already has no path to routing_weight (routing_logits is
3511
+ detached in load_balancing_logits), and the detach additionally blocks
3512
+ routing_integral_weight.
3513
+
3514
+ Args:
3515
+ x: Input hidden states, shape (batch, seq_len, embedding_width).
3516
+ active_mask: Boolean active-token mask, shape (batch, seq_len). Dead tokens
3517
+ are excluded from the cumsum history in integral mode.
3518
+
3519
+ Returns:
3520
+ Dict with keys:
3521
+ - ``routing_logits``: A·x, shape (B, N, L).
3522
+ - ``balance_logits``: B·x, shape (B, N, L).
3523
+ - ``semantic_logits``: combined task-loss pathway, shape (B, N, L).
3524
+ - ``load_balancing_logits``: combined load-balance pathway, shape (B, N, L).
3525
+ """
3526
+ routing_logits = F.linear(x, self.routing_weight) # (B, N, L)
3527
+ balance_logits = F.linear(x, self.balance_weight) # (B, N, L)
3528
+ semantic_logits = routing_logits + balance_logits.detach()
3529
+ load_balancing_logits = routing_logits.detach() + F.linear(x.detach(), self.balance_weight)
3530
+
3531
+ if self.routing_mode == "integral":
3532
+ # Zero out dead token positions before cumsum so inactive tokens do not
3533
+ # contaminate the routing history of subsequent live tokens.
3534
+ live = active_mask.unsqueeze(-1) # (B, N, 1)
3535
+ u_semantic = self.exclusive_cumsum(semantic_logits * live) # (B, N, L)
3536
+ u_load = self.exclusive_cumsum(load_balancing_logits * live) # (B, N, L)
3537
+
3538
+ # Semantic pathway: A' trains on task loss; B' term is fully detached to
3539
+ # isolate balance_integral_weight from task loss and prevent double-counting
3540
+ # the cumsum gradient path back to routing_weight.
3541
+ semantic_logits = (
3542
+ semantic_logits
3543
+ + F.linear(u_semantic, self.routing_integral_weight)
3544
+ + F.linear(u_semantic, self.balance_integral_weight).detach()
3545
+ )
3546
+
3547
+ # Load-balance pathway: B' trains on load_balance_loss; A' term is fully
3548
+ # detached to isolate routing_integral_weight from load_balance_loss.
3549
+ load_balancing_logits = (
3550
+ load_balancing_logits
3551
+ + F.linear(u_load, self.routing_integral_weight).detach()
3552
+ + F.linear(u_load, self.balance_integral_weight)
3553
+ )
3554
+
3555
+ return {
3556
+ "routing_logits": routing_logits,
3557
+ "balance_logits": balance_logits,
3558
+ "semantic_logits": semantic_logits,
3559
+ "load_balancing_logits": load_balancing_logits,
3560
+ }
3561
+
3562
+ @staticmethod
3563
+ def _compute_bias_diagnostics(
3564
+ routing_logits: torch.Tensor,
3565
+ balance_logits: torch.Tensor,
3566
+ semantic_logits: torch.Tensor,
3567
+ ) -> dict[str, torch.Tensor]:
3568
+ """Compute detached diagnostic scalars characterising the two routing pathways.
3569
+
3570
+ All scalars must be computed from pre-capacity logits; balance_capacity
3571
+ applies -1e8 sentinels that would corrupt std and cosine similarity.
3572
+ Extracted from forward to keep the main body free of reduction logic.
3573
+
3574
+ Args:
3575
+ routing_logits: A·x, routing pathway output, shape (B, N, L).
3576
+ balance_logits: B·x, balance pathway output, shape (B, N, L).
3577
+ semantic_logits: A·x + (B·x).detach(), combined signal, shape (B, N, L).
3578
+
3579
+ Returns:
3580
+ Dict with keys:
3581
+ - ``raw_logit_std``: Mean per-token std of routing_logits. Natural
3582
+ routing preference scale; reference baseline for
3583
+ interpreting bias_std.
3584
+ - ``bias_std``: Mean per-token std of balance_logits. Near-zero
3585
+ means balance corrections have not built up
3586
+ relative to the routing scale.
3587
+ - ``logit_std``: Mean per-token std of semantic_logits. Lower than
3588
+ raw_logit_std indicates balance is flattening
3589
+ preferences (healthy correction signal).
3590
+ - ``bias_alignment``: Mean cosine similarity of routing_logits vs
3591
+ balance_logits per token. Range [-1, 1]. Negative
3592
+ means balance opposes routing direction (healthy
3593
+ correction); positive means runaway reinforcement.
3594
+ """
3595
+ return {
3596
+ "raw_logit_std": routing_logits.std(dim=-1).mean().detach(),
3597
+ "bias_std": balance_logits.std(dim=-1).mean().detach(),
3598
+ "logit_std": semantic_logits.std(dim=-1).mean().detach(),
3599
+ "bias_alignment": F.cosine_similarity(
3600
+ routing_logits, balance_logits, dim=-1
3601
+ ).mean().detach(),
3602
+ }
3603
+
3604
+ @staticmethod
3605
+ def _compute_max_vio(
3606
+ assignment_mask: torch.Tensor,
3607
+ active_mask: torch.Tensor,
3608
+ num_heads: int,
3609
+ ) -> torch.Tensor:
3610
  """Compute the MaxVio routing-imbalance scalar.
3611
 
3612
+ MaxVio = mean_b( L · max_l(f_bl − 1/L) ), where f_bl is the per-batch-item
3613
+ realised routing frequency of head l. Uses reduce_frequency_tokens for consistent
3614
+ per-batch-item frequency computation with dead tokens excluded, matching how the
3615
+ load balance loss computes frequencies. A value of zero indicates perfect balance;
3616
+ a value of 0.5 means the most overloaded head in the average batch item received
3617
+ 50% more routed tokens than ideal.
3618
 
3619
+ The result is detached — MaxVio is a monitoring scalar and must not contribute
3620
+ gradients to any parameter.
3621
 
3622
  Args:
3623
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
3624
+ active_mask: Boolean active-token mask, shape (B, N).
3625
+ num_heads: Total number of MoSRAH heads L.
3626
 
3627
  Returns:
3628
  Detached scalar MaxVio tensor.
3629
  """
3630
+ f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L)
3631
+ per_item_max_vio = num_heads * (f_bl - 1.0 / num_heads).max(dim=-1).values # (B,)
3632
+ return per_item_max_vio.mean().detach()
3633
 
3634
  # -----------
3635
  # Inlined from: positions_converter.py