smithblack-0 commited on
Commit
72e7455
·
verified ·
1 Parent(s): 03ca6a0

Update architecture and tokenizer

Browse files
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: mit
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ tags:
8
+ - pytorch
9
+ - research
10
+ - sparse-attention
11
+ - mixture-of-experts
12
+ ---
13
+
14
+ # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
+
16
+ A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
+ Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
+ the Hub and instantiate a freshly initialised model from config. Every parameter is
19
+ overridable at instantiation time via kwargs.
20
+
21
+ > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
+ > source files from the Hub and imports them into your Python process. Review the
23
+ > source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use.
24
+
25
+ ## Architecture
26
+
27
+ SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
28
+
29
+ - **h_l** — local sliding-window causal attention path.
30
+ - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
31
+ via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
32
+
33
+ All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
34
+
35
+ ## Usage
36
+
37
+ This repository contains no pretrained weights. The intended workflow is: pull the
38
+ architecture config from the Hub, instantiate a model with fresh random weights, then
39
+ train it yourself.
40
+
41
+ ```python
42
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
43
+
44
+ # Step 1: pull the architecture config from the Hub.
45
+ # AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
46
+ # Override any parameter via kwargs.
47
+ config = AutoConfig.from_pretrained(
48
+ "smithblack-0/SHRAM",
49
+ trust_remote_code=True,
50
+ num_hidden_layers=16, # example override
51
+ num_mosrah_heads=32, # example override
52
+ )
53
+
54
+ # Step 2: instantiate with fresh random weights.
55
+ # from_config never loads a checkpoint — it always produces a randomly initialised model.
56
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
57
+
58
+ # Step 3: load the tokenizer.
59
+ tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
60
+ ```
61
+
62
+ After training your own checkpoint, save and reload it in the standard way:
63
+
64
+ ```python
65
+ model.save_pretrained("./my-checkpoint")
66
+ model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
67
+ ```
68
+
69
+ ## Constructor Defaults
70
+
71
+ The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
72
+ no overrides. They are not the parameters of a pretrained model — this repository
73
+ contains no weights. All values are overridable via kwargs.
74
+
75
+ | Parameter | Default |
76
+ |-----------|---------|
77
+ | `alpha` | 1.0 |
78
+ | `attention_dropout` | 0.0 |
79
+ | `beta` | 32.0 |
80
+ | `dtype` | None |
81
+ | `head_dim` | 16 |
82
+ | `hidden_size` | 512 |
83
+ | `inference_sequence_length` | 1024 |
84
+ | `intermediate_size` | 1366 |
85
+ | `local_rope_theta` | 10000.0 |
86
+ | `mosrah_rope_theta` | 10000.0 |
87
+ | `num_hidden_layers` | 12 |
88
+ | `num_mosrah_heads` | 16 |
89
+ | `num_selected_heads` | 16 |
90
+ | `num_sliding_window_heads` | 16 |
91
+ | `output_hidden_states` | False |
92
+ | `rms_norm_eps` | 1e-05 |
93
+ | `rope_mode` | main_sequence |
94
+ | `tie_word_embeddings` | False |
95
+ | `training_sequence_length` | 1024 |
96
+ | `use_cache` | True |
97
+ | `vocab_size` | 50277 |
98
+ | `window_size` | 128 |
99
+
100
+ ## License
101
+
102
+ MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
103
+ (`EleutherAI/gpt-neox-20b`, Apache 2.0).
__attention__bottlenecked_ensemble_attention.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
+
3
+ BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
+ It consumes packed expert-choice tensors, a supplied position tensor, an active
5
+ token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
+ same packed expert-choice space expected by later unpacking.
7
+
8
+ BEA does not compute positions and does not choose packed-position semantics.
9
+ Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
+ (K̃) and raw values (V) into the sparse cache and attends against the
11
+ accumulated cached state returned by that cache.
12
+ """
13
+
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
+
20
+ from .configuration import ShramConfig
21
+ from .__cache__mosrah_cache import MoSRAHCache
22
+ from .rope import RotaryEmbedding
23
+
24
+
25
+ class BottleneckedEnsembleAttention(nn.Module):
26
+ """
27
+ Packed expert-choice attention operator for the MoSRAH sparse path.
28
+ Operates per-head independently on an ensemble of tokens.
29
+ FlexAttention saves flops on dead tokens.
30
+
31
+ Architectural properties:
32
+ - consumes packed expert-choice tensors of shape (B, L, T, d)
33
+ - uses independent per-head Q/K/V/O projection parameters
34
+ - applies YaRN-capable RoPE using supplied position_ids
35
+ - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
+ - uses a fast fused attention path
37
+ - returns outputs in the same packed expert-choice space (B, L, T, d)
38
+
39
+ Args:
40
+ config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
+ `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
+ `inference_sequence_length`, `alpha`, and `beta`.
43
+ """
44
+
45
+ def __init__(self, config: ShramConfig) -> None:
46
+ super().__init__()
47
+
48
+ self.hidden_size = config.hidden_size
49
+ self.num_heads = config.num_mosrah_heads
50
+ self.head_dim = config.head_dim
51
+
52
+ # Independent per-head projections. No cross-head parameter sharing.
53
+ self.q_proj = nn.Parameter(
54
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
+ )
56
+ self.k_proj = nn.Parameter(
57
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
+ )
59
+ self.v_proj = nn.Parameter(
60
+ torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
+ )
62
+ self.o_proj = nn.Parameter(
63
+ torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
+ )
65
+
66
+ self._reset_parameters()
67
+
68
+ # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
+ # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
+ # no yarn dilation occurs.
71
+ self.rope = RotaryEmbedding(
72
+ mode="yarn",
73
+ head_dim=self.head_dim,
74
+ theta=config.mosrah_rope_theta,
75
+ initial_seq_length=config.training_sequence_length,
76
+ dilation=config.scale,
77
+ alpha=config.alpha,
78
+ beta=config.beta,
79
+ )
80
+
81
+ def forward(
82
+ self,
83
+ packed_embeddings: torch.Tensor,
84
+ position_ids: torch.Tensor,
85
+ active_mask: torch.Tensor,
86
+ cache: MoSRAHCache | None = None,
87
+ ) -> torch.Tensor:
88
+ """Apply BEA to packed expert-choice tensors.
89
+
90
+ Args:
91
+ packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
+ position_ids: Supplied packed positions of shape (B, L, T).
93
+ active_mask: Boolean active-token mask of shape (B, L, T).
94
+ cache: Optional layer-local MoSRAH cache.
95
+
96
+ Returns:
97
+ Packed expert-choice output tensor of shape (B, L, T, d).
98
+ """
99
+ batch_size, _, query_length, _ = packed_embeddings.shape
100
+ self._validate_tensor_shape(packed_embeddings)
101
+ self._validate_position_shape(packed_embeddings, position_ids)
102
+ self._validate_active_mask_shape(packed_embeddings, active_mask)
103
+
104
+ # Independent per-head projections:
105
+ # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
+ query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
+ key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
+ value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
+
110
+ rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
+ query_states,
112
+ key_states,
113
+ position_ids,
114
+ )
115
+
116
+ if cache is not None:
117
+ # In cached execution, the current query tensor uses local tensor rows
118
+ # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
+ # packed sequence for each (batch, head) slot. The only additional data
120
+ # needed to align those two views is the pre-update cached prefix length.
121
+ # which will indicate how many queries were processed before now.
122
+ num_tokens_processed = cache.get_heads_lengths().clone()
123
+ key_states, value_states, key_active_mask = cache.update(
124
+ rotated_key_states,
125
+ value_states,
126
+ active_mask,
127
+ )
128
+ else:
129
+ num_tokens_processed = torch.zeros(
130
+ batch_size,
131
+ self.num_heads,
132
+ dtype=torch.long,
133
+ device=packed_embeddings.device,
134
+ )
135
+ key_states = rotated_key_states
136
+ key_active_mask = active_mask
137
+
138
+ block_mask = self._make_block_mask(
139
+ query_active_mask=active_mask,
140
+ key_active_mask=key_active_mask,
141
+ num_tokens_processed=num_tokens_processed,
142
+ query_length=query_length,
143
+ key_length=key_states.shape[2],
144
+ device=packed_embeddings.device,
145
+ )
146
+ attended_states = flex_attention(
147
+ rotated_query_states,
148
+ key_states,
149
+ value_states,
150
+ block_mask=block_mask,
151
+ scale=attention_scaling / math.sqrt(self.head_dim),
152
+ )
153
+
154
+ # Project back to model width:
155
+ # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
+ return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
+
158
+ def _reset_parameters(self) -> None:
159
+ """Initialize per-head projection weights."""
160
+ for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
+ nn.init.xavier_uniform_(weight)
162
+
163
+ def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
+ """Validate the local packed-embedding shape contract required by BEA."""
165
+ if packed_embeddings.shape[1] != self.num_heads:
166
+ raise ValueError(
167
+ f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
+ f"got {packed_embeddings.shape[1]}."
169
+ )
170
+
171
+ if packed_embeddings.shape[-1] != self.hidden_size:
172
+ raise ValueError(
173
+ f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
+ f"got {packed_embeddings.shape[-1]}."
175
+ )
176
+
177
+ def _validate_position_shape(
178
+ self,
179
+ packed_embeddings: torch.Tensor,
180
+ position_ids: torch.Tensor,
181
+ ) -> None:
182
+ """Validate the supplied packed-position tensor shape."""
183
+ if position_ids.shape != packed_embeddings.shape[:3]:
184
+ raise ValueError(
185
+ f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
+ f"got {tuple(position_ids.shape)}."
187
+ )
188
+
189
+ def _validate_active_mask_shape(
190
+ self,
191
+ packed_embeddings: torch.Tensor,
192
+ active_mask: torch.Tensor,
193
+ ) -> None:
194
+ """Validate the supplied active-token mask shape."""
195
+ if active_mask.shape != packed_embeddings.shape[:3]:
196
+ raise ValueError(
197
+ f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
+ f"got {tuple(active_mask.shape)}."
199
+ )
200
+
201
+ def _make_block_mask(
202
+ self,
203
+ query_active_mask: torch.Tensor,
204
+ key_active_mask: torch.Tensor,
205
+ num_tokens_processed: torch.Tensor,
206
+ query_length: int,
207
+ key_length: int,
208
+ device: torch.device,
209
+ ):
210
+ """Create the packed-sequence causal mask for FlexAttention.
211
+
212
+ At the root, causality is still triangular. The only nuance is cached
213
+ execution: query rows are indexed locally as 0..Q-1 inside the current
214
+ query tensor, but the key tensor may already contain a cached prefix for
215
+ that (batch, head) slot. The causal horizon for query tensor row q is
216
+ therefore:
217
+
218
+ cached_prefix_lengths[b, h] + q
219
+
220
+ Query and key activity masks are then composed with that triangular rule
221
+ so FlexAttention can skip padded query rows and ignore inactive key slots.
222
+ """
223
+ batch_size, num_heads, _ = query_active_mask.shape
224
+
225
+ # Build the per-(batch, head, query_row) triangular horizon from a simple
226
+ # arange over query rows plus the cached prefix lengths for each slot.
227
+ relative_query_positions = torch.arange(
228
+ query_length,
229
+ device=device,
230
+ dtype=torch.long,
231
+ ).view(1, 1, query_length)
232
+ causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
+
234
+ def packed_causal_mask(
235
+ batch_idx: torch.Tensor,
236
+ head_idx: torch.Tensor,
237
+ query_idx: torch.Tensor,
238
+ key_idx: torch.Tensor,
239
+ ) -> torch.Tensor:
240
+ query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
+ key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
+ is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
+ return query_is_active & key_is_active & is_causal
244
+
245
+ return create_block_mask(
246
+ packed_causal_mask,
247
+ B=batch_size,
248
+ H=num_heads,
249
+ Q_LEN=query_length,
250
+ KV_LEN=key_length,
251
+ device=device,
252
+ )
__attention__expert_packing.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Expert packing and unpacking for the MoSRAH path.
2
+
3
+ This module implements the low-level token-choice -> expert-choice -> token-choice
4
+ conversion boundary specified in the paper. The externally visible behavior is fixed:
5
+
6
+ - setup_packing() prepares the auxiliary ordering data.
7
+ - pack_experts() converts routed token-choice state into packed expert-choice state.
8
+ - unpack_experts() restores token-choice ordering afterward.
9
+
10
+ Stable sort is a correctness requirement. It preserves causal ordering inside each
11
+ expert bucket, which is the foundation on which BEA's later triangular causal mask
12
+ is correct.
13
+
14
+ pack_experts() returns two distinct masks that serve different roles and must not
15
+ be interchanged:
16
+
17
+ - unpacking_mask: marks every packed slot that contains a routed token copy,
18
+ live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
+ so its reshape invariant holds regardless of outer token liveness.
20
+ - active_mask: marks only the packed slots whose source token was semantically
21
+ live. This is what BEA consumes for attention gating. Dead outer tokens must
22
+ not influence sparse attention outputs.
23
+ """
24
+
25
+ import torch
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Setup
30
+ # ---------------------------------------------------------------------------
31
+
32
+ def setup_packing(
33
+ selected_heads: torch.Tensor,
34
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ """Prepare the auxiliary ordering data used by pack/unpack.
36
+
37
+ Routing produces token-choice state I of shape (B, N, K): for each token, which
38
+ K experts were selected. Packing needs the same routed token copies reordered into
39
+ expert-major order so each expert bucket becomes contiguous.
40
+
41
+ The paper's setup step does this by flattening (N, K) into one axis to produce
42
+ H in token-major order, then computing a stable argsort permutation Pi over the
43
+ expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
+ expert-major order while preserving their original token order *within* each expert
45
+ bucket. That preservation is why stable sort is required for causality.
46
+
47
+ Args:
48
+ selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
+
50
+ Returns:
51
+ Tuple of:
52
+ - flattened_selected_heads: H of shape (B, N*K)
53
+ - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
+ - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
+ """
56
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
+ flattened_selected_heads = selected_heads.reshape(
58
+ batch_size,
59
+ sequence_length * num_selected_heads,
60
+ )
61
+
62
+ permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
+ inverse_permutation = torch.argsort(permutation, dim=-1)
64
+
65
+ return flattened_selected_heads, permutation, inverse_permutation
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Packing
70
+ # ---------------------------------------------------------------------------
71
+
72
+ def pack_experts(
73
+ hidden_states: torch.Tensor,
74
+ position_ids: torch.Tensor,
75
+ selected_heads: torch.Tensor,
76
+ num_experts: int,
77
+ flattened_selected_heads: torch.Tensor,
78
+ permutation: torch.Tensor,
79
+ outer_active_mask: torch.Tensor,
80
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
+ """Pack token-choice hidden states into expert-choice padded form.
82
+
83
+ The paper's packing path has two jobs:
84
+
85
+ 1. Convert routed token-choice copies into expert-major order.
86
+ 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
+
88
+ The routed hidden-state copies are not stored explicitly in token-choice form.
89
+ Instead, the same token hidden state is conceptually copied once per selected expert.
90
+ The packing step reconstructs those copies by expanding local source-token indices,
91
+ reordering those indices with Pi, then gathering hidden states, positions, and outer
92
+ liveness in that packed order. All three are carried through the same expert-major
93
+ rearrangement so they remain aligned in the packed frame.
94
+
95
+ Packed positions are sourced from the authoritative upstream position_ids tensor
96
+ rather than synthesized locally from arange(N). This preserves advanced positions
97
+ correctly during cached inference while leaving training/full-sequence behavior
98
+ unchanged when position_ids is the ordinary sequential token positions.
99
+
100
+ Args:
101
+ hidden_states: Token-choice hidden states x of shape (B, N, d).
102
+ position_ids: Authoritative upstream token positions J of shape (B, N).
103
+ selected_heads: Routed head selections I of shape (B, N, K).
104
+ num_experts: Total number of experts L.
105
+ flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
+ permutation: Pi from setup_packing(), shape (B, N*K).
107
+ outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
+ means the token is semantically live. Dead tokens do not become
109
+ semantically active in the packed sparse representation.
110
+
111
+ Returns:
112
+ Tuple of:
113
+ - packed_hidden_states: x' of shape (B, L, T, d)
114
+ - packed_positions: J' of shape (B, L, T)
115
+ - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
+ routed token copy, live or dead. Always has exactly B*N*K True entries.
117
+ Pass this to unpack_experts — not active_mask.
118
+ - active_mask: of shape (B, L, T). True only where a slot contains a
119
+ copy of a live outer token. Pass this to BEA for attention gating.
120
+ """
121
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
122
+ _, _, num_selected_heads = selected_heads.shape
123
+
124
+ # -----------------------------------------------------------------------
125
+ # Reconstruct routed local source-token indices in token-choice order.
126
+ #
127
+ # The internal arange(N) is no longer the packed position tensor. It is only
128
+ # the local source-row index object used to gather from the current chunk
129
+ # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
+ # token-major routed-copy order.
131
+ # -----------------------------------------------------------------------
132
+ source_token_indices = torch.arange(
133
+ sequence_length,
134
+ device=hidden_states.device,
135
+ dtype=torch.long,
136
+ ).view(1, sequence_length, 1).expand(
137
+ batch_size,
138
+ sequence_length,
139
+ num_selected_heads,
140
+ )
141
+ flattened_source_indices = source_token_indices.reshape(
142
+ batch_size,
143
+ sequence_length * num_selected_heads,
144
+ )
145
+
146
+ # -----------------------------------------------------------------------
147
+ # Reorder source-token indices into expert-major order.
148
+ #
149
+ # Applying Pi yields the local source-token rows in the packed expert-major
150
+ # order required by the paper. Those same reordered source indices are then
151
+ # used to gather hidden states, authoritative upstream positions, and outer
152
+ # liveness so all three remain aligned under the exact same packing
153
+ # transformation.
154
+ # -----------------------------------------------------------------------
155
+ sorted_source_indices = flattened_source_indices.gather(
156
+ dim=1,
157
+ index=permutation,
158
+ )
159
+ sorted_hidden_states = hidden_states.gather(
160
+ dim=1,
161
+ index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
+ )
163
+ sorted_positions = position_ids.gather(
164
+ dim=1,
165
+ index=sorted_source_indices,
166
+ )
167
+ sorted_active_mask = outer_active_mask.gather(
168
+ dim=1,
169
+ index=sorted_source_indices,
170
+ )
171
+
172
+ # -----------------------------------------------------------------------
173
+ # Count how many routed copies land in each expert bucket.
174
+ #
175
+ # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
+ # T is the maximum such count across all batches and experts; it determines the
177
+ # padded expert-length dimension of the packed representation.
178
+ # -----------------------------------------------------------------------
179
+ tokens_per_expert = _bincount_rows(
180
+ values=flattened_selected_heads,
181
+ num_bins=num_experts,
182
+ )
183
+ max_tokens_per_expert = int(tokens_per_expert.max().item())
184
+
185
+ # -----------------------------------------------------------------------
186
+ # Construct the active-token mask M.
187
+ #
188
+ # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
+ # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
+ # mask therefore both identifies real packed tokens and enforces left-justified
191
+ # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
+ # outer token liveness, and always has exactly B*N*K True entries.
193
+ # -----------------------------------------------------------------------
194
+ time_axis = torch.arange(
195
+ max_tokens_per_expert,
196
+ device=hidden_states.device,
197
+ dtype=torch.long,
198
+ ).view(1, 1, max_tokens_per_expert)
199
+ unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
+
201
+ # -----------------------------------------------------------------------
202
+ # Materialize the padded packed tensors.
203
+ #
204
+ # The packed hidden states x', packed original-token positions J', and packed
205
+ # active-token mask are allocated as zero-filled tensors. Active entries are
206
+ # then written into those buffers in the expert-major order established above.
207
+ # Padding remains zero / inactive.
208
+ # -----------------------------------------------------------------------
209
+ packed_hidden_states = hidden_states.new_zeros(
210
+ batch_size,
211
+ num_experts,
212
+ max_tokens_per_expert,
213
+ hidden_dim,
214
+ )
215
+ packed_positions = position_ids.new_zeros(
216
+ batch_size,
217
+ num_experts,
218
+ max_tokens_per_expert,
219
+ )
220
+ active_mask = torch.zeros(
221
+ batch_size,
222
+ num_experts,
223
+ max_tokens_per_expert,
224
+ dtype=torch.bool,
225
+ device=hidden_states.device,
226
+ )
227
+
228
+ packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
+ packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
+ active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
+
232
+ return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
+
234
+
235
+ # ---------------------------------------------------------------------------
236
+ # Unpacking
237
+ # ---------------------------------------------------------------------------
238
+
239
+ def unpack_experts(
240
+ expert_outputs: torch.Tensor,
241
+ selected_heads: torch.Tensor,
242
+ unpacking_mask: torch.Tensor,
243
+ inverse_permutation: torch.Tensor,
244
+ ) -> torch.Tensor:
245
+ """Restore token-choice ordering from BEA expert-choice output.
246
+
247
+ Unpacking inverts the packing path only on occupied entries. Padding does not
248
+ participate: the output tensor is first filtered by unpacking_mask to recover
249
+ only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
+ the original token-choice ordering, and finally the tensor is reshaped back to
251
+ (B, N, K, d).
252
+
253
+ The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
+ outer tokens occupy slots and must be un-scattered correctly for the inverse
255
+ permutation to hold. The total True entry count in unpacking_mask is always
256
+ B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
+
258
+ Args:
259
+ expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
+ selected_heads: Routed head selections I of shape (B, N, K).
261
+ unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
+ occupied packed slots regardless of outer token liveness.
263
+ inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
+
265
+ Returns:
266
+ Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
+ """
268
+ batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
+ hidden_dim = expert_outputs.shape[-1]
270
+
271
+ active_outputs = expert_outputs[unpacking_mask]
272
+ sorted_token_choice_outputs = active_outputs.reshape(
273
+ batch_size,
274
+ sequence_length * num_selected_heads,
275
+ hidden_dim,
276
+ )
277
+ restored_outputs = sorted_token_choice_outputs.gather(
278
+ dim=1,
279
+ index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
+ )
281
+
282
+ return restored_outputs.reshape(
283
+ batch_size,
284
+ sequence_length,
285
+ num_selected_heads,
286
+ hidden_dim,
287
+ )
288
+
289
+
290
+ # ---------------------------------------------------------------------------
291
+ # Helpers
292
+ # ---------------------------------------------------------------------------
293
+
294
+ def _bincount_rows(
295
+ values: torch.Tensor,
296
+ num_bins: int,
297
+ ) -> torch.Tensor:
298
+ """Count per-row integer occurrences for a 2D tensor.
299
+
300
+ torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
+ one bincount per batch row. The trick used here is to shift each row into its
302
+ own disjoint bin range before flattening:
303
+
304
+ row 0 uses bins [0, ..., num_bins - 1]
305
+ row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
+ row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
+ ...
308
+
309
+ After that shift, one global torch.bincount produces all row-local counts at
310
+ once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
+
312
+ This is a vectorized implementation detail only; externally visible behavior
313
+ remains exactly the paper's S tensor of per-batch per-expert token counts.
314
+
315
+ Args:
316
+ values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
+ num_bins: Number of bins.
318
+
319
+ Returns:
320
+ Counts tensor of shape (B, num_bins).
321
+ """
322
+ batch_size = values.shape[0]
323
+
324
+ row_offsets = torch.arange(
325
+ batch_size,
326
+ device=values.device,
327
+ dtype=values.dtype,
328
+ ).unsqueeze(1) * num_bins
329
+ shifted_values = values + row_offsets
330
+
331
+ counts = torch.bincount(
332
+ shifted_values.reshape(-1),
333
+ minlength=batch_size * num_bins,
334
+ )
335
+ return counts.reshape(batch_size, num_bins)
__attention__load_balance_loss.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
+
3
+ This module implements the custom autograd Function H(b, f) described in the paper's
4
+ Implementation Concerns section. The operator bridges two requirements that are in
5
+ tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
+ multiplication) so that existing training loops remain compatible, while simultaneously
7
+ implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
+ path through the router weights.
9
+
10
+ The resolution is a custom backward pass. The forward emits the load balance imbalance
11
+ as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
+ its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
+ then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
+ correction without a standalone SGD update step.
15
+
16
+ Paper ref: Appendix A.Implementation Concerns.
17
+ """
18
+
19
+ import torch
20
+
21
+
22
+ class LoadBalanceLoss(torch.autograd.Function):
23
+ """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
+
25
+ Forward computes the load balance imbalance:
26
+
27
+ L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
+
29
+ Backward emits a bias-correction gradient to expert_bias:
30
+
31
+ grad_b = L_grad * sign(f_l - 1/L)
32
+
33
+ expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
+ in the computation graph and routes gradients through it. routing_freqs (f) receives
35
+ no gradient — its origin is the discrete TopK operation which has no gradient, so
36
+ defining a gradient for f here would be mathematically incorrect.
37
+
38
+ Paper ref: Appendix A.Implementation Concerns.
39
+ """
40
+
41
+ @staticmethod
42
+ def forward(
43
+ ctx: torch.autograd.function.FunctionCtx,
44
+ expert_bias: torch.Tensor,
45
+ routing_freqs: torch.Tensor,
46
+ ) -> torch.Tensor:
47
+ """Compute the load balance loss.
48
+
49
+ Args:
50
+ ctx: Autograd context for saving state needed in backward.
51
+ expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
+ PyTorch tracks it as a computation graph node needing a gradient.
53
+ routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
+ from the discrete TopK selection — not differentiable.
55
+
56
+ Returns:
57
+ Scalar loss equal to sum_l |f_l - 1/L|.
58
+ """
59
+ L = expert_bias.shape[0]
60
+ # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
+ # underloaded. Saved for backward where sign(imbalance) determines the direction
62
+ # of the bias-correction update.
63
+ imbalance = routing_freqs - 1.0 / L
64
+ ctx.save_for_backward(imbalance)
65
+ return imbalance.abs().sum()
66
+
67
+ @staticmethod
68
+ def backward(
69
+ ctx: torch.autograd.function.FunctionCtx,
70
+ grad_output: torch.Tensor,
71
+ ) -> tuple[torch.Tensor, None]:
72
+ """Emit the DeepSeek-style bias-correction gradient.
73
+
74
+ Args:
75
+ ctx: Autograd context carrying imbalance saved in forward.
76
+ grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
+ by the training loop arrives here and is propagated to grad_b, so the
78
+ correction magnitude is proportional to the loss weight chosen by the
79
+ consumer.
80
+
81
+ Returns:
82
+ Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
+ None for routing_freqs: no gradient is defined for the discrete routing
84
+ frequency.
85
+ """
86
+ (imbalance,) = ctx.saved_tensors
87
+ grad_expert_bias = grad_output * imbalance.sign()
88
+ return grad_expert_bias, None
__attention__mosrah.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full MoSRAH sparse path for SHRAM.
2
+
3
+ This module coordinates the routed sparse attention path used inside the SHRAM
4
+ hybrid attention layer. The underlying mechanics already live in verified
5
+ subunits. The responsibility here is to connect those subunits without
6
+ corrupting their bridge contracts.
7
+
8
+ In particular, this path must preserve three architectural distinctions:
9
+
10
+ - selected head indices are not routing probabilities
11
+ - packed position semantics are chosen before BEA, not inside it
12
+ - weighted reduction must consume the router's unbiased renormalized
13
+ probabilities after token-choice order has been restored
14
+ """
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from .__cache__mosrah_cache import MoSRAHCache
20
+ from .configuration import ShramConfig
21
+ from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
+ from .__attention__expert_packing import (
23
+ pack_experts,
24
+ setup_packing,
25
+ unpack_experts,
26
+ )
27
+ from .__attention__router import MoSRAHRouter
28
+ from .__attention__positions_converter import SparseMoSRAHPositions
29
+
30
+
31
+ class MoSRAHLayer(nn.Module):
32
+ """Full routed sparse attention path for SHRAM.
33
+
34
+ The MoSRAH path consumes model-space hidden states together with
35
+ authoritative per-token positions and returns the model-space sparse-path
36
+ contribution, the router's load-balance loss, and the router's MaxVio
37
+ routing-imbalance scalar.
38
+ """
39
+
40
+ def __init__(self, config: ShramConfig) -> None:
41
+ super().__init__()
42
+ self.num_experts = config.num_mosrah_heads
43
+
44
+ self.router = MoSRAHRouter(config)
45
+ self.positions = SparseMoSRAHPositions(config)
46
+ self.bea = BottleneckedEnsembleAttention(config)
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: torch.Tensor,
51
+ position_ids: torch.Tensor,
52
+ active_mask: torch.Tensor,
53
+ cache: MoSRAHCache | None,
54
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
+ """Run the full MoSRAH sparse path.
56
+
57
+ Args:
58
+ hidden_states: Model-space hidden states x of shape (B, N, d).
59
+ position_ids: Authoritative per-token positions of shape (B, N).
60
+ active_mask: Current-chunk active mask of shape (B, N), where True
61
+ means the token is semantically live. Forwarded to the router
62
+ so dead tokens are excluded from routing statistics, and to
63
+ pack_experts so dead outer tokens do not become semantically
64
+ active packed entries.
65
+ cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
+ execution and the layer-local cache instance for cached execution.
67
+
68
+ Returns:
69
+ sparse_output: Model-space sparse-path output of shape (B, N, d).
70
+ load_balance_loss: Scalar router load-balance loss.
71
+ max_vio: Detached scalar routing-imbalance summary. Passed through
72
+ unchanged from the router; see MoSRAHRouter for semantics.
73
+ """
74
+
75
+ # -------------------------------------------------------------------
76
+ # The first transition moves from model-space token-choice input into
77
+ # the packed expert-choice sparse-attention state. Routing decides both
78
+ # which experts each token uses and which unbiased probabilities must be
79
+ # reserved for the final reduction. The active mask is forwarded to the
80
+ # router so dead tokens are excluded from routing statistics, and to
81
+ # pack_experts so outer liveness is faithfully carried into the packed
82
+ # frame. Packing returns both the unpacking mask (slot occupancy, always
83
+ # B*N*K True entries) and the packed active mask (live slots only);
84
+ # active_mask is rebound to the packed form after this point.
85
+ # -------------------------------------------------------------------
86
+ selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
+ hidden_states, active_mask
88
+ )
89
+
90
+ flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
+ selected_heads
92
+ )
93
+ packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
+ hidden_states=hidden_states,
95
+ position_ids=position_ids,
96
+ selected_heads=selected_heads,
97
+ num_experts=self.num_experts,
98
+ flattened_selected_heads=flattened_selected_heads,
99
+ permutation=permutation,
100
+ outer_active_mask=active_mask,
101
+ )
102
+
103
+ # -------------------------------------------------------------------
104
+ # Sparse attention runs entirely in the packed expert-choice frame, so
105
+ # the RoPE position semantics must also be chosen in that frame. The
106
+ # position layer therefore decides whether BEA should see packed
107
+ # original-token positions or packed local-slot positions. BEA then
108
+ # consumes that packed position tensor together with the packed hidden
109
+ # states and the layer-local sparse cache, which it owns directly.
110
+ # -------------------------------------------------------------------
111
+ bea_positions = self.positions(
112
+ packed_positions=packed_positions,
113
+ cache=cache,
114
+ )
115
+ packed_outputs = self.bea(
116
+ packed_embeddings=packed_hidden_states,
117
+ position_ids=bea_positions,
118
+ active_mask=active_mask,
119
+ cache=cache,
120
+ )
121
+
122
+ # -------------------------------------------------------------------
123
+ # The final transition restores token-choice meaning and only then
124
+ # collapses the K routed copies back into model space. This ordering is
125
+ # required because routing_probs live in token-choice space, whereas BEA
126
+ # returns expert-choice packed outputs. The reduction must therefore
127
+ # happen after unpacking, and it must use the router's unbiased
128
+ # renormalized probabilities rather than any biased selection scores.
129
+ # -------------------------------------------------------------------
130
+ token_choice_outputs = unpack_experts(
131
+ expert_outputs=packed_outputs,
132
+ selected_heads=selected_heads,
133
+ unpacking_mask=unpacking_mask,
134
+ inverse_permutation=inverse_permutation,
135
+ )
136
+ final_output = (
137
+ token_choice_outputs * routing_probs.unsqueeze(-1)
138
+ ).sum(dim=2)
139
+
140
+ return final_output, load_balance_loss, max_vio
__attention__positions_converter.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Position computation for the MoSRAH sparse path.
2
+
3
+ This layer computes the packed position tensor P consumed by BEA.
4
+
5
+ - In main-sequence mode, P is the packed original-token position tensor from the
6
+ packing path.
7
+ - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
+ expert-choice layout, optionally offset by the current sparse-cache occupancies
9
+ during cached inference.
10
+ """
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from .configuration import ShramConfig
16
+ from .__cache__mosrah_cache import MoSRAHCache
17
+
18
+
19
+ class SparseMoSRAHPositions(nn.Module):
20
+ """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
+
22
+ This layer operates in the packed expert-choice frame used by BEA. The input
23
+ packed_positions tensor is always the packed original-token position tensor
24
+ produced by the packing path. The configured rope_mode determines whether that
25
+ tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
+ """
27
+
28
+ def __init__(self, config: ShramConfig) -> None:
29
+ super().__init__()
30
+ self.rope_mode = config.rope_mode
31
+
32
+ def forward(
33
+ self,
34
+ packed_positions: torch.Tensor,
35
+ cache: MoSRAHCache | None,
36
+ ) -> torch.Tensor:
37
+ """Compute the packed position tensor P consumed by BEA.
38
+
39
+ Args:
40
+ packed_positions: Packed original-token positions J' of shape (B, L, T).
41
+ cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
+ mode, the current per-head occupancies offset the local packed sequence.
43
+
44
+ Returns:
45
+ Packed position tensor P of shape (B, L, T).
46
+ """
47
+ if self.rope_mode == "main_sequence":
48
+ return self._main_sequence_positions(packed_positions)
49
+
50
+ if self.rope_mode == "semantic_sequence":
51
+ return self._semantic_sequence_positions(packed_positions, cache)
52
+
53
+ raise NotImplementedError(
54
+ f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
+ )
56
+
57
+ def _main_sequence_positions(
58
+ self,
59
+ packed_positions: torch.Tensor,
60
+ ) -> torch.Tensor:
61
+ """Forward packed original-token positions unchanged."""
62
+ return packed_positions
63
+
64
+ def _semantic_sequence_positions(
65
+ self,
66
+ packed_positions: torch.Tensor,
67
+ cache: MoSRAHCache | None,
68
+ ) -> torch.Tensor:
69
+ """Compute semantic-sequence packed positions in expert-choice space.
70
+
71
+ Without a sparse cache, semantic positions are the local packed sequence
72
+ 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
+ same local sequence is offset by the current per-(batch, expert) occupancies
74
+ returned by get_heads_lengths().
75
+ """
76
+ batch_size, num_experts, packed_length = packed_positions.shape
77
+
78
+ # -------------------------------------------------------------------
79
+ # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
+ # sequence dimension T. This is then broadcast across batch and experts.
81
+ # -------------------------------------------------------------------
82
+ local_positions = torch.arange(
83
+ packed_length,
84
+ device=packed_positions.device,
85
+ dtype=packed_positions.dtype,
86
+ ).view(1, 1, packed_length).expand(
87
+ batch_size,
88
+ num_experts,
89
+ packed_length,
90
+ )
91
+
92
+ # -------------------------------------------------------------------
93
+ # In cached semantic-sequence mode, positions continue from the current
94
+ # sparse-cache occupancies rather than restarting at zero for the local
95
+ # chunk.
96
+ # -------------------------------------------------------------------
97
+ if cache is None:
98
+ return local_positions
99
+
100
+ cached_lengths = cache.get_heads_lengths().to(
101
+ device=packed_positions.device,
102
+ dtype=packed_positions.dtype,
103
+ ).unsqueeze(-1)
104
+
105
+ return local_positions + cached_lengths
__attention__router.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token-choice router for the MoSRAH sparse attention path.
2
+
3
+ This module implements the routing mechanism described in Appendix A.Routing of the
4
+ paper. Given an input hidden state x, the router produces two outputs used downstream:
5
+
6
+ - selected_heads (I): which K of the L available expert heads each token routes to,
7
+ determined by TopK over biased routing scores.
8
+ - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
+ *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
+ bias b must not influence P.
11
+
12
+ This separation is architecturally critical: expert_bias drives selection (and thus load
13
+ balancing) but does not corrupt the gradient path from the output through routing_probs
14
+ back to the routing projection weights.
15
+
16
+ The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
+ autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
+ loop can weight and add to the language modeling loss.
19
+
20
+ The router additionally computes and returns MaxVio, a detached scalar summarising
21
+ routing imbalance for the current forward pass:
22
+
23
+ MaxVio = L · max_l(f_l − 1/L)
24
+
25
+ where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
+ target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
+
28
+ Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+ from .configuration import ShramConfig
36
+ from .__attention__load_balance_loss import LoadBalanceLoss
37
+
38
+
39
+ class MoSRAHRouter(nn.Module):
40
+ """Token-choice router for MoSRAH sparse attention.
41
+
42
+ Each input token independently selects K of the L available expert heads. Selection
43
+ is driven by biased routing scores to enable load balancing, but the routing
44
+ probabilities used for output reduction are computed from unbiased scores so that
45
+ the expert bias does not interfere with the gradient path to the router weights.
46
+
47
+ The routing projection W_r has no bias term — the paper specifies xW_r with no
48
+ additional projection bias. The only bias-like parameter is expert_bias (b), which
49
+ has an entirely separate role and update mechanism.
50
+
51
+ Args:
52
+ config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
+ (L), and ``num_selected_heads`` (K).
54
+ """
55
+
56
+ def __init__(self, config: ShramConfig) -> None:
57
+ super().__init__()
58
+ self.num_mosrah_heads = config.num_mosrah_heads
59
+ self.num_selected_heads = config.num_selected_heads
60
+
61
+ # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
+ self.routing_projection = nn.Linear(
63
+ config.hidden_size, config.num_mosrah_heads, bias=False
64
+ )
65
+
66
+ # b: learned per-head bias for load balancing. Initialized to zero so that all
67
+ # heads start with equal selection probability. Updated by the main optimizer
68
+ # via the LoadBalanceLoss custom backward.
69
+ self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
+
71
+ def forward(
72
+ self, x: torch.Tensor, active_mask: torch.Tensor
73
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
+ """Route input tokens to K expert heads each and compute routing probabilities.
75
+
76
+ Args:
77
+ x: Input hidden states of shape (batch, seq_len, hidden_size).
78
+ active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
+ True means the token is semantically live. Dead tokens do not
80
+ contribute to routing frequencies, load_balance_loss, or max_vio.
81
+
82
+ Returns:
83
+ selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
+ Each token's K selected head indices, determined by TopK on biased scores.
85
+ routing_probs: Routing probabilities P of shape (batch, seq_len,
86
+ num_selected_heads). Gathered from unbiased scores at selected_heads
87
+ indices and renormalized to sum to 1 per token.
88
+ load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
+ Training loop scales this by a weight and adds it to the main loss.
90
+ max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
+ Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
+ never contributes gradients.
93
+ """
94
+ B, N, _ = x.shape
95
+ L = self.num_mosrah_heads
96
+ K = self.num_selected_heads
97
+
98
+ # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
+ # compute routing_probs — expert_bias must not influence them.
100
+ logits = self.routing_projection(x) # (B, N, L)
101
+ routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
+
103
+ # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
+ # selection. expert_bias is added to logits before softmax so that the bias
105
+ # shifts selection probability without rescaling the unbiased distribution.
106
+ biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
+ logits + self.expert_bias, dim=-1
108
+ )
109
+
110
+ # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
+ selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
+
113
+ # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
+ # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
+ # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
+ # the output back to the router weights free of expert_bias influence.
117
+ gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
+ routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
+
120
+ # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
+ # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
+ # assignment mask before reduction. Normalization uses the active assignment
123
+ # count so frequencies remain properly scaled regardless of how many tokens
124
+ # are live in this chunk.
125
+ assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
+ assignment_mask.scatter_(-1, selected_heads, 1.0)
127
+ active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
+ num_active_assignments = active_mask.sum() * K
129
+ routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
+
131
+ # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
+ # registers it as a graph node; the custom backward writes the DeepSeek-style
133
+ # correction gradient to expert_bias.grad for the optimizer to consume.
134
+ load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
+
136
+ # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
+ # not contribute gradients under any circumstance, so it is detached at the
138
+ # point of computation rather than left to callers to detach.
139
+ max_vio = self._compute_max_vio(routing_freqs, L)
140
+
141
+ return selected_heads, routing_probs, load_balance_loss, max_vio
142
+
143
+ @staticmethod
144
+ def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
+ """Compute the MaxVio routing-imbalance scalar.
146
+
147
+ MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
+ head l and 1/L is the perfectly balanced target. A value of zero indicates
149
+ perfect balance; a value of 1 means the most overloaded head received exactly
150
+ double its fair share.
151
+
152
+ The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
+ and must never contribute gradients to any parameter.
154
+
155
+ Args:
156
+ routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
+ num_heads: Total number of MoSRAH heads L.
158
+
159
+ Returns:
160
+ Detached scalar MaxVio tensor.
161
+ """
162
+ return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
__attention__shram.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM hybrid attention layer.
2
+
3
+ This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
+ used at one decoder attention slot in SHRAM.
5
+
6
+ The local sliding-window path and the MoSRAH sparse path are already verified
7
+ independently. The responsibility here is therefore not to introduce new
8
+ attention logic, but to preserve the bridge contracts between them: both paths
9
+ must consume the same input hidden state, each path must receive the sub-cache
10
+ it actually owns, the two model-space outputs must be summed directly, and the
11
+ sparse-path load-balance loss must remain visible to the caller.
12
+ """
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from .__cache__shram_layer_cache import ShramLayerCache
18
+ from .configuration import ShramConfig
19
+ from .__attention__sliding_window_attention import SlidingWindowAttention
20
+ from .__attention__mosrah import MoSRAHLayer
21
+
22
+
23
+ class SHRAMHybridLayer(nn.Module):
24
+ """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
+
26
+ The local path preserves nearby-token behavior through sliding-window causal
27
+ attention. The sparse path is the theorem-facing MoSRAH routed attention
28
+ path. Both operate over the same model-space hidden state and return
29
+ model-space outputs, so the hybrid composition is a direct sum in model
30
+ space.
31
+ """
32
+
33
+ def __init__(self, config: ShramConfig) -> None:
34
+ super().__init__()
35
+ self.local_attention = SlidingWindowAttention(config)
36
+ self.sparse_attention = MoSRAHLayer(config)
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ position_ids: torch.Tensor,
42
+ active_mask: torch.Tensor,
43
+ cache: ShramLayerCache | None,
44
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
+ """Apply the SHRAM hybrid attention layer.
46
+
47
+ Args:
48
+ hidden_states: Input hidden states of shape (B, N, d).
49
+ position_ids: Authoritative token positions of shape (B, N).
50
+ active_mask: Current-chunk active mask of shape (B, N), where True
51
+ means the token is semantically live. Forwarded unchanged to
52
+ both the local path and the sparse path.
53
+ cache: Optional per-layer SHRAM cache. When provided, the owned
54
+ sliding-window and MoSRAH sub-caches are dispatched directly to
55
+ their corresponding attention paths.
56
+
57
+ Returns:
58
+ hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
+ load_balance_loss: Scalar sparse-path load-balance loss.
60
+ max_vio: Detached scalar routing-imbalance summary. Passed through
61
+ unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
+ """
63
+ # ------------------------------------------------
64
+ # It is not possible, due to how bea constructs its block mask,
65
+ # for the model to process a sequence that does not start at zero
66
+ # without a cache to track the per-head offsets
67
+ # ------------------------------------------------
68
+
69
+ if cache is None and torch.any(position_ids[:, 0] != 0):
70
+ raise ValueError(
71
+ "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
+ "Either provide a matching ShramLayerCache populated by the prefix for "
73
+ "continued decoding, or rebase the uncached sequence to start at 0."
74
+ )
75
+
76
+ # -------------------------------------------------------------------
77
+ # The hybrid layer's first responsibility is cache dispatch. The layer
78
+ # cache already owns the concrete sub-cache objects required by each
79
+ # path, so this unit should forward those exact references rather than
80
+ # reinterpret cache ownership or invent a composite update protocol here.
81
+ # -------------------------------------------------------------------
82
+ if cache is None:
83
+ sliding_window_cache = None
84
+ mosrah_cache = None
85
+ else:
86
+ sliding_window_cache = cache.sliding_window_cache
87
+ mosrah_cache = cache.mosrah_cache
88
+
89
+ # -------------------------------------------------------------------
90
+ # Both attention paths must see the same model-space hidden state for
91
+ # the current decoder layer. The local path preserves short-range
92
+ # structure, while the sparse path provides the routed long-range
93
+ # contribution and emits the load-balance signal used by training.
94
+ # -------------------------------------------------------------------
95
+ local_output = self.local_attention(
96
+ x=hidden_states,
97
+ position_ids=position_ids,
98
+ active_mask=active_mask,
99
+ cache=sliding_window_cache,
100
+ )
101
+ sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
+ hidden_states=hidden_states,
103
+ position_ids=position_ids,
104
+ active_mask=active_mask,
105
+ cache=mosrah_cache,
106
+ )
107
+
108
+ # -------------------------------------------------------------------
109
+ # The composition rule is intentionally simple at this boundary. Both
110
+ # sublayers already return model-space tensors of matching shape, so the
111
+ # correct hybrid behavior is their direct sum with no additional mixing
112
+ # logic introduced here.
113
+ # -------------------------------------------------------------------
114
+ hybrid_output = local_output + sparse_output
115
+
116
+ return hybrid_output, load_balance_loss, max_vio
__attention__sliding_window_attention.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/shram/model/attention/sliding_window_attention.py
2
+
3
+ """Local sliding-window attention path for SHRAM.
4
+
5
+ This file defines `SlidingWindowAttention`, the local short-range attention path
6
+ used inside the SHRAM hybrid layer.
7
+
8
+ In the masked-continuation variant, the local cache no longer returns a
9
+ semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
+ returns:
11
+
12
+ - the retained local window memory concatenated with the current chunk
13
+ - an aligned active mask over that returned frame
14
+
15
+ This module consumes that returned frame directly and constructs effective local
16
+ causal/window visibility from the mask. It does not own cache retention policy;
17
+ it owns only local attention semantics.
18
+ """
19
+
20
+ import math
21
+ from typing import Any
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
+
27
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
28
+ from .configuration import ShramConfig
29
+ from .rope import RotaryEmbedding
30
+
31
+
32
+ class SlidingWindowAttention(nn.Module):
33
+ """Causal local sliding-window attention for one SHRAM layer.
34
+
35
+ Args:
36
+ config: SHRAM config. Must expose `hidden_size`,
37
+ `num_sliding_window_heads`, `head_dim`, `window_size`,
38
+ `attention_dropout`, and `local_rope_theta`.
39
+
40
+ Raises:
41
+ NotImplementedError: If `attention_dropout != 0.0`.
42
+ """
43
+
44
+ def __init__(self, config: ShramConfig) -> None:
45
+ super().__init__()
46
+
47
+ self.hidden_size = config.hidden_size
48
+ self.num_heads = config.num_sliding_window_heads
49
+ self.head_dim = config.head_dim
50
+ self.window_size = config.window_size
51
+ self.attention_dropout = config.attention_dropout
52
+
53
+ if self.attention_dropout != 0.0:
54
+ raise NotImplementedError(
55
+ "SlidingWindowAttention currently supports only "
56
+ "attention_dropout == 0.0."
57
+ )
58
+
59
+ self.inner_dim = self.num_heads * self.head_dim
60
+
61
+ # Standard MHA projections for the local path.
62
+ self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
+ self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
+ self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
+ self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
+
67
+ # The local path always uses default-mode RoPE with its own theta.
68
+ self.rope = RotaryEmbedding(
69
+ mode="default",
70
+ head_dim=self.head_dim,
71
+ theta=config.local_rope_theta,
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ position_ids: torch.Tensor,
78
+ active_mask: torch.Tensor,
79
+ cache: LocalSlidingWindowLayerCache | None = None,
80
+ ) -> torch.Tensor:
81
+ """Apply local causal sliding-window attention.
82
+
83
+ Args:
84
+ x: Input tensor of shape `(B, N, hidden_size)`.
85
+ position_ids: Position tensor of shape `(B, N)`.
86
+ active_mask: Current-chunk active mask of shape `(B, N)`, where
87
+ `True` means active.
88
+ cache: Optional `LocalSlidingWindowLayerCache`.
89
+
90
+ Returns:
91
+ Output tensor of shape `(B, N, hidden_size)`.
92
+ """
93
+ batch_size, query_len, _ = x.shape
94
+
95
+ self._validate_position_shape(x, position_ids)
96
+ self._validate_active_mask_shape(x, active_mask)
97
+
98
+ # (B, N, H*D) -> (B, H, N, D)
99
+ q = self.q_proj(x).view(
100
+ batch_size,
101
+ query_len,
102
+ self.num_heads,
103
+ self.head_dim,
104
+ ).transpose(1, 2)
105
+ k = self.k_proj(x).view(
106
+ batch_size,
107
+ query_len,
108
+ self.num_heads,
109
+ self.head_dim,
110
+ ).transpose(1, 2)
111
+ v = self.v_proj(x).view(
112
+ batch_size,
113
+ query_len,
114
+ self.num_heads,
115
+ self.head_dim,
116
+ ).transpose(1, 2)
117
+
118
+ q, k, attention_scaling = self.rope(q, k, position_ids)
119
+
120
+ # The cache returns the current-step visible local frame, not merely the
121
+ # retained next-step cache buffer.
122
+ if cache is not None:
123
+ k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
124
+ else:
125
+ k_full, v_full, full_active_mask = k, v, active_mask
126
+
127
+ block_mask = self._make_block_mask(
128
+ active_mask=full_active_mask,
129
+ batch_size=batch_size,
130
+ num_heads=self.num_heads,
131
+ query_len=query_len,
132
+ kv_len=k_full.shape[-2],
133
+ window_size=self.window_size,
134
+ device=x.device,
135
+ )
136
+
137
+ attn_output = flex_attention(
138
+ q,
139
+ k_full,
140
+ v_full,
141
+ block_mask=block_mask,
142
+ scale=attention_scaling / math.sqrt(self.head_dim),
143
+ )
144
+
145
+ # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
146
+ attn_output = (
147
+ attn_output.transpose(1, 2)
148
+ .contiguous()
149
+ .view(batch_size, query_len, self.inner_dim)
150
+ )
151
+
152
+ return self.o_proj(attn_output)
153
+
154
+ def _validate_position_shape(
155
+ self,
156
+ x: torch.Tensor,
157
+ position_ids: torch.Tensor,
158
+ ) -> None:
159
+ """Validate the position tensor shape expected by local RoPE."""
160
+ if position_ids.shape != x.shape[:2]:
161
+ raise ValueError(
162
+ f"position_ids must have shape {tuple(x.shape[:2])}, "
163
+ f"got {tuple(position_ids.shape)}."
164
+ )
165
+
166
+ def _validate_active_mask_shape(
167
+ self,
168
+ x: torch.Tensor,
169
+ active_mask: torch.Tensor,
170
+ ) -> None:
171
+ """Validate the current-chunk active-mask contract."""
172
+ if active_mask.shape != x.shape[:2]:
173
+ raise ValueError(
174
+ f"active_mask must have shape {tuple(x.shape[:2])}, "
175
+ f"got {tuple(active_mask.shape)}."
176
+ )
177
+ if active_mask.dtype != torch.bool:
178
+ raise ValueError(
179
+ f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
180
+ )
181
+
182
+ def _make_block_mask(
183
+ self,
184
+ active_mask: torch.Tensor,
185
+ batch_size: int,
186
+ num_heads: int,
187
+ query_len: int,
188
+ kv_len: int,
189
+ window_size: int,
190
+ device: torch.device,
191
+ ) -> Any:
192
+ """Create the FlexAttention block mask for masked local continuation.
193
+
194
+ The returned local frame is chronological in raw buffer order, but dead
195
+ positions may remain inside it. Effective local order is therefore
196
+ recovered from the active mask itself by taking a cumulative count over
197
+ active positions.
198
+
199
+ Queries still occupy the tail of the returned frame, so raw buffer order
200
+ is used to locate query rows. Semantic active-token positions are then
201
+ used to decide causality and sliding-window distance.
202
+ """
203
+ query_offset = kv_len - query_len
204
+ semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
+
206
+ def sliding_window_mask(
207
+ batch_idx: torch.Tensor,
208
+ head_idx: torch.Tensor,
209
+ q_idx: torch.Tensor,
210
+ kv_idx: torch.Tensor,
211
+ ) -> torch.Tensor:
212
+
213
+ q_abs = query_offset + q_idx
214
+
215
+ query_is_active = active_mask[batch_idx, q_abs]
216
+ key_is_active = active_mask[batch_idx, kv_idx]
217
+
218
+ q_sem = semantic_positions[batch_idx, q_abs]
219
+ k_sem = semantic_positions[batch_idx, kv_idx]
220
+
221
+ is_causal = k_sem <= q_sem
222
+ in_window = (q_sem - k_sem) < window_size
223
+
224
+ return query_is_active & key_is_active & is_causal & in_window
225
+
226
+ return create_block_mask(
227
+ sliding_window_mask,
228
+ B=batch_size,
229
+ H=num_heads,
230
+ Q_LEN=query_len,
231
+ KV_LEN=kv_len,
232
+ device=device,
233
+ )
__cache__mosrah_cache.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MoSRAH sparse KV cache — single-layer implementation.
2
+
3
+ MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
+ by head rather than by sequence position. The routing is dynamic and produces a ragged
5
+ distribution of token counts across (batch, head) slots — different batch items may
6
+ route different numbers of tokens to the same head, and different heads accumulate at
7
+ different rates. DynamicCache cannot represent this correctly: it concatenates along
8
+ the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
+ therefore uses a custom buffer design.
10
+
11
+ Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
+ attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
+ heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
+ head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
+ valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
+ buffer_capacity property and is derived directly from self.keys rather than tracked
17
+ as a separate variable.
18
+
19
+ The primary interface is update(key_states, value_states, active_mask), which accepts
20
+ expert-choice layout, stores only active entries in causal order, and returns the full
21
+ accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
+ active_mask identifies valid cached positions; everything beyond each slot's count is
23
+ junk data that downstream attention must exclude.
24
+
25
+ BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
+ exposed by get_heads_lengths() must be read before update() if the caller needs the
27
+ pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
+ in-place and the pre-update values are not recoverable afterward.
29
+
30
+ All buffers are allocated at construction time. MoSRAHCache is constructed by
31
+ ShramLayerCache, which has access to batch size, device, and all model config parameters
32
+ needed to fully specify the storage layout upfront.
33
+ """
34
+
35
+ import torch
36
+ from transformers.cache_utils import CacheLayerMixin
37
+
38
+
39
+ class MoSRAHCache(CacheLayerMixin):
40
+ """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
+
42
+ Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
+ Stores keys and values in the mixin-standard self.keys and self.values attributes
44
+ using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
+ which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
+
47
+ All storage is allocated at construction time and is_initialized is True
48
+ immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
+ config parameters so no lazy allocation is needed.
50
+
51
+ Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
+ (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
+ This matches the packed representation produced by expert packing in the MoSRAH
54
+ forward pass, where BEA has already applied RoPE before calling update().
55
+
56
+ Args:
57
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
+ second dimension of all storage tensors.
59
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
+ dimension of all storage tensors.
61
+ batch_size: Number of sequences in the batch. Determines the first dimension
62
+ of all storage tensors.
63
+ device: Device on which to allocate all tensors. Should match the model device.
64
+ initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
+ during prompt processing.
67
+ """
68
+
69
+ is_compileable = False
70
+ is_sliding = False
71
+
72
+ def __init__(
73
+ self,
74
+ num_mosrah_heads: int,
75
+ head_dim: int,
76
+ batch_size: int,
77
+ device: torch.device,
78
+ initial_buffer_size: int = 64,
79
+ ) -> None:
80
+ super().__init__()
81
+ self.num_mosrah_heads = num_mosrah_heads
82
+ self.head_dim = head_dim
83
+ self.batch_size = batch_size
84
+ self.device = device
85
+
86
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
87
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
+ # tracks valid occupancy per (batch, head) slot.
89
+ self.keys: torch.Tensor = torch.zeros(
90
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
+ )
92
+ self.values: torch.Tensor = torch.zeros(
93
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
+ )
95
+ self._counts: torch.Tensor = torch.zeros(
96
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
97
+ )
98
+
99
+ # Storage is fully allocated at construction — the cache is initialized.
100
+ self.is_initialized = True
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Properties
104
+ # ---------------------------------------------------------------------------
105
+
106
+ @property
107
+ def buffer_capacity(self) -> int:
108
+ """Current number of slots allocated per (batch, head) pair.
109
+
110
+ Derived directly from self.keys rather than tracked separately, so it is
111
+ always consistent with the actual buffer after expansion.
112
+ """
113
+ return self.keys.shape[2]
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Primary API
117
+ # ---------------------------------------------------------------------------
118
+
119
+ def update( # type: ignore[override]
120
+ self,
121
+ key_states: torch.Tensor,
122
+ value_states: torch.Tensor,
123
+ active_mask: torch.Tensor,
124
+ cache_kwargs: dict | None = None,
125
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ """Scatter active key/value states into the buffer and return the full cache state.
127
+
128
+ Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
129
+ active_mask is (B, L, T) bool with True marking real tokens. Only active
130
+ positions are written; inactive positions are ignored.
131
+
132
+ Uses a cumsum construction to derive the absolute buffer position for each
133
+ active token without any Python loops. For a given (batch, head) slot,
134
+ positions are assigned in the order tokens appear along the T dimension,
135
+ preserving causal ordering.
136
+
137
+ Returns the full accumulated (keys, values, active_mask) across the cached
138
+ sparse sequence. The returned active_mask is True exactly for slots t <
139
+ counts[b, l]; everything beyond is junk data that BEA must exclude.
140
+
141
+ Note: get_heads_lengths() must be called before update() if the caller needs
142
+ the pre-update occupancy for position computation (Unit 10.A). update()
143
+ increments counts in-place and the pre-update values are not recoverable.
144
+
145
+ Args:
146
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
147
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
148
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
149
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
150
+
151
+ Returns:
152
+ Tuple of (keys, values, active_mask):
153
+ keys: (B, L, T, u) float — full key buffer including junk slots.
154
+ values: (B, L, T, u) float — full value buffer including junk slots.
155
+ active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
+ """
157
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
+
159
+ if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
+ self._expand()
161
+
162
+ # Cumulative count of active positions along T for each (b, l) slot. Entry
163
+ # [b, l, t] is the 1-based rank of position t among all active positions in
164
+ # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
+ # produce a negative value, which is excluded by the mask gate below.
166
+ within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
+
168
+ # Add the pre-update count to get the absolute buffer position for each
169
+ # active token.
170
+ abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
+
172
+ # Scatter key and value vectors at all active positions.
173
+ b_idx, l_idx, t_idx = torch.where(active_mask)
174
+ self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
+ key_states[b_idx, l_idx, t_idx]
176
+ )
177
+ self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
+ value_states[b_idx, l_idx, t_idx]
179
+ )
180
+
181
+ self._counts += incoming_delta
182
+
183
+ return self.keys, self.values, self._make_active_mask()
184
+
185
+ def get_heads_lengths(self) -> torch.Tensor:
186
+ """Return the per-(batch, head) token count for this layer.
187
+
188
+ This is the authoritative occupancy tensor consumed by BEA for attention
189
+ masking and by position computation (Unit 10.A) for semantic-sequence
190
+ position computation.
191
+
192
+ Note: in the MoSRAH forward pass, this must be called before update() if the
193
+ caller needs the pre-update occupancy. update() increments these counts in-place.
194
+
195
+ Returns:
196
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
197
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
198
+ """
199
+ return self._counts
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # CacheLayerMixin — overridden coordination methods
203
+ # ---------------------------------------------------------------------------
204
+
205
+ def reset(self) -> None:
206
+ """Clear all cached key and value tensors.
207
+
208
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
209
+ and is_initialized remains True — only the contents are cleared.
210
+ """
211
+ self.keys.zero_()
212
+ self.values.zero_()
213
+ self._counts.zero_()
214
+
215
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
216
+ """Reorder the batch dimension of all cached tensors for beam search.
217
+
218
+ Applied atomically across self.keys, self.values, and _counts. Beam search
219
+ must reorder all three together or the occupancy counts and buffer contents
220
+ will correspond to different beam hypotheses.
221
+
222
+ Overrides the parent because the parent's implementation calls get_seq_length(),
223
+ which is not supported for this cache.
224
+
225
+ Args:
226
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
227
+ search algorithm.
228
+ """
229
+ self.keys = self.keys[beam_idx]
230
+ self.values = self.values[beam_idx]
231
+ self._counts = self._counts[beam_idx]
232
+
233
+ def batch_repeat_interleave(self, repeats: int) -> None:
234
+ """Expand the batch dimension by repeating each entry repeats times.
235
+
236
+ Used at beam search initialisation to expand the cache from batch size B to
237
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
238
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
239
+
240
+ Args:
241
+ repeats: Number of times to repeat each batch entry.
242
+ """
243
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
244
+ self.values = self.values.repeat_interleave(repeats, dim=0)
245
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
246
+ self.batch_size = self.batch_size * repeats
247
+
248
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
249
+ """Select a subset of batch entries by index.
250
+
251
+ Used in contrastive search to retain only the selected candidate entries.
252
+ Applied atomically across keys, values, and _counts; batch_size is updated
253
+ to reflect the number of retained entries.
254
+
255
+ Args:
256
+ indices: 1-D integer tensor of batch indices to retain.
257
+ """
258
+ self.keys = self.keys[indices]
259
+ self.values = self.values[indices]
260
+ self._counts = self._counts[indices]
261
+ self.batch_size = indices.shape[0]
262
+
263
+ def offload(self) -> None:
264
+ """Offload all cached tensors to CPU.
265
+
266
+ Extends the parent to also offload _counts, which the parent does not know
267
+ about. All three tensors are moved atomically so device state remains consistent.
268
+ """
269
+ super().offload()
270
+ self._counts = self._counts.to("cpu", non_blocking=True)
271
+
272
+ def prefetch(self) -> None:
273
+ """Move all cached tensors back to the model device ahead of time.
274
+
275
+ Extends the parent to also prefetch _counts, which the parent does not know
276
+ about. _counts is synced to self.keys.device after the parent moves keys and
277
+ values, so all three remain consistent.
278
+ """
279
+ super().prefetch()
280
+ if self._counts.device != self.keys.device:
281
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
282
+
283
+ def lazy_initialization( # type: ignore[override]
284
+ self, key_states: torch.Tensor, value_states: torch.Tensor
285
+ ) -> None:
286
+ """No-op — storage is fully allocated at construction time."""
287
+ pass
288
+
289
+ # ---------------------------------------------------------------------------
290
+ # CacheLayerMixin — unsupported abstract methods
291
+ # ---------------------------------------------------------------------------
292
+
293
+ def get_seq_length(self) -> int: # type: ignore[override]
294
+ """Not supported — no single sequence length represents this cache's state.
295
+
296
+ MoSRAH heads accumulate independently; (batch, head) slots have different
297
+ lengths depending on routing history. There is no meaningful scalar summary.
298
+ Use get_heads_lengths() for per-head occupancy.
299
+ """
300
+ raise NotImplementedError(
301
+ "MoSRAHCache has no single sequence length. "
302
+ "Use get_heads_lengths() for per-head occupancy."
303
+ )
304
+
305
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
306
+ """Not supported — MoSRAHCache is dynamic and unbounded."""
307
+ raise NotImplementedError(
308
+ "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
+ )
310
+
311
+ def get_mask_sizes( # type: ignore[override]
312
+ self,
313
+ cache_position: torch.Tensor,
314
+ ) -> tuple[int, int]:
315
+ """Not supported — MoSRAHCache does not participate in HF mask construction."""
316
+ raise NotImplementedError(
317
+ "MoSRAHCache does not support get_mask_sizes()."
318
+ )
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # Internal helpers
322
+ # ---------------------------------------------------------------------------
323
+
324
+ def _make_active_mask(self) -> torch.Tensor:
325
+ """Construct the (B, L, T) active mask from current counts.
326
+
327
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
328
+ has been written. Positions at or beyond the count are junk and must be
329
+ excluded by downstream attention.
330
+ """
331
+ cap = self.buffer_capacity
332
+ return (
333
+ torch.arange(cap, device=self.keys.device)
334
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
335
+ < self._counts.unsqueeze(-1)
336
+ )
337
+
338
+ def _expand(self) -> None:
339
+ """Double the buffer capacity, preserving existing data.
340
+
341
+ Called by update() when an incoming batch of tokens would cause any
342
+ (batch, head) slot to exceed the current buffer capacity. All existing
343
+ key and value data is copied into the low half of the new buffer; the
344
+ high half is zero-initialised and will be filled by subsequent writes.
345
+ After reassignment, buffer_capacity reflects the new size automatically.
346
+ """
347
+ old_cap = self.buffer_capacity
348
+ new_cap = old_cap * 2
349
+ dev = self.keys.device
350
+ new_keys = torch.zeros(
351
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
+ )
353
+ new_values = torch.zeros(
354
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
+ )
356
+ new_keys[:, :, :old_cap, :] = self.keys
357
+ new_values[:, :, :old_cap, :] = self.values
358
+ self.keys = new_keys
359
+ self.values = new_values
__cache__shram_cache.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
+
3
+ The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
+ CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
+ lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
+ ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
+ presents them through the Cache interface, and transparently forwards model-wide operations
8
+ across all of them.
9
+
10
+ ShramCache does not define a composite update() interface. The two attention paths inside each
11
+ SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
+ nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
+ relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
+ coordination of the layer caches — not routing attention inputs.
15
+
16
+ Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
+ specified layer, which tracks the cumulative count of token positions processed. This is
18
+ what HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import Cache
23
+
24
+ from .__cache__shram_layer_cache import ShramLayerCache
25
+
26
+
27
+ class ShramCache(Cache):
28
+ """Top-level cache for the full SHRAM model.
29
+
30
+ Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
31
+ role and transparently forwards reset, reorder, and sequence-length queries across all
32
+ owned layer caches.
33
+
34
+ No composite update() interface is provided. The two attention paths inside each SHRAM
35
+ layer have materially different update semantics; callers must update sub-caches directly
36
+ via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
+
38
+ Args:
39
+ num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
+ ShramLayerCache objects are constructed.
41
+ sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
+ num_local_heads: Number of local attention heads per layer.
43
+ local_head_dim: Per-head embedding width for the local path.
44
+ num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
+ mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
+ batch_size: Number of sequences in the batch.
47
+ device: Device on which to allocate cache tensors.
48
+ initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
+ Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
+ during prompt processing.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ num_hidden_layers: int,
56
+ sliding_window: int,
57
+ num_local_heads: int,
58
+ local_head_dim: int,
59
+ num_mosrah_heads: int,
60
+ mosrah_head_dim: int,
61
+ batch_size: int,
62
+ device: torch.device,
63
+ initial_buffer_size: int = 64,
64
+ ) -> None:
65
+ layers = [
66
+ ShramLayerCache(
67
+ sliding_window=sliding_window,
68
+ num_local_heads=num_local_heads,
69
+ local_head_dim=local_head_dim,
70
+ num_mosrah_heads=num_mosrah_heads,
71
+ mosrah_head_dim=mosrah_head_dim,
72
+ batch_size=batch_size,
73
+ device=device,
74
+ initial_buffer_size=initial_buffer_size,
75
+ )
76
+ for _ in range(num_hidden_layers)
77
+ ]
78
+ super().__init__(layers=layers)
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Cache — composite-meaningful methods
82
+ # ---------------------------------------------------------------------------
83
+ #
84
+ # reset(): Inherited. Iterates all layer caches and calls reset() on each.
85
+ #
86
+ # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
87
+ #
88
+ # is_initialized: Inherited property. True iff all layer caches are initialized.
89
+ # Since ShramLayerCache.is_initialized is True from construction, this is True
90
+ # immediately after ShramCache.__init__ returns.
91
+
92
+ def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
93
+ """Return the cumulative sequence length for the specified layer.
94
+
95
+ Delegates to the layer cache at layer_idx, which in turn delegates to the
96
+ local sliding-window sub-cache. That sub-cache is authoritative for sequence
97
+ progress: it sees every token presented to the layer and accumulates a truthful
98
+ total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
99
+ """
100
+ return self.layers[layer_idx].get_seq_length()
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # Cache — unsupported methods
104
+ # ---------------------------------------------------------------------------
105
+
106
+ def update( # type: ignore[override]
107
+ self,
108
+ key_states: torch.Tensor,
109
+ value_states: torch.Tensor,
110
+ layer_idx: int,
111
+ cache_kwargs: dict | None = None,
112
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
+ """Not supported — ShramCache has no composite update interface.
114
+
115
+ The two attention paths inside each SHRAM layer have different update semantics.
116
+ Callers must update sub-caches directly:
117
+ cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
118
+ cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
119
+ """
120
+ raise NotImplementedError(
121
+ "ShramCache has no composite update interface. "
122
+ "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
123
+ )
124
+
125
+ def crop(self, max_length: int) -> None:
126
+ """Not supported — ShramCache layers do not implement crop()."""
127
+ raise NotImplementedError("ShramCache does not support crop().")
128
+
129
+ @property
130
+ def max_batch_size(self) -> int:
131
+ """Not supported — ShramCache does not track a uniform batch size across layers."""
132
+ raise NotImplementedError("ShramCache does not expose max_batch_size.")
133
+
134
+ @property
135
+ def max_cache_len(self) -> int:
136
+ """Not supported — ShramCache has no single maximum cache length.
137
+
138
+ The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
+ No truthful scalar maximum represents the composite.
140
+ """
141
+ raise NotImplementedError("ShramCache does not expose max_cache_len.")
__cache__shram_layer_cache.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
+
3
+ A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
+ local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
+ different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
+ the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
+ path can interact with it without indirection.
8
+
9
+ ShramLayerCache does not define a composite update() interface. The two paths have materially
10
+ different update semantics — the local side uses chunk-local key/value/mask concatenation
11
+ while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
+ behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
+ each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
+ ownership, coordination, and reset/reorder boundary for one decoder layer.
15
+
16
+ Sequence length at this boundary is reported by delegating to the local sliding-window
17
+ sub-cache, which tracks the cumulative count of token positions processed. This is the
18
+ quantity HuggingFace generation reads through get_seq_length().
19
+ """
20
+
21
+ import torch
22
+ from transformers.cache_utils import CacheLayerMixin
23
+
24
+ from .__cache__mosrah_cache import MoSRAHCache
25
+ from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
26
+
27
+
28
+ class ShramLayerCache(CacheLayerMixin):
29
+ """Cache subsystem for one SHRAM decoder layer.
30
+
31
+ Owns and coordinates two sub-caches:
32
+ - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
33
+ - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
34
+
35
+ Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
36
+ exposed directly for their downstream attention paths — no composite update() interface is
37
+ provided, because the two paths have materially different update semantics.
38
+
39
+ Sequence length is reported by delegating to the local sliding-window sub-cache, which
40
+ tracks the cumulative count of token positions processed across all update() calls.
41
+
42
+ Args:
43
+ sliding_window: Number of tokens retained by the local sliding-window cache.
44
+ num_local_heads: Number of local attention heads.
45
+ local_head_dim: Per-head embedding width for the local path.
46
+ num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
+ mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
+ batch_size: Number of sequences in the batch.
49
+ device: Device on which to allocate cache tensors.
50
+ initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
+ prompt processing.
53
+ """
54
+
55
+ is_compileable = False
56
+ is_sliding = False
57
+
58
+ def __init__(
59
+ self,
60
+ sliding_window: int,
61
+ num_local_heads: int,
62
+ local_head_dim: int,
63
+ num_mosrah_heads: int,
64
+ mosrah_head_dim: int,
65
+ batch_size: int,
66
+ device: torch.device,
67
+ initial_buffer_size: int = 64,
68
+ ) -> None:
69
+ super().__init__()
70
+ self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
+ sliding_window=sliding_window,
72
+ num_heads=num_local_heads,
73
+ head_dim=local_head_dim,
74
+ batch_size=batch_size,
75
+ device=device,
76
+ )
77
+ self.mosrah_cache = MoSRAHCache(
78
+ num_mosrah_heads=num_mosrah_heads,
79
+ head_dim=mosrah_head_dim,
80
+ batch_size=batch_size,
81
+ device=device,
82
+ initial_buffer_size=initial_buffer_size,
83
+ )
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Properties
87
+ # ---------------------------------------------------------------------------
88
+
89
+ @property
90
+ def is_initialized(self) -> bool:
91
+ """True iff both sub-caches have allocated their storage.
92
+
93
+ Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
94
+ so this is True immediately after ShramLayerCache.__init__ returns.
95
+ """
96
+ return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
97
+
98
+ @is_initialized.setter
99
+ def is_initialized(self, value: bool) -> None:
100
+ # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
101
+ # attribute. Since property is a data descriptor it takes precedence, but Python
102
+ # still routes the assignment through __set__. Absorb it silently — state is
103
+ # derived from sub-caches, not stored here.
104
+ pass
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # CacheLayerMixin — composite-meaningful methods
108
+ # ---------------------------------------------------------------------------
109
+
110
+ def get_seq_length(self) -> int: # type: ignore[override]
111
+ """Return the cumulative sequence length from the local sliding-window path.
112
+
113
+ The local path is authoritative for sequence progress: it sees every token
114
+ presented to this layer and accumulates a truthful total. Delegates to
115
+ sliding_window_cache.get_seq_length().
116
+ """
117
+ return self.sliding_window_cache.get_seq_length()
118
+
119
+ def reset(self) -> None:
120
+ """Clear both sub-caches.
121
+
122
+ Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
123
+ state and MoSRAH sparse state remain consistent.
124
+ """
125
+ self.sliding_window_cache.reset()
126
+ self.mosrah_cache.reset()
127
+
128
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
129
+ """Reorder the batch dimension of both sub-caches for beam search.
130
+
131
+ Delegates to each sub-cache. Both are reordered atomically so the sliding-window
132
+ and MoSRAH state correspond to the same beam hypotheses after reordering.
133
+
134
+ Args:
135
+ beam_idx: Permutation indices of shape (batch,) produced by beam search.
136
+ """
137
+ self.sliding_window_cache.reorder_cache(beam_idx)
138
+ self.mosrah_cache.reorder_cache(beam_idx)
139
+
140
+ def batch_repeat_interleave(self, repeats: int) -> None:
141
+ """Expand the batch dimension of both sub-caches for beam search initialisation.
142
+
143
+ Delegates atomically to each sub-cache. Both must be expanded together so the
144
+ sliding-window and MoSRAH state correspond to the same beam candidates.
145
+
146
+ Args:
147
+ repeats: Number of times to repeat each batch entry.
148
+ """
149
+ self.sliding_window_cache.batch_repeat_interleave(repeats)
150
+ self.mosrah_cache.batch_repeat_interleave(repeats)
151
+
152
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
153
+ """Select a subset of batch entries in both sub-caches for contrastive search.
154
+
155
+ Delegates atomically to each sub-cache. Both must be trimmed together so the
156
+ sliding-window and MoSRAH state remain consistent.
157
+
158
+ Args:
159
+ indices: 1-D integer tensor of batch indices to retain.
160
+ """
161
+ self.sliding_window_cache.batch_select_indices(indices)
162
+ self.mosrah_cache.batch_select_indices(indices)
163
+
164
+ def offload(self) -> None:
165
+ """Offload both sub-caches to CPU.
166
+
167
+ Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
168
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
+ """
170
+ self.sliding_window_cache.offload()
171
+ self.mosrah_cache.offload()
172
+
173
+ def prefetch(self) -> None:
174
+ """Move both sub-caches back to their model device ahead of time.
175
+
176
+ Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
177
+ does not own self.keys/self.values directly; all cached data lives in the sub-caches.
178
+ """
179
+ self.sliding_window_cache.prefetch()
180
+ self.mosrah_cache.prefetch()
181
+
182
+ def lazy_initialization( # type: ignore[override]
183
+ self, key_states: torch.Tensor, value_states: torch.Tensor
184
+ ) -> None:
185
+ """No-op — both sub-caches handle their own initialization."""
186
+ pass
187
+
188
+ # ---------------------------------------------------------------------------
189
+ # CacheLayerMixin — unsupported abstract methods
190
+ # ---------------------------------------------------------------------------
191
+
192
+ def update( # type: ignore[override]
193
+ self,
194
+ key_states: torch.Tensor,
195
+ value_states: torch.Tensor,
196
+ cache_kwargs: dict | None = None,
197
+ ) -> tuple[torch.Tensor, torch.Tensor]:
198
+ """Not supported — ShramLayerCache has no composite update interface.
199
+
200
+ The two sub-caches have materially different update semantics: the sliding-window
201
+ side uses standard key/value concatenation while the MoSRAH side uses expert-choice
202
+ scatter with an active mask. Callers must update each sub-cache directly via
203
+ sliding_window_cache.update() or mosrah_cache.update().
204
+ """
205
+ raise NotImplementedError(
206
+ "ShramLayerCache has no composite update interface. "
207
+ "Update sliding_window_cache or mosrah_cache directly."
208
+ )
209
+
210
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
211
+ """Not supported — the composite cache has no single maximum shape.
212
+
213
+ The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
+ unbounded. No truthful scalar maximum represents the composite.
215
+ """
216
+ raise NotImplementedError(
217
+ "ShramLayerCache has no single maximum cache shape. "
218
+ "Query sliding_window_cache or mosrah_cache directly."
219
+ )
220
+
221
+ def get_mask_sizes( # type: ignore[override]
222
+ self,
223
+ cache_position: torch.Tensor,
224
+ ) -> tuple[int, int]:
225
+ """Not supported — ShramLayerCache does not participate in HF mask construction.
226
+
227
+ The two sub-caches have different mask semantics and their respective attention
228
+ paths handle masking directly.
229
+ """
230
+ raise NotImplementedError(
231
+ "ShramLayerCache does not support get_mask_sizes(). "
232
+ "Query sliding_window_cache or mosrah_cache directly."
233
+ )
__cache__sliding_window_cache.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/shram/model/cache/sliding_window_cache.py
2
+
3
+ """Local sliding-window cache for the SHRAM local attention path.
4
+
5
+ This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
+ `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
+
8
+ Its job is narrow:
9
+
10
+ - accept the current chunk's local key/value tensors and active mask
11
+ - return the current-step local frame consumed by local attention
12
+ - separately retain the next-step sliding-window cache state
13
+
14
+ It does not decide local causal visibility. That is owned by
15
+ `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
+ constructs the effective local attention mask from it.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
+ """Fixed-width local cache for one SHRAM decoder layer.
25
+
26
+ The cache keeps a retained local sliding-window buffer and an aligned active
27
+ mask. On update, it returns the current-step local frame formed by
28
+ concatenating retained cache state with the new chunk, then remembers only
29
+ the last `sliding_window` positions for the next step.
30
+
31
+ Dead positions are allowed to remain in both the returned frame and the
32
+ retained cache. Correctness is carried by the aligned active mask.
33
+
34
+ Args:
35
+ sliding_window: Width of the retained local sliding-window buffer.
36
+ num_heads: Number of local attention heads.
37
+ head_dim: Per-head embedding width for the local path.
38
+ batch_size: Number of sequences in the batch.
39
+ device: Device on which to allocate cache storage.
40
+ """
41
+
42
+ is_compileable = False
43
+ is_sliding = True
44
+
45
+ def __init__(
46
+ self,
47
+ sliding_window: int,
48
+ num_heads: int,
49
+ head_dim: int,
50
+ batch_size: int,
51
+ device: torch.device,
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ if sliding_window < 1:
56
+ raise ValueError(
57
+ f"sliding_window must be >= 1, got {sliding_window}."
58
+ )
59
+ if num_heads < 1:
60
+ raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
+ if head_dim < 1:
62
+ raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
+ if batch_size < 1:
64
+ raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
+
66
+ self.sliding_window = sliding_window
67
+ self.num_heads = num_heads
68
+ self.head_dim = head_dim
69
+ self.batch_size = batch_size
70
+ self.device = device
71
+
72
+ # Retained next-step local cache state. Storage is fixed-width from the
73
+ # start; semantic validity is carried by `active_mask`.
74
+ self.keys = torch.zeros(
75
+ batch_size,
76
+ num_heads,
77
+ sliding_window,
78
+ head_dim,
79
+ device=device,
80
+ )
81
+ self.values = torch.zeros(
82
+ batch_size,
83
+ num_heads,
84
+ sliding_window,
85
+ head_dim,
86
+ device=device,
87
+ )
88
+ self.active_mask = torch.zeros(
89
+ batch_size,
90
+ sliding_window,
91
+ dtype=torch.bool,
92
+ device=device,
93
+ )
94
+
95
+ self.is_initialized = True
96
+
97
+ # Cumulative count of all token positions presented through update() for
98
+ # this cache instance. This is the quantity HuggingFace generation reads
99
+ # through get_seq_length() to track how far along the sequence we are.
100
+ self._total_processed: int = 0
101
+
102
+ def update( # type: ignore[override]
103
+ self,
104
+ key_states: torch.Tensor,
105
+ value_states: torch.Tensor,
106
+ active_mask: torch.Tensor,
107
+ cache_kwargs: dict | None = None,
108
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
+ """Return the current-step local frame and retain the next-step window.
110
+
111
+ Args:
112
+ key_states: Shape `(B, H, T_new, D)` local key vectors for the
113
+ current chunk.
114
+ value_states: Shape `(B, H, T_new, D)` local value vectors for the
115
+ current chunk.
116
+ active_mask: Shape `(B, T_new)` bool. `True` means the
117
+ corresponding token position in the current chunk is active.
118
+ cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
+ interface. Unused by this cache.
120
+
121
+ Returns:
122
+ Tuple of:
123
+ - visible_keys: `(B, H, sliding_window + T_new, D)`
124
+ - visible_values: `(B, H, sliding_window + T_new, D)`
125
+ - visible_active_mask: `(B, sliding_window + T_new)`
126
+
127
+ These are the tensors the local attention path should consume
128
+ directly for the current step.
129
+ """
130
+ self._ensure_state_compatibility(
131
+ key_states=key_states,
132
+ value_states=value_states,
133
+ )
134
+
135
+ # The current-step local frame is just retained cache state followed by
136
+ # the current chunk in chronological order.
137
+ composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
+ key_states=key_states,
139
+ value_states=value_states,
140
+ active_mask=active_mask,
141
+ )
142
+
143
+ # The cache remembers only the last raw sliding-window positions of that
144
+ # composite frame for the next step. Dead positions are allowed to
145
+ # survive; downstream local attention will ignore them using the mask.
146
+ self._retain_next_window(
147
+ composite_keys=composite_keys,
148
+ composite_values=composite_values,
149
+ composite_mask=composite_mask,
150
+ )
151
+
152
+ self._total_processed += key_states.shape[2]
153
+
154
+ return composite_keys, composite_values, composite_mask
155
+
156
+ def _ensure_state_compatibility(
157
+ self,
158
+ key_states: torch.Tensor,
159
+ value_states: torch.Tensor,
160
+ ) -> None:
161
+ """Keep retained cache buffers compatible with the incoming update tensors.
162
+
163
+ The cache is allocated eagerly for simplicity. If later updates arrive on
164
+ a different device or in a different floating dtype, move the retained
165
+ state to match while preserving its contents.
166
+ """
167
+ if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
168
+ self.keys = self.keys.to(
169
+ device=key_states.device,
170
+ dtype=key_states.dtype,
171
+ )
172
+
173
+ if (
174
+ self.values.dtype != value_states.dtype
175
+ or self.values.device != value_states.device
176
+ ):
177
+ self.values = self.values.to(
178
+ device=value_states.device,
179
+ dtype=value_states.dtype,
180
+ )
181
+
182
+ if self.active_mask.device != key_states.device:
183
+ self.active_mask = self.active_mask.to(
184
+ key_states.device,
185
+ non_blocking=True,
186
+ )
187
+
188
+ def _make_composite_frame(
189
+ self,
190
+ key_states: torch.Tensor,
191
+ value_states: torch.Tensor,
192
+ active_mask: torch.Tensor,
193
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
+ """Build the current-step local frame in chronological order."""
195
+ return (
196
+ torch.cat([self.keys, key_states], dim=-2),
197
+ torch.cat([self.values, value_states], dim=-2),
198
+ torch.cat([self.active_mask, active_mask], dim=-1),
199
+ )
200
+
201
+ def _retain_next_window(
202
+ self,
203
+ composite_keys: torch.Tensor,
204
+ composite_values: torch.Tensor,
205
+ composite_mask: torch.Tensor,
206
+ ) -> None:
207
+ """Remember the next-step retained local state.
208
+
209
+ This is a raw positional trim to the last `sliding_window` positions, not
210
+ a semantic live-token trim.
211
+ """
212
+ self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
+ self.values = composite_values[:, :, -self.sliding_window :, :]
214
+ self.active_mask = composite_mask[:, -self.sliding_window :]
215
+
216
+ def get_seq_length(self) -> int:
217
+ """Return the cumulative number of token positions processed by this cache.
218
+
219
+ This is the total count of token positions presented across all update()
220
+ calls since construction or the last reset(). It is the quantity HuggingFace
221
+ generation reads to track sequence progress and is not the same as active-token
222
+ count or current window occupancy.
223
+ """
224
+ return self._total_processed
225
+
226
+ def get_max_cache_shape(self) -> int:
227
+ return self.sliding_window
228
+
229
+ def get_mask_sizes( # type: ignore[override]
230
+ self,
231
+ cache_position: torch.Tensor,
232
+ ) -> tuple[int, int]:
233
+ raise NotImplementedError(
234
+ "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
235
+ )
236
+
237
+ def reset(self) -> None:
238
+ """Restore fresh-cache behavior."""
239
+ self.keys.zero_()
240
+ self.values.zero_()
241
+ self.active_mask.zero_()
242
+ self._total_processed = 0
243
+
244
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
245
+ """Reorder the batch dimension for beam search."""
246
+ self.keys = self.keys[beam_idx]
247
+ self.values = self.values[beam_idx]
248
+ self.active_mask = self.active_mask[beam_idx]
249
+
250
+ def batch_repeat_interleave(self, repeats: int) -> None:
251
+ """Expand the batch dimension for beam-search initialisation."""
252
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
+ self.values = self.values.repeat_interleave(repeats, dim=0)
254
+ self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
255
+ self.batch_size = self.batch_size * repeats
256
+
257
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
258
+ """Select a subset of batch entries for contrastive search."""
259
+ self.keys = self.keys[indices]
260
+ self.values = self.values[indices]
261
+ self.active_mask = self.active_mask[indices]
262
+ self.batch_size = int(indices.shape[0])
263
+
264
+ def offload(self) -> None:
265
+ """Offload cache tensors to CPU."""
266
+ super().offload()
267
+ self.active_mask = self.active_mask.to("cpu", non_blocking=True)
268
+
269
+ def prefetch(self) -> None:
270
+ """Move cache tensors back to the model device ahead of time."""
271
+ super().prefetch()
272
+ if self.active_mask.device != self.keys.device:
273
+ self.active_mask = self.active_mask.to(
274
+ self.keys.device,
275
+ non_blocking=True,
276
+ )
277
+
278
+ def crop(self, max_length: int) -> None:
279
+ raise NotImplementedError(
280
+ "LocalSlidingWindowLayerCache does not support crop()."
281
+ )
282
+
283
+ def lazy_initialization(
284
+ self,
285
+ key_states: torch.Tensor,
286
+ value_states: torch.Tensor,
287
+ ) -> None:
288
+ """No-op — this cache allocates its fixed buffers at construction time."""
289
+ return
__cache__slow_mosrah_cache.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
+
3
+ This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
+ interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
+ (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
+ position's key and value are written to the next available slot for that (batch, head)
7
+ pair, in the order positions appear along the T dimension, which directly enforces
8
+ causal ordering without any index arithmetic to verify.
9
+
10
+ SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
+ trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
+ Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
+ vectorized implementation is validated by asserting exact agreement with this one on all
14
+ test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
+ (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
+ an oracle.
17
+ """
18
+
19
+ import torch
20
+ from transformers.cache_utils import CacheLayerMixin
21
+
22
+
23
+ class SlowMoSRAHCache(CacheLayerMixin):
24
+ """Unvectorized reference implementation of the MoSRAH KV cache.
25
+
26
+ Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
+ mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
+ with the same constructor signature and the same CacheLayerMixin protocol methods.
29
+ The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
+ triples rather than vectorized index arithmetic.
31
+
32
+ This class is not used in the model path. It exists so that MoSRAHCache.update()
33
+ can be validated by asserting exact agreement with this implementation on all test
34
+ inputs. See module docstring for the trust chain this enables.
35
+
36
+ Args:
37
+ num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
+ second dimension of all storage tensors.
39
+ head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
+ dimension of all storage tensors.
41
+ batch_size: Number of sequences in the batch. Determines the first dimension
42
+ of all storage tensors.
43
+ device: Device on which to allocate all tensors. Should match the model device.
44
+ initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
+ when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
+ during prompt processing.
47
+ """
48
+
49
+ is_compileable = False
50
+ is_sliding = False
51
+
52
+ def __init__(
53
+ self,
54
+ num_mosrah_heads: int,
55
+ head_dim: int,
56
+ batch_size: int,
57
+ device: torch.device,
58
+ initial_buffer_size: int = 64,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.num_mosrah_heads = num_mosrah_heads
62
+ self.head_dim = head_dim
63
+ self.batch_size = batch_size
64
+ self.device = device
65
+
66
+ # Allocate primary storage into the mixin-standard self.keys / self.values so
67
+ # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
+ # tracks valid occupancy per (batch, head) slot.
69
+ self.keys: torch.Tensor = torch.zeros(
70
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
+ )
72
+ self.values: torch.Tensor = torch.zeros(
73
+ batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
+ )
75
+ self._counts: torch.Tensor = torch.zeros(
76
+ batch_size, num_mosrah_heads, dtype=torch.long, device=device
77
+ )
78
+
79
+ # Storage is fully allocated at construction — the cache is initialized.
80
+ self.is_initialized = True
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # Properties
84
+ # ---------------------------------------------------------------------------
85
+
86
+ @property
87
+ def buffer_capacity(self) -> int:
88
+ """Current number of slots allocated per (batch, head) pair.
89
+
90
+ Derived directly from self.keys rather than tracked separately, so it is
91
+ always consistent with the actual buffer after expansion.
92
+ """
93
+ return self.keys.shape[2]
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Primary API
97
+ # ---------------------------------------------------------------------------
98
+
99
+ def update( # type: ignore[override]
100
+ self,
101
+ key_states: torch.Tensor,
102
+ value_states: torch.Tensor,
103
+ active_mask: torch.Tensor,
104
+ cache_kwargs: dict | None = None,
105
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
+ """Scatter active key/value states using an explicit loop; return full cache state.
107
+
108
+ Iterates over every (b, l, t) triple. For each position where active_mask is
109
+ True, the key and value are written to the next available slot for that
110
+ (batch, head) pair and the count is incremented. Causal ordering is guaranteed
111
+ because the t dimension is traversed from 0 to T-1 and counts are updated
112
+ immediately after each write.
113
+
114
+ Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
+ the incoming tokens would cause any slot to overflow the current capacity.
116
+
117
+ Args:
118
+ key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
119
+ value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
120
+ active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
121
+ cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
122
+
123
+ Returns:
124
+ Tuple of (keys, values, active_mask):
125
+ keys: (B, L, T, u) float — full key buffer including junk slots.
126
+ values: (B, L, T, u) float — full value buffer including junk slots.
127
+ active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
+ """
129
+ B, L, T = active_mask.shape
130
+
131
+ # Expansion check uses the total active tokens per slot, same as the
132
+ # vectorized implementation, so both expand under identical conditions.
133
+ incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
+ if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
+ self._expand()
136
+
137
+ # Write each active position into the next available slot for its (batch, head)
138
+ # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
139
+ for b in range(B):
140
+ for l in range(L):
141
+ for t in range(T):
142
+ if active_mask[b, l, t]:
143
+ pos = self._counts[b, l].item()
144
+ self.keys[b, l, pos, :] = key_states[b, l, t, :]
145
+ self.values[b, l, pos, :] = value_states[b, l, t, :]
146
+ self._counts[b, l] += 1
147
+
148
+ return self.keys, self.values, self._make_active_mask()
149
+
150
+ def get_heads_lengths(self) -> torch.Tensor:
151
+ """Return the per-(batch, head) token count for this layer.
152
+
153
+ This is the authoritative occupancy tensor consumed by BEA for attention
154
+ masking and by position computation (Unit 10.A) for semantic-sequence
155
+ position computation.
156
+
157
+ Returns:
158
+ Integer tensor of shape (B, L) where entry [b, h] is the number of valid
159
+ tokens stored in the (b, h) slot. Zero for slots with no writes yet.
160
+ """
161
+ return self._counts
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # CacheLayerMixin — overridden coordination methods
165
+ # ---------------------------------------------------------------------------
166
+
167
+ def reset(self) -> None:
168
+ """Clear all cached key and value tensors.
169
+
170
+ Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
171
+ and is_initialized remains True — only the contents are cleared.
172
+ """
173
+ self.keys.zero_()
174
+ self.values.zero_()
175
+ self._counts.zero_()
176
+
177
+ def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
178
+ """Reorder the batch dimension of all cached tensors for beam search.
179
+
180
+ Applied atomically across self.keys, self.values, and _counts. Beam search
181
+ must reorder all three together or the occupancy counts and buffer contents
182
+ will correspond to different beam hypotheses.
183
+
184
+ Overrides the parent because the parent's implementation calls get_seq_length(),
185
+ which is not supported for this cache.
186
+
187
+ Args:
188
+ beam_idx: Permutation indices of shape (batch,) produced by the beam
189
+ search algorithm.
190
+ """
191
+ self.keys = self.keys[beam_idx]
192
+ self.values = self.values[beam_idx]
193
+ self._counts = self._counts[beam_idx]
194
+
195
+ def batch_repeat_interleave(self, repeats: int) -> None:
196
+ """Expand the batch dimension by repeating each entry repeats times.
197
+
198
+ Used at beam search initialisation to expand the cache from batch size B to
199
+ B * repeats, matching the expanded beam candidate batch. Applied atomically
200
+ across keys, values, and _counts; batch_size is updated to reflect the new size.
201
+
202
+ Args:
203
+ repeats: Number of times to repeat each batch entry.
204
+ """
205
+ self.keys = self.keys.repeat_interleave(repeats, dim=0)
206
+ self.values = self.values.repeat_interleave(repeats, dim=0)
207
+ self._counts = self._counts.repeat_interleave(repeats, dim=0)
208
+ self.batch_size = self.batch_size * repeats
209
+
210
+ def batch_select_indices(self, indices: torch.Tensor) -> None:
211
+ """Select a subset of batch entries by index.
212
+
213
+ Used in contrastive search to retain only the selected candidate entries.
214
+ Applied atomically across keys, values, and _counts; batch_size is updated
215
+ to reflect the number of retained entries.
216
+
217
+ Args:
218
+ indices: 1-D integer tensor of batch indices to retain.
219
+ """
220
+ self.keys = self.keys[indices]
221
+ self.values = self.values[indices]
222
+ self._counts = self._counts[indices]
223
+ self.batch_size = indices.shape[0]
224
+
225
+ def offload(self) -> None:
226
+ """Offload all cached tensors to CPU.
227
+
228
+ Extends the parent to also offload _counts, which the parent does not know
229
+ about. All three tensors are moved atomically so device state remains consistent.
230
+ """
231
+ super().offload()
232
+ self._counts = self._counts.to("cpu", non_blocking=True)
233
+
234
+ def prefetch(self) -> None:
235
+ """Move all cached tensors back to the model device ahead of time.
236
+
237
+ Extends the parent to also prefetch _counts, which the parent does not know
238
+ about. _counts is synced to self.keys.device after the parent moves keys and
239
+ values, so all three remain consistent.
240
+ """
241
+ super().prefetch()
242
+ if self._counts.device != self.keys.device:
243
+ self._counts = self._counts.to(self.keys.device, non_blocking=True)
244
+
245
+ def lazy_initialization( # type: ignore[override]
246
+ self, key_states: torch.Tensor, value_states: torch.Tensor
247
+ ) -> None:
248
+ """No-op — storage is fully allocated at construction time."""
249
+ pass
250
+
251
+ # ---------------------------------------------------------------------------
252
+ # CacheLayerMixin — unsupported abstract methods
253
+ # ---------------------------------------------------------------------------
254
+
255
+ def get_seq_length(self) -> int: # type: ignore[override]
256
+ """Not supported — no single sequence length represents this cache's state.
257
+
258
+ MoSRAH heads accumulate independently; (batch, head) slots have different
259
+ lengths depending on routing history. There is no meaningful scalar summary.
260
+ Use get_heads_lengths() for per-head occupancy.
261
+ """
262
+ raise NotImplementedError(
263
+ "SlowMoSRAHCache has no single sequence length. "
264
+ "Use get_heads_lengths() for per-head occupancy."
265
+ )
266
+
267
+ def get_max_cache_shape(self) -> int: # type: ignore[override]
268
+ """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
269
+ raise NotImplementedError(
270
+ "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
271
+ )
272
+
273
+ def get_mask_sizes( # type: ignore[override]
274
+ self,
275
+ cache_position: torch.Tensor,
276
+ ) -> tuple[int, int]:
277
+ """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
278
+ raise NotImplementedError(
279
+ "SlowMoSRAHCache does not support get_mask_sizes()."
280
+ )
281
+
282
+ # ---------------------------------------------------------------------------
283
+ # Internal helpers
284
+ # ---------------------------------------------------------------------------
285
+
286
+ def _make_active_mask(self) -> torch.Tensor:
287
+ """Construct the (B, L, T) active mask from current counts.
288
+
289
+ Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
290
+ has been written. Positions at or beyond the count are junk and must be
291
+ excluded by downstream attention.
292
+ """
293
+ cap = self.buffer_capacity
294
+ return (
295
+ torch.arange(cap, device=self.keys.device)
296
+ .expand(self.batch_size, self.num_mosrah_heads, cap)
297
+ < self._counts.unsqueeze(-1)
298
+ )
299
+
300
+ def _expand(self) -> None:
301
+ """Double the buffer capacity, preserving existing data.
302
+
303
+ Called by update() when an incoming batch of tokens would cause any
304
+ (batch, head) slot to exceed the current buffer capacity. All existing
305
+ key and value data is copied into the low half of the new buffer; the
306
+ high half is zero-initialised and will be filled by subsequent writes.
307
+ After reassignment, buffer_capacity reflects the new size automatically.
308
+ """
309
+ old_cap = self.buffer_capacity
310
+ new_cap = old_cap * 2
311
+ dev = self.keys.device
312
+ new_keys = torch.zeros(
313
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
+ )
315
+ new_values = torch.zeros(
316
+ self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
+ )
318
+ new_keys[:, :, :old_cap, :] = self.keys
319
+ new_values[:, :, :old_cap, :] = self.values
320
+ self.keys = new_keys
321
+ self.values = new_values
__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration import ShramConfig
2
+ from .decoder_layer import DecoderLayer
3
+ from .huggingface import ShramForCausalLM
4
+ from .__attention__load_balance_loss import LoadBalanceLoss
5
+ from .mlp import SwiGLUMLP
6
+ from .model import ShramModel
7
+ from .rope import RotaryEmbedding
8
+ from .__attention__router import MoSRAHRouter
9
+ from .__cache__mosrah_cache import MoSRAHCache
10
+
11
+ __all__ = [
12
+ "DecoderLayer",
13
+ "LoadBalanceLoss",
14
+ "MoSRAHCache",
15
+ "MoSRAHRouter",
16
+ "ShramConfig",
17
+ "ShramForCausalLM",
18
+ "ShramModel",
19
+ "RotaryEmbedding",
20
+ "SwiGLUMLP",
21
+ ]
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 1.0,
3
+ "attention_dropout": 0.0,
4
+ "auto_map": {
5
+ "AutoConfig": "configuration.ShramConfig",
6
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
+ },
8
+ "beta": 32.0,
9
+ "head_dim": 16,
10
+ "hidden_size": 512,
11
+ "inference_sequence_length": 1024,
12
+ "intermediate_size": 1366,
13
+ "local_rope_theta": 10000.0,
14
+ "model_type": "shram",
15
+ "mosrah_rope_theta": 10000.0,
16
+ "num_hidden_layers": 12,
17
+ "num_mosrah_heads": 16,
18
+ "num_selected_heads": 16,
19
+ "num_sliding_window_heads": 16,
20
+ "rms_norm_eps": 1e-05,
21
+ "rope_mode": "main_sequence",
22
+ "tie_word_embeddings": false,
23
+ "training_sequence_length": 1024,
24
+ "transformers_version": "5.3.0",
25
+ "use_cache": true,
26
+ "vocab_size": 50277,
27
+ "window_size": 128
28
+ }
configuration.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for the SHRAM transformer.
2
+
3
+ All architectural parameters that vary across model scales or are meaningful research
4
+ variables are expressed here. Architectural constants (no bias in linear layers,
5
+ SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
+ documented at the point of use — they are not config parameters because they do not
7
+ vary and changing them produces a different architecture, not a different scale.
8
+
9
+ RoPE configuration is owned entirely by this config. Each attention path reads its
10
+ parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
+ HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
+ """
13
+
14
+ from transformers import PretrainedConfig
15
+
16
+
17
+ class ShramConfig(PretrainedConfig):
18
+ """Configuration class for the SHRAM decoder-only transformer.
19
+
20
+ SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
21
+ attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
22
+ local sliding-window causal attention path and h_s is the MoSRAH sparse routed
23
+ path. All other components follow the Llama 3 baseline.
24
+
25
+ This config is the single source of truth for every architectural dimension of the
26
+ model. Nothing in the architecture may use a literal number that belongs here.
27
+
28
+ Two independent RoPE configurations exist — one per attention path:
29
+
30
+ - h_l always uses standard RoPE with ``local_rope_theta``.
31
+ - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
32
+ ``inference_sequence_length``, ``alpha``, and ``beta``. When
33
+ ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
34
+ ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
35
+ and the correct setting for experiments that do not require context extension.
36
+
37
+ Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
38
+
39
+ config = AutoConfig.from_pretrained(
40
+ "your-namespace/advanced-transformers-lib",
41
+ trust_remote_code=True,
42
+ num_hidden_layers=12,
43
+ )
44
+ model = AutoModelForCausalLM.from_config(config)
45
+
46
+ Args:
47
+ vocab_size: Vocabulary size. Controls the embedding table and output logits
48
+ dimension. Must match the tokenizer.
49
+ embedding_width: Model width ``d``. The dimension of the residual stream.
50
+ mlp_width: FFN hidden dimension.
51
+ num_decoder_layers: Number of transformer blocks stacked in sequence.
52
+ num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
53
+ num_mosrah_heads: Total MoSRAH expert heads available ``L``.
54
+ num_selected_heads: MoSRAH heads each token selects ``K``.
55
+ head_dim: Per-head dimension, shared by both attention paths. Must be even
56
+ (RoPE rotates dimensions in pairs). Paper uses 16.
57
+ window_size: Sliding window size for h_l. Paper uses 128.
58
+ rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
59
+ original sequence positions; ``"semantic_sequence"`` supplies local slot
60
+ indices. Both are required; experimentally correct mode is undetermined
61
+ (paper §4). Default ``"main_sequence"``.
62
+ rms_norm_eps: Epsilon for RMSNorm layers.
63
+ local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
64
+ Paper uses b=10000.
65
+ mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
66
+ b=10000.
67
+ training_sequence_length: Context length ``C_train`` the model was or will be
68
+ trained at. Used to compute the YaRN scale factor for BEA.
69
+ inference_sequence_length: Context length ``C_target`` the model must support
70
+ at inference. When equal to ``training_sequence_length``, scale ``s=1``
71
+ and YaRN reduces to standard RoPE.
72
+ alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
73
+ ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
74
+ beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
75
+ ``r(d) > beta`` are left unscaled. Paper value: 32.0.
76
+ attention_dropout: Dropout probability on attention weights. Default 0.0.
77
+ use_cache: Whether to return past_key_values for KV caching.
78
+ output_hidden_states: Whether to return hidden states after each layer.
79
+ tie_word_embeddings: Whether input embedding and LM head share weights.
80
+ """
81
+
82
+ model_type = "shram"
83
+
84
+ auto_map = {
85
+ "AutoConfig": "configuration.ShramConfig",
86
+ "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ vocab_size: int = 50277,
92
+ embedding_width: int = 512,
93
+ mlp_width: int = 1366,
94
+ num_decoder_layers: int = 12,
95
+ num_sliding_window_heads: int = 16,
96
+ num_mosrah_heads: int = 16,
97
+ num_selected_heads: int = 16,
98
+ head_dim: int = 16,
99
+ window_size: int = 128,
100
+ rope_mode: str = "main_sequence",
101
+ rms_norm_eps: float = 1e-5,
102
+ local_rope_theta: float = 10000.0,
103
+ mosrah_rope_theta: float = 10000.0,
104
+ training_sequence_length: int = 1024,
105
+ alpha: float = 1.0,
106
+ beta: float = 32.0,
107
+ attention_dropout: float = 0.0,
108
+ use_cache: bool = True,
109
+ output_hidden_states: bool = False,
110
+ tie_word_embeddings: bool = False,
111
+ **kwargs,
112
+ ):
113
+ if head_dim % 2 != 0:
114
+ raise ValueError(
115
+ f"head_dim must be even (RoPE rotates dimensions in pairs). "
116
+ f"Got head_dim={head_dim}."
117
+ )
118
+
119
+ if rope_mode not in {"main_sequence", "semantic_sequence"}:
120
+ raise ValueError(
121
+ f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
122
+ f"got '{rope_mode}'."
123
+ )
124
+
125
+ if training_sequence_length <= 0:
126
+ raise ValueError(
127
+ f"training_sequence_length must be positive, "
128
+ f"got {training_sequence_length}."
129
+ )
130
+
131
+ # inference_sequence_length is not a constructor parameter. It defaults to
132
+ # training_sequence_length (scale=1.0, standard RoPE). If a saved config
133
+ # carries the field through kwargs (after set_inference_context() was called
134
+ # before saving), restore it here with validation.
135
+ saved_inference_length = kwargs.pop("inference_sequence_length", training_sequence_length)
136
+ if saved_inference_length <= 0:
137
+ raise ValueError(
138
+ f"inference_sequence_length must be positive, "
139
+ f"got {saved_inference_length}."
140
+ )
141
+
142
+ self.vocab_size = vocab_size
143
+ self.hidden_size = embedding_width
144
+ self.intermediate_size = mlp_width
145
+ self.num_hidden_layers = num_decoder_layers
146
+ self.num_sliding_window_heads = num_sliding_window_heads
147
+ self.num_mosrah_heads = num_mosrah_heads
148
+ self.num_selected_heads = num_selected_heads
149
+ self.head_dim = head_dim
150
+ self.window_size = window_size
151
+ self.rope_mode = rope_mode
152
+ self.rms_norm_eps = rms_norm_eps
153
+ self.local_rope_theta = local_rope_theta
154
+ self.mosrah_rope_theta = mosrah_rope_theta
155
+ self.training_sequence_length = training_sequence_length
156
+ self.inference_sequence_length = saved_inference_length
157
+ self.alpha = alpha
158
+ self.beta = beta
159
+ self.attention_dropout = attention_dropout
160
+ self.use_cache = use_cache
161
+
162
+ super().__init__(
163
+ tie_word_embeddings=tie_word_embeddings,
164
+ output_hidden_states=output_hidden_states,
165
+ **kwargs,
166
+ )
167
+
168
+ # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
169
+ # serialises it into config.json.
170
+ self.auto_map = type(self).auto_map
171
+
172
+ @property
173
+ def scale(self) -> float:
174
+ """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
175
+
176
+ When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
177
+ adjustments cancel and A_rope = 1. This is the default state.
178
+ """
179
+ return self.inference_sequence_length / self.training_sequence_length
180
+
181
+ def set_inference_context(self, inference_sequence_length: int) -> None:
182
+ """Set the inference context length for YaRN context extension.
183
+
184
+ This is the only supported way to set inference_sequence_length. At construction
185
+ the inference context defaults to training_sequence_length (scale=1.0, standard
186
+ RoPE). Call this method to configure a longer inference context, which causes
187
+ YaRN to interpolate frequencies and extend the effective context window.
188
+
189
+ Args:
190
+ inference_sequence_length: Target inference context length. Must be positive.
191
+ Values equal to training_sequence_length produce scale=1.0 (standard RoPE).
192
+ Values greater than training_sequence_length enable YaRN extrapolation.
193
+ """
194
+ if inference_sequence_length <= 0:
195
+ raise ValueError(
196
+ f"inference_sequence_length must be positive, "
197
+ f"got {inference_sequence_length}."
198
+ )
199
+ self.inference_sequence_length = inference_sequence_length
decoder_layer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Decoder layer — a single transformer block.
2
+
3
+ Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
4
+ residual connections around both sublayers:
5
+
6
+ normed_attn = RMSNorm(x)
7
+ attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
8
+ h = x + attn_out
9
+
10
+ normed_mlp = RMSNorm(h)
11
+ mlp_out = SwiGLUMLP(normed_mlp)
12
+ out = h + mlp_out
13
+
14
+ Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
15
+ through unnormalised residuals at depth, and each sublayer receives a stable,
16
+ normalised view of the signal.
17
+
18
+ Two independent RMSNorm instances are used — one before attention, one before
19
+ MLP. They learn different scalings because they precede layers with different
20
+ dynamic ranges. Sharing them would be wrong.
21
+
22
+ torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
23
+ subtraction, is faster than LayerNorm, and proved more stable at scale.
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from .__attention__shram import SHRAMHybridLayer
30
+ from .__cache__shram_layer_cache import ShramLayerCache
31
+ from .configuration import ShramConfig
32
+ from .mlp import SwiGLUMLP
33
+
34
+
35
+ class DecoderLayer(nn.Module):
36
+ """A single pre-norm SHRAM decoder block.
37
+
38
+ Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
39
+ independent RMSNorm instances on each sublayer input.
40
+
41
+ Args:
42
+ config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
43
+ in addition to the fields required by SHRAMHybridLayer and
44
+ SwiGLUMLP.
45
+ """
46
+
47
+ def __init__(self, config: ShramConfig) -> None:
48
+ super().__init__()
49
+ self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
+ self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51
+ self.attention = SHRAMHybridLayer(config)
52
+ self.mlp = SwiGLUMLP(config)
53
+
54
+ def forward(
55
+ self,
56
+ x: torch.Tensor,
57
+ position_ids: torch.Tensor,
58
+ active_mask: torch.Tensor,
59
+ cache: ShramLayerCache | None = None,
60
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61
+ """Apply one decoder block to the input.
62
+
63
+ Args:
64
+ x: Input of shape (batch, seq_len, hidden_size).
65
+ position_ids: Authoritative positions of shape (batch, seq_len).
66
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
67
+ where True means the token is semantically live. Forwarded
68
+ unchanged to the hybrid attention layer.
69
+ cache: Optional per-layer SHRAM cache passed through to the hybrid
70
+ attention layer unchanged.
71
+
72
+ Returns:
73
+ output: Tensor of shape (batch, seq_len, hidden_size).
74
+ load_balance_loss: Scalar sparse-path load-balance loss propagated
75
+ from SHRAMHybridLayer.
76
+ max_vio: Detached scalar routing-imbalance summary. Passed through
77
+ unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
78
+ """
79
+ attn_out, load_balance_loss, max_vio = self.attention(
80
+ hidden_states=self.attn_norm(x),
81
+ position_ids=position_ids,
82
+ active_mask=active_mask,
83
+ cache=cache,
84
+ )
85
+ hidden_states = x + attn_out
86
+ output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
87
+ return output, load_balance_loss, max_vio
huggingface.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace causal-LM wrapper for SHRAM.
2
+
3
+ ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
4
+ It owns token embedding lookup, LM-head projection, wrapper-level next-token
5
+ cross-entropy loss, config-controlled tied embeddings, and generation/cache
6
+ orchestration at the wrapper boundary.
7
+
8
+ The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
9
+ hidden states together with current position IDs, a current active mask, and an
10
+ optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
11
+ or causal-LM loss.
12
+
13
+ HuggingFace generation reaches this wrapper with two different tensor
14
+ conventions:
15
+
16
+ - ``position_ids`` is a current-step tensor. GenerationMixin updates the total
17
+ sequence state between steps, then slices position-bearing tensors back down
18
+ before calling ``forward()``.
19
+ - ``attention_mask`` is a full 2D mask over the total sequence so far. This
20
+ wrapper slices its recent chunk to produce the current semantic liveness mask
21
+ expected by the backbone.
22
+
23
+ Generation-created caches are handled in ``_prepare_cache_for_generation``.
24
+ That hook ensures HuggingFace generation uses ShramCache rather than a generic
25
+ dynamic cache. The direct ``forward()`` path does not silently create caches;
26
+ when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
27
+ """
28
+
29
+ from dataclasses import dataclass
30
+ from typing import Any
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ from transformers import GenerationMixin, PreTrainedModel
35
+ from transformers.cache_utils import Cache
36
+ from transformers.generation.configuration_utils import GenerationMode
37
+ from transformers.modeling_outputs import CausalLMOutputWithPast
38
+
39
+ from .__cache__shram_cache import ShramCache
40
+ from .configuration import ShramConfig
41
+ from .model import ShramModel
42
+
43
+
44
+ @dataclass
45
+ class ShramCausalLMOutput(CausalLMOutputWithPast):
46
+ """SHRAM causal-LM wrapper output.
47
+
48
+ This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
49
+ Dataclass inheritance is sufficient here: all standard causal-LM fields and
50
+ ModelOutput behavior are inherited from the parent, and this subclass adds
51
+ only the SHRAM-specific wrapper outputs.
52
+ """
53
+
54
+ ce_loss: torch.FloatTensor | None = None
55
+ load_balance_loss: torch.FloatTensor | None = None
56
+ max_vio: torch.FloatTensor | None = None
57
+
58
+
59
+ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
60
+ """HuggingFace-facing causal language model wrapper for SHRAM.
61
+
62
+ Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
63
+ tied embedding configuration, and generation/cache boundary behavior.
64
+ Delegates all transformer computation to ``ShramModel``.
65
+
66
+ Args:
67
+ config: SHRAM model configuration.
68
+ """
69
+
70
+ config_class = ShramConfig
71
+ base_model_prefix = "model"
72
+ _no_split_modules = ["DecoderLayer"]
73
+ supports_gradient_checkpointing = True
74
+
75
+ def __init__(self, config: ShramConfig) -> None:
76
+ super().__init__(config)
77
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
78
+ self.model = ShramModel(config)
79
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
+ self._configure_tied_embeddings()
81
+ self.post_init()
82
+
83
+ def _configure_tied_embeddings(self) -> None:
84
+ """Apply config-controlled tied embedding behavior on this instance."""
85
+ if self.config.tie_word_embeddings:
86
+ self.lm_head.weight = self.embed_tokens.weight
87
+ self._tied_weights_keys = {
88
+ "lm_head.weight": "embed_tokens.weight",
89
+ }
90
+ else:
91
+ self._tied_weights_keys = {}
92
+
93
+ def get_input_embeddings(self) -> nn.Embedding:
94
+ """Return the token embedding matrix."""
95
+ return self.embed_tokens
96
+
97
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
98
+ """Replace the token embedding matrix."""
99
+ self.embed_tokens = value
100
+ self._configure_tied_embeddings()
101
+
102
+ def get_output_embeddings(self) -> nn.Linear:
103
+ """Return the LM head."""
104
+ return self.lm_head
105
+
106
+ def set_output_embeddings(self, value: nn.Linear) -> None:
107
+ """Replace the LM head."""
108
+ self.lm_head = value
109
+ self._configure_tied_embeddings()
110
+
111
+ def _build_shram_cache(
112
+ self,
113
+ batch_size: int,
114
+ device: torch.device,
115
+ ) -> ShramCache:
116
+ """Construct a fresh top-level SHRAM cache."""
117
+ return ShramCache(
118
+ num_hidden_layers=self.config.num_hidden_layers,
119
+ sliding_window=self.config.window_size,
120
+ num_local_heads=self.config.num_sliding_window_heads,
121
+ local_head_dim=self.config.head_dim,
122
+ num_mosrah_heads=self.config.num_mosrah_heads,
123
+ mosrah_head_dim=self.config.hidden_size // self.config.num_selected_heads,
124
+ batch_size=batch_size,
125
+ device=device,
126
+ )
127
+
128
+ def _validate_generation_cache_request(
129
+ self,
130
+ generation_config: Any,
131
+ model_kwargs: dict[str, Any],
132
+ generation_mode: GenerationMode,
133
+ ) -> None:
134
+ """Validate SHRAM's generation-side cache policy."""
135
+ if generation_mode in {
136
+ GenerationMode.ASSISTED_GENERATION,
137
+ GenerationMode.CONTRASTIVE_SEARCH,
138
+ }:
139
+ raise NotImplementedError(
140
+ "ShramForCausalLM does not currently support assisted generation "
141
+ "or contrastive search because ShramCache does not support crop()."
142
+ )
143
+
144
+ user_defined_cache = model_kwargs.get("past_key_values")
145
+ if user_defined_cache is not None:
146
+ if generation_config.cache_implementation is not None:
147
+ raise ValueError(
148
+ "Passing both `cache_implementation` and `past_key_values` "
149
+ "is unsupported. Please use only one."
150
+ )
151
+ if isinstance(user_defined_cache, tuple):
152
+ raise ValueError(
153
+ "Passing a tuple of `past_key_values` is not supported. "
154
+ "Please use a `ShramCache` instance."
155
+ )
156
+ if not isinstance(user_defined_cache, ShramCache):
157
+ raise TypeError(
158
+ "ShramForCausalLM requires `past_key_values` to be a "
159
+ "`ShramCache` instance."
160
+ )
161
+
162
+ if (
163
+ user_defined_cache is None
164
+ and generation_config.use_cache
165
+ and generation_config.cache_implementation is not None
166
+ ):
167
+ raise ValueError(
168
+ "ShramForCausalLM does not support `cache_implementation`. "
169
+ "Generation-created caches must be `ShramCache` objects."
170
+ )
171
+
172
+ def _prepare_cache_for_generation(
173
+ self,
174
+ generation_config: Any,
175
+ model_kwargs: dict[str, Any],
176
+ generation_mode: GenerationMode,
177
+ batch_size: int,
178
+ max_cache_length: int,
179
+ ) -> None:
180
+ """Ensure HuggingFace generation uses ShramCache.
181
+
182
+ This is the SHRAM-specific generation hook. The rest of the default
183
+ generation plumbing is kept intact as much as possible.
184
+
185
+ Args:
186
+ generation_config: Active generation configuration.
187
+ model_kwargs: Generation kwargs, updated in place.
188
+ generation_mode: HuggingFace generation mode.
189
+ batch_size: Effective generation batch size.
190
+ max_cache_length: Requested cache length. Accepted but unused here.
191
+ """
192
+ self._validate_generation_cache_request(
193
+ generation_config=generation_config,
194
+ model_kwargs=model_kwargs,
195
+ generation_mode=generation_mode,
196
+ )
197
+
198
+ if model_kwargs.get("past_key_values") is not None:
199
+ return
200
+
201
+ if not generation_config.use_cache:
202
+ return
203
+
204
+ num_repeats = max(
205
+ generation_config.num_beams or 1,
206
+ generation_config.num_return_sequences or 1,
207
+ )
208
+ model_kwargs["past_key_values"] = self._build_shram_cache(
209
+ batch_size=batch_size*num_repeats,
210
+ device=self.embed_tokens.weight.device,
211
+ )
212
+
213
+ def _reorder_cache(
214
+ self,
215
+ past_key_values: Cache,
216
+ beam_idx: torch.Tensor,
217
+ ) -> Cache:
218
+ """Reorder the cache in place for beam search."""
219
+ past_key_values.reorder_cache(beam_idx)
220
+ return past_key_values
221
+
222
+ def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
223
+ """Validate token IDs at the wrapper boundary."""
224
+ if input_ids.ndim != 2:
225
+ raise ValueError("input_ids must have shape (batch, seq_len).")
226
+ if input_ids.shape[1] == 0:
227
+ raise ValueError("input_ids sequence length must be nonzero.")
228
+ if input_ids.dtype != torch.long:
229
+ raise TypeError("input_ids must be an long int tensor.")
230
+
231
+ def _validate_attention_mask(
232
+ self,
233
+ input_ids: torch.Tensor,
234
+ attention_mask: torch.Tensor | None,
235
+ ) -> None:
236
+ """Validate the full-sequence attention mask."""
237
+ if attention_mask is None:
238
+ return
239
+ if attention_mask.ndim != 2:
240
+ raise ValueError("attention_mask must have shape (batch, total_seq_len).")
241
+ if attention_mask.shape[0] != input_ids.shape[0]:
242
+ raise ValueError("attention_mask batch dimension must match input_ids.")
243
+ if attention_mask.shape[1] < input_ids.shape[1]:
244
+ raise ValueError(
245
+ "attention_mask must be at least as long as the current input_ids chunk."
246
+ )
247
+
248
+ def _validate_position_ids(
249
+ self,
250
+ input_ids: torch.Tensor,
251
+ position_ids: torch.Tensor | None,
252
+ ) -> None:
253
+ """Validate current-step position IDs."""
254
+ if position_ids is None:
255
+ return
256
+ if position_ids.ndim != 2:
257
+ raise ValueError("position_ids must have shape (batch, seq_len).")
258
+ if position_ids.shape != input_ids.shape:
259
+ raise ValueError(
260
+ "position_ids must match the current input_ids shape exactly."
261
+ )
262
+ if input_ids.dtype != torch.long:
263
+ raise TypeError("position_ids must be an long tensor.")
264
+
265
+ def _validate_labels(
266
+ self,
267
+ input_ids: torch.Tensor,
268
+ labels: torch.Tensor | None,
269
+ ) -> None:
270
+ """Validate label shape at the wrapper boundary."""
271
+ if labels is None:
272
+ return
273
+ if labels.ndim != 2:
274
+ raise ValueError("labels must have shape (batch, seq_len).")
275
+ if labels.shape != input_ids.shape:
276
+ raise ValueError("labels must have the same shape as input_ids.")
277
+ if input_ids.dtype != torch.long:
278
+ raise TypeError("labels must be a long tensor.")
279
+
280
+ def _validate_cache_inputs(
281
+ self,
282
+ use_cache: bool,
283
+ past_key_values: Cache | None,
284
+ ) -> None:
285
+ """Validate cache policy for direct wrapper calls."""
286
+ if use_cache:
287
+ if past_key_values is None:
288
+ raise ValueError(
289
+ "use_cache=True requires an explicit ShramCache. During "
290
+ "generate(), HuggingFace should supply this through "
291
+ "_prepare_cache_for_generation()."
292
+ )
293
+ if not isinstance(past_key_values, ShramCache):
294
+ raise TypeError(
295
+ "past_key_values must be a ShramCache when use_cache=True."
296
+ )
297
+ return
298
+
299
+ if past_key_values is not None:
300
+ raise ValueError("past_key_values was provided while use_cache=False.")
301
+
302
+ def _validate_position_sources(
303
+ self,
304
+ use_cache: bool,
305
+ attention_mask: torch.Tensor | None,
306
+ position_ids: torch.Tensor | None,
307
+ ) -> None:
308
+ """Validate that cached forward has a truthful source of positions."""
309
+ if use_cache and attention_mask is None and position_ids is None:
310
+ raise ValueError(
311
+ "Cached forward requires either position_ids or attention_mask."
312
+ )
313
+
314
+ def _validate_hf_boundary(
315
+ self,
316
+ output_attentions: bool | None,
317
+ return_dict: bool | None,
318
+ inputs_embeds: torch.Tensor | None,
319
+ cache_position: torch.Tensor | None,
320
+ extra_kwargs: dict[str, Any],
321
+ ) -> None:
322
+ """Validate unsupported HuggingFace-facing wrapper inputs."""
323
+ if output_attentions:
324
+ raise NotImplementedError(
325
+ "ShramForCausalLM does not expose output_attentions."
326
+ )
327
+ if return_dict is False:
328
+ raise ValueError(
329
+ "return_dict=False is not supported. "
330
+ "ShramForCausalLM always returns ShramCausalLMOutput."
331
+ )
332
+ if inputs_embeds is not None:
333
+ raise ValueError(
334
+ "inputs_embeds is not supported at the SHRAM wrapper boundary. "
335
+ "Pass input_ids instead."
336
+ )
337
+ if extra_kwargs:
338
+ unsupported = ", ".join(sorted(extra_kwargs))
339
+ raise TypeError(
340
+ f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
341
+ )
342
+
343
+ def _standardize_full_attention_mask(
344
+ self,
345
+ input_ids: torch.Tensor,
346
+ attention_mask: torch.Tensor | None,
347
+ ) -> torch.BoolTensor:
348
+ """Return a concrete full-sequence boolean attention mask."""
349
+ if attention_mask is None:
350
+ return torch.ones_like(input_ids, dtype=torch.bool)
351
+ return attention_mask.to(dtype=torch.bool)
352
+
353
+ def _resolve_current_position_ids(
354
+ self,
355
+ input_ids: torch.Tensor,
356
+ position_ids: torch.Tensor | None,
357
+ full_attention_mask: torch.BoolTensor,
358
+ ) -> torch.LongTensor:
359
+ """Resolve concrete current-step position IDs for the backbone."""
360
+ if position_ids is not None:
361
+ return position_ids.to(dtype=torch.long)
362
+
363
+ full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
364
+ full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
365
+ current_length = input_ids.shape[1]
366
+ return full_position_ids[:, -current_length:]
367
+
368
+ def forward(
369
+ self,
370
+ input_ids: torch.Tensor,
371
+ attention_mask: torch.Tensor | None = None,
372
+ position_ids: torch.Tensor | None = None,
373
+ past_key_values: Cache | None = None,
374
+ use_cache: bool | None = None,
375
+ output_hidden_states: bool | None = None,
376
+ labels: torch.Tensor | None = None,
377
+ return_dict: bool | None = None,
378
+ ce_weight: float = 1.0,
379
+ load_balance_weight: float = 0.01,
380
+ **kwargs: Any,
381
+ ) -> ShramCausalLMOutput:
382
+ """Run the SHRAM causal language model wrapper.
383
+
384
+ Args:
385
+ input_ids: Current token IDs of shape ``(batch, seq_len)``.
386
+ attention_mask: Optional full 2D mask of shape
387
+ ``(batch, total_seq_len)``. The wrapper slices its recent chunk
388
+ to produce the current semantic liveness mask expected by the
389
+ backbone.
390
+ position_ids: Optional current-step position IDs of shape
391
+ ``(batch, seq_len)``. In ordinary HuggingFace generation this is
392
+ already the current-step tensor when it reaches ``forward()``.
393
+ past_key_values: Optional SHRAM cache. Required when
394
+ ``use_cache=True``.
395
+ use_cache: Whether to use and return a cache. Defaults to
396
+ ``config.use_cache``.
397
+ output_hidden_states: Whether to return backbone hidden states.
398
+ Defaults to ``config.output_hidden_states``.
399
+ labels: Optional target token IDs of shape ``(batch, seq_len)``.
400
+ return_dict: Must be ``True`` or ``None``.
401
+ ce_weight: Weight applied to the cross-entropy loss when combining with
402
+ the load-balance loss. Default 1.0.
403
+ load_balance_weight: Weight applied to the load-balance auxiliary loss.
404
+ Default 0.01, matching the paper's recommendation.
405
+ **kwargs: Unsupported HuggingFace kwargs fail explicitly.
406
+
407
+ Returns:
408
+ ``ShramCausalLMOutput`` with:
409
+ - ``logits`` of shape ``(batch, seq_len, vocab_size)``,
410
+ - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
411
+ when labels are provided (``None`` otherwise),
412
+ - ``ce_loss`` — raw unweighted cross-entropy loss for logging,
413
+ - ``past_key_values`` as the active ``ShramCache`` or ``None``,
414
+ - ``hidden_states`` when requested,
415
+ - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
416
+ - detached ``max_vio`` from the backbone.
417
+ """
418
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
419
+ output_hidden_states = (
420
+ output_hidden_states
421
+ if output_hidden_states is not None
422
+ else self.config.output_hidden_states
423
+ )
424
+
425
+ inputs_embeds = kwargs.pop("inputs_embeds", None)
426
+ output_attentions = kwargs.pop("output_attentions", None)
427
+ cache_position = kwargs.pop("cache_position", None)
428
+
429
+ # ------------------------------------------------------------------
430
+ # Validation zone.
431
+ #
432
+ # The wrapper boundary is where HuggingFace-facing inputs are judged
433
+ # for truthfulness before any internal work begins. These checks are
434
+ # intentionally front-loaded so the core logic below can assume one
435
+ # coherent interpretation of the call rather than defensively checking
436
+ # shapes, cache policy, or unsupported HF knobs at the point of use.
437
+ # This keeps the main sequence readable while ensuring invalid states
438
+ # fail before they can silently contaminate backbone execution.
439
+ # ------------------------------------------------------------------
440
+ self._validate_input_ids(input_ids)
441
+ self._validate_attention_mask(input_ids, attention_mask)
442
+ self._validate_position_ids(input_ids, position_ids)
443
+ self._validate_labels(input_ids, labels)
444
+ self._validate_cache_inputs(use_cache, past_key_values)
445
+ self._validate_position_sources(use_cache, attention_mask, position_ids)
446
+ self._validate_hf_boundary(
447
+ output_attentions=output_attentions,
448
+ return_dict=return_dict,
449
+ inputs_embeds=inputs_embeds,
450
+ cache_position=cache_position,
451
+ extra_kwargs=kwargs,
452
+ )
453
+
454
+ # ------------------------------------------------------------------
455
+ # Standardization zone.
456
+ #
457
+ # HuggingFace and SHRAM use different boundary conventions: generation
458
+ # carries a full-sequence 2D attention mask, while the SHRAM backbone
459
+ # wants a current-step active mask and concrete current position IDs.
460
+ # This zone collapses those wrapper-facing conventions into one valid
461
+ # backbone-facing state. After this point the core no longer reasons
462
+ # about optional or ambiguous input forms; it works only with concrete
463
+ # tensors whose semantics are already fixed.
464
+ # ------------------------------------------------------------------
465
+ full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ )
469
+ current_length: int = input_ids.shape[1]
470
+ current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
471
+ current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
472
+ input_ids=input_ids,
473
+ position_ids=position_ids,
474
+ full_attention_mask=full_attention_mask,
475
+ )
476
+ shram_cache: ShramCache | None = past_key_values if use_cache else None
477
+
478
+ # ------------------------------------------------------------------
479
+ # Core wrapper responsibilities.
480
+ #
481
+ # The wrapper's primary job is kept visible here: convert token IDs to
482
+ # embeddings, delegate transformer computation to ShramModel, project
483
+ # hidden states back to vocabulary logits, optionally compute the
484
+ # wrapper-level shifted next-token loss, and return the HuggingFace-
485
+ # facing output object. The backbone remains responsible only for
486
+ # transformer semantics; token/vocabulary/loss concerns stay here.
487
+ # ------------------------------------------------------------------
488
+ token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
489
+ backbone_outputs = self.model(
490
+ inputs_embeds=token_embeddings,
491
+ position_ids=current_position_ids,
492
+ active_mask=current_active_mask,
493
+ cache=shram_cache,
494
+ output_hidden_states=output_hidden_states,
495
+ )
496
+
497
+ logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
498
+
499
+ ce_loss: torch.FloatTensor | None = None
500
+ loss: torch.FloatTensor | None = None
501
+ if labels is not None:
502
+ shift_logits = logits[:, :-1, :].contiguous()
503
+ shift_labels = labels[:, 1:].contiguous()
504
+ ce_loss = nn.functional.cross_entropy(
505
+ shift_logits.view(-1, self.config.vocab_size),
506
+ shift_labels.view(-1),
507
+ )
508
+ loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
509
+
510
+ return ShramCausalLMOutput(
511
+ loss=loss,
512
+ ce_loss=ce_loss,
513
+ logits=logits,
514
+ past_key_values=backbone_outputs["past_key_values"],
515
+ hidden_states=backbone_outputs["hidden_states"],
516
+ load_balance_loss=backbone_outputs["load_balance_loss"],
517
+ max_vio=backbone_outputs["max_vio"],
518
+ )
mlp.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SwiGLU feed-forward sublayer.
2
+
3
+ SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
4
+ element-wise against a separate up-projection:
5
+
6
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
7
+
8
+ The gating mechanism gives the network more expressive control over which features
9
+ to propagate than a plain two-matrix FFN. It requires three weight matrices instead
10
+ of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
11
+ typical of two-matrix FFNs — the total parameter count remains comparable.
12
+
13
+ SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
14
+ — a fixed architectural choice.
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class SwiGLUMLP(nn.Module):
24
+ """SwiGLU feed-forward sublayer.
25
+
26
+ Implements the three-matrix SwiGLU FFN used in Llama 3:
27
+
28
+ output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
29
+
30
+ No bias on any projection. SiLU as the gate activation is an architectural
31
+ constant — it is what defines SwiGLU specifically.
32
+
33
+ Args:
34
+ config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
35
+ """
36
+
37
+ def __init__(self, config: PretrainedConfig) -> None:
38
+ super().__init__()
39
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
40
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
41
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ """Apply the SwiGLU feed-forward transformation.
45
+
46
+ Args:
47
+ x: Input tensor of shape (batch, seq_len, hidden_size).
48
+
49
+ Returns:
50
+ Output tensor of shape (batch, seq_len, hidden_size).
51
+ """
52
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer backbone for Shram.
2
+
3
+ ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
4
+ by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
5
+ representations. It has no knowledge of tokens, vocabulary, generation, or the
6
+ HuggingFace causal-LM wrapper contract.
7
+
8
+ Keeping the embedding out of the backbone is the correct convention and makes
9
+ the backbone genuinely modality-agnostic. The token interface — embedding lookup,
10
+ LM head, weight tying, and generation-facing naming conventions — belongs on the
11
+ task wrapper (ShramForCausalLM), which is the only class that knows this
12
+ backbone is being used for language modelling.
13
+
14
+ The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
15
+ each sublayer normalises its own input, leaving the residual stream itself
16
+ unnormalised. After many layers of accumulated residuals, that stream arrives at
17
+ the top with uncontrolled magnitude. The final norm brings it to a well-scaled
18
+ state before any projection. Without it, the LM head would receive signals of
19
+ arbitrary scale.
20
+
21
+ Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
22
+ corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
23
+ same top-level ShramCache object in the output dict. If None is provided, no
24
+ caching occurs.
25
+
26
+ Returns a plain dict with keys:
27
+ - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
28
+ - "past_key_values": the ShramCache object passed in, or None
29
+ - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
30
+ - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
31
+ - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
32
+ """
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+
37
+ from .__cache__shram_cache import ShramCache
38
+ from .configuration import ShramConfig
39
+ from .decoder_layer import DecoderLayer
40
+
41
+
42
+ class ShramModel(nn.Module):
43
+ """Pure transformer backbone: decoder stack and final normalisation.
44
+
45
+ Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
46
+ and returns contextual representations of the same shape. No token embedding,
47
+ vocabulary projection, or causal-LM lifecycle concerns.
48
+
49
+ RoPE is applied inside each attention layer. Positional information is
50
+ encoded in the relationship between Q and K, not added to the residual
51
+ stream, so the backbone is agnostic to how positions are represented.
52
+
53
+ Args:
54
+ config: Model configuration. Must be a ``ShramConfig`` instance.
55
+ """
56
+
57
+ def __init__(self, config: ShramConfig) -> None:
58
+ super().__init__()
59
+ self.config = config
60
+ self.layers = nn.ModuleList(
61
+ [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
62
+ )
63
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
+
65
+ def forward(
66
+ self,
67
+ inputs_embeds: torch.Tensor,
68
+ position_ids: torch.Tensor,
69
+ active_mask: torch.Tensor,
70
+ cache: ShramCache | None = None,
71
+ output_hidden_states: bool = False,
72
+ ) -> dict:
73
+ """Run the transformer stack over a batch of pre-embedded sequences.
74
+
75
+ Args:
76
+ inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
77
+ position_ids: Absolute positions of shape (batch, seq_len). Required.
78
+ Must be provided explicitly by the caller — this module does not
79
+ infer positions from cache state.
80
+ active_mask: Current-chunk active mask of shape (batch, seq_len),
81
+ where True means the token is semantically live. Forwarded
82
+ unchanged to every decoder layer.
83
+ cache: Optional top-level ShramCache. When provided, each DecoderLayer
84
+ receives its own layer-local cache via ``cache.layers[layer_idx]``.
85
+ The top-level cache object is updated in place and returned unchanged.
86
+ output_hidden_states: When True, the output dict includes a tuple of
87
+ per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
88
+ collected before the final norm.
89
+
90
+ Returns:
91
+ Plain dict with keys:
92
+ - ``"last_hidden_state"``: normed backbone output,
93
+ shape (batch, seq_len, hidden_size).
94
+ - ``"past_key_values"``: the cache object passed in, or None.
95
+ - ``"hidden_states"``: tuple of per-layer activations (including
96
+ inputs_embeds as position 0) if ``output_hidden_states`` is True,
97
+ else None. Collected before the final norm so each entry reflects the
98
+ unnormalised residual stream at that depth.
99
+ - ``"load_balance_loss"``: scalar sum of per-layer SHRAM
100
+ load-balance losses.
101
+ - ``"max_vio"``: detached scalar maximum routing-imbalance across
102
+ all decoder layers. Zero means perfectly balanced routing across
103
+ every layer; higher values identify the worst-case head imbalance.
104
+ """
105
+ hidden_states = inputs_embeds
106
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
107
+ total_load_balance_loss = inputs_embeds.new_zeros(())
108
+ max_vio = inputs_embeds.new_zeros(())
109
+
110
+ for layer_idx, layer in enumerate(self.layers):
111
+ layer_cache = None if cache is None else cache.layers[layer_idx]
112
+ hidden_states, layer_load_balance_loss, layer_max_vio = layer(
113
+ hidden_states,
114
+ position_ids,
115
+ active_mask,
116
+ cache=layer_cache,
117
+ )
118
+ total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
119
+ max_vio = torch.maximum(max_vio, layer_max_vio)
120
+
121
+ if output_hidden_states:
122
+ all_hidden_states = all_hidden_states + (hidden_states,)
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+
126
+ return {
127
+ "last_hidden_state": hidden_states,
128
+ "past_key_values": cache,
129
+ "hidden_states": all_hidden_states,
130
+ "load_balance_loss": total_load_balance_loss,
131
+ "max_vio": max_vio,
132
+ }
rope.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rotary Position Embeddings (RoPE).
2
+
3
+ RoPE encodes position in the *relationship* between query and key vectors. When the
4
+ attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
5
+ a score that depends only on the relative distance — not on absolute positions.
6
+
7
+ Two modes are supported:
8
+
9
+ default Standard RoPE with base frequency b. Each dimension pair d is assigned
10
+ frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
11
+ scaling A_rope = 1.
12
+
13
+ yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
14
+ "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
15
+ §A.2). Three frequency regimes:
16
+ - Low-frequency dimensions (r < α): fully interpolated by scale s.
17
+ These dimensions have long wavelengths relative to the training window
18
+ and must be compressed to avoid out-of-distribution positions.
19
+ - High-frequency dimensions (r > β): left unchanged. Short-wavelength
20
+ dimensions already encode relative position accurately at any scale.
21
+ - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
22
+ Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
23
+ standard RoPE.
24
+
25
+ Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
26
+ parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
+
28
+ Cache sharing: all instances with identical parameters share one cos/sin table via a
29
+ class-level registry. The first instance that needs a particular (parameters, seq_len,
30
+ device, dtype) combination builds the table; all subsequent instances reference it
31
+ directly. This avoids redundant builds across the num_hidden_layers instances that
32
+ share the same parametrisation.
33
+ """
34
+
35
+ import math
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Rotation helper
43
+ # ---------------------------------------------------------------------------
44
+
45
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
+ """Apply the 90° rotation used in the RoPE update formula.
47
+
48
+ Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
49
+ Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
50
+ on each consecutive pair of dimensions, matching the block-diagonal operator
51
+ R^u_{Θ,p} in the paper.
52
+ """
53
+ d = x.shape[-1] // 2
54
+ x1, x2 = x[..., :d], x[..., d:]
55
+ return torch.cat([-x2, x1], dim=-1)
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # RotaryEmbedding
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class RotaryEmbedding(nn.Module):
63
+ """Rotary Position Embeddings with explicit mode and parameter control.
64
+
65
+ Each caller constructs its own instance with the exact parameters it needs.
66
+ h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
+ config object is read inside this module.
68
+
69
+ The cos/sin cache is built lazily on the first forward call and extended
70
+ automatically when a longer sequence is encountered. Instances with identical
71
+ parameters share one cache via the class-level ``_cache`` registry,
72
+ avoiding redundant computation across decoder layers.
73
+
74
+ Args:
75
+ mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
76
+ head_dim: Per-head embedding dimension ``u``. Must be even.
77
+ theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
78
+ initial_seq_length: ``C_train`` — context length the model was trained at.
79
+ Required for ``mode="yarn"``.
80
+ dilation: Scale factor ``s = C_target / C_train`` — how much the context
81
+ window is extended beyond training length. Required for ``mode="yarn"``.
82
+ When ``dilation=1.0``, YaRN reduces to standard RoPE.
83
+ alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
84
+ interpolated. Required for ``mode="yarn"``.
85
+ beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
86
+ unchanged. Required for ``mode="yarn"``.
87
+ device: Optional device for initial buffer placement.
88
+
89
+ Raises:
90
+ NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
91
+ ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
92
+ ``dilation``, ``alpha``, ``beta`` are absent.
93
+ """
94
+
95
+ # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
96
+ # Shared across all RotaryEmbedding instances in the process. Keys include device
97
+ # and dtype so that tables built on different devices or in different precisions
98
+ # are stored independently.
99
+ _cache: dict = {}
100
+
101
+ def __init__(
102
+ self,
103
+ mode: str,
104
+ head_dim: int,
105
+ theta: float,
106
+ initial_seq_length: int | None = None,
107
+ dilation: float | None = None,
108
+ alpha: float | None = None,
109
+ beta: float | None = None,
110
+ device: torch.device | None = None,
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ self._validate_mode(mode)
115
+ self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
116
+ self.mode = mode
117
+
118
+ # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
119
+ # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
120
+ # so rotation_freqs has head_dim/2 entries.
121
+ d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
122
+ base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
123
+
124
+ if mode == "default":
125
+ rotation_freqs = base_freqs
126
+ self.attention_scaling: float = 1.0
127
+
128
+ else: # yarn
129
+ s = dilation
130
+
131
+ # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
132
+ # function to classify each dimension into one of three regimes.
133
+ normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
134
+
135
+ # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
136
+ # linear blend between α and β.
137
+ blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
138
+
139
+ # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
140
+ rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
141
+
142
+ # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
143
+ self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
144
+
145
+ # freq_key uniquely identifies the parameter set that produced rotation_freqs.
146
+ # Used as the primary component of the cache registry key.
147
+ if mode == "default":
148
+ self._freq_key: tuple = ("default", head_dim, float(theta))
149
+ else:
150
+ self._freq_key = (
151
+ "yarn", head_dim, float(theta),
152
+ int(initial_seq_length), float(dilation),
153
+ float(alpha), float(beta),
154
+ )
155
+
156
+ # rotation_freqs is a non-persistent buffer so it moves with the model across
157
+ # devices via .to() / .cuda() without appearing in saved checkpoints.
158
+ # It is stored per-instance rather than in the shared cache because it is
159
+ # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
160
+ # it is used to build. The meaningful sharing win is on those tables.
161
+ self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
162
+
163
+ # Cache tensors are plain instance attributes (not registered buffers) so that
164
+ # sharing across identically-parametrised instances survives .to() calls.
165
+ # Registered buffers are copied on device move; plain attributes are aliased,
166
+ # preserving the shared-tensor identity that the cache design depends on.
167
+ self._cos_cached: torch.Tensor | None = None
168
+ self._sin_cached: torch.Tensor | None = None
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Validation helpers
172
+ # ---------------------------------------------------------------------------
173
+
174
+ @staticmethod
175
+ def _validate_mode(mode: str) -> None:
176
+ """Raise NotImplementedError if mode is not a supported value."""
177
+ if mode not in {"default", "yarn"}:
178
+ raise NotImplementedError(
179
+ f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
180
+ )
181
+
182
+ @staticmethod
183
+ def _validate_yarn_params(
184
+ mode: str,
185
+ initial_seq_length: int | None,
186
+ dilation: float | None,
187
+ alpha: float | None,
188
+ beta: float | None,
189
+ ) -> None:
190
+ """Raise ValueError if mode='yarn' and any required parameter is absent."""
191
+ if mode != "yarn":
192
+ return
193
+ missing = [
194
+ name for name, val in [
195
+ ("initial_seq_length", initial_seq_length),
196
+ ("dilation", dilation),
197
+ ("alpha", alpha),
198
+ ("beta", beta),
199
+ ]
200
+ if val is None
201
+ ]
202
+ if missing:
203
+ raise ValueError(f"mode='yarn' requires {missing}.")
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Cache management
207
+ # ---------------------------------------------------------------------------
208
+
209
+ def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
210
+ """Build the cos/sin table to cover positions [0, seq_len).
211
+
212
+ Checks the class-level registry first. If a table already exists for this
213
+ exact (parameters, seq_len, device, dtype) combination it is reused directly;
214
+ otherwise it is computed and stored. The instance attributes are pointed at
215
+ the registry entry so that all layers sharing the same parametrisation
216
+ reference the same tensor.
217
+ """
218
+ cache_key = (self._freq_key, seq_len, str(device), str(dtype))
219
+
220
+ if cache_key not in RotaryEmbedding._cache:
221
+ positions = torch.arange(seq_len, device=device, dtype=torch.float32)
222
+ # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
223
+ freqs = torch.outer(
224
+ positions,
225
+ self.rotation_freqs.to(device=device, dtype=torch.float32),
226
+ )
227
+ angle_embedding = torch.cat((freqs, freqs), dim=-1)
228
+ RotaryEmbedding._cache[cache_key] = (
229
+ angle_embedding.cos().to(dtype),
230
+ angle_embedding.sin().to(dtype),
231
+ )
232
+
233
+ self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
234
+
235
+ def forward(
236
+ self,
237
+ q: torch.Tensor,
238
+ k: torch.Tensor,
239
+ position_ids: torch.Tensor,
240
+ ) -> tuple[torch.Tensor, torch.Tensor, float]:
241
+ """Apply rotary embeddings to query and key tensors.
242
+
243
+ The cos/sin cache is extended lazily when position_ids reference positions
244
+ beyond its current length, or when the device or dtype has changed.
245
+
246
+ ``position_ids`` may be any integer tensor shape. Its values are valid
247
+ position indices into the cos/sin cache:
248
+
249
+ - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
250
+ - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
251
+
252
+ When q/k have head dimensions absent from position_ids, broadcast dimensions
253
+ are inserted automatically at dim 1.
254
+
255
+ Args:
256
+ q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
257
+ k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
258
+ position_ids: Integer positions of shape (batch, *pos_dims).
259
+
260
+ Returns:
261
+ Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
262
+ 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
263
+ apply to attention logits before softmax.
264
+ """
265
+ seq_len = int(position_ids.max().item()) + 1
266
+
267
+ # The cache is valid when it exists, covers all positions referenced by
268
+ # position_ids, and matches q's dtype and device. Each condition is named
269
+ # separately so the rebuild trigger is readable rather than a compound predicate.
270
+ cache_missing = self._cos_cached is None
271
+ cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
272
+ wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
273
+ wrong_device = not cache_missing and self._cos_cached.device != q.device
274
+
275
+ if cache_missing or cache_too_short or wrong_dtype or wrong_device:
276
+ self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
277
+
278
+ cos = self._cos_cached[position_ids]
279
+ sin = self._sin_cached[position_ids]
280
+
281
+ # Insert broadcast dimensions for any head axes present in q/k but absent
282
+ # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
283
+ # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
284
+ while cos.ndim < q.ndim:
285
+ cos = cos.unsqueeze(1)
286
+ sin = sin.unsqueeze(1)
287
+
288
+ q_rotated = q * cos + _rotate_half(q) * sin
289
+ k_rotated = k * cos + _rotate_half(k) * sin
290
+
291
+ return q_rotated, k_rotated, self.attention_scaling
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|endoftext|>",
5
+ "eos_token": "<|endoftext|>",
6
+ "errors": "replace",
7
+ "is_local": false,
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<|padding|>",
10
+ "tokenizer_class": "GPTNeoXTokenizerFast",
11
+ "trim_offsets": true,
12
+ "unk_token": "<|endoftext|>"
13
+ }