File size: 10,950 Bytes
5200189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#!/usr/bin/env python3
"""
Interactive chat with the 1B Transformer.
Runs in an infinite conversation loop from the terminal.

Usage:
  python chat.py                                        # auto-find latest checkpoint
  python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt  # specific checkpoint
"""

import sys
import os
import glob
import time
import torch
import torch.nn.functional as F
import readline  # enables arrow keys and history in input()

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer


def find_latest_checkpoint():
    """Look for DPO > SFT > pretrained checkpoint."""
    dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo"
    sft_dir = "/jfs/deepak-kumar/checkpoints_sft"
    pt_dir = "/jfs/deepak-kumar/checkpoints"

    # Prefer DPO final
    dpo_final = os.path.join(dpo_dir, "dpo_final.pt")
    if os.path.exists(dpo_final):
        return dpo_final, True

    dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt"))
    if dpo_files:
        return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True

    # Then SFT
    sft_final = os.path.join(sft_dir, "sft_final.pt")
    if os.path.exists(sft_final):
        return sft_final, True

    sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt"))
    if sft_files:
        return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True

    # Fall back to pretrained
    pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt"))
    if pt_files:
        return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False

    return None, False


def load_model(checkpoint_path, tokenizer, device="cuda:0"):
    config = ModelConfig()
    model = Transformer(config)
    ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

    # Handle expanded vocab from SFT
    saved_vocab = ckpt.get("vocab_size", config.vocab_size)
    if saved_vocab > config.vocab_size:
        config.vocab_size = saved_vocab
        model = Transformer(config)

    model.load_state_dict(ckpt["model"])
    model = model.to(device).bfloat16().eval()
    step = ckpt.get("step", "?")
    loss = ckpt.get("loss", "?")
    del ckpt
    torch.cuda.empty_cache()
    return model, config, step, loss


@torch.no_grad()
def generate_stream(model, tokenizer, prompt, max_new_tokens=512,
                    temperature=0.8, top_k=50, top_p=0.9,
                    repetition_penalty=1.15, device="cuda:0",
                    stop_token_ids=None):
    """Generate tokens one at a time, yielding each for streaming output."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    generated_ids = []
    prev_decoded_len = 0

    if stop_token_ids is None:
        stop_token_ids = set()
    else:
        stop_token_ids = set(stop_token_ids)
    stop_token_ids.add(tokenizer.eos_token_id)

    for _ in range(max_new_tokens):
        if input_ids.shape[1] >= model.config.max_seq_len:
            break

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, _ = model(input_ids)

        logits = logits[:, -1, :]

        if repetition_penalty != 1.0 and generated_ids:
            prev_tokens = torch.tensor(generated_ids, device=device).unique()
            for token_id in prev_tokens:
                if logits[0, token_id] > 0:
                    logits[0, token_id] /= repetition_penalty
                else:
                    logits[0, token_id] *= repetition_penalty

        logits = logits / temperature

        if top_k > 0:
            topk_vals, _ = torch.topk(logits, top_k)
            logits[logits < topk_vals[:, -1:]] = float("-inf")

        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
            sorted_logits[mask] = float("-inf")
            logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)

        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        token_id = next_token.item()

        # Stop on any stop token (EOS, <|end|>, <|user|>)
        if token_id in stop_token_ids:
            break

        generated_ids.append(token_id)
        input_ids = torch.cat([input_ids, next_token], dim=1)

        full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
        new_text = full_decoded[prev_decoded_len:]
        prev_decoded_len = len(full_decoded)
        yield new_text

    return


def print_banner(step, loss, device):
    print("\033[1;36m")  # cyan bold
    print("=" * 60)
    print("   1B TRANSFORMER β€” Interactive Chat")
    print("=" * 60)
    print(f"\033[0m  Checkpoint : step {step}")
    print(f"  Loss       : {loss}")
    print(f"  Device     : {device}")
    print(f"  Parameters : 1.106B")
    print()
    print("  \033[90mCommands:\033[0m")
    print("    \033[33m/quit\033[0m      β€” exit")
    print("    \033[33m/clear\033[0m     β€” clear conversation context")
    print("    \033[33m/temp N\033[0m    β€” set temperature (default 0.8)")
    print("    \033[33m/tokens N\033[0m  β€” set max tokens (default 512)")
    print("    \033[33m/topp N\033[0m    β€” set top-p (default 0.9)")
    print("    \033[33m/topk N\033[0m    β€” set top-k (default 50)")
    print("    \033[33m/rep N\033[0m     β€” set repetition penalty (default 1.15)")
    print()
    print("\033[90m" + "─" * 60 + "\033[0m")


def main():
    device = "cuda:0"

    is_sft = False
    if len(sys.argv) > 1:
        checkpoint = sys.argv[1]
        is_sft = "sft" in checkpoint.lower()
    else:
        result = find_latest_checkpoint()
        if result[0] is None:
            print("No checkpoint found!")
            sys.exit(1)
        checkpoint, is_sft = result

    tokenizer = get_tokenizer()

    # Add chat tokens for SFT models
    if is_sft:
        special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
        vocab = tokenizer.get_vocab()
        new_tokens = [t for t in special_tokens if t not in vocab]
        if new_tokens:
            tokenizer.add_tokens(new_tokens, special_tokens=True)

    print(f"\n  Loading model from {checkpoint}...")
    print(f"  Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}")
    model, config, step, loss = load_model(checkpoint, tokenizer, device)
    print(f"  Model loaded!\n")

    print_banner(step, loss, device)
    if is_sft:
        print("  \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n")

    # Settings
    temperature = 0.7 if is_sft else 0.8
    max_tokens = 512
    top_p = 0.9
    top_k = 50
    rep_penalty = 1.15
    context = ""

    # Chat template tokens for SFT
    USER_START = "<|user|>\n"
    ASST_START = "<|assistant|>\n"
    TURN_END = "\n<|end|>\n"

    # Build stop token IDs for generation
    sft_stop_ids = []
    if is_sft:
        vocab = tokenizer.get_vocab()
        for tok_str in ["<|end|>", "<|user|>"]:
            if tok_str in vocab:
                sft_stop_ids.append(vocab[tok_str])

    while True:
        try:
            user_input = input("\n\033[1;32mYou:\033[0m ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\n\nGoodbye!")
            break

        if not user_input:
            continue

        # Handle commands
        if user_input.startswith("/"):
            cmd = user_input.lower().split()
            if cmd[0] == "/quit":
                print("Goodbye!")
                break
            elif cmd[0] == "/clear":
                context = ""
                print("\033[90m  [Context cleared]\033[0m")
                continue
            elif cmd[0] == "/temp" and len(cmd) > 1:
                temperature = float(cmd[1])
                print(f"\033[90m  [Temperature set to {temperature}]\033[0m")
                continue
            elif cmd[0] == "/tokens" and len(cmd) > 1:
                max_tokens = int(cmd[1])
                print(f"\033[90m  [Max tokens set to {max_tokens}]\033[0m")
                continue
            elif cmd[0] == "/topp" and len(cmd) > 1:
                top_p = float(cmd[1])
                print(f"\033[90m  [Top-p set to {top_p}]\033[0m")
                continue
            elif cmd[0] == "/topk" and len(cmd) > 1:
                top_k = int(cmd[1])
                print(f"\033[90m  [Top-k set to {top_k}]\033[0m")
                continue
            elif cmd[0] == "/rep" and len(cmd) > 1:
                rep_penalty = float(cmd[1])
                print(f"\033[90m  [Repetition penalty set to {rep_penalty}]\033[0m")
                continue
            else:
                print("\033[90m  Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m")
                continue

        # Build prompt
        if is_sft:
            prompt = context + USER_START + user_input + TURN_END + ASST_START
        else:
            if context:
                prompt = context + "\n" + user_input
            else:
                prompt = user_input

        # Trim context if too long
        while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens:
            if is_sft:
                parts = context.split(TURN_END)
                if len(parts) <= 2:
                    break
                context = TURN_END.join(parts[2:])
                prompt = context + USER_START + user_input + TURN_END + ASST_START
            else:
                lines = prompt.split("\n")
                if len(lines) <= 2:
                    break
                prompt = "\n".join(lines[1:])

        # Generate with streaming
        print("\033[1;34mModel:\033[0m ", end="", flush=True)
        t0 = time.time()
        full_response = ""
        token_count = 0

        for token_text in generate_stream(
            model, tokenizer, prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=rep_penalty,
            device=device,
            stop_token_ids=sft_stop_ids if is_sft else None,
        ):
            print(token_text, end="", flush=True)
            full_response += token_text
            token_count += 1

        elapsed = time.time() - t0
        tps = token_count / max(elapsed, 1e-9)
        print(f"\n\033[90m  [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m")

        # Append to context for multi-turn
        if is_sft:
            context = (context + USER_START + user_input + TURN_END +
                       ASST_START + full_response.strip() + TURN_END)
        else:
            context = prompt + full_response


if __name__ == "__main__":
    main()