| | |
| | |
| | |
| | |
| |
|
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| | def top_k_top_p_filtering( |
| | logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 |
| | ): |
| | """ |
| | Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. |
| | |
| | Args: |
| | logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size). |
| | top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering). |
| | Set to 0 to disable. Defaults to 0. |
| | top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering). |
| | Must be between 0 and 1, inclusive. Defaults to 1.0. |
| | filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf'). |
| | min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example. |
| | Defaults to 1. |
| | |
| | Returns: |
| | torch.Tensor: The filtered logits. |
| | """ |
| | """ |
| | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
| | Make sure we keep at least min_tokens_to_keep per batch example in the output |
| | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
| | """ |
| | if top_k > 0: |
| | |
| | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) |
| | indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None] |
| | logits[indices_to_remove] = filter_value |
| |
|
| | if top_p < 1.0: |
| | |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
|
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | if min_tokens_to_keep > 1: |
| | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| |
|
| | |
| | indices_to_remove = sorted_indices.scatter(1, sorted_indices, sorted_indices_to_remove) |
| | logits[indices_to_remove] = filter_value |
| |
|
| | return logits |
| |
|
| |
|
| | def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0): |
| | """ |
| | Perform top-k and top-p sampling on logits. |
| | |
| | Args: |
| | logits (torch.Tensor): The logits to sample from. |
| | top_k (int, optional): The number of highest probability tokens to keep for top-k filtering. |
| | Must be a positive integer. Defaults to 50. |
| | top_p (float, optional): The cumulative probability threshold for nucleus sampling. |
| | Must be between 0 and 1. Defaults to 1.0. |
| | temperature (float, optional): The scaling factor to adjust the logits distribution. |
| | Must be strictly positive. Defaults to 1.0. |
| | |
| | Returns: |
| | torch.Tensor: The sampled token. |
| | """ |
| |
|
| | |
| | if temperature != 1.0: |
| | logits = logits / temperature |
| |
|
| | |
| | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| |
|
| | |
| | token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
| | return token |
| |
|