File size: 2,040 Bytes
c32c359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Per-request sampling.  Temperature, top-p, top-k, greedy."""
from __future__ import annotations

from typing import Optional

import torch

from .config import SamplingParams


class Sampler:
    def __init__(self, device: torch.device) -> None:
        self.device = device

    def sample(
        self,
        logits: torch.Tensor,                 # [num_seqs, vocab]
        params: list[SamplingParams],
        generators: Optional[list[Optional[torch.Generator]]] = None,
    ) -> list[int]:
        out: list[int] = []
        for i, p in enumerate(params):
            row = logits[i]
            if p.is_greedy:
                out.append(int(row.argmax().item()))
                continue

            # Temperature.
            row = row / max(p.temperature, 1e-5)
            # Top-k.
            if p.top_k > 0 and p.top_k < row.size(-1):
                topk_vals, _ = torch.topk(row, p.top_k)
                row = torch.where(row < topk_vals[-1], torch.full_like(row, float("-inf")), row)
            # Top-p (nucleus).
            if 0.0 < p.top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(row, descending=True)
                probs = torch.softmax(sorted_logits, dim=-1)
                cumprobs = probs.cumsum(dim=-1)
                # Drop tokens whose CUMULATIVE prob (including themselves) exceeds top_p,
                # but always keep the highest-probability one.
                drop = cumprobs > p.top_p
                drop[0] = False
                drop = drop.roll(shifts=1, dims=0)  # so the boundary token stays
                drop[0] = False
                sorted_logits = sorted_logits.masked_fill(drop, float("-inf"))
                row = torch.full_like(row, float("-inf"))
                row.scatter_(0, sorted_idx, sorted_logits)

            probs = torch.softmax(row, dim=-1)
            gen = generators[i] if generators else None
            token = torch.multinomial(probs, num_samples=1, generator=gen)
            out.append(int(token.item()))
        return out