File size: 4,336 Bytes
a19b01b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Inference script for the 1B Transformer — Single GPU.

Usage:
  python inference.py                          # auto-finds latest checkpoint
  python inference.py /path/to/checkpoint.pt   # specific checkpoint
"""

import sys
import os
import glob
import time
import torch
import torch.nn.functional as F

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer


def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"):
    files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
    if not files:
        final = os.path.join(checkpoint_dir, "final.pt")
        return final if os.path.exists(final) else None
    return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0]))


def load_model(checkpoint_path, device="cuda:0"):
    config = ModelConfig()
    model = Transformer(config)

    print(f"Loading checkpoint: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
    model.load_state_dict(ckpt["model"])
    model = model.to(device).bfloat16().eval()

    step = ckpt.get("step", "?")
    loss = ckpt.get("loss", "?")
    print(f"  Step: {step} | Loss: {loss}")
    print(f"  Params: {sum(p.numel() for p in model.parameters()):,}")
    print(f"  Device: {device}")
    del ckpt
    torch.cuda.empty_cache()
    return model, config


@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=200,
             temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    t0 = time.time()

    for i in range(max_new_tokens):
        if input_ids.shape[1] >= model.config.max_seq_len:
            break

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, _ = model(input_ids)

        logits = logits[:, -1, :] / temperature

        if top_k > 0:
            topk_vals, _ = torch.topk(logits, top_k)
            logits[logits < topk_vals[:, -1:]] = float("-inf")

        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
            sorted_logits[mask] = float("-inf")
            logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

        input_ids = torch.cat([input_ids, next_token], dim=1)

    elapsed = time.time() - t0
    gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt))
    tok_per_sec = gen_tokens / max(elapsed, 1e-9)

    text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return text, gen_tokens, tok_per_sec


def main():
    device = "cuda:0"
    if len(sys.argv) > 1:
        checkpoint = sys.argv[1]
    else:
        checkpoint = find_latest_checkpoint()
        if checkpoint is None:
            print("No checkpoint found!")
            sys.exit(1)

    model, config = load_model(checkpoint, device)
    tokenizer = get_tokenizer()

    prompts = [
        "The meaning of life is",
        "In machine learning, a neural network",
        "The capital of France is",
        "Once upon a time, there was a",
        "To solve a quadratic equation, you need to",
        "The theory of relativity explains that",
        "Python is a programming language that",
        "The sun rises in the east and",
    ]

    print("\n" + "=" * 70)
    print("  INFERENCE — 1B Transformer (Single GPU)")
    print("=" * 70)

    for prompt in prompts:
        print(f"\n{'─' * 60}")
        print(f"PROMPT: {prompt}")
        print(f"{'─' * 60}")
        text, n_tok, tps = generate(model, tokenizer, prompt,
                                     max_new_tokens=150, temperature=0.8,
                                     top_k=50, device=device)
        generated = text[len(prompt):]
        print(f"OUTPUT:{generated}")
        print(f"  [{n_tok} tokens, {tps:.1f} tok/s]")

    print("\n" + "=" * 70)


if __name__ == "__main__":
    main()