NicolasBFR commited on
Commit
86a6a26
·
verified ·
1 Parent(s): 3081418

Fix to avoid error with synced_gpus

Browse files

Add a default value to the argument `synced_gpus` to avoid this error:

```
Traceback (most recent call last):
File "myFolder/src/translationProject/main.py", line 12, in <module>
outputs = model.generate(
File "myFolder/.venv/lib64/python3.9/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "myFolder/.venv/lib/python3.9/site-packages/transformers/generation/utils.py", line 2367, in generate
return custom_generate_function(model=self, **generate_arguments)
File "/root/.cache/huggingface/modules/transformers_modules/transformers_hyphen_community/constrained_hyphen_beam_hyphen_search/3081418faf290f61bc253e649b0033adf877e655/custom_generate/generate.py", line 338, in generate
generation_outputs = GenerationMixin.generate(
File "myFolder/.venv/lib64/python3.9/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "myFolder/.venv/lib/python3.9/site-packages/transformers/generation/utils.py", line 2564, in generate
result = decoding_method(
TypeError: _constrained_beam_search() missing 1 required positional argument: 'synced_gpus'
```

Files changed (1) hide show
  1. custom_generate/generate.py +1 -1
custom_generate/generate.py CHANGED
@@ -26,7 +26,7 @@ def _constrained_beam_search(
26
  logits_processor: LogitsProcessorList,
27
  stopping_criteria: StoppingCriteriaList,
28
  generation_config: GenerationConfig,
29
- synced_gpus: bool,
30
  streamer: Optional["BaseStreamer"] = None,
31
  **model_kwargs,
32
  ) -> Union[GenerateBeamOutput, torch.LongTensor]:
 
26
  logits_processor: LogitsProcessorList,
27
  stopping_criteria: StoppingCriteriaList,
28
  generation_config: GenerationConfig,
29
+ synced_gpus: bool = False,
30
  streamer: Optional["BaseStreamer"] = None,
31
  **model_kwargs,
32
  ) -> Union[GenerateBeamOutput, torch.LongTensor]: