Fix UniversalTransformerCache.get_mask_sizes for batched generation
Browse files## Problem
Batched generation (batch_size > 1) produces corrupted output for all sequences except the longest (unpadded) one in the batch.
`UniversalTransformerCache` inherits `Cache.get_mask_sizes`, which falls back to `return cache_position.shape[0], 0` when `layer_idx >= len(self.layers)`.
Since `UniversalTransformerCache` manages its own flat `key_cache`/`value_cache` lists and keeps `self.layers` empty (`[]`), this fallback **always** fires:
- **Prefill**: works correctly (`cache_position` spans full input length)
- **Autoregressive decoding**: **fails** — `cache_position` has length 1, so the 4D attention mask is built for `kv_length=1` instead of `cached_length + 1`
This undersized mask gets broadcasted across the full KV cache, losing per-position padding information and corrupting every padded sequence in the batch.
Manifests differently by attention implementation:
- `"eager"`: garbled whitespace output
- `"sdpa"`: `RuntimeError: (*bias): last dimension must be contiguous`
- `"flash_attention_2"`: works fine (ignores 4D mask)
## Fix
Override `get_mask_sizes` to return the correct `(seq_length + query_length, 0)`, matching the semantics of `DynamicCacheLayer.get_mask_sizes`.
## Reproduction
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "ByteDance/Ouro-1.4B" # or Ouro-2.6B
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
```
See also: same fix merged for the Thinking variant — https://huggingface.co/ByteDance/Ouro-1.4B-Thinking/discussions/5
- modeling_ouro.py +12 -0
|
@@ -214,6 +214,18 @@ class UniversalTransformerCache(Cache):
|
|
| 214 |
self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
|
| 215 |
self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
@property
|
| 218 |
def is_compileable(self) -> bool:
|
| 219 |
return False
|
|
|
|
| 214 |
self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
|
| 215 |
self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
|
| 216 |
|
| 217 |
+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int = 0) -> tuple[int, int]:
|
| 218 |
+
"""Return (kv_length, kv_offset) accounting for cached tokens.
|
| 219 |
+
|
| 220 |
+
The inherited Cache.get_mask_sizes checks ``self.layers`` which is
|
| 221 |
+
always empty for UniversalTransformerCache, causing it to return
|
| 222 |
+
``(query_length, 0)`` instead of ``(cached_length + query_length, 0)``
|
| 223 |
+
during autoregressive decoding.
|
| 224 |
+
"""
|
| 225 |
+
query_length = cache_position.shape[0]
|
| 226 |
+
seq_length = self.get_seq_length(layer_idx)
|
| 227 |
+
return seq_length + query_length, 0
|
| 228 |
+
|
| 229 |
@property
|
| 230 |
def is_compileable(self) -> bool:
|
| 231 |
return False
|