sllm / test_chatmodel.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
test_chatmodel.py — Interactive CLI chat and evaluation for the fine-tuned SLLM chat model.
Usage:
python test_chatmodel.py --run_dir runs/sllm_150m_chat
python test_chatmodel.py --run_dir runs/sllm_150m_chat --mode sample
In interactive mode:
Type your message and press Enter.
Special commands:
/reset Clear conversation history
/system <text> Change the system prompt
/quit Exit the chat
"""
import os
import sys
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from torch.amp import autocast
from transformers import PreTrainedTokenizerFast
# Add project root to path
PROJECT_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))
from model.config import SLLM_150M
from model.model import SLLM
DEFAULT_SYSTEM = "You are a helpful, concise assistant."
DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
# ------------------------------------------------------------------ #
# HELPERS
# ------------------------------------------------------------------ #
def find_latest_ckpt(run_dir: str) -> str:
"""Returns path to the most recent SFT or base checkpoint in run_dir."""
if not os.path.isdir(run_dir):
raise FileNotFoundError(f"Run directory '{run_dir}' does not exist.")
ckpts = sorted([
f for f in os.listdir(run_dir)
if (f.startswith("ckpt_sft_") or f.startswith("ckpt_")) and f.endswith(".pt")
])
if not ckpts:
raise FileNotFoundError(
f"No checkpoints found in '{run_dir}'.\n"
f"Please ensure you have trained the model or point to the correct folder."
)
return os.path.join(run_dir, ckpts[-1])
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
"""Resizes the token embeddings matrix to support added special tokens."""
old_size = model.config.vocab_size
if new_vocab_size == old_size:
return
d_model = model.config.d_model
device = model.token_emb.weight.device
dtype = model.token_emb.weight.dtype
old_weight = model.token_emb.weight.data.clone()
mean_vec = old_weight.mean(dim=0)
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
new_weight[:old_size] = old_weight
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
new_emb.weight.data = new_weight
model.token_emb = new_emb
model.lm_head.weight = model.token_emb.weight
model.config.vocab_size = new_vocab_size
print(f" [INFO] Resized model vocab embedding from {old_size:,} to {new_vocab_size:,}")
def load_model_and_tokenizer(run_dir: str, device: torch.device):
"""Loads tokenizer and the latest model checkpoint."""
# ---- Tokenizer ------------------------------------------------- #
# Look in finetune/data or tokenizer/fineweb_edu_tokenizer
data_tok_dir = PROJECT_ROOT / "finetune" / "data"
base_tok_dir = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer"
if os.path.exists(data_tok_dir / "tokenizer.json"):
tok_path = str(data_tok_dir)
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
print(f" Tokenizer: Loaded extended tokenizer from '{tok_path}'")
elif os.path.exists(base_tok_dir):
tok_path = str(base_tok_dir)
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
tokenizer.add_special_tokens({
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
})
print(f" Tokenizer: Loaded base tokenizer from '{tok_path}' and added ChatML tokens")
else:
raise FileNotFoundError("Could not find a tokenizer directory.")
# ---- Checkpoint ------------------------------------------------ #
try:
ckpt_path = find_latest_ckpt(run_dir)
except FileNotFoundError:
# Fall back to base pretraining checkpoint if SFT directory is empty
print(f" [WARN] No checkpoint found in '{run_dir}'. Trying pretraining base run...")
base_dir = PROJECT_ROOT / "runs" / "sllm_150m"
ckpt_path = find_latest_ckpt(str(base_dir))
print(f" Loading checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# ---- Model ----------------------------------------------------- #
model = SLLM(SLLM_150M).to(device)
saved_vocab = ckpt.get("vocab_size", len(tokenizer))
resize_token_embeddings(model, saved_vocab)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
step = ckpt.get("step", "?")
loss = ckpt.get("loss", float("nan"))
return model, tokenizer, ckpt_path, step, loss
# ------------------------------------------------------------------ #
# PROMPT BUILDING
# ------------------------------------------------------------------ #
def build_prompt(history: list[dict], system_prompt: str,
tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
"""Formats conversation history as ChatML and tokenizes it."""
text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
for turn in history:
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
# Prime the model to respond as assistant
text += "<|im_start|>assistant\n"
ids = tokenizer.encode(text, add_special_tokens=False)
return torch.tensor([ids], dtype=torch.long)
# ------------------------------------------------------------------ #
# GENERATION
# ------------------------------------------------------------------ #
@torch.no_grad()
def generate_response(
model: SLLM,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizerFast,
max_new_tokens: int = 200,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.9,
device: torch.device = None,
dtype_torch: torch.dtype = torch.float32,
use_amp: bool = False,
) -> str:
"""Generates a response from the model using top-k/top-p sampling."""
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
eos_id = tokenizer.eos_token_id
ids = input_ids.to(device)
generated = []
for _ in range(max_new_tokens):
# Crop context to model window
ctx = ids if ids.shape[1] <= model.config.context_length \
else ids[:, -model.config.context_length:]
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
logits, _ = model(ctx) # (1, T, V)
# Pull last token logits
logits = logits[:, -1, :]
if temperature == 0.0:
# Greedy
next_token = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / max(temperature, 1e-8)
# Top-k filtering
if top_k and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
tok_id = next_token.item()
# Stop if end of message or end of stream token is generated
if tok_id == im_end_id or tok_id == eos_id:
break
generated.append(tok_id)
ids = torch.cat([ids, next_token], dim=1)
return tokenizer.decode(generated, skip_special_tokens=True).strip()
# ------------------------------------------------------------------ #
# MODES
# ------------------------------------------------------------------ #
def run_interactive(model, tokenizer, device, dtype_torch, use_amp, args):
system_prompt = args.system
history = []
print("\n" + "=" * 60)
print(" CHAT MODE (Interactive)")
print("=" * 60)
print(f" System prompt : {system_prompt}")
print(" Commands : /reset to clear memory | /system <prompt> | /quit to exit")
print("─" * 60 + "\n")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input:
continue
# Check for commands
if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
print("Bye!")
break
if user_input.lower() == "/reset":
history = []
print(" [Conversation history reset]\n")
continue
if user_input.lower().startswith("/system "):
new_sys = user_input[8:].strip()
if new_sys:
system_prompt = new_sys
history = []
print(f" [System prompt updated. History cleared.]\n")
continue
# Add to history and build ChatML prompt
history.append({"role": "user", "content": user_input})
input_ids = build_prompt(history, system_prompt, tokenizer)
# Trim conversation window if it exceeds model context length
while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
if len(history) > 2:
history = history[2:] # Remove oldest user + assistant turn
input_ids = build_prompt(history, system_prompt, tokenizer)
else:
break
print("SLLM: ", end="", flush=True)
response = generate_response(
model, input_ids, tokenizer,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
device=device,
dtype_torch=dtype_torch,
use_amp=use_amp,
)
print(response + "\n")
history.append({"role": "assistant", "content": response})
def run_sample(model, tokenizer, device, dtype_torch, use_amp, args):
sample_prompts = [
"Hello! Who are you?",
"What is the capital of France?",
"Write a quick, 3-line poem about a small robot learning to speak.",
"Explain gravity in one simple sentence.",
]
print("\n" + "=" * 60)
print(" SAMPLE EVALUATION MODE")
print("=" * 60)
print(f" System prompt: {args.system}")
print("─" * 60)
for prompt in sample_prompts:
print(f"\n[PROMPT] : {prompt}")
history = [{"role": "user", "content": prompt}]
input_ids = build_prompt(history, args.system, tokenizer)
print("[SLLM] : ", end="", flush=True)
response = generate_response(
model, input_ids, tokenizer,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
device=device,
dtype_torch=dtype_torch,
use_amp=use_amp,
)
print(response)
print("\n" + "─" * 60 + "\n")
# ------------------------------------------------------------------ #
# MAIN
# ------------------------------------------------------------------ #
def main():
p = argparse.ArgumentParser(description="SLLM Chat Checker")
p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
p.add_argument("--mode", type=str, default="interactive", choices=["interactive", "sample"])
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top_k", type=int, default=40)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument("--max_new_tokens", type=int, default=200)
p.add_argument("--system", type=str, default=DEFAULT_SYSTEM)
p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice : {device}")
if device.type == "cuda":
print(f"GPU : {torch.cuda.get_device_name(0)}")
# Precision setup
use_amp = False
if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
dtype_torch = torch.bfloat16
use_amp = True
elif args.dtype == "fp16" and device.type == "cuda":
dtype_torch = torch.float16
use_amp = True
else:
dtype_torch = torch.float32
print(f"dtype : {args.dtype}")
# Load Model and Tokenizer
try:
model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
print(f" Step : {step}")
if not torch.isnan(torch.tensor(loss)):
print(f" Loss : {loss:.4f}")
except Exception as e:
print(f"\n[ERROR] Failed to load chat model: {e}")
return
if args.mode == "interactive":
run_interactive(model, tokenizer, device, dtype_torch, use_amp, args)
elif args.mode == "sample":
run_sample(model, tokenizer, device, dtype_torch, use_amp, args)
if __name__ == "__main__":
main()