#!/usr/bin/env python3 """ Interactive chat with the 1B Transformer. Runs in an infinite conversation loop from the terminal. Usage: python chat.py # auto-find latest checkpoint python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint """ import sys import os import glob import time import torch import torch.nn.functional as F import readline # enables arrow keys and history in input() sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import ModelConfig from model.transformer import Transformer from model.data import get_tokenizer def find_latest_checkpoint(): """Look for DPO > SFT > pretrained checkpoint.""" dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo" sft_dir = "/jfs/deepak-kumar/checkpoints_sft" pt_dir = "/jfs/deepak-kumar/checkpoints" # Prefer DPO final dpo_final = os.path.join(dpo_dir, "dpo_final.pt") if os.path.exists(dpo_final): return dpo_final, True dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt")) if dpo_files: return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True # Then SFT sft_final = os.path.join(sft_dir, "sft_final.pt") if os.path.exists(sft_final): return sft_final, True sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt")) if sft_files: return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True # Fall back to pretrained pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt")) if pt_files: return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False return None, False def load_model(checkpoint_path, tokenizer, device="cuda:0"): config = ModelConfig() model = Transformer(config) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) # Handle expanded vocab from SFT saved_vocab = ckpt.get("vocab_size", config.vocab_size) if saved_vocab > config.vocab_size: config.vocab_size = saved_vocab model = Transformer(config) model.load_state_dict(ckpt["model"]) model = model.to(device).bfloat16().eval() step = ckpt.get("step", "?") loss = ckpt.get("loss", "?") del ckpt torch.cuda.empty_cache() return model, config, step, loss @torch.no_grad() def generate_stream(model, tokenizer, prompt, max_new_tokens=512, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.15, device="cuda:0", stop_token_ids=None): """Generate tokens one at a time, yielding each for streaming output.""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) generated_ids = [] prev_decoded_len = 0 if stop_token_ids is None: stop_token_ids = set() else: stop_token_ids = set(stop_token_ids) stop_token_ids.add(tokenizer.eos_token_id) for _ in range(max_new_tokens): if input_ids.shape[1] >= model.config.max_seq_len: break with torch.autocast(device_type="cuda", dtype=torch.bfloat16): logits, _ = model(input_ids) logits = logits[:, -1, :] if repetition_penalty != 1.0 and generated_ids: prev_tokens = torch.tensor(generated_ids, device=device).unique() for token_id in prev_tokens: if logits[0, token_id] > 0: logits[0, token_id] /= repetition_penalty else: logits[0, token_id] *= repetition_penalty logits = logits / temperature if top_k > 0: topk_vals, _ = torch.topk(logits, top_k) logits[logits < topk_vals[:, -1:]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[mask] = float("-inf") logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() # Stop on any stop token (EOS, <|end|>, <|user|>) if token_id in stop_token_ids: break generated_ids.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=1) full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True) new_text = full_decoded[prev_decoded_len:] prev_decoded_len = len(full_decoded) yield new_text return def print_banner(step, loss, device): print("\033[1;36m") # cyan bold print("=" * 60) print(" 1B TRANSFORMER — Interactive Chat") print("=" * 60) print(f"\033[0m Checkpoint : step {step}") print(f" Loss : {loss}") print(f" Device : {device}") print(f" Parameters : 1.106B") print() print(" \033[90mCommands:\033[0m") print(" \033[33m/quit\033[0m — exit") print(" \033[33m/clear\033[0m — clear conversation context") print(" \033[33m/temp N\033[0m — set temperature (default 0.8)") print(" \033[33m/tokens N\033[0m — set max tokens (default 512)") print(" \033[33m/topp N\033[0m — set top-p (default 0.9)") print(" \033[33m/topk N\033[0m — set top-k (default 50)") print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)") print() print("\033[90m" + "─" * 60 + "\033[0m") def main(): device = "cuda:0" is_sft = False if len(sys.argv) > 1: checkpoint = sys.argv[1] is_sft = "sft" in checkpoint.lower() else: result = find_latest_checkpoint() if result[0] is None: print("No checkpoint found!") sys.exit(1) checkpoint, is_sft = result tokenizer = get_tokenizer() # Add chat tokens for SFT models if is_sft: special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] vocab = tokenizer.get_vocab() new_tokens = [t for t in special_tokens if t not in vocab] if new_tokens: tokenizer.add_tokens(new_tokens, special_tokens=True) print(f"\n Loading model from {checkpoint}...") print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}") model, config, step, loss = load_model(checkpoint, tokenizer, device) print(f" Model loaded!\n") print_banner(step, loss, device) if is_sft: print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n") # Settings temperature = 0.7 if is_sft else 0.8 max_tokens = 512 top_p = 0.9 top_k = 50 rep_penalty = 1.15 context = "" # Chat template tokens for SFT USER_START = "<|user|>\n" ASST_START = "<|assistant|>\n" TURN_END = "\n<|end|>\n" # Build stop token IDs for generation sft_stop_ids = [] if is_sft: vocab = tokenizer.get_vocab() for tok_str in ["<|end|>", "<|user|>"]: if tok_str in vocab: sft_stop_ids.append(vocab[tok_str]) while True: try: user_input = input("\n\033[1;32mYou:\033[0m ").strip() except (KeyboardInterrupt, EOFError): print("\n\nGoodbye!") break if not user_input: continue # Handle commands if user_input.startswith("/"): cmd = user_input.lower().split() if cmd[0] == "/quit": print("Goodbye!") break elif cmd[0] == "/clear": context = "" print("\033[90m [Context cleared]\033[0m") continue elif cmd[0] == "/temp" and len(cmd) > 1: temperature = float(cmd[1]) print(f"\033[90m [Temperature set to {temperature}]\033[0m") continue elif cmd[0] == "/tokens" and len(cmd) > 1: max_tokens = int(cmd[1]) print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m") continue elif cmd[0] == "/topp" and len(cmd) > 1: top_p = float(cmd[1]) print(f"\033[90m [Top-p set to {top_p}]\033[0m") continue elif cmd[0] == "/topk" and len(cmd) > 1: top_k = int(cmd[1]) print(f"\033[90m [Top-k set to {top_k}]\033[0m") continue elif cmd[0] == "/rep" and len(cmd) > 1: rep_penalty = float(cmd[1]) print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m") continue else: print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m") continue # Build prompt if is_sft: prompt = context + USER_START + user_input + TURN_END + ASST_START else: if context: prompt = context + "\n" + user_input else: prompt = user_input # Trim context if too long while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens: if is_sft: parts = context.split(TURN_END) if len(parts) <= 2: break context = TURN_END.join(parts[2:]) prompt = context + USER_START + user_input + TURN_END + ASST_START else: lines = prompt.split("\n") if len(lines) <= 2: break prompt = "\n".join(lines[1:]) # Generate with streaming print("\033[1;34mModel:\033[0m ", end="", flush=True) t0 = time.time() full_response = "" token_count = 0 for token_text in generate_stream( model, tokenizer, prompt, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=rep_penalty, device=device, stop_token_ids=sft_stop_ids if is_sft else None, ): print(token_text, end="", flush=True) full_response += token_text token_count += 1 elapsed = time.time() - t0 tps = token_count / max(elapsed, 1e-9) print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m") # Append to context for multi-turn if is_sft: context = (context + USER_START + user_input + TURN_END + ASST_START + full_response.strip() + TURN_END) else: context = prompt + full_response if __name__ == "__main__": main()