Fix UniversalTransformerCache.get_mask_sizes for batched generation

#4
by KristianS7 - opened
Files changed (1) hide show
  1. modeling_ouro.py +12 -0
modeling_ouro.py CHANGED
@@ -214,6 +214,18 @@ class UniversalTransformerCache(Cache):
214
  self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
215
  self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
216
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  @property
218
  def is_compileable(self) -> bool:
219
  return False
 
214
  self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
215
  self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
216
 
217
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int = 0) -> tuple[int, int]:
218
+ """Return (kv_length, kv_offset) accounting for cached tokens.
219
+
220
+ The inherited Cache.get_mask_sizes checks ``self.layers`` which is
221
+ always empty for UniversalTransformerCache, causing it to return
222
+ ``(query_length, 0)`` instead of ``(cached_length + query_length, 0)``
223
+ during autoregressive decoding.
224
+ """
225
+ query_length = cache_position.shape[0]
226
+ seq_length = self.get_seq_length(layer_idx)
227
+ return seq_length + query_length, 0
228
+
229
  @property
230
  def is_compileable(self) -> bool:
231
  return False