nanogpt-tr-v5-code / 06_sample.py
musabc's picture
upload 06_sample.py
e2e6604 verified
Raw
History Blame Contribute Delete
9.39 kB
"""
Egitilmis modelden ornek metin uret. V3 / V4 / V5 ile uyumlu.
Otomatik tespit:
- ckpt['version'] = 'v5' -> V5 (200M, 32K vocab, T=2048, theta=100K)
- ckpt['version'] = 'v4*' -> V4 (50M)
- ckpt['config'] icinde 'rope_theta' yoksa -> V3
- tokenizer auto-detect (v5 -> tokenizer-tr-v5, v4 -> tokenizer-tr-16k)
Kullanim:
python 06_sample.py # V5 default (best ckpt)
python 06_sample.py --version v4 # V4'ten sample
python 06_sample.py --prompt "İstanbul" --max-tokens 200
python 06_sample.py --latest # latest_ckpt.pt
python 06_sample.py --ckpt runs/tr-200m-v5/best_ckpt.pt
python 06_sample.py --num-samples 5 --temperature 0.7
python 06_sample.py --chat --prompt "Türkiye nedir?" # SFT modeli için
"""
import argparse
import os
from pathlib import Path
import torch
from tokenizers import Tokenizer
# Liger Kernel'i sample sirasinda kapat — kosul: tek-token forward Liger'in
# chunked CE'sine ihtiyac duymaz, fused kernel JIT compile overhead'i sample'i yavaslatir
os.environ.setdefault("NANOGPT_NO_LIGER", "1")
# Model import'lari — opsiyonel
HAS_V3 = HAS_V4 = HAS_V5 = False
try:
from model import GPT, GPTConfig
HAS_V3 = True
except ImportError:
GPT = GPTConfig = None
try:
from model_v4 import GPTV4, GPTConfigV4
HAS_V4 = True
except ImportError:
GPTV4 = GPTConfigV4 = None
try:
from model_v5 import GPTV5, GPTConfigV5
HAS_V5 = True
except ImportError:
GPTV5 = GPTConfigV5 = None
DATA_DIR = Path(__file__).parent / "data"
RUN_DIRS = {
"v3": Path(__file__).parent / "runs" / "tr-50m-v3",
"v4": Path(__file__).parent / "runs" / "tr-50m-v4",
"v5": Path(__file__).parent / "runs" / "tr-200m-v5",
}
TOKENIZERS = {
"v3": "tokenizer-tr-16k.json",
"v4": "tokenizer-tr-16k.json",
"v5": "tokenizer-tr-v5.json",
}
def detect_version(ckpt: dict) -> str:
"""Checkpoint icinden version tespit et."""
v = ckpt.get("version", "")
if isinstance(v, str):
if v.startswith("v5"):
return "v5"
if v.startswith("v4"):
return "v4"
# Config bazli fallback
cfg = ckpt.get("config", {})
if "rope_theta" not in cfg:
return "v3"
# V5 vs V4: vocab_size farkli (V5=32000, V4=16000)
vs = cfg.get("vocab_size", 0)
if vs >= 24000:
return "v5"
return "v4"
def build_model(version: str, ckpt: dict, device: str):
cfg_dict = ckpt["config"]
if version == "v5":
if not HAS_V5:
raise ImportError("V5 checkpoint ama model_v5.py yok.")
cfg = GPTConfigV5(**cfg_dict)
model = GPTV5(cfg).to(device)
return model, cfg, "V5 (RoPE+RMSNorm+SwiGLU+QK-norm+softcap, 200M)"
if version == "v4":
if not HAS_V4:
raise ImportError("V4 checkpoint ama model_v4.py yok.")
cfg = GPTConfigV4(**cfg_dict)
model = GPTV4(cfg).to(device)
return model, cfg, "V4 (RoPE+RMSNorm+SwiGLU+QK-norm, 50M)"
if version == "v3":
if not HAS_V3:
raise ImportError("V3 checkpoint ama model.py yok.")
cfg = GPTConfig(**cfg_dict)
model = GPT(cfg).to(device)
return model, cfg, "V3 (LayerNorm+GELU+learned PE)"
raise ValueError(f"Bilinmeyen version: {version}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="Türkiye")
parser.add_argument("--max-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--repetition-penalty", type=float, default=1.15)
parser.add_argument("--no-repeat-ngram", type=int, default=3)
parser.add_argument("--num-samples", type=int, default=3)
parser.add_argument("--ckpt", type=str, default=None,
help="Tam checkpoint yolu (yoksa --version + best/latest)")
parser.add_argument("--version", type=str, default="v5",
choices=["v3", "v4", "v5"],
help="Hangi run dizini (--ckpt verilmediyse)")
parser.add_argument("--latest", action="store_true",
help="best yerine latest checkpoint'i kullan")
parser.add_argument("--chat", action="store_true",
help="SFT/Instruct ChatML formatı uygula (otomatik tespit de var)")
parser.add_argument("--instruction", type=str, default=None,
help="ChatML için ayrı instruction (input ile birlikte)")
parser.add_argument("--tokenizer", type=str, default=None,
help="Tokenizer dosya yolu (yoksa version'a göre seç)")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--device", type=str, default=None,
choices=["cuda", "cpu"])
args = parser.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Checkpoint yolunu belirle
if args.ckpt:
ckpt_path = Path(args.ckpt)
else:
run_dir = RUN_DIRS[args.version]
name = "latest_ckpt.pt" if args.latest else "best_ckpt.pt"
ckpt_path = run_dir / name
if not ckpt_path.exists():
# Fallback: best yoksa latest dene
alt = run_dir / ("best_ckpt.pt" if args.latest else "latest_ckpt.pt")
if alt.exists():
ckpt_path = alt
print(f" ({name} yok, {alt.name} kullanılıyor)")
else:
raise FileNotFoundError(
f"Checkpoint yok: {ckpt_path} (run_dir={run_dir})"
)
print(f"Checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# Version tespit
version = detect_version(ckpt)
if version != args.version and not args.ckpt:
print(f" ! Algılanan version={version}, --version={args.version}")
print(f"Version: {version}")
# Model
model, cfg, desc = build_model(version, ckpt, device)
# State dict — torch.compile prefix temizle
state = ckpt["model"]
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()}
model.load_state_dict(state)
model.eval()
step = ckpt.get("step", "?")
val = ckpt.get("best_val", None)
val_str = f", val={val:.4f}" if val is not None and val != float("inf") else ""
n_params = model.num_params() if hasattr(model, "num_params") else \
sum(p.numel() for p in model.parameters())
print(f"Model: {desc}")
print(f" {n_params/1e6:.2f}M param (step={step}, "
f"version={ckpt.get('version','?')}{val_str})")
# Tokenizer
tok_path = args.tokenizer or str(DATA_DIR / TOKENIZERS[version])
if not Path(tok_path).exists():
raise FileNotFoundError(f"Tokenizer yok: {tok_path}")
tokenizer = Tokenizer.from_file(tok_path)
print(f"Tokenizer: {Path(tok_path).name} "
f"(vocab={tokenizer.get_vocab_size()})")
# ChatML format (--chat veya version=*-instruct otomatik)
raw_version = ckpt.get("version", "")
auto_chat = any(t in str(raw_version) for t in ("instruct", "sft", "dpo", "chat"))
use_chat = args.chat or auto_chat
if use_chat:
user_msg = (f"{args.instruction}\n{args.prompt}"
if args.instruction else args.prompt)
formatted = f"<|user|>\n{user_msg}\n<|assistant|>\n"
print(f"\nChatML format AKTİF (version={raw_version})")
print(f"User prompt: {user_msg!r}")
else:
formatted = args.prompt
print(f"\nRaw prompt: {args.prompt!r}")
print(f"Settings: max={args.max_tokens}, temp={args.temperature}, "
f"top_k={args.top_k}, rep_pen={args.repetition_penalty}, "
f"no_rep_ngram={args.no_repeat_ngram}")
print("=" * 70)
ids = tokenizer.encode(formatted).ids
x = torch.tensor([ids], dtype=torch.long, device=device)
use_bf16 = device == "cuda" and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float32
# Context window — V5 block_size=2048 (RoPE buffer x2 = 4096'ya kadar uzar)
max_ctx = cfg.block_size
for i in range(args.num_samples):
# Context overflow koruması
cur_ids = ids
if len(cur_ids) >= max_ctx:
print(f" ! Prompt {len(cur_ids)} token, "
f"context {max_ctx} → kırpılıyor")
cur_ids = cur_ids[-(max_ctx - args.max_tokens):]
x_i = torch.tensor([cur_ids], dtype=torch.long, device=device)
amp_ctx = (torch.amp.autocast(device_type="cuda", dtype=dtype)
if device == "cuda" else torch.no_grad())
with amp_ctx, torch.no_grad():
out = model.generate(
x_i,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
no_repeat_ngram_size=args.no_repeat_ngram,
)
text = tokenizer.decode(out[0].tolist())
print(f"\n--- Sample {i+1} ---")
print(text)
print()
if __name__ == "__main__":
main()