| """ |
| Egitilmis modelden ornek metin uret. V3 / V4 / V5 ile uyumlu. |
| |
| Otomatik tespit: |
| - ckpt['version'] = 'v5' -> V5 (200M, 32K vocab, T=2048, theta=100K) |
| - ckpt['version'] = 'v4*' -> V4 (50M) |
| - ckpt['config'] icinde 'rope_theta' yoksa -> V3 |
| - tokenizer auto-detect (v5 -> tokenizer-tr-v5, v4 -> tokenizer-tr-16k) |
| |
| Kullanim: |
| python 06_sample.py # V5 default (best ckpt) |
| python 06_sample.py --version v4 # V4'ten sample |
| python 06_sample.py --prompt "İstanbul" --max-tokens 200 |
| python 06_sample.py --latest # latest_ckpt.pt |
| python 06_sample.py --ckpt runs/tr-200m-v5/best_ckpt.pt |
| python 06_sample.py --num-samples 5 --temperature 0.7 |
| python 06_sample.py --chat --prompt "Türkiye nedir?" # SFT modeli için |
| """ |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| import torch |
| from tokenizers import Tokenizer |
|
|
| |
| |
| os.environ.setdefault("NANOGPT_NO_LIGER", "1") |
|
|
| |
| HAS_V3 = HAS_V4 = HAS_V5 = False |
| try: |
| from model import GPT, GPTConfig |
| HAS_V3 = True |
| except ImportError: |
| GPT = GPTConfig = None |
|
|
| try: |
| from model_v4 import GPTV4, GPTConfigV4 |
| HAS_V4 = True |
| except ImportError: |
| GPTV4 = GPTConfigV4 = None |
|
|
| try: |
| from model_v5 import GPTV5, GPTConfigV5 |
| HAS_V5 = True |
| except ImportError: |
| GPTV5 = GPTConfigV5 = None |
|
|
|
|
| DATA_DIR = Path(__file__).parent / "data" |
| RUN_DIRS = { |
| "v3": Path(__file__).parent / "runs" / "tr-50m-v3", |
| "v4": Path(__file__).parent / "runs" / "tr-50m-v4", |
| "v5": Path(__file__).parent / "runs" / "tr-200m-v5", |
| } |
| TOKENIZERS = { |
| "v3": "tokenizer-tr-16k.json", |
| "v4": "tokenizer-tr-16k.json", |
| "v5": "tokenizer-tr-v5.json", |
| } |
|
|
|
|
| def detect_version(ckpt: dict) -> str: |
| """Checkpoint icinden version tespit et.""" |
| v = ckpt.get("version", "") |
| if isinstance(v, str): |
| if v.startswith("v5"): |
| return "v5" |
| if v.startswith("v4"): |
| return "v4" |
| |
| cfg = ckpt.get("config", {}) |
| if "rope_theta" not in cfg: |
| return "v3" |
| |
| vs = cfg.get("vocab_size", 0) |
| if vs >= 24000: |
| return "v5" |
| return "v4" |
|
|
|
|
| def build_model(version: str, ckpt: dict, device: str): |
| cfg_dict = ckpt["config"] |
| if version == "v5": |
| if not HAS_V5: |
| raise ImportError("V5 checkpoint ama model_v5.py yok.") |
| cfg = GPTConfigV5(**cfg_dict) |
| model = GPTV5(cfg).to(device) |
| return model, cfg, "V5 (RoPE+RMSNorm+SwiGLU+QK-norm+softcap, 200M)" |
| if version == "v4": |
| if not HAS_V4: |
| raise ImportError("V4 checkpoint ama model_v4.py yok.") |
| cfg = GPTConfigV4(**cfg_dict) |
| model = GPTV4(cfg).to(device) |
| return model, cfg, "V4 (RoPE+RMSNorm+SwiGLU+QK-norm, 50M)" |
| if version == "v3": |
| if not HAS_V3: |
| raise ImportError("V3 checkpoint ama model.py yok.") |
| cfg = GPTConfig(**cfg_dict) |
| model = GPT(cfg).to(device) |
| return model, cfg, "V3 (LayerNorm+GELU+learned PE)" |
| raise ValueError(f"Bilinmeyen version: {version}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--prompt", type=str, default="Türkiye") |
| parser.add_argument("--max-tokens", type=int, default=200) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--repetition-penalty", type=float, default=1.15) |
| parser.add_argument("--no-repeat-ngram", type=int, default=3) |
| parser.add_argument("--num-samples", type=int, default=3) |
| parser.add_argument("--ckpt", type=str, default=None, |
| help="Tam checkpoint yolu (yoksa --version + best/latest)") |
| parser.add_argument("--version", type=str, default="v5", |
| choices=["v3", "v4", "v5"], |
| help="Hangi run dizini (--ckpt verilmediyse)") |
| parser.add_argument("--latest", action="store_true", |
| help="best yerine latest checkpoint'i kullan") |
| parser.add_argument("--chat", action="store_true", |
| help="SFT/Instruct ChatML formatı uygula (otomatik tespit de var)") |
| parser.add_argument("--instruction", type=str, default=None, |
| help="ChatML için ayrı instruction (input ile birlikte)") |
| parser.add_argument("--tokenizer", type=str, default=None, |
| help="Tokenizer dosya yolu (yoksa version'a göre seç)") |
| parser.add_argument("--seed", type=int, default=None) |
| parser.add_argument("--device", type=str, default=None, |
| choices=["cuda", "cpu"]) |
| args = parser.parse_args() |
|
|
| if args.seed is not None: |
| torch.manual_seed(args.seed) |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| |
| if args.ckpt: |
| ckpt_path = Path(args.ckpt) |
| else: |
| run_dir = RUN_DIRS[args.version] |
| name = "latest_ckpt.pt" if args.latest else "best_ckpt.pt" |
| ckpt_path = run_dir / name |
| if not ckpt_path.exists(): |
| |
| alt = run_dir / ("best_ckpt.pt" if args.latest else "latest_ckpt.pt") |
| if alt.exists(): |
| ckpt_path = alt |
| print(f" ({name} yok, {alt.name} kullanılıyor)") |
| else: |
| raise FileNotFoundError( |
| f"Checkpoint yok: {ckpt_path} (run_dir={run_dir})" |
| ) |
|
|
| print(f"Checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| |
| version = detect_version(ckpt) |
| if version != args.version and not args.ckpt: |
| print(f" ! Algılanan version={version}, --version={args.version}") |
| print(f"Version: {version}") |
|
|
| |
| model, cfg, desc = build_model(version, ckpt, device) |
| |
| state = ckpt["model"] |
| state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} |
| model.load_state_dict(state) |
| model.eval() |
|
|
| step = ckpt.get("step", "?") |
| val = ckpt.get("best_val", None) |
| val_str = f", val={val:.4f}" if val is not None and val != float("inf") else "" |
| n_params = model.num_params() if hasattr(model, "num_params") else \ |
| sum(p.numel() for p in model.parameters()) |
| print(f"Model: {desc}") |
| print(f" {n_params/1e6:.2f}M param (step={step}, " |
| f"version={ckpt.get('version','?')}{val_str})") |
|
|
| |
| tok_path = args.tokenizer or str(DATA_DIR / TOKENIZERS[version]) |
| if not Path(tok_path).exists(): |
| raise FileNotFoundError(f"Tokenizer yok: {tok_path}") |
| tokenizer = Tokenizer.from_file(tok_path) |
| print(f"Tokenizer: {Path(tok_path).name} " |
| f"(vocab={tokenizer.get_vocab_size()})") |
|
|
| |
| raw_version = ckpt.get("version", "") |
| auto_chat = any(t in str(raw_version) for t in ("instruct", "sft", "dpo", "chat")) |
| use_chat = args.chat or auto_chat |
|
|
| if use_chat: |
| user_msg = (f"{args.instruction}\n{args.prompt}" |
| if args.instruction else args.prompt) |
| formatted = f"<|user|>\n{user_msg}\n<|assistant|>\n" |
| print(f"\nChatML format AKTİF (version={raw_version})") |
| print(f"User prompt: {user_msg!r}") |
| else: |
| formatted = args.prompt |
| print(f"\nRaw prompt: {args.prompt!r}") |
|
|
| print(f"Settings: max={args.max_tokens}, temp={args.temperature}, " |
| f"top_k={args.top_k}, rep_pen={args.repetition_penalty}, " |
| f"no_rep_ngram={args.no_repeat_ngram}") |
| print("=" * 70) |
|
|
| ids = tokenizer.encode(formatted).ids |
| x = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| use_bf16 = device == "cuda" and torch.cuda.is_bf16_supported() |
| dtype = torch.bfloat16 if use_bf16 else torch.float32 |
|
|
| |
| max_ctx = cfg.block_size |
|
|
| for i in range(args.num_samples): |
| |
| cur_ids = ids |
| if len(cur_ids) >= max_ctx: |
| print(f" ! Prompt {len(cur_ids)} token, " |
| f"context {max_ctx} → kırpılıyor") |
| cur_ids = cur_ids[-(max_ctx - args.max_tokens):] |
| x_i = torch.tensor([cur_ids], dtype=torch.long, device=device) |
|
|
| amp_ctx = (torch.amp.autocast(device_type="cuda", dtype=dtype) |
| if device == "cuda" else torch.no_grad()) |
| with amp_ctx, torch.no_grad(): |
| out = model.generate( |
| x_i, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| repetition_penalty=args.repetition_penalty, |
| no_repeat_ngram_size=args.no_repeat_ngram, |
| ) |
| text = tokenizer.decode(out[0].tolist()) |
| print(f"\n--- Sample {i+1} ---") |
| print(text) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|