github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
import torch
from torch import Tensor
from einops import rearrange
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
def sample_logits(
logits: Tensor,
temperature: float = 0.8,
top_k: int = 20,
top_p: float = 0.9,
generator: torch.Generator | None = None,
) -> Tensor:
"""
logits: (B, T, V)
return: (B, T)
"""
B, T, _ = logits.shape
flat_logits = rearrange(logits, "b t v -> (b t) v")
processors = LogitsProcessorList(
[
TemperatureLogitsWarper(temperature=temperature),
TopKLogitsWarper(top_k=top_k),
TopPLogitsWarper(top_p=top_p),
]
)
dummy_input_ids = torch.zeros(
B * T,
1,
dtype=torch.long,
device=logits.device,
)
flat_logits = processors(dummy_input_ids, flat_logits)
probs = torch.softmax(flat_logits, dim=-1)
sampled = torch.multinomial(probs, num_samples=1, generator=generator)
return rearrange(sampled, "(b t) 1 -> b t", b=B, t=T)