smithblack-0 commited on
Commit
e56b8fd
Β·
verified Β·
1 Parent(s): f192edb

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. config.json +2 -2
  3. configuration.py +13 -11
  4. huggingface.py +171 -15
README.md CHANGED
@@ -82,7 +82,7 @@ contains no weights. All values are overridable via kwargs.
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
- | `load_balance_loss_type` | temporal_overcapacity |
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
88
  | `maximum_expert_overclaim` | 20 |
 
82
  | `embedding_width` | 512 |
83
  | `head_dim` | 16 |
84
  | `inference_sequence_length` | 1024 |
85
+ | `load_balance_loss_type` | causal_overcapacity |
86
  | `local_rope_theta` | 10000.0 |
87
  | `max_bid_rounds` | 10 |
88
  | `maximum_expert_overclaim` | 20 |
config.json CHANGED
@@ -9,7 +9,7 @@
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
- "load_balance_loss_type": "temporal_overcapacity",
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
15
  "maximum_expert_overclaim": 20,
@@ -25,7 +25,7 @@
25
  "rope_mode": "main_sequence",
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
- "transformers_version": "5.10.2",
29
  "use_cache": true,
30
  "vocab_size": 50277,
31
  "window_size": 128
 
9
  "embedding_width": 512,
10
  "head_dim": 16,
11
  "inference_sequence_length": 1024,
12
+ "load_balance_loss_type": "causal_overcapacity",
13
  "local_rope_theta": 10000.0,
14
  "max_bid_rounds": 10,
15
  "maximum_expert_overclaim": 20,
 
25
  "rope_mode": "main_sequence",
26
  "tie_word_embeddings": false,
27
  "training_sequence_length": 1024,
28
+ "transformers_version": "5.11.0",
29
  "use_cache": true,
30
  "vocab_size": 50277,
31
  "window_size": 128
configuration.py CHANGED
@@ -91,17 +91,19 @@ class ShramConfig(PretrainedConfig):
91
  correctness guard β€” exhausting it raises ``RuntimeError``. Must be >= 1.
92
  Default 10.
93
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
94
- One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
95
- ``"temporal_overcapacity"`` is the default; it fires only when an expert
96
- exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
97
- and shuts off automatically once routing is balanced, allowing it to be
98
- used with a strong weight without interfering with task training during
99
- balanced routing. Default ``"temporal_overcapacity"``.
 
 
100
  maximum_expert_overclaim: Maximum number of tokens an expert may receive above
101
- its ideal allocation trajectory before the temporal overcapacity loss
102
- fires. A value of 0 means violations trigger immediately at any imbalance.
103
  Larger values permit short-lived semantic specialization before correction.
104
- Only used when ``load_balance_loss_type="temporal_overcapacity"``.
105
  Must be non-negative. Default 20.
106
  """
107
 
@@ -137,7 +139,7 @@ class ShramConfig(PretrainedConfig):
137
  tie_word_embeddings: bool = False,
138
  mosrah_overallocation_factor: float = 2.0,
139
  max_bid_rounds: int = 10,
140
- load_balance_loss_type: str = "temporal_overcapacity",
141
  maximum_expert_overclaim: int = 20,
142
  **kwargs
143
  ):
@@ -185,7 +187,7 @@ class ShramConfig(PretrainedConfig):
185
  f"got {maximum_expert_overclaim}."
186
  )
187
 
188
- _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
189
  if load_balance_loss_type not in _supported_loss_types:
190
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
191
  raise ValueError(
 
91
  correctness guard β€” exhausting it raises ``RuntimeError``. Must be >= 1.
92
  Default 10.
93
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
94
+ One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or
95
+ ``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes
96
+ violations to the causal trajectory that produced them β€” each expert
97
+ accumulates a running mean of its selection log-probability and the loss
98
+ penalises the gap between overloaded and typical trajectories. Like
99
+ ``"temporal_overcapacity"``, it fires only when a violation exists and shuts
100
+ off automatically, making it safe to weight strongly. Default
101
+ ``"causal_overcapacity"``.
102
  maximum_expert_overclaim: Maximum number of tokens an expert may receive above
103
+ its ideal allocation trajectory before either overcapacity loss fires.
104
+ A value of 0 means violations trigger immediately at any imbalance.
105
  Larger values permit short-lived semantic specialization before correction.
106
+ Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``.
107
  Must be non-negative. Default 20.
108
  """
109
 
 
139
  tie_word_embeddings: bool = False,
140
  mosrah_overallocation_factor: float = 2.0,
141
  max_bid_rounds: int = 10,
142
+ load_balance_loss_type: str = "causal_overcapacity",
143
  maximum_expert_overclaim: int = 20,
144
  **kwargs
145
  ):
 
187
  f"got {maximum_expert_overclaim}."
188
  )
189
 
190
+ _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"}
191
  if load_balance_loss_type not in _supported_loss_types:
192
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
193
  raise ValueError(
huggingface.py CHANGED
@@ -178,17 +178,19 @@ class ShramConfig(PretrainedConfig):
178
  correctness guard β€” exhausting it raises ``RuntimeError``. Must be >= 1.
179
  Default 10.
180
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
181
- One of ``"gshard"``, ``"ce"``, ``"bce"``, or ``"temporal_overcapacity"``.
182
- ``"temporal_overcapacity"`` is the default; it fires only when an expert
183
- exceeds its allowed trajectory (controlled by ``maximum_expert_overclaim``)
184
- and shuts off automatically once routing is balanced, allowing it to be
185
- used with a strong weight without interfering with task training during
186
- balanced routing. Default ``"temporal_overcapacity"``.
 
 
187
  maximum_expert_overclaim: Maximum number of tokens an expert may receive above
188
- its ideal allocation trajectory before the temporal overcapacity loss
189
- fires. A value of 0 means violations trigger immediately at any imbalance.
190
  Larger values permit short-lived semantic specialization before correction.
191
- Only used when ``load_balance_loss_type="temporal_overcapacity"``.
192
  Must be non-negative. Default 20.
193
  """
194
 
@@ -224,7 +226,7 @@ class ShramConfig(PretrainedConfig):
224
  tie_word_embeddings: bool = False,
225
  mosrah_overallocation_factor: float = 2.0,
226
  max_bid_rounds: int = 10,
227
- load_balance_loss_type: str = "temporal_overcapacity",
228
  maximum_expert_overclaim: int = 20,
229
  **kwargs
230
  ):
@@ -272,7 +274,7 @@ class ShramConfig(PretrainedConfig):
272
  f"got {maximum_expert_overclaim}."
273
  )
274
 
275
- _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity"}
276
  if load_balance_loss_type not in _supported_loss_types:
277
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
278
  raise ValueError(
@@ -2768,7 +2770,7 @@ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, Β§MaxVio.
2768
  # -----------
2769
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2770
 
2771
- This module provides four load-balance loss formulations, two token-reduction
2772
  helpers, and a factory that selects among the formulations. All formulations
2773
  share the same external contract:
2774
 
@@ -3103,6 +3105,139 @@ def _temporal_overcapacity_loss(
3103
  return final_loss
3104
 
3105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3106
  # ---------------------------------------------------------------------------
3107
  # Factory
3108
  # ---------------------------------------------------------------------------
@@ -3139,11 +3274,32 @@ def _temporal_overcapacity_factory(
3139
  return _runtime
3140
 
3141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3142
  _LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
3143
  "gshard": _gshard_factory,
3144
  "ce": _ce_factory,
3145
  "bce": _bce_factory,
3146
  "temporal_overcapacity": _temporal_overcapacity_factory,
 
3147
  }
3148
 
3149
 
@@ -3163,11 +3319,11 @@ def make_load_balance_loss(
3163
 
3164
  Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
3165
  factories silently ignore all kwargs; this allows callers to pass loss-type-specific
3166
- parameters (e.g. for temporal_overcapacity) without branching on loss_type.
3167
 
3168
  Args:
3169
- loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``, or
3170
- ``"temporal_overcapacity"``.
3171
  **loss_parameters: Construction-time parameters forwarded to the factory.
3172
 
3173
  Returns:
 
178
  correctness guard β€” exhausting it raises ``RuntimeError``. Must be >= 1.
179
  Default 10.
180
  load_balance_loss_type: Formula used for the load-balance auxiliary loss.
181
+ One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or
182
+ ``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes
183
+ violations to the causal trajectory that produced them β€” each expert
184
+ accumulates a running mean of its selection log-probability and the loss
185
+ penalises the gap between overloaded and typical trajectories. Like
186
+ ``"temporal_overcapacity"``, it fires only when a violation exists and shuts
187
+ off automatically, making it safe to weight strongly. Default
188
+ ``"causal_overcapacity"``.
189
  maximum_expert_overclaim: Maximum number of tokens an expert may receive above
190
+ its ideal allocation trajectory before either overcapacity loss fires.
191
+ A value of 0 means violations trigger immediately at any imbalance.
192
  Larger values permit short-lived semantic specialization before correction.
193
+ Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``.
194
  Must be non-negative. Default 20.
195
  """
196
 
 
226
  tie_word_embeddings: bool = False,
227
  mosrah_overallocation_factor: float = 2.0,
228
  max_bid_rounds: int = 10,
229
+ load_balance_loss_type: str = "causal_overcapacity",
230
  maximum_expert_overclaim: int = 20,
231
  **kwargs
232
  ):
 
274
  f"got {maximum_expert_overclaim}."
275
  )
276
 
277
+ _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"}
278
  if load_balance_loss_type not in _supported_loss_types:
279
  supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types))
280
  raise ValueError(
 
2770
  # -----------
2771
  """Log-probability auxiliary loss functions for MoSRAH load balancing.
2772
 
2773
+ This module provides five load-balance loss formulations, two token-reduction
2774
  helpers, and a factory that selects among the formulations. All formulations
2775
  share the same external contract:
2776
 
 
3105
  return final_loss
3106
 
3107
 
3108
+ def _causal_overcapacity_loss(
3109
+ logits: torch.Tensor,
3110
+ assignment_mask: torch.Tensor,
3111
+ active_mask: torch.Tensor,
3112
+ expected_tokens_rate: float,
3113
+ maximum_expert_overclaim: int,
3114
+ ) -> torch.Tensor:
3115
+ """Causal overcapacity loss for MoSRAH load balancing.
3116
+
3117
+ Penalises selected expert trajectories that exceed their ideal cumulative
3118
+ allocation budget. A selected expert assignment is over capacity when its
3119
+ inclusive active assignment count exceeds cumulative_active_tokens * M + C,
3120
+ where M is the expected_tokens_rate (K/L) and C is the
3121
+ maximum_expert_overclaim slack.
3122
+
3123
+ The loss consumes discrete TopK assignment structure but only routes gradients
3124
+ through logits. It returns an fp32 scalar and is exactly inactive when no active
3125
+ selected expert exceeds its allowed trajectory.
3126
+
3127
+ Args:
3128
+ logits: Pre-softmax routing scores, shape (B, N, L).
3129
+ Gradient flows through this tensor.
3130
+ assignment_mask: Per-token head-assignment indicators, shape (B, N, L).
3131
+ 1.0 if token (b, n) is assigned to head l.
3132
+ active_mask: Boolean active-token mask, shape (B, N).
3133
+ expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed
3134
+ by the factory so the division is not repeated each
3135
+ forward pass.
3136
+ maximum_expert_overclaim (C): Slack above the ideal trajectory before
3137
+ overcapacity fires. Larger C tolerates more deviation.
3138
+
3139
+ Returns:
3140
+ Scalar fp32 loss tensor. Exactly 0.0 when no active selected expert exceeds
3141
+ its allowed trajectory. Can be interpreted as the difference in nats of preference
3142
+ between the violating and typical paths.
3143
+ """
3144
+ # ── Algorithm overview ──────────────────────────────────────────────────────
3145
+ #
3146
+ # Expert selections form causal trajectories through the sequence. Each trajectory
3147
+ # is scored by the mean signed nats of the selected routing events that produced
3148
+ # it: larger trajectory nats mean the router preferred that path more strongly.
3149
+ #
3150
+ # When a selected trajectory exceeds its cumulative budget, the loss forms a
3151
+ # preference contrast between the violating trajectory field and the baseline
3152
+ # trajectory field. Minimizing that contrast suppresses the over-preferred path
3153
+ # while lifting alternatives through the router softmax.
3154
+ #
3155
+ # This is not precisely equivalent to log likihood due to the selection
3156
+ # of multiple experts per round, but we deem this issue to be insignificant.
3157
+
3158
+ # ── Process setup ────────────────────────────────────────────────────────────
3159
+ #
3160
+ # A small amount of standardization is needed before the loss-specific trajectory
3161
+ # logic begins. Active selected assignments define the event structure. Routing
3162
+ # log-probabilities remain the only differentiable source and are computed in fp32
3163
+ # so the downstream trajectory accumulation does not inherit reduced precision.
3164
+
3165
+ selected_assignment_mask = assignment_mask.bool() # (B, N, L)
3166
+ active_assignment_mask = selected_assignment_mask & active_mask.unsqueeze(-1) # (B, N, L)
3167
+ routing_log_probability = F.log_softmax(logits.float(), dim=-1) # (B, N, L)
3168
+
3169
+ # ── Mask construction ────────────────────────────────────────────────────────
3170
+ #
3171
+ # The corrective target set is defined by active selected assignments whose
3172
+ # inclusive count crosses the allowed causal budget. Position and sequence masks
3173
+ # identify where that target set exists; they are reduction structure, not a
3174
+ # separate source of gradient.
3175
+
3176
+ inclusive_assignment_count = active_assignment_mask.to(torch.int32).cumsum(dim=1) # (B, N, L)
3177
+ inclusive_active_token_count = active_mask.to(torch.int32).cumsum(dim=1) # (B, N)
3178
+
3179
+ maximum_allowed_assignment_count = (
3180
+ inclusive_active_token_count.float().unsqueeze(-1) * expected_tokens_rate
3181
+ + maximum_expert_overclaim
3182
+ ) # (B, N, 1) β†’ broadcasts to (B, N, L)
3183
+
3184
+ violating_assignment_mask = ( # (B, N, L)
3185
+ active_assignment_mask
3186
+ & (inclusive_assignment_count.float() > maximum_allowed_assignment_count)
3187
+ )
3188
+ has_violation_at_position = violating_assignment_mask.any(dim=-1) # (B, N)
3189
+ has_violation_in_sequence = has_violation_at_position.any(dim=-1) # (B,)
3190
+
3191
+ # ── Trajectory construction ──────────────────────────────────────────────────
3192
+ #
3193
+ # The current selection is part of the trajectory being judged, so the trajectory
3194
+ # score is inclusive. Empty histories intentionally receive the neutral zero score;
3195
+ # this keeps the later baseline compact without introducing a second eligibility
3196
+ # system.
3197
+
3198
+ selected_trajectory_nat_sum = ( # (B, N, L)
3199
+ active_assignment_mask.float() * routing_log_probability
3200
+ ).cumsum(dim=1)
3201
+ mean_selected_trajectory_nats = ( # (B, N, L)
3202
+ selected_trajectory_nat_sum
3203
+ / inclusive_assignment_count.clamp(min=1).float()
3204
+ )
3205
+
3206
+ # ── Contrast construction ────────────────────────────────────────────────────
3207
+ #
3208
+ # This is the correction moment. The violating trajectory field is compared to
3209
+ # the baseline trajectory field at the same sequence position, producing a signed
3210
+ # preference contrast measured in nats.
3211
+
3212
+ violating_assignment_count = violating_assignment_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N)
3213
+ mean_violating_trajectory_nats = ( # (B, N)
3214
+ (violating_assignment_mask.float() * mean_selected_trajectory_nats).sum(dim=-1)
3215
+ / violating_assignment_count
3216
+ )
3217
+ mean_baseline_trajectory_nats = mean_selected_trajectory_nats.mean(dim=-1) # (B, N)
3218
+ contrastive_preference_nats = ( # (B, N)
3219
+ mean_violating_trajectory_nats
3220
+ - mean_baseline_trajectory_nats
3221
+ )
3222
+
3223
+ # ── Violation-only reduction ─────────────────────────────────────────────────
3224
+ #
3225
+ # Non-violating positions and sequences are not anchors for this loss. The scalar
3226
+ # is an average violation contrast, not total violation mass, and the entire loss
3227
+ # remains exactly inactive when no corrective target exists.
3228
+
3229
+ violation_position_count = has_violation_at_position.float().sum(dim=-1).clamp(min=1.0) # (B,)
3230
+ sequence_preference_nats = ( # (B,)
3231
+ (contrastive_preference_nats * has_violation_at_position.float()).sum(dim=-1)
3232
+ / violation_position_count
3233
+ )
3234
+ violating_sequence_count = has_violation_in_sequence.float().sum().clamp(min=1.0) # scalar
3235
+ final_loss = ( # scalar
3236
+ sequence_preference_nats * has_violation_in_sequence.float()
3237
+ ).sum() / violating_sequence_count
3238
+ return final_loss
3239
+
3240
+
3241
  # ---------------------------------------------------------------------------
3242
  # Factory
3243
  # ---------------------------------------------------------------------------
 
3274
  return _runtime
3275
 
3276
 
3277
+ def _causal_overcapacity_factory(
3278
+ num_selected_heads: int,
3279
+ num_total_heads: int,
3280
+ maximum_expert_overclaim: int,
3281
+ **kwargs: object,
3282
+ ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
3283
+ expected_tokens_rate = num_selected_heads / num_total_heads
3284
+ def _runtime(
3285
+ logits: torch.Tensor,
3286
+ assignment_mask: torch.Tensor,
3287
+ active_mask: torch.Tensor,
3288
+ ) -> torch.Tensor:
3289
+ return _causal_overcapacity_loss(
3290
+ logits, assignment_mask, active_mask,
3291
+ expected_tokens_rate=expected_tokens_rate,
3292
+ maximum_expert_overclaim=maximum_expert_overclaim,
3293
+ )
3294
+ return _runtime
3295
+
3296
+
3297
  _LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = {
3298
  "gshard": _gshard_factory,
3299
  "ce": _ce_factory,
3300
  "bce": _bce_factory,
3301
  "temporal_overcapacity": _temporal_overcapacity_factory,
3302
+ "causal_overcapacity": _causal_overcapacity_factory,
3303
  }
3304
 
3305
 
 
3319
 
3320
  Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce
3321
  factories silently ignore all kwargs; this allows callers to pass loss-type-specific
3322
+ parameters (e.g. for overcapacity losses) without branching on loss_type.
3323
 
3324
  Args:
3325
+ loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``,
3326
+ ``"temporal_overcapacity"``, or ``"causal_overcapacity"``.
3327
  **loss_parameters: Construction-time parameters forwarded to the factory.
3328
 
3329
  Returns: