File size: 3,634 Bytes
c2bf4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Shared sampling utilities for chat.py / chat_eval.py.

Pure functions: given a 1-D logits tensor (vocab_size,), return a single
sampled token id. No model/tokenizer knowledge here.
"""

from __future__ import annotations

from typing import Iterable, Optional

import torch


def apply_repetition_penalty(
    logits: torch.Tensor,
    recent_tokens: Optional[Iterable[int]],
    penalty: float,
) -> torch.Tensor:
    """Divide logits of recent positive tokens by `penalty`, multiply negatives.

    Operates in-place on a *copy* (logits is cloned first by caller if needed).
    `recent_tokens` may be any iterable of ints; duplicates are deduped internally.
    """
    if penalty == 1.0 or not recent_tokens:
        return logits
    seen = set(int(t) for t in recent_tokens)
    if not seen:
        return logits
    idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
    vals = logits.index_select(0, idx)
    vals = torch.where(vals > 0, vals / penalty, vals * penalty)
    logits.index_copy_(0, idx, vals)
    return logits


def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
    """Keep only the top-k logits; set the rest to -inf.

    top_k<=0 or top_k>=vocab disables the filter."""
    if top_k <= 0 or top_k >= logits.size(-1):
        return logits
    topk_vals, topk_idx = logits.topk(top_k)
    mask = torch.full_like(logits, float("-inf"))
    mask.scatter_(0, topk_idx, topk_vals)
    return mask


def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    """Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
    if top_p >= 1.0 or top_p <= 0.0:
        return logits
    sorted_logits, sorted_idx = logits.sort(descending=True)
    cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
    mask = cumulative_probs > top_p
    # shift right so we always keep at least one token
    mask[1:] = mask[:-1].clone()
    mask[0] = False
    sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
    out = torch.full_like(logits, float("-inf"))
    out.scatter_(0, sorted_idx, sorted_logits)
    return out


def sample_token(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    recent_tokens: Optional[Iterable[int]] = None,
) -> int:
    """Return a single sampled token id (Python int).

    logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
    """
    if logits.dim() != 1:
        raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")

    # Work in fp32 on a clone so the caller's tensor is unchanged.
    work = logits.detach().to(torch.float32).clone()

    if repetition_penalty != 1.0 and recent_tokens is not None:
        work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)

    # Temperature. Greedy when temperature <= 0.
    if temperature <= 0.0:
        return int(work.argmax().item())
    work = work / max(temperature, 1e-6)

    work = apply_top_k(work, top_k)
    work = apply_top_p(work, top_p)

    # Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
    if torch.isinf(work).all():
        return int(logits.argmax().item())

    probs = torch.softmax(work, dim=-1)
    # Numerical safety — replace any NaN with 0 and renormalize.
    if torch.isnan(probs).any():
        probs = torch.nan_to_num(probs, nan=0.0)
        s = probs.sum()
        if s <= 0:
            return int(logits.argmax().item())
        probs = probs / s

    tok = torch.multinomial(probs, num_samples=1)
    return int(tok.item())