sllm / test_checkpoint.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
test_checkpoint.py — Load a checkpoint and run inference / inspect it.
QUICK START: Edit the variables in the CONFIG section below, then run:
python test_checkpoint.py
Modes:
INTERACTIVE — Chat loop: type prompts, model responds.
SAMPLE — Auto-generate N samples from fixed prompts and exit.
INSPECT — Just print checkpoint info (no generation).
"""
import os
import sys
import torch
from torch.amp import autocast
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import SLLM_100M, SLLM_150M, ModelConfig
from model.model import SLLM
# ================================================================== #
# ✏️ EDIT THESE VARIABLES
# ================================================================== #
# --- Checkpoint to load -------------------------------------------
# Point to any .pt file inside a runs/ subfolder.
# Examples:
# RUN_DIR = "runs/sllm_150m" # loads latest .pt in this folder
# CKPT_FILE = None # set to a specific filename to override
# CKPT_FILE = "ckpt_0002000.pt" # or pick a specific step
RUN_DIR = "runs/sllm_150m"
CKPT_FILE = None # None = auto-pick latest checkpoint in RUN_DIR
# --- Model config --------------------------------------------------
# Must match what you trained with: "100M" or "150M"
CONFIG = "150M"
# --- Generation settings ------------------------------------------
MAX_NEW_TOKENS = 100 # tokens to generate per prompt
TEMPERATURE = 0.8 # 0.0 = greedy, 1.0 = random, 0.8 = balanced
TOP_K = 50 # keep only top-k logits (0 = disabled)
TOP_P = 0.95 # nucleus sampling threshold (1.0 = disabled)
# --- Mode ---------------------------------------------------------
# "interactive" : chat loop in the terminal
# "sample" : run SAMPLE_PROMPTS list and exit
# "inspect" : just print checkpoint metadata, no generation
MODE = "sample"
# --- Prompts for SAMPLE mode --------------------------------------
SAMPLE_PROMPTS = [
"Once upon a time",
"The meaning of life is",
"In the year 2050,",
]
# --- dtype --------------------------------------------------------
# "bf16" (recommended on RTX cards), "fp16", or "fp32"
DTYPE = "bf16"
# ================================================================== #
# INTERNALS (no need to edit below)
# ================================================================== #
def resolve_checkpoint(run_dir: str, ckpt_file) -> str:
"""Return full path to the checkpoint file."""
if ckpt_file is not None:
path = os.path.join(run_dir, ckpt_file)
if not os.path.isfile(path):
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path
# Auto-pick latest
if not os.path.isdir(run_dir):
raise FileNotFoundError(f"Run directory not found: {run_dir}")
ckpts = sorted([
f for f in os.listdir(run_dir)
if f.startswith("ckpt_") and f.endswith(".pt")
])
if not ckpts:
raise FileNotFoundError(f"No checkpoints found in: {run_dir}")
return os.path.join(run_dir, ckpts[-1])
def load_model(ckpt_path: str, config_name: str, device, dtype_torch):
"""Load model weights from checkpoint."""
cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M}
cfg = cfg_map[config_name]
print(f"\n Config : {cfg}")
model = SLLM(cfg).to(device)
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# Prefer config_name stored in checkpoint (override CLI if available)
ckpt_cfg_name = ckpt.get("config_name", config_name)
if ckpt_cfg_name != config_name:
print(f" [WARN] Checkpoint config_name='{ckpt_cfg_name}' "
f"differs from CONFIG='{config_name}'. "
f"Using checkpoint's config: '{ckpt_cfg_name}'")
cfg = cfg_map[ckpt_cfg_name]
model = SLLM(cfg).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
step = ckpt.get("step", "?")
loss = ckpt.get("loss", float("nan"))
return model, cfg, step, loss
@torch.no_grad()
def generate(model, prompt_ids: list[int], cfg: ModelConfig, device,
dtype_torch, use_amp: bool,
max_new_tokens: int, temperature: float,
top_k: int, top_p: float) -> list[int]:
"""Token-by-token autoregressive generation."""
ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
ctx_len = cfg.context_length
for _ in range(max_new_tokens):
# Crop to context window
ids_crop = ids[:, -ctx_len:]
with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp):
logits, _ = model(ids_crop)
# Logits for the last position
logits = logits[:, -1, :] # (1, vocab)
if temperature == 0.0:
# Greedy
next_id = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / temperature
# Top-K filtering
if top_k > 0:
vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < vals[:, [-1]]] = float("-inf")
# Top-P (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative prob > top_p
sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf")
logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
ids = torch.cat([ids, next_id], dim=1)
return ids[0].tolist()
def char_tokenize(text: str) -> list[int]:
"""
Fallback character-level tokenizer.
Your model uses a real tokenizer — swap this out with yours if available.
Each char maps to its Unicode code point (capped at vocab_size - 1).
"""
return [min(ord(c), 31_999) for c in text]
def char_detokenize(ids: list[int]) -> str:
"""Reverse of char_tokenize."""
return "".join(chr(i) if 32 <= i < 127 else "?" for i in ids)
def try_load_sentencepiece(tokenizer_dir="tokenizer/fineweb_edu_tokenizer"):
"""Load the HuggingFace PreTrainedTokenizerFast used during training."""
try:
from transformers import PreTrainedTokenizerFast
tok = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
encode = lambda text: tok.encode(text)
decode = lambda ids: tok.decode(ids, skip_special_tokens=True)
print(f" Tokenizer: HuggingFace tokenizer loaded from '{tokenizer_dir}'")
print(f" vocab_size={tok.vocab_size:,} eos_id={tok.eos_token_id}")
return encode, decode
except Exception as e:
print(f" Tokenizer: Could not load HuggingFace tokenizer ({e})")
print(" Falling back to char tokenizer — output will be garbled!")
return char_tokenize, char_detokenize
def run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode):
print("\n" + "="*60)
print(" INTERACTIVE MODE (type 'quit' or 'exit' to stop)")
print("="*60)
print(f" max_new_tokens : {MAX_NEW_TOKENS}")
print(f" temperature : {TEMPERATURE}")
print(f" top_k / top_p : {TOP_K} / {TOP_P}")
print()
while True:
try:
prompt = input("Prompt> ").strip()
except (EOFError, KeyboardInterrupt):
print("\n Exiting.")
break
if prompt.lower() in ("quit", "exit", ""):
print(" Exiting.")
break
prompt_ids = encode(prompt)
output_ids = generate(
model, prompt_ids, cfg, device, dtype_torch, use_amp,
MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
)
# Only show the newly generated tokens
new_ids = output_ids[len(prompt_ids):]
print(f"\nGenerated: {decode(new_ids)}\n")
def run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode):
print("\n" + "="*60)
print(" SAMPLE MODE")
print("="*60)
for i, prompt in enumerate(SAMPLE_PROMPTS, 1):
print(f"\n[{i}] Prompt : {prompt!r}")
prompt_ids = encode(prompt)
output_ids = generate(
model, prompt_ids, cfg, device, dtype_torch, use_amp,
MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P,
)
new_ids = output_ids[len(prompt_ids):]
print(f" Output : {decode(new_ids)}")
def run_inspect(ckpt_path, step, loss, cfg):
print("\n" + "="*60)
print(" INSPECT MODE")
print("="*60)
print(f" Checkpoint : {ckpt_path}")
print(f" Step : {step}")
print(f" Loss : {loss:.4f}" if isinstance(loss, float) else f" Loss: {loss}")
print(f" Config : {cfg}")
print(f" Params : {cfg.count_params()/1e6:.1f}M")
print()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nDevice : {device}")
if device.type == "cuda":
print(f"GPU : {torch.cuda.get_device_name(0)}")
print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# dtype setup
use_amp = False
if DTYPE == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported():
dtype_torch = torch.bfloat16
use_amp = True
elif DTYPE == "fp16" and device.type == "cuda":
dtype_torch = torch.float16
use_amp = True
else:
dtype_torch = torch.float32
print(f"dtype : {DTYPE}")
# Resolve checkpoint path
ckpt_path = resolve_checkpoint(RUN_DIR, CKPT_FILE)
print(f"\nCheckpoint: {ckpt_path}")
# Load model
model, cfg, step, loss = load_model(ckpt_path, CONFIG, device, dtype_torch)
print(f" Loaded : step={step}, loss={loss:.4f}")
print(f" Params : {model.count_params()/1e6:.1f}M")
if MODE == "inspect":
run_inspect(ckpt_path, step, loss, cfg)
return
# Load tokenizer
encode, decode = try_load_sentencepiece()
if MODE == "interactive":
run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode)
elif MODE == "sample":
run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode)
else:
print(f" [ERROR] Unknown MODE: '{MODE}'. Use 'interactive', 'sample', or 'inspect'.")
if __name__ == "__main__":
main()