Upload folder using huggingface_hub
Browse files
custom_generate/generate.py
CHANGED
|
@@ -88,8 +88,8 @@ def _dola_decoding(
|
|
| 88 |
logits_processor: LogitsProcessorList,
|
| 89 |
stopping_criteria: StoppingCriteriaList,
|
| 90 |
generation_config: GenerationConfig,
|
| 91 |
-
synced_gpus: bool,
|
| 92 |
-
streamer: "BaseStreamer",
|
| 93 |
**model_kwargs,
|
| 94 |
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 95 |
r"""
|
|
@@ -113,7 +113,7 @@ def _dola_decoding(
|
|
| 113 |
used to tell if the generation loop should stop.
|
| 114 |
generation_config ([`~generation.GenerationConfig`]):
|
| 115 |
The generation configuration to be used as parametrization of the decoding method.
|
| 116 |
-
synced_gpus (`bool`):
|
| 117 |
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 118 |
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 119 |
streamer (`BaseStreamer`, *optional*):
|
|
|
|
| 88 |
logits_processor: LogitsProcessorList,
|
| 89 |
stopping_criteria: StoppingCriteriaList,
|
| 90 |
generation_config: GenerationConfig,
|
| 91 |
+
synced_gpus: bool = False,
|
| 92 |
+
streamer: "BaseStreamer" = None,
|
| 93 |
**model_kwargs,
|
| 94 |
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
| 95 |
r"""
|
|
|
|
| 113 |
used to tell if the generation loop should stop.
|
| 114 |
generation_config ([`~generation.GenerationConfig`]):
|
| 115 |
The generation configuration to be used as parametrization of the decoding method.
|
| 116 |
+
synced_gpus (`bool`, *optional*, defaults to `False`):
|
| 117 |
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 118 |
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 119 |
streamer (`BaseStreamer`, *optional*):
|