| | |
| | """ |
| | Interactive chat with the 1B Transformer. |
| | Runs in an infinite conversation loop from the terminal. |
| | |
| | Usage: |
| | python chat.py # auto-find latest checkpoint |
| | python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint |
| | """ |
| |
|
| | import sys |
| | import os |
| | import glob |
| | import time |
| | import torch |
| | import torch.nn.functional as F |
| | import readline |
| |
|
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from model.config import ModelConfig |
| | from model.transformer import Transformer |
| | from model.data import get_tokenizer |
| |
|
| |
|
| | def find_latest_checkpoint(): |
| | """Look for DPO > SFT > pretrained checkpoint.""" |
| | dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo" |
| | sft_dir = "/jfs/deepak-kumar/checkpoints_sft" |
| | pt_dir = "/jfs/deepak-kumar/checkpoints" |
| |
|
| | |
| | dpo_final = os.path.join(dpo_dir, "dpo_final.pt") |
| | if os.path.exists(dpo_final): |
| | return dpo_final, True |
| |
|
| | dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt")) |
| | if dpo_files: |
| | return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True |
| |
|
| | |
| | sft_final = os.path.join(sft_dir, "sft_final.pt") |
| | if os.path.exists(sft_final): |
| | return sft_final, True |
| |
|
| | sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt")) |
| | if sft_files: |
| | return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True |
| |
|
| | |
| | pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt")) |
| | if pt_files: |
| | return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False |
| |
|
| | return None, False |
| |
|
| |
|
| | def load_model(checkpoint_path, tokenizer, device="cuda:0"): |
| | config = ModelConfig() |
| | model = Transformer(config) |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| |
|
| | |
| | saved_vocab = ckpt.get("vocab_size", config.vocab_size) |
| | if saved_vocab > config.vocab_size: |
| | config.vocab_size = saved_vocab |
| | model = Transformer(config) |
| |
|
| | model.load_state_dict(ckpt["model"]) |
| | model = model.to(device).bfloat16().eval() |
| | step = ckpt.get("step", "?") |
| | loss = ckpt.get("loss", "?") |
| | del ckpt |
| | torch.cuda.empty_cache() |
| | return model, config, step, loss |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate_stream(model, tokenizer, prompt, max_new_tokens=512, |
| | temperature=0.8, top_k=50, top_p=0.9, |
| | repetition_penalty=1.15, device="cuda:0", |
| | stop_token_ids=None): |
| | """Generate tokens one at a time, yielding each for streaming output.""" |
| | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| | generated_ids = [] |
| | prev_decoded_len = 0 |
| |
|
| | if stop_token_ids is None: |
| | stop_token_ids = set() |
| | else: |
| | stop_token_ids = set(stop_token_ids) |
| | stop_token_ids.add(tokenizer.eos_token_id) |
| |
|
| | for _ in range(max_new_tokens): |
| | if input_ids.shape[1] >= model.config.max_seq_len: |
| | break |
| |
|
| | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| | logits, _ = model(input_ids) |
| |
|
| | logits = logits[:, -1, :] |
| |
|
| | if repetition_penalty != 1.0 and generated_ids: |
| | prev_tokens = torch.tensor(generated_ids, device=device).unique() |
| | for token_id in prev_tokens: |
| | if logits[0, token_id] > 0: |
| | logits[0, token_id] /= repetition_penalty |
| | else: |
| | logits[0, token_id] *= repetition_penalty |
| |
|
| | logits = logits / temperature |
| |
|
| | if top_k > 0: |
| | topk_vals, _ = torch.topk(logits, top_k) |
| | logits[logits < topk_vals[:, -1:]] = float("-inf") |
| |
|
| | if top_p < 1.0: |
| | sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| | sorted_logits[mask] = float("-inf") |
| | logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
| |
|
| | probs = F.softmax(logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | token_id = next_token.item() |
| |
|
| | |
| | if token_id in stop_token_ids: |
| | break |
| |
|
| | generated_ids.append(token_id) |
| | input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
|
| | full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True) |
| | new_text = full_decoded[prev_decoded_len:] |
| | prev_decoded_len = len(full_decoded) |
| | yield new_text |
| |
|
| | return |
| |
|
| |
|
| | def print_banner(step, loss, device): |
| | print("\033[1;36m") |
| | print("=" * 60) |
| | print(" 1B TRANSFORMER — Interactive Chat") |
| | print("=" * 60) |
| | print(f"\033[0m Checkpoint : step {step}") |
| | print(f" Loss : {loss}") |
| | print(f" Device : {device}") |
| | print(f" Parameters : 1.106B") |
| | print() |
| | print(" \033[90mCommands:\033[0m") |
| | print(" \033[33m/quit\033[0m — exit") |
| | print(" \033[33m/clear\033[0m — clear conversation context") |
| | print(" \033[33m/temp N\033[0m — set temperature (default 0.8)") |
| | print(" \033[33m/tokens N\033[0m — set max tokens (default 512)") |
| | print(" \033[33m/topp N\033[0m — set top-p (default 0.9)") |
| | print(" \033[33m/topk N\033[0m — set top-k (default 50)") |
| | print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)") |
| | print() |
| | print("\033[90m" + "─" * 60 + "\033[0m") |
| |
|
| |
|
| | def main(): |
| | device = "cuda:0" |
| |
|
| | is_sft = False |
| | if len(sys.argv) > 1: |
| | checkpoint = sys.argv[1] |
| | is_sft = "sft" in checkpoint.lower() |
| | else: |
| | result = find_latest_checkpoint() |
| | if result[0] is None: |
| | print("No checkpoint found!") |
| | sys.exit(1) |
| | checkpoint, is_sft = result |
| |
|
| | tokenizer = get_tokenizer() |
| |
|
| | |
| | if is_sft: |
| | special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] |
| | vocab = tokenizer.get_vocab() |
| | new_tokens = [t for t in special_tokens if t not in vocab] |
| | if new_tokens: |
| | tokenizer.add_tokens(new_tokens, special_tokens=True) |
| |
|
| | print(f"\n Loading model from {checkpoint}...") |
| | print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}") |
| | model, config, step, loss = load_model(checkpoint, tokenizer, device) |
| | print(f" Model loaded!\n") |
| |
|
| | print_banner(step, loss, device) |
| | if is_sft: |
| | print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n") |
| |
|
| | |
| | temperature = 0.7 if is_sft else 0.8 |
| | max_tokens = 512 |
| | top_p = 0.9 |
| | top_k = 50 |
| | rep_penalty = 1.15 |
| | context = "" |
| |
|
| | |
| | USER_START = "<|user|>\n" |
| | ASST_START = "<|assistant|>\n" |
| | TURN_END = "\n<|end|>\n" |
| |
|
| | |
| | sft_stop_ids = [] |
| | if is_sft: |
| | vocab = tokenizer.get_vocab() |
| | for tok_str in ["<|end|>", "<|user|>"]: |
| | if tok_str in vocab: |
| | sft_stop_ids.append(vocab[tok_str]) |
| |
|
| | while True: |
| | try: |
| | user_input = input("\n\033[1;32mYou:\033[0m ").strip() |
| | except (KeyboardInterrupt, EOFError): |
| | print("\n\nGoodbye!") |
| | break |
| |
|
| | if not user_input: |
| | continue |
| |
|
| | |
| | if user_input.startswith("/"): |
| | cmd = user_input.lower().split() |
| | if cmd[0] == "/quit": |
| | print("Goodbye!") |
| | break |
| | elif cmd[0] == "/clear": |
| | context = "" |
| | print("\033[90m [Context cleared]\033[0m") |
| | continue |
| | elif cmd[0] == "/temp" and len(cmd) > 1: |
| | temperature = float(cmd[1]) |
| | print(f"\033[90m [Temperature set to {temperature}]\033[0m") |
| | continue |
| | elif cmd[0] == "/tokens" and len(cmd) > 1: |
| | max_tokens = int(cmd[1]) |
| | print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m") |
| | continue |
| | elif cmd[0] == "/topp" and len(cmd) > 1: |
| | top_p = float(cmd[1]) |
| | print(f"\033[90m [Top-p set to {top_p}]\033[0m") |
| | continue |
| | elif cmd[0] == "/topk" and len(cmd) > 1: |
| | top_k = int(cmd[1]) |
| | print(f"\033[90m [Top-k set to {top_k}]\033[0m") |
| | continue |
| | elif cmd[0] == "/rep" and len(cmd) > 1: |
| | rep_penalty = float(cmd[1]) |
| | print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m") |
| | continue |
| | else: |
| | print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m") |
| | continue |
| |
|
| | |
| | if is_sft: |
| | prompt = context + USER_START + user_input + TURN_END + ASST_START |
| | else: |
| | if context: |
| | prompt = context + "\n" + user_input |
| | else: |
| | prompt = user_input |
| |
|
| | |
| | while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens: |
| | if is_sft: |
| | parts = context.split(TURN_END) |
| | if len(parts) <= 2: |
| | break |
| | context = TURN_END.join(parts[2:]) |
| | prompt = context + USER_START + user_input + TURN_END + ASST_START |
| | else: |
| | lines = prompt.split("\n") |
| | if len(lines) <= 2: |
| | break |
| | prompt = "\n".join(lines[1:]) |
| |
|
| | |
| | print("\033[1;34mModel:\033[0m ", end="", flush=True) |
| | t0 = time.time() |
| | full_response = "" |
| | token_count = 0 |
| |
|
| | for token_text in generate_stream( |
| | model, tokenizer, prompt, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=rep_penalty, |
| | device=device, |
| | stop_token_ids=sft_stop_ids if is_sft else None, |
| | ): |
| | print(token_text, end="", flush=True) |
| | full_response += token_text |
| | token_count += 1 |
| |
|
| | elapsed = time.time() - t0 |
| | tps = token_count / max(elapsed, 1e-9) |
| | print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m") |
| |
|
| | |
| | if is_sft: |
| | context = (context + USER_START + user_input + TURN_END + |
| | ASST_START + full_response.strip() + TURN_END) |
| | else: |
| | context = prompt + full_response |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|