KristianS7 commited on
Commit
7b9f90a
·
verified ·
1 Parent(s): 85a074d

Fix UniversalTransformerCache.get_mask_sizes for batched generation

Browse files

## Problem

Batched generation with HuggingFace Transformers produces corrupted output for
all sequences except the longest (unpadded) one in the batch.

## Root cause

`UniversalTransformerCache` inherits `Cache.get_mask_sizes`, which falls back to
`return cache_position.shape[0], 0` when `layer_idx >= len(self.layers)`.

Because `UniversalTransformerCache` manages its own flat `key_cache` /
`value_cache` lists and keeps `self.layers` empty (`[]`), this fallback **always**
fires. During the prefill step this happens to be correct (`cache_position`
spans the full input length), but during autoregressive decoding
`cache_position` has length 1, so the 4D attention mask is built for
`kv_length=1` instead of `cached_length + 1`.

The undersized mask gets broadcasted across the full KV cache, losing all
per-position padding information. This corrupts every padded sequence in the
batch.

## Fix

Override `get_mask_sizes` to return the correct `(seq_length + query_length, 0)`,
matching the semantics of `DynamicCacheLayer.get_mask_sizes`.

## Reproduction

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "ByteDance/Ouro-1.4B-Thinking"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)

tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0 # <|endoftext|>

# Two prompts of different lengths
prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Without fix: shorter prompt output is corrupted
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
for i, p in enumerate(prompts):
tokens = outputs[i][batch["input_ids"].shape[1]:]
print(f"[{i}] {p!r} -> {tokenizer.decode(tokens, skip_special_tokens=False)[:100]}")
```

Files changed (1) hide show
  1. modeling_ouro.py +22 -0
modeling_ouro.py CHANGED
@@ -195,6 +195,28 @@ class UniversalTransformerCache(Cache):
195
  return 0
196
  return cached.shape[2]
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def get_max_length(self) -> Optional[int]:
199
  return None
200
 
 
195
  return 0
196
  return cached.shape[2]
197
 
198
+ def get_mask_sizes(
199
+ self, cache_position: torch.Tensor, layer_idx: int = 0
200
+ ) -> tuple[int, int]:
201
+ """Return (kv_length, kv_offset) for attention mask creation.
202
+
203
+ The inherited ``Cache.get_mask_sizes`` falls back to
204
+ ``(cache_position.shape[0], 0)`` when ``layer_idx >= len(self.layers)``.
205
+ Because ``UniversalTransformerCache`` manages its own flat
206
+ ``key_cache`` / ``value_cache`` lists and keeps ``self.layers`` empty,
207
+ the fallback always fires. During the prefill step this happens to be
208
+ correct (``cache_position`` spans the full input), but during
209
+ autoregressive decoding ``cache_position`` has length 1, so the mask is
210
+ built for ``kv_length=1`` instead of ``cached_length + 1``. As a
211
+ result the 4D attention mask is too small and padding information is
212
+ lost, corrupting batched generation for every sequence except the
213
+ longest (unpadded) one.
214
+ """
215
+ query_length = cache_position.shape[0]
216
+ seq_length = self.get_seq_length(layer_idx)
217
+ kv_length = seq_length + query_length
218
+ return kv_length, 0
219
+
220
  def get_max_length(self) -> Optional[int]:
221
  return None
222