from __future__ import annotations import argparse import json import math import os import random import time from pathlib import Path from typing import Any import numpy as np import torch from rich.console import Console from searshorai.model import GPT, GPTConfig console = Console() PRESETS = { "quick_test": dict( n_layer=6, n_head=6, n_embd=384, block_size=256, batch_size=8, grad_accum=8, max_steps=1000, ), "gpu_16gb": dict( n_layer=10, n_head=10, n_embd=640, block_size=512, batch_size=4, grad_accum=16, max_steps=20000, ), "rtx3090_8h": dict( n_layer=12, n_head=12, n_embd=768, block_size=512, batch_size=8, grad_accum=16, max_steps=20000, ), "rtx3090_quality": dict( n_layer=16, n_head=16, n_embd=1024, block_size=512, batch_size=4, grad_accum=24, max_steps=30000, ), "gpu_40gb_quality": dict( n_layer=20, n_head=16, n_embd=1024, block_size=768, batch_size=4, grad_accum=32, max_steps=40000, ), } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train a GPT-style language model from scratch.") parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103")) parser.add_argument("--out_dir", type=Path, default=Path("runs/wikitext-gpt")) parser.add_argument("--preset", choices=PRESETS.keys(), default="gpu_16gb") parser.add_argument("--resume", type=Path, default=None) parser.add_argument("--reset_optimizer", action="store_true") parser.add_argument("--reset_step", action="store_true", help="When resuming, restart step counter at 0 (useful when restarting a fresh schedule).") parser.add_argument("--n_layer", type=int, default=None) parser.add_argument("--n_head", type=int, default=None) parser.add_argument("--n_embd", type=int, default=None) parser.add_argument("--block_size", type=int, default=None) parser.add_argument("--batch_size", type=int, default=None, help="Micro-batch size.") parser.add_argument("--grad_accum", type=int, default=None) parser.add_argument("--max_steps", type=int, default=None) parser.add_argument("--learning_rate", type=float, default=2.5e-4) parser.add_argument("--min_lr", type=float, default=2.5e-5) parser.add_argument("--warmup_steps", type=int, default=1000) parser.add_argument("--weight_decay", type=float, default=0.1) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--eval_interval", type=int, default=500) parser.add_argument("--eval_iters", type=int, default=100) parser.add_argument("--save_interval", type=int, default=1000) parser.add_argument("--log_interval", type=int, default=20) parser.add_argument("--seed", type=int, default=1337) parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"]) parser.add_argument("--compile", action="store_true") parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument( "--no_gradient_checkpointing", "--no-gradient-checkpointing", action="store_true", help="Disable checkpointing when resuming from a checkpoint that was trained with it.", ) parser.add_argument("--eval_only", action="store_true") parser.add_argument("--always_save_checkpoint", action="store_true") parser.add_argument("--save_optimizer", action="store_true") return parser.parse_args() def apply_preset(args: argparse.Namespace) -> argparse.Namespace: preset = PRESETS[args.preset] for key, value in preset.items(): if getattr(args, key) is None: setattr(args, key, value) return args def setup_reproducibility(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True def choose_device(args: argparse.Namespace) -> str: if args.device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" if args.device == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.") return args.device def choose_dtype(args: argparse.Namespace, device: str) -> torch.dtype: if device == "cpu": return torch.float32 if args.dtype == "float32": return torch.float32 if args.dtype == "float16": return torch.float16 if args.dtype == "bfloat16": if torch.cuda.is_bf16_supported(): return torch.bfloat16 console.print("[yellow]bfloat16 requested but not supported. Falling back to float16.[/yellow]") return torch.float16 if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 def make_autocast_context(device: str, dtype: torch.dtype): enabled = device == "cuda" and dtype in (torch.float16, torch.bfloat16) return torch.amp.autocast(device_type=device, dtype=dtype, enabled=enabled) def make_grad_scaler(device: str, dtype: torch.dtype): enabled = device == "cuda" and dtype == torch.float16 try: return torch.amp.GradScaler("cuda", enabled=enabled) except TypeError: return torch.cuda.amp.GradScaler(enabled=enabled) def get_lr(step: int, args: argparse.Namespace) -> float: if step < args.warmup_steps: return args.learning_rate * step / max(1, args.warmup_steps) if step > args.max_steps: return args.min_lr decay_ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps) decay_ratio = min(1.0, max(0.0, decay_ratio)) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return args.min_lr + coeff * (args.learning_rate - args.min_lr) def load_json(path: Path) -> dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Missing required file: {path}") return json.loads(path.read_text(encoding="utf-8")) def validate_meta(meta: dict[str, Any]) -> None: required_keys = ["vocab_size", "dtype"] for key in required_keys: if key not in meta: raise KeyError(f"meta.json is missing required key: {key}") if meta["dtype"] not in ("uint16", "uint32"): raise ValueError(f"Unsupported meta dtype: {meta['dtype']}. Expected uint16 or uint32.") if int(meta["vocab_size"]) <= 0: raise ValueError("meta.json vocab_size must be greater than zero.") if meta["dtype"] == "uint16" and int(meta["vocab_size"]) > 65535: raise ValueError("meta dtype is uint16 but vocab_size is greater than 65535. Use uint32 data files.") def load_memmap(path: Path, dtype: str) -> np.memmap: if not path.exists(): raise FileNotFoundError(f"Missing required file: {path}") np_dtype = np.uint16 if dtype == "uint16" else np.uint32 return np.memmap(path, dtype=np_dtype, mode="r") def validate_dataset(train_data: np.memmap, val_data: np.memmap, block_size: int, vocab_size: int) -> None: min_required = block_size + 2 if len(train_data) < min_required: raise ValueError( f"train.bin is too small. Need at least {min_required} tokens for block_size={block_size}, " f"but got {len(train_data)}." ) if len(val_data) < min_required: raise ValueError( f"val.bin is too small. Need at least {min_required} tokens for block_size={block_size}, " f"but got {len(val_data)}." ) sample_count = min(10000, len(train_data)) sample_positions = np.linspace(0, len(train_data) - 1, sample_count, dtype=np.int64) sample = np.asarray(train_data[sample_positions], dtype=np.int64) max_token = int(sample.max()) min_token = int(sample.min()) if min_token < 0: raise ValueError(f"Dataset contains negative token id: {min_token}") if max_token >= vocab_size: raise ValueError( f"Dataset token id {max_token} is >= vocab_size {vocab_size}. " "This usually means tokenizer/meta/train.bin mismatch." ) def get_batch( data: np.memmap, batch_size: int, block_size: int, device: str, ) -> tuple[torch.Tensor, torch.Tensor]: """ Fast batch loader: one vectorized gather, then a single host->device transfer. The old code did batch_size python-level numpy slices per call, which was a major bottleneck. """ max_start = len(data) - block_size - 1 if max_start <= 0: raise ValueError("Dataset is too small for the configured block_size.") # Random start positions. ix = np.random.randint(0, max_start, size=(batch_size,), dtype=np.int64) # Allocate contiguous int64 arrays. memmap reads are cheap for sequential blocks. x_np = np.empty((batch_size, block_size), dtype=np.int64) y_np = np.empty((batch_size, block_size), dtype=np.int64) for row, start in enumerate(ix): x_np[row] = data[start : start + block_size] y_np[row] = data[start + 1 : start + 1 + block_size] x = torch.from_numpy(x_np) y = torch.from_numpy(y_np) if device == "cuda": x = x.pin_memory().to(device, non_blocking=True) y = y.pin_memory().to(device, non_blocking=True) else: x = x.to(device) y = y.to(device) return x, y @torch.no_grad() def estimate_loss( model: GPT, train_data: np.memmap, val_data: np.memmap, args: argparse.Namespace, device: str, autocast_ctx, ) -> dict[str, float]: out: dict[str, float] = {} model.eval() for split, data in [("train", train_data), ("val", val_data)]: losses = [] for _ in range(args.eval_iters): x, y = get_batch(data, args.batch_size, args.block_size, device) with autocast_ctx: _, loss = model(x, y) if torch.isfinite(loss): losses.append(float(loss.item())) out[split] = float(sum(losses) / max(1, len(losses))) model.train() return out def unwrap_model(model: GPT) -> GPT: if hasattr(model, "_orig_mod"): return model._orig_mod return model def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: cleaned = {} for key, value in state_dict.items(): if key.startswith("_orig_mod."): key = key[len("_orig_mod.") :] cleaned[key] = value return cleaned def optimizer_to_device(optimizer: torch.optim.Optimizer, device: str) -> None: for state in optimizer.state.values(): for key, value in state.items(): if isinstance(value, torch.Tensor): state[key] = value.to(device) def save_checkpoint( path: Path, model: GPT, optimizer: torch.optim.Optimizer | None, args: argparse.Namespace, step: int, best_val_loss: float, meta: dict[str, Any], ) -> None: raw_model = unwrap_model(model) checkpoint: dict[str, Any] = { "model": raw_model.state_dict(), "args": vars(args), "config": vars(raw_model.config), "step": step, "best_val_loss": best_val_loss, "meta": meta, } if args.save_optimizer and optimizer is not None: checkpoint["optimizer"] = optimizer.state_dict() torch.save(checkpoint, path) def write_run_config(args: argparse.Namespace, meta: dict[str, Any], device: str, dtype: torch.dtype) -> None: config_path = args.out_dir / "run_config.json" payload = { "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}, "meta": meta, "device": device, "dtype": str(dtype), "torch_version": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, } config_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") def build_model_from_checkpoint( ckpt_path: Path, device: str, args: argparse.Namespace, ) -> tuple[GPT, int, float, dict[str, Any]]: ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) config = GPTConfig(**ckpt["config"]) if hasattr(config, "gradient_checkpointing"): if args.no_gradient_checkpointing: config.gradient_checkpointing = False elif args.gradient_checkpointing: config.gradient_checkpointing = True model = GPT(config) state_dict = strip_compile_prefix(ckpt["model"]) model.load_state_dict(state_dict, strict=True) start_step = int(ckpt.get("step", 0)) best_val_loss = float(ckpt.get("best_val_loss", float("inf"))) checkpoint_meta = ckpt.get("meta", {}) return model, start_step, best_val_loss, checkpoint_meta def build_new_model(meta: dict[str, Any], args: argparse.Namespace) -> tuple[GPT, int, float]: config = GPTConfig( vocab_size=int(meta["vocab_size"]), block_size=int(args.block_size), n_layer=int(args.n_layer), n_head=int(args.n_head), n_embd=int(args.n_embd), dropout=float(args.dropout), gradient_checkpointing=bool(args.gradient_checkpointing), ) model = GPT(config) return model, 0, float("inf") def print_startup_info( model: GPT, args: argparse.Namespace, device: str, dtype: torch.dtype, train_data: np.memmap, val_data: np.memmap, start_step: int, ) -> None: raw_model = unwrap_model(model) tokens_per_step = args.batch_size * args.grad_accum * args.block_size if hasattr(raw_model, "num_parameters"): num_params = raw_model.num_parameters() else: num_params = sum(p.numel() for p in raw_model.parameters()) console.print("") console.print("[bold green]Training configuration[/bold green]") console.print(f"Device: {device}") console.print(f"Dtype: {dtype}") console.print(f"Preset: {args.preset}") console.print(f"Parameters: {num_params / 1e6:.2f}M") console.print(f"Layers: {args.n_layer}") console.print(f"Heads: {args.n_head}") console.print(f"Embedding size: {args.n_embd}") console.print(f"Block size: {args.block_size}") console.print(f"Batch size: {args.batch_size}") console.print(f"Grad accumulation: {args.grad_accum}") console.print(f"Tokens per step: {tokens_per_step:,}") console.print(f"Train tokens: {len(train_data):,}") console.print(f"Val tokens: {len(val_data):,}") console.print(f"Start step: {start_step:,}") console.print(f"Max steps: {args.max_steps:,}") console.print(f"Learning rate: {args.learning_rate:.2e}") console.print(f"Min LR: {args.min_lr:.2e}") console.print(f"Warmup steps: {args.warmup_steps:,}") console.print(f"Grad clip: {args.grad_clip}") console.print("") def main() -> None: args = apply_preset(parse_args()) args.out_dir.mkdir(parents=True, exist_ok=True) setup_reproducibility(args.seed) device = choose_device(args) dtype = choose_dtype(args, device) autocast_ctx = make_autocast_context(device, dtype) scaler = make_grad_scaler(device, dtype) meta_path = args.data_dir / "meta.json" meta = load_json(meta_path) validate_meta(meta) train_data = load_memmap(args.data_dir / "train.bin", meta["dtype"]) val_data = load_memmap(args.data_dir / "val.bin", meta["dtype"]) validate_dataset( train_data=train_data, val_data=val_data, block_size=int(args.block_size), vocab_size=int(meta["vocab_size"]), ) if args.resume is not None: console.print(f"[yellow]Resuming from checkpoint:[/yellow] {args.resume}") model, start_step, best_val_loss, checkpoint_meta = build_model_from_checkpoint(args.resume, device, args) if checkpoint_meta: meta = checkpoint_meta else: model, start_step, best_val_loss = build_new_model(meta, args) if args.reset_step: start_step = 0 best_val_loss = float("inf") console.print("[yellow]reset_step set: step counter restarted at 0.[/yellow]") model.to(device) optimizer = model.configure_optimizers( args.weight_decay, args.learning_rate, (0.9, 0.95), "cuda" if device == "cuda" else "cpu", ) if args.resume is not None and not args.reset_optimizer: ckpt = torch.load(args.resume, map_location=device, weights_only=False) if "optimizer" in ckpt: try: optimizer.load_state_dict(ckpt["optimizer"]) optimizer_to_device(optimizer, device) console.print("[green]Loaded optimizer state from checkpoint.[/green]") except Exception as exc: console.print(f"[yellow]Could not load optimizer state. Continuing with fresh optimizer. Error: {exc}[/yellow]") else: console.print("[yellow]Checkpoint has no optimizer state. Continuing with fresh optimizer.[/yellow]") elif args.resume is not None and args.reset_optimizer: console.print("[yellow]reset_optimizer set: starting with fresh Adam moments.[/yellow]") if args.compile: console.print("[cyan]Compiling model...[/cyan]") model = torch.compile(model) write_run_config(args, meta, device, dtype) print_startup_info(model, args, device, dtype, train_data, val_data, start_step) if args.eval_only: losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx) console.print(f"eval only: train {losses['train']:.4f}, val {losses['val']:.4f}") return model.train() tokens_per_step = args.batch_size * args.grad_accum * args.block_size start_time = time.time() last_log_time = start_time last_log_step = start_step for completed_step in range(start_step, args.max_steps): step = completed_step + 1 lr = get_lr(step, args) for param_group in optimizer.param_groups: param_group["lr"] = lr optimizer.zero_grad(set_to_none=True) loss_accum = 0.0 skipped_micro = 0 for _ in range(args.grad_accum): x, y = get_batch(train_data, args.batch_size, args.block_size, device) with autocast_ctx: _, loss = model(x, y) loss = loss / args.grad_accum if not torch.isfinite(loss): console.print(f"[yellow]Non-finite loss at step {step}, skipping micro-batch.[/yellow]") skipped_micro += 1 continue scaler.scale(loss).backward() loss_accum += float(loss.item()) if skipped_micro == args.grad_accum: # Whole step was bad. Skip the optimizer update. scaler.update() continue scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() if step % args.log_interval == 0 or step == start_step + 1: now = time.time() elapsed = max(now - last_log_time, 1e-9) steps_done = max(1, step - last_log_step) toks_per_sec = (tokens_per_step * steps_done) / elapsed last_log_time = now last_log_step = step console.print( f"step {step:7d} | " f"loss {loss_accum:.4f} | " f"lr {lr:.2e} | " f"grad {float(grad_norm):.2f} | " f"{toks_per_sec:,.0f} tok/s" ) should_eval = step % args.eval_interval == 0 or step == args.max_steps if should_eval: losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx) console.print( f"[bold]eval step {step}:[/bold] " f"train {losses['train']:.4f}, val {losses['val']:.4f}" ) if losses["val"] < best_val_loss: best_val_loss = losses["val"] save_checkpoint( args.out_dir / "best.pt", model, optimizer, args, step, best_val_loss, meta, ) console.print(f"[green]saved best checkpoint: val {best_val_loss:.4f}[/green]") if args.always_save_checkpoint: save_checkpoint( args.out_dir / f"step_{step}.pt", model, optimizer, args, step, best_val_loss, meta, ) if step % args.save_interval == 0: save_checkpoint( args.out_dir / "latest.pt", model, optimizer, args, step, best_val_loss, meta, ) console.print(f"[cyan]saved latest checkpoint at step {step}[/cyan]") save_checkpoint( args.out_dir / "latest.pt", model, optimizer, args, args.max_steps, best_val_loss, meta, ) elapsed_hours = (time.time() - start_time) / 3600.0 console.print("") console.print(f"[bold green]Finished in {elapsed_hours:.2f} hours.[/bold green]") console.print(f"[bold green]Best validation loss: {best_val_loss:.4f}[/bold green]") console.print(f"[bold green]Best checkpoint: {args.out_dir / 'best.pt'}[/bold green]") console.print(f"[bold green]Latest checkpoint: {args.out_dir / 'latest.pt'}[/bold green]") if __name__ == "__main__": main()