| """ |
| finetune/chat.py |
| |
| Interactive CLI chat with the fine-tuned SLLM-150M chat model. |
| |
| Loads the latest SFT checkpoint from --run_dir, formats your input |
| as a ChatML prompt, generates a response token-by-token, and stops |
| at the <|im_end|> token. |
| |
| Usage: |
| python finetune/chat.py |
| python finetune/chat.py --run_dir runs/sllm_150m_chat |
| python finetune/chat.py --temperature 0.7 --top_k 40 |
| |
| In-chat commands: |
| /reset clear conversation history (start fresh) |
| /system <text> change the system prompt |
| /quit exit |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedTokenizerFast |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| PROJECT_ROOT = SCRIPT_DIR.parent |
| DATA_DIR = SCRIPT_DIR / "data" |
|
|
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from model.config import SLLM_150M |
| from model.model import SLLM |
|
|
| DEFAULT_SYSTEM = "You are a helpful, concise assistant." |
| DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat") |
|
|
|
|
| |
| |
| |
|
|
| def find_latest_ckpt(run_dir: str) -> str: |
| """Returns path to the most recent ckpt_sft_*.pt in run_dir.""" |
| ckpts = sorted([ |
| f for f in os.listdir(run_dir) |
| if f.startswith("ckpt_sft_") and f.endswith(".pt") |
| ]) |
| if not ckpts: |
| raise FileNotFoundError( |
| f"No SFT checkpoints found in '{run_dir}'.\n" |
| f"Run sft_train.py first." |
| ) |
| return os.path.join(run_dir, ckpts[-1]) |
|
|
|
|
| def resize_token_embeddings(model: SLLM, new_vocab_size: int): |
| """Same resize logic as sft_train.py — kept local to avoid circular imports.""" |
| old_size = model.config.vocab_size |
| if new_vocab_size == old_size: |
| return |
| d_model = model.config.d_model |
| device = model.token_emb.weight.device |
| dtype = model.token_emb.weight.dtype |
| old_weight = model.token_emb.weight.data.clone() |
| mean_vec = old_weight.mean(dim=0) |
| new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) |
| new_weight[:old_size] = old_weight |
| new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) |
| new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) |
| new_emb.weight.data = new_weight |
| model.token_emb = new_emb |
| model.lm_head.weight = model.token_emb.weight |
| model.config.vocab_size = new_vocab_size |
|
|
|
|
| def load_model_and_tokenizer(run_dir: str, device: torch.device): |
| """Loads tokenizer (from data dir) and fine-tuned model (from run_dir).""" |
|
|
| |
| tok_path = str(DATA_DIR) |
| if os.path.exists(os.path.join(tok_path, "tokenizer.json")): |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) |
| else: |
| |
| base_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer") |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(base_dir) |
| tokenizer.add_special_tokens({ |
| "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] |
| }) |
|
|
| |
| ckpt_path = find_latest_ckpt(run_dir) |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| |
| model = SLLM(SLLM_150M).to(device) |
| saved_vocab = ckpt.get("vocab_size", len(tokenizer)) |
| resize_token_embeddings(model, saved_vocab) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
|
|
| return model, tokenizer, ckpt_path, ckpt.get("step", "?"), ckpt.get("loss", float("nan")) |
|
|
|
|
| |
| |
| |
|
|
| def build_prompt(history: list[dict], system_prompt: str, |
| tokenizer: PreTrainedTokenizerFast) -> torch.Tensor: |
| """ |
| Formats conversation history as ChatML and tokenises it. |
| |
| Template: |
| <|im_start|>system |
| {system}<|im_end|> |
| <|im_start|>user |
| {user}<|im_end|> |
| <|im_start|>assistant |
| {assistant}<|im_end|> |
| ... |
| <|im_start|>assistant\\n ← left open for the model to complete |
| |
| Returns: |
| input_ids : (1, T) LongTensor |
| """ |
| text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" |
| for turn in history: |
| text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" |
| |
| text += "<|im_start|>assistant\n" |
|
|
| ids = tokenizer.encode(text, add_special_tokens=False) |
| return torch.tensor([ids], dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def generate_response( |
| model: SLLM, |
| input_ids: torch.Tensor, |
| tokenizer: PreTrainedTokenizerFast, |
| max_new_tokens: int = 300, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| device: torch.device = None, |
| ) -> str: |
| """ |
| Autoregressively generates tokens until: |
| - <|im_end|> is produced (clean stop), or |
| - eos_token_id is produced, or |
| - max_new_tokens is reached |
| |
| Returns the decoded response string (special tokens stripped). |
| """ |
| im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| eos_id = tokenizer.eos_token_id |
|
|
| ids = input_ids.to(device) |
| generated = [] |
|
|
| for _ in range(max_new_tokens): |
| |
| ctx = ids if ids.shape[1] <= model.config.context_length \ |
| else ids[:, -model.config.context_length:] |
|
|
| logits, _ = model(ctx) |
| logits = logits[:, -1, :] / max(temperature, 1e-8) |
|
|
| |
| if top_k and top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| tok_id = next_token.item() |
|
|
| |
| if tok_id == im_end_id or tok_id == eos_id: |
| break |
|
|
| generated.append(tok_id) |
| ids = torch.cat([ids, next_token], dim=1) |
|
|
| return tokenizer.decode(generated, skip_special_tokens=True).strip() |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="SLLM-150M Chat") |
| p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR) |
| p.add_argument("--temperature", type=float, default=0.8, |
| help="Sampling temperature (lower = more focused)") |
| p.add_argument("--top_k", type=int, default=50, |
| help="Top-k sampling (0 = disabled)") |
| p.add_argument("--max_new_tokens", type=int, default=300, |
| help="Max tokens per assistant response") |
| p.add_argument("--system", type=str, default=DEFAULT_SYSTEM, |
| help="System prompt") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print("\n" + "=" * 60) |
| print(" SLLM-150M Chat") |
| print("=" * 60) |
| print(f" Device : {device}") |
| if device.type == "cuda": |
| print(f" GPU : {torch.cuda.get_device_name(0)}") |
|
|
| |
| print("\nLoading model...") |
| model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device) |
| print(f" Checkpoint : {ckpt_path}") |
| print(f" Step : {step} Loss: {loss:.4f}") |
| print(f" Vocab size : {len(tokenizer):,}") |
|
|
| |
| system_prompt = args.system |
| history: list[dict] = [] |
|
|
| print(f"\n System : {system_prompt}") |
| print(" Commands: /reset | /system <new prompt> | /quit") |
| print("─" * 60 + "\n") |
|
|
| while True: |
| try: |
| user_input = input("You: ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nBye!") |
| break |
|
|
| if not user_input: |
| continue |
|
|
| |
| if user_input.lower() in ("/quit", "/exit", "quit", "exit"): |
| print("Bye!") |
| break |
|
|
| if user_input.lower() == "/reset": |
| history = [] |
| print(" [Conversation cleared]\n") |
| continue |
|
|
| if user_input.lower().startswith("/system "): |
| new_sys = user_input[8:].strip() |
| if new_sys: |
| system_prompt = new_sys |
| history = [] |
| print(f" [System prompt updated. Conversation cleared.]\n") |
| continue |
|
|
| |
| history.append({"role": "user", "content": user_input}) |
| input_ids = build_prompt(history, system_prompt, tokenizer) |
|
|
| |
| while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10: |
| if len(history) > 2: |
| history = history[2:] |
| input_ids = build_prompt(history, system_prompt, tokenizer) |
| else: |
| break |
|
|
| |
| print("SLLM: ", end="", flush=True) |
| response = generate_response( |
| model, input_ids, tokenizer, |
| max_new_tokens = args.max_new_tokens, |
| temperature = args.temperature, |
| top_k = args.top_k, |
| device = device, |
| ) |
| print(response + "\n") |
|
|
| history.append({"role": "assistant", "content": response}) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|