DoLa custom generation produces gibberish with use_cache=True

#1
by lavrenko - opened

Problem

transformers-community/dola produces different and degenerate output when
use_cache=True, while the same deterministic generation call with
use_cache=False produces a coherent continuation.

Since the KV cache should only be an optimization, enabling use_cache should
not materially change the generated token sequence under deterministic decoding.

This looks similar to the recently reported and fixed cache-handling issue in
transformers-community/group-beam-search:

Minimal reproduction

import platform

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

print("python:", platform.python_version())
print("torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("cuda:", torch.cuda.is_available())

m = "Qwen/Qwen3-0.6B"
custom_generate = "transformers-community/dola"

tok = AutoTokenizer.from_pretrained(m)
model = AutoModelForCausalLM.from_pretrained(
    m,
    device_map="auto",
)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

print("model:", m)
print("custom_generate:", custom_generate)
print("device:", model.device)

inputs = tok(
    "The most popular ways of using transformers are",
    return_tensors="pt",
).to(model.device)

outs = []
for use_cache in [False, True]:
    out = model.generate(
        **inputs,
        custom_generate=custom_generate,
        trust_remote_code=True,
        dola_layers="low",
        do_sample=False,
        min_new_tokens=10,
        max_new_tokens=30,
        use_cache=use_cache,
        output_hidden_states=True,
        repetition_penalty=1.2,
        pad_token_id=tok.pad_token_id,
    )

    new = out[0, inputs["input_ids"].shape[-1]:]
    outs.append(new.tolist())

    print("\nuse_cache =", use_cache)
    print("ids:", new.tolist())
    print("text:", repr(tok.decode(new, skip_special_tokens=True)))

assert outs[0] == outs[1]

Environment and observed output

python: 3.12.13
torch: 2.11.0+cu128
transformers: 5.12.0
cuda: True
Loading weights: 100% 311/311 [00:00<00:00, 918.98it/s]
model: Qwen/Qwen3-0.6B
custom_generate: transformers-community/dola
device: cuda:0
[transformers] The following generation flags are not valid and may be ignored:
['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.

use_cache = False
ids: [304, 279, 2355, 7982, 1849, 11, 323, 1083, 304, 12785,
8357, 13, 576, 6028, 990, 374, 311, 8317, 9072, 4802, 504,
825, 1992, 311, 2441, 13, 80532, 646, 387, 1483]
text: ' in the power distribution system, and also in industrial applications. The primary use is to transfer electric energy from one place to another. Transformers can be used'

use_cache = True
ids: [304, 1112, 1447, 2146, 25, 3110, 271, 716, 311, 1112,
198, 279, 11, 30, 264, 220, 4701, 4226, 1429, 13, 320, 279,
374, 537, 576, 279, 369, 1667, 86870, 7611]
text: ' in...:\n\n...\n\n: example\n\n _ to...\n the,? a transform answer most. ( the is not The the for using transformers devices'
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_9209/2819012582.py in <cell line: 0>()
     53     print("text:", repr(tok.decode(new, skip_special_tokens=True)))
     54 
---> 55 assert outs[0] == outs[1]

AssertionError:

Expected behavior

use_cache=True and use_cache=False should produce the same generated token
sequence for this deterministic generation call, or at least should not produce
a nonsensical continuation only when cache is enabled.

Why this looks cache-related

The only changed generation argument in the reproduction above is use_cache.

The uncached path produces a coherent continuation:

in the power distribution system, and also in industrial applications. The
primary use is to transfer electric energy from one place to another.
Transformers can be used

The cached path produces nonsensical / degenerate text:

in...:

...

: example

 _ to...
 the,? a transform answer most. ( the is not The the for using transformers
devices

This suggests that the cached decoding path may be passing an incorrect token
sequence or cache position during decoding.

This appears to be the same class of issue recently found in
transformers-community/group-beam-search, where after the prefill step the
custom generation loop needed to distinguish the first iteration from subsequent
cached decoding iterations.

Relevant prior discussion and fix:

Transformers Community org

Merged

RaushanTurganbay changed discussion status to closed

Sign up or log in to comment