smithblack-0 commited on
Commit
e15c0d5
·
verified ·
1 Parent(s): 4fedccb

Update architecture and tokenizer

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. huggingface.py +93 -12
config.json CHANGED
@@ -24,7 +24,7 @@
24
  "rope_mode": "main_sequence",
25
  "tie_word_embeddings": false,
26
  "training_sequence_length": 1024,
27
- "transformers_version": "5.9.0",
28
  "use_cache": true,
29
  "vocab_size": 50277,
30
  "window_size": 128
 
24
  "rope_mode": "main_sequence",
25
  "tie_word_embeddings": false,
26
  "training_sequence_length": 1024,
27
+ "transformers_version": "5.10.1",
28
  "use_cache": true,
29
  "vocab_size": 50277,
30
  "window_size": 128
huggingface.py CHANGED
@@ -1284,6 +1284,13 @@ class ShramCache(Cache):
1284
  layer have materially different update semantics; callers must update sub-caches directly
1285
  via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
1286
 
 
 
 
 
 
 
 
1287
  Args:
1288
  config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
1289
  dimensions are derived from config so that a single source of truth governs
@@ -1310,11 +1317,19 @@ class ShramCache(Cache):
1310
  ]
1311
  super().__init__(layers=layers)
1312
 
 
 
 
 
 
 
 
1313
  # ---------------------------------------------------------------------------
1314
  # Cache — composite-meaningful methods
1315
  # ---------------------------------------------------------------------------
1316
  #
1317
- # reset(): Inherited. Iterates all layer caches and calls reset() on each.
 
1318
  #
1319
  # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
1320
  #
@@ -1322,6 +1337,40 @@ class ShramCache(Cache):
1322
  # Since ShramLayerCache.is_initialized is True from construction, this is True
1323
  # immediately after ShramCache.__init__ returns.
1324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1325
  def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
1326
  """Return the cumulative sequence length for the specified layer.
1327
 
@@ -2191,6 +2240,7 @@ class BottleneckedEnsembleAttention(nn.Module):
2191
  key_length=key_states.shape[2],
2192
  device=packed_embeddings.device,
2193
  )
 
2194
  attended_states = flex_attention(
2195
  rotated_query_states,
2196
  key_states,
@@ -2836,7 +2886,7 @@ class MoSRAHRouter(nn.Module):
2836
  outputs.
2837
  capacity_scalar: Static upper bound on n; used to derive topk k as
2838
  min(tensor.shape[dim], capacity_scalar). Must be a Python int
2839
- for compile compatibility.
2840
 
2841
  Returns:
2842
  Boolean mask of the same shape as tensor.
@@ -4055,19 +4105,49 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4055
  return attention_mask.to(dtype=torch.bool)
4056
 
4057
  def _resolve_current_position_ids(
4058
- self,
4059
- input_ids: torch.Tensor,
4060
- position_ids: torch.Tensor | None,
4061
- full_attention_mask: torch.BoolTensor,
 
4062
  ) -> torch.LongTensor:
4063
- """Resolve concrete current-step position IDs for the backbone."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4064
  if position_ids is not None:
4065
  return position_ids.to(dtype=torch.long)
4066
 
4067
- full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
4068
- full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
4069
  current_length = input_ids.shape[1]
4070
- return full_position_ids[:, -current_length:]
 
 
 
 
 
 
 
 
 
 
 
4071
 
4072
  def forward(
4073
  self,
@@ -4172,12 +4252,13 @@ class ShramForCausalLM(PreTrainedModel, GenerationMixin):
4172
  )
4173
  current_length: int = input_ids.shape[1]
4174
  current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
 
4175
  current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
4176
  input_ids=input_ids,
4177
  position_ids=position_ids,
4178
- full_attention_mask=full_attention_mask,
 
4179
  )
4180
- shram_cache: ShramCache | None = past_key_values if use_cache else None
4181
 
4182
  if shram_cache is None:
4183
  positions_start_sane = torch.all(current_position_ids[:, 0] == 0)
 
1284
  layer have materially different update semantics; callers must update sub-caches directly
1285
  via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
1286
 
1287
+ ShramCache also tracks per-batch cumulative active token counts via
1288
+ ``_active_token_counts``. ``total_active_tokens(active_mask)`` returns the accumulated
1289
+ count before the current step and updates the buffer in-place; the caller uses this as a
1290
+ per-batch position bias for contiguous arange-based position ID resolution. All counter
1291
+ updates are in-place to satisfy CUDAGraph fixed-memory requirements. ``reset()``
1292
+ zeroes the buffer along with all layer caches.
1293
+
1294
  Args:
1295
  config: ShramConfig instance. All layer counts, buffer sizes, and sub-cache
1296
  dimensions are derived from config so that a single source of truth governs
 
1317
  ]
1318
  super().__init__(layers=layers)
1319
 
1320
+ # Active token counter for position ID resolution (Unit 23.B). Pre-allocated
1321
+ # at construction so all updates remain in-place across forward passes,
1322
+ # satisfying CUDAGraph fixed-memory requirements.
1323
+ self._active_token_counts: torch.Tensor = torch.zeros(
1324
+ batch_size, dtype=torch.long, device=device
1325
+ )
1326
+
1327
  # ---------------------------------------------------------------------------
1328
  # Cache — composite-meaningful methods
1329
  # ---------------------------------------------------------------------------
1330
  #
1331
+ # reset(): Overridden. Zeroes _active_token_counts in-place, then delegates to
1332
+ # the inherited implementation to reset all layer caches.
1333
  #
1334
  # reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
1335
  #
 
1337
  # Since ShramLayerCache.is_initialized is True from construction, this is True
1338
  # immediately after ShramCache.__init__ returns.
1339
 
1340
+ def total_active_tokens(self, active_mask: torch.BoolTensor) -> torch.Tensor:
1341
+ """Return the per-batch accumulated active token count before this step, then update.
1342
+
1343
+ Reads the current per-batch accumulated count as a position bias for the caller,
1344
+ then increments the internal counter in-place by the number of active tokens in
1345
+ ``active_mask`` for each batch item. The pre-update count is returned so the
1346
+ caller can offset an arange-based position tensor to the correct starting position
1347
+ for this forward pass.
1348
+
1349
+ All updates are in-place to satisfy CUDAGraph fixed-memory requirements. The
1350
+ counter persists across forward passes until ``reset()`` is called.
1351
+
1352
+ Args:
1353
+ active_mask: Boolean mask of shape ``(B, N)`` for the current forward step,
1354
+ where True marks an active (non-padding) token position.
1355
+
1356
+ Returns:
1357
+ Integer tensor of shape ``(B,)`` — the accumulated count before this update.
1358
+ """
1359
+ prior_counts = self._active_token_counts.clone()
1360
+ self._active_token_counts.add_(active_mask.sum(dim=-1))
1361
+ return prior_counts
1362
+
1363
+ def reset(self) -> None:
1364
+ """Clear all layer caches and reset the active token counter.
1365
+
1366
+ Zeroes ``_active_token_counts`` in-place, then delegates to the inherited
1367
+ implementation to reset all ShramLayerCache instances. In-place mutation of
1368
+ the counter is required for CUDAGraph compatibility — the buffer must remain
1369
+ at the same memory address across steps.
1370
+ """
1371
+ self._active_token_counts.zero_()
1372
+ super().reset()
1373
+
1374
  def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
1375
  """Return the cumulative sequence length for the specified layer.
1376
 
 
2240
  key_length=key_states.shape[2],
2241
  device=packed_embeddings.device,
2242
  )
2243
+
2244
  attended_states = flex_attention(
2245
  rotated_query_states,
2246
  key_states,
 
2886
  outputs.
2887
  capacity_scalar: Static upper bound on n; used to derive topk k as
2888
  min(tensor.shape[dim], capacity_scalar). Must be a Python int
2889
+ for compile compatibility.
2890
 
2891
  Returns:
2892
  Boolean mask of the same shape as tensor.
 
4105
  return attention_mask.to(dtype=torch.bool)
4106
 
4107
  def _resolve_current_position_ids(
4108
+ self,
4109
+ input_ids: torch.Tensor,
4110
+ position_ids: torch.Tensor | None,
4111
+ current_active_mask: torch.BoolTensor,
4112
+ cache: ShramCache | None,
4113
  ) -> torch.LongTensor:
4114
+ """Resolve concrete current-step position IDs for the backbone.
4115
+
4116
+ Builds a fresh contiguous allocation via arange + per-batch bias. No cumsum
4117
+ or stride-based views are produced; the returned tensor is always a new
4118
+ allocation safe for Inductor tracing at the FlexAttention boundary.
4119
+
4120
+ When a cache is present, ``total_active_tokens()`` provides the per-batch
4121
+ accumulated active token count as a position bias. Uncached calls use a zero
4122
+ bias. In both cases positions are ``bias + arange(current_length)``, with
4123
+ inactive positions masked to 0.
4124
+
4125
+ Args:
4126
+ input_ids: Current token IDs of shape ``(B, N)``.
4127
+ position_ids: Explicit positions if supplied by the caller; returned
4128
+ unchanged (cast to long). Bias computation is skipped entirely.
4129
+ current_active_mask: Boolean mask of shape ``(B, N)`` for the current step.
4130
+ cache: Active ``ShramCache``, or ``None`` for uncached forward passes.
4131
+
4132
+ Returns:
4133
+ Long tensor of shape ``(B, N)`` — position index per token, 0 for inactive.
4134
+ """
4135
  if position_ids is not None:
4136
  return position_ids.to(dtype=torch.long)
4137
 
 
 
4138
  current_length = input_ids.shape[1]
4139
+
4140
+ if cache is not None:
4141
+ position_bias = cache.total_active_tokens(current_active_mask)
4142
+ else:
4143
+ position_bias = torch.zeros(
4144
+ input_ids.shape[0], dtype=torch.long, device=input_ids.device
4145
+ )
4146
+
4147
+ positions = position_bias.unsqueeze(1) + torch.arange(
4148
+ current_length, device=input_ids.device, dtype=torch.long
4149
+ )
4150
+ return positions.masked_fill(~current_active_mask, 0)
4151
 
4152
  def forward(
4153
  self,
 
4252
  )
4253
  current_length: int = input_ids.shape[1]
4254
  current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
4255
+ shram_cache: ShramCache | None = past_key_values if use_cache else None
4256
  current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
4257
  input_ids=input_ids,
4258
  position_ids=position_ids,
4259
+ current_active_mask=current_active_mask,
4260
+ cache=shram_cache,
4261
  )
 
4262
 
4263
  if shram_cache is None:
4264
  positions_start_sane = torch.all(current_position_ids[:, 0] == 0)