Spaces:
Runtime error
Runtime error
File size: 3,634 Bytes
c2bf4b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | """Shared sampling utilities for chat.py / chat_eval.py.
Pure functions: given a 1-D logits tensor (vocab_size,), return a single
sampled token id. No model/tokenizer knowledge here.
"""
from __future__ import annotations
from typing import Iterable, Optional
import torch
def apply_repetition_penalty(
logits: torch.Tensor,
recent_tokens: Optional[Iterable[int]],
penalty: float,
) -> torch.Tensor:
"""Divide logits of recent positive tokens by `penalty`, multiply negatives.
Operates in-place on a *copy* (logits is cloned first by caller if needed).
`recent_tokens` may be any iterable of ints; duplicates are deduped internally.
"""
if penalty == 1.0 or not recent_tokens:
return logits
seen = set(int(t) for t in recent_tokens)
if not seen:
return logits
idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
vals = logits.index_select(0, idx)
vals = torch.where(vals > 0, vals / penalty, vals * penalty)
logits.index_copy_(0, idx, vals)
return logits
def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
"""Keep only the top-k logits; set the rest to -inf.
top_k<=0 or top_k>=vocab disables the filter."""
if top_k <= 0 or top_k >= logits.size(-1):
return logits
topk_vals, topk_idx = logits.topk(top_k)
mask = torch.full_like(logits, float("-inf"))
mask.scatter_(0, topk_idx, topk_vals)
return mask
def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
"""Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
if top_p >= 1.0 or top_p <= 0.0:
return logits
sorted_logits, sorted_idx = logits.sort(descending=True)
cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
mask = cumulative_probs > top_p
# shift right so we always keep at least one token
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
out = torch.full_like(logits, float("-inf"))
out.scatter_(0, sorted_idx, sorted_logits)
return out
def sample_token(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
recent_tokens: Optional[Iterable[int]] = None,
) -> int:
"""Return a single sampled token id (Python int).
logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
"""
if logits.dim() != 1:
raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")
# Work in fp32 on a clone so the caller's tensor is unchanged.
work = logits.detach().to(torch.float32).clone()
if repetition_penalty != 1.0 and recent_tokens is not None:
work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)
# Temperature. Greedy when temperature <= 0.
if temperature <= 0.0:
return int(work.argmax().item())
work = work / max(temperature, 1e-6)
work = apply_top_k(work, top_k)
work = apply_top_p(work, top_p)
# Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
if torch.isinf(work).all():
return int(logits.argmax().item())
probs = torch.softmax(work, dim=-1)
# Numerical safety — replace any NaN with 0 and renormalize.
if torch.isnan(probs).any():
probs = torch.nan_to_num(probs, nan=0.0)
s = probs.sum()
if s <= 0:
return int(logits.argmax().item())
probs = probs / s
tok = torch.multinomial(probs, num_samples=1)
return int(tok.item())
|