manueldeprada HF Staff commited on
Commit
e4e1114
·
verified ·
1 Parent(s): 005232e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +3 -3
custom_generate/generate.py CHANGED
@@ -88,8 +88,8 @@ def _dola_decoding(
88
  logits_processor: LogitsProcessorList,
89
  stopping_criteria: StoppingCriteriaList,
90
  generation_config: GenerationConfig,
91
- synced_gpus: bool,
92
- streamer: "BaseStreamer",
93
  **model_kwargs,
94
  ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
95
  r"""
@@ -113,7 +113,7 @@ def _dola_decoding(
113
  used to tell if the generation loop should stop.
114
  generation_config ([`~generation.GenerationConfig`]):
115
  The generation configuration to be used as parametrization of the decoding method.
116
- synced_gpus (`bool`):
117
  Whether to continue running the while loop until max_length (needed to avoid deadlocking with
118
  `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
119
  streamer (`BaseStreamer`, *optional*):
 
88
  logits_processor: LogitsProcessorList,
89
  stopping_criteria: StoppingCriteriaList,
90
  generation_config: GenerationConfig,
91
+ synced_gpus: bool = False,
92
+ streamer: "BaseStreamer" = None,
93
  **model_kwargs,
94
  ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
95
  r"""
 
113
  used to tell if the generation loop should stop.
114
  generation_config ([`~generation.GenerationConfig`]):
115
  The generation configuration to be used as parametrization of the decoding method.
116
+ synced_gpus (`bool`, *optional*, defaults to `False`):
117
  Whether to continue running the while loop until max_length (needed to avoid deadlocking with
118
  `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
119
  streamer (`BaseStreamer`, *optional*):