smithblack-0 commited on
Commit
1670228
·
verified ·
1 Parent(s): e9503f5

Update architecture and tokenizer

Browse files
README.md CHANGED
@@ -1,3 +1,106 @@
1
  ---
 
 
2
  license: mit
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - en
4
  license: mit
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - pytorch
9
+ - research
10
+ - sparse-attention
11
+ - mixture-of-experts
12
  ---
13
+
14
+ # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
+
16
+ A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
+ Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
+ the Hub and instantiate a freshly initialised model from config. Every parameter is
19
+ overridable at instantiation time via kwargs.
20
+
21
+ > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
+ > source files from the Hub and imports them into your Python process. Review the
23
+ > source at [smithblack-0/SHRAM-dev](https://huggingface.co/smithblack-0/SHRAM-dev) before use. Those interested can also
24
+ > clone the git repository at https://github.com/smithblack-0/advanced-transformers-lib
25
+
26
+ ## Architecture
27
+
28
+ SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
29
+
30
+ - **h_l** — local sliding-window causal attention path.
31
+ - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
32
+ via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
33
+
34
+ All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
35
+
36
+ ## Usage
37
+
38
+ This repository contains no pretrained weights. The intended workflow is: pull the
39
+ architecture config from the Hub, instantiate a model with fresh random weights, then
40
+ train it yourself.
41
+
42
+ ```python
43
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
44
+
45
+ # Step 1: pull the architecture config from the Hub.
46
+ # AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
47
+ # Override any parameter via kwargs.
48
+ config = AutoConfig.from_pretrained(
49
+ "smithblack-0/SHRAM-dev",
50
+ trust_remote_code=True,
51
+ num_hidden_layers=16, # example override
52
+ num_mosrah_heads=32, # example override
53
+ )
54
+
55
+ # Step 2: instantiate with fresh random weights.
56
+ # from_config never loads a checkpoint — it always produces a randomly initialised model.
57
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
58
+
59
+ # Step 3: load the tokenizer.
60
+ tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM-dev")
61
+ ```
62
+
63
+ After training your own checkpoint, save and reload it in the standard way:
64
+
65
+ ```python
66
+ model.save_pretrained("./my-checkpoint")
67
+ model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
68
+ ```
69
+
70
+ ## Constructor Defaults
71
+
72
+ The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
73
+ no overrides. They are not the parameters of a pretrained model — this repository
74
+ contains no weights. All values are overridable via kwargs.
75
+
76
+ | Parameter | Default |
77
+ |-----------|---------|
78
+ | `alpha` | 1.0 |
79
+ | `attention_dropout` | 0.0 |
80
+ | `beta` | 32.0 |
81
+ | `dtype` | None |
82
+ | `embedding_width` | 512 |
83
+ | `head_dim` | 16 |
84
+ | `inference_sequence_length` | 1024 |
85
+ | `load_balance_p` | 2.0 |
86
+ | `local_rope_theta` | 10000.0 |
87
+ | `mlp_width` | 1366 |
88
+ | `mosrah_overallocation_factor` | 2.0 |
89
+ | `mosrah_rope_theta` | 10000.0 |
90
+ | `num_decoder_layers` | 12 |
91
+ | `num_mosrah_heads` | 16 |
92
+ | `num_selected_heads` | 16 |
93
+ | `num_sliding_window_heads` | 16 |
94
+ | `output_hidden_states` | False |
95
+ | `rms_norm_eps` | 1e-05 |
96
+ | `rope_mode` | main_sequence |
97
+ | `tie_word_embeddings` | False |
98
+ | `training_sequence_length` | 1024 |
99
+ | `use_cache` | True |
100
+ | `vocab_size` | 50277 |
101
+ | `window_size` | 128 |
102
+
103
+ ## License
104
+
105
+ MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
106
+ (`EleutherAI/gpt-neox-20b`, Apache 2.0).
__attention__bottlenecked_ensemble_attention.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
+
3
+ BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
+ It consumes packed expert-choice tensors, a supplied position tensor, an active
5
+ token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
+ same packed expert-choice space expected by later unpacking.
7
+
8
+ BEA does not compute positions and does not choose packed-position semantics.
9
+ Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
+ (K̃) and raw values (V) into the sparse cache and attends against the
11
+ accumulated cached state returned by that cache.
12
+ """
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
+
20
+ from .configuration import ShramConfig
21
+ from .__cache__mosrah_cache import MoSRAHCache
22
+ from .rope import RotaryEmbedding
23
+
24
+
25
+ class BottleneckedEnsembleAttention(nn.Module):
26
+ """
27
+ Packed expert-choice attention operator for the MoSRAH sparse path.
28
+ Operates per-head independently on an ensemble of tokens.
29
+ FlexAttention saves flops on dead tokens.
30
+
31
+ Architectural properties:
32
+ - consumes packed expert-choice tensors of shape (B, L, T, d)
33
+ - uses independent per-head Q/K/V/O projection parameters
34
+ - applies YaRN-capable RoPE using supplied position_ids
35
+ - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
+ - uses a fast fused attention path
37
+ - returns outputs in the same packed expert-choice space (B, L, T, d)
38
+
39
+ Args:
40
+ config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
+ `head_dim`, `mosrah_rope_theta`, `inference_sequence_length`,
42
+ `scale`, `alpha`, and `beta`.
43
+ """
44
+
45
+ def __init__(self, config: ShramConfig) -> None:
46
+ super().__init__()
47
+
48
+ self.hidden_size = config.embedding_width
49
+ self.num_heads = config.num_mosrah_heads
50
+ self.head_dim = config.head_dim
51
+
52
+ # Independent per-head projections. No cross-head parameter sharing.
53
+ self.q_proj = nn.Parameter(
54
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
+ )
56
+ self.k_proj = nn.Parameter(
57
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
+ )
59
+ self.v_proj = nn.Parameter(
60
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
+ )
62
+ self.o_proj = nn.Parameter(
63
+ torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
+ )
65
+
66
+ self._reset_parameters()
67
+
68
+ # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
+ # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
+ # no yarn dilation occurs.
71
+ #
72
+ # The required table size depends on position semantics:
73
+ # main_sequence — positions are original token positions, bounded by
74
+ # inference_sequence_length.
75
+ # semantic_sequence — positions are local per-expert slot indices, bounded
76
+ # by mosrah_packed_length.
77
+ maximum_rope_length = (
78
+ config.mosrah_packed_length
79
+ if config.rope_mode == "semantic_sequence"
80
+ else config.inference_sequence_length
81
+ )
82
+ self.rope = RotaryEmbedding(
83
+ mode="yarn",
84
+ head_dim=self.head_dim,
85
+ theta=config.mosrah_rope_theta,
86
+ maximum_sequence_length=maximum_rope_length,
87
+ dilation=config.scale,
88
+ alpha=config.alpha,
89
+ beta=config.beta,
90
+ )
91
+
92
+ def forward(
93
+ self,
94
+ packed_embeddings: torch.Tensor,
95
+ position_ids: torch.Tensor,
96
+ active_mask: torch.Tensor,
97
+ cache: MoSRAHCache | None = None,
98
+ ) -> torch.Tensor:
99
+ """Apply BEA to packed expert-choice tensors.
100
+
101
+ Args:
102
+ packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
103
+ position_ids: Supplied packed positions of shape (B, L, T).
104
+ active_mask: Boolean active-token mask of shape (B, L, T).
105
+ cache: Optional layer-local MoSRAH cache.
106
+
107
+ Returns:
108
+ Packed expert-choice output tensor of shape (B, L, T, d).
109
+ """
110
+ batch_size, _, query_length, _ = packed_embeddings.shape
111
+ self._validate_tensor_shape(packed_embeddings)
112
+ self._validate_position_shape(packed_embeddings, position_ids)
113
+ self._validate_active_mask_shape(packed_embeddings, active_mask)
114
+
115
+ # Independent per-head projections:
116
+ # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
117
+ query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
118
+ key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
119
+ value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
120
+
121
+ rotated_query_states, rotated_key_states, attention_scaling = self.rope(
122
+ query_states,
123
+ key_states,
124
+ position_ids,
125
+ )
126
+
127
+ if cache is not None:
128
+ # In cached execution, the current query tensor uses local tensor rows
129
+ # 0..Q-1, but the key tensor returned by the cache is the full accumulated
130
+ # packed sequence for each (batch, head) slot. The only additional data
131
+ # needed to align those two views is the pre-update cached prefix length.
132
+ # which will indicate how many queries were processed before now.
133
+ num_tokens_processed = cache.get_heads_lengths().clone()
134
+ key_states, value_states, key_active_mask = cache.update(
135
+ rotated_key_states,
136
+ value_states,
137
+ active_mask,
138
+ )
139
+ else:
140
+ num_tokens_processed = torch.zeros(
141
+ batch_size,
142
+ self.num_heads,
143
+ dtype=torch.long,
144
+ device=packed_embeddings.device,
145
+ )
146
+ key_states = rotated_key_states
147
+ key_active_mask = active_mask
148
+
149
+ block_mask = self._make_block_mask(
150
+ query_active_mask=active_mask,
151
+ key_active_mask=key_active_mask,
152
+ num_tokens_processed=num_tokens_processed,
153
+ query_length=query_length,
154
+ key_length=key_states.shape[2],
155
+ device=packed_embeddings.device,
156
+ )
157
+ attended_states = flex_attention(
158
+ rotated_query_states,
159
+ key_states,
160
+ value_states,
161
+ block_mask=block_mask,
162
+ scale=attention_scaling / math.sqrt(self.head_dim),
163
+ )
164
+
165
+ # Project back to model width:
166
+ # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
167
+ return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
168
+
169
+ def _reset_parameters(self) -> None:
170
+ """Initialize per-head projection weights."""
171
+ for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
172
+ nn.init.xavier_uniform_(weight)
173
+
174
+ def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
175
+ """Validate the local packed-embedding shape contract required by BEA."""
176
+ if packed_embeddings.shape[1] != self.num_heads:
177
+ raise ValueError(
178
+ f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
179
+ f"got {packed_embeddings.shape[1]}."
180
+ )
181
+
182
+ if packed_embeddings.shape[-1] != self.hidden_size:
183
+ raise ValueError(
184
+ f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
185
+ f"got {packed_embeddings.shape[-1]}."
186
+ )
187
+
188
+ def _validate_position_shape(
189
+ self,
190
+ packed_embeddings: torch.Tensor,
191
+ position_ids: torch.Tensor,
192
+ ) -> None:
193
+ """Validate the supplied packed-position tensor shape."""
194
+ if position_ids.shape != packed_embeddings.shape[:3]:
195
+ raise ValueError(
196
+ f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
197
+ f"got {tuple(position_ids.shape)}."
198
+ )
199
+
200
+ def _validate_active_mask_shape(
201
+ self,
202
+ packed_embeddings: torch.Tensor,
203
+ active_mask: torch.Tensor,
204
+ ) -> None:
205
+ """Validate the supplied active-token mask shape."""
206
+ if active_mask.shape != packed_embeddings.shape[:3]:
207
+ raise ValueError(
208
+ f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
209
+ f"got {tuple(active_mask.shape)}."
210
+ )
211
+
212
+ def _make_block_mask(
213
+ self,
214
+ query_active_mask: torch.Tensor,
215
+ key_active_mask: torch.Tensor,
216
+ num_tokens_processed: torch.Tensor,
217
+ query_length: int,
218
+ key_length: int,
219
+ device: torch.device,
220
+ ):
221
+ """Create the packed-sequence causal mask for FlexAttention.
222
+
223
+ At the root, causality is still triangular. The only nuance is cached
224
+ execution: query rows are indexed locally as 0..Q-1 inside the current
225
+ query tensor, but the key tensor may already contain a cached prefix for
226
+ that (batch, head) slot. The causal horizon for query tensor row q is
227
+ therefore:
228
+
229
+ cached_prefix_lengths[b, h] + q
230
+
231
+ Query and key activity masks are then composed with that triangular rule
232
+ so FlexAttention can skip padded query rows and ignore inactive key slots.
233
+ """
234
+ batch_size, num_heads, _ = query_active_mask.shape
235
+
236
+ # Build the per-(batch, head, query_row) triangular horizon from a simple
237
+ # arange over query rows plus the cached prefix lengths for each slot.
238
+ relative_query_positions = torch.arange(
239
+ query_length,
240
+ device=device,
241
+ dtype=torch.long,
242
+ ).view(1, 1, query_length)
243
+ causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
244
+
245
+ def packed_causal_mask(
246
+ batch_idx: torch.Tensor,
247
+ head_idx: torch.Tensor,
248
+ query_idx: torch.Tensor,
249
+ key_idx: torch.Tensor,
250
+ ) -> torch.Tensor:
251
+ query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
252
+ key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
253
+ is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
254
+ return query_is_active & key_is_active & is_causal
255
+
256
+ return create_block_mask(
257
+ packed_causal_mask,
258
+ B=batch_size,
259
+ H=num_heads,
260
+ Q_LEN=query_length,
261
+ KV_LEN=key_length,
262
+ device=device,
263
+ )
__attention__expert_packing.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Expert packing and unpacking for the MoSRAH path.
2
+
3
+ This module implements the low-level token-choice -> expert-choice -> token-choice
4
+ conversion boundary specified in the paper. The externally visible behavior is fixed:
5
+
6
+ - setup_packing() prepares the auxiliary ordering data and returns it as a dict
7
+ payload forwarded whole to pack_experts and unpack_experts.
8
+ - pack_experts() converts a dict of routed token-choice tensors into packed
9
+ expert-choice form. Each entry is paired with its intended padding value; all
10
+ entries undergo the same expert-major gather-scatter so they remain aligned.
11
+ - unpack_experts() restores token-choice ordering afterward.
12
+
13
+ Stable sort is a correctness requirement. It preserves causal ordering inside each
14
+ expert bucket, which is the foundation on which BEA's later triangular causal mask
15
+ is correct.
16
+
17
+ pack_experts() returns the packed entries dict together with a separate unpacking_mask.
18
+ Two masks serve different roles and must not be interchanged:
19
+
20
+ - unpacking_mask: marks every packed slot that contains a routed token copy,
21
+ live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
22
+ so its reshape invariant holds regardless of outer token liveness.
23
+ - active_mask (caller-supplied entry): marks only the packed slots whose source
24
+ token was semantically live. This is what BEA consumes for attention gating.
25
+ Dead outer tokens must not influence sparse attention outputs.
26
+ """
27
+
28
+ import torch
29
+ from typing import Any
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Setup
34
+ # ---------------------------------------------------------------------------
35
+
36
+ def setup_packing(
37
+ selected_heads: torch.Tensor,
38
+ ) -> dict[str, torch.Tensor]:
39
+ """Prepare the auxiliary ordering data used by pack/unpack.
40
+
41
+ Routing produces token-choice state I of shape (B, N, K): for each token, which
42
+ K experts were selected. Packing needs the same routed token copies reordered into
43
+ expert-major order so each expert bucket becomes contiguous.
44
+
45
+ The paper's setup step does this by flattening (N, K) into one axis to produce
46
+ H in token-major order, then computing a stable argsort permutation Pi over the
47
+ expert indices stored in H. Applying Pi reorders the flattened routed copies into
48
+ expert-major order while preserving their original token order *within* each expert
49
+ bucket. That preservation is why stable sort is required for causality.
50
+
51
+ Args:
52
+ selected_heads: Routed token-choice head selections I of shape (B, N, K).
53
+
54
+ Returns:
55
+ Auxiliary payload dict with keys:
56
+ - "flattened_selected_heads": H of shape (B, N*K)
57
+ - "permutation": stable expert-major permutation Pi of shape (B, N*K)
58
+ - "inverse_permutation": inverse permutation Pi^{-1} of shape (B, N*K)
59
+ This dict is forwarded whole to pack_experts and unpack_experts.
60
+ """
61
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
62
+ flattened_selected_heads = selected_heads.reshape(
63
+ batch_size,
64
+ sequence_length * num_selected_heads,
65
+ )
66
+
67
+ permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
68
+ inverse_permutation = torch.argsort(permutation, dim=-1)
69
+
70
+ return {
71
+ "flattened_selected_heads": flattened_selected_heads,
72
+ "permutation": permutation,
73
+ "inverse_permutation": inverse_permutation,
74
+ }
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Packing
79
+ # ---------------------------------------------------------------------------
80
+
81
+ def pack_experts(
82
+ entries: dict[str, tuple[torch.Tensor, Any]],
83
+ setup: dict[str, torch.Tensor],
84
+ selected_heads: torch.Tensor,
85
+ num_experts: int,
86
+ packed_length: int,
87
+ ) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
88
+ """Pack token-choice tensors into expert-choice padded form.
89
+
90
+ The paper's packing path has two jobs:
91
+
92
+ 1. Convert routed token-choice copies into expert-major order.
93
+ 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
94
+
95
+ All entries in the provided dict undergo the same expert-major gather-scatter so
96
+ they remain mutually aligned in the packed frame. Each entry is paired with its
97
+ intended padding value, which fills slots that contain no routed token copy.
98
+
99
+ Packed positions are sourced from the authoritative upstream position_ids tensor
100
+ rather than synthesized locally from arange(N). This preserves advanced positions
101
+ correctly during cached inference while leaving training/full-sequence behavior
102
+ unchanged when position_ids is the ordinary sequential token positions.
103
+
104
+ Args:
105
+ entries: Mapping from string keys to (tensor, padding_value) pairs. Each
106
+ tensor has shape (B, N, ...) and is rearranged into expert-choice layout
107
+ (B, L, T, ...). The returned dict carries the same keys.
108
+ setup: Auxiliary payload returned by setup_packing().
109
+ selected_heads: Routed head selections I of shape (B, N, K).
110
+ num_experts: Total number of experts L.
111
+ packed_length: Static packed time dimension T. All per-expert buffers are
112
+ allocated to exactly this length. Use config.mosrah_packed_length as the
113
+ source of this value. Raises if any actual per-expert token count exceeds
114
+ this value.
115
+
116
+ Returns:
117
+ Tuple of:
118
+ - packed_entries: Dict with same keys as entries; each value is the
119
+ packed tensor of shape (B, L, T, ...).
120
+ - unpacking_mask: Boolean tensor of shape (B, L, T). True where a slot
121
+ contains any routed token copy, live or dead. Always has exactly
122
+ B*N*K True entries. Pass this to unpack_experts — not active_mask.
123
+ """
124
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
125
+
126
+ flattened_selected_heads = setup["flattened_selected_heads"]
127
+ permutation = setup["permutation"]
128
+
129
+ # -----------------------------------------------------------------------
130
+ # Reconstruct routed local source-token indices in token-choice order.
131
+ #
132
+ # The internal arange(N) is only the local source-row index object used to
133
+ # gather from the current chunk tensors. Flattening gives a (B, N*K) tensor
134
+ # aligned with H's token-major routed-copy order.
135
+ # -----------------------------------------------------------------------
136
+ source_token_indices = torch.arange(
137
+ sequence_length,
138
+ device=flattened_selected_heads.device,
139
+ dtype=torch.long,
140
+ ).view(1, sequence_length, 1).expand(
141
+ batch_size,
142
+ sequence_length,
143
+ num_selected_heads,
144
+ )
145
+ flattened_source_indices = source_token_indices.reshape(
146
+ batch_size,
147
+ sequence_length * num_selected_heads,
148
+ )
149
+
150
+ # -----------------------------------------------------------------------
151
+ # Reorder source-token indices into expert-major order.
152
+ #
153
+ # Applying Pi yields the local source-token rows in the packed expert-major
154
+ # order required by the paper. All entries are then gathered using these same
155
+ # reordered indices so they remain aligned under the exact same transformation.
156
+ # -----------------------------------------------------------------------
157
+ sorted_source_indices = flattened_source_indices.gather(
158
+ dim=1,
159
+ index=permutation,
160
+ )
161
+
162
+ # -----------------------------------------------------------------------
163
+ # Count how many routed copies land in each expert bucket and verify
164
+ # that no bucket exceeds the statically preallocated packed_length T.
165
+ #
166
+ # S[b, l] is the number of routed token copies assigned to expert l in
167
+ # batch b. T (packed_length) is a static allocation derived from config,
168
+ # not a data-dependent maximum. Overflow is detected here and raises in
169
+ # both eager and compiled modes.
170
+ # -----------------------------------------------------------------------
171
+ tokens_per_expert = _count_tokens_per_expert(flattened_selected_heads, num_experts)
172
+ max_count = tokens_per_expert.max().item()
173
+ no_overflow = max_count <= packed_length
174
+ _enforce_no_overflow(no_overflow)
175
+
176
+ # -----------------------------------------------------------------------
177
+ # Construct the unpacking mask.
178
+ #
179
+ # Each expert bucket is left-justified: if S[b, l] = s, then slots
180
+ # t = 0, ..., s-1 are occupied and all later slots are padding. The mask
181
+ # marks slot occupancy regardless of outer token liveness, and always has
182
+ # exactly B*N*K True entries.
183
+ # -----------------------------------------------------------------------
184
+ time_axis = torch.arange(
185
+ packed_length,
186
+ device=flattened_selected_heads.device,
187
+ dtype=torch.long,
188
+ ).view(1, 1, packed_length)
189
+ unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
190
+
191
+ # -----------------------------------------------------------------------
192
+ # Materialize all entries into the packed expert-choice frame.
193
+ #
194
+ # Each entry is gathered using the expert-major sorted source indices, then
195
+ # scattered into a padded buffer. The gather index is expanded to cover each
196
+ # tensor's trailing dimensions. Padding slots receive the caller-supplied fill
197
+ # value rather than an implicit zero.
198
+ # -----------------------------------------------------------------------
199
+ packed_entries: dict[str, torch.Tensor] = {}
200
+ for key, (tensor, padding_value) in entries.items():
201
+ extra_shape = tensor.shape[2:]
202
+
203
+ # Expand gather index to cover trailing dimensions, if any.
204
+ idx = sorted_source_indices.view(
205
+ batch_size,
206
+ sequence_length * num_selected_heads,
207
+ *(1,) * len(extra_shape),
208
+ ).expand(-1, -1, *extra_shape)
209
+ sorted_tensor = tensor.gather(dim=1, index=idx)
210
+
211
+ packed_tensor = tensor.new_full(
212
+ (batch_size, num_experts, packed_length, *extra_shape),
213
+ fill_value=padding_value,
214
+ )
215
+ packed_tensor[unpacking_mask] = sorted_tensor.reshape(-1, *extra_shape)
216
+ packed_entries[key] = packed_tensor
217
+
218
+ return packed_entries, unpacking_mask
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Unpacking
223
+ # ---------------------------------------------------------------------------
224
+
225
+ def unpack_experts(
226
+ expert_outputs: torch.Tensor,
227
+ setup: dict[str, torch.Tensor],
228
+ unpacking_mask: torch.Tensor,
229
+ selected_heads: torch.Tensor,
230
+ ) -> torch.Tensor:
231
+ """Restore token-choice ordering from BEA expert-choice output.
232
+
233
+ Unpacking inverts the packing path only on occupied entries. Padding does not
234
+ participate: the output tensor is first filtered by unpacking_mask to recover
235
+ only the real routed-token copies in expert-major order, then Pi^{-1} restores
236
+ the original token-choice ordering, and finally the tensor is reshaped back to
237
+ (B, N, K, d).
238
+
239
+ The unpacking_mask — not active_mask — must be used here. Even copies of dead
240
+ outer tokens occupy slots and must be un-scattered correctly for the inverse
241
+ permutation to hold. The total True entry count in unpacking_mask is always
242
+ B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
243
+
244
+ Args:
245
+ expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
246
+ setup: Auxiliary payload returned by setup_packing().
247
+ unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
248
+ occupied packed slots regardless of outer token liveness.
249
+ selected_heads: Routed head selections I of shape (B, N, K).
250
+
251
+ Returns:
252
+ Restored token-choice tensor y_tilde of shape (B, N, K, d).
253
+ """
254
+ inverse_permutation = setup["inverse_permutation"]
255
+
256
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
257
+ hidden_dim = expert_outputs.shape[-1]
258
+
259
+ active_outputs = expert_outputs[unpacking_mask]
260
+ sorted_token_choice_outputs = active_outputs.reshape(
261
+ batch_size,
262
+ sequence_length * num_selected_heads,
263
+ hidden_dim,
264
+ )
265
+ restored_outputs = sorted_token_choice_outputs.gather(
266
+ dim=1,
267
+ index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
268
+ )
269
+
270
+ return restored_outputs.reshape(
271
+ batch_size,
272
+ sequence_length,
273
+ num_selected_heads,
274
+ hidden_dim,
275
+ )
276
+
277
+
278
+ # ---------------------------------------------------------------------------
279
+ # Helpers
280
+ # ---------------------------------------------------------------------------
281
+
282
+ def _enforce_no_overflow(condition: bool) -> None:
283
+ """Enforce that no expert bucket exceeds the preallocated packed length.
284
+
285
+ This check fires when the number of tokens assigned to any expert in any
286
+ batch item exceeds mosrah_packed_length. When that limit is exceeded, the
287
+ packed buffer is too small to hold all assignments and data would be dropped.
288
+ Increase mosrah_overallocation_factor in ShramConfig to resolve.
289
+
290
+ The caller must derive condition via .item() on the max count tensor so that
291
+ dynamo captures a SymInt and the comparison produces a SymBool. Passing a
292
+ tensor comparison result directly bypasses the SymInt mechanism and prevents
293
+ the check from firing at compiled runtime.
294
+
295
+ Args:
296
+ condition: True means no overflow has occurred; False means at least one
297
+ expert bucket exceeds packed_length. In compiled mode this is a SymBool
298
+ produced by comparing a SymInt against the static packed_length.
299
+ """
300
+ if torch.compiler.is_compiling():
301
+ torch._check(condition)
302
+ else:
303
+ if not condition:
304
+ raise RuntimeError(
305
+ "Expert packing overflow: at least one expert bucket contains more "
306
+ "tokens than mosrah_packed_length allows. Increase "
307
+ "mosrah_overallocation_factor in ShramConfig to resolve."
308
+ )
309
+
310
+
311
+ def _count_tokens_per_expert(
312
+ flattened_selected_heads: torch.Tensor,
313
+ num_experts: int,
314
+ ) -> torch.Tensor:
315
+ """Count how many routed token copies are assigned to each expert per batch item.
316
+
317
+ Uses scatter_add into a pre-sized (B, num_experts) zero buffer, producing a
318
+ statically-shaped output that compiles without graph breaks. Each position in
319
+ flattened_selected_heads contributes one count to the corresponding expert slot.
320
+
321
+ Args:
322
+ flattened_selected_heads: Expert assignments of shape (B, N*K) with values
323
+ in [0, num_experts).
324
+ num_experts: Total number of experts L.
325
+
326
+ Returns:
327
+ Counts tensor of shape (B, num_experts).
328
+ """
329
+ batch_size = flattened_selected_heads.shape[0]
330
+ counts = torch.zeros(
331
+ batch_size,
332
+ num_experts,
333
+ device=flattened_selected_heads.device,
334
+ dtype=flattened_selected_heads.dtype,
335
+ )
336
+ counts.scatter_add_(
337
+ dim=1,
338
+ index=flattened_selected_heads,
339
+ src=torch.ones_like(flattened_selected_heads),
340
+ )
341
+ return counts
__attention__load_balance_loss.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
+
3
+ This module implements the custom autograd Function H(b, f) described in the paper's
4
+ Implementation Concerns section. The operator bridges two requirements that are in
5
+ tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
+ multiplication) so that existing training loops remain compatible, while simultaneously
7
+ implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
+ path through the router weights.
9
+
10
+ The resolution is a custom backward pass. The forward emits the load balance imbalance
11
+ as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
+ its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
+ then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
+ correction without a standalone SGD update step.
15
+
16
+ Paper ref: Appendix A.Implementation Concerns.
17
+ """
18
+
19
+ import torch
20
+
21
+
22
+ class LoadBalanceLoss(torch.autograd.Function):
23
+ """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
+
25
+ Forward computes the load balance imbalance:
26
+
27
+ L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
+
29
+ Backward emits a bias-correction gradient to expert_bias:
30
+
31
+ grad_b = L_grad * sign(f_l - 1/L)
32
+
33
+ expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
+ in the computation graph and routes gradients through it. routing_freqs (f) receives
35
+ no gradient — its origin is the discrete TopK operation which has no gradient, so
36
+ defining a gradient for f here would be mathematically incorrect.
37
+
38
+ Paper ref: Appendix A.Implementation Concerns.
39
+ """
40
+
41
+ @staticmethod
42
+ def forward(
43
+ ctx: torch.autograd.function.FunctionCtx,
44
+ expert_bias: torch.Tensor,
45
+ routing_freqs: torch.Tensor,
46
+ ) -> torch.Tensor:
47
+ """Compute the load balance loss.
48
+
49
+ Args:
50
+ ctx: Autograd context for saving state needed in backward.
51
+ expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
+ PyTorch tracks it as a computation graph node needing a gradient.
53
+ routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
+ from the discrete TopK selection — not differentiable.
55
+
56
+ Returns:
57
+ Scalar loss equal to sum_l |f_l - 1/L|.
58
+ """
59
+ L = expert_bias.shape[0]
60
+ # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
+ # underloaded. Saved for backward where sign(imbalance) determines the direction
62
+ # of the bias-correction update.
63
+ imbalance = routing_freqs - 1.0 / L
64
+ ctx.save_for_backward(imbalance)
65
+ return imbalance.abs().sum()
66
+
67
+ @staticmethod
68
+ def backward(
69
+ ctx: torch.autograd.function.FunctionCtx,
70
+ grad_output: torch.Tensor,
71
+ ) -> tuple[torch.Tensor, None]:
72
+ """Emit the DeepSeek-style bias-correction gradient.
73
+
74
+ Args:
75
+ ctx: Autograd context carrying imbalance saved in forward.
76
+ grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
+ by the training loop arrives here and is propagated to grad_b, so the
78
+ correction magnitude is proportional to the loss weight chosen by the
79
+ consumer.
80
+
81
+ Returns:
82
+ Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
+ None for routing_freqs: no gradient is defined for the discrete routing
84
+ frequency.
85
+ """
86
+ (imbalance,) = ctx.saved_tensors
87
+ grad_expert_bias = grad_output * imbalance.sign()
88
+ return grad_expert_bias, None
__attention__mosrah.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full MoSRAH sparse path for SHRAM.
2
+
3
+ This module coordinates the routed sparse attention path used inside the SHRAM
4
+ hybrid attention layer. The underlying mechanics already live in verified
5
+ subunits. The responsibility here is to connect those subunits without
6
+ corrupting their bridge contracts.
7
+
8
+ In particular, this path must preserve three architectural distinctions:
9
+
10
+ - selected head indices are not routing probabilities
11
+ - packed position semantics are chosen before BEA, not inside it
12
+ - weighted reduction must consume the router's unbiased renormalized
13
+ probabilities after token-choice order has been restored
14
+ """
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from .__cache__mosrah_cache import MoSRAHCache
20
+ from .configuration import ShramConfig
21
+ from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
+ from .__attention__expert_packing import (
23
+ pack_experts,
24
+ setup_packing,
25
+ unpack_experts,
26
+ )
27
+ from .__attention__router import MoSRAHRouter
28
+ from .__attention__positions_converter import SparseMoSRAHPositions
29
+
30
+
31
+ class MoSRAHLayer(nn.Module):
32
+ """Full routed sparse attention path for SHRAM.
33
+
34
+ The MoSRAH path consumes model-space hidden states together with
35
+ authoritative per-token positions and returns the model-space sparse-path
36
+ contribution, the router's load-balance loss, and the router's MaxVio
37
+ routing-imbalance scalar.
38
+ """
39
+
40
+ def __init__(self, config: ShramConfig) -> None:
41
+ super().__init__()
42
+ self.num_experts = config.num_mosrah_heads
43
+ self.packed_length = config.mosrah_packed_length
44
+
45
+ self.router = MoSRAHRouter(config)
46
+ self.positions = SparseMoSRAHPositions(config)
47
+ self.bea = BottleneckedEnsembleAttention(config)
48
+
49
+ def num_mosrah_parameters(self) -> int:
50
+ """Return the total number of trainable parameters in this MoSRAH layer."""
51
+ return sum(p.numel() for p in self.parameters())
52
+
53
+ def forward(
54
+ self,
55
+ hidden_states: torch.Tensor,
56
+ position_ids: torch.Tensor,
57
+ active_mask: torch.Tensor,
58
+ cache: MoSRAHCache | None,
59
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ """Run the full MoSRAH sparse path.
61
+
62
+ Args:
63
+ hidden_states: Model-space hidden states x of shape (B, N, d).
64
+ position_ids: Authoritative per-token positions of shape (B, N).
65
+ active_mask: Current-chunk active mask of shape (B, N), where True
66
+ means the token is semantically live. Forwarded to the router
67
+ so dead tokens are excluded from routing statistics, and to
68
+ pack_experts so dead outer tokens do not become semantically
69
+ active packed entries.
70
+ cache: Optional layer-local MoSRAH cache. Pass None for uncached
71
+ execution and the layer-local cache instance for cached execution.
72
+
73
+ Returns:
74
+ sparse_output: Model-space sparse-path output of shape (B, N, d).
75
+ load_balance_loss: Scalar router load-balance loss.
76
+ max_vio: Detached scalar routing-imbalance summary. Passed through
77
+ unchanged from the router; see MoSRAHRouter for semantics.
78
+ """
79
+
80
+ # -------------------------------------------------------------------
81
+ # The first transition moves from model-space token-choice input into
82
+ # the packed expert-choice sparse-attention state. Routing decides both
83
+ # which experts each token uses and which unbiased probabilities must be
84
+ # reserved for the final reduction. The active mask is forwarded to the
85
+ # router so dead tokens are excluded from routing statistics, and to
86
+ # pack_experts so outer liveness is faithfully carried into the packed
87
+ # frame. Packing returns both the unpacking mask (slot occupancy, always
88
+ # B*N*K True entries) and the packed active mask (live slots only);
89
+ # active_mask is rebound to the packed form after this point.
90
+ # -------------------------------------------------------------------
91
+ selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
92
+ hidden_states, active_mask
93
+ )
94
+
95
+ setup = setup_packing(selected_heads)
96
+ entries = {
97
+ "hidden_states": (hidden_states, 0.0),
98
+ "position_ids": (position_ids, 0),
99
+ "active_mask": (active_mask, False),
100
+ }
101
+ packed, unpacking_mask = pack_experts(entries, setup, selected_heads, self.num_experts, self.packed_length)
102
+ packed_hidden_states = packed["hidden_states"]
103
+ packed_positions = packed["position_ids"]
104
+ active_mask = packed["active_mask"]
105
+
106
+ # -------------------------------------------------------------------
107
+ # Sparse attention runs entirely in the packed expert-choice frame, so
108
+ # the RoPE position semantics must also be chosen in that frame. The
109
+ # position layer therefore decides whether BEA should see packed
110
+ # original-token positions or packed local-slot positions. BEA then
111
+ # consumes that packed position tensor together with the packed hidden
112
+ # states and the layer-local sparse cache, which it owns directly.
113
+ # -------------------------------------------------------------------
114
+ bea_positions = self.positions(
115
+ packed_positions=packed_positions,
116
+ active_mask=active_mask,
117
+ cache=cache,
118
+ )
119
+ packed_outputs = self.bea(
120
+ packed_embeddings=packed_hidden_states,
121
+ position_ids=bea_positions,
122
+ active_mask=active_mask,
123
+ cache=cache,
124
+ )
125
+
126
+ # -------------------------------------------------------------------
127
+ # The final transition restores token-choice meaning and only then
128
+ # collapses the K routed copies back into model space. This ordering is
129
+ # required because routing_probs live in token-choice space, whereas BEA
130
+ # returns expert-choice packed outputs. The reduction must therefore
131
+ # happen after unpacking, and it must use the router's unbiased
132
+ # renormalized probabilities rather than any biased selection scores.
133
+ # -------------------------------------------------------------------
134
+ token_choice_outputs = unpack_experts(
135
+ expert_outputs=packed_outputs,
136
+ setup=setup,
137
+ unpacking_mask=unpacking_mask,
138
+ selected_heads=selected_heads,
139
+ )
140
+ final_output = (
141
+ token_choice_outputs * routing_probs.unsqueeze(-1)
142
+ ).sum(dim=2)
143
+
144
+ return final_output, load_balance_loss, max_vio
__attention__positions_converter.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Position computation for the MoSRAH sparse path.
2
+
3
+ This layer computes the packed position tensor P consumed by BEA.
4
+
5
+ - In main-sequence mode, P is the packed original-token position tensor from the
6
+ packing path.
7
+ - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
+ expert-choice layout, optionally offset by the current sparse-cache occupancies
9
+ during cached inference.
10
+ """
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from .configuration import ShramConfig
16
+ from .__cache__mosrah_cache import MoSRAHCache
17
+
18
+
19
+ class SparseMoSRAHPositions(nn.Module):
20
+ """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
+
22
+ This layer operates in the packed expert-choice frame used by BEA. The input
23
+ packed_positions tensor is always the packed original-token position tensor
24
+ produced by the packing path. The configured rope_mode determines whether that
25
+ tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
+ """
27
+
28
+ def __init__(self, config: ShramConfig) -> None:
29
+ super().__init__()
30
+ self.rope_mode = config.rope_mode
31
+
32
+ def forward(
33
+ self,
34
+ packed_positions: torch.Tensor,
35
+ active_mask: torch.Tensor,
36
+ cache: MoSRAHCache | None,
37
+ ) -> torch.Tensor:
38
+ """Compute the packed position tensor P consumed by BEA.
39
+
40
+ Args:
41
+ packed_positions: Packed original-token positions J' of shape (B, L, T).
42
+ active_mask: Boolean active-token mask of shape (B, L, T). Inactive
43
+ positions are zeroed in the returned tensor regardless of mode —
44
+ their position value is semantically irrelevant and 0 is guaranteed
45
+ to be within any valid RoPE table.
46
+ cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
47
+ mode, the current per-head occupancies offset the local packed sequence.
48
+
49
+ Returns:
50
+ Packed position tensor P of shape (B, L, T).
51
+ """
52
+ if self.rope_mode == "main_sequence":
53
+ positions = self._main_sequence_positions(packed_positions)
54
+ elif self.rope_mode == "semantic_sequence":
55
+ positions = self._semantic_sequence_positions(packed_positions, cache)
56
+ else:
57
+ raise NotImplementedError(
58
+ f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
59
+ )
60
+
61
+ return torch.where(active_mask, positions, torch.zeros_like(positions))
62
+
63
+ def _main_sequence_positions(
64
+ self,
65
+ packed_positions: torch.Tensor,
66
+ ) -> torch.Tensor:
67
+ """Forward packed original-token positions unchanged."""
68
+ return packed_positions
69
+
70
+ def _semantic_sequence_positions(
71
+ self,
72
+ packed_positions: torch.Tensor,
73
+ cache: MoSRAHCache | None,
74
+ ) -> torch.Tensor:
75
+ """Compute semantic-sequence packed positions in expert-choice space.
76
+
77
+ Without a sparse cache, semantic positions are the local packed sequence
78
+ 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
79
+ same local sequence is offset by the current per-(batch, expert) occupancies
80
+ returned by get_heads_lengths().
81
+ """
82
+ batch_size, num_experts, packed_length = packed_positions.shape
83
+
84
+ # -------------------------------------------------------------------
85
+ # Construct the local packed sequence 0, 1, 2, ... over the expert-local
86
+ # sequence dimension T. This is then broadcast across batch and experts.
87
+ # -------------------------------------------------------------------
88
+ local_positions = torch.arange(
89
+ packed_length,
90
+ device=packed_positions.device,
91
+ dtype=packed_positions.dtype,
92
+ ).view(1, 1, packed_length).expand(
93
+ batch_size,
94
+ num_experts,
95
+ packed_length,
96
+ )
97
+
98
+ # -------------------------------------------------------------------
99
+ # In cached semantic-sequence mode, positions continue from the current
100
+ # sparse-cache occupancies rather than restarting at zero for the local
101
+ # chunk.
102
+ # -------------------------------------------------------------------
103
+ if cache is None:
104
+ return local_positions
105
+
106
+ cached_lengths = cache.get_heads_lengths().to(
107
+ device=packed_positions.device,
108
+ dtype=packed_positions.dtype,
109
+ ).unsqueeze(-1)
110
+
111
+ return local_positions + cached_lengths
__attention__router.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token-choice router for the MoSRAH sparse attention path.
2
+
3
+ This module implements the routing mechanism described in Appendix A.Routing of the
4
+ paper. Given an input hidden state x, the router produces two outputs used downstream:
5
+
6
+ - selected_heads (I): which K of the L available expert heads each token routes to,
7
+ determined by TopK over biased routing scores.
8
+ - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
+ *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
+ bias b must not influence P.
11
+
12
+ This separation is architecturally critical: expert_bias drives selection (and thus load
13
+ balancing) but does not corrupt the gradient path from the output through routing_probs
14
+ back to the routing projection weights.
15
+
16
+ The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
+ autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
+ loop can weight and add to the language modeling loss.
19
+
20
+ The router additionally computes and returns MaxVio, a detached scalar summarising
21
+ routing imbalance for the current forward pass:
22
+
23
+ MaxVio = L · max_l(f_l − 1/L)
24
+
25
+ where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
+ target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
+
28
+ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ from .configuration import ShramConfig
36
+ from .__attention__load_balance_loss import LoadBalanceLoss
37
+
38
+
39
+ class MoSRAHRouter(nn.Module):
40
+ """Token-choice router for MoSRAH sparse attention.
41
+
42
+ Each input token independently selects K of the L available expert heads. Selection
43
+ is driven by biased routing scores to enable load balancing, but the routing
44
+ probabilities used for output reduction are computed from unbiased scores so that
45
+ the expert bias does not interfere with the gradient path to the router weights.
46
+
47
+ The routing projection W_r has no bias term — the paper specifies xW_r with no
48
+ additional projection bias. The only bias-like parameter is expert_bias (b), which
49
+ has an entirely separate role and update mechanism.
50
+
51
+ Args:
52
+ config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
+ (L), and ``num_selected_heads`` (K).
54
+ """
55
+
56
+ def __init__(self, config: ShramConfig) -> None:
57
+ super().__init__()
58
+ self.num_mosrah_heads = config.num_mosrah_heads
59
+ self.num_selected_heads = config.num_selected_heads
60
+ self.load_balance_p = config.load_balance_p
61
+
62
+ # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
63
+ self.routing_projection = nn.Linear(
64
+ config.embedding_width, config.num_mosrah_heads, bias=False
65
+ )
66
+
67
+ # b: learned per-head bias for load balancing. Initialized to zero so that all
68
+ # heads start with equal selection probability. Updated by the main optimizer
69
+ # via the LoadBalanceLoss custom backward.
70
+ self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
71
+
72
+ def forward(
73
+ self, x: torch.Tensor, active_mask: torch.Tensor
74
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
75
+ """Route input tokens to K expert heads each and compute routing probabilities.
76
+
77
+ Args:
78
+ x: Input hidden states of shape (batch, seq_len, hidden_size).
79
+ active_mask: Current-chunk active mask of shape (batch, seq_len), where
80
+ True means the token is semantically live. Dead tokens do not
81
+ contribute to routing frequencies, load_balance_loss, or max_vio.
82
+
83
+ Returns:
84
+ selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
85
+ Each token's K selected head indices, determined by TopK on biased scores.
86
+ routing_probs: Routing probabilities P of shape (batch, seq_len,
87
+ num_selected_heads). Gathered from unbiased scores at selected_heads
88
+ indices and renormalized to sum to 1 per token.
89
+ load_balance_loss: Scalar load balance imbalance loss for this forward pass.
90
+ Training loop scales this by a weight and adds it to the main loss.
91
+ max_vio: Detached scalar routing-imbalance summary for this forward pass.
92
+ Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
93
+ never contributes gradients.
94
+ """
95
+ B, N, _ = x.shape
96
+ L = self.num_mosrah_heads
97
+ K = self.num_selected_heads
98
+
99
+ # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
100
+ # compute routing_probs — expert_bias must not influence them.
101
+ logits = self.routing_projection(x) # (B, N, L)
102
+ routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
103
+
104
+ # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
105
+ # selection. expert_bias is added to logits before softmax so that the bias
106
+ # shifts selection probability without rescaling the unbiased distribution.
107
+ biased_routing_scores = F.softmax( # R̂, (B, N, L)
108
+ logits + self.expert_bias, dim=-1
109
+ )
110
+
111
+ # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
112
+ selected_heads = biased_routing_scores.topk(K, dim=-1).indices
113
+
114
+ # Routing probabilities P: gathered from unbiased R at selected_heads indices,
115
+ # then renormalized so they sum to 1 per token. Gathering from routing_scores
116
+ # (not biased_routing_scores) is the invariant that keeps the gradient path from
117
+ # the output back to the router weights free of expert_bias influence.
118
+ gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
119
+ routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
120
+
121
+ # Per-item routing frequencies f_{b,l}: for each batch item b and head l, what
122
+ # fraction of that item's active K assignments over all tokens go to head l.
123
+ # Dead tokens are excluded before reduction. Normalization is per batch item so
124
+ # each item's frequencies sum to 1 independently of other items in the batch.
125
+ assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
+ assignment_mask.scatter_(-1, selected_heads, 1.0)
127
+ active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
+ per_item_counts = active_assignments.sum(dim=1) # (B, L)
129
+ per_item_total = active_mask.sum(dim=1, keepdim=True) * K # (B, 1)
130
+ per_item_freqs = per_item_counts / per_item_total # (B, L)
131
+
132
+ # p-mean of per_item_freqs over the batch dimension produces routing_freqs (L,).
133
+ # p-mean weights aggregation toward the worst-case batch item relative to
134
+ # arithmetic mean, making the load balance signal sensitive to per-item spikes
135
+ # that cause packing overflow.
136
+ p = self.load_balance_p
137
+ routing_freqs = (per_item_freqs ** p).mean(dim=0) ** (1.0 / p) # (L,)
138
+
139
+ # Load balance loss via custom autograd. expert_bias is an input so PyTorch
140
+ # registers it as a graph node; the custom backward writes the DeepSeek-style
141
+ # correction gradient to expert_bias.grad for the optimizer to consume.
142
+ load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
143
+
144
+ # MaxVio is a detached monitoring scalar following the paper's formula
145
+ # L · max_l(f_l − 1/L) applied to routing_freqs. Must not contribute gradients.
146
+ max_vio = self._compute_max_vio(routing_freqs, L)
147
+
148
+ return selected_heads, routing_probs, load_balance_loss, max_vio
149
+
150
+ @staticmethod
151
+ def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
152
+ """Compute the MaxVio routing-imbalance scalar.
153
+
154
+ MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
155
+ head l and 1/L is the perfectly balanced target. Follows the paper's definition
156
+ (Wang et al.) applied to routing_freqs. A value of zero indicates perfect
157
+ balance; a value of 0.5 means the most overloaded head received 50% more routed
158
+ tokens than ideal.
159
+
160
+ The result is detached from the autograd graph — MaxVio is a monitoring scalar
161
+ and must never contribute gradients to any parameter.
162
+
163
+ Args:
164
+ routing_freqs: Per-head routing frequencies of shape (L,).
165
+ num_heads: Total number of MoSRAH heads L.
166
+
167
+ Returns:
168
+ Detached scalar MaxVio tensor.
169
+ """
170
+ return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
__attention__shram.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM hybrid attention layer.
2
+
3
+ This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
+ used at one decoder attention slot in SHRAM.
5
+
6
+ The local sliding-window path and the MoSRAH sparse path are already verified
7
+ independently. The responsibility here is therefore not to introduce new
8
+ attention logic, but to preserve the bridge contracts between them: both paths
9
+ must consume the same input hidden state, each path must receive the sub-cache
10
+ it actually owns, the two model-space outputs must be summed directly, and the
11
+ sparse-path load-balance loss must remain visible to the caller.
12
+ """
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from .__cache__shram_layer_cache import ShramLayerCache
18
+ from .configuration import ShramConfig
19
+ from .__attention__sliding_window_attention import SlidingWindowAttention
20
+ from .__attention__mosrah import MoSRAHLayer
21
+
22
+
23
+ class SHRAMHybridLayer(nn.Module):
24
+ """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
+
26
+ The local path preserves nearby-token behavior through sliding-window causal
27
+ attention. The sparse path is the theorem-facing MoSRAH routed attention
28
+ path. Both operate over the same model-space hidden state and return
29
+ model-space outputs, so the hybrid composition is a direct sum in model
30
+ space.
31
+ """
32
+
33
+ def __init__(self, config: ShramConfig) -> None:
34
+ super().__init__()
35
+ self.local_attention = SlidingWindowAttention(config)
36
+ self.sparse_attention = MoSRAHLayer(config)
37
+
38
+ def num_mosrah_parameters(self) -> int:
39
+ """Return the total number of trainable parameters in the MoSRAH sparse path."""
40
+ return self.sparse_attention.num_mosrah_parameters()
41
+
42
+ def forward(
43
+ self,
44
+ hidden_states: torch.Tensor,
45
+ position_ids: torch.Tensor,
46
+ active_mask: torch.Tensor,
47
+ cache: ShramLayerCache | None,
48
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
49
+ """Apply the SHRAM hybrid attention layer.
50
+
51
+ Args:
52
+ hidden_states: Input hidden states of shape (B, N, d).
53
+ position_ids: Authoritative token positions of shape (B, N).
54
+ active_mask: Current-chunk active mask of shape (B, N), where True
55
+ means the token is semantically live. Forwarded unchanged to
56
+ both the local path and the sparse path.
57
+ cache: Optional per-layer SHRAM cache. When provided, the owned
58
+ sliding-window and MoSRAH sub-caches are dispatched directly to
59
+ their corresponding attention paths.
60
+
61
+ Returns:
62
+ hybrid_output: Model-space hybrid attention output of shape (B, N, d).
63
+ load_balance_loss: Scalar sparse-path load-balance loss.
64
+ max_vio: Detached scalar routing-imbalance summary. Passed through
65
+ unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
66
+ """
67
+ # -------------------------------------------------------------------
68
+ # The hybrid layer's first responsibility is cache dispatch. The layer
69
+ # cache already owns the concrete sub-cache objects required by each
70
+ # path, so this unit should forward those exact references rather than
71
+ # reinterpret cache ownership or invent a composite update protocol here.
72
+ # -------------------------------------------------------------------
73
+ if cache is None:
74
+ sliding_window_cache = None
75
+ mosrah_cache = None
76
+ else:
77
+ sliding_window_cache = cache.sliding_window_cache
78
+ mosrah_cache = cache.mosrah_cache
79
+
80
+ # -------------------------------------------------------------------
81
+ # Both attention paths must see the same model-space hidden state for
82
+ # the current decoder layer. The local path preserves short-range
83
+ # structure, while the sparse path provides the routed long-range
84
+ # contribution and emits the load-balance signal used by training.
85
+ # -------------------------------------------------------------------
86
+ local_output = self.local_attention(
87
+ x=hidden_states,
88
+ position_ids=position_ids,
89
+ active_mask=active_mask,
90
+ cache=sliding_window_cache,
91
+ )
92
+ sparse_output, load_balance_loss, max_vio = self.sparse_attention(
93
+ hidden_states=hidden_states,
94
+ position_ids=position_ids,
95
+ active_mask=active_mask,
96
+ cache=mosrah_cache,
97
+ )
98
+
99
+ # -------------------------------------------------------------------
100
+ # The composition rule is intentionally simple at this boundary. Both
101
+ # sublayers already return model-space tensors of matching shape, so the
102
+ # correct hybrid behavior is their direct sum with no additional mixing
103
+ # logic introduced here.
104
+ # -------------------------------------------------------------------
105
+ hybrid_output = local_output + sparse_output
106
+
107
+ return hybrid_output, load_balance_loss, max_vio
__attention__sliding_window_attention.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/shram/model/attention/sliding_window_attention.py
2
+
3
+ """Local sliding-window attention path for SHRAM.
4
+
5
+ This file defines `SlidingWindowAttention`, the local short-range attention path
6
+ used inside the SHRAM hybrid layer.
7
+
8
+ In the masked-continuation variant, the local cache no longer returns a
9
+ semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
+ returns:
11
+
12
+ - the retained local window memory concatenated with the current chunk
13
+ - an aligned active mask over that returned frame
14
+
15
+ This module consumes that returned frame directly and constructs effective local
16
+ causal/window visibility from the mask. It does not own cache retention policy;
17
+ it owns only local attention semantics.
18
+ """
19
+
20
+ import math
21
+ from typing import Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
+
27
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
28
+ from .configuration import ShramConfig
29
+ from .rope import RotaryEmbedding
30
+
31
+
32
+ class SlidingWindowAttention(nn.Module):
33
+ """Causal local sliding-window attention for one SHRAM layer.
34
+
35
+ Args:
36
+ config: SHRAM config. Must expose `hidden_size`,
37
+ `num_sliding_window_heads`, `head_dim`, `window_size`,
38
+ `attention_dropout`, and `local_rope_theta`.
39
+
40
+ Raises:
41
+ NotImplementedError: If `attention_dropout != 0.0`.
42
+ """
43
+
44
+ def __init__(self, config: ShramConfig) -> None:
45
+ super().__init__()
46
+
47
+ self.hidden_size = config.embedding_width
48
+ self.num_heads = config.num_sliding_window_heads
49
+ self.head_dim = config.head_dim
50
+ self.window_size = config.window_size
51
+ self.attention_dropout = config.attention_dropout
52
+
53
+ if self.attention_dropout != 0.0:
54
+ raise NotImplementedError(
55
+ "SlidingWindowAttention currently supports only "
56
+ "attention_dropout == 0.0."
57
+ )
58
+
59
+ self.inner_dim = self.num_heads * self.head_dim
60
+
61
+ # Standard MHA projections for the local path.
62
+ self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
+ self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
+ self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
+ self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
+
67
+ # The local path always uses default-mode RoPE with its own theta.
68
+ self.rope = RotaryEmbedding(
69
+ mode="default",
70
+ head_dim=self.head_dim,
71
+ theta=config.local_rope_theta,
72
+ maximum_sequence_length=config.inference_sequence_length,
73
+ )
74
+
75
+ def forward(
76
+ self,
77
+ x: torch.Tensor,
78
+ position_ids: torch.Tensor,
79
+ active_mask: torch.Tensor,
80
+ cache: LocalSlidingWindowLayerCache | None = None,
81
+ ) -> torch.Tensor:
82
+ """Apply local causal sliding-window attention.
83
+
84
+ Args:
85
+ x: Input tensor of shape `(B, N, hidden_size)`.
86
+ position_ids: Position tensor of shape `(B, N)`.
87
+ active_mask: Current-chunk active mask of shape `(B, N)`, where
88
+ `True` means active.
89
+ cache: Optional `LocalSlidingWindowLayerCache`.
90
+
91
+ Returns:
92
+ Output tensor of shape `(B, N, hidden_size)`.
93
+ """
94
+ batch_size, query_len, _ = x.shape
95
+
96
+ self._validate_position_shape(x, position_ids)
97
+ self._validate_active_mask_shape(x, active_mask)
98
+
99
+ # (B, N, H*D) -> (B, H, N, D)
100
+ q = self.q_proj(x).view(
101
+ batch_size,
102
+ query_len,
103
+ self.num_heads,
104
+ self.head_dim,
105
+ ).transpose(1, 2)
106
+ k = self.k_proj(x).view(
107
+ batch_size,
108
+ query_len,
109
+ self.num_heads,
110
+ self.head_dim,
111
+ ).transpose(1, 2)
112
+ v = self.v_proj(x).view(
113
+ batch_size,
114
+ query_len,
115
+ self.num_heads,
116
+ self.head_dim,
117
+ ).transpose(1, 2)
118
+
119
+ q, k, attention_scaling = self.rope(q, k, position_ids)
120
+
121
+ # The cache returns the current-step visible local frame, not merely the
122
+ # retained next-step cache buffer.
123
+ if cache is not None:
124
+ k_full, v_full, full_active_mask, full_positions = cache.update(
125
+ k, v, active_mask, position_ids
126
+ )
127
+ else:
128
+ k_full, v_full, full_active_mask, full_positions = k, v, active_mask, position_ids
129
+
130
+ block_mask = self._make_block_mask(
131
+ active_mask=full_active_mask,
132
+ positions=full_positions,
133
+ batch_size=batch_size,
134
+ num_heads=self.num_heads,
135
+ query_len=query_len,
136
+ kv_len=k_full.shape[-2],
137
+ window_size=self.window_size,
138
+ device=x.device,
139
+ )
140
+
141
+ attn_output = flex_attention(
142
+ q,
143
+ k_full,
144
+ v_full,
145
+ block_mask=block_mask,
146
+ scale=attention_scaling / math.sqrt(self.head_dim),
147
+ )
148
+
149
+ # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
150
+ attn_output = (
151
+ attn_output.transpose(1, 2)
152
+ .contiguous()
153
+ .view(batch_size, query_len, self.inner_dim)
154
+ )
155
+
156
+ return self.o_proj(attn_output)
157
+
158
+ def _validate_position_shape(
159
+ self,
160
+ x: torch.Tensor,
161
+ position_ids: torch.Tensor,
162
+ ) -> None:
163
+ """Validate the position tensor shape expected by local RoPE."""
164
+ if position_ids.shape != x.shape[:2]:
165
+ raise ValueError(
166
+ f"position_ids must have shape {tuple(x.shape[:2])}, "
167
+ f"got {tuple(position_ids.shape)}."
168
+ )
169
+
170
+ def _validate_active_mask_shape(
171
+ self,
172
+ x: torch.Tensor,
173
+ active_mask: torch.Tensor,
174
+ ) -> None:
175
+ """Validate the current-chunk active-mask contract."""
176
+ if active_mask.shape != x.shape[:2]:
177
+ raise ValueError(
178
+ f"active_mask must have shape {tuple(x.shape[:2])}, "
179
+ f"got {tuple(active_mask.shape)}."
180
+ )
181
+ if active_mask.dtype != torch.bool:
182
+ raise ValueError(
183
+ f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
184
+ )
185
+
186
+ def _make_block_mask(
187
+ self,
188
+ active_mask: torch.Tensor,
189
+ positions: torch.Tensor,
190
+ batch_size: int,
191
+ num_heads: int,
192
+ query_len: int,
193
+ kv_len: int,
194
+ window_size: int,
195
+ device: torch.device,
196
+ ) -> Any:
197
+ """Create the FlexAttention block mask for masked local continuation.
198
+
199
+ The returned local frame is chronological in raw buffer order; dead
200
+ positions may remain inside it. Liveness is carried by `active_mask`.
201
+ Causality and window distance are determined from `positions`, which
202
+ holds the absolute sequence position of every slot in the composite
203
+ frame. Using absolute positions rather than a cumsum over the active
204
+ mask eliminates the data-dependent computation that blocks torch.compile.
205
+ """
206
+ query_offset = kv_len - query_len
207
+
208
+ def sliding_window_mask(
209
+ batch_idx: torch.Tensor,
210
+ head_idx: torch.Tensor,
211
+ q_idx: torch.Tensor,
212
+ kv_idx: torch.Tensor,
213
+ ) -> torch.Tensor:
214
+
215
+ q_abs = query_offset + q_idx
216
+
217
+ query_is_active = active_mask[batch_idx, q_abs]
218
+ key_is_active = active_mask[batch_idx, kv_idx]
219
+
220
+ q_pos = positions[batch_idx, q_abs]
221
+ k_pos = positions[batch_idx, kv_idx]
222
+
223
+ is_causal = k_pos <= q_pos
224
+ in_window = (q_pos - k_pos) < window_size
225
+
226
+ return query_is_active & key_is_active & is_causal & in_window
227
+
228
+ return create_block_mask(
229
+ sliding_window_mask,
230
+ B=batch_size,
231
+ H=num_heads,
232
+ Q_LEN=query_len,
233
+ KV_LEN=kv_len,
234
+ device=device,
235
+ )
__cache__mosrah_cache.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MoSRAH sparse KV cache — single-layer implementation.
2
+
3
+ MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
+ by head rather than by sequence position. The routing is dynamic and produces a ragged
5
+ distribution of token counts across (batch, head) slots — different batch items may
6
+ route different numbers of tokens to the same head, and different heads accumulate at
7
+ different rates. DynamicCache cannot represent this correctly: it concatenates along
8
+ the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
+ therefore uses a custom buffer design.
10
+
11
+ Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
+ attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
+ heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
+ head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
+ valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
+ buffer_capacity property and is derived directly from self.keys rather than tracked
17
+ as a separate variable.
18
+
19
+ The primary interface is update(key_states, value_states, active_mask), which accepts
20
+ expert-choice layout, stores only active entries in causal order, and returns the full
21
+ accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
+ active_mask identifies valid cached positions; everything beyond each slot's count is
23
+ junk data that downstream attention must exclude.
24
+
25
+ BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
+ exposed by get_heads_lengths() must be read before update() if the caller needs the
27
+ pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
+ in-place and the pre-update values are not recoverable afterward.
29
+
30
+ All buffers are allocated at construction time. MoSRAHCache is constructed by
31
+ ShramLayerCache, which has access to batch size, device, and all model config parameters
32
+ needed to fully specify the storage layout upfront.
33
+ """
34
+
35
+ import torch
36
+ from transformers.cache_utils import CacheLayerMixin
37
+
38
+
39
+ class MoSRAHCache(CacheLayerMixin):
40
+ """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
+
42
+ Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
+ Stores keys and values in the mixin-standard self.keys and self.values attributes
44
+ using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
+ which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
+
47
+ All storage is allocated at construction time and is_initialized is True
48
+ immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
+ config parameters so no lazy allocation is needed.
50
+
51
+ Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
+ (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
+ This matches the packed representation produced by expert packing in the MoSRAH
54
+ forward pass, where BEA has already applied RoPE before calling update().
55
+
56
+ Args:
57
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
+ second dimension of all storage tensors.
59
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
+ dimension of all storage tensors.
61
+ batch_size: Number of sequences in the batch. Determines the first dimension
62
+ of all storage tensors.
63
+ device: Device on which to allocate all tensors. Should match the model device.
64
+ mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
65
+ config.mosrah_cache_length. The buffer never grows; if any slot would exceed
66
+ this capacity, update() raises in both eager and compiled modes. Increase
67
+ mosrah_overallocation_factor in ShramConfig to resolve an overflow.
68
+ """
69
+
70
+ is_compileable = True
71
+ is_sliding = False
72
+
73
+ def __init__(
74
+ self,
75
+ num_mosrah_heads: int,
76
+ head_dim: int,
77
+ batch_size: int,
78
+ device: torch.device,
79
+ mosrah_cache_length: int,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.num_mosrah_heads = num_mosrah_heads
83
+ self.head_dim = head_dim
84
+ self.batch_size = batch_size
85
+ self.device = device
86
+ self.mosrah_cache_length = mosrah_cache_length
87
+
88
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
89
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
90
+ # tracks valid occupancy per (batch, head) slot.
91
+ self.keys: torch.Tensor = torch.zeros(
92
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
93
+ )
94
+ self.values: torch.Tensor = torch.zeros(
95
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
96
+ )
97
+ self._counts: torch.Tensor = torch.zeros(
98
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
99
+ )
100
+
101
+ # Storage is fully allocated at construction — the cache is initialized.
102
+ self.is_initialized = True
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Properties
106
+ # ---------------------------------------------------------------------------
107
+
108
+ @property
109
+ def buffer_capacity(self) -> int:
110
+ """Current number of slots allocated per (batch, head) pair.
111
+
112
+ Equal to mosrah_cache_length as supplied at construction. Derived from
113
+ self.keys so it remains consistent with the actual buffer shape.
114
+ """
115
+ return self.keys.shape[2]
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Primary API
119
+ # ---------------------------------------------------------------------------
120
+
121
+ def update( # type: ignore[override]
122
+ self,
123
+ key_states: torch.Tensor,
124
+ value_states: torch.Tensor,
125
+ active_mask: torch.Tensor,
126
+ cache_kwargs: dict | None = None,
127
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ """Scatter active key/value states into the buffer and return the full cache state.
129
+
130
+ Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
131
+ active_mask is (B, L, T) bool with True marking real tokens. Only active
132
+ positions are written; inactive positions are ignored.
133
+
134
+ Uses a fixed-shape destination mask constructed from per-slot write intervals
135
+ to transfer active tokens into the buffer without any data-dependent shape
136
+ operations. Active tokens are left-justified within each packed slot by the
137
+ packing machinery, so the destination positions are a contiguous range
138
+ starting at the current slot count — no cumsum or torch.where needed.
139
+
140
+ Returns the full accumulated (keys, values, active_mask) across the cached
141
+ sparse sequence. The returned active_mask is True exactly for slots t <
142
+ counts[b, l]; everything beyond is junk data that BEA must exclude.
143
+
144
+ Note: get_heads_lengths() must be called before update() if the caller needs
145
+ the pre-update occupancy for position computation (Unit 10.A). update()
146
+ increments counts in-place and the pre-update values are not recoverable.
147
+
148
+ Args:
149
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
150
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
151
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
152
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
153
+
154
+ Returns:
155
+ Tuple of (keys, values, active_mask):
156
+ keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
157
+ values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
158
+ active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
159
+ """
160
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
161
+
162
+ post_counts = self._counts + incoming_delta
163
+ self._check_no_overflow(post_counts.max(), self.mosrah_cache_length)
164
+
165
+ # Build a fixed-shape destination mask in cache space. Active tokens within
166
+ # each (b, l) slot are left-justified by the packing machinery, so they occupy
167
+ # positions 0..s-1 in their packed slot. The corresponding cache positions are
168
+ # write_start[b,l]..write_start[b,l]+write_count[b,l]-1. Broadcasting a
169
+ # time arange against these per-slot intervals selects exactly the target
170
+ # positions without any data-dependent shape query.
171
+ write_start = self._counts.unsqueeze(-1) # cache position where new tokens begin
172
+ write_count = incoming_delta.unsqueeze(-1) # number of new tokens arriving per slot
173
+ time_arange = torch.arange(
174
+ self.mosrah_cache_length, device=active_mask.device
175
+ )
176
+ dest_mask = (time_arange >= write_start) & (time_arange < write_start + write_count)
177
+ # dest_mask: (B, L, mosrah_cache_length)
178
+
179
+ # Transfer key and value vectors. Left-justification guarantees that
180
+ # dest_mask and active_mask have equal True counts per (b, l) slot, so the
181
+ # boolean-mask transfer is correct without any explicit count verification.
182
+ self.keys[dest_mask] = key_states[active_mask]
183
+ self.values[dest_mask] = value_states[active_mask]
184
+
185
+ self._counts = post_counts
186
+
187
+ return self.keys, self.values, self._make_active_mask()
188
+
189
+ def get_heads_lengths(self) -> torch.Tensor:
190
+ """Return the per-(batch, head) token count for this layer.
191
+
192
+ This is the authoritative occupancy tensor consumed by BEA for attention
193
+ masking and by position computation (Unit 10.A) for semantic-sequence
194
+ position computation.
195
+
196
+ Note: in the MoSRAH forward pass, this must be called before update() if the
197
+ caller needs the pre-update occupancy. update() increments these counts in-place.
198
+
199
+ Returns:
200
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
201
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
202
+ """
203
+ return self._counts
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # CacheLayerMixin — overridden coordination methods
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def reset(self) -> None:
210
+ """Clear all cached key and value tensors.
211
+
212
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
213
+ and is_initialized remains True — only the contents are cleared.
214
+ """
215
+ self.keys.zero_()
216
+ self.values.zero_()
217
+ self._counts.zero_()
218
+
219
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
220
+ """Reorder the batch dimension of all cached tensors for beam search.
221
+
222
+ Applied atomically across self.keys, self.values, and _counts. Beam search
223
+ must reorder all three together or the occupancy counts and buffer contents
224
+ will correspond to different beam hypotheses.
225
+
226
+ Overrides the parent because the parent's implementation calls get_seq_length(),
227
+ which is not supported for this cache.
228
+
229
+ Args:
230
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
231
+ search algorithm.
232
+ """
233
+ self.keys = self.keys[beam_idx]
234
+ self.values = self.values[beam_idx]
235
+ self._counts = self._counts[beam_idx]
236
+
237
+ def batch_repeat_interleave(self, repeats: int) -> None:
238
+ """Expand the batch dimension by repeating each entry repeats times.
239
+
240
+ Used at beam search initialisation to expand the cache from batch size B to
241
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
242
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
243
+
244
+ Args:
245
+ repeats: Number of times to repeat each batch entry.
246
+ """
247
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
248
+ self.values = self.values.repeat_interleave(repeats, dim=0)
249
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
250
+ self.batch_size = self.batch_size * repeats
251
+
252
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
253
+ """Select a subset of batch entries by index.
254
+
255
+ Used in contrastive search to retain only the selected candidate entries.
256
+ Applied atomically across keys, values, and _counts; batch_size is updated
257
+ to reflect the number of retained entries.
258
+
259
+ Args:
260
+ indices: 1-D integer tensor of batch indices to retain.
261
+ """
262
+ self.keys = self.keys[indices]
263
+ self.values = self.values[indices]
264
+ self._counts = self._counts[indices]
265
+ self.batch_size = indices.shape[0]
266
+
267
+ def offload(self) -> None:
268
+ """Offload all cached tensors to CPU.
269
+
270
+ Extends the parent to also offload _counts, which the parent does not know
271
+ about. All three tensors are moved atomically so device state remains consistent.
272
+ """
273
+ super().offload()
274
+ self._counts = self._counts.to("cpu", non_blocking=True)
275
+
276
+ def prefetch(self) -> None:
277
+ """Move all cached tensors back to the model device ahead of time.
278
+
279
+ Extends the parent to also prefetch _counts, which the parent does not know
280
+ about. _counts is synced to self.keys.device after the parent moves keys and
281
+ values, so all three remain consistent.
282
+ """
283
+ super().prefetch()
284
+ if self._counts.device != self.keys.device:
285
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
286
+
287
+ def lazy_initialization( # type: ignore[override]
288
+ self, key_states: torch.Tensor, value_states: torch.Tensor
289
+ ) -> None:
290
+ """No-op — storage is fully allocated at construction time."""
291
+ pass
292
+
293
+ # ---------------------------------------------------------------------------
294
+ # CacheLayerMixin — unsupported abstract methods
295
+ # ---------------------------------------------------------------------------
296
+
297
+ def get_seq_length(self) -> int: # type: ignore[override]
298
+ """Not supported — no single sequence length represents this cache's state.
299
+
300
+ MoSRAH heads accumulate independently; (batch, head) slots have different
301
+ lengths depending on routing history. There is no meaningful scalar summary.
302
+ Use get_heads_lengths() for per-head occupancy.
303
+ """
304
+ raise NotImplementedError(
305
+ "MoSRAHCache has no single sequence length. "
306
+ "Use get_heads_lengths() for per-head occupancy."
307
+ )
308
+
309
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
310
+ """Return the static per-(batch, head) slot capacity of this cache.
311
+
312
+ Equal to mosrah_cache_length as supplied at construction, which is derived
313
+ from config.mosrah_cache_length. Required by the HuggingFace static cache
314
+ contract; generation machinery uses this to size attention masks.
315
+ """
316
+ return self.mosrah_cache_length
317
+
318
+ def get_mask_sizes( # type: ignore[override]
319
+ self,
320
+ cache_position: torch.Tensor,
321
+ ) -> tuple[int, int]:
322
+ """Not supported — MoSRAHCache does not participate in HF mask construction."""
323
+ raise NotImplementedError(
324
+ "MoSRAHCache does not support get_mask_sizes()."
325
+ )
326
+
327
+ # ---------------------------------------------------------------------------
328
+ # Internal helpers
329
+ # ---------------------------------------------------------------------------
330
+
331
+ def _make_active_mask(self) -> torch.Tensor:
332
+ """Construct the (B, L, T) active mask from current counts.
333
+
334
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
335
+ has been written. Positions at or beyond the count are junk and must be
336
+ excluded by downstream attention.
337
+ """
338
+ cap = self.buffer_capacity
339
+ return (
340
+ torch.arange(cap, device=self.keys.device)
341
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
342
+ < self._counts.unsqueeze(-1)
343
+ )
344
+
345
+ @staticmethod
346
+ def _check_no_overflow(max_count: torch.Tensor, capacity: int) -> None:
347
+ """Raise if any (batch, head) slot would exceed the static buffer capacity.
348
+
349
+ Uses the 19.F.1 pattern: branches on whether the graph is being compiled.
350
+ In compiled mode, `.item()` folds into the graph when capture_scalar_outputs=True
351
+ and `torch._check` issues a compile-time assertion. In eager mode, a plain
352
+ RuntimeError is raised with a descriptive message.
353
+
354
+ Args:
355
+ max_count: Scalar tensor — the maximum post-update count across all slots.
356
+ capacity: The static buffer capacity (mosrah_cache_length).
357
+ """
358
+ if torch.compiler.is_compiling():
359
+ torch._check(max_count.item() <= capacity)
360
+ else:
361
+ if max_count.item() > capacity:
362
+ raise RuntimeError(
363
+ f"MoSRAHCache overflow: a (batch, head) slot would reach "
364
+ f"{max_count.item()} tokens but the static buffer capacity is "
365
+ f"{capacity}. Increase mosrah_overallocation_factor in ShramConfig."
366
+ )
367
+
__cache__shram_cache.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
+
3
+ The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
+ CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
+ lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
+ ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
+ presents them through the Cache interface, and transparently forwards model-wide operations
8
+ across all of them.
9
+
10
+ ShramCache does not define a composite update() interface. The two attention paths inside each
11
+ SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
+ nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
+ relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
+ coordination of the layer caches — not routing attention inputs.
15
+
16
+ Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
+ specified layer, which tracks the cumulative count of token positions processed. This is
18
+ what HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import Cache
23
+
24
+ from .configuration import ShramConfig
25
+ from .__cache__shram_layer_cache import ShramLayerCache
26
+
27
+
28
+ class ShramCache(Cache):
29
+ """Top-level cache for the full SHRAM model.
30
+
31
+ Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
32
+ role and transparently forwards reset, reorder, and sequence-length queries across all
33
+ owned layer caches.
34
+
35
+ No composite update() interface is provided. The two attention paths inside each SHRAM
36
+ layer have materially different update semantics; callers must update sub-caches directly
37
+ via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
38
+
39
+ Args:
40
+ config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
41
+ dimensions are derived from config so that a single source of truth governs
42
+ every buffer size across the full cache stack.
43
+ batch_size: Number of sequences in the batch.
44
+ device: Device on which to allocate cache tensors.
45
+ """
46
+
47
+ is_compileable = True
48
+
49
+ def __init__(
50
+ self,
51
+ config: ShramConfig,
52
+ batch_size: int,
53
+ device: torch.device,
54
+ ) -> None:
55
+ layers = [
56
+ ShramLayerCache(
57
+ config=config,
58
+ batch_size=batch_size,
59
+ device=device,
60
+ )
61
+ for _ in range(config.num_decoder_layers)
62
+ ]
63
+ super().__init__(layers=layers)
64
+
65
+ # ---------------------------------------------------------------------------
66
+ # Cache — composite-meaningful methods
67
+ # ---------------------------------------------------------------------------
68
+ #
69
+ # reset(): Inherited. Iterates all layer caches and calls reset() on each.
70
+ #
71
+ # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
72
+ #
73
+ # is_initialized: Inherited property. True iff all layer caches are initialized.
74
+ # Since ShramLayerCache.is_initialized is True from construction, this is True
75
+ # immediately after ShramCache.__init__ returns.
76
+
77
+ def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
78
+ """Return the cumulative sequence length for the specified layer.
79
+
80
+ Delegates to the layer cache at layer_idx, which in turn delegates to the
81
+ local sliding-window sub-cache. That sub-cache is authoritative for sequence
82
+ progress: it sees every token presented to the layer and accumulates a truthful
83
+ total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
84
+ """
85
+ return self.layers[layer_idx].get_seq_length()
86
+
87
+ # ---------------------------------------------------------------------------
88
+ # Cache — unsupported methods
89
+ # ---------------------------------------------------------------------------
90
+
91
+ def update( # type: ignore[override]
92
+ self,
93
+ key_states: torch.Tensor,
94
+ value_states: torch.Tensor,
95
+ layer_idx: int,
96
+ cache_kwargs: dict | None = None,
97
+ ) -> tuple[torch.Tensor, torch.Tensor]:
98
+ """Not supported — ShramCache has no composite update interface.
99
+
100
+ The two attention paths inside each SHRAM layer have different update semantics.
101
+ Callers must update sub-caches directly:
102
+ cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
103
+ cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
104
+ """
105
+ raise NotImplementedError(
106
+ "ShramCache has no composite update interface. "
107
+ "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
108
+ )
109
+
110
+ def crop(self, max_length: int) -> None:
111
+ """Not supported — ShramCache layers do not implement crop()."""
112
+ raise NotImplementedError("ShramCache does not support crop().")
113
+
114
+ @property
115
+ def max_batch_size(self) -> int:
116
+ """Not supported — ShramCache does not track a uniform batch size across layers."""
117
+ raise NotImplementedError("ShramCache does not expose max_batch_size.")
118
+
119
+ @property
120
+ def max_cache_len(self) -> int:
121
+ """Return the maximum sequence length the cache can serve.
122
+
123
+ Delegates to layers[0].get_max_cache_shape(), which returns
124
+ config.inference_sequence_length. HuggingFace's static-cache machinery reads
125
+ this value to size generation loops and verify compileable cache contracts.
126
+ """
127
+ return self.layers[0].get_max_cache_shape()
__cache__shram_layer_cache.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
+
3
+ A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
+ local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
+ different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
+ the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
+ path can interact with it without indirection.
8
+
9
+ ShramLayerCache does not define a composite update() interface. The two paths have materially
10
+ different update semantics — the local side uses chunk-local key/value/mask concatenation
11
+ while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
+ behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
+ each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
+ ownership, coordination, and reset/reorder boundary for one decoder layer.
15
+
16
+ Sequence length at this boundary is reported by delegating to the local sliding-window
17
+ sub-cache, which tracks the cumulative count of token positions processed. This is the
18
+ quantity HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import CacheLayerMixin
23
+
24
+ from .configuration import ShramConfig
25
+ from .__cache__mosrah_cache import MoSRAHCache
26
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
27
+
28
+
29
+ class ShramLayerCache(CacheLayerMixin):
30
+ """Cache subsystem for one SHRAM decoder layer.
31
+
32
+ Owns and coordinates two sub-caches:
33
+ - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
34
+ - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
35
+
36
+ Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
37
+ exposed directly for their downstream attention paths — no composite update() interface is
38
+ provided, because the two paths have materially different update semantics.
39
+
40
+ Sequence length is reported by delegating to the local sliding-window sub-cache, which
41
+ tracks the cumulative count of token positions processed across all update() calls.
42
+
43
+ Args:
44
+ config: ShramConfig instance. All sub-cache dimensions and capacities are derived
45
+ from config so that a single source of truth governs every buffer size.
46
+ batch_size: Number of sequences in the batch.
47
+ device: Device on which to allocate cache tensors.
48
+ """
49
+
50
+ is_compileable = True
51
+ is_sliding = False
52
+
53
+ def __init__(
54
+ self,
55
+ config: ShramConfig,
56
+ batch_size: int,
57
+ device: torch.device,
58
+ ) -> None:
59
+ super().__init__()
60
+ self._inference_sequence_length = config.inference_sequence_length
61
+ self.sliding_window_cache = LocalSlidingWindowLayerCache(
62
+ sliding_window=config.window_size,
63
+ num_heads=config.num_sliding_window_heads,
64
+ head_dim=config.head_dim,
65
+ batch_size=batch_size,
66
+ device=device,
67
+ )
68
+ self.mosrah_cache = MoSRAHCache(
69
+ num_mosrah_heads=config.num_mosrah_heads,
70
+ head_dim=config.head_dim,
71
+ batch_size=batch_size,
72
+ device=device,
73
+ mosrah_cache_length=config.mosrah_cache_length,
74
+ )
75
+
76
+ # ---------------------------------------------------------------------------
77
+ # Properties
78
+ # ---------------------------------------------------------------------------
79
+
80
+ @property
81
+ def is_initialized(self) -> bool:
82
+ """True iff both sub-caches have allocated their storage.
83
+
84
+ Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
85
+ so this is True immediately after ShramLayerCache.__init__ returns.
86
+ """
87
+ return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
88
+
89
+ @is_initialized.setter
90
+ def is_initialized(self, value: bool) -> None:
91
+ # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
92
+ # attribute. Since property is a data descriptor it takes precedence, but Python
93
+ # still routes the assignment through __set__. Absorb it silently — state is
94
+ # derived from sub-caches, not stored here.
95
+ pass
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # CacheLayerMixin — composite-meaningful methods
99
+ # ---------------------------------------------------------------------------
100
+
101
+ def get_seq_length(self) -> int: # type: ignore[override]
102
+ """Return the cumulative sequence length from the local sliding-window path.
103
+
104
+ The local path is authoritative for sequence progress: it sees every token
105
+ presented to this layer and accumulates a truthful total. Delegates to
106
+ sliding_window_cache.get_seq_length().
107
+ """
108
+ return self.sliding_window_cache.get_seq_length()
109
+
110
+ def reset(self) -> None:
111
+ """Clear both sub-caches.
112
+
113
+ Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
114
+ state and MoSRAH sparse state remain consistent.
115
+ """
116
+ self.sliding_window_cache.reset()
117
+ self.mosrah_cache.reset()
118
+
119
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
120
+ """Reorder the batch dimension of both sub-caches for beam search.
121
+
122
+ Delegates to each sub-cache. Both are reordered atomically so the sliding-window
123
+ and MoSRAH state correspond to the same beam hypotheses after reordering.
124
+
125
+ Args:
126
+ beam_idx: Permutation indices of shape (batch,) produced by beam search.
127
+ """
128
+ self.sliding_window_cache.reorder_cache(beam_idx)
129
+ self.mosrah_cache.reorder_cache(beam_idx)
130
+
131
+ def batch_repeat_interleave(self, repeats: int) -> None:
132
+ """Expand the batch dimension of both sub-caches for beam search initialisation.
133
+
134
+ Delegates atomically to each sub-cache. Both must be expanded together so the
135
+ sliding-window and MoSRAH state correspond to the same beam candidates.
136
+
137
+ Args:
138
+ repeats: Number of times to repeat each batch entry.
139
+ """
140
+ self.sliding_window_cache.batch_repeat_interleave(repeats)
141
+ self.mosrah_cache.batch_repeat_interleave(repeats)
142
+
143
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
144
+ """Select a subset of batch entries in both sub-caches for contrastive search.
145
+
146
+ Delegates atomically to each sub-cache. Both must be trimmed together so the
147
+ sliding-window and MoSRAH state remain consistent.
148
+
149
+ Args:
150
+ indices: 1-D integer tensor of batch indices to retain.
151
+ """
152
+ self.sliding_window_cache.batch_select_indices(indices)
153
+ self.mosrah_cache.batch_select_indices(indices)
154
+
155
+ def offload(self) -> None:
156
+ """Offload both sub-caches to CPU.
157
+
158
+ Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
159
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
160
+ """
161
+ self.sliding_window_cache.offload()
162
+ self.mosrah_cache.offload()
163
+
164
+ def prefetch(self) -> None:
165
+ """Move both sub-caches back to their model device ahead of time.
166
+
167
+ Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
168
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
+ """
170
+ self.sliding_window_cache.prefetch()
171
+ self.mosrah_cache.prefetch()
172
+
173
+ def lazy_initialization( # type: ignore[override]
174
+ self, key_states: torch.Tensor, value_states: torch.Tensor
175
+ ) -> None:
176
+ """No-op — both sub-caches handle their own initialization."""
177
+ pass
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # CacheLayerMixin — unsupported abstract methods
181
+ # ---------------------------------------------------------------------------
182
+
183
+ def update( # type: ignore[override]
184
+ self,
185
+ key_states: torch.Tensor,
186
+ value_states: torch.Tensor,
187
+ cache_kwargs: dict | None = None,
188
+ ) -> tuple[torch.Tensor, torch.Tensor]:
189
+ """Not supported — ShramLayerCache has no composite update interface.
190
+
191
+ The two sub-caches have materially different update semantics: the sliding-window
192
+ side uses standard key/value concatenation while the MoSRAH side uses expert-choice
193
+ scatter with an active mask. Callers must update each sub-cache directly via
194
+ sliding_window_cache.update() or mosrah_cache.update().
195
+ """
196
+ raise NotImplementedError(
197
+ "ShramLayerCache has no composite update interface. "
198
+ "Update sliding_window_cache or mosrah_cache directly."
199
+ )
200
+
201
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
202
+ """Return the maximum sequence length this layer cache can serve.
203
+
204
+ The authoritative upper bound is ``config.inference_sequence_length``, which
205
+ governs the full accumulated token history the model is configured to handle.
206
+ HuggingFace's static-cache machinery reads this value to determine whether the
207
+ cache is compileable and to size generation loops.
208
+ """
209
+ return self._inference_sequence_length
210
+
211
+ def get_mask_sizes( # type: ignore[override]
212
+ self,
213
+ cache_position: torch.Tensor,
214
+ ) -> tuple[int, int]:
215
+ """Return the KV dimensions for HuggingFace causal mask construction.
216
+
217
+ Returns (inference_sequence_length, 0): the full static cache capacity as
218
+ kv_length and zero offset. HuggingFace reads these values to size the causal
219
+ attention mask when is_compileable is True.
220
+ """
221
+ return self._inference_sequence_length, 0
__cache__sliding_window_cache.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/shram/model/cache/sliding_window_cache.py
2
+
3
+ """Local sliding-window cache for the SHRAM local attention path.
4
+
5
+ This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
+ `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
+
8
+ Its job is narrow:
9
+
10
+ - accept the current chunk's local key/value tensors and active mask
11
+ - return the current-step local frame consumed by local attention
12
+ - separately retain the next-step sliding-window cache state
13
+
14
+ It does not decide local causal visibility. That is owned by
15
+ `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
+ constructs the effective local attention mask from it.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
+ """Fixed-width local cache for one SHRAM decoder layer.
25
+
26
+ The cache keeps a retained local sliding-window buffer and an aligned active
27
+ mask. On update, it returns the current-step local frame formed by
28
+ concatenating retained cache state with the new chunk, then remembers only
29
+ the last `sliding_window` positions for the next step.
30
+
31
+ Dead positions are allowed to remain in both the returned frame and the
32
+ retained cache. Correctness is carried by the aligned active mask.
33
+
34
+ Args:
35
+ sliding_window: Width of the retained local sliding-window buffer.
36
+ num_heads: Number of local attention heads.
37
+ head_dim: Per-head embedding width for the local path.
38
+ batch_size: Number of sequences in the batch.
39
+ device: Device on which to allocate cache storage.
40
+ """
41
+
42
+ is_compileable = True
43
+ is_sliding = True
44
+
45
+ def __init__(
46
+ self,
47
+ sliding_window: int,
48
+ num_heads: int,
49
+ head_dim: int,
50
+ batch_size: int,
51
+ device: torch.device,
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ if sliding_window < 1:
56
+ raise ValueError(
57
+ f"sliding_window must be >= 1, got {sliding_window}."
58
+ )
59
+ if num_heads < 1:
60
+ raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
+ if head_dim < 1:
62
+ raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
+ if batch_size < 1:
64
+ raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
+
66
+ self.sliding_window = sliding_window
67
+ self.num_heads = num_heads
68
+ self.head_dim = head_dim
69
+ self.batch_size = batch_size
70
+ self.device = device
71
+
72
+ # Retained next-step local cache state. Storage is fixed-width from the
73
+ # start; semantic validity is carried by `active_mask`.
74
+ self.keys = torch.zeros(
75
+ batch_size,
76
+ num_heads,
77
+ sliding_window,
78
+ head_dim,
79
+ device=device,
80
+ )
81
+ self.values = torch.zeros(
82
+ batch_size,
83
+ num_heads,
84
+ sliding_window,
85
+ head_dim,
86
+ device=device,
87
+ )
88
+ self.active_mask = torch.zeros(
89
+ batch_size,
90
+ sliding_window,
91
+ dtype=torch.bool,
92
+ device=device,
93
+ )
94
+
95
+ # Absolute sequence positions of each retained slot. Inactive slots
96
+ # retain zero; correctness is carried by active_mask.
97
+ self.positions = torch.zeros(
98
+ batch_size,
99
+ sliding_window,
100
+ dtype=torch.long,
101
+ device=device,
102
+ )
103
+
104
+ self.is_initialized = True
105
+
106
+ # Cumulative count of all token positions presented through update() for
107
+ # this cache instance. This is the quantity HuggingFace generation reads
108
+ # through get_seq_length() to track how far along the sequence we are.
109
+ self._total_processed: int = 0
110
+
111
+ def update( # type: ignore[override]
112
+ self,
113
+ key_states: torch.Tensor,
114
+ value_states: torch.Tensor,
115
+ active_mask: torch.Tensor,
116
+ positions: torch.Tensor,
117
+ cache_kwargs: dict | None = None,
118
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
119
+ """Return the current-step local frame and retain the next-step window.
120
+
121
+ Args:
122
+ key_states: Shape `(B, H, T_new, D)` local key vectors for the
123
+ current chunk.
124
+ value_states: Shape `(B, H, T_new, D)` local value vectors for the
125
+ current chunk.
126
+ active_mask: Shape `(B, T_new)` bool. `True` means the
127
+ corresponding token position in the current chunk is active.
128
+ positions: Shape `(B, T_new)` long. Absolute sequence position of
129
+ each token in the current chunk.
130
+ cache_kwargs: Present only to satisfy the `CacheLayerMixin`
131
+ interface. Unused by this cache.
132
+
133
+ Returns:
134
+ Tuple of:
135
+ - visible_keys: `(B, H, sliding_window + T_new, D)`
136
+ - visible_values: `(B, H, sliding_window + T_new, D)`
137
+ - visible_active_mask: `(B, sliding_window + T_new)`
138
+ - visible_positions: `(B, sliding_window + T_new)`
139
+
140
+ These are the tensors the local attention path should consume
141
+ directly for the current step.
142
+ """
143
+ self._ensure_state_compatibility(
144
+ key_states=key_states,
145
+ value_states=value_states,
146
+ )
147
+
148
+ # The current-step local frame is just retained cache state followed by
149
+ # the current chunk in chronological order.
150
+ composite_keys, composite_values, composite_mask, composite_positions = self._make_composite_frame(
151
+ key_states=key_states,
152
+ value_states=value_states,
153
+ active_mask=active_mask,
154
+ positions=positions,
155
+ )
156
+
157
+ # The cache remembers only the last raw sliding-window positions of that
158
+ # composite frame for the next step. Dead positions are allowed to
159
+ # survive; downstream local attention will ignore them using the mask.
160
+ self._retain_next_window(
161
+ composite_keys=composite_keys,
162
+ composite_values=composite_values,
163
+ composite_mask=composite_mask,
164
+ composite_positions=composite_positions,
165
+ )
166
+
167
+ self._total_processed += key_states.shape[2]
168
+
169
+ return composite_keys, composite_values, composite_mask, composite_positions
170
+
171
+ def _ensure_state_compatibility(
172
+ self,
173
+ key_states: torch.Tensor,
174
+ value_states: torch.Tensor,
175
+ ) -> None:
176
+ """Keep retained cache buffers compatible with the incoming update tensors.
177
+
178
+ The cache is allocated eagerly for simplicity. If later updates arrive on
179
+ a different device or in a different floating dtype, move the retained
180
+ state to match while preserving its contents.
181
+ """
182
+ if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
183
+ self.keys = self.keys.to(
184
+ device=key_states.device,
185
+ dtype=key_states.dtype,
186
+ )
187
+
188
+ if (
189
+ self.values.dtype != value_states.dtype
190
+ or self.values.device != value_states.device
191
+ ):
192
+ self.values = self.values.to(
193
+ device=value_states.device,
194
+ dtype=value_states.dtype,
195
+ )
196
+
197
+ if self.active_mask.device != key_states.device:
198
+ self.active_mask = self.active_mask.to(
199
+ key_states.device,
200
+ non_blocking=True,
201
+ )
202
+
203
+ if self.positions.device != key_states.device:
204
+ self.positions = self.positions.to(
205
+ key_states.device,
206
+ non_blocking=True,
207
+ )
208
+
209
+ def _make_composite_frame(
210
+ self,
211
+ key_states: torch.Tensor,
212
+ value_states: torch.Tensor,
213
+ active_mask: torch.Tensor,
214
+ positions: torch.Tensor,
215
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
216
+ """Build the current-step local frame in chronological order."""
217
+ return (
218
+ torch.cat([self.keys, key_states], dim=-2),
219
+ torch.cat([self.values, value_states], dim=-2),
220
+ torch.cat([self.active_mask, active_mask], dim=-1),
221
+ torch.cat([self.positions, positions], dim=-1),
222
+ )
223
+
224
+ def _retain_next_window(
225
+ self,
226
+ composite_keys: torch.Tensor,
227
+ composite_values: torch.Tensor,
228
+ composite_mask: torch.Tensor,
229
+ composite_positions: torch.Tensor,
230
+ ) -> None:
231
+ """Remember the next-step retained local state.
232
+
233
+ This is a raw positional trim to the last `sliding_window` positions, not
234
+ a semantic live-token trim.
235
+ """
236
+ self.keys[:] = composite_keys[:, :, -self.sliding_window :, :]
237
+ self.values[:] = composite_values[:, :, -self.sliding_window :, :]
238
+ self.active_mask[:] = composite_mask[:, -self.sliding_window :]
239
+ self.positions[:] = composite_positions[:, -self.sliding_window :]
240
+
241
+ def get_seq_length(self) -> int:
242
+ """Return the cumulative number of token positions processed by this cache.
243
+
244
+ This is the total count of token positions presented across all update()
245
+ calls since construction or the last reset(). It is the quantity HuggingFace
246
+ generation reads to track sequence progress and is not the same as active-token
247
+ count or current window occupancy.
248
+ """
249
+ return self._total_processed
250
+
251
+ def get_max_cache_shape(self) -> int:
252
+ return self.sliding_window
253
+
254
+ def get_mask_sizes( # type: ignore[override]
255
+ self,
256
+ cache_position: torch.Tensor,
257
+ ) -> tuple[int, int]:
258
+ raise NotImplementedError(
259
+ "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
260
+ )
261
+
262
+ def reset(self) -> None:
263
+ """Restore fresh-cache behavior."""
264
+ self.keys.zero_()
265
+ self.values.zero_()
266
+ self.active_mask.zero_()
267
+ self.positions.zero_()
268
+ self._total_processed = 0
269
+
270
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
271
+ """Reorder the batch dimension for beam search."""
272
+ self.keys = self.keys[beam_idx]
273
+ self.values = self.values[beam_idx]
274
+ self.active_mask = self.active_mask[beam_idx]
275
+ self.positions = self.positions[beam_idx]
276
+
277
+ def batch_repeat_interleave(self, repeats: int) -> None:
278
+ """Expand the batch dimension for beam-search initialisation."""
279
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
280
+ self.values = self.values.repeat_interleave(repeats, dim=0)
281
+ self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
282
+ self.positions = self.positions.repeat_interleave(repeats, dim=0)
283
+ self.batch_size = self.batch_size * repeats
284
+
285
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
286
+ """Select a subset of batch entries for contrastive search."""
287
+ self.keys = self.keys[indices]
288
+ self.values = self.values[indices]
289
+ self.active_mask = self.active_mask[indices]
290
+ self.positions = self.positions[indices]
291
+ self.batch_size = int(indices.shape[0])
292
+
293
+ def offload(self) -> None:
294
+ """Offload cache tensors to CPU."""
295
+ super().offload()
296
+ self.active_mask = self.active_mask.to("cpu", non_blocking=True)
297
+ self.positions = self.positions.to("cpu", non_blocking=True)
298
+
299
+ def prefetch(self) -> None:
300
+ """Move cache tensors back to the model device ahead of time."""
301
+ super().prefetch()
302
+ if self.active_mask.device != self.keys.device:
303
+ self.active_mask = self.active_mask.to(
304
+ self.keys.device,
305
+ non_blocking=True,
306
+ )
307
+ self.positions = self.positions.to(
308
+ self.keys.device,
309
+ non_blocking=True,
310
+ )
311
+
312
+ def crop(self, max_length: int) -> None:
313
+ raise NotImplementedError(
314
+ "LocalSlidingWindowLayerCache does not support crop()."
315
+ )
316
+
317
+ def lazy_initialization(
318
+ self,
319
+ key_states: torch.Tensor,
320
+ value_states: torch.Tensor,
321
+ ) -> None:
322
+ """No-op — this cache allocates its fixed buffers at construction time."""
323
+ return
__cache__slow_mosrah_cache.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
+
3
+ This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
+ interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
+ (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
+ position's key and value are written to the next available slot for that (batch, head)
7
+ pair, in the order positions appear along the T dimension, which directly enforces
8
+ causal ordering without any index arithmetic to verify.
9
+
10
+ SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
+ trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
+ Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
+ vectorized implementation is validated by asserting exact agreement with this one on all
14
+ test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
+ (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
+ an oracle.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class SlowMoSRAHCache(CacheLayerMixin):
24
+ """Unvectorized reference implementation of the MoSRAH KV cache.
25
+
26
+ Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
+ mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
+ with the same constructor signature and the same CacheLayerMixin protocol methods.
29
+ The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
+ triples rather than vectorized index arithmetic.
31
+
32
+ This class is not used in the model path. It exists so that MoSRAHCache.update()
33
+ can be validated by asserting exact agreement with this implementation on all test
34
+ inputs. See module docstring for the trust chain this enables.
35
+
36
+ Args:
37
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
+ second dimension of all storage tensors.
39
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
+ dimension of all storage tensors.
41
+ batch_size: Number of sequences in the batch. Determines the first dimension
42
+ of all storage tensors.
43
+ device: Device on which to allocate all tensors. Should match the model device.
44
+ mosrah_cache_length: Static sequence capacity per (batch, head) slot. Equal to
45
+ config.mosrah_cache_length. The buffer never grows; if any slot would exceed
46
+ this capacity, update() raises a RuntimeError.
47
+ """
48
+
49
+ is_compileable = False
50
+ is_sliding = False
51
+
52
+ def __init__(
53
+ self,
54
+ num_mosrah_heads: int,
55
+ head_dim: int,
56
+ batch_size: int,
57
+ device: torch.device,
58
+ mosrah_cache_length: int,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.num_mosrah_heads = num_mosrah_heads
62
+ self.head_dim = head_dim
63
+ self.batch_size = batch_size
64
+ self.device = device
65
+ self.mosrah_cache_length = mosrah_cache_length
66
+
67
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
68
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
69
+ # tracks valid occupancy per (batch, head) slot.
70
+ self.keys: torch.Tensor = torch.zeros(
71
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
72
+ )
73
+ self.values: torch.Tensor = torch.zeros(
74
+ batch_size, num_mosrah_heads, mosrah_cache_length, head_dim, device=device
75
+ )
76
+ self._counts: torch.Tensor = torch.zeros(
77
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
78
+ )
79
+
80
+ # Storage is fully allocated at construction — the cache is initialized.
81
+ self.is_initialized = True
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Properties
85
+ # ---------------------------------------------------------------------------
86
+
87
+ @property
88
+ def buffer_capacity(self) -> int:
89
+ """Current number of slots allocated per (batch, head) pair.
90
+
91
+ Equal to mosrah_cache_length as supplied at construction. Derived from
92
+ self.keys so it remains consistent with the actual buffer shape.
93
+ """
94
+ return self.keys.shape[2]
95
+
96
+ # ---------------------------------------------------------------------------
97
+ # Primary API
98
+ # ---------------------------------------------------------------------------
99
+
100
+ def update( # type: ignore[override]
101
+ self,
102
+ key_states: torch.Tensor,
103
+ value_states: torch.Tensor,
104
+ active_mask: torch.Tensor,
105
+ cache_kwargs: dict | None = None,
106
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
107
+ """Scatter active key/value states using an explicit loop; return full cache state.
108
+
109
+ Iterates over every (b, l, t) triple. For each position where active_mask is
110
+ True, the key and value are written to the next available slot for that
111
+ (batch, head) pair and the count is incremented. Causal ordering is guaranteed
112
+ because the t dimension is traversed from 0 to T-1 and counts are updated
113
+ immediately after each write.
114
+
115
+ Raises RuntimeError before any writes if the incoming tokens would cause any
116
+ slot to exceed the static mosrah_cache_length capacity.
117
+
118
+ Args:
119
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
120
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
121
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
122
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
123
+
124
+ Returns:
125
+ Tuple of (keys, values, active_mask):
126
+ keys: (B, L, mosrah_cache_length, u) float — full key buffer including junk slots.
127
+ values: (B, L, mosrah_cache_length, u) float — full value buffer including junk slots.
128
+ active_mask: (B, L, mosrah_cache_length) bool — True iff slot t has been written.
129
+ """
130
+ B, L, T = active_mask.shape
131
+
132
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
133
+ if (self._counts + incoming_delta).max().item() > self.mosrah_cache_length:
134
+ raise RuntimeError(
135
+ f"SlowMoSRAHCache overflow: a (batch, head) slot would exceed the "
136
+ f"static buffer capacity of {self.mosrah_cache_length}. Increase "
137
+ f"mosrah_overallocation_factor in ShramConfig."
138
+ )
139
+
140
+ # Write each active position into the next available slot for its (batch, head)
141
+ # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
142
+ for b in range(B):
143
+ for l in range(L):
144
+ for t in range(T):
145
+ if active_mask[b, l, t]:
146
+ pos = self._counts[b, l].item()
147
+ self.keys[b, l, pos, :] = key_states[b, l, t, :]
148
+ self.values[b, l, pos, :] = value_states[b, l, t, :]
149
+ self._counts[b, l] += 1
150
+
151
+ return self.keys, self.values, self._make_active_mask()
152
+
153
+ def get_heads_lengths(self) -> torch.Tensor:
154
+ """Return the per-(batch, head) token count for this layer.
155
+
156
+ This is the authoritative occupancy tensor consumed by BEA for attention
157
+ masking and by position computation (Unit 10.A) for semantic-sequence
158
+ position computation.
159
+
160
+ Returns:
161
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
162
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
163
+ """
164
+ return self._counts
165
+
166
+ # ---------------------------------------------------------------------------
167
+ # CacheLayerMixin — overridden coordination methods
168
+ # ---------------------------------------------------------------------------
169
+
170
+ def reset(self) -> None:
171
+ """Clear all cached key and value tensors.
172
+
173
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
174
+ and is_initialized remains True — only the contents are cleared.
175
+ """
176
+ self.keys.zero_()
177
+ self.values.zero_()
178
+ self._counts.zero_()
179
+
180
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
181
+ """Reorder the batch dimension of all cached tensors for beam search.
182
+
183
+ Applied atomically across self.keys, self.values, and _counts. Beam search
184
+ must reorder all three together or the occupancy counts and buffer contents
185
+ will correspond to different beam hypotheses.
186
+
187
+ Overrides the parent because the parent's implementation calls get_seq_length(),
188
+ which is not supported for this cache.
189
+
190
+ Args:
191
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
192
+ search algorithm.
193
+ """
194
+ self.keys = self.keys[beam_idx]
195
+ self.values = self.values[beam_idx]
196
+ self._counts = self._counts[beam_idx]
197
+
198
+ def batch_repeat_interleave(self, repeats: int) -> None:
199
+ """Expand the batch dimension by repeating each entry repeats times.
200
+
201
+ Used at beam search initialisation to expand the cache from batch size B to
202
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
203
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
204
+
205
+ Args:
206
+ repeats: Number of times to repeat each batch entry.
207
+ """
208
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
209
+ self.values = self.values.repeat_interleave(repeats, dim=0)
210
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
211
+ self.batch_size = self.batch_size * repeats
212
+
213
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
214
+ """Select a subset of batch entries by index.
215
+
216
+ Used in contrastive search to retain only the selected candidate entries.
217
+ Applied atomically across keys, values, and _counts; batch_size is updated
218
+ to reflect the number of retained entries.
219
+
220
+ Args:
221
+ indices: 1-D integer tensor of batch indices to retain.
222
+ """
223
+ self.keys = self.keys[indices]
224
+ self.values = self.values[indices]
225
+ self._counts = self._counts[indices]
226
+ self.batch_size = indices.shape[0]
227
+
228
+ def offload(self) -> None:
229
+ """Offload all cached tensors to CPU.
230
+
231
+ Extends the parent to also offload _counts, which the parent does not know
232
+ about. All three tensors are moved atomically so device state remains consistent.
233
+ """
234
+ super().offload()
235
+ self._counts = self._counts.to("cpu", non_blocking=True)
236
+
237
+ def prefetch(self) -> None:
238
+ """Move all cached tensors back to the model device ahead of time.
239
+
240
+ Extends the parent to also prefetch _counts, which the parent does not know
241
+ about. _counts is synced to self.keys.device after the parent moves keys and
242
+ values, so all three remain consistent.
243
+ """
244
+ super().prefetch()
245
+ if self._counts.device != self.keys.device:
246
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
247
+
248
+ def lazy_initialization( # type: ignore[override]
249
+ self, key_states: torch.Tensor, value_states: torch.Tensor
250
+ ) -> None:
251
+ """No-op — storage is fully allocated at construction time."""
252
+ pass
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # CacheLayerMixin — unsupported abstract methods
256
+ # ---------------------------------------------------------------------------
257
+
258
+ def get_seq_length(self) -> int: # type: ignore[override]
259
+ """Not supported — no single sequence length represents this cache's state.
260
+
261
+ MoSRAH heads accumulate independently; (batch, head) slots have different
262
+ lengths depending on routing history. There is no meaningful scalar summary.
263
+ Use get_heads_lengths() for per-head occupancy.
264
+ """
265
+ raise NotImplementedError(
266
+ "SlowMoSRAHCache has no single sequence length. "
267
+ "Use get_heads_lengths() for per-head occupancy."
268
+ )
269
+
270
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
271
+ """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
272
+ raise NotImplementedError(
273
+ "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
274
+ )
275
+
276
+ def get_mask_sizes( # type: ignore[override]
277
+ self,
278
+ cache_position: torch.Tensor,
279
+ ) -> tuple[int, int]:
280
+ """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
281
+ raise NotImplementedError(
282
+ "SlowMoSRAHCache does not support get_mask_sizes()."
283
+ )
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # Internal helpers
287
+ # ---------------------------------------------------------------------------
288
+
289
+ def _make_active_mask(self) -> torch.Tensor:
290
+ """Construct the (B, L, T) active mask from current counts.
291
+
292
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
293
+ has been written. Positions at or beyond the count are junk and must be
294
+ excluded by downstream attention.
295
+ """
296
+ cap = self.buffer_capacity
297
+ return (
298
+ torch.arange(cap, device=self.keys.device)
299
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
300
+ < self._counts.unsqueeze(-1)
301
+ )
302
+
__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration import ShramConfig
2
+ from .decoder_layer import DecoderLayer
3
+ from .huggingface import ShramForCausalLM
4
+ from .__attention__load_balance_loss import LoadBalanceLoss
5
+ from .mlp import SwiGLUMLP
6
+ from .model import ShramModel
7
+ from .rope import RotaryEmbedding
8
+ from .__attention__router import MoSRAHRouter
9
+ from .__cache__mosrah_cache import MoSRAHCache
10
+
11
+ __all__ = [
12
+ "DecoderLayer",
13
+ "LoadBalanceLoss",
14
+ "MoSRAHCache",
15
+ "MoSRAHRouter",
16
+ "ShramConfig",
17
+ "ShramForCausalLM",
18
+ "ShramModel",
19
+ "RotaryEmbedding",
20
+ "SwiGLUMLP",
21
+ ]
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 1.0,
3
+ "attention_dropout": 0.0,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration.ShramConfig",
6
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
+ },
8
+ "beta": 32.0,
9
+ "embedding_width": 512,
10
+ "head_dim": 16,
11
+ "inference_sequence_length": 1024,
12
+ "load_balance_p": 2.0,
13
+ "local_rope_theta": 10000.0,
14
+ "mlp_width": 1366,
15
+ "model_type": "shram",
16
+ "mosrah_overallocation_factor": 2.0,
17
+ "mosrah_rope_theta": 10000.0,
18
+ "num_decoder_layers": 12,
19
+ "num_mosrah_heads": 16,
20
+ "num_selected_heads": 16,
21
+ "num_sliding_window_heads": 16,
22
+ "rms_norm_eps": 1e-05,
23
+ "rope_mode": "main_sequence",
24
+ "tie_word_embeddings": false,
25
+ "training_sequence_length": 1024,
26
+ "transformers_version": "5.9.0",
27
+ "use_cache": true,
28
+ "vocab_size": 50277,
29
+ "window_size": 128
30
+ }
configuration.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for the SHRAM transformer.
2
+
3
+ All architectural parameters that vary across model scales or are meaningful research
4
+ variables are expressed here. Architectural constants (no bias in linear layers,
5
+ SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
+ documented at the point of use — they are not config parameters because they do not
7
+ vary and changing them produces a different architecture, not a different scale.
8
+
9
+ RoPE configuration is owned entirely by this config. Each attention path reads its
10
+ parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
+ HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
+ """
13
+
14
+ import math
15
+
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+ class ShramConfig(PretrainedConfig):
20
+ """Configuration class for the SHRAM decoder-only transformer.
21
+
22
+ SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
23
+ attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
24
+ local sliding-window causal attention path and h_s is the MoSRAH sparse routed
25
+ path. All other components follow the Llama 3 baseline.
26
+
27
+ This config is the single source of truth for every architectural dimension of the
28
+ model. Nothing in the architecture may use a literal number that belongs here.
29
+
30
+ Two independent RoPE configurations exist — one per attention path:
31
+
32
+ - h_l always uses standard RoPE with ``local_rope_theta``.
33
+ - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
34
+ ``inference_sequence_length``, ``alpha``, and ``beta``. When
35
+ ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
36
+ ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
37
+ and the correct setting for experiments that do not require context extension.
38
+
39
+ Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
40
+
41
+ config = AutoConfig.from_pretrained(
42
+ "your-namespace/advanced-transformers-lib",
43
+ trust_remote_code=True,
44
+ num_hidden_layers=12,
45
+ )
46
+ model = AutoModelForCausalLM.from_config(config)
47
+
48
+ Args:
49
+ vocab_size: Vocabulary size. Controls the embedding table and output logits
50
+ dimension. Must match the tokenizer.
51
+ embedding_width: Model width ``d``. The dimension of the residual stream.
52
+ mlp_width: FFN hidden dimension.
53
+ num_decoder_layers: Number of transformer blocks stacked in sequence.
54
+ num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
55
+ num_mosrah_heads: Total MoSRAH expert heads available ``L``.
56
+ num_selected_heads: MoSRAH heads each token selects ``K``.
57
+ head_dim: Per-head dimension, shared by both attention paths. Must be even
58
+ (RoPE rotates dimensions in pairs). Paper uses 16.
59
+ window_size: Sliding window size for h_l. Paper uses 128.
60
+ rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
61
+ original sequence positions; ``"semantic_sequence"`` supplies local slot
62
+ indices. Both are required; experimentally correct mode is undetermined
63
+ (paper §4). Default ``"main_sequence"``.
64
+ rms_norm_eps: Epsilon for RMSNorm layers.
65
+ local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
66
+ Paper uses b=10000.
67
+ mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
68
+ b=10000.
69
+ training_sequence_length: Context length ``C_train`` the model was or will be
70
+ trained at. Used to compute the YaRN scale factor for BEA.
71
+ inference_sequence_length: Context length ``C_target`` the model must support
72
+ at inference. Optional; defaults to ``training_sequence_length`` so that
73
+ ``scale=1`` and YaRN reduces to standard RoPE unless explicitly extended.
74
+ alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
75
+ ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
76
+ beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
77
+ ``r(d) > beta`` are left unscaled. Paper value: 32.0.
78
+ attention_dropout: Dropout probability on attention weights. Default 0.0.
79
+ use_cache: Whether to return past_key_values for KV caching.
80
+ output_hidden_states: Whether to return hidden states after each layer.
81
+ tie_word_embeddings: Whether input embedding and LM head share weights.
82
+ mosrah_overallocation_factor: Overallocation multiplier for the expert packing
83
+ buffer. ``mosrah_packed_length`` = ceil(training_sequence_length *
84
+ num_selected_heads / num_mosrah_heads * mosrah_overallocation_factor).
85
+ Must be > 1.0 to guarantee a buffer larger than the balanced-routing
86
+ baseline. Default 2.0.
87
+ load_balance_p: Exponent p for the p-mean aggregation of per-item routing
88
+ frequencies into the load balance signal. Higher p weights aggregation
89
+ toward the worst-case batch item, making the correction signal more
90
+ sensitive to per-item allocation spikes. Must be positive. Default 2.0.
91
+ """
92
+
93
+ model_type = "shram"
94
+
95
+ auto_map = {
96
+ "AutoConfig": "configuration.ShramConfig",
97
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
98
+ }
99
+
100
+ def __init__(
101
+ self,
102
+ vocab_size: int = 50277,
103
+ embedding_width: int = 512,
104
+ mlp_width: int = 1366,
105
+ num_decoder_layers: int = 12,
106
+ num_sliding_window_heads: int = 16,
107
+ num_mosrah_heads: int = 16,
108
+ num_selected_heads: int = 16,
109
+ head_dim: int = 16,
110
+ window_size: int = 128,
111
+ rope_mode: str = "main_sequence",
112
+ rms_norm_eps: float = 1e-5,
113
+ local_rope_theta: float = 10000.0,
114
+ mosrah_rope_theta: float = 10000.0,
115
+ training_sequence_length: int = 1024,
116
+ inference_sequence_length: int | None = None,
117
+ alpha: float = 1.0,
118
+ beta: float = 32.0,
119
+ attention_dropout: float = 0.0,
120
+ use_cache: bool = True,
121
+ output_hidden_states: bool = False,
122
+ tie_word_embeddings: bool = False,
123
+ mosrah_overallocation_factor: float = 2.0,
124
+ load_balance_p: float = 2.0,
125
+ **kwargs
126
+ ):
127
+ if head_dim % 2 != 0:
128
+ raise ValueError(
129
+ f"head_dim must be even (RoPE rotates dimensions in pairs). "
130
+ f"Got head_dim={head_dim}."
131
+ )
132
+
133
+ if rope_mode not in {"main_sequence", "semantic_sequence"}:
134
+ raise ValueError(
135
+ f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
136
+ f"got '{rope_mode}'."
137
+ )
138
+
139
+ if training_sequence_length <= 0:
140
+ raise ValueError(
141
+ f"training_sequence_length must be positive, "
142
+ f"got {training_sequence_length}."
143
+ )
144
+
145
+ if inference_sequence_length is None:
146
+ inference_sequence_length = training_sequence_length
147
+ if inference_sequence_length <= 0:
148
+ raise ValueError(
149
+ f"inference_sequence_length must be positive, "
150
+ f"got {inference_sequence_length}."
151
+ )
152
+
153
+ if mosrah_overallocation_factor <= 1.0:
154
+ raise ValueError(
155
+ f"mosrah_overallocation_factor must be > 1.0 to guarantee a packed "
156
+ f"buffer larger than the balanced-routing baseline. "
157
+ f"Got {mosrah_overallocation_factor}."
158
+ )
159
+
160
+ if load_balance_p <= 0.0:
161
+ raise ValueError(
162
+ f"load_balance_p must be positive, got {load_balance_p}."
163
+ )
164
+
165
+ self.vocab_size = vocab_size
166
+ self.embedding_width = embedding_width
167
+ self.mlp_width = mlp_width
168
+ self.num_decoder_layers = num_decoder_layers
169
+ self.num_sliding_window_heads = num_sliding_window_heads
170
+ self.num_mosrah_heads = num_mosrah_heads
171
+ self.num_selected_heads = num_selected_heads
172
+ self.head_dim = head_dim
173
+ self.window_size = window_size
174
+ self.rope_mode = rope_mode
175
+ self.rms_norm_eps = rms_norm_eps
176
+ self.local_rope_theta = local_rope_theta
177
+ self.mosrah_rope_theta = mosrah_rope_theta
178
+ self.training_sequence_length = training_sequence_length
179
+ self.inference_sequence_length = inference_sequence_length
180
+ self.alpha = alpha
181
+ self.beta = beta
182
+ self.mosrah_overallocation_factor = mosrah_overallocation_factor
183
+ self.load_balance_p = load_balance_p
184
+ self.attention_dropout = attention_dropout
185
+ self.use_cache = use_cache
186
+
187
+ super().__init__(
188
+ tie_word_embeddings=tie_word_embeddings,
189
+ output_hidden_states=output_hidden_states,
190
+ **kwargs
191
+ )
192
+
193
+ # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
194
+ # serialises it into config.json.
195
+ self.auto_map = type(self).auto_map
196
+
197
+ @property
198
+ def scale(self) -> float:
199
+ """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
200
+
201
+ When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
202
+ adjustments cancel and A_rope = 1. This is the default state.
203
+ """
204
+ return self.inference_sequence_length / self.training_sequence_length
205
+
206
+ @property
207
+ def mosrah_packed_length(self) -> int:
208
+ """Static packed time dimension T for expert packing.
209
+
210
+ The expected tokens per expert under perfectly balanced routing is
211
+ ``training_sequence_length * num_selected_heads / num_mosrah_heads``.
212
+ Multiplying by ``mosrah_overallocation_factor`` provides a buffer above
213
+ that baseline. The ceiling ensures T is always an integer >= 1.
214
+
215
+ All consumers of the packed buffer size must read this property rather
216
+ than deriving T independently.
217
+ """
218
+ return math.ceil(
219
+ self.training_sequence_length
220
+ * self.num_selected_heads
221
+ / self.num_mosrah_heads
222
+ * self.mosrah_overallocation_factor
223
+ )
224
+
225
+ @property
226
+ def mosrah_cache_length(self) -> int:
227
+ """Static per-(batch, head) slot capacity for the MoSRAH inference cache.
228
+
229
+ The expected tokens per expert over the full inference context under perfectly
230
+ balanced routing is ``inference_sequence_length * num_selected_heads /
231
+ num_mosrah_heads``. Multiplying by ``mosrah_overallocation_factor`` provides
232
+ a buffer above that baseline. The ceiling ensures the result is always an
233
+ integer >= 1.
234
+
235
+ Distinct from ``mosrah_packed_length``, which sizes the training packing buffer
236
+ using ``training_sequence_length``. This property uses
237
+ ``inference_sequence_length`` because the cache must hold the full accumulated
238
+ token history across the entire inference run.
239
+
240
+ All consumers of the MoSRAH cache buffer size must read this property rather
241
+ than deriving the capacity independently.
242
+ """
243
+ return math.ceil(
244
+ self.inference_sequence_length
245
+ * self.num_selected_heads
246
+ / self.num_mosrah_heads
247
+ * self.mosrah_overallocation_factor
248
+ )
249
+
decoder_layer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder layer — a single transformer block.
2
+
3
+ Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
4
+ residual connections around both sublayers:
5
+
6
+ normed_attn = RMSNorm(x)
7
+ attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
8
+ h = x + attn_out
9
+
10
+ normed_mlp = RMSNorm(h)
11
+ mlp_out = SwiGLUMLP(normed_mlp)
12
+ out = h + mlp_out
13
+
14
+ Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
15
+ through unnormalised residuals at depth, and each sublayer receives a stable,
16
+ normalised view of the signal.
17
+
18
+ Two independent RMSNorm instances are used — one before attention, one before
19
+ MLP. They learn different scalings because they precede layers with different
20
+ dynamic ranges. Sharing them would be wrong.
21
+
22
+ torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
23
+ subtraction, is faster than LayerNorm, and proved more stable at scale.
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from .__attention__shram import SHRAMHybridLayer
30
+ from .__cache__shram_layer_cache import ShramLayerCache
31
+ from .configuration import ShramConfig
32
+ from .mlp import SwiGLUMLP
33
+
34
+
35
+ class DecoderLayer(nn.Module):
36
+ """A single pre-norm SHRAM decoder block.
37
+
38
+ Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
39
+ independent RMSNorm instances on each sublayer input.
40
+
41
+ Args:
42
+ config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
43
+ in addition to the fields required by SHRAMHybridLayer and
44
+ SwiGLUMLP.
45
+ """
46
+
47
+ def __init__(self, config: ShramConfig) -> None:
48
+ super().__init__()
49
+ self.attn_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
50
+ self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
51
+ self.attention = SHRAMHybridLayer(config)
52
+ self.mlp = SwiGLUMLP(config)
53
+
54
+ def num_mosrah_parameters(self) -> int:
55
+ """Return the total number of trainable MoSRAH parameters in this decoder layer."""
56
+ return self.attention.num_mosrah_parameters()
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ position_ids: torch.Tensor,
62
+ active_mask: torch.Tensor,
63
+ cache: ShramLayerCache | None = None,
64
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65
+ """Apply one decoder block to the input.
66
+
67
+ Args:
68
+ x: Input of shape (batch, seq_len, hidden_size).
69
+ position_ids: Authoritative positions of shape (batch, seq_len).
70
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
71
+ where True means the token is semantically live. Forwarded
72
+ unchanged to the hybrid attention layer.
73
+ cache: Optional per-layer SHRAM cache passed through to the hybrid
74
+ attention layer unchanged.
75
+
76
+ Returns:
77
+ output: Tensor of shape (batch, seq_len, hidden_size).
78
+ load_balance_loss: Scalar sparse-path load-balance loss propagated
79
+ from SHRAMHybridLayer.
80
+ max_vio: Detached scalar routing-imbalance summary. Passed through
81
+ unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
82
+ """
83
+ attn_out, load_balance_loss, max_vio = self.attention(
84
+ hidden_states=self.attn_norm(x),
85
+ position_ids=position_ids,
86
+ active_mask=active_mask,
87
+ cache=cache,
88
+ )
89
+ hidden_states = x + attn_out
90
+ output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
91
+ return output, load_balance_loss, max_vio
huggingface.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace causal-LM wrapper for SHRAM.
2
+
3
+ ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
4
+ It owns token embedding lookup, LM-head projection, wrapper-level next-token
5
+ cross-entropy loss, config-controlled tied embeddings, and generation/cache
6
+ orchestration at the wrapper boundary.
7
+
8
+ The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
9
+ hidden states together with current position IDs, a current active mask, and an
10
+ optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
11
+ or causal-LM loss.
12
+
13
+ HuggingFace generation reaches this wrapper with two different tensor
14
+ conventions:
15
+
16
+ - ``position_ids`` is a current-step tensor. GenerationMixin updates the total
17
+ sequence state between steps, then slices position-bearing tensors back down
18
+ before calling ``forward()``.
19
+ - ``attention_mask`` is a full 2D mask over the total sequence so far. This
20
+ wrapper slices its recent chunk to produce the current semantic liveness mask
21
+ expected by the backbone.
22
+
23
+ Generation-created caches are handled in ``_prepare_cache_for_generation``.
24
+ That hook ensures HuggingFace generation uses ShramCache rather than a generic
25
+ dynamic cache. The direct ``forward()`` path does not silently create caches;
26
+ when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
27
+ """
28
+
29
+ ### manifest
30
+ # HuggingFace's check_imports only resolves one level of imports from the
31
+ # entry point file. Without the imports below, any module not directly
32
+ # imported here would be missing when loading from a local path via
33
+ # from_pretrained. These imports are intentionally non-functional (aliased
34
+ # to _ ) to make clear they exist solely for dependency resolution.
35
+ #### imports
36
+ from . import __attention__bottlenecked_ensemble_attention as _
37
+ from . import __attention__expert_packing as _
38
+ from . import __attention__load_balance_loss as _
39
+ from . import __attention__mosrah as _
40
+ from . import __attention__positions_converter as _
41
+ from . import __attention__router as _
42
+ from . import __attention__shram as _
43
+ from . import __attention__sliding_window_attention as _
44
+ from . import __cache__mosrah_cache as _
45
+ from . import __cache__shram_cache as _
46
+ from . import __cache__shram_layer_cache as _
47
+ from . import __cache__sliding_window_cache as _
48
+ from . import __cache__slow_mosrah_cache as _
49
+ from . import __init__ as _
50
+ from . import configuration as _
51
+ from . import decoder_layer as _
52
+ from . import mlp as _
53
+ from . import model as _
54
+ from . import rope as _
55
+ # end manifest
56
+
57
+ from dataclasses import dataclass
58
+ from typing import Any
59
+
60
+ import torch
61
+ import torch.nn as nn
62
+ from transformers import GenerationMixin, PreTrainedModel
63
+ from transformers.cache_utils import Cache
64
+ from transformers.generation.configuration_utils import GenerationMode
65
+ from transformers.modeling_outputs import CausalLMOutputWithPast
66
+
67
+ from .__cache__shram_cache import ShramCache
68
+ from .configuration import ShramConfig
69
+ from .model import ShramModel
70
+
71
+
72
+ @dataclass
73
+ class ShramCausalLMOutput(CausalLMOutputWithPast):
74
+ """SHRAM causal-LM wrapper output.
75
+
76
+ This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
77
+ Dataclass inheritance is sufficient here: all standard causal-LM fields and
78
+ ModelOutput behavior are inherited from the parent, and this subclass adds
79
+ only the SHRAM-specific wrapper outputs.
80
+ """
81
+
82
+ ce_loss: torch.FloatTensor | None = None
83
+ load_balance_loss: torch.FloatTensor | None = None
84
+ max_vio: torch.FloatTensor | None = None
85
+
86
+
87
+ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
88
+ """HuggingFace-facing causal language model wrapper for SHRAM.
89
+
90
+ Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
91
+ tied embedding configuration, and generation/cache boundary behavior.
92
+ Delegates all transformer computation to ``ShramModel``.
93
+
94
+ Args:
95
+ config: SHRAM model configuration.
96
+ """
97
+
98
+ config_class = ShramConfig
99
+ base_model_prefix = "model"
100
+ _no_split_modules = ["DecoderLayer"]
101
+ supports_gradient_checkpointing = True
102
+
103
+ def __init__(self, config: ShramConfig) -> None:
104
+ super().__init__(config)
105
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_width)
106
+ self.model = ShramModel(config)
107
+ self.lm_head = nn.Linear(config.embedding_width, config.vocab_size, bias=False)
108
+ self._configure_tied_embeddings()
109
+ self.post_init()
110
+
111
+ def _configure_tied_embeddings(self) -> None:
112
+ """Apply config-controlled tied embedding behavior on this instance."""
113
+ if self.config.tie_word_embeddings:
114
+ self.lm_head.weight = self.embed_tokens.weight
115
+ self._tied_weights_keys = {
116
+ "lm_head.weight": "embed_tokens.weight",
117
+ }
118
+ else:
119
+ self._tied_weights_keys = {}
120
+
121
+ def num_mosrah_parameters(self) -> int:
122
+ """Return the total number of trainable parameters belonging to MoSRAH layers.
123
+
124
+ Aggregates across all decoder layers. Excludes sliding-window path parameters,
125
+ FFN parameters, norms, and embeddings. Use this for experimental plotting of
126
+ MoSRAH parameter count versus performance.
127
+
128
+ Returns:
129
+ Total count of trainable MoSRAH parameters.
130
+ """
131
+ return self.model.num_mosrah_parameters()
132
+
133
+ def get_input_embeddings(self) -> nn.Embedding:
134
+ """Return the token embedding matrix."""
135
+ return self.embed_tokens
136
+
137
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
138
+ """Replace the token embedding matrix."""
139
+ self.embed_tokens = value
140
+ self._configure_tied_embeddings()
141
+
142
+ def get_output_embeddings(self) -> nn.Linear:
143
+ """Return the LM head."""
144
+ return self.lm_head
145
+
146
+ def set_output_embeddings(self, value: nn.Linear) -> None:
147
+ """Replace the LM head."""
148
+ self.lm_head = value
149
+ self._configure_tied_embeddings()
150
+
151
+ def _build_shram_cache(
152
+ self,
153
+ batch_size: int,
154
+ device: torch.device,
155
+ ) -> ShramCache:
156
+ """Construct a fresh top-level SHRAM cache."""
157
+ return ShramCache(
158
+ config=self.config,
159
+ batch_size=batch_size,
160
+ device=device,
161
+ )
162
+
163
+ def _validate_generation_cache_request(
164
+ self,
165
+ generation_config: Any,
166
+ model_kwargs: dict[str, Any],
167
+ generation_mode: GenerationMode,
168
+ ) -> None:
169
+ """Validate SHRAM's generation-side cache policy."""
170
+ if generation_mode in {
171
+ GenerationMode.ASSISTED_GENERATION,
172
+ GenerationMode.CONTRASTIVE_SEARCH,
173
+ }:
174
+ raise NotImplementedError(
175
+ "ShramForCausalLM does not currently support assisted generation "
176
+ "or contrastive search because ShramCache does not support crop()."
177
+ )
178
+
179
+ user_defined_cache = model_kwargs.get("past_key_values")
180
+ if user_defined_cache is not None:
181
+ if generation_config.cache_implementation is not None:
182
+ raise ValueError(
183
+ "Passing both `cache_implementation` and `past_key_values` "
184
+ "is unsupported. Please use only one."
185
+ )
186
+ if isinstance(user_defined_cache, tuple):
187
+ raise ValueError(
188
+ "Passing a tuple of `past_key_values` is not supported. "
189
+ "Please use a `ShramCache` instance."
190
+ )
191
+ if not isinstance(user_defined_cache, ShramCache):
192
+ raise TypeError(
193
+ "ShramForCausalLM requires `past_key_values` to be a "
194
+ "`ShramCache` instance."
195
+ )
196
+
197
+ if (
198
+ user_defined_cache is None
199
+ and generation_config.use_cache
200
+ and generation_config.cache_implementation is not None
201
+ ):
202
+ raise ValueError(
203
+ "ShramForCausalLM does not support `cache_implementation`. "
204
+ "Generation-created caches must be `ShramCache` objects."
205
+ )
206
+
207
+ def _prepare_cache_for_generation(
208
+ self,
209
+ generation_config: Any,
210
+ model_kwargs: dict[str, Any],
211
+ generation_mode: GenerationMode,
212
+ batch_size: int,
213
+ max_cache_length: int,
214
+ ) -> None:
215
+ """Ensure HuggingFace generation uses ShramCache.
216
+
217
+ This is the SHRAM-specific generation hook. The rest of the default
218
+ generation plumbing is kept intact as much as possible.
219
+
220
+ Args:
221
+ generation_config: Active generation configuration.
222
+ model_kwargs: Generation kwargs, updated in place.
223
+ generation_mode: HuggingFace generation mode.
224
+ batch_size: Effective generation batch size.
225
+ max_cache_length: Requested cache length. Accepted but unused here.
226
+ """
227
+ self._validate_generation_cache_request(
228
+ generation_config=generation_config,
229
+ model_kwargs=model_kwargs,
230
+ generation_mode=generation_mode,
231
+ )
232
+
233
+ if model_kwargs.get("past_key_values") is not None:
234
+ return
235
+
236
+ if not generation_config.use_cache:
237
+ return
238
+
239
+ num_repeats = max(
240
+ generation_config.num_beams or 1,
241
+ generation_config.num_return_sequences or 1,
242
+ )
243
+ model_kwargs["past_key_values"] = self._build_shram_cache(
244
+ batch_size=batch_size*num_repeats,
245
+ device=self.embed_tokens.weight.device,
246
+ )
247
+
248
+ def _reorder_cache(
249
+ self,
250
+ past_key_values: Cache,
251
+ beam_idx: torch.Tensor,
252
+ ) -> Cache:
253
+ """Reorder the cache in place for beam search."""
254
+ past_key_values.reorder_cache(beam_idx)
255
+ return past_key_values
256
+
257
+ @staticmethod
258
+ def create_masks_for_generate(
259
+ config: Any,
260
+ inputs_embeds: torch.Tensor,
261
+ attention_mask: torch.Tensor | None,
262
+ past_key_values: Cache | None,
263
+ position_ids: torch.Tensor | None = None,
264
+ **kwargs: Any,
265
+ ) -> torch.Tensor | None:
266
+ """Return the 2D attention_mask unchanged.
267
+
268
+ HuggingFace calls this during compiled generation to convert the 2D
269
+ attention mask into a 4D causal additive-bias mask. SHRAM uses flex
270
+ attention with custom masking and constructs causality internally; the
271
+ 4D format is incompatible with the SHRAM masking contract. Overriding
272
+ as a no-op restores symmetry between compiled and non-compiled pathways
273
+ without any loss of correctness or performance (see Unit 19.G.4).
274
+ """
275
+ return attention_mask
276
+
277
+ def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
278
+ """Validate token IDs at the wrapper boundary."""
279
+ if input_ids.ndim != 2:
280
+ raise ValueError("input_ids must have shape (batch, seq_len).")
281
+ if input_ids.shape[1] == 0:
282
+ raise ValueError("input_ids sequence length must be nonzero.")
283
+ if input_ids.dtype != torch.long:
284
+ raise TypeError("input_ids must be an long int tensor.")
285
+
286
+ def _validate_attention_mask(
287
+ self,
288
+ input_ids: torch.Tensor,
289
+ attention_mask: torch.Tensor | None,
290
+ ) -> None:
291
+ """Validate the full-sequence attention mask."""
292
+ if attention_mask is None:
293
+ return
294
+ if attention_mask.ndim != 2:
295
+ raise ValueError("attention_mask must have shape (batch, total_seq_len).")
296
+ if attention_mask.shape[0] != input_ids.shape[0]:
297
+ raise ValueError("attention_mask batch dimension must match input_ids.")
298
+ if attention_mask.shape[1] < input_ids.shape[1]:
299
+ raise ValueError(
300
+ "attention_mask must be at least as long as the current input_ids chunk."
301
+ )
302
+
303
+ def _validate_position_ids(
304
+ self,
305
+ input_ids: torch.Tensor,
306
+ position_ids: torch.Tensor | None,
307
+ ) -> None:
308
+ """Validate current-step position IDs."""
309
+ if position_ids is None:
310
+ return
311
+ if position_ids.ndim != 2:
312
+ raise ValueError("position_ids must have shape (batch, seq_len).")
313
+ if position_ids.shape != input_ids.shape:
314
+ raise ValueError(
315
+ "position_ids must match the current input_ids shape exactly."
316
+ )
317
+ if input_ids.dtype != torch.long:
318
+ raise TypeError("position_ids must be an long tensor.")
319
+
320
+ def _validate_labels(
321
+ self,
322
+ input_ids: torch.Tensor,
323
+ labels: torch.Tensor | None,
324
+ ) -> None:
325
+ """Validate label shape at the wrapper boundary."""
326
+ if labels is None:
327
+ return
328
+ if labels.ndim != 2:
329
+ raise ValueError("labels must have shape (batch, seq_len).")
330
+ if labels.shape != input_ids.shape:
331
+ raise ValueError("labels must have the same shape as input_ids.")
332
+ if input_ids.dtype != torch.long:
333
+ raise TypeError("labels must be a long tensor.")
334
+
335
+ def _validate_cache_inputs(
336
+ self,
337
+ use_cache: bool,
338
+ past_key_values: Cache | None,
339
+ ) -> None:
340
+ """Validate cache policy for direct wrapper calls."""
341
+ if use_cache:
342
+ if past_key_values is None:
343
+ raise ValueError(
344
+ "use_cache=True requires an explicit ShramCache. During "
345
+ "generate(), HuggingFace should supply this through "
346
+ "_prepare_cache_for_generation()."
347
+ )
348
+ if not isinstance(past_key_values, ShramCache):
349
+ raise TypeError(
350
+ "past_key_values must be a ShramCache when use_cache=True."
351
+ )
352
+ return
353
+
354
+ if past_key_values is not None:
355
+ raise ValueError("past_key_values was provided while use_cache=False.")
356
+
357
+ def _validate_position_sources(
358
+ self,
359
+ use_cache: bool,
360
+ attention_mask: torch.Tensor | None,
361
+ position_ids: torch.Tensor | None,
362
+ ) -> None:
363
+ """Validate that cached forward has a truthful source of positions."""
364
+ if use_cache and attention_mask is None and position_ids is None:
365
+ raise ValueError(
366
+ "Cached forward requires either position_ids or attention_mask."
367
+ )
368
+
369
+ def _validate_hf_boundary(
370
+ self,
371
+ output_attentions: bool | None,
372
+ return_dict: bool | None,
373
+ inputs_embeds: torch.Tensor | None,
374
+ cache_position: torch.Tensor | None,
375
+ extra_kwargs: dict[str, Any],
376
+ ) -> None:
377
+ """Validate unsupported HuggingFace-facing wrapper inputs."""
378
+ if output_attentions:
379
+ raise NotImplementedError(
380
+ "ShramForCausalLM does not expose output_attentions."
381
+ )
382
+ if return_dict is False:
383
+ raise ValueError(
384
+ "return_dict=False is not supported. "
385
+ "ShramForCausalLM always returns ShramCausalLMOutput."
386
+ )
387
+ if inputs_embeds is not None:
388
+ raise ValueError(
389
+ "inputs_embeds is not supported at the SHRAM wrapper boundary. "
390
+ "Pass input_ids instead."
391
+ )
392
+ if extra_kwargs:
393
+ unsupported = ", ".join(sorted(extra_kwargs))
394
+ raise TypeError(
395
+ f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
396
+ )
397
+
398
+ @staticmethod
399
+ def _enforce_uncached_starting_position(condition: torch.Tensor) -> None:
400
+ """Enforce that an uncached forward pass begins at position 0.
401
+
402
+ An uncached forward has no prior KV state. Nonzero starting positions
403
+ produce silently incorrect RoPE encoding and attention outputs with no
404
+ downstream diagnostic. This method intercepts that misuse at the
405
+ outermost boundary before any backbone computation runs.
406
+
407
+ To resolve a violation: either supply a ShramCache populated with the
408
+ prefix (for continued decoding), or rebase the sequence so positions
409
+ start at 0.
410
+
411
+ Args:
412
+ condition: Scalar bool tensor. True = all batch items start at 0
413
+ (valid); False = at least one batch item starts nonzero
414
+ (violated).
415
+ """
416
+ if torch.compiler.is_compiling():
417
+ # bool.item() is not captured as a SymBool by dynamo; converting to
418
+ # int first produces a SymInt, and the Python comparison (!=0) then
419
+ # yields a SymBool that torch._check folds into the compiled graph.
420
+ condition_as_int = condition.to(torch.int).item()
421
+ torch._check(condition_as_int != 0)
422
+ else:
423
+ if not condition.item():
424
+ raise RuntimeError(
425
+ "Uncached ShramForCausalLM forward does not support nonzero "
426
+ "starting positions. Either provide a ShramCache populated "
427
+ "with the prefix for continued decoding, or rebase the "
428
+ "uncached sequence to start at 0.",
429
+ )
430
+
431
+ @staticmethod
432
+ def _enforce_capture_scalar_outputs() -> None:
433
+ """Enforce that capture_scalar_outputs is enabled when compiling.
434
+
435
+ The safety checks in this model (e.g. position-zero constraint, packing
436
+ overflow detection) rely on torch._check folding into the compiled graph,
437
+ which requires torch._dynamo.config.capture_scalar_outputs = True. Without
438
+ it those checks are silently absent in the compiled model while appearing
439
+ to work in eager mode — a misconfiguration with no diagnostic output.
440
+
441
+ This method fires during dynamo tracing so the missing flag is surfaced
442
+ immediately at compile time rather than discovered from downstream failures.
443
+ """
444
+ if torch.compiler.is_compiling():
445
+ torch._check(
446
+ torch._dynamo.config.capture_scalar_outputs,
447
+ lambda: RuntimeError(
448
+ "ShramForCausalLM requires torch._dynamo.config.capture_scalar_outputs = True "
449
+ "when compiled. Without it, runtime safety checks (position constraints, "
450
+ "overflow detection) are silently absent in the compiled model. Set the flag "
451
+ "before calling torch.compile()."
452
+ ),
453
+ )
454
+
455
+ def _standardize_full_attention_mask(
456
+ self,
457
+ input_ids: torch.Tensor,
458
+ attention_mask: torch.Tensor | None,
459
+ ) -> torch.BoolTensor:
460
+ """Return a concrete full-sequence boolean attention mask."""
461
+ if attention_mask is None:
462
+ return torch.ones_like(input_ids, dtype=torch.bool)
463
+ return attention_mask.to(dtype=torch.bool)
464
+
465
+ def _resolve_current_position_ids(
466
+ self,
467
+ input_ids: torch.Tensor,
468
+ position_ids: torch.Tensor | None,
469
+ full_attention_mask: torch.BoolTensor,
470
+ ) -> torch.LongTensor:
471
+ """Resolve concrete current-step position IDs for the backbone."""
472
+ if position_ids is not None:
473
+ return position_ids.to(dtype=torch.long)
474
+
475
+ full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
476
+ full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
477
+ current_length = input_ids.shape[1]
478
+ return full_position_ids[:, -current_length:]
479
+
480
+ def forward(
481
+ self,
482
+ input_ids: torch.Tensor,
483
+ attention_mask: torch.Tensor | None = None,
484
+ position_ids: torch.Tensor | None = None,
485
+ past_key_values: Cache | None = None,
486
+ use_cache: bool | None = None,
487
+ output_hidden_states: bool | None = None,
488
+ labels: torch.Tensor | None = None,
489
+ return_dict: bool | None = None,
490
+ ce_weight: float = 1.0,
491
+ load_balance_weight: float = 0.01,
492
+ **kwargs: Any,
493
+ ) -> ShramCausalLMOutput:
494
+ """Run the SHRAM causal language model wrapper.
495
+
496
+ Args:
497
+ input_ids: Current token IDs of shape ``(batch, seq_len)``.
498
+ attention_mask: Optional full 2D mask of shape
499
+ ``(batch, total_seq_len)``. The wrapper slices its recent chunk
500
+ to produce the current semantic liveness mask expected by the
501
+ backbone.
502
+ position_ids: Optional current-step position IDs of shape
503
+ ``(batch, seq_len)``. In ordinary HuggingFace generation this is
504
+ already the current-step tensor when it reaches ``forward()``.
505
+ past_key_values: Optional SHRAM cache. Required when
506
+ ``use_cache=True``.
507
+ use_cache: Whether to use and return a cache. Defaults to
508
+ ``config.use_cache``.
509
+ output_hidden_states: Whether to return backbone hidden states.
510
+ Defaults to ``config.output_hidden_states``.
511
+ labels: Optional target token IDs of shape ``(batch, seq_len)``.
512
+ return_dict: Must be ``True`` or ``None``.
513
+ ce_weight: Weight applied to the cross-entropy loss when combining with
514
+ the load-balance loss. Default 1.0.
515
+ load_balance_weight: Weight applied to the load-balance auxiliary loss.
516
+ Default 0.01, matching the paper's recommendation.
517
+ **kwargs: Unsupported HuggingFace kwargs fail explicitly.
518
+
519
+ Returns:
520
+ ``ShramCausalLMOutput`` with:
521
+ - ``logits`` of shape ``(batch, seq_len, vocab_size)``,
522
+ - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
523
+ when labels are provided (``None`` otherwise),
524
+ - ``ce_loss`` — raw unweighted cross-entropy loss for logging,
525
+ - ``past_key_values`` as the active ``ShramCache`` or ``None``,
526
+ - ``hidden_states`` when requested,
527
+ - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
528
+ - detached ``max_vio`` from the backbone.
529
+ """
530
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
531
+ output_hidden_states = (
532
+ output_hidden_states
533
+ if output_hidden_states is not None
534
+ else self.config.output_hidden_states
535
+ )
536
+
537
+ inputs_embeds = kwargs.pop("inputs_embeds", None)
538
+ output_attentions = kwargs.pop("output_attentions", None)
539
+ cache_position = kwargs.pop("cache_position", None)
540
+
541
+ # ------------------------------------------------------------------
542
+ # Validation zone.
543
+ #
544
+ # The wrapper boundary is where HuggingFace-facing inputs are judged
545
+ # for truthfulness before any internal work begins. These checks are
546
+ # intentionally front-loaded so the core logic below can assume one
547
+ # coherent interpretation of the call rather than defensively checking
548
+ # shapes, cache policy, or unsupported HF knobs at the point of use.
549
+ # This keeps the main sequence readable while ensuring invalid states
550
+ # fail before they can silently contaminate backbone execution.
551
+ # ------------------------------------------------------------------
552
+ self._enforce_capture_scalar_outputs()
553
+ self._validate_input_ids(input_ids)
554
+ self._validate_attention_mask(input_ids, attention_mask)
555
+ self._validate_position_ids(input_ids, position_ids)
556
+ self._validate_labels(input_ids, labels)
557
+ self._validate_cache_inputs(use_cache, past_key_values)
558
+ self._validate_position_sources(use_cache, attention_mask, position_ids)
559
+ self._validate_hf_boundary(
560
+ output_attentions=output_attentions,
561
+ return_dict=return_dict,
562
+ inputs_embeds=inputs_embeds,
563
+ cache_position=cache_position,
564
+ extra_kwargs=kwargs,
565
+ )
566
+
567
+ # ------------------------------------------------------------------
568
+ # Standardization zone.
569
+ #
570
+ # HuggingFace and SHRAM use different boundary conventions: generation
571
+ # carries a full-sequence 2D attention mask, while the SHRAM backbone
572
+ # wants a current-step active mask and concrete current position IDs.
573
+ # This zone collapses those wrapper-facing conventions into one valid
574
+ # backbone-facing state. After this point the core no longer reasons
575
+ # about optional or ambiguous input forms; it works only with concrete
576
+ # tensors whose semantics are already fixed.
577
+ # ------------------------------------------------------------------
578
+ full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
579
+ input_ids=input_ids,
580
+ attention_mask=attention_mask,
581
+ )
582
+ current_length: int = input_ids.shape[1]
583
+ current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
584
+ current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
585
+ input_ids=input_ids,
586
+ position_ids=position_ids,
587
+ full_attention_mask=full_attention_mask,
588
+ )
589
+ shram_cache: ShramCache | None = past_key_values if use_cache else None
590
+
591
+ if shram_cache is None:
592
+ positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
593
+ self._enforce_uncached_starting_position(positions_start_sane)
594
+
595
+ # ------------------------------------------------------------------
596
+ # Core wrapper responsibilities.
597
+ #
598
+ # The wrapper's primary job is kept visible here: convert token IDs to
599
+ # embeddings, delegate transformer computation to ShramModel, project
600
+ # hidden states back to vocabulary logits, optionally compute the
601
+ # wrapper-level shifted next-token loss, and return the HuggingFace-
602
+ # facing output object. The backbone remains responsible only for
603
+ # transformer semantics; token/vocabulary/loss concerns stay here.
604
+ # ------------------------------------------------------------------
605
+ token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
606
+ backbone_outputs = self.model(
607
+ inputs_embeds=token_embeddings,
608
+ position_ids=current_position_ids,
609
+ active_mask=current_active_mask,
610
+ cache=shram_cache,
611
+ output_hidden_states=output_hidden_states,
612
+ )
613
+
614
+ logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
615
+
616
+ ce_loss: torch.FloatTensor | None = None
617
+ loss: torch.FloatTensor | None = None
618
+ if labels is not None:
619
+ shift_logits = logits[:, :-1, :].contiguous()
620
+ shift_labels = labels[:, 1:].contiguous()
621
+ ce_loss = nn.functional.cross_entropy(
622
+ shift_logits.view(-1, self.config.vocab_size),
623
+ shift_labels.view(-1),
624
+ )
625
+ loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
626
+
627
+ return ShramCausalLMOutput(
628
+ loss=loss,
629
+ ce_loss=ce_loss,
630
+ logits=logits,
631
+ past_key_values=backbone_outputs["past_key_values"],
632
+ hidden_states=backbone_outputs["hidden_states"],
633
+ load_balance_loss=backbone_outputs["load_balance_loss"],
634
+ max_vio=backbone_outputs["max_vio"],
635
+ )
mlp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SwiGLU feed-forward sublayer.
2
+
3
+ SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
4
+ element-wise against a separate up-projection:
5
+
6
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
7
+
8
+ The gating mechanism gives the network more expressive control over which features
9
+ to propagate than a plain two-matrix FFN. It requires three weight matrices instead
10
+ of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
11
+ typical of two-matrix FFNs — the total parameter count remains comparable.
12
+
13
+ SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
14
+ — a fixed architectural choice.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class SwiGLUMLP(nn.Module):
24
+ """SwiGLU feed-forward sublayer.
25
+
26
+ Implements the three-matrix SwiGLU FFN used in Llama 3:
27
+
28
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
29
+
30
+ No bias on any projection. SiLU as the gate activation is an architectural
31
+ constant — it is what defines SwiGLU specifically.
32
+
33
+ Args:
34
+ config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
35
+ """
36
+
37
+ def __init__(self, config: PretrainedConfig) -> None:
38
+ super().__init__()
39
+ self.gate_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
40
+ self.up_proj = nn.Linear(config.embedding_width, config.mlp_width, bias=False)
41
+ self.down_proj = nn.Linear(config.mlp_width, config.embedding_width, bias=False)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ """Apply the SwiGLU feed-forward transformation.
45
+
46
+ Args:
47
+ x: Input tensor of shape (batch, seq_len, hidden_size).
48
+
49
+ Returns:
50
+ Output tensor of shape (batch, seq_len, hidden_size).
51
+ """
52
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
model.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer backbone for Shram.
2
+
3
+ ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
4
+ by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
5
+ representations. It has no knowledge of tokens, vocabulary, generation, or the
6
+ HuggingFace causal-LM wrapper contract.
7
+
8
+ Keeping the embedding out of the backbone is the correct convention and makes
9
+ the backbone genuinely modality-agnostic. The token interface — embedding lookup,
10
+ LM head, weight tying, and generation-facing naming conventions — belongs on the
11
+ task wrapper (ShramForCausalLM), which is the only class that knows this
12
+ backbone is being used for language modelling.
13
+
14
+ The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
15
+ each sublayer normalises its own input, leaving the residual stream itself
16
+ unnormalised. After many layers of accumulated residuals, that stream arrives at
17
+ the top with uncontrolled magnitude. The final norm brings it to a well-scaled
18
+ state before any projection. Without it, the LM head would receive signals of
19
+ arbitrary scale.
20
+
21
+ Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
22
+ corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
23
+ same top-level ShramCache object in the output dict. If None is provided, no
24
+ caching occurs.
25
+
26
+ Returns a plain dict with keys:
27
+ - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
28
+ - "past_key_values": the ShramCache object passed in, or None
29
+ - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
30
+ - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
31
+ - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+
37
+ from .__cache__shram_cache import ShramCache
38
+ from .configuration import ShramConfig
39
+ from .decoder_layer import DecoderLayer
40
+
41
+
42
+ class ShramModel(nn.Module):
43
+ """Pure transformer backbone: decoder stack and final normalisation.
44
+
45
+ Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
46
+ and returns contextual representations of the same shape. No token embedding,
47
+ vocabulary projection, or causal-LM lifecycle concerns.
48
+
49
+ RoPE is applied inside each attention layer. Positional information is
50
+ encoded in the relationship between Q and K, not added to the residual
51
+ stream, so the backbone is agnostic to how positions are represented.
52
+
53
+ Args:
54
+ config: Model configuration. Must be a ``ShramConfig`` instance.
55
+ """
56
+
57
+ def __init__(self, config: ShramConfig) -> None:
58
+ super().__init__()
59
+ self.config = config
60
+ self.layers = nn.ModuleList(
61
+ [DecoderLayer(config) for _ in range(config.num_decoder_layers)]
62
+ )
63
+ self.norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
64
+
65
+ def num_mosrah_parameters(self) -> int:
66
+ """Return the total number of trainable MoSRAH parameters across all decoder layers."""
67
+ return sum(layer.num_mosrah_parameters() for layer in self.layers)
68
+
69
+ def forward(
70
+ self,
71
+ inputs_embeds: torch.Tensor,
72
+ position_ids: torch.Tensor,
73
+ active_mask: torch.Tensor,
74
+ cache: ShramCache | None = None,
75
+ output_hidden_states: bool = False,
76
+ ) -> dict:
77
+ """Run the transformer stack over a batch of pre-embedded sequences.
78
+
79
+ Args:
80
+ inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
81
+ position_ids: Absolute positions of shape (batch, seq_len). Required.
82
+ Must be provided explicitly by the caller — this module does not
83
+ infer positions from cache state.
84
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
85
+ where True means the token is semantically live. Forwarded
86
+ unchanged to every decoder layer.
87
+ cache: Optional top-level ShramCache. When provided, each DecoderLayer
88
+ receives its own layer-local cache via ``cache.layers[layer_idx]``.
89
+ The top-level cache object is updated in place and returned unchanged.
90
+ output_hidden_states: When True, the output dict includes a tuple of
91
+ per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
92
+ collected before the final norm.
93
+
94
+ Returns:
95
+ Plain dict with keys:
96
+ - ``"last_hidden_state"``: normed backbone output,
97
+ shape (batch, seq_len, hidden_size).
98
+ - ``"past_key_values"``: the cache object passed in, or None.
99
+ - ``"hidden_states"``: tuple of per-layer activations (including
100
+ inputs_embeds as position 0) if ``output_hidden_states`` is True,
101
+ else None. Collected before the final norm so each entry reflects the
102
+ unnormalised residual stream at that depth.
103
+ - ``"load_balance_loss"``: scalar sum of per-layer SHRAM
104
+ load-balance losses.
105
+ - ``"max_vio"``: detached scalar maximum routing-imbalance across
106
+ all decoder layers. Zero means perfectly balanced routing across
107
+ every layer; higher values identify the worst-case head imbalance.
108
+ """
109
+ hidden_states = inputs_embeds
110
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
111
+ total_load_balance_loss = inputs_embeds.new_zeros(())
112
+ max_vio = inputs_embeds.new_zeros(())
113
+
114
+ for layer_idx, layer in enumerate(self.layers):
115
+ layer_cache = None if cache is None else cache.layers[layer_idx]
116
+ hidden_states, layer_load_balance_loss, layer_max_vio = layer(
117
+ hidden_states,
118
+ position_ids,
119
+ active_mask,
120
+ cache=layer_cache,
121
+ )
122
+ total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
123
+ max_vio = torch.maximum(max_vio, layer_max_vio)
124
+
125
+ if output_hidden_states:
126
+ all_hidden_states = all_hidden_states + (hidden_states,)
127
+
128
+ hidden_states = self.norm(hidden_states)
129
+
130
+ return {
131
+ "last_hidden_state": hidden_states,
132
+ "past_key_values": cache,
133
+ "hidden_states": all_hidden_states,
134
+ "load_balance_loss": total_load_balance_loss,
135
+ "max_vio": max_vio,
136
+ }
rope.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rotary Position Embeddings (RoPE).
2
+
3
+ RoPE encodes position in the *relationship* between query and key vectors. When the
4
+ attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
5
+ a score that depends only on the relative distance — not on absolute positions.
6
+
7
+ Two modes are supported:
8
+
9
+ default Standard RoPE with base frequency b. Each dimension pair d is assigned
10
+ frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
11
+ scaling A_rope = 1.
12
+
13
+ yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
14
+ "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
15
+ §A.2). Three frequency regimes:
16
+ - Low-frequency dimensions (r < α): fully interpolated by scale s.
17
+ These dimensions have long wavelengths relative to the training window
18
+ and must be compressed to avoid out-of-distribution positions.
19
+ - High-frequency dimensions (r > β): left unchanged. Short-wavelength
20
+ dimensions already encode relative position accurately at any scale.
21
+ - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
22
+ Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
23
+ standard RoPE.
24
+
25
+ Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
26
+ parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
+
28
+ Cache sharing: all instances with identical parameters share one cos/sin table via a
29
+ class-level registry. The first instance that needs a particular (parameters, device,
30
+ dtype) combination builds the table; all subsequent instances reference it directly.
31
+ This avoids redundant builds across the num_hidden_layers instances that share the
32
+ same parametrisation.
33
+ """
34
+
35
+ import math
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Rotation helper
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
+ """Apply the 90° rotation used in the RoPE update formula.
47
+
48
+ Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
49
+ Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
50
+ on each consecutive pair of dimensions, matching the block-diagonal operator
51
+ R^u_{Θ,p} in the paper.
52
+ """
53
+ d = x.shape[-1] // 2
54
+ x1, x2 = x[..., :d], x[..., d:]
55
+ return torch.cat([-x2, x1], dim=-1)
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # RotaryEmbedding
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class RotaryEmbedding(nn.Module):
63
+ """Rotary Position Embeddings with explicit mode and parameter control.
64
+
65
+ Each caller constructs its own instance with the exact parameters it needs.
66
+ h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
+ config object is read inside this module.
68
+
69
+ The cos/sin table is built at construction time to cover all positions in
70
+ ``[0, maximum_sequence_length)``. In forward, the table is rebuilt only if
71
+ the query tensor's dtype or device has changed since construction.
72
+
73
+ Instances with identical parameters share one cos/sin table via the class-level
74
+ ``_cache`` registry, avoiding redundant computation across decoder layers.
75
+
76
+ Args:
77
+ mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
78
+ head_dim: Per-head embedding dimension ``u``. Must be even.
79
+ theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
80
+ maximum_sequence_length: Maximum number of positions the table must cover.
81
+ The cos/sin table is preallocated to this length at construction time.
82
+ For ``mode="yarn"``, the training context length C_train is derived
83
+ internally as ``round(maximum_sequence_length / dilation)``.
84
+ dilation: Scale factor ``s = C_target / C_train`` — how much the context
85
+ window is extended beyond training length. Required for ``mode="yarn"``.
86
+ When ``dilation=1.0``, YaRN reduces to standard RoPE.
87
+ alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
88
+ interpolated. Required for ``mode="yarn"``.
89
+ beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
90
+ unchanged. Required for ``mode="yarn"``.
91
+ device: Optional device for initial buffer placement.
92
+
93
+ Raises:
94
+ NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
95
+ ValueError: If ``mode="yarn"`` and any of ``dilation``, ``alpha``,
96
+ ``beta`` are absent.
97
+ """
98
+
99
+ # Maps (freq_key, device_str, dtype_str) → (cos_table, sin_table).
100
+ # Shared across all RotaryEmbedding instances in the process. Keys include device
101
+ # and dtype so that tables built on different devices or in different precisions
102
+ # are stored independently.
103
+ _cache: dict = {}
104
+
105
+ def __init__(
106
+ self,
107
+ mode: str,
108
+ head_dim: int,
109
+ theta: float,
110
+ maximum_sequence_length: int,
111
+ dilation: float | None = None,
112
+ alpha: float | None = None,
113
+ beta: float | None = None,
114
+ device: torch.device | None = None,
115
+ ) -> None:
116
+ super().__init__()
117
+
118
+ self._validate_mode(mode)
119
+ self._validate_yarn_params(mode, dilation, alpha, beta)
120
+ self.mode = mode
121
+ self._maximum_sequence_length = maximum_sequence_length
122
+
123
+ # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
124
+ # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
125
+ # so rotation_freqs has head_dim/2 entries.
126
+ d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
127
+ base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
128
+
129
+ if mode == "default":
130
+ rotation_freqs = base_freqs
131
+ self.attention_scaling: float = 1.0
132
+
133
+ else: # yarn
134
+ s = dilation
135
+
136
+ # C_train is the training context length, recovered from the inference
137
+ # context length and the dilation factor. round() guards against floating
138
+ # point error since both underlying quantities are integers.
139
+ c_train: int = round(maximum_sequence_length / dilation)
140
+
141
+ # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
142
+ # function to classify each dimension into one of three regimes.
143
+ normalized_freqs = c_train * base_freqs / (2.0 * math.pi)
144
+
145
+ # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
146
+ # linear blend between α and β.
147
+ blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
148
+
149
+ # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
150
+ rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
151
+
152
+ # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
153
+ self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
154
+
155
+ # freq_key uniquely identifies the parameter set that produced rotation_freqs,
156
+ # including maximum_sequence_length so instances with different table sizes
157
+ # do not collide in the registry.
158
+ if mode == "default":
159
+ self._freq_key: tuple = ("default", head_dim, theta, maximum_sequence_length)
160
+ else:
161
+ self._freq_key = ("yarn", head_dim, theta, maximum_sequence_length, dilation, alpha, beta)
162
+
163
+ # rotation_freqs is a non-persistent buffer so it moves with the model across
164
+ # devices via .to() / .cuda() without appearing in saved checkpoints.
165
+ # It is stored per-instance rather than in the shared cache because it is
166
+ # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
167
+ # it is used to build. The meaningful sharing win is on those tables.
168
+ self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
169
+
170
+ # Cache tensors are plain instance attributes (not registered buffers) so that
171
+ # sharing across identically-parametrised instances survives .to() calls.
172
+ # Registered buffers are copied on device move; plain attributes are aliased,
173
+ # preserving the shared-tensor identity that the cache design depends on.
174
+ self._cos_cached: torch.Tensor | None = None
175
+ self._sin_cached: torch.Tensor | None = None
176
+
177
+ # Build the table at construction time. Forward rebuilds only on dtype or
178
+ # device change. If no device is specified, build on CPU as the default.
179
+ build_device = device if device is not None else torch.device("cpu")
180
+ self._build_cache(device=build_device, dtype=torch.float32)
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Validation helpers
184
+ # ---------------------------------------------------------------------------
185
+
186
+ @staticmethod
187
+ def _validate_mode(mode: str) -> None:
188
+ """Raise NotImplementedError if mode is not a supported value."""
189
+ if mode not in {"default", "yarn"}:
190
+ raise NotImplementedError(
191
+ f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
192
+ )
193
+
194
+ @staticmethod
195
+ def _validate_yarn_params(
196
+ mode: str,
197
+ dilation: float | None,
198
+ alpha: float | None,
199
+ beta: float | None,
200
+ ) -> None:
201
+ """Raise ValueError if mode='yarn' and any required parameter is absent."""
202
+ if mode != "yarn":
203
+ return
204
+ missing = [
205
+ name for name, val in [
206
+ ("dilation", dilation),
207
+ ("alpha", alpha),
208
+ ("beta", beta),
209
+ ]
210
+ if val is None
211
+ ]
212
+ if missing:
213
+ raise ValueError(f"mode='yarn' requires {missing}.")
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Cache management
217
+ # ---------------------------------------------------------------------------
218
+
219
+ def _build_cache(self, device: torch.device, dtype: torch.dtype) -> None:
220
+ """Build the cos/sin table to cover positions [0, maximum_sequence_length).
221
+
222
+ Checks the class-level registry first. If a table already exists for this
223
+ exact (parameters, device, dtype) combination it is reused directly;
224
+ otherwise it is computed and stored. The instance attributes are pointed at
225
+ the registry entry so that all layers sharing the same parametrisation
226
+ reference the same tensor.
227
+ """
228
+ cache_key = (self._freq_key, str(device), str(dtype))
229
+
230
+ if cache_key not in RotaryEmbedding._cache:
231
+ positions = torch.arange(
232
+ self._maximum_sequence_length, device=device, dtype=torch.float32
233
+ )
234
+ # outer product → (maximum_sequence_length, head_dim // 2);
235
+ # duplicate to (maximum_sequence_length, head_dim)
236
+ freqs = torch.outer(
237
+ positions,
238
+ self.rotation_freqs.to(device=device, dtype=torch.float32),
239
+ )
240
+ angle_embedding = torch.cat((freqs, freqs), dim=-1)
241
+ RotaryEmbedding._cache[cache_key] = (
242
+ angle_embedding.cos().to(dtype),
243
+ angle_embedding.sin().to(dtype),
244
+ )
245
+
246
+ self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
247
+
248
+ def forward(
249
+ self,
250
+ q: torch.Tensor,
251
+ k: torch.Tensor,
252
+ position_ids: torch.Tensor,
253
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
254
+ """Apply rotary embeddings to query and key tensors.
255
+
256
+ The cos/sin table is built at construction time. It is rebuilt here only
257
+ if ``q``'s dtype or device differs from the cached table — for example,
258
+ after moving the model to a different device via ``.cuda()``.
259
+
260
+ ``position_ids`` may be any integer tensor shape. Its values must be in
261
+ ``[0, maximum_sequence_length)``:
262
+
263
+ - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
264
+ - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
265
+
266
+ When q/k have head dimensions absent from position_ids, broadcast dimensions
267
+ are inserted automatically at dim 1.
268
+
269
+ Args:
270
+ q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
271
+ k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
272
+ position_ids: Integer positions of shape (batch, *pos_dims).
273
+
274
+ Returns:
275
+ Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
276
+ 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
277
+ apply to attention logits before softmax.
278
+ """
279
+ wrong_dtype = self._cos_cached.dtype != q.dtype
280
+ wrong_device = self._cos_cached.device != q.device
281
+
282
+ if wrong_dtype or wrong_device:
283
+ self._build_cache(device=q.device, dtype=q.dtype)
284
+
285
+ cos = self._cos_cached[position_ids]
286
+ sin = self._sin_cached[position_ids]
287
+
288
+ # Insert broadcast dimensions for any head axes present in q/k but absent
289
+ # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
290
+ # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
291
+ while cos.ndim < q.ndim:
292
+ cos = cos.unsqueeze(1)
293
+ sin = sin.unsqueeze(1)
294
+
295
+ q_rotated = q * cos + _rotate_half(q) * sin
296
+ k_rotated = k * cos + _rotate_half(k) * sin
297
+
298
+ return q_rotated, k_rotated, self.attention_scaling
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "local_files_only": false,
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "pad_token": "<|padding|>",
11
+ "tokenizer_class": "GPTNeoXTokenizerFast",
12
+ "trim_offsets": true,
13
+ "unk_token": "<|endoftext|>"
14
+ }