File size: 1,562 Bytes
267f903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
whale4b -- open-source inference for the W1-4B diffusion language model.
"""
from __future__ import annotations

from pathlib import Path
from typing import Optional

_PKG_DIR = Path(__file__).resolve().parent

# Re-export key classes for advanced users
from .core.model import LangDiT, create_model
from .core.loader import load_checkpoint, ModelWrapper
from .core.runner import SamplingRunner, RunConfig, GenerationResult
from .samplers import get_sampler, list_samplers


def generate(
    checkpoint: str,
    prompt: str = "",
    *,
    sampler: str = "standard",
    steps: int = 64,
    max_new_tokens: int = 256,
    temperature: float = 0.0,
    top_k: int = 0,
    device: str = "",
    dtype: str = "bf16",
    seed: Optional[int] = 1234,
    config: Optional[str] = None,
    tokenizer_path: Optional[str] = None,
) -> str:
    """
    One-call convenience API.  Returns the generated continuation text.

    For repeated calls with the same model, instantiate
    :class:`SamplingRunner` directly to avoid reloading weights each time.
    """
    cfg = RunConfig(
        ckpt_path=checkpoint,
        config_path=config or str(_PKG_DIR / "configs" / "whale3b.yaml"),
        tokenizer_path=tokenizer_path or str(_PKG_DIR / "whale-tokenizer"),
        sampler=sampler,
        steps=steps,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        device=device,
        dtype=dtype,
        seed=seed,
    )
    runner = SamplingRunner(cfg)
    result = runner.run(prompt)
    return result.new_text