File size: 1,233 Bytes
1315cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from __future__ import annotations

import torch


def sample_token(
    logits: torch.Tensor,
    *,
    temp: float,
    top_k: int = 0,
) -> torch.Tensor:
    logits32 = logits.to(torch.float32)
    if temp <= 0.0:
        return torch.argmax(logits32, dim=-1, keepdim=True)
    probs = torch.softmax(logits32 / max(temp, 1e-6), dim=-1)
    probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
    probs = torch.clamp_min(probs, 0.0)
    flat = probs.reshape(-1, probs.shape[-1])
    norm = flat.sum(dim=-1, keepdim=True)
    zero_mask = norm <= 0
    norm = norm.clamp_min(1e-12)
    flat = flat / norm
    if zero_mask.any():
        filler = torch.zeros_like(flat)
        filler[..., 0] = 1.0
        mask = zero_mask.expand_as(flat)
        flat = torch.where(mask, filler, flat)
    vocab = flat.shape[-1]
    if top_k > 0 and top_k < vocab:
        topv, indices = torch.topk(flat, top_k, dim=-1)
        topv = topv / topv.sum(dim=-1, keepdim=True).clamp_min(1e-12)
        draws = torch.multinomial(topv, num_samples=1)
        picks = torch.gather(indices, dim=-1, index=draws)
    else:
        picks = torch.multinomial(flat, num_samples=1)
    picks = picks.reshape(*probs.shape[:-1], 1)
    return picks