Update architecture and tokenizer
Browse files- __attention__sliding_window_attention.py +16 -15
- __cache__sliding_window_cache.py +41 -7
- config.json +1 -1
__attention__sliding_window_attention.py
CHANGED
|
@@ -120,12 +120,15 @@ class SlidingWindowAttention(nn.Module):
|
|
| 120 |
# The cache returns the current-step visible local frame, not merely the
|
| 121 |
# retained next-step cache buffer.
|
| 122 |
if cache is not None:
|
| 123 |
-
k_full, v_full, full_active_mask = cache.update(
|
|
|
|
|
|
|
| 124 |
else:
|
| 125 |
-
k_full, v_full, full_active_mask = k, v, active_mask
|
| 126 |
|
| 127 |
block_mask = self._make_block_mask(
|
| 128 |
active_mask=full_active_mask,
|
|
|
|
| 129 |
batch_size=batch_size,
|
| 130 |
num_heads=self.num_heads,
|
| 131 |
query_len=query_len,
|
|
@@ -182,6 +185,7 @@ class SlidingWindowAttention(nn.Module):
|
|
| 182 |
def _make_block_mask(
|
| 183 |
self,
|
| 184 |
active_mask: torch.Tensor,
|
|
|
|
| 185 |
batch_size: int,
|
| 186 |
num_heads: int,
|
| 187 |
query_len: int,
|
|
@@ -191,17 +195,14 @@ class SlidingWindowAttention(nn.Module):
|
|
| 191 |
) -> Any:
|
| 192 |
"""Create the FlexAttention block mask for masked local continuation.
|
| 193 |
|
| 194 |
-
The returned local frame is chronological in raw buffer order
|
| 195 |
-
positions may remain inside it.
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
is used to locate query rows. Semantic active-token positions are then
|
| 201 |
-
used to decide causality and sliding-window distance.
|
| 202 |
"""
|
| 203 |
query_offset = kv_len - query_len
|
| 204 |
-
semantic_positions = active_mask.long().cumsum(dim=-1) - 1
|
| 205 |
|
| 206 |
def sliding_window_mask(
|
| 207 |
batch_idx: torch.Tensor,
|
|
@@ -215,11 +216,11 @@ class SlidingWindowAttention(nn.Module):
|
|
| 215 |
query_is_active = active_mask[batch_idx, q_abs]
|
| 216 |
key_is_active = active_mask[batch_idx, kv_idx]
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
|
| 221 |
-
is_causal =
|
| 222 |
-
in_window = (
|
| 223 |
|
| 224 |
return query_is_active & key_is_active & is_causal & in_window
|
| 225 |
|
|
|
|
| 120 |
# The cache returns the current-step visible local frame, not merely the
|
| 121 |
# retained next-step cache buffer.
|
| 122 |
if cache is not None:
|
| 123 |
+
k_full, v_full, full_active_mask, full_positions = cache.update(
|
| 124 |
+
k, v, active_mask, position_ids
|
| 125 |
+
)
|
| 126 |
else:
|
| 127 |
+
k_full, v_full, full_active_mask, full_positions = k, v, active_mask, position_ids
|
| 128 |
|
| 129 |
block_mask = self._make_block_mask(
|
| 130 |
active_mask=full_active_mask,
|
| 131 |
+
positions=full_positions,
|
| 132 |
batch_size=batch_size,
|
| 133 |
num_heads=self.num_heads,
|
| 134 |
query_len=query_len,
|
|
|
|
| 185 |
def _make_block_mask(
|
| 186 |
self,
|
| 187 |
active_mask: torch.Tensor,
|
| 188 |
+
positions: torch.Tensor,
|
| 189 |
batch_size: int,
|
| 190 |
num_heads: int,
|
| 191 |
query_len: int,
|
|
|
|
| 195 |
) -> Any:
|
| 196 |
"""Create the FlexAttention block mask for masked local continuation.
|
| 197 |
|
| 198 |
+
The returned local frame is chronological in raw buffer order; dead
|
| 199 |
+
positions may remain inside it. Liveness is carried by `active_mask`.
|
| 200 |
+
Causality and window distance are determined from `positions`, which
|
| 201 |
+
holds the absolute sequence position of every slot in the composite
|
| 202 |
+
frame. Using absolute positions rather than a cumsum over the active
|
| 203 |
+
mask eliminates the data-dependent computation that blocks torch.compile.
|
|
|
|
|
|
|
| 204 |
"""
|
| 205 |
query_offset = kv_len - query_len
|
|
|
|
| 206 |
|
| 207 |
def sliding_window_mask(
|
| 208 |
batch_idx: torch.Tensor,
|
|
|
|
| 216 |
query_is_active = active_mask[batch_idx, q_abs]
|
| 217 |
key_is_active = active_mask[batch_idx, kv_idx]
|
| 218 |
|
| 219 |
+
q_pos = positions[batch_idx, q_abs]
|
| 220 |
+
k_pos = positions[batch_idx, kv_idx]
|
| 221 |
|
| 222 |
+
is_causal = k_pos <= q_pos
|
| 223 |
+
in_window = (q_pos - k_pos) < window_size
|
| 224 |
|
| 225 |
return query_is_active & key_is_active & is_causal & in_window
|
| 226 |
|
__cache__sliding_window_cache.py
CHANGED
|
@@ -92,6 +92,15 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 92 |
device=device,
|
| 93 |
)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
self.is_initialized = True
|
| 96 |
|
| 97 |
# Cumulative count of all token positions presented through update() for
|
|
@@ -104,8 +113,9 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 104 |
key_states: torch.Tensor,
|
| 105 |
value_states: torch.Tensor,
|
| 106 |
active_mask: torch.Tensor,
|
|
|
|
| 107 |
cache_kwargs: dict | None = None,
|
| 108 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 109 |
"""Return the current-step local frame and retain the next-step window.
|
| 110 |
|
| 111 |
Args:
|
|
@@ -115,6 +125,8 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 115 |
current chunk.
|
| 116 |
active_mask: Shape `(B, T_new)` bool. `True` means the
|
| 117 |
corresponding token position in the current chunk is active.
|
|
|
|
|
|
|
| 118 |
cache_kwargs: Present only to satisfy the `CacheLayerMixin`
|
| 119 |
interface. Unused by this cache.
|
| 120 |
|
|
@@ -123,6 +135,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 123 |
- visible_keys: `(B, H, sliding_window + T_new, D)`
|
| 124 |
- visible_values: `(B, H, sliding_window + T_new, D)`
|
| 125 |
- visible_active_mask: `(B, sliding_window + T_new)`
|
|
|
|
| 126 |
|
| 127 |
These are the tensors the local attention path should consume
|
| 128 |
directly for the current step.
|
|
@@ -134,10 +147,11 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 134 |
|
| 135 |
# The current-step local frame is just retained cache state followed by
|
| 136 |
# the current chunk in chronological order.
|
| 137 |
-
composite_keys, composite_values, composite_mask = self._make_composite_frame(
|
| 138 |
key_states=key_states,
|
| 139 |
value_states=value_states,
|
| 140 |
active_mask=active_mask,
|
|
|
|
| 141 |
)
|
| 142 |
|
| 143 |
# The cache remembers only the last raw sliding-window positions of that
|
|
@@ -147,11 +161,12 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 147 |
composite_keys=composite_keys,
|
| 148 |
composite_values=composite_values,
|
| 149 |
composite_mask=composite_mask,
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
self._total_processed += key_states.shape[2]
|
| 153 |
|
| 154 |
-
return composite_keys, composite_values, composite_mask
|
| 155 |
|
| 156 |
def _ensure_state_compatibility(
|
| 157 |
self,
|
|
@@ -185,17 +200,25 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 185 |
non_blocking=True,
|
| 186 |
)
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def _make_composite_frame(
|
| 189 |
self,
|
| 190 |
key_states: torch.Tensor,
|
| 191 |
value_states: torch.Tensor,
|
| 192 |
active_mask: torch.Tensor,
|
| 193 |
-
|
|
|
|
| 194 |
"""Build the current-step local frame in chronological order."""
|
| 195 |
return (
|
| 196 |
torch.cat([self.keys, key_states], dim=-2),
|
| 197 |
torch.cat([self.values, value_states], dim=-2),
|
| 198 |
torch.cat([self.active_mask, active_mask], dim=-1),
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
def _retain_next_window(
|
|
@@ -203,15 +226,17 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 203 |
composite_keys: torch.Tensor,
|
| 204 |
composite_values: torch.Tensor,
|
| 205 |
composite_mask: torch.Tensor,
|
|
|
|
| 206 |
) -> None:
|
| 207 |
"""Remember the next-step retained local state.
|
| 208 |
|
| 209 |
This is a raw positional trim to the last `sliding_window` positions, not
|
| 210 |
a semantic live-token trim.
|
| 211 |
"""
|
| 212 |
-
self.keys = composite_keys[:, :, -self.sliding_window :, :]
|
| 213 |
-
self.values = composite_values[:, :, -self.sliding_window :, :]
|
| 214 |
-
self.active_mask = composite_mask[:, -self.sliding_window :]
|
|
|
|
| 215 |
|
| 216 |
def get_seq_length(self) -> int:
|
| 217 |
"""Return the cumulative number of token positions processed by this cache.
|
|
@@ -239,6 +264,7 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 239 |
self.keys.zero_()
|
| 240 |
self.values.zero_()
|
| 241 |
self.active_mask.zero_()
|
|
|
|
| 242 |
self._total_processed = 0
|
| 243 |
|
| 244 |
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
|
@@ -246,12 +272,14 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 246 |
self.keys = self.keys[beam_idx]
|
| 247 |
self.values = self.values[beam_idx]
|
| 248 |
self.active_mask = self.active_mask[beam_idx]
|
|
|
|
| 249 |
|
| 250 |
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 251 |
"""Expand the batch dimension for beam-search initialisation."""
|
| 252 |
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 253 |
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 254 |
self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
|
|
|
|
| 255 |
self.batch_size = self.batch_size * repeats
|
| 256 |
|
| 257 |
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
|
@@ -259,12 +287,14 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 259 |
self.keys = self.keys[indices]
|
| 260 |
self.values = self.values[indices]
|
| 261 |
self.active_mask = self.active_mask[indices]
|
|
|
|
| 262 |
self.batch_size = int(indices.shape[0])
|
| 263 |
|
| 264 |
def offload(self) -> None:
|
| 265 |
"""Offload cache tensors to CPU."""
|
| 266 |
super().offload()
|
| 267 |
self.active_mask = self.active_mask.to("cpu", non_blocking=True)
|
|
|
|
| 268 |
|
| 269 |
def prefetch(self) -> None:
|
| 270 |
"""Move cache tensors back to the model device ahead of time."""
|
|
@@ -274,6 +304,10 @@ class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
|
| 274 |
self.keys.device,
|
| 275 |
non_blocking=True,
|
| 276 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
def crop(self, max_length: int) -> None:
|
| 279 |
raise NotImplementedError(
|
|
|
|
| 92 |
device=device,
|
| 93 |
)
|
| 94 |
|
| 95 |
+
# Absolute sequence positions of each retained slot. Inactive slots
|
| 96 |
+
# retain zero; correctness is carried by active_mask.
|
| 97 |
+
self.positions = torch.zeros(
|
| 98 |
+
batch_size,
|
| 99 |
+
sliding_window,
|
| 100 |
+
dtype=torch.long,
|
| 101 |
+
device=device,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
self.is_initialized = True
|
| 105 |
|
| 106 |
# Cumulative count of all token positions presented through update() for
|
|
|
|
| 113 |
key_states: torch.Tensor,
|
| 114 |
value_states: torch.Tensor,
|
| 115 |
active_mask: torch.Tensor,
|
| 116 |
+
positions: torch.Tensor,
|
| 117 |
cache_kwargs: dict | None = None,
|
| 118 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 119 |
"""Return the current-step local frame and retain the next-step window.
|
| 120 |
|
| 121 |
Args:
|
|
|
|
| 125 |
current chunk.
|
| 126 |
active_mask: Shape `(B, T_new)` bool. `True` means the
|
| 127 |
corresponding token position in the current chunk is active.
|
| 128 |
+
positions: Shape `(B, T_new)` long. Absolute sequence position of
|
| 129 |
+
each token in the current chunk.
|
| 130 |
cache_kwargs: Present only to satisfy the `CacheLayerMixin`
|
| 131 |
interface. Unused by this cache.
|
| 132 |
|
|
|
|
| 135 |
- visible_keys: `(B, H, sliding_window + T_new, D)`
|
| 136 |
- visible_values: `(B, H, sliding_window + T_new, D)`
|
| 137 |
- visible_active_mask: `(B, sliding_window + T_new)`
|
| 138 |
+
- visible_positions: `(B, sliding_window + T_new)`
|
| 139 |
|
| 140 |
These are the tensors the local attention path should consume
|
| 141 |
directly for the current step.
|
|
|
|
| 147 |
|
| 148 |
# The current-step local frame is just retained cache state followed by
|
| 149 |
# the current chunk in chronological order.
|
| 150 |
+
composite_keys, composite_values, composite_mask, composite_positions = self._make_composite_frame(
|
| 151 |
key_states=key_states,
|
| 152 |
value_states=value_states,
|
| 153 |
active_mask=active_mask,
|
| 154 |
+
positions=positions,
|
| 155 |
)
|
| 156 |
|
| 157 |
# The cache remembers only the last raw sliding-window positions of that
|
|
|
|
| 161 |
composite_keys=composite_keys,
|
| 162 |
composite_values=composite_values,
|
| 163 |
composite_mask=composite_mask,
|
| 164 |
+
composite_positions=composite_positions,
|
| 165 |
)
|
| 166 |
|
| 167 |
self._total_processed += key_states.shape[2]
|
| 168 |
|
| 169 |
+
return composite_keys, composite_values, composite_mask, composite_positions
|
| 170 |
|
| 171 |
def _ensure_state_compatibility(
|
| 172 |
self,
|
|
|
|
| 200 |
non_blocking=True,
|
| 201 |
)
|
| 202 |
|
| 203 |
+
if self.positions.device != key_states.device:
|
| 204 |
+
self.positions = self.positions.to(
|
| 205 |
+
key_states.device,
|
| 206 |
+
non_blocking=True,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
def _make_composite_frame(
|
| 210 |
self,
|
| 211 |
key_states: torch.Tensor,
|
| 212 |
value_states: torch.Tensor,
|
| 213 |
active_mask: torch.Tensor,
|
| 214 |
+
positions: torch.Tensor,
|
| 215 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 216 |
"""Build the current-step local frame in chronological order."""
|
| 217 |
return (
|
| 218 |
torch.cat([self.keys, key_states], dim=-2),
|
| 219 |
torch.cat([self.values, value_states], dim=-2),
|
| 220 |
torch.cat([self.active_mask, active_mask], dim=-1),
|
| 221 |
+
torch.cat([self.positions, positions], dim=-1),
|
| 222 |
)
|
| 223 |
|
| 224 |
def _retain_next_window(
|
|
|
|
| 226 |
composite_keys: torch.Tensor,
|
| 227 |
composite_values: torch.Tensor,
|
| 228 |
composite_mask: torch.Tensor,
|
| 229 |
+
composite_positions: torch.Tensor,
|
| 230 |
) -> None:
|
| 231 |
"""Remember the next-step retained local state.
|
| 232 |
|
| 233 |
This is a raw positional trim to the last `sliding_window` positions, not
|
| 234 |
a semantic live-token trim.
|
| 235 |
"""
|
| 236 |
+
self.keys[:] = composite_keys[:, :, -self.sliding_window :, :]
|
| 237 |
+
self.values[:] = composite_values[:, :, -self.sliding_window :, :]
|
| 238 |
+
self.active_mask[:] = composite_mask[:, -self.sliding_window :]
|
| 239 |
+
self.positions[:] = composite_positions[:, -self.sliding_window :]
|
| 240 |
|
| 241 |
def get_seq_length(self) -> int:
|
| 242 |
"""Return the cumulative number of token positions processed by this cache.
|
|
|
|
| 264 |
self.keys.zero_()
|
| 265 |
self.values.zero_()
|
| 266 |
self.active_mask.zero_()
|
| 267 |
+
self.positions.zero_()
|
| 268 |
self._total_processed = 0
|
| 269 |
|
| 270 |
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
|
|
|
| 272 |
self.keys = self.keys[beam_idx]
|
| 273 |
self.values = self.values[beam_idx]
|
| 274 |
self.active_mask = self.active_mask[beam_idx]
|
| 275 |
+
self.positions = self.positions[beam_idx]
|
| 276 |
|
| 277 |
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 278 |
"""Expand the batch dimension for beam-search initialisation."""
|
| 279 |
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 280 |
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 281 |
self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
|
| 282 |
+
self.positions = self.positions.repeat_interleave(repeats, dim=0)
|
| 283 |
self.batch_size = self.batch_size * repeats
|
| 284 |
|
| 285 |
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
|
|
|
| 287 |
self.keys = self.keys[indices]
|
| 288 |
self.values = self.values[indices]
|
| 289 |
self.active_mask = self.active_mask[indices]
|
| 290 |
+
self.positions = self.positions[indices]
|
| 291 |
self.batch_size = int(indices.shape[0])
|
| 292 |
|
| 293 |
def offload(self) -> None:
|
| 294 |
"""Offload cache tensors to CPU."""
|
| 295 |
super().offload()
|
| 296 |
self.active_mask = self.active_mask.to("cpu", non_blocking=True)
|
| 297 |
+
self.positions = self.positions.to("cpu", non_blocking=True)
|
| 298 |
|
| 299 |
def prefetch(self) -> None:
|
| 300 |
"""Move cache tensors back to the model device ahead of time."""
|
|
|
|
| 304 |
self.keys.device,
|
| 305 |
non_blocking=True,
|
| 306 |
)
|
| 307 |
+
self.positions = self.positions.to(
|
| 308 |
+
self.keys.device,
|
| 309 |
+
non_blocking=True,
|
| 310 |
+
)
|
| 311 |
|
| 312 |
def crop(self, max_length: int) -> None:
|
| 313 |
raise NotImplementedError(
|
config.json
CHANGED
|
@@ -21,7 +21,7 @@
|
|
| 21 |
"rope_mode": "main_sequence",
|
| 22 |
"tie_word_embeddings": false,
|
| 23 |
"training_sequence_length": 1024,
|
| 24 |
-
"transformers_version": "5.
|
| 25 |
"use_cache": true,
|
| 26 |
"vocab_size": 50277,
|
| 27 |
"window_size": 128
|
|
|
|
| 21 |
"rope_mode": "main_sequence",
|
| 22 |
"tie_word_embeddings": false,
|
| 23 |
"training_sequence_length": 1024,
|
| 24 |
+
"transformers_version": "5.8.0",
|
| 25 |
"use_cache": true,
|
| 26 |
"vocab_size": 50277,
|
| 27 |
"window_size": 128
|