Pramodith commited on
Commit
cf9e688
·
1 Parent(s): 52a5378

log debug temp

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +6 -4
custom_generate/generate.py CHANGED
@@ -41,11 +41,14 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
41
  and max_new_tokens.
42
  n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
43
  **kwargs: Additional keyword arguments.
44
- """
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
- max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
48
- print(f"Starting generation with max_length: {max_length}, current length: {cur_length}")
 
 
 
49
 
50
  while cur_length < max_length:
51
  logits = model(input_ids).logits
@@ -56,6 +59,5 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
56
  next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
57
  input_ids = torch.cat((input_ids, next_tokens), dim=-1)
58
  cur_length += 1
59
- print(f"Current length: {cur_length}, Next token: {next_tokens.item()}")
60
 
61
  return input_ids
 
41
  and max_new_tokens.
42
  n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
43
  **kwargs: Additional keyword arguments.
44
+ """
45
  generation_config = generation_config or model.generation_config # default to the model generation config
46
  cur_length = input_ids.shape[1]
47
+ if generation_config.max_length is None:
48
+ max_length = cur_length + generation_config.max_new_tokens
49
+ else:
50
+ max_length = generation_config.max_length
51
+ print(f"Starting generation with max_length: {max_length}, current length: {cur_length} and temperature: {generation_config.temperature}")
52
 
53
  while cur_length < max_length:
54
  logits = model(input_ids).logits
 
59
  next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
60
  input_ids = torch.cat((input_ids, next_tokens), dim=-1)
61
  cur_length += 1
 
62
 
63
  return input_ids