Fix UniversalTransformerCache.get_mask_sizes for batched generation
#4
by
KristianS7 - opened
- modeling_ouro.py +12 -0
modeling_ouro.py
CHANGED
|
@@ -214,6 +214,18 @@ class UniversalTransformerCache(Cache):
|
|
| 214 |
self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
|
| 215 |
self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
@property
|
| 218 |
def is_compileable(self) -> bool:
|
| 219 |
return False
|
|
|
|
| 214 |
self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
|
| 215 |
self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
|
| 216 |
|
| 217 |
+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int = 0) -> tuple[int, int]:
|
| 218 |
+
"""Return (kv_length, kv_offset) accounting for cached tokens.
|
| 219 |
+
|
| 220 |
+
The inherited Cache.get_mask_sizes checks ``self.layers`` which is
|
| 221 |
+
always empty for UniversalTransformerCache, causing it to return
|
| 222 |
+
``(query_length, 0)`` instead of ``(cached_length + query_length, 0)``
|
| 223 |
+
during autoregressive decoding.
|
| 224 |
+
"""
|
| 225 |
+
query_length = cache_position.shape[0]
|
| 226 |
+
seq_length = self.get_seq_length(layer_idx)
|
| 227 |
+
return seq_length + query_length, 0
|
| 228 |
+
|
| 229 |
@property
|
| 230 |
def is_compileable(self) -> bool:
|
| 231 |
return False
|