""" test_checkpoint.py — Load a checkpoint and run inference / inspect it. QUICK START: Edit the variables in the CONFIG section below, then run: python test_checkpoint.py Modes: INTERACTIVE — Chat loop: type prompts, model responds. SAMPLE — Auto-generate N samples from fixed prompts and exit. INSPECT — Just print checkpoint info (no generation). """ import os import sys import torch from torch.amp import autocast sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model.config import SLLM_100M, SLLM_150M, ModelConfig from model.model import SLLM # ================================================================== # # ✏️ EDIT THESE VARIABLES # ================================================================== # # --- Checkpoint to load ------------------------------------------- # Point to any .pt file inside a runs/ subfolder. # Examples: # RUN_DIR = "runs/sllm_150m" # loads latest .pt in this folder # CKPT_FILE = None # set to a specific filename to override # CKPT_FILE = "ckpt_0002000.pt" # or pick a specific step RUN_DIR = "runs/sllm_150m" CKPT_FILE = None # None = auto-pick latest checkpoint in RUN_DIR # --- Model config -------------------------------------------------- # Must match what you trained with: "100M" or "150M" CONFIG = "150M" # --- Generation settings ------------------------------------------ MAX_NEW_TOKENS = 100 # tokens to generate per prompt TEMPERATURE = 0.8 # 0.0 = greedy, 1.0 = random, 0.8 = balanced TOP_K = 50 # keep only top-k logits (0 = disabled) TOP_P = 0.95 # nucleus sampling threshold (1.0 = disabled) # --- Mode --------------------------------------------------------- # "interactive" : chat loop in the terminal # "sample" : run SAMPLE_PROMPTS list and exit # "inspect" : just print checkpoint metadata, no generation MODE = "sample" # --- Prompts for SAMPLE mode -------------------------------------- SAMPLE_PROMPTS = [ "Once upon a time", "The meaning of life is", "In the year 2050,", ] # --- dtype -------------------------------------------------------- # "bf16" (recommended on RTX cards), "fp16", or "fp32" DTYPE = "bf16" # ================================================================== # # INTERNALS (no need to edit below) # ================================================================== # def resolve_checkpoint(run_dir: str, ckpt_file) -> str: """Return full path to the checkpoint file.""" if ckpt_file is not None: path = os.path.join(run_dir, ckpt_file) if not os.path.isfile(path): raise FileNotFoundError(f"Checkpoint not found: {path}") return path # Auto-pick latest if not os.path.isdir(run_dir): raise FileNotFoundError(f"Run directory not found: {run_dir}") ckpts = sorted([ f for f in os.listdir(run_dir) if f.startswith("ckpt_") and f.endswith(".pt") ]) if not ckpts: raise FileNotFoundError(f"No checkpoints found in: {run_dir}") return os.path.join(run_dir, ckpts[-1]) def load_model(ckpt_path: str, config_name: str, device, dtype_torch): """Load model weights from checkpoint.""" cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M} cfg = cfg_map[config_name] print(f"\n Config : {cfg}") model = SLLM(cfg).to(device) ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) # Prefer config_name stored in checkpoint (override CLI if available) ckpt_cfg_name = ckpt.get("config_name", config_name) if ckpt_cfg_name != config_name: print(f" [WARN] Checkpoint config_name='{ckpt_cfg_name}' " f"differs from CONFIG='{config_name}'. " f"Using checkpoint's config: '{ckpt_cfg_name}'") cfg = cfg_map[ckpt_cfg_name] model = SLLM(cfg).to(device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() step = ckpt.get("step", "?") loss = ckpt.get("loss", float("nan")) return model, cfg, step, loss @torch.no_grad() def generate(model, prompt_ids: list[int], cfg: ModelConfig, device, dtype_torch, use_amp: bool, max_new_tokens: int, temperature: float, top_k: int, top_p: float) -> list[int]: """Token-by-token autoregressive generation.""" ids = torch.tensor([prompt_ids], dtype=torch.long, device=device) ctx_len = cfg.context_length for _ in range(max_new_tokens): # Crop to context window ids_crop = ids[:, -ctx_len:] with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): logits, _ = model(ids_crop) # Logits for the last position logits = logits[:, -1, :] # (1, vocab) if temperature == 0.0: # Greedy next_id = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature # Top-K filtering if top_k > 0: vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < vals[:, [-1]]] = float("-inf") # Top-P (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative prob > top_p sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) ids = torch.cat([ids, next_id], dim=1) return ids[0].tolist() def char_tokenize(text: str) -> list[int]: """ Fallback character-level tokenizer. Your model uses a real tokenizer — swap this out with yours if available. Each char maps to its Unicode code point (capped at vocab_size - 1). """ return [min(ord(c), 31_999) for c in text] def char_detokenize(ids: list[int]) -> str: """Reverse of char_tokenize.""" return "".join(chr(i) if 32 <= i < 127 else "?" for i in ids) def try_load_sentencepiece(tokenizer_dir="tokenizer/fineweb_edu_tokenizer"): """Load the HuggingFace PreTrainedTokenizerFast used during training.""" try: from transformers import PreTrainedTokenizerFast tok = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir) encode = lambda text: tok.encode(text) decode = lambda ids: tok.decode(ids, skip_special_tokens=True) print(f" Tokenizer: HuggingFace tokenizer loaded from '{tokenizer_dir}'") print(f" vocab_size={tok.vocab_size:,} eos_id={tok.eos_token_id}") return encode, decode except Exception as e: print(f" Tokenizer: Could not load HuggingFace tokenizer ({e})") print(" Falling back to char tokenizer — output will be garbled!") return char_tokenize, char_detokenize def run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode): print("\n" + "="*60) print(" INTERACTIVE MODE (type 'quit' or 'exit' to stop)") print("="*60) print(f" max_new_tokens : {MAX_NEW_TOKENS}") print(f" temperature : {TEMPERATURE}") print(f" top_k / top_p : {TOP_K} / {TOP_P}") print() while True: try: prompt = input("Prompt> ").strip() except (EOFError, KeyboardInterrupt): print("\n Exiting.") break if prompt.lower() in ("quit", "exit", ""): print(" Exiting.") break prompt_ids = encode(prompt) output_ids = generate( model, prompt_ids, cfg, device, dtype_torch, use_amp, MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P, ) # Only show the newly generated tokens new_ids = output_ids[len(prompt_ids):] print(f"\nGenerated: {decode(new_ids)}\n") def run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode): print("\n" + "="*60) print(" SAMPLE MODE") print("="*60) for i, prompt in enumerate(SAMPLE_PROMPTS, 1): print(f"\n[{i}] Prompt : {prompt!r}") prompt_ids = encode(prompt) output_ids = generate( model, prompt_ids, cfg, device, dtype_torch, use_amp, MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P, ) new_ids = output_ids[len(prompt_ids):] print(f" Output : {decode(new_ids)}") def run_inspect(ckpt_path, step, loss, cfg): print("\n" + "="*60) print(" INSPECT MODE") print("="*60) print(f" Checkpoint : {ckpt_path}") print(f" Step : {step}") print(f" Loss : {loss:.4f}" if isinstance(loss, float) else f" Loss: {loss}") print(f" Config : {cfg}") print(f" Params : {cfg.count_params()/1e6:.1f}M") print() def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice : {device}") if device.type == "cuda": print(f"GPU : {torch.cuda.get_device_name(0)}") print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # dtype setup use_amp = False if DTYPE == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): dtype_torch = torch.bfloat16 use_amp = True elif DTYPE == "fp16" and device.type == "cuda": dtype_torch = torch.float16 use_amp = True else: dtype_torch = torch.float32 print(f"dtype : {DTYPE}") # Resolve checkpoint path ckpt_path = resolve_checkpoint(RUN_DIR, CKPT_FILE) print(f"\nCheckpoint: {ckpt_path}") # Load model model, cfg, step, loss = load_model(ckpt_path, CONFIG, device, dtype_torch) print(f" Loaded : step={step}, loss={loss:.4f}") print(f" Params : {model.count_params()/1e6:.1f}M") if MODE == "inspect": run_inspect(ckpt_path, step, loss, cfg) return # Load tokenizer encode, decode = try_load_sentencepiece() if MODE == "interactive": run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode) elif MODE == "sample": run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode) else: print(f" [ERROR] Unknown MODE: '{MODE}'. Use 'interactive', 'sample', or 'inspect'.") if __name__ == "__main__": main()