| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import random |
| import time |
| import unicodedata |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from rich.console import Console |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from searshorai.model import GPT, GPTConfig |
| from searshorai.tokenizer import TextTokenizer |
|
|
|
|
| console = Console() |
|
|
|
|
| @dataclass |
| class Example: |
| input_ids: list[int] |
| labels: list[int] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Stable supervised fine-tune for paragraph explanation.") |
|
|
| parser.add_argument("--base_checkpoint", type=Path, default=Path("runs/wikitext-gpt/best.pt")) |
| parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json")) |
| parser.add_argument("--sft_file", type=Path, default=Path("data/wikitext103/paragraph_sft.jsonl")) |
| parser.add_argument("--out_dir", type=Path, default=Path("runs/paragraph-explainer")) |
|
|
| parser.add_argument("--max_steps", type=int, default=8000) |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--grad_accum", type=int, default=8) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--min_lr", type=float, default=2e-6) |
| parser.add_argument("--warmup_steps", type=int, default=300) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--grad_clip", type=float, default=1.0) |
|
|
| parser.add_argument("--max_answer_tokens", type=int, default=220) |
| parser.add_argument("--min_answer_tokens", type=int, default=8) |
| parser.add_argument("--val_ratio", type=float, default=0.02) |
|
|
| parser.add_argument("--eval_interval", type=int, default=250) |
| parser.add_argument("--eval_batches", type=int, default=40) |
| parser.add_argument("--save_interval", type=int, default=500) |
| parser.add_argument("--log_interval", type=int, default=20) |
| parser.add_argument("--seed", type=int, default=1337) |
| parser.add_argument("--compile", action="store_true") |
| parser.add_argument("--resume", type=Path, default=None) |
|
|
| return parser.parse_args() |
|
|
|
|
| def clean_text(text: Any) -> str: |
| if text is None: |
| return "" |
| text = str(text) |
| text = text.replace("\ufffd", " ") |
| text = unicodedata.normalize("NFKC", text) |
| text = "".join(ch if (ch in ("\n", "\t") or ord(ch) >= 32) else " " for ch in text) |
| text = "\n".join(" ".join(line.split()) for line in text.splitlines()) |
| return text.strip() |
|
|
|
|
| def get_special_id(tok: TextTokenizer, name: str) -> int | None: |
| value = getattr(tok, name, None) |
| return int(value) if isinstance(value, int) else None |
|
|
|
|
| def ensure_eos(ids: list[int], eos_id: int | None) -> list[int]: |
| if eos_id is None: |
| return ids |
| if not ids or ids[-1] != eos_id: |
| return ids + [eos_id] |
| return ids |
|
|
|
|
| def get_lr(step: int, args: argparse.Namespace) -> float: |
| if step < args.warmup_steps: |
| return args.learning_rate * (step + 1) / max(1, args.warmup_steps) |
| ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * min(1.0, max(0.0, ratio)))) |
| return args.min_lr + coeff * (args.learning_rate - args.min_lr) |
|
|
|
|
| def read_prompt_answer(row: dict[str, Any]) -> tuple[str, str]: |
| """ |
| Supports these JSONL styles: |
| {"prompt": "...", "answer": "..."} |
| {"input": "...", "output": "..."} |
| {"paragraph": "...", "explanation": "..."} |
| {"text": "...", "answer": "..."} |
| """ |
| if "prompt" in row: |
| prompt = row.get("prompt", "") |
| elif "paragraph" in row: |
| prompt = f"Explain this paragraph in simple words:\n\n{row.get('paragraph', '')}\n\nExplanation:\n" |
| elif "text" in row: |
| prompt = f"Explain this paragraph in simple words:\n\n{row.get('text', '')}\n\nExplanation:\n" |
| else: |
| prompt = row.get("input", "") |
|
|
| answer = ( |
| row.get("answer") |
| if row.get("answer") is not None |
| else row.get("output") |
| if row.get("output") is not None |
| else row.get("explanation", "") |
| ) |
|
|
| return clean_text(prompt), clean_text(answer) |
|
|
|
|
| def load_examples(path: Path, tok: TextTokenizer, block_size: int, args: argparse.Namespace) -> list[Example]: |
| if not path.exists(): |
| raise FileNotFoundError(f"SFT file not found: {path}") |
|
|
| eos_id = get_special_id(tok, "eos_id") |
| examples: list[Example] = [] |
|
|
| skipped_empty = 0 |
| skipped_too_short = 0 |
| truncated_answers = 0 |
| bad_json = 0 |
|
|
| with path.open("r", encoding="utf-8", errors="replace") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| row = json.loads(line) |
| except json.JSONDecodeError: |
| bad_json += 1 |
| continue |
|
|
| prompt, answer = read_prompt_answer(row) |
| if not prompt or not answer: |
| skipped_empty += 1 |
| continue |
|
|
| prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False) |
|
|
| |
| answer_ids = tok.encode(answer, add_bos=False, add_eos=False) |
| if len(answer_ids) < args.min_answer_tokens: |
| skipped_too_short += 1 |
| continue |
| if len(answer_ids) > args.max_answer_tokens: |
| answer_ids = answer_ids[: args.max_answer_tokens] |
| truncated_answers += 1 |
| answer_ids = ensure_eos(answer_ids, eos_id) |
|
|
| |
| room_for_prompt = (block_size + 1) - len(answer_ids) |
| if room_for_prompt < 16: |
| |
| keep = max(16, block_size - 32) |
| answer_ids = answer_ids[: keep - 1] |
| answer_ids = ensure_eos(answer_ids, eos_id) |
| room_for_prompt = (block_size + 1) - len(answer_ids) |
|
|
| |
| if len(prompt_ids) > room_for_prompt: |
| |
| bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else [] |
| tail = prompt_ids[-(room_for_prompt - len(bos)) :] if room_for_prompt - len(bos) > 0 else [] |
| prompt_ids = bos + tail |
|
|
| full_ids = prompt_ids + answer_ids |
|
|
| if len(full_ids) > block_size + 1: |
| |
| full_ids = full_ids[: block_size + 1] |
| if eos_id is not None and full_ids[-1] != eos_id: |
| full_ids[-1] = eos_id |
|
|
| if len(full_ids) < 16: |
| skipped_too_short += 1 |
| continue |
|
|
| input_ids = full_ids[:-1] |
| next_ids = full_ids[1:] |
|
|
| |
| prompt_len = len(prompt_ids) |
| labels = [ |
| token_id if (position + 1) >= prompt_len else -100 |
| for position, token_id in enumerate(next_ids) |
| ] |
|
|
| if any(x != -100 for x in labels): |
| examples.append(Example(input_ids=input_ids, labels=labels)) |
|
|
| console.print( |
| f"Loaded {len(examples):,} examples | " |
| f"empty={skipped_empty:,}, short={skipped_too_short:,}, " |
| f"truncated_answers={truncated_answers:,}, bad_json={bad_json:,}" |
| ) |
| if len(examples) < 10: |
| raise RuntimeError("Too few valid SFT examples. Check your JSONL keys and tokenizer.") |
| return examples |
|
|
|
|
| def make_batch( |
| examples: list[Example], |
| batch_size: int, |
| pad_id: int, |
| device: str, |
| block_size: int, |
| ): |
| if len(examples) >= batch_size: |
| batch = random.sample(examples, batch_size) |
| else: |
| batch = random.choices(examples, k=batch_size) |
|
|
| xs = [] |
| ys = [] |
| for ex in batch: |
| ix = ex.input_ids[:block_size] |
| ly = ex.labels[:block_size] |
| xs.append(torch.tensor(ix, dtype=torch.long)) |
| ys.append(torch.tensor(ly, dtype=torch.long)) |
|
|
| x = pad_sequence(xs, batch_first=True, padding_value=pad_id) |
| y = pad_sequence(ys, batch_first=True, padding_value=-100) |
|
|
| if device == "cuda": |
| x = x.pin_memory().to(device, non_blocking=True) |
| y = y.pin_memory().to(device, non_blocking=True) |
| else: |
| x = x.to(device) |
| y = y.to(device) |
| return x, y |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, examples, args, pad_id, device, autocast_ctx, block_size) -> float: |
| model.eval() |
| losses: list[float] = [] |
| for _ in range(args.eval_batches): |
| x, y = make_batch(examples, args.batch_size, pad_id, device, block_size) |
| with autocast_ctx: |
| _, loss = model(x, y) |
| if torch.isfinite(loss): |
| losses.append(float(loss.item())) |
| model.train() |
| return sum(losses) / max(1, len(losses)) |
|
|
|
|
| def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| cleaned = {} |
| for key, value in state_dict.items(): |
| if key.startswith("_orig_mod."): |
| key = key[len("_orig_mod.") :] |
| cleaned[key] = value |
| return cleaned |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| model, |
| optimizer, |
| args: argparse.Namespace, |
| step: int, |
| best_val_loss: float, |
| meta: dict[str, Any], |
| ) -> None: |
| raw_model = model._orig_mod if hasattr(model, "_orig_mod") else model |
| meta = dict(meta or {}) |
| meta.update( |
| { |
| "task": "paragraph_explainer_sft", |
| "tokenizer": str(args.tokenizer), |
| "sft_file": str(args.sft_file), |
| "important": "Prompt tokens are masked; answer is EOS-safe truncated.", |
| } |
| ) |
| torch.save( |
| { |
| "model": raw_model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}, |
| "config": vars(raw_model.config), |
| "step": step, |
| "best_val_loss": best_val_loss, |
| "meta": meta, |
| }, |
| path, |
| ) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| args.out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| random.seed(args.seed) |
| torch.manual_seed(args.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(args.seed) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| device_type = "cuda" if device == "cuda" else "cpu" |
|
|
| if device == "cuda" and torch.cuda.is_bf16_supported(): |
| amp_dtype = torch.bfloat16 |
| console.print("AMP dtype: bfloat16") |
| elif device == "cuda": |
| amp_dtype = torch.float16 |
| console.print("AMP dtype: float16") |
| else: |
| amp_dtype = torch.float32 |
| console.print("AMP disabled on CPU") |
|
|
| autocast_ctx = torch.amp.autocast( |
| device_type=device_type, |
| dtype=amp_dtype, |
| enabled=(device == "cuda"), |
| ) |
|
|
| tok = TextTokenizer(args.tokenizer) |
| pad_id = int(getattr(tok, "pad_id", 0)) |
|
|
| if args.resume is not None: |
| ckpt_path = args.resume |
| console.print(f"Resuming SFT checkpoint: {ckpt_path}") |
| else: |
| ckpt_path = args.base_checkpoint |
| console.print(f"Starting from base checkpoint: {ckpt_path}") |
|
|
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| config = GPTConfig(**ckpt["config"]) |
| |
| config.dropout = 0.0 |
| model = GPT(config) |
| state_dict = strip_compile_prefix(ckpt["model"]) |
| model.load_state_dict(state_dict, strict=True) |
| model.to(device) |
|
|
| |
| if tok.vocab_size != model.config.vocab_size: |
| raise RuntimeError( |
| f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. " |
| "This is the most common cause of garbled output. Use the same tokenizer that produced the pretrain data." |
| ) |
|
|
| optimizer = model.configure_optimizers( |
| args.weight_decay, |
| args.learning_rate, |
| (0.9, 0.95), |
| device_type, |
| ) |
|
|
| start_step = 0 |
| best_val_loss = float("inf") |
| if args.resume is not None and "optimizer" in ckpt: |
| try: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| start_step = int(ckpt.get("step", 0)) + 1 |
| best_val_loss = float(ckpt.get("best_val_loss", float("inf"))) |
| console.print(f"Resume from step {start_step}, previous best val {best_val_loss:.4f}") |
| except Exception as exc: |
| console.print(f"[yellow]Could not load optimizer state, starting fresh: {exc}[/yellow]") |
|
|
| try: |
| scaler = torch.amp.GradScaler("cuda", enabled=(device == "cuda" and amp_dtype == torch.float16)) |
| except TypeError: |
| scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda" and amp_dtype == torch.float16)) |
|
|
| examples = load_examples(args.sft_file, tok, model.config.block_size, args) |
| random.shuffle(examples) |
|
|
| val_size = max(1, int(len(examples) * args.val_ratio)) |
| val_examples = examples[:val_size] |
| train_examples = examples[val_size:] |
| if not train_examples: |
| raise RuntimeError("No training examples after split.") |
|
|
| console.print( |
| f"Train={len(train_examples):,} | Val={len(val_examples):,} | " |
| f"Block size={model.config.block_size} | Device={device}" |
| ) |
|
|
| if args.compile: |
| console.print("Compiling model with torch.compile...") |
| model = torch.compile(model) |
|
|
| model.train() |
| block_size = model.config.block_size if not hasattr(model, "_orig_mod") else model._orig_mod.config.block_size |
| last_time = time.time() |
| last_step = start_step |
|
|
| for step in range(start_step, args.max_steps + 1): |
| lr = get_lr(step, args) |
| for group in optimizer.param_groups: |
| group["lr"] = lr |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss_accum = 0.0 |
| ok_micro_steps = 0 |
|
|
| for _ in range(args.grad_accum): |
| x, y = make_batch(train_examples, args.batch_size, pad_id, device, block_size) |
| with autocast_ctx: |
| _, loss = model(x, y) |
| loss = loss / args.grad_accum |
| if not torch.isfinite(loss): |
| console.print(f"[yellow]Skipping non-finite loss at step {step}[/yellow]") |
| continue |
| scaler.scale(loss).backward() |
| loss_accum += float(loss.item()) |
| ok_micro_steps += 1 |
|
|
| if ok_micro_steps == 0: |
| scaler.update() |
| continue |
|
|
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| if step % args.log_interval == 0: |
| now = time.time() |
| steps_done = max(1, step - last_step) |
| console.print( |
| f"step {step:6d} | loss {loss_accum:.4f} | " |
| f"lr {lr:.2e} | {(now - last_time) / steps_done:.2f}s/step" |
| ) |
| last_time = now |
| last_step = step |
|
|
| if step > 0 and (step % args.eval_interval == 0 or step == args.max_steps): |
| val_loss = evaluate(model, val_examples, args, pad_id, device, autocast_ctx, block_size) |
| console.print(f"eval step {step}: val {val_loss:.4f}") |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| save_checkpoint( |
| args.out_dir / "best.pt", |
| model, |
| optimizer, |
| args, |
| step, |
| best_val_loss, |
| ckpt.get("meta", {}), |
| ) |
| console.print(f"[green]saved best checkpoint: {best_val_loss:.4f}[/green]") |
|
|
| if step > 0 and step % args.save_interval == 0: |
| save_checkpoint( |
| args.out_dir / "latest.pt", |
| model, |
| optimizer, |
| args, |
| step, |
| best_val_loss, |
| ckpt.get("meta", {}), |
| ) |
|
|
| save_checkpoint( |
| args.out_dir / "latest.pt", |
| model, |
| optimizer, |
| args, |
| args.max_steps, |
| best_val_loss, |
| ckpt.get("meta", {}), |
| ) |
| console.print("Fine-tuning complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |