NanoGPT-X_Base / test-checkpoint.py
luxopes's picture
Rename test-checkpoints.py to test-checkpoint.py
711e74d verified
# -- coding: utf-8 --
# Author: Antonín Tomeček
# Date: 10 Jan 2026
# Description: Standalone text generation from GPT-style checkpoint 500k
import os
import torch
import sentencepiece as spm
# importuj model a třídy z tvého tréninkového souboru
from train import Transformer, ModelArgs, generate_text # uprav podle názvu souboru
# =========================
# CONFIG
# =========================
CHECKPOINT_PATH = "checkpoints/best.pt"
TOKENIZER_MODEL_PATH = "tokenizer.model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.8
TOP_P = 0.95
EOS_ID = 1 # podle tokenizeru, většinou 1 je </s>
# =========================
# Povolit ModelArgs při odpickle
# =========================
torch.serialization.add_safe_globals([ModelArgs])
# =========================
# LOAD TOKENIZER
# =========================
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
vocab_size = tokenizer.vocab_size()
# =========================
# LOAD CHECKPOINT
# =========================
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
# načteme model podle uložených args
model_args = checkpoint.get("model_args", ModelArgs())
model_args.vocab_size = vocab_size
model = Transformer(model_args).to(DEVICE)
# načteme váhy
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params")
# =========================
# PROMPTS
# =========================
prompts = [
"Once upon a time",
"In a distant future",
"Artificial intelligence will",
"First step to build a rocket",
"Capital city of France"
]
# =========================
# GENERATE TEXT
# =========================
results = generate_text(
model,
tokenizer,
prompts,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
eos_id=EOS_ID
)
# =========================
# PRINT RESULTS
# =========================
for prompt, text in results.items():
print("="*50)
print(f"Prompt: {prompt}")
print(f"Generated: {text}")