smithblack-0 commited on
Commit
03ca6a0
·
verified ·
1 Parent(s): 02b5fbb

Clear repository contents

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -35
  2. README.md +0 -103
  3. __attention__bottlenecked_ensemble_attention.py +0 -252
  4. __attention__expert_packing.py +0 -335
  5. __attention__load_balance_loss.py +0 -88
  6. __attention__mosrah.py +0 -140
  7. __attention__positions_converter.py +0 -105
  8. __attention__router.py +0 -162
  9. __attention__shram.py +0 -116
  10. __attention__sliding_window_attention.py +0 -233
  11. __cache__mosrah_cache.py +0 -359
  12. __cache__shram_cache.py +0 -141
  13. __cache__shram_layer_cache.py +0 -233
  14. __cache__sliding_window_cache.py +0 -289
  15. __cache__slow_mosrah_cache.py +0 -321
  16. __init__.py +0 -21
  17. architecture_core/README.md +0 -95
  18. architecture_core/__init__.py +0 -21
  19. architecture_core/attention/__init__.py +0 -0
  20. architecture_core/attention/bottlenecked_ensemble_attention.py +0 -252
  21. architecture_core/attention/expert_packing.py +0 -335
  22. architecture_core/attention/load_balance_loss.py +0 -88
  23. architecture_core/attention/mosrah.py +0 -140
  24. architecture_core/attention/positions_converter.py +0 -105
  25. architecture_core/attention/router.py +0 -162
  26. architecture_core/attention/shram.py +0 -116
  27. architecture_core/attention/sliding_window_attention.py +0 -233
  28. architecture_core/cache/__init__.py +0 -6
  29. architecture_core/cache/mosrah_cache.py +0 -359
  30. architecture_core/cache/shram_cache.py +0 -141
  31. architecture_core/cache/shram_layer_cache.py +0 -233
  32. architecture_core/cache/sliding_window_cache.py +0 -289
  33. architecture_core/cache/slow_mosrah_cache.py +0 -321
  34. architecture_core/config.json +0 -28
  35. architecture_core/configuration.py +0 -175
  36. architecture_core/decoder_layer.py +0 -87
  37. architecture_core/huggingface.py +0 -506
  38. architecture_core/mlp.py +0 -52
  39. architecture_core/model.py +0 -132
  40. architecture_core/rope.py +0 -291
  41. architecture_core/tokenizer.json +0 -0
  42. architecture_core/tokenizer_config.json +0 -13
  43. attention/__init__.py +0 -0
  44. attention/bottlenecked_ensemble_attention.py +0 -252
  45. attention/expert_packing.py +0 -335
  46. attention/load_balance_loss.py +0 -88
  47. attention/mosrah.py +0 -140
  48. attention/positions_converter.py +0 -105
  49. attention/router.py +0 -162
  50. attention/shram.py +0 -116
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,103 +0,0 @@
1
- ---
2
- language:
3
- - en
4
- license: mit
5
- library_name: transformers
6
- pipeline_tag: text-generation
7
- tags:
8
- - pytorch
9
- - research
10
- - sparse-attention
11
- - mixture-of-experts
12
- ---
13
-
14
- # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
-
16
- A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
- Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
- the Hub and instantiate a freshly initialised model from config. Every parameter is
19
- overridable at instantiation time via kwargs.
20
-
21
- > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
- > source files from the Hub and imports them into your Python process. Review the
23
- > source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use.
24
-
25
- ## Architecture
26
-
27
- SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
28
-
29
- - **h_l** — local sliding-window causal attention path.
30
- - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
31
- via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
32
-
33
- All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
34
-
35
- ## Usage
36
-
37
- This repository contains no pretrained weights. The intended workflow is: pull the
38
- architecture config from the Hub, instantiate a model with fresh random weights, then
39
- train it yourself.
40
-
41
- ```python
42
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
43
-
44
- # Step 1: pull the architecture config from the Hub.
45
- # AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
46
- # Override any parameter via kwargs.
47
- config = AutoConfig.from_pretrained(
48
- "smithblack-0/SHRAM",
49
- trust_remote_code=True,
50
- num_hidden_layers=16, # example override
51
- num_mosrah_heads=32, # example override
52
- )
53
-
54
- # Step 2: instantiate with fresh random weights.
55
- # from_config never loads a checkpoint — it always produces a randomly initialised model.
56
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
57
-
58
- # Step 3: load the tokenizer.
59
- tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
60
- ```
61
-
62
- After training your own checkpoint, save and reload it in the standard way:
63
-
64
- ```python
65
- model.save_pretrained("./my-checkpoint")
66
- model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
67
- ```
68
-
69
- ## Constructor Defaults
70
-
71
- The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
72
- no overrides. They are not the parameters of a pretrained model — this repository
73
- contains no weights. All values are overridable via kwargs.
74
-
75
- | Parameter | Default |
76
- |-----------|---------|
77
- | `alpha` | 1.0 |
78
- | `attention_dropout` | 0.0 |
79
- | `beta` | 32.0 |
80
- | `dtype` | None |
81
- | `head_dim` | 16 |
82
- | `hidden_size` | 512 |
83
- | `inference_sequence_length` | 1024 |
84
- | `intermediate_size` | 1366 |
85
- | `local_rope_theta` | 10000.0 |
86
- | `mosrah_rope_theta` | 10000.0 |
87
- | `num_hidden_layers` | 12 |
88
- | `num_mosrah_heads` | 16 |
89
- | `num_selected_heads` | 16 |
90
- | `num_sliding_window_heads` | 16 |
91
- | `output_hidden_states` | False |
92
- | `rms_norm_eps` | 1e-05 |
93
- | `rope_mode` | main_sequence |
94
- | `tie_word_embeddings` | False |
95
- | `training_sequence_length` | 1024 |
96
- | `use_cache` | True |
97
- | `vocab_size` | 50277 |
98
- | `window_size` | 128 |
99
-
100
- ## License
101
-
102
- MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
103
- (`EleutherAI/gpt-neox-20b`, Apache 2.0).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__bottlenecked_ensemble_attention.py DELETED
@@ -1,252 +0,0 @@
1
- """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
-
3
- BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
- It consumes packed expert-choice tensors, a supplied position tensor, an active
5
- token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
- same packed expert-choice space expected by later unpacking.
7
-
8
- BEA does not compute positions and does not choose packed-position semantics.
9
- Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
- (K̃) and raw values (V) into the sparse cache and attends against the
11
- accumulated cached state returned by that cache.
12
- """
13
-
14
- import math
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
-
20
- from .configuration import ShramConfig
21
- from .__cache__mosrah_cache import MoSRAHCache
22
- from .rope import RotaryEmbedding
23
-
24
-
25
- class BottleneckedEnsembleAttention(nn.Module):
26
- """
27
- Packed expert-choice attention operator for the MoSRAH sparse path.
28
- Operates per-head independently on an ensemble of tokens.
29
- FlexAttention saves flops on dead tokens.
30
-
31
- Architectural properties:
32
- - consumes packed expert-choice tensors of shape (B, L, T, d)
33
- - uses independent per-head Q/K/V/O projection parameters
34
- - applies YaRN-capable RoPE using supplied position_ids
35
- - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
- - uses a fast fused attention path
37
- - returns outputs in the same packed expert-choice space (B, L, T, d)
38
-
39
- Args:
40
- config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
- `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
- `inference_sequence_length`, `alpha`, and `beta`.
43
- """
44
-
45
- def __init__(self, config: ShramConfig) -> None:
46
- super().__init__()
47
-
48
- self.hidden_size = config.hidden_size
49
- self.num_heads = config.num_mosrah_heads
50
- self.head_dim = config.head_dim
51
-
52
- # Independent per-head projections. No cross-head parameter sharing.
53
- self.q_proj = nn.Parameter(
54
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
- )
56
- self.k_proj = nn.Parameter(
57
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
- )
59
- self.v_proj = nn.Parameter(
60
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
- )
62
- self.o_proj = nn.Parameter(
63
- torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
- )
65
-
66
- self._reset_parameters()
67
-
68
- # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
- # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
- # no yarn dilation occurs.
71
- self.rope = RotaryEmbedding(
72
- mode="yarn",
73
- head_dim=self.head_dim,
74
- theta=config.mosrah_rope_theta,
75
- initial_seq_length=config.training_sequence_length,
76
- dilation=config.scale,
77
- alpha=config.alpha,
78
- beta=config.beta,
79
- )
80
-
81
- def forward(
82
- self,
83
- packed_embeddings: torch.Tensor,
84
- position_ids: torch.Tensor,
85
- active_mask: torch.Tensor,
86
- cache: MoSRAHCache | None = None,
87
- ) -> torch.Tensor:
88
- """Apply BEA to packed expert-choice tensors.
89
-
90
- Args:
91
- packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
- position_ids: Supplied packed positions of shape (B, L, T).
93
- active_mask: Boolean active-token mask of shape (B, L, T).
94
- cache: Optional layer-local MoSRAH cache.
95
-
96
- Returns:
97
- Packed expert-choice output tensor of shape (B, L, T, d).
98
- """
99
- batch_size, _, query_length, _ = packed_embeddings.shape
100
- self._validate_tensor_shape(packed_embeddings)
101
- self._validate_position_shape(packed_embeddings, position_ids)
102
- self._validate_active_mask_shape(packed_embeddings, active_mask)
103
-
104
- # Independent per-head projections:
105
- # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
- query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
- key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
- value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
-
110
- rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
- query_states,
112
- key_states,
113
- position_ids,
114
- )
115
-
116
- if cache is not None:
117
- # In cached execution, the current query tensor uses local tensor rows
118
- # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
- # packed sequence for each (batch, head) slot. The only additional data
120
- # needed to align those two views is the pre-update cached prefix length.
121
- # which will indicate how many queries were processed before now.
122
- num_tokens_processed = cache.get_heads_lengths().clone()
123
- key_states, value_states, key_active_mask = cache.update(
124
- rotated_key_states,
125
- value_states,
126
- active_mask,
127
- )
128
- else:
129
- num_tokens_processed = torch.zeros(
130
- batch_size,
131
- self.num_heads,
132
- dtype=torch.long,
133
- device=packed_embeddings.device,
134
- )
135
- key_states = rotated_key_states
136
- key_active_mask = active_mask
137
-
138
- block_mask = self._make_block_mask(
139
- query_active_mask=active_mask,
140
- key_active_mask=key_active_mask,
141
- num_tokens_processed=num_tokens_processed,
142
- query_length=query_length,
143
- key_length=key_states.shape[2],
144
- device=packed_embeddings.device,
145
- )
146
- attended_states = flex_attention(
147
- rotated_query_states,
148
- key_states,
149
- value_states,
150
- block_mask=block_mask,
151
- scale=attention_scaling / math.sqrt(self.head_dim),
152
- )
153
-
154
- # Project back to model width:
155
- # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
- return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
-
158
- def _reset_parameters(self) -> None:
159
- """Initialize per-head projection weights."""
160
- for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
- nn.init.xavier_uniform_(weight)
162
-
163
- def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
- """Validate the local packed-embedding shape contract required by BEA."""
165
- if packed_embeddings.shape[1] != self.num_heads:
166
- raise ValueError(
167
- f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
- f"got {packed_embeddings.shape[1]}."
169
- )
170
-
171
- if packed_embeddings.shape[-1] != self.hidden_size:
172
- raise ValueError(
173
- f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
- f"got {packed_embeddings.shape[-1]}."
175
- )
176
-
177
- def _validate_position_shape(
178
- self,
179
- packed_embeddings: torch.Tensor,
180
- position_ids: torch.Tensor,
181
- ) -> None:
182
- """Validate the supplied packed-position tensor shape."""
183
- if position_ids.shape != packed_embeddings.shape[:3]:
184
- raise ValueError(
185
- f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
- f"got {tuple(position_ids.shape)}."
187
- )
188
-
189
- def _validate_active_mask_shape(
190
- self,
191
- packed_embeddings: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> None:
194
- """Validate the supplied active-token mask shape."""
195
- if active_mask.shape != packed_embeddings.shape[:3]:
196
- raise ValueError(
197
- f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
- f"got {tuple(active_mask.shape)}."
199
- )
200
-
201
- def _make_block_mask(
202
- self,
203
- query_active_mask: torch.Tensor,
204
- key_active_mask: torch.Tensor,
205
- num_tokens_processed: torch.Tensor,
206
- query_length: int,
207
- key_length: int,
208
- device: torch.device,
209
- ):
210
- """Create the packed-sequence causal mask for FlexAttention.
211
-
212
- At the root, causality is still triangular. The only nuance is cached
213
- execution: query rows are indexed locally as 0..Q-1 inside the current
214
- query tensor, but the key tensor may already contain a cached prefix for
215
- that (batch, head) slot. The causal horizon for query tensor row q is
216
- therefore:
217
-
218
- cached_prefix_lengths[b, h] + q
219
-
220
- Query and key activity masks are then composed with that triangular rule
221
- so FlexAttention can skip padded query rows and ignore inactive key slots.
222
- """
223
- batch_size, num_heads, _ = query_active_mask.shape
224
-
225
- # Build the per-(batch, head, query_row) triangular horizon from a simple
226
- # arange over query rows plus the cached prefix lengths for each slot.
227
- relative_query_positions = torch.arange(
228
- query_length,
229
- device=device,
230
- dtype=torch.long,
231
- ).view(1, 1, query_length)
232
- causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
-
234
- def packed_causal_mask(
235
- batch_idx: torch.Tensor,
236
- head_idx: torch.Tensor,
237
- query_idx: torch.Tensor,
238
- key_idx: torch.Tensor,
239
- ) -> torch.Tensor:
240
- query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
- key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
- is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
- return query_is_active & key_is_active & is_causal
244
-
245
- return create_block_mask(
246
- packed_causal_mask,
247
- B=batch_size,
248
- H=num_heads,
249
- Q_LEN=query_length,
250
- KV_LEN=key_length,
251
- device=device,
252
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__expert_packing.py DELETED
@@ -1,335 +0,0 @@
1
- """Expert packing and unpacking for the MoSRAH path.
2
-
3
- This module implements the low-level token-choice -> expert-choice -> token-choice
4
- conversion boundary specified in the paper. The externally visible behavior is fixed:
5
-
6
- - setup_packing() prepares the auxiliary ordering data.
7
- - pack_experts() converts routed token-choice state into packed expert-choice state.
8
- - unpack_experts() restores token-choice ordering afterward.
9
-
10
- Stable sort is a correctness requirement. It preserves causal ordering inside each
11
- expert bucket, which is the foundation on which BEA's later triangular causal mask
12
- is correct.
13
-
14
- pack_experts() returns two distinct masks that serve different roles and must not
15
- be interchanged:
16
-
17
- - unpacking_mask: marks every packed slot that contains a routed token copy,
18
- live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
- so its reshape invariant holds regardless of outer token liveness.
20
- - active_mask: marks only the packed slots whose source token was semantically
21
- live. This is what BEA consumes for attention gating. Dead outer tokens must
22
- not influence sparse attention outputs.
23
- """
24
-
25
- import torch
26
-
27
-
28
- # ---------------------------------------------------------------------------
29
- # Setup
30
- # ---------------------------------------------------------------------------
31
-
32
- def setup_packing(
33
- selected_heads: torch.Tensor,
34
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
- """Prepare the auxiliary ordering data used by pack/unpack.
36
-
37
- Routing produces token-choice state I of shape (B, N, K): for each token, which
38
- K experts were selected. Packing needs the same routed token copies reordered into
39
- expert-major order so each expert bucket becomes contiguous.
40
-
41
- The paper's setup step does this by flattening (N, K) into one axis to produce
42
- H in token-major order, then computing a stable argsort permutation Pi over the
43
- expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
- expert-major order while preserving their original token order *within* each expert
45
- bucket. That preservation is why stable sort is required for causality.
46
-
47
- Args:
48
- selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
-
50
- Returns:
51
- Tuple of:
52
- - flattened_selected_heads: H of shape (B, N*K)
53
- - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
- - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
- """
56
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
- flattened_selected_heads = selected_heads.reshape(
58
- batch_size,
59
- sequence_length * num_selected_heads,
60
- )
61
-
62
- permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
- inverse_permutation = torch.argsort(permutation, dim=-1)
64
-
65
- return flattened_selected_heads, permutation, inverse_permutation
66
-
67
-
68
- # ---------------------------------------------------------------------------
69
- # Packing
70
- # ---------------------------------------------------------------------------
71
-
72
- def pack_experts(
73
- hidden_states: torch.Tensor,
74
- position_ids: torch.Tensor,
75
- selected_heads: torch.Tensor,
76
- num_experts: int,
77
- flattened_selected_heads: torch.Tensor,
78
- permutation: torch.Tensor,
79
- outer_active_mask: torch.Tensor,
80
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
- """Pack token-choice hidden states into expert-choice padded form.
82
-
83
- The paper's packing path has two jobs:
84
-
85
- 1. Convert routed token-choice copies into expert-major order.
86
- 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
-
88
- The routed hidden-state copies are not stored explicitly in token-choice form.
89
- Instead, the same token hidden state is conceptually copied once per selected expert.
90
- The packing step reconstructs those copies by expanding local source-token indices,
91
- reordering those indices with Pi, then gathering hidden states, positions, and outer
92
- liveness in that packed order. All three are carried through the same expert-major
93
- rearrangement so they remain aligned in the packed frame.
94
-
95
- Packed positions are sourced from the authoritative upstream position_ids tensor
96
- rather than synthesized locally from arange(N). This preserves advanced positions
97
- correctly during cached inference while leaving training/full-sequence behavior
98
- unchanged when position_ids is the ordinary sequential token positions.
99
-
100
- Args:
101
- hidden_states: Token-choice hidden states x of shape (B, N, d).
102
- position_ids: Authoritative upstream token positions J of shape (B, N).
103
- selected_heads: Routed head selections I of shape (B, N, K).
104
- num_experts: Total number of experts L.
105
- flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
- permutation: Pi from setup_packing(), shape (B, N*K).
107
- outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
- means the token is semantically live. Dead tokens do not become
109
- semantically active in the packed sparse representation.
110
-
111
- Returns:
112
- Tuple of:
113
- - packed_hidden_states: x' of shape (B, L, T, d)
114
- - packed_positions: J' of shape (B, L, T)
115
- - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
- routed token copy, live or dead. Always has exactly B*N*K True entries.
117
- Pass this to unpack_experts — not active_mask.
118
- - active_mask: of shape (B, L, T). True only where a slot contains a
119
- copy of a live outer token. Pass this to BEA for attention gating.
120
- """
121
- batch_size, sequence_length, hidden_dim = hidden_states.shape
122
- _, _, num_selected_heads = selected_heads.shape
123
-
124
- # -----------------------------------------------------------------------
125
- # Reconstruct routed local source-token indices in token-choice order.
126
- #
127
- # The internal arange(N) is no longer the packed position tensor. It is only
128
- # the local source-row index object used to gather from the current chunk
129
- # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
- # token-major routed-copy order.
131
- # -----------------------------------------------------------------------
132
- source_token_indices = torch.arange(
133
- sequence_length,
134
- device=hidden_states.device,
135
- dtype=torch.long,
136
- ).view(1, sequence_length, 1).expand(
137
- batch_size,
138
- sequence_length,
139
- num_selected_heads,
140
- )
141
- flattened_source_indices = source_token_indices.reshape(
142
- batch_size,
143
- sequence_length * num_selected_heads,
144
- )
145
-
146
- # -----------------------------------------------------------------------
147
- # Reorder source-token indices into expert-major order.
148
- #
149
- # Applying Pi yields the local source-token rows in the packed expert-major
150
- # order required by the paper. Those same reordered source indices are then
151
- # used to gather hidden states, authoritative upstream positions, and outer
152
- # liveness so all three remain aligned under the exact same packing
153
- # transformation.
154
- # -----------------------------------------------------------------------
155
- sorted_source_indices = flattened_source_indices.gather(
156
- dim=1,
157
- index=permutation,
158
- )
159
- sorted_hidden_states = hidden_states.gather(
160
- dim=1,
161
- index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
- )
163
- sorted_positions = position_ids.gather(
164
- dim=1,
165
- index=sorted_source_indices,
166
- )
167
- sorted_active_mask = outer_active_mask.gather(
168
- dim=1,
169
- index=sorted_source_indices,
170
- )
171
-
172
- # -----------------------------------------------------------------------
173
- # Count how many routed copies land in each expert bucket.
174
- #
175
- # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
- # T is the maximum such count across all batches and experts; it determines the
177
- # padded expert-length dimension of the packed representation.
178
- # -----------------------------------------------------------------------
179
- tokens_per_expert = _bincount_rows(
180
- values=flattened_selected_heads,
181
- num_bins=num_experts,
182
- )
183
- max_tokens_per_expert = int(tokens_per_expert.max().item())
184
-
185
- # -----------------------------------------------------------------------
186
- # Construct the active-token mask M.
187
- #
188
- # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
- # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
- # mask therefore both identifies real packed tokens and enforces left-justified
191
- # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
- # outer token liveness, and always has exactly B*N*K True entries.
193
- # -----------------------------------------------------------------------
194
- time_axis = torch.arange(
195
- max_tokens_per_expert,
196
- device=hidden_states.device,
197
- dtype=torch.long,
198
- ).view(1, 1, max_tokens_per_expert)
199
- unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
-
201
- # -----------------------------------------------------------------------
202
- # Materialize the padded packed tensors.
203
- #
204
- # The packed hidden states x', packed original-token positions J', and packed
205
- # active-token mask are allocated as zero-filled tensors. Active entries are
206
- # then written into those buffers in the expert-major order established above.
207
- # Padding remains zero / inactive.
208
- # -----------------------------------------------------------------------
209
- packed_hidden_states = hidden_states.new_zeros(
210
- batch_size,
211
- num_experts,
212
- max_tokens_per_expert,
213
- hidden_dim,
214
- )
215
- packed_positions = position_ids.new_zeros(
216
- batch_size,
217
- num_experts,
218
- max_tokens_per_expert,
219
- )
220
- active_mask = torch.zeros(
221
- batch_size,
222
- num_experts,
223
- max_tokens_per_expert,
224
- dtype=torch.bool,
225
- device=hidden_states.device,
226
- )
227
-
228
- packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
- packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
- active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
-
232
- return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
-
234
-
235
- # ---------------------------------------------------------------------------
236
- # Unpacking
237
- # ---------------------------------------------------------------------------
238
-
239
- def unpack_experts(
240
- expert_outputs: torch.Tensor,
241
- selected_heads: torch.Tensor,
242
- unpacking_mask: torch.Tensor,
243
- inverse_permutation: torch.Tensor,
244
- ) -> torch.Tensor:
245
- """Restore token-choice ordering from BEA expert-choice output.
246
-
247
- Unpacking inverts the packing path only on occupied entries. Padding does not
248
- participate: the output tensor is first filtered by unpacking_mask to recover
249
- only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
- the original token-choice ordering, and finally the tensor is reshaped back to
251
- (B, N, K, d).
252
-
253
- The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
- outer tokens occupy slots and must be un-scattered correctly for the inverse
255
- permutation to hold. The total True entry count in unpacking_mask is always
256
- B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
-
258
- Args:
259
- expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
- selected_heads: Routed head selections I of shape (B, N, K).
261
- unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
- occupied packed slots regardless of outer token liveness.
263
- inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
-
265
- Returns:
266
- Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
- """
268
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
- hidden_dim = expert_outputs.shape[-1]
270
-
271
- active_outputs = expert_outputs[unpacking_mask]
272
- sorted_token_choice_outputs = active_outputs.reshape(
273
- batch_size,
274
- sequence_length * num_selected_heads,
275
- hidden_dim,
276
- )
277
- restored_outputs = sorted_token_choice_outputs.gather(
278
- dim=1,
279
- index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
- )
281
-
282
- return restored_outputs.reshape(
283
- batch_size,
284
- sequence_length,
285
- num_selected_heads,
286
- hidden_dim,
287
- )
288
-
289
-
290
- # ---------------------------------------------------------------------------
291
- # Helpers
292
- # ---------------------------------------------------------------------------
293
-
294
- def _bincount_rows(
295
- values: torch.Tensor,
296
- num_bins: int,
297
- ) -> torch.Tensor:
298
- """Count per-row integer occurrences for a 2D tensor.
299
-
300
- torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
- one bincount per batch row. The trick used here is to shift each row into its
302
- own disjoint bin range before flattening:
303
-
304
- row 0 uses bins [0, ..., num_bins - 1]
305
- row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
- row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
- ...
308
-
309
- After that shift, one global torch.bincount produces all row-local counts at
310
- once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
-
312
- This is a vectorized implementation detail only; externally visible behavior
313
- remains exactly the paper's S tensor of per-batch per-expert token counts.
314
-
315
- Args:
316
- values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
- num_bins: Number of bins.
318
-
319
- Returns:
320
- Counts tensor of shape (B, num_bins).
321
- """
322
- batch_size = values.shape[0]
323
-
324
- row_offsets = torch.arange(
325
- batch_size,
326
- device=values.device,
327
- dtype=values.dtype,
328
- ).unsqueeze(1) * num_bins
329
- shifted_values = values + row_offsets
330
-
331
- counts = torch.bincount(
332
- shifted_values.reshape(-1),
333
- minlength=batch_size * num_bins,
334
- )
335
- return counts.reshape(batch_size, num_bins)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__load_balance_loss.py DELETED
@@ -1,88 +0,0 @@
1
- """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
-
3
- This module implements the custom autograd Function H(b, f) described in the paper's
4
- Implementation Concerns section. The operator bridges two requirements that are in
5
- tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
- multiplication) so that existing training loops remain compatible, while simultaneously
7
- implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
- path through the router weights.
9
-
10
- The resolution is a custom backward pass. The forward emits the load balance imbalance
11
- as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
- its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
- then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
- correction without a standalone SGD update step.
15
-
16
- Paper ref: Appendix A.Implementation Concerns.
17
- """
18
-
19
- import torch
20
-
21
-
22
- class LoadBalanceLoss(torch.autograd.Function):
23
- """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
-
25
- Forward computes the load balance imbalance:
26
-
27
- L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
-
29
- Backward emits a bias-correction gradient to expert_bias:
30
-
31
- grad_b = L_grad * sign(f_l - 1/L)
32
-
33
- expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
- in the computation graph and routes gradients through it. routing_freqs (f) receives
35
- no gradient — its origin is the discrete TopK operation which has no gradient, so
36
- defining a gradient for f here would be mathematically incorrect.
37
-
38
- Paper ref: Appendix A.Implementation Concerns.
39
- """
40
-
41
- @staticmethod
42
- def forward(
43
- ctx: torch.autograd.function.FunctionCtx,
44
- expert_bias: torch.Tensor,
45
- routing_freqs: torch.Tensor,
46
- ) -> torch.Tensor:
47
- """Compute the load balance loss.
48
-
49
- Args:
50
- ctx: Autograd context for saving state needed in backward.
51
- expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
- PyTorch tracks it as a computation graph node needing a gradient.
53
- routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
- from the discrete TopK selection — not differentiable.
55
-
56
- Returns:
57
- Scalar loss equal to sum_l |f_l - 1/L|.
58
- """
59
- L = expert_bias.shape[0]
60
- # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
- # underloaded. Saved for backward where sign(imbalance) determines the direction
62
- # of the bias-correction update.
63
- imbalance = routing_freqs - 1.0 / L
64
- ctx.save_for_backward(imbalance)
65
- return imbalance.abs().sum()
66
-
67
- @staticmethod
68
- def backward(
69
- ctx: torch.autograd.function.FunctionCtx,
70
- grad_output: torch.Tensor,
71
- ) -> tuple[torch.Tensor, None]:
72
- """Emit the DeepSeek-style bias-correction gradient.
73
-
74
- Args:
75
- ctx: Autograd context carrying imbalance saved in forward.
76
- grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
- by the training loop arrives here and is propagated to grad_b, so the
78
- correction magnitude is proportional to the loss weight chosen by the
79
- consumer.
80
-
81
- Returns:
82
- Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
- None for routing_freqs: no gradient is defined for the discrete routing
84
- frequency.
85
- """
86
- (imbalance,) = ctx.saved_tensors
87
- grad_expert_bias = grad_output * imbalance.sign()
88
- return grad_expert_bias, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__mosrah.py DELETED
@@ -1,140 +0,0 @@
1
- """Full MoSRAH sparse path for SHRAM.
2
-
3
- This module coordinates the routed sparse attention path used inside the SHRAM
4
- hybrid attention layer. The underlying mechanics already live in verified
5
- subunits. The responsibility here is to connect those subunits without
6
- corrupting their bridge contracts.
7
-
8
- In particular, this path must preserve three architectural distinctions:
9
-
10
- - selected head indices are not routing probabilities
11
- - packed position semantics are chosen before BEA, not inside it
12
- - weighted reduction must consume the router's unbiased renormalized
13
- probabilities after token-choice order has been restored
14
- """
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from .__cache__mosrah_cache import MoSRAHCache
20
- from .configuration import ShramConfig
21
- from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
- from .__attention__expert_packing import (
23
- pack_experts,
24
- setup_packing,
25
- unpack_experts,
26
- )
27
- from .__attention__router import MoSRAHRouter
28
- from .__attention__positions_converter import SparseMoSRAHPositions
29
-
30
-
31
- class MoSRAHLayer(nn.Module):
32
- """Full routed sparse attention path for SHRAM.
33
-
34
- The MoSRAH path consumes model-space hidden states together with
35
- authoritative per-token positions and returns the model-space sparse-path
36
- contribution, the router's load-balance loss, and the router's MaxVio
37
- routing-imbalance scalar.
38
- """
39
-
40
- def __init__(self, config: ShramConfig) -> None:
41
- super().__init__()
42
- self.num_experts = config.num_mosrah_heads
43
-
44
- self.router = MoSRAHRouter(config)
45
- self.positions = SparseMoSRAHPositions(config)
46
- self.bea = BottleneckedEnsembleAttention(config)
47
-
48
- def forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- position_ids: torch.Tensor,
52
- active_mask: torch.Tensor,
53
- cache: MoSRAHCache | None,
54
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
- """Run the full MoSRAH sparse path.
56
-
57
- Args:
58
- hidden_states: Model-space hidden states x of shape (B, N, d).
59
- position_ids: Authoritative per-token positions of shape (B, N).
60
- active_mask: Current-chunk active mask of shape (B, N), where True
61
- means the token is semantically live. Forwarded to the router
62
- so dead tokens are excluded from routing statistics, and to
63
- pack_experts so dead outer tokens do not become semantically
64
- active packed entries.
65
- cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
- execution and the layer-local cache instance for cached execution.
67
-
68
- Returns:
69
- sparse_output: Model-space sparse-path output of shape (B, N, d).
70
- load_balance_loss: Scalar router load-balance loss.
71
- max_vio: Detached scalar routing-imbalance summary. Passed through
72
- unchanged from the router; see MoSRAHRouter for semantics.
73
- """
74
-
75
- # -------------------------------------------------------------------
76
- # The first transition moves from model-space token-choice input into
77
- # the packed expert-choice sparse-attention state. Routing decides both
78
- # which experts each token uses and which unbiased probabilities must be
79
- # reserved for the final reduction. The active mask is forwarded to the
80
- # router so dead tokens are excluded from routing statistics, and to
81
- # pack_experts so outer liveness is faithfully carried into the packed
82
- # frame. Packing returns both the unpacking mask (slot occupancy, always
83
- # B*N*K True entries) and the packed active mask (live slots only);
84
- # active_mask is rebound to the packed form after this point.
85
- # -------------------------------------------------------------------
86
- selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
- hidden_states, active_mask
88
- )
89
-
90
- flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
- selected_heads
92
- )
93
- packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
- hidden_states=hidden_states,
95
- position_ids=position_ids,
96
- selected_heads=selected_heads,
97
- num_experts=self.num_experts,
98
- flattened_selected_heads=flattened_selected_heads,
99
- permutation=permutation,
100
- outer_active_mask=active_mask,
101
- )
102
-
103
- # -------------------------------------------------------------------
104
- # Sparse attention runs entirely in the packed expert-choice frame, so
105
- # the RoPE position semantics must also be chosen in that frame. The
106
- # position layer therefore decides whether BEA should see packed
107
- # original-token positions or packed local-slot positions. BEA then
108
- # consumes that packed position tensor together with the packed hidden
109
- # states and the layer-local sparse cache, which it owns directly.
110
- # -------------------------------------------------------------------
111
- bea_positions = self.positions(
112
- packed_positions=packed_positions,
113
- cache=cache,
114
- )
115
- packed_outputs = self.bea(
116
- packed_embeddings=packed_hidden_states,
117
- position_ids=bea_positions,
118
- active_mask=active_mask,
119
- cache=cache,
120
- )
121
-
122
- # -------------------------------------------------------------------
123
- # The final transition restores token-choice meaning and only then
124
- # collapses the K routed copies back into model space. This ordering is
125
- # required because routing_probs live in token-choice space, whereas BEA
126
- # returns expert-choice packed outputs. The reduction must therefore
127
- # happen after unpacking, and it must use the router's unbiased
128
- # renormalized probabilities rather than any biased selection scores.
129
- # -------------------------------------------------------------------
130
- token_choice_outputs = unpack_experts(
131
- expert_outputs=packed_outputs,
132
- selected_heads=selected_heads,
133
- unpacking_mask=unpacking_mask,
134
- inverse_permutation=inverse_permutation,
135
- )
136
- final_output = (
137
- token_choice_outputs * routing_probs.unsqueeze(-1)
138
- ).sum(dim=2)
139
-
140
- return final_output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__positions_converter.py DELETED
@@ -1,105 +0,0 @@
1
- """Position computation for the MoSRAH sparse path.
2
-
3
- This layer computes the packed position tensor P consumed by BEA.
4
-
5
- - In main-sequence mode, P is the packed original-token position tensor from the
6
- packing path.
7
- - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
- expert-choice layout, optionally offset by the current sparse-cache occupancies
9
- during cached inference.
10
- """
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from .configuration import ShramConfig
16
- from .__cache__mosrah_cache import MoSRAHCache
17
-
18
-
19
- class SparseMoSRAHPositions(nn.Module):
20
- """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
-
22
- This layer operates in the packed expert-choice frame used by BEA. The input
23
- packed_positions tensor is always the packed original-token position tensor
24
- produced by the packing path. The configured rope_mode determines whether that
25
- tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
- """
27
-
28
- def __init__(self, config: ShramConfig) -> None:
29
- super().__init__()
30
- self.rope_mode = config.rope_mode
31
-
32
- def forward(
33
- self,
34
- packed_positions: torch.Tensor,
35
- cache: MoSRAHCache | None,
36
- ) -> torch.Tensor:
37
- """Compute the packed position tensor P consumed by BEA.
38
-
39
- Args:
40
- packed_positions: Packed original-token positions J' of shape (B, L, T).
41
- cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
- mode, the current per-head occupancies offset the local packed sequence.
43
-
44
- Returns:
45
- Packed position tensor P of shape (B, L, T).
46
- """
47
- if self.rope_mode == "main_sequence":
48
- return self._main_sequence_positions(packed_positions)
49
-
50
- if self.rope_mode == "semantic_sequence":
51
- return self._semantic_sequence_positions(packed_positions, cache)
52
-
53
- raise NotImplementedError(
54
- f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
- )
56
-
57
- def _main_sequence_positions(
58
- self,
59
- packed_positions: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """Forward packed original-token positions unchanged."""
62
- return packed_positions
63
-
64
- def _semantic_sequence_positions(
65
- self,
66
- packed_positions: torch.Tensor,
67
- cache: MoSRAHCache | None,
68
- ) -> torch.Tensor:
69
- """Compute semantic-sequence packed positions in expert-choice space.
70
-
71
- Without a sparse cache, semantic positions are the local packed sequence
72
- 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
- same local sequence is offset by the current per-(batch, expert) occupancies
74
- returned by get_heads_lengths().
75
- """
76
- batch_size, num_experts, packed_length = packed_positions.shape
77
-
78
- # -------------------------------------------------------------------
79
- # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
- # sequence dimension T. This is then broadcast across batch and experts.
81
- # -------------------------------------------------------------------
82
- local_positions = torch.arange(
83
- packed_length,
84
- device=packed_positions.device,
85
- dtype=packed_positions.dtype,
86
- ).view(1, 1, packed_length).expand(
87
- batch_size,
88
- num_experts,
89
- packed_length,
90
- )
91
-
92
- # -------------------------------------------------------------------
93
- # In cached semantic-sequence mode, positions continue from the current
94
- # sparse-cache occupancies rather than restarting at zero for the local
95
- # chunk.
96
- # -------------------------------------------------------------------
97
- if cache is None:
98
- return local_positions
99
-
100
- cached_lengths = cache.get_heads_lengths().to(
101
- device=packed_positions.device,
102
- dtype=packed_positions.dtype,
103
- ).unsqueeze(-1)
104
-
105
- return local_positions + cached_lengths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__router.py DELETED
@@ -1,162 +0,0 @@
1
- """Token-choice router for the MoSRAH sparse attention path.
2
-
3
- This module implements the routing mechanism described in Appendix A.Routing of the
4
- paper. Given an input hidden state x, the router produces two outputs used downstream:
5
-
6
- - selected_heads (I): which K of the L available expert heads each token routes to,
7
- determined by TopK over biased routing scores.
8
- - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
- *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
- bias b must not influence P.
11
-
12
- This separation is architecturally critical: expert_bias drives selection (and thus load
13
- balancing) but does not corrupt the gradient path from the output through routing_probs
14
- back to the routing projection weights.
15
-
16
- The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
- autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
- loop can weight and add to the language modeling loss.
19
-
20
- The router additionally computes and returns MaxVio, a detached scalar summarising
21
- routing imbalance for the current forward pass:
22
-
23
- MaxVio = L · max_l(f_l − 1/L)
24
-
25
- where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
- target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
-
28
- Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
- """
30
-
31
- import torch
32
- import torch.nn as nn
33
- import torch.nn.functional as F
34
-
35
- from .configuration import ShramConfig
36
- from .__attention__load_balance_loss import LoadBalanceLoss
37
-
38
-
39
- class MoSRAHRouter(nn.Module):
40
- """Token-choice router for MoSRAH sparse attention.
41
-
42
- Each input token independently selects K of the L available expert heads. Selection
43
- is driven by biased routing scores to enable load balancing, but the routing
44
- probabilities used for output reduction are computed from unbiased scores so that
45
- the expert bias does not interfere with the gradient path to the router weights.
46
-
47
- The routing projection W_r has no bias term — the paper specifies xW_r with no
48
- additional projection bias. The only bias-like parameter is expert_bias (b), which
49
- has an entirely separate role and update mechanism.
50
-
51
- Args:
52
- config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
- (L), and ``num_selected_heads`` (K).
54
- """
55
-
56
- def __init__(self, config: ShramConfig) -> None:
57
- super().__init__()
58
- self.num_mosrah_heads = config.num_mosrah_heads
59
- self.num_selected_heads = config.num_selected_heads
60
-
61
- # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
- self.routing_projection = nn.Linear(
63
- config.hidden_size, config.num_mosrah_heads, bias=False
64
- )
65
-
66
- # b: learned per-head bias for load balancing. Initialized to zero so that all
67
- # heads start with equal selection probability. Updated by the main optimizer
68
- # via the LoadBalanceLoss custom backward.
69
- self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
-
71
- def forward(
72
- self, x: torch.Tensor, active_mask: torch.Tensor
73
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
- """Route input tokens to K expert heads each and compute routing probabilities.
75
-
76
- Args:
77
- x: Input hidden states of shape (batch, seq_len, hidden_size).
78
- active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
- True means the token is semantically live. Dead tokens do not
80
- contribute to routing frequencies, load_balance_loss, or max_vio.
81
-
82
- Returns:
83
- selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
- Each token's K selected head indices, determined by TopK on biased scores.
85
- routing_probs: Routing probabilities P of shape (batch, seq_len,
86
- num_selected_heads). Gathered from unbiased scores at selected_heads
87
- indices and renormalized to sum to 1 per token.
88
- load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
- Training loop scales this by a weight and adds it to the main loss.
90
- max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
- Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
- never contributes gradients.
93
- """
94
- B, N, _ = x.shape
95
- L = self.num_mosrah_heads
96
- K = self.num_selected_heads
97
-
98
- # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
- # compute routing_probs — expert_bias must not influence them.
100
- logits = self.routing_projection(x) # (B, N, L)
101
- routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
-
103
- # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
- # selection. expert_bias is added to logits before softmax so that the bias
105
- # shifts selection probability without rescaling the unbiased distribution.
106
- biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
- logits + self.expert_bias, dim=-1
108
- )
109
-
110
- # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
- selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
-
113
- # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
- # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
- # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
- # the output back to the router weights free of expert_bias influence.
117
- gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
- routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
-
120
- # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
- # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
- # assignment mask before reduction. Normalization uses the active assignment
123
- # count so frequencies remain properly scaled regardless of how many tokens
124
- # are live in this chunk.
125
- assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
- assignment_mask.scatter_(-1, selected_heads, 1.0)
127
- active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
- num_active_assignments = active_mask.sum() * K
129
- routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
-
131
- # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
- # registers it as a graph node; the custom backward writes the DeepSeek-style
133
- # correction gradient to expert_bias.grad for the optimizer to consume.
134
- load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
-
136
- # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
- # not contribute gradients under any circumstance, so it is detached at the
138
- # point of computation rather than left to callers to detach.
139
- max_vio = self._compute_max_vio(routing_freqs, L)
140
-
141
- return selected_heads, routing_probs, load_balance_loss, max_vio
142
-
143
- @staticmethod
144
- def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
- """Compute the MaxVio routing-imbalance scalar.
146
-
147
- MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
- head l and 1/L is the perfectly balanced target. A value of zero indicates
149
- perfect balance; a value of 1 means the most overloaded head received exactly
150
- double its fair share.
151
-
152
- The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
- and must never contribute gradients to any parameter.
154
-
155
- Args:
156
- routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
- num_heads: Total number of MoSRAH heads L.
158
-
159
- Returns:
160
- Detached scalar MaxVio tensor.
161
- """
162
- return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__shram.py DELETED
@@ -1,116 +0,0 @@
1
- """SHRAM hybrid attention layer.
2
-
3
- This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
- used at one decoder attention slot in SHRAM.
5
-
6
- The local sliding-window path and the MoSRAH sparse path are already verified
7
- independently. The responsibility here is therefore not to introduce new
8
- attention logic, but to preserve the bridge contracts between them: both paths
9
- must consume the same input hidden state, each path must receive the sub-cache
10
- it actually owns, the two model-space outputs must be summed directly, and the
11
- sparse-path load-balance loss must remain visible to the caller.
12
- """
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from .__cache__shram_layer_cache import ShramLayerCache
18
- from .configuration import ShramConfig
19
- from .__attention__sliding_window_attention import SlidingWindowAttention
20
- from .__attention__mosrah import MoSRAHLayer
21
-
22
-
23
- class SHRAMHybridLayer(nn.Module):
24
- """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
-
26
- The local path preserves nearby-token behavior through sliding-window causal
27
- attention. The sparse path is the theorem-facing MoSRAH routed attention
28
- path. Both operate over the same model-space hidden state and return
29
- model-space outputs, so the hybrid composition is a direct sum in model
30
- space.
31
- """
32
-
33
- def __init__(self, config: ShramConfig) -> None:
34
- super().__init__()
35
- self.local_attention = SlidingWindowAttention(config)
36
- self.sparse_attention = MoSRAHLayer(config)
37
-
38
- def forward(
39
- self,
40
- hidden_states: torch.Tensor,
41
- position_ids: torch.Tensor,
42
- active_mask: torch.Tensor,
43
- cache: ShramLayerCache | None,
44
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
- """Apply the SHRAM hybrid attention layer.
46
-
47
- Args:
48
- hidden_states: Input hidden states of shape (B, N, d).
49
- position_ids: Authoritative token positions of shape (B, N).
50
- active_mask: Current-chunk active mask of shape (B, N), where True
51
- means the token is semantically live. Forwarded unchanged to
52
- both the local path and the sparse path.
53
- cache: Optional per-layer SHRAM cache. When provided, the owned
54
- sliding-window and MoSRAH sub-caches are dispatched directly to
55
- their corresponding attention paths.
56
-
57
- Returns:
58
- hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
- load_balance_loss: Scalar sparse-path load-balance loss.
60
- max_vio: Detached scalar routing-imbalance summary. Passed through
61
- unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
- """
63
- # ------------------------------------------------
64
- # It is not possible, due to how bea constructs its block mask,
65
- # for the model to process a sequence that does not start at zero
66
- # without a cache to track the per-head offsets
67
- # ------------------------------------------------
68
-
69
- if cache is None and torch.any(position_ids[:, 0] != 0):
70
- raise ValueError(
71
- "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
- "Either provide a matching ShramLayerCache populated by the prefix for "
73
- "continued decoding, or rebase the uncached sequence to start at 0."
74
- )
75
-
76
- # -------------------------------------------------------------------
77
- # The hybrid layer's first responsibility is cache dispatch. The layer
78
- # cache already owns the concrete sub-cache objects required by each
79
- # path, so this unit should forward those exact references rather than
80
- # reinterpret cache ownership or invent a composite update protocol here.
81
- # -------------------------------------------------------------------
82
- if cache is None:
83
- sliding_window_cache = None
84
- mosrah_cache = None
85
- else:
86
- sliding_window_cache = cache.sliding_window_cache
87
- mosrah_cache = cache.mosrah_cache
88
-
89
- # -------------------------------------------------------------------
90
- # Both attention paths must see the same model-space hidden state for
91
- # the current decoder layer. The local path preserves short-range
92
- # structure, while the sparse path provides the routed long-range
93
- # contribution and emits the load-balance signal used by training.
94
- # -------------------------------------------------------------------
95
- local_output = self.local_attention(
96
- x=hidden_states,
97
- position_ids=position_ids,
98
- active_mask=active_mask,
99
- cache=sliding_window_cache,
100
- )
101
- sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
- hidden_states=hidden_states,
103
- position_ids=position_ids,
104
- active_mask=active_mask,
105
- cache=mosrah_cache,
106
- )
107
-
108
- # -------------------------------------------------------------------
109
- # The composition rule is intentionally simple at this boundary. Both
110
- # sublayers already return model-space tensors of matching shape, so the
111
- # correct hybrid behavior is their direct sum with no additional mixing
112
- # logic introduced here.
113
- # -------------------------------------------------------------------
114
- hybrid_output = local_output + sparse_output
115
-
116
- return hybrid_output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__attention__sliding_window_attention.py DELETED
@@ -1,233 +0,0 @@
1
- # src/shram/model/attention/sliding_window_attention.py
2
-
3
- """Local sliding-window attention path for SHRAM.
4
-
5
- This file defines `SlidingWindowAttention`, the local short-range attention path
6
- used inside the SHRAM hybrid layer.
7
-
8
- In the masked-continuation variant, the local cache no longer returns a
9
- semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
- returns:
11
-
12
- - the retained local window memory concatenated with the current chunk
13
- - an aligned active mask over that returned frame
14
-
15
- This module consumes that returned frame directly and constructs effective local
16
- causal/window visibility from the mask. It does not own cache retention policy;
17
- it owns only local attention semantics.
18
- """
19
-
20
- import math
21
- from typing import Any
22
-
23
- import torch
24
- import torch.nn as nn
25
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
-
27
- from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
28
- from .configuration import ShramConfig
29
- from .rope import RotaryEmbedding
30
-
31
-
32
- class SlidingWindowAttention(nn.Module):
33
- """Causal local sliding-window attention for one SHRAM layer.
34
-
35
- Args:
36
- config: SHRAM config. Must expose `hidden_size`,
37
- `num_sliding_window_heads`, `head_dim`, `window_size`,
38
- `attention_dropout`, and `local_rope_theta`.
39
-
40
- Raises:
41
- NotImplementedError: If `attention_dropout != 0.0`.
42
- """
43
-
44
- def __init__(self, config: ShramConfig) -> None:
45
- super().__init__()
46
-
47
- self.hidden_size = config.hidden_size
48
- self.num_heads = config.num_sliding_window_heads
49
- self.head_dim = config.head_dim
50
- self.window_size = config.window_size
51
- self.attention_dropout = config.attention_dropout
52
-
53
- if self.attention_dropout != 0.0:
54
- raise NotImplementedError(
55
- "SlidingWindowAttention currently supports only "
56
- "attention_dropout == 0.0."
57
- )
58
-
59
- self.inner_dim = self.num_heads * self.head_dim
60
-
61
- # Standard MHA projections for the local path.
62
- self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
- self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
- self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
- self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
-
67
- # The local path always uses default-mode RoPE with its own theta.
68
- self.rope = RotaryEmbedding(
69
- mode="default",
70
- head_dim=self.head_dim,
71
- theta=config.local_rope_theta,
72
- )
73
-
74
- def forward(
75
- self,
76
- x: torch.Tensor,
77
- position_ids: torch.Tensor,
78
- active_mask: torch.Tensor,
79
- cache: LocalSlidingWindowLayerCache | None = None,
80
- ) -> torch.Tensor:
81
- """Apply local causal sliding-window attention.
82
-
83
- Args:
84
- x: Input tensor of shape `(B, N, hidden_size)`.
85
- position_ids: Position tensor of shape `(B, N)`.
86
- active_mask: Current-chunk active mask of shape `(B, N)`, where
87
- `True` means active.
88
- cache: Optional `LocalSlidingWindowLayerCache`.
89
-
90
- Returns:
91
- Output tensor of shape `(B, N, hidden_size)`.
92
- """
93
- batch_size, query_len, _ = x.shape
94
-
95
- self._validate_position_shape(x, position_ids)
96
- self._validate_active_mask_shape(x, active_mask)
97
-
98
- # (B, N, H*D) -> (B, H, N, D)
99
- q = self.q_proj(x).view(
100
- batch_size,
101
- query_len,
102
- self.num_heads,
103
- self.head_dim,
104
- ).transpose(1, 2)
105
- k = self.k_proj(x).view(
106
- batch_size,
107
- query_len,
108
- self.num_heads,
109
- self.head_dim,
110
- ).transpose(1, 2)
111
- v = self.v_proj(x).view(
112
- batch_size,
113
- query_len,
114
- self.num_heads,
115
- self.head_dim,
116
- ).transpose(1, 2)
117
-
118
- q, k, attention_scaling = self.rope(q, k, position_ids)
119
-
120
- # The cache returns the current-step visible local frame, not merely the
121
- # retained next-step cache buffer.
122
- if cache is not None:
123
- k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
124
- else:
125
- k_full, v_full, full_active_mask = k, v, active_mask
126
-
127
- block_mask = self._make_block_mask(
128
- active_mask=full_active_mask,
129
- batch_size=batch_size,
130
- num_heads=self.num_heads,
131
- query_len=query_len,
132
- kv_len=k_full.shape[-2],
133
- window_size=self.window_size,
134
- device=x.device,
135
- )
136
-
137
- attn_output = flex_attention(
138
- q,
139
- k_full,
140
- v_full,
141
- block_mask=block_mask,
142
- scale=attention_scaling / math.sqrt(self.head_dim),
143
- )
144
-
145
- # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
146
- attn_output = (
147
- attn_output.transpose(1, 2)
148
- .contiguous()
149
- .view(batch_size, query_len, self.inner_dim)
150
- )
151
-
152
- return self.o_proj(attn_output)
153
-
154
- def _validate_position_shape(
155
- self,
156
- x: torch.Tensor,
157
- position_ids: torch.Tensor,
158
- ) -> None:
159
- """Validate the position tensor shape expected by local RoPE."""
160
- if position_ids.shape != x.shape[:2]:
161
- raise ValueError(
162
- f"position_ids must have shape {tuple(x.shape[:2])}, "
163
- f"got {tuple(position_ids.shape)}."
164
- )
165
-
166
- def _validate_active_mask_shape(
167
- self,
168
- x: torch.Tensor,
169
- active_mask: torch.Tensor,
170
- ) -> None:
171
- """Validate the current-chunk active-mask contract."""
172
- if active_mask.shape != x.shape[:2]:
173
- raise ValueError(
174
- f"active_mask must have shape {tuple(x.shape[:2])}, "
175
- f"got {tuple(active_mask.shape)}."
176
- )
177
- if active_mask.dtype != torch.bool:
178
- raise ValueError(
179
- f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
180
- )
181
-
182
- def _make_block_mask(
183
- self,
184
- active_mask: torch.Tensor,
185
- batch_size: int,
186
- num_heads: int,
187
- query_len: int,
188
- kv_len: int,
189
- window_size: int,
190
- device: torch.device,
191
- ) -> Any:
192
- """Create the FlexAttention block mask for masked local continuation.
193
-
194
- The returned local frame is chronological in raw buffer order, but dead
195
- positions may remain inside it. Effective local order is therefore
196
- recovered from the active mask itself by taking a cumulative count over
197
- active positions.
198
-
199
- Queries still occupy the tail of the returned frame, so raw buffer order
200
- is used to locate query rows. Semantic active-token positions are then
201
- used to decide causality and sliding-window distance.
202
- """
203
- query_offset = kv_len - query_len
204
- semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
-
206
- def sliding_window_mask(
207
- batch_idx: torch.Tensor,
208
- head_idx: torch.Tensor,
209
- q_idx: torch.Tensor,
210
- kv_idx: torch.Tensor,
211
- ) -> torch.Tensor:
212
-
213
- q_abs = query_offset + q_idx
214
-
215
- query_is_active = active_mask[batch_idx, q_abs]
216
- key_is_active = active_mask[batch_idx, kv_idx]
217
-
218
- q_sem = semantic_positions[batch_idx, q_abs]
219
- k_sem = semantic_positions[batch_idx, kv_idx]
220
-
221
- is_causal = k_sem <= q_sem
222
- in_window = (q_sem - k_sem) < window_size
223
-
224
- return query_is_active & key_is_active & is_causal & in_window
225
-
226
- return create_block_mask(
227
- sliding_window_mask,
228
- B=batch_size,
229
- H=num_heads,
230
- Q_LEN=query_len,
231
- KV_LEN=kv_len,
232
- device=device,
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__cache__mosrah_cache.py DELETED
@@ -1,359 +0,0 @@
1
- """MoSRAH sparse KV cache — single-layer implementation.
2
-
3
- MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
- by head rather than by sequence position. The routing is dynamic and produces a ragged
5
- distribution of token counts across (batch, head) slots — different batch items may
6
- route different numbers of tokens to the same head, and different heads accumulate at
7
- different rates. DynamicCache cannot represent this correctly: it concatenates along
8
- the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
- therefore uses a custom buffer design.
10
-
11
- Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
- attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
- heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
- head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
- valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
- buffer_capacity property and is derived directly from self.keys rather than tracked
17
- as a separate variable.
18
-
19
- The primary interface is update(key_states, value_states, active_mask), which accepts
20
- expert-choice layout, stores only active entries in causal order, and returns the full
21
- accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
- active_mask identifies valid cached positions; everything beyond each slot's count is
23
- junk data that downstream attention must exclude.
24
-
25
- BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
- exposed by get_heads_lengths() must be read before update() if the caller needs the
27
- pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
- in-place and the pre-update values are not recoverable afterward.
29
-
30
- All buffers are allocated at construction time. MoSRAHCache is constructed by
31
- ShramLayerCache, which has access to batch size, device, and all model config parameters
32
- needed to fully specify the storage layout upfront.
33
- """
34
-
35
- import torch
36
- from transformers.cache_utils import CacheLayerMixin
37
-
38
-
39
- class MoSRAHCache(CacheLayerMixin):
40
- """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
-
42
- Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
- Stores keys and values in the mixin-standard self.keys and self.values attributes
44
- using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
- which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
-
47
- All storage is allocated at construction time and is_initialized is True
48
- immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
- config parameters so no lazy allocation is needed.
50
-
51
- Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
- (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
- This matches the packed representation produced by expert packing in the MoSRAH
54
- forward pass, where BEA has already applied RoPE before calling update().
55
-
56
- Args:
57
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
- second dimension of all storage tensors.
59
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
- dimension of all storage tensors.
61
- batch_size: Number of sequences in the batch. Determines the first dimension
62
- of all storage tensors.
63
- device: Device on which to allocate all tensors. Should match the model device.
64
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
- during prompt processing.
67
- """
68
-
69
- is_compileable = False
70
- is_sliding = False
71
-
72
- def __init__(
73
- self,
74
- num_mosrah_heads: int,
75
- head_dim: int,
76
- batch_size: int,
77
- device: torch.device,
78
- initial_buffer_size: int = 64,
79
- ) -> None:
80
- super().__init__()
81
- self.num_mosrah_heads = num_mosrah_heads
82
- self.head_dim = head_dim
83
- self.batch_size = batch_size
84
- self.device = device
85
-
86
- # Allocate primary storage into the mixin-standard self.keys / self.values so
87
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
- # tracks valid occupancy per (batch, head) slot.
89
- self.keys: torch.Tensor = torch.zeros(
90
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
- )
92
- self.values: torch.Tensor = torch.zeros(
93
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
- )
95
- self._counts: torch.Tensor = torch.zeros(
96
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
97
- )
98
-
99
- # Storage is fully allocated at construction — the cache is initialized.
100
- self.is_initialized = True
101
-
102
- # ---------------------------------------------------------------------------
103
- # Properties
104
- # ---------------------------------------------------------------------------
105
-
106
- @property
107
- def buffer_capacity(self) -> int:
108
- """Current number of slots allocated per (batch, head) pair.
109
-
110
- Derived directly from self.keys rather than tracked separately, so it is
111
- always consistent with the actual buffer after expansion.
112
- """
113
- return self.keys.shape[2]
114
-
115
- # ---------------------------------------------------------------------------
116
- # Primary API
117
- # ---------------------------------------------------------------------------
118
-
119
- def update( # type: ignore[override]
120
- self,
121
- key_states: torch.Tensor,
122
- value_states: torch.Tensor,
123
- active_mask: torch.Tensor,
124
- cache_kwargs: dict | None = None,
125
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
- """Scatter active key/value states into the buffer and return the full cache state.
127
-
128
- Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
129
- active_mask is (B, L, T) bool with True marking real tokens. Only active
130
- positions are written; inactive positions are ignored.
131
-
132
- Uses a cumsum construction to derive the absolute buffer position for each
133
- active token without any Python loops. For a given (batch, head) slot,
134
- positions are assigned in the order tokens appear along the T dimension,
135
- preserving causal ordering.
136
-
137
- Returns the full accumulated (keys, values, active_mask) across the cached
138
- sparse sequence. The returned active_mask is True exactly for slots t <
139
- counts[b, l]; everything beyond is junk data that BEA must exclude.
140
-
141
- Note: get_heads_lengths() must be called before update() if the caller needs
142
- the pre-update occupancy for position computation (Unit 10.A). update()
143
- increments counts in-place and the pre-update values are not recoverable.
144
-
145
- Args:
146
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
147
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
148
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
149
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
150
-
151
- Returns:
152
- Tuple of (keys, values, active_mask):
153
- keys: (B, L, T, u) float — full key buffer including junk slots.
154
- values: (B, L, T, u) float — full value buffer including junk slots.
155
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
- """
157
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
-
159
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
- self._expand()
161
-
162
- # Cumulative count of active positions along T for each (b, l) slot. Entry
163
- # [b, l, t] is the 1-based rank of position t among all active positions in
164
- # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
- # produce a negative value, which is excluded by the mask gate below.
166
- within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
-
168
- # Add the pre-update count to get the absolute buffer position for each
169
- # active token.
170
- abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
-
172
- # Scatter key and value vectors at all active positions.
173
- b_idx, l_idx, t_idx = torch.where(active_mask)
174
- self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
- key_states[b_idx, l_idx, t_idx]
176
- )
177
- self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
- value_states[b_idx, l_idx, t_idx]
179
- )
180
-
181
- self._counts += incoming_delta
182
-
183
- return self.keys, self.values, self._make_active_mask()
184
-
185
- def get_heads_lengths(self) -> torch.Tensor:
186
- """Return the per-(batch, head) token count for this layer.
187
-
188
- This is the authoritative occupancy tensor consumed by BEA for attention
189
- masking and by position computation (Unit 10.A) for semantic-sequence
190
- position computation.
191
-
192
- Note: in the MoSRAH forward pass, this must be called before update() if the
193
- caller needs the pre-update occupancy. update() increments these counts in-place.
194
-
195
- Returns:
196
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
197
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
198
- """
199
- return self._counts
200
-
201
- # ---------------------------------------------------------------------------
202
- # CacheLayerMixin — overridden coordination methods
203
- # ---------------------------------------------------------------------------
204
-
205
- def reset(self) -> None:
206
- """Clear all cached key and value tensors.
207
-
208
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
209
- and is_initialized remains True — only the contents are cleared.
210
- """
211
- self.keys.zero_()
212
- self.values.zero_()
213
- self._counts.zero_()
214
-
215
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
216
- """Reorder the batch dimension of all cached tensors for beam search.
217
-
218
- Applied atomically across self.keys, self.values, and _counts. Beam search
219
- must reorder all three together or the occupancy counts and buffer contents
220
- will correspond to different beam hypotheses.
221
-
222
- Overrides the parent because the parent's implementation calls get_seq_length(),
223
- which is not supported for this cache.
224
-
225
- Args:
226
- beam_idx: Permutation indices of shape (batch,) produced by the beam
227
- search algorithm.
228
- """
229
- self.keys = self.keys[beam_idx]
230
- self.values = self.values[beam_idx]
231
- self._counts = self._counts[beam_idx]
232
-
233
- def batch_repeat_interleave(self, repeats: int) -> None:
234
- """Expand the batch dimension by repeating each entry repeats times.
235
-
236
- Used at beam search initialisation to expand the cache from batch size B to
237
- B * repeats, matching the expanded beam candidate batch. Applied atomically
238
- across keys, values, and _counts; batch_size is updated to reflect the new size.
239
-
240
- Args:
241
- repeats: Number of times to repeat each batch entry.
242
- """
243
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
244
- self.values = self.values.repeat_interleave(repeats, dim=0)
245
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
246
- self.batch_size = self.batch_size * repeats
247
-
248
- def batch_select_indices(self, indices: torch.Tensor) -> None:
249
- """Select a subset of batch entries by index.
250
-
251
- Used in contrastive search to retain only the selected candidate entries.
252
- Applied atomically across keys, values, and _counts; batch_size is updated
253
- to reflect the number of retained entries.
254
-
255
- Args:
256
- indices: 1-D integer tensor of batch indices to retain.
257
- """
258
- self.keys = self.keys[indices]
259
- self.values = self.values[indices]
260
- self._counts = self._counts[indices]
261
- self.batch_size = indices.shape[0]
262
-
263
- def offload(self) -> None:
264
- """Offload all cached tensors to CPU.
265
-
266
- Extends the parent to also offload _counts, which the parent does not know
267
- about. All three tensors are moved atomically so device state remains consistent.
268
- """
269
- super().offload()
270
- self._counts = self._counts.to("cpu", non_blocking=True)
271
-
272
- def prefetch(self) -> None:
273
- """Move all cached tensors back to the model device ahead of time.
274
-
275
- Extends the parent to also prefetch _counts, which the parent does not know
276
- about. _counts is synced to self.keys.device after the parent moves keys and
277
- values, so all three remain consistent.
278
- """
279
- super().prefetch()
280
- if self._counts.device != self.keys.device:
281
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
282
-
283
- def lazy_initialization( # type: ignore[override]
284
- self, key_states: torch.Tensor, value_states: torch.Tensor
285
- ) -> None:
286
- """No-op — storage is fully allocated at construction time."""
287
- pass
288
-
289
- # ---------------------------------------------------------------------------
290
- # CacheLayerMixin — unsupported abstract methods
291
- # ---------------------------------------------------------------------------
292
-
293
- def get_seq_length(self) -> int: # type: ignore[override]
294
- """Not supported — no single sequence length represents this cache's state.
295
-
296
- MoSRAH heads accumulate independently; (batch, head) slots have different
297
- lengths depending on routing history. There is no meaningful scalar summary.
298
- Use get_heads_lengths() for per-head occupancy.
299
- """
300
- raise NotImplementedError(
301
- "MoSRAHCache has no single sequence length. "
302
- "Use get_heads_lengths() for per-head occupancy."
303
- )
304
-
305
- def get_max_cache_shape(self) -> int: # type: ignore[override]
306
- """Not supported — MoSRAHCache is dynamic and unbounded."""
307
- raise NotImplementedError(
308
- "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
- )
310
-
311
- def get_mask_sizes( # type: ignore[override]
312
- self,
313
- cache_position: torch.Tensor,
314
- ) -> tuple[int, int]:
315
- """Not supported — MoSRAHCache does not participate in HF mask construction."""
316
- raise NotImplementedError(
317
- "MoSRAHCache does not support get_mask_sizes()."
318
- )
319
-
320
- # ---------------------------------------------------------------------------
321
- # Internal helpers
322
- # ---------------------------------------------------------------------------
323
-
324
- def _make_active_mask(self) -> torch.Tensor:
325
- """Construct the (B, L, T) active mask from current counts.
326
-
327
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
328
- has been written. Positions at or beyond the count are junk and must be
329
- excluded by downstream attention.
330
- """
331
- cap = self.buffer_capacity
332
- return (
333
- torch.arange(cap, device=self.keys.device)
334
- .expand(self.batch_size, self.num_mosrah_heads, cap)
335
- < self._counts.unsqueeze(-1)
336
- )
337
-
338
- def _expand(self) -> None:
339
- """Double the buffer capacity, preserving existing data.
340
-
341
- Called by update() when an incoming batch of tokens would cause any
342
- (batch, head) slot to exceed the current buffer capacity. All existing
343
- key and value data is copied into the low half of the new buffer; the
344
- high half is zero-initialised and will be filled by subsequent writes.
345
- After reassignment, buffer_capacity reflects the new size automatically.
346
- """
347
- old_cap = self.buffer_capacity
348
- new_cap = old_cap * 2
349
- dev = self.keys.device
350
- new_keys = torch.zeros(
351
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
- )
353
- new_values = torch.zeros(
354
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
- )
356
- new_keys[:, :, :old_cap, :] = self.keys
357
- new_values[:, :, :old_cap, :] = self.values
358
- self.keys = new_keys
359
- self.values = new_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__cache__shram_cache.py DELETED
@@ -1,141 +0,0 @@
1
- """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
-
3
- The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
- CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
- lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
- ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
- presents them through the Cache interface, and transparently forwards model-wide operations
8
- across all of them.
9
-
10
- ShramCache does not define a composite update() interface. The two attention paths inside each
11
- SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
- nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
- relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
- coordination of the layer caches — not routing attention inputs.
15
-
16
- Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
- specified layer, which tracks the cumulative count of token positions processed. This is
18
- what HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import Cache
23
-
24
- from .__cache__shram_layer_cache import ShramLayerCache
25
-
26
-
27
- class ShramCache(Cache):
28
- """Top-level cache for the full SHRAM model.
29
-
30
- Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
31
- role and transparently forwards reset, reorder, and sequence-length queries across all
32
- owned layer caches.
33
-
34
- No composite update() interface is provided. The two attention paths inside each SHRAM
35
- layer have materially different update semantics; callers must update sub-caches directly
36
- via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
-
38
- Args:
39
- num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
- ShramLayerCache objects are constructed.
41
- sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
- num_local_heads: Number of local attention heads per layer.
43
- local_head_dim: Per-head embedding width for the local path.
44
- num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
- batch_size: Number of sequences in the batch.
47
- device: Device on which to allocate cache tensors.
48
- initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
- Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
- during prompt processing.
51
- """
52
-
53
- def __init__(
54
- self,
55
- num_hidden_layers: int,
56
- sliding_window: int,
57
- num_local_heads: int,
58
- local_head_dim: int,
59
- num_mosrah_heads: int,
60
- mosrah_head_dim: int,
61
- batch_size: int,
62
- device: torch.device,
63
- initial_buffer_size: int = 64,
64
- ) -> None:
65
- layers = [
66
- ShramLayerCache(
67
- sliding_window=sliding_window,
68
- num_local_heads=num_local_heads,
69
- local_head_dim=local_head_dim,
70
- num_mosrah_heads=num_mosrah_heads,
71
- mosrah_head_dim=mosrah_head_dim,
72
- batch_size=batch_size,
73
- device=device,
74
- initial_buffer_size=initial_buffer_size,
75
- )
76
- for _ in range(num_hidden_layers)
77
- ]
78
- super().__init__(layers=layers)
79
-
80
- # ---------------------------------------------------------------------------
81
- # Cache — composite-meaningful methods
82
- # ---------------------------------------------------------------------------
83
- #
84
- # reset(): Inherited. Iterates all layer caches and calls reset() on each.
85
- #
86
- # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
87
- #
88
- # is_initialized: Inherited property. True iff all layer caches are initialized.
89
- # Since ShramLayerCache.is_initialized is True from construction, this is True
90
- # immediately after ShramCache.__init__ returns.
91
-
92
- def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
93
- """Return the cumulative sequence length for the specified layer.
94
-
95
- Delegates to the layer cache at layer_idx, which in turn delegates to the
96
- local sliding-window sub-cache. That sub-cache is authoritative for sequence
97
- progress: it sees every token presented to the layer and accumulates a truthful
98
- total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
99
- """
100
- return self.layers[layer_idx].get_seq_length()
101
-
102
- # ---------------------------------------------------------------------------
103
- # Cache — unsupported methods
104
- # ---------------------------------------------------------------------------
105
-
106
- def update( # type: ignore[override]
107
- self,
108
- key_states: torch.Tensor,
109
- value_states: torch.Tensor,
110
- layer_idx: int,
111
- cache_kwargs: dict | None = None,
112
- ) -> tuple[torch.Tensor, torch.Tensor]:
113
- """Not supported — ShramCache has no composite update interface.
114
-
115
- The two attention paths inside each SHRAM layer have different update semantics.
116
- Callers must update sub-caches directly:
117
- cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
118
- cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
119
- """
120
- raise NotImplementedError(
121
- "ShramCache has no composite update interface. "
122
- "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
123
- )
124
-
125
- def crop(self, max_length: int) -> None:
126
- """Not supported — ShramCache layers do not implement crop()."""
127
- raise NotImplementedError("ShramCache does not support crop().")
128
-
129
- @property
130
- def max_batch_size(self) -> int:
131
- """Not supported — ShramCache does not track a uniform batch size across layers."""
132
- raise NotImplementedError("ShramCache does not expose max_batch_size.")
133
-
134
- @property
135
- def max_cache_len(self) -> int:
136
- """Not supported — ShramCache has no single maximum cache length.
137
-
138
- The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
- No truthful scalar maximum represents the composite.
140
- """
141
- raise NotImplementedError("ShramCache does not expose max_cache_len.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__cache__shram_layer_cache.py DELETED
@@ -1,233 +0,0 @@
1
- """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
-
3
- A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
- local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
- different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
- the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
- path can interact with it without indirection.
8
-
9
- ShramLayerCache does not define a composite update() interface. The two paths have materially
10
- different update semantics — the local side uses chunk-local key/value/mask concatenation
11
- while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
- behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
- each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
- ownership, coordination, and reset/reorder boundary for one decoder layer.
15
-
16
- Sequence length at this boundary is reported by delegating to the local sliding-window
17
- sub-cache, which tracks the cumulative count of token positions processed. This is the
18
- quantity HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import CacheLayerMixin
23
-
24
- from .__cache__mosrah_cache import MoSRAHCache
25
- from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
26
-
27
-
28
- class ShramLayerCache(CacheLayerMixin):
29
- """Cache subsystem for one SHRAM decoder layer.
30
-
31
- Owns and coordinates two sub-caches:
32
- - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
33
- - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
34
-
35
- Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
36
- exposed directly for their downstream attention paths — no composite update() interface is
37
- provided, because the two paths have materially different update semantics.
38
-
39
- Sequence length is reported by delegating to the local sliding-window sub-cache, which
40
- tracks the cumulative count of token positions processed across all update() calls.
41
-
42
- Args:
43
- sliding_window: Number of tokens retained by the local sliding-window cache.
44
- num_local_heads: Number of local attention heads.
45
- local_head_dim: Per-head embedding width for the local path.
46
- num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
- batch_size: Number of sequences in the batch.
49
- device: Device on which to allocate cache tensors.
50
- initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
- when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
- prompt processing.
53
- """
54
-
55
- is_compileable = False
56
- is_sliding = False
57
-
58
- def __init__(
59
- self,
60
- sliding_window: int,
61
- num_local_heads: int,
62
- local_head_dim: int,
63
- num_mosrah_heads: int,
64
- mosrah_head_dim: int,
65
- batch_size: int,
66
- device: torch.device,
67
- initial_buffer_size: int = 64,
68
- ) -> None:
69
- super().__init__()
70
- self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
- sliding_window=sliding_window,
72
- num_heads=num_local_heads,
73
- head_dim=local_head_dim,
74
- batch_size=batch_size,
75
- device=device,
76
- )
77
- self.mosrah_cache = MoSRAHCache(
78
- num_mosrah_heads=num_mosrah_heads,
79
- head_dim=mosrah_head_dim,
80
- batch_size=batch_size,
81
- device=device,
82
- initial_buffer_size=initial_buffer_size,
83
- )
84
-
85
- # ---------------------------------------------------------------------------
86
- # Properties
87
- # ---------------------------------------------------------------------------
88
-
89
- @property
90
- def is_initialized(self) -> bool:
91
- """True iff both sub-caches have allocated their storage.
92
-
93
- Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
94
- so this is True immediately after ShramLayerCache.__init__ returns.
95
- """
96
- return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
97
-
98
- @is_initialized.setter
99
- def is_initialized(self, value: bool) -> None:
100
- # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
101
- # attribute. Since property is a data descriptor it takes precedence, but Python
102
- # still routes the assignment through __set__. Absorb it silently — state is
103
- # derived from sub-caches, not stored here.
104
- pass
105
-
106
- # ---------------------------------------------------------------------------
107
- # CacheLayerMixin — composite-meaningful methods
108
- # ---------------------------------------------------------------------------
109
-
110
- def get_seq_length(self) -> int: # type: ignore[override]
111
- """Return the cumulative sequence length from the local sliding-window path.
112
-
113
- The local path is authoritative for sequence progress: it sees every token
114
- presented to this layer and accumulates a truthful total. Delegates to
115
- sliding_window_cache.get_seq_length().
116
- """
117
- return self.sliding_window_cache.get_seq_length()
118
-
119
- def reset(self) -> None:
120
- """Clear both sub-caches.
121
-
122
- Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
123
- state and MoSRAH sparse state remain consistent.
124
- """
125
- self.sliding_window_cache.reset()
126
- self.mosrah_cache.reset()
127
-
128
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
129
- """Reorder the batch dimension of both sub-caches for beam search.
130
-
131
- Delegates to each sub-cache. Both are reordered atomically so the sliding-window
132
- and MoSRAH state correspond to the same beam hypotheses after reordering.
133
-
134
- Args:
135
- beam_idx: Permutation indices of shape (batch,) produced by beam search.
136
- """
137
- self.sliding_window_cache.reorder_cache(beam_idx)
138
- self.mosrah_cache.reorder_cache(beam_idx)
139
-
140
- def batch_repeat_interleave(self, repeats: int) -> None:
141
- """Expand the batch dimension of both sub-caches for beam search initialisation.
142
-
143
- Delegates atomically to each sub-cache. Both must be expanded together so the
144
- sliding-window and MoSRAH state correspond to the same beam candidates.
145
-
146
- Args:
147
- repeats: Number of times to repeat each batch entry.
148
- """
149
- self.sliding_window_cache.batch_repeat_interleave(repeats)
150
- self.mosrah_cache.batch_repeat_interleave(repeats)
151
-
152
- def batch_select_indices(self, indices: torch.Tensor) -> None:
153
- """Select a subset of batch entries in both sub-caches for contrastive search.
154
-
155
- Delegates atomically to each sub-cache. Both must be trimmed together so the
156
- sliding-window and MoSRAH state remain consistent.
157
-
158
- Args:
159
- indices: 1-D integer tensor of batch indices to retain.
160
- """
161
- self.sliding_window_cache.batch_select_indices(indices)
162
- self.mosrah_cache.batch_select_indices(indices)
163
-
164
- def offload(self) -> None:
165
- """Offload both sub-caches to CPU.
166
-
167
- Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
168
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
- """
170
- self.sliding_window_cache.offload()
171
- self.mosrah_cache.offload()
172
-
173
- def prefetch(self) -> None:
174
- """Move both sub-caches back to their model device ahead of time.
175
-
176
- Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
177
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
178
- """
179
- self.sliding_window_cache.prefetch()
180
- self.mosrah_cache.prefetch()
181
-
182
- def lazy_initialization( # type: ignore[override]
183
- self, key_states: torch.Tensor, value_states: torch.Tensor
184
- ) -> None:
185
- """No-op — both sub-caches handle their own initialization."""
186
- pass
187
-
188
- # ---------------------------------------------------------------------------
189
- # CacheLayerMixin — unsupported abstract methods
190
- # ---------------------------------------------------------------------------
191
-
192
- def update( # type: ignore[override]
193
- self,
194
- key_states: torch.Tensor,
195
- value_states: torch.Tensor,
196
- cache_kwargs: dict | None = None,
197
- ) -> tuple[torch.Tensor, torch.Tensor]:
198
- """Not supported — ShramLayerCache has no composite update interface.
199
-
200
- The two sub-caches have materially different update semantics: the sliding-window
201
- side uses standard key/value concatenation while the MoSRAH side uses expert-choice
202
- scatter with an active mask. Callers must update each sub-cache directly via
203
- sliding_window_cache.update() or mosrah_cache.update().
204
- """
205
- raise NotImplementedError(
206
- "ShramLayerCache has no composite update interface. "
207
- "Update sliding_window_cache or mosrah_cache directly."
208
- )
209
-
210
- def get_max_cache_shape(self) -> int: # type: ignore[override]
211
- """Not supported — the composite cache has no single maximum shape.
212
-
213
- The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
- unbounded. No truthful scalar maximum represents the composite.
215
- """
216
- raise NotImplementedError(
217
- "ShramLayerCache has no single maximum cache shape. "
218
- "Query sliding_window_cache or mosrah_cache directly."
219
- )
220
-
221
- def get_mask_sizes( # type: ignore[override]
222
- self,
223
- cache_position: torch.Tensor,
224
- ) -> tuple[int, int]:
225
- """Not supported — ShramLayerCache does not participate in HF mask construction.
226
-
227
- The two sub-caches have different mask semantics and their respective attention
228
- paths handle masking directly.
229
- """
230
- raise NotImplementedError(
231
- "ShramLayerCache does not support get_mask_sizes(). "
232
- "Query sliding_window_cache or mosrah_cache directly."
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__cache__sliding_window_cache.py DELETED
@@ -1,289 +0,0 @@
1
- # src/shram/model/cache/sliding_window_cache.py
2
-
3
- """Local sliding-window cache for the SHRAM local attention path.
4
-
5
- This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
- `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
-
8
- Its job is narrow:
9
-
10
- - accept the current chunk's local key/value tensors and active mask
11
- - return the current-step local frame consumed by local attention
12
- - separately retain the next-step sliding-window cache state
13
-
14
- It does not decide local causal visibility. That is owned by
15
- `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
- constructs the effective local attention mask from it.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
- """Fixed-width local cache for one SHRAM decoder layer.
25
-
26
- The cache keeps a retained local sliding-window buffer and an aligned active
27
- mask. On update, it returns the current-step local frame formed by
28
- concatenating retained cache state with the new chunk, then remembers only
29
- the last `sliding_window` positions for the next step.
30
-
31
- Dead positions are allowed to remain in both the returned frame and the
32
- retained cache. Correctness is carried by the aligned active mask.
33
-
34
- Args:
35
- sliding_window: Width of the retained local sliding-window buffer.
36
- num_heads: Number of local attention heads.
37
- head_dim: Per-head embedding width for the local path.
38
- batch_size: Number of sequences in the batch.
39
- device: Device on which to allocate cache storage.
40
- """
41
-
42
- is_compileable = False
43
- is_sliding = True
44
-
45
- def __init__(
46
- self,
47
- sliding_window: int,
48
- num_heads: int,
49
- head_dim: int,
50
- batch_size: int,
51
- device: torch.device,
52
- ) -> None:
53
- super().__init__()
54
-
55
- if sliding_window < 1:
56
- raise ValueError(
57
- f"sliding_window must be >= 1, got {sliding_window}."
58
- )
59
- if num_heads < 1:
60
- raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
- if head_dim < 1:
62
- raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
- if batch_size < 1:
64
- raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
-
66
- self.sliding_window = sliding_window
67
- self.num_heads = num_heads
68
- self.head_dim = head_dim
69
- self.batch_size = batch_size
70
- self.device = device
71
-
72
- # Retained next-step local cache state. Storage is fixed-width from the
73
- # start; semantic validity is carried by `active_mask`.
74
- self.keys = torch.zeros(
75
- batch_size,
76
- num_heads,
77
- sliding_window,
78
- head_dim,
79
- device=device,
80
- )
81
- self.values = torch.zeros(
82
- batch_size,
83
- num_heads,
84
- sliding_window,
85
- head_dim,
86
- device=device,
87
- )
88
- self.active_mask = torch.zeros(
89
- batch_size,
90
- sliding_window,
91
- dtype=torch.bool,
92
- device=device,
93
- )
94
-
95
- self.is_initialized = True
96
-
97
- # Cumulative count of all token positions presented through update() for
98
- # this cache instance. This is the quantity HuggingFace generation reads
99
- # through get_seq_length() to track how far along the sequence we are.
100
- self._total_processed: int = 0
101
-
102
- def update( # type: ignore[override]
103
- self,
104
- key_states: torch.Tensor,
105
- value_states: torch.Tensor,
106
- active_mask: torch.Tensor,
107
- cache_kwargs: dict | None = None,
108
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
- """Return the current-step local frame and retain the next-step window.
110
-
111
- Args:
112
- key_states: Shape `(B, H, T_new, D)` local key vectors for the
113
- current chunk.
114
- value_states: Shape `(B, H, T_new, D)` local value vectors for the
115
- current chunk.
116
- active_mask: Shape `(B, T_new)` bool. `True` means the
117
- corresponding token position in the current chunk is active.
118
- cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
- interface. Unused by this cache.
120
-
121
- Returns:
122
- Tuple of:
123
- - visible_keys: `(B, H, sliding_window + T_new, D)`
124
- - visible_values: `(B, H, sliding_window + T_new, D)`
125
- - visible_active_mask: `(B, sliding_window + T_new)`
126
-
127
- These are the tensors the local attention path should consume
128
- directly for the current step.
129
- """
130
- self._ensure_state_compatibility(
131
- key_states=key_states,
132
- value_states=value_states,
133
- )
134
-
135
- # The current-step local frame is just retained cache state followed by
136
- # the current chunk in chronological order.
137
- composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
- key_states=key_states,
139
- value_states=value_states,
140
- active_mask=active_mask,
141
- )
142
-
143
- # The cache remembers only the last raw sliding-window positions of that
144
- # composite frame for the next step. Dead positions are allowed to
145
- # survive; downstream local attention will ignore them using the mask.
146
- self._retain_next_window(
147
- composite_keys=composite_keys,
148
- composite_values=composite_values,
149
- composite_mask=composite_mask,
150
- )
151
-
152
- self._total_processed += key_states.shape[2]
153
-
154
- return composite_keys, composite_values, composite_mask
155
-
156
- def _ensure_state_compatibility(
157
- self,
158
- key_states: torch.Tensor,
159
- value_states: torch.Tensor,
160
- ) -> None:
161
- """Keep retained cache buffers compatible with the incoming update tensors.
162
-
163
- The cache is allocated eagerly for simplicity. If later updates arrive on
164
- a different device or in a different floating dtype, move the retained
165
- state to match while preserving its contents.
166
- """
167
- if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
168
- self.keys = self.keys.to(
169
- device=key_states.device,
170
- dtype=key_states.dtype,
171
- )
172
-
173
- if (
174
- self.values.dtype != value_states.dtype
175
- or self.values.device != value_states.device
176
- ):
177
- self.values = self.values.to(
178
- device=value_states.device,
179
- dtype=value_states.dtype,
180
- )
181
-
182
- if self.active_mask.device != key_states.device:
183
- self.active_mask = self.active_mask.to(
184
- key_states.device,
185
- non_blocking=True,
186
- )
187
-
188
- def _make_composite_frame(
189
- self,
190
- key_states: torch.Tensor,
191
- value_states: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """Build the current-step local frame in chronological order."""
195
- return (
196
- torch.cat([self.keys, key_states], dim=-2),
197
- torch.cat([self.values, value_states], dim=-2),
198
- torch.cat([self.active_mask, active_mask], dim=-1),
199
- )
200
-
201
- def _retain_next_window(
202
- self,
203
- composite_keys: torch.Tensor,
204
- composite_values: torch.Tensor,
205
- composite_mask: torch.Tensor,
206
- ) -> None:
207
- """Remember the next-step retained local state.
208
-
209
- This is a raw positional trim to the last `sliding_window` positions, not
210
- a semantic live-token trim.
211
- """
212
- self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
- self.values = composite_values[:, :, -self.sliding_window :, :]
214
- self.active_mask = composite_mask[:, -self.sliding_window :]
215
-
216
- def get_seq_length(self) -> int:
217
- """Return the cumulative number of token positions processed by this cache.
218
-
219
- This is the total count of token positions presented across all update()
220
- calls since construction or the last reset(). It is the quantity HuggingFace
221
- generation reads to track sequence progress and is not the same as active-token
222
- count or current window occupancy.
223
- """
224
- return self._total_processed
225
-
226
- def get_max_cache_shape(self) -> int:
227
- return self.sliding_window
228
-
229
- def get_mask_sizes( # type: ignore[override]
230
- self,
231
- cache_position: torch.Tensor,
232
- ) -> tuple[int, int]:
233
- raise NotImplementedError(
234
- "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
235
- )
236
-
237
- def reset(self) -> None:
238
- """Restore fresh-cache behavior."""
239
- self.keys.zero_()
240
- self.values.zero_()
241
- self.active_mask.zero_()
242
- self._total_processed = 0
243
-
244
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
245
- """Reorder the batch dimension for beam search."""
246
- self.keys = self.keys[beam_idx]
247
- self.values = self.values[beam_idx]
248
- self.active_mask = self.active_mask[beam_idx]
249
-
250
- def batch_repeat_interleave(self, repeats: int) -> None:
251
- """Expand the batch dimension for beam-search initialisation."""
252
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
- self.values = self.values.repeat_interleave(repeats, dim=0)
254
- self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
255
- self.batch_size = self.batch_size * repeats
256
-
257
- def batch_select_indices(self, indices: torch.Tensor) -> None:
258
- """Select a subset of batch entries for contrastive search."""
259
- self.keys = self.keys[indices]
260
- self.values = self.values[indices]
261
- self.active_mask = self.active_mask[indices]
262
- self.batch_size = int(indices.shape[0])
263
-
264
- def offload(self) -> None:
265
- """Offload cache tensors to CPU."""
266
- super().offload()
267
- self.active_mask = self.active_mask.to("cpu", non_blocking=True)
268
-
269
- def prefetch(self) -> None:
270
- """Move cache tensors back to the model device ahead of time."""
271
- super().prefetch()
272
- if self.active_mask.device != self.keys.device:
273
- self.active_mask = self.active_mask.to(
274
- self.keys.device,
275
- non_blocking=True,
276
- )
277
-
278
- def crop(self, max_length: int) -> None:
279
- raise NotImplementedError(
280
- "LocalSlidingWindowLayerCache does not support crop()."
281
- )
282
-
283
- def lazy_initialization(
284
- self,
285
- key_states: torch.Tensor,
286
- value_states: torch.Tensor,
287
- ) -> None:
288
- """No-op — this cache allocates its fixed buffers at construction time."""
289
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__cache__slow_mosrah_cache.py DELETED
@@ -1,321 +0,0 @@
1
- """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
-
3
- This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
- interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
- (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
- position's key and value are written to the next available slot for that (batch, head)
7
- pair, in the order positions appear along the T dimension, which directly enforces
8
- causal ordering without any index arithmetic to verify.
9
-
10
- SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
- trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
- Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
- vectorized implementation is validated by asserting exact agreement with this one on all
14
- test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
- (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
- an oracle.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class SlowMoSRAHCache(CacheLayerMixin):
24
- """Unvectorized reference implementation of the MoSRAH KV cache.
25
-
26
- Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
- mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
- with the same constructor signature and the same CacheLayerMixin protocol methods.
29
- The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
- triples rather than vectorized index arithmetic.
31
-
32
- This class is not used in the model path. It exists so that MoSRAHCache.update()
33
- can be validated by asserting exact agreement with this implementation on all test
34
- inputs. See module docstring for the trust chain this enables.
35
-
36
- Args:
37
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
- second dimension of all storage tensors.
39
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
- dimension of all storage tensors.
41
- batch_size: Number of sequences in the batch. Determines the first dimension
42
- of all storage tensors.
43
- device: Device on which to allocate all tensors. Should match the model device.
44
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
- during prompt processing.
47
- """
48
-
49
- is_compileable = False
50
- is_sliding = False
51
-
52
- def __init__(
53
- self,
54
- num_mosrah_heads: int,
55
- head_dim: int,
56
- batch_size: int,
57
- device: torch.device,
58
- initial_buffer_size: int = 64,
59
- ) -> None:
60
- super().__init__()
61
- self.num_mosrah_heads = num_mosrah_heads
62
- self.head_dim = head_dim
63
- self.batch_size = batch_size
64
- self.device = device
65
-
66
- # Allocate primary storage into the mixin-standard self.keys / self.values so
67
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
- # tracks valid occupancy per (batch, head) slot.
69
- self.keys: torch.Tensor = torch.zeros(
70
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
- )
72
- self.values: torch.Tensor = torch.zeros(
73
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
- )
75
- self._counts: torch.Tensor = torch.zeros(
76
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
77
- )
78
-
79
- # Storage is fully allocated at construction — the cache is initialized.
80
- self.is_initialized = True
81
-
82
- # ---------------------------------------------------------------------------
83
- # Properties
84
- # ---------------------------------------------------------------------------
85
-
86
- @property
87
- def buffer_capacity(self) -> int:
88
- """Current number of slots allocated per (batch, head) pair.
89
-
90
- Derived directly from self.keys rather than tracked separately, so it is
91
- always consistent with the actual buffer after expansion.
92
- """
93
- return self.keys.shape[2]
94
-
95
- # ---------------------------------------------------------------------------
96
- # Primary API
97
- # ---------------------------------------------------------------------------
98
-
99
- def update( # type: ignore[override]
100
- self,
101
- key_states: torch.Tensor,
102
- value_states: torch.Tensor,
103
- active_mask: torch.Tensor,
104
- cache_kwargs: dict | None = None,
105
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
- """Scatter active key/value states using an explicit loop; return full cache state.
107
-
108
- Iterates over every (b, l, t) triple. For each position where active_mask is
109
- True, the key and value are written to the next available slot for that
110
- (batch, head) pair and the count is incremented. Causal ordering is guaranteed
111
- because the t dimension is traversed from 0 to T-1 and counts are updated
112
- immediately after each write.
113
-
114
- Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
- the incoming tokens would cause any slot to overflow the current capacity.
116
-
117
- Args:
118
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
119
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
120
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
121
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
122
-
123
- Returns:
124
- Tuple of (keys, values, active_mask):
125
- keys: (B, L, T, u) float — full key buffer including junk slots.
126
- values: (B, L, T, u) float — full value buffer including junk slots.
127
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
- """
129
- B, L, T = active_mask.shape
130
-
131
- # Expansion check uses the total active tokens per slot, same as the
132
- # vectorized implementation, so both expand under identical conditions.
133
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
- self._expand()
136
-
137
- # Write each active position into the next available slot for its (batch, head)
138
- # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
139
- for b in range(B):
140
- for l in range(L):
141
- for t in range(T):
142
- if active_mask[b, l, t]:
143
- pos = self._counts[b, l].item()
144
- self.keys[b, l, pos, :] = key_states[b, l, t, :]
145
- self.values[b, l, pos, :] = value_states[b, l, t, :]
146
- self._counts[b, l] += 1
147
-
148
- return self.keys, self.values, self._make_active_mask()
149
-
150
- def get_heads_lengths(self) -> torch.Tensor:
151
- """Return the per-(batch, head) token count for this layer.
152
-
153
- This is the authoritative occupancy tensor consumed by BEA for attention
154
- masking and by position computation (Unit 10.A) for semantic-sequence
155
- position computation.
156
-
157
- Returns:
158
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
159
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
160
- """
161
- return self._counts
162
-
163
- # ---------------------------------------------------------------------------
164
- # CacheLayerMixin — overridden coordination methods
165
- # ---------------------------------------------------------------------------
166
-
167
- def reset(self) -> None:
168
- """Clear all cached key and value tensors.
169
-
170
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
171
- and is_initialized remains True — only the contents are cleared.
172
- """
173
- self.keys.zero_()
174
- self.values.zero_()
175
- self._counts.zero_()
176
-
177
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
178
- """Reorder the batch dimension of all cached tensors for beam search.
179
-
180
- Applied atomically across self.keys, self.values, and _counts. Beam search
181
- must reorder all three together or the occupancy counts and buffer contents
182
- will correspond to different beam hypotheses.
183
-
184
- Overrides the parent because the parent's implementation calls get_seq_length(),
185
- which is not supported for this cache.
186
-
187
- Args:
188
- beam_idx: Permutation indices of shape (batch,) produced by the beam
189
- search algorithm.
190
- """
191
- self.keys = self.keys[beam_idx]
192
- self.values = self.values[beam_idx]
193
- self._counts = self._counts[beam_idx]
194
-
195
- def batch_repeat_interleave(self, repeats: int) -> None:
196
- """Expand the batch dimension by repeating each entry repeats times.
197
-
198
- Used at beam search initialisation to expand the cache from batch size B to
199
- B * repeats, matching the expanded beam candidate batch. Applied atomically
200
- across keys, values, and _counts; batch_size is updated to reflect the new size.
201
-
202
- Args:
203
- repeats: Number of times to repeat each batch entry.
204
- """
205
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
206
- self.values = self.values.repeat_interleave(repeats, dim=0)
207
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
208
- self.batch_size = self.batch_size * repeats
209
-
210
- def batch_select_indices(self, indices: torch.Tensor) -> None:
211
- """Select a subset of batch entries by index.
212
-
213
- Used in contrastive search to retain only the selected candidate entries.
214
- Applied atomically across keys, values, and _counts; batch_size is updated
215
- to reflect the number of retained entries.
216
-
217
- Args:
218
- indices: 1-D integer tensor of batch indices to retain.
219
- """
220
- self.keys = self.keys[indices]
221
- self.values = self.values[indices]
222
- self._counts = self._counts[indices]
223
- self.batch_size = indices.shape[0]
224
-
225
- def offload(self) -> None:
226
- """Offload all cached tensors to CPU.
227
-
228
- Extends the parent to also offload _counts, which the parent does not know
229
- about. All three tensors are moved atomically so device state remains consistent.
230
- """
231
- super().offload()
232
- self._counts = self._counts.to("cpu", non_blocking=True)
233
-
234
- def prefetch(self) -> None:
235
- """Move all cached tensors back to the model device ahead of time.
236
-
237
- Extends the parent to also prefetch _counts, which the parent does not know
238
- about. _counts is synced to self.keys.device after the parent moves keys and
239
- values, so all three remain consistent.
240
- """
241
- super().prefetch()
242
- if self._counts.device != self.keys.device:
243
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
244
-
245
- def lazy_initialization( # type: ignore[override]
246
- self, key_states: torch.Tensor, value_states: torch.Tensor
247
- ) -> None:
248
- """No-op — storage is fully allocated at construction time."""
249
- pass
250
-
251
- # ---------------------------------------------------------------------------
252
- # CacheLayerMixin — unsupported abstract methods
253
- # ---------------------------------------------------------------------------
254
-
255
- def get_seq_length(self) -> int: # type: ignore[override]
256
- """Not supported — no single sequence length represents this cache's state.
257
-
258
- MoSRAH heads accumulate independently; (batch, head) slots have different
259
- lengths depending on routing history. There is no meaningful scalar summary.
260
- Use get_heads_lengths() for per-head occupancy.
261
- """
262
- raise NotImplementedError(
263
- "SlowMoSRAHCache has no single sequence length. "
264
- "Use get_heads_lengths() for per-head occupancy."
265
- )
266
-
267
- def get_max_cache_shape(self) -> int: # type: ignore[override]
268
- """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
269
- raise NotImplementedError(
270
- "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
271
- )
272
-
273
- def get_mask_sizes( # type: ignore[override]
274
- self,
275
- cache_position: torch.Tensor,
276
- ) -> tuple[int, int]:
277
- """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
278
- raise NotImplementedError(
279
- "SlowMoSRAHCache does not support get_mask_sizes()."
280
- )
281
-
282
- # ---------------------------------------------------------------------------
283
- # Internal helpers
284
- # ---------------------------------------------------------------------------
285
-
286
- def _make_active_mask(self) -> torch.Tensor:
287
- """Construct the (B, L, T) active mask from current counts.
288
-
289
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
290
- has been written. Positions at or beyond the count are junk and must be
291
- excluded by downstream attention.
292
- """
293
- cap = self.buffer_capacity
294
- return (
295
- torch.arange(cap, device=self.keys.device)
296
- .expand(self.batch_size, self.num_mosrah_heads, cap)
297
- < self._counts.unsqueeze(-1)
298
- )
299
-
300
- def _expand(self) -> None:
301
- """Double the buffer capacity, preserving existing data.
302
-
303
- Called by update() when an incoming batch of tokens would cause any
304
- (batch, head) slot to exceed the current buffer capacity. All existing
305
- key and value data is copied into the low half of the new buffer; the
306
- high half is zero-initialised and will be filled by subsequent writes.
307
- After reassignment, buffer_capacity reflects the new size automatically.
308
- """
309
- old_cap = self.buffer_capacity
310
- new_cap = old_cap * 2
311
- dev = self.keys.device
312
- new_keys = torch.zeros(
313
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
- )
315
- new_values = torch.zeros(
316
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
- )
318
- new_keys[:, :, :old_cap, :] = self.keys
319
- new_values[:, :, :old_cap, :] = self.values
320
- self.keys = new_keys
321
- self.values = new_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from .configuration import ShramConfig
2
- from .decoder_layer import DecoderLayer
3
- from .huggingface import ShramForCausalLM
4
- from .__attention__load_balance_loss import LoadBalanceLoss
5
- from .mlp import SwiGLUMLP
6
- from .model import ShramModel
7
- from .rope import RotaryEmbedding
8
- from .__attention__router import MoSRAHRouter
9
- from .__cache__mosrah_cache import MoSRAHCache
10
-
11
- __all__ = [
12
- "DecoderLayer",
13
- "LoadBalanceLoss",
14
- "MoSRAHCache",
15
- "MoSRAHRouter",
16
- "ShramConfig",
17
- "ShramForCausalLM",
18
- "ShramModel",
19
- "RotaryEmbedding",
20
- "SwiGLUMLP",
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/README.md DELETED
@@ -1,95 +0,0 @@
1
- ---
2
- language:
3
- - en
4
- license: mit
5
- library_name: transformers
6
- pipeline_tag: text-generation
7
- tags:
8
- - pytorch
9
- - research
10
- - sparse-attention
11
- - mixture-of-experts
12
- ---
13
-
14
- # SHRAM — Sparse Hybrid Token Routed Attention Mixture
15
-
16
- A research baseline implementing the SHRAM architecture from "An Examination of Sparse
17
- Attention for Long Context Purposes." No pretrained weights — pull the architecture from
18
- the Hub and instantiate a freshly initialised model from config. Every parameter is
19
- overridable at instantiation time via kwargs.
20
-
21
- > **Important:** `trust_remote_code=True` is required. It downloads the architecture
22
- > source files from the Hub and imports them into your Python process. Review the
23
- > source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use.
24
-
25
- ## Architecture
26
-
27
- SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
28
-
29
- - **h_l** — local sliding-window causal attention path.
30
- - **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
31
- via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
32
-
33
- All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
34
-
35
- ## Usage
36
-
37
- This repository contains no pretrained weights. The intended workflow is: pull the
38
- architecture config from the Hub, instantiate a model with fresh random weights, then
39
- train it yourself.
40
-
41
- ```python
42
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
43
-
44
- # Step 1: pull the architecture config from the Hub.
45
- # AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
46
- # Override any parameter via kwargs.
47
- config = AutoConfig.from_pretrained(
48
- "smithblack-0/SHRAM",
49
- trust_remote_code=True,
50
- num_hidden_layers=16, # example override
51
- num_mosrah_heads=32, # example override
52
- )
53
-
54
- # Step 2: instantiate with fresh random weights.
55
- # from_config never loads a checkpoint — it always produces a randomly initialised model.
56
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
57
-
58
- # Step 3: load the tokenizer.
59
- tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
60
- ```
61
-
62
- After training your own checkpoint, save and reload it in the standard way:
63
-
64
- ```python
65
- model.save_pretrained("./my-checkpoint")
66
- model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
67
- ```
68
-
69
- ## Constructor Defaults
70
-
71
- The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
72
- no overrides. They are not the parameters of a pretrained model — this repository
73
- contains no weights. All values are overridable via kwargs.
74
-
75
- | Parameter | Default |
76
- |-----------|---------|
77
- | `vocab_size` | 50277 |
78
- | `hidden_size` | 512 |
79
- | `intermediate_size` | 1366 |
80
- | `num_hidden_layers` | 12 |
81
- | `num_sliding_window_heads` | 16 |
82
- | `num_mosrah_heads` | 16 |
83
- | `num_selected_heads` | 16 |
84
- | `head_dim` | 16 |
85
- | `window_size` | 128 |
86
- | `rope_mode` | main_sequence |
87
- | `local_rope_theta` | 10000.0 |
88
- | `mosrah_rope_theta` | 10000.0 |
89
- | `training_sequence_length` | 8192 |
90
- | `inference_sequence_length` | 8192 |
91
-
92
- ## License
93
-
94
- MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
95
- (`EleutherAI/gpt-neox-20b`, Apache 2.0).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from .configuration import ShramConfig
2
- from .decoder_layer import DecoderLayer
3
- from .huggingface import ShramForCausalLM
4
- from src.shram.model.attention.load_balance_loss import LoadBalanceLoss
5
- from .mlp import SwiGLUMLP
6
- from .model import ShramModel
7
- from .rope import RotaryEmbedding
8
- from src.shram.model.attention.router import MoSRAHRouter
9
- from .cache import MoSRAHCache
10
-
11
- __all__ = [
12
- "DecoderLayer",
13
- "LoadBalanceLoss",
14
- "MoSRAHCache",
15
- "MoSRAHRouter",
16
- "ShramConfig",
17
- "ShramForCausalLM",
18
- "ShramModel",
19
- "RotaryEmbedding",
20
- "SwiGLUMLP",
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/__init__.py DELETED
File without changes
architecture_core/attention/bottlenecked_ensemble_attention.py DELETED
@@ -1,252 +0,0 @@
1
- """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
-
3
- BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
- It consumes packed expert-choice tensors, a supplied position tensor, an active
5
- token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
- same packed expert-choice space expected by later unpacking.
7
-
8
- BEA does not compute positions and does not choose packed-position semantics.
9
- Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
- (K̃) and raw values (V) into the sparse cache and attends against the
11
- accumulated cached state returned by that cache.
12
- """
13
-
14
- import math
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
-
20
- from ..configuration import ShramConfig
21
- from ..cache.mosrah_cache import MoSRAHCache
22
- from ..rope import RotaryEmbedding
23
-
24
-
25
- class BottleneckedEnsembleAttention(nn.Module):
26
- """
27
- Packed expert-choice attention operator for the MoSRAH sparse path.
28
- Operates per-head independently on an ensemble of tokens.
29
- FlexAttention saves flops on dead tokens.
30
-
31
- Architectural properties:
32
- - consumes packed expert-choice tensors of shape (B, L, T, d)
33
- - uses independent per-head Q/K/V/O projection parameters
34
- - applies YaRN-capable RoPE using supplied position_ids
35
- - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
- - uses a fast fused attention path
37
- - returns outputs in the same packed expert-choice space (B, L, T, d)
38
-
39
- Args:
40
- config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
- `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
- `inference_sequence_length`, `alpha`, and `beta`.
43
- """
44
-
45
- def __init__(self, config: ShramConfig) -> None:
46
- super().__init__()
47
-
48
- self.hidden_size = config.hidden_size
49
- self.num_heads = config.num_mosrah_heads
50
- self.head_dim = config.head_dim
51
-
52
- # Independent per-head projections. No cross-head parameter sharing.
53
- self.q_proj = nn.Parameter(
54
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
- )
56
- self.k_proj = nn.Parameter(
57
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
- )
59
- self.v_proj = nn.Parameter(
60
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
- )
62
- self.o_proj = nn.Parameter(
63
- torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
- )
65
-
66
- self._reset_parameters()
67
-
68
- # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
- # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
- # no yarn dilation occurs.
71
- self.rope = RotaryEmbedding(
72
- mode="yarn",
73
- head_dim=self.head_dim,
74
- theta=config.mosrah_rope_theta,
75
- initial_seq_length=config.training_sequence_length,
76
- dilation=config.scale,
77
- alpha=config.alpha,
78
- beta=config.beta,
79
- )
80
-
81
- def forward(
82
- self,
83
- packed_embeddings: torch.Tensor,
84
- position_ids: torch.Tensor,
85
- active_mask: torch.Tensor,
86
- cache: MoSRAHCache | None = None,
87
- ) -> torch.Tensor:
88
- """Apply BEA to packed expert-choice tensors.
89
-
90
- Args:
91
- packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
- position_ids: Supplied packed positions of shape (B, L, T).
93
- active_mask: Boolean active-token mask of shape (B, L, T).
94
- cache: Optional layer-local MoSRAH cache.
95
-
96
- Returns:
97
- Packed expert-choice output tensor of shape (B, L, T, d).
98
- """
99
- batch_size, _, query_length, _ = packed_embeddings.shape
100
- self._validate_tensor_shape(packed_embeddings)
101
- self._validate_position_shape(packed_embeddings, position_ids)
102
- self._validate_active_mask_shape(packed_embeddings, active_mask)
103
-
104
- # Independent per-head projections:
105
- # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
- query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
- key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
- value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
-
110
- rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
- query_states,
112
- key_states,
113
- position_ids,
114
- )
115
-
116
- if cache is not None:
117
- # In cached execution, the current query tensor uses local tensor rows
118
- # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
- # packed sequence for each (batch, head) slot. The only additional data
120
- # needed to align those two views is the pre-update cached prefix length.
121
- # which will indicate how many queries were processed before now.
122
- num_tokens_processed = cache.get_heads_lengths().clone()
123
- key_states, value_states, key_active_mask = cache.update(
124
- rotated_key_states,
125
- value_states,
126
- active_mask,
127
- )
128
- else:
129
- num_tokens_processed = torch.zeros(
130
- batch_size,
131
- self.num_heads,
132
- dtype=torch.long,
133
- device=packed_embeddings.device,
134
- )
135
- key_states = rotated_key_states
136
- key_active_mask = active_mask
137
-
138
- block_mask = self._make_block_mask(
139
- query_active_mask=active_mask,
140
- key_active_mask=key_active_mask,
141
- num_tokens_processed=num_tokens_processed,
142
- query_length=query_length,
143
- key_length=key_states.shape[2],
144
- device=packed_embeddings.device,
145
- )
146
- attended_states = flex_attention(
147
- rotated_query_states,
148
- key_states,
149
- value_states,
150
- block_mask=block_mask,
151
- scale=attention_scaling / math.sqrt(self.head_dim),
152
- )
153
-
154
- # Project back to model width:
155
- # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
- return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
-
158
- def _reset_parameters(self) -> None:
159
- """Initialize per-head projection weights."""
160
- for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
- nn.init.xavier_uniform_(weight)
162
-
163
- def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
- """Validate the local packed-embedding shape contract required by BEA."""
165
- if packed_embeddings.shape[1] != self.num_heads:
166
- raise ValueError(
167
- f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
- f"got {packed_embeddings.shape[1]}."
169
- )
170
-
171
- if packed_embeddings.shape[-1] != self.hidden_size:
172
- raise ValueError(
173
- f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
- f"got {packed_embeddings.shape[-1]}."
175
- )
176
-
177
- def _validate_position_shape(
178
- self,
179
- packed_embeddings: torch.Tensor,
180
- position_ids: torch.Tensor,
181
- ) -> None:
182
- """Validate the supplied packed-position tensor shape."""
183
- if position_ids.shape != packed_embeddings.shape[:3]:
184
- raise ValueError(
185
- f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
- f"got {tuple(position_ids.shape)}."
187
- )
188
-
189
- def _validate_active_mask_shape(
190
- self,
191
- packed_embeddings: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> None:
194
- """Validate the supplied active-token mask shape."""
195
- if active_mask.shape != packed_embeddings.shape[:3]:
196
- raise ValueError(
197
- f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
- f"got {tuple(active_mask.shape)}."
199
- )
200
-
201
- def _make_block_mask(
202
- self,
203
- query_active_mask: torch.Tensor,
204
- key_active_mask: torch.Tensor,
205
- num_tokens_processed: torch.Tensor,
206
- query_length: int,
207
- key_length: int,
208
- device: torch.device,
209
- ):
210
- """Create the packed-sequence causal mask for FlexAttention.
211
-
212
- At the root, causality is still triangular. The only nuance is cached
213
- execution: query rows are indexed locally as 0..Q-1 inside the current
214
- query tensor, but the key tensor may already contain a cached prefix for
215
- that (batch, head) slot. The causal horizon for query tensor row q is
216
- therefore:
217
-
218
- cached_prefix_lengths[b, h] + q
219
-
220
- Query and key activity masks are then composed with that triangular rule
221
- so FlexAttention can skip padded query rows and ignore inactive key slots.
222
- """
223
- batch_size, num_heads, _ = query_active_mask.shape
224
-
225
- # Build the per-(batch, head, query_row) triangular horizon from a simple
226
- # arange over query rows plus the cached prefix lengths for each slot.
227
- relative_query_positions = torch.arange(
228
- query_length,
229
- device=device,
230
- dtype=torch.long,
231
- ).view(1, 1, query_length)
232
- causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
-
234
- def packed_causal_mask(
235
- batch_idx: torch.Tensor,
236
- head_idx: torch.Tensor,
237
- query_idx: torch.Tensor,
238
- key_idx: torch.Tensor,
239
- ) -> torch.Tensor:
240
- query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
- key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
- is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
- return query_is_active & key_is_active & is_causal
244
-
245
- return create_block_mask(
246
- packed_causal_mask,
247
- B=batch_size,
248
- H=num_heads,
249
- Q_LEN=query_length,
250
- KV_LEN=key_length,
251
- device=device,
252
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/expert_packing.py DELETED
@@ -1,335 +0,0 @@
1
- """Expert packing and unpacking for the MoSRAH path.
2
-
3
- This module implements the low-level token-choice -> expert-choice -> token-choice
4
- conversion boundary specified in the paper. The externally visible behavior is fixed:
5
-
6
- - setup_packing() prepares the auxiliary ordering data.
7
- - pack_experts() converts routed token-choice state into packed expert-choice state.
8
- - unpack_experts() restores token-choice ordering afterward.
9
-
10
- Stable sort is a correctness requirement. It preserves causal ordering inside each
11
- expert bucket, which is the foundation on which BEA's later triangular causal mask
12
- is correct.
13
-
14
- pack_experts() returns two distinct masks that serve different roles and must not
15
- be interchanged:
16
-
17
- - unpacking_mask: marks every packed slot that contains a routed token copy,
18
- live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
- so its reshape invariant holds regardless of outer token liveness.
20
- - active_mask: marks only the packed slots whose source token was semantically
21
- live. This is what BEA consumes for attention gating. Dead outer tokens must
22
- not influence sparse attention outputs.
23
- """
24
-
25
- import torch
26
-
27
-
28
- # ---------------------------------------------------------------------------
29
- # Setup
30
- # ---------------------------------------------------------------------------
31
-
32
- def setup_packing(
33
- selected_heads: torch.Tensor,
34
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
- """Prepare the auxiliary ordering data used by pack/unpack.
36
-
37
- Routing produces token-choice state I of shape (B, N, K): for each token, which
38
- K experts were selected. Packing needs the same routed token copies reordered into
39
- expert-major order so each expert bucket becomes contiguous.
40
-
41
- The paper's setup step does this by flattening (N, K) into one axis to produce
42
- H in token-major order, then computing a stable argsort permutation Pi over the
43
- expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
- expert-major order while preserving their original token order *within* each expert
45
- bucket. That preservation is why stable sort is required for causality.
46
-
47
- Args:
48
- selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
-
50
- Returns:
51
- Tuple of:
52
- - flattened_selected_heads: H of shape (B, N*K)
53
- - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
- - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
- """
56
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
- flattened_selected_heads = selected_heads.reshape(
58
- batch_size,
59
- sequence_length * num_selected_heads,
60
- )
61
-
62
- permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
- inverse_permutation = torch.argsort(permutation, dim=-1)
64
-
65
- return flattened_selected_heads, permutation, inverse_permutation
66
-
67
-
68
- # ---------------------------------------------------------------------------
69
- # Packing
70
- # ---------------------------------------------------------------------------
71
-
72
- def pack_experts(
73
- hidden_states: torch.Tensor,
74
- position_ids: torch.Tensor,
75
- selected_heads: torch.Tensor,
76
- num_experts: int,
77
- flattened_selected_heads: torch.Tensor,
78
- permutation: torch.Tensor,
79
- outer_active_mask: torch.Tensor,
80
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
- """Pack token-choice hidden states into expert-choice padded form.
82
-
83
- The paper's packing path has two jobs:
84
-
85
- 1. Convert routed token-choice copies into expert-major order.
86
- 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
-
88
- The routed hidden-state copies are not stored explicitly in token-choice form.
89
- Instead, the same token hidden state is conceptually copied once per selected expert.
90
- The packing step reconstructs those copies by expanding local source-token indices,
91
- reordering those indices with Pi, then gathering hidden states, positions, and outer
92
- liveness in that packed order. All three are carried through the same expert-major
93
- rearrangement so they remain aligned in the packed frame.
94
-
95
- Packed positions are sourced from the authoritative upstream position_ids tensor
96
- rather than synthesized locally from arange(N). This preserves advanced positions
97
- correctly during cached inference while leaving training/full-sequence behavior
98
- unchanged when position_ids is the ordinary sequential token positions.
99
-
100
- Args:
101
- hidden_states: Token-choice hidden states x of shape (B, N, d).
102
- position_ids: Authoritative upstream token positions J of shape (B, N).
103
- selected_heads: Routed head selections I of shape (B, N, K).
104
- num_experts: Total number of experts L.
105
- flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
- permutation: Pi from setup_packing(), shape (B, N*K).
107
- outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
- means the token is semantically live. Dead tokens do not become
109
- semantically active in the packed sparse representation.
110
-
111
- Returns:
112
- Tuple of:
113
- - packed_hidden_states: x' of shape (B, L, T, d)
114
- - packed_positions: J' of shape (B, L, T)
115
- - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
- routed token copy, live or dead. Always has exactly B*N*K True entries.
117
- Pass this to unpack_experts — not active_mask.
118
- - active_mask: of shape (B, L, T). True only where a slot contains a
119
- copy of a live outer token. Pass this to BEA for attention gating.
120
- """
121
- batch_size, sequence_length, hidden_dim = hidden_states.shape
122
- _, _, num_selected_heads = selected_heads.shape
123
-
124
- # -----------------------------------------------------------------------
125
- # Reconstruct routed local source-token indices in token-choice order.
126
- #
127
- # The internal arange(N) is no longer the packed position tensor. It is only
128
- # the local source-row index object used to gather from the current chunk
129
- # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
- # token-major routed-copy order.
131
- # -----------------------------------------------------------------------
132
- source_token_indices = torch.arange(
133
- sequence_length,
134
- device=hidden_states.device,
135
- dtype=torch.long,
136
- ).view(1, sequence_length, 1).expand(
137
- batch_size,
138
- sequence_length,
139
- num_selected_heads,
140
- )
141
- flattened_source_indices = source_token_indices.reshape(
142
- batch_size,
143
- sequence_length * num_selected_heads,
144
- )
145
-
146
- # -----------------------------------------------------------------------
147
- # Reorder source-token indices into expert-major order.
148
- #
149
- # Applying Pi yields the local source-token rows in the packed expert-major
150
- # order required by the paper. Those same reordered source indices are then
151
- # used to gather hidden states, authoritative upstream positions, and outer
152
- # liveness so all three remain aligned under the exact same packing
153
- # transformation.
154
- # -----------------------------------------------------------------------
155
- sorted_source_indices = flattened_source_indices.gather(
156
- dim=1,
157
- index=permutation,
158
- )
159
- sorted_hidden_states = hidden_states.gather(
160
- dim=1,
161
- index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
- )
163
- sorted_positions = position_ids.gather(
164
- dim=1,
165
- index=sorted_source_indices,
166
- )
167
- sorted_active_mask = outer_active_mask.gather(
168
- dim=1,
169
- index=sorted_source_indices,
170
- )
171
-
172
- # -----------------------------------------------------------------------
173
- # Count how many routed copies land in each expert bucket.
174
- #
175
- # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
- # T is the maximum such count across all batches and experts; it determines the
177
- # padded expert-length dimension of the packed representation.
178
- # -----------------------------------------------------------------------
179
- tokens_per_expert = _bincount_rows(
180
- values=flattened_selected_heads,
181
- num_bins=num_experts,
182
- )
183
- max_tokens_per_expert = int(tokens_per_expert.max().item())
184
-
185
- # -----------------------------------------------------------------------
186
- # Construct the active-token mask M.
187
- #
188
- # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
- # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
- # mask therefore both identifies real packed tokens and enforces left-justified
191
- # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
- # outer token liveness, and always has exactly B*N*K True entries.
193
- # -----------------------------------------------------------------------
194
- time_axis = torch.arange(
195
- max_tokens_per_expert,
196
- device=hidden_states.device,
197
- dtype=torch.long,
198
- ).view(1, 1, max_tokens_per_expert)
199
- unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
-
201
- # -----------------------------------------------------------------------
202
- # Materialize the padded packed tensors.
203
- #
204
- # The packed hidden states x', packed original-token positions J', and packed
205
- # active-token mask are allocated as zero-filled tensors. Active entries are
206
- # then written into those buffers in the expert-major order established above.
207
- # Padding remains zero / inactive.
208
- # -----------------------------------------------------------------------
209
- packed_hidden_states = hidden_states.new_zeros(
210
- batch_size,
211
- num_experts,
212
- max_tokens_per_expert,
213
- hidden_dim,
214
- )
215
- packed_positions = position_ids.new_zeros(
216
- batch_size,
217
- num_experts,
218
- max_tokens_per_expert,
219
- )
220
- active_mask = torch.zeros(
221
- batch_size,
222
- num_experts,
223
- max_tokens_per_expert,
224
- dtype=torch.bool,
225
- device=hidden_states.device,
226
- )
227
-
228
- packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
- packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
- active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
-
232
- return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
-
234
-
235
- # ---------------------------------------------------------------------------
236
- # Unpacking
237
- # ---------------------------------------------------------------------------
238
-
239
- def unpack_experts(
240
- expert_outputs: torch.Tensor,
241
- selected_heads: torch.Tensor,
242
- unpacking_mask: torch.Tensor,
243
- inverse_permutation: torch.Tensor,
244
- ) -> torch.Tensor:
245
- """Restore token-choice ordering from BEA expert-choice output.
246
-
247
- Unpacking inverts the packing path only on occupied entries. Padding does not
248
- participate: the output tensor is first filtered by unpacking_mask to recover
249
- only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
- the original token-choice ordering, and finally the tensor is reshaped back to
251
- (B, N, K, d).
252
-
253
- The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
- outer tokens occupy slots and must be un-scattered correctly for the inverse
255
- permutation to hold. The total True entry count in unpacking_mask is always
256
- B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
-
258
- Args:
259
- expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
- selected_heads: Routed head selections I of shape (B, N, K).
261
- unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
- occupied packed slots regardless of outer token liveness.
263
- inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
-
265
- Returns:
266
- Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
- """
268
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
- hidden_dim = expert_outputs.shape[-1]
270
-
271
- active_outputs = expert_outputs[unpacking_mask]
272
- sorted_token_choice_outputs = active_outputs.reshape(
273
- batch_size,
274
- sequence_length * num_selected_heads,
275
- hidden_dim,
276
- )
277
- restored_outputs = sorted_token_choice_outputs.gather(
278
- dim=1,
279
- index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
- )
281
-
282
- return restored_outputs.reshape(
283
- batch_size,
284
- sequence_length,
285
- num_selected_heads,
286
- hidden_dim,
287
- )
288
-
289
-
290
- # ---------------------------------------------------------------------------
291
- # Helpers
292
- # ---------------------------------------------------------------------------
293
-
294
- def _bincount_rows(
295
- values: torch.Tensor,
296
- num_bins: int,
297
- ) -> torch.Tensor:
298
- """Count per-row integer occurrences for a 2D tensor.
299
-
300
- torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
- one bincount per batch row. The trick used here is to shift each row into its
302
- own disjoint bin range before flattening:
303
-
304
- row 0 uses bins [0, ..., num_bins - 1]
305
- row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
- row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
- ...
308
-
309
- After that shift, one global torch.bincount produces all row-local counts at
310
- once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
-
312
- This is a vectorized implementation detail only; externally visible behavior
313
- remains exactly the paper's S tensor of per-batch per-expert token counts.
314
-
315
- Args:
316
- values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
- num_bins: Number of bins.
318
-
319
- Returns:
320
- Counts tensor of shape (B, num_bins).
321
- """
322
- batch_size = values.shape[0]
323
-
324
- row_offsets = torch.arange(
325
- batch_size,
326
- device=values.device,
327
- dtype=values.dtype,
328
- ).unsqueeze(1) * num_bins
329
- shifted_values = values + row_offsets
330
-
331
- counts = torch.bincount(
332
- shifted_values.reshape(-1),
333
- minlength=batch_size * num_bins,
334
- )
335
- return counts.reshape(batch_size, num_bins)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/load_balance_loss.py DELETED
@@ -1,88 +0,0 @@
1
- """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
-
3
- This module implements the custom autograd Function H(b, f) described in the paper's
4
- Implementation Concerns section. The operator bridges two requirements that are in
5
- tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
- multiplication) so that existing training loops remain compatible, while simultaneously
7
- implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
- path through the router weights.
9
-
10
- The resolution is a custom backward pass. The forward emits the load balance imbalance
11
- as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
- its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
- then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
- correction without a standalone SGD update step.
15
-
16
- Paper ref: Appendix A.Implementation Concerns.
17
- """
18
-
19
- import torch
20
-
21
-
22
- class LoadBalanceLoss(torch.autograd.Function):
23
- """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
-
25
- Forward computes the load balance imbalance:
26
-
27
- L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
-
29
- Backward emits a bias-correction gradient to expert_bias:
30
-
31
- grad_b = L_grad * sign(f_l - 1/L)
32
-
33
- expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
- in the computation graph and routes gradients through it. routing_freqs (f) receives
35
- no gradient — its origin is the discrete TopK operation which has no gradient, so
36
- defining a gradient for f here would be mathematically incorrect.
37
-
38
- Paper ref: Appendix A.Implementation Concerns.
39
- """
40
-
41
- @staticmethod
42
- def forward(
43
- ctx: torch.autograd.function.FunctionCtx,
44
- expert_bias: torch.Tensor,
45
- routing_freqs: torch.Tensor,
46
- ) -> torch.Tensor:
47
- """Compute the load balance loss.
48
-
49
- Args:
50
- ctx: Autograd context for saving state needed in backward.
51
- expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
- PyTorch tracks it as a computation graph node needing a gradient.
53
- routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
- from the discrete TopK selection — not differentiable.
55
-
56
- Returns:
57
- Scalar loss equal to sum_l |f_l - 1/L|.
58
- """
59
- L = expert_bias.shape[0]
60
- # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
- # underloaded. Saved for backward where sign(imbalance) determines the direction
62
- # of the bias-correction update.
63
- imbalance = routing_freqs - 1.0 / L
64
- ctx.save_for_backward(imbalance)
65
- return imbalance.abs().sum()
66
-
67
- @staticmethod
68
- def backward(
69
- ctx: torch.autograd.function.FunctionCtx,
70
- grad_output: torch.Tensor,
71
- ) -> tuple[torch.Tensor, None]:
72
- """Emit the DeepSeek-style bias-correction gradient.
73
-
74
- Args:
75
- ctx: Autograd context carrying imbalance saved in forward.
76
- grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
- by the training loop arrives here and is propagated to grad_b, so the
78
- correction magnitude is proportional to the loss weight chosen by the
79
- consumer.
80
-
81
- Returns:
82
- Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
- None for routing_freqs: no gradient is defined for the discrete routing
84
- frequency.
85
- """
86
- (imbalance,) = ctx.saved_tensors
87
- grad_expert_bias = grad_output * imbalance.sign()
88
- return grad_expert_bias, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/mosrah.py DELETED
@@ -1,140 +0,0 @@
1
- """Full MoSRAH sparse path for SHRAM.
2
-
3
- This module coordinates the routed sparse attention path used inside the SHRAM
4
- hybrid attention layer. The underlying mechanics already live in verified
5
- subunits. The responsibility here is to connect those subunits without
6
- corrupting their bridge contracts.
7
-
8
- In particular, this path must preserve three architectural distinctions:
9
-
10
- - selected head indices are not routing probabilities
11
- - packed position semantics are chosen before BEA, not inside it
12
- - weighted reduction must consume the router's unbiased renormalized
13
- probabilities after token-choice order has been restored
14
- """
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from src.shram.model.cache.mosrah_cache import MoSRAHCache
20
- from src.shram.model.configuration import ShramConfig
21
- from src.shram.model.attention.bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
- from src.shram.model.attention.expert_packing import (
23
- pack_experts,
24
- setup_packing,
25
- unpack_experts,
26
- )
27
- from src.shram.model.attention.router import MoSRAHRouter
28
- from src.shram.model.attention.positions_converter import SparseMoSRAHPositions
29
-
30
-
31
- class MoSRAHLayer(nn.Module):
32
- """Full routed sparse attention path for SHRAM.
33
-
34
- The MoSRAH path consumes model-space hidden states together with
35
- authoritative per-token positions and returns the model-space sparse-path
36
- contribution, the router's load-balance loss, and the router's MaxVio
37
- routing-imbalance scalar.
38
- """
39
-
40
- def __init__(self, config: ShramConfig) -> None:
41
- super().__init__()
42
- self.num_experts = config.num_mosrah_heads
43
-
44
- self.router = MoSRAHRouter(config)
45
- self.positions = SparseMoSRAHPositions(config)
46
- self.bea = BottleneckedEnsembleAttention(config)
47
-
48
- def forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- position_ids: torch.Tensor,
52
- active_mask: torch.Tensor,
53
- cache: MoSRAHCache | None,
54
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
- """Run the full MoSRAH sparse path.
56
-
57
- Args:
58
- hidden_states: Model-space hidden states x of shape (B, N, d).
59
- position_ids: Authoritative per-token positions of shape (B, N).
60
- active_mask: Current-chunk active mask of shape (B, N), where True
61
- means the token is semantically live. Forwarded to the router
62
- so dead tokens are excluded from routing statistics, and to
63
- pack_experts so dead outer tokens do not become semantically
64
- active packed entries.
65
- cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
- execution and the layer-local cache instance for cached execution.
67
-
68
- Returns:
69
- sparse_output: Model-space sparse-path output of shape (B, N, d).
70
- load_balance_loss: Scalar router load-balance loss.
71
- max_vio: Detached scalar routing-imbalance summary. Passed through
72
- unchanged from the router; see MoSRAHRouter for semantics.
73
- """
74
-
75
- # -------------------------------------------------------------------
76
- # The first transition moves from model-space token-choice input into
77
- # the packed expert-choice sparse-attention state. Routing decides both
78
- # which experts each token uses and which unbiased probabilities must be
79
- # reserved for the final reduction. The active mask is forwarded to the
80
- # router so dead tokens are excluded from routing statistics, and to
81
- # pack_experts so outer liveness is faithfully carried into the packed
82
- # frame. Packing returns both the unpacking mask (slot occupancy, always
83
- # B*N*K True entries) and the packed active mask (live slots only);
84
- # active_mask is rebound to the packed form after this point.
85
- # -------------------------------------------------------------------
86
- selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
- hidden_states, active_mask
88
- )
89
-
90
- flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
- selected_heads
92
- )
93
- packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
- hidden_states=hidden_states,
95
- position_ids=position_ids,
96
- selected_heads=selected_heads,
97
- num_experts=self.num_experts,
98
- flattened_selected_heads=flattened_selected_heads,
99
- permutation=permutation,
100
- outer_active_mask=active_mask,
101
- )
102
-
103
- # -------------------------------------------------------------------
104
- # Sparse attention runs entirely in the packed expert-choice frame, so
105
- # the RoPE position semantics must also be chosen in that frame. The
106
- # position layer therefore decides whether BEA should see packed
107
- # original-token positions or packed local-slot positions. BEA then
108
- # consumes that packed position tensor together with the packed hidden
109
- # states and the layer-local sparse cache, which it owns directly.
110
- # -------------------------------------------------------------------
111
- bea_positions = self.positions(
112
- packed_positions=packed_positions,
113
- cache=cache,
114
- )
115
- packed_outputs = self.bea(
116
- packed_embeddings=packed_hidden_states,
117
- position_ids=bea_positions,
118
- active_mask=active_mask,
119
- cache=cache,
120
- )
121
-
122
- # -------------------------------------------------------------------
123
- # The final transition restores token-choice meaning and only then
124
- # collapses the K routed copies back into model space. This ordering is
125
- # required because routing_probs live in token-choice space, whereas BEA
126
- # returns expert-choice packed outputs. The reduction must therefore
127
- # happen after unpacking, and it must use the router's unbiased
128
- # renormalized probabilities rather than any biased selection scores.
129
- # -------------------------------------------------------------------
130
- token_choice_outputs = unpack_experts(
131
- expert_outputs=packed_outputs,
132
- selected_heads=selected_heads,
133
- unpacking_mask=unpacking_mask,
134
- inverse_permutation=inverse_permutation,
135
- )
136
- final_output = (
137
- token_choice_outputs * routing_probs.unsqueeze(-1)
138
- ).sum(dim=2)
139
-
140
- return final_output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/positions_converter.py DELETED
@@ -1,105 +0,0 @@
1
- """Position computation for the MoSRAH sparse path.
2
-
3
- This layer computes the packed position tensor P consumed by BEA.
4
-
5
- - In main-sequence mode, P is the packed original-token position tensor from the
6
- packing path.
7
- - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
- expert-choice layout, optionally offset by the current sparse-cache occupancies
9
- during cached inference.
10
- """
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from src.shram.model.configuration import ShramConfig
16
- from src.shram.model.cache.mosrah_cache import MoSRAHCache
17
-
18
-
19
- class SparseMoSRAHPositions(nn.Module):
20
- """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
-
22
- This layer operates in the packed expert-choice frame used by BEA. The input
23
- packed_positions tensor is always the packed original-token position tensor
24
- produced by the packing path. The configured rope_mode determines whether that
25
- tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
- """
27
-
28
- def __init__(self, config: ShramConfig) -> None:
29
- super().__init__()
30
- self.rope_mode = config.rope_mode
31
-
32
- def forward(
33
- self,
34
- packed_positions: torch.Tensor,
35
- cache: MoSRAHCache | None,
36
- ) -> torch.Tensor:
37
- """Compute the packed position tensor P consumed by BEA.
38
-
39
- Args:
40
- packed_positions: Packed original-token positions J' of shape (B, L, T).
41
- cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
- mode, the current per-head occupancies offset the local packed sequence.
43
-
44
- Returns:
45
- Packed position tensor P of shape (B, L, T).
46
- """
47
- if self.rope_mode == "main_sequence":
48
- return self._main_sequence_positions(packed_positions)
49
-
50
- if self.rope_mode == "semantic_sequence":
51
- return self._semantic_sequence_positions(packed_positions, cache)
52
-
53
- raise NotImplementedError(
54
- f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
- )
56
-
57
- def _main_sequence_positions(
58
- self,
59
- packed_positions: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """Forward packed original-token positions unchanged."""
62
- return packed_positions
63
-
64
- def _semantic_sequence_positions(
65
- self,
66
- packed_positions: torch.Tensor,
67
- cache: MoSRAHCache | None,
68
- ) -> torch.Tensor:
69
- """Compute semantic-sequence packed positions in expert-choice space.
70
-
71
- Without a sparse cache, semantic positions are the local packed sequence
72
- 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
- same local sequence is offset by the current per-(batch, expert) occupancies
74
- returned by get_heads_lengths().
75
- """
76
- batch_size, num_experts, packed_length = packed_positions.shape
77
-
78
- # -------------------------------------------------------------------
79
- # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
- # sequence dimension T. This is then broadcast across batch and experts.
81
- # -------------------------------------------------------------------
82
- local_positions = torch.arange(
83
- packed_length,
84
- device=packed_positions.device,
85
- dtype=packed_positions.dtype,
86
- ).view(1, 1, packed_length).expand(
87
- batch_size,
88
- num_experts,
89
- packed_length,
90
- )
91
-
92
- # -------------------------------------------------------------------
93
- # In cached semantic-sequence mode, positions continue from the current
94
- # sparse-cache occupancies rather than restarting at zero for the local
95
- # chunk.
96
- # -------------------------------------------------------------------
97
- if cache is None:
98
- return local_positions
99
-
100
- cached_lengths = cache.get_heads_lengths().to(
101
- device=packed_positions.device,
102
- dtype=packed_positions.dtype,
103
- ).unsqueeze(-1)
104
-
105
- return local_positions + cached_lengths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/router.py DELETED
@@ -1,162 +0,0 @@
1
- """Token-choice router for the MoSRAH sparse attention path.
2
-
3
- This module implements the routing mechanism described in Appendix A.Routing of the
4
- paper. Given an input hidden state x, the router produces two outputs used downstream:
5
-
6
- - selected_heads (I): which K of the L available expert heads each token routes to,
7
- determined by TopK over biased routing scores.
8
- - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
- *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
- bias b must not influence P.
11
-
12
- This separation is architecturally critical: expert_bias drives selection (and thus load
13
- balancing) but does not corrupt the gradient path from the output through routing_probs
14
- back to the routing projection weights.
15
-
16
- The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
- autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
- loop can weight and add to the language modeling loss.
19
-
20
- The router additionally computes and returns MaxVio, a detached scalar summarising
21
- routing imbalance for the current forward pass:
22
-
23
- MaxVio = L · max_l(f_l − 1/L)
24
-
25
- where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
- target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
-
28
- Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
- """
30
-
31
- import torch
32
- import torch.nn as nn
33
- import torch.nn.functional as F
34
-
35
- from src.shram.model.configuration import ShramConfig
36
- from src.shram.model.attention.load_balance_loss import LoadBalanceLoss
37
-
38
-
39
- class MoSRAHRouter(nn.Module):
40
- """Token-choice router for MoSRAH sparse attention.
41
-
42
- Each input token independently selects K of the L available expert heads. Selection
43
- is driven by biased routing scores to enable load balancing, but the routing
44
- probabilities used for output reduction are computed from unbiased scores so that
45
- the expert bias does not interfere with the gradient path to the router weights.
46
-
47
- The routing projection W_r has no bias term — the paper specifies xW_r with no
48
- additional projection bias. The only bias-like parameter is expert_bias (b), which
49
- has an entirely separate role and update mechanism.
50
-
51
- Args:
52
- config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
- (L), and ``num_selected_heads`` (K).
54
- """
55
-
56
- def __init__(self, config: ShramConfig) -> None:
57
- super().__init__()
58
- self.num_mosrah_heads = config.num_mosrah_heads
59
- self.num_selected_heads = config.num_selected_heads
60
-
61
- # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
- self.routing_projection = nn.Linear(
63
- config.hidden_size, config.num_mosrah_heads, bias=False
64
- )
65
-
66
- # b: learned per-head bias for load balancing. Initialized to zero so that all
67
- # heads start with equal selection probability. Updated by the main optimizer
68
- # via the LoadBalanceLoss custom backward.
69
- self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
-
71
- def forward(
72
- self, x: torch.Tensor, active_mask: torch.Tensor
73
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
- """Route input tokens to K expert heads each and compute routing probabilities.
75
-
76
- Args:
77
- x: Input hidden states of shape (batch, seq_len, hidden_size).
78
- active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
- True means the token is semantically live. Dead tokens do not
80
- contribute to routing frequencies, load_balance_loss, or max_vio.
81
-
82
- Returns:
83
- selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
- Each token's K selected head indices, determined by TopK on biased scores.
85
- routing_probs: Routing probabilities P of shape (batch, seq_len,
86
- num_selected_heads). Gathered from unbiased scores at selected_heads
87
- indices and renormalized to sum to 1 per token.
88
- load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
- Training loop scales this by a weight and adds it to the main loss.
90
- max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
- Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
- never contributes gradients.
93
- """
94
- B, N, _ = x.shape
95
- L = self.num_mosrah_heads
96
- K = self.num_selected_heads
97
-
98
- # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
- # compute routing_probs — expert_bias must not influence them.
100
- logits = self.routing_projection(x) # (B, N, L)
101
- routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
-
103
- # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
- # selection. expert_bias is added to logits before softmax so that the bias
105
- # shifts selection probability without rescaling the unbiased distribution.
106
- biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
- logits + self.expert_bias, dim=-1
108
- )
109
-
110
- # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
- selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
-
113
- # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
- # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
- # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
- # the output back to the router weights free of expert_bias influence.
117
- gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
- routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
-
120
- # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
- # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
- # assignment mask before reduction. Normalization uses the active assignment
123
- # count so frequencies remain properly scaled regardless of how many tokens
124
- # are live in this chunk.
125
- assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
- assignment_mask.scatter_(-1, selected_heads, 1.0)
127
- active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
- num_active_assignments = active_mask.sum() * K
129
- routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
-
131
- # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
- # registers it as a graph node; the custom backward writes the DeepSeek-style
133
- # correction gradient to expert_bias.grad for the optimizer to consume.
134
- load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
-
136
- # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
- # not contribute gradients under any circumstance, so it is detached at the
138
- # point of computation rather than left to callers to detach.
139
- max_vio = self._compute_max_vio(routing_freqs, L)
140
-
141
- return selected_heads, routing_probs, load_balance_loss, max_vio
142
-
143
- @staticmethod
144
- def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
- """Compute the MaxVio routing-imbalance scalar.
146
-
147
- MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
- head l and 1/L is the perfectly balanced target. A value of zero indicates
149
- perfect balance; a value of 1 means the most overloaded head received exactly
150
- double its fair share.
151
-
152
- The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
- and must never contribute gradients to any parameter.
154
-
155
- Args:
156
- routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
- num_heads: Total number of MoSRAH heads L.
158
-
159
- Returns:
160
- Detached scalar MaxVio tensor.
161
- """
162
- return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/shram.py DELETED
@@ -1,116 +0,0 @@
1
- """SHRAM hybrid attention layer.
2
-
3
- This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
- used at one decoder attention slot in SHRAM.
5
-
6
- The local sliding-window path and the MoSRAH sparse path are already verified
7
- independently. The responsibility here is therefore not to introduce new
8
- attention logic, but to preserve the bridge contracts between them: both paths
9
- must consume the same input hidden state, each path must receive the sub-cache
10
- it actually owns, the two model-space outputs must be summed directly, and the
11
- sparse-path load-balance loss must remain visible to the caller.
12
- """
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from src.shram.model.cache.shram_layer_cache import ShramLayerCache
18
- from src.shram.model.configuration import ShramConfig
19
- from src.shram.model.attention.sliding_window_attention import SlidingWindowAttention
20
- from src.shram.model.attention.mosrah import MoSRAHLayer
21
-
22
-
23
- class SHRAMHybridLayer(nn.Module):
24
- """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
-
26
- The local path preserves nearby-token behavior through sliding-window causal
27
- attention. The sparse path is the theorem-facing MoSRAH routed attention
28
- path. Both operate over the same model-space hidden state and return
29
- model-space outputs, so the hybrid composition is a direct sum in model
30
- space.
31
- """
32
-
33
- def __init__(self, config: ShramConfig) -> None:
34
- super().__init__()
35
- self.local_attention = SlidingWindowAttention(config)
36
- self.sparse_attention = MoSRAHLayer(config)
37
-
38
- def forward(
39
- self,
40
- hidden_states: torch.Tensor,
41
- position_ids: torch.Tensor,
42
- active_mask: torch.Tensor,
43
- cache: ShramLayerCache | None,
44
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
- """Apply the SHRAM hybrid attention layer.
46
-
47
- Args:
48
- hidden_states: Input hidden states of shape (B, N, d).
49
- position_ids: Authoritative token positions of shape (B, N).
50
- active_mask: Current-chunk active mask of shape (B, N), where True
51
- means the token is semantically live. Forwarded unchanged to
52
- both the local path and the sparse path.
53
- cache: Optional per-layer SHRAM cache. When provided, the owned
54
- sliding-window and MoSRAH sub-caches are dispatched directly to
55
- their corresponding attention paths.
56
-
57
- Returns:
58
- hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
- load_balance_loss: Scalar sparse-path load-balance loss.
60
- max_vio: Detached scalar routing-imbalance summary. Passed through
61
- unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
- """
63
- # ------------------------------------------------
64
- # It is not possible, due to how bea constructs its block mask,
65
- # for the model to process a sequence that does not start at zero
66
- # without a cache to track the per-head offsets
67
- # ------------------------------------------------
68
-
69
- if cache is None and torch.any(position_ids[:, 0] != 0):
70
- raise ValueError(
71
- "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
- "Either provide a matching ShramLayerCache populated by the prefix for "
73
- "continued decoding, or rebase the uncached sequence to start at 0."
74
- )
75
-
76
- # -------------------------------------------------------------------
77
- # The hybrid layer's first responsibility is cache dispatch. The layer
78
- # cache already owns the concrete sub-cache objects required by each
79
- # path, so this unit should forward those exact references rather than
80
- # reinterpret cache ownership or invent a composite update protocol here.
81
- # -------------------------------------------------------------------
82
- if cache is None:
83
- sliding_window_cache = None
84
- mosrah_cache = None
85
- else:
86
- sliding_window_cache = cache.sliding_window_cache
87
- mosrah_cache = cache.mosrah_cache
88
-
89
- # -------------------------------------------------------------------
90
- # Both attention paths must see the same model-space hidden state for
91
- # the current decoder layer. The local path preserves short-range
92
- # structure, while the sparse path provides the routed long-range
93
- # contribution and emits the load-balance signal used by training.
94
- # -------------------------------------------------------------------
95
- local_output = self.local_attention(
96
- x=hidden_states,
97
- position_ids=position_ids,
98
- active_mask=active_mask,
99
- cache=sliding_window_cache,
100
- )
101
- sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
- hidden_states=hidden_states,
103
- position_ids=position_ids,
104
- active_mask=active_mask,
105
- cache=mosrah_cache,
106
- )
107
-
108
- # -------------------------------------------------------------------
109
- # The composition rule is intentionally simple at this boundary. Both
110
- # sublayers already return model-space tensors of matching shape, so the
111
- # correct hybrid behavior is their direct sum with no additional mixing
112
- # logic introduced here.
113
- # -------------------------------------------------------------------
114
- hybrid_output = local_output + sparse_output
115
-
116
- return hybrid_output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/attention/sliding_window_attention.py DELETED
@@ -1,233 +0,0 @@
1
- # src/shram/model/attention/sliding_window_attention.py
2
-
3
- """Local sliding-window attention path for SHRAM.
4
-
5
- This file defines `SlidingWindowAttention`, the local short-range attention path
6
- used inside the SHRAM hybrid layer.
7
-
8
- In the masked-continuation variant, the local cache no longer returns a
9
- semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
10
- returns:
11
-
12
- - the retained local window memory concatenated with the current chunk
13
- - an aligned active mask over that returned frame
14
-
15
- This module consumes that returned frame directly and constructs effective local
16
- causal/window visibility from the mask. It does not own cache retention policy;
17
- it owns only local attention semantics.
18
- """
19
-
20
- import math
21
- from typing import Any
22
-
23
- import torch
24
- import torch.nn as nn
25
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
26
-
27
- from ..cache.sliding_window_cache import LocalSlidingWindowLayerCache
28
- from ..configuration import ShramConfig
29
- from ..rope import RotaryEmbedding
30
-
31
-
32
- class SlidingWindowAttention(nn.Module):
33
- """Causal local sliding-window attention for one SHRAM layer.
34
-
35
- Args:
36
- config: SHRAM config. Must expose `hidden_size`,
37
- `num_sliding_window_heads`, `head_dim`, `window_size`,
38
- `attention_dropout`, and `local_rope_theta`.
39
-
40
- Raises:
41
- NotImplementedError: If `attention_dropout != 0.0`.
42
- """
43
-
44
- def __init__(self, config: ShramConfig) -> None:
45
- super().__init__()
46
-
47
- self.hidden_size = config.hidden_size
48
- self.num_heads = config.num_sliding_window_heads
49
- self.head_dim = config.head_dim
50
- self.window_size = config.window_size
51
- self.attention_dropout = config.attention_dropout
52
-
53
- if self.attention_dropout != 0.0:
54
- raise NotImplementedError(
55
- "SlidingWindowAttention currently supports only "
56
- "attention_dropout == 0.0."
57
- )
58
-
59
- self.inner_dim = self.num_heads * self.head_dim
60
-
61
- # Standard MHA projections for the local path.
62
- self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
63
- self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
64
- self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
65
- self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
66
-
67
- # The local path always uses default-mode RoPE with its own theta.
68
- self.rope = RotaryEmbedding(
69
- mode="default",
70
- head_dim=self.head_dim,
71
- theta=config.local_rope_theta,
72
- )
73
-
74
- def forward(
75
- self,
76
- x: torch.Tensor,
77
- position_ids: torch.Tensor,
78
- active_mask: torch.Tensor,
79
- cache: LocalSlidingWindowLayerCache | None = None,
80
- ) -> torch.Tensor:
81
- """Apply local causal sliding-window attention.
82
-
83
- Args:
84
- x: Input tensor of shape `(B, N, hidden_size)`.
85
- position_ids: Position tensor of shape `(B, N)`.
86
- active_mask: Current-chunk active mask of shape `(B, N)`, where
87
- `True` means active.
88
- cache: Optional `LocalSlidingWindowLayerCache`.
89
-
90
- Returns:
91
- Output tensor of shape `(B, N, hidden_size)`.
92
- """
93
- batch_size, query_len, _ = x.shape
94
-
95
- self._validate_position_shape(x, position_ids)
96
- self._validate_active_mask_shape(x, active_mask)
97
-
98
- # (B, N, H*D) -> (B, H, N, D)
99
- q = self.q_proj(x).view(
100
- batch_size,
101
- query_len,
102
- self.num_heads,
103
- self.head_dim,
104
- ).transpose(1, 2)
105
- k = self.k_proj(x).view(
106
- batch_size,
107
- query_len,
108
- self.num_heads,
109
- self.head_dim,
110
- ).transpose(1, 2)
111
- v = self.v_proj(x).view(
112
- batch_size,
113
- query_len,
114
- self.num_heads,
115
- self.head_dim,
116
- ).transpose(1, 2)
117
-
118
- q, k, attention_scaling = self.rope(q, k, position_ids)
119
-
120
- # The cache returns the current-step visible local frame, not merely the
121
- # retained next-step cache buffer.
122
- if cache is not None:
123
- k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
124
- else:
125
- k_full, v_full, full_active_mask = k, v, active_mask
126
-
127
- block_mask = self._make_block_mask(
128
- active_mask=full_active_mask,
129
- batch_size=batch_size,
130
- num_heads=self.num_heads,
131
- query_len=query_len,
132
- kv_len=k_full.shape[-2],
133
- window_size=self.window_size,
134
- device=x.device,
135
- )
136
-
137
- attn_output = flex_attention(
138
- q,
139
- k_full,
140
- v_full,
141
- block_mask=block_mask,
142
- scale=attention_scaling / math.sqrt(self.head_dim),
143
- )
144
-
145
- # (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
146
- attn_output = (
147
- attn_output.transpose(1, 2)
148
- .contiguous()
149
- .view(batch_size, query_len, self.inner_dim)
150
- )
151
-
152
- return self.o_proj(attn_output)
153
-
154
- def _validate_position_shape(
155
- self,
156
- x: torch.Tensor,
157
- position_ids: torch.Tensor,
158
- ) -> None:
159
- """Validate the position tensor shape expected by local RoPE."""
160
- if position_ids.shape != x.shape[:2]:
161
- raise ValueError(
162
- f"position_ids must have shape {tuple(x.shape[:2])}, "
163
- f"got {tuple(position_ids.shape)}."
164
- )
165
-
166
- def _validate_active_mask_shape(
167
- self,
168
- x: torch.Tensor,
169
- active_mask: torch.Tensor,
170
- ) -> None:
171
- """Validate the current-chunk active-mask contract."""
172
- if active_mask.shape != x.shape[:2]:
173
- raise ValueError(
174
- f"active_mask must have shape {tuple(x.shape[:2])}, "
175
- f"got {tuple(active_mask.shape)}."
176
- )
177
- if active_mask.dtype != torch.bool:
178
- raise ValueError(
179
- f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
180
- )
181
-
182
- def _make_block_mask(
183
- self,
184
- active_mask: torch.Tensor,
185
- batch_size: int,
186
- num_heads: int,
187
- query_len: int,
188
- kv_len: int,
189
- window_size: int,
190
- device: torch.device,
191
- ) -> Any:
192
- """Create the FlexAttention block mask for masked local continuation.
193
-
194
- The returned local frame is chronological in raw buffer order, but dead
195
- positions may remain inside it. Effective local order is therefore
196
- recovered from the active mask itself by taking a cumulative count over
197
- active positions.
198
-
199
- Queries still occupy the tail of the returned frame, so raw buffer order
200
- is used to locate query rows. Semantic active-token positions are then
201
- used to decide causality and sliding-window distance.
202
- """
203
- query_offset = kv_len - query_len
204
- semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
-
206
- def sliding_window_mask(
207
- batch_idx: torch.Tensor,
208
- head_idx: torch.Tensor,
209
- q_idx: torch.Tensor,
210
- kv_idx: torch.Tensor,
211
- ) -> torch.Tensor:
212
-
213
- q_abs = query_offset + q_idx
214
-
215
- query_is_active = active_mask[batch_idx, q_abs]
216
- key_is_active = active_mask[batch_idx, kv_idx]
217
-
218
- q_sem = semantic_positions[batch_idx, q_abs]
219
- k_sem = semantic_positions[batch_idx, kv_idx]
220
-
221
- is_causal = k_sem <= q_sem
222
- in_window = (q_sem - k_sem) < window_size
223
-
224
- return query_is_active & key_is_active & is_causal & in_window
225
-
226
- return create_block_mask(
227
- sliding_window_mask,
228
- B=batch_size,
229
- H=num_heads,
230
- Q_LEN=query_len,
231
- KV_LEN=kv_len,
232
- device=device,
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/cache/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .mosrah_cache import MoSRAHCache
2
- from .shram_cache import ShramCache
3
- from .shram_layer_cache import ShramLayerCache
4
- from .slow_mosrah_cache import SlowMoSRAHCache
5
-
6
- __all__ = ["MoSRAHCache", "ShramCache", "ShramLayerCache", "SlowMoSRAHCache"]
 
 
 
 
 
 
 
architecture_core/cache/mosrah_cache.py DELETED
@@ -1,359 +0,0 @@
1
- """MoSRAH sparse KV cache — single-layer implementation.
2
-
3
- MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
4
- by head rather than by sequence position. The routing is dynamic and produces a ragged
5
- distribution of token counts across (batch, head) slots — different batch items may
6
- route different numbers of tokens to the same head, and different heads accumulate at
7
- different rates. DynamicCache cannot represent this correctly: it concatenates along
8
- the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
9
- therefore uses a custom buffer design.
10
-
11
- Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
12
- attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
13
- heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
14
- head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
15
- valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
16
- buffer_capacity property and is derived directly from self.keys rather than tracked
17
- as a separate variable.
18
-
19
- The primary interface is update(key_states, value_states, active_mask), which accepts
20
- expert-choice layout, stores only active entries in causal order, and returns the full
21
- accumulated (keys, values, active_mask) for immediate use by BEA. The returned
22
- active_mask identifies valid cached positions; everything beyond each slot's count is
23
- junk data that downstream attention must exclude.
24
-
25
- BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
26
- exposed by get_heads_lengths() must be read before update() if the caller needs the
27
- pre-update occupancy for position computation (Unit 10.A). update() increments counts
28
- in-place and the pre-update values are not recoverable afterward.
29
-
30
- All buffers are allocated at construction time. MoSRAHCache is constructed by
31
- ShramLayerCache, which has access to batch size, device, and all model config parameters
32
- needed to fully specify the storage layout upfront.
33
- """
34
-
35
- import torch
36
- from transformers.cache_utils import CacheLayerMixin
37
-
38
-
39
- class MoSRAHCache(CacheLayerMixin):
40
- """KV cache for the MoSRAH sparse attention path — single decoder layer.
41
-
42
- Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
43
- Stores keys and values in the mixin-standard self.keys and self.values attributes
44
- using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
45
- which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
46
-
47
- All storage is allocated at construction time and is_initialized is True
48
- immediately. The caller (ShramLayerCache) provides batch size, device, and model
49
- config parameters so no lazy allocation is needed.
50
-
51
- Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
52
- (B, L, T) boolean active_mask. Only positions where active_mask is True are written.
53
- This matches the packed representation produced by expert packing in the MoSRAH
54
- forward pass, where BEA has already applied RoPE before calling update().
55
-
56
- Args:
57
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
58
- second dimension of all storage tensors.
59
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
60
- dimension of all storage tensors.
61
- batch_size: Number of sequences in the batch. Determines the first dimension
62
- of all storage tensors.
63
- device: Device on which to allocate all tensors. Should match the model device.
64
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
65
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
66
- during prompt processing.
67
- """
68
-
69
- is_compileable = False
70
- is_sliding = False
71
-
72
- def __init__(
73
- self,
74
- num_mosrah_heads: int,
75
- head_dim: int,
76
- batch_size: int,
77
- device: torch.device,
78
- initial_buffer_size: int = 64,
79
- ) -> None:
80
- super().__init__()
81
- self.num_mosrah_heads = num_mosrah_heads
82
- self.head_dim = head_dim
83
- self.batch_size = batch_size
84
- self.device = device
85
-
86
- # Allocate primary storage into the mixin-standard self.keys / self.values so
87
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
88
- # tracks valid occupancy per (batch, head) slot.
89
- self.keys: torch.Tensor = torch.zeros(
90
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
91
- )
92
- self.values: torch.Tensor = torch.zeros(
93
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
94
- )
95
- self._counts: torch.Tensor = torch.zeros(
96
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
97
- )
98
-
99
- # Storage is fully allocated at construction — the cache is initialized.
100
- self.is_initialized = True
101
-
102
- # ---------------------------------------------------------------------------
103
- # Properties
104
- # ---------------------------------------------------------------------------
105
-
106
- @property
107
- def buffer_capacity(self) -> int:
108
- """Current number of slots allocated per (batch, head) pair.
109
-
110
- Derived directly from self.keys rather than tracked separately, so it is
111
- always consistent with the actual buffer after expansion.
112
- """
113
- return self.keys.shape[2]
114
-
115
- # ---------------------------------------------------------------------------
116
- # Primary API
117
- # ---------------------------------------------------------------------------
118
-
119
- def update( # type: ignore[override]
120
- self,
121
- key_states: torch.Tensor,
122
- value_states: torch.Tensor,
123
- active_mask: torch.Tensor,
124
- cache_kwargs: dict | None = None,
125
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
- """Scatter active key/value states into the buffer and return the full cache state.
127
-
128
- Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
129
- active_mask is (B, L, T) bool with True marking real tokens. Only active
130
- positions are written; inactive positions are ignored.
131
-
132
- Uses a cumsum construction to derive the absolute buffer position for each
133
- active token without any Python loops. For a given (batch, head) slot,
134
- positions are assigned in the order tokens appear along the T dimension,
135
- preserving causal ordering.
136
-
137
- Returns the full accumulated (keys, values, active_mask) across the cached
138
- sparse sequence. The returned active_mask is True exactly for slots t <
139
- counts[b, l]; everything beyond is junk data that BEA must exclude.
140
-
141
- Note: get_heads_lengths() must be called before update() if the caller needs
142
- the pre-update occupancy for position computation (Unit 10.A). update()
143
- increments counts in-place and the pre-update values are not recoverable.
144
-
145
- Args:
146
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
147
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
148
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
149
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
150
-
151
- Returns:
152
- Tuple of (keys, values, active_mask):
153
- keys: (B, L, T, u) float — full key buffer including junk slots.
154
- values: (B, L, T, u) float — full value buffer including junk slots.
155
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
156
- """
157
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
158
-
159
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
160
- self._expand()
161
-
162
- # Cumulative count of active positions along T for each (b, l) slot. Entry
163
- # [b, l, t] is the 1-based rank of position t among all active positions in
164
- # that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
165
- # produce a negative value, which is excluded by the mask gate below.
166
- within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
167
-
168
- # Add the pre-update count to get the absolute buffer position for each
169
- # active token.
170
- abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
171
-
172
- # Scatter key and value vectors at all active positions.
173
- b_idx, l_idx, t_idx = torch.where(active_mask)
174
- self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
175
- key_states[b_idx, l_idx, t_idx]
176
- )
177
- self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
178
- value_states[b_idx, l_idx, t_idx]
179
- )
180
-
181
- self._counts += incoming_delta
182
-
183
- return self.keys, self.values, self._make_active_mask()
184
-
185
- def get_heads_lengths(self) -> torch.Tensor:
186
- """Return the per-(batch, head) token count for this layer.
187
-
188
- This is the authoritative occupancy tensor consumed by BEA for attention
189
- masking and by position computation (Unit 10.A) for semantic-sequence
190
- position computation.
191
-
192
- Note: in the MoSRAH forward pass, this must be called before update() if the
193
- caller needs the pre-update occupancy. update() increments these counts in-place.
194
-
195
- Returns:
196
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
197
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
198
- """
199
- return self._counts
200
-
201
- # ---------------------------------------------------------------------------
202
- # CacheLayerMixin — overridden coordination methods
203
- # ---------------------------------------------------------------------------
204
-
205
- def reset(self) -> None:
206
- """Clear all cached key and value tensors.
207
-
208
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
209
- and is_initialized remains True — only the contents are cleared.
210
- """
211
- self.keys.zero_()
212
- self.values.zero_()
213
- self._counts.zero_()
214
-
215
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
216
- """Reorder the batch dimension of all cached tensors for beam search.
217
-
218
- Applied atomically across self.keys, self.values, and _counts. Beam search
219
- must reorder all three together or the occupancy counts and buffer contents
220
- will correspond to different beam hypotheses.
221
-
222
- Overrides the parent because the parent's implementation calls get_seq_length(),
223
- which is not supported for this cache.
224
-
225
- Args:
226
- beam_idx: Permutation indices of shape (batch,) produced by the beam
227
- search algorithm.
228
- """
229
- self.keys = self.keys[beam_idx]
230
- self.values = self.values[beam_idx]
231
- self._counts = self._counts[beam_idx]
232
-
233
- def batch_repeat_interleave(self, repeats: int) -> None:
234
- """Expand the batch dimension by repeating each entry repeats times.
235
-
236
- Used at beam search initialisation to expand the cache from batch size B to
237
- B * repeats, matching the expanded beam candidate batch. Applied atomically
238
- across keys, values, and _counts; batch_size is updated to reflect the new size.
239
-
240
- Args:
241
- repeats: Number of times to repeat each batch entry.
242
- """
243
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
244
- self.values = self.values.repeat_interleave(repeats, dim=0)
245
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
246
- self.batch_size = self.batch_size * repeats
247
-
248
- def batch_select_indices(self, indices: torch.Tensor) -> None:
249
- """Select a subset of batch entries by index.
250
-
251
- Used in contrastive search to retain only the selected candidate entries.
252
- Applied atomically across keys, values, and _counts; batch_size is updated
253
- to reflect the number of retained entries.
254
-
255
- Args:
256
- indices: 1-D integer tensor of batch indices to retain.
257
- """
258
- self.keys = self.keys[indices]
259
- self.values = self.values[indices]
260
- self._counts = self._counts[indices]
261
- self.batch_size = indices.shape[0]
262
-
263
- def offload(self) -> None:
264
- """Offload all cached tensors to CPU.
265
-
266
- Extends the parent to also offload _counts, which the parent does not know
267
- about. All three tensors are moved atomically so device state remains consistent.
268
- """
269
- super().offload()
270
- self._counts = self._counts.to("cpu", non_blocking=True)
271
-
272
- def prefetch(self) -> None:
273
- """Move all cached tensors back to the model device ahead of time.
274
-
275
- Extends the parent to also prefetch _counts, which the parent does not know
276
- about. _counts is synced to self.keys.device after the parent moves keys and
277
- values, so all three remain consistent.
278
- """
279
- super().prefetch()
280
- if self._counts.device != self.keys.device:
281
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
282
-
283
- def lazy_initialization( # type: ignore[override]
284
- self, key_states: torch.Tensor, value_states: torch.Tensor
285
- ) -> None:
286
- """No-op — storage is fully allocated at construction time."""
287
- pass
288
-
289
- # ---------------------------------------------------------------------------
290
- # CacheLayerMixin — unsupported abstract methods
291
- # ---------------------------------------------------------------------------
292
-
293
- def get_seq_length(self) -> int: # type: ignore[override]
294
- """Not supported — no single sequence length represents this cache's state.
295
-
296
- MoSRAH heads accumulate independently; (batch, head) slots have different
297
- lengths depending on routing history. There is no meaningful scalar summary.
298
- Use get_heads_lengths() for per-head occupancy.
299
- """
300
- raise NotImplementedError(
301
- "MoSRAHCache has no single sequence length. "
302
- "Use get_heads_lengths() for per-head occupancy."
303
- )
304
-
305
- def get_max_cache_shape(self) -> int: # type: ignore[override]
306
- """Not supported — MoSRAHCache is dynamic and unbounded."""
307
- raise NotImplementedError(
308
- "MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
309
- )
310
-
311
- def get_mask_sizes( # type: ignore[override]
312
- self,
313
- cache_position: torch.Tensor,
314
- ) -> tuple[int, int]:
315
- """Not supported — MoSRAHCache does not participate in HF mask construction."""
316
- raise NotImplementedError(
317
- "MoSRAHCache does not support get_mask_sizes()."
318
- )
319
-
320
- # ---------------------------------------------------------------------------
321
- # Internal helpers
322
- # ---------------------------------------------------------------------------
323
-
324
- def _make_active_mask(self) -> torch.Tensor:
325
- """Construct the (B, L, T) active mask from current counts.
326
-
327
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
328
- has been written. Positions at or beyond the count are junk and must be
329
- excluded by downstream attention.
330
- """
331
- cap = self.buffer_capacity
332
- return (
333
- torch.arange(cap, device=self.keys.device)
334
- .expand(self.batch_size, self.num_mosrah_heads, cap)
335
- < self._counts.unsqueeze(-1)
336
- )
337
-
338
- def _expand(self) -> None:
339
- """Double the buffer capacity, preserving existing data.
340
-
341
- Called by update() when an incoming batch of tokens would cause any
342
- (batch, head) slot to exceed the current buffer capacity. All existing
343
- key and value data is copied into the low half of the new buffer; the
344
- high half is zero-initialised and will be filled by subsequent writes.
345
- After reassignment, buffer_capacity reflects the new size automatically.
346
- """
347
- old_cap = self.buffer_capacity
348
- new_cap = old_cap * 2
349
- dev = self.keys.device
350
- new_keys = torch.zeros(
351
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
352
- )
353
- new_values = torch.zeros(
354
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
355
- )
356
- new_keys[:, :, :old_cap, :] = self.keys
357
- new_values[:, :, :old_cap, :] = self.values
358
- self.keys = new_keys
359
- self.values = new_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/cache/shram_cache.py DELETED
@@ -1,141 +0,0 @@
1
- """SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
2
-
3
- The HuggingFace Cache protocol expects a single top-level Cache object that owns one
4
- CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
5
- lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
6
- ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
7
- presents them through the Cache interface, and transparently forwards model-wide operations
8
- across all of them.
9
-
10
- ShramCache does not define a composite update() interface. The two attention paths inside each
11
- SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
12
- nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
13
- relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
14
- coordination of the layer caches — not routing attention inputs.
15
-
16
- Sequence length is reported by delegating to the local sliding-window sub-cache of the
17
- specified layer, which tracks the cumulative count of token positions processed. This is
18
- what HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import Cache
23
-
24
- from .shram_layer_cache import ShramLayerCache
25
-
26
-
27
- class ShramCache(Cache):
28
- """Top-level cache for the full SHRAM model.
29
-
30
- Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
31
- role and transparently forwards reset, reorder, and sequence-length queries across all
32
- owned layer caches.
33
-
34
- No composite update() interface is provided. The two attention paths inside each SHRAM
35
- layer have materially different update semantics; callers must update sub-caches directly
36
- via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
37
-
38
- Args:
39
- num_hidden_layers: Number of SHRAM decoder layers. Determines how many
40
- ShramLayerCache objects are constructed.
41
- sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
42
- num_local_heads: Number of local attention heads per layer.
43
- local_head_dim: Per-head embedding width for the local path.
44
- num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
45
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
46
- batch_size: Number of sequences in the batch.
47
- device: Device on which to allocate cache tensors.
48
- initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
49
- Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
50
- during prompt processing.
51
- """
52
-
53
- def __init__(
54
- self,
55
- num_hidden_layers: int,
56
- sliding_window: int,
57
- num_local_heads: int,
58
- local_head_dim: int,
59
- num_mosrah_heads: int,
60
- mosrah_head_dim: int,
61
- batch_size: int,
62
- device: torch.device,
63
- initial_buffer_size: int = 64,
64
- ) -> None:
65
- layers = [
66
- ShramLayerCache(
67
- sliding_window=sliding_window,
68
- num_local_heads=num_local_heads,
69
- local_head_dim=local_head_dim,
70
- num_mosrah_heads=num_mosrah_heads,
71
- mosrah_head_dim=mosrah_head_dim,
72
- batch_size=batch_size,
73
- device=device,
74
- initial_buffer_size=initial_buffer_size,
75
- )
76
- for _ in range(num_hidden_layers)
77
- ]
78
- super().__init__(layers=layers)
79
-
80
- # ---------------------------------------------------------------------------
81
- # Cache — composite-meaningful methods
82
- # ---------------------------------------------------------------------------
83
- #
84
- # reset(): Inherited. Iterates all layer caches and calls reset() on each.
85
- #
86
- # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
87
- #
88
- # is_initialized: Inherited property. True iff all layer caches are initialized.
89
- # Since ShramLayerCache.is_initialized is True from construction, this is True
90
- # immediately after ShramCache.__init__ returns.
91
-
92
- def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
93
- """Return the cumulative sequence length for the specified layer.
94
-
95
- Delegates to the layer cache at layer_idx, which in turn delegates to the
96
- local sliding-window sub-cache. That sub-cache is authoritative for sequence
97
- progress: it sees every token presented to the layer and accumulates a truthful
98
- total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
99
- """
100
- return self.layers[layer_idx].get_seq_length()
101
-
102
- # ---------------------------------------------------------------------------
103
- # Cache — unsupported methods
104
- # ---------------------------------------------------------------------------
105
-
106
- def update( # type: ignore[override]
107
- self,
108
- key_states: torch.Tensor,
109
- value_states: torch.Tensor,
110
- layer_idx: int,
111
- cache_kwargs: dict | None = None,
112
- ) -> tuple[torch.Tensor, torch.Tensor]:
113
- """Not supported — ShramCache has no composite update interface.
114
-
115
- The two attention paths inside each SHRAM layer have different update semantics.
116
- Callers must update sub-caches directly:
117
- cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
118
- cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
119
- """
120
- raise NotImplementedError(
121
- "ShramCache has no composite update interface. "
122
- "Update sliding_window_cache or mosrah_cache on the relevant layer directly."
123
- )
124
-
125
- def crop(self, max_length: int) -> None:
126
- """Not supported — ShramCache layers do not implement crop()."""
127
- raise NotImplementedError("ShramCache does not support crop().")
128
-
129
- @property
130
- def max_batch_size(self) -> int:
131
- """Not supported — ShramCache does not track a uniform batch size across layers."""
132
- raise NotImplementedError("ShramCache does not expose max_batch_size.")
133
-
134
- @property
135
- def max_cache_len(self) -> int:
136
- """Not supported — ShramCache has no single maximum cache length.
137
-
138
- The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
139
- No truthful scalar maximum represents the composite.
140
- """
141
- raise NotImplementedError("ShramCache does not expose max_cache_len.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/cache/shram_layer_cache.py DELETED
@@ -1,233 +0,0 @@
1
- """SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
2
-
3
- A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
4
- local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
5
- different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
6
- the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
7
- path can interact with it without indirection.
8
-
9
- ShramLayerCache does not define a composite update() interface. The two paths have materially
10
- different update semantics — the local side uses chunk-local key/value/mask concatenation
11
- while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
12
- behind a single update() would hide those differences behind a misleading abstraction. Instead,
13
- each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
14
- ownership, coordination, and reset/reorder boundary for one decoder layer.
15
-
16
- Sequence length at this boundary is reported by delegating to the local sliding-window
17
- sub-cache, which tracks the cumulative count of token positions processed. This is the
18
- quantity HuggingFace generation reads through get_seq_length().
19
- """
20
-
21
- import torch
22
- from transformers.cache_utils import CacheLayerMixin
23
-
24
- from .mosrah_cache import MoSRAHCache
25
- from .sliding_window_cache import LocalSlidingWindowLayerCache
26
-
27
-
28
- class ShramLayerCache(CacheLayerMixin):
29
- """Cache subsystem for one SHRAM decoder layer.
30
-
31
- Owns and coordinates two sub-caches:
32
- - sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
33
- - mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
34
-
35
- Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
36
- exposed directly for their downstream attention paths — no composite update() interface is
37
- provided, because the two paths have materially different update semantics.
38
-
39
- Sequence length is reported by delegating to the local sliding-window sub-cache, which
40
- tracks the cumulative count of token positions processed across all update() calls.
41
-
42
- Args:
43
- sliding_window: Number of tokens retained by the local sliding-window cache.
44
- num_local_heads: Number of local attention heads.
45
- local_head_dim: Per-head embedding width for the local path.
46
- num_mosrah_heads: Total number of MoSRAH expert heads (L).
47
- mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
48
- batch_size: Number of sequences in the batch.
49
- device: Device on which to allocate cache tensors.
50
- initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
51
- when any slot overflows. Defaults to 64 to avoid repeated reallocation during
52
- prompt processing.
53
- """
54
-
55
- is_compileable = False
56
- is_sliding = False
57
-
58
- def __init__(
59
- self,
60
- sliding_window: int,
61
- num_local_heads: int,
62
- local_head_dim: int,
63
- num_mosrah_heads: int,
64
- mosrah_head_dim: int,
65
- batch_size: int,
66
- device: torch.device,
67
- initial_buffer_size: int = 64,
68
- ) -> None:
69
- super().__init__()
70
- self.sliding_window_cache = LocalSlidingWindowLayerCache(
71
- sliding_window=sliding_window,
72
- num_heads=num_local_heads,
73
- head_dim=local_head_dim,
74
- batch_size=batch_size,
75
- device=device,
76
- )
77
- self.mosrah_cache = MoSRAHCache(
78
- num_mosrah_heads=num_mosrah_heads,
79
- head_dim=mosrah_head_dim,
80
- batch_size=batch_size,
81
- device=device,
82
- initial_buffer_size=initial_buffer_size,
83
- )
84
-
85
- # ---------------------------------------------------------------------------
86
- # Properties
87
- # ---------------------------------------------------------------------------
88
-
89
- @property
90
- def is_initialized(self) -> bool:
91
- """True iff both sub-caches have allocated their storage.
92
-
93
- Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
94
- so this is True immediately after ShramLayerCache.__init__ returns.
95
- """
96
- return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
97
-
98
- @is_initialized.setter
99
- def is_initialized(self, value: bool) -> None:
100
- # CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
101
- # attribute. Since property is a data descriptor it takes precedence, but Python
102
- # still routes the assignment through __set__. Absorb it silently — state is
103
- # derived from sub-caches, not stored here.
104
- pass
105
-
106
- # ---------------------------------------------------------------------------
107
- # CacheLayerMixin — composite-meaningful methods
108
- # ---------------------------------------------------------------------------
109
-
110
- def get_seq_length(self) -> int: # type: ignore[override]
111
- """Return the cumulative sequence length from the local sliding-window path.
112
-
113
- The local path is authoritative for sequence progress: it sees every token
114
- presented to this layer and accumulates a truthful total. Delegates to
115
- sliding_window_cache.get_seq_length().
116
- """
117
- return self.sliding_window_cache.get_seq_length()
118
-
119
- def reset(self) -> None:
120
- """Clear both sub-caches.
121
-
122
- Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
123
- state and MoSRAH sparse state remain consistent.
124
- """
125
- self.sliding_window_cache.reset()
126
- self.mosrah_cache.reset()
127
-
128
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
129
- """Reorder the batch dimension of both sub-caches for beam search.
130
-
131
- Delegates to each sub-cache. Both are reordered atomically so the sliding-window
132
- and MoSRAH state correspond to the same beam hypotheses after reordering.
133
-
134
- Args:
135
- beam_idx: Permutation indices of shape (batch,) produced by beam search.
136
- """
137
- self.sliding_window_cache.reorder_cache(beam_idx)
138
- self.mosrah_cache.reorder_cache(beam_idx)
139
-
140
- def batch_repeat_interleave(self, repeats: int) -> None:
141
- """Expand the batch dimension of both sub-caches for beam search initialisation.
142
-
143
- Delegates atomically to each sub-cache. Both must be expanded together so the
144
- sliding-window and MoSRAH state correspond to the same beam candidates.
145
-
146
- Args:
147
- repeats: Number of times to repeat each batch entry.
148
- """
149
- self.sliding_window_cache.batch_repeat_interleave(repeats)
150
- self.mosrah_cache.batch_repeat_interleave(repeats)
151
-
152
- def batch_select_indices(self, indices: torch.Tensor) -> None:
153
- """Select a subset of batch entries in both sub-caches for contrastive search.
154
-
155
- Delegates atomically to each sub-cache. Both must be trimmed together so the
156
- sliding-window and MoSRAH state remain consistent.
157
-
158
- Args:
159
- indices: 1-D integer tensor of batch indices to retain.
160
- """
161
- self.sliding_window_cache.batch_select_indices(indices)
162
- self.mosrah_cache.batch_select_indices(indices)
163
-
164
- def offload(self) -> None:
165
- """Offload both sub-caches to CPU.
166
-
167
- Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
168
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
169
- """
170
- self.sliding_window_cache.offload()
171
- self.mosrah_cache.offload()
172
-
173
- def prefetch(self) -> None:
174
- """Move both sub-caches back to their model device ahead of time.
175
-
176
- Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
177
- does not own self.keys/self.values directly; all cached data lives in the sub-caches.
178
- """
179
- self.sliding_window_cache.prefetch()
180
- self.mosrah_cache.prefetch()
181
-
182
- def lazy_initialization( # type: ignore[override]
183
- self, key_states: torch.Tensor, value_states: torch.Tensor
184
- ) -> None:
185
- """No-op — both sub-caches handle their own initialization."""
186
- pass
187
-
188
- # ---------------------------------------------------------------------------
189
- # CacheLayerMixin — unsupported abstract methods
190
- # ---------------------------------------------------------------------------
191
-
192
- def update( # type: ignore[override]
193
- self,
194
- key_states: torch.Tensor,
195
- value_states: torch.Tensor,
196
- cache_kwargs: dict | None = None,
197
- ) -> tuple[torch.Tensor, torch.Tensor]:
198
- """Not supported — ShramLayerCache has no composite update interface.
199
-
200
- The two sub-caches have materially different update semantics: the sliding-window
201
- side uses standard key/value concatenation while the MoSRAH side uses expert-choice
202
- scatter with an active mask. Callers must update each sub-cache directly via
203
- sliding_window_cache.update() or mosrah_cache.update().
204
- """
205
- raise NotImplementedError(
206
- "ShramLayerCache has no composite update interface. "
207
- "Update sliding_window_cache or mosrah_cache directly."
208
- )
209
-
210
- def get_max_cache_shape(self) -> int: # type: ignore[override]
211
- """Not supported — the composite cache has no single maximum shape.
212
-
213
- The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
214
- unbounded. No truthful scalar maximum represents the composite.
215
- """
216
- raise NotImplementedError(
217
- "ShramLayerCache has no single maximum cache shape. "
218
- "Query sliding_window_cache or mosrah_cache directly."
219
- )
220
-
221
- def get_mask_sizes( # type: ignore[override]
222
- self,
223
- cache_position: torch.Tensor,
224
- ) -> tuple[int, int]:
225
- """Not supported — ShramLayerCache does not participate in HF mask construction.
226
-
227
- The two sub-caches have different mask semantics and their respective attention
228
- paths handle masking directly.
229
- """
230
- raise NotImplementedError(
231
- "ShramLayerCache does not support get_mask_sizes(). "
232
- "Query sliding_window_cache or mosrah_cache directly."
233
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/cache/sliding_window_cache.py DELETED
@@ -1,289 +0,0 @@
1
- # src/shram/model/cache/sliding_window_cache.py
2
-
3
- """Local sliding-window cache for the SHRAM local attention path.
4
-
5
- This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
6
- `ShramLayerCache` and consumed by `SlidingWindowAttention`.
7
-
8
- Its job is narrow:
9
-
10
- - accept the current chunk's local key/value tensors and active mask
11
- - return the current-step local frame consumed by local attention
12
- - separately retain the next-step sliding-window cache state
13
-
14
- It does not decide local causal visibility. That is owned by
15
- `SlidingWindowAttention`, which consumes the returned key/value/mask frame and
16
- constructs the effective local attention mask from it.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class LocalSlidingWindowLayerCache(CacheLayerMixin):
24
- """Fixed-width local cache for one SHRAM decoder layer.
25
-
26
- The cache keeps a retained local sliding-window buffer and an aligned active
27
- mask. On update, it returns the current-step local frame formed by
28
- concatenating retained cache state with the new chunk, then remembers only
29
- the last `sliding_window` positions for the next step.
30
-
31
- Dead positions are allowed to remain in both the returned frame and the
32
- retained cache. Correctness is carried by the aligned active mask.
33
-
34
- Args:
35
- sliding_window: Width of the retained local sliding-window buffer.
36
- num_heads: Number of local attention heads.
37
- head_dim: Per-head embedding width for the local path.
38
- batch_size: Number of sequences in the batch.
39
- device: Device on which to allocate cache storage.
40
- """
41
-
42
- is_compileable = False
43
- is_sliding = True
44
-
45
- def __init__(
46
- self,
47
- sliding_window: int,
48
- num_heads: int,
49
- head_dim: int,
50
- batch_size: int,
51
- device: torch.device,
52
- ) -> None:
53
- super().__init__()
54
-
55
- if sliding_window < 1:
56
- raise ValueError(
57
- f"sliding_window must be >= 1, got {sliding_window}."
58
- )
59
- if num_heads < 1:
60
- raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
61
- if head_dim < 1:
62
- raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
63
- if batch_size < 1:
64
- raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
65
-
66
- self.sliding_window = sliding_window
67
- self.num_heads = num_heads
68
- self.head_dim = head_dim
69
- self.batch_size = batch_size
70
- self.device = device
71
-
72
- # Retained next-step local cache state. Storage is fixed-width from the
73
- # start; semantic validity is carried by `active_mask`.
74
- self.keys = torch.zeros(
75
- batch_size,
76
- num_heads,
77
- sliding_window,
78
- head_dim,
79
- device=device,
80
- )
81
- self.values = torch.zeros(
82
- batch_size,
83
- num_heads,
84
- sliding_window,
85
- head_dim,
86
- device=device,
87
- )
88
- self.active_mask = torch.zeros(
89
- batch_size,
90
- sliding_window,
91
- dtype=torch.bool,
92
- device=device,
93
- )
94
-
95
- self.is_initialized = True
96
-
97
- # Cumulative count of all token positions presented through update() for
98
- # this cache instance. This is the quantity HuggingFace generation reads
99
- # through get_seq_length() to track how far along the sequence we are.
100
- self._total_processed: int = 0
101
-
102
- def update( # type: ignore[override]
103
- self,
104
- key_states: torch.Tensor,
105
- value_states: torch.Tensor,
106
- active_mask: torch.Tensor,
107
- cache_kwargs: dict | None = None,
108
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
- """Return the current-step local frame and retain the next-step window.
110
-
111
- Args:
112
- key_states: Shape `(B, H, T_new, D)` local key vectors for the
113
- current chunk.
114
- value_states: Shape `(B, H, T_new, D)` local value vectors for the
115
- current chunk.
116
- active_mask: Shape `(B, T_new)` bool. `True` means the
117
- corresponding token position in the current chunk is active.
118
- cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
- interface. Unused by this cache.
120
-
121
- Returns:
122
- Tuple of:
123
- - visible_keys: `(B, H, sliding_window + T_new, D)`
124
- - visible_values: `(B, H, sliding_window + T_new, D)`
125
- - visible_active_mask: `(B, sliding_window + T_new)`
126
-
127
- These are the tensors the local attention path should consume
128
- directly for the current step.
129
- """
130
- self._ensure_state_compatibility(
131
- key_states=key_states,
132
- value_states=value_states,
133
- )
134
-
135
- # The current-step local frame is just retained cache state followed by
136
- # the current chunk in chronological order.
137
- composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
- key_states=key_states,
139
- value_states=value_states,
140
- active_mask=active_mask,
141
- )
142
-
143
- # The cache remembers only the last raw sliding-window positions of that
144
- # composite frame for the next step. Dead positions are allowed to
145
- # survive; downstream local attention will ignore them using the mask.
146
- self._retain_next_window(
147
- composite_keys=composite_keys,
148
- composite_values=composite_values,
149
- composite_mask=composite_mask,
150
- )
151
-
152
- self._total_processed += key_states.shape[2]
153
-
154
- return composite_keys, composite_values, composite_mask
155
-
156
- def _ensure_state_compatibility(
157
- self,
158
- key_states: torch.Tensor,
159
- value_states: torch.Tensor,
160
- ) -> None:
161
- """Keep retained cache buffers compatible with the incoming update tensors.
162
-
163
- The cache is allocated eagerly for simplicity. If later updates arrive on
164
- a different device or in a different floating dtype, move the retained
165
- state to match while preserving its contents.
166
- """
167
- if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
168
- self.keys = self.keys.to(
169
- device=key_states.device,
170
- dtype=key_states.dtype,
171
- )
172
-
173
- if (
174
- self.values.dtype != value_states.dtype
175
- or self.values.device != value_states.device
176
- ):
177
- self.values = self.values.to(
178
- device=value_states.device,
179
- dtype=value_states.dtype,
180
- )
181
-
182
- if self.active_mask.device != key_states.device:
183
- self.active_mask = self.active_mask.to(
184
- key_states.device,
185
- non_blocking=True,
186
- )
187
-
188
- def _make_composite_frame(
189
- self,
190
- key_states: torch.Tensor,
191
- value_states: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
194
- """Build the current-step local frame in chronological order."""
195
- return (
196
- torch.cat([self.keys, key_states], dim=-2),
197
- torch.cat([self.values, value_states], dim=-2),
198
- torch.cat([self.active_mask, active_mask], dim=-1),
199
- )
200
-
201
- def _retain_next_window(
202
- self,
203
- composite_keys: torch.Tensor,
204
- composite_values: torch.Tensor,
205
- composite_mask: torch.Tensor,
206
- ) -> None:
207
- """Remember the next-step retained local state.
208
-
209
- This is a raw positional trim to the last `sliding_window` positions, not
210
- a semantic live-token trim.
211
- """
212
- self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
- self.values = composite_values[:, :, -self.sliding_window :, :]
214
- self.active_mask = composite_mask[:, -self.sliding_window :]
215
-
216
- def get_seq_length(self) -> int:
217
- """Return the cumulative number of token positions processed by this cache.
218
-
219
- This is the total count of token positions presented across all update()
220
- calls since construction or the last reset(). It is the quantity HuggingFace
221
- generation reads to track sequence progress and is not the same as active-token
222
- count or current window occupancy.
223
- """
224
- return self._total_processed
225
-
226
- def get_max_cache_shape(self) -> int:
227
- return self.sliding_window
228
-
229
- def get_mask_sizes( # type: ignore[override]
230
- self,
231
- cache_position: torch.Tensor,
232
- ) -> tuple[int, int]:
233
- raise NotImplementedError(
234
- "LocalSlidingWindowLayerCache does not support get_mask_sizes()."
235
- )
236
-
237
- def reset(self) -> None:
238
- """Restore fresh-cache behavior."""
239
- self.keys.zero_()
240
- self.values.zero_()
241
- self.active_mask.zero_()
242
- self._total_processed = 0
243
-
244
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
245
- """Reorder the batch dimension for beam search."""
246
- self.keys = self.keys[beam_idx]
247
- self.values = self.values[beam_idx]
248
- self.active_mask = self.active_mask[beam_idx]
249
-
250
- def batch_repeat_interleave(self, repeats: int) -> None:
251
- """Expand the batch dimension for beam-search initialisation."""
252
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
- self.values = self.values.repeat_interleave(repeats, dim=0)
254
- self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
255
- self.batch_size = self.batch_size * repeats
256
-
257
- def batch_select_indices(self, indices: torch.Tensor) -> None:
258
- """Select a subset of batch entries for contrastive search."""
259
- self.keys = self.keys[indices]
260
- self.values = self.values[indices]
261
- self.active_mask = self.active_mask[indices]
262
- self.batch_size = int(indices.shape[0])
263
-
264
- def offload(self) -> None:
265
- """Offload cache tensors to CPU."""
266
- super().offload()
267
- self.active_mask = self.active_mask.to("cpu", non_blocking=True)
268
-
269
- def prefetch(self) -> None:
270
- """Move cache tensors back to the model device ahead of time."""
271
- super().prefetch()
272
- if self.active_mask.device != self.keys.device:
273
- self.active_mask = self.active_mask.to(
274
- self.keys.device,
275
- non_blocking=True,
276
- )
277
-
278
- def crop(self, max_length: int) -> None:
279
- raise NotImplementedError(
280
- "LocalSlidingWindowLayerCache does not support crop()."
281
- )
282
-
283
- def lazy_initialization(
284
- self,
285
- key_states: torch.Tensor,
286
- value_states: torch.Tensor,
287
- ) -> None:
288
- """No-op — this cache allocates its fixed buffers at construction time."""
289
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/cache/slow_mosrah_cache.py DELETED
@@ -1,321 +0,0 @@
1
- """Unvectorized reference implementation of the MoSRAH sparse KV cache.
2
-
3
- This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
4
- interface and storage layout as MoSRAHCache but uses an explicit Python loop over
5
- (b, l, t) triples in update(). The loop is obviously correct by inspection: each active
6
- position's key and value are written to the next available slot for that (batch, head)
7
- pair, in the order positions appear along the T dimension, which directly enforces
8
- causal ordering without any index arithmetic to verify.
9
-
10
- SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
11
- trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
12
- Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
13
- vectorized implementation is validated by asserting exact agreement with this one on all
14
- test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
15
- (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
16
- an oracle.
17
- """
18
-
19
- import torch
20
- from transformers.cache_utils import CacheLayerMixin
21
-
22
-
23
- class SlowMoSRAHCache(CacheLayerMixin):
24
- """Unvectorized reference implementation of the MoSRAH KV cache.
25
-
26
- Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
27
- mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
28
- with the same constructor signature and the same CacheLayerMixin protocol methods.
29
- The sole difference is update(), which uses an explicit Python loop over (b, l, t)
30
- triples rather than vectorized index arithmetic.
31
-
32
- This class is not used in the model path. It exists so that MoSRAHCache.update()
33
- can be validated by asserting exact agreement with this implementation on all test
34
- inputs. See module docstring for the trust chain this enables.
35
-
36
- Args:
37
- num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
38
- second dimension of all storage tensors.
39
- head_dim: Bottlenecked head embedding width (u). Determines the fourth
40
- dimension of all storage tensors.
41
- batch_size: Number of sequences in the batch. Determines the first dimension
42
- of all storage tensors.
43
- device: Device on which to allocate all tensors. Should match the model device.
44
- initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
45
- when any slot overflows. Defaults to 64 to avoid repeated reallocation
46
- during prompt processing.
47
- """
48
-
49
- is_compileable = False
50
- is_sliding = False
51
-
52
- def __init__(
53
- self,
54
- num_mosrah_heads: int,
55
- head_dim: int,
56
- batch_size: int,
57
- device: torch.device,
58
- initial_buffer_size: int = 64,
59
- ) -> None:
60
- super().__init__()
61
- self.num_mosrah_heads = num_mosrah_heads
62
- self.head_dim = head_dim
63
- self.batch_size = batch_size
64
- self.device = device
65
-
66
- # Allocate primary storage into the mixin-standard self.keys / self.values so
67
- # that inherited methods (offload, prefetch) operate on real tensors. _counts
68
- # tracks valid occupancy per (batch, head) slot.
69
- self.keys: torch.Tensor = torch.zeros(
70
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
71
- )
72
- self.values: torch.Tensor = torch.zeros(
73
- batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
74
- )
75
- self._counts: torch.Tensor = torch.zeros(
76
- batch_size, num_mosrah_heads, dtype=torch.long, device=device
77
- )
78
-
79
- # Storage is fully allocated at construction — the cache is initialized.
80
- self.is_initialized = True
81
-
82
- # ---------------------------------------------------------------------------
83
- # Properties
84
- # ---------------------------------------------------------------------------
85
-
86
- @property
87
- def buffer_capacity(self) -> int:
88
- """Current number of slots allocated per (batch, head) pair.
89
-
90
- Derived directly from self.keys rather than tracked separately, so it is
91
- always consistent with the actual buffer after expansion.
92
- """
93
- return self.keys.shape[2]
94
-
95
- # ---------------------------------------------------------------------------
96
- # Primary API
97
- # ---------------------------------------------------------------------------
98
-
99
- def update( # type: ignore[override]
100
- self,
101
- key_states: torch.Tensor,
102
- value_states: torch.Tensor,
103
- active_mask: torch.Tensor,
104
- cache_kwargs: dict | None = None,
105
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106
- """Scatter active key/value states using an explicit loop; return full cache state.
107
-
108
- Iterates over every (b, l, t) triple. For each position where active_mask is
109
- True, the key and value are written to the next available slot for that
110
- (batch, head) pair and the count is incremented. Causal ordering is guaranteed
111
- because the t dimension is traversed from 0 to T-1 and counts are updated
112
- immediately after each write.
113
-
114
- Buffer expansion (doubling buffer_capacity) is triggered before any writes if
115
- the incoming tokens would cause any slot to overflow the current capacity.
116
-
117
- Args:
118
- key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
119
- value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
120
- active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
121
- cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
122
-
123
- Returns:
124
- Tuple of (keys, values, active_mask):
125
- keys: (B, L, T, u) float — full key buffer including junk slots.
126
- values: (B, L, T, u) float — full value buffer including junk slots.
127
- active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
128
- """
129
- B, L, T = active_mask.shape
130
-
131
- # Expansion check uses the total active tokens per slot, same as the
132
- # vectorized implementation, so both expand under identical conditions.
133
- incoming_delta = active_mask.long().sum(dim=2) # (B, L)
134
- if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
135
- self._expand()
136
-
137
- # Write each active position into the next available slot for its (batch, head)
138
- # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
139
- for b in range(B):
140
- for l in range(L):
141
- for t in range(T):
142
- if active_mask[b, l, t]:
143
- pos = self._counts[b, l].item()
144
- self.keys[b, l, pos, :] = key_states[b, l, t, :]
145
- self.values[b, l, pos, :] = value_states[b, l, t, :]
146
- self._counts[b, l] += 1
147
-
148
- return self.keys, self.values, self._make_active_mask()
149
-
150
- def get_heads_lengths(self) -> torch.Tensor:
151
- """Return the per-(batch, head) token count for this layer.
152
-
153
- This is the authoritative occupancy tensor consumed by BEA for attention
154
- masking and by position computation (Unit 10.A) for semantic-sequence
155
- position computation.
156
-
157
- Returns:
158
- Integer tensor of shape (B, L) where entry [b, h] is the number of valid
159
- tokens stored in the (b, h) slot. Zero for slots with no writes yet.
160
- """
161
- return self._counts
162
-
163
- # ---------------------------------------------------------------------------
164
- # CacheLayerMixin — overridden coordination methods
165
- # ---------------------------------------------------------------------------
166
-
167
- def reset(self) -> None:
168
- """Clear all cached key and value tensors.
169
-
170
- Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
171
- and is_initialized remains True — only the contents are cleared.
172
- """
173
- self.keys.zero_()
174
- self.values.zero_()
175
- self._counts.zero_()
176
-
177
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
178
- """Reorder the batch dimension of all cached tensors for beam search.
179
-
180
- Applied atomically across self.keys, self.values, and _counts. Beam search
181
- must reorder all three together or the occupancy counts and buffer contents
182
- will correspond to different beam hypotheses.
183
-
184
- Overrides the parent because the parent's implementation calls get_seq_length(),
185
- which is not supported for this cache.
186
-
187
- Args:
188
- beam_idx: Permutation indices of shape (batch,) produced by the beam
189
- search algorithm.
190
- """
191
- self.keys = self.keys[beam_idx]
192
- self.values = self.values[beam_idx]
193
- self._counts = self._counts[beam_idx]
194
-
195
- def batch_repeat_interleave(self, repeats: int) -> None:
196
- """Expand the batch dimension by repeating each entry repeats times.
197
-
198
- Used at beam search initialisation to expand the cache from batch size B to
199
- B * repeats, matching the expanded beam candidate batch. Applied atomically
200
- across keys, values, and _counts; batch_size is updated to reflect the new size.
201
-
202
- Args:
203
- repeats: Number of times to repeat each batch entry.
204
- """
205
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
206
- self.values = self.values.repeat_interleave(repeats, dim=0)
207
- self._counts = self._counts.repeat_interleave(repeats, dim=0)
208
- self.batch_size = self.batch_size * repeats
209
-
210
- def batch_select_indices(self, indices: torch.Tensor) -> None:
211
- """Select a subset of batch entries by index.
212
-
213
- Used in contrastive search to retain only the selected candidate entries.
214
- Applied atomically across keys, values, and _counts; batch_size is updated
215
- to reflect the number of retained entries.
216
-
217
- Args:
218
- indices: 1-D integer tensor of batch indices to retain.
219
- """
220
- self.keys = self.keys[indices]
221
- self.values = self.values[indices]
222
- self._counts = self._counts[indices]
223
- self.batch_size = indices.shape[0]
224
-
225
- def offload(self) -> None:
226
- """Offload all cached tensors to CPU.
227
-
228
- Extends the parent to also offload _counts, which the parent does not know
229
- about. All three tensors are moved atomically so device state remains consistent.
230
- """
231
- super().offload()
232
- self._counts = self._counts.to("cpu", non_blocking=True)
233
-
234
- def prefetch(self) -> None:
235
- """Move all cached tensors back to the model device ahead of time.
236
-
237
- Extends the parent to also prefetch _counts, which the parent does not know
238
- about. _counts is synced to self.keys.device after the parent moves keys and
239
- values, so all three remain consistent.
240
- """
241
- super().prefetch()
242
- if self._counts.device != self.keys.device:
243
- self._counts = self._counts.to(self.keys.device, non_blocking=True)
244
-
245
- def lazy_initialization( # type: ignore[override]
246
- self, key_states: torch.Tensor, value_states: torch.Tensor
247
- ) -> None:
248
- """No-op — storage is fully allocated at construction time."""
249
- pass
250
-
251
- # ---------------------------------------------------------------------------
252
- # CacheLayerMixin — unsupported abstract methods
253
- # ---------------------------------------------------------------------------
254
-
255
- def get_seq_length(self) -> int: # type: ignore[override]
256
- """Not supported — no single sequence length represents this cache's state.
257
-
258
- MoSRAH heads accumulate independently; (batch, head) slots have different
259
- lengths depending on routing history. There is no meaningful scalar summary.
260
- Use get_heads_lengths() for per-head occupancy.
261
- """
262
- raise NotImplementedError(
263
- "SlowMoSRAHCache has no single sequence length. "
264
- "Use get_heads_lengths() for per-head occupancy."
265
- )
266
-
267
- def get_max_cache_shape(self) -> int: # type: ignore[override]
268
- """Not supported — SlowMoSRAHCache is dynamic and unbounded."""
269
- raise NotImplementedError(
270
- "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
271
- )
272
-
273
- def get_mask_sizes( # type: ignore[override]
274
- self,
275
- cache_position: torch.Tensor,
276
- ) -> tuple[int, int]:
277
- """Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
278
- raise NotImplementedError(
279
- "SlowMoSRAHCache does not support get_mask_sizes()."
280
- )
281
-
282
- # ---------------------------------------------------------------------------
283
- # Internal helpers
284
- # ---------------------------------------------------------------------------
285
-
286
- def _make_active_mask(self) -> torch.Tensor:
287
- """Construct the (B, L, T) active mask from current counts.
288
-
289
- Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
290
- has been written. Positions at or beyond the count are junk and must be
291
- excluded by downstream attention.
292
- """
293
- cap = self.buffer_capacity
294
- return (
295
- torch.arange(cap, device=self.keys.device)
296
- .expand(self.batch_size, self.num_mosrah_heads, cap)
297
- < self._counts.unsqueeze(-1)
298
- )
299
-
300
- def _expand(self) -> None:
301
- """Double the buffer capacity, preserving existing data.
302
-
303
- Called by update() when an incoming batch of tokens would cause any
304
- (batch, head) slot to exceed the current buffer capacity. All existing
305
- key and value data is copied into the low half of the new buffer; the
306
- high half is zero-initialised and will be filled by subsequent writes.
307
- After reassignment, buffer_capacity reflects the new size automatically.
308
- """
309
- old_cap = self.buffer_capacity
310
- new_cap = old_cap * 2
311
- dev = self.keys.device
312
- new_keys = torch.zeros(
313
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
314
- )
315
- new_values = torch.zeros(
316
- self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
317
- )
318
- new_keys[:, :, :old_cap, :] = self.keys
319
- new_values[:, :, :old_cap, :] = self.values
320
- self.keys = new_keys
321
- self.values = new_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/config.json DELETED
@@ -1,28 +0,0 @@
1
- {
2
- "alpha": 1.0,
3
- "attention_dropout": 0.0,
4
- "auto_map": {
5
- "AutoConfig": "configuration.ShramConfig",
6
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM"
7
- },
8
- "beta": 32.0,
9
- "head_dim": 16,
10
- "hidden_size": 512,
11
- "inference_sequence_length": 8192,
12
- "intermediate_size": 1366,
13
- "local_rope_theta": 10000.0,
14
- "model_type": "shram",
15
- "mosrah_rope_theta": 10000.0,
16
- "num_hidden_layers": 12,
17
- "num_mosrah_heads": 16,
18
- "num_selected_heads": 16,
19
- "num_sliding_window_heads": 16,
20
- "rms_norm_eps": 1e-05,
21
- "rope_mode": "main_sequence",
22
- "tie_word_embeddings": false,
23
- "training_sequence_length": 8192,
24
- "transformers_version": "5.3.0",
25
- "use_cache": true,
26
- "vocab_size": 50277,
27
- "window_size": 128
28
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/configuration.py DELETED
@@ -1,175 +0,0 @@
1
- """Configuration for the SHRAM transformer.
2
-
3
- All architectural parameters that vary across model scales or are meaningful research
4
- variables are expressed here. Architectural constants (no bias in linear layers,
5
- SwiGLU activation with SiLU gate) are implemented in the relevant modules and
6
- documented at the point of use — they are not config parameters because they do not
7
- vary and changing them produces a different architecture, not a different scale.
8
-
9
- RoPE configuration is owned entirely by this config. Each attention path reads its
10
- parameters directly and constructs its own RotaryEmbedding instance explicitly — no
11
- HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
12
- """
13
-
14
- from transformers import PretrainedConfig
15
-
16
-
17
- class ShramConfig(PretrainedConfig):
18
- """Configuration class for the SHRAM decoder-only transformer.
19
-
20
- SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
21
- attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
22
- local sliding-window causal attention path and h_s is the MoSRAH sparse routed
23
- path. All other components follow the Llama 3 baseline.
24
-
25
- This config is the single source of truth for every architectural dimension of the
26
- model. Nothing in the architecture may use a literal number that belongs here.
27
-
28
- Two independent RoPE configurations exist — one per attention path:
29
-
30
- - h_l always uses standard RoPE with ``local_rope_theta``.
31
- - BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
32
- ``inference_sequence_length``, ``alpha``, and ``beta``. When
33
- ``inference_sequence_length == training_sequence_length`` the YaRN scale factor
34
- ``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
35
- and the correct setting for experiments that do not require context extension.
36
-
37
- Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
38
-
39
- config = AutoConfig.from_pretrained(
40
- "your-namespace/advanced-transformers-lib",
41
- trust_remote_code=True,
42
- num_hidden_layers=12,
43
- )
44
- model = AutoModelForCausalLM.from_config(config)
45
-
46
- Args:
47
- vocab_size: Vocabulary size. Controls the embedding table and output logits
48
- dimension. Must match the tokenizer.
49
- hidden_size: Model width ``d``. The dimension of the residual stream.
50
- intermediate_size: FFN hidden dimension.
51
- num_hidden_layers: Number of transformer blocks stacked in sequence.
52
- num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
53
- num_mosrah_heads: Total MoSRAH expert heads available ``L``.
54
- num_selected_heads: MoSRAH heads each token selects ``K``.
55
- head_dim: Per-head dimension, shared by both attention paths. Must be even
56
- (RoPE rotates dimensions in pairs). Paper uses 16.
57
- window_size: Sliding window size for h_l. Paper uses 128.
58
- rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
59
- original sequence positions; ``"semantic_sequence"`` supplies local slot
60
- indices. Both are required; experimentally correct mode is undetermined
61
- (paper §4). Default ``"main_sequence"``.
62
- rms_norm_eps: Epsilon for RMSNorm layers.
63
- local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
64
- Paper uses b=10000.
65
- mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
66
- b=10000.
67
- training_sequence_length: Context length ``C_train`` the model was or will be
68
- trained at. Used to compute the YaRN scale factor for BEA.
69
- inference_sequence_length: Context length ``C_target`` the model must support
70
- at inference. When equal to ``training_sequence_length``, scale ``s=1``
71
- and YaRN reduces to standard RoPE.
72
- alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
73
- ``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
74
- beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
75
- ``r(d) > beta`` are left unscaled. Paper value: 32.0.
76
- attention_dropout: Dropout probability on attention weights. Default 0.0.
77
- use_cache: Whether to return past_key_values for KV caching.
78
- output_hidden_states: Whether to return hidden states after each layer.
79
- tie_word_embeddings: Whether input embedding and LM head share weights.
80
- """
81
-
82
- model_type = "shram"
83
-
84
- auto_map = {
85
- "AutoConfig": "configuration.ShramConfig",
86
- "AutoModelForCausalLM": "huggingface.ShramForCausalLM",
87
- }
88
-
89
- def __init__(
90
- self,
91
- vocab_size: int = 50277,
92
- hidden_size: int = 512,
93
- intermediate_size: int = 1366,
94
- num_hidden_layers: int = 12,
95
- num_sliding_window_heads: int = 16,
96
- num_mosrah_heads: int = 16,
97
- num_selected_heads: int = 16,
98
- head_dim: int = 16,
99
- window_size: int = 128,
100
- rope_mode: str = "main_sequence",
101
- rms_norm_eps: float = 1e-5,
102
- local_rope_theta: float = 10000.0,
103
- mosrah_rope_theta: float = 10000.0,
104
- training_sequence_length: int = 8192,
105
- inference_sequence_length: int = 8192,
106
- alpha: float = 1.0,
107
- beta: float = 32.0,
108
- attention_dropout: float = 0.0,
109
- use_cache: bool = True,
110
- output_hidden_states: bool = False,
111
- tie_word_embeddings: bool = False,
112
- **kwargs,
113
- ):
114
- if head_dim % 2 != 0:
115
- raise ValueError(
116
- f"head_dim must be even (RoPE rotates dimensions in pairs). "
117
- f"Got head_dim={head_dim}."
118
- )
119
-
120
- if rope_mode not in {"main_sequence", "semantic_sequence"}:
121
- raise ValueError(
122
- f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
123
- f"got '{rope_mode}'."
124
- )
125
-
126
- if training_sequence_length <= 0:
127
- raise ValueError(
128
- f"training_sequence_length must be positive, "
129
- f"got {training_sequence_length}."
130
- )
131
-
132
- if inference_sequence_length <= 0:
133
- raise ValueError(
134
- f"inference_sequence_length must be positive, "
135
- f"got {inference_sequence_length}."
136
- )
137
-
138
- self.vocab_size = vocab_size
139
- self.hidden_size = hidden_size
140
- self.intermediate_size = intermediate_size
141
- self.num_hidden_layers = num_hidden_layers
142
- self.num_sliding_window_heads = num_sliding_window_heads
143
- self.num_mosrah_heads = num_mosrah_heads
144
- self.num_selected_heads = num_selected_heads
145
- self.head_dim = head_dim
146
- self.window_size = window_size
147
- self.rope_mode = rope_mode
148
- self.rms_norm_eps = rms_norm_eps
149
- self.local_rope_theta = local_rope_theta
150
- self.mosrah_rope_theta = mosrah_rope_theta
151
- self.training_sequence_length = training_sequence_length
152
- self.inference_sequence_length = inference_sequence_length
153
- self.alpha = alpha
154
- self.beta = beta
155
- self.attention_dropout = attention_dropout
156
- self.use_cache = use_cache
157
-
158
- super().__init__(
159
- tie_word_embeddings=tie_word_embeddings,
160
- output_hidden_states=output_hidden_states,
161
- **kwargs,
162
- )
163
-
164
- # Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
165
- # serialises it into config.json.
166
- self.auto_map = type(self).auto_map
167
-
168
- @property
169
- def scale(self) -> float:
170
- """YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
171
-
172
- When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
173
- adjustments cancel and A_rope = 1. This is the default state.
174
- """
175
- return self.inference_sequence_length / self.training_sequence_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/decoder_layer.py DELETED
@@ -1,87 +0,0 @@
1
- """Decoder layer — a single transformer block.
2
-
3
- Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
4
- residual connections around both sublayers:
5
-
6
- normed_attn = RMSNorm(x)
7
- attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
8
- h = x + attn_out
9
-
10
- normed_mlp = RMSNorm(h)
11
- mlp_out = SwiGLUMLP(normed_mlp)
12
- out = h + mlp_out
13
-
14
- Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
15
- through unnormalised residuals at depth, and each sublayer receives a stable,
16
- normalised view of the signal.
17
-
18
- Two independent RMSNorm instances are used — one before attention, one before
19
- MLP. They learn different scalings because they precede layers with different
20
- dynamic ranges. Sharing them would be wrong.
21
-
22
- torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
23
- subtraction, is faster than LayerNorm, and proved more stable at scale.
24
- """
25
-
26
- import torch
27
- import torch.nn as nn
28
-
29
- from .attention.shram import SHRAMHybridLayer
30
- from .cache.shram_layer_cache import ShramLayerCache
31
- from .configuration import ShramConfig
32
- from .mlp import SwiGLUMLP
33
-
34
-
35
- class DecoderLayer(nn.Module):
36
- """A single pre-norm SHRAM decoder block.
37
-
38
- Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
39
- independent RMSNorm instances on each sublayer input.
40
-
41
- Args:
42
- config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
43
- in addition to the fields required by SHRAMHybridLayer and
44
- SwiGLUMLP.
45
- """
46
-
47
- def __init__(self, config: ShramConfig) -> None:
48
- super().__init__()
49
- self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
50
- self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51
- self.attention = SHRAMHybridLayer(config)
52
- self.mlp = SwiGLUMLP(config)
53
-
54
- def forward(
55
- self,
56
- x: torch.Tensor,
57
- position_ids: torch.Tensor,
58
- active_mask: torch.Tensor,
59
- cache: ShramLayerCache | None = None,
60
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
61
- """Apply one decoder block to the input.
62
-
63
- Args:
64
- x: Input of shape (batch, seq_len, hidden_size).
65
- position_ids: Authoritative positions of shape (batch, seq_len).
66
- active_mask: Current-chunk active mask of shape (batch, seq_len),
67
- where True means the token is semantically live. Forwarded
68
- unchanged to the hybrid attention layer.
69
- cache: Optional per-layer SHRAM cache passed through to the hybrid
70
- attention layer unchanged.
71
-
72
- Returns:
73
- output: Tensor of shape (batch, seq_len, hidden_size).
74
- load_balance_loss: Scalar sparse-path load-balance loss propagated
75
- from SHRAMHybridLayer.
76
- max_vio: Detached scalar routing-imbalance summary. Passed through
77
- unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
78
- """
79
- attn_out, load_balance_loss, max_vio = self.attention(
80
- hidden_states=self.attn_norm(x),
81
- position_ids=position_ids,
82
- active_mask=active_mask,
83
- cache=cache,
84
- )
85
- hidden_states = x + attn_out
86
- output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
87
- return output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/huggingface.py DELETED
@@ -1,506 +0,0 @@
1
- """HuggingFace causal-LM wrapper for SHRAM.
2
-
3
- ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
4
- It owns token embedding lookup, LM-head projection, wrapper-level next-token
5
- cross-entropy loss, config-controlled tied embeddings, and generation/cache
6
- orchestration at the wrapper boundary.
7
-
8
- The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
9
- hidden states together with current position IDs, a current active mask, and an
10
- optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
11
- or causal-LM loss.
12
-
13
- HuggingFace generation reaches this wrapper with two different tensor
14
- conventions:
15
-
16
- - ``position_ids`` is a current-step tensor. GenerationMixin updates the total
17
- sequence state between steps, then slices position-bearing tensors back down
18
- before calling ``forward()``.
19
- - ``attention_mask`` is a full 2D mask over the total sequence so far. This
20
- wrapper slices its recent chunk to produce the current semantic liveness mask
21
- expected by the backbone.
22
-
23
- Generation-created caches are handled in ``_prepare_cache_for_generation``.
24
- That hook ensures HuggingFace generation uses ShramCache rather than a generic
25
- dynamic cache. The direct ``forward()`` path does not silently create caches;
26
- when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
27
- """
28
-
29
- from dataclasses import dataclass
30
- from typing import Any
31
-
32
- import torch
33
- import torch.nn as nn
34
- from transformers import GenerationMixin, PreTrainedModel
35
- from transformers.cache_utils import Cache
36
- from transformers.generation.configuration_utils import GenerationMode
37
- from transformers.modeling_outputs import CausalLMOutputWithPast
38
-
39
- from .cache.shram_cache import ShramCache
40
- from .configuration import ShramConfig
41
- from .model import ShramModel
42
-
43
-
44
- @dataclass
45
- class ShramCausalLMOutput(CausalLMOutputWithPast):
46
- """SHRAM causal-LM wrapper output.
47
-
48
- This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
49
- Dataclass inheritance is sufficient here: all standard causal-LM fields and
50
- ModelOutput behavior are inherited from the parent, and this subclass adds
51
- only the SHRAM-specific wrapper outputs.
52
- """
53
-
54
- load_balance_loss: torch.FloatTensor | None = None
55
- max_vio: torch.FloatTensor | None = None
56
-
57
-
58
- class ShramForCausalLM(PreTrainedModel, GenerationMixin):
59
- """HuggingFace-facing causal language model wrapper for SHRAM.
60
-
61
- Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
62
- tied embedding configuration, and generation/cache boundary behavior.
63
- Delegates all transformer computation to ``ShramModel``.
64
-
65
- Args:
66
- config: SHRAM model configuration.
67
- """
68
-
69
- config_class = ShramConfig
70
- base_model_prefix = "model"
71
- _no_split_modules = ["DecoderLayer"]
72
- supports_gradient_checkpointing = True
73
-
74
- def __init__(self, config: ShramConfig) -> None:
75
- super().__init__(config)
76
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
77
- self.model = ShramModel(config)
78
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
79
- self._configure_tied_embeddings()
80
- self.post_init()
81
-
82
- def _configure_tied_embeddings(self) -> None:
83
- """Apply config-controlled tied embedding behavior on this instance."""
84
- if self.config.tie_word_embeddings:
85
- self.lm_head.weight = self.embed_tokens.weight
86
- self._tied_weights_keys = {
87
- "lm_head.weight": "embed_tokens.weight",
88
- }
89
- else:
90
- self._tied_weights_keys = {}
91
-
92
- def get_input_embeddings(self) -> nn.Embedding:
93
- """Return the token embedding matrix."""
94
- return self.embed_tokens
95
-
96
- def set_input_embeddings(self, value: nn.Embedding) -> None:
97
- """Replace the token embedding matrix."""
98
- self.embed_tokens = value
99
- self._configure_tied_embeddings()
100
-
101
- def get_output_embeddings(self) -> nn.Linear:
102
- """Return the LM head."""
103
- return self.lm_head
104
-
105
- def set_output_embeddings(self, value: nn.Linear) -> None:
106
- """Replace the LM head."""
107
- self.lm_head = value
108
- self._configure_tied_embeddings()
109
-
110
- def _build_shram_cache(
111
- self,
112
- batch_size: int,
113
- device: torch.device,
114
- ) -> ShramCache:
115
- """Construct a fresh top-level SHRAM cache."""
116
- return ShramCache(
117
- num_hidden_layers=self.config.num_hidden_layers,
118
- sliding_window=self.config.window_size,
119
- num_local_heads=self.config.num_sliding_window_heads,
120
- local_head_dim=self.config.head_dim,
121
- num_mosrah_heads=self.config.num_mosrah_heads,
122
- mosrah_head_dim=self.config.hidden_size // self.config.num_selected_heads,
123
- batch_size=batch_size,
124
- device=device,
125
- )
126
-
127
- def _validate_generation_cache_request(
128
- self,
129
- generation_config: Any,
130
- model_kwargs: dict[str, Any],
131
- generation_mode: GenerationMode,
132
- ) -> None:
133
- """Validate SHRAM's generation-side cache policy."""
134
- if generation_mode in {
135
- GenerationMode.ASSISTED_GENERATION,
136
- GenerationMode.CONTRASTIVE_SEARCH,
137
- }:
138
- raise NotImplementedError(
139
- "ShramForCausalLM does not currently support assisted generation "
140
- "or contrastive search because ShramCache does not support crop()."
141
- )
142
-
143
- user_defined_cache = model_kwargs.get("past_key_values")
144
- if user_defined_cache is not None:
145
- if generation_config.cache_implementation is not None:
146
- raise ValueError(
147
- "Passing both `cache_implementation` and `past_key_values` "
148
- "is unsupported. Please use only one."
149
- )
150
- if isinstance(user_defined_cache, tuple):
151
- raise ValueError(
152
- "Passing a tuple of `past_key_values` is not supported. "
153
- "Please use a `ShramCache` instance."
154
- )
155
- if not isinstance(user_defined_cache, ShramCache):
156
- raise TypeError(
157
- "ShramForCausalLM requires `past_key_values` to be a "
158
- "`ShramCache` instance."
159
- )
160
-
161
- if (
162
- user_defined_cache is None
163
- and generation_config.use_cache
164
- and generation_config.cache_implementation is not None
165
- ):
166
- raise ValueError(
167
- "ShramForCausalLM does not support `cache_implementation`. "
168
- "Generation-created caches must be `ShramCache` objects."
169
- )
170
-
171
- def _prepare_cache_for_generation(
172
- self,
173
- generation_config: Any,
174
- model_kwargs: dict[str, Any],
175
- generation_mode: GenerationMode,
176
- batch_size: int,
177
- max_cache_length: int,
178
- ) -> None:
179
- """Ensure HuggingFace generation uses ShramCache.
180
-
181
- This is the SHRAM-specific generation hook. The rest of the default
182
- generation plumbing is kept intact as much as possible.
183
-
184
- Args:
185
- generation_config: Active generation configuration.
186
- model_kwargs: Generation kwargs, updated in place.
187
- generation_mode: HuggingFace generation mode.
188
- batch_size: Effective generation batch size.
189
- max_cache_length: Requested cache length. Accepted but unused here.
190
- """
191
- self._validate_generation_cache_request(
192
- generation_config=generation_config,
193
- model_kwargs=model_kwargs,
194
- generation_mode=generation_mode,
195
- )
196
-
197
- if model_kwargs.get("past_key_values") is not None:
198
- return
199
-
200
- if not generation_config.use_cache:
201
- return
202
-
203
- num_repeats = max(
204
- generation_config.num_beams,
205
- generation_config.num_return_sequences,
206
- )
207
- model_kwargs["past_key_values"] = self._build_shram_cache(
208
- batch_size=batch_size*num_repeats,
209
- device=self.embed_tokens.weight.device,
210
- )
211
-
212
- def _reorder_cache(
213
- self,
214
- past_key_values: Cache,
215
- beam_idx: torch.Tensor,
216
- ) -> Cache:
217
- """Reorder the cache in place for beam search."""
218
- past_key_values.reorder_cache(beam_idx)
219
- return past_key_values
220
-
221
- def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
222
- """Validate token IDs at the wrapper boundary."""
223
- if input_ids.ndim != 2:
224
- raise ValueError("input_ids must have shape (batch, seq_len).")
225
- if input_ids.shape[1] == 0:
226
- raise ValueError("input_ids sequence length must be nonzero.")
227
- if input_ids.dtype != torch.long:
228
- raise TypeError("input_ids must be an long int tensor.")
229
-
230
- def _validate_attention_mask(
231
- self,
232
- input_ids: torch.Tensor,
233
- attention_mask: torch.Tensor | None,
234
- ) -> None:
235
- """Validate the full-sequence attention mask."""
236
- if attention_mask is None:
237
- return
238
- if attention_mask.ndim != 2:
239
- raise ValueError("attention_mask must have shape (batch, total_seq_len).")
240
- if attention_mask.shape[0] != input_ids.shape[0]:
241
- raise ValueError("attention_mask batch dimension must match input_ids.")
242
- if attention_mask.shape[1] < input_ids.shape[1]:
243
- raise ValueError(
244
- "attention_mask must be at least as long as the current input_ids chunk."
245
- )
246
-
247
- def _validate_position_ids(
248
- self,
249
- input_ids: torch.Tensor,
250
- position_ids: torch.Tensor | None,
251
- ) -> None:
252
- """Validate current-step position IDs."""
253
- if position_ids is None:
254
- return
255
- if position_ids.ndim != 2:
256
- raise ValueError("position_ids must have shape (batch, seq_len).")
257
- if position_ids.shape != input_ids.shape:
258
- raise ValueError(
259
- "position_ids must match the current input_ids shape exactly."
260
- )
261
- if input_ids.dtype != torch.long:
262
- raise TypeError("position_ids must be an long tensor.")
263
-
264
- def _validate_labels(
265
- self,
266
- input_ids: torch.Tensor,
267
- labels: torch.Tensor | None,
268
- ) -> None:
269
- """Validate label shape at the wrapper boundary."""
270
- if labels is None:
271
- return
272
- if labels.ndim != 2:
273
- raise ValueError("labels must have shape (batch, seq_len).")
274
- if labels.shape != input_ids.shape:
275
- raise ValueError("labels must have the same shape as input_ids.")
276
- if input_ids.dtype != torch.long:
277
- raise TypeError("labels must be a long tensor.")
278
-
279
- def _validate_cache_inputs(
280
- self,
281
- use_cache: bool,
282
- past_key_values: Cache | None,
283
- ) -> None:
284
- """Validate cache policy for direct wrapper calls."""
285
- if use_cache:
286
- if past_key_values is None:
287
- raise ValueError(
288
- "use_cache=True requires an explicit ShramCache. During "
289
- "generate(), HuggingFace should supply this through "
290
- "_prepare_cache_for_generation()."
291
- )
292
- if not isinstance(past_key_values, ShramCache):
293
- raise TypeError(
294
- "past_key_values must be a ShramCache when use_cache=True."
295
- )
296
- return
297
-
298
- if past_key_values is not None:
299
- raise ValueError("past_key_values was provided while use_cache=False.")
300
-
301
- def _validate_position_sources(
302
- self,
303
- use_cache: bool,
304
- attention_mask: torch.Tensor | None,
305
- position_ids: torch.Tensor | None,
306
- ) -> None:
307
- """Validate that cached forward has a truthful source of positions."""
308
- if use_cache and attention_mask is None and position_ids is None:
309
- raise ValueError(
310
- "Cached forward requires either position_ids or attention_mask."
311
- )
312
-
313
- def _validate_hf_boundary(
314
- self,
315
- output_attentions: bool | None,
316
- return_dict: bool | None,
317
- inputs_embeds: torch.Tensor | None,
318
- cache_position: torch.Tensor | None,
319
- extra_kwargs: dict[str, Any],
320
- ) -> None:
321
- """Validate unsupported HuggingFace-facing wrapper inputs."""
322
- if output_attentions:
323
- raise NotImplementedError(
324
- "ShramForCausalLM does not expose output_attentions."
325
- )
326
- if return_dict is False:
327
- raise ValueError(
328
- "return_dict=False is not supported. "
329
- "ShramForCausalLM always returns ShramCausalLMOutput."
330
- )
331
- if inputs_embeds is not None:
332
- raise ValueError(
333
- "inputs_embeds is not supported at the SHRAM wrapper boundary. "
334
- "Pass input_ids instead."
335
- )
336
- if extra_kwargs:
337
- unsupported = ", ".join(sorted(extra_kwargs))
338
- raise TypeError(
339
- f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
340
- )
341
-
342
- def _standardize_full_attention_mask(
343
- self,
344
- input_ids: torch.Tensor,
345
- attention_mask: torch.Tensor | None,
346
- ) -> torch.BoolTensor:
347
- """Return a concrete full-sequence boolean attention mask."""
348
- if attention_mask is None:
349
- return torch.ones_like(input_ids, dtype=torch.bool)
350
- return attention_mask.to(dtype=torch.bool)
351
-
352
- def _resolve_current_position_ids(
353
- self,
354
- input_ids: torch.Tensor,
355
- position_ids: torch.Tensor | None,
356
- full_attention_mask: torch.BoolTensor,
357
- ) -> torch.LongTensor:
358
- """Resolve concrete current-step position IDs for the backbone."""
359
- if position_ids is not None:
360
- return position_ids.to(dtype=torch.long)
361
-
362
- full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
363
- full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
364
- current_length = input_ids.shape[1]
365
- return full_position_ids[:, -current_length:]
366
-
367
- def forward(
368
- self,
369
- input_ids: torch.Tensor,
370
- attention_mask: torch.Tensor | None = None,
371
- position_ids: torch.Tensor | None = None,
372
- past_key_values: Cache | None = None,
373
- use_cache: bool | None = None,
374
- output_hidden_states: bool | None = None,
375
- labels: torch.Tensor | None = None,
376
- return_dict: bool | None = None,
377
- **kwargs: Any,
378
- ) -> ShramCausalLMOutput:
379
- """Run the SHRAM causal language model wrapper.
380
-
381
- Args:
382
- input_ids: Current token IDs of shape ``(batch, seq_len)``.
383
- attention_mask: Optional full 2D mask of shape
384
- ``(batch, total_seq_len)``. The wrapper slices its recent chunk
385
- to produce the current semantic liveness mask expected by the
386
- backbone.
387
- position_ids: Optional current-step position IDs of shape
388
- ``(batch, seq_len)``. In ordinary HuggingFace generation this is
389
- already the current-step tensor when it reaches ``forward()``.
390
- past_key_values: Optional SHRAM cache. Required when
391
- ``use_cache=True``.
392
- use_cache: Whether to use and return a cache. Defaults to
393
- ``config.use_cache``.
394
- output_hidden_states: Whether to return backbone hidden states.
395
- Defaults to ``config.output_hidden_states``.
396
- labels: Optional target token IDs of shape ``(batch, seq_len)``.
397
- return_dict: Must be ``True`` or ``None``.
398
- **kwargs: Unsupported HuggingFace kwargs fail explicitly.
399
-
400
- Returns:
401
- ``ShramCausalLMOutput`` with:
402
- - ``logits`` of shape ``(batch, seq_len, vocab_size)``,
403
- - ``loss`` when labels are provided,
404
- - ``past_key_values`` as the active ``ShramCache`` or ``None``,
405
- - ``hidden_states`` when requested,
406
- - ``load_balance_loss`` from the backbone,
407
- - detached ``max_vio`` from the backbone.
408
- """
409
- use_cache = use_cache if use_cache is not None else self.config.use_cache
410
- output_hidden_states = (
411
- output_hidden_states
412
- if output_hidden_states is not None
413
- else self.config.output_hidden_states
414
- )
415
-
416
- inputs_embeds = kwargs.pop("inputs_embeds", None)
417
- output_attentions = kwargs.pop("output_attentions", None)
418
- cache_position = kwargs.pop("cache_position", None)
419
-
420
- # ------------------------------------------------------------------
421
- # Validation zone.
422
- #
423
- # The wrapper boundary is where HuggingFace-facing inputs are judged
424
- # for truthfulness before any internal work begins. These checks are
425
- # intentionally front-loaded so the core logic below can assume one
426
- # coherent interpretation of the call rather than defensively checking
427
- # shapes, cache policy, or unsupported HF knobs at the point of use.
428
- # This keeps the main sequence readable while ensuring invalid states
429
- # fail before they can silently contaminate backbone execution.
430
- # ------------------------------------------------------------------
431
- self._validate_input_ids(input_ids)
432
- self._validate_attention_mask(input_ids, attention_mask)
433
- self._validate_position_ids(input_ids, position_ids)
434
- self._validate_labels(input_ids, labels)
435
- self._validate_cache_inputs(use_cache, past_key_values)
436
- self._validate_position_sources(use_cache, attention_mask, position_ids)
437
- self._validate_hf_boundary(
438
- output_attentions=output_attentions,
439
- return_dict=return_dict,
440
- inputs_embeds=inputs_embeds,
441
- cache_position=cache_position,
442
- extra_kwargs=kwargs,
443
- )
444
-
445
- # ------------------------------------------------------------------
446
- # Standardization zone.
447
- #
448
- # HuggingFace and SHRAM use different boundary conventions: generation
449
- # carries a full-sequence 2D attention mask, while the SHRAM backbone
450
- # wants a current-step active mask and concrete current position IDs.
451
- # This zone collapses those wrapper-facing conventions into one valid
452
- # backbone-facing state. After this point the core no longer reasons
453
- # about optional or ambiguous input forms; it works only with concrete
454
- # tensors whose semantics are already fixed.
455
- # ------------------------------------------------------------------
456
- full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
457
- input_ids=input_ids,
458
- attention_mask=attention_mask,
459
- )
460
- current_length: int = input_ids.shape[1]
461
- current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
462
- current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
463
- input_ids=input_ids,
464
- position_ids=position_ids,
465
- full_attention_mask=full_attention_mask,
466
- )
467
- shram_cache: ShramCache | None = past_key_values if use_cache else None
468
-
469
- # ------------------------------------------------------------------
470
- # Core wrapper responsibilities.
471
- #
472
- # The wrapper's primary job is kept visible here: convert token IDs to
473
- # embeddings, delegate transformer computation to ShramModel, project
474
- # hidden states back to vocabulary logits, optionally compute the
475
- # wrapper-level shifted next-token loss, and return the HuggingFace-
476
- # facing output object. The backbone remains responsible only for
477
- # transformer semantics; token/vocabulary/loss concerns stay here.
478
- # ------------------------------------------------------------------
479
- token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
480
- backbone_outputs = self.model(
481
- inputs_embeds=token_embeddings,
482
- position_ids=current_position_ids,
483
- active_mask=current_active_mask,
484
- cache=shram_cache,
485
- output_hidden_states=output_hidden_states,
486
- )
487
-
488
- logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
489
-
490
- loss: torch.FloatTensor | None = None
491
- if labels is not None:
492
- shift_logits = logits[:, :-1, :].contiguous()
493
- shift_labels = labels[:, 1:].contiguous()
494
- loss = nn.functional.cross_entropy(
495
- shift_logits.view(-1, self.config.vocab_size),
496
- shift_labels.view(-1),
497
- )
498
-
499
- return ShramCausalLMOutput(
500
- loss=loss,
501
- logits=logits,
502
- past_key_values=backbone_outputs["past_key_values"],
503
- hidden_states=backbone_outputs["hidden_states"],
504
- load_balance_loss=backbone_outputs["load_balance_loss"],
505
- max_vio=backbone_outputs["max_vio"],
506
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/mlp.py DELETED
@@ -1,52 +0,0 @@
1
- """SwiGLU feed-forward sublayer.
2
-
3
- SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
4
- element-wise against a separate up-projection:
5
-
6
- output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
7
-
8
- The gating mechanism gives the network more expressive control over which features
9
- to propagate than a plain two-matrix FFN. It requires three weight matrices instead
10
- of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
11
- typical of two-matrix FFNs — the total parameter count remains comparable.
12
-
13
- SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
14
- — a fixed architectural choice.
15
- """
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from transformers import PretrainedConfig
21
-
22
-
23
- class SwiGLUMLP(nn.Module):
24
- """SwiGLU feed-forward sublayer.
25
-
26
- Implements the three-matrix SwiGLU FFN used in Llama 3:
27
-
28
- output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
29
-
30
- No bias on any projection. SiLU as the gate activation is an architectural
31
- constant — it is what defines SwiGLU specifically.
32
-
33
- Args:
34
- config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
35
- """
36
-
37
- def __init__(self, config: PretrainedConfig) -> None:
38
- super().__init__()
39
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
40
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
41
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
42
-
43
- def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- """Apply the SwiGLU feed-forward transformation.
45
-
46
- Args:
47
- x: Input tensor of shape (batch, seq_len, hidden_size).
48
-
49
- Returns:
50
- Output tensor of shape (batch, seq_len, hidden_size).
51
- """
52
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/model.py DELETED
@@ -1,132 +0,0 @@
1
- """Transformer backbone for Shram.
2
-
3
- ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
4
- by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
5
- representations. It has no knowledge of tokens, vocabulary, generation, or the
6
- HuggingFace causal-LM wrapper contract.
7
-
8
- Keeping the embedding out of the backbone is the correct convention and makes
9
- the backbone genuinely modality-agnostic. The token interface — embedding lookup,
10
- LM head, weight tying, and generation-facing naming conventions — belongs on the
11
- task wrapper (ShramForCausalLM), which is the only class that knows this
12
- backbone is being used for language modelling.
13
-
14
- The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
15
- each sublayer normalises its own input, leaving the residual stream itself
16
- unnormalised. After many layers of accumulated residuals, that stream arrives at
17
- the top with uncontrolled magnitude. The final norm brings it to a well-scaled
18
- state before any projection. Without it, the LM head would receive signals of
19
- arbitrary scale.
20
-
21
- Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
22
- corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
23
- same top-level ShramCache object in the output dict. If None is provided, no
24
- caching occurs.
25
-
26
- Returns a plain dict with keys:
27
- - "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
28
- - "past_key_values": the ShramCache object passed in, or None
29
- - "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
30
- - "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
31
- - "max_vio": detached scalar maximum routing-imbalance across all decoder layers
32
- """
33
-
34
- import torch
35
- import torch.nn as nn
36
-
37
- from .cache.shram_cache import ShramCache
38
- from .configuration import ShramConfig
39
- from .decoder_layer import DecoderLayer
40
-
41
-
42
- class ShramModel(nn.Module):
43
- """Pure transformer backbone: decoder stack and final normalisation.
44
-
45
- Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
46
- and returns contextual representations of the same shape. No token embedding,
47
- vocabulary projection, or causal-LM lifecycle concerns.
48
-
49
- RoPE is applied inside each attention layer. Positional information is
50
- encoded in the relationship between Q and K, not added to the residual
51
- stream, so the backbone is agnostic to how positions are represented.
52
-
53
- Args:
54
- config: Model configuration. Must be a ``ShramConfig`` instance.
55
- """
56
-
57
- def __init__(self, config: ShramConfig) -> None:
58
- super().__init__()
59
- self.config = config
60
- self.layers = nn.ModuleList(
61
- [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
62
- )
63
- self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
64
-
65
- def forward(
66
- self,
67
- inputs_embeds: torch.Tensor,
68
- position_ids: torch.Tensor,
69
- active_mask: torch.Tensor,
70
- cache: ShramCache | None = None,
71
- output_hidden_states: bool = False,
72
- ) -> dict:
73
- """Run the transformer stack over a batch of pre-embedded sequences.
74
-
75
- Args:
76
- inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
77
- position_ids: Absolute positions of shape (batch, seq_len). Required.
78
- Must be provided explicitly by the caller — this module does not
79
- infer positions from cache state.
80
- active_mask: Current-chunk active mask of shape (batch, seq_len),
81
- where True means the token is semantically live. Forwarded
82
- unchanged to every decoder layer.
83
- cache: Optional top-level ShramCache. When provided, each DecoderLayer
84
- receives its own layer-local cache via ``cache.layers[layer_idx]``.
85
- The top-level cache object is updated in place and returned unchanged.
86
- output_hidden_states: When True, the output dict includes a tuple of
87
- per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
88
- collected before the final norm.
89
-
90
- Returns:
91
- Plain dict with keys:
92
- - ``"last_hidden_state"``: normed backbone output,
93
- shape (batch, seq_len, hidden_size).
94
- - ``"past_key_values"``: the cache object passed in, or None.
95
- - ``"hidden_states"``: tuple of per-layer activations (including
96
- inputs_embeds as position 0) if ``output_hidden_states`` is True,
97
- else None. Collected before the final norm so each entry reflects the
98
- unnormalised residual stream at that depth.
99
- - ``"load_balance_loss"``: scalar sum of per-layer SHRAM
100
- load-balance losses.
101
- - ``"max_vio"``: detached scalar maximum routing-imbalance across
102
- all decoder layers. Zero means perfectly balanced routing across
103
- every layer; higher values identify the worst-case head imbalance.
104
- """
105
- hidden_states = inputs_embeds
106
- all_hidden_states = (hidden_states,) if output_hidden_states else None
107
- total_load_balance_loss = inputs_embeds.new_zeros(())
108
- max_vio = inputs_embeds.new_zeros(())
109
-
110
- for layer_idx, layer in enumerate(self.layers):
111
- layer_cache = None if cache is None else cache.layers[layer_idx]
112
- hidden_states, layer_load_balance_loss, layer_max_vio = layer(
113
- hidden_states,
114
- position_ids,
115
- active_mask,
116
- cache=layer_cache,
117
- )
118
- total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
119
- max_vio = torch.maximum(max_vio, layer_max_vio)
120
-
121
- if output_hidden_states:
122
- all_hidden_states = all_hidden_states + (hidden_states,)
123
-
124
- hidden_states = self.norm(hidden_states)
125
-
126
- return {
127
- "last_hidden_state": hidden_states,
128
- "past_key_values": cache,
129
- "hidden_states": all_hidden_states,
130
- "load_balance_loss": total_load_balance_loss,
131
- "max_vio": max_vio,
132
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/rope.py DELETED
@@ -1,291 +0,0 @@
1
- """Rotary Position Embeddings (RoPE).
2
-
3
- RoPE encodes position in the *relationship* between query and key vectors. When the
4
- attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
5
- a score that depends only on the relative distance — not on absolute positions.
6
-
7
- Two modes are supported:
8
-
9
- default Standard RoPE with base frequency b. Each dimension pair d is assigned
10
- frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
11
- scaling A_rope = 1.
12
-
13
- yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
14
- "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
15
- §A.2). Three frequency regimes:
16
- - Low-frequency dimensions (r < α): fully interpolated by scale s.
17
- These dimensions have long wavelengths relative to the training window
18
- and must be compressed to avoid out-of-distribution positions.
19
- - High-frequency dimensions (r > β): left unchanged. Short-wavelength
20
- dimensions already encode relative position accurately at any scale.
21
- - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
22
- Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
23
- standard RoPE.
24
-
25
- Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
26
- parameters — no shared instance, no config reading. See Unit 5.A design decisions.
27
-
28
- Cache sharing: all instances with identical parameters share one cos/sin table via a
29
- class-level registry. The first instance that needs a particular (parameters, seq_len,
30
- device, dtype) combination builds the table; all subsequent instances reference it
31
- directly. This avoids redundant builds across the num_hidden_layers instances that
32
- share the same parametrisation.
33
- """
34
-
35
- import math
36
-
37
- import torch
38
- import torch.nn as nn
39
-
40
-
41
- # ---------------------------------------------------------------------------
42
- # Rotation helper
43
- # ---------------------------------------------------------------------------
44
-
45
- def _rotate_half(x: torch.Tensor) -> torch.Tensor:
46
- """Apply the 90° rotation used in the RoPE update formula.
47
-
48
- Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
49
- Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
50
- on each consecutive pair of dimensions, matching the block-diagonal operator
51
- R^u_{Θ,p} in the paper.
52
- """
53
- d = x.shape[-1] // 2
54
- x1, x2 = x[..., :d], x[..., d:]
55
- return torch.cat([-x2, x1], dim=-1)
56
-
57
-
58
- # ---------------------------------------------------------------------------
59
- # RotaryEmbedding
60
- # ---------------------------------------------------------------------------
61
-
62
- class RotaryEmbedding(nn.Module):
63
- """Rotary Position Embeddings with explicit mode and parameter control.
64
-
65
- Each caller constructs its own instance with the exact parameters it needs.
66
- h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
67
- config object is read inside this module.
68
-
69
- The cos/sin cache is built lazily on the first forward call and extended
70
- automatically when a longer sequence is encountered. Instances with identical
71
- parameters share one cache via the class-level ``_cache`` registry,
72
- avoiding redundant computation across decoder layers.
73
-
74
- Args:
75
- mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
76
- head_dim: Per-head embedding dimension ``u``. Must be even.
77
- theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
78
- initial_seq_length: ``C_train`` — context length the model was trained at.
79
- Required for ``mode="yarn"``.
80
- dilation: Scale factor ``s = C_target / C_train`` — how much the context
81
- window is extended beyond training length. Required for ``mode="yarn"``.
82
- When ``dilation=1.0``, YaRN reduces to standard RoPE.
83
- alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
84
- interpolated. Required for ``mode="yarn"``.
85
- beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
86
- unchanged. Required for ``mode="yarn"``.
87
- device: Optional device for initial buffer placement.
88
-
89
- Raises:
90
- NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
91
- ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
92
- ``dilation``, ``alpha``, ``beta`` are absent.
93
- """
94
-
95
- # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
96
- # Shared across all RotaryEmbedding instances in the process. Keys include device
97
- # and dtype so that tables built on different devices or in different precisions
98
- # are stored independently.
99
- _cache: dict = {}
100
-
101
- def __init__(
102
- self,
103
- mode: str,
104
- head_dim: int,
105
- theta: float,
106
- initial_seq_length: int | None = None,
107
- dilation: float | None = None,
108
- alpha: float | None = None,
109
- beta: float | None = None,
110
- device: torch.device | None = None,
111
- ) -> None:
112
- super().__init__()
113
-
114
- self._validate_mode(mode)
115
- self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
116
- self.mode = mode
117
-
118
- # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
119
- # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
120
- # so rotation_freqs has head_dim/2 entries.
121
- d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
122
- base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
123
-
124
- if mode == "default":
125
- rotation_freqs = base_freqs
126
- self.attention_scaling: float = 1.0
127
-
128
- else: # yarn
129
- s = dilation
130
-
131
- # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
132
- # function to classify each dimension into one of three regimes.
133
- normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
134
-
135
- # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
136
- # linear blend between α and β.
137
- blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
138
-
139
- # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
140
- rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
141
-
142
- # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
143
- self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
144
-
145
- # freq_key uniquely identifies the parameter set that produced rotation_freqs.
146
- # Used as the primary component of the cache registry key.
147
- if mode == "default":
148
- self._freq_key: tuple = ("default", head_dim, float(theta))
149
- else:
150
- self._freq_key = (
151
- "yarn", head_dim, float(theta),
152
- int(initial_seq_length), float(dilation),
153
- float(alpha), float(beta),
154
- )
155
-
156
- # rotation_freqs is a non-persistent buffer so it moves with the model across
157
- # devices via .to() / .cuda() without appearing in saved checkpoints.
158
- # It is stored per-instance rather than in the shared cache because it is
159
- # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
160
- # it is used to build. The meaningful sharing win is on those tables.
161
- self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
162
-
163
- # Cache tensors are plain instance attributes (not registered buffers) so that
164
- # sharing across identically-parametrised instances survives .to() calls.
165
- # Registered buffers are copied on device move; plain attributes are aliased,
166
- # preserving the shared-tensor identity that the cache design depends on.
167
- self._cos_cached: torch.Tensor | None = None
168
- self._sin_cached: torch.Tensor | None = None
169
-
170
- # ---------------------------------------------------------------------------
171
- # Validation helpers
172
- # ---------------------------------------------------------------------------
173
-
174
- @staticmethod
175
- def _validate_mode(mode: str) -> None:
176
- """Raise NotImplementedError if mode is not a supported value."""
177
- if mode not in {"default", "yarn"}:
178
- raise NotImplementedError(
179
- f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
180
- )
181
-
182
- @staticmethod
183
- def _validate_yarn_params(
184
- mode: str,
185
- initial_seq_length: int | None,
186
- dilation: float | None,
187
- alpha: float | None,
188
- beta: float | None,
189
- ) -> None:
190
- """Raise ValueError if mode='yarn' and any required parameter is absent."""
191
- if mode != "yarn":
192
- return
193
- missing = [
194
- name for name, val in [
195
- ("initial_seq_length", initial_seq_length),
196
- ("dilation", dilation),
197
- ("alpha", alpha),
198
- ("beta", beta),
199
- ]
200
- if val is None
201
- ]
202
- if missing:
203
- raise ValueError(f"mode='yarn' requires {missing}.")
204
-
205
- # ---------------------------------------------------------------------------
206
- # Cache management
207
- # ---------------------------------------------------------------------------
208
-
209
- def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
210
- """Build the cos/sin table to cover positions [0, seq_len).
211
-
212
- Checks the class-level registry first. If a table already exists for this
213
- exact (parameters, seq_len, device, dtype) combination it is reused directly;
214
- otherwise it is computed and stored. The instance attributes are pointed at
215
- the registry entry so that all layers sharing the same parametrisation
216
- reference the same tensor.
217
- """
218
- cache_key = (self._freq_key, seq_len, str(device), str(dtype))
219
-
220
- if cache_key not in RotaryEmbedding._cache:
221
- positions = torch.arange(seq_len, device=device, dtype=torch.float32)
222
- # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
223
- freqs = torch.outer(
224
- positions,
225
- self.rotation_freqs.to(device=device, dtype=torch.float32),
226
- )
227
- angle_embedding = torch.cat((freqs, freqs), dim=-1)
228
- RotaryEmbedding._cache[cache_key] = (
229
- angle_embedding.cos().to(dtype),
230
- angle_embedding.sin().to(dtype),
231
- )
232
-
233
- self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
234
-
235
- def forward(
236
- self,
237
- q: torch.Tensor,
238
- k: torch.Tensor,
239
- position_ids: torch.Tensor,
240
- ) -> tuple[torch.Tensor, torch.Tensor, float]:
241
- """Apply rotary embeddings to query and key tensors.
242
-
243
- The cos/sin cache is extended lazily when position_ids reference positions
244
- beyond its current length, or when the device or dtype has changed.
245
-
246
- ``position_ids`` may be any integer tensor shape. Its values are valid
247
- position indices into the cos/sin cache:
248
-
249
- - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
250
- - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
251
-
252
- When q/k have head dimensions absent from position_ids, broadcast dimensions
253
- are inserted automatically at dim 1.
254
-
255
- Args:
256
- q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
257
- k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
258
- position_ids: Integer positions of shape (batch, *pos_dims).
259
-
260
- Returns:
261
- Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
262
- 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
263
- apply to attention logits before softmax.
264
- """
265
- seq_len = int(position_ids.max().item()) + 1
266
-
267
- # The cache is valid when it exists, covers all positions referenced by
268
- # position_ids, and matches q's dtype and device. Each condition is named
269
- # separately so the rebuild trigger is readable rather than a compound predicate.
270
- cache_missing = self._cos_cached is None
271
- cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
272
- wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
273
- wrong_device = not cache_missing and self._cos_cached.device != q.device
274
-
275
- if cache_missing or cache_too_short or wrong_dtype or wrong_device:
276
- self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
277
-
278
- cos = self._cos_cached[position_ids]
279
- sin = self._sin_cached[position_ids]
280
-
281
- # Insert broadcast dimensions for any head axes present in q/k but absent
282
- # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
283
- # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
284
- while cos.ndim < q.ndim:
285
- cos = cos.unsqueeze(1)
286
- sin = sin.unsqueeze(1)
287
-
288
- q_rotated = q * cos + _rotate_half(q) * sin
289
- k_rotated = k * cos + _rotate_half(k) * sin
290
-
291
- return q_rotated, k_rotated, self.attention_scaling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
architecture_core/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
architecture_core/tokenizer_config.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "add_prefix_space": false,
3
- "backend": "tokenizers",
4
- "bos_token": "<|endoftext|>",
5
- "eos_token": "<|endoftext|>",
6
- "errors": "replace",
7
- "is_local": false,
8
- "model_max_length": 1000000000000000019884624838656,
9
- "pad_token": "<|padding|>",
10
- "tokenizer_class": "GPTNeoXTokenizerFast",
11
- "trim_offsets": true,
12
- "unk_token": "<|endoftext|>"
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/__init__.py DELETED
File without changes
attention/bottlenecked_ensemble_attention.py DELETED
@@ -1,252 +0,0 @@
1
- """Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
2
-
3
- BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
4
- It consumes packed expert-choice tensors, a supplied position tensor, an active
5
- token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
6
- same packed expert-choice space expected by later unpacking.
7
-
8
- BEA does not compute positions and does not choose packed-position semantics.
9
- Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
10
- (K̃) and raw values (V) into the sparse cache and attends against the
11
- accumulated cached state returned by that cache.
12
- """
13
-
14
- import math
15
-
16
- import torch
17
- import torch.nn as nn
18
- from torch.nn.attention.flex_attention import create_block_mask, flex_attention
19
-
20
- from ..configuration import ShramConfig
21
- from ..cache.mosrah_cache import MoSRAHCache
22
- from ..rope import RotaryEmbedding
23
-
24
-
25
- class BottleneckedEnsembleAttention(nn.Module):
26
- """
27
- Packed expert-choice attention operator for the MoSRAH sparse path.
28
- Operates per-head independently on an ensemble of tokens.
29
- FlexAttention saves flops on dead tokens.
30
-
31
- Architectural properties:
32
- - consumes packed expert-choice tensors of shape (B, L, T, d)
33
- - uses independent per-head Q/K/V/O projection parameters
34
- - applies YaRN-capable RoPE using supplied position_ids
35
- - stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
36
- - uses a fast fused attention path
37
- - returns outputs in the same packed expert-choice space (B, L, T, d)
38
-
39
- Args:
40
- config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
41
- `head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
42
- `inference_sequence_length`, `alpha`, and `beta`.
43
- """
44
-
45
- def __init__(self, config: ShramConfig) -> None:
46
- super().__init__()
47
-
48
- self.hidden_size = config.hidden_size
49
- self.num_heads = config.num_mosrah_heads
50
- self.head_dim = config.head_dim
51
-
52
- # Independent per-head projections. No cross-head parameter sharing.
53
- self.q_proj = nn.Parameter(
54
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
55
- )
56
- self.k_proj = nn.Parameter(
57
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
58
- )
59
- self.v_proj = nn.Parameter(
60
- torch.empty(self.num_heads, self.hidden_size, self.head_dim)
61
- )
62
- self.o_proj = nn.Parameter(
63
- torch.empty(self.num_heads, self.head_dim, self.hidden_size)
64
- )
65
-
66
- self._reset_parameters()
67
-
68
- # BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
69
- # this unit only consumes it. In training modes, dilation will be 1.0 and so
70
- # no yarn dilation occurs.
71
- self.rope = RotaryEmbedding(
72
- mode="yarn",
73
- head_dim=self.head_dim,
74
- theta=config.mosrah_rope_theta,
75
- initial_seq_length=config.training_sequence_length,
76
- dilation=config.scale,
77
- alpha=config.alpha,
78
- beta=config.beta,
79
- )
80
-
81
- def forward(
82
- self,
83
- packed_embeddings: torch.Tensor,
84
- position_ids: torch.Tensor,
85
- active_mask: torch.Tensor,
86
- cache: MoSRAHCache | None = None,
87
- ) -> torch.Tensor:
88
- """Apply BEA to packed expert-choice tensors.
89
-
90
- Args:
91
- packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
92
- position_ids: Supplied packed positions of shape (B, L, T).
93
- active_mask: Boolean active-token mask of shape (B, L, T).
94
- cache: Optional layer-local MoSRAH cache.
95
-
96
- Returns:
97
- Packed expert-choice output tensor of shape (B, L, T, d).
98
- """
99
- batch_size, _, query_length, _ = packed_embeddings.shape
100
- self._validate_tensor_shape(packed_embeddings)
101
- self._validate_position_shape(packed_embeddings, position_ids)
102
- self._validate_active_mask_shape(packed_embeddings, active_mask)
103
-
104
- # Independent per-head projections:
105
- # (B, L, T, d) x (L, d, u) -> (B, L, T, u)
106
- query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
107
- key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
108
- value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
109
-
110
- rotated_query_states, rotated_key_states, attention_scaling = self.rope(
111
- query_states,
112
- key_states,
113
- position_ids,
114
- )
115
-
116
- if cache is not None:
117
- # In cached execution, the current query tensor uses local tensor rows
118
- # 0..Q-1, but the key tensor returned by the cache is the full accumulated
119
- # packed sequence for each (batch, head) slot. The only additional data
120
- # needed to align those two views is the pre-update cached prefix length.
121
- # which will indicate how many queries were processed before now.
122
- num_tokens_processed = cache.get_heads_lengths().clone()
123
- key_states, value_states, key_active_mask = cache.update(
124
- rotated_key_states,
125
- value_states,
126
- active_mask,
127
- )
128
- else:
129
- num_tokens_processed = torch.zeros(
130
- batch_size,
131
- self.num_heads,
132
- dtype=torch.long,
133
- device=packed_embeddings.device,
134
- )
135
- key_states = rotated_key_states
136
- key_active_mask = active_mask
137
-
138
- block_mask = self._make_block_mask(
139
- query_active_mask=active_mask,
140
- key_active_mask=key_active_mask,
141
- num_tokens_processed=num_tokens_processed,
142
- query_length=query_length,
143
- key_length=key_states.shape[2],
144
- device=packed_embeddings.device,
145
- )
146
- attended_states = flex_attention(
147
- rotated_query_states,
148
- key_states,
149
- value_states,
150
- block_mask=block_mask,
151
- scale=attention_scaling / math.sqrt(self.head_dim),
152
- )
153
-
154
- # Project back to model width:
155
- # (B, L, T, u) x (L, u, d) -> (B, L, T, d)
156
- return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
157
-
158
- def _reset_parameters(self) -> None:
159
- """Initialize per-head projection weights."""
160
- for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
161
- nn.init.xavier_uniform_(weight)
162
-
163
- def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
164
- """Validate the local packed-embedding shape contract required by BEA."""
165
- if packed_embeddings.shape[1] != self.num_heads:
166
- raise ValueError(
167
- f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
168
- f"got {packed_embeddings.shape[1]}."
169
- )
170
-
171
- if packed_embeddings.shape[-1] != self.hidden_size:
172
- raise ValueError(
173
- f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
174
- f"got {packed_embeddings.shape[-1]}."
175
- )
176
-
177
- def _validate_position_shape(
178
- self,
179
- packed_embeddings: torch.Tensor,
180
- position_ids: torch.Tensor,
181
- ) -> None:
182
- """Validate the supplied packed-position tensor shape."""
183
- if position_ids.shape != packed_embeddings.shape[:3]:
184
- raise ValueError(
185
- f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
186
- f"got {tuple(position_ids.shape)}."
187
- )
188
-
189
- def _validate_active_mask_shape(
190
- self,
191
- packed_embeddings: torch.Tensor,
192
- active_mask: torch.Tensor,
193
- ) -> None:
194
- """Validate the supplied active-token mask shape."""
195
- if active_mask.shape != packed_embeddings.shape[:3]:
196
- raise ValueError(
197
- f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
198
- f"got {tuple(active_mask.shape)}."
199
- )
200
-
201
- def _make_block_mask(
202
- self,
203
- query_active_mask: torch.Tensor,
204
- key_active_mask: torch.Tensor,
205
- num_tokens_processed: torch.Tensor,
206
- query_length: int,
207
- key_length: int,
208
- device: torch.device,
209
- ):
210
- """Create the packed-sequence causal mask for FlexAttention.
211
-
212
- At the root, causality is still triangular. The only nuance is cached
213
- execution: query rows are indexed locally as 0..Q-1 inside the current
214
- query tensor, but the key tensor may already contain a cached prefix for
215
- that (batch, head) slot. The causal horizon for query tensor row q is
216
- therefore:
217
-
218
- cached_prefix_lengths[b, h] + q
219
-
220
- Query and key activity masks are then composed with that triangular rule
221
- so FlexAttention can skip padded query rows and ignore inactive key slots.
222
- """
223
- batch_size, num_heads, _ = query_active_mask.shape
224
-
225
- # Build the per-(batch, head, query_row) triangular horizon from a simple
226
- # arange over query rows plus the cached prefix lengths for each slot.
227
- relative_query_positions = torch.arange(
228
- query_length,
229
- device=device,
230
- dtype=torch.long,
231
- ).view(1, 1, query_length)
232
- causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
233
-
234
- def packed_causal_mask(
235
- batch_idx: torch.Tensor,
236
- head_idx: torch.Tensor,
237
- query_idx: torch.Tensor,
238
- key_idx: torch.Tensor,
239
- ) -> torch.Tensor:
240
- query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
241
- key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
242
- is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
243
- return query_is_active & key_is_active & is_causal
244
-
245
- return create_block_mask(
246
- packed_causal_mask,
247
- B=batch_size,
248
- H=num_heads,
249
- Q_LEN=query_length,
250
- KV_LEN=key_length,
251
- device=device,
252
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/expert_packing.py DELETED
@@ -1,335 +0,0 @@
1
- """Expert packing and unpacking for the MoSRAH path.
2
-
3
- This module implements the low-level token-choice -> expert-choice -> token-choice
4
- conversion boundary specified in the paper. The externally visible behavior is fixed:
5
-
6
- - setup_packing() prepares the auxiliary ordering data.
7
- - pack_experts() converts routed token-choice state into packed expert-choice state.
8
- - unpack_experts() restores token-choice ordering afterward.
9
-
10
- Stable sort is a correctness requirement. It preserves causal ordering inside each
11
- expert bucket, which is the foundation on which BEA's later triangular causal mask
12
- is correct.
13
-
14
- pack_experts() returns two distinct masks that serve different roles and must not
15
- be interchanged:
16
-
17
- - unpacking_mask: marks every packed slot that contains a routed token copy,
18
- live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
19
- so its reshape invariant holds regardless of outer token liveness.
20
- - active_mask: marks only the packed slots whose source token was semantically
21
- live. This is what BEA consumes for attention gating. Dead outer tokens must
22
- not influence sparse attention outputs.
23
- """
24
-
25
- import torch
26
-
27
-
28
- # ---------------------------------------------------------------------------
29
- # Setup
30
- # ---------------------------------------------------------------------------
31
-
32
- def setup_packing(
33
- selected_heads: torch.Tensor,
34
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
- """Prepare the auxiliary ordering data used by pack/unpack.
36
-
37
- Routing produces token-choice state I of shape (B, N, K): for each token, which
38
- K experts were selected. Packing needs the same routed token copies reordered into
39
- expert-major order so each expert bucket becomes contiguous.
40
-
41
- The paper's setup step does this by flattening (N, K) into one axis to produce
42
- H in token-major order, then computing a stable argsort permutation Pi over the
43
- expert indices stored in H. Applying Pi reorders the flattened routed copies into
44
- expert-major order while preserving their original token order *within* each expert
45
- bucket. That preservation is why stable sort is required for causality.
46
-
47
- Args:
48
- selected_heads: Routed token-choice head selections I of shape (B, N, K).
49
-
50
- Returns:
51
- Tuple of:
52
- - flattened_selected_heads: H of shape (B, N*K)
53
- - permutation: stable expert-major permutation Pi of shape (B, N*K)
54
- - inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
55
- """
56
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
57
- flattened_selected_heads = selected_heads.reshape(
58
- batch_size,
59
- sequence_length * num_selected_heads,
60
- )
61
-
62
- permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
63
- inverse_permutation = torch.argsort(permutation, dim=-1)
64
-
65
- return flattened_selected_heads, permutation, inverse_permutation
66
-
67
-
68
- # ---------------------------------------------------------------------------
69
- # Packing
70
- # ---------------------------------------------------------------------------
71
-
72
- def pack_experts(
73
- hidden_states: torch.Tensor,
74
- position_ids: torch.Tensor,
75
- selected_heads: torch.Tensor,
76
- num_experts: int,
77
- flattened_selected_heads: torch.Tensor,
78
- permutation: torch.Tensor,
79
- outer_active_mask: torch.Tensor,
80
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
81
- """Pack token-choice hidden states into expert-choice padded form.
82
-
83
- The paper's packing path has two jobs:
84
-
85
- 1. Convert routed token-choice copies into expert-major order.
86
- 2. Materialize that expert-major order into a padded tensor layout BEA can consume.
87
-
88
- The routed hidden-state copies are not stored explicitly in token-choice form.
89
- Instead, the same token hidden state is conceptually copied once per selected expert.
90
- The packing step reconstructs those copies by expanding local source-token indices,
91
- reordering those indices with Pi, then gathering hidden states, positions, and outer
92
- liveness in that packed order. All three are carried through the same expert-major
93
- rearrangement so they remain aligned in the packed frame.
94
-
95
- Packed positions are sourced from the authoritative upstream position_ids tensor
96
- rather than synthesized locally from arange(N). This preserves advanced positions
97
- correctly during cached inference while leaving training/full-sequence behavior
98
- unchanged when position_ids is the ordinary sequential token positions.
99
-
100
- Args:
101
- hidden_states: Token-choice hidden states x of shape (B, N, d).
102
- position_ids: Authoritative upstream token positions J of shape (B, N).
103
- selected_heads: Routed head selections I of shape (B, N, K).
104
- num_experts: Total number of experts L.
105
- flattened_selected_heads: H from setup_packing(), shape (B, N*K).
106
- permutation: Pi from setup_packing(), shape (B, N*K).
107
- outer_active_mask: Current-chunk active mask of shape (B, N), where True
108
- means the token is semantically live. Dead tokens do not become
109
- semantically active in the packed sparse representation.
110
-
111
- Returns:
112
- Tuple of:
113
- - packed_hidden_states: x' of shape (B, L, T, d)
114
- - packed_positions: J' of shape (B, L, T)
115
- - unpacking_mask: of shape (B, L, T). True where a slot contains any
116
- routed token copy, live or dead. Always has exactly B*N*K True entries.
117
- Pass this to unpack_experts — not active_mask.
118
- - active_mask: of shape (B, L, T). True only where a slot contains a
119
- copy of a live outer token. Pass this to BEA for attention gating.
120
- """
121
- batch_size, sequence_length, hidden_dim = hidden_states.shape
122
- _, _, num_selected_heads = selected_heads.shape
123
-
124
- # -----------------------------------------------------------------------
125
- # Reconstruct routed local source-token indices in token-choice order.
126
- #
127
- # The internal arange(N) is no longer the packed position tensor. It is only
128
- # the local source-row index object used to gather from the current chunk
129
- # tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
130
- # token-major routed-copy order.
131
- # -----------------------------------------------------------------------
132
- source_token_indices = torch.arange(
133
- sequence_length,
134
- device=hidden_states.device,
135
- dtype=torch.long,
136
- ).view(1, sequence_length, 1).expand(
137
- batch_size,
138
- sequence_length,
139
- num_selected_heads,
140
- )
141
- flattened_source_indices = source_token_indices.reshape(
142
- batch_size,
143
- sequence_length * num_selected_heads,
144
- )
145
-
146
- # -----------------------------------------------------------------------
147
- # Reorder source-token indices into expert-major order.
148
- #
149
- # Applying Pi yields the local source-token rows in the packed expert-major
150
- # order required by the paper. Those same reordered source indices are then
151
- # used to gather hidden states, authoritative upstream positions, and outer
152
- # liveness so all three remain aligned under the exact same packing
153
- # transformation.
154
- # -----------------------------------------------------------------------
155
- sorted_source_indices = flattened_source_indices.gather(
156
- dim=1,
157
- index=permutation,
158
- )
159
- sorted_hidden_states = hidden_states.gather(
160
- dim=1,
161
- index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
162
- )
163
- sorted_positions = position_ids.gather(
164
- dim=1,
165
- index=sorted_source_indices,
166
- )
167
- sorted_active_mask = outer_active_mask.gather(
168
- dim=1,
169
- index=sorted_source_indices,
170
- )
171
-
172
- # -----------------------------------------------------------------------
173
- # Count how many routed copies land in each expert bucket.
174
- #
175
- # S[b, l] is the number of routed token copies assigned to expert l in batch b.
176
- # T is the maximum such count across all batches and experts; it determines the
177
- # padded expert-length dimension of the packed representation.
178
- # -----------------------------------------------------------------------
179
- tokens_per_expert = _bincount_rows(
180
- values=flattened_selected_heads,
181
- num_bins=num_experts,
182
- )
183
- max_tokens_per_expert = int(tokens_per_expert.max().item())
184
-
185
- # -----------------------------------------------------------------------
186
- # Construct the active-token mask M.
187
- #
188
- # Each expert bucket is left-justified: if S[b, l] = s, then slots
189
- # t = 0, ..., s-1 are active and all later slots are padding. The resulting
190
- # mask therefore both identifies real packed tokens and enforces left-justified
191
- # packing. This is the unpacking_mask — it marks slot occupancy regardless of
192
- # outer token liveness, and always has exactly B*N*K True entries.
193
- # -----------------------------------------------------------------------
194
- time_axis = torch.arange(
195
- max_tokens_per_expert,
196
- device=hidden_states.device,
197
- dtype=torch.long,
198
- ).view(1, 1, max_tokens_per_expert)
199
- unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
200
-
201
- # -----------------------------------------------------------------------
202
- # Materialize the padded packed tensors.
203
- #
204
- # The packed hidden states x', packed original-token positions J', and packed
205
- # active-token mask are allocated as zero-filled tensors. Active entries are
206
- # then written into those buffers in the expert-major order established above.
207
- # Padding remains zero / inactive.
208
- # -----------------------------------------------------------------------
209
- packed_hidden_states = hidden_states.new_zeros(
210
- batch_size,
211
- num_experts,
212
- max_tokens_per_expert,
213
- hidden_dim,
214
- )
215
- packed_positions = position_ids.new_zeros(
216
- batch_size,
217
- num_experts,
218
- max_tokens_per_expert,
219
- )
220
- active_mask = torch.zeros(
221
- batch_size,
222
- num_experts,
223
- max_tokens_per_expert,
224
- dtype=torch.bool,
225
- device=hidden_states.device,
226
- )
227
-
228
- packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
229
- packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
230
- active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
231
-
232
- return packed_hidden_states, packed_positions, unpacking_mask, active_mask
233
-
234
-
235
- # ---------------------------------------------------------------------------
236
- # Unpacking
237
- # ---------------------------------------------------------------------------
238
-
239
- def unpack_experts(
240
- expert_outputs: torch.Tensor,
241
- selected_heads: torch.Tensor,
242
- unpacking_mask: torch.Tensor,
243
- inverse_permutation: torch.Tensor,
244
- ) -> torch.Tensor:
245
- """Restore token-choice ordering from BEA expert-choice output.
246
-
247
- Unpacking inverts the packing path only on occupied entries. Padding does not
248
- participate: the output tensor is first filtered by unpacking_mask to recover
249
- only the real routed-token copies in expert-major order, then Pi^{-1} restores
250
- the original token-choice ordering, and finally the tensor is reshaped back to
251
- (B, N, K, d).
252
-
253
- The unpacking_mask — not active_mask — must be used here. Even copies of dead
254
- outer tokens occupy slots and must be un-scattered correctly for the inverse
255
- permutation to hold. The total True entry count in unpacking_mask is always
256
- B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
257
-
258
- Args:
259
- expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
260
- selected_heads: Routed head selections I of shape (B, N, K).
261
- unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
262
- occupied packed slots regardless of outer token liveness.
263
- inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
264
-
265
- Returns:
266
- Restored token-choice tensor y_tilde of shape (B, N, K, d).
267
- """
268
- batch_size, sequence_length, num_selected_heads = selected_heads.shape
269
- hidden_dim = expert_outputs.shape[-1]
270
-
271
- active_outputs = expert_outputs[unpacking_mask]
272
- sorted_token_choice_outputs = active_outputs.reshape(
273
- batch_size,
274
- sequence_length * num_selected_heads,
275
- hidden_dim,
276
- )
277
- restored_outputs = sorted_token_choice_outputs.gather(
278
- dim=1,
279
- index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
280
- )
281
-
282
- return restored_outputs.reshape(
283
- batch_size,
284
- sequence_length,
285
- num_selected_heads,
286
- hidden_dim,
287
- )
288
-
289
-
290
- # ---------------------------------------------------------------------------
291
- # Helpers
292
- # ---------------------------------------------------------------------------
293
-
294
- def _bincount_rows(
295
- values: torch.Tensor,
296
- num_bins: int,
297
- ) -> torch.Tensor:
298
- """Count per-row integer occurrences for a 2D tensor.
299
-
300
- torch.bincount operates on a flat 1D vector, but the packing algorithm needs
301
- one bincount per batch row. The trick used here is to shift each row into its
302
- own disjoint bin range before flattening:
303
-
304
- row 0 uses bins [0, ..., num_bins - 1]
305
- row 1 uses bins [num_bins, ..., 2*num_bins - 1]
306
- row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
307
- ...
308
-
309
- After that shift, one global torch.bincount produces all row-local counts at
310
- once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
311
-
312
- This is a vectorized implementation detail only; externally visible behavior
313
- remains exactly the paper's S tensor of per-batch per-expert token counts.
314
-
315
- Args:
316
- values: Integer tensor of shape (B, M) with entries in [0, num_bins).
317
- num_bins: Number of bins.
318
-
319
- Returns:
320
- Counts tensor of shape (B, num_bins).
321
- """
322
- batch_size = values.shape[0]
323
-
324
- row_offsets = torch.arange(
325
- batch_size,
326
- device=values.device,
327
- dtype=values.dtype,
328
- ).unsqueeze(1) * num_bins
329
- shifted_values = values + row_offsets
330
-
331
- counts = torch.bincount(
332
- shifted_values.reshape(-1),
333
- minlength=batch_size * num_bins,
334
- )
335
- return counts.reshape(batch_size, num_bins)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/load_balance_loss.py DELETED
@@ -1,88 +0,0 @@
1
- """Auxiliary-loss-free load balancing operator for MoSRAH routing.
2
-
3
- This module implements the custom autograd Function H(b, f) described in the paper's
4
- Implementation Concerns section. The operator bridges two requirements that are in
5
- tension: it must behave like a standard auxiliary loss (scalar output, scalable via
6
- multiplication) so that existing training loops remain compatible, while simultaneously
7
- implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
8
- path through the router weights.
9
-
10
- The resolution is a custom backward pass. The forward emits the load balance imbalance
11
- as a scalar loss. The backward, instead of differentiating that scalar with respect to
12
- its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
13
- then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
14
- correction without a standalone SGD update step.
15
-
16
- Paper ref: Appendix A.Implementation Concerns.
17
- """
18
-
19
- import torch
20
-
21
-
22
- class LoadBalanceLoss(torch.autograd.Function):
23
- """Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
24
-
25
- Forward computes the load balance imbalance:
26
-
27
- L_load_balance = H(b, f) = sum_l | f_l - 1/L |
28
-
29
- Backward emits a bias-correction gradient to expert_bias:
30
-
31
- grad_b = L_grad * sign(f_l - 1/L)
32
-
33
- expert_bias (b) is included as a forward input so PyTorch registers it as a node
34
- in the computation graph and routes gradients through it. routing_freqs (f) receives
35
- no gradient — its origin is the discrete TopK operation which has no gradient, so
36
- defining a gradient for f here would be mathematically incorrect.
37
-
38
- Paper ref: Appendix A.Implementation Concerns.
39
- """
40
-
41
- @staticmethod
42
- def forward(
43
- ctx: torch.autograd.function.FunctionCtx,
44
- expert_bias: torch.Tensor,
45
- routing_freqs: torch.Tensor,
46
- ) -> torch.Tensor:
47
- """Compute the load balance loss.
48
-
49
- Args:
50
- ctx: Autograd context for saving state needed in backward.
51
- expert_bias: Learned per-head bias b, shape (L,). Included as an input so
52
- PyTorch tracks it as a computation graph node needing a gradient.
53
- routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
54
- from the discrete TopK selection — not differentiable.
55
-
56
- Returns:
57
- Scalar loss equal to sum_l |f_l - 1/L|.
58
- """
59
- L = expert_bias.shape[0]
60
- # imbalance = f_l - 1/L for each head: positive means overloaded, negative means
61
- # underloaded. Saved for backward where sign(imbalance) determines the direction
62
- # of the bias-correction update.
63
- imbalance = routing_freqs - 1.0 / L
64
- ctx.save_for_backward(imbalance)
65
- return imbalance.abs().sum()
66
-
67
- @staticmethod
68
- def backward(
69
- ctx: torch.autograd.function.FunctionCtx,
70
- grad_output: torch.Tensor,
71
- ) -> tuple[torch.Tensor, None]:
72
- """Emit the DeepSeek-style bias-correction gradient.
73
-
74
- Args:
75
- ctx: Autograd context carrying imbalance saved in forward.
76
- grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
77
- by the training loop arrives here and is propagated to grad_b, so the
78
- correction magnitude is proportional to the loss weight chosen by the
79
- consumer.
80
-
81
- Returns:
82
- Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
83
- None for routing_freqs: no gradient is defined for the discrete routing
84
- frequency.
85
- """
86
- (imbalance,) = ctx.saved_tensors
87
- grad_expert_bias = grad_output * imbalance.sign()
88
- return grad_expert_bias, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/mosrah.py DELETED
@@ -1,140 +0,0 @@
1
- """Full MoSRAH sparse path for SHRAM.
2
-
3
- This module coordinates the routed sparse attention path used inside the SHRAM
4
- hybrid attention layer. The underlying mechanics already live in verified
5
- subunits. The responsibility here is to connect those subunits without
6
- corrupting their bridge contracts.
7
-
8
- In particular, this path must preserve three architectural distinctions:
9
-
10
- - selected head indices are not routing probabilities
11
- - packed position semantics are chosen before BEA, not inside it
12
- - weighted reduction must consume the router's unbiased renormalized
13
- probabilities after token-choice order has been restored
14
- """
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from src.shram.model.cache.mosrah_cache import MoSRAHCache
20
- from src.shram.model.configuration import ShramConfig
21
- from src.shram.model.attention.bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
22
- from src.shram.model.attention.expert_packing import (
23
- pack_experts,
24
- setup_packing,
25
- unpack_experts,
26
- )
27
- from src.shram.model.attention.router import MoSRAHRouter
28
- from src.shram.model.attention.positions_converter import SparseMoSRAHPositions
29
-
30
-
31
- class MoSRAHLayer(nn.Module):
32
- """Full routed sparse attention path for SHRAM.
33
-
34
- The MoSRAH path consumes model-space hidden states together with
35
- authoritative per-token positions and returns the model-space sparse-path
36
- contribution, the router's load-balance loss, and the router's MaxVio
37
- routing-imbalance scalar.
38
- """
39
-
40
- def __init__(self, config: ShramConfig) -> None:
41
- super().__init__()
42
- self.num_experts = config.num_mosrah_heads
43
-
44
- self.router = MoSRAHRouter(config)
45
- self.positions = SparseMoSRAHPositions(config)
46
- self.bea = BottleneckedEnsembleAttention(config)
47
-
48
- def forward(
49
- self,
50
- hidden_states: torch.Tensor,
51
- position_ids: torch.Tensor,
52
- active_mask: torch.Tensor,
53
- cache: MoSRAHCache | None,
54
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55
- """Run the full MoSRAH sparse path.
56
-
57
- Args:
58
- hidden_states: Model-space hidden states x of shape (B, N, d).
59
- position_ids: Authoritative per-token positions of shape (B, N).
60
- active_mask: Current-chunk active mask of shape (B, N), where True
61
- means the token is semantically live. Forwarded to the router
62
- so dead tokens are excluded from routing statistics, and to
63
- pack_experts so dead outer tokens do not become semantically
64
- active packed entries.
65
- cache: Optional layer-local MoSRAH cache. Pass None for uncached
66
- execution and the layer-local cache instance for cached execution.
67
-
68
- Returns:
69
- sparse_output: Model-space sparse-path output of shape (B, N, d).
70
- load_balance_loss: Scalar router load-balance loss.
71
- max_vio: Detached scalar routing-imbalance summary. Passed through
72
- unchanged from the router; see MoSRAHRouter for semantics.
73
- """
74
-
75
- # -------------------------------------------------------------------
76
- # The first transition moves from model-space token-choice input into
77
- # the packed expert-choice sparse-attention state. Routing decides both
78
- # which experts each token uses and which unbiased probabilities must be
79
- # reserved for the final reduction. The active mask is forwarded to the
80
- # router so dead tokens are excluded from routing statistics, and to
81
- # pack_experts so outer liveness is faithfully carried into the packed
82
- # frame. Packing returns both the unpacking mask (slot occupancy, always
83
- # B*N*K True entries) and the packed active mask (live slots only);
84
- # active_mask is rebound to the packed form after this point.
85
- # -------------------------------------------------------------------
86
- selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
87
- hidden_states, active_mask
88
- )
89
-
90
- flattened_selected_heads, permutation, inverse_permutation = setup_packing(
91
- selected_heads
92
- )
93
- packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
94
- hidden_states=hidden_states,
95
- position_ids=position_ids,
96
- selected_heads=selected_heads,
97
- num_experts=self.num_experts,
98
- flattened_selected_heads=flattened_selected_heads,
99
- permutation=permutation,
100
- outer_active_mask=active_mask,
101
- )
102
-
103
- # -------------------------------------------------------------------
104
- # Sparse attention runs entirely in the packed expert-choice frame, so
105
- # the RoPE position semantics must also be chosen in that frame. The
106
- # position layer therefore decides whether BEA should see packed
107
- # original-token positions or packed local-slot positions. BEA then
108
- # consumes that packed position tensor together with the packed hidden
109
- # states and the layer-local sparse cache, which it owns directly.
110
- # -------------------------------------------------------------------
111
- bea_positions = self.positions(
112
- packed_positions=packed_positions,
113
- cache=cache,
114
- )
115
- packed_outputs = self.bea(
116
- packed_embeddings=packed_hidden_states,
117
- position_ids=bea_positions,
118
- active_mask=active_mask,
119
- cache=cache,
120
- )
121
-
122
- # -------------------------------------------------------------------
123
- # The final transition restores token-choice meaning and only then
124
- # collapses the K routed copies back into model space. This ordering is
125
- # required because routing_probs live in token-choice space, whereas BEA
126
- # returns expert-choice packed outputs. The reduction must therefore
127
- # happen after unpacking, and it must use the router's unbiased
128
- # renormalized probabilities rather than any biased selection scores.
129
- # -------------------------------------------------------------------
130
- token_choice_outputs = unpack_experts(
131
- expert_outputs=packed_outputs,
132
- selected_heads=selected_heads,
133
- unpacking_mask=unpacking_mask,
134
- inverse_permutation=inverse_permutation,
135
- )
136
- final_output = (
137
- token_choice_outputs * routing_probs.unsqueeze(-1)
138
- ).sum(dim=2)
139
-
140
- return final_output, load_balance_loss, max_vio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/positions_converter.py DELETED
@@ -1,105 +0,0 @@
1
- """Position computation for the MoSRAH sparse path.
2
-
3
- This layer computes the packed position tensor P consumed by BEA.
4
-
5
- - In main-sequence mode, P is the packed original-token position tensor from the
6
- packing path.
7
- - In semantic-sequence mode, P is a per-expert local sequence over the packed
8
- expert-choice layout, optionally offset by the current sparse-cache occupancies
9
- during cached inference.
10
- """
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from src.shram.model.configuration import ShramConfig
16
- from src.shram.model.cache.mosrah_cache import MoSRAHCache
17
-
18
-
19
- class SparseMoSRAHPositions(nn.Module):
20
- """Compute the packed RoPE position tensor for the MoSRAH sparse path.
21
-
22
- This layer operates in the packed expert-choice frame used by BEA. The input
23
- packed_positions tensor is always the packed original-token position tensor
24
- produced by the packing path. The configured rope_mode determines whether that
25
- tensor is forwarded directly or replaced by a semantic local-slot sequence.
26
- """
27
-
28
- def __init__(self, config: ShramConfig) -> None:
29
- super().__init__()
30
- self.rope_mode = config.rope_mode
31
-
32
- def forward(
33
- self,
34
- packed_positions: torch.Tensor,
35
- cache: MoSRAHCache | None,
36
- ) -> torch.Tensor:
37
- """Compute the packed position tensor P consumed by BEA.
38
-
39
- Args:
40
- packed_positions: Packed original-token positions J' of shape (B, L, T).
41
- cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
42
- mode, the current per-head occupancies offset the local packed sequence.
43
-
44
- Returns:
45
- Packed position tensor P of shape (B, L, T).
46
- """
47
- if self.rope_mode == "main_sequence":
48
- return self._main_sequence_positions(packed_positions)
49
-
50
- if self.rope_mode == "semantic_sequence":
51
- return self._semantic_sequence_positions(packed_positions, cache)
52
-
53
- raise NotImplementedError(
54
- f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
55
- )
56
-
57
- def _main_sequence_positions(
58
- self,
59
- packed_positions: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """Forward packed original-token positions unchanged."""
62
- return packed_positions
63
-
64
- def _semantic_sequence_positions(
65
- self,
66
- packed_positions: torch.Tensor,
67
- cache: MoSRAHCache | None,
68
- ) -> torch.Tensor:
69
- """Compute semantic-sequence packed positions in expert-choice space.
70
-
71
- Without a sparse cache, semantic positions are the local packed sequence
72
- 0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
73
- same local sequence is offset by the current per-(batch, expert) occupancies
74
- returned by get_heads_lengths().
75
- """
76
- batch_size, num_experts, packed_length = packed_positions.shape
77
-
78
- # -------------------------------------------------------------------
79
- # Construct the local packed sequence 0, 1, 2, ... over the expert-local
80
- # sequence dimension T. This is then broadcast across batch and experts.
81
- # -------------------------------------------------------------------
82
- local_positions = torch.arange(
83
- packed_length,
84
- device=packed_positions.device,
85
- dtype=packed_positions.dtype,
86
- ).view(1, 1, packed_length).expand(
87
- batch_size,
88
- num_experts,
89
- packed_length,
90
- )
91
-
92
- # -------------------------------------------------------------------
93
- # In cached semantic-sequence mode, positions continue from the current
94
- # sparse-cache occupancies rather than restarting at zero for the local
95
- # chunk.
96
- # -------------------------------------------------------------------
97
- if cache is None:
98
- return local_positions
99
-
100
- cached_lengths = cache.get_heads_lengths().to(
101
- device=packed_positions.device,
102
- dtype=packed_positions.dtype,
103
- ).unsqueeze(-1)
104
-
105
- return local_positions + cached_lengths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/router.py DELETED
@@ -1,162 +0,0 @@
1
- """Token-choice router for the MoSRAH sparse attention path.
2
-
3
- This module implements the routing mechanism described in Appendix A.Routing of the
4
- paper. Given an input hidden state x, the router produces two outputs used downstream:
5
-
6
- - selected_heads (I): which K of the L available expert heads each token routes to,
7
- determined by TopK over biased routing scores.
8
- - routing_probs (P): the weights used for the weighted output reduction, gathered from
9
- *unbiased* routing scores at the selected indices and renormalized. The learned expert
10
- bias b must not influence P.
11
-
12
- This separation is architecturally critical: expert_bias drives selection (and thus load
13
- balancing) but does not corrupt the gradient path from the output through routing_probs
14
- back to the routing projection weights.
15
-
16
- The router also computes and returns the load balance loss via the LoadBalanceLoss custom
17
- autograd operator (see load_balance_loss.py). This loss is a scalar that the training
18
- loop can weight and add to the language modeling loss.
19
-
20
- The router additionally computes and returns MaxVio, a detached scalar summarising
21
- routing imbalance for the current forward pass:
22
-
23
- MaxVio = L · max_l(f_l − 1/L)
24
-
25
- where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
26
- target. MaxVio is a monitoring quantity only; it never contributes gradients.
27
-
28
- Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
29
- """
30
-
31
- import torch
32
- import torch.nn as nn
33
- import torch.nn.functional as F
34
-
35
- from src.shram.model.configuration import ShramConfig
36
- from src.shram.model.attention.load_balance_loss import LoadBalanceLoss
37
-
38
-
39
- class MoSRAHRouter(nn.Module):
40
- """Token-choice router for MoSRAH sparse attention.
41
-
42
- Each input token independently selects K of the L available expert heads. Selection
43
- is driven by biased routing scores to enable load balancing, but the routing
44
- probabilities used for output reduction are computed from unbiased scores so that
45
- the expert bias does not interfere with the gradient path to the router weights.
46
-
47
- The routing projection W_r has no bias term — the paper specifies xW_r with no
48
- additional projection bias. The only bias-like parameter is expert_bias (b), which
49
- has an entirely separate role and update mechanism.
50
-
51
- Args:
52
- config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
53
- (L), and ``num_selected_heads`` (K).
54
- """
55
-
56
- def __init__(self, config: ShramConfig) -> None:
57
- super().__init__()
58
- self.num_mosrah_heads = config.num_mosrah_heads
59
- self.num_selected_heads = config.num_selected_heads
60
-
61
- # W_r: routing projection, no bias (paper specifies xW_r, no additional term).
62
- self.routing_projection = nn.Linear(
63
- config.hidden_size, config.num_mosrah_heads, bias=False
64
- )
65
-
66
- # b: learned per-head bias for load balancing. Initialized to zero so that all
67
- # heads start with equal selection probability. Updated by the main optimizer
68
- # via the LoadBalanceLoss custom backward.
69
- self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
70
-
71
- def forward(
72
- self, x: torch.Tensor, active_mask: torch.Tensor
73
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
74
- """Route input tokens to K expert heads each and compute routing probabilities.
75
-
76
- Args:
77
- x: Input hidden states of shape (batch, seq_len, hidden_size).
78
- active_mask: Current-chunk active mask of shape (batch, seq_len), where
79
- True means the token is semantically live. Dead tokens do not
80
- contribute to routing frequencies, load_balance_loss, or max_vio.
81
-
82
- Returns:
83
- selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
84
- Each token's K selected head indices, determined by TopK on biased scores.
85
- routing_probs: Routing probabilities P of shape (batch, seq_len,
86
- num_selected_heads). Gathered from unbiased scores at selected_heads
87
- indices and renormalized to sum to 1 per token.
88
- load_balance_loss: Scalar load balance imbalance loss for this forward pass.
89
- Training loop scales this by a weight and adds it to the main loss.
90
- max_vio: Detached scalar routing-imbalance summary for this forward pass.
91
- Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
92
- never contributes gradients.
93
- """
94
- B, N, _ = x.shape
95
- L = self.num_mosrah_heads
96
- K = self.num_selected_heads
97
-
98
- # Unbiased routing scores R = Softmax(xW_r). These are the scores used to
99
- # compute routing_probs — expert_bias must not influence them.
100
- logits = self.routing_projection(x) # (B, N, L)
101
- routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
102
-
103
- # Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
104
- # selection. expert_bias is added to logits before softmax so that the bias
105
- # shifts selection probability without rescaling the unbiased distribution.
106
- biased_routing_scores = F.softmax( # R̂, (B, N, L)
107
- logits + self.expert_bias, dim=-1
108
- )
109
-
110
- # selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
111
- selected_heads = biased_routing_scores.topk(K, dim=-1).indices
112
-
113
- # Routing probabilities P: gathered from unbiased R at selected_heads indices,
114
- # then renormalized so they sum to 1 per token. Gathering from routing_scores
115
- # (not biased_routing_scores) is the invariant that keeps the gradient path from
116
- # the output back to the router weights free of expert_bias influence.
117
- gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
118
- routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
119
-
120
- # Routing frequency f_l: fraction of active (batch, token, head_slot) triples
121
- # assigned to each head. Dead tokens are excluded by zeroing their rows in the
122
- # assignment mask before reduction. Normalization uses the active assignment
123
- # count so frequencies remain properly scaled regardless of how many tokens
124
- # are live in this chunk.
125
- assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
126
- assignment_mask.scatter_(-1, selected_heads, 1.0)
127
- active_assignments = assignment_mask * active_mask.unsqueeze(-1)
128
- num_active_assignments = active_mask.sum() * K
129
- routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
130
-
131
- # Load balance loss via custom autograd. expert_bias is an input so PyTorch
132
- # registers it as a graph node; the custom backward writes the DeepSeek-style
133
- # correction gradient to expert_bias.grad for the optimizer to consume.
134
- load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
135
-
136
- # MaxVio is a detached monitoring scalar derived from routing_freqs. It must
137
- # not contribute gradients under any circumstance, so it is detached at the
138
- # point of computation rather than left to callers to detach.
139
- max_vio = self._compute_max_vio(routing_freqs, L)
140
-
141
- return selected_heads, routing_probs, load_balance_loss, max_vio
142
-
143
- @staticmethod
144
- def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
145
- """Compute the MaxVio routing-imbalance scalar.
146
-
147
- MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
148
- head l and 1/L is the perfectly balanced target. A value of zero indicates
149
- perfect balance; a value of 1 means the most overloaded head received exactly
150
- double its fair share.
151
-
152
- The result is detached from the autograd graph — MaxVio is a monitoring scalar
153
- and must never contribute gradients to any parameter.
154
-
155
- Args:
156
- routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
157
- num_heads: Total number of MoSRAH heads L.
158
-
159
- Returns:
160
- Detached scalar MaxVio tensor.
161
- """
162
- return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
attention/shram.py DELETED
@@ -1,116 +0,0 @@
1
- """SHRAM hybrid attention layer.
2
-
3
- This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
4
- used at one decoder attention slot in SHRAM.
5
-
6
- The local sliding-window path and the MoSRAH sparse path are already verified
7
- independently. The responsibility here is therefore not to introduce new
8
- attention logic, but to preserve the bridge contracts between them: both paths
9
- must consume the same input hidden state, each path must receive the sub-cache
10
- it actually owns, the two model-space outputs must be summed directly, and the
11
- sparse-path load-balance loss must remain visible to the caller.
12
- """
13
-
14
- import torch
15
- from torch import nn
16
-
17
- from src.shram.model.cache.shram_layer_cache import ShramLayerCache
18
- from src.shram.model.configuration import ShramConfig
19
- from src.shram.model.attention.sliding_window_attention import SlidingWindowAttention
20
- from src.shram.model.attention.mosrah import MoSRAHLayer
21
-
22
-
23
- class SHRAMHybridLayer(nn.Module):
24
- """Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
25
-
26
- The local path preserves nearby-token behavior through sliding-window causal
27
- attention. The sparse path is the theorem-facing MoSRAH routed attention
28
- path. Both operate over the same model-space hidden state and return
29
- model-space outputs, so the hybrid composition is a direct sum in model
30
- space.
31
- """
32
-
33
- def __init__(self, config: ShramConfig) -> None:
34
- super().__init__()
35
- self.local_attention = SlidingWindowAttention(config)
36
- self.sparse_attention = MoSRAHLayer(config)
37
-
38
- def forward(
39
- self,
40
- hidden_states: torch.Tensor,
41
- position_ids: torch.Tensor,
42
- active_mask: torch.Tensor,
43
- cache: ShramLayerCache | None,
44
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
- """Apply the SHRAM hybrid attention layer.
46
-
47
- Args:
48
- hidden_states: Input hidden states of shape (B, N, d).
49
- position_ids: Authoritative token positions of shape (B, N).
50
- active_mask: Current-chunk active mask of shape (B, N), where True
51
- means the token is semantically live. Forwarded unchanged to
52
- both the local path and the sparse path.
53
- cache: Optional per-layer SHRAM cache. When provided, the owned
54
- sliding-window and MoSRAH sub-caches are dispatched directly to
55
- their corresponding attention paths.
56
-
57
- Returns:
58
- hybrid_output: Model-space hybrid attention output of shape (B, N, d).
59
- load_balance_loss: Scalar sparse-path load-balance loss.
60
- max_vio: Detached scalar routing-imbalance summary. Passed through
61
- unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
62
- """
63
- # ------------------------------------------------
64
- # It is not possible, due to how bea constructs its block mask,
65
- # for the model to process a sequence that does not start at zero
66
- # without a cache to track the per-head offsets
67
- # ------------------------------------------------
68
-
69
- if cache is None and torch.any(position_ids[:, 0] != 0):
70
- raise ValueError(
71
- "Uncached SHRAMHybridLayer does not support nonzero starting positions. "
72
- "Either provide a matching ShramLayerCache populated by the prefix for "
73
- "continued decoding, or rebase the uncached sequence to start at 0."
74
- )
75
-
76
- # -------------------------------------------------------------------
77
- # The hybrid layer's first responsibility is cache dispatch. The layer
78
- # cache already owns the concrete sub-cache objects required by each
79
- # path, so this unit should forward those exact references rather than
80
- # reinterpret cache ownership or invent a composite update protocol here.
81
- # -------------------------------------------------------------------
82
- if cache is None:
83
- sliding_window_cache = None
84
- mosrah_cache = None
85
- else:
86
- sliding_window_cache = cache.sliding_window_cache
87
- mosrah_cache = cache.mosrah_cache
88
-
89
- # -------------------------------------------------------------------
90
- # Both attention paths must see the same model-space hidden state for
91
- # the current decoder layer. The local path preserves short-range
92
- # structure, while the sparse path provides the routed long-range
93
- # contribution and emits the load-balance signal used by training.
94
- # -------------------------------------------------------------------
95
- local_output = self.local_attention(
96
- x=hidden_states,
97
- position_ids=position_ids,
98
- active_mask=active_mask,
99
- cache=sliding_window_cache,
100
- )
101
- sparse_output, load_balance_loss, max_vio = self.sparse_attention(
102
- hidden_states=hidden_states,
103
- position_ids=position_ids,
104
- active_mask=active_mask,
105
- cache=mosrah_cache,
106
- )
107
-
108
- # -------------------------------------------------------------------
109
- # The composition rule is intentionally simple at this boundary. Both
110
- # sublayers already return model-space tensors of matching shape, so the
111
- # correct hybrid behavior is their direct sum with no additional mixing
112
- # logic introduced here.
113
- # -------------------------------------------------------------------
114
- hybrid_output = local_output + sparse_output
115
-
116
- return hybrid_output, load_balance_loss, max_vio