dkumar15 commited on
Commit
5200189
·
verified ·
1 Parent(s): f099982

Upload training_code/chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_code/chat.py +318 -0
training_code/chat.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Interactive chat with the 1B Transformer.
4
+ Runs in an infinite conversation loop from the terminal.
5
+
6
+ Usage:
7
+ python chat.py # auto-find latest checkpoint
8
+ python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint
9
+ """
10
+
11
+ import sys
12
+ import os
13
+ import glob
14
+ import time
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import readline # enables arrow keys and history in input()
18
+
19
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
20
+ from model.config import ModelConfig
21
+ from model.transformer import Transformer
22
+ from model.data import get_tokenizer
23
+
24
+
25
+ def find_latest_checkpoint():
26
+ """Look for DPO > SFT > pretrained checkpoint."""
27
+ dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo"
28
+ sft_dir = "/jfs/deepak-kumar/checkpoints_sft"
29
+ pt_dir = "/jfs/deepak-kumar/checkpoints"
30
+
31
+ # Prefer DPO final
32
+ dpo_final = os.path.join(dpo_dir, "dpo_final.pt")
33
+ if os.path.exists(dpo_final):
34
+ return dpo_final, True
35
+
36
+ dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt"))
37
+ if dpo_files:
38
+ return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True
39
+
40
+ # Then SFT
41
+ sft_final = os.path.join(sft_dir, "sft_final.pt")
42
+ if os.path.exists(sft_final):
43
+ return sft_final, True
44
+
45
+ sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt"))
46
+ if sft_files:
47
+ return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True
48
+
49
+ # Fall back to pretrained
50
+ pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt"))
51
+ if pt_files:
52
+ return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False
53
+
54
+ return None, False
55
+
56
+
57
+ def load_model(checkpoint_path, tokenizer, device="cuda:0"):
58
+ config = ModelConfig()
59
+ model = Transformer(config)
60
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
61
+
62
+ # Handle expanded vocab from SFT
63
+ saved_vocab = ckpt.get("vocab_size", config.vocab_size)
64
+ if saved_vocab > config.vocab_size:
65
+ config.vocab_size = saved_vocab
66
+ model = Transformer(config)
67
+
68
+ model.load_state_dict(ckpt["model"])
69
+ model = model.to(device).bfloat16().eval()
70
+ step = ckpt.get("step", "?")
71
+ loss = ckpt.get("loss", "?")
72
+ del ckpt
73
+ torch.cuda.empty_cache()
74
+ return model, config, step, loss
75
+
76
+
77
+ @torch.no_grad()
78
+ def generate_stream(model, tokenizer, prompt, max_new_tokens=512,
79
+ temperature=0.8, top_k=50, top_p=0.9,
80
+ repetition_penalty=1.15, device="cuda:0",
81
+ stop_token_ids=None):
82
+ """Generate tokens one at a time, yielding each for streaming output."""
83
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
84
+ generated_ids = []
85
+ prev_decoded_len = 0
86
+
87
+ if stop_token_ids is None:
88
+ stop_token_ids = set()
89
+ else:
90
+ stop_token_ids = set(stop_token_ids)
91
+ stop_token_ids.add(tokenizer.eos_token_id)
92
+
93
+ for _ in range(max_new_tokens):
94
+ if input_ids.shape[1] >= model.config.max_seq_len:
95
+ break
96
+
97
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
98
+ logits, _ = model(input_ids)
99
+
100
+ logits = logits[:, -1, :]
101
+
102
+ if repetition_penalty != 1.0 and generated_ids:
103
+ prev_tokens = torch.tensor(generated_ids, device=device).unique()
104
+ for token_id in prev_tokens:
105
+ if logits[0, token_id] > 0:
106
+ logits[0, token_id] /= repetition_penalty
107
+ else:
108
+ logits[0, token_id] *= repetition_penalty
109
+
110
+ logits = logits / temperature
111
+
112
+ if top_k > 0:
113
+ topk_vals, _ = torch.topk(logits, top_k)
114
+ logits[logits < topk_vals[:, -1:]] = float("-inf")
115
+
116
+ if top_p < 1.0:
117
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
118
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
119
+ mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
120
+ sorted_logits[mask] = float("-inf")
121
+ logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
122
+
123
+ probs = F.softmax(logits, dim=-1)
124
+ next_token = torch.multinomial(probs, num_samples=1)
125
+ token_id = next_token.item()
126
+
127
+ # Stop on any stop token (EOS, <|end|>, <|user|>)
128
+ if token_id in stop_token_ids:
129
+ break
130
+
131
+ generated_ids.append(token_id)
132
+ input_ids = torch.cat([input_ids, next_token], dim=1)
133
+
134
+ full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
135
+ new_text = full_decoded[prev_decoded_len:]
136
+ prev_decoded_len = len(full_decoded)
137
+ yield new_text
138
+
139
+ return
140
+
141
+
142
+ def print_banner(step, loss, device):
143
+ print("\033[1;36m") # cyan bold
144
+ print("=" * 60)
145
+ print(" 1B TRANSFORMER — Interactive Chat")
146
+ print("=" * 60)
147
+ print(f"\033[0m Checkpoint : step {step}")
148
+ print(f" Loss : {loss}")
149
+ print(f" Device : {device}")
150
+ print(f" Parameters : 1.106B")
151
+ print()
152
+ print(" \033[90mCommands:\033[0m")
153
+ print(" \033[33m/quit\033[0m — exit")
154
+ print(" \033[33m/clear\033[0m — clear conversation context")
155
+ print(" \033[33m/temp N\033[0m — set temperature (default 0.8)")
156
+ print(" \033[33m/tokens N\033[0m — set max tokens (default 512)")
157
+ print(" \033[33m/topp N\033[0m — set top-p (default 0.9)")
158
+ print(" \033[33m/topk N\033[0m — set top-k (default 50)")
159
+ print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)")
160
+ print()
161
+ print("\033[90m" + "─" * 60 + "\033[0m")
162
+
163
+
164
+ def main():
165
+ device = "cuda:0"
166
+
167
+ is_sft = False
168
+ if len(sys.argv) > 1:
169
+ checkpoint = sys.argv[1]
170
+ is_sft = "sft" in checkpoint.lower()
171
+ else:
172
+ result = find_latest_checkpoint()
173
+ if result[0] is None:
174
+ print("No checkpoint found!")
175
+ sys.exit(1)
176
+ checkpoint, is_sft = result
177
+
178
+ tokenizer = get_tokenizer()
179
+
180
+ # Add chat tokens for SFT models
181
+ if is_sft:
182
+ special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
183
+ vocab = tokenizer.get_vocab()
184
+ new_tokens = [t for t in special_tokens if t not in vocab]
185
+ if new_tokens:
186
+ tokenizer.add_tokens(new_tokens, special_tokens=True)
187
+
188
+ print(f"\n Loading model from {checkpoint}...")
189
+ print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}")
190
+ model, config, step, loss = load_model(checkpoint, tokenizer, device)
191
+ print(f" Model loaded!\n")
192
+
193
+ print_banner(step, loss, device)
194
+ if is_sft:
195
+ print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n")
196
+
197
+ # Settings
198
+ temperature = 0.7 if is_sft else 0.8
199
+ max_tokens = 512
200
+ top_p = 0.9
201
+ top_k = 50
202
+ rep_penalty = 1.15
203
+ context = ""
204
+
205
+ # Chat template tokens for SFT
206
+ USER_START = "<|user|>\n"
207
+ ASST_START = "<|assistant|>\n"
208
+ TURN_END = "\n<|end|>\n"
209
+
210
+ # Build stop token IDs for generation
211
+ sft_stop_ids = []
212
+ if is_sft:
213
+ vocab = tokenizer.get_vocab()
214
+ for tok_str in ["<|end|>", "<|user|>"]:
215
+ if tok_str in vocab:
216
+ sft_stop_ids.append(vocab[tok_str])
217
+
218
+ while True:
219
+ try:
220
+ user_input = input("\n\033[1;32mYou:\033[0m ").strip()
221
+ except (KeyboardInterrupt, EOFError):
222
+ print("\n\nGoodbye!")
223
+ break
224
+
225
+ if not user_input:
226
+ continue
227
+
228
+ # Handle commands
229
+ if user_input.startswith("/"):
230
+ cmd = user_input.lower().split()
231
+ if cmd[0] == "/quit":
232
+ print("Goodbye!")
233
+ break
234
+ elif cmd[0] == "/clear":
235
+ context = ""
236
+ print("\033[90m [Context cleared]\033[0m")
237
+ continue
238
+ elif cmd[0] == "/temp" and len(cmd) > 1:
239
+ temperature = float(cmd[1])
240
+ print(f"\033[90m [Temperature set to {temperature}]\033[0m")
241
+ continue
242
+ elif cmd[0] == "/tokens" and len(cmd) > 1:
243
+ max_tokens = int(cmd[1])
244
+ print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m")
245
+ continue
246
+ elif cmd[0] == "/topp" and len(cmd) > 1:
247
+ top_p = float(cmd[1])
248
+ print(f"\033[90m [Top-p set to {top_p}]\033[0m")
249
+ continue
250
+ elif cmd[0] == "/topk" and len(cmd) > 1:
251
+ top_k = int(cmd[1])
252
+ print(f"\033[90m [Top-k set to {top_k}]\033[0m")
253
+ continue
254
+ elif cmd[0] == "/rep" and len(cmd) > 1:
255
+ rep_penalty = float(cmd[1])
256
+ print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m")
257
+ continue
258
+ else:
259
+ print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m")
260
+ continue
261
+
262
+ # Build prompt
263
+ if is_sft:
264
+ prompt = context + USER_START + user_input + TURN_END + ASST_START
265
+ else:
266
+ if context:
267
+ prompt = context + "\n" + user_input
268
+ else:
269
+ prompt = user_input
270
+
271
+ # Trim context if too long
272
+ while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens:
273
+ if is_sft:
274
+ parts = context.split(TURN_END)
275
+ if len(parts) <= 2:
276
+ break
277
+ context = TURN_END.join(parts[2:])
278
+ prompt = context + USER_START + user_input + TURN_END + ASST_START
279
+ else:
280
+ lines = prompt.split("\n")
281
+ if len(lines) <= 2:
282
+ break
283
+ prompt = "\n".join(lines[1:])
284
+
285
+ # Generate with streaming
286
+ print("\033[1;34mModel:\033[0m ", end="", flush=True)
287
+ t0 = time.time()
288
+ full_response = ""
289
+ token_count = 0
290
+
291
+ for token_text in generate_stream(
292
+ model, tokenizer, prompt,
293
+ max_new_tokens=max_tokens,
294
+ temperature=temperature,
295
+ top_k=top_k,
296
+ top_p=top_p,
297
+ repetition_penalty=rep_penalty,
298
+ device=device,
299
+ stop_token_ids=sft_stop_ids if is_sft else None,
300
+ ):
301
+ print(token_text, end="", flush=True)
302
+ full_response += token_text
303
+ token_count += 1
304
+
305
+ elapsed = time.time() - t0
306
+ tps = token_count / max(elapsed, 1e-9)
307
+ print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m")
308
+
309
+ # Append to context for multi-turn
310
+ if is_sft:
311
+ context = (context + USER_START + user_input + TURN_END +
312
+ ASST_START + full_response.strip() + TURN_END)
313
+ else:
314
+ context = prompt + full_response
315
+
316
+
317
+ if __name__ == "__main__":
318
+ main()