|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import sentencepiece as spm |
|
|
|
|
|
|
|
|
from train import Transformer, ModelArgs, generate_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHECKPOINT_PATH = "checkpoints/best.pt" |
|
|
TOKENIZER_MODEL_PATH = "tokenizer.model" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
MAX_NEW_TOKENS = 200 |
|
|
TEMPERATURE = 0.8 |
|
|
TOP_P = 0.95 |
|
|
EOS_ID = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([ModelArgs]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = spm.SentencePieceProcessor(model_file=TOKENIZER_MODEL_PATH) |
|
|
vocab_size = tokenizer.vocab_size() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(CHECKPOINT_PATH): |
|
|
raise FileNotFoundError(f"Checkpoint {CHECKPOINT_PATH} not found") |
|
|
|
|
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=False) |
|
|
|
|
|
|
|
|
model_args = checkpoint.get("model_args", ModelArgs()) |
|
|
model_args.vocab_size = vocab_size |
|
|
model = Transformer(model_args).to(DEVICE) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
|
|
|
print(f"[Info] Loaded checkpoint from step {checkpoint.get('step', 'unknown')}") |
|
|
print(f"[Info] Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} params") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"Once upon a time", |
|
|
"In a distant future", |
|
|
"Artificial intelligence will", |
|
|
"First step to build a rocket", |
|
|
"Capital city of France" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results = generate_text( |
|
|
model, |
|
|
tokenizer, |
|
|
prompts, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
eos_id=EOS_ID |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for prompt, text in results.items(): |
|
|
print("="*50) |
|
|
print(f"Prompt: {prompt}") |
|
|
print(f"Generated: {text}") |
|
|
|