smithblack-0 commited on
Commit
b9eaded
Β·
verified Β·
1 Parent(s): 6ad9cf3

Update architecture and tokenizer

Browse files
Files changed (1) hide show
  1. huggingface.py +45 -62
huggingface.py CHANGED
@@ -2872,8 +2872,20 @@ class MoSRAHRouter(nn.Module):
2872
  else:
2873
  element_included = ranks < n.unsqueeze(positive_dim)
2874
 
2875
- mask = torch.zeros_like(tensor, dtype=torch.bool)
2876
- mask.scatter_(dim, topk_indices, element_included.expand_as(topk_indices))
 
 
 
 
 
 
 
 
 
 
 
 
2877
  return mask
2878
 
2879
  @staticmethod
@@ -2910,8 +2922,9 @@ class MoSRAHRouter(nn.Module):
2910
 
2911
  Tokens propose experts in descending preference order; experts provisionally
2912
  accept their top-``remaining_capacity`` proposed tokens each round. Proposals
2913
- are monotone (never retracted). The loop continues until every token has at
2914
- least ``min_choices`` accepted experts or ``max_rounds`` is exhausted.
 
2915
 
2916
  Both the column bound (per-expert token count ≀ remaining_capacity) and the
2917
  row bound (per-token expert count β‰₯ min_choices) are satisfied simultaneously
@@ -2922,35 +2935,26 @@ class MoSRAHRouter(nn.Module):
2922
  remaining_capacity: Per-expert token budget. Scalar int for training;
2923
  (B, L) tensor for inference.
2924
  min_choices: Minimum experts each token must have accepted (K).
2925
- max_rounds: Iteration ceiling; raises via ``_check_bidding_converged``
2926
- if exhausted.
2927
  capacity_scalar: Static upper bound on remaining_capacity, passed to
2928
  ``get_mask`` as the topk k bound for the acceptance step.
2929
 
2930
  Returns:
2931
  accepted: (B, N, L) bool β€” True at positions accepted by the solver.
2932
  """
2933
- # ── initialise loop variables ─────────────────────────────────────────
2934
- #
2935
- # All three loop_vars must be tensors of fixed shape across iterations,
2936
- # as required by torch.while_loop. logits and remaining_capacity are
2937
- # captured read-only by the closures; they do not travel as loop_vars.
2938
- proposals = torch.zeros_like(logits, dtype=torch.bool)
2939
  acceptances = torch.zeros_like(logits, dtype=torch.bool)
2940
- round_count = torch.zeros((), device=logits.device, dtype=torch.int64)
2941
- max_rounds_t = torch.full((), max_rounds, device=logits.device, dtype=torch.int64)
2942
-
2943
- def cond_fn(proposals, acceptances, round_count):
2944
- all_satisfied = (acceptances.sum(dim=-1) >= min_choices).all()
2945
- return (round_count < max_rounds_t) & ~all_satisfied
2946
 
2947
- def body_fn(proposals, acceptances, round_count):
 
 
 
2948
  # ── token proposal step ───────────────────────────────────────────
2949
  #
2950
  # Tokens with fewer than min_choices accepted experts propose their
2951
  # next-best unproposed expert(s). The deficit determines how many new
2952
- # proposals each token makes this round; already-satisfied tokens
2953
- # propose nothing (deficit = 0 β†’ get_mask returns all-False).
2954
  accepted_per_token = acceptances.sum(dim=-1) # (B, N)
2955
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2956
 
@@ -2969,12 +2973,18 @@ class MoSRAHRouter(nn.Module):
2969
  updated_acceptances = cls.get_mask(
2970
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2971
  )
 
2972
 
2973
- return updated_proposals, updated_acceptances, round_count + 1
 
 
 
 
 
 
 
 
2974
 
2975
- proposals, acceptances, _ = torch.while_loop(
2976
- cond_fn, body_fn, (proposals, acceptances, round_count),
2977
- )
2978
  return acceptances
2979
 
2980
  @classmethod
@@ -2993,13 +3003,10 @@ class MoSRAHRouter(nn.Module):
2993
  - Column bound: per-expert unmasked token count ≀ remaining_capacity.
2994
  - Row bound: per-token unmasked expert count β‰₯ min_choices.
2995
 
2996
- A training fast path and a column-capacity fast path are attempted before
2997
- falling back to the bidding solver:
2998
 
2999
  1. Training with N ≀ capacity: return logits unchanged.
3000
- 2. Column-capacity fast path: if the most permissive column-bound-satisfying
3001
- mask already gives every token at least min_choices choices, return it.
3002
- 3. Bidding fallback: deferred-acceptance solver guaranteeing both bounds.
3003
 
3004
  Args:
3005
  logits: Routing scores of shape (B, N, L).
@@ -3029,15 +3036,8 @@ class MoSRAHRouter(nn.Module):
3029
  # terminates when every token has min_choices accepted experts or
3030
  # max_bid_rounds is exhausted (RuntimeError in the latter case).
3031
  #
3032
- # Two cheaper paths precede the solver:
3033
- #
3034
- # Training fast path β€” when N ≀ capacity and all experts start empty,
3035
- # no expert can overflow regardless of routing. No masking is needed.
3036
- #
3037
- # Column-capacity fast path β€” the most permissive mask satisfying the
3038
- # column bound selects each expert's top-remaining_capacity tokens. If
3039
- # that mask also satisfies the row bound, both constraints hold and the
3040
- # solver is skipped entirely.
3041
 
3042
  # Training fast path: N ≀ capacity with empty experts β†’ no overflow possible.
3043
  if used_capacity is None and logits.shape[-2] <= capacity:
@@ -3052,32 +3052,15 @@ class MoSRAHRouter(nn.Module):
3052
  else:
3053
  remaining_capacity = (capacity - used_capacity).clamp(min=0) # (B, L)
3054
 
3055
- # Column-capacity fast path: select each expert's top-remaining_capacity
3056
- # tokens β€” the most permissive mask satisfying the column bound. If it
3057
- # also satisfies the row bound, both constraints hold simultaneously.
3058
- # Mask computation runs under no_grad: the boolean mask is a hard routing
3059
- # decision and must not accumulate gradient memory through the solver.
3060
- def skip(mask: torch.Tensor, logits: torch.Tensor)->torch.Tensor:
3061
- """Skip bidding on the mask"""
3062
- return mask.clone()
3063
-
3064
- def resolve_mask(mask: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
3065
- """Execute full bidding process"""
3066
- return cls._run_bidding(logits,
3067
- remaining_capacity,
3068
- min_choices,
3069
- max_rounds,
3070
- capacity)
3071
-
3072
  with torch.no_grad():
3073
- col_capacity_mask = cls.get_mask(logits,
3074
- dim=-2,
3075
- n=remaining_capacity,
3076
- capacity_scalar=capacity)
3077
- mask_sufficient = (col_capacity_mask.sum(dim=-1) >= min_choices).all()
3078
- final_mask = torch.cond(mask_sufficient, skip, resolve_mask, [col_capacity_mask, logits])
3079
  cls._check_bidding_converged(final_mask, min_choices, max_rounds)
3080
  return logits.masked_fill(~final_mask, mask_value)
 
3081
  def forward(
3082
  self,
3083
  x: torch.Tensor,
 
2872
  else:
2873
  element_included = ranks < n.unsqueeze(positive_dim)
2874
 
2875
+ # Allocate from explicit logical shape rather than using zeros_like. This keeps
2876
+ # the output mask tied to tensor.shape, not to any stride/layout metadata carried
2877
+ # by tensor from earlier view operations or compiler lowering.
2878
+ mask = torch.zeros(
2879
+ tuple(tensor.shape),
2880
+ device=tensor.device,
2881
+ dtype=torch.bool,
2882
+ )
2883
+
2884
+ # Materialize the scatter source shape explicitly. This avoids passing a
2885
+ # broadcast-view source into scatter while preserving the same logical rule:
2886
+ # every selected top-k index receives True iff its rank is within budget.
2887
+ scatter_values = torch.broadcast_to(element_included, topk_indices.shape)
2888
+ mask = mask.scatter(dim, topk_indices, scatter_values)
2889
  return mask
2890
 
2891
  @staticmethod
 
2922
 
2923
  Tokens propose experts in descending preference order; experts provisionally
2924
  accept their top-``remaining_capacity`` proposed tokens each round. Proposals
2925
+ are monotone (never retracted). Runs for exactly ``max_rounds`` iterations;
2926
+ each round is skipped via ``torch.cond`` once all tokens are satisfied, so
2927
+ subsequent iterations are no-ops without data-dependent Python control flow.
2928
 
2929
  Both the column bound (per-expert token count ≀ remaining_capacity) and the
2930
  row bound (per-token expert count β‰₯ min_choices) are satisfied simultaneously
 
2935
  remaining_capacity: Per-expert token budget. Scalar int for training;
2936
  (B, L) tensor for inference.
2937
  min_choices: Minimum experts each token must have accepted (K).
2938
+ max_rounds: Number of iterations to run. Convergence is checked after
2939
+ all rounds via ``_check_bidding_converged``; raises if not met.
2940
  capacity_scalar: Static upper bound on remaining_capacity, passed to
2941
  ``get_mask`` as the topk k bound for the acceptance step.
2942
 
2943
  Returns:
2944
  accepted: (B, N, L) bool β€” True at positions accepted by the solver.
2945
  """
2946
+ proposals = torch.zeros_like(logits, dtype=torch.bool)
 
 
 
 
 
2947
  acceptances = torch.zeros_like(logits, dtype=torch.bool)
 
 
 
 
 
 
2948
 
2949
+ # Branch functions defined once so Dynamo sees stable function objects
2950
+ # across all loop iterations. logits, remaining_capacity, min_choices, and
2951
+ # capacity_scalar are captured read-only from the enclosing scope.
2952
+ def body_fn(proposals: torch.Tensor, acceptances: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
2953
  # ── token proposal step ───────────────────────────────────────────
2954
  #
2955
  # Tokens with fewer than min_choices accepted experts propose their
2956
  # next-best unproposed expert(s). The deficit determines how many new
2957
+ # proposals each token makes; satisfied tokens propose nothing.
 
2958
  accepted_per_token = acceptances.sum(dim=-1) # (B, N)
2959
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2960
 
 
2973
  updated_acceptances = cls.get_mask(
2974
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2975
  )
2976
+ return updated_proposals, updated_acceptances
2977
 
2978
+ def skip_fn(proposals: torch.Tensor, acceptances: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
2979
+ # Already converged β€” return clones so torch.cond aliasing rule is satisfied.
2980
+ return proposals.clone(), acceptances.clone()
2981
+
2982
+ for _ in range(max_rounds):
2983
+ # Skip this round if every token already has min_choices accepted experts.
2984
+ # torch.cond avoids data-dependent Python branches in compiled graphs.
2985
+ not_done = ~(acceptances.sum(dim=-1) >= min_choices).all()
2986
+ proposals, acceptances = torch.cond(not_done, body_fn, skip_fn, [proposals, acceptances])
2987
 
 
 
 
2988
  return acceptances
2989
 
2990
  @classmethod
 
3003
  - Column bound: per-expert unmasked token count ≀ remaining_capacity.
3004
  - Row bound: per-token unmasked expert count β‰₯ min_choices.
3005
 
3006
+ A training fast path is attempted before the bidding solver:
 
3007
 
3008
  1. Training with N ≀ capacity: return logits unchanged.
3009
+ 2. Bidding: deferred-acceptance solver guaranteeing both bounds simultaneously.
 
 
3010
 
3011
  Args:
3012
  logits: Routing scores of shape (B, N, L).
 
3036
  # terminates when every token has min_choices accepted experts or
3037
  # max_bid_rounds is exhausted (RuntimeError in the latter case).
3038
  #
3039
+ # Training fast path β€” when N ≀ capacity and all experts start empty,
3040
+ # no expert can overflow regardless of routing. No masking is needed.
 
 
 
 
 
 
 
3041
 
3042
  # Training fast path: N ≀ capacity with empty experts β†’ no overflow possible.
3043
  if used_capacity is None and logits.shape[-2] <= capacity:
 
3052
  else:
3053
  remaining_capacity = (capacity - used_capacity).clamp(min=0) # (B, L)
3054
 
3055
+ # Bidding solver: jointly satisfies column and row bounds. Runs under
3056
+ # no_grad because the boolean mask is a hard routing decision and must
3057
+ # not accumulate gradient memory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3058
  with torch.no_grad():
3059
+ final_mask = cls._run_bidding(logits, remaining_capacity,
3060
+ min_choices, max_rounds, capacity)
 
 
 
 
3061
  cls._check_bidding_converged(final_mask, min_choices, max_rounds)
3062
  return logits.masked_fill(~final_mask, mask_value)
3063
+
3064
  def forward(
3065
  self,
3066
  x: torch.Tensor,