Pramodith commited on
Commit
52a5378
·
1 Parent(s): 2fa5de0
Files changed (1) hide show
  1. custom_generate/generate.py +1 -0
custom_generate/generate.py CHANGED
@@ -45,6 +45,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
 
48
 
49
  while cur_length < max_length:
50
  logits = model(input_ids).logits
 
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
  max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
48
+ print(f"Starting generation with max_length: {max_length}, current length: {cur_length}")
49
 
50
  while cur_length < max_length:
51
  logits = model(input_ids).logits