Commit ·
3201be3
1
Parent(s): 85a074d
Fix UniversalTransformerCache.get_mask_sizes for batched generation (#8)
Browse files- Fix UniversalTransformerCache.get_mask_sizes for batched generation (7b9f90ad5eb766f3c3559124ad2fe67f885a2a5c)
Co-authored-by: Kristian Shw <KristianS7@users.noreply.huggingface.co>
- modeling_ouro.py +22 -0
modeling_ouro.py
CHANGED
|
@@ -195,6 +195,28 @@ class UniversalTransformerCache(Cache):
|
|
| 195 |
return 0
|
| 196 |
return cached.shape[2]
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
def get_max_length(self) -> Optional[int]:
|
| 199 |
return None
|
| 200 |
|
|
|
|
| 195 |
return 0
|
| 196 |
return cached.shape[2]
|
| 197 |
|
| 198 |
+
def get_mask_sizes(
|
| 199 |
+
self, cache_position: torch.Tensor, layer_idx: int = 0
|
| 200 |
+
) -> tuple[int, int]:
|
| 201 |
+
"""Return (kv_length, kv_offset) for attention mask creation.
|
| 202 |
+
|
| 203 |
+
The inherited ``Cache.get_mask_sizes`` falls back to
|
| 204 |
+
``(cache_position.shape[0], 0)`` when ``layer_idx >= len(self.layers)``.
|
| 205 |
+
Because ``UniversalTransformerCache`` manages its own flat
|
| 206 |
+
``key_cache`` / ``value_cache`` lists and keeps ``self.layers`` empty,
|
| 207 |
+
the fallback always fires. During the prefill step this happens to be
|
| 208 |
+
correct (``cache_position`` spans the full input), but during
|
| 209 |
+
autoregressive decoding ``cache_position`` has length 1, so the mask is
|
| 210 |
+
built for ``kv_length=1`` instead of ``cached_length + 1``. As a
|
| 211 |
+
result the 4D attention mask is too small and padding information is
|
| 212 |
+
lost, corrupting batched generation for every sequence except the
|
| 213 |
+
longest (unpadded) one.
|
| 214 |
+
"""
|
| 215 |
+
query_length = cache_position.shape[0]
|
| 216 |
+
seq_length = self.get_seq_length(layer_idx)
|
| 217 |
+
kv_length = seq_length + query_length
|
| 218 |
+
return kv_length, 0
|
| 219 |
+
|
| 220 |
def get_max_length(self) -> Optional[int]:
|
| 221 |
return None
|
| 222 |
|