Fix for the test
Browse files
custom_generate/generate.py
CHANGED
|
@@ -6,7 +6,7 @@ def generate(model, input_ids, generation_config=None, left_padding=None, **kwar
|
|
| 6 |
|
| 7 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 8 |
cur_length = input_ids.shape[1]
|
| 9 |
-
max_length = generation_config.max_length or cur_length +
|
| 10 |
|
| 11 |
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
| 12 |
if left_padding is not None:
|
|
|
|
| 6 |
|
| 7 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 8 |
cur_length = input_ids.shape[1]
|
| 9 |
+
max_length = generation_config.max_length or cur_length + kwargs.get("max_new_tokens")
|
| 10 |
|
| 11 |
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
| 12 |
if left_padding is not None:
|