File size: 2,345 Bytes
8449341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711e74d
8449341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -- coding: utf-8 --
# Author: Antonín Tomeček
# Date: 10 Jan 2026
# Description: Standalone text generation from GPT-style checkpoint 500k

import os
import torch
import sentencepiece as spm

# importuj model a třídy z tvého tréninkového souboru
from train import Transformer, ModelArgs, generate_text  # uprav podle názvu souboru

# =========================
# CONFIG
# =========================
CHECKPOINT_PATH = "checkpoints/best.pt"
TOKENIZER_MODEL_PATH = "tokenizer.model"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MAX_NEW_TOKENS = 200
TEMPERATURE = 0.8
TOP_P = 0.95
EOS_ID = 1  # podle tokenizeru, většinou 1 je </s>

# =========================
# Povolit ModelArgs při odpickle
# =========================
torch.serialization.add_safe_globals([ModelArgs])

# =========================
# LOAD TOKENIZER
# =========================
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH)
vocab_size = tokenizer.vocab_size()

# =========================
# LOAD CHECKPOINT
# =========================
if not os.path.exists(CHECKPOINT_PATH):
    raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found")

checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False)

# načteme model podle uložených args
model_args = checkpoint.get("model_args", ModelArgs())
model_args.vocab_size = vocab_size
model = Transformer(model_args).to(DEVICE)

# načteme váhy
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params")

# =========================
# PROMPTS
# =========================
prompts = [
    "Once upon a time",
    "In a distant future",
    "Artificial intelligence will",
    "First step to build a rocket",
    "Capital city of France"
]

# =========================
# GENERATE TEXT
# =========================
results = generate_text(
    model,
    tokenizer,
    prompts,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    eos_id=EOS_ID
)

# =========================
# PRINT RESULTS
# =========================
for prompt, text in results.items():
    print("="*50)
    print(f"Prompt: {prompt}")
    print(f"Generated: {text}")