""" finetune/chat.py Interactive CLI chat with the fine-tuned SLLM-150M chat model. Loads the latest SFT checkpoint from --run_dir, formats your input as a ChatML prompt, generates a response token-by-token, and stops at the <|im_end|> token. Usage: python finetune/chat.py python finetune/chat.py --run_dir runs/sllm_150m_chat python finetune/chat.py --temperature 0.7 --top_k 40 In-chat commands: /reset clear conversation history (start fresh) /system change the system prompt /quit exit """ import os import sys import argparse from pathlib import Path import torch import torch.nn as nn from transformers import PreTrainedTokenizerFast SCRIPT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = SCRIPT_DIR.parent DATA_DIR = SCRIPT_DIR / "data" sys.path.insert(0, str(PROJECT_ROOT)) from model.config import SLLM_150M from model.model import SLLM DEFAULT_SYSTEM = "You are a helpful, concise assistant." DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat") # ------------------------------------------------------------------ # # HELPERS # ------------------------------------------------------------------ # def find_latest_ckpt(run_dir: str) -> str: """Returns path to the most recent ckpt_sft_*.pt in run_dir.""" ckpts = sorted([ f for f in os.listdir(run_dir) if f.startswith("ckpt_sft_") and f.endswith(".pt") ]) if not ckpts: raise FileNotFoundError( f"No SFT checkpoints found in '{run_dir}'.\n" f"Run sft_train.py first." ) return os.path.join(run_dir, ckpts[-1]) def resize_token_embeddings(model: SLLM, new_vocab_size: int): """Same resize logic as sft_train.py — kept local to avoid circular imports.""" old_size = model.config.vocab_size if new_vocab_size == old_size: return d_model = model.config.d_model device = model.token_emb.weight.device dtype = model.token_emb.weight.dtype old_weight = model.token_emb.weight.data.clone() mean_vec = old_weight.mean(dim=0) new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) new_weight[:old_size] = old_weight new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) new_emb.weight.data = new_weight model.token_emb = new_emb model.lm_head.weight = model.token_emb.weight model.config.vocab_size = new_vocab_size def load_model_and_tokenizer(run_dir: str, device: torch.device): """Loads tokenizer (from data dir) and fine-tuned model (from run_dir).""" # ---- Tokenizer ------------------------------------------------- # tok_path = str(DATA_DIR) if os.path.exists(os.path.join(tok_path, "tokenizer.json")): tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) else: # Fallback: base tokenizer + manual special token add base_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer") tokenizer = PreTrainedTokenizerFast.from_pretrained(base_dir) tokenizer.add_special_tokens({ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] }) # ---- Checkpoint ------------------------------------------------ # ckpt_path = find_latest_ckpt(run_dir) ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) # ---- Model ----------------------------------------------------- # model = SLLM(SLLM_150M).to(device) saved_vocab = ckpt.get("vocab_size", len(tokenizer)) resize_token_embeddings(model, saved_vocab) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model, tokenizer, ckpt_path, ckpt.get("step", "?"), ckpt.get("loss", float("nan")) # ------------------------------------------------------------------ # # PROMPT BUILDING # ------------------------------------------------------------------ # def build_prompt(history: list[dict], system_prompt: str, tokenizer: PreTrainedTokenizerFast) -> torch.Tensor: """ Formats conversation history as ChatML and tokenises it. Template: <|im_start|>system {system}<|im_end|> <|im_start|>user {user}<|im_end|> <|im_start|>assistant {assistant}<|im_end|> ... <|im_start|>assistant\\n ← left open for the model to complete Returns: input_ids : (1, T) LongTensor """ text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" for turn in history: text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" # Prime the model to generate as assistant text += "<|im_start|>assistant\n" ids = tokenizer.encode(text, add_special_tokens=False) return torch.tensor([ids], dtype=torch.long) # ------------------------------------------------------------------ # # GENERATION # ------------------------------------------------------------------ # @torch.no_grad() def generate_response( model: SLLM, input_ids: torch.Tensor, tokenizer: PreTrainedTokenizerFast, max_new_tokens: int = 300, temperature: float = 0.8, top_k: int = 50, device: torch.device = None, ) -> str: """ Autoregressively generates tokens until: - <|im_end|> is produced (clean stop), or - eos_token_id is produced, or - max_new_tokens is reached Returns the decoded response string (special tokens stripped). """ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") eos_id = tokenizer.eos_token_id ids = input_ids.to(device) generated = [] for _ in range(max_new_tokens): # Crop to context window ctx = ids if ids.shape[1] <= model.config.context_length \ else ids[:, -model.config.context_length:] logits, _ = model(ctx) # (1, T, V) logits = logits[:, -1, :] / max(temperature, 1e-8) # Top-k filtering if top_k and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # (1, 1) tok_id = next_token.item() # Stop conditions if tok_id == im_end_id or tok_id == eos_id: break generated.append(tok_id) ids = torch.cat([ids, next_token], dim=1) return tokenizer.decode(generated, skip_special_tokens=True).strip() # ------------------------------------------------------------------ # # MAIN # ------------------------------------------------------------------ # def parse_args(): p = argparse.ArgumentParser(description="SLLM-150M Chat") p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR) p.add_argument("--temperature", type=float, default=0.8, help="Sampling temperature (lower = more focused)") p.add_argument("--top_k", type=int, default=50, help="Top-k sampling (0 = disabled)") p.add_argument("--max_new_tokens", type=int, default=300, help="Max tokens per assistant response") p.add_argument("--system", type=str, default=DEFAULT_SYSTEM, help="System prompt") return p.parse_args() def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("\n" + "=" * 60) print(" SLLM-150M Chat") print("=" * 60) print(f" Device : {device}") if device.type == "cuda": print(f" GPU : {torch.cuda.get_device_name(0)}") # ---- Load ------------------------------------------------------ # print("\nLoading model...") model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device) print(f" Checkpoint : {ckpt_path}") print(f" Step : {step} Loss: {loss:.4f}") print(f" Vocab size : {len(tokenizer):,}") # ---- Chat loop ------------------------------------------------- # system_prompt = args.system history: list[dict] = [] print(f"\n System : {system_prompt}") print(" Commands: /reset | /system | /quit") print("─" * 60 + "\n") while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if not user_input: continue # ---- Commands ---------------------------------------------- # if user_input.lower() in ("/quit", "/exit", "quit", "exit"): print("Bye!") break if user_input.lower() == "/reset": history = [] print(" [Conversation cleared]\n") continue if user_input.lower().startswith("/system "): new_sys = user_input[8:].strip() if new_sys: system_prompt = new_sys history = [] print(f" [System prompt updated. Conversation cleared.]\n") continue # ---- Build prompt ------------------------------------------ # history.append({"role": "user", "content": user_input}) input_ids = build_prompt(history, system_prompt, tokenizer) # Trim history if prompt is getting close to context limit while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10: if len(history) > 2: history = history[2:] # drop oldest user+assistant pair input_ids = build_prompt(history, system_prompt, tokenizer) else: break # can't trim further — just truncate in generation # ---- Generate ---------------------------------------------- # print("SLLM: ", end="", flush=True) response = generate_response( model, input_ids, tokenizer, max_new_tokens = args.max_new_tokens, temperature = args.temperature, top_k = args.top_k, device = device, ) print(response + "\n") history.append({"role": "assistant", "content": response}) if __name__ == "__main__": main()