diff --git "a/huggingface.py" "b/huggingface.py" --- "a/huggingface.py" +++ "b/huggingface.py" @@ -44,7 +44,6 @@ from torch import nn from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import flex_attention import torch.nn.functional as F -from typing import Callable @@ -166,32 +165,6 @@ class ShramConfig(PretrainedConfig): use_cache: Whether to return past_key_values for KV caching. output_hidden_states: Whether to return hidden states after each layer. tie_word_embeddings: Whether input embedding and LM head share weights. - mosrah_overallocation_factor: Overallocation multiplier for the expert packing - buffer. ``mosrah_packed_length`` = ceil(training_sequence_length * - num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor). - Must be > 1.0 to guarantee a buffer larger than the balanced-routing - baseline. Default 2.0. - max_bid_rounds: Maximum bidding rounds for the deferred-acceptance capacity - solver in ``balance_capacity``. 10 covers convergence at approximately - the 98th percentile of routing densities; the top 2% of extreme-density - cases are not expected under normal training. The bound exists as a - correctness guard — exhausting it raises ``RuntimeError``. Must be >= 1. - Default 10. - load_balance_loss_type: Formula used for the load-balance auxiliary loss. - One of ``"gshard"``, ``"ce"``, ``"bce"``, ``"temporal_overcapacity"``, or - ``"causal_overcapacity"``. ``"causal_overcapacity"`` (the default) attributes - violations to the causal trajectory that produced them — each expert - accumulates a running mean of its selection log-probability and the loss - penalises the gap between overloaded and typical trajectories. Like - ``"temporal_overcapacity"``, it fires only when a violation exists and shuts - off automatically, making it safe to weight strongly. Default - ``"causal_overcapacity"``. - maximum_expert_overclaim: Maximum number of tokens an expert may receive above - its ideal allocation trajectory before either overcapacity loss fires. - A value of 0 means violations trigger immediately at any imbalance. - Larger values permit short-lived semantic specialization before correction. - Used by both ``"temporal_overcapacity"`` and ``"causal_overcapacity"``. - Must be non-negative. Default 20. """ model_type = "shram" @@ -224,10 +197,6 @@ class ShramConfig(PretrainedConfig): use_cache: bool = True, output_hidden_states: bool = False, tie_word_embeddings: bool = False, - mosrah_overallocation_factor: float = 2.0, - max_bid_rounds: int = 10, - load_balance_loss_type: str = "causal_overcapacity", - maximum_expert_overclaim: int = 20, **kwargs ): if head_dim % 2 != 0: @@ -256,30 +225,13 @@ class ShramConfig(PretrainedConfig): f"got {inference_sequence_length}." ) - if mosrah_overallocation_factor <= 1.0: + if num_mosrah_heads % num_selected_heads != 0: raise ValueError( - f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed " - f"buffer larger than the balanced-routing baseline. " - f"Got {mosrah_overallocation_factor}." - ) - - if max_bid_rounds < 1: - raise ValueError( - f"max_bid_rounds must be at least 1, got {max_bid_rounds}." - ) - - if maximum_expert_overclaim < 0: - raise ValueError( - f"maximum_expert_overclaim must be non-negative, " - f"got {maximum_expert_overclaim}." - ) - - _supported_loss_types = {"gshard", "ce", "bce", "temporal_overcapacity", "causal_overcapacity"} - if load_balance_loss_type not in _supported_loss_types: - supported = ", ".join(f'"{t}"' for t in sorted(_supported_loss_types)) - raise ValueError( - f"load_balance_loss_type must be one of {supported}, " - f"got {load_balance_loss_type!r}." + f"num_mosrah_heads must be exactly divisible by num_selected_heads. " + f"Mechanical load balancing partitions the sequence into blocks of " + f"W = num_mosrah_heads // num_selected_heads tokens; each block covers " + f"every expert exactly once, which requires an integer W. " + f"Got num_mosrah_heads={num_mosrah_heads}, num_selected_heads={num_selected_heads}." ) self.vocab_size = vocab_size @@ -299,10 +251,6 @@ class ShramConfig(PretrainedConfig): self.inference_sequence_length = inference_sequence_length self.alpha = alpha self.beta = beta - self.mosrah_overallocation_factor = mosrah_overallocation_factor - self.max_bid_rounds = max_bid_rounds - self.load_balance_loss_type = load_balance_loss_type - self.maximum_expert_overclaim = maximum_expert_overclaim self.attention_dropout = attention_dropout self.use_cache = use_cache @@ -329,10 +277,10 @@ class ShramConfig(PretrainedConfig): def mosrah_packed_length(self) -> int: """Static packed time dimension T for expert packing. - The expected tokens per expert under perfectly balanced routing is - ``training_sequence_length * num_selected_heads / num_mosrah_heads``. - Multiplying by ``mosrah_overallocation_factor`` provides a buffer above - that baseline. The ceiling ensures T is always an integer >= 1. + Mechanical load balancing guarantees exactly + ``training_sequence_length * num_selected_heads / num_mosrah_heads`` + tokens per expert. The ceiling handles non-integer results when + training_sequence_length is not divisible by the block length W. All consumers of the packed buffer size must read this property rather than deriving T independently. @@ -341,34 +289,44 @@ class ShramConfig(PretrainedConfig): self.training_sequence_length * self.num_selected_heads / self.num_mosrah_heads - * self.mosrah_overallocation_factor - ) + ) + self.block_length @property def mosrah_cache_length(self) -> int: """Static per-(batch, head) slot capacity for the MoSRAH inference cache. - The expected tokens per expert over the full inference context under perfectly - balanced routing is ``inference_sequence_length * num_selected_heads / - num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides - a buffer above that baseline. The ceiling ensures the result is always an - integer >= 1. + Mechanical load balancing guarantees exactly + ``inference_sequence_length * num_selected_heads / num_mosrah_heads`` + tokens per expert over the full inference context. The ceiling handles + non-integer results when inference_sequence_length is not divisible by + the block length W. - Distinct from ``mosrah_packed_length``, which sizes the training packing buffer - using ``training_sequence_length``. This property uses - ``inference_sequence_length`` because the cache must hold the full accumulated - token history across the entire inference run. + Distinct from ``mosrah_packed_length``, which sizes the training packing + buffer using ``training_sequence_length``. This property uses + ``inference_sequence_length`` because the cache must hold the full + accumulated token history across the entire inference run. - All consumers of the MoSRAH cache buffer size must read this property rather - than deriving the capacity independently. + All consumers of the MoSRAH cache buffer size must read this property + rather than deriving the capacity independently. """ return math.ceil( self.inference_sequence_length * self.num_selected_heads / self.num_mosrah_heads - * self.mosrah_overallocation_factor - ) + ) + self.block_length + + @property + def block_length(self) -> int: + """Routing block length W = num_mosrah_heads // num_selected_heads. + + Within each block of W consecutive tokens every expert is used exactly once, + giving perfect load balance by construction. The E % K == 0 constraint + enforced at construction guarantees W is an exact integer. + All consumers of the routing block length must read this property rather + than deriving W independently. + """ + return self.num_mosrah_heads // self.num_selected_heads # ----------- # Inlined from: shram_layer_cache.py @@ -771,6 +729,269 @@ class MoSRAHCache(CacheLayerMixin): ) +# ----------- +# Inlined from: router_cache.py +# ----------- +"""Block-state cache for the MoSRAH causal block-balanced router. + +The block-balanced router partitions the token sequence into non-overlapping blocks +of W = L/K tokens. Within each block every expert is assigned exactly once, giving +perfect load balance by construction. During training the full sequence is available +and block state is managed locally in MoSRAHRouter.forward(). During inference tokens +arrive one at a time and the router must remember which experts have been claimed in +the current partial block across decode steps. + +RouterCache holds two pieces of state across decode steps: + + - _used_in_block: Boolean mask (B, L) tracking which experts have been claimed by + earlier tokens in the current block. The decode router masks these to -inf before + TopK, preserving the one-usage-per-block invariant. + + - _step_in_block: Integer counter (B,) of how many tokens have been processed in + the current block. Reaches block_length W when the block completes, at which + point both tensors are reset in-place for the next block. + +All decode-step operations (update_decode) use fixed-shape in-place tensor ops and +are fully compileable under torch.compile(dynamic=False, fullgraph=True). The prefill +update (update_prefill) may use data-dependent indexing and must not be called inside +a compiled graph; prefill runs in eager mode before the compiled decode loop in +standard HuggingFace generate(). + +RouterCache is constructed by ShramLayerCache and passed directly to +MoSRAHRouter.forward(). ShramLayerCache.reset() clears the router state atomically +with the KV caches it also owns. +""" + + + + + +class RouterCache(CacheLayerMixin): + """Block-state cache for the MoSRAH causal block-balanced router. + + Tracks which experts have been claimed in the current routing block and how + far into that block the current decode step is. This allows the router to + maintain its one-usage-per-block contract across decode steps without + reprocessing the full accumulated sequence. + + All state is pre-allocated at construction time. The primary decode method + (update_decode) uses only in-place fixed-shape operations and is fully + compileable. + + Args: + block_length: Tokens per routing block, W = num_mosrah_heads // num_selected_heads. + The router resets block state after every W consecutive decode tokens. + num_mosrah_heads: Total expert count L. Determines the width of the + used-expert mask. + batch_size: Number of sequences in the batch. + device: Device on which to allocate state tensors. + """ + + is_compileable = True + is_sliding = False + + def __init__( + self, + block_length: int, + num_mosrah_heads: int, + batch_size: int, + device: torch.device, + ) -> None: + super().__init__() + self._block_length = block_length + self._device = device + + # used_in_block: which experts are already claimed in the current block. + # False = expert is still available for the next decode token that needs it. + # Reset to all-False when step_in_block reaches block_length. + self._used_in_block = torch.zeros( + batch_size, num_mosrah_heads, dtype=torch.bool, device=device + ) + + # step_in_block: how many tokens have been processed in the current block. + # Range [0, block_length - 1]. Resets to 0 when a block completes. + self._step_in_block = torch.zeros(batch_size, dtype=torch.int64, device=device) + + # --------------------------------------------------------------------------- + # is_initialized — pre-allocated at construction, always True + # --------------------------------------------------------------------------- + + @property + def is_initialized(self) -> bool: + """True always — RouterCache pre-allocates all state at construction.""" + return True + + @is_initialized.setter + def is_initialized(self, value: bool) -> None: + # CacheLayerMixin.__init__ assigns self.is_initialized = False as an + # instance attribute. Absorb it silently — state is always initialized. + pass + + # --------------------------------------------------------------------------- + # Public interface for the router + # --------------------------------------------------------------------------- + + def get_used_in_block(self) -> torch.Tensor: + """Return the current block's used-expert mask. + + Returns: + Boolean mask of shape (B, L). True entries mark experts already claimed + by earlier tokens in the current block and must be excluded from TopK. + """ + return self._used_in_block + + def update_decode(self, step_heads: torch.Tensor) -> None: + """Record a single decode-step expert selection and advance the block counter. + + Marks the K selected experts as used in the current block, then either + advances the per-batch step counter or resets both tensors in-place when + the block completes. All operations are in-place and compile-compatible. + + Args: + step_heads: Expert indices selected at this decode step, shape (B, K). + """ + # Mark the K selected experts as unavailable for the rest of this block. + self._used_in_block.scatter_(-1, step_heads, True) + + # Detect block completion before incrementing: step was W-1 (0-indexed), + # meaning this token is the last one in the current block. + block_done = self._step_in_block.eq(self._block_length - 1) # (B,) bool + + # Advance step counter, then zero it for any batch item that just finished a block. + self._step_in_block.add_(1) + self._step_in_block.masked_fill_(block_done, 0) + + # Clear expert availability for batch items that completed a block, so the + # next decode token for those items starts with a clean slate. + self._used_in_block.masked_fill_(block_done.unsqueeze(-1), False) + + def update_prefill( + self, + selected_heads_blocked: torch.Tensor, + seq_len: int, + ) -> None: + """Record the partial block state left over at the end of a prefill pass. + + After processing a prefill sequence of length seq_len with the training-style + block solver, the last block may be incomplete when seq_len is not a multiple + of block_length. This method saves the partial block state so decode steps can + continue the current block without a gap. + + Not compile-compatible: uses a data-dependent slice [:seq_mod] on the W + dimension. Must only be called in eager mode. Standard HuggingFace generate() + runs prefill in eager before entering the compiled decode loop. + + Args: + selected_heads_blocked: Block-solver assignment output from the prefill pass, + shape (B, num_blocks, W, K). The final block entry contains expert + assignments for both real tokens (steps 0..seq_mod-1) and padding + artefacts (steps seq_mod..W-1) which must be discarded. + seq_len: Actual prefill sequence length before block padding. Determines + how many steps of the last block contain real assignments. + """ + B = selected_heads_blocked.shape[0] + seq_mod = seq_len % self._block_length + + self._used_in_block.zero_() + + if seq_mod == 0: + # All blocks were complete — start fresh for the next decode token. + self._step_in_block.zero_() + else: + # Last block is partial: only the first seq_mod steps are real assignments. + # Rebuild the used-expert mask from those steps and record the step position. + last_block_real_steps = selected_heads_blocked[:, -1, :seq_mod, :] # (B, seq_mod, K) + real_experts_flat = last_block_real_steps.reshape(B, -1) # (B, seq_mod * K) + self._used_in_block.scatter_(-1, real_experts_flat, True) + self._step_in_block.fill_(seq_mod) + + # --------------------------------------------------------------------------- + # CacheLayerMixin — reset and beam-search coordination + # --------------------------------------------------------------------------- + + def reset(self) -> None: + """Clear block state for a new generation session. + + Zeros both state tensors in-place. Called by ShramLayerCache.reset() + atomically with the KV cache reset. + """ + self._used_in_block.zero_() + self._step_in_block.zero_() + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + """Reorder the batch dimension for beam search. + + Args: + beam_idx: Permutation indices of shape (batch,). + """ + self._used_in_block = self._used_in_block[beam_idx] + self._step_in_block = self._step_in_block[beam_idx] + + def batch_repeat_interleave(self, repeats: int) -> None: + """Expand the batch dimension for beam search initialisation. + + Args: + repeats: Number of times to repeat each batch entry along the batch dimension. + """ + self._used_in_block = self._used_in_block.repeat_interleave(repeats, dim=0) + self._step_in_block = self._step_in_block.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor) -> None: + """Select a subset of batch entries for contrastive search. + + Args: + indices: 1-D integer tensor of batch indices to retain. + """ + self._used_in_block = self._used_in_block[indices] + self._step_in_block = self._step_in_block[indices] + + def offload(self) -> None: + """Move state tensors to CPU for memory management between decode steps.""" + self._used_in_block = self._used_in_block.cpu() + self._step_in_block = self._step_in_block.cpu() + + def prefetch(self) -> None: + """Move state tensors back to model device ahead of the next decode step.""" + self._used_in_block = self._used_in_block.to(self._device) + self._step_in_block = self._step_in_block.to(self._device) + + # --------------------------------------------------------------------------- + # CacheLayerMixin — unsupported abstract methods + # --------------------------------------------------------------------------- + + def lazy_initialization( # type: ignore[override] + self, key_states: torch.Tensor, value_states: torch.Tensor + ) -> None: + """No-op — RouterCache pre-allocates all state at construction.""" + pass + + def update( # type: ignore[override] + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: dict | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Not supported — use update_decode() or update_prefill() instead.""" + raise NotImplementedError( + "RouterCache has no composite key/value update interface. " + "Use update_decode() for single decode steps or update_prefill() after prefill." + ) + + def get_seq_length(self) -> int: + """Not supported — RouterCache tracks block position, not sequence length.""" + raise NotImplementedError("RouterCache does not track sequence length.") + + def get_max_cache_shape(self) -> int: + """Not supported — RouterCache does not hold KV pairs.""" + raise NotImplementedError("RouterCache does not have a KV cache shape.") + + def get_mask_sizes( # type: ignore[override] + self, + cache_position: torch.Tensor, + ) -> tuple[int, int]: + """Not supported — RouterCache does not participate in KV attention masking.""" + raise NotImplementedError("RouterCache does not participate in KV masking.") + # ----------- # Inlined from: sliding_window_cache.py # ----------- @@ -1102,13 +1323,14 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin): class ShramLayerCache(CacheLayerMixin): """Cache subsystem for one SHRAM decoder layer. - Owns and coordinates two sub-caches: + Owns and coordinates three sub-caches: - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path. - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path. + - router_cache: RouterCache for the block-balanced router's block state. - Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are - exposed directly for their downstream attention paths — no composite update() interface is - provided, because the two paths have materially different update semantics. + Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The sub-caches are + exposed directly for their downstream consumers — no composite update() interface is + provided, because the paths have materially different update semantics. Sequence length is reported by delegating to the local sliding-window sub-cache, which tracks the cumulative count of token positions processed across all update() calls. @@ -1145,6 +1367,12 @@ class ShramLayerCache(CacheLayerMixin): device=device, mosrah_cache_length=config.mosrah_cache_length, ) + self.router_cache = RouterCache( + block_length=config.block_length, + num_mosrah_heads=config.num_mosrah_heads, + batch_size=batch_size, + device=device, + ) # --------------------------------------------------------------------------- # Properties @@ -1157,7 +1385,11 @@ class ShramLayerCache(CacheLayerMixin): Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction, so this is True immediately after ShramLayerCache.__init__ returns. """ - return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized + return ( + self.sliding_window_cache.is_initialized + and self.mosrah_cache.is_initialized + and self.router_cache.is_initialized + ) @is_initialized.setter def is_initialized(self, value: bool) -> None: @@ -1188,6 +1420,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.reset() self.mosrah_cache.reset() + self.router_cache.reset() def reorder_cache(self, beam_idx: torch.LongTensor) -> None: """Reorder the batch dimension of both sub-caches for beam search. @@ -1200,6 +1433,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.reorder_cache(beam_idx) self.mosrah_cache.reorder_cache(beam_idx) + self.router_cache.reorder_cache(beam_idx) def batch_repeat_interleave(self, repeats: int) -> None: """Expand the batch dimension of both sub-caches for beam search initialisation. @@ -1212,6 +1446,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.batch_repeat_interleave(repeats) self.mosrah_cache.batch_repeat_interleave(repeats) + self.router_cache.batch_repeat_interleave(repeats) def batch_select_indices(self, indices: torch.Tensor) -> None: """Select a subset of batch entries in both sub-caches for contrastive search. @@ -1224,6 +1459,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.batch_select_indices(indices) self.mosrah_cache.batch_select_indices(indices) + self.router_cache.batch_select_indices(indices) def offload(self) -> None: """Offload both sub-caches to CPU. @@ -1233,6 +1469,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.offload() self.mosrah_cache.offload() + self.router_cache.offload() def prefetch(self) -> None: """Move both sub-caches back to their model device ahead of time. @@ -1242,6 +1479,7 @@ class ShramLayerCache(CacheLayerMixin): """ self.sliding_window_cache.prefetch() self.mosrah_cache.prefetch() + self.router_cache.prefetch() def lazy_initialization( # type: ignore[override] self, key_states: torch.Tensor, value_states: torch.Tensor @@ -1478,8 +1716,8 @@ Returns a plain dict with keys: - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size) - "past_key_values": the ShramCache object passed in, or None - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None -- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses -- "max_vio": detached scalar maximum routing-imbalance across all decoder layers +- "regret_loss": scalar sum of per-layer SHRAM regret losses +- "logit_regret": detached scalar mean per-layer logit-space regret - "logit_std": detached scalar mean per-layer per-token routing logit spread """ @@ -2104,6 +2342,7 @@ In particular, this path must preserve three architectural distinctions: + # ----------- # Inlined from: bottlenecked_ensemble_attention.py # ----------- @@ -2660,8 +2899,9 @@ def _enforce_no_overflow(tokens_per_expert: torch.Tensor, packed_length: int) -> This check fires when the number of tokens assigned to any expert in any batch item exceeds mosrah_packed_length. When that limit is exceeded, the packed buffer - is too small to hold all assignments and data would be dropped. Increase - mosrah_overallocation_factor in ShramConfig to resolve. + is too small to hold all assignments and data would be dropped. Reduce the input + sequence length or increase training_sequence_length (for training) or + inference_sequence_length (for inference) in ShramConfig to resolve. Args: tokens_per_expert: Per-expert token counts, shape (B, num_experts). @@ -2671,15 +2911,17 @@ def _enforce_no_overflow(tokens_per_expert: torch.Tensor, packed_length: int) -> torch._assert_async( tokens_per_expert.max() <= packed_length, "Expert packing overflow: expert bucket exceeds mosrah_packed_length. " - "Increase mosrah_overallocation_factor in ShramConfig.", + "Reduce sequence length or increase training_sequence_length / " + "inference_sequence_length in ShramConfig.", ) else: max_count = tokens_per_expert.max().item() if max_count > packed_length: raise RuntimeError( "Expert packing overflow: at least one expert bucket contains more " - "tokens than mosrah_packed_length allows. Increase " - "mosrah_overallocation_factor in ShramConfig to resolve.\n" + "tokens than mosrah_packed_length allows. Reduce sequence length or " + "increase training_sequence_length / inference_sequence_length in " + "ShramConfig to resolve.\n" f"Packed length: {packed_length}\n" f"Head lengths: {tokens_per_expert}\n" ) @@ -2720,90 +2962,32 @@ def _count_tokens_per_expert( # ----------- """Token-choice router for the MoSRAH sparse attention path. -This module implements the routing mechanism described in Appendix A.Routing of the -paper. Given an input hidden state x, the router produces two outputs used downstream: +This module implements mechanically load-balanced routing for MoSRAH. Given an +input hidden state x, the router produces two outputs used downstream: - - selected_heads (I): which K of the L available expert heads each token routes to, - determined by TopK over capacity-balanced routing scores. - - routing_probs (P): the weights used for the weighted output reduction, gathered from - the routing scores at the selected indices and renormalized to sum to 1 per token. + - selected_heads (I): which K of the L available expert heads each token + routes to, determined by a block-balanced causal solver. + - routing_probs (P): the weights used for the weighted output reduction, + gathered from the softmax routing scores at the selected indices and + renormalized to sum to 1 per token. Routing uses a single learnable projection: - - routing_weight: shape (L, embedding_width). Maps input to per-head routing scores. - Both task loss and load_balance_loss train this parameter directly — there is no - gradient isolation between the two signals. - -This coupled design is intentional. SHRAM has an unusually strong task-level incentive -to concentrate tokens into the same expert bucket (sparse attention only occurs among -tokens routed to the same expert), so any indirect balancing pathway will be outlearned. -Coupling the gradients allows the load balance loss to act with full strength directly -on the parameter that determines routing. - -routing_weight is nn.Parameter so that HuggingFace _init_weights does not override -its kaiming initialization at construction. - -routing_probs are computed before balance_capacity applies -1e8 sentinels. Post-capacity -softmax would corrupt routing_probs for over-capacity experts (near-zero probability -after masking does not reflect genuine routing preference). - -The router computes and returns: - - load_balance_loss: scalar auxiliary loss (see load_balance_loss.py); gradient flows - to routing_weight. - - max_vio: detached scalar summarising routing imbalance: - MaxVio = mean_b( L · max_l(f_bl − 1/L) ) - where f_bl is the per-batch-item realised routing frequency of head l. Zero means - perfect balance; 1.0 means the most loaded head received double its fair share. - - logit_std: detached scalar; mean per-token standard deviation of routing logits. - Monitoring metric for routing sharpness. - -Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio. -""" - + - routing_weight: shape (L, embedding_width). Maps input to per-head routing + scores. Task loss trains this parameter through routing_probs; regret_loss + trains it to prefer expert assignments at positions of peak preference. +Block-balanced routing partitions the sequence into non-overlapping blocks of +W = L/K tokens. Within each block every expert is assigned to exactly one token, +guaranteeing perfect load balance by construction. The L % K == 0 compatibility +constraint (enforced in ShramConfig) makes W an exact integer. +Selection is causal within each block: at each of the W steps the current +token chooses its K experts from those not yet claimed by earlier tokens in +the same block. All W steps execute in parallel across blocks and batch via +a fully-unrolled Python for loop, keeping the compiled graph flat. - - -# ----------- -# Inlined from: load_balance_loss.py -# ----------- -"""Log-probability auxiliary loss functions for MoSRAH load balancing. - -This module provides five load-balance loss formulations, two token-reduction -helpers, and a factory that selects among the formulations. All formulations -share the same external contract: - - loss_fn( - logits: Tensor[B, N, L], - assignment_mask: Tensor[B, N, L], - active_mask: Tensor[B, N], - ) -> scalar Tensor - - logits: Pre-softmax routing scores, shape (B, N, L). Gradient flows - through this tensor. - assignment_mask: Per-token head-assignment indicators. assignment_mask[b, n, l] - is 1.0 if token (b, n) was assigned to head l. Dead tokens - should carry zero entries. - active_mask: Boolean mask, shape (B, N). True means the token is - semantically live. - -Token reduction is split into two helpers with distinct roles: - - reduce_frequency_tokens — produces per-batch-item routing frequencies f_bl (B, L). - Called by gshard, ce, and bce. Output is detached; f_bl carries no gradient. - - reduce_probability_tokens — produces per-batch-item mean assignment probabilities - p_bl (B, L). Called only by gshard and bce. Gradient flows through the - internal softmax over logits. - -CE delegates probability computation to F.cross_entropy, which handles its own -log_softmax and operates directly on the raw (B, N, L) logits. - -``make_load_balance_loss`` is the sole public entry point. The individual loss -functions are internal implementation details; their signatures may change between -units. Callers and tests must construct loss callables through the factory, not by -importing or invoking the loss functions directly. +Paper ref: Appendix A.Routing. """ @@ -2811,573 +2995,37 @@ importing or invoking the loss functions directly. -# --------------------------------------------------------------------------- -# Token-reduction helpers -# --------------------------------------------------------------------------- - -def reduce_frequency_tokens( - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, -) -> torch.Tensor: - """Reduce per-token head assignments to per-batch-item routing frequencies. - - f_bl[b, l] is the fraction of active-token assignments in batch item b going - to head l. Values sum to 1 per batch item when routing is valid. - - The output is detached from the autograd graph: routing frequencies are - derived from discrete TopK selections and must not carry gradients. - - Denominators are clamped to 1 to handle the all-dead-tokens edge case. - - Args: - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - active_mask: Boolean active-token mask, shape (B, N). - - Returns: - f_bl: Per-batch-item routing frequencies, shape (B, L). Detached. - """ - active_float = active_mask.float().unsqueeze(-1) # (B, N, 1) - active_assignments = assignment_mask * active_float # (B, N, L) - assignment_totals = ( - active_assignments.sum(dim=(1, 2)).clamp(min=1.0).unsqueeze(-1) # (B, 1) - ) - return (active_assignments.sum(dim=1) / assignment_totals).detach() # (B, L) - - -def reduce_probability_tokens( - logits: torch.Tensor, - active_mask: torch.Tensor, -) -> torch.Tensor: - """Reduce per-token load-balancing logits to per-batch-item assignment probabilities. - - p_bl[b, l] is the mean softmax probability for head l over active tokens in - batch item b. Values sum to 1 per batch item. Gradient flows to expert_bias - through the internal softmax. - - Denominators are clamped to 1 to handle the all-dead-tokens edge case. - - Args: - logits: Load-balancing logits, shape (B, N, L). Gradient flows through. - active_mask: Boolean active-token mask, shape (B, N). - - Returns: - p_bl: Per-batch-item mean assignment probabilities, shape (B, L). - """ - per_token_probs = F.softmax(logits, dim=-1) # (B, N, L) - active_float = active_mask.float().unsqueeze(-1) # (B, N, 1) - active_count = active_mask.float().sum(dim=1, keepdim=True).clamp(min=1.0) # (B, 1) - return (per_token_probs * active_float).sum(dim=1) / active_count # (B, L) - - -# --------------------------------------------------------------------------- -# Loss functions -# --------------------------------------------------------------------------- - -def gshard_loss( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, -) -> torch.Tensor: - """GShard-style linear load-balance loss. - - Computes (1/L) * Σ_l f_bl * p_bl per batch item, averaged over B, where - f_bl comes from reduce_frequency_tokens and p_bl from reduce_probability_tokens. - - The linear signal is the weakest of the three formulations; gradient magnitude - does not grow with violation severity. Provided for comparison. - - Args: - logits: Load-balancing logits, shape (B, N, L). - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - active_mask: Boolean active-token mask, shape (B, N). - - Returns: - Scalar loss tensor. - """ - L = logits.shape[-1] - f_bl = reduce_frequency_tokens(assignment_mask, active_mask) - p_bl = reduce_probability_tokens(logits, active_mask) - return (f_bl * p_bl).sum(dim=-1).mean() / L - - -def ce_loss( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, -) -> torch.Tensor: - """Cross-entropy load-balance loss. - Constructs per-batch-item soft target distributions from routing frequencies - and delegates to F.cross_entropy operating directly on (B, N, L) logits. - Inactive tokens receive all-zero targets, producing zero loss and zero gradient. - - The soft target for head l in batch item b is (1 - f_bl) / (L - 1). This - distribution sums to 1 per batch item (since Σ_l (1 - f_bl) = L - 1) and - weights underloaded heads (low f_bl → high target) more strongly than - overloaded ones. - - The total CE over active tokens is normalised by the active token count rather - than B*N to avoid dilution from inactive positions. - - Args: - logits: Load-balancing logits, shape (B, N, L). - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - active_mask: Boolean active-token mask, shape (B, N). - - Returns: - Scalar loss tensor. - """ - B, N, L = logits.shape - f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L) - active_count = active_mask.float().sum().clamp(min=1.0) - - # Soft target: (1 - f_bl) / (L - 1) for active tokens, zeros for inactive. - # Zeros give zero CE loss and zero gradient at inactive positions. - target = (1.0 - f_bl) / (L - 1) # (B, L) - target_per_token = ( - target.unsqueeze(1).expand(-1, N, -1) # (B, N, L) - * active_mask.float().unsqueeze(-1) # zero inactive - ) - - # F.cross_entropy requires the class dimension to be dim 1. - # Permute (B, N, L) → (B, L, N) to satisfy the (N, C, d) contract. - return F.cross_entropy( - logits.permute(0, 2, 1), # (B, L, N) - target_per_token.permute(0, 2, 1), # (B, L, N) - reduction='sum', - ) / active_count - - -def bce_loss( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, -) -> torch.Tensor: - """Binary cross-entropy load-balance loss. - - Treats each head as an independent binary target with label (1 - f_bl). - Uses reduce_probability_tokens to produce per-batch-item probabilities, - then delegates to F.binary_cross_entropy over (B, L) tensors. - - Unlike CE, BCE maintains a repulsion signal from saturated experts: when - f_bl → 1 the target → 0, driving p_bl away from 1 and preventing runaway - concentration. - - Active masking is handled inside reduce_frequency_tokens and - reduce_probability_tokens, so the (B, L) output tensors already exclude - inactive tokens from both frequencies and probabilities. - - Args: - logits: Load-balancing logits, shape (B, N, L). - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - active_mask: Boolean active-token mask, shape (B, N). - - Returns: - Scalar loss tensor. - """ - f_bl = reduce_frequency_tokens(assignment_mask, active_mask) - p_bl = reduce_probability_tokens(logits, active_mask) - # Clamp for numerical safety: softmax outputs are strictly positive in - # normal operation; the clamp guards the all-dead-tokens edge case where - # the mean defaults to zero. log1p(-p) avoids cancellation near p=1. - p = p_bl.clamp(min=1e-7, max=1.0 - 1e-7) - target = 1.0 - f_bl - return -(target * torch.log(p) + (1.0 - target) * torch.log1p(-p)).mean() - - -def _temporal_overcapacity_loss( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, - expected_tokens_rate: float, - maximum_expert_overclaim: int, -) -> torch.Tensor: - """Temporal overcapacity loss for MoSRAH load balancing. - - Penalises routing decisions that select a head already overloaded relative to - its ideal allocation trajectory. A head is considered overloaded when the number - of active tokens before position n assigned to that head exceeds - cumulative_active_tokens * M + C, where M is the expected_tokens_rate (K/L) and - C is the maximum_expert_overclaim slack. - - Loss is exactly zero when no head exceeds its trajectory, making it safe to - weight strongly — it stays out of the way when routing is balanced. - - Args: - logits: Pre-softmax routing scores, shape (B, N, L). - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - 1.0 if token (b, n) is assigned to head l. - active_mask: Boolean active-token mask, shape (B, N). - expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed - by the factory so the division is not repeated each - forward pass. - maximum_expert_overclaim (C): Slack above the ideal trajectory before - imbalance fires. Larger C tolerates more deviation. - - Returns: - Scalar loss tensor. Exactly 0.0 when no head exceeds its allowed trajectory. - """ - # ── Algorithm overview ────────────────────────────────────────────────────── - # - # Problem: token routing is stateless — each token's TopK selection is blind to - # how many times each expert has already been chosen earlier in the sequence. A - # router that develops a strong preference for certain experts will overload them - # far beyond their K/L fair share with no correction signal at the moment of - # selection. - # - # Approach: track per-head assignment history as exclusive cumulative counts - # (assignments by all active tokens strictly before position n) and compare - # against an ideal trajectory S·M, where S is the inclusive cumulative active - # token count and M is the amount of tokens expected given ideal balancing - # A head is overloaded when its prior count exceeds that trajectory - # by more than C. When a token selects an already-overloaded head, the loss - # moment — mean(violating logits) minus mean(non-overloaded logits) — penalises - # the gap and pushes future routing toward underloaded alternatives. - - # ── Routing history and imbalance threshold ────────────────────────────────── - # - # prior_assignment_counts is the exclusive routing history at each position: - # active assignments to each head by all tokens strictly before position n. - # Exclusive because it reflects only what was known when token n was being routed. - # cumulative_active_tokens grows by 1 per active token; the ideal per-head - # allocation at n is S·M. Exceeding that by more than C triggers imbalance. - - active_float = active_mask.float() # (B, N) - active_assignments = assignment_mask * active_float.unsqueeze(-1) # (B, N, L) - - # exclusive cumsums: subtract self to exclude position n - prior_assignment_counts = active_assignments.cumsum(dim=1) - active_assignments # (B, N, L) - cumulative_active_tokens = active_float.cumsum(dim=1) - active_float # (B, N) - - maximum_supportable_assignments = ( - cumulative_active_tokens.unsqueeze(-1) * expected_tokens_rate - + maximum_expert_overclaim - ) # (B, N, 1) → broadcasts to (B, N, L) - - # ── Mask construction ──────────────────────────────────────────────────────── - # - # Three derived masks: - # imbalance_mask: any head exceeding its trajectory. - # violating_selection_mask: selected AND imbalanced — the penalty target. - # non_overloaded_head_mask: NOT imbalanced, regardless of selection. - # - # Masking is deliberately assymetric. We have a problem when something is over - # capacity AND gets chosen by topk. We can transfer it elsewhere only if we - # are not overcapacity. - - imbalance_mask = prior_assignment_counts > maximum_supportable_assignments # (B, N, L) - violating_selection_mask = assignment_mask.bool() & imbalance_mask # (B, N, L) - non_overloaded_head_mask = ~imbalance_mask # (B, N, L) - has_violation_mask = violating_selection_mask.any(dim=-1) # (B, N) - - # ── Loss moment ──────────────────────────────────────────────────────── - # - # Epsilons on the count denominators guard against NaN when violation_count or - # non_overloaded_count is zero. has_violation_mask zeros positions with no - # violations at the gating step, so the epsilon-inflated denominator never - # contributes to the loss. - # - # One notable property of this moment is it keeps the amount of transferred - # logit mass constant. That is the gradient reduces violating logits and increases - # non-overloaded logits by equal magnitude. Routing is redirected, not suppressed. - - violation_count = violating_selection_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N) - non_overloaded_count = non_overloaded_head_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N) - mean_violating_logit = (violating_selection_mask.float() * logits).sum(dim=-1) / violation_count # (B, N) - mean_non_overloaded_logit = (non_overloaded_head_mask.float() * logits).sum(dim=-1) / non_overloaded_count # (B, N) - raw_loss = mean_violating_logit - mean_non_overloaded_logit # (B, N) - - # ── Loss reduction ─────────────────────────────────────────────────────────── - # - # Reduction is over active positions only; dead tokens are excluded from both - # numerator (gated by active_float) and denominator (active_count_per_seq). - # clamp(min=1.0) handles the all-dead-tokens edge case: gated_loss is zero - # there since active_float gates it, so the result is 0/1 = 0. - # - # Exact-zero guarantee: when no head exceeds its trajectory, has_violation_mask - # is all-False, gated_loss is zeroed everywhere, and the scalar return is - # exactly 0.0. The loss is inert when routing is balanced. - - gated_loss = active_float * has_violation_mask.float() * raw_loss # (B, N) - active_count_per_seq = active_float.sum(dim=1).clamp(min=1.0) # (B,) - sequence_loss = gated_loss.sum(dim=1) / active_count_per_seq # (B,) - final_loss = sequence_loss.mean() - return final_loss - - -def _causal_overcapacity_loss( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, - expected_tokens_rate: float, - maximum_expert_overclaim: int, -) -> torch.Tensor: - """Causal overcapacity loss for MoSRAH load balancing. - - Penalises selected expert trajectories that exceed their ideal cumulative - allocation budget. A selected expert assignment is over capacity when its - inclusive active assignment count exceeds cumulative_active_tokens * M + C, - where M is the expected_tokens_rate (K/L) and C is the - maximum_expert_overclaim slack. - - The loss consumes discrete TopK assignment structure but only routes gradients - through logits. It returns an fp32 scalar and is exactly inactive when no active - selected expert exceeds its allowed trajectory. - - Args: - logits: Pre-softmax routing scores, shape (B, N, L). - Gradient flows through this tensor. - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - 1.0 if token (b, n) is assigned to head l. - active_mask: Boolean active-token mask, shape (B, N). - expected_tokens_rate (M): Ideal per-head allocation rate K/L. Pre-computed - by the factory so the division is not repeated each - forward pass. - maximum_expert_overclaim (C): Slack above the ideal trajectory before - overcapacity fires. Larger C tolerates more deviation. - - Returns: - Scalar fp32 loss tensor. Exactly 0.0 when no active selected expert exceeds - its allowed trajectory. Can be interpreted as the difference in nats of preference - between the violating and typical paths. - """ - # ── Algorithm overview ─────────────────────────────���──────────────────────── - # - # Expert selections form causal trajectories through the sequence. Each trajectory - # is scored by the mean signed nats of the selected routing events that produced - # it: larger trajectory nats mean the router preferred that path more strongly. - # - # When a selected trajectory exceeds its cumulative budget, the loss forms a - # preference contrast between the violating trajectory field and the baseline - # trajectory field. Minimizing that contrast suppresses the over-preferred path - # while lifting alternatives through the router softmax. - # - # This is not precisely equivalent to log likihood due to the selection - # of multiple experts per round, but we deem this issue to be insignificant. - - # ── Process setup ──────────────────────────────────────────────────────────── - # - # A small amount of standardization is needed before the loss-specific trajectory - # logic begins. Active selected assignments define the event structure. Routing - # log-probabilities remain the only differentiable source and are computed in fp32 - # so the downstream trajectory accumulation does not inherit reduced precision. - - selected_assignment_mask = assignment_mask.bool() # (B, N, L) - active_assignment_mask = selected_assignment_mask & active_mask.unsqueeze(-1) # (B, N, L) - routing_log_probability = F.log_softmax(logits.float(), dim=-1) # (B, N, L) - - # ── Mask construction ──────────────────────────────────────────────────────── - # - # The corrective target set is defined by active selected assignments whose - # inclusive count crosses the allowed causal budget. Position and sequence masks - # identify where that target set exists; they are reduction structure, not a - # separate source of gradient. - - inclusive_assignment_count = active_assignment_mask.to(torch.int32).cumsum(dim=1) # (B, N, L) - inclusive_active_token_count = active_mask.to(torch.int32).cumsum(dim=1) # (B, N) - - maximum_allowed_assignment_count = ( - inclusive_active_token_count.float().unsqueeze(-1) * expected_tokens_rate - + maximum_expert_overclaim - ) # (B, N, 1) → broadcasts to (B, N, L) - - violating_assignment_mask = ( # (B, N, L) - active_assignment_mask - & (inclusive_assignment_count.float() > maximum_allowed_assignment_count) - ) - has_violation_at_position = violating_assignment_mask.any(dim=-1) # (B, N) - has_violation_in_sequence = has_violation_at_position.any(dim=-1) # (B,) - - # ── Trajectory construction ────────────────────────────────────────────────── - # - # The current selection is part of the trajectory being judged, so the trajectory - # score is inclusive. Empty histories intentionally receive the neutral zero score; - # this keeps the later baseline compact without introducing a second eligibility - # system. - - selected_trajectory_nat_sum = ( # (B, N, L) - active_assignment_mask.float() * routing_log_probability - ).cumsum(dim=1) - mean_selected_trajectory_nats = ( # (B, N, L) - selected_trajectory_nat_sum - / inclusive_assignment_count.clamp(min=1).float() - ) - - # ── Contrast construction ──────────────────────────────────────────────────── - # - # This is the correction moment. The violating trajectory field is compared to - # the baseline trajectory field at the same sequence position, producing a signed - # preference contrast measured in nats. - - violating_assignment_count = violating_assignment_mask.float().sum(dim=-1).clamp(min=1.0) # (B, N) - mean_violating_trajectory_nats = ( # (B, N) - (violating_assignment_mask.float() * mean_selected_trajectory_nats).sum(dim=-1) - / violating_assignment_count - ) - mean_baseline_trajectory_nats = mean_selected_trajectory_nats.mean(dim=-1) # (B, N) - contrastive_preference_nats = ( # (B, N) - mean_violating_trajectory_nats - - mean_baseline_trajectory_nats - ) - - # ── Violation-only reduction ───────────────────────────────��───────────────── - # - # Non-violating positions and sequences are not anchors for this loss. The scalar - # is an average violation contrast, not total violation mass, and the entire loss - # remains exactly inactive when no corrective target exists. - - violation_position_count = has_violation_at_position.float().sum(dim=-1).clamp(min=1.0) # (B,) - sequence_preference_nats = ( # (B,) - (contrastive_preference_nats * has_violation_at_position.float()).sum(dim=-1) - / violation_position_count - ) - violating_sequence_count = has_violation_in_sequence.float().sum().clamp(min=1.0) # scalar - final_loss = ( # scalar - sequence_preference_nats * has_violation_in_sequence.float() - ).sum() / violating_sequence_count - return final_loss - - -# --------------------------------------------------------------------------- -# Factory -# --------------------------------------------------------------------------- - -def _gshard_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - return gshard_loss - - -def _ce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - return ce_loss - - -def _bce_factory(**kwargs: object) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - return bce_loss - - -def _temporal_overcapacity_factory( - num_selected_heads: int, - num_total_heads: int, - maximum_expert_overclaim: int, - **kwargs: object, -) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - expected_tokens_rate = num_selected_heads / num_total_heads - def _runtime( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, - ) -> torch.Tensor: - return _temporal_overcapacity_loss( - logits, assignment_mask, active_mask, - expected_tokens_rate=expected_tokens_rate, - maximum_expert_overclaim=maximum_expert_overclaim, - ) - return _runtime - - -def _causal_overcapacity_factory( - num_selected_heads: int, - num_total_heads: int, - maximum_expert_overclaim: int, - **kwargs: object, -) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - expected_tokens_rate = num_selected_heads / num_total_heads - def _runtime( - logits: torch.Tensor, - assignment_mask: torch.Tensor, - active_mask: torch.Tensor, - ) -> torch.Tensor: - return _causal_overcapacity_loss( - logits, assignment_mask, active_mask, - expected_tokens_rate=expected_tokens_rate, - maximum_expert_overclaim=maximum_expert_overclaim, - ) - return _runtime - - -_LOSS_REGISTRY: dict[str, Callable[..., Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]]] = { - "gshard": _gshard_factory, - "ce": _ce_factory, - "bce": _bce_factory, - "temporal_overcapacity": _temporal_overcapacity_factory, - "causal_overcapacity": _causal_overcapacity_factory, -} - - -def make_load_balance_loss( - loss_type: str, - **loss_parameters: object, -) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: - """Return a load-balance loss callable for the requested formulation. - - All returned callables share the external contract: - - loss_fn( - logits: Tensor[B, N, L], - assignment_mask: Tensor[B, N, L], - active_mask: Tensor[B, N], - ) -> scalar Tensor - - Keyword arguments are forwarded to the selected factory. The gshard, ce, and bce - factories silently ignore all kwargs; this allows callers to pass loss-type-specific - parameters (e.g. for overcapacity losses) without branching on loss_type. - - Args: - loss_type: One of ``"gshard"``, ``"ce"``, ``"bce"``, - ``"temporal_overcapacity"``, or ``"causal_overcapacity"``. - **loss_parameters: Construction-time parameters forwarded to the factory. - - Returns: - Loss callable matching the shared contract. - - Raises: - ValueError: If loss_type is not one of the supported values. - """ - if loss_type not in _LOSS_REGISTRY: - supported = ", ".join(f'"{k}"' for k in _LOSS_REGISTRY) - raise ValueError( - f"load_balance_loss_type must be one of {supported}, got {loss_type!r}." - ) - return _LOSS_REGISTRY[loss_type](**loss_parameters) class MoSRAHRouter(nn.Module): """Token-choice router for MoSRAH sparse attention. - Each input token independently selects K of the L available expert heads. - A single routing projection maps input hidden states to per-head scores; both - task loss and load_balance_loss train this projection directly. + Each input token independently selects K of the L available expert heads + through a block-balanced causal solver. Within each block of W = L/K + consecutive tokens every expert is used exactly once, giving perfect load + balance by construction. routing_weight is nn.Parameter rather than nn.Linear so that HuggingFace _init_weights does not override its kaiming initialization at construction. Attributes: routing_weight: Shape (L, embedding_width). Maps input hidden states to - per-head routing scores. Receives gradients from both task loss and - load_balance_loss. + per-head routing scores. + block_length: Tokens per routing block W = L / K. Within each block + every expert is used exactly once. Args: config: Model configuration. Must expose ``embedding_width``, - ``num_mosrah_heads`` (L), ``num_selected_heads`` (K), - ``load_balance_loss_type``, ``maximum_expert_overclaim``, ``max_bid_rounds``, - ``use_cache``, ``mosrah_cache_length``, and ``mosrah_packed_length``. + ``num_mosrah_heads`` (L), ``num_selected_heads`` (K), and + ``block_length`` (W). """ def __init__(self, config: ShramConfig) -> None: super().__init__() - self.num_mosrah_heads = config.num_mosrah_heads + self.num_mosrah_heads = config.num_mosrah_heads self.num_selected_heads = config.num_selected_heads - if config.use_cache: - self.capacity = config.mosrah_cache_length - else: - self.capacity = config.mosrah_packed_length - - self.max_bid_rounds = config.max_bid_rounds - self._load_balance_loss = make_load_balance_loss( - config.load_balance_loss_type, - num_selected_heads=config.num_selected_heads, - num_total_heads=config.num_mosrah_heads, - maximum_expert_overclaim=config.maximum_expert_overclaim, - ) + self.block_length = config.block_length # Routing projection: maps input (B, N, d) to per-head routing scores (B, N, L). # nn.Parameter ensures HuggingFace _init_weights does not override kaiming init. @@ -3386,317 +3034,185 @@ class MoSRAHRouter(nn.Module): ) nn.init.kaiming_normal_(self.routing_weight) - @staticmethod - def get_best_proposals( - tensor: torch.Tensor, - dim: int, - n: int | torch.Tensor, - capacity_scalar: int, - ) -> torch.Tensor: - """Return a boolean mask selecting the top-n entries along dim. - - Uses topk to select exactly min(n_per_slice, dim_length) True entries - per slice along dim. Unlike a threshold comparison, this never - over-selects under tied logit values, which occurs when padding tokens - contribute identical scores to multiple expert slots. - - Args: - tensor: Input tensor. Higher values rank first. - dim: Dimension to select along. - n: Per-slice selection count. Scalar int or tensor broadcastable - to tensor with dim removed. Slices where n=0 produce all-False - outputs. - capacity_scalar: Static upper bound on n; used to derive topk k as - min(tensor.shape[dim], capacity_scalar). Must be a Python int - for compile compatibility. - - Returns: - Boolean mask of the same shape as tensor. - """ - positive_dim = dim % tensor.ndim - dim_length = tensor.shape[positive_dim] - k = min(dim_length, capacity_scalar) - - topk_indices = tensor.topk(k, dim=dim).indices - - # Rank tensor broadcast-compatible with topk_indices: rank r along dim - # corresponds to the (r+1)-th highest value in that slice. - rank_shape = [1] * tensor.ndim - rank_shape[positive_dim] = k - ranks = torch.arange(k, device=tensor.device, dtype=torch.long).view(rank_shape) - - # element_included: True where this rank falls within the per-slice budget. - # For scalar n all k ranks satisfy rank < n (since k = min(dim_length, n)). - # For tensor n per-slice budgets differ; rank >= n[slice] yields False, - # correctly excluding excess slots including those with n=0. - if isinstance(n, int): - element_included = ranks < n - else: - element_included = ranks < n.unsqueeze(positive_dim) - - # Allocate from explicit logical shape rather than using zeros_like. This keeps - # the output mask tied to tensor.shape, not to any stride/layout metadata carried - # by tensor from earlier view operations or compiler lowering. - mask = torch.zeros( - tuple(tensor.shape), - device=tensor.device, - dtype=torch.bool, - ) - - # Materialize the scatter source shape explicitly. This avoids passing a - # broadcast-view source into scatter while preserving the same logical rule: - # every selected top-k index receives True iff its rank is within budget. - scatter_values = torch.broadcast_to(element_included, topk_indices.shape) - mask = mask.scatter(dim, topk_indices, scatter_values) - return mask - - @staticmethod - def _check_bidding_converged(acceptances: torch.Tensor, - min_choices: int, - max_rounds: int) -> None: - """Raise if the bidding loop exhausted max_rounds without satisfying all tokens. - - Args: - acceptances: bool tensor of shape (B, N, L) indicating what experts L accepted - what tokens. - min_choices: Convergence has been reached if acceptances are such that a sum along - N always has at least min_choices choices. - max_rounds: The iteration ceiling that was applied, for the error message. Used - for reporting - """ - msg = ( - f"balance_capacity bidding did not converge within {max_rounds} rounds. " - f"Increase mosrah_overallocation_factor or max_bid_rounds." - ) - converged = (acceptances.sum(dim=-1) >= min_choices).all() - torch._assert_async(converged, msg) - - @classmethod - def _run_bidding( - cls, - logits: torch.Tensor, - remaining_capacity: int | torch.Tensor, - min_choices: int, - max_rounds: int, - capacity_scalar: int, - ) -> torch.Tensor: - """Deferred-acceptance (Gale-Shapley) bidding solver for joint capacity enforcement. - - Tokens propose experts in descending preference order; experts provisionally - accept their top-``remaining_capacity`` proposed tokens each round. Proposals - are monotone (never retracted), so once all tokens are satisfied, subsequent - iterations are no-ops. Runs unconditionally for exactly ``max_rounds`` iterations - to keep the compiled graph flat and free of data-dependent control flow. - - Both the column bound (per-expert token count ≤ remaining_capacity) and the - row bound (per-token expert count ≥ min_choices) are satisfied simultaneously - on the returned mask by construction. - - Args: - logits: Routing scores of shape (B, N, L). - remaining_capacity: Per-expert token budget. Scalar int for training; - (B, L) tensor for inference. - min_choices: Minimum experts each token must have accepted (K). - max_rounds: Number of iterations to run. Convergence is checked after - all rounds via ``_check_bidding_converged``; raises if not met. - capacity_scalar: Static upper bound on remaining_capacity, passed to - ``get_mask`` as the topk k bound for the acceptance step. - - Returns: - accepted: (B, N, L) bool — True at positions accepted by the solver. - """ - proposals = torch.zeros_like(logits, dtype=torch.bool) - acceptances = torch.zeros_like(logits, dtype=torch.bool) - - for _ in range(max_rounds): - # ── token proposal step ─────────────────────────────────────────── - # - # Tokens with fewer than min_choices accepted experts propose their - # next-best unproposed expert(s). The deficit determines how many new - # proposals each token makes; satisfied tokens propose nothing - # (deficit = 0 → get_mask returns all-False). Proposals are monotone: - # once all tokens are satisfied, subsequent iterations are no-ops. - accepted_per_token = acceptances.sum(dim=-1) # (B, N) - choices_deficit = (min_choices - accepted_per_token).clamp_min(0) - - unproposed_logits = logits.masked_fill(proposals, float('-inf')) - new_proposals = cls.get_best_proposals( - unproposed_logits, dim=-1, n=choices_deficit, capacity_scalar=min_choices, - ) - proposals = proposals | new_proposals - - # ── expert acceptance step ──────────────────────────────────────── - # - # Each expert accepts its top-remaining_capacity proposed tokens. - # Acceptances are recomputed from scratch each round so that a - # stronger new proposal can displace a weaker prior one. - proposed_logits = logits.masked_fill(~proposals, float('-inf')) - acceptances = cls.get_best_proposals( - proposed_logits, dim=-2, n=remaining_capacity, capacity_scalar=capacity_scalar, - ) - - return acceptances - - @classmethod - def balance_capacity( - cls, - logits: torch.Tensor, - used_capacity: torch.Tensor | None, - capacity: int, - min_choices: int, - max_rounds: int, - mask_value: float = -1e8, - ) -> torch.Tensor: - """Mask logits so both capacity constraints hold simultaneously on the output. - - Two constraints must hold: - - Column bound: per-expert unmasked token count ≤ remaining_capacity. - - Row bound: per-token unmasked expert count ≥ min_choices. - - A training fast path is attempted before the bidding solver: - - 1. Training with N ≤ capacity: return logits unchanged. - 2. Bidding: deferred-acceptance solver guaranteeing both bounds simultaneously. - - Args: - logits: Routing scores of shape (B, N, L). - used_capacity: Tokens already accumulated per expert, shape (B, L). - ``None`` during training (full capacity available). - capacity: Maximum tokens per expert (from config). - min_choices: Minimum experts each token must retain (K). - max_rounds: Bidding iteration ceiling (from config.max_bid_rounds). - mask_value: Value written to masked positions. Default -1e8. - - Returns: - Logits with unavailable positions set to ``mask_value``, shape (B, N, L). - """ - # ── Algorithm overview ──────────────────────────────────────────────── - # - # Problem: mask (B, N, L) logits so that both the column bound (each - # expert receives at most remaining_capacity tokens) and the row bound - # (each token retains at least min_choices expert choices) hold - # simultaneously. Satisfying either constraint greedily can violate the - # other, requiring a joint solver for the hard case. - # - # Approach: deferred-acceptance (Gale-Shapley) bidding. Each round, - # tokens that still lack min_choices accepted experts propose their - # next-best unproposed expert. Each expert then provisionally accepts its - # top-remaining_capacity proposed tokens, potentially displacing weaker - # prior acceptances. Proposals are monotone (never retracted). The loop - # terminates when every token has min_choices accepted experts or - # max_bid_rounds is exhausted (RuntimeError in the latter case). - # - # Training fast path — when N ≤ capacity and all experts start empty, - # no expert can overflow regardless of routing. No masking is needed. - - # Training fast path: N ≤ capacity with empty experts → no overflow possible. - if used_capacity is None and logits.shape[-2] <= capacity: - return logits - - # Compute per-expert remaining budget. - # Training (N > capacity path): scalar — all experts start with full capacity. - # Inference: subtract already-accumulated tokens; clamp prevents negatives - # when rounding causes used_capacity to slightly exceed capacity. - if used_capacity is None: - remaining_capacity = capacity - else: - remaining_capacity = (capacity - used_capacity).clamp(min=0) # (B, L) - - # Bidding solver: jointly satisfies column and row bounds. Runs under - # no_grad because the boolean mask is a hard routing decision and must - # not accumulate gradient memory. - with torch.no_grad(): - final_mask = cls._run_bidding(logits, remaining_capacity, - min_choices, max_rounds, capacity) - cls._check_bidding_converged(final_mask, min_choices, max_rounds) - return logits.masked_fill(~final_mask, mask_value) - def forward( self, x: torch.Tensor, active_mask: torch.Tensor, - used_capacity: torch.Tensor | None + router_cache: RouterCache | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: """Route input tokens to K expert heads each and compute routing probabilities. Args: x: Input hidden states of shape (batch, seq_len, embedding_width). active_mask: Current-chunk active mask of shape (batch, seq_len), where - True means the token is semantically live. Dead tokens do not - contribute to routing frequencies, load_balance_loss, or max_vio. - used_capacity: Used for capacity management during inference, missing during training. + True marks a semantically live token. Dead tokens do not contribute + to regret_loss or logit_regret. Returns: selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads). - Each token's K selected head indices, determined by TopK on - capacity-balanced routing scores. + Each token's K selected head indices from the block-balanced solver. routing_probs: Routing probabilities P of shape (batch, seq_len, - num_selected_heads). Gathered from pre-capacity routing softmax at - selected_heads indices and renormalized to sum to 1 per token. - router_diagnostics: Dict of routing feedback scalars. Keys: - - ``load_balance_loss``: scalar load-balance loss with gradient. - - ``max_vio``: detached scalar routing-imbalance summary. - - ``logit_std``: detached mean per-token std of routing logits; - monitoring metric for routing sharpness. + num_selected_heads). Gathered from the pre-balance softmax at + selected_heads and renormalized to sum to 1 per token. + router_diagnostics: Dict of routing scalars: + - ``regret_loss``: gradient-carrying mean regret, mean of + max(p_max_active − p_chosen, 0) over live (B, num_blocks, L) + entries. In [0, 1]. Zero when every expert is assigned at its + peak-preference token within the block. + - ``logit_regret``: detached logit-space regret; same formula + applied to routing logits rather than softmax probabilities. + In [0, ∞). Monitoring only. + - ``logit_std``: detached mean per-token std of routing logits. """ + # ── Algorithm overview ────────────────────────────────────────────────────── + # + # Problem: each token independently selects its top-K heads with no knowledge + # of what other tokens in the same sequence will choose. Independent selection + # means a single popular head can be chosen by every token while another is + # never used — statistics-based corrections (auxiliary losses, bias vectors) + # can only push routing probabilistically and have proven unstable when tuned + # strongly enough to prevent degeneracy. + # + # Approach: the compatibility constraint E % K == 0 (enforced in ShramConfig) + # makes W = E / K an exact integer. A block of W consecutive tokens contains + # exactly W × K = E selection slots — one per expert. Enforcing that each + # expert is used exactly once per block makes the block perfectly balanced by + # construction, eliminating any need for auxiliary losses or correction steps. + # Enforcement is causal: at each of the W steps the current position picks its + # K experts from those not yet claimed earlier in the same block, by masking + # claimed experts with -inf before top-K. All W steps run simultaneously across + # blocks and batch via a Python for loop that is fully unrolled at compile time. + B, N, _ = x.shape L = self.num_mosrah_heads K = self.num_selected_heads + W = self.block_length - # ── Phase: pre-capacity scoring ─────────────────────────────────────── + # ── Phase: pre-balance scoring ───────────────────────────────────────── # - # Establishes the clean pre-sentinel distribution that all downstream - # consumers draw from. logit_std must be captured here — balance_capacity - # injects -1e8 sentinels that would corrupt the standard deviation. - # routing_scores is the pre-capacity probability distribution; both the - # load balance signal and the final routing_probs gather from it. + # Establishes the clean routing distribution before any -inf masking. + # logit_std is captured here because the block solver's masking would + # corrupt the standard deviation. routing_scores is used both for + # regret_loss and for the final routing_probs. routing_logits = self._compute_routing_logits(x) # (B, N, L) logit_std = routing_logits.std(dim=-1).mean().detach() routing_scores = F.softmax(routing_logits, dim=-1) # (B, N, L) - # ── Phase: load balance signal ──────────────────────────────────────── + # ── Phase: block-balanced causal selection ───────────────────────────── # - # The loss must observe the unconstrained routing decision — the genuine - # routing pressure before capacity enforcement masks any imbalance. - # pre_cap_heads and assignment_mask exist solely to give the loss this - # honest view; nothing downstream uses them. - pre_cap_heads = routing_scores.topk(K, dim=-1).indices # (B, N, K) - assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype) - assignment_mask.scatter_(-1, pre_cap_heads, 1.0) - - load_balance_loss = self._load_balance_loss( - routing_logits, assignment_mask, active_mask - ) - - # ── Phase: capacity enforcement and final selection ─────────────────── + # Three execution modes, distinguished by router_cache and sequence length: # - # Produces the capacity-enforced routing that all downstream consumers - # depend on. max_vio is computed here because it measures realized routing - # imbalance — the actual post-capacity assignment, not the unconstrained - # preference. routing_probs are gathered from the pre-capacity routing_scores - # (not the balanced distribution) to avoid sentinel corruption — overloaded - # experts would otherwise receive near-zero probability regardless of genuine - # routing preference. - balanced_logits = self.balance_capacity( - routing_logits, - used_capacity, - self.capacity, - self.num_selected_heads, - self.max_bid_rounds, - ) - selected_heads = F.softmax(balanced_logits, dim=-1).topk(K, dim=-1).indices # (B, N, K) + # Training (router_cache is None): the full sequence is available. All W + # steps of the block solver run simultaneously across every block in the + # sequence. No cache interaction. + # + # Prefill (router_cache is not None, N > 1): identical to training, but + # the partial last-block state is written to the cache so decode steps can + # continue within the same block without a gap. + # + # Decode (router_cache is not None, N == 1): one token arrives at a known + # position within the current block. The cached used_in_block mask is + # applied before TopK to enforce the one-usage-per-block contract, then + # the cache is updated in-place with this step's selections. - realized_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype) - realized_mask.scatter_(-1, selected_heads, 1.0) - max_vio = self._compute_max_vio(realized_mask, active_mask, L) + if router_cache is not None and N == 1: + # ── Decode mode ─────────────────────────────────────────────────── + # + # Single token; block position and claimed-expert state come from the + # cache. Treating this as a one-token, one-step block means the regret + # computation downstream sees a (B, 1, 1, K) assignment tensor and + # produces exactly zero regret, which is correct: with only one active + # token per "block" there is no alternative assignment with higher + # preference. + used_in_block = router_cache.get_used_in_block() # (B, L) + step_logits = routing_logits[:, 0, :] # (B, L) + available = step_logits.masked_fill(used_in_block, float('-inf')) + step_heads = available.topk(K, dim=-1).indices # (B, K) + + router_cache.update_decode(step_heads) + + selected_heads = step_heads.unsqueeze(1) # (B, 1, K) + else: + # ── Training / prefill mode ─────────────────────────────────────── + # + # The full N-token sequence is available. Padding extends it to a + # multiple of W; padded tokens occupy the tail of the last block and + # never consume experts needed by real tokens because the real tokens + # preceding them have already had their pick each step. The pad is + # discarded after the solver. + num_blocks = (N + W - 1) // W + N_pad = num_blocks * W + pad_len = N_pad - N + + if pad_len > 0: + padded_logits = torch.cat( + [routing_logits, routing_logits.new_zeros(B, pad_len, L)], dim=1 + ) # (B, N_pad, L) + else: + padded_logits = routing_logits + + blocked_logits = padded_logits.view(B, num_blocks, W, L) # (B, blk, W, L) + + # used_in_block tracks which experts have been claimed within each block. + # No gradient here — expert availability is a hard structural constraint, + # not a differentiable quantity. Gradient flows through routing_probs. + used_in_block = torch.zeros(B, num_blocks, L, dtype=torch.bool, device=x.device) + step_heads_list = [] + + for step in range(W): + step_logits = blocked_logits[:, :, step, :] # (B, blk, L) + + # Claimed experts receive -inf so top-K never selects them. + available = step_logits.masked_fill(used_in_block, float('-inf')) + step_heads = available.topk(K, dim=-1).indices # (B, blk, K) + step_heads_list.append(step_heads) + + # Mark the K chosen experts as unavailable for the rest of this block. + used_in_block = used_in_block.scatter(-1, step_heads, True) + + # Stack W steps and reshape to (B, N_pad, K), then unpad. + selected_heads_blocked = torch.stack(step_heads_list, dim=2) # (B, blk, W, K) + selected_heads = selected_heads_blocked.view(B, N_pad, K)[:, :N, :] # (B, N, K) + + if router_cache is not None: + # Prefill: persist the partial last-block state so decode steps + # that follow can continue within the same block. + router_cache.update_prefill(selected_heads_blocked, N) + + # ── Phase: regret loss ───────────────────────────────────────────────── + # + # Regret measures how much routing preference was sacrificed at each expert + # assignment relative to the peak active preference within the same block. + # A non-zero regret at expert l in block bl means some other active token + # in that block would have preferred expert l more than the one assigned. + # Minimising regret trains the router to save experts for the tokens that + # want them most. + # + # Decode mode returns zeros: regret is only defined over complete W-token + # blocks, and a single decode step is not a complete block. Backward is + # never called during inference so the zero is a correct no-op. + if router_cache is not None and N == 1: + regret_loss = routing_logits.new_zeros(()) + logit_regret = routing_logits.new_zeros(()).detach() + else: + regret_loss, logit_regret = self._compute_regret( + routing_scores, + routing_logits, + selected_heads_blocked, + active_mask, + ) + # ── Phase: routing probabilities ──────────────────────────────────────── + # + # Gathered from the pre-balance routing_scores to reflect genuine routing + # preference; renormalized so they sum to 1 per token. gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K) - routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K) + routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # (B, N, K) router_diagnostics = { - "load_balance_loss": load_balance_loss, - "max_vio": max_vio, - "logit_std": logit_std, + "regret_loss": regret_loss, + "logit_regret": logit_regret, + "logit_std": logit_std, } return selected_heads, routing_probs, router_diagnostics @@ -3712,34 +3228,120 @@ class MoSRAHRouter(nn.Module): return F.linear(x, self.routing_weight) # (B, N, L) @staticmethod - def _compute_max_vio( - assignment_mask: torch.Tensor, + def _compute_regret( + routing_scores: torch.Tensor, + routing_logits: torch.Tensor, + selected_heads_blocked: torch.Tensor, active_mask: torch.Tensor, - num_heads: int, - ) -> torch.Tensor: - """Compute the MaxVio routing-imbalance scalar. + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute regret_loss and logit_regret from a completed block assignment. + + Regret at expert l in block bl = max(p_max_active − p_chosen, 0), where + p_max_active is the highest routing probability any active token holds for + expert l within the block, and p_chosen is the routing probability of the + token actually assigned to expert l (0 if that token is dead). - MaxVio = mean_b( L · max_l(f_bl − 1/L) ), where f_bl is the per-batch-item - realised routing frequency of head l. Uses reduce_frequency_tokens for consistent - per-batch-item frequency computation with dead tokens excluded, matching how the - load balance loss computes frequencies. A value of zero indicates perfect balance; - a value of 0.5 means the most overloaded head in the average batch item received - 50% more routed tokens than ideal. + regret_loss is the mean over live (batch, block, expert) triples. A block is + live iff it contains at least one active token; all L experts in a live block + contribute. Result is in [0, 1]. - The result is detached — MaxVio is a monitoring scalar and must not contribute - gradients to any parameter. + logit_regret applies the same formula to routing_logits and is returned + detached — it is a monitoring scalar only, in [0, ∞). Args: - assignment_mask: Per-token head-assignment indicators, shape (B, N, L). - active_mask: Boolean active-token mask, shape (B, N). - num_heads: Total number of MoSRAH heads L. + routing_scores: Softmax routing probabilities, shape (B, N, L). + Gradient flows through this tensor into regret_loss. + routing_logits: Pre-softmax routing logits, shape (B, N, L). + Used only for the detached logit_regret. + selected_heads_blocked: Expert assignments from the block solver, + shape (B, num_blocks, W, K). Block geometry + (num_blocks, W) is derived from this shape. + active_mask: Boolean live-token mask, shape (B, N). Returns: - Detached scalar MaxVio tensor. + regret_loss: Gradient-carrying scalar in [0, 1]. + logit_regret: Detached scalar in [0, ∞). """ - f_bl = reduce_frequency_tokens(assignment_mask, active_mask) # (B, L) - per_item_max_vio = num_heads * (f_bl - 1.0 / num_heads).max(dim=-1).values # (B,) - return per_item_max_vio.mean().detach() + B, num_blocks, W, _K = selected_heads_blocked.shape + L = routing_scores.shape[-1] + N = routing_scores.shape[1] + N_pad = num_blocks * W + + # ── Reshape into block form ───────────────────────────────────────── + # + # Block geometry is read from selected_heads_blocked — no recomputation + # needed here. Padded tail positions receive zero scores and False + # activity; they do not contribute to any block metric. + if N_pad > N: + pad_len = N_pad - N + scores_blocked = torch.cat( + [routing_scores, routing_scores.new_zeros(B, pad_len, L)], dim=1 + ).view(B, num_blocks, W, L) # (B, nb, W, L) + logits_blocked = torch.cat( + [routing_logits, routing_logits.new_zeros(B, pad_len, L)], dim=1 + ).view(B, num_blocks, W, L) # (B, nb, W, L) + active_blocked = torch.cat( + [active_mask, active_mask.new_zeros(B, pad_len)], dim=1 + ).view(B, num_blocks, W) # (B, nb, W) + else: + scores_blocked = routing_scores.view(B, num_blocks, W, L) + logits_blocked = routing_logits.view(B, num_blocks, W, L) + active_blocked = active_mask.view(B, num_blocks, W) + + active_float = active_blocked.float() # (B, nb, W) + block_active = active_blocked.any(dim=-1) # (B, nb) + + # ── Assignment mask ───────────────────────────────────────────────── + # + # One-hot indicator of which token was assigned to each expert. Block + # balance guarantees exactly one entry per (b, bl, l) triple, so + # summing over W recovers exactly one score value per expert. + assignment_mask = scores_blocked.new_zeros(B, num_blocks, W, L) + assignment_mask.scatter_(dim=-1, index=selected_heads_blocked, value=1.0) + # (B, nb, W, L) + + # ── Prob regret (gradient flows through routing_scores) ───────────── + # + # p_chosen: routing score at the assigned token, gated by active_float + # so dead assignments contribute 0 — the expert accrues full regret + # against the active maximum rather than no penalty. + # p_max: peak routing score over active tokens; dead tokens zeroed before + # max (safe because softmax outputs are non-negative). + p_chosen = (assignment_mask * active_float.unsqueeze(-1) * scores_blocked).sum(dim=2) + # (B, nb, L) + p_max = (active_float.unsqueeze(-1) * scores_blocked).max(dim=2).values + # (B, nb, L) + + regret = (p_max - p_chosen).clamp(min=0.0) # (B, nb, L) + + # Mean over live (B, num_blocks, L) entries. Clamped to 1 for the + # all-dead edge case where the numerator is already 0. + num_live = block_active.float().sum() # scalar + regret_loss = ( + block_active.float().unsqueeze(-1) * regret + ).sum() / num_live.mul(L).clamp(min=1.0) + + # ── Logit regret (detached monitoring) ────────────────────────────── + # + # Same formula applied to routing_logits. Dead tokens cannot be zeroed + # before max (logits may be negative), so they are masked to -inf; + # dead blocks are replaced with 0 before subtraction. Detached so it + # never influences any parameter during backward. + logit_chosen = ( + assignment_mask * active_float.unsqueeze(-1) * logits_blocked + ).sum(dim=2) # (B, nb, L) + + logit_max = logits_blocked.masked_fill( + ~active_blocked.unsqueeze(-1), float('-inf') + ).max(dim=2).values # (B, nb, L) + logit_max = logit_max.masked_fill(~block_active.unsqueeze(-1), 0.0) + + logit_regret = ( + block_active.float().unsqueeze(-1) * (logit_max - logit_chosen).clamp(min=0.0) + ).sum() / num_live.mul(L).clamp(min=1.0) + logit_regret = logit_regret.detach() + + return regret_loss, logit_regret # ----------- # Inlined from: positions_converter.py @@ -3889,6 +3491,7 @@ class MoSRAHLayer(nn.Module): position_ids: torch.Tensor, active_mask: torch.Tensor, cache: MoSRAHCache | None, + router_cache: RouterCache | None = None, ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """Run the full MoSRAH sparse path. @@ -3906,9 +3509,8 @@ class MoSRAHLayer(nn.Module): Returns: sparse_output: Model-space sparse-path output of shape (B, N, d). router_diagnostics: Dict of router feedback scalars. Keys: - ``load_balance_loss`` (has grad), ``max_vio``, ``bias_std``, - ``raw_logit_std``, ``logit_std``, ``bias_alignment`` (all - detached). See MoSRAHRouter for semantics. + ``regret_loss`` (has grad), ``logit_regret`` (detached), + ``logit_std`` (detached). See MoSRAHRouter for semantics. """ # ------------------------------------------------------------------- @@ -3922,9 +3524,8 @@ class MoSRAHLayer(nn.Module): # B*N*K True entries) and the packed active mask (live slots only); # active_mask is rebound to the packed form after this point. # ------------------------------------------------------------------- - used_capacity = cache.get_heads_lengths() if cache is not None else None selected_heads, routing_probs, router_diagnostics = self.router( - hidden_states, active_mask, used_capacity + hidden_states, active_mask, router_cache ) setup = setup_packing(selected_heads) @@ -4032,9 +3633,11 @@ class SHRAMHybridLayer(nn.Module): if cache is None: sliding_window_cache = None mosrah_cache = None + router_cache = None else: sliding_window_cache = cache.sliding_window_cache mosrah_cache = cache.mosrah_cache + router_cache = cache.router_cache # ------------------------------------------------------------------- # Both attention paths must see the same model-space hidden state for @@ -4053,6 +3656,7 @@ class SHRAMHybridLayer(nn.Module): position_ids=position_ids, active_mask=active_mask, cache=mosrah_cache, + router_cache=router_cache, ) # ------------------------------------------------------------------- @@ -4242,20 +3846,19 @@ class ShramModel(nn.Module): inputs_embeds as position 0) if ``output_hidden_states`` is True, else None. Collected before the final norm so each entry reflects the unnormalised residual stream at that depth. - - ``"load_balance_loss"``: scalar sum of per-layer SHRAM - load-balance losses. - - ``"max_vio"``: detached scalar maximum routing-imbalance across - all decoder layers. Zero means perfectly balanced routing across - every layer; higher values identify the worst-case head imbalance. + - ``"regret_loss"``: scalar sum of per-layer SHRAM regret losses. + Gradient flows through this tensor into the router. + - ``"logit_regret"``: detached scalar — mean across layers of the + logit-space regret. Monitoring metric for assignment quality. - ``"logit_std"``: detached scalar — mean across layers of the per-token routing logit spread. Monitoring metric for routing sharpness. """ hidden_states = inputs_embeds all_hidden_states = (hidden_states,) if output_hidden_states else None - total_load_balance_loss = inputs_embeds.new_zeros(()) - max_vio = inputs_embeds.new_zeros(()) - total_logit_std = inputs_embeds.new_zeros(()) + total_regret_loss = inputs_embeds.new_zeros(()) + total_logit_regret = inputs_embeds.new_zeros(()) + total_logit_std = inputs_embeds.new_zeros(()) for layer_idx, layer in enumerate(self.layers): layer_cache = None if cache is None else cache.layers[layer_idx] @@ -4265,9 +3868,9 @@ class ShramModel(nn.Module): active_mask, cache=layer_cache, ) - total_load_balance_loss = total_load_balance_loss + layer_diagnostics["load_balance_loss"] - max_vio = torch.maximum(max_vio, layer_diagnostics["max_vio"]) - total_logit_std = total_logit_std + layer_diagnostics["logit_std"] + total_regret_loss = total_regret_loss + layer_diagnostics["regret_loss"] + total_logit_regret = total_logit_regret + layer_diagnostics["logit_regret"] + total_logit_std = total_logit_std + layer_diagnostics["logit_std"] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -4279,9 +3882,9 @@ class ShramModel(nn.Module): "last_hidden_state": hidden_states, "past_key_values": cache, "hidden_states": all_hidden_states, - "load_balance_loss": total_load_balance_loss, - "max_vio": max_vio, - "logit_std": total_logit_std / num_layers, + "regret_loss": total_regret_loss, + "logit_regret": total_logit_regret / num_layers, + "logit_std": total_logit_std / num_layers, } @@ -4298,13 +3901,13 @@ class ShramCausalLMOutput(CausalLMOutputWithPast): ## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all ## fields to None, which forces every subclass field to also carry a default. ## The = None below is a language constraint, not a semantic statement. In - ## practice, load_balance_loss, max_vio, and logit_std are always populated + ## practice, regret_loss, logit_regret, and logit_std are always populated ## by ShramForCausalLM.forward(). ce_loss is genuinely optional — present ## only when labels are supplied. ce_loss: torch.FloatTensor | None = None - load_balance_loss: torch.FloatTensor | None = None - max_vio: torch.FloatTensor | None = None + regret_loss: torch.FloatTensor | None = None + logit_regret: torch.Tensor | None = None logit_std: torch.Tensor | None = None class ShramForCausalLM(PreTrainedModel, GenerationMixin): @@ -4739,21 +4342,21 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin): against ``labels[:, 1:]``. Do not pre-shift the caller side. return_dict: Must be ``True`` or ``None``. ce_weight: Weight applied to the cross-entropy loss when combining with - the load-balance loss. Default 1.0. - load_balance_weight: Weight applied to the load-balance auxiliary loss. + the regret loss. Default 1.0. + load_balance_weight: Weight applied to the regret loss. Default 0.01, matching the paper's recommendation. **kwargs: Unsupported HuggingFace kwargs fail explicitly. Returns: ``ShramCausalLMOutput`` with: - ``logits`` of shape ``(batch, seq_len, vocab_size)``, - - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss`` + - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * regret_loss`` when labels are provided (``None`` otherwise), - ``ce_loss`` — raw unweighted cross-entropy loss for logging, - ``past_key_values`` as the active ``ShramCache`` or ``None``, - ``hidden_states`` when requested, - - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone, - - ``max_vio`` — detached worst-case routing imbalance across layers, + - ``regret_loss`` — raw unweighted regret loss from the backbone, + - ``logit_regret`` — detached mean logit-space regret across layers, - ``logit_std`` — detached mean per-token routing logit spread across layers. """ use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -4851,7 +4454,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin): shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ) - loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"] + loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["regret_loss"] return ShramCausalLMOutput( loss=loss, @@ -4859,7 +4462,7 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin): logits=logits, past_key_values=backbone_outputs["past_key_values"], hidden_states=backbone_outputs["hidden_states"], - load_balance_loss=backbone_outputs["load_balance_loss"], - max_vio=backbone_outputs["max_vio"], + regret_loss=backbone_outputs["regret_loss"], + logit_regret=backbone_outputs["logit_regret"], logit_std=backbone_outputs["logit_std"], ) \ No newline at end of file