Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,233 Bytes
aa16b75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from __future__ import annotations
import torch
def sample_token(
logits: torch.Tensor,
*,
temp: float,
top_k: int = 0,
) -> torch.Tensor:
logits32 = logits.to(torch.float32)
if temp <= 0.0:
return torch.argmax(logits32, dim=-1, keepdim=True)
probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
probs = torch.clamp_min(probs, 0.0)
flat = probs.reshape(-1, probs.shape[-1])
norm = flat.sum(dim=-1, keepdim=True)
zero_mask = norm <= 0
norm = norm.clamp_min(1e-12)
flat = flat / norm
if zero_mask.any():
filler = torch.zeros_like(flat)
filler[..., 0] = 1.0
mask = zero_mask.expand_as(flat)
flat = torch.where(mask, filler, flat)
vocab = flat.shape[-1]
if top_k > 0 and top_k < vocab:
topv, indices = torch.topk(flat, top_k, dim=-1)
topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
draws = torch.multinomial(topv, num_samples=1)
picks = torch.gather(indices, dim=-1, index=draws)
else:
picks = torch.multinomial(flat, num_samples=1)
picks = picks.reshape(*probs.shape[:-1], 1)
return picks
|