Spaces:
Runtime error
Runtime error
| """Shared sampling utilities for chat.py / chat_eval.py. | |
| Pure functions: given a 1-D logits tensor (vocab_size,), return a single | |
| sampled token id. No model/tokenizer knowledge here. | |
| """ | |
| from __future__ import annotations | |
| from typing import Iterable, Optional | |
| import torch | |
| def apply_repetition_penalty( | |
| logits: torch.Tensor, | |
| recent_tokens: Optional[Iterable[int]], | |
| penalty: float, | |
| ) -> torch.Tensor: | |
| """Divide logits of recent positive tokens by `penalty`, multiply negatives. | |
| Operates in-place on a *copy* (logits is cloned first by caller if needed). | |
| `recent_tokens` may be any iterable of ints; duplicates are deduped internally. | |
| """ | |
| if penalty == 1.0 or not recent_tokens: | |
| return logits | |
| seen = set(int(t) for t in recent_tokens) | |
| if not seen: | |
| return logits | |
| idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long) | |
| vals = logits.index_select(0, idx) | |
| vals = torch.where(vals > 0, vals / penalty, vals * penalty) | |
| logits.index_copy_(0, idx, vals) | |
| return logits | |
| def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor: | |
| """Keep only the top-k logits; set the rest to -inf. | |
| top_k<=0 or top_k>=vocab disables the filter.""" | |
| if top_k <= 0 or top_k >= logits.size(-1): | |
| return logits | |
| topk_vals, topk_idx = logits.topk(top_k) | |
| mask = torch.full_like(logits, float("-inf")) | |
| mask.scatter_(0, topk_idx, topk_vals) | |
| return mask | |
| def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor: | |
| """Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p.""" | |
| if top_p >= 1.0 or top_p <= 0.0: | |
| return logits | |
| sorted_logits, sorted_idx = logits.sort(descending=True) | |
| cumulative_probs = sorted_logits.softmax(-1).cumsum(-1) | |
| mask = cumulative_probs > top_p | |
| # shift right so we always keep at least one token | |
| mask[1:] = mask[:-1].clone() | |
| mask[0] = False | |
| sorted_logits = sorted_logits.masked_fill(mask, float("-inf")) | |
| out = torch.full_like(logits, float("-inf")) | |
| out.scatter_(0, sorted_idx, sorted_logits) | |
| return out | |
| def sample_token( | |
| logits: torch.Tensor, | |
| temperature: float = 1.0, | |
| top_k: int = 0, | |
| top_p: float = 1.0, | |
| repetition_penalty: float = 1.0, | |
| recent_tokens: Optional[Iterable[int]] = None, | |
| ) -> int: | |
| """Return a single sampled token id (Python int). | |
| logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe. | |
| """ | |
| if logits.dim() != 1: | |
| raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}") | |
| # Work in fp32 on a clone so the caller's tensor is unchanged. | |
| work = logits.detach().to(torch.float32).clone() | |
| if repetition_penalty != 1.0 and recent_tokens is not None: | |
| work = apply_repetition_penalty(work, recent_tokens, repetition_penalty) | |
| # Temperature. Greedy when temperature <= 0. | |
| if temperature <= 0.0: | |
| return int(work.argmax().item()) | |
| work = work / max(temperature, 1e-6) | |
| work = apply_top_k(work, top_k) | |
| work = apply_top_p(work, top_p) | |
| # Guard against all-(-inf) (can happen if top_k/top_p filter everything out). | |
| if torch.isinf(work).all(): | |
| return int(logits.argmax().item()) | |
| probs = torch.softmax(work, dim=-1) | |
| # Numerical safety — replace any NaN with 0 and renormalize. | |
| if torch.isnan(probs).any(): | |
| probs = torch.nan_to_num(probs, nan=0.0) | |
| s = probs.sum() | |
| if s <= 0: | |
| return int(logits.argmax().item()) | |
| probs = probs / s | |
| tok = torch.multinomial(probs, num_samples=1) | |
| return int(tok.item()) | |