smithblack-0 commited on
Commit
6ad9cf3
·
verified ·
1 Parent(s): 651c51b

Update architecture and tokenizer

Browse files
Files changed (1) hide show
  1. huggingface.py +34 -36
huggingface.py CHANGED
@@ -2877,34 +2877,25 @@ class MoSRAHRouter(nn.Module):
2877
  return mask
2878
 
2879
  @staticmethod
2880
- def _check_bidding_converged(converged: torch.Tensor, max_rounds: int) -> None:
 
 
2881
  """Raise if the bidding loop exhausted max_rounds without satisfying all tokens.
2882
 
2883
- In compiled mode ``torch._check`` fires a C++ assertion
2884
- (``capture_scalar_outputs=True`` is a precondition — see Unit 19.F.1).
2885
- In eager mode raises ``RuntimeError`` directly.
2886
-
2887
- Exhausting ``max_rounds`` indicates an extreme routing density case or an
2888
- infeasible configuration where total capacity is insufficient for N * K
2889
- demands. In normal training this should never occur; the default
2890
- ``max_bid_rounds=10`` covers approximately the 98th percentile of routing
2891
- densities.
2892
-
2893
  Args:
2894
- converged: Scalar bool tensor True if all tokens have >= K accepted experts.
2895
- max_rounds: The iteration ceiling that was applied, for the error message.
 
 
 
 
2896
  """
2897
- if torch.compiler.is_compiling():
2898
- torch._check(converged)
2899
- else:
2900
- if not converged.item():
2901
- raise RuntimeError(
2902
- f"balance_capacity bidding did not converge within {max_rounds} rounds. "
2903
- f"All tokens must have at least K accepted experts before the loop exits. "
2904
- f"This indicates either an infeasible configuration (total remaining "
2905
- f"capacity < N * K) or an extreme routing density. "
2906
- f"Increase mosrah_overallocation_factor or max_bid_rounds."
2907
- )
2908
 
2909
  @classmethod
2910
  def _run_bidding(
@@ -2984,9 +2975,6 @@ class MoSRAHRouter(nn.Module):
2984
  proposals, acceptances, _ = torch.while_loop(
2985
  cond_fn, body_fn, (proposals, acceptances, round_count),
2986
  )
2987
-
2988
- converged = (acceptances.sum(dim=-1) >= min_choices).all()
2989
- cls._check_bidding_converged(converged, max_rounds)
2990
  return acceptances
2991
 
2992
  @classmethod
@@ -3069,17 +3057,27 @@ class MoSRAHRouter(nn.Module):
3069
  # also satisfies the row bound, both constraints hold simultaneously.
3070
  # Mask computation runs under no_grad: the boolean mask is a hard routing
3071
  # decision and must not accumulate gradient memory through the solver.
3072
- with torch.no_grad():
3073
- col_capacity_mask = cls.get_mask(logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity)
3074
- if (col_capacity_mask.sum(dim=-1) >= min_choices).all():
3075
- return logits.masked_fill(~col_capacity_mask, mask_value)
 
 
 
 
 
 
 
3076
 
3077
- # Column-capacity mask violates the row bound: routing is concentrated
3078
- # enough that per-expert capacity limits leave some tokens with fewer
3079
- # than min_choices choices. The bidding solver handles this jointly.
3080
  with torch.no_grad():
3081
- accepted = cls._run_bidding(logits, remaining_capacity, min_choices, max_rounds, capacity)
3082
- return logits.masked_fill(~accepted, mask_value)
 
 
 
 
 
 
3083
  def forward(
3084
  self,
3085
  x: torch.Tensor,
 
2877
  return mask
2878
 
2879
  @staticmethod
2880
+ def _check_bidding_converged(acceptances: torch.Tensor,
2881
+ min_choices: int,
2882
+ max_rounds: int) -> None:
2883
  """Raise if the bidding loop exhausted max_rounds without satisfying all tokens.
2884
 
 
 
 
 
 
 
 
 
 
 
2885
  Args:
2886
+ acceptances: bool tensor of shape (B, N, L) indicating what experts L accepted
2887
+ what tokens.
2888
+ min_choices: Convergence has been reached if acceptances are such that a sum along
2889
+ N always has at least min_choices choices.
2890
+ max_rounds: The iteration ceiling that was applied, for the error message. Used
2891
+ for reporting
2892
  """
2893
+ msg = (
2894
+ f"balance_capacity bidding did not converge within {max_rounds} rounds. "
2895
+ f"Increase mosrah_overallocation_factor or max_bid_rounds."
2896
+ )
2897
+ converged = (acceptances.sum(dim=-1) >= min_choices).all()
2898
+ torch._assert_async(converged, msg)
 
 
 
 
 
2899
 
2900
  @classmethod
2901
  def _run_bidding(
 
2975
  proposals, acceptances, _ = torch.while_loop(
2976
  cond_fn, body_fn, (proposals, acceptances, round_count),
2977
  )
 
 
 
2978
  return acceptances
2979
 
2980
  @classmethod
 
3057
  # also satisfies the row bound, both constraints hold simultaneously.
3058
  # Mask computation runs under no_grad: the boolean mask is a hard routing
3059
  # decision and must not accumulate gradient memory through the solver.
3060
+ def skip(mask: torch.Tensor, logits: torch.Tensor)->torch.Tensor:
3061
+ """Skip bidding on the mask"""
3062
+ return mask.clone()
3063
+
3064
+ def resolve_mask(mask: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
3065
+ """Execute full bidding process"""
3066
+ return cls._run_bidding(logits,
3067
+ remaining_capacity,
3068
+ min_choices,
3069
+ max_rounds,
3070
+ capacity)
3071
 
 
 
 
3072
  with torch.no_grad():
3073
+ col_capacity_mask = cls.get_mask(logits,
3074
+ dim=-2,
3075
+ n=remaining_capacity,
3076
+ capacity_scalar=capacity)
3077
+ mask_sufficient = (col_capacity_mask.sum(dim=-1) >= min_choices).all()
3078
+ final_mask = torch.cond(mask_sufficient, skip, resolve_mask, [col_capacity_mask, logits])
3079
+ cls._check_bidding_converged(final_mask, min_choices, max_rounds)
3080
+ return logits.masked_fill(~final_mask, mask_value)
3081
  def forward(
3082
  self,
3083
  x: torch.Tensor,