sllm / finetune /chat.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
finetune/chat.py
Interactive CLI chat with the fine-tuned SLLM-150M chat model.
Loads the latest SFT checkpoint from --run_dir, formats your input
as a ChatML prompt, generates a response token-by-token, and stops
at the <|im_end|> token.
Usage:
python finetune/chat.py
python finetune/chat.py --run_dir runs/sllm_150m_chat
python finetune/chat.py --temperature 0.7 --top_k 40
In-chat commands:
/reset clear conversation history (start fresh)
/system <text> change the system prompt
/quit exit
"""
import os
import sys
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from transformers import PreTrainedTokenizerFast
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
DATA_DIR = SCRIPT_DIR / "data"
sys.path.insert(0, str(PROJECT_ROOT))
from model.config import SLLM_150M
from model.model import SLLM
DEFAULT_SYSTEM = "You are a helpful, concise assistant."
DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat")
# ------------------------------------------------------------------ #
# HELPERS
# ------------------------------------------------------------------ #
def find_latest_ckpt(run_dir: str) -> str:
"""Returns path to the most recent ckpt_sft_*.pt in run_dir."""
ckpts = sorted([
f for f in os.listdir(run_dir)
if f.startswith("ckpt_sft_") and f.endswith(".pt")
])
if not ckpts:
raise FileNotFoundError(
f"No SFT checkpoints found in '{run_dir}'.\n"
f"Run sft_train.py first."
)
return os.path.join(run_dir, ckpts[-1])
def resize_token_embeddings(model: SLLM, new_vocab_size: int):
"""Same resize logic as sft_train.py — kept local to avoid circular imports."""
old_size = model.config.vocab_size
if new_vocab_size == old_size:
return
d_model = model.config.d_model
device = model.token_emb.weight.device
dtype = model.token_emb.weight.dtype
old_weight = model.token_emb.weight.data.clone()
mean_vec = old_weight.mean(dim=0)
new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device)
new_weight[:old_size] = old_weight
new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1)
new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype)
new_emb.weight.data = new_weight
model.token_emb = new_emb
model.lm_head.weight = model.token_emb.weight
model.config.vocab_size = new_vocab_size
def load_model_and_tokenizer(run_dir: str, device: torch.device):
"""Loads tokenizer (from data dir) and fine-tuned model (from run_dir)."""
# ---- Tokenizer ------------------------------------------------- #
tok_path = str(DATA_DIR)
if os.path.exists(os.path.join(tok_path, "tokenizer.json")):
tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path)
else:
# Fallback: base tokenizer + manual special token add
base_dir = str(PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(base_dir)
tokenizer.add_special_tokens({
"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]
})
# ---- Checkpoint ------------------------------------------------ #
ckpt_path = find_latest_ckpt(run_dir)
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# ---- Model ----------------------------------------------------- #
model = SLLM(SLLM_150M).to(device)
saved_vocab = ckpt.get("vocab_size", len(tokenizer))
resize_token_embeddings(model, saved_vocab)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
return model, tokenizer, ckpt_path, ckpt.get("step", "?"), ckpt.get("loss", float("nan"))
# ------------------------------------------------------------------ #
# PROMPT BUILDING
# ------------------------------------------------------------------ #
def build_prompt(history: list[dict], system_prompt: str,
tokenizer: PreTrainedTokenizerFast) -> torch.Tensor:
"""
Formats conversation history as ChatML and tokenises it.
Template:
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{user}<|im_end|>
<|im_start|>assistant
{assistant}<|im_end|>
...
<|im_start|>assistant\\n ← left open for the model to complete
Returns:
input_ids : (1, T) LongTensor
"""
text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
for turn in history:
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
# Prime the model to generate as assistant
text += "<|im_start|>assistant\n"
ids = tokenizer.encode(text, add_special_tokens=False)
return torch.tensor([ids], dtype=torch.long)
# ------------------------------------------------------------------ #
# GENERATION
# ------------------------------------------------------------------ #
@torch.no_grad()
def generate_response(
model: SLLM,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizerFast,
max_new_tokens: int = 300,
temperature: float = 0.8,
top_k: int = 50,
device: torch.device = None,
) -> str:
"""
Autoregressively generates tokens until:
- <|im_end|> is produced (clean stop), or
- eos_token_id is produced, or
- max_new_tokens is reached
Returns the decoded response string (special tokens stripped).
"""
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
eos_id = tokenizer.eos_token_id
ids = input_ids.to(device)
generated = []
for _ in range(max_new_tokens):
# Crop to context window
ctx = ids if ids.shape[1] <= model.config.context_length \
else ids[:, -model.config.context_length:]
logits, _ = model(ctx) # (1, T, V)
logits = logits[:, -1, :] / max(temperature, 1e-8)
# Top-k filtering
if top_k and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
tok_id = next_token.item()
# Stop conditions
if tok_id == im_end_id or tok_id == eos_id:
break
generated.append(tok_id)
ids = torch.cat([ids, next_token], dim=1)
return tokenizer.decode(generated, skip_special_tokens=True).strip()
# ------------------------------------------------------------------ #
# MAIN
# ------------------------------------------------------------------ #
def parse_args():
p = argparse.ArgumentParser(description="SLLM-150M Chat")
p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR)
p.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature (lower = more focused)")
p.add_argument("--top_k", type=int, default=50,
help="Top-k sampling (0 = disabled)")
p.add_argument("--max_new_tokens", type=int, default=300,
help="Max tokens per assistant response")
p.add_argument("--system", type=str, default=DEFAULT_SYSTEM,
help="System prompt")
return p.parse_args()
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\n" + "=" * 60)
print(" SLLM-150M Chat")
print("=" * 60)
print(f" Device : {device}")
if device.type == "cuda":
print(f" GPU : {torch.cuda.get_device_name(0)}")
# ---- Load ------------------------------------------------------ #
print("\nLoading model...")
model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device)
print(f" Checkpoint : {ckpt_path}")
print(f" Step : {step} Loss: {loss:.4f}")
print(f" Vocab size : {len(tokenizer):,}")
# ---- Chat loop ------------------------------------------------- #
system_prompt = args.system
history: list[dict] = []
print(f"\n System : {system_prompt}")
print(" Commands: /reset | /system <new prompt> | /quit")
print("─" * 60 + "\n")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input:
continue
# ---- Commands ---------------------------------------------- #
if user_input.lower() in ("/quit", "/exit", "quit", "exit"):
print("Bye!")
break
if user_input.lower() == "/reset":
history = []
print(" [Conversation cleared]\n")
continue
if user_input.lower().startswith("/system "):
new_sys = user_input[8:].strip()
if new_sys:
system_prompt = new_sys
history = []
print(f" [System prompt updated. Conversation cleared.]\n")
continue
# ---- Build prompt ------------------------------------------ #
history.append({"role": "user", "content": user_input})
input_ids = build_prompt(history, system_prompt, tokenizer)
# Trim history if prompt is getting close to context limit
while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10:
if len(history) > 2:
history = history[2:] # drop oldest user+assistant pair
input_ids = build_prompt(history, system_prompt, tokenizer)
else:
break # can't trim further — just truncate in generation
# ---- Generate ---------------------------------------------- #
print("SLLM: ", end="", flush=True)
response = generate_response(
model, input_ids, tokenizer,
max_new_tokens = args.max_new_tokens,
temperature = args.temperature,
top_k = args.top_k,
device = device,
)
print(response + "\n")
history.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()