"""Shared sampling utilities for chat.py / chat_eval.py. Pure functions: given a 1-D logits tensor (vocab_size,), return a single sampled token id. No model/tokenizer knowledge here. """ from __future__ import annotations from typing import Iterable, Optional import torch def apply_repetition_penalty( logits: torch.Tensor, recent_tokens: Optional[Iterable[int]], penalty: float, ) -> torch.Tensor: """Divide logits of recent positive tokens by `penalty`, multiply negatives. Operates in-place on a *copy* (logits is cloned first by caller if needed). `recent_tokens` may be any iterable of ints; duplicates are deduped internally. """ if penalty == 1.0 or not recent_tokens: return logits seen = set(int(t) for t in recent_tokens) if not seen: return logits idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long) vals = logits.index_select(0, idx) vals = torch.where(vals > 0, vals / penalty, vals * penalty) logits.index_copy_(0, idx, vals) return logits def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor: """Keep only the top-k logits; set the rest to -inf. top_k<=0 or top_k>=vocab disables the filter.""" if top_k <= 0 or top_k >= logits.size(-1): return logits topk_vals, topk_idx = logits.topk(top_k) mask = torch.full_like(logits, float("-inf")) mask.scatter_(0, topk_idx, topk_vals) return mask def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: """Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p.""" if top_p >= 1.0 or top_p <= 0.0: return logits sorted_logits, sorted_idx = logits.sort(descending=True) cumulative_probs = sorted_logits.softmax(-1).cumsum(-1) mask = cumulative_probs > top_p # shift right so we always keep at least one token mask[1:] = mask[:-1].clone() mask[0] = False sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) out = torch.full_like(logits, float("-inf")) out.scatter_(0, sorted_idx, sorted_logits) return out def sample_token( logits: torch.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, repetition_penalty: float = 1.0, recent_tokens: Optional[Iterable[int]] = None, ) -> int: """Return a single sampled token id (Python int). logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe. """ if logits.dim() != 1: raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}") # Work in fp32 on a clone so the caller's tensor is unchanged. work = logits.detach().to(torch.float32).clone() if repetition_penalty != 1.0 and recent_tokens is not None: work = apply_repetition_penalty(work, recent_tokens, repetition_penalty) # Temperature. Greedy when temperature <= 0. if temperature <= 0.0: return int(work.argmax().item()) work = work / max(temperature, 1e-6) work = apply_top_k(work, top_k) work = apply_top_p(work, top_p) # Guard against all-(-inf) (can happen if top_k/top_p filter everything out). if torch.isinf(work).all(): return int(logits.argmax().item()) probs = torch.softmax(work, dim=-1) # Numerical safety — replace any NaN with 0 and renormalize. if torch.isnan(probs).any(): probs = torch.nan_to_num(probs, nan=0.0) s = probs.sum() if s <= 0: return int(logits.argmax().item()) probs = probs / s tok = torch.multinomial(probs, num_samples=1) return int(tok.item())