Spaces:
Sleeping
Sleeping
File size: 1,089 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
from torch import Tensor
from einops import rearrange
from transformers import (
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
def sample_logits(
logits: Tensor,
temperature: float = 0.8,
top_k: int = 20,
top_p: float = 0.9,
generator: torch.Generator | None = None,
) -> Tensor:
"""
logits: (B, T, V)
return: (B, T)
"""
B, T, _ = logits.shape
flat_logits = rearrange(logits, "b t v -> (b t) v")
processors = LogitsProcessorList(
[
TemperatureLogitsWarper(temperature=temperature),
TopKLogitsWarper(top_k=top_k),
TopPLogitsWarper(top_p=top_p),
]
)
dummy_input_ids = torch.zeros(
B * T,
1,
dtype=torch.long,
device=logits.device,
)
flat_logits = processors(dummy_input_ids, flat_logits)
probs = torch.softmax(flat_logits, dim=-1)
sampled = torch.multinomial(probs, num_samples=1, generator=generator)
return rearrange(sampled, "(b t) 1 -> b t", b=B, t=T)
|