| """ |
| test_chatmodel.py — Interactive CLI chat and evaluation for the fine-tuned SLLM chat model. |
| |
| Usage: |
| python test_chatmodel.py --run_dir runs/sllm_150m_chat |
| python test_chatmodel.py --run_dir runs/sllm_150m_chat --mode sample |
| |
| In interactive mode: |
| Type your message and press Enter. |
| Special commands: |
| /reset Clear conversation history |
| /system <text> Change the system prompt |
| /quit Exit the chat |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from torch.amp import autocast |
| from transformers import PreTrainedTokenizerFast |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parent |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from model.config import SLLM_150M |
| from model.model import SLLM |
|
|
| DEFAULT_SYSTEM = "You are a helpful, concise assistant." |
| DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat") |
|
|
|
|
| |
| |
| |
|
|
| def find_latest_ckpt(run_dir: str) -> str: |
| """Returns path to the most recent SFT or base checkpoint in run_dir.""" |
| if not os.path.isdir(run_dir): |
| raise FileNotFoundError(f"Run directory '{run_dir}' does not exist.") |
| |
| ckpts = sorted([ |
| f for f in os.listdir(run_dir) |
| if (f.startswith("ckpt_sft_") or f.startswith("ckpt_")) and f.endswith(".pt") |
| ]) |
| if not ckpts: |
| raise FileNotFoundError( |
| f"No checkpoints found in '{run_dir}'.\n" |
| f"Please ensure you have trained the model or point to the correct folder." |
| ) |
| return os.path.join(run_dir, ckpts[-1]) |
|
|
|
|
| def resize_token_embeddings(model: SLLM, new_vocab_size: int): |
| """Resizes the token embeddings matrix to support added special tokens.""" |
| old_size = model.config.vocab_size |
| if new_vocab_size == old_size: |
| return |
| d_model = model.config.d_model |
| device = model.token_emb.weight.device |
| dtype = model.token_emb.weight.dtype |
| old_weight = model.token_emb.weight.data.clone() |
| mean_vec = old_weight.mean(dim=0) |
| |
| new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) |
| new_weight[:old_size] = old_weight |
| new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) |
| |
| new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) |
| new_emb.weight.data = new_weight |
| model.token_emb = new_emb |
| model.lm_head.weight = model.token_emb.weight |
| model.config.vocab_size = new_vocab_size |
| print(f" [INFO] Resized model vocab embedding from {old_size:,} to {new_vocab_size:,}") |
|
|
|
|
| def load_model_and_tokenizer(run_dir: str, device: torch.device): |
| """Loads tokenizer and the latest model checkpoint.""" |
| |
| |
| data_tok_dir = PROJECT_ROOT / "finetune" / "data" |
| base_tok_dir = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer" |
| |
| if os.path.exists(data_tok_dir / "tokenizer.json"): |
| tok_path = str(data_tok_dir) |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) |
| print(f" Tokenizer: Loaded extended tokenizer from '{tok_path}'") |
| elif os.path.exists(base_tok_dir): |
| tok_path = str(base_tok_dir) |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) |
| tokenizer.add_special_tokens({ |
| "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] |
| }) |
| print(f" Tokenizer: Loaded base tokenizer from '{tok_path}' and added ChatML tokens") |
| else: |
| raise FileNotFoundError("Could not find a tokenizer directory.") |
|
|
| |
| try: |
| ckpt_path = find_latest_ckpt(run_dir) |
| except FileNotFoundError: |
| |
| print(f" [WARN] No checkpoint found in '{run_dir}'. Trying pretraining base run...") |
| base_dir = PROJECT_ROOT / "runs" / "sllm_150m" |
| ckpt_path = find_latest_ckpt(str(base_dir)) |
|
|
| print(f" Loading checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| |
| model = SLLM(SLLM_150M).to(device) |
| saved_vocab = ckpt.get("vocab_size", len(tokenizer)) |
| resize_token_embeddings(model, saved_vocab) |
| |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
|
|
| step = ckpt.get("step", "?") |
| loss = ckpt.get("loss", float("nan")) |
| return model, tokenizer, ckpt_path, step, loss |
|
|
|
|
| |
| |
| |
|
|
| def build_prompt(history: list[dict], system_prompt: str, |
| tokenizer: PreTrainedTokenizerFast) -> torch.Tensor: |
| """Formats conversation history as ChatML and tokenizes it.""" |
| text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" |
| for turn in history: |
| text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" |
| |
| text += "<|im_start|>assistant\n" |
|
|
| ids = tokenizer.encode(text, add_special_tokens=False) |
| return torch.tensor([ids], dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def generate_response( |
| model: SLLM, |
| input_ids: torch.Tensor, |
| tokenizer: PreTrainedTokenizerFast, |
| max_new_tokens: int = 200, |
| temperature: float = 0.7, |
| top_k: int = 40, |
| top_p: float = 0.9, |
| device: torch.device = None, |
| dtype_torch: torch.dtype = torch.float32, |
| use_amp: bool = False, |
| ) -> str: |
| """Generates a response from the model using top-k/top-p sampling.""" |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| eos_id = tokenizer.eos_token_id |
|
|
| ids = input_ids.to(device) |
| generated = [] |
|
|
| for _ in range(max_new_tokens): |
| |
| ctx = ids if ids.shape[1] <= model.config.context_length \ |
| else ids[:, -model.config.context_length:] |
|
|
| with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): |
| logits, _ = model(ctx) |
| |
| |
| logits = logits[:, -1, :] |
| |
| if temperature == 0.0: |
| |
| next_token = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits = logits / max(temperature, 1e-8) |
|
|
| |
| if top_k and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") |
| logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| tok_id = next_token.item() |
|
|
| |
| if tok_id == im_end_id or tok_id == eos_id: |
| break |
|
|
| generated.append(tok_id) |
| ids = torch.cat([ids, next_token], dim=1) |
|
|
| return tokenizer.decode(generated, skip_special_tokens=True).strip() |
|
|
|
|
| |
| |
| |
|
|
| def run_interactive(model, tokenizer, device, dtype_torch, use_amp, args): |
| system_prompt = args.system |
| history = [] |
|
|
| print("\n" + "=" * 60) |
| print(" CHAT MODE (Interactive)") |
| print("=" * 60) |
| print(f" System prompt : {system_prompt}") |
| print(" Commands : /reset to clear memory | /system <prompt> | /quit to exit") |
| print("─" * 60 + "\n") |
|
|
| while True: |
| try: |
| user_input = input("You: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nBye!") |
| break |
|
|
| if not user_input: |
| continue |
|
|
| |
| if user_input.lower() in ("/quit", "/exit", "quit", "exit"): |
| print("Bye!") |
| break |
|
|
| if user_input.lower() == "/reset": |
| history = [] |
| print(" [Conversation history reset]\n") |
| continue |
|
|
| if user_input.lower().startswith("/system "): |
| new_sys = user_input[8:].strip() |
| if new_sys: |
| system_prompt = new_sys |
| history = [] |
| print(f" [System prompt updated. History cleared.]\n") |
| continue |
|
|
| |
| history.append({"role": "user", "content": user_input}) |
| input_ids = build_prompt(history, system_prompt, tokenizer) |
|
|
| |
| while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10: |
| if len(history) > 2: |
| history = history[2:] |
| input_ids = build_prompt(history, system_prompt, tokenizer) |
| else: |
| break |
|
|
| print("SLLM: ", end="", flush=True) |
| response = generate_response( |
| model, input_ids, tokenizer, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| device=device, |
| dtype_torch=dtype_torch, |
| use_amp=use_amp, |
| ) |
| print(response + "\n") |
| history.append({"role": "assistant", "content": response}) |
|
|
|
|
| def run_sample(model, tokenizer, device, dtype_torch, use_amp, args): |
| sample_prompts = [ |
| "Hello! Who are you?", |
| "What is the capital of France?", |
| "Write a quick, 3-line poem about a small robot learning to speak.", |
| "Explain gravity in one simple sentence.", |
| ] |
|
|
| print("\n" + "=" * 60) |
| print(" SAMPLE EVALUATION MODE") |
| print("=" * 60) |
| print(f" System prompt: {args.system}") |
| print("─" * 60) |
|
|
| for prompt in sample_prompts: |
| print(f"\n[PROMPT] : {prompt}") |
| history = [{"role": "user", "content": prompt}] |
| input_ids = build_prompt(history, args.system, tokenizer) |
| |
| print("[SLLM] : ", end="", flush=True) |
| response = generate_response( |
| model, input_ids, tokenizer, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| device=device, |
| dtype_torch=dtype_torch, |
| use_amp=use_amp, |
| ) |
| print(response) |
| print("\n" + "─" * 60 + "\n") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser(description="SLLM Chat Checker") |
| p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR) |
| p.add_argument("--mode", type=str, default="interactive", choices=["interactive", "sample"]) |
| p.add_argument("--temperature", type=float, default=0.7) |
| p.add_argument("--top_k", type=int, default=40) |
| p.add_argument("--top_p", type=float, default=0.9) |
| p.add_argument("--max_new_tokens", type=int, default=200) |
| p.add_argument("--system", type=str, default=DEFAULT_SYSTEM) |
| p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) |
| args = p.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice : {device}") |
| if device.type == "cuda": |
| print(f"GPU : {torch.cuda.get_device_name(0)}") |
|
|
| |
| use_amp = False |
| if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): |
| dtype_torch = torch.bfloat16 |
| use_amp = True |
| elif args.dtype == "fp16" and device.type == "cuda": |
| dtype_torch = torch.float16 |
| use_amp = True |
| else: |
| dtype_torch = torch.float32 |
| print(f"dtype : {args.dtype}") |
|
|
| |
| try: |
| model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device) |
| print(f" Step : {step}") |
| if not torch.isnan(torch.tensor(loss)): |
| print(f" Loss : {loss:.4f}") |
| except Exception as e: |
| print(f"\n[ERROR] Failed to load chat model: {e}") |
| return |
|
|
| if args.mode == "interactive": |
| run_interactive(model, tokenizer, device, dtype_torch, use_amp, args) |
| elif args.mode == "sample": |
| run_sample(model, tokenizer, device, dtype_torch, use_amp, args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|