| """Per-request sampling. Temperature, top-p, top-k, greedy.""" |
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
|
|
| from .config import SamplingParams |
|
|
|
|
| class Sampler: |
| def __init__(self, device: torch.device) -> None: |
| self.device = device |
|
|
| def sample( |
| self, |
| logits: torch.Tensor, |
| params: list[SamplingParams], |
| generators: Optional[list[Optional[torch.Generator]]] = None, |
| ) -> list[int]: |
| out: list[int] = [] |
| for i, p in enumerate(params): |
| row = logits[i] |
| if p.is_greedy: |
| out.append(int(row.argmax().item())) |
| continue |
|
|
| |
| row = row / max(p.temperature, 1e-5) |
| |
| if p.top_k > 0 and p.top_k < row.size(-1): |
| topk_vals, _ = torch.topk(row, p.top_k) |
| row = torch.where(row < topk_vals[-1], torch.full_like(row, float("-inf")), row) |
| |
| if 0.0 < p.top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(row, descending=True) |
| probs = torch.softmax(sorted_logits, dim=-1) |
| cumprobs = probs.cumsum(dim=-1) |
| |
| |
| drop = cumprobs > p.top_p |
| drop[0] = False |
| drop = drop.roll(shifts=1, dims=0) |
| drop[0] = False |
| sorted_logits = sorted_logits.masked_fill(drop, float("-inf")) |
| row = torch.full_like(row, float("-inf")) |
| row.scatter_(0, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(row, dim=-1) |
| gen = generators[i] if generators else None |
| token = torch.multinomial(probs, num_samples=1, generator=gen) |
| out.append(int(token.item())) |
| return out |
|
|