| """ |
| test_checkpoint.py — Load a checkpoint and run inference / inspect it. |
| |
| QUICK START: Edit the variables in the CONFIG section below, then run: |
| python test_checkpoint.py |
| |
| Modes: |
| INTERACTIVE — Chat loop: type prompts, model responds. |
| SAMPLE — Auto-generate N samples from fixed prompts and exit. |
| INSPECT — Just print checkpoint info (no generation). |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from torch.amp import autocast |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.config import SLLM_100M, SLLM_150M, ModelConfig |
| from model.model import SLLM |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| RUN_DIR = "runs/sllm_150m" |
| CKPT_FILE = None |
|
|
| |
| |
| CONFIG = "150M" |
|
|
| |
| MAX_NEW_TOKENS = 100 |
| TEMPERATURE = 0.8 |
| TOP_K = 50 |
| TOP_P = 0.95 |
|
|
| |
| |
| |
| |
| MODE = "sample" |
|
|
| |
| SAMPLE_PROMPTS = [ |
| "Once upon a time", |
| "The meaning of life is", |
| "In the year 2050,", |
| ] |
|
|
| |
| |
| DTYPE = "bf16" |
|
|
| |
| |
| |
|
|
| def resolve_checkpoint(run_dir: str, ckpt_file) -> str: |
| """Return full path to the checkpoint file.""" |
| if ckpt_file is not None: |
| path = os.path.join(run_dir, ckpt_file) |
| if not os.path.isfile(path): |
| raise FileNotFoundError(f"Checkpoint not found: {path}") |
| return path |
|
|
| |
| if not os.path.isdir(run_dir): |
| raise FileNotFoundError(f"Run directory not found: {run_dir}") |
| ckpts = sorted([ |
| f for f in os.listdir(run_dir) |
| if f.startswith("ckpt_") and f.endswith(".pt") |
| ]) |
| if not ckpts: |
| raise FileNotFoundError(f"No checkpoints found in: {run_dir}") |
| return os.path.join(run_dir, ckpts[-1]) |
|
|
|
|
| def load_model(ckpt_path: str, config_name: str, device, dtype_torch): |
| """Load model weights from checkpoint.""" |
| cfg_map = {"100M": SLLM_100M, "150M": SLLM_150M} |
| cfg = cfg_map[config_name] |
|
|
| print(f"\n Config : {cfg}") |
| model = SLLM(cfg).to(device) |
|
|
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| |
| ckpt_cfg_name = ckpt.get("config_name", config_name) |
| if ckpt_cfg_name != config_name: |
| print(f" [WARN] Checkpoint config_name='{ckpt_cfg_name}' " |
| f"differs from CONFIG='{config_name}'. " |
| f"Using checkpoint's config: '{ckpt_cfg_name}'") |
| cfg = cfg_map[ckpt_cfg_name] |
| model = SLLM(cfg).to(device) |
|
|
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
|
|
| step = ckpt.get("step", "?") |
| loss = ckpt.get("loss", float("nan")) |
| return model, cfg, step, loss |
|
|
|
|
| @torch.no_grad() |
| def generate(model, prompt_ids: list[int], cfg: ModelConfig, device, |
| dtype_torch, use_amp: bool, |
| max_new_tokens: int, temperature: float, |
| top_k: int, top_p: float) -> list[int]: |
| """Token-by-token autoregressive generation.""" |
| ids = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
| ctx_len = cfg.context_length |
|
|
| for _ in range(max_new_tokens): |
| |
| ids_crop = ids[:, -ctx_len:] |
|
|
| with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): |
| logits, _ = model(ids_crop) |
|
|
| |
| logits = logits[:, -1, :] |
|
|
| if temperature == 0.0: |
| |
| next_id = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits = logits / temperature |
|
|
| |
| if top_k > 0: |
| vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < vals[:, [-1]]] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") |
| logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
|
|
| ids = torch.cat([ids, next_id], dim=1) |
|
|
| return ids[0].tolist() |
|
|
|
|
| def char_tokenize(text: str) -> list[int]: |
| """ |
| Fallback character-level tokenizer. |
| Your model uses a real tokenizer — swap this out with yours if available. |
| Each char maps to its Unicode code point (capped at vocab_size - 1). |
| """ |
| return [min(ord(c), 31_999) for c in text] |
|
|
|
|
| def char_detokenize(ids: list[int]) -> str: |
| """Reverse of char_tokenize.""" |
| return "".join(chr(i) if 32 <= i < 127 else "?" for i in ids) |
|
|
|
|
| def try_load_sentencepiece(tokenizer_dir="tokenizer/fineweb_edu_tokenizer"): |
| """Load the HuggingFace PreTrainedTokenizerFast used during training.""" |
| try: |
| from transformers import PreTrainedTokenizerFast |
| tok = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir) |
| encode = lambda text: tok.encode(text) |
| decode = lambda ids: tok.decode(ids, skip_special_tokens=True) |
| print(f" Tokenizer: HuggingFace tokenizer loaded from '{tokenizer_dir}'") |
| print(f" vocab_size={tok.vocab_size:,} eos_id={tok.eos_token_id}") |
| return encode, decode |
| except Exception as e: |
| print(f" Tokenizer: Could not load HuggingFace tokenizer ({e})") |
| print(" Falling back to char tokenizer — output will be garbled!") |
| return char_tokenize, char_detokenize |
|
|
|
|
| def run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode): |
| print("\n" + "="*60) |
| print(" INTERACTIVE MODE (type 'quit' or 'exit' to stop)") |
| print("="*60) |
| print(f" max_new_tokens : {MAX_NEW_TOKENS}") |
| print(f" temperature : {TEMPERATURE}") |
| print(f" top_k / top_p : {TOP_K} / {TOP_P}") |
| print() |
|
|
| while True: |
| try: |
| prompt = input("Prompt> ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\n Exiting.") |
| break |
|
|
| if prompt.lower() in ("quit", "exit", ""): |
| print(" Exiting.") |
| break |
|
|
| prompt_ids = encode(prompt) |
| output_ids = generate( |
| model, prompt_ids, cfg, device, dtype_torch, use_amp, |
| MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P, |
| ) |
| |
| new_ids = output_ids[len(prompt_ids):] |
| print(f"\nGenerated: {decode(new_ids)}\n") |
|
|
|
|
| def run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode): |
| print("\n" + "="*60) |
| print(" SAMPLE MODE") |
| print("="*60) |
| for i, prompt in enumerate(SAMPLE_PROMPTS, 1): |
| print(f"\n[{i}] Prompt : {prompt!r}") |
| prompt_ids = encode(prompt) |
| output_ids = generate( |
| model, prompt_ids, cfg, device, dtype_torch, use_amp, |
| MAX_NEW_TOKENS, TEMPERATURE, TOP_K, TOP_P, |
| ) |
| new_ids = output_ids[len(prompt_ids):] |
| print(f" Output : {decode(new_ids)}") |
|
|
|
|
| def run_inspect(ckpt_path, step, loss, cfg): |
| print("\n" + "="*60) |
| print(" INSPECT MODE") |
| print("="*60) |
| print(f" Checkpoint : {ckpt_path}") |
| print(f" Step : {step}") |
| print(f" Loss : {loss:.4f}" if isinstance(loss, float) else f" Loss: {loss}") |
| print(f" Config : {cfg}") |
| print(f" Params : {cfg.count_params()/1e6:.1f}M") |
| print() |
|
|
|
|
| def main(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice : {device}") |
| if device.type == "cuda": |
| print(f"GPU : {torch.cuda.get_device_name(0)}") |
| print(f"VRAM : {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| use_amp = False |
| if DTYPE == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): |
| dtype_torch = torch.bfloat16 |
| use_amp = True |
| elif DTYPE == "fp16" and device.type == "cuda": |
| dtype_torch = torch.float16 |
| use_amp = True |
| else: |
| dtype_torch = torch.float32 |
| print(f"dtype : {DTYPE}") |
|
|
| |
| ckpt_path = resolve_checkpoint(RUN_DIR, CKPT_FILE) |
| print(f"\nCheckpoint: {ckpt_path}") |
|
|
| |
| model, cfg, step, loss = load_model(ckpt_path, CONFIG, device, dtype_torch) |
| print(f" Loaded : step={step}, loss={loss:.4f}") |
| print(f" Params : {model.count_params()/1e6:.1f}M") |
|
|
| if MODE == "inspect": |
| run_inspect(ckpt_path, step, loss, cfg) |
| return |
|
|
| |
| encode, decode = try_load_sentencepiece() |
|
|
| if MODE == "interactive": |
| run_interactive(model, cfg, device, dtype_torch, use_amp, encode, decode) |
| elif MODE == "sample": |
| run_sample(model, cfg, device, dtype_torch, use_amp, encode, decode) |
| else: |
| print(f" [ERROR] Unknown MODE: '{MODE}'. Use 'interactive', 'sample', or 'inspect'.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|