Minor bug fix.
Browse files
custom_generate/generate.py
CHANGED
|
@@ -16,7 +16,7 @@ def top_n_sigma_sampling(logits:torch.Tensor, temperature:float, n_sigma:float)
|
|
| 16 |
max_logit_score = torch.max(logits, dim=-1, keepdim=True).values
|
| 17 |
std = torch.std(logits, dim=-1, keepdim=True)
|
| 18 |
threshold = max_logit_score - n_sigma * std
|
| 19 |
-
filtered_logits = torch.where(logits >= threshold, logits, torch.
|
| 20 |
return filtered_logits
|
| 21 |
|
| 22 |
@torch.inference_mode()
|
|
@@ -39,7 +39,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 39 |
input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
|
| 40 |
generation_config (optional): Configuration for generation, such as max_length, pad_token_id,
|
| 41 |
and max_new_tokens.
|
| 42 |
-
n_sigma (
|
| 43 |
**kwargs: Additional keyword arguments.
|
| 44 |
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
|
|
|
| 16 |
max_logit_score = torch.max(logits, dim=-1, keepdim=True).values
|
| 17 |
std = torch.std(logits, dim=-1, keepdim=True)
|
| 18 |
threshold = max_logit_score - n_sigma * std
|
| 19 |
+
filtered_logits = torch.where(logits >= threshold, logits, torch.full_like(logits, float('-inf')))
|
| 20 |
return filtered_logits
|
| 21 |
|
| 22 |
@torch.inference_mode()
|
|
|
|
| 39 |
input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
|
| 40 |
generation_config (optional): Configuration for generation, such as max_length, pad_token_id,
|
| 41 |
and max_new_tokens.
|
| 42 |
+
n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
|
| 43 |
**kwargs: Additional keyword arguments.
|
| 44 |
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|