File size: 5,152 Bytes
267f903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
SamplingRunner: wires model, tokenizer, schedule, and sampler together.
"""
from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Callable, Optional

import torch
import yaml

from .loader import ModelWrapper, load_checkpoint
from .vocab import build_masked_input, decode_output, load_tokenizer
from ..samplers import get_sampler
from ..samplers.base import StepCallback


@dataclass
class RunConfig:
    """All settings for one generation run."""
    # Required
    ckpt_path: str
    config_path: str

    # Tokenizer
    tokenizer_path: str = "whale-tokenizer"

    # Sampler
    sampler: str = "standard"
    steps: int = 64
    max_new_tokens: int = 256
    temperature: float = 0.0
    top_k: int = 0
    p: float = 0.9

    # Jump sampler extras
    jump_last_steps: int = 10
    jump_frac: float = 0.10
    jump_min_tokens: int = 1
    no_mask_jump: bool = True

    # GIDD sampler extras
    gidd_eps: float = 1e-4
    gidd_min_p: float = 0.0
    posterior_temperature: float = 1.0
    suppress_mask_clean: bool = False
    rho_mode: str = "w1_train_like"
    gidd_exact_mode: bool = False
    fail_on_negative_mass: bool = False

    # Runtime
    device: str = "cuda"
    dtype: str = "bf16"
    use_ema: bool = True
    strict: bool = False
    seed: Optional[int] = 1234


@dataclass
class GenerationResult:
    full_text: str
    new_text: str
    prompt_tokens: int
    generated_tokens: int
    total_tokens: int
    elapsed_s: float
    sampler: str
    steps_run: int


class SamplingRunner:
    """
    Load once, call ``run()`` many times with different prompts.

    Usage::

        runner = SamplingRunner(cfg)
        result = runner.run("The quick brown")
    """

    def __init__(self, cfg: RunConfig):
        self.cfg = cfg

        config = _load_yaml(cfg.config_path)
        device = torch.device(
            cfg.device if cfg.device
            else ("cuda" if torch.cuda.is_available() else "cpu")
        )

        self.model_wrapper = load_checkpoint(
            ckpt_path=cfg.ckpt_path,
            config=config,
            device=device,
            dtype=cfg.dtype,
            use_ema=cfg.use_ema,
            strict=cfg.strict,
        )

        self.tokenizer = load_tokenizer(cfg.tokenizer_path)
        self.config = config
        self.device = device
        self.mask_token_id = self.model_wrapper.mask_token_id
        self.vocab_size = self.model_wrapper.vocab_size
        self.max_seq_len = int(config["model"]["max_seq_len"])

    def run(
        self,
        prompt: str = "",
        callback: Optional[StepCallback] = None,
    ) -> GenerationResult:
        cfg = self.cfg

        if cfg.seed is not None:
            torch.manual_seed(cfg.seed)
            if self.device.type == "cuda":
                torch.cuda.manual_seed_all(cfg.seed)

        x_init, prefix_len = build_masked_input(
            tokenizer=self.tokenizer,
            prompt=prompt,
            max_new_tokens=cfg.max_new_tokens,
            max_seq_len=self.max_seq_len,
            mask_token_id=self.mask_token_id,
            device=self.device,
        )
        total_len = x_init.shape[1]

        timesteps = torch.linspace(1.0, 1e-4, cfg.steps, device=self.device)

        sampler_cfg = {
            "temperature": cfg.temperature,
            "top_k": cfg.top_k,
            "p": cfg.p,
            "jump_last_steps": cfg.jump_last_steps,
            "jump_frac": cfg.jump_frac,
            "jump_min_tokens": cfg.jump_min_tokens,
            "no_mask_jump": cfg.no_mask_jump,
            "gidd_eps": cfg.gidd_eps,
            "gidd_min_p": cfg.gidd_min_p,
            "posterior_temperature": cfg.posterior_temperature,
            "suppress_mask_clean": cfg.suppress_mask_clean,
            "rho_mode": cfg.rho_mode,
            "gidd_exact_mode": cfg.gidd_exact_mode,
            "fail_on_negative_mass": cfg.fail_on_negative_mass,
        }

        sampler_fn = get_sampler(cfg.sampler)

        t0 = time.perf_counter()
        if self.device.type == "cuda":
            torch.cuda.synchronize(self.device)

        x_final = sampler_fn(
            model_fn=self.model_wrapper,
            x_init=x_init,
            prefix_len=prefix_len,
            vocab_size=self.vocab_size,
            mask_token_id=self.mask_token_id,
            timesteps=timesteps,
            cfg=sampler_cfg,
            callback=callback,
        )

        if self.device.type == "cuda":
            torch.cuda.synchronize(self.device)
        elapsed = time.perf_counter() - t0

        full_text, new_text = decode_output(self.tokenizer, x_final, prefix_len)
        generated_tokens = total_len - prefix_len

        return GenerationResult(
            full_text=full_text,
            new_text=new_text,
            prompt_tokens=prefix_len,
            generated_tokens=generated_tokens,
            total_tokens=total_len,
            elapsed_s=elapsed,
            sampler=cfg.sampler,
            steps_run=cfg.steps,
        )


def _load_yaml(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f) or {}