Fix UniversalTransformerCache.get_mask_sizes for batched generation

#5
by KristianS7 - opened

Problem

Batched generation with HuggingFace Transformers produces corrupted output for
all sequences except the longest (unpadded) one in the batch.

Root cause

UniversalTransformerCache inherits Cache.get_mask_sizes, which falls back to
return cache_position.shape[0], 0 when layer_idx >= len(self.layers).

Because UniversalTransformerCache manages its own flat key_cache /
value_cache lists and keeps self.layers empty ([]), this fallback always
fires. During the prefill step this happens to be correct (cache_position
spans the full input length), but during autoregressive decoding
cache_position has length 1, so the 4D attention mask is built for
kv_length=1 instead of cached_length + 1.

The undersized mask gets broadcasted across the full KV cache, losing all
per-position padding information. This corrupts every padded sequence in the
batch.

Fix

Override get_mask_sizes to return the correct (seq_length + query_length, 0),
matching the semantics of DynamicCacheLayer.get_mask_sizes.

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "ByteDance/Ouro-1.4B-Thinking"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
    attn_implementation="eager",
)

tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0  # <|endoftext|>

# Two prompts of different lengths
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Without fix: shorter prompt output is corrupted
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
for i, p in enumerate(prompts):
    tokens = outputs[i][batch["input_ids"].shape[1]:]
    print(f"[{i}] {p!r} -> {tokenizer.decode(tokens, skip_special_tokens=False)[:100]}")

batch_size>1 not working properly with attn_implementation="eager" (many whitespaces) and "sdpa" (completely crash). "flash_attention_2" backend worked fine.

ridger changed pull request status to merged

Sign up or log in to comment