Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from einops import rearrange | |
| from transformers import ( | |
| LogitsProcessorList, | |
| TemperatureLogitsWarper, | |
| TopKLogitsWarper, | |
| TopPLogitsWarper, | |
| ) | |
| def sample_logits( | |
| logits: Tensor, | |
| temperature: float = 0.8, | |
| top_k: int = 20, | |
| top_p: float = 0.9, | |
| generator: torch.Generator | None = None, | |
| ) -> Tensor: | |
| """ | |
| logits: (B, T, V) | |
| return: (B, T) | |
| """ | |
| B, T, _ = logits.shape | |
| flat_logits = rearrange(logits, "b t v -> (b t) v") | |
| processors = LogitsProcessorList( | |
| [ | |
| TemperatureLogitsWarper(temperature=temperature), | |
| TopKLogitsWarper(top_k=top_k), | |
| TopPLogitsWarper(top_p=top_p), | |
| ] | |
| ) | |
| dummy_input_ids = torch.zeros( | |
| B * T, | |
| 1, | |
| dtype=torch.long, | |
| device=logits.device, | |
| ) | |
| flat_logits = processors(dummy_input_ids, flat_logits) | |
| probs = torch.softmax(flat_logits, dim=-1) | |
| sampled = torch.multinomial(probs, num_samples=1, generator=generator) | |
| return rearrange(sampled, "(b t) 1 -> b t", b=B, t=T) | |