Fix UniversalTransformerCache.get_mask_sizes for batched generation

#4
by KristianS7 - opened

Problem

Batched generation (batch_size > 1) produces corrupted output for all sequences except the longest (unpadded) one in the batch.

UniversalTransformerCache inherits Cache.get_mask_sizes, which falls back to return cache_position.shape[0], 0 when layer_idx >= len(self.layers).

Since UniversalTransformerCache manages its own flat key_cache/value_cache lists and keeps self.layers empty ([]), this fallback always fires:

  • Prefill: works correctly (cache_position spans full input length)
  • Autoregressive decoding: fails β€” cache_position has length 1, so the 4D attention mask is built for kv_length=1 instead of cached_length + 1

This undersized mask gets broadcasted across the full KV cache, losing per-position padding information and corrupting every padded sequence in the batch.

Manifests differently by attention implementation:

  • "eager": garbled whitespace output
  • "sdpa": RuntimeError: (*bias): last dimension must be contiguous
  • "flash_attention_2": works fine (ignores 4D mask)

Fix

Override get_mask_sizes to return the correct (seq_length + query_length, 0), matching the semantics of DynamicCacheLayer.get_mask_sizes.

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "ByteDance/Ouro-1.4B"  # or Ouro-2.6B
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
    attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0

prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)

See also: same fix merged for the Thinking variant β€” https://huggingface.co/ByteDance/Ouro-1.4B-Thinking/discussions/5

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment