Fix UniversalTransformerCache.get_mask_sizes for batched generation
Problem
Batched generation with HuggingFace Transformers produces corrupted output for
all sequences except the longest (unpadded) one in the batch.
Root cause
UniversalTransformerCache inherits Cache.get_mask_sizes, which falls back toreturn cache_position.shape[0], 0 when layer_idx >= len(self.layers).
Because UniversalTransformerCache manages its own flat key_cache /value_cache lists and keeps self.layers empty ([]), this fallback always
fires. During the prefill step this happens to be correct (cache_position
spans the full input length), but during autoregressive decodingcache_position has length 1, so the 4D attention mask is built forkv_length=1 instead of cached_length + 1.
The undersized mask gets broadcasted across the full KV cache, losing all
per-position padding information. This corrupts every padded sequence in the
batch.
Fix
Override get_mask_sizes to return the correct (seq_length + query_length, 0),
matching the semantics of DynamicCacheLayer.get_mask_sizes.
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "ByteDance/Ouro-1.4B-Thinking"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0 # <|endoftext|>
# Two prompts of different lengths
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Without fix: shorter prompt output is corrupted
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
for i, p in enumerate(prompts):
tokens = outputs[i][batch["input_ids"].shape[1]:]
print(f"[{i}] {p!r} -> {tokenizer.decode(tokens, skip_special_tokens=False)[:100]}")
batch_size>1 not working properly with attn_implementation="eager" (many whitespaces) and "sdpa" (completely crash). "flash_attention_2" backend worked fine.