Pramodith commited on
Commit
2fa5de0
·
1 Parent(s): 2c23dfb
Files changed (1) hide show
  1. custom_generate/generate.py +1 -0
custom_generate/generate.py CHANGED
@@ -55,5 +55,6 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
55
  next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
56
  input_ids = torch.cat((input_ids, next_tokens), dim=-1)
57
  cur_length += 1
 
58
 
59
  return input_ids
 
55
  next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
56
  input_ids = torch.cat((input_ids, next_tokens), dim=-1)
57
  cur_length += 1
58
+ print(f"Current length: {cur_length}, Next token: {next_tokens.item()}")
59
 
60
  return input_ids