smithblack-0 commited on
Commit
dea5927
·
verified ·
1 Parent(s): 46c9e9f

Update architecture and tokenizer

Browse files
Files changed (5) hide show
  1. README.md +2 -1
  2. config.json +2 -1
  3. configuration.py +19 -1
  4. huggingface.py +176 -79
  5. tokenizer_config.json +1 -1
README.md CHANGED
@@ -82,7 +82,8 @@ contains no weights. All values are overridable via kwargs.
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
- | `load_balance_p` | 2.0 |
 
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
88
  | `mlp_width` | 1366 |
 
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
+ | `load_balance_loss_type` | ce |
86
+ | `load_balance_p` | 1.0 |
87
  | `local_rope_theta` | 10000.0 |
88
  | `max_bid_rounds` | 10 |
89
  | `mlp_width` | 1366 |
config.json CHANGED
@@ -9,7 +9,8 @@
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
- "load_balance_p": 2.0,
 
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
15
  "mlp_width": 1366,
 
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
+ "load_balance_loss_type": "ce",
13
+ "load_balance_p": 1.0,
14
  "local_rope_theta": 10000.0,
15
  "max_bid_rounds": 10,
16
  "mlp_width": 1366,
configuration.py CHANGED
@@ -94,6 +94,11 @@ class ShramConfig(PretrainedConfig):
94
  cases are not expected under normal training. The bound exists as a
95
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
96
  Default 10.
 
 
 
 
 
97
  """
98
 
99
  model_type = "shram"
@@ -127,8 +132,9 @@ class ShramConfig(PretrainedConfig):
127
  output_hidden_states: bool = False,
128
  tie_word_embeddings: bool = False,
129
  mosrah_overallocation_factor: float = 2.0,
130
- load_balance_p: float = 2.0,
131
  max_bid_rounds: int = 10,
 
132
  **kwargs
133
  ):
134
  if head_dim % 2 != 0:
@@ -174,6 +180,17 @@ class ShramConfig(PretrainedConfig):
174
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
175
  )
176
 
 
 
 
 
 
 
 
 
 
 
 
177
  self.vocab_size = vocab_size
178
  self.embedding_width = embedding_width
179
  self.mlp_width = mlp_width
@@ -194,6 +211,7 @@ class ShramConfig(PretrainedConfig):
194
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
195
  self.load_balance_p = load_balance_p
196
  self.max_bid_rounds = max_bid_rounds
 
197
  self.attention_dropout = attention_dropout
198
  self.use_cache = use_cache
199
 
 
94
  cases are not expected under normal training. The bound exists as a
95
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
96
  Default 10.
97
+ load_balance_loss_type: Formula used for the load-balance auxiliary loss.
98
+ One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
99
+ is the default; its log-probability signal scales with violation severity
100
+ and makes correction magnitude proportional to routing imbalance.
101
+ Default ``"ce"``.
102
  """
103
 
104
  model_type = "shram"
 
132
  output_hidden_states: bool = False,
133
  tie_word_embeddings: bool = False,
134
  mosrah_overallocation_factor: float = 2.0,
135
+ load_balance_p: float = 1.0,
136
  max_bid_rounds: int = 10,
137
+ load_balance_loss_type: str = "ce",
138
  **kwargs
139
  ):
140
  if head_dim % 2 != 0:
 
180
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
181
  )
182
 
183
+ _supported_loss_types = {"gshard", "ce", "bce"}
184
+ if load_balance_loss_type not in _supported_loss_types:
185
+ supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
186
+ raise ValueError(
187
+ f"load_balance_loss_type must be one of {supported}, "
188
+ f"got {load_balance_loss_type!r}."
189
+ )
190
+ if load_balance_loss_type == "ce" and load_balance_p != 1.0:
191
+ raise ValueError("In cross entropy mode, aggregation of "
192
+ "frequencies must be with mean 1.0")
193
+
194
  self.vocab_size = vocab_size
195
  self.embedding_width = embedding_width
196
  self.mlp_width = mlp_width
 
211
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
212
  self.load_balance_p = load_balance_p
213
  self.max_bid_rounds = max_bid_rounds
214
+ self.load_balance_loss_type = load_balance_loss_type
215
  self.attention_dropout = attention_dropout
216
  self.use_cache = use_cache
217
 
huggingface.py CHANGED
@@ -44,6 +44,7 @@ from torch import nn
44
  from torch.nn.attention.flex_attention import create_block_mask
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
 
47
  from typing import Optional
48
 
49
 
@@ -181,6 +182,11 @@ class ShramConfig(PretrainedConfig):
181
  cases are not expected under normal training. The bound exists as a
182
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
183
  Default 10.
 
 
 
 
 
184
  """
185
 
186
  model_type = "shram"
@@ -214,8 +220,9 @@ class ShramConfig(PretrainedConfig):
214
  output_hidden_states: bool = False,
215
  tie_word_embeddings: bool = False,
216
  mosrah_overallocation_factor: float = 2.0,
217
- load_balance_p: float = 2.0,
218
  max_bid_rounds: int = 10,
 
219
  **kwargs
220
  ):
221
  if head_dim % 2 != 0:
@@ -261,6 +268,17 @@ class ShramConfig(PretrainedConfig):
261
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
264
  self.vocab_size = vocab_size
265
  self.embedding_width = embedding_width
266
  self.mlp_width = mlp_width
@@ -281,6 +299,7 @@ class ShramConfig(PretrainedConfig):
281
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
282
  self.load_balance_p = load_balance_p
283
  self.max_bid_rounds = max_bid_rounds
 
284
  self.attention_dropout = attention_dropout
285
  self.use_cache = use_cache
286
 
@@ -2714,9 +2733,10 @@ This separation is architecturally critical: expert_bias drives selection (and t
2714
  balancing) but does not corrupt the gradient path from the output through routing_probs
2715
  back to the routing projection weights.
2716
 
2717
- The router also computes and returns the load balance loss via the LoadBalanceLoss custom
2718
- autograd operator (see load_balance_loss.py). This loss is a scalar that the training
2719
- loop can weight and add to the language modeling loss.
 
2720
 
2721
  The router additionally computes and returns MaxVio, a detached scalar summarising
2722
  routing imbalance for the current forward pass:
@@ -2737,94 +2757,165 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
2737
  # -----------
2738
  # Inlined from: load_balance_loss.py
2739
  # -----------
2740
- """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2741
-
2742
- This module implements the custom autograd Function H(b, f) described in the paper's
2743
- Implementation Concerns section. The operator bridges two requirements that are in
2744
- tension: it must behave like a standard auxiliary loss (scalar output, scalable via
2745
- multiplication) so that existing training loops remain compatible, while simultaneously
2746
- implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
2747
- path through the router weights.
2748
-
2749
- The resolution is a custom backward pass. The forward emits the load balance imbalance
2750
- as a scalar loss. The backward, instead of differentiating that scalar with respect to
2751
- its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
2752
- then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
2753
- correction without a standalone SGD update step.
2754
-
2755
- Paper ref: Appendix A.Implementation Concerns.
 
 
 
 
 
 
 
 
2756
  """
2757
 
2758
 
2759
 
2760
 
2761
- class LoadBalanceLoss(torch.autograd.Function):
2762
- """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
2763
 
2764
- Forward computes the load balance imbalance:
 
 
2765
 
2766
- L_load_balance = H(b, f) = sum_l | f_l - 1/L |
 
 
 
 
2767
 
2768
- Backward emits a bias-correction gradient to expert_bias:
 
 
2769
 
2770
- grad_b = L_grad * sign(f_l - 1/L)
 
 
 
2771
 
2772
- expert_bias (b) is included as a forward input so PyTorch registers it as a node
2773
- in the computation graph and routes gradients through it. routing_freqs (f) receives
2774
- no gradient its origin is the discrete TopK operation which has no gradient, so
2775
- defining a gradient for f here would be mathematically incorrect.
2776
 
2777
- Paper ref: Appendix A.Implementation Concerns.
 
2778
  """
 
 
2779
 
2780
- @staticmethod
2781
- def forward(
2782
- ctx: torch.autograd.function.FunctionCtx,
2783
- expert_bias: torch.Tensor,
2784
- routing_freqs: torch.Tensor,
2785
- ) -> torch.Tensor:
2786
- """Compute the load balance loss.
2787
 
2788
- Args:
2789
- ctx: Autograd context for saving state needed in backward.
2790
- expert_bias: Learned per-head bias b, shape (L,). Included as an input so
2791
- PyTorch tracks it as a computation graph node needing a gradient.
2792
- routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
2793
- from the discrete TopK selection — not differentiable.
2794
 
2795
- Returns:
2796
- Scalar loss equal to sum_l |f_l - 1/L|.
2797
- """
2798
- L = expert_bias.shape[0]
2799
- # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
2800
- # underloaded. Saved for backward where sign(imbalance) determines the direction
2801
- # of the bias-correction update.
2802
- imbalance = routing_freqs - 1.0 / L
2803
- ctx.save_for_backward(imbalance)
2804
- return imbalance.abs().sum()
2805
 
2806
- @staticmethod
2807
- def backward(
2808
- ctx: torch.autograd.function.FunctionCtx,
2809
- grad_output: torch.Tensor,
2810
- ) -> tuple[torch.Tensor, None]:
2811
- """Emit the DeepSeek-style bias-correction gradient.
2812
 
2813
- Args:
2814
- ctx: Autograd context carrying imbalance saved in forward.
2815
- grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
2816
- by the training loop arrives here and is propagated to grad_b, so the
2817
- correction magnitude is proportional to the loss weight chosen by the
2818
- consumer.
2819
 
2820
- Returns:
2821
- Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
2822
- None for routing_freqs: no gradient is defined for the discrete routing
2823
- frequency.
2824
- """
2825
- (imbalance,) = ctx.saved_tensors
2826
- grad_expert_bias = grad_output * imbalance.sign()
2827
- return grad_expert_bias, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2828
 
2829
 
2830
 
@@ -2857,6 +2948,7 @@ class MoSRAHRouter(nn.Module):
2857
  self.capacity = config.mosrah_packed_length
2858
 
2859
  self.max_bid_rounds = config.max_bid_rounds
 
2860
 
2861
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
2862
  self.routing_projection = nn.Linear(
@@ -3143,12 +3235,13 @@ class MoSRAHRouter(nn.Module):
3143
  logits, self.expert_bias.expand_as(logits), dim=-1
3144
  ).mean().detach()
3145
 
 
3146
  routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
3147
 
3148
  # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
3149
  # selection. expert_bias is added to logits before softmax so that the bias
3150
  # shifts selection probability without rescaling the unbiased distribution.
3151
- biased_logits = logits + self.expert_bias
3152
  biased_logits = self.balance_capacity(
3153
  biased_logits,
3154
  used_capacity,
@@ -3189,10 +3282,14 @@ class MoSRAHRouter(nn.Module):
3189
  p = self.load_balance_p
3190
  routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
3191
 
3192
- # Load balance loss via custom autograd. expert_bias is an input so PyTorch
3193
- # registers it as a graph node; the custom backward writes the DeepSeek-style
3194
- # correction gradient to expert_bias.grad for the optimizer to consume.
3195
- load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
 
 
 
 
3196
 
3197
  # MaxVio is a detached monitoring scalar following the paper's formula
3198
  # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
 
44
  from torch.nn.attention.flex_attention import create_block_mask
45
  from torch.nn.attention.flex_attention import flex_attention
46
  import torch.nn.functional as F
47
+ from typing import Callable
48
  from typing import Optional
49
 
50
 
 
182
  cases are not expected under normal training. The bound exists as a
183
  correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1.
184
  Default 10.
185
+ load_balance_loss_type: Formula used for the load-balance auxiliary loss.
186
+ One of ``"gshard"``, ``"ce"``, or ``"bce"``. ``"ce"`` (cross-entropy)
187
+ is the default; its log-probability signal scales with violation severity
188
+ and makes correction magnitude proportional to routing imbalance.
189
+ Default ``"ce"``.
190
  """
191
 
192
  model_type = "shram"
 
220
  output_hidden_states: bool = False,
221
  tie_word_embeddings: bool = False,
222
  mosrah_overallocation_factor: float = 2.0,
223
+ load_balance_p: float = 1.0,
224
  max_bid_rounds: int = 10,
225
+ load_balance_loss_type: str = "ce",
226
  **kwargs
227
  ):
228
  if head_dim % 2 != 0:
 
268
  f"max_bid_rounds must be at least 1, got {max_bid_rounds}."
269
  )
270
 
271
+ _supported_loss_types = {"gshard", "ce", "bce"}
272
+ if load_balance_loss_type not in _supported_loss_types:
273
+ supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
274
+ raise ValueError(
275
+ f"load_balance_loss_type must be one of {supported}, "
276
+ f"got {load_balance_loss_type!r}."
277
+ )
278
+ if load_balance_loss_type == "ce" and load_balance_p != 1.0:
279
+ raise ValueError("In cross entropy mode, aggregation of "
280
+ "frequencies must be with mean 1.0")
281
+
282
  self.vocab_size = vocab_size
283
  self.embedding_width = embedding_width
284
  self.mlp_width = mlp_width
 
299
  self.mosrah_overallocation_factor = mosrah_overallocation_factor
300
  self.load_balance_p = load_balance_p
301
  self.max_bid_rounds = max_bid_rounds
302
+ self.load_balance_loss_type = load_balance_loss_type
303
  self.attention_dropout = attention_dropout
304
  self.use_cache = use_cache
305
 
 
2733
  balancing) but does not corrupt the gradient path from the output through routing_probs
2734
  back to the routing projection weights.
2735
 
2736
+ The router also computes and returns the load balance loss via a log-probability auxiliary
2737
+ loss (see load_balance_loss.py). The loss formulation is selected by config; the default
2738
+ is cross-entropy. Gradients flow only to expert_bias routing_projection.weight is
2739
+ isolated by detaching logits before computing assignment probabilities.
2740
 
2741
  The router additionally computes and returns MaxVio, a detached scalar summarising
2742
  routing imbalance for the current forward pass:
 
2757
  # -----------
2758
  # Inlined from: load_balance_loss.py
2759
  # -----------
2760
+ """Log-probability auxiliary loss functions for MoSRAH load balancing.
2761
+
2762
+ This module provides three load-balance loss formulations and a factory that selects
2763
+ among them. All formulations share the same external contract and the same gradient
2764
+ isolation property: assignment probabilities are computed from detached logits plus
2765
+ expert_bias, so only expert_bias receives gradients from the loss signal. The routing
2766
+ projection weights are not reachable from any returned loss.
2767
+
2768
+ The factory is the intended entry point. The caller (MoSRAHRouter) constructs the
2769
+ loss callable once at init and invokes it each forward pass.
2770
+
2771
+ Log-probability formulations (ce, bce) are preferred over linear ones (gshard) because
2772
+ their gradient magnitude scales with how far the distribution deviates from the target.
2773
+ A linear signal can be outrun by routing concentrations that diverge nonlinearly; a
2774
+ log-probability signal cannot.
2775
+
2776
+ The external contract for all returned callables is:
2777
+
2778
+ loss_fn(routing_freqs, assignment_probs) -> scalar Tensor
2779
+
2780
+ routing_freqs: (L,) realized routing frequencies f_i, detached.
2781
+ assignment_probs: (L,) soft assignment probabilities p_i with gradient through
2782
+ expert_bias. Caller must compute these via
2783
+ softmax(logits.detach() + expert_bias) to preserve isolation.
2784
  """
2785
 
2786
 
2787
 
2788
 
 
 
2789
 
2790
+ # ---------------------------------------------------------------------------
2791
+ # Loss functions
2792
+ # ---------------------------------------------------------------------------
2793
 
2794
+ def gshard_loss(
2795
+ routing_freqs: torch.Tensor,
2796
+ assignment_probs: torch.Tensor,
2797
+ ) -> torch.Tensor:
2798
+ """GShard-style linear load-balance loss.
2799
 
2800
+ Computes (1/L) * Σ_i f_i * p_i, where L is the number of expert heads,
2801
+ f_i is the realized routing frequency for head i, and p_i is the soft
2802
+ assignment probability for head i.
2803
 
2804
+ The fixed point of this loss under gradient descent is uniform routing:
2805
+ when p_i = 1/L for all i, the loss is minimized at 1/L (independent of f_i).
2806
+ The linear signal is the weakest of the three formulations — gradient magnitude
2807
+ does not grow with deviation from the target. Provided for comparison.
2808
 
2809
+ Args:
2810
+ routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2811
+ assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2812
+ flows to expert_bias through this tensor.
2813
 
2814
+ Returns:
2815
+ Scalar loss tensor.
2816
  """
2817
+ L = routing_freqs.shape[0]
2818
+ return (routing_freqs * assignment_probs).sum() / L
2819
 
 
 
 
 
 
 
 
2820
 
2821
+ def ce_loss(
2822
+ routing_freqs: torch.Tensor,
2823
+ assignment_probs: torch.Tensor,
2824
+ ) -> torch.Tensor:
2825
+ """Cross-entropy load-balance loss.
 
2826
 
2827
+ Computes -(1/(L-1)) * Σ_i (1 - f_i) * log(p_i), where the weight (1 - f_i)
2828
+ suppresses the signal for overloaded heads (high f_i → weight near zero) and
2829
+ amplifies it for underloaded heads (low f_i → weight near 1). This makes the
2830
+ loss push probability mass toward under-utilized experts.
 
 
 
 
 
 
2831
 
2832
+ The (1/(L-1)) normalization makes the coefficient interpretable as a controller
2833
+ strength independent of expert count. The log-probability signal grows as p_i
2834
+ deviates from the target, providing correction that scales with violation severity.
 
 
 
2835
 
2836
+ Args:
2837
+ routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2838
+ assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2839
+ flows to expert_bias through this tensor.
 
 
2840
 
2841
+ Returns:
2842
+ Scalar loss tensor.
2843
+ """
2844
+ L = routing_freqs.shape[0]
2845
+ # Numerical stability: torch.log is safe here because softmax outputs are
2846
+ # strictly positive. The (1 - f_i) weight goes to zero exactly when f_i = 1,
2847
+ # which can only occur with a single head, so the 0 * (-inf) degenerate case
2848
+ # does not arise in practice.
2849
+ return -(((1.0 - routing_freqs) * torch.log(assignment_probs)).sum()) / (L - 1)
2850
+
2851
+
2852
+ def bce_loss(
2853
+ routing_freqs: torch.Tensor,
2854
+ assignment_probs: torch.Tensor,
2855
+ ) -> torch.Tensor:
2856
+ """Binary cross-entropy load-balance loss.
2857
+
2858
+ Computes -(1/L) * Σ_i [(1 - f_i) * log(p_i) + f_i * log(1 - p_i)], where
2859
+ each head is treated as an independent binary target. Unlike CE, BCE maintains
2860
+ a repulsion signal from saturated experts: when f_i → 1, the weight on
2861
+ log(1 - p_i) drives p_i away from 1, preventing runaway concentration.
2862
+
2863
+ log(1 - p_i) is computed as log1p(-p_i) for numerical safety near p_i = 1.
2864
+
2865
+ Args:
2866
+ routing_freqs: Realized routing frequencies f_i, shape (L,). Detached.
2867
+ assignment_probs: Soft assignment probabilities p_i, shape (L,). Gradient
2868
+ flows to expert_bias through this tensor.
2869
+
2870
+ Returns:
2871
+ Scalar loss tensor.
2872
+ """
2873
+ L = routing_freqs.shape[0]
2874
+ positive_term = (1.0 - routing_freqs) * torch.log(assignment_probs)
2875
+ # log1p(-p) instead of log(1-p): avoids catastrophic cancellation when p is
2876
+ # close to 1, where (1 - p) loses precision and log produces large errors.
2877
+ negative_term = routing_freqs * torch.log1p(-assignment_probs)
2878
+ return -(positive_term + negative_term).sum() / L
2879
+
2880
+
2881
+ # ---------------------------------------------------------------------------
2882
+ # Factory
2883
+ # ---------------------------------------------------------------------------
2884
+
2885
+ _LOSS_REGISTRY: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {
2886
+ "gshard": gshard_loss,
2887
+ "ce": ce_loss,
2888
+ "bce": bce_loss,
2889
+ }
2890
+
2891
+
2892
+ def make_load_balance_loss(
2893
+ loss_type: str,
2894
+ ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
2895
+ """Return a load-balance loss callable for the requested formulation.
2896
+
2897
+ All returned callables share the same external contract:
2898
+
2899
+ loss_fn(routing_freqs: Tensor, assignment_probs: Tensor) -> scalar Tensor
2900
+
2901
+ The caller is responsible for computing assignment_probs via
2902
+ softmax(logits.detach() + expert_bias) to ensure gradient isolation.
2903
+
2904
+ Args:
2905
+ loss_type: One of ``"gshard"``, ``"ce"``, or ``"bce"``.
2906
+
2907
+ Returns:
2908
+ Loss callable matching the shared contract.
2909
+
2910
+ Raises:
2911
+ ValueError: If loss_type is not one of the supported values.
2912
+ """
2913
+ if loss_type not in _LOSS_REGISTRY:
2914
+ supported = ", ".join(f'"{k}"' for k in _LOSS_REGISTRY)
2915
+ raise ValueError(
2916
+ f"load_balance_loss_type must be one of {supported}, got {loss_type!r}."
2917
+ )
2918
+ return _LOSS_REGISTRY[loss_type]
2919
 
2920
 
2921
 
 
2948
  self.capacity = config.mosrah_packed_length
2949
 
2950
  self.max_bid_rounds = config.max_bid_rounds
2951
+ self._load_balance_loss = make_load_balance_loss(config.load_balance_loss_type)
2952
 
2953
  # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
2954
  self.routing_projection = nn.Linear(
 
3235
  logits, self.expert_bias.expand_as(logits), dim=-1
3236
  ).mean().detach()
3237
 
3238
+ # Routing scores. Direct.
3239
  routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
3240
 
3241
  # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
3242
  # selection. expert_bias is added to logits before softmax so that the bias
3243
  # shifts selection probability without rescaling the unbiased distribution.
3244
+ biased_logits = logits.detach() + self.expert_bias
3245
  biased_logits = self.balance_capacity(
3246
  biased_logits,
3247
  used_capacity,
 
3282
  p = self.load_balance_p
3283
  routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
3284
 
3285
+ # Active-token mean softmax probabilities. Detaching logits before softmax
3286
+ # ensures the only differentiable path into p is through expert_bias the
3287
+ # load balance loss cannot reach routing_projection.weight.
3288
+ biased_probs = biased_routing_scores # (B, N, L)
3289
+ active_float = active_mask.float().unsqueeze(-1) # (B, N, 1)
3290
+ assignment_probs = (biased_probs * active_float).sum(dim=(0, 1)) # (L,) unnorm
3291
+ assignment_probs = assignment_probs / active_mask.float().sum() # (L,) norm
3292
+ load_balance_loss = self._load_balance_loss(routing_freqs, assignment_probs)
3293
 
3294
  # MaxVio is a detached monitoring scalar following the paper's formula
3295
  # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
tokenizer_config.json CHANGED
@@ -4,7 +4,7 @@
4
  "bos_token": "<|endoftext|>",
5
  "eos_token": "<|endoftext|>",
6
  "errors": "replace",
7
- "is_local": false,
8
  "local_files_only": false,
9
  "model_max_length": 1000000000000000019884624838656,
10
  "pad_token": "<|padding|>",
 
4
  "bos_token": "<|endoftext|>",
5
  "eos_token": "<|endoftext|>",
6
  "errors": "replace",
7
+ "is_local": true,
8
  "local_files_only": false,
9
  "model_max_length": 1000000000000000019884624838656,
10
  "pad_token": "<|padding|>",