| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from rich.console import Console |
|
|
| from searshorai.model import GPT, GPTConfig |
|
|
|
|
| console = Console() |
|
|
|
|
| PRESETS = { |
| "quick_test": dict( |
| n_layer=6, |
| n_head=6, |
| n_embd=384, |
| block_size=256, |
| batch_size=8, |
| grad_accum=8, |
| max_steps=1000, |
| ), |
| "gpu_16gb": dict( |
| n_layer=10, |
| n_head=10, |
| n_embd=640, |
| block_size=512, |
| batch_size=4, |
| grad_accum=16, |
| max_steps=20000, |
| ), |
| "rtx3090_8h": dict( |
| n_layer=12, |
| n_head=12, |
| n_embd=768, |
| block_size=512, |
| batch_size=8, |
| grad_accum=16, |
| max_steps=20000, |
| ), |
| "rtx3090_quality": dict( |
| n_layer=16, |
| n_head=16, |
| n_embd=1024, |
| block_size=512, |
| batch_size=4, |
| grad_accum=24, |
| max_steps=30000, |
| ), |
| "gpu_40gb_quality": dict( |
| n_layer=20, |
| n_head=16, |
| n_embd=1024, |
| block_size=768, |
| batch_size=4, |
| grad_accum=32, |
| max_steps=40000, |
| ), |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Train a GPT-style language model from scratch.") |
|
|
| parser.add_argument("--data_dir", type=Path, default=Path("data/wikitext103")) |
| parser.add_argument("--out_dir", type=Path, default=Path("runs/wikitext-gpt")) |
|
|
| parser.add_argument("--preset", choices=PRESETS.keys(), default="gpu_16gb") |
|
|
| parser.add_argument("--resume", type=Path, default=None) |
| parser.add_argument("--reset_optimizer", action="store_true") |
| parser.add_argument("--reset_step", action="store_true", |
| help="When resuming, restart step counter at 0 (useful when restarting a fresh schedule).") |
|
|
| parser.add_argument("--n_layer", type=int, default=None) |
| parser.add_argument("--n_head", type=int, default=None) |
| parser.add_argument("--n_embd", type=int, default=None) |
| parser.add_argument("--block_size", type=int, default=None) |
|
|
| parser.add_argument("--batch_size", type=int, default=None, help="Micro-batch size.") |
| parser.add_argument("--grad_accum", type=int, default=None) |
| parser.add_argument("--max_steps", type=int, default=None) |
|
|
| parser.add_argument("--learning_rate", type=float, default=2.5e-4) |
| parser.add_argument("--min_lr", type=float, default=2.5e-5) |
| parser.add_argument("--warmup_steps", type=int, default=1000) |
| parser.add_argument("--weight_decay", type=float, default=0.1) |
| parser.add_argument("--dropout", type=float, default=0.0) |
| parser.add_argument("--grad_clip", type=float, default=1.0) |
|
|
| parser.add_argument("--eval_interval", type=int, default=500) |
| parser.add_argument("--eval_iters", type=int, default=100) |
| parser.add_argument("--save_interval", type=int, default=1000) |
| parser.add_argument("--log_interval", type=int, default=20) |
|
|
| parser.add_argument("--seed", type=int, default=1337) |
|
|
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"]) |
| parser.add_argument("--dtype", type=str, default="auto", choices=["auto", "float32", "float16", "bfloat16"]) |
|
|
| parser.add_argument("--compile", action="store_true") |
|
|
| parser.add_argument("--gradient_checkpointing", action="store_true") |
| parser.add_argument( |
| "--no_gradient_checkpointing", |
| "--no-gradient-checkpointing", |
| action="store_true", |
| help="Disable checkpointing when resuming from a checkpoint that was trained with it.", |
| ) |
|
|
| parser.add_argument("--eval_only", action="store_true") |
| parser.add_argument("--always_save_checkpoint", action="store_true") |
| parser.add_argument("--save_optimizer", action="store_true") |
|
|
| return parser.parse_args() |
|
|
|
|
| def apply_preset(args: argparse.Namespace) -> argparse.Namespace: |
| preset = PRESETS[args.preset] |
| for key, value in preset.items(): |
| if getattr(args, key) is None: |
| setattr(args, key, value) |
| return args |
|
|
|
|
| def setup_reproducibility(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def choose_device(args: argparse.Namespace) -> str: |
| if args.device == "auto": |
| return "cuda" if torch.cuda.is_available() else "cpu" |
| if args.device == "cuda" and not torch.cuda.is_available(): |
| raise RuntimeError("CUDA was requested, but torch.cuda.is_available() is False.") |
| return args.device |
|
|
|
|
| def choose_dtype(args: argparse.Namespace, device: str) -> torch.dtype: |
| if device == "cpu": |
| return torch.float32 |
| if args.dtype == "float32": |
| return torch.float32 |
| if args.dtype == "float16": |
| return torch.float16 |
| if args.dtype == "bfloat16": |
| if torch.cuda.is_bf16_supported(): |
| return torch.bfloat16 |
| console.print("[yellow]bfloat16 requested but not supported. Falling back to float16.[/yellow]") |
| return torch.float16 |
| if torch.cuda.is_bf16_supported(): |
| return torch.bfloat16 |
| return torch.float16 |
|
|
|
|
| def make_autocast_context(device: str, dtype: torch.dtype): |
| enabled = device == "cuda" and dtype in (torch.float16, torch.bfloat16) |
| return torch.amp.autocast(device_type=device, dtype=dtype, enabled=enabled) |
|
|
|
|
| def make_grad_scaler(device: str, dtype: torch.dtype): |
| enabled = device == "cuda" and dtype == torch.float16 |
| try: |
| return torch.amp.GradScaler("cuda", enabled=enabled) |
| except TypeError: |
| return torch.cuda.amp.GradScaler(enabled=enabled) |
|
|
|
|
| def get_lr(step: int, args: argparse.Namespace) -> float: |
| if step < args.warmup_steps: |
| return args.learning_rate * step / max(1, args.warmup_steps) |
| if step > args.max_steps: |
| return args.min_lr |
| decay_ratio = (step - args.warmup_steps) / max(1, args.max_steps - args.warmup_steps) |
| decay_ratio = min(1.0, max(0.0, decay_ratio)) |
| coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
| return args.min_lr + coeff * (args.learning_rate - args.min_lr) |
|
|
|
|
| def load_json(path: Path) -> dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Missing required file: {path}") |
| return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
| def validate_meta(meta: dict[str, Any]) -> None: |
| required_keys = ["vocab_size", "dtype"] |
| for key in required_keys: |
| if key not in meta: |
| raise KeyError(f"meta.json is missing required key: {key}") |
| if meta["dtype"] not in ("uint16", "uint32"): |
| raise ValueError(f"Unsupported meta dtype: {meta['dtype']}. Expected uint16 or uint32.") |
| if int(meta["vocab_size"]) <= 0: |
| raise ValueError("meta.json vocab_size must be greater than zero.") |
| if meta["dtype"] == "uint16" and int(meta["vocab_size"]) > 65535: |
| raise ValueError("meta dtype is uint16 but vocab_size is greater than 65535. Use uint32 data files.") |
|
|
|
|
| def load_memmap(path: Path, dtype: str) -> np.memmap: |
| if not path.exists(): |
| raise FileNotFoundError(f"Missing required file: {path}") |
| np_dtype = np.uint16 if dtype == "uint16" else np.uint32 |
| return np.memmap(path, dtype=np_dtype, mode="r") |
|
|
|
|
| def validate_dataset(train_data: np.memmap, val_data: np.memmap, block_size: int, vocab_size: int) -> None: |
| min_required = block_size + 2 |
| if len(train_data) < min_required: |
| raise ValueError( |
| f"train.bin is too small. Need at least {min_required} tokens for block_size={block_size}, " |
| f"but got {len(train_data)}." |
| ) |
| if len(val_data) < min_required: |
| raise ValueError( |
| f"val.bin is too small. Need at least {min_required} tokens for block_size={block_size}, " |
| f"but got {len(val_data)}." |
| ) |
|
|
| sample_count = min(10000, len(train_data)) |
| sample_positions = np.linspace(0, len(train_data) - 1, sample_count, dtype=np.int64) |
| sample = np.asarray(train_data[sample_positions], dtype=np.int64) |
| max_token = int(sample.max()) |
| min_token = int(sample.min()) |
| if min_token < 0: |
| raise ValueError(f"Dataset contains negative token id: {min_token}") |
| if max_token >= vocab_size: |
| raise ValueError( |
| f"Dataset token id {max_token} is >= vocab_size {vocab_size}. " |
| "This usually means tokenizer/meta/train.bin mismatch." |
| ) |
|
|
|
|
| def get_batch( |
| data: np.memmap, |
| batch_size: int, |
| block_size: int, |
| device: str, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Fast batch loader: one vectorized gather, then a single host->device transfer. |
| The old code did batch_size python-level numpy slices per call, which was a |
| major bottleneck. |
| """ |
| max_start = len(data) - block_size - 1 |
| if max_start <= 0: |
| raise ValueError("Dataset is too small for the configured block_size.") |
|
|
| |
| ix = np.random.randint(0, max_start, size=(batch_size,), dtype=np.int64) |
|
|
| |
| x_np = np.empty((batch_size, block_size), dtype=np.int64) |
| y_np = np.empty((batch_size, block_size), dtype=np.int64) |
| for row, start in enumerate(ix): |
| x_np[row] = data[start : start + block_size] |
| y_np[row] = data[start + 1 : start + 1 + block_size] |
|
|
| x = torch.from_numpy(x_np) |
| y = torch.from_numpy(y_np) |
|
|
| if device == "cuda": |
| x = x.pin_memory().to(device, non_blocking=True) |
| y = y.pin_memory().to(device, non_blocking=True) |
| else: |
| x = x.to(device) |
| y = y.to(device) |
| return x, y |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss( |
| model: GPT, |
| train_data: np.memmap, |
| val_data: np.memmap, |
| args: argparse.Namespace, |
| device: str, |
| autocast_ctx, |
| ) -> dict[str, float]: |
| out: dict[str, float] = {} |
| model.eval() |
| for split, data in [("train", train_data), ("val", val_data)]: |
| losses = [] |
| for _ in range(args.eval_iters): |
| x, y = get_batch(data, args.batch_size, args.block_size, device) |
| with autocast_ctx: |
| _, loss = model(x, y) |
| if torch.isfinite(loss): |
| losses.append(float(loss.item())) |
| out[split] = float(sum(losses) / max(1, len(losses))) |
| model.train() |
| return out |
|
|
|
|
| def unwrap_model(model: GPT) -> GPT: |
| if hasattr(model, "_orig_mod"): |
| return model._orig_mod |
| return model |
|
|
|
|
| def strip_compile_prefix(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| cleaned = {} |
| for key, value in state_dict.items(): |
| if key.startswith("_orig_mod."): |
| key = key[len("_orig_mod.") :] |
| cleaned[key] = value |
| return cleaned |
|
|
|
|
| def optimizer_to_device(optimizer: torch.optim.Optimizer, device: str) -> None: |
| for state in optimizer.state.values(): |
| for key, value in state.items(): |
| if isinstance(value, torch.Tensor): |
| state[key] = value.to(device) |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| model: GPT, |
| optimizer: torch.optim.Optimizer | None, |
| args: argparse.Namespace, |
| step: int, |
| best_val_loss: float, |
| meta: dict[str, Any], |
| ) -> None: |
| raw_model = unwrap_model(model) |
| checkpoint: dict[str, Any] = { |
| "model": raw_model.state_dict(), |
| "args": vars(args), |
| "config": vars(raw_model.config), |
| "step": step, |
| "best_val_loss": best_val_loss, |
| "meta": meta, |
| } |
| if args.save_optimizer and optimizer is not None: |
| checkpoint["optimizer"] = optimizer.state_dict() |
| torch.save(checkpoint, path) |
|
|
|
|
| def write_run_config(args: argparse.Namespace, meta: dict[str, Any], device: str, dtype: torch.dtype) -> None: |
| config_path = args.out_dir / "run_config.json" |
| payload = { |
| "args": {k: (str(v) if isinstance(v, Path) else v) for k, v in vars(args).items()}, |
| "meta": meta, |
| "device": device, |
| "dtype": str(dtype), |
| "torch_version": torch.__version__, |
| "cuda_available": torch.cuda.is_available(), |
| "cuda_device_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, |
| } |
| config_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") |
|
|
|
|
| def build_model_from_checkpoint( |
| ckpt_path: Path, |
| device: str, |
| args: argparse.Namespace, |
| ) -> tuple[GPT, int, float, dict[str, Any]]: |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| config = GPTConfig(**ckpt["config"]) |
| if hasattr(config, "gradient_checkpointing"): |
| if args.no_gradient_checkpointing: |
| config.gradient_checkpointing = False |
| elif args.gradient_checkpointing: |
| config.gradient_checkpointing = True |
| model = GPT(config) |
| state_dict = strip_compile_prefix(ckpt["model"]) |
| model.load_state_dict(state_dict, strict=True) |
| start_step = int(ckpt.get("step", 0)) |
| best_val_loss = float(ckpt.get("best_val_loss", float("inf"))) |
| checkpoint_meta = ckpt.get("meta", {}) |
| return model, start_step, best_val_loss, checkpoint_meta |
|
|
|
|
| def build_new_model(meta: dict[str, Any], args: argparse.Namespace) -> tuple[GPT, int, float]: |
| config = GPTConfig( |
| vocab_size=int(meta["vocab_size"]), |
| block_size=int(args.block_size), |
| n_layer=int(args.n_layer), |
| n_head=int(args.n_head), |
| n_embd=int(args.n_embd), |
| dropout=float(args.dropout), |
| gradient_checkpointing=bool(args.gradient_checkpointing), |
| ) |
| model = GPT(config) |
| return model, 0, float("inf") |
|
|
|
|
| def print_startup_info( |
| model: GPT, |
| args: argparse.Namespace, |
| device: str, |
| dtype: torch.dtype, |
| train_data: np.memmap, |
| val_data: np.memmap, |
| start_step: int, |
| ) -> None: |
| raw_model = unwrap_model(model) |
| tokens_per_step = args.batch_size * args.grad_accum * args.block_size |
| if hasattr(raw_model, "num_parameters"): |
| num_params = raw_model.num_parameters() |
| else: |
| num_params = sum(p.numel() for p in raw_model.parameters()) |
|
|
| console.print("") |
| console.print("[bold green]Training configuration[/bold green]") |
| console.print(f"Device: {device}") |
| console.print(f"Dtype: {dtype}") |
| console.print(f"Preset: {args.preset}") |
| console.print(f"Parameters: {num_params / 1e6:.2f}M") |
| console.print(f"Layers: {args.n_layer}") |
| console.print(f"Heads: {args.n_head}") |
| console.print(f"Embedding size: {args.n_embd}") |
| console.print(f"Block size: {args.block_size}") |
| console.print(f"Batch size: {args.batch_size}") |
| console.print(f"Grad accumulation: {args.grad_accum}") |
| console.print(f"Tokens per step: {tokens_per_step:,}") |
| console.print(f"Train tokens: {len(train_data):,}") |
| console.print(f"Val tokens: {len(val_data):,}") |
| console.print(f"Start step: {start_step:,}") |
| console.print(f"Max steps: {args.max_steps:,}") |
| console.print(f"Learning rate: {args.learning_rate:.2e}") |
| console.print(f"Min LR: {args.min_lr:.2e}") |
| console.print(f"Warmup steps: {args.warmup_steps:,}") |
| console.print(f"Grad clip: {args.grad_clip}") |
| console.print("") |
|
|
|
|
| def main() -> None: |
| args = apply_preset(parse_args()) |
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| setup_reproducibility(args.seed) |
|
|
| device = choose_device(args) |
| dtype = choose_dtype(args, device) |
| autocast_ctx = make_autocast_context(device, dtype) |
| scaler = make_grad_scaler(device, dtype) |
|
|
| meta_path = args.data_dir / "meta.json" |
| meta = load_json(meta_path) |
| validate_meta(meta) |
|
|
| train_data = load_memmap(args.data_dir / "train.bin", meta["dtype"]) |
| val_data = load_memmap(args.data_dir / "val.bin", meta["dtype"]) |
| validate_dataset( |
| train_data=train_data, |
| val_data=val_data, |
| block_size=int(args.block_size), |
| vocab_size=int(meta["vocab_size"]), |
| ) |
|
|
| if args.resume is not None: |
| console.print(f"[yellow]Resuming from checkpoint:[/yellow] {args.resume}") |
| model, start_step, best_val_loss, checkpoint_meta = build_model_from_checkpoint(args.resume, device, args) |
| if checkpoint_meta: |
| meta = checkpoint_meta |
| else: |
| model, start_step, best_val_loss = build_new_model(meta, args) |
|
|
| if args.reset_step: |
| start_step = 0 |
| best_val_loss = float("inf") |
| console.print("[yellow]reset_step set: step counter restarted at 0.[/yellow]") |
|
|
| model.to(device) |
|
|
| optimizer = model.configure_optimizers( |
| args.weight_decay, |
| args.learning_rate, |
| (0.9, 0.95), |
| "cuda" if device == "cuda" else "cpu", |
| ) |
|
|
| if args.resume is not None and not args.reset_optimizer: |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) |
| if "optimizer" in ckpt: |
| try: |
| optimizer.load_state_dict(ckpt["optimizer"]) |
| optimizer_to_device(optimizer, device) |
| console.print("[green]Loaded optimizer state from checkpoint.[/green]") |
| except Exception as exc: |
| console.print(f"[yellow]Could not load optimizer state. Continuing with fresh optimizer. Error: {exc}[/yellow]") |
| else: |
| console.print("[yellow]Checkpoint has no optimizer state. Continuing with fresh optimizer.[/yellow]") |
| elif args.resume is not None and args.reset_optimizer: |
| console.print("[yellow]reset_optimizer set: starting with fresh Adam moments.[/yellow]") |
|
|
| if args.compile: |
| console.print("[cyan]Compiling model...[/cyan]") |
| model = torch.compile(model) |
|
|
| write_run_config(args, meta, device, dtype) |
| print_startup_info(model, args, device, dtype, train_data, val_data, start_step) |
|
|
| if args.eval_only: |
| losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx) |
| console.print(f"eval only: train {losses['train']:.4f}, val {losses['val']:.4f}") |
| return |
|
|
| model.train() |
| tokens_per_step = args.batch_size * args.grad_accum * args.block_size |
|
|
| start_time = time.time() |
| last_log_time = start_time |
| last_log_step = start_step |
|
|
| for completed_step in range(start_step, args.max_steps): |
| step = completed_step + 1 |
|
|
| lr = get_lr(step, args) |
| for param_group in optimizer.param_groups: |
| param_group["lr"] = lr |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss_accum = 0.0 |
| skipped_micro = 0 |
|
|
| for _ in range(args.grad_accum): |
| x, y = get_batch(train_data, args.batch_size, args.block_size, device) |
| with autocast_ctx: |
| _, loss = model(x, y) |
| loss = loss / args.grad_accum |
| if not torch.isfinite(loss): |
| console.print(f"[yellow]Non-finite loss at step {step}, skipping micro-batch.[/yellow]") |
| skipped_micro += 1 |
| continue |
| scaler.scale(loss).backward() |
| loss_accum += float(loss.item()) |
|
|
| if skipped_micro == args.grad_accum: |
| |
| scaler.update() |
| continue |
|
|
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| if step % args.log_interval == 0 or step == start_step + 1: |
| now = time.time() |
| elapsed = max(now - last_log_time, 1e-9) |
| steps_done = max(1, step - last_log_step) |
| toks_per_sec = (tokens_per_step * steps_done) / elapsed |
| last_log_time = now |
| last_log_step = step |
| console.print( |
| f"step {step:7d} | " |
| f"loss {loss_accum:.4f} | " |
| f"lr {lr:.2e} | " |
| f"grad {float(grad_norm):.2f} | " |
| f"{toks_per_sec:,.0f} tok/s" |
| ) |
|
|
| should_eval = step % args.eval_interval == 0 or step == args.max_steps |
| if should_eval: |
| losses = estimate_loss(model, train_data, val_data, args, device, autocast_ctx) |
| console.print( |
| f"[bold]eval step {step}:[/bold] " |
| f"train {losses['train']:.4f}, val {losses['val']:.4f}" |
| ) |
| if losses["val"] < best_val_loss: |
| best_val_loss = losses["val"] |
| save_checkpoint( |
| args.out_dir / "best.pt", |
| model, |
| optimizer, |
| args, |
| step, |
| best_val_loss, |
| meta, |
| ) |
| console.print(f"[green]saved best checkpoint: val {best_val_loss:.4f}[/green]") |
| if args.always_save_checkpoint: |
| save_checkpoint( |
| args.out_dir / f"step_{step}.pt", |
| model, |
| optimizer, |
| args, |
| step, |
| best_val_loss, |
| meta, |
| ) |
|
|
| if step % args.save_interval == 0: |
| save_checkpoint( |
| args.out_dir / "latest.pt", |
| model, |
| optimizer, |
| args, |
| step, |
| best_val_loss, |
| meta, |
| ) |
| console.print(f"[cyan]saved latest checkpoint at step {step}[/cyan]") |
|
|
| save_checkpoint( |
| args.out_dir / "latest.pt", |
| model, |
| optimizer, |
| args, |
| args.max_steps, |
| best_val_loss, |
| meta, |
| ) |
|
|
| elapsed_hours = (time.time() - start_time) / 3600.0 |
| console.print("") |
| console.print(f"[bold green]Finished in {elapsed_hours:.2f} hours.[/bold green]") |
| console.print(f"[bold green]Best validation loss: {best_val_loss:.4f}[/bold green]") |
| console.print(f"[bold green]Best checkpoint: {args.out_dir / 'best.pt'}[/bold green]") |
| console.print(f"[bold green]Latest checkpoint: {args.out_dir / 'latest.pt'}[/bold green]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |