cyrilvallez HF Staff commited on
Commit
b9006ff
·
verified ·
1 Parent(s): 30636bc

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +11 -55
custom_generate/generate.py CHANGED
@@ -282,11 +282,11 @@ def _contrastive_search(
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
- # We now only use Cache classes, but a few models have custom cache class, so we use this check instead of an instance check
286
- elif not hasattr(past_key_values, "update"):
287
  raise ValueError(
288
- f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
289
- "used for contrastive search without further modifications."
290
  )
291
 
292
  # contrastive_search main logic start:
@@ -324,19 +324,7 @@ def _contrastive_search(
324
  del outputs
325
 
326
  if not sequential:
327
- # Replicates the new past_key_values to match the `top_k` candidates
328
- if isinstance(outputs["past_key_values"], DynamicCache) or (
329
- isinstance(outputs["past_key_values"], EncoderDecoderCache)
330
- and isinstance(
331
- outputs["past_key_values"].self_attention_cache, DynamicCache
332
- )
333
- ):
334
- model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
335
- else:
336
- raise ValueError(
337
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
338
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
339
- )
340
 
341
  if sequential:
342
  all_outputs = []
@@ -352,21 +340,10 @@ def _contrastive_search(
352
  output_hidden_states=True,
353
  output_attentions=output_attentions,
354
  )
355
- if isinstance(outputs["past_key_values"], DynamicCache) or (
356
- isinstance(outputs["past_key_values"], EncoderDecoderCache)
357
- and isinstance(
358
- outputs["past_key_values"].self_attention_cache, DynamicCache
359
- )
360
- ):
361
- # Remove past K-V from output since we don't need to stack later
362
- outputs["past_key_values"] = None
363
- # Remove last token from past K-V since we don't want to append it at this point
364
- model_kwargs["past_key_values"].crop(-1)
365
- else:
366
- raise ValueError(
367
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
368
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
369
- )
370
 
371
  all_outputs.append(outputs)
372
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
@@ -463,17 +440,7 @@ def _contrastive_search(
463
  next_past_key_values = next_past_key_values or getattr(
464
  outputs, possible_cache_name, None
465
  )
466
- # Do it in-place layer per layer to save memory
467
- if isinstance(next_past_key_values, DynamicCache) or (
468
- isinstance(next_past_key_values, EncoderDecoderCache)
469
- and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
470
- ):
471
- next_past_key_values.batch_select_indices(augmented_idx)
472
- else:
473
- raise ValueError(
474
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
475
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
476
- )
477
 
478
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[
479
  range(batch_size), selected_idx, :
@@ -549,18 +516,7 @@ def _contrastive_search(
549
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
550
  # `past_key_values` to be consistent with the other decoding methods
551
  if model_kwargs.get("past_key_values") is not None:
552
- if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
553
- isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
554
- and isinstance(
555
- model_kwargs["past_key_values"].self_attention_cache, DynamicCache
556
- )
557
- ):
558
- model_kwargs["past_key_values"].crop(-1)
559
- else:
560
- raise ValueError(
561
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
562
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
563
- )
564
 
565
  if model.config.is_encoder_decoder:
566
  return GenerateEncoderDecoderOutput(
 
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
+ # Only those caches have the necesary methods
286
+ elif not (isinstance(past_key_values, DynamicCache) or (isinstance(past_key_values, EncoderDecoderCache) and isinstance(past_key_values.self_attention_cache, DynamicCache))):
287
  raise ValueError(
288
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
289
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
290
  )
291
 
292
  # contrastive_search main logic start:
 
324
  del outputs
325
 
326
  if not sequential:
327
+ model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  if sequential:
330
  all_outputs = []
 
340
  output_hidden_states=True,
341
  output_attentions=output_attentions,
342
  )
343
+ # Remove past K-V from output since we don't need to stack later
344
+ outputs["past_key_values"] = None
345
+ # Remove last token from past K-V since we don't want to append it at this point
346
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  all_outputs.append(outputs)
349
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
 
440
  next_past_key_values = next_past_key_values or getattr(
441
  outputs, possible_cache_name, None
442
  )
443
+ next_past_key_values.batch_select_indices(augmented_idx)
 
 
 
 
 
 
 
 
 
 
444
 
445
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[
446
  range(batch_size), selected_idx, :
 
516
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
517
  # `past_key_values` to be consistent with the other decoding methods
518
  if model_kwargs.get("past_key_values") is not None:
519
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
 
 
520
 
521
  if model.config.is_encoder_decoder:
522
  return GenerateEncoderDecoderOutput(