Support Transformers v5 cache handling

#3
by lavrenko - opened

This PR was created to address the cached-decoding regression reported in transformers-community/constrained-beam-search discussion #2, where deterministic constrained beam search behaves correctly with use_cache=False but produces degenerate output with use_cache=True.

The issue appears to be the same Transformers v5 cache-handling compatibility problem that was previously reported and fixed in transformers-community/group-beam-search: https://huggingface.co/transformers-community/group-beam-search/discussions/4

I pushed a fix for the cached-decoding issue reported in https://huggingface.co/transformers-community/constrained-beam-search/discussions/2.

The generation loop now passes next_sequence_length and is_first_iteration into prepare_inputs_for_generation, matching the Transformers v5 cache protocol. This follows the same compatibility pattern that was already accepted for transformers-community/group-beam-search: https://huggingface.co/transformers-community/group-beam-search/discussions/4

The reason for the change is that the KV cache should be an optimization only: with deterministic constrained beam search, use_cache=True should not change the generated result compared with use_cache=False.

I tested the PR in Colab with openai-community/gpt2, force_words_ids, num_beams=4, and do_sample=False.

Transformers 5 test

python: 3.12.13
torch: 2.11.0+cu128
transformers: 5.12.0
cuda: True
model: openai-community/gpt2
custom_generate: /content/constrained-beam-search/
device: cuda:0

use_cache=False:

ids: [25, 198, 198, 16, 13, 5765, 262, 6121, 364, 287, 262, 976, 835, 345, 561, 779, 257, 3218, 5408, 13, 198, 198, 17, 13, 5765, 262, 6121, 364, 287, 11059]
text: ':\n\n1. Use the transformers in the same way you would use a regular expression.\n\n2. Use the transformers in translation'

use_cache=True:

ids: [25, 198, 198, 16, 13, 5765, 262, 6121, 364, 287, 262, 976, 835, 345, 561, 779, 257, 3218, 5408, 13, 198, 198, 17, 13, 5765, 262, 6121, 364, 287, 11059]
text: ':\n\n1. Use the transformers in the same way you would use a regular expression.\n\n2. Use the transformers in translation'

Transformers 4 compatibility test

python: 3.12.12
torch: 2.9.0+cu126
transformers: 4.57.6
cuda: True
model: openai-community/gpt2
custom_generate: /content/constrained-beam-search/
device: cuda:0

use_cache=False:

ids: [25, 198, 198, 16, 13, 5765, 262, 6121, 364, 287, 262, 976, 835, 345, 561, 779, 257, 3218, 5408, 13, 198, 198, 17, 13, 5765, 262, 6121, 364, 287, 11059]
text: ':\n\n1. Use the transformers in the same way you would use a regular expression.\n\n2. Use the transformers in translation'

use_cache=True:

ids: [25, 198, 198, 16, 13, 5765, 262, 6121, 364, 287, 262, 976, 835, 345, 561, 779, 257, 3218, 5408, 13, 198, 198, 17, 13, 5765, 262, 6121, 364, 287, 11059]
text: ':\n\n1. Use the transformers in the same way you would use a regular expression.\n\n2. Use the transformers in translation'

So the PR fixes the Transformers v5 cached-decoding regression and keeps the same deterministic output under Transformers 4.57.6.

lavrenko changed pull request status to open
Transformers Community org

Yep, same issue and thanks for the PR

RaushanTurganbay changed pull request status to merged

Sign up or log in to comment