Pramodith commited on
Commit
d031870
·
1 Parent(s): cf9e688

change max_length logic

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +1 -1
custom_generate/generate.py CHANGED
@@ -44,7 +44,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
44
  """
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
- if generation_config.max_length is None:
48
  max_length = cur_length + generation_config.max_new_tokens
49
  else:
50
  max_length = generation_config.max_length
 
44
  """
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
+ if generation_config.max_new_tokens:
48
  max_length = cur_length + generation_config.max_new_tokens
49
  else:
50
  max_length = generation_config.max_length