File size: 1,089 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch import Tensor
from einops import rearrange
from transformers import (
    LogitsProcessorList,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
)


def sample_logits(
    logits: Tensor,
    temperature: float = 0.8,
    top_k: int = 20,
    top_p: float = 0.9,
    generator: torch.Generator | None = None,
) -> Tensor:
    """
    logits: (B, T, V)
    return: (B, T)
    """
    B, T, _ = logits.shape

    flat_logits = rearrange(logits, "b t v -> (b t) v")

    processors = LogitsProcessorList(
        [
            TemperatureLogitsWarper(temperature=temperature),
            TopKLogitsWarper(top_k=top_k),
            TopPLogitsWarper(top_p=top_p),
        ]
    )

    dummy_input_ids = torch.zeros(
        B * T,
        1,
        dtype=torch.long,
        device=logits.device,
    )

    flat_logits = processors(dummy_input_ids, flat_logits)

    probs = torch.softmax(flat_logits, dim=-1)
    sampled = torch.multinomial(probs, num_samples=1, generator=generator)

    return rearrange(sampled, "(b t) 1 -> b t", b=B, t=T)