Fix UniversalTransformerCache.get_mask_sizes for batched generation
Browse files## Problem
Batched generation with HuggingFace Transformers produces corrupted output for
all sequences except the longest (unpadded) one in the batch.
## Root cause
`UniversalTransformerCache` inherits `Cache.get_mask_sizes`, which falls back to
`return cache_position.shape[0], 0` when `layer_idx >= len(self.layers)`.
Because `UniversalTransformerCache` manages its own flat `key_cache` /
`value_cache` lists and keeps `self.layers` empty (`[]`), this fallback **always**
fires. During the prefill step this happens to be correct (`cache_position`
spans the full input length), but during autoregressive decoding
`cache_position` has length 1, so the 4D attention mask is built for
`kv_length=1` instead of `cached_length + 1`.
The undersized mask gets broadcasted across the full KV cache, losing all
per-position padding information. This corrupts every padded sequence in the
batch.
## Fix
Override `get_mask_sizes` to return the correct `(seq_length + query_length, 0)`,
matching the semantics of `DynamicCacheLayer.get_mask_sizes`.
## Reproduction
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "ByteDance/Ouro-1.4B-Thinking"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0 # <|endoftext|>
# Two prompts of different lengths
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Without fix: shorter prompt output is corrupted
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
for i, p in enumerate(prompts):
tokens = outputs[i][batch["input_ids"].shape[1]:]
print(f"[{i}] {p!r} -> {tokenizer.decode(tokens, skip_special_tokens=False)[:100]}")
```
- modeling_ouro.py +22 -0
|
@@ -195,6 +195,28 @@ class UniversalTransformerCache(Cache):
|
|
| 195 |
return 0
|
| 196 |
return cached.shape[2]
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
def get_max_length(self) -> Optional[int]:
|
| 199 |
return None
|
| 200 |
|
|
|
|
| 195 |
return 0
|
| 196 |
return cached.shape[2]
|
| 197 |
|
| 198 |
+
def get_mask_sizes(
|
| 199 |
+
self, cache_position: torch.Tensor, layer_idx: int = 0
|
| 200 |
+
) -> tuple[int, int]:
|
| 201 |
+
"""Return (kv_length, kv_offset) for attention mask creation.
|
| 202 |
+
|
| 203 |
+
The inherited ``Cache.get_mask_sizes`` falls back to
|
| 204 |
+
``(cache_position.shape[0], 0)`` when ``layer_idx >= len(self.layers)``.
|
| 205 |
+
Because ``UniversalTransformerCache`` manages its own flat
|
| 206 |
+
``key_cache`` / ``value_cache`` lists and keeps ``self.layers`` empty,
|
| 207 |
+
the fallback always fires. During the prefill step this happens to be
|
| 208 |
+
correct (``cache_position`` spans the full input), but during
|
| 209 |
+
autoregressive decoding ``cache_position`` has length 1, so the mask is
|
| 210 |
+
built for ``kv_length=1`` instead of ``cached_length + 1``. As a
|
| 211 |
+
result the 4D attention mask is too small and padding information is
|
| 212 |
+
lost, corrupting batched generation for every sequence except the
|
| 213 |
+
longest (unpadded) one.
|
| 214 |
+
"""
|
| 215 |
+
query_length = cache_position.shape[0]
|
| 216 |
+
seq_length = self.get_seq_length(layer_idx)
|
| 217 |
+
kv_length = seq_length + query_length
|
| 218 |
+
return kv_length, 0
|
| 219 |
+
|
| 220 |
def get_max_length(self) -> Optional[int]:
|
| 221 |
return None
|
| 222 |
|