Fix UniversalTransformerCache.get_mask_sizes for batched generation
Problem
Batched generation (batch_size > 1) produces corrupted output for all sequences except the longest (unpadded) one in the batch.
UniversalTransformerCache inherits Cache.get_mask_sizes, which falls back to return cache_position.shape[0], 0 when layer_idx >= len(self.layers).
Since UniversalTransformerCache manages its own flat key_cache/value_cache lists and keeps self.layers empty ([]), this fallback always fires:
- Prefill: works correctly (
cache_positionspans full input length) - Autoregressive decoding: fails β
cache_positionhas length 1, so the 4D attention mask is built forkv_length=1instead ofcached_length + 1
This undersized mask gets broadcasted across the full KV cache, losing per-position padding information and corrupting every padded sequence in the batch.
Manifests differently by attention implementation:
"eager": garbled whitespace output"sdpa":RuntimeError: (*bias): last dimension must be contiguous"flash_attention_2": works fine (ignores 4D mask)
Fix
Override get_mask_sizes to return the correct (seq_length + query_length, 0), matching the semantics of DynamicCacheLayer.get_mask_sizes.
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "ByteDance/Ouro-1.4B" # or Ouro-2.6B
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
See also: same fix merged for the Thinking variant β https://huggingface.co/ByteDance/Ouro-1.4B-Thinking/discussions/5