smithblack-0 commited on
Commit
28f3eff
Β·
verified Β·
1 Parent(s): b9eaded

Update architecture and tokenizer

Browse files
Files changed (1) hide show
  1. huggingface.py +10 -22
huggingface.py CHANGED
@@ -2922,9 +2922,9 @@ class MoSRAHRouter(nn.Module):
2922
 
2923
  Tokens propose experts in descending preference order; experts provisionally
2924
  accept their top-``remaining_capacity`` proposed tokens each round. Proposals
2925
- are monotone (never retracted). Runs for exactly ``max_rounds`` iterations;
2926
- each round is skipped via ``torch.cond`` once all tokens are satisfied, so
2927
- subsequent iterations are no-ops without data-dependent Python control flow.
2928
 
2929
  Both the column bound (per-expert token count ≀ remaining_capacity) and the
2930
  row bound (per-token expert count β‰₯ min_choices) are satisfied simultaneously
@@ -2946,15 +2946,14 @@ class MoSRAHRouter(nn.Module):
2946
  proposals = torch.zeros_like(logits, dtype=torch.bool)
2947
  acceptances = torch.zeros_like(logits, dtype=torch.bool)
2948
 
2949
- # Branch functions defined once so Dynamo sees stable function objects
2950
- # across all loop iterations. logits, remaining_capacity, min_choices, and
2951
- # capacity_scalar are captured read-only from the enclosing scope.
2952
- def body_fn(proposals: torch.Tensor, acceptances: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
2953
  # ── token proposal step ───────────────────────────────────────────
2954
  #
2955
  # Tokens with fewer than min_choices accepted experts propose their
2956
  # next-best unproposed expert(s). The deficit determines how many new
2957
- # proposals each token makes; satisfied tokens propose nothing.
 
 
2958
  accepted_per_token = acceptances.sum(dim=-1) # (B, N)
2959
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2960
 
@@ -2962,28 +2961,17 @@ class MoSRAHRouter(nn.Module):
2962
  new_proposals = cls.get_mask(
2963
  unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
2964
  )
2965
- updated_proposals = proposals | new_proposals
2966
 
2967
  # ── expert acceptance step ────────────────────────────────────────
2968
  #
2969
  # Each expert accepts its top-remaining_capacity proposed tokens.
2970
  # Acceptances are recomputed from scratch each round so that a
2971
  # stronger new proposal can displace a weaker prior one.
2972
- proposed_logits = logits.masked_fill(~updated_proposals, float('-inf'))
2973
- updated_acceptances = cls.get_mask(
2974
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2975
  )
2976
- return updated_proposals, updated_acceptances
2977
-
2978
- def skip_fn(proposals: torch.Tensor, acceptances: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
2979
- # Already converged β€” return clones so torch.cond aliasing rule is satisfied.
2980
- return proposals.clone(), acceptances.clone()
2981
-
2982
- for _ in range(max_rounds):
2983
- # Skip this round if every token already has min_choices accepted experts.
2984
- # torch.cond avoids data-dependent Python branches in compiled graphs.
2985
- not_done = ~(acceptances.sum(dim=-1) >= min_choices).all()
2986
- proposals, acceptances = torch.cond(not_done, body_fn, skip_fn, [proposals, acceptances])
2987
 
2988
  return acceptances
2989
 
 
2922
 
2923
  Tokens propose experts in descending preference order; experts provisionally
2924
  accept their top-``remaining_capacity`` proposed tokens each round. Proposals
2925
+ are monotone (never retracted), so once all tokens are satisfied, subsequent
2926
+ iterations are no-ops. Runs unconditionally for exactly ``max_rounds`` iterations
2927
+ to keep the compiled graph flat and free of data-dependent control flow.
2928
 
2929
  Both the column bound (per-expert token count ≀ remaining_capacity) and the
2930
  row bound (per-token expert count β‰₯ min_choices) are satisfied simultaneously
 
2946
  proposals = torch.zeros_like(logits, dtype=torch.bool)
2947
  acceptances = torch.zeros_like(logits, dtype=torch.bool)
2948
 
2949
+ for _ in range(max_rounds):
 
 
 
2950
  # ── token proposal step ───────────────────────────────────────────
2951
  #
2952
  # Tokens with fewer than min_choices accepted experts propose their
2953
  # next-best unproposed expert(s). The deficit determines how many new
2954
+ # proposals each token makes; satisfied tokens propose nothing
2955
+ # (deficit = 0 β†’ get_mask returns all-False). Proposals are monotone:
2956
+ # once all tokens are satisfied, subsequent iterations are no-ops.
2957
  accepted_per_token = acceptances.sum(dim=-1) # (B, N)
2958
  choices_deficit = (min_choices - accepted_per_token).clamp_min(0)
2959
 
 
2961
  new_proposals = cls.get_mask(
2962
  unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices,
2963
  )
2964
+ proposals = proposals | new_proposals
2965
 
2966
  # ── expert acceptance step ────────────────────────────────────────
2967
  #
2968
  # Each expert accepts its top-remaining_capacity proposed tokens.
2969
  # Acceptances are recomputed from scratch each round so that a
2970
  # stronger new proposal can displace a weaker prior one.
2971
+ proposed_logits = logits.masked_fill(~proposals, float('-inf'))
2972
+ acceptances = cls.get_mask(
2973
  proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar,
2974
  )
 
 
 
 
 
 
 
 
 
 
 
2975
 
2976
  return acceptances
2977