Radi Akbar commited on
Commit ·
83e18c2
1
Parent(s): 0730db2
New generate function
Browse files
custom_generate/generate.py
CHANGED
|
@@ -559,6 +559,8 @@ def generate(
|
|
| 559 |
generation_config is None or generation_config.min_length is None
|
| 560 |
)
|
| 561 |
generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
|
|
|
|
|
|
|
| 562 |
|
| 563 |
generation_mode = generation_config.get_generation_mode(assistant_model)
|
| 564 |
decoding_method = _speculative_cascades
|
|
@@ -721,6 +723,8 @@ def generate(
|
|
| 721 |
logits_processor=prepared_logits_processor,
|
| 722 |
stopping_criteria=prepared_stopping_criteria,
|
| 723 |
generation_config=generation_config,
|
|
|
|
|
|
|
| 724 |
**generation_mode_kwargs,
|
| 725 |
**model_kwargs,
|
| 726 |
)
|
|
|
|
| 559 |
generation_config is None or generation_config.min_length is None
|
| 560 |
)
|
| 561 |
generation_config, model_kwargs = model._prepare_generation_config(generation_config, **kwargs)
|
| 562 |
+
alpha = model_kwargs.pop('alpha')
|
| 563 |
+
deferral = model_kwargs.pop('deferral')
|
| 564 |
|
| 565 |
generation_mode = generation_config.get_generation_mode(assistant_model)
|
| 566 |
decoding_method = _speculative_cascades
|
|
|
|
| 723 |
logits_processor=prepared_logits_processor,
|
| 724 |
stopping_criteria=prepared_stopping_criteria,
|
| 725 |
generation_config=generation_config,
|
| 726 |
+
alpha=alpha,
|
| 727 |
+
deferral=deferral,
|
| 728 |
**generation_mode_kwargs,
|
| 729 |
**model_kwargs,
|
| 730 |
)
|