# -- coding: utf-8 -- # Author: Antonín Tomeček # Date: 10 Jan 2026 # Description: Standalone text generation from GPT-style checkpoint 500k import os import torch import sentencepiece as spm # importuj model a třídy z tvého tréninkového souboru from train import Transformer, ModelArgs, generate_text # uprav podle názvu souboru # ========================= # CONFIG # ========================= CHECKPOINT_PATH = "checkpoints/best.pt" TOKENIZER_MODEL_PATH = "tokenizer.model" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_NEW_TOKENS = 200 TEMPERATURE = 0.8 TOP_P = 0.95 EOS_ID = 1 # podle tokenizeru, většinou 1 je # ========================= # Povolit ModelArgs při odpickle # ========================= torch.serialization.add_safe_globals([ModelArgs]) # ========================= # LOAD TOKENIZER # ========================= tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) vocab_size = tokenizer.vocab_size() # ========================= # LOAD CHECKPOINT # ========================= if not os.path.exists(CHECKPOINT_PATH): raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found") checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) # načteme model podle uložených args model_args = checkpoint.get("model_args", ModelArgs()) model_args.vocab_size = vocab_size model = Transformer(model_args).to(DEVICE) # načteme váhy model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}") print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params") # ========================= # PROMPTS # ========================= prompts = [ "Once upon a time", "In a distant future", "Artificial intelligence will", "First step to build a rocket", "Capital city of France" ] # ========================= # GENERATE TEXT # ========================= results = generate_text( model, tokenizer, prompts, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_p=TOP_P, eos_id=EOS_ID ) # ========================= # PRINT RESULTS # ========================= for prompt, text in results.items(): print("="*50) print(f"Prompt: {prompt}") print(f"Generated: {text}")