Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM-dev with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM-dev with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM-dev", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM-dev", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use smithblack-0/SHRAM-dev with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM-dev" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM-dev
- SGLang
How to use smithblack-0/SHRAM-dev with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM-dev" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM-dev", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM-dev with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM-dev
| # This file is auto-generated by stage_for_hub.py from the source repository. | |
| # Do not edit it directly — changes will be overwritten on the next release. | |
| """HuggingFace causal-LM wrapper for SHRAM. | |
| ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM. | |
| It owns token embedding lookup, LM-head projection, wrapper-level next-token | |
| cross-entropy loss, config-controlled tied embeddings, and generation/cache | |
| orchestration at the wrapper boundary. | |
| The backbone remains a pure transformer stack. ShramModel accepts pre-embedded | |
| hidden states together with current position IDs, a current active mask, and an | |
| optional ShramCache. It has no knowledge of token IDs, vocabulary projection, | |
| or causal-LM loss. | |
| HuggingFace generation reaches this wrapper with two different tensor | |
| conventions: | |
| - ``position_ids`` is a current-step tensor. GenerationMixin updates the total | |
| sequence state between steps, then slices position-bearing tensors back down | |
| before calling ``forward()``. | |
| - ``attention_mask`` is a full 2D mask over the total sequence so far. This | |
| wrapper slices its recent chunk to produce the current semantic liveness mask | |
| expected by the backbone. | |
| Generation-created caches are handled in ``_prepare_cache_for_generation``. | |
| That hook ensures HuggingFace generation uses ShramCache rather than a generic | |
| dynamic cache. The direct ``forward()`` path does not silently create caches; | |
| when ``use_cache=True`` it expects a truthful ShramCache to have been supplied. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| from transformers import GenerationMixin | |
| from transformers import PreTrainedModel | |
| from transformers.cache_utils import Cache | |
| from transformers.generation.configuration_utils import GenerationMode | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| import math | |
| from transformers import PretrainedConfig | |
| from transformers.cache_utils import CacheLayerMixin | |
| from torch import nn | |
| from torch.nn.attention.flex_attention import create_block_mask | |
| from torch.nn.attention.flex_attention import flex_attention | |
| import torch.nn.functional as F | |
| # ----------- | |
| # Inlined from: shram_cache.py | |
| # ----------- | |
| """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack. | |
| The HuggingFace Cache protocol expects a single top-level Cache object that owns one | |
| CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level | |
| lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache. | |
| ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer, | |
| presents them through the Cache interface, and transparently forwards model-wide operations | |
| across all of them. | |
| ShramCache does not define a composite update() interface. The two attention paths inside each | |
| SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B) | |
| nor the model-level boundary here can meaningfully unify them. Callers must reach down to the | |
| relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide | |
| coordination of the layer caches — not routing attention inputs. | |
| Sequence length is reported by delegating to the local sliding-window sub-cache of the | |
| specified layer, which tracks the cumulative count of token positions processed. This is | |
| what HuggingFace generation reads through get_seq_length(). | |
| """ | |
| # ----------- | |
| # Inlined from: configuration.py | |
| # ----------- | |
| """Configuration for the SHRAM transformer. | |
| All architectural parameters that vary across model scales or are meaningful research | |
| variables are expressed here. Architectural constants (no bias in linear layers, | |
| SwiGLU activation with SiLU gate) are implemented in the relevant modules and | |
| documented at the point of use — they are not config parameters because they do not | |
| vary and changing them produces a different architecture, not a different scale. | |
| RoPE configuration is owned entirely by this config. Each attention path reads its | |
| parameters directly and constructs its own RotaryEmbedding instance explicitly — no | |
| HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md. | |
| """ | |
| class ShramConfig(PretrainedConfig): | |
| """Configuration class for the SHRAM decoder-only transformer. | |
| SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard | |
| attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a | |
| local sliding-window causal attention path and h_s is the MoSRAH sparse routed | |
| path. All other components follow the Llama 3 baseline. | |
| This config is the single source of truth for every architectural dimension of the | |
| model. Nothing in the architecture may use a literal number that belongs here. | |
| Two independent RoPE configurations exist — one per attention path: | |
| - h_l always uses standard RoPE with ``local_rope_theta``. | |
| - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``, | |
| ``inference_sequence_length``, ``alpha``, and ``beta``. When | |
| ``inference_sequence_length == training_sequence_length`` the YaRN scale factor | |
| ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state | |
| and the correct setting for experiments that do not require context extension. | |
| Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub:: | |
| config = AutoConfig.from_pretrained( | |
| "your-namespace/advanced-transformers-lib", | |
| trust_remote_code=True, | |
| num_decoder_layers=12, | |
| ) | |
| model = AutoModelForCausalLM.from_config(config) | |
| Args: | |
| vocab_size: Vocabulary size. Controls the embedding table and output logits | |
| dimension. Must match the tokenizer. | |
| embedding_width: Model width ``d``. The dimension of the residual stream. | |
| mlp_width: FFN hidden dimension. | |
| num_decoder_layers: Number of transformer blocks stacked in sequence. | |
| num_sliding_window_heads: Number of heads in the local sliding-window path h_l. | |
| num_mosrah_heads: Total MoSRAH expert heads available ``L``. | |
| num_selected_heads: MoSRAH heads each token selects ``K``. | |
| head_dim: Per-head dimension, shared by both attention paths. Must be even | |
| (RoPE rotates dimensions in pairs). Paper uses 16. | |
| window_size: Sliding window size for h_l. Paper uses 128. | |
| rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies | |
| original sequence positions; ``"semantic_sequence"`` supplies local slot | |
| indices. Both are required; experimentally correct mode is undetermined | |
| (paper §4). Default ``"main_sequence"``. | |
| rms_norm_eps: Epsilon for RMSNorm layers. | |
| local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l. | |
| Paper uses b=10000. | |
| mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses | |
| b=10000. | |
| training_sequence_length: Context length ``C_train`` the model was or will be | |
| trained at. Used to compute the YaRN scale factor for BEA. | |
| inference_sequence_length: Context length ``C_target`` the model must support | |
| at inference. Optional; defaults to ``training_sequence_length`` so that | |
| ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended. | |
| alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with | |
| ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0. | |
| beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with | |
| ``r(d) > beta`` are left unscaled. Paper value: 32.0. | |
| attention_dropout: Dropout probability on attention weights. Default 0.0. | |
| use_cache: Whether to return past_key_values for KV caching. | |
| output_hidden_states: Whether to return hidden states after each layer. | |
| tie_word_embeddings: Whether input embedding and LM head share weights. | |
| use_residual_gate: When True, each DecoderLayer gates its residual contributions | |
| with a learnable scalar parameter (init: zero). When False, uses a fixed | |
| ``1/√num_decoder_layers`` scale instead, which preserves O(1) residual | |
| variance at depth with no learnable gate. Default True. | |
| """ | |
| model_type = "shram" | |
| auto_map = { | |
| "AutoConfig": "configuration.ShramConfig", | |
| "AutoModelForCausalLM": "huggingface.ShramForCausalLM", | |
| } | |
| def __init__( | |
| self, | |
| vocab_size: int = 50277, | |
| embedding_width: int = 512, | |
| mlp_width: int = 1366, | |
| num_decoder_layers: int = 12, | |
| num_sliding_window_heads: int = 16, | |
| num_mosrah_heads: int = 16, | |
| num_selected_heads: int = 16, | |
| head_dim: int = 16, | |
| window_size: int = 128, | |
| rope_mode: str = "main_sequence", | |
| rms_norm_eps: float = 1e-5, | |
| local_rope_theta: float = 10000.0, | |
| mosrah_rope_theta: float = 10000.0, | |
| training_sequence_length: int = 1024, | |
| inference_sequence_length: int | None = None, | |
| alpha: float = 1.0, | |
| beta: float = 32.0, | |
| attention_dropout: float = 0.0, | |
| use_cache: bool = True, | |
| output_hidden_states: bool = False, | |
| tie_word_embeddings: bool = False, | |
| use_residual_gate: bool = True, | |
| **kwargs | |
| ): | |
| if head_dim % 2 != 0: | |
| raise ValueError( | |
| f"head_dim must be even (RoPE rotates dimensions in pairs). " | |
| f"Got head_dim={head_dim}." | |
| ) | |
| if rope_mode not in {"main_sequence", "semantic_sequence"}: | |
| raise ValueError( | |
| f"rope_mode must be 'main_sequence' or 'semantic_sequence', " | |
| f"got '{rope_mode}'." | |
| ) | |
| if training_sequence_length <= 0: | |
| raise ValueError( | |
| f"training_sequence_length must be positive, " | |
| f"got {training_sequence_length}." | |
| ) | |
| if inference_sequence_length is None: | |
| inference_sequence_length = training_sequence_length | |
| if inference_sequence_length <= 0: | |
| raise ValueError( | |
| f"inference_sequence_length must be positive, " | |
| f"got {inference_sequence_length}." | |
| ) | |
| if num_mosrah_heads % num_selected_heads != 0: | |
| raise ValueError( | |
| f"num_mosrah_heads must be exactly divisible by num_selected_heads. " | |
| f"Mechanical load balancing partitions the sequence into blocks of " | |
| f"W = num_mosrah_heads // num_selected_heads tokens; each block covers " | |
| f"every expert exactly once, which requires an integer W. " | |
| f"Got num_mosrah_heads={num_mosrah_heads}, num_selected_heads={num_selected_heads}." | |
| ) | |
| self.vocab_size = vocab_size | |
| self.embedding_width = embedding_width | |
| self.mlp_width = mlp_width | |
| self.num_decoder_layers = num_decoder_layers | |
| self.num_sliding_window_heads = num_sliding_window_heads | |
| self.num_mosrah_heads = num_mosrah_heads | |
| self.num_selected_heads = num_selected_heads | |
| self.head_dim = head_dim | |
| self.window_size = window_size | |
| self.rope_mode = rope_mode | |
| self.rms_norm_eps = rms_norm_eps | |
| self.local_rope_theta = local_rope_theta | |
| self.mosrah_rope_theta = mosrah_rope_theta | |
| self.training_sequence_length = training_sequence_length | |
| self.inference_sequence_length = inference_sequence_length | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.attention_dropout = attention_dropout | |
| self.use_cache = use_cache | |
| self.use_residual_gate = use_residual_gate | |
| super().__init__( | |
| tie_word_embeddings=tie_word_embeddings, | |
| output_hidden_states=output_hidden_states, | |
| **kwargs | |
| ) | |
| # Promote auto_map to an instance attribute so PretrainedConfig.to_dict() | |
| # serialises it into config.json. | |
| self.auto_map = type(self).auto_map | |
| def scale(self) -> float: | |
| """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length. | |
| When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency | |
| adjustments cancel and A_rope = 1. This is the default state. | |
| """ | |
| return self.inference_sequence_length / self.training_sequence_length | |
| def mosrah_packed_length(self) -> int: | |
| """Static packed time dimension T for expert packing. | |
| Mechanical load balancing guarantees exactly | |
| ``training_sequence_length * num_selected_heads / num_mosrah_heads`` | |
| tokens per expert. The ceiling handles non-integer results when | |
| training_sequence_length is not divisible by the block length W. | |
| All consumers of the packed buffer size must read this property rather | |
| than deriving T independently. | |
| """ | |
| return math.ceil( | |
| self.training_sequence_length | |
| * self.num_selected_heads | |
| / self.num_mosrah_heads | |
| ) + self.block_length | |
| def mosrah_cache_length(self) -> int: | |
| """Static per-(batch, head) slot capacity for the MoSRAH inference cache. | |
| Mechanical load balancing guarantees exactly | |
| ``inference_sequence_length * num_selected_heads / num_mosrah_heads`` | |
| tokens per expert over the full inference context. The ceiling handles | |
| non-integer results when inference_sequence_length is not divisible by | |
| the block length W. | |
| Distinct from ``mosrah_packed_length``, which sizes the training packing | |
| buffer using ``training_sequence_length``. This property uses | |
| ``inference_sequence_length`` because the cache must hold the full | |
| accumulated token history across the entire inference run. | |
| All consumers of the MoSRAH cache buffer size must read this property | |
| rather than deriving the capacity independently. | |
| """ | |
| return math.ceil( | |
| self.inference_sequence_length | |
| * self.num_selected_heads | |
| / self.num_mosrah_heads | |
| ) + self.block_length | |
| def block_length(self) -> int: | |
| """Routing block length W = num_mosrah_heads // num_selected_heads. | |
| Within each block of W consecutive tokens every expert is used exactly once, | |
| giving perfect load balance by construction. The E % K == 0 constraint | |
| enforced at construction guarantees W is an exact integer. | |
| All consumers of the routing block length must read this property rather | |
| than deriving W independently. | |
| """ | |
| return self.num_mosrah_heads // self.num_selected_heads | |
| # ----------- | |
| # Inlined from: shram_layer_cache.py | |
| # ----------- | |
| """SHRAM per-layer cache — composite owner for one SHRAM decoder layer. | |
| A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the | |
| local sliding-window path and the MoSRAH sparse path. Each path has its own cache with | |
| different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies | |
| the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention | |
| path can interact with it without indirection. | |
| ShramLayerCache does not define a composite update() interface. The two paths have materially | |
| different update semantics — the local side uses chunk-local key/value/mask concatenation | |
| while the MoSRAH side uses expert-choice scatter with an active mask — and merging these | |
| behind a single update() would hide those differences behind a misleading abstraction. Instead, | |
| each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the | |
| ownership, coordination, and reset/reorder boundary for one decoder layer. | |
| Sequence length at this boundary is reported by delegating to the local sliding-window | |
| sub-cache, which tracks the cumulative count of token positions processed. This is the | |
| quantity HuggingFace generation reads through get_seq_length(). | |
| """ | |
| # ----------- | |
| # Inlined from: mosrah_cache.py | |
| # ----------- | |
| """MoSRAH sparse KV cache — single-layer implementation. | |
| MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed | |
| by head rather than by sequence position. The routing is dynamic and produces a ragged | |
| distribution of token counts across (batch, head) slots — different batch items may | |
| route different numbers of tokens to the same head, and different heads accumulate at | |
| different rates. DynamicCache cannot represent this correctly: it concatenates along | |
| the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache | |
| therefore uses a custom buffer design. | |
| Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values | |
| attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert | |
| heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked | |
| head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the | |
| valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the | |
| buffer_capacity property and is derived directly from self.keys rather than tracked | |
| as a separate variable. | |
| The primary interface is update(key_states, value_states, active_mask), which accepts | |
| expert-choice layout, stores only active entries in causal order, and returns the full | |
| accumulated (keys, values, active_mask) for immediate use by BEA. The returned | |
| active_mask identifies valid cached positions; everything beyond each slot's count is | |
| junk data that downstream attention must exclude. | |
| BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts | |
| exposed by get_heads_lengths() must be read before update() if the caller needs the | |
| pre-update occupancy for position computation (Unit 10.A). update() increments counts | |
| in-place and the pre-update values are not recoverable afterward. | |
| All buffers are allocated at construction time. MoSRAHCache is constructed by | |
| ShramLayerCache, which has access to batch size, device, and all model config parameters | |
| needed to fully specify the storage layout upfront. | |
| """ | |
| class MoSRAHCache(CacheLayerMixin): | |
| """KV cache for the MoSRAH sparse attention path — single decoder layer. | |
| Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role. | |
| Stores keys and values in the mixin-standard self.keys and self.values attributes | |
| using a custom (B, L, T, u) layout rather than delegating to DynamicCache, | |
| which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly. | |
| All storage is allocated at construction time and is_initialized is True | |
| immediately. The caller (ShramLayerCache) provides batch size, device, and model | |
| config parameters so no lazy allocation is needed. | |
| Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a | |
| (B, L, T) boolean active_mask. Only positions where active_mask is True are written. | |
| This matches the packed representation produced by expert packing in the MoSRAH | |
| forward pass, where BEA has already applied RoPE before calling update(). | |
| Args: | |
| num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the | |
| second dimension of all storage tensors. | |
| head_dim: Bottlenecked head embedding width (u). Determines the fourth | |
| dimension of all storage tensors. | |
| batch_size: Number of sequences in the batch. Determines the first dimension | |
| of all storage tensors. | |
| device: Device on which to allocate all tensors. Should match the model device. | |
| mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to | |
| config.mosrah_cache_length. The buffer never grows; if any slot would exceed | |
| this capacity, update() raises in both eager and compiled modes. Increase | |
| mosrah_overallocation_factor in ShramConfig to resolve an overflow. | |
| """ | |
| is_compileable = True | |
| is_sliding = False | |
| def __init__( | |
| self, | |
| num_mosrah_heads: int, | |
| head_dim: int, | |
| batch_size: int, | |
| device: torch.device, | |
| mosrah_cache_length: int, | |
| ) -> None: | |
| super().__init__() | |
| self.num_mosrah_heads = num_mosrah_heads | |
| self.head_dim = head_dim | |
| self.batch_size = batch_size | |
| self.device = device | |
| self.mosrah_cache_length = mosrah_cache_length | |
| # Allocate primary storage into the mixin-standard self.keys / self.values so | |
| # that inherited methods (offload, prefetch) operate on real tensors. _counts | |
| # tracks valid occupancy per (batch, head) slot. | |
| self.keys: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device | |
| ) | |
| self.values: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device | |
| ) | |
| self._counts: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, dtype=torch.long, device=device | |
| ) | |
| # Storage is fully allocated at construction — the cache is initialized. | |
| self.is_initialized = True | |
| # --------------------------------------------------------------------------- | |
| # Properties | |
| # --------------------------------------------------------------------------- | |
| def buffer_capacity(self) -> int: | |
| """Current number of slots allocated per (batch, head) pair. | |
| Equal to mosrah_cache_length as supplied at construction. Derived from | |
| self.keys so it remains consistent with the actual buffer shape. | |
| """ | |
| return self.keys.shape[2] | |
| # --------------------------------------------------------------------------- | |
| # Primary API | |
| # --------------------------------------------------------------------------- | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Scatter active key/value states into the buffer and return the full cache state. | |
| Accepts expert-choice layout: key_states and value_states are (B, L, T, u); | |
| active_mask is (B, L, T) bool with True marking real tokens. Only active | |
| positions are written; inactive positions are ignored. | |
| Uses a fixed-shape destination mask constructed from per-slot write intervals | |
| to transfer active tokens into the buffer without any data-dependent shape | |
| operations. Active tokens are left-justified within each packed slot by the | |
| packing machinery, so the destination positions are a contiguous range | |
| starting at the current slot count — no cumsum or torch.where needed. | |
| Returns the full accumulated (keys, values, active_mask) across the cached | |
| sparse sequence. The returned active_mask is True exactly for slots t < | |
| counts[b, l]; everything beyond is junk data that BEA must exclude. | |
| Note: get_heads_lengths() must be called before update() if the caller needs | |
| the pre-update occupancy for position computation (Unit 10.A). update() | |
| increments counts in-place and the pre-update values are not recoverable. | |
| Args: | |
| key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout. | |
| value_states: Shape (B, L, T, u) — value vectors in expert-choice layout. | |
| active_mask: Shape (B, L, T) bool — True for real tokens, False for padding. | |
| cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature. | |
| Returns: | |
| Tuple of (keys, values, active_mask): | |
| keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots. | |
| values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots. | |
| active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written. | |
| """ | |
| incoming_delta = active_mask.long().sum(dim=2) # (B, L) | |
| post_counts = self._counts + incoming_delta | |
| self._check_no_overflow(post_counts.max(), self.mosrah_cache_length) | |
| # Build a fixed-shape destination mask in cache space. Active tokens within | |
| # each (b, l) slot are left-justified by the packing machinery, so they occupy | |
| # positions 0..s-1 in their packed slot. The corresponding cache positions are | |
| # write_start[b,l]..write_start[b,l]+write_count[b,l]-1. Broadcasting a | |
| # time arange against these per-slot intervals selects exactly the target | |
| # positions without any data-dependent shape query. | |
| write_start = self._counts.unsqueeze(-1) # cache position where new tokens begin | |
| write_count = incoming_delta.unsqueeze(-1) # number of new tokens arriving per slot | |
| time_arange = torch.arange( | |
| self.mosrah_cache_length, device=active_mask.device | |
| ) | |
| dest_mask = (time_arange >= write_start) & (time_arange < write_start + write_count) | |
| # dest_mask: (B, L, mosrah_cache_length) | |
| # Transfer key and value vectors. Left-justification guarantees that | |
| # dest_mask and active_mask have equal True counts per (b, l) slot, so the | |
| # boolean-mask transfer is correct without any explicit count verification. | |
| self.keys[dest_mask] = key_states[active_mask] | |
| self.values[dest_mask] = value_states[active_mask] | |
| self._counts[:] = post_counts[:] | |
| return self.keys, self.values, self._make_active_mask() | |
| def get_heads_lengths(self) -> torch.Tensor: | |
| """Return the per-(batch, head) token count for this layer. | |
| This is the authoritative occupancy tensor consumed by BEA for attention | |
| masking and by position computation (Unit 10.A) for semantic-sequence | |
| position computation. | |
| Note: in the MoSRAH forward pass, this must be called before update() if the | |
| caller needs the pre-update occupancy. update() increments these counts in-place. | |
| Returns: | |
| Integer tensor of shape (B, L) where entry [b, h] is the number of valid | |
| tokens stored in the (b, h) slot. Zero for slots with no writes yet. | |
| """ | |
| return self._counts | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — overridden coordination methods | |
| # --------------------------------------------------------------------------- | |
| def reset(self) -> None: | |
| """Clear all cached key and value tensors. | |
| Zeroes self.keys, self.values, and _counts in place. Storage remains allocated | |
| and is_initialized remains True — only the contents are cleared. | |
| """ | |
| self.keys.zero_() | |
| self.values.zero_() | |
| self._counts.zero_() | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| """Reorder the batch dimension of all cached tensors for beam search. | |
| Applied atomically across self.keys, self.values, and _counts. Beam search | |
| must reorder all three together or the occupancy counts and buffer contents | |
| will correspond to different beam hypotheses. | |
| Overrides the parent because the parent's implementation calls get_seq_length(), | |
| which is not supported for this cache. | |
| Args: | |
| beam_idx: Permutation indices of shape (batch,) produced by the beam | |
| search algorithm. | |
| """ | |
| self.keys = self.keys[beam_idx] | |
| self.values = self.values[beam_idx] | |
| self._counts = self._counts[beam_idx] | |
| def batch_repeat_interleave(self, repeats: int) -> None: | |
| """Expand the batch dimension by repeating each entry repeats times. | |
| Used at beam search initialisation to expand the cache from batch size B to | |
| B * repeats, matching the expanded beam candidate batch. Applied atomically | |
| across keys, values, and _counts; batch_size is updated to reflect the new size. | |
| Args: | |
| repeats: Number of times to repeat each batch entry. | |
| """ | |
| self.keys = self.keys.repeat_interleave(repeats, dim=0) | |
| self.values = self.values.repeat_interleave(repeats, dim=0) | |
| self._counts = self._counts.repeat_interleave(repeats, dim=0) | |
| self.batch_size = self.batch_size * repeats | |
| def batch_select_indices(self, indices: torch.Tensor) -> None: | |
| """Select a subset of batch entries by index. | |
| Used in contrastive search to retain only the selected candidate entries. | |
| Applied atomically across keys, values, and _counts; batch_size is updated | |
| to reflect the number of retained entries. | |
| Args: | |
| indices: 1-D integer tensor of batch indices to retain. | |
| """ | |
| self.keys = self.keys[indices] | |
| self.values = self.values[indices] | |
| self._counts = self._counts[indices] | |
| self.batch_size = indices.shape[0] | |
| def offload(self) -> None: | |
| """Offload all cached tensors to CPU. | |
| Extends the parent to also offload _counts, which the parent does not know | |
| about. All three tensors are moved atomically so device state remains consistent. | |
| """ | |
| super().offload() | |
| self._counts = self._counts.to("cpu", non_blocking=True) | |
| def prefetch(self) -> None: | |
| """Move all cached tensors back to the model device ahead of time. | |
| Extends the parent to also prefetch _counts, which the parent does not know | |
| about. _counts is synced to self.keys.device after the parent moves keys and | |
| values, so all three remain consistent. | |
| """ | |
| super().prefetch() | |
| if self._counts.device != self.keys.device: | |
| self._counts = self._counts.to(self.keys.device, non_blocking=True) | |
| def lazy_initialization( # type: ignore[override] | |
| self, key_states: torch.Tensor, value_states: torch.Tensor | |
| ) -> None: | |
| """No-op — storage is fully allocated at construction time.""" | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — unsupported abstract methods | |
| # --------------------------------------------------------------------------- | |
| def get_seq_length(self) -> int: # type: ignore[override] | |
| """Not supported — no single sequence length represents this cache's state. | |
| MoSRAH heads accumulate independently; (batch, head) slots have different | |
| lengths depending on routing history. There is no meaningful scalar summary. | |
| Use get_heads_lengths() for per-head occupancy. | |
| """ | |
| raise NotImplementedError( | |
| "MoSRAHCache has no single sequence length. " | |
| "Use get_heads_lengths() for per-head occupancy." | |
| ) | |
| def get_max_cache_shape(self) -> int: # type: ignore[override] | |
| """Return the static per-(batch, head) slot capacity of this cache. | |
| Equal to mosrah_cache_length as supplied at construction, which is derived | |
| from config.mosrah_cache_length. Required by the HuggingFace static cache | |
| contract; generation machinery uses this to size attention masks. | |
| """ | |
| return self.mosrah_cache_length | |
| def get_mask_sizes( # type: ignore[override] | |
| self, | |
| cache_position: torch.Tensor, | |
| ) -> tuple[int, int]: | |
| """Not supported — MoSRAHCache does not participate in HF mask construction.""" | |
| raise NotImplementedError( | |
| "MoSRAHCache does not support get_mask_sizes()." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Internal helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_active_mask(self) -> torch.Tensor: | |
| """Construct the (B, L, T) active mask from current counts. | |
| Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot | |
| has been written. Positions at or beyond the count are junk and must be | |
| excluded by downstream attention. | |
| """ | |
| cap = self.buffer_capacity | |
| return ( | |
| torch.arange(cap, device=self.keys.device) | |
| .expand(self.batch_size, self.num_mosrah_heads, cap) | |
| < self._counts.unsqueeze(-1) | |
| ) | |
| def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None: | |
| """Raise if any (batch, head) slot would exceed the static buffer capacity. | |
| Branches on whether the graph is being compiled. In compiled mode, | |
| torch._assert_async fires asynchronously on the GPU when the condition | |
| tensor is False. In eager mode, a plain RuntimeError is raised with a | |
| descriptive message. | |
| Args: | |
| max_count: Scalar tensor — the maximum post-update count across all slots. | |
| capacity: The static buffer capacity (mosrah_cache_length). | |
| """ | |
| if torch.compiler.is_compiling(): | |
| torch._assert_async( | |
| max_count <= capacity, | |
| "MoSRAHCache overflow: buffer capacity exceeded. " | |
| "Increase mosrah_overallocation_factor in ShramConfig.", | |
| ) | |
| else: | |
| if max_count.item() > capacity: | |
| raise RuntimeError( | |
| f"MoSRAHCache overflow: a (batch, head) slot would reach " | |
| f"{max_count.item()} tokens but the static buffer capacity is " | |
| f"{capacity}. Increase mosrah_overallocation_factor in ShramConfig." | |
| ) | |
| # ----------- | |
| # Inlined from: router_cache.py | |
| # ----------- | |
| """Block-state cache for the MoSRAH causal block-balanced router. | |
| The block-balanced router partitions the token sequence into non-overlapping blocks | |
| of W = L/K tokens. Within each block every expert is assigned exactly once, giving | |
| perfect load balance by construction. During training the full sequence is available | |
| and block state is managed locally in MoSRAHRouter.forward(). During inference tokens | |
| arrive one at a time and the router must remember which experts have been claimed in | |
| the current partial block across decode steps. | |
| RouterCache holds two pieces of state across decode steps: | |
| - _used_in_block: Boolean mask (B, L) tracking which experts have been claimed by | |
| earlier tokens in the current block. The decode router masks these to -inf before | |
| TopK, preserving the one-usage-per-block invariant. | |
| - _step_in_block: Integer counter (B,) of how many tokens have been processed in | |
| the current block. Reaches block_length W when the block completes, at which | |
| point both tensors are reset in-place for the next block. | |
| All decode-step operations (update_decode) use fixed-shape in-place tensor ops and | |
| are fully compileable under torch.compile(dynamic=False, fullgraph=True). The prefill | |
| update (update_prefill) may use data-dependent indexing and must not be called inside | |
| a compiled graph; prefill runs in eager mode before the compiled decode loop in | |
| standard HuggingFace generate(). | |
| RouterCache is constructed by ShramLayerCache and passed directly to | |
| MoSRAHRouter.forward(). ShramLayerCache.reset() clears the router state atomically | |
| with the KV caches it also owns. | |
| """ | |
| class RouterCache(CacheLayerMixin): | |
| """Block-state cache for the MoSRAH causal block-balanced router. | |
| Tracks which experts have been claimed in the current routing block and how | |
| far into that block the current decode step is. This allows the router to | |
| maintain its one-usage-per-block contract across decode steps without | |
| reprocessing the full accumulated sequence. | |
| All state is pre-allocated at construction time. The primary decode method | |
| (update_decode) uses only in-place fixed-shape operations and is fully | |
| compileable. | |
| Args: | |
| block_length: Tokens per routing block, W = num_mosrah_heads // num_selected_heads. | |
| The router resets block state after every W consecutive decode tokens. | |
| num_mosrah_heads: Total expert count L. Determines the width of the | |
| used-expert mask. | |
| batch_size: Number of sequences in the batch. | |
| device: Device on which to allocate state tensors. | |
| """ | |
| is_compileable = True | |
| is_sliding = False | |
| def __init__( | |
| self, | |
| block_length: int, | |
| num_mosrah_heads: int, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> None: | |
| super().__init__() | |
| self._block_length = block_length | |
| self._device = device | |
| # used_in_block: which experts are already claimed in the current block. | |
| # False = expert is still available for the next decode token that needs it. | |
| # Reset to all-False when step_in_block reaches block_length. | |
| self._used_in_block = torch.zeros( | |
| batch_size, num_mosrah_heads, dtype=torch.bool, device=device | |
| ) | |
| # step_in_block: how many tokens have been processed in the current block. | |
| # Range [0, block_length - 1]. Resets to 0 when a block completes. | |
| self._step_in_block = torch.zeros(batch_size, dtype=torch.int64, device=device) | |
| # --------------------------------------------------------------------------- | |
| # is_initialized — pre-allocated at construction, always True | |
| # --------------------------------------------------------------------------- | |
| def is_initialized(self) -> bool: | |
| """True always — RouterCache pre-allocates all state at construction.""" | |
| return True | |
| def is_initialized(self, value: bool) -> None: | |
| # CacheLayerMixin.__init__ assigns self.is_initialized = False as an | |
| # instance attribute. Absorb it silently — state is always initialized. | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Public interface for the router | |
| # --------------------------------------------------------------------------- | |
| def get_used_in_block(self) -> torch.Tensor: | |
| """Return the current block's used-expert mask. | |
| Returns: | |
| Boolean mask of shape (B, L). True entries mark experts already claimed | |
| by earlier tokens in the current block and must be excluded from TopK. | |
| """ | |
| return self._used_in_block | |
| def update_decode(self, step_heads: torch.Tensor) -> None: | |
| """Record a single decode-step expert selection and advance the block counter. | |
| Marks the K selected experts as used in the current block, then either | |
| advances the per-batch step counter or resets both tensors in-place when | |
| the block completes. All operations are in-place and compile-compatible. | |
| Args: | |
| step_heads: Expert indices selected at this decode step, shape (B, K). | |
| """ | |
| # Mark the K selected experts as unavailable for the rest of this block. | |
| self._used_in_block.scatter_(-1, step_heads, True) | |
| # Detect block completion before incrementing: step was W-1 (0-indexed), | |
| # meaning this token is the last one in the current block. | |
| block_done = self._step_in_block.eq(self._block_length - 1) # (B,) bool | |
| # Advance step counter, then zero it for any batch item that just finished a block. | |
| self._step_in_block.add_(1) | |
| self._step_in_block.masked_fill_(block_done, 0) | |
| # Clear expert availability for batch items that completed a block, so the | |
| # next decode token for those items starts with a clean slate. | |
| self._used_in_block.masked_fill_(block_done.unsqueeze(-1), False) | |
| def update_prefill( | |
| self, | |
| selected_heads_blocked: torch.Tensor, | |
| seq_len: int, | |
| ) -> None: | |
| """Record the partial block state left over at the end of a prefill pass. | |
| After processing a prefill sequence of length seq_len with the training-style | |
| block solver, the last block may be incomplete when seq_len is not a multiple | |
| of block_length. This method saves the partial block state so decode steps can | |
| continue the current block without a gap. | |
| Not compile-compatible: uses a data-dependent slice [:seq_mod] on the W | |
| dimension. Must only be called in eager mode. Standard HuggingFace generate() | |
| runs prefill in eager before entering the compiled decode loop. | |
| Args: | |
| selected_heads_blocked: Block-solver assignment output from the prefill pass, | |
| shape (B, num_blocks, W, K). The final block entry contains expert | |
| assignments for both real tokens (steps 0..seq_mod-1) and padding | |
| artefacts (steps seq_mod..W-1) which must be discarded. | |
| seq_len: Actual prefill sequence length before block padding. Determines | |
| how many steps of the last block contain real assignments. | |
| """ | |
| B = selected_heads_blocked.shape[0] | |
| seq_mod = seq_len % self._block_length | |
| self._used_in_block.zero_() | |
| if seq_mod == 0: | |
| # All blocks were complete — start fresh for the next decode token. | |
| self._step_in_block.zero_() | |
| else: | |
| # Last block is partial: only the first seq_mod steps are real assignments. | |
| # Rebuild the used-expert mask from those steps and record the step position. | |
| last_block_real_steps = selected_heads_blocked[:, -1, :seq_mod, :] # (B, seq_mod, K) | |
| real_experts_flat = last_block_real_steps.reshape(B, -1) # (B, seq_mod * K) | |
| self._used_in_block.scatter_(-1, real_experts_flat, True) | |
| self._step_in_block.fill_(seq_mod) | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — reset and beam-search coordination | |
| # --------------------------------------------------------------------------- | |
| def reset(self) -> None: | |
| """Clear block state for a new generation session. | |
| Zeros both state tensors in-place. Called by ShramLayerCache.reset() | |
| atomically with the KV cache reset. | |
| """ | |
| self._used_in_block.zero_() | |
| self._step_in_block.zero_() | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| """Reorder the batch dimension for beam search. | |
| Args: | |
| beam_idx: Permutation indices of shape (batch,). | |
| """ | |
| self._used_in_block = self._used_in_block[beam_idx] | |
| self._step_in_block = self._step_in_block[beam_idx] | |
| def batch_repeat_interleave(self, repeats: int) -> None: | |
| """Expand the batch dimension for beam search initialisation. | |
| Args: | |
| repeats: Number of times to repeat each batch entry along the batch dimension. | |
| """ | |
| self._used_in_block = self._used_in_block.repeat_interleave(repeats, dim=0) | |
| self._step_in_block = self._step_in_block.repeat_interleave(repeats, dim=0) | |
| def batch_select_indices(self, indices: torch.Tensor) -> None: | |
| """Select a subset of batch entries for contrastive search. | |
| Args: | |
| indices: 1-D integer tensor of batch indices to retain. | |
| """ | |
| self._used_in_block = self._used_in_block[indices] | |
| self._step_in_block = self._step_in_block[indices] | |
| def offload(self) -> None: | |
| """Move state tensors to CPU for memory management between decode steps.""" | |
| self._used_in_block = self._used_in_block.cpu() | |
| self._step_in_block = self._step_in_block.cpu() | |
| def prefetch(self) -> None: | |
| """Move state tensors back to model device ahead of the next decode step.""" | |
| self._used_in_block = self._used_in_block.to(self._device) | |
| self._step_in_block = self._step_in_block.to(self._device) | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — unsupported abstract methods | |
| # --------------------------------------------------------------------------- | |
| def lazy_initialization( # type: ignore[override] | |
| self, key_states: torch.Tensor, value_states: torch.Tensor | |
| ) -> None: | |
| """No-op — RouterCache pre-allocates all state at construction.""" | |
| pass | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Not supported — use update_decode() or update_prefill() instead.""" | |
| raise NotImplementedError( | |
| "RouterCache has no composite key/value update interface. " | |
| "Use update_decode() for single decode steps or update_prefill() after prefill." | |
| ) | |
| def get_seq_length(self) -> int: | |
| """Not supported — RouterCache tracks block position, not sequence length.""" | |
| raise NotImplementedError("RouterCache does not track sequence length.") | |
| def get_max_cache_shape(self) -> int: | |
| """Not supported — RouterCache does not hold KV pairs.""" | |
| raise NotImplementedError("RouterCache does not have a KV cache shape.") | |
| def get_mask_sizes( # type: ignore[override] | |
| self, | |
| cache_position: torch.Tensor, | |
| ) -> tuple[int, int]: | |
| """Not supported — RouterCache does not participate in KV attention masking.""" | |
| raise NotImplementedError("RouterCache does not participate in KV masking.") | |
| # ----------- | |
| # Inlined from: sliding_window_cache.py | |
| # ----------- | |
| # src/shram/model/cache/sliding_window_cache.py | |
| """Local sliding-window cache for the SHRAM local attention path. | |
| This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by | |
| `ShramLayerCache` and consumed by `SlidingWindowAttention`. | |
| Its job is narrow: | |
| - accept the current chunk's local key/value tensors and active mask | |
| - return the current-step local frame consumed by local attention | |
| - separately retain the next-step sliding-window cache state | |
| It does not decide local causal visibility. That is owned by | |
| `SlidingWindowAttention`, which consumes the returned key/value/mask frame and | |
| constructs the effective local attention mask from it. | |
| """ | |
| class LocalSlidingWindowLayerCache(CacheLayerMixin): | |
| """Fixed-width local cache for one SHRAM decoder layer. | |
| The cache keeps a retained local sliding-window buffer and an aligned active | |
| mask. On update, it returns the current-step local frame formed by | |
| concatenating retained cache state with the new chunk, then remembers only | |
| the last `sliding_window` positions for the next step. | |
| Dead positions are allowed to remain in both the returned frame and the | |
| retained cache. Correctness is carried by the aligned active mask. | |
| Args: | |
| sliding_window: Width of the retained local sliding-window buffer. | |
| num_heads: Number of local attention heads. | |
| head_dim: Per-head embedding width for the local path. | |
| batch_size: Number of sequences in the batch. | |
| device: Device on which to allocate cache storage. | |
| """ | |
| is_compileable = True | |
| is_sliding = True | |
| def __init__( | |
| self, | |
| sliding_window: int, | |
| num_heads: int, | |
| head_dim: int, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> None: | |
| super().__init__() | |
| if sliding_window < 1: | |
| raise ValueError( | |
| f"sliding_window must be >= 1, got {sliding_window}." | |
| ) | |
| if num_heads < 1: | |
| raise ValueError(f"num_heads must be >= 1, got {num_heads}.") | |
| if head_dim < 1: | |
| raise ValueError(f"head_dim must be >= 1, got {head_dim}.") | |
| if batch_size < 1: | |
| raise ValueError(f"batch_size must be >= 1, got {batch_size}.") | |
| self.sliding_window = sliding_window | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.batch_size = batch_size | |
| self.device = device | |
| # Retained next-step local cache state. Storage is fixed-width from the | |
| # start; semantic validity is carried by `active_mask`. | |
| self.keys = torch.zeros( | |
| batch_size, | |
| num_heads, | |
| sliding_window, | |
| head_dim, | |
| device=device, | |
| ) | |
| self.values = torch.zeros( | |
| batch_size, | |
| num_heads, | |
| sliding_window, | |
| head_dim, | |
| device=device, | |
| ) | |
| self.active_mask = torch.zeros( | |
| batch_size, | |
| sliding_window, | |
| dtype=torch.bool, | |
| device=device, | |
| ) | |
| # Absolute sequence positions of each retained slot. Inactive slots | |
| # retain zero; correctness is carried by active_mask. | |
| self.positions = torch.zeros( | |
| batch_size, | |
| sliding_window, | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| self.is_initialized = True | |
| # Cumulative count of all token positions presented through update() for | |
| # this cache instance. This is the quantity HuggingFace generation reads | |
| # through get_seq_length() to track how far along the sequence we are. | |
| self._total_processed = torch.tensor(0) | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| positions: torch.Tensor, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Return the current-step local frame and retain the next-step window. | |
| Args: | |
| key_states: Shape `(B, H, T_new, D)` local key vectors for the | |
| current chunk. | |
| value_states: Shape `(B, H, T_new, D)` local value vectors for the | |
| current chunk. | |
| active_mask: Shape `(B, T_new)` bool. `True` means the | |
| corresponding token position in the current chunk is active. | |
| positions: Shape `(B, T_new)` long. Absolute sequence position of | |
| each token in the current chunk. | |
| cache_kwargs: Present only to satisfy the `CacheLayerMixin` | |
| interface. Unused by this cache. | |
| Returns: | |
| Tuple of: | |
| - visible_keys: `(B, H, sliding_window + T_new, D)` | |
| - visible_values: `(B, H, sliding_window + T_new, D)` | |
| - visible_active_mask: `(B, sliding_window + T_new)` | |
| - visible_positions: `(B, sliding_window + T_new)` | |
| These are the tensors the local attention path should consume | |
| directly for the current step. | |
| """ | |
| self._ensure_state_compatibility( | |
| key_states=key_states, | |
| value_states=value_states, | |
| ) | |
| # The current-step local frame is just retained cache state followed by | |
| # the current chunk in chronological order. | |
| composite_keys, composite_values, composite_mask, composite_positions = self._make_composite_frame( | |
| key_states=key_states, | |
| value_states=value_states, | |
| active_mask=active_mask, | |
| positions=positions, | |
| ) | |
| # The cache remembers only the last raw sliding-window positions of that | |
| # composite frame for the next step. Dead positions are allowed to | |
| # survive; downstream local attention will ignore them using the mask. | |
| self._retain_next_window( | |
| composite_keys=composite_keys, | |
| composite_values=composite_values, | |
| composite_mask=composite_mask, | |
| composite_positions=composite_positions, | |
| ) | |
| self._total_processed += key_states.shape[2] | |
| return composite_keys, composite_values, composite_mask, composite_positions | |
| def _ensure_state_compatibility( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| ) -> None: | |
| """Keep retained cache buffers compatible with the incoming update tensors. | |
| The cache is allocated eagerly for simplicity. If later updates arrive on | |
| a different device or in a different floating dtype, move the retained | |
| state to match while preserving its contents. | |
| """ | |
| if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device: | |
| self.keys = self.keys.to( | |
| device=key_states.device, | |
| dtype=key_states.dtype, | |
| ) | |
| if ( | |
| self.values.dtype != value_states.dtype | |
| or self.values.device != value_states.device | |
| ): | |
| self.values = self.values.to( | |
| device=value_states.device, | |
| dtype=value_states.dtype, | |
| ) | |
| if self.active_mask.device != key_states.device: | |
| self.active_mask = self.active_mask.to( | |
| key_states.device, | |
| non_blocking=True, | |
| ) | |
| if self.positions.device != key_states.device: | |
| self.positions = self.positions.to( | |
| key_states.device, | |
| non_blocking=True, | |
| ) | |
| def _make_composite_frame( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| positions: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Build the current-step local frame in chronological order.""" | |
| return ( | |
| torch.cat([self.keys, key_states], dim=-2), | |
| torch.cat([self.values, value_states], dim=-2), | |
| torch.cat([self.active_mask, active_mask], dim=-1), | |
| torch.cat([self.positions, positions], dim=-1), | |
| ) | |
| def _retain_next_window( | |
| self, | |
| composite_keys: torch.Tensor, | |
| composite_values: torch.Tensor, | |
| composite_mask: torch.Tensor, | |
| composite_positions: torch.Tensor, | |
| ) -> None: | |
| """Remember the next-step retained local state. | |
| This is a raw positional trim to the last `sliding_window` positions, not | |
| a semantic live-token trim. | |
| """ | |
| self.keys[:] = composite_keys[:, :, -self.sliding_window :, :] | |
| self.values[:] = composite_values[:, :, -self.sliding_window :, :] | |
| self.active_mask[:] = composite_mask[:, -self.sliding_window :] | |
| self.positions[:] = composite_positions[:, -self.sliding_window :] | |
| def get_seq_length(self) -> int: | |
| """Return the cumulative number of token positions processed by this cache. | |
| This is the total count of token positions presented across all update() | |
| calls since construction or the last reset(). It is the quantity HuggingFace | |
| generation reads to track sequence progress and is not the same as active-token | |
| count or current window occupancy. | |
| """ | |
| return int(self._total_processed) | |
| def get_max_cache_shape(self) -> int: | |
| return self.sliding_window | |
| def get_mask_sizes( # type: ignore[override] | |
| self, | |
| cache_position: torch.Tensor, | |
| ) -> tuple[int, int]: | |
| raise NotImplementedError( | |
| "LocalSlidingWindowLayerCache does not support get_mask_sizes()." | |
| ) | |
| def reset(self) -> None: | |
| """Restore fresh-cache behavior.""" | |
| self.keys.zero_() | |
| self.values.zero_() | |
| self.active_mask.zero_() | |
| self.positions.zero_() | |
| self._total_processed = 0 | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| """Reorder the batch dimension for beam search.""" | |
| self.keys = self.keys[beam_idx] | |
| self.values = self.values[beam_idx] | |
| self.active_mask = self.active_mask[beam_idx] | |
| self.positions = self.positions[beam_idx] | |
| def batch_repeat_interleave(self, repeats: int) -> None: | |
| """Expand the batch dimension for beam-search initialisation.""" | |
| self.keys = self.keys.repeat_interleave(repeats, dim=0) | |
| self.values = self.values.repeat_interleave(repeats, dim=0) | |
| self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0) | |
| self.positions = self.positions.repeat_interleave(repeats, dim=0) | |
| self.batch_size = self.batch_size * repeats | |
| def batch_select_indices(self, indices: torch.Tensor) -> None: | |
| """Select a subset of batch entries for contrastive search.""" | |
| self.keys = self.keys[indices] | |
| self.values = self.values[indices] | |
| self.active_mask = self.active_mask[indices] | |
| self.positions = self.positions[indices] | |
| self.batch_size = int(indices.shape[0]) | |
| def offload(self) -> None: | |
| """Offload cache tensors to CPU.""" | |
| super().offload() | |
| self.active_mask = self.active_mask.to("cpu", non_blocking=True) | |
| self.positions = self.positions.to("cpu", non_blocking=True) | |
| def prefetch(self) -> None: | |
| """Move cache tensors back to the model device ahead of time.""" | |
| super().prefetch() | |
| if self.active_mask.device != self.keys.device: | |
| self.active_mask = self.active_mask.to( | |
| self.keys.device, | |
| non_blocking=True, | |
| ) | |
| self.positions = self.positions.to( | |
| self.keys.device, | |
| non_blocking=True, | |
| ) | |
| def crop(self, max_length: int) -> None: | |
| raise NotImplementedError( | |
| "LocalSlidingWindowLayerCache does not support crop()." | |
| ) | |
| def lazy_initialization( | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| ) -> None: | |
| """No-op — this cache allocates its fixed buffers at construction time.""" | |
| return | |
| class ShramLayerCache(CacheLayerMixin): | |
| """Cache subsystem for one SHRAM decoder layer. | |
| Owns and coordinates three sub-caches: | |
| - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path. | |
| - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path. | |
| - router_cache: RouterCache for the block-balanced router's block state. | |
| Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The sub-caches are | |
| exposed directly for their downstream consumers — no composite update() interface is | |
| provided, because the paths have materially different update semantics. | |
| Sequence length is reported by delegating to the local sliding-window sub-cache, which | |
| tracks the cumulative count of token positions processed across all update() calls. | |
| Args: | |
| config: ShramConfig instance. All sub-cache dimensions and capacities are derived | |
| from config so that a single source of truth governs every buffer size. | |
| batch_size: Number of sequences in the batch. | |
| device: Device on which to allocate cache tensors. | |
| """ | |
| is_compileable = True | |
| is_sliding = False | |
| def __init__( | |
| self, | |
| config: ShramConfig, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> None: | |
| super().__init__() | |
| self._inference_sequence_length = config.inference_sequence_length | |
| self.sliding_window_cache = LocalSlidingWindowLayerCache( | |
| sliding_window=config.window_size, | |
| num_heads=config.num_sliding_window_heads, | |
| head_dim=config.head_dim, | |
| batch_size=batch_size, | |
| device=device, | |
| ) | |
| self.mosrah_cache = MoSRAHCache( | |
| num_mosrah_heads=config.num_mosrah_heads, | |
| head_dim=config.head_dim, | |
| batch_size=batch_size, | |
| device=device, | |
| mosrah_cache_length=config.mosrah_cache_length, | |
| ) | |
| self.router_cache = RouterCache( | |
| block_length=config.block_length, | |
| num_mosrah_heads=config.num_mosrah_heads, | |
| batch_size=batch_size, | |
| device=device, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Properties | |
| # --------------------------------------------------------------------------- | |
| def is_initialized(self) -> bool: | |
| """True iff both sub-caches have allocated their storage. | |
| Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction, | |
| so this is True immediately after ShramLayerCache.__init__ returns. | |
| """ | |
| return ( | |
| self.sliding_window_cache.is_initialized | |
| and self.mosrah_cache.is_initialized | |
| and self.router_cache.is_initialized | |
| ) | |
| def is_initialized(self, value: bool) -> None: | |
| # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance | |
| # attribute. Since property is a data descriptor it takes precedence, but Python | |
| # still routes the assignment through __set__. Absorb it silently — state is | |
| # derived from sub-caches, not stored here. | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — composite-meaningful methods | |
| # --------------------------------------------------------------------------- | |
| def get_seq_length(self) -> int: # type: ignore[override] | |
| """Return the cumulative sequence length from the local sliding-window path. | |
| The local path is authoritative for sequence progress: it sees every token | |
| presented to this layer and accumulates a truthful total. Delegates to | |
| sliding_window_cache.get_seq_length(). | |
| """ | |
| return self.sliding_window_cache.get_seq_length() | |
| def reset(self) -> None: | |
| """Clear both sub-caches. | |
| Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window | |
| state and MoSRAH sparse state remain consistent. | |
| """ | |
| self.sliding_window_cache.reset() | |
| self.mosrah_cache.reset() | |
| self.router_cache.reset() | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| """Reorder the batch dimension of both sub-caches for beam search. | |
| Delegates to each sub-cache. Both are reordered atomically so the sliding-window | |
| and MoSRAH state correspond to the same beam hypotheses after reordering. | |
| Args: | |
| beam_idx: Permutation indices of shape (batch,) produced by beam search. | |
| """ | |
| self.sliding_window_cache.reorder_cache(beam_idx) | |
| self.mosrah_cache.reorder_cache(beam_idx) | |
| self.router_cache.reorder_cache(beam_idx) | |
| def batch_repeat_interleave(self, repeats: int) -> None: | |
| """Expand the batch dimension of both sub-caches for beam search initialisation. | |
| Delegates atomically to each sub-cache. Both must be expanded together so the | |
| sliding-window and MoSRAH state correspond to the same beam candidates. | |
| Args: | |
| repeats: Number of times to repeat each batch entry. | |
| """ | |
| self.sliding_window_cache.batch_repeat_interleave(repeats) | |
| self.mosrah_cache.batch_repeat_interleave(repeats) | |
| self.router_cache.batch_repeat_interleave(repeats) | |
| def batch_select_indices(self, indices: torch.Tensor) -> None: | |
| """Select a subset of batch entries in both sub-caches for contrastive search. | |
| Delegates atomically to each sub-cache. Both must be trimmed together so the | |
| sliding-window and MoSRAH state remain consistent. | |
| Args: | |
| indices: 1-D integer tensor of batch indices to retain. | |
| """ | |
| self.sliding_window_cache.batch_select_indices(indices) | |
| self.mosrah_cache.batch_select_indices(indices) | |
| self.router_cache.batch_select_indices(indices) | |
| def offload(self) -> None: | |
| """Offload both sub-caches to CPU. | |
| Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache | |
| does not own self.keys/self.values directly; all cached data lives in the sub-caches. | |
| """ | |
| self.sliding_window_cache.offload() | |
| self.mosrah_cache.offload() | |
| self.router_cache.offload() | |
| def prefetch(self) -> None: | |
| """Move both sub-caches back to their model device ahead of time. | |
| Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache | |
| does not own self.keys/self.values directly; all cached data lives in the sub-caches. | |
| """ | |
| self.sliding_window_cache.prefetch() | |
| self.mosrah_cache.prefetch() | |
| self.router_cache.prefetch() | |
| def lazy_initialization( # type: ignore[override] | |
| self, key_states: torch.Tensor, value_states: torch.Tensor | |
| ) -> None: | |
| """No-op — both sub-caches handle their own initialization.""" | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — unsupported abstract methods | |
| # --------------------------------------------------------------------------- | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Not supported — ShramLayerCache has no composite update interface. | |
| The two sub-caches have materially different update semantics: the sliding-window | |
| side uses standard key/value concatenation while the MoSRAH side uses expert-choice | |
| scatter with an active mask. Callers must update each sub-cache directly via | |
| sliding_window_cache.update() or mosrah_cache.update(). | |
| """ | |
| raise NotImplementedError( | |
| "ShramLayerCache has no composite update interface. " | |
| "Update sliding_window_cache or mosrah_cache directly." | |
| ) | |
| def get_max_cache_shape(self) -> int: # type: ignore[override] | |
| """Return the maximum sequence length this layer cache can serve. | |
| The authoritative upper bound is ``config.inference_sequence_length``, which | |
| governs the full accumulated token history the model is configured to handle. | |
| HuggingFace's static-cache machinery reads this value to determine whether the | |
| cache is compileable and to size generation loops. | |
| """ | |
| return self._inference_sequence_length | |
| def get_mask_sizes( # type: ignore[override] | |
| self, | |
| cache_position: torch.Tensor, | |
| ) -> tuple[int, int]: | |
| """Return the KV dimensions for HuggingFace causal mask construction. | |
| Returns (inference_sequence_length, 0): the full static cache capacity as | |
| kv_length and zero offset. HuggingFace reads these values to size the causal | |
| attention mask when is_compileable is True. | |
| """ | |
| return self._inference_sequence_length, 0 | |
| class ShramCache(Cache): | |
| """Top-level cache for the full SHRAM model. | |
| Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache | |
| role and transparently forwards reset, reorder, and sequence-length queries across all | |
| owned layer caches. | |
| No composite update() interface is provided. The two attention paths inside each SHRAM | |
| layer have materially different update semantics; callers must update sub-caches directly | |
| via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache. | |
| ShramCache also tracks per-batch cumulative active token counts via | |
| ``_active_token_counts``. ``total_active_tokens(active_mask)`` returns the accumulated | |
| count before the current step and updates the buffer in-place; the caller uses this as a | |
| per-batch position bias for contiguous arange-based position ID resolution. All counter | |
| updates are in-place to satisfy CUDAGraph fixed-memory requirements. ``reset()`` | |
| zeroes the buffer along with all layer caches. | |
| Args: | |
| config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache | |
| dimensions are derived from config so that a single source of truth governs | |
| every buffer size across the full cache stack. | |
| batch_size: Number of sequences in the batch. | |
| device: Device on which to allocate cache tensors. | |
| """ | |
| is_compileable = True | |
| def __init__( | |
| self, | |
| config: ShramConfig, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> None: | |
| layers = [ | |
| ShramLayerCache( | |
| config=config, | |
| batch_size=batch_size, | |
| device=device, | |
| ) | |
| for _ in range(config.num_decoder_layers) | |
| ] | |
| super().__init__(layers=layers) | |
| # Active token counter for position ID resolution (Unit 23.B). Pre-allocated | |
| # at construction so all updates remain in-place across forward passes, | |
| # satisfying CUDAGraph fixed-memory requirements. | |
| self._active_token_counts: torch.Tensor = torch.zeros( | |
| batch_size, dtype=torch.long, device=device | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Cache — composite-meaningful methods | |
| # --------------------------------------------------------------------------- | |
| # | |
| # reset(): Overridden. Zeroes _active_token_counts in-place, then delegates to | |
| # the inherited implementation to reset all layer caches. | |
| # | |
| # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each. | |
| # | |
| # is_initialized: Inherited property. True iff all layer caches are initialized. | |
| # Since ShramLayerCache.is_initialized is True from construction, this is True | |
| # immediately after ShramCache.__init__ returns. | |
| def total_active_tokens(self, active_mask: torch.BoolTensor) -> torch.Tensor: | |
| """Return the per-batch accumulated active token count before this step, then update. | |
| Reads the current per-batch accumulated count as a position bias for the caller, | |
| then increments the internal counter in-place by the number of active tokens in | |
| ``active_mask`` for each batch item. The pre-update count is returned so the | |
| caller can offset an arange-based position tensor to the correct starting position | |
| for this forward pass. | |
| All updates are in-place to satisfy CUDAGraph fixed-memory requirements. The | |
| counter persists across forward passes until ``reset()`` is called. | |
| Args: | |
| active_mask: Boolean mask of shape ``(B, N)`` for the current forward step, | |
| where True marks an active (non-padding) token position. | |
| Returns: | |
| Integer tensor of shape ``(B,)`` — the accumulated count before this update. | |
| """ | |
| prior_counts = self._active_token_counts.clone() | |
| self._active_token_counts.add_(active_mask.sum(dim=-1)) | |
| return prior_counts | |
| def reset(self) -> None: | |
| """Clear all layer caches and reset the active token counter. | |
| Zeroes ``_active_token_counts`` in-place, then delegates to the inherited | |
| implementation to reset all ShramLayerCache instances. In-place mutation of | |
| the counter is required for CUDAGraph compatibility — the buffer must remain | |
| at the same memory address across steps. | |
| """ | |
| self._active_token_counts.zero_() | |
| super().reset() | |
| def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override] | |
| """Return the cumulative sequence length for the specified layer. | |
| Delegates to the layer cache at layer_idx, which in turn delegates to the | |
| local sliding-window sub-cache. That sub-cache is authoritative for sequence | |
| progress: it sees every token presented to the layer and accumulates a truthful | |
| total count. Defaults to layer 0, which is sufficient for HuggingFace generation. | |
| """ | |
| return self.layers[layer_idx].get_seq_length() | |
| # --------------------------------------------------------------------------- | |
| # Cache — unsupported methods | |
| # --------------------------------------------------------------------------- | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| layer_idx: int, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Not supported — ShramCache has no composite update interface. | |
| The two attention paths inside each SHRAM layer have different update semantics. | |
| Callers must update sub-caches directly: | |
| cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states) | |
| cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask) | |
| """ | |
| raise NotImplementedError( | |
| "ShramCache has no composite update interface. " | |
| "Update sliding_window_cache or mosrah_cache on the relevant layer directly." | |
| ) | |
| def crop(self, max_length: int) -> None: | |
| """Not supported — ShramCache layers do not implement crop().""" | |
| raise NotImplementedError("ShramCache does not support crop().") | |
| def max_batch_size(self) -> int: | |
| """Not supported — ShramCache does not track a uniform batch size across layers.""" | |
| raise NotImplementedError("ShramCache does not expose max_batch_size.") | |
| def max_cache_len(self) -> int: | |
| """Return the maximum sequence length the cache can serve. | |
| Delegates to layers[0].get_max_cache_shape(), which returns | |
| config.inference_sequence_length. HuggingFace's static-cache machinery reads | |
| this value to size generation loops and verify compileable cache contracts. | |
| """ | |
| return self.layers[0].get_max_cache_shape() | |
| # ----------- | |
| # Inlined from: model.py | |
| # ----------- | |
| """Transformer backbone for Shram. | |
| ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed | |
| by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual | |
| representations. It has no knowledge of tokens, vocabulary, generation, or the | |
| HuggingFace causal-LM wrapper contract. | |
| Keeping the embedding out of the backbone is the correct convention and makes | |
| the backbone genuinely modality-agnostic. The token interface — embedding lookup, | |
| LM head, weight tying, and generation-facing naming conventions — belongs on the | |
| task wrapper (ShramForCausalLM), which is the only class that knows this | |
| backbone is being used for language modelling. | |
| The final RMSNorm is necessary because the decoder stack uses pre-norm throughout: | |
| each sublayer normalises its own input, leaving the residual stream itself | |
| unnormalised. After many layers of accumulated residuals, that stream arrives at | |
| the top with uncontrolled magnitude. The final norm brings it to a well-scaled | |
| state before any projection. Without it, the LM head would receive signals of | |
| arbitrary scale. | |
| Caching is caller-managed. If a ShramCache is provided, ShramModel threads the | |
| corresponding per-layer ShramLayerCache into each DecoderLayer and returns the | |
| same top-level ShramCache object in the output dict. If None is provided, no | |
| caching occurs. | |
| Returns a plain dict with keys: | |
| - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size) | |
| - "past_key_values": the ShramCache object passed in, or None | |
| - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None | |
| - "regret_loss": scalar sum of per-layer SHRAM regret losses | |
| - "logit_regret": detached scalar mean per-layer logit-space regret | |
| - "logit_std": detached scalar mean per-layer per-token routing logit spread | |
| """ | |
| # ----------- | |
| # Inlined from: decoder_layer.py | |
| # ----------- | |
| """Decoder layer — a single transformer block. | |
| Each block applies pre-norm hybrid attention followed by pre-norm MLP, with | |
| gated residual connections around both sublayers: | |
| normed_attn = RMSNorm(x) | |
| attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...) | |
| h = x + attn_residual_scale * attn_out | |
| normed_mlp = RMSNorm(h) | |
| mlp_out = SwiGLUMLP(normed_mlp) | |
| out = h + mlp_residual_scale * mlp_out | |
| ``attn_residual_scale`` and ``mlp_residual_scale`` are always present. Their nature | |
| depends on ``config.use_residual_gate``: | |
| - ``True`` (default): learnable scalar ``nn.Parameter`` initialised to zero. The layer | |
| is a pure identity at initialisation and the scales open during training. | |
| - ``False``: fixed buffer ``1/√num_decoder_layers``. No learnable parameter; residual | |
| variance sums to O(1) across depth by construction. | |
| Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly | |
| through unnormalised residuals at depth, and each sublayer receives a stable, | |
| normalised view of the signal. | |
| Two independent RMSNorm instances are used — one before attention, one before | |
| MLP. They learn different scalings because they precede layers with different | |
| dynamic ranges. Sharing them would be wrong. | |
| torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean | |
| subtraction, is faster than LayerNorm, and proved more stable at scale. | |
| """ | |
| # ----------- | |
| # Inlined from: shram.py | |
| # ----------- | |
| """SHRAM hybrid attention layer. | |
| This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x) | |
| used at one decoder attention slot in SHRAM. | |
| The local sliding-window path and the MoSRAH sparse path are already verified | |
| independently. The responsibility here is therefore not to introduce new | |
| attention logic, but to preserve the bridge contracts between them: both paths | |
| must consume the same input hidden state, each path must receive the sub-cache | |
| it actually owns, the two model-space outputs must be summed directly, and the | |
| sparse-path load-balance loss must remain visible to the caller. | |
| """ | |
| # ----------- | |
| # Inlined from: sliding_window_attention.py | |
| # ----------- | |
| # src/shram/model/attention/sliding_window_attention.py | |
| """Local sliding-window attention path for SHRAM. | |
| This file defines `SlidingWindowAttention`, the local short-range attention path | |
| used inside the SHRAM hybrid layer. | |
| In the masked-continuation variant, the local cache no longer returns a | |
| semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache` | |
| returns: | |
| - the retained local window memory concatenated with the current chunk | |
| - an aligned active mask over that returned frame | |
| This module consumes that returned frame directly and constructs effective local | |
| causal/window visibility from the mask. It does not own cache retention policy; | |
| it owns only local attention semantics. | |
| """ | |
| # ----------- | |
| # Inlined from: rope.py | |
| # ----------- | |
| """Rotary Position Embeddings (RoPE). | |
| RoPE encodes position in the *relationship* between query and key vectors. When the | |
| attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce | |
| a score that depends only on the relative distance — not on absolute positions. | |
| Two modes are supported: | |
| default Standard RoPE with base frequency b. Each dimension pair d is assigned | |
| frequency θ_d = b^{-2d/u} where u is the head dimension. The attention | |
| scaling A_rope = 1. | |
| yarn YaRN frequency interpolation for long-context extrapolation (Peng et al., | |
| "YaRN: Efficient Context Window Extension of Large Language Models", 2023, | |
| §A.2). Three frequency regimes: | |
| - Low-frequency dimensions (r < α): fully interpolated by scale s. | |
| These dimensions have long wavelengths relative to the training window | |
| and must be compressed to avoid out-of-distribution positions. | |
| - High-frequency dimensions (r > β): left unchanged. Short-wavelength | |
| dimensions already encode relative position accurately at any scale. | |
| - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r). | |
| Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to | |
| standard RoPE. | |
| Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit | |
| parameters — no shared instance, no config reading. See Unit 5.A design decisions. | |
| Cache sharing: all instances with identical parameters share one cos/sin table via a | |
| class-level registry. The first instance that needs a particular (parameters, device, | |
| dtype) combination builds the table; all subsequent instances reference it directly. | |
| This avoids redundant builds across the num_hidden_layers instances that share the | |
| same parametrisation. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Rotation helper | |
| # --------------------------------------------------------------------------- | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| """Apply the 90° rotation used in the RoPE update formula. | |
| Splits the last dimension into two halves [x1, x2] and returns [-x2, x1]. | |
| Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation | |
| on each consecutive pair of dimensions, matching the block-diagonal operator | |
| R^u_{Θ,p} in the paper. | |
| """ | |
| d = x.shape[-1] // 2 | |
| x1, x2 = x[..., :d], x[..., d:] | |
| return torch.cat([-x2, x1], dim=-1) | |
| # --------------------------------------------------------------------------- | |
| # RotaryEmbedding | |
| # --------------------------------------------------------------------------- | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary Position Embeddings with explicit mode and parameter control. | |
| Each caller constructs its own instance with the exact parameters it needs. | |
| h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No | |
| config object is read inside this module. | |
| The cos/sin table is built at construction time to cover all positions in | |
| ``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if | |
| the query tensor's dtype or device has changed since construction. | |
| Instances with identical parameters share one cos/sin table via the class-level | |
| ``_cache`` registry, avoiding redundant computation across decoder layers. | |
| Args: | |
| mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation. | |
| head_dim: Per-head embedding dimension ``u``. Must be even. | |
| theta: Base frequency ``b`` in θ_d = b^{-2d/u}. | |
| maximum_sequence_length: Maximum number of positions the table must cover. | |
| The cos/sin table is preallocated to this length at construction time. | |
| For ``mode="yarn"``, the training context length C_train is derived | |
| internally as ``round(maximum_sequence_length / dilation)``. | |
| dilation: Scale factor ``s = C_target / C_train`` — how much the context | |
| window is extended beyond training length. Required for ``mode="yarn"``. | |
| When ``dilation=1.0``, YaRN reduces to standard RoPE. | |
| alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully | |
| interpolated. Required for ``mode="yarn"``. | |
| beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left | |
| unchanged. Required for ``mode="yarn"``. | |
| device: Optional device for initial buffer placement. | |
| Raises: | |
| NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``. | |
| ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``, | |
| ``beta`` are absent. | |
| """ | |
| # Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table). | |
| # Shared across all RotaryEmbedding instances in the process. Keys include device | |
| # and dtype so that tables built on different devices or in different precisions | |
| # are stored independently. | |
| _cache: dict = {} | |
| def __init__( | |
| self, | |
| mode: str, | |
| head_dim: int, | |
| theta: float, | |
| maximum_sequence_length: int, | |
| dilation: float | None = None, | |
| alpha: float | None = None, | |
| beta: float | None = None, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self._validate_mode(mode) | |
| self._validate_yarn_params(mode, dilation, alpha, beta) | |
| self.mode = mode | |
| self._maximum_sequence_length = maximum_sequence_length | |
| device = torch.device("cpu") if device is None else device | |
| # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn). | |
| # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair, | |
| # so rotation_freqs has head_dim/2 entries. | |
| d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) | |
| base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u} | |
| if mode == "default": | |
| rotation_freqs = base_freqs | |
| self.attention_scaling: float = 1.0 | |
| else: # yarn | |
| s = dilation | |
| # C_train is the training context length, recovered from the inference | |
| # context length and the dilation factor. round() guards against floating | |
| # point error since both underlying quantities are integers. | |
| c_train: int = round(maximum_sequence_length / dilation) | |
| # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp | |
| # function to classify each dimension into one of three regimes. | |
| normalized_freqs = c_train * base_freqs / (2.0 * math.pi) | |
| # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged), | |
| # linear blend between α and β. | |
| blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0) | |
| # θ_d' = (1 − γ) · θ_d / s + γ · θ_d | |
| rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs | |
| # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller. | |
| self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2 | |
| # freq_key uniquely identifies the parameter set that produced rotation_freqs, | |
| # including maximum_sequence_length so instances with different table sizes | |
| # do not collide in the registry. | |
| if mode == "default": | |
| self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length) | |
| else: | |
| self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta) | |
| # rotation_freqs is a plain instance attribute, not a registered buffer. | |
| # This keeps it out of the state dict and prevents HuggingFace's fast-init | |
| # path from turning it into a meta tensor, which would break _build_cache. | |
| self.rotation_freqs = rotation_freqs | |
| # Cache tensors are plain instance attributes (not registered buffers) so that | |
| # sharing across identically-parametrised instances survives .to() calls. | |
| # Registered buffers are copied on device move; plain attributes are aliased, | |
| # preserving the shared-tensor identity that the cache design depends on. | |
| self._cos_cached: torch.Tensor | None = None | |
| self._sin_cached: torch.Tensor | None = None | |
| # Build the table at construction time. Forward rebuilds only on dtype or | |
| # device change. If no device is specified, build on CPU as the default. | |
| self._build_cache(device=device, dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Validation helpers | |
| # --------------------------------------------------------------------------- | |
| def _validate_mode(mode: str) -> None: | |
| """Raise NotImplementedError if mode is not a supported value.""" | |
| if mode not in {"default", "yarn"}: | |
| raise NotImplementedError( | |
| f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'." | |
| ) | |
| def _validate_yarn_params( | |
| mode: str, | |
| dilation: float | None, | |
| alpha: float | None, | |
| beta: float | None, | |
| ) -> None: | |
| """Raise ValueError if mode='yarn' and any required parameter is absent.""" | |
| if mode != "yarn": | |
| return | |
| missing = [ | |
| name for name, val in [ | |
| ("dilation", dilation), | |
| ("alpha", alpha), | |
| ("beta", beta), | |
| ] | |
| if val is None | |
| ] | |
| if missing: | |
| raise ValueError(f"mode='yarn' requires {missing}.") | |
| # --------------------------------------------------------------------------- | |
| # Cache management | |
| # --------------------------------------------------------------------------- | |
| def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None: | |
| """Build the cos/sin table to cover positions [0, maximum_sequence_length). | |
| Checks the class-level registry first. If a table already exists for this | |
| exact (parameters, device, dtype) combination it is reused directly; | |
| otherwise it is computed and stored. The instance attributes are pointed at | |
| the registry entry so that all layers sharing the same parametrisation | |
| reference the same tensor. | |
| """ | |
| cache_key = (self._freq_key, str(device), str(dtype)) | |
| if cache_key not in RotaryEmbedding._cache: | |
| positions = torch.arange( | |
| self._maximum_sequence_length, device=device, dtype=torch.float32 | |
| ) | |
| # outer product → (maximum_sequence_length, head_dim // 2); | |
| # duplicate to (maximum_sequence_length, head_dim) | |
| freqs = torch.outer( | |
| positions, | |
| self.rotation_freqs.to(device=device, dtype=torch.float32), | |
| ) | |
| angle_embedding = torch.cat((freqs, freqs), dim=-1) | |
| RotaryEmbedding._cache[cache_key] = ( | |
| angle_embedding.cos().to(dtype), | |
| angle_embedding.sin().to(dtype), | |
| ) | |
| self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key] | |
| def forward( | |
| self, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor, float]: | |
| """Apply rotary embeddings to query and key tensors. | |
| The cos/sin table is built at construction time. It is rebuilt here only | |
| if ``q``'s dtype or device differs from the cached table — for example, | |
| after moving the model to a different device via ``.cuda()``. | |
| ``position_ids`` may be any integer tensor shape. Its values must be in | |
| ``[0, maximum_sequence_length)``: | |
| - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim). | |
| - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim). | |
| When q/k have head dimensions absent from position_ids, broadcast dimensions | |
| are inserted automatically at dim 1. | |
| Args: | |
| q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim). | |
| k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim). | |
| position_ids: Integer positions of shape (batch, *pos_dims). | |
| Returns: | |
| Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is | |
| 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must | |
| apply to attention logits before softmax. | |
| """ | |
| wrong_dtype = self._cos_cached.dtype != q.dtype | |
| wrong_device = self._cos_cached.device != q.device | |
| if wrong_dtype or wrong_device: | |
| self._build_cache(device=q.device, dtype=q.dtype) | |
| cos = self._cos_cached[position_ids] | |
| sin = self._sin_cached[position_ids] | |
| # Insert broadcast dimensions for any head axes present in q/k but absent | |
| # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once. | |
| # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed. | |
| while cos.ndim < q.ndim: | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q_rotated = q * cos + _rotate_half(q) * sin | |
| k_rotated = k * cos + _rotate_half(k) * sin | |
| return q_rotated, k_rotated, self.attention_scaling | |
| class SlidingWindowAttention(nn.Module): | |
| """Causal local sliding-window attention for one SHRAM layer. | |
| Args: | |
| config: SHRAM config. Must expose `hidden_size`, | |
| `num_sliding_window_heads`, `head_dim`, `window_size`, | |
| `attention_dropout`, and `local_rope_theta`. | |
| Raises: | |
| NotImplementedError: If `attention_dropout != 0.0`. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.hidden_size = config.embedding_width | |
| self.num_heads = config.num_sliding_window_heads | |
| self.head_dim = config.head_dim | |
| self.window_size = config.window_size | |
| self.attention_dropout = config.attention_dropout | |
| if self.attention_dropout != 0.0: | |
| raise NotImplementedError( | |
| "SlidingWindowAttention currently supports only " | |
| "attention_dropout == 0.0." | |
| ) | |
| self.inner_dim = self.num_heads * self.head_dim | |
| # Standard MHA projections for the local path. | |
| self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False) | |
| self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False) | |
| # The local path always uses default-mode RoPE with its own theta. | |
| self.rope = RotaryEmbedding( | |
| mode="default", | |
| head_dim=self.head_dim, | |
| theta=config.local_rope_theta, | |
| maximum_sequence_length=config.inference_sequence_length, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: LocalSlidingWindowLayerCache | None = None, | |
| ) -> torch.Tensor: | |
| """Apply local causal sliding-window attention. | |
| Args: | |
| x: Input tensor of shape `(B, N, hidden_size)`. | |
| position_ids: Position tensor of shape `(B, N)`. | |
| active_mask: Current-chunk active mask of shape `(B, N)`, where | |
| `True` means active. | |
| cache: Optional `LocalSlidingWindowLayerCache`. | |
| Returns: | |
| Output tensor of shape `(B, N, hidden_size)`. | |
| """ | |
| batch_size, query_len, _ = x.shape | |
| self._validate_position_shape(x, position_ids) | |
| self._validate_active_mask_shape(x, active_mask) | |
| # (B, N, H*D) -> (B, H, N, D) | |
| q = self.q_proj(x).view( | |
| batch_size, | |
| query_len, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| k = self.k_proj(x).view( | |
| batch_size, | |
| query_len, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| v = self.v_proj(x).view( | |
| batch_size, | |
| query_len, | |
| self.num_heads, | |
| self.head_dim, | |
| ).transpose(1, 2) | |
| q, k, attention_scaling = self.rope(q, k, position_ids) | |
| # The cache returns the current-step visible local frame, not merely the | |
| # retained next-step cache buffer. | |
| if cache is not None: | |
| k_full, v_full, full_active_mask, full_positions = cache.update( | |
| k, v, active_mask, position_ids | |
| ) | |
| else: | |
| k_full, v_full, full_active_mask, full_positions = k, v, active_mask, position_ids | |
| block_mask = self._make_block_mask( | |
| active_mask=full_active_mask, | |
| positions=full_positions, | |
| batch_size=batch_size, | |
| num_heads=self.num_heads, | |
| query_len=query_len, | |
| kv_len=k_full.shape[-2], | |
| window_size=self.window_size, | |
| device=x.device, | |
| ) | |
| attn_output = flex_attention( | |
| q, | |
| k_full, | |
| v_full, | |
| block_mask=block_mask, | |
| scale=attention_scaling / math.sqrt(self.head_dim), | |
| ) | |
| # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size) | |
| attn_output = ( | |
| attn_output.transpose(1, 2) | |
| .contiguous() | |
| .view(batch_size, query_len, self.inner_dim) | |
| ) | |
| return self.o_proj(attn_output) | |
| def _validate_position_shape( | |
| self, | |
| x: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| ) -> None: | |
| """Validate the position tensor shape expected by local RoPE.""" | |
| if position_ids.shape != x.shape[:2]: | |
| raise ValueError( | |
| f"position_ids must have shape {tuple(x.shape[:2])}, " | |
| f"got {tuple(position_ids.shape)}." | |
| ) | |
| def _validate_active_mask_shape( | |
| self, | |
| x: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| ) -> None: | |
| """Validate the current-chunk active-mask contract.""" | |
| if active_mask.shape != x.shape[:2]: | |
| raise ValueError( | |
| f"active_mask must have shape {tuple(x.shape[:2])}, " | |
| f"got {tuple(active_mask.shape)}." | |
| ) | |
| if active_mask.dtype != torch.bool: | |
| raise ValueError( | |
| f"active_mask must have dtype torch.bool, got {active_mask.dtype}." | |
| ) | |
| def _make_block_mask( | |
| self, | |
| active_mask: torch.Tensor, | |
| positions: torch.Tensor, | |
| batch_size: int, | |
| num_heads: int, | |
| query_len: int, | |
| kv_len: int, | |
| window_size: int, | |
| device: torch.device, | |
| ) -> Any: | |
| """Create the FlexAttention block mask for masked local continuation. | |
| The returned local frame is chronological in raw buffer order; dead | |
| positions may remain inside it. Liveness is carried by `active_mask`. | |
| Causality and window distance are determined from `positions`, which | |
| holds the absolute sequence position of every slot in the composite | |
| frame. Using absolute positions rather than a cumsum over the active | |
| mask eliminates the data-dependent computation that blocks torch.compile. | |
| """ | |
| query_offset = kv_len - query_len | |
| def sliding_window_mask( | |
| batch_idx: torch.Tensor, | |
| head_idx: torch.Tensor, | |
| q_idx: torch.Tensor, | |
| kv_idx: torch.Tensor, | |
| ) -> torch.Tensor: | |
| q_abs = query_offset + q_idx | |
| query_is_active = active_mask[batch_idx, q_abs] | |
| key_is_active = active_mask[batch_idx, kv_idx] | |
| q_pos = positions[batch_idx, q_abs] | |
| k_pos = positions[batch_idx, kv_idx] | |
| is_causal = k_pos <= q_pos | |
| in_window = (q_pos - k_pos) < window_size | |
| return query_is_active & key_is_active & is_causal & in_window | |
| return create_block_mask( | |
| sliding_window_mask, | |
| B=batch_size, | |
| H=num_heads, | |
| Q_LEN=query_len, | |
| KV_LEN=kv_len, | |
| device=device, | |
| ) | |
| # ----------- | |
| # Inlined from: mosrah.py | |
| # ----------- | |
| """Full MoSRAH sparse path for SHRAM. | |
| This module coordinates the routed sparse attention path used inside the SHRAM | |
| hybrid attention layer. The underlying mechanics already live in verified | |
| subunits. The responsibility here is to connect those subunits without | |
| corrupting their bridge contracts. | |
| In particular, this path must preserve three architectural distinctions: | |
| - selected head indices are not routing probabilities | |
| - packed position semantics are chosen before BEA, not inside it | |
| - weighted reduction must consume the router's unbiased renormalized | |
| probabilities after token-choice order has been restored | |
| """ | |
| # ----------- | |
| # Inlined from: bottlenecked_ensemble_attention.py | |
| # ----------- | |
| """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path. | |
| BEA is the packed expert-choice attention operator over the MoSRAH sparse path. | |
| It consumes packed expert-choice tensors, a supplied position tensor, an active | |
| token mask, and an optional layer-local MoSRAH cache. It returns outputs in the | |
| same packed expert-choice space expected by later unpacking. | |
| BEA does not compute positions and does not choose packed-position semantics. | |
| Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys | |
| (K̃) and raw values (V) into the sparse cache and attends against the | |
| accumulated cached state returned by that cache. | |
| """ | |
| class BottleneckedEnsembleAttention(nn.Module): | |
| """ | |
| Packed expert-choice attention operator for the MoSRAH sparse path. | |
| Operates per-head independently on an ensemble of tokens. | |
| FlexAttention saves flops on dead tokens. | |
| Architectural properties: | |
| - consumes packed expert-choice tensors of shape (B, L, T, d) | |
| - uses independent per-head Q/K/V/O projection parameters | |
| - applies YaRN-capable RoPE using supplied position_ids | |
| - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled | |
| - uses a fast fused attention path | |
| - returns outputs in the same packed expert-choice space (B, L, T, d) | |
| Args: | |
| config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`, | |
| `head_dim`, `mosrah_rope_theta`, `inference_sequence_length`, | |
| `scale`, `alpha`, and `beta`. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.hidden_size = config.embedding_width | |
| self.num_heads = config.num_mosrah_heads | |
| self.head_dim = config.head_dim | |
| # Independent per-head projections. No cross-head parameter sharing. | |
| self.q_proj = nn.Parameter( | |
| torch.empty(self.num_heads, self.hidden_size, self.head_dim) | |
| ) | |
| self.k_proj = nn.Parameter( | |
| torch.empty(self.num_heads, self.hidden_size, self.head_dim) | |
| ) | |
| self.v_proj = nn.Parameter( | |
| torch.empty(self.num_heads, self.hidden_size, self.head_dim) | |
| ) | |
| self.o_proj = nn.Parameter( | |
| torch.empty(self.num_heads, self.head_dim, self.hidden_size) | |
| ) | |
| self._reset_parameters() | |
| # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor; | |
| # this unit only consumes it. In training modes, dilation will be 1.0 and so | |
| # no yarn dilation occurs. | |
| # | |
| # The required table size depends on position semantics: | |
| # main_sequence — positions are original token positions, bounded by | |
| # inference_sequence_length. | |
| # semantic_sequence — positions are local per-expert slot indices, bounded | |
| # by mosrah_packed_length. | |
| maximum_rope_length = ( | |
| config.mosrah_packed_length | |
| if config.rope_mode == "semantic_sequence" | |
| else config.inference_sequence_length | |
| ) | |
| self.rope = RotaryEmbedding( | |
| mode="yarn", | |
| head_dim=self.head_dim, | |
| theta=config.mosrah_rope_theta, | |
| maximum_sequence_length=maximum_rope_length, | |
| dilation=config.scale, | |
| alpha=config.alpha, | |
| beta=config.beta, | |
| ) | |
| def forward( | |
| self, | |
| packed_embeddings: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: MoSRAHCache | None = None, | |
| ) -> torch.Tensor: | |
| """Apply BEA to packed expert-choice tensors. | |
| Args: | |
| packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d). | |
| position_ids: Supplied packed positions of shape (B, L, T). | |
| active_mask: Boolean active-token mask of shape (B, L, T). | |
| cache: Optional layer-local MoSRAH cache. | |
| Returns: | |
| Packed expert-choice output tensor of shape (B, L, T, d). | |
| """ | |
| batch_size, _, query_length, _ = packed_embeddings.shape | |
| self._validate_tensor_shape(packed_embeddings) | |
| self._validate_position_shape(packed_embeddings, position_ids) | |
| self._validate_active_mask_shape(packed_embeddings, active_mask) | |
| # Independent per-head projections: | |
| # (B, L, T, d) x (L, d, u) -> (B, L, T, u) | |
| query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj) | |
| key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj) | |
| value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj) | |
| rotated_query_states, rotated_key_states, attention_scaling = self.rope( | |
| query_states, | |
| key_states, | |
| position_ids, | |
| ) | |
| if cache is not None: | |
| # In cached execution, the current query tensor uses local tensor rows | |
| # 0..Q-1, but the key tensor returned by the cache is the full accumulated | |
| # packed sequence for each (batch, head) slot. The only additional data | |
| # needed to align those two views is the pre-update cached prefix length. | |
| # which will indicate how many queries were processed before now. | |
| num_tokens_processed = cache.get_heads_lengths().clone() | |
| key_states, value_states, key_active_mask = cache.update( | |
| rotated_key_states, | |
| value_states, | |
| active_mask, | |
| ) | |
| else: | |
| num_tokens_processed = torch.zeros( | |
| batch_size, | |
| self.num_heads, | |
| dtype=torch.long, | |
| device=packed_embeddings.device, | |
| ) | |
| key_states = rotated_key_states | |
| key_active_mask = active_mask | |
| block_mask = self._make_block_mask( | |
| query_active_mask=active_mask, | |
| key_active_mask=key_active_mask, | |
| num_tokens_processed=num_tokens_processed, | |
| query_length=query_length, | |
| key_length=key_states.shape[2], | |
| device=packed_embeddings.device, | |
| ) | |
| attended_states = flex_attention( | |
| rotated_query_states, | |
| key_states, | |
| value_states, | |
| block_mask=block_mask, | |
| scale=attention_scaling / math.sqrt(self.head_dim), | |
| ) | |
| # Project back to model width: | |
| # (B, L, T, u) x (L, u, d) -> (B, L, T, d) | |
| return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj) | |
| def _reset_parameters(self) -> None: | |
| """Initialize per-head projection weights.""" | |
| for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): | |
| nn.init.xavier_uniform_(weight) | |
| def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None: | |
| """Validate the local packed-embedding shape contract required by BEA.""" | |
| if packed_embeddings.shape[1] != self.num_heads: | |
| raise ValueError( | |
| f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, " | |
| f"got {packed_embeddings.shape[1]}." | |
| ) | |
| if packed_embeddings.shape[-1] != self.hidden_size: | |
| raise ValueError( | |
| f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, " | |
| f"got {packed_embeddings.shape[-1]}." | |
| ) | |
| def _validate_position_shape( | |
| self, | |
| packed_embeddings: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| ) -> None: | |
| """Validate the supplied packed-position tensor shape.""" | |
| if position_ids.shape != packed_embeddings.shape[:3]: | |
| raise ValueError( | |
| f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, " | |
| f"got {tuple(position_ids.shape)}." | |
| ) | |
| def _validate_active_mask_shape( | |
| self, | |
| packed_embeddings: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| ) -> None: | |
| """Validate the supplied active-token mask shape.""" | |
| if active_mask.shape != packed_embeddings.shape[:3]: | |
| raise ValueError( | |
| f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, " | |
| f"got {tuple(active_mask.shape)}." | |
| ) | |
| def _make_block_mask( | |
| self, | |
| query_active_mask: torch.Tensor, | |
| key_active_mask: torch.Tensor, | |
| num_tokens_processed: torch.Tensor, | |
| query_length: int, | |
| key_length: int, | |
| device: torch.device, | |
| ): | |
| """Create the packed-sequence causal mask for FlexAttention. | |
| At the root, causality is still triangular. The only nuance is cached | |
| execution: query rows are indexed locally as 0..Q-1 inside the current | |
| query tensor, but the key tensor may already contain a cached prefix for | |
| that (batch, head) slot. The causal horizon for query tensor row q is | |
| therefore: | |
| cached_prefix_lengths[b, h] + q | |
| Query and key activity masks are then composed with that triangular rule | |
| so FlexAttention can skip padded query rows and ignore inactive key slots. | |
| """ | |
| batch_size, num_heads, _ = query_active_mask.shape | |
| # Build the per-(batch, head, query_row) triangular horizon from a simple | |
| # arange over query rows plus the cached prefix lengths for each slot. | |
| relative_query_positions = torch.arange( | |
| query_length, | |
| device=device, | |
| dtype=torch.long, | |
| ).view(1, 1, query_length) | |
| causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions | |
| def packed_causal_mask( | |
| batch_idx: torch.Tensor, | |
| head_idx: torch.Tensor, | |
| query_idx: torch.Tensor, | |
| key_idx: torch.Tensor, | |
| ) -> torch.Tensor: | |
| query_is_active = query_active_mask[batch_idx, head_idx, query_idx] | |
| key_is_active = key_active_mask[batch_idx, head_idx, key_idx] | |
| is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx] | |
| return query_is_active & key_is_active & is_causal | |
| return create_block_mask( | |
| packed_causal_mask, | |
| B=batch_size, | |
| H=num_heads, | |
| Q_LEN=query_length, | |
| KV_LEN=key_length, | |
| device=device, | |
| ) | |
| # ----------- | |
| # Inlined from: expert_packing.py | |
| # ----------- | |
| """Expert packing and unpacking for the MoSRAH path. | |
| This module owns the token-choice -> expert-choice -> token-choice conversion | |
| boundary used by the sparse routed attention path. Its public behavior is fixed: | |
| - setup_packing() prepares the auxiliary ordering data forwarded through packing | |
| and unpacking. | |
| - pack_experts() converts routed token-choice tensors into padded expert-choice | |
| tensors. | |
| - unpack_experts() restores token-choice ordering from padded expert-choice output. | |
| Packed expert-choice tensors are expert-major and left-justified. For each expert, | |
| routed token copies occupy the prefix of that expert's packed block; padding occupies | |
| the suffix. Every packed entry uses the same ordering and transfer artifact, so | |
| hidden states, positions, masks, and probabilities remain aligned across the boundary. | |
| pack_experts() returns a flat transfer index together with the packed entries. This | |
| index replaces the old boolean unpacking artifact as the source of truth for | |
| pack/unpack data movement: packing writes to those flat packed slots, and unpacking | |
| reads from those same slots. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Setup | |
| # --------------------------------------------------------------------------- | |
| def setup_packing( | |
| selected_heads: torch.Tensor, | |
| ) -> dict[str, torch.Tensor]: | |
| """Prepare the auxiliary ordering data used by pack/unpack. | |
| Args: | |
| selected_heads: Routed token-choice head selections I of shape (B, N, K). | |
| Returns: | |
| Auxiliary payload dict with keys: | |
| - "flattened_selected_heads": H of shape (B, N*K) | |
| - "permutation": expert-major permutation Pi of shape (B, N*K) | |
| - "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K) | |
| This dict is forwarded whole to pack_experts and unpack_experts. | |
| """ | |
| batch_size, sequence_length, num_selected_heads = selected_heads.shape | |
| flattened_selected_heads = selected_heads.reshape( | |
| batch_size, | |
| sequence_length * num_selected_heads, | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Establish the expert-major ordering invariant. | |
| # | |
| # BEA later applies a triangular causal mask inside each expert bucket. That | |
| # mask is only meaningful if routed copies for the same expert preserve their | |
| # source-token order. Stable sorting by selected head establishes that order. | |
| # ----------------------------------------------------------------------- | |
| permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True) | |
| inverse_permutation = torch.argsort(permutation, dim=-1) | |
| return { | |
| "flattened_selected_heads": flattened_selected_heads, | |
| "permutation": permutation, | |
| "inverse_permutation": inverse_permutation, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Packing | |
| # --------------------------------------------------------------------------- | |
| def pack_experts( | |
| entries: dict[str, tuple[torch.Tensor, Any]], | |
| setup: dict[str, torch.Tensor], | |
| selected_heads: torch.Tensor, | |
| num_experts: int, | |
| packed_length: int, | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: | |
| """Pack token-choice tensors into expert-choice padded form. | |
| Args: | |
| entries: Mapping from string keys to (tensor, padding_value) pairs. Each | |
| tensor has shape (B, N, ...) and is rearranged into expert-choice layout | |
| (B, L, T, ...). The returned dict carries the same keys. | |
| setup: Auxiliary payload returned by setup_packing(). | |
| selected_heads: Routed head selections I of shape (B, N, K). | |
| num_experts: Total number of experts L. | |
| packed_length: Static packed time dimension T. All per-expert buffers are | |
| allocated to exactly this length. Raises if any actual per-expert token | |
| count exceeds this value. | |
| Returns: | |
| Tuple of: | |
| - packed_entries: Dict with same keys as entries; each value is the | |
| packed tensor of shape (B, L, T, ...). | |
| - flat_packed_transfer_indices: Long tensor of shape (B*N*K,). Each value | |
| is the flattened padded expert-choice slot occupied by the corresponding | |
| routed-copy row. Pass this to unpack_experts(). | |
| """ | |
| batch_size, sequence_length, num_selected_heads = selected_heads.shape | |
| num_routed_copies_per_batch = sequence_length * num_selected_heads | |
| num_routed_copies = batch_size * num_routed_copies_per_batch | |
| flattened_selected_heads = setup["flattened_selected_heads"] | |
| permutation = setup["permutation"] | |
| # ----------------------------------------------------------------------- | |
| # Algorithm overview. | |
| # | |
| # Packing first builds one routed-copy row for each selected token/expert | |
| # pair, ordered by the stable expert-major permutation. Those rows contain | |
| # no padding. The final packed tensor reserves packed_length slots per expert. | |
| # The flat transfer index bridges those layouts by adding back the cumulative | |
| # padding skipped before each expert block. | |
| # ----------------------------------------------------------------------- | |
| # ----------------------------------------------------------------------- | |
| # Build the shared routed-copy source rows. | |
| # | |
| # This tensor identifies the source token row for each selected token/expert | |
| # pair after the stable expert-major permutation. Every packed entry uses this | |
| # same row plan, so all entries remain aligned before padded materialization. | |
| # ----------------------------------------------------------------------- | |
| source_token_indices = torch.arange( | |
| sequence_length, | |
| device=flattened_selected_heads.device, | |
| dtype=torch.long, | |
| ).view(1, sequence_length, 1).expand( | |
| batch_size, | |
| sequence_length, | |
| num_selected_heads, | |
| ) | |
| flattened_source_token_indices = source_token_indices.reshape( | |
| batch_size, | |
| num_routed_copies_per_batch, | |
| ) | |
| sorted_source_token_indices = flattened_source_token_indices.gather( | |
| dim=1, | |
| index=permutation, | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Establish packed expert occupancy and capacity. | |
| # | |
| # tokens_per_expert tells how many routed-copy rows occupy the prefix of each | |
| # expert block. The padded layout is valid only when every prefix fits inside | |
| # the configured packed_length. | |
| # ----------------------------------------------------------------------- | |
| tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts) | |
| _enforce_no_overflow(tokens_per_expert, packed_length) | |
| # ----------------------------------------------------------------------- | |
| # Build the flat insertion points for the padded expert frame. | |
| # | |
| # Routed-copy rows omit padding, while the packed frame reserves packed_length | |
| # slots for every expert. The transfer index adds back the cumulative padding | |
| # skipped before each expert block, producing one flat destination slot for | |
| # every routed-copy row. This tensor is forwarded to unpack_experts so removal | |
| # uses the same positions that insertion used. | |
| # ----------------------------------------------------------------------- | |
| flat_tokens_per_expert = tokens_per_expert.reshape(-1) | |
| flat_padding_per_expert = packed_length - flat_tokens_per_expert | |
| flat_padding_before_expert = ( | |
| flat_padding_per_expert.cumsum(dim=0) - flat_padding_per_expert | |
| ) | |
| flat_padding_for_routed_rows = torch.repeat_interleave( | |
| flat_padding_before_expert, | |
| flat_tokens_per_expert, | |
| output_size=num_routed_copies, | |
| ) | |
| flat_routed_row_indices = torch.arange( | |
| num_routed_copies, | |
| device=flattened_selected_heads.device, | |
| dtype=torch.long, | |
| ) | |
| flat_packed_transfer_indices = ( | |
| flat_routed_row_indices + flat_padding_for_routed_rows | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Materialize each entry through the shared routing and transfer artifacts. | |
| # | |
| # Each entry first gathers into the shared routed-copy order. The flat packed | |
| # allocation supplies padding, and the transfer index writes each routed-copy | |
| # row into its padded expert slot before the public shape is restored. | |
| # ----------------------------------------------------------------------- | |
| packed_entries: dict[str, torch.Tensor] = {} | |
| for key, (tensor, padding_value) in entries.items(): | |
| extra_shape = tensor.shape[2:] | |
| # The sorted source index is shared across all entries; expanding it over | |
| # trailing dimensions lets the same routing/order plan apply to hidden | |
| # states, positions, masks, probabilities, and any other packed tensor. | |
| sorted_gather_indices = sorted_source_token_indices.view( | |
| batch_size, | |
| num_routed_copies_per_batch, | |
| *(1,) * len(extra_shape), | |
| ).expand(-1, -1, *extra_shape) | |
| sorted_tensor = tensor.gather(dim=1, index=sorted_gather_indices) | |
| packed_tensor = tensor.new_full( | |
| (batch_size * num_experts * packed_length, *extra_shape), | |
| fill_value=padding_value, | |
| ) | |
| packed_tensor[flat_packed_transfer_indices] = sorted_tensor.reshape( | |
| num_routed_copies, | |
| *extra_shape, | |
| ) | |
| packed_entries[key] = packed_tensor.reshape( | |
| batch_size, | |
| num_experts, | |
| packed_length, | |
| *extra_shape, | |
| ) | |
| return packed_entries, flat_packed_transfer_indices | |
| # --------------------------------------------------------------------------- | |
| # Unpacking | |
| # --------------------------------------------------------------------------- | |
| def unpack_experts( | |
| expert_outputs: torch.Tensor, | |
| setup: dict[str, torch.Tensor], | |
| flat_packed_transfer_indices: torch.Tensor, | |
| selected_heads: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Restore token-choice ordering from BEA expert-choice output. | |
| Args: | |
| expert_outputs: Expert-choice BEA output y of shape (B, L, T, d). | |
| setup: Auxiliary payload returned by setup_packing(). | |
| flat_packed_transfer_indices: Transfer index returned by pack_experts(). | |
| Each value identifies a routed-copy slot in the flattened padded | |
| expert-choice frame. | |
| selected_heads: Routed head selections I of shape (B, N, K). | |
| Returns: | |
| Restored token-choice tensor y_tilde of shape (B, N, K, d). | |
| """ | |
| inverse_permutation = setup["inverse_permutation"] | |
| batch_size, sequence_length, num_selected_heads = selected_heads.shape | |
| num_routed_copies_per_batch = sequence_length * num_selected_heads | |
| hidden_dim = expert_outputs.shape[-1] | |
| # ----------------------------------------------------------------------- | |
| # Recover routed-copy rows from the same packed slots used at insertion. | |
| # | |
| # Packing writes into the forwarded flat slots, and unpacking reads from those | |
| # same slots before applying the inverse routing permutation back to | |
| # token-choice order. | |
| # ----------------------------------------------------------------------- | |
| flat_expert_outputs = expert_outputs.reshape(-1, hidden_dim) | |
| flat_routed_copy_outputs = flat_expert_outputs[flat_packed_transfer_indices] | |
| sorted_token_choice_outputs = flat_routed_copy_outputs.reshape( | |
| batch_size, | |
| num_routed_copies_per_batch, | |
| hidden_dim, | |
| ) | |
| restored_outputs = sorted_token_choice_outputs.gather( | |
| dim=1, | |
| index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim), | |
| ) | |
| return restored_outputs.reshape( | |
| batch_size, | |
| sequence_length, | |
| num_selected_heads, | |
| hidden_dim, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _enforce_no_overflow(tokens_per_expert: torch.Tensor, packed_length: int) -> None: | |
| """Enforce that no expert bucket exceeds the preallocated packed length. | |
| This check fires when the number of tokens assigned to any expert in any batch | |
| item exceeds mosrah_packed_length. When that limit is exceeded, the packed buffer | |
| is too small to hold all assignments and data would be dropped. Reduce the input | |
| sequence length or increase training_sequence_length (for training) or | |
| inference_sequence_length (for inference) in ShramConfig to resolve. | |
| Args: | |
| tokens_per_expert: Per-expert token counts, shape (B, num_experts). | |
| packed_length: The preallocated packed time dimension. | |
| """ | |
| if torch.compiler.is_compiling(): | |
| torch._assert_async( | |
| tokens_per_expert.max() <= packed_length, | |
| "Expert packing overflow: expert bucket exceeds mosrah_packed_length. " | |
| "Reduce sequence length or increase training_sequence_length / " | |
| "inference_sequence_length in ShramConfig.", | |
| ) | |
| else: | |
| max_count = tokens_per_expert.max().item() | |
| if max_count > packed_length: | |
| raise RuntimeError( | |
| "Expert packing overflow: at least one expert bucket contains more " | |
| "tokens than mosrah_packed_length allows. Reduce sequence length or " | |
| "increase training_sequence_length / inference_sequence_length in " | |
| "ShramConfig to resolve.\n" | |
| f"Packed length: {packed_length}\n" | |
| f"Head lengths: {tokens_per_expert}\n" | |
| ) | |
| def _count_tokens_per_expert( | |
| flattened_selected_heads: torch.Tensor, | |
| num_experts: int, | |
| ) -> torch.Tensor: | |
| """Count how many routed token copies are assigned to each expert per batch item. | |
| Uses scatter_add into a pre-sized (B, num_experts) buffer. Each position in | |
| flattened_selected_heads contributes one count to the corresponding expert slot. | |
| Args: | |
| flattened_selected_heads: Expert assignments of shape (B, N*K) with values | |
| in [0, num_experts). | |
| num_experts: Total number of experts L. | |
| Returns: | |
| Counts tensor of shape (B, num_experts). | |
| """ | |
| batch_size = flattened_selected_heads.shape[0] | |
| tokens_per_expert = torch.zeros( | |
| batch_size, | |
| num_experts, | |
| device=flattened_selected_heads.device, | |
| dtype=torch.long, | |
| ) | |
| tokens_per_expert.scatter_add_( | |
| dim=1, | |
| index=flattened_selected_heads, | |
| src=torch.ones_like(flattened_selected_heads, dtype=torch.long), | |
| ) | |
| return tokens_per_expert | |
| # ----------- | |
| # Inlined from: router.py | |
| # ----------- | |
| """Token-choice router for the MoSRAH sparse attention path. | |
| This module implements mechanically load-balanced routing for MoSRAH. Given an | |
| input hidden state x, the router produces two outputs used downstream: | |
| - selected_heads (I): which K of the L available expert heads each token | |
| routes to, determined by a block-balanced causal solver. | |
| - routing_probs (P): the weights used for the weighted output reduction, | |
| gathered from the softmax routing scores at the selected indices and | |
| renormalized to sum to 1 per token. | |
| Routing uses a single learnable projection: | |
| - routing_weight: shape (L, embedding_width). Maps input to per-head routing | |
| scores. Task loss trains this parameter through routing_probs; regret_loss | |
| trains it to prefer expert assignments at positions of peak preference. | |
| Block-balanced routing partitions the sequence into non-overlapping blocks of | |
| W = L/K tokens. Within each block every expert is assigned to exactly one token, | |
| guaranteeing perfect load balance by construction. The L % K == 0 compatibility | |
| constraint (enforced in ShramConfig) makes W an exact integer. | |
| Selection is causal within each block: at each of the W steps the current | |
| token chooses its K experts from those not yet claimed by earlier tokens in | |
| the same block. All W steps execute in parallel across blocks and batch via | |
| a fully-unrolled Python for loop, keeping the compiled graph flat. | |
| Paper ref: Appendix A.Routing. | |
| """ | |
| class MoSRAHRouter(nn.Module): | |
| """Token-choice router for MoSRAH sparse attention. | |
| Each input token independently selects K of the L available expert heads | |
| through a block-balanced causal solver. Within each block of W = L/K | |
| consecutive tokens every expert is used exactly once, giving perfect load | |
| balance by construction. | |
| routing_weight is nn.Parameter rather than nn.Linear so that HuggingFace | |
| _init_weights does not override its kaiming initialization at construction. | |
| Attributes: | |
| routing_weight: Shape (L, embedding_width). Maps input hidden states to | |
| per-head routing scores. | |
| block_length: Tokens per routing block W = L / K. Within each block | |
| every expert is used exactly once. | |
| Args: | |
| config: Model configuration. Must expose ``embedding_width``, | |
| ``num_mosrah_heads`` (L), ``num_selected_heads`` (K), and | |
| ``block_length`` (W). | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.num_mosrah_heads = config.num_mosrah_heads | |
| self.num_selected_heads = config.num_selected_heads | |
| self.block_length = config.block_length | |
| # Routing projection: maps input (B, N, d) to per-head routing scores (B, N, L). | |
| # nn.Parameter ensures HuggingFace _init_weights does not override kaiming init. | |
| self.routing_weight = nn.Parameter( | |
| torch.empty(config.num_mosrah_heads, config.embedding_width) | |
| ) | |
| nn.init.kaiming_normal_(self.routing_weight) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| router_cache: RouterCache | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: | |
| """Route input tokens to K expert heads each and compute routing probabilities. | |
| Args: | |
| x: Input hidden states of shape (batch, seq_len, embedding_width). | |
| active_mask: Current-chunk active mask of shape (batch, seq_len), where | |
| True marks a semantically live token. Dead tokens do not contribute | |
| to regret_loss or logit_regret. | |
| Returns: | |
| selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads). | |
| Each token's K selected head indices from the block-balanced solver. | |
| routing_probs: Routing probabilities P of shape (batch, seq_len, | |
| num_selected_heads). Gathered from the pre-balance softmax at | |
| selected_heads and renormalized to sum to 1 per token. | |
| router_diagnostics: Dict of routing scalars: | |
| - ``regret_loss``: gradient-carrying mean regret, mean of | |
| max(p_max_active − p_chosen, 0) over live (B, num_blocks, L) | |
| entries. In [0, 1]. Zero when every expert is assigned at its | |
| peak-preference token within the block. | |
| - ``logit_regret``: detached logit-space regret; same formula | |
| applied to routing logits rather than softmax probabilities. | |
| In [0, ∞). Monitoring only. | |
| - ``logit_std``: detached mean per-token std of routing logits. | |
| """ | |
| # ── Algorithm overview ────────────────────────────────────────────────────── | |
| # | |
| # Problem: each token independently selects its top-K heads with no knowledge | |
| # of what other tokens in the same sequence will choose. Independent selection | |
| # means a single popular head can be chosen by every token while another is | |
| # never used — statistics-based corrections (auxiliary losses, bias vectors) | |
| # can only push routing probabilistically and have proven unstable when tuned | |
| # strongly enough to prevent degeneracy. | |
| # | |
| # Approach: the compatibility constraint E % K == 0 (enforced in ShramConfig) | |
| # makes W = E / K an exact integer. A block of W consecutive tokens contains | |
| # exactly W × K = E selection slots — one per expert. Enforcing that each | |
| # expert is used exactly once per block makes the block perfectly balanced by | |
| # construction, eliminating any need for auxiliary losses or correction steps. | |
| # Enforcement is causal: at each of the W steps the current position picks its | |
| # K experts from those not yet claimed earlier in the same block, by masking | |
| # claimed experts with -inf before top-K. All W steps run simultaneously across | |
| # blocks and batch via a Python for loop that is fully unrolled at compile time. | |
| B, N, _ = x.shape | |
| L = self.num_mosrah_heads | |
| K = self.num_selected_heads | |
| W = self.block_length | |
| # ── Phase: pre-balance scoring ───────────────────────────────────────── | |
| # | |
| # Establishes the clean routing distribution before any -inf masking. | |
| # logit_std is captured here because the block solver's masking would | |
| # corrupt the standard deviation. routing_scores is used both for | |
| # regret_loss and for the final routing_probs. | |
| routing_logits = self._compute_routing_logits(x) # (B, N, L) | |
| logit_std = routing_logits.std(dim=-1).mean().detach() | |
| routing_scores = F.softmax(routing_logits, dim=-1) # (B, N, L) | |
| # ── Phase: block-balanced causal selection ───────────────────────────── | |
| # | |
| # Three execution modes, distinguished by router_cache and sequence length: | |
| # | |
| # Training (router_cache is None): the full sequence is available. All W | |
| # steps of the block solver run simultaneously across every block in the | |
| # sequence. No cache interaction. | |
| # | |
| # Prefill (router_cache is not None, N > 1): identical to training, but | |
| # the partial last-block state is written to the cache so decode steps can | |
| # continue within the same block without a gap. | |
| # | |
| # Decode (router_cache is not None, N == 1): one token arrives at a known | |
| # position within the current block. The cached used_in_block mask is | |
| # applied before TopK to enforce the one-usage-per-block contract, then | |
| # the cache is updated in-place with this step's selections. | |
| if router_cache is not None and N == 1: | |
| # ── Decode mode ─────────────────────────────────────────────────── | |
| # | |
| # Single token; block position and claimed-expert state come from the | |
| # cache. Treating this as a one-token, one-step block means the regret | |
| # computation downstream sees a (B, 1, 1, K) assignment tensor and | |
| # produces exactly zero regret, which is correct: with only one active | |
| # token per "block" there is no alternative assignment with higher | |
| # preference. | |
| used_in_block = router_cache.get_used_in_block() # (B, L) | |
| step_logits = routing_logits[:, 0, :] # (B, L) | |
| available = step_logits.masked_fill(used_in_block, float('-inf')) | |
| step_heads = available.topk(K, dim=-1).indices # (B, K) | |
| router_cache.update_decode(step_heads) | |
| selected_heads = step_heads.unsqueeze(1) # (B, 1, K) | |
| else: | |
| # ── Training / prefill mode ─────────────────────────────────────── | |
| # | |
| # The full N-token sequence is available. Padding extends it to a | |
| # multiple of W; padded tokens occupy the tail of the last block and | |
| # never consume experts needed by real tokens because the real tokens | |
| # preceding them have already had their pick each step. The pad is | |
| # discarded after the solver. | |
| num_blocks = (N + W - 1) // W | |
| N_pad = num_blocks * W | |
| pad_len = N_pad - N | |
| if pad_len > 0: | |
| padded_logits = torch.cat( | |
| [routing_logits, routing_logits.new_zeros(B, pad_len, L)], dim=1 | |
| ) # (B, N_pad, L) | |
| else: | |
| padded_logits = routing_logits | |
| blocked_logits = padded_logits.view(B, num_blocks, W, L) # (B, blk, W, L) | |
| # used_in_block tracks which experts have been claimed within each block. | |
| # No gradient here — expert availability is a hard structural constraint, | |
| # not a differentiable quantity. Gradient flows through routing_probs. | |
| used_in_block = torch.zeros(B, num_blocks, L, dtype=torch.bool, device=x.device) | |
| step_heads_list = [] | |
| for step in range(W): | |
| step_logits = blocked_logits[:, :, step, :] # (B, blk, L) | |
| # Claimed experts receive -inf so top-K never selects them. | |
| available = step_logits.masked_fill(used_in_block, float('-inf')) | |
| step_heads = available.topk(K, dim=-1).indices # (B, blk, K) | |
| step_heads_list.append(step_heads) | |
| # Mark the K chosen experts as unavailable for the rest of this block. | |
| used_in_block = used_in_block.scatter(-1, step_heads, True) | |
| # Stack W steps and reshape to (B, N_pad, K), then unpad. | |
| selected_heads_blocked = torch.stack(step_heads_list, dim=2) # (B, blk, W, K) | |
| selected_heads = selected_heads_blocked.view(B, N_pad, K)[:, :N, :] # (B, N, K) | |
| if router_cache is not None: | |
| # Prefill: persist the partial last-block state so decode steps | |
| # that follow can continue within the same block. | |
| router_cache.update_prefill(selected_heads_blocked, N) | |
| # ── Phase: regret loss ───────────────────────────────────────────────── | |
| # | |
| # Regret measures how much routing preference was sacrificed at each expert | |
| # assignment relative to the peak active preference within the same block. | |
| # A non-zero regret at expert l in block bl means some other active token | |
| # in that block would have preferred expert l more than the one assigned. | |
| # Minimising regret trains the router to save experts for the tokens that | |
| # want them most. | |
| # | |
| # Decode mode returns zeros: regret is only defined over complete W-token | |
| # blocks, and a single decode step is not a complete block. Backward is | |
| # never called during inference so the zero is a correct no-op. | |
| if router_cache is not None and N == 1: | |
| regret_loss = routing_logits.new_zeros(()) | |
| logit_regret = routing_logits.new_zeros(()).detach() | |
| else: | |
| regret_loss, logit_regret = self._compute_regret( | |
| routing_scores, | |
| routing_logits, | |
| selected_heads_blocked, | |
| active_mask, | |
| ) | |
| # ── Phase: routing probabilities ──────────────────────────────────────── | |
| # | |
| # Gathered from the pre-balance routing_scores to reflect genuine routing | |
| # preference; renormalized so they sum to 1 per token. | |
| gathered = routing_scores.gather(dim=-1, index=selected_heads) # (B, N, K) | |
| routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # (B, N, K) | |
| router_diagnostics = { | |
| "regret_loss": regret_loss, | |
| "logit_regret": logit_regret, | |
| "logit_std": logit_std, | |
| } | |
| return selected_heads, routing_probs, router_diagnostics | |
| def _compute_routing_logits(self, x: torch.Tensor) -> torch.Tensor: | |
| """Compute per-head routing logits from input hidden states. | |
| Args: | |
| x: Input hidden states, shape (batch, seq_len, embedding_width). | |
| Returns: | |
| Routing logits, shape (batch, seq_len, num_mosrah_heads). | |
| """ | |
| return F.linear(x, self.routing_weight) # (B, N, L) | |
| def _compute_regret( | |
| routing_scores: torch.Tensor, | |
| routing_logits: torch.Tensor, | |
| selected_heads_blocked: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Compute regret_loss and logit_regret from a completed block assignment. | |
| Regret at expert l in block bl = max(p_max_active − p_chosen, 0), where | |
| p_max_active is the highest routing probability any active token holds for | |
| expert l within the block, and p_chosen is the routing probability of the | |
| token actually assigned to expert l (0 if that token is dead). | |
| regret_loss is the mean over live (batch, block, expert) triples. A block is | |
| live iff it contains at least one active token; all L experts in a live block | |
| contribute. Result is in [0, 1]. | |
| logit_regret applies the same formula to routing_logits and is returned | |
| detached — it is a monitoring scalar only, in [0, ∞). | |
| Args: | |
| routing_scores: Softmax routing probabilities, shape (B, N, L). | |
| Gradient flows through this tensor into regret_loss. | |
| routing_logits: Pre-softmax routing logits, shape (B, N, L). | |
| Used only for the detached logit_regret. | |
| selected_heads_blocked: Expert assignments from the block solver, | |
| shape (B, num_blocks, W, K). Block geometry | |
| (num_blocks, W) is derived from this shape. | |
| active_mask: Boolean live-token mask, shape (B, N). | |
| Returns: | |
| regret_loss: Gradient-carrying scalar in [0, 1]. | |
| logit_regret: Detached scalar in [0, ∞). | |
| """ | |
| B, num_blocks, W, _K = selected_heads_blocked.shape | |
| L = routing_scores.shape[-1] | |
| N = routing_scores.shape[1] | |
| N_pad = num_blocks * W | |
| # ── Reshape into block form ───────────────────────────────────────── | |
| # | |
| # Block geometry is read from selected_heads_blocked — no recomputation | |
| # needed here. Padded tail positions receive zero scores and False | |
| # activity; they do not contribute to any block metric. | |
| if N_pad > N: | |
| pad_len = N_pad - N | |
| scores_blocked = torch.cat( | |
| [routing_scores, routing_scores.new_zeros(B, pad_len, L)], dim=1 | |
| ).view(B, num_blocks, W, L) # (B, nb, W, L) | |
| logits_blocked = torch.cat( | |
| [routing_logits, routing_logits.new_zeros(B, pad_len, L)], dim=1 | |
| ).view(B, num_blocks, W, L) # (B, nb, W, L) | |
| active_blocked = torch.cat( | |
| [active_mask, active_mask.new_zeros(B, pad_len)], dim=1 | |
| ).view(B, num_blocks, W) # (B, nb, W) | |
| else: | |
| scores_blocked = routing_scores.view(B, num_blocks, W, L) | |
| logits_blocked = routing_logits.view(B, num_blocks, W, L) | |
| active_blocked = active_mask.view(B, num_blocks, W) | |
| active_float = active_blocked.float() # (B, nb, W) | |
| block_active = active_blocked.any(dim=-1) # (B, nb) | |
| # ── Assignment mask ───────────────────────────────────────────────── | |
| # | |
| # One-hot indicator of which token was assigned to each expert. Block | |
| # balance guarantees exactly one entry per (b, bl, l) triple, so | |
| # summing over W recovers exactly one score value per expert. | |
| assignment_mask = scores_blocked.new_zeros(B, num_blocks, W, L) | |
| assignment_mask.scatter_(dim=-1, index=selected_heads_blocked, value=1.0) | |
| # (B, nb, W, L) | |
| # ── Prob regret (gradient flows through routing_scores) ───────────── | |
| # | |
| # p_chosen: routing score at the assigned token, gated by active_float | |
| # so dead assignments contribute 0 — the expert accrues full regret | |
| # against the active maximum rather than no penalty. | |
| # p_max: peak routing score over active tokens; dead tokens zeroed before | |
| # max (safe because softmax outputs are non-negative). | |
| p_chosen = (assignment_mask * active_float.unsqueeze(-1) * scores_blocked).sum(dim=2) | |
| # (B, nb, L) | |
| p_max = (active_float.unsqueeze(-1) * scores_blocked).max(dim=2).values | |
| # (B, nb, L) | |
| regret = (p_max - p_chosen).clamp(min=0.0) # (B, nb, L) | |
| # Mean over live (B, num_blocks, L) entries. Clamped to 1 for the | |
| # all-dead edge case where the numerator is already 0. | |
| num_live = block_active.float().sum() # scalar | |
| regret_loss = ( | |
| block_active.float().unsqueeze(-1) * regret | |
| ).sum() / num_live.mul(L).clamp(min=1.0) | |
| # ── Logit regret (detached monitoring) ────────────────────────────── | |
| # | |
| # Same formula applied to routing_logits. Dead tokens cannot be zeroed | |
| # before max (logits may be negative), so they are masked to -inf; | |
| # dead blocks are replaced with 0 before subtraction. Detached so it | |
| # never influences any parameter during backward. | |
| logit_chosen = ( | |
| assignment_mask * active_float.unsqueeze(-1) * logits_blocked | |
| ).sum(dim=2) # (B, nb, L) | |
| logit_max = logits_blocked.masked_fill( | |
| ~active_blocked.unsqueeze(-1), float('-inf') | |
| ).max(dim=2).values # (B, nb, L) | |
| logit_max = logit_max.masked_fill(~block_active.unsqueeze(-1), 0.0) | |
| logit_regret = ( | |
| block_active.float().unsqueeze(-1) * (logit_max - logit_chosen).clamp(min=0.0) | |
| ).sum() / num_live.mul(L).clamp(min=1.0) | |
| logit_regret = logit_regret.detach() | |
| return regret_loss, logit_regret | |
| # ----------- | |
| # Inlined from: positions_converter.py | |
| # ----------- | |
| """Position computation for the MoSRAH sparse path. | |
| This layer computes the packed position tensor P consumed by BEA. | |
| - In main-sequence mode, P is the packed original-token position tensor from the | |
| packing path. | |
| - In semantic-sequence mode, P is a per-expert local sequence over the packed | |
| expert-choice layout, optionally offset by the current sparse-cache occupancies | |
| during cached inference. | |
| """ | |
| class SparseMoSRAHPositions(nn.Module): | |
| """Compute the packed RoPE position tensor for the MoSRAH sparse path. | |
| This layer operates in the packed expert-choice frame used by BEA. The input | |
| packed_positions tensor is always the packed original-token position tensor | |
| produced by the packing path. The configured rope_mode determines whether that | |
| tensor is forwarded directly or replaced by a semantic local-slot sequence. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.rope_mode = config.rope_mode | |
| def forward( | |
| self, | |
| packed_positions: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: MoSRAHCache | None, | |
| ) -> torch.Tensor: | |
| """Compute the packed position tensor P consumed by BEA. | |
| Args: | |
| packed_positions: Packed original-token positions J' of shape (B, L, T). | |
| active_mask: Boolean active-token mask of shape (B, L, T). Inactive | |
| positions are zeroed in the returned tensor regardless of mode — | |
| their position value is semantically irrelevant and 0 is guaranteed | |
| to be within any valid RoPE table. | |
| cache: Optional layer-local MoSRAH cache. When present in semantic-sequence | |
| mode, the current per-head occupancies offset the local packed sequence. | |
| Returns: | |
| Packed position tensor P of shape (B, L, T). | |
| """ | |
| if self.rope_mode == "main_sequence": | |
| positions = self._main_sequence_positions(packed_positions) | |
| elif self.rope_mode == "semantic_sequence": | |
| positions = self._semantic_sequence_positions(packed_positions, cache) | |
| else: | |
| raise NotImplementedError( | |
| f"Unsupported MoSRAH rope_mode '{self.rope_mode}'." | |
| ) | |
| return torch.where(active_mask, positions, torch.zeros_like(positions)) | |
| def _main_sequence_positions( | |
| self, | |
| packed_positions: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Forward packed original-token positions unchanged.""" | |
| return packed_positions | |
| def _semantic_sequence_positions( | |
| self, | |
| packed_positions: torch.Tensor, | |
| cache: MoSRAHCache | None, | |
| ) -> torch.Tensor: | |
| """Compute semantic-sequence packed positions in expert-choice space. | |
| Without a sparse cache, semantic positions are the local packed sequence | |
| 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that | |
| same local sequence is offset by the current per-(batch, expert) occupancies | |
| returned by get_heads_lengths(). | |
| """ | |
| batch_size, num_experts, packed_length = packed_positions.shape | |
| # ------------------------------------------------------------------- | |
| # Construct the local packed sequence 0, 1, 2, ... over the expert-local | |
| # sequence dimension T. This is then broadcast across batch and experts. | |
| # ------------------------------------------------------------------- | |
| local_positions = torch.arange( | |
| packed_length, | |
| device=packed_positions.device, | |
| dtype=packed_positions.dtype, | |
| ).view(1, 1, packed_length).expand( | |
| batch_size, | |
| num_experts, | |
| packed_length, | |
| ) | |
| # ------------------------------------------------------------------- | |
| # In cached semantic-sequence mode, positions continue from the current | |
| # sparse-cache occupancies rather than restarting at zero for the local | |
| # chunk. | |
| # ------------------------------------------------------------------- | |
| if cache is None: | |
| return local_positions | |
| cached_lengths = cache.get_heads_lengths().to( | |
| device=packed_positions.device, | |
| dtype=packed_positions.dtype, | |
| ).unsqueeze(-1) | |
| return local_positions + cached_lengths | |
| class MoSRAHLayer(nn.Module): | |
| """Full routed sparse attention path for SHRAM. | |
| The MoSRAH path consumes model-space hidden states together with | |
| authoritative per-token positions and returns the model-space sparse-path | |
| contribution and a diagnostics dict from the router containing | |
| load-balance loss, routing-imbalance scalar, and load-balance health | |
| scalars. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.num_experts = config.num_mosrah_heads | |
| if config.use_cache: | |
| self.packed_length = config.mosrah_cache_length | |
| else: | |
| self.packed_length = config.mosrah_packed_length | |
| self.router = MoSRAHRouter(config) | |
| self.positions = SparseMoSRAHPositions(config) | |
| self.bea = BottleneckedEnsembleAttention(config) | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable parameters in this MoSRAH layer.""" | |
| return sum(p.numel() for p in self.parameters()) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: MoSRAHCache | None, | |
| router_cache: RouterCache | None = None, | |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| """Run the full MoSRAH sparse path. | |
| Args: | |
| hidden_states: Model-space hidden states x of shape (B, N, d). | |
| position_ids: Authoritative per-token positions of shape (B, N). | |
| active_mask: Current-chunk active mask of shape (B, N), where True | |
| means the token is semantically live. Forwarded to the router | |
| so dead tokens are excluded from routing statistics, and to | |
| pack_experts so dead outer tokens do not become semantically | |
| active packed entries. | |
| cache: Optional layer-local MoSRAH cache. Pass None for uncached | |
| execution and the layer-local cache instance for cached execution. | |
| Returns: | |
| sparse_output: Model-space sparse-path output of shape (B, N, d). | |
| router_diagnostics: Dict of router feedback scalars. Keys: | |
| ``regret_loss`` (has grad), ``logit_regret`` (detached), | |
| ``logit_std`` (detached). See MoSRAHRouter for semantics. | |
| """ | |
| # ------------------------------------------------------------------- | |
| # The first transition moves from model-space token-choice input into | |
| # the packed expert-choice sparse-attention state. Routing decides both | |
| # which experts each token uses and which unbiased probabilities must be | |
| # reserved for the final reduction. The active mask is forwarded to the | |
| # router so dead tokens are excluded from routing statistics, and to | |
| # pack_experts so outer liveness is faithfully carried into the packed | |
| # frame. Packing returns both the unpacking mask (slot occupancy, always | |
| # B*N*K True entries) and the packed active mask (live slots only); | |
| # active_mask is rebound to the packed form after this point. | |
| # ------------------------------------------------------------------- | |
| selected_heads, routing_probs, router_diagnostics = self.router( | |
| hidden_states, active_mask, router_cache | |
| ) | |
| setup = setup_packing(selected_heads) | |
| entries = { | |
| "hidden_states": (hidden_states, 0.0), | |
| "position_ids": (position_ids, 0), | |
| "active_mask": (active_mask, False), | |
| } | |
| packed, unpacking_map = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length) | |
| packed_hidden_states = packed["hidden_states"] | |
| packed_positions = packed["position_ids"] | |
| active_mask = packed["active_mask"] | |
| # ------------------------------------------------------------------- | |
| # Sparse attention runs entirely in the packed expert-choice frame, so | |
| # the RoPE position semantics must also be chosen in that frame. The | |
| # position layer therefore decides whether BEA should see packed | |
| # original-token positions or packed local-slot positions. BEA then | |
| # consumes that packed position tensor together with the packed hidden | |
| # states and the layer-local sparse cache, which it owns directly. | |
| # ------------------------------------------------------------------- | |
| bea_positions = self.positions( | |
| packed_positions=packed_positions, | |
| active_mask=active_mask, | |
| cache=cache, | |
| ) | |
| packed_outputs = self.bea( | |
| packed_embeddings=packed_hidden_states, | |
| position_ids=bea_positions, | |
| active_mask=active_mask, | |
| cache=cache, | |
| ) | |
| # ------------------------------------------------------------------- | |
| # The final transition restores token-choice meaning and only then | |
| # collapses the K routed copies back into model space. This ordering is | |
| # required because routing_probs live in token-choice space, whereas BEA | |
| # returns expert-choice packed outputs. The reduction must therefore | |
| # happen after unpacking, and it must use the router's unbiased | |
| # renormalized probabilities rather than any biased selection scores. | |
| # ------------------------------------------------------------------- | |
| token_choice_outputs = unpack_experts( | |
| expert_outputs=packed_outputs, | |
| setup=setup, | |
| flat_packed_transfer_indices=unpacking_map, | |
| selected_heads=selected_heads, | |
| ) | |
| final_output = ( | |
| token_choice_outputs * routing_probs.unsqueeze(-1) | |
| ).sum(dim=2) | |
| return final_output, router_diagnostics | |
| class SHRAMHybridLayer(nn.Module): | |
| """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot. | |
| The local path preserves nearby-token behavior through sliding-window causal | |
| attention. The sparse path is the theorem-facing MoSRAH routed attention | |
| path. Both operate over the same model-space hidden state and return | |
| model-space outputs, so the hybrid composition is a direct sum in model | |
| space. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.local_attention = SlidingWindowAttention(config) | |
| self.sparse_attention = MoSRAHLayer(config) | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable parameters in the MoSRAH sparse path.""" | |
| return self.sparse_attention.num_mosrah_parameters() | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: ShramLayerCache | None, | |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| """Apply the SHRAM hybrid attention layer. | |
| Args: | |
| hidden_states: Input hidden states of shape (B, N, d). | |
| position_ids: Authoritative token positions of shape (B, N). | |
| active_mask: Current-chunk active mask of shape (B, N), where True | |
| means the token is semantically live. Forwarded unchanged to | |
| both the local path and the sparse path. | |
| cache: Optional per-layer SHRAM cache. When provided, the owned | |
| sliding-window and MoSRAH sub-caches are dispatched directly to | |
| their corresponding attention paths. | |
| Returns: | |
| hybrid_output: Model-space hybrid attention output of shape (B, N, d). | |
| router_diagnostics: Dict of router feedback scalars passed through | |
| unchanged from MoSRAHLayer; see MoSRAHRouter for semantics. | |
| """ | |
| # ------------------------------------------------------------------- | |
| # The hybrid layer's first responsibility is cache dispatch. The layer | |
| # cache already owns the concrete sub-cache objects required by each | |
| # path, so this unit should forward those exact references rather than | |
| # reinterpret cache ownership or invent a composite update protocol here. | |
| # ------------------------------------------------------------------- | |
| if cache is None: | |
| sliding_window_cache = None | |
| mosrah_cache = None | |
| router_cache = None | |
| else: | |
| sliding_window_cache = cache.sliding_window_cache | |
| mosrah_cache = cache.mosrah_cache | |
| router_cache = cache.router_cache | |
| # ------------------------------------------------------------------- | |
| # Both attention paths must see the same model-space hidden state for | |
| # the current decoder layer. The local path preserves short-range | |
| # structure, while the sparse path provides the routed long-range | |
| # contribution and emits the load-balance signal used by training. | |
| # ------------------------------------------------------------------- | |
| local_output = self.local_attention( | |
| x=hidden_states, | |
| position_ids=position_ids, | |
| active_mask=active_mask, | |
| cache=sliding_window_cache, | |
| ) | |
| sparse_output, router_diagnostics = self.sparse_attention( | |
| hidden_states=hidden_states, | |
| position_ids=position_ids, | |
| active_mask=active_mask, | |
| cache=mosrah_cache, | |
| router_cache=router_cache, | |
| ) | |
| # ------------------------------------------------------------------- | |
| # The composition rule is intentionally simple at this boundary. Both | |
| # sublayers already return model-space tensors of matching shape, so the | |
| # correct hybrid behavior is their direct sum with no additional mixing | |
| # logic introduced here. | |
| # ------------------------------------------------------------------- | |
| hybrid_output = local_output + sparse_output | |
| return hybrid_output, router_diagnostics | |
| # ----------- | |
| # Inlined from: mlp.py | |
| # ----------- | |
| """SwiGLU feed-forward sublayer. | |
| SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection | |
| element-wise against a separate up-projection: | |
| output = W_down(SiLU(W_gate(x)) ⊙ W_up(x)) | |
| The gating mechanism gives the network more expressive control over which features | |
| to propagate than a plain two-matrix FFN. It requires three weight matrices instead | |
| of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier | |
| typical of two-matrix FFNs — the total parameter count remains comparable. | |
| SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically | |
| — a fixed architectural choice. | |
| """ | |
| class SwiGLUMLP(nn.Module): | |
| """SwiGLU feed-forward sublayer. | |
| Implements the three-matrix SwiGLU FFN used in Llama 3: | |
| output = W_down(SiLU(W_gate(x)) ⊙ W_up(x)) | |
| No bias on any projection. SiLU as the gate activation is an architectural | |
| constant — it is what defines SwiGLU specifically. | |
| Args: | |
| config: Model config. Must expose ``hidden_size`` and ``intermediate_size``. | |
| """ | |
| def __init__(self, config: PretrainedConfig) -> None: | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False) | |
| self.up_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False) | |
| self.down_proj = nn.Linear(config.mlp_width, config.embedding_width, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply the SwiGLU feed-forward transformation. | |
| Args: | |
| x: Input tensor of shape (batch, seq_len, hidden_size). | |
| Returns: | |
| Output tensor of shape (batch, seq_len, hidden_size). | |
| """ | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class DecoderLayer(nn.Module): | |
| """A single pre-norm SHRAM decoder block. | |
| Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and | |
| independent RMSNorm instances on each sublayer input. | |
| Args: | |
| config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps`` | |
| in addition to the fields required by SHRAMHybridLayer and | |
| SwiGLUMLP. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.attn_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps) | |
| self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps) | |
| self.attention = SHRAMHybridLayer(config) | |
| self.mlp = SwiGLUMLP(config) | |
| scale = 1.0 / math.sqrt(config.num_decoder_layers) | |
| if config.use_residual_gate: | |
| self.attn_residual_scale = nn.Parameter(torch.zeros(1)) | |
| self.mlp_residual_scale = nn.Parameter(torch.zeros(1)) | |
| else: | |
| self.register_buffer("attn_residual_scale", torch.full((1,), scale)) | |
| self.register_buffer("mlp_residual_scale", torch.full((1,), scale)) | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable MoSRAH parameters in this decoder layer.""" | |
| return self.attention.num_mosrah_parameters() | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: ShramLayerCache | None = None, | |
| ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: | |
| """Apply one decoder block to the input. | |
| Args: | |
| x: Input of shape (batch, seq_len, hidden_size). | |
| position_ids: Authoritative positions of shape (batch, seq_len). | |
| active_mask: Current-chunk active mask of shape (batch, seq_len), | |
| where True means the token is semantically live. Forwarded | |
| unchanged to the hybrid attention layer. | |
| cache: Optional per-layer SHRAM cache passed through to the hybrid | |
| attention layer unchanged. | |
| Returns: | |
| output: Tensor of shape (batch, seq_len, hidden_size). | |
| router_diagnostics: Dict of router feedback scalars passed through | |
| unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics. | |
| """ | |
| attn_out, router_diagnostics = self.attention( | |
| hidden_states=self.attn_norm(x), | |
| position_ids=position_ids, | |
| active_mask=active_mask, | |
| cache=cache, | |
| ) | |
| hidden_states = x + self.attn_residual_scale * attn_out | |
| output = hidden_states + self.mlp_residual_scale * self.mlp(self.mlp_norm(hidden_states)) | |
| return output, router_diagnostics | |
| class ShramModel(nn.Module): | |
| """Pure transformer backbone: decoder stack and final normalisation. | |
| Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size) | |
| and returns contextual representations of the same shape. No token embedding, | |
| vocabulary projection, or causal-LM lifecycle concerns. | |
| RoPE is applied inside each attention layer. Positional information is | |
| encoded in the relationship between Q and K, not added to the residual | |
| stream, so the backbone is agnostic to how positions are represented. | |
| Args: | |
| config: Model configuration. Must be a ``ShramConfig`` instance. | |
| """ | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| [DecoderLayer(config) for _ in range(config.num_decoder_layers)] | |
| ) | |
| self.norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps) | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable MoSRAH parameters across all decoder layers.""" | |
| return sum(layer.num_mosrah_parameters() for layer in self.layers) | |
| def forward( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache: ShramCache | None = None, | |
| output_hidden_states: bool = False, | |
| ) -> dict: | |
| """Run the transformer stack over a batch of pre-embedded sequences. | |
| Args: | |
| inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size). | |
| position_ids: Absolute positions of shape (batch, seq_len). Required. | |
| Must be provided explicitly by the caller — this module does not | |
| infer positions from cache state. | |
| active_mask: Current-chunk active mask of shape (batch, seq_len), | |
| where True means the token is semantically live. Forwarded | |
| unchanged to every decoder layer. | |
| cache: Optional top-level ShramCache. When provided, each DecoderLayer | |
| receives its own layer-local cache via ``cache.layers[layer_idx]``. | |
| The top-level cache object is updated in place and returned unchanged. | |
| output_hidden_states: When True, the output dict includes a tuple of | |
| per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out), | |
| collected before the final norm. | |
| Returns: | |
| Plain dict with keys: | |
| - ``"last_hidden_state"``: normed backbone output, | |
| shape (batch, seq_len, hidden_size). | |
| - ``"past_key_values"``: the cache object passed in, or None. | |
| - ``"hidden_states"``: tuple of per-layer activations (including | |
| inputs_embeds as position 0) if ``output_hidden_states`` is True, | |
| else None. Collected before the final norm so each entry reflects the | |
| unnormalised residual stream at that depth. | |
| - ``"regret_loss"``: scalar sum of per-layer SHRAM regret losses. | |
| Gradient flows through this tensor into the router. | |
| - ``"logit_regret"``: detached scalar — mean across layers of the | |
| logit-space regret. Monitoring metric for assignment quality. | |
| - ``"logit_std"``: detached scalar — mean across layers of the | |
| per-token routing logit spread. Monitoring metric for routing | |
| sharpness. | |
| """ | |
| hidden_states = inputs_embeds | |
| all_hidden_states = (hidden_states,) if output_hidden_states else None | |
| total_regret_loss = inputs_embeds.new_zeros(()) | |
| total_logit_regret = inputs_embeds.new_zeros(()) | |
| total_logit_std = inputs_embeds.new_zeros(()) | |
| for layer_idx, layer in enumerate(self.layers): | |
| layer_cache = None if cache is None else cache.layers[layer_idx] | |
| hidden_states, layer_diagnostics = layer( | |
| hidden_states, | |
| position_ids, | |
| active_mask, | |
| cache=layer_cache, | |
| ) | |
| total_regret_loss = total_regret_loss + layer_diagnostics["regret_loss"] | |
| total_logit_regret = total_logit_regret + layer_diagnostics["logit_regret"] | |
| total_logit_std = total_logit_std + layer_diagnostics["logit_std"] | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| hidden_states = self.norm(hidden_states) | |
| num_layers = len(self.layers) | |
| return { | |
| "last_hidden_state": hidden_states, | |
| "past_key_values": cache, | |
| "hidden_states": all_hidden_states, | |
| "regret_loss": total_regret_loss, | |
| "logit_regret": total_logit_regret / num_layers, | |
| "logit_std": total_logit_std / num_layers, | |
| } | |
| class ShramCausalLMOutput(CausalLMOutputWithPast): | |
| """SHRAM causal-LM wrapper output. | |
| This subclasses HuggingFace's standard ``CausalLMOutputWithPast``. | |
| Dataclass inheritance is sufficient here: all standard causal-LM fields and | |
| ModelOutput behavior are inherited from the parent, and this subclass adds | |
| only the SHRAM-specific wrapper outputs. | |
| """ | |
| ## Python dataclass inheritance violation: CausalLMOutputWithPast defaults all | |
| ## fields to None, which forces every subclass field to also carry a default. | |
| ## The = None below is a language constraint, not a semantic statement. In | |
| ## practice, regret_loss, logit_regret, and logit_std are always populated | |
| ## by ShramForCausalLM.forward(). ce_loss is genuinely optional — present | |
| ## only when labels are supplied. | |
| ce_loss: torch.FloatTensor | None = None | |
| regret_loss: torch.FloatTensor | None = None | |
| logit_regret: torch.Tensor | None = None | |
| logit_std: torch.Tensor | None = None | |
| class ShramForCausalLM(PreTrainedModel, GenerationMixin): | |
| """HuggingFace-facing causal language model wrapper for SHRAM. | |
| Owns token embeddings, LM-head projection, wrapper-level shifted CE loss, | |
| tied embedding configuration, and generation/cache boundary behavior. | |
| Delegates all transformer computation to ``ShramModel``. | |
| Args: | |
| config: SHRAM model configuration. | |
| """ | |
| config_class = ShramConfig | |
| base_model_prefix = "model" | |
| _no_split_modules = ["DecoderLayer"] | |
| supports_gradient_checkpointing = True | |
| _supports_assign_param_buffer = False | |
| def __init__(self, config: ShramConfig) -> None: | |
| super().__init__(config) | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_width) | |
| self.model = ShramModel(config) | |
| self.lm_head = nn.Linear(config.embedding_width, config.vocab_size, bias=False) | |
| self._configure_tied_embeddings() | |
| self.post_init() | |
| def _configure_tied_embeddings(self) -> None: | |
| """Apply config-controlled tied embedding behavior on this instance.""" | |
| if self.config.tie_word_embeddings: | |
| self.lm_head.weight = self.embed_tokens.weight | |
| self._tied_weights_keys = { | |
| "lm_head.weight": "embed_tokens.weight", | |
| } | |
| else: | |
| self._tied_weights_keys = {} | |
| def num_mosrah_parameters(self) -> int: | |
| """Return the total number of trainable parameters belonging to MoSRAH layers. | |
| Aggregates across all decoder layers. Excludes sliding-window path parameters, | |
| FFN parameters, norms, and embeddings. Use this for experimental plotting of | |
| MoSRAH parameter count versus performance. | |
| Returns: | |
| Total count of trainable MoSRAH parameters. | |
| """ | |
| return self.model.num_mosrah_parameters() | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| """Return the token embedding matrix.""" | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value: nn.Embedding) -> None: | |
| """Replace the token embedding matrix.""" | |
| self.embed_tokens = value | |
| self._configure_tied_embeddings() | |
| def get_output_embeddings(self) -> nn.Linear: | |
| """Return the LM head.""" | |
| return self.lm_head | |
| def set_output_embeddings(self, value: nn.Linear) -> None: | |
| """Replace the LM head.""" | |
| self.lm_head = value | |
| self._configure_tied_embeddings() | |
| def _build_shram_cache( | |
| self, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> ShramCache: | |
| """Construct a fresh top-level SHRAM cache.""" | |
| return ShramCache( | |
| config=self.config, | |
| batch_size=batch_size, | |
| device=device, | |
| ) | |
| def _validate_generation_cache_request( | |
| self, | |
| generation_config: Any, | |
| model_kwargs: dict[str, Any], | |
| generation_mode: GenerationMode, | |
| ) -> None: | |
| """Validate SHRAM's generation-side cache policy.""" | |
| if generation_mode in { | |
| GenerationMode.ASSISTED_GENERATION, | |
| GenerationMode.CONTRASTIVE_SEARCH, | |
| }: | |
| raise NotImplementedError( | |
| "ShramForCausalLM does not currently support assisted generation " | |
| "or contrastive search because ShramCache does not support crop()." | |
| ) | |
| user_defined_cache = model_kwargs.get("past_key_values") | |
| if user_defined_cache is not None: | |
| if generation_config.cache_implementation is not None: | |
| raise ValueError( | |
| "Passing both `cache_implementation` and `past_key_values` " | |
| "is unsupported. Please use only one." | |
| ) | |
| if isinstance(user_defined_cache, tuple): | |
| raise ValueError( | |
| "Passing a tuple of `past_key_values` is not supported. " | |
| "Please use a `ShramCache` instance." | |
| ) | |
| if not isinstance(user_defined_cache, ShramCache): | |
| raise TypeError( | |
| "ShramForCausalLM requires `past_key_values` to be a " | |
| "`ShramCache` instance." | |
| ) | |
| if ( | |
| user_defined_cache is None | |
| and generation_config.use_cache | |
| and generation_config.cache_implementation is not None | |
| ): | |
| raise ValueError( | |
| "ShramForCausalLM does not support `cache_implementation`. " | |
| "Generation-created caches must be `ShramCache` objects." | |
| ) | |
| def _prepare_cache_for_generation( | |
| self, | |
| generation_config: Any, | |
| model_kwargs: dict[str, Any], | |
| generation_mode: GenerationMode, | |
| batch_size: int, | |
| max_cache_length: int, | |
| ) -> None: | |
| """Ensure HuggingFace generation uses ShramCache. | |
| This is the SHRAM-specific generation hook. The rest of the default | |
| generation plumbing is kept intact as much as possible. | |
| Args: | |
| generation_config: Active generation configuration. | |
| model_kwargs: Generation kwargs, updated in place. | |
| generation_mode: HuggingFace generation mode. | |
| batch_size: Effective generation batch size. | |
| max_cache_length: Requested cache length. Accepted but unused here. | |
| """ | |
| self._validate_generation_cache_request( | |
| generation_config=generation_config, | |
| model_kwargs=model_kwargs, | |
| generation_mode=generation_mode, | |
| ) | |
| if model_kwargs.get("past_key_values") is not None: | |
| return | |
| if not generation_config.use_cache: | |
| return | |
| num_repeats = max( | |
| generation_config.num_beams or 1, | |
| generation_config.num_return_sequences or 1, | |
| ) | |
| model_kwargs["past_key_values"] = self._build_shram_cache( | |
| batch_size=batch_size*num_repeats, | |
| device=self.embed_tokens.weight.device, | |
| ) | |
| def _reorder_cache( | |
| self, | |
| past_key_values: Cache, | |
| beam_idx: torch.Tensor, | |
| ) -> Cache: | |
| """Reorder the cache in place for beam search.""" | |
| past_key_values.reorder_cache(beam_idx) | |
| return past_key_values | |
| def create_masks_for_generate( | |
| attention_mask: torch.Tensor | None, | |
| **kwargs: Any, | |
| ) -> torch.Tensor | None: | |
| """Return the 2D attention_mask unchanged. | |
| HuggingFace calls this during compiled generation to convert the 2D | |
| attention mask into a 4D causal additive-bias mask. SHRAM uses flex | |
| attention with custom masking and constructs causality internally; the | |
| 4D format is incompatible with the SHRAM masking contract. Overriding | |
| as a no-op restores symmetry between compiled and non-compiled pathways | |
| without any loss of correctness or performance (see Unit 19.G.4). | |
| """ | |
| return attention_mask | |
| def _validate_input_ids(self, input_ids: torch.Tensor) -> None: | |
| """Validate token IDs at the wrapper boundary.""" | |
| if input_ids.ndim != 2: | |
| raise ValueError("input_ids must have shape (batch, seq_len).") | |
| if input_ids.shape[1] == 0: | |
| raise ValueError("input_ids sequence length must be nonzero.") | |
| if input_ids.dtype != torch.long: | |
| raise TypeError("input_ids must be an long int tensor.") | |
| def _validate_attention_mask( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| ) -> None: | |
| """Validate the full-sequence attention mask.""" | |
| if attention_mask is None: | |
| return | |
| if attention_mask.ndim != 2: | |
| raise ValueError("attention_mask must have shape (batch, total_seq_len).") | |
| if attention_mask.shape[0] != input_ids.shape[0]: | |
| raise ValueError("attention_mask batch dimension must match input_ids.") | |
| if attention_mask.shape[1] < input_ids.shape[1]: | |
| raise ValueError( | |
| "attention_mask must be at least as long as the current input_ids chunk." | |
| ) | |
| def _validate_position_ids( | |
| self, | |
| input_ids: torch.Tensor, | |
| position_ids: torch.Tensor | None, | |
| ) -> None: | |
| """Validate current-step position IDs.""" | |
| if position_ids is None: | |
| return | |
| if position_ids.ndim != 2: | |
| raise ValueError("position_ids must have shape (batch, seq_len).") | |
| if position_ids.shape != input_ids.shape: | |
| raise ValueError( | |
| "position_ids must match the current input_ids shape exactly." | |
| ) | |
| if position_ids.dtype != torch.long: | |
| raise TypeError("position_ids must be an long tensor.") | |
| def _validate_labels( | |
| self, | |
| input_ids: torch.Tensor, | |
| labels: torch.Tensor | None, | |
| ) -> None: | |
| """Validate label shape at the wrapper boundary.""" | |
| if labels is None: | |
| return | |
| if labels.ndim != 2: | |
| raise ValueError("labels must have shape (batch, seq_len).") | |
| if labels.shape != input_ids.shape: | |
| raise ValueError("labels must have the same shape as input_ids.") | |
| if labels.dtype != torch.long: | |
| raise TypeError("labels must be a long tensor.") | |
| def _validate_cache_inputs( | |
| self, | |
| use_cache: bool, | |
| past_key_values: Cache | None, | |
| ) -> None: | |
| """Validate cache policy for direct wrapper calls.""" | |
| if use_cache: | |
| if past_key_values is None: | |
| raise ValueError( | |
| "use_cache=True requires an explicit ShramCache. During " | |
| "generate(), HuggingFace should supply this through " | |
| "_prepare_cache_for_generation()." | |
| ) | |
| if not isinstance(past_key_values, ShramCache): | |
| raise TypeError( | |
| "past_key_values must be a ShramCache when use_cache=True." | |
| ) | |
| return | |
| if past_key_values is not None: | |
| raise ValueError("past_key_values was provided while use_cache=False.") | |
| def _validate_position_sources( | |
| self, | |
| use_cache: bool, | |
| attention_mask: torch.Tensor | None, | |
| position_ids: torch.Tensor | None, | |
| ) -> None: | |
| """Validate that cached forward has a truthful source of positions.""" | |
| if use_cache and attention_mask is None and position_ids is None: | |
| raise ValueError( | |
| "Cached forward requires either position_ids or attention_mask." | |
| ) | |
| def _validate_hf_boundary( | |
| self, | |
| output_attentions: bool | None, | |
| return_dict: bool | None, | |
| inputs_embeds: torch.Tensor | None, | |
| cache_position: torch.Tensor | None, | |
| extra_kwargs: dict[str, Any], | |
| ) -> None: | |
| """Validate unsupported HuggingFace-facing wrapper inputs.""" | |
| if output_attentions: | |
| raise NotImplementedError( | |
| "ShramForCausalLM does not expose output_attentions." | |
| ) | |
| if return_dict is False: | |
| raise ValueError( | |
| "return_dict=False is not supported. " | |
| "ShramForCausalLM always returns ShramCausalLMOutput." | |
| ) | |
| if inputs_embeds is not None: | |
| raise ValueError( | |
| "inputs_embeds is not supported at the SHRAM wrapper boundary. " | |
| "Pass input_ids instead." | |
| ) | |
| if extra_kwargs: | |
| unsupported = ", ".join(sorted(extra_kwargs)) | |
| raise TypeError( | |
| f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}" | |
| ) | |
| def _enforce_uncached_starting_position(condition: torch.Tensor) -> None: | |
| """Enforce that an uncached forward pass begins at position 0. | |
| An uncached forward has no prior KV state. Nonzero starting positions | |
| produce silently incorrect RoPE encoding and attention outputs with no | |
| downstream diagnostic. This method intercepts that misuse at the | |
| outermost boundary before any backbone computation runs. | |
| To resolve a violation: either supply a ShramCache populated with the | |
| prefix (for continued decoding), or rebase the sequence so positions | |
| start at 0. | |
| Args: | |
| condition: Scalar bool tensor. True = all batch items start at 0 | |
| (valid); False = at least one batch item starts nonzero | |
| (violated). | |
| """ | |
| if torch.compiler.is_compiling(): | |
| torch._assert_async( | |
| condition, | |
| "Uncached ShramForCausalLM: nonzero starting positions. " | |
| "Supply a ShramCache with prefix or rebase sequence to start at 0.", | |
| ) | |
| else: | |
| if not condition.item(): | |
| raise RuntimeError( | |
| "Uncached ShramForCausalLM forward does not support nonzero " | |
| "starting positions. Either provide a ShramCache populated " | |
| "with the prefix for continued decoding, or rebase the " | |
| "uncached sequence to start at 0.", | |
| ) | |
| def _standardize_full_attention_mask( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None, | |
| ) -> torch.BoolTensor: | |
| """Return a concrete full-sequence boolean attention mask.""" | |
| if attention_mask is None: | |
| return torch.ones_like(input_ids, dtype=torch.bool) | |
| return attention_mask.to(dtype=torch.bool) | |
| def _resolve_current_position_ids( | |
| self, | |
| input_ids: torch.Tensor, | |
| position_ids: torch.Tensor | None, | |
| current_active_mask: torch.BoolTensor, | |
| cache: ShramCache | None, | |
| ) -> torch.LongTensor: | |
| """Resolve concrete current-step position IDs for the backbone. | |
| Builds a fresh contiguous allocation via arange + per-batch bias. No cumsum | |
| or stride-based views are produced; the returned tensor is always a new | |
| allocation safe for Inductor tracing at the FlexAttention boundary. | |
| When a cache is present, ``total_active_tokens()`` provides the per-batch | |
| accumulated active token count as a position bias. Uncached calls use a zero | |
| bias. In both cases positions are ``bias + arange(current_length)``, with | |
| inactive positions masked to 0. | |
| Args: | |
| input_ids: Current token IDs of shape ``(B, N)``. | |
| position_ids: Explicit positions if supplied by the caller; returned | |
| unchanged (cast to long). Bias computation is skipped entirely. | |
| current_active_mask: Boolean mask of shape ``(B, N)`` for the current step. | |
| cache: Active ``ShramCache``, or ``None`` for uncached forward passes. | |
| Returns: | |
| Long tensor of shape ``(B, N)`` — position index per token, 0 for inactive. | |
| """ | |
| if position_ids is not None: | |
| return position_ids.to(dtype=torch.long) | |
| current_length = input_ids.shape[1] | |
| if cache is not None: | |
| position_bias = cache.total_active_tokens(current_active_mask) | |
| else: | |
| position_bias = torch.zeros( | |
| input_ids.shape[0], dtype=torch.long, device=input_ids.device | |
| ) | |
| positions = position_bias.unsqueeze(1) + torch.arange( | |
| current_length, device=input_ids.device, dtype=torch.long | |
| ) | |
| return positions.masked_fill(~current_active_mask, 0) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.Tensor | None = None, | |
| past_key_values: Cache | None = None, | |
| use_cache: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| labels: torch.Tensor | None = None, | |
| return_dict: bool | None = None, | |
| ce_weight: float = 1.0, | |
| load_balance_weight: float = 0.01, | |
| **kwargs: Any, | |
| ) -> ShramCausalLMOutput: | |
| """Run the SHRAM causal language model wrapper. | |
| Args: | |
| input_ids: Current token IDs of shape ``(batch, seq_len)``. | |
| attention_mask: Optional full 2D mask of shape | |
| ``(batch, total_seq_len)``. The wrapper slices its recent chunk | |
| to produce the current semantic liveness mask expected by the | |
| backbone. | |
| position_ids: Optional current-step position IDs of shape | |
| ``(batch, seq_len)``. In ordinary HuggingFace generation this is | |
| already the current-step tensor when it reaches ``forward()``. | |
| past_key_values: Optional SHRAM cache. Required when | |
| ``use_cache=True``. | |
| use_cache: Whether to use and return a cache. Defaults to | |
| ``config.use_cache``. | |
| output_hidden_states: Whether to return backbone hidden states. | |
| Defaults to ``config.output_hidden_states``. | |
| labels: Optional target token IDs of shape ``(batch, seq_len)``. | |
| Pass unshifted labels (same alignment as ``input_ids``). This | |
| wrapper shifts internally: ``logits[:, :-1]`` is compared | |
| against ``labels[:, 1:]``. Do not pre-shift the caller side. | |
| return_dict: Must be ``True`` or ``None``. | |
| ce_weight: Weight applied to the cross-entropy loss when combining with | |
| the regret loss. Default 1.0. | |
| load_balance_weight: Weight applied to the regret loss. | |
| Default 0.01, matching the paper's recommendation. | |
| **kwargs: Unsupported HuggingFace kwargs fail explicitly. | |
| Returns: | |
| ``ShramCausalLMOutput`` with: | |
| - ``logits`` of shape ``(batch, seq_len, vocab_size)``, | |
| - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * regret_loss`` | |
| when labels are provided (``None`` otherwise), | |
| - ``ce_loss`` — raw unweighted cross-entropy loss for logging, | |
| - ``past_key_values`` as the active ``ShramCache`` or ``None``, | |
| - ``hidden_states`` when requested, | |
| - ``regret_loss`` — raw unweighted regret loss from the backbone, | |
| - ``logit_regret`` — detached mean logit-space regret across layers, | |
| - ``logit_std`` — detached mean per-token routing logit spread across layers. | |
| """ | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| inputs_embeds = kwargs.pop("inputs_embeds", None) | |
| output_attentions = kwargs.pop("output_attentions", None) | |
| cache_position = kwargs.pop("cache_position", None) | |
| # ------------------------------------------------------------------ | |
| # Validation zone. | |
| # | |
| # The wrapper boundary is where HuggingFace-facing inputs are judged | |
| # for truthfulness before any internal work begins. These checks are | |
| # intentionally front-loaded so the core logic below can assume one | |
| # coherent interpretation of the call rather than defensively checking | |
| # shapes, cache policy, or unsupported HF knobs at the point of use. | |
| # This keeps the main sequence readable while ensuring invalid states | |
| # fail before they can silently contaminate backbone execution. | |
| # ------------------------------------------------------------------ | |
| self._validate_input_ids(input_ids) | |
| self._validate_attention_mask(input_ids, attention_mask) | |
| self._validate_position_ids(input_ids, position_ids) | |
| self._validate_labels(input_ids, labels) | |
| self._validate_cache_inputs(use_cache, past_key_values) | |
| self._validate_position_sources(use_cache, attention_mask, position_ids) | |
| self._validate_hf_boundary( | |
| output_attentions=output_attentions, | |
| return_dict=return_dict, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| extra_kwargs=kwargs, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Standardization zone. | |
| # | |
| # HuggingFace and SHRAM use different boundary conventions: generation | |
| # carries a full-sequence 2D attention mask, while the SHRAM backbone | |
| # wants a current-step active mask and concrete current position IDs. | |
| # This zone collapses those wrapper-facing conventions into one valid | |
| # backbone-facing state. After this point the core no longer reasons | |
| # about optional or ambiguous input forms; it works only with concrete | |
| # tensors whose semantics are already fixed. | |
| # ------------------------------------------------------------------ | |
| full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| current_length: int = input_ids.shape[1] | |
| current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:] | |
| shram_cache: ShramCache | None = past_key_values if use_cache else None | |
| current_position_ids: torch.LongTensor = self._resolve_current_position_ids( | |
| input_ids=input_ids, | |
| position_ids=position_ids, | |
| current_active_mask=current_active_mask, | |
| cache=shram_cache, | |
| ) | |
| if shram_cache is None: | |
| positions_start_sane = torch.all(current_position_ids[:, 0] == 0) | |
| self._enforce_uncached_starting_position(positions_start_sane) | |
| # ------------------------------------------------------------------ | |
| # Core wrapper responsibilities. | |
| # | |
| # The wrapper's primary job is kept visible here: convert token IDs to | |
| # embeddings, delegate transformer computation to ShramModel, project | |
| # hidden states back to vocabulary logits, optionally compute the | |
| # wrapper-level shifted next-token loss, and return the HuggingFace- | |
| # facing output object. The backbone remains responsible only for | |
| # transformer semantics; token/vocabulary/loss concerns stay here. | |
| # ------------------------------------------------------------------ | |
| token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids) | |
| backbone_outputs = self.model( | |
| inputs_embeds=token_embeddings, | |
| position_ids=current_position_ids, | |
| active_mask=current_active_mask, | |
| cache=shram_cache, | |
| output_hidden_states=output_hidden_states, | |
| ) | |
| logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"]) | |
| ce_loss: torch.FloatTensor | None = None | |
| loss: torch.FloatTensor | None = None | |
| if labels is not None: | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| ce_loss = nn.functional.cross_entropy( | |
| shift_logits.view(-1, self.config.vocab_size), | |
| shift_labels.view(-1), | |
| ) | |
| loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["regret_loss"] | |
| return ShramCausalLMOutput( | |
| loss=loss, | |
| ce_loss=ce_loss, | |
| logits=logits, | |
| past_key_values=backbone_outputs["past_key_values"], | |
| hidden_states=backbone_outputs["hidden_states"], | |
| regret_loss=backbone_outputs["regret_loss"], | |
| logit_regret=backbone_outputs["logit_regret"], | |
| logit_std=backbone_outputs["logit_std"], | |
| ) |