tiny-vllm / tiny_vllm /sampler.py
enCoder's picture
minimal continuous-batching LLM engine
c32c359
"""Per-request sampling. Temperature, top-p, top-k, greedy."""
from __future__ import annotations
from typing import Optional
import torch
from .config import SamplingParams
class Sampler:
def __init__(self, device: torch.device) -> None:
self.device = device
def sample(
self,
logits: torch.Tensor, # [num_seqs, vocab]
params: list[SamplingParams],
generators: Optional[list[Optional[torch.Generator]]] = None,
) -> list[int]:
out: list[int] = []
for i, p in enumerate(params):
row = logits[i]
if p.is_greedy:
out.append(int(row.argmax().item()))
continue
# Temperature.
row = row / max(p.temperature, 1e-5)
# Top-k.
if p.top_k > 0 and p.top_k < row.size(-1):
topk_vals, _ = torch.topk(row, p.top_k)
row = torch.where(row < topk_vals[-1], torch.full_like(row, float("-inf")), row)
# Top-p (nucleus).
if 0.0 < p.top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(row, descending=True)
probs = torch.softmax(sorted_logits, dim=-1)
cumprobs = probs.cumsum(dim=-1)
# Drop tokens whose CUMULATIVE prob (including themselves) exceeds top_p,
# but always keep the highest-probability one.
drop = cumprobs > p.top_p
drop[0] = False
drop = drop.roll(shifts=1, dims=0) # so the boundary token stays
drop[0] = False
sorted_logits = sorted_logits.masked_fill(drop, float("-inf"))
row = torch.full_like(row, float("-inf"))
row.scatter_(0, sorted_idx, sorted_logits)
probs = torch.softmax(row, dim=-1)
gen = generators[i] if generators else None
token = torch.multinomial(probs, num_samples=1, generator=gen)
out.append(int(token.item()))
return out