File size: 5,029 Bytes
c342850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Quick test of model quality with diverse prompts."""

import os, sys, time, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer

DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT
DEVICE = "cuda:0"

USER_START = "<|user|>\n"
ASST_START = "<|assistant|>\n"
TURN_END = "\n<|end|>\n"

TEST_PROMPTS = [
    "Hi! How are you?",
    "What is photosynthesis?",
    "Explain gravity to a 5-year-old.",
    "Write a short poem about the ocean.",
    "What are the three states of matter?",
    "How does a computer work?",
    "What is the capital of France and why is it famous?",
    "Give me 3 tips for learning a new language.",
    "What is machine learning in simple terms?",
]


@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=256,
             temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15):
    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE)
    generated = []
    eos_id = tokenizer.eos_token_id

    end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False)
    end_id = end_token_ids[0] if end_token_ids else None
    user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False)
    user_id = user_token_ids[0] if user_token_ids else None

    stop_ids = set()
    if eos_id is not None:
        stop_ids.add(eos_id)
    if end_id is not None:
        stop_ids.add(end_id)
    if user_id is not None:
        stop_ids.add(user_id)

    for _ in range(max_new_tokens):
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, _ = model(input_ids)

        logits = logits[:, -1, :].float()

        if repetition_penalty != 1.0 and generated:
            for tid in set(generated):
                if logits[0, tid] > 0:
                    logits[0, tid] /= repetition_penalty
                else:
                    logits[0, tid] *= repetition_penalty

        logits = logits / max(temperature, 1e-5)

        if top_k > 0:
            topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < topk_vals[:, -1:]] = float('-inf')

        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p
            sorted_logits[remove] = float('-inf')
            logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)

        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1)
        token_id = next_token.item()

        if token_id in stop_ids:
            break

        generated.append(token_id)
        input_ids = torch.cat([input_ids, next_token], dim=1)

        if input_ids.size(1) > 2048:
            break

    return tokenizer.decode(generated, skip_special_tokens=True)


def main():
    ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT"
    print("=" * 70)
    print("  " + ckpt_name + " MODEL TEST")
    print("=" * 70)

    tokenizer = get_tokenizer()
    special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
    vocab = tokenizer.get_vocab()
    new_tokens = [t for t in special_tokens if t not in vocab]
    if new_tokens:
        tokenizer.add_tokens(new_tokens, special_tokens=True)

    config = ModelConfig()
    config.vocab_size = len(tokenizer)
    model = Transformer(config)

    print("")
    print("Loading checkpoint: " + CHECKPOINT)
    ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
    model.load_state_dict(ckpt["model"])
    step = ckpt.get("step", "?")
    del ckpt

    model = model.to(DEVICE).bfloat16().eval()
    print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")")
    mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9
    print("GPU memory: " + str(round(mem, 1)) + " GB")
    print("-" * 70)

    for i, question in enumerate(TEST_PROMPTS, 1):
        prompt = USER_START + question + TURN_END + ASST_START

        print("")
        print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]")
        print("  Q: " + question)

        t0 = time.time()
        response = generate(model, tokenizer, prompt)
        dt = time.time() - t0
        tokens = len(tokenizer.encode(response, add_special_tokens=False))

        response = response.split("<|end|>")[0].split("<|user|>")[0].strip()

        print("  A: " + response)
        tps = int(tokens / max(dt, 0.01))
        print("  [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]")
        print("-" * 70)

    print("")
    print("Done!")


if __name__ == "__main__":
    main()