|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify_logits_for_top_k_filtering(logits, top_k): |
|
|
"""Set the logits for none top-k values to -inf. Done in-place.""" |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits.masked_fill_(indices_to_remove, float("-Inf")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def modify_logits_for_top_p_filtering(logits, top_p): |
|
|
"""Set the logits for none top-p values to -inf. Done in-place.""" |
|
|
if top_p <= 0.0 or top_p >= 1.0: |
|
|
return |
|
|
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits.masked_fill_(indices_to_remove, float("-inf")) |
|
|
|
|
|
|
|
|
|
|
|
def sample(logits, top_k=1, top_p=0.0, temperature=1.0): |
|
|
"""Sample from top-k logits. |
|
|
Arguments: |
|
|
logits: Tensor of shape (batch_size, vocab_size) |
|
|
""" |
|
|
logits = torch.nan_to_num(logits) |
|
|
logits = torch.where(logits == float("-inf"), 0, logits) |
|
|
logits = torch.where(logits == float("inf"), 0, logits) |
|
|
|
|
|
if top_k == 1: |
|
|
return logits.argmax(dim=-1) |
|
|
else: |
|
|
if top_p > 0.0: |
|
|
assert top_p <= 1.0, "top-p should be in (0, 1]." |
|
|
if top_k > 0: |
|
|
top_k = min(top_k, logits.size(-1)) |
|
|
logits_top, indices = torch.topk(logits, top_k, dim=-1) |
|
|
if temperature != 1.0: |
|
|
logits_top /= temperature |
|
|
modify_logits_for_top_p_filtering(logits_top, top_p) |
|
|
|
|
|
return indices[ |
|
|
torch.arange(indices.shape[0], device=indices.device), |
|
|
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), |
|
|
] |
|
|
else: |
|
|
|
|
|
logits_top = logits / temperature if temperature != 1.0 else logits.clone() |
|
|
modify_logits_for_top_p_filtering(logits_top, top_p) |
|
|
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) |
|
|
|