ridger KristianS7 commited on
Commit
3201be3
·
1 Parent(s): 85a074d

Fix UniversalTransformerCache.get_mask_sizes for batched generation (#8)

Browse files

- Fix UniversalTransformerCache.get_mask_sizes for batched generation (7b9f90ad5eb766f3c3559124ad2fe67f885a2a5c)


Co-authored-by: Kristian Shw <KristianS7@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_ouro.py +22 -0
modeling_ouro.py CHANGED
@@ -195,6 +195,28 @@ class UniversalTransformerCache(Cache):
195
  return 0
196
  return cached.shape[2]
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def get_max_length(self) -> Optional[int]:
199
  return None
200
 
 
195
  return 0
196
  return cached.shape[2]
197
 
198
+ def get_mask_sizes(
199
+ self, cache_position: torch.Tensor, layer_idx: int = 0
200
+ ) -> tuple[int, int]:
201
+ """Return (kv_length, kv_offset) for attention mask creation.
202
+
203
+ The inherited ``Cache.get_mask_sizes`` falls back to
204
+ ``(cache_position.shape[0], 0)`` when ``layer_idx >= len(self.layers)``.
205
+ Because ``UniversalTransformerCache`` manages its own flat
206
+ ``key_cache`` / ``value_cache`` lists and keeps ``self.layers`` empty,
207
+ the fallback always fires. During the prefill step this happens to be
208
+ correct (``cache_position`` spans the full input), but during
209
+ autoregressive decoding ``cache_position`` has length 1, so the mask is
210
+ built for ``kv_length=1`` instead of ``cached_length + 1``. As a
211
+ result the 4D attention mask is too small and padding information is
212
+ lost, corrupting batched generation for every sequence except the
213
+ longest (unpadded) one.
214
+ """
215
+ query_length = cache_position.shape[0]
216
+ seq_length = self.get_seq_length(layer_idx)
217
+ kv_length = seq_length + query_length
218
+ return kv_length, 0
219
+
220
  def get_max_length(self) -> Optional[int]:
221
  return None
222