File size: 2,345 Bytes
8449341 711e74d 8449341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# -- coding: utf-8 --
# Author: Antonín Tomeček
# Date: 10 Jan 2026
# Description: Standalone text generation from GPT-style checkpoint 500k
import os
import torch
import sentencepiece as spm
# importuj model a třídy z tvého tréninkového souboru
from train import Transformer, ModelArgs, generate_text # uprav podle názvu souboru
# =========================
# CONFIG
# =========================
CHECKPOINT_PATH = "checkpoints/best.pt"
TOKENIZER_MODEL_PATH = "tokenizer.model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.8
TOP_P = 0.95
EOS_ID = 1 # podle tokenizeru, většinou 1 je </s>
# =========================
# Povolit ModelArgs při odpickle
# =========================
torch.serialization.add_safe_globals([ModelArgs])
# =========================
# LOAD TOKENIZER
# =========================
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
vocab_size = tokenizer.vocab_size()
# =========================
# LOAD CHECKPOINT
# =========================
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)
# načteme model podle uložených args
model_args = checkpoint.get("model_args", ModelArgs())
model_args.vocab_size = vocab_size
model = Transformer(model_args).to(DEVICE)
# načteme váhy
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params")
# =========================
# PROMPTS
# =========================
prompts = [
"Once upon a time",
"In a distant future",
"Artificial intelligence will",
"First step to build a rocket",
"Capital city of France"
]
# =========================
# GENERATE TEXT
# =========================
results = generate_text(
model,
tokenizer,
prompts,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
eos_id=EOS_ID
)
# =========================
# PRINT RESULTS
# =========================
for prompt, text in results.items():
print("="*50)
print(f"Prompt: {prompt}")
print(f"Generated: {text}")
|