smithblack-0 commited on
Commit
0c295ed
·
verified ·
1 Parent(s): 78610c2

Update architecture and tokenizer

Browse files
__attention__sliding_window_attention.py CHANGED
@@ -120,12 +120,15 @@ class SlidingWindowAttention(nn.Module):
120
  # The cache returns the current-step visible local frame, not merely the
121
  # retained next-step cache buffer.
122
  if cache is not None:
123
- k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
 
 
124
  else:
125
- k_full, v_full, full_active_mask = k, v, active_mask
126
 
127
  block_mask = self._make_block_mask(
128
  active_mask=full_active_mask,
 
129
  batch_size=batch_size,
130
  num_heads=self.num_heads,
131
  query_len=query_len,
@@ -182,6 +185,7 @@ class SlidingWindowAttention(nn.Module):
182
  def _make_block_mask(
183
  self,
184
  active_mask: torch.Tensor,
 
185
  batch_size: int,
186
  num_heads: int,
187
  query_len: int,
@@ -191,17 +195,14 @@ class SlidingWindowAttention(nn.Module):
191
  ) -> Any:
192
  """Create the FlexAttention block mask for masked local continuation.
193
 
194
- The returned local frame is chronological in raw buffer order, but dead
195
- positions may remain inside it. Effective local order is therefore
196
- recovered from the active mask itself by taking a cumulative count over
197
- active positions.
198
-
199
- Queries still occupy the tail of the returned frame, so raw buffer order
200
- is used to locate query rows. Semantic active-token positions are then
201
- used to decide causality and sliding-window distance.
202
  """
203
  query_offset = kv_len - query_len
204
- semantic_positions = active_mask.long().cumsum(dim=-1) - 1
205
 
206
  def sliding_window_mask(
207
  batch_idx: torch.Tensor,
@@ -215,11 +216,11 @@ class SlidingWindowAttention(nn.Module):
215
  query_is_active = active_mask[batch_idx, q_abs]
216
  key_is_active = active_mask[batch_idx, kv_idx]
217
 
218
- q_sem = semantic_positions[batch_idx, q_abs]
219
- k_sem = semantic_positions[batch_idx, kv_idx]
220
 
221
- is_causal = k_sem <= q_sem
222
- in_window = (q_sem - k_sem) < window_size
223
 
224
  return query_is_active & key_is_active & is_causal & in_window
225
 
 
120
  # The cache returns the current-step visible local frame, not merely the
121
  # retained next-step cache buffer.
122
  if cache is not None:
123
+ k_full, v_full, full_active_mask, full_positions = cache.update(
124
+ k, v, active_mask, position_ids
125
+ )
126
  else:
127
+ k_full, v_full, full_active_mask, full_positions = k, v, active_mask, position_ids
128
 
129
  block_mask = self._make_block_mask(
130
  active_mask=full_active_mask,
131
+ positions=full_positions,
132
  batch_size=batch_size,
133
  num_heads=self.num_heads,
134
  query_len=query_len,
 
185
  def _make_block_mask(
186
  self,
187
  active_mask: torch.Tensor,
188
+ positions: torch.Tensor,
189
  batch_size: int,
190
  num_heads: int,
191
  query_len: int,
 
195
  ) -> Any:
196
  """Create the FlexAttention block mask for masked local continuation.
197
 
198
+ The returned local frame is chronological in raw buffer order; dead
199
+ positions may remain inside it. Liveness is carried by `active_mask`.
200
+ Causality and window distance are determined from `positions`, which
201
+ holds the absolute sequence position of every slot in the composite
202
+ frame. Using absolute positions rather than a cumsum over the active
203
+ mask eliminates the data-dependent computation that blocks torch.compile.
 
 
204
  """
205
  query_offset = kv_len - query_len
 
206
 
207
  def sliding_window_mask(
208
  batch_idx: torch.Tensor,
 
216
  query_is_active = active_mask[batch_idx, q_abs]
217
  key_is_active = active_mask[batch_idx, kv_idx]
218
 
219
+ q_pos = positions[batch_idx, q_abs]
220
+ k_pos = positions[batch_idx, kv_idx]
221
 
222
+ is_causal = k_pos <= q_pos
223
+ in_window = (q_pos - k_pos) < window_size
224
 
225
  return query_is_active & key_is_active & is_causal & in_window
226
 
__cache__sliding_window_cache.py CHANGED
@@ -92,6 +92,15 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
92
  device=device,
93
  )
94
 
 
 
 
 
 
 
 
 
 
95
  self.is_initialized = True
96
 
97
  # Cumulative count of all token positions presented through update() for
@@ -104,8 +113,9 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
104
  key_states: torch.Tensor,
105
  value_states: torch.Tensor,
106
  active_mask: torch.Tensor,
 
107
  cache_kwargs: dict | None = None,
108
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
  """Return the current-step local frame and retain the next-step window.
110
 
111
  Args:
@@ -115,6 +125,8 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
115
  current chunk.
116
  active_mask: Shape `(B, T_new)` bool. `True` means the
117
  corresponding token position in the current chunk is active.
 
 
118
  cache_kwargs: Present only to satisfy the `CacheLayerMixin`
119
  interface. Unused by this cache.
120
 
@@ -123,6 +135,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
123
  - visible_keys: `(B, H, sliding_window + T_new, D)`
124
  - visible_values: `(B, H, sliding_window + T_new, D)`
125
  - visible_active_mask: `(B, sliding_window + T_new)`
 
126
 
127
  These are the tensors the local attention path should consume
128
  directly for the current step.
@@ -134,10 +147,11 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
134
 
135
  # The current-step local frame is just retained cache state followed by
136
  # the current chunk in chronological order.
137
- composite_keys, composite_values, composite_mask = self._make_composite_frame(
138
  key_states=key_states,
139
  value_states=value_states,
140
  active_mask=active_mask,
 
141
  )
142
 
143
  # The cache remembers only the last raw sliding-window positions of that
@@ -147,11 +161,12 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
147
  composite_keys=composite_keys,
148
  composite_values=composite_values,
149
  composite_mask=composite_mask,
 
150
  )
151
 
152
  self._total_processed += key_states.shape[2]
153
 
154
- return composite_keys, composite_values, composite_mask
155
 
156
  def _ensure_state_compatibility(
157
  self,
@@ -185,17 +200,25 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
185
  non_blocking=True,
186
  )
187
 
 
 
 
 
 
 
188
  def _make_composite_frame(
189
  self,
190
  key_states: torch.Tensor,
191
  value_states: torch.Tensor,
192
  active_mask: torch.Tensor,
193
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
194
  """Build the current-step local frame in chronological order."""
195
  return (
196
  torch.cat([self.keys, key_states], dim=-2),
197
  torch.cat([self.values, value_states], dim=-2),
198
  torch.cat([self.active_mask, active_mask], dim=-1),
 
199
  )
200
 
201
  def _retain_next_window(
@@ -203,15 +226,17 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
203
  composite_keys: torch.Tensor,
204
  composite_values: torch.Tensor,
205
  composite_mask: torch.Tensor,
 
206
  ) -> None:
207
  """Remember the next-step retained local state.
208
 
209
  This is a raw positional trim to the last `sliding_window` positions, not
210
  a semantic live-token trim.
211
  """
212
- self.keys = composite_keys[:, :, -self.sliding_window :, :]
213
- self.values = composite_values[:, :, -self.sliding_window :, :]
214
- self.active_mask = composite_mask[:, -self.sliding_window :]
 
215
 
216
  def get_seq_length(self) -> int:
217
  """Return the cumulative number of token positions processed by this cache.
@@ -239,6 +264,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
239
  self.keys.zero_()
240
  self.values.zero_()
241
  self.active_mask.zero_()
 
242
  self._total_processed = 0
243
 
244
  def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
@@ -246,12 +272,14 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
246
  self.keys = self.keys[beam_idx]
247
  self.values = self.values[beam_idx]
248
  self.active_mask = self.active_mask[beam_idx]
 
249
 
250
  def batch_repeat_interleave(self, repeats: int) -> None:
251
  """Expand the batch dimension for beam-search initialisation."""
252
  self.keys = self.keys.repeat_interleave(repeats, dim=0)
253
  self.values = self.values.repeat_interleave(repeats, dim=0)
254
  self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
 
255
  self.batch_size = self.batch_size * repeats
256
 
257
  def batch_select_indices(self, indices: torch.Tensor) -> None:
@@ -259,12 +287,14 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
259
  self.keys = self.keys[indices]
260
  self.values = self.values[indices]
261
  self.active_mask = self.active_mask[indices]
 
262
  self.batch_size = int(indices.shape[0])
263
 
264
  def offload(self) -> None:
265
  """Offload cache tensors to CPU."""
266
  super().offload()
267
  self.active_mask = self.active_mask.to("cpu", non_blocking=True)
 
268
 
269
  def prefetch(self) -> None:
270
  """Move cache tensors back to the model device ahead of time."""
@@ -274,6 +304,10 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
274
  self.keys.device,
275
  non_blocking=True,
276
  )
 
 
 
 
277
 
278
  def crop(self, max_length: int) -> None:
279
  raise NotImplementedError(
 
92
  device=device,
93
  )
94
 
95
+ # Absolute sequence positions of each retained slot. Inactive slots
96
+ # retain zero; correctness is carried by active_mask.
97
+ self.positions = torch.zeros(
98
+ batch_size,
99
+ sliding_window,
100
+ dtype=torch.long,
101
+ device=device,
102
+ )
103
+
104
  self.is_initialized = True
105
 
106
  # Cumulative count of all token positions presented through update() for
 
113
  key_states: torch.Tensor,
114
  value_states: torch.Tensor,
115
  active_mask: torch.Tensor,
116
+ positions: torch.Tensor,
117
  cache_kwargs: dict | None = None,
118
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
119
  """Return the current-step local frame and retain the next-step window.
120
 
121
  Args:
 
125
  current chunk.
126
  active_mask: Shape `(B, T_new)` bool. `True` means the
127
  corresponding token position in the current chunk is active.
128
+ positions: Shape `(B, T_new)` long. Absolute sequence position of
129
+ each token in the current chunk.
130
  cache_kwargs: Present only to satisfy the `CacheLayerMixin`
131
  interface. Unused by this cache.
132
 
 
135
  - visible_keys: `(B, H, sliding_window + T_new, D)`
136
  - visible_values: `(B, H, sliding_window + T_new, D)`
137
  - visible_active_mask: `(B, sliding_window + T_new)`
138
+ - visible_positions: `(B, sliding_window + T_new)`
139
 
140
  These are the tensors the local attention path should consume
141
  directly for the current step.
 
147
 
148
  # The current-step local frame is just retained cache state followed by
149
  # the current chunk in chronological order.
150
+ composite_keys, composite_values, composite_mask, composite_positions = self._make_composite_frame(
151
  key_states=key_states,
152
  value_states=value_states,
153
  active_mask=active_mask,
154
+ positions=positions,
155
  )
156
 
157
  # The cache remembers only the last raw sliding-window positions of that
 
161
  composite_keys=composite_keys,
162
  composite_values=composite_values,
163
  composite_mask=composite_mask,
164
+ composite_positions=composite_positions,
165
  )
166
 
167
  self._total_processed += key_states.shape[2]
168
 
169
+ return composite_keys, composite_values, composite_mask, composite_positions
170
 
171
  def _ensure_state_compatibility(
172
  self,
 
200
  non_blocking=True,
201
  )
202
 
203
+ if self.positions.device != key_states.device:
204
+ self.positions = self.positions.to(
205
+ key_states.device,
206
+ non_blocking=True,
207
+ )
208
+
209
  def _make_composite_frame(
210
  self,
211
  key_states: torch.Tensor,
212
  value_states: torch.Tensor,
213
  active_mask: torch.Tensor,
214
+ positions: torch.Tensor,
215
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
216
  """Build the current-step local frame in chronological order."""
217
  return (
218
  torch.cat([self.keys, key_states], dim=-2),
219
  torch.cat([self.values, value_states], dim=-2),
220
  torch.cat([self.active_mask, active_mask], dim=-1),
221
+ torch.cat([self.positions, positions], dim=-1),
222
  )
223
 
224
  def _retain_next_window(
 
226
  composite_keys: torch.Tensor,
227
  composite_values: torch.Tensor,
228
  composite_mask: torch.Tensor,
229
+ composite_positions: torch.Tensor,
230
  ) -> None:
231
  """Remember the next-step retained local state.
232
 
233
  This is a raw positional trim to the last `sliding_window` positions, not
234
  a semantic live-token trim.
235
  """
236
+ self.keys[:] = composite_keys[:, :, -self.sliding_window :, :]
237
+ self.values[:] = composite_values[:, :, -self.sliding_window :, :]
238
+ self.active_mask[:] = composite_mask[:, -self.sliding_window :]
239
+ self.positions[:] = composite_positions[:, -self.sliding_window :]
240
 
241
  def get_seq_length(self) -> int:
242
  """Return the cumulative number of token positions processed by this cache.
 
264
  self.keys.zero_()
265
  self.values.zero_()
266
  self.active_mask.zero_()
267
+ self.positions.zero_()
268
  self._total_processed = 0
269
 
270
  def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
 
272
  self.keys = self.keys[beam_idx]
273
  self.values = self.values[beam_idx]
274
  self.active_mask = self.active_mask[beam_idx]
275
+ self.positions = self.positions[beam_idx]
276
 
277
  def batch_repeat_interleave(self, repeats: int) -> None:
278
  """Expand the batch dimension for beam-search initialisation."""
279
  self.keys = self.keys.repeat_interleave(repeats, dim=0)
280
  self.values = self.values.repeat_interleave(repeats, dim=0)
281
  self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
282
+ self.positions = self.positions.repeat_interleave(repeats, dim=0)
283
  self.batch_size = self.batch_size * repeats
284
 
285
  def batch_select_indices(self, indices: torch.Tensor) -> None:
 
287
  self.keys = self.keys[indices]
288
  self.values = self.values[indices]
289
  self.active_mask = self.active_mask[indices]
290
+ self.positions = self.positions[indices]
291
  self.batch_size = int(indices.shape[0])
292
 
293
  def offload(self) -> None:
294
  """Offload cache tensors to CPU."""
295
  super().offload()
296
  self.active_mask = self.active_mask.to("cpu", non_blocking=True)
297
+ self.positions = self.positions.to("cpu", non_blocking=True)
298
 
299
  def prefetch(self) -> None:
300
  """Move cache tensors back to the model device ahead of time."""
 
304
  self.keys.device,
305
  non_blocking=True,
306
  )
307
+ self.positions = self.positions.to(
308
+ self.keys.device,
309
+ non_blocking=True,
310
+ )
311
 
312
  def crop(self, max_length: int) -> None:
313
  raise NotImplementedError(
config.json CHANGED
@@ -21,7 +21,7 @@
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
- "transformers_version": "5.7.0",
25
  "use_cache": true,
26
  "vocab_size": 50277,
27
  "window_size": 128
 
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
+ "transformers_version": "5.8.0",
25
  "use_cache": true,
26
  "vocab_size": 50277,
27
  "window_size": 128