Radi Akbar commited on
Commit
83e18c2
·
1 Parent(s): 0730db2

New generate function

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +4 -0
custom_generate/generate.py CHANGED
@@ -559,6 +559,8 @@ def generate(
559
  generation_config is None or generation_config.min_length is None
560
  )
561
  generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
 
 
562
 
563
  generation_mode = generation_config.get_generation_mode(assistant_model)
564
  decoding_method = _speculative_cascades
@@ -721,6 +723,8 @@ def generate(
721
  logits_processor=prepared_logits_processor,
722
  stopping_criteria=prepared_stopping_criteria,
723
  generation_config=generation_config,
 
 
724
  **generation_mode_kwargs,
725
  **model_kwargs,
726
  )
 
559
  generation_config is None or generation_config.min_length is None
560
  )
561
  generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
562
+ alpha = model_kwargs.pop('alpha')
563
+ deferral = model_kwargs.pop('deferral')
564
 
565
  generation_mode = generation_config.get_generation_mode(assistant_model)
566
  decoding_method = _speculative_cascades
 
723
  logits_processor=prepared_logits_processor,
724
  stopping_criteria=prepared_stopping_criteria,
725
  generation_config=generation_config,
726
+ alpha=alpha,
727
+ deferral=deferral,
728
  **generation_mode_kwargs,
729
  **model_kwargs,
730
  )