Pramodith commited on
Commit
2c23dfb
·
1 Parent(s): 95d9c07

Minor bug fix.

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +2 -2
custom_generate/generate.py CHANGED
@@ -16,7 +16,7 @@ def top_n_sigma_sampling(logits:torch.Tensor, temperature:float, n_sigma:float)
16
  max_logit_score = torch.max(logits, dim=-1, keepdim=True).values
17
  std = torch.std(logits, dim=-1, keepdim=True)
18
  threshold = max_logit_score - n_sigma * std
19
- filtered_logits = torch.where(logits >= threshold, logits, torch.tensor(float('-inf')))
20
  return filtered_logits
21
 
22
  @torch.inference_mode()
@@ -39,7 +39,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
39
  input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
40
  generation_config (optional): Configuration for generation, such as max_length, pad_token_id,
41
  and max_new_tokens.
42
- n_sigma (int): The number of standard deviations to use for topN-sigma sampling.
43
  **kwargs: Additional keyword arguments.
44
  """
45
  generation_config = generation_config or model.generation_config # default to the model generation config
 
16
  max_logit_score = torch.max(logits, dim=-1, keepdim=True).values
17
  std = torch.std(logits, dim=-1, keepdim=True)
18
  threshold = max_logit_score - n_sigma * std
19
+ filtered_logits = torch.where(logits >= threshold, logits, torch.full_like(logits, float('-inf')))
20
  return filtered_logits
21
 
22
  @torch.inference_mode()
 
39
  input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
40
  generation_config (optional): Configuration for generation, such as max_length, pad_token_id,
41
  and max_new_tokens.
42
+ n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
43
  **kwargs: Additional keyword arguments.
44
  """
45
  generation_config = generation_config or model.generation_config # default to the model generation config