KristianS7 commited on
Commit
9ff42b9
·
verified ·
1 Parent(s): 1ed0425

Fix UniversalTransformerCache.get_mask_sizes for batched generation

Browse files

## Problem

Batched generation (batch_size > 1) produces corrupted output for all sequences except the longest (unpadded) one in the batch.

`UniversalTransformerCache` inherits `Cache.get_mask_sizes`, which falls back to `return cache_position.shape[0], 0` when `layer_idx >= len(self.layers)`.

Since `UniversalTransformerCache` manages its own flat `key_cache`/`value_cache` lists and keeps `self.layers` empty (`[]`), this fallback **always** fires:

- **Prefill**: works correctly (`cache_position` spans full input length)
- **Autoregressive decoding**: **fails** — `cache_position` has length 1, so the 4D attention mask is built for `kv_length=1` instead of `cached_length + 1`

This undersized mask gets broadcasted across the full KV cache, losing per-position padding information and corrupting every padded sequence in the batch.

Manifests differently by attention implementation:
- `"eager"`: garbled whitespace output
- `"sdpa"`: `RuntimeError: (*bias): last dimension must be contiguous`
- `"flash_attention_2"`: works fine (ignores 4D mask)

## Fix

Override `get_mask_sizes` to return the correct `(seq_length + query_length, 0)`, matching the semantics of `DynamicCacheLayer.get_mask_sizes`.

## Reproduction

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "ByteDance/Ouro-1.4B" # or Ouro-2.6B
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto",
attn_implementation="eager",
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 0

prompts = ["What is 2+2?", "Explain why the sky is blue in one sentence."]
batch = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
outputs = model.generate(**batch, max_new_tokens=64, do_sample=False, eos_token_id=2, pad_token_id=0)
```

See also: same fix merged for the Thinking variant — https://huggingface.co/ByteDance/Ouro-1.4B-Thinking/discussions/5

Files changed (1) hide show
  1. modeling_ouro.py +12 -0
modeling_ouro.py CHANGED
@@ -214,6 +214,18 @@ class UniversalTransformerCache(Cache):
214
  self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
215
  self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
216
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  @property
218
  def is_compileable(self) -> bool:
219
  return False
 
214
  self.key_cache[idx] = key_entry.index_select(0, beam_idx.to(device))
215
  self.value_cache[idx] = value_entry.index_select(0, beam_idx.to(device))
216
 
217
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int = 0) -> tuple[int, int]:
218
+ """Return (kv_length, kv_offset) accounting for cached tokens.
219
+
220
+ The inherited Cache.get_mask_sizes checks ``self.layers`` which is
221
+ always empty for UniversalTransformerCache, causing it to return
222
+ ``(query_length, 0)`` instead of ``(cached_length + query_length, 0)``
223
+ during autoregressive decoding.
224
+ """
225
+ query_length = cache_position.shape[0]
226
+ seq_length = self.get_seq_length(layer_idx)
227
+ return seq_length + query_length, 0
228
+
229
  @property
230
  def is_compileable(self) -> bool:
231
  return False