change max_length logic
Browse files
custom_generate/generate.py
CHANGED
|
@@ -44,7 +44,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 44 |
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
-
if generation_config.
|
| 48 |
max_length = cur_length + generation_config.max_new_tokens
|
| 49 |
else:
|
| 50 |
max_length = generation_config.max_length
|
|
|
|
| 44 |
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
+
if generation_config.max_new_tokens:
|
| 48 |
max_length = cur_length + generation_config.max_new_tokens
|
| 49 |
else:
|
| 50 |
max_length = generation_config.max_length
|