Fix Transformers v5 cached decoding in DoLa

#2
by lavrenko - opened
Files changed (1) hide show
  1. custom_generate/generate.py +17 -2
custom_generate/generate.py CHANGED
@@ -229,9 +229,24 @@ def _dola_decoding(
229
  if lm_head is None:
230
  raise ValueError("DoLa is not supported for models that don't have output embeddings.")
231
 
 
232
  while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
233
- # prepare model inputs
234
- model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # forward pass to get next token
237
  outputs = model(**model_inputs, return_dict=True)
 
229
  if lm_head is None:
230
  raise ValueError("DoLa is not supported for models that don't have output embeddings.")
231
 
232
+ is_first_iteration = True
233
  while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
234
+ # Transformers v5 cache protocol: prefill uses the full prompt; later cached
235
+ # steps use only the newest token. Uncached decoding keeps the full prefix.
236
+ next_sequence_length = (
237
+ None
238
+ if is_first_iteration or not model_kwargs.get("use_cache", True)
239
+ else 1
240
+ )
241
+
242
+ model_inputs = model.prepare_inputs_for_generation(
243
+ input_ids,
244
+ next_sequence_length=next_sequence_length,
245
+ is_first_iteration=is_first_iteration,
246
+ **model_kwargs,
247
+ )
248
+
249
+ is_first_iteration = False
250
 
251
  # forward pass to get next token
252
  outputs = model(**model_inputs, return_dict=True)