File size: 5,134 Bytes
e2bfccc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Generate a few text samples from a saved checkpoint."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import torch

REPO_ROOT = Path(__file__).resolve().parents[2]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

from taoTrain.checkpointing.checkpoint import CheckpointManager
from taoTrain.config import ModelConfig
from taoTrain.inference.inferencer import Inferencer
from taoTrain.models import get_model


def clear_kernel_caches(model) -> None:
    for module in model.modules():
        clear = getattr(module, "clear_kernel_cache", None)
        if callable(clear):
            clear()


def generate_once(

    model,

    tokenizer,

    prompt: str,

    *,

    device: torch.device,

    max_new_tokens: int,

    temperature: float,

    top_p: float,

    dtype: torch.dtype,

) -> str:
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated = []
    eos_token_id = getattr(tokenizer, "eos_token_id", None)
    model.eval()
    device_type = "cuda" if device.type == "cuda" else "cpu"
    autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16}
    with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled):
        for _ in range(max_new_tokens):
            clear_kernel_caches(model)
            outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None)
            logits = outputs["logits"][:, -1, :] / max(temperature, 1e-6)
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                sorted_probs = torch.softmax(sorted_logits, dim=-1)
                cumulative = torch.cumsum(sorted_probs, dim=-1)
                remove = cumulative > top_p
                remove[..., 1:] = remove[..., :-1].clone()
                remove[..., 0] = False
                indices_to_remove = sorted_indices[remove]
                logits[0, indices_to_remove] = float("-inf")
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            token_id = int(next_token.item())
            if eos_token_id is not None and token_id == eos_token_id:
                break
            generated.append(token_id)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
    clear_kernel_caches(model)
    return tokenizer.decode(generated, skip_special_tokens=True)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--tokenizer-path", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--prompt", action="append", default=[])
    parser.add_argument("--max-new-tokens", type=int, default=80)
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--top-p", type=float, default=0.9)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16")
    args = parser.parse_args()

    prompts = args.prompt or [
        "The purpose of artificial intelligence is",
        "In a small village,",
        "<user>Hello, who are you?<assistant>",
    ]
    device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu")
    dtype = {
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
        "float16": torch.float16,
    }[args.dtype]
    tokenizer = Inferencer._load_tokenizer(args.tokenizer_path)
    checkpoint_path = Path(args.checkpoint)
    checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device)
    model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {}))
    model = get_model(model_config, device=device)
    model.load_state_dict(checkpoint["model_state"], strict=False)

    samples = []
    for prompt in prompts:
        text = generate_once(
            model,
            tokenizer,
            prompt,
            device=device,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_p=args.top_p,
            dtype=dtype,
        )
        samples.append({"prompt": prompt, "completion": text})

    result = {
        "checkpoint": args.checkpoint,
        "tokenizer_path": args.tokenizer_path,
        "device": str(device),
        "dtype": str(dtype),
        "max_new_tokens": args.max_new_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "samples": samples,
    }
    output = Path(args.output)
    output.parent.mkdir(parents=True, exist_ok=True)
    output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8")
    print(json.dumps(result, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()