""" Derived from Andrej Karpathy's nanochat project. MIT License Copyright (c) 2025 Andrej Karpathy Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. """ from __future__ import annotations import argparse import csv from collections import defaultdict from dataclasses import dataclass from datetime import datetime import json import math import os from pathlib import Path import random import statistics import sys import time from typing import Callable import numpy as np import torch from dropout_decay.data import ( encode_corpus, load_cached_splits, resolve_paths, train_or_load_tokenizer, ) from dropout_decay.license import NANOCHAT_ATTRIBUTION from dropout_decay.model import DropoutGPT, GPTConfig from dropout_decay.scheduler import DropoutDecayConfig, DropoutDecayScheduler DEFAULT_DROPOUT_RATES = [0.0, 0.02, 0.05, 0.08, 0.10, 0.14, 0.20, 0.30, 0.50] SUMMARY_FIELDS = [ "run_mode", "condition", "condition_kind", "stage", "token_limit", "model_name", "n_layer", "n_head", "n_embd", "parameters", "dropout_initial", "dropout_final", "dropout_schedule", "n", "mean_train_eval_loss", "std_train_eval_loss", "mean_val_eval_loss", "std_val_eval_loss", "mean_generalization_gap", "std_generalization_gap", ] SELECTION_FIELDS = [ "run_mode", "token_limit", "model_name", "n_layer", "n_head", "n_embd", "parameters", "n", "best_dropout", "best_val_loss", "best_val_std", "plateau_start_dropout", "plateau_end_dropout", "plateau_delta", "zero_dropout_val_loss", "zero_minus_best", "best_nonzero_dropout", "best_nonzero_val_loss", "zero_minus_best_nonzero", "max_dropout", "max_dropout_val_loss", "max_dropout_minus_best", "has_nonzero_optimum", "meets_target_dropout", "curve_json", ] @dataclass(frozen=True) class ModelSpec: name: str n_layer: int n_head: int n_embd: int def config(self, vocab_size: int, block_size: int, dropout: float) -> GPTConfig: return GPTConfig( block_size=block_size, vocab_size=vocab_size, n_layer=self.n_layer, n_head=self.n_head, n_embd=self.n_embd, dropout=dropout, ) def to_dict(self) -> dict[str, int | str]: return { "model_name": self.name, "n_layer": self.n_layer, "n_head": self.n_head, "n_embd": self.n_embd, } @dataclass(frozen=True) class DropoutCondition: name: str kind: str initial: float final: float schedule: str = "constant" decay_tokens: int | None = None anchors: tuple[tuple[int, float], ...] = () def to_dict(self) -> dict: return { "name": self.name, "kind": self.kind, "initial": self.initial, "final": self.final, "schedule": self.schedule, "decay_tokens": self.decay_tokens, "anchors": [list(anchor) for anchor in self.anchors], } def make_fn( self, fallback_decay_tokens: int, unique_tokens: int | None = None, ) -> Callable[[int], float]: if self.kind == "static": return lambda _tokens_seen, p=self.initial: p if self.kind == "anchor_decay": if unique_tokens is None: raise ValueError("anchor_decay conditions require unique_tokens") p = anchor_dropout(unique_tokens, self.anchors) return lambda _tokens_seen, p=p: p scheduler = DropoutDecayScheduler( DropoutDecayConfig( initial_dropout=self.initial, final_dropout=self.final, decay_tokens=self.decay_tokens or fallback_decay_tokens, schedule=self.schedule, ) ) return scheduler.value def anchor_dropout(unique_tokens: int, anchors: tuple[tuple[int, float], ...]) -> float: if not anchors: raise ValueError("anchor dropout schedule requires at least one anchor") ordered = sorted(anchors) if unique_tokens <= ordered[0][0]: return ordered[0][1] if unique_tokens >= ordered[-1][0]: return ordered[-1][1] log_unique = math.log(unique_tokens) for (left_tokens, left_dropout), (right_tokens, right_dropout) in zip( ordered, ordered[1:] ): if left_tokens <= unique_tokens <= right_tokens: left_log = math.log(left_tokens) right_log = math.log(right_tokens) mix = (log_unique - left_log) / (right_log - left_log) return left_dropout + mix * (right_dropout - left_dropout) return ordered[-1][1] def assert_mps_only() -> torch.device: if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") == "1": raise SystemExit( "Refusing to run: PYTORCH_ENABLE_MPS_FALLBACK=1 could execute CPU fallbacks." ) if not torch.backends.mps.is_built(): raise SystemExit("PyTorch was not built with MPS. Stopping.") if not torch.backends.mps.is_available(): raise SystemExit("MPS is not available. Stopping before any Torch experiment work.") device = torch.device("mps") torch.set_default_device(device) return device def clean_name(value: str) -> str: cleaned = "".join( c if c.isalnum() or c in {"-", "_"} else "_" for c in value.strip() ) return cleaned.strip("_") or "model" def rate_label(rate: float) -> str: label = f"{rate:.3f}".rstrip("0").rstrip(".") return label if label else "0" def parse_model_spec(raw: str) -> ModelSpec: if "=" in raw: name, dims = raw.split("=", 1) name = clean_name(name) else: dims = raw name = "" parts = dims.lower().replace(",", "x").split("x") if len(parts) != 3: raise argparse.ArgumentTypeError( "model specs must look like 8x8x256 or name=8x8x256" ) try: n_layer, n_head, n_embd = [int(part) for part in parts] except ValueError as exc: raise argparse.ArgumentTypeError("model dimensions must be integers") from exc if n_layer <= 0 or n_head <= 0 or n_embd <= 0: raise argparse.ArgumentTypeError("model dimensions must be positive") if n_embd % n_head != 0: raise argparse.ArgumentTypeError("n_embd must be divisible by n_head") if (n_embd // n_head) % 2 != 0: raise argparse.ArgumentTypeError("n_embd / n_head must be even for rotary attention") return ModelSpec( name or f"L{n_layer}_H{n_head}_D{n_embd}", n_layer, n_head, n_embd, ) def parse_decay_spec(raw: str) -> DropoutCondition: parts = raw.split(":") if len(parts) not in {3, 4, 5}: raise argparse.ArgumentTypeError( "decay specs must look like " "name:initial:final[:cosine|smoothstep|linear[:decay_tokens]]" ) name = clean_name(parts[0]) try: initial = float(parts[1]) final = float(parts[2]) decay_tokens = int(parts[4]) if len(parts) == 5 else None except ValueError as exc: raise argparse.ArgumentTypeError( "decay dropout values and decay_tokens must be numeric" ) from exc schedule = parts[3] if len(parts) >= 4 else "cosine" if schedule not in {"cosine", "smoothstep", "linear"}: raise argparse.ArgumentTypeError( "decay schedule must be cosine, smoothstep, or linear" ) return DropoutCondition( name=name, kind="decay", initial=initial, final=final, schedule=schedule, decay_tokens=decay_tokens, ) def parse_anchor_decay_spec(raw: str) -> DropoutCondition: if ":" not in raw: raise argparse.ArgumentTypeError( "anchor decay specs must look like name:250000=0.60,500000=0.40" ) name, raw_anchors = raw.split(":", 1) anchors: list[tuple[int, float]] = [] for piece in raw_anchors.split(","): if "=" not in piece: raise argparse.ArgumentTypeError( "anchor decay anchors must look like token_count=dropout" ) raw_tokens, raw_dropout = piece.split("=", 1) try: token_count = int(raw_tokens) dropout = float(raw_dropout) except ValueError as exc: raise argparse.ArgumentTypeError( "anchor token counts must be integers and dropout values numeric" ) from exc if token_count <= 0: raise argparse.ArgumentTypeError("anchor token counts must be positive") if not 0.0 <= dropout < 1.0: raise argparse.ArgumentTypeError("anchor dropout must satisfy 0 <= p < 1") anchors.append((token_count, dropout)) anchors = sorted(set(anchors)) if len(anchors) < 2: raise argparse.ArgumentTypeError("provide at least two dropout anchors") for left, right in zip(anchors, anchors[1:]): if right[1] > left[1]: raise argparse.ArgumentTypeError( "anchor dropout values must be non-increasing as token counts grow" ) return DropoutCondition( name=clean_name(name), kind="anchor_decay", initial=anchors[0][1], final=anchors[-1][1], schedule="log_prefix_anchor", anchors=tuple(anchors), ) def default_seeds(mode: str, seeds: list[int] | None) -> list[int]: if seeds: return seeds return [1] if mode == "screen_static" else [1, 2, 3] def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def batch_seed(seed: int, model: ModelSpec, dropout_code: int, stage: int | None) -> int: return ( seed * 1_000_003 + model.n_layer * 100_003 + model.n_head * 10_007 + model.n_embd * 101 + dropout_code * 37 + (stage or 0) * 997 ) def make_batch( tokens: np.ndarray, token_limit: int, batch_size: int, block_size: int, rng: np.random.Generator, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: limit = min(token_limit, len(tokens)) max_start = limit - block_size - 1 if max_start <= 0: raise ValueError("token_limit is too small for the requested block_size") starts = rng.integers(0, max_start, size=batch_size) x_np = np.stack([tokens[start : start + block_size] for start in starts]).astype( np.int64 ) y_np = np.stack( [tokens[start + 1 : start + 1 + block_size] for start in starts] ).astype(np.int64) return ( torch.tensor(x_np, dtype=torch.long, device=device), torch.tensor(y_np, dtype=torch.long, device=device), ) @torch.no_grad() def estimate_loss( model: DropoutGPT, tokens: np.ndarray, token_limit: int, batches: int, args: argparse.Namespace, device: torch.device, rng_seed: int, ) -> float: if batches <= 0: return float("nan") model.eval() rng = np.random.default_rng(rng_seed) losses: list[float] = [] for _ in range(batches): x, y = make_batch(tokens, token_limit, args.batch_size, args.block_size, rng, device) _, loss = model(x, y) losses.append(float(loss.item())) model.train() return float(statistics.fmean(losses)) def write_jsonl_row(handle, row: dict) -> None: handle.write(json.dumps(row, sort_keys=True) + "\n") handle.flush() def metric_key(row: dict) -> tuple: return ( row["run_mode"], row["condition"], row["condition_kind"], row.get("stage"), int(row["token_limit"]), row["model_name"], int(row["n_layer"]), int(row["n_head"]), int(row["n_embd"]), int(row["seed"]), float(row["dropout_initial"]), float(row["dropout_final"]), row["dropout_schedule"], ) def planned_metric_key( *, mode: str, condition: DropoutCondition, model_spec: ModelSpec, seed: int, token_limit: int, stage: int | None = None, ) -> tuple: return ( mode, condition.name, condition.kind, stage, int(token_limit), model_spec.name, int(model_spec.n_layer), int(model_spec.n_head), int(model_spec.n_embd), int(seed), float(condition.initial), float(condition.final), condition.schedule, ) def load_metrics(path: Path) -> list[dict]: if not path.exists(): return [] rows = [] for line in path.read_text(encoding="utf-8").splitlines(): if line.strip(): rows.append(json.loads(line)) return rows def train_segment( *, run_mode: str, condition: DropoutCondition, model_spec: ModelSpec, config: GPTConfig, train_tokens: np.ndarray, val_tokens: np.ndarray, token_limit: int, steps: int, seed: int, args: argparse.Namespace, device: torch.device, dropout_fn: Callable[[int], float], metrics_file, trace_file, stage: int | None = None, model: DropoutGPT | None = None, optimizer: torch.optim.Optimizer | None = None, tokens_seen_start: int = 0, ) -> tuple[DropoutGPT, torch.optim.Optimizer, int, dict]: if model is None: set_seed(seed) model = DropoutGPT(config).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay, ) else: torch.manual_seed(seed + 10_000 + (stage or 0)) if optimizer is None: raise ValueError("optimizer is required when reusing a model") model.train() dropout_code = int(round(condition.initial * 10_000)) rng = np.random.default_rng(batch_seed(seed, model_spec, dropout_code, stage)) tokens_seen = tokens_seen_start last_loss = float("nan") active_dropout = condition.initial t0 = time.time() for step in range(1, steps + 1): active_dropout = dropout_fn(tokens_seen) model.set_dropout(active_dropout) x, y = make_batch(train_tokens, token_limit, args.batch_size, args.block_size, rng, device) _, loss = model(x, y) optimizer.zero_grad(set_to_none=True) loss.backward() if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() tokens_seen += args.batch_size * args.block_size last_loss = float(loss.item()) if args.log_every > 0 and step % args.log_every == 0: write_jsonl_row( trace_file, { "event": "train_step", "run_mode": run_mode, "condition": condition.name, "model_name": model_spec.name, "seed": seed, "stage": stage, "step": step, "steps": steps, "token_limit": int(token_limit), "tokens_seen": int(tokens_seen), "dropout": float(active_dropout), "train_batch_loss": last_loss, }, ) if args.eval_every > 0 and step % args.eval_every == 0: train_eval = estimate_loss( model, train_tokens, token_limit, args.trace_eval_batches, args, device, rng_seed=seed + 20_000 + step, ) val_eval = estimate_loss( model, val_tokens, len(val_tokens), args.trace_eval_batches, args, device, rng_seed=seed + 30_000 + step, ) write_jsonl_row( trace_file, { "event": "eval_step", "run_mode": run_mode, "condition": condition.name, "model_name": model_spec.name, "seed": seed, "stage": stage, "step": step, "steps": steps, "token_limit": int(token_limit), "tokens_seen": int(tokens_seen), "dropout": float(active_dropout), "train_eval_loss": train_eval, "val_eval_loss": val_eval, "generalization_gap": val_eval - train_eval, }, ) train_eval = estimate_loss( model, train_tokens, token_limit, args.train_eval_batches, args, device, rng_seed=seed + 40_000 + (stage or 0), ) val_eval = estimate_loss( model, val_tokens, len(val_tokens), args.eval_batches, args, device, rng_seed=seed + 50_000 + (stage or 0), ) row = { "run_mode": run_mode, "condition": condition.name, "condition_kind": condition.kind, "seed": seed, "stage": stage, "token_limit": int(token_limit), "steps": int(steps), "tokens_seen": int(tokens_seen), "dropout_initial": float(condition.initial), "dropout_final": float(condition.final), "dropout_schedule": condition.schedule, "dropout_active_final": float(active_dropout), "train_loss_last": last_loss, "train_eval_loss": train_eval, "val_eval_loss": val_eval, "eval_loss": val_eval, "generalization_gap": val_eval - train_eval, "elapsed_sec": time.time() - t0, "parameters": model.num_parameters(), "model_config": config.to_dict(), **model_spec.to_dict(), } write_jsonl_row(metrics_file, row) return model, optimizer, tokens_seen, row def summarize(rows: list[dict]) -> list[dict]: groups: dict[tuple, list[dict]] = defaultdict(list) for row in rows: key = ( row["run_mode"], row["condition"], row["condition_kind"], row["stage"], row["token_limit"], row["model_name"], row["n_layer"], row["n_head"], row["n_embd"], row["parameters"], row["dropout_initial"], row["dropout_final"], row["dropout_schedule"], ) groups[key].append(row) summary: list[dict] = [] for group_rows in groups.values(): first = group_rows[0] item = {field: first[field] for field in SUMMARY_FIELDS if field in first} item["n"] = len(group_rows) for source, mean_key, std_key in [ ("train_eval_loss", "mean_train_eval_loss", "std_train_eval_loss"), ("val_eval_loss", "mean_val_eval_loss", "std_val_eval_loss"), ("generalization_gap", "mean_generalization_gap", "std_generalization_gap"), ]: values = [float(row[source]) for row in group_rows] item[mean_key] = statistics.fmean(values) item[std_key] = statistics.stdev(values) if len(values) > 1 else 0.0 summary.append(item) return sorted( summary, key=lambda row: ( row["run_mode"], row["model_name"], row["token_limit"], row["condition"], row["stage"] or -1, ), ) def write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None: with path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(rows) def format_duration(seconds: float) -> str: seconds = max(0, int(seconds)) hours, remainder = divmod(seconds, 3600) minutes, secs = divmod(remainder, 60) if hours: return f"{hours}h{minutes:02d}m" if minutes: return f"{minutes}m{secs:02d}s" return f"{secs}s" class ProgressMeter: def __init__(self, total: int): self.total = max(0, total) self.done = 0 self.started_at = time.time() def mark_done(self, row: dict) -> None: self.done += 1 elapsed = time.time() - self.started_at mean = elapsed / self.done if self.done else 0.0 remaining = max(0, self.total - self.done) eta = remaining * mean stage = row.get("stage") stage_text = f"stage={stage} " if stage is not None else "" print( "progress " f"{self.done}/{self.total} " f"eta={format_duration(eta)} " f"mode={row['run_mode']} " f"model={row['model_name']} " f"params={int(row['parameters']):,} " f"prefix={int(row['token_limit']):,} " f"{stage_text}" f"seed={row['seed']} " f"condition={row['condition']} " f"dropout={float(row['dropout_active_final']):.3f} " f"val={row['val_eval_loss']:.4f} " f"train={row['train_eval_loss']:.4f} " f"gap={row['generalization_gap']:.4f} " f"elapsed={format_duration(float(row['elapsed_sec']))}", flush=True, ) def build_model_selection(summary: list[dict], args: argparse.Namespace) -> list[dict]: groups: dict[tuple, list[dict]] = defaultdict(list) for row in summary: if row["condition_kind"] == "static" and row["stage"] is None: groups[(row["run_mode"], row["token_limit"], row["model_name"])].append(row) selection: list[dict] = [] for (run_mode, token_limit, model_name), rows in groups.items(): curve = sorted(rows, key=lambda row: row["dropout_initial"]) best = min(curve, key=lambda row: row["mean_val_eval_loss"]) plateau = [ row for row in curve if row["mean_val_eval_loss"] <= best["mean_val_eval_loss"] + args.plateau_delta ] zero = next((row for row in curve if row["dropout_initial"] == 0.0), None) nonzero = [row for row in curve if row["dropout_initial"] > 0.0] best_nonzero = min(nonzero, key=lambda row: row["mean_val_eval_loss"]) if nonzero else None max_dropout = max(curve, key=lambda row: row["dropout_initial"]) zero_loss = zero["mean_val_eval_loss"] if zero else None zero_minus_best = zero_loss - best["mean_val_eval_loss"] if zero_loss is not None else None zero_minus_best_nonzero = ( zero_loss - best_nonzero["mean_val_eval_loss"] if zero_loss is not None and best_nonzero is not None else None ) max_minus_best = max_dropout["mean_val_eval_loss"] - best["mean_val_eval_loss"] has_nonzero_optimum = ( best["dropout_initial"] > 0.0 and zero_minus_best is not None and zero_minus_best >= args.min_nonzero_margin and max_minus_best >= args.min_high_dropout_margin ) curve_json = json.dumps( [ { "dropout": row["dropout_initial"], "mean_val_loss": row["mean_val_eval_loss"], "std_val_loss": row["std_val_eval_loss"], "mean_train_loss": row["mean_train_eval_loss"], "mean_generalization_gap": row["mean_generalization_gap"], "n": row["n"], } for row in curve ], sort_keys=True, ) selection.append( { "run_mode": run_mode, "token_limit": token_limit, "model_name": model_name, "n_layer": best["n_layer"], "n_head": best["n_head"], "n_embd": best["n_embd"], "parameters": best["parameters"], "n": best["n"], "best_dropout": best["dropout_initial"], "best_val_loss": best["mean_val_eval_loss"], "best_val_std": best["std_val_eval_loss"], "plateau_start_dropout": min(row["dropout_initial"] for row in plateau), "plateau_end_dropout": max(row["dropout_initial"] for row in plateau), "plateau_delta": args.plateau_delta, "zero_dropout_val_loss": zero_loss, "zero_minus_best": zero_minus_best, "best_nonzero_dropout": ( best_nonzero["dropout_initial"] if best_nonzero else None ), "best_nonzero_val_loss": ( best_nonzero["mean_val_eval_loss"] if best_nonzero else None ), "zero_minus_best_nonzero": zero_minus_best_nonzero, "max_dropout": max_dropout["dropout_initial"], "max_dropout_val_loss": max_dropout["mean_val_eval_loss"], "max_dropout_minus_best": max_minus_best, "has_nonzero_optimum": has_nonzero_optimum, "meets_target_dropout": best["dropout_initial"] >= args.target_min_dropout, "curve_json": curve_json, } ) return sorted( selection, key=lambda row: ( not row["has_nonzero_optimum"], not row["meets_target_dropout"], row["best_val_loss"], ), ) def write_screen_markdown_summary(output_dir: Path, rows: list[dict]) -> None: if not rows: return static_rows = [ row for row in rows if row["run_mode"] == "screen_static" and row["condition_kind"] == "static" ] if not static_rows: return by_model_prefix_rate: dict[tuple[str, int, float], list[dict]] = defaultdict(list) for row in rows: if row["run_mode"] == "screen_static" and row["condition_kind"] == "static": by_model_prefix_rate[ ( row["model_name"], int(row["token_limit"]), float(row["dropout_initial"]), ) ].append(row) aggregates: list[dict] = [] for (model_name, prefix, dropout), group_rows in by_model_prefix_rate.items(): first = group_rows[0] val_losses = [float(row["val_eval_loss"]) for row in group_rows] train_losses = [float(row["train_eval_loss"]) for row in group_rows] gaps = [float(row["generalization_gap"]) for row in group_rows] aggregates.append( { "model_name": model_name, "token_limit": prefix, "dropout_initial": dropout, "n": len(group_rows), "mean_val_eval_loss": statistics.fmean(val_losses), "std_val_eval_loss": statistics.stdev(val_losses) if len(val_losses) > 1 else 0.0, "mean_train_eval_loss": statistics.fmean(train_losses), "std_train_eval_loss": statistics.stdev(train_losses) if len(train_losses) > 1 else 0.0, "mean_generalization_gap": statistics.fmean(gaps), "std_generalization_gap": statistics.stdev(gaps) if len(gaps) > 1 else 0.0, "parameters": int(first["parameters"]), "n_layer": int(first["n_layer"]), "n_head": int(first["n_head"]), "n_embd": int(first["n_embd"]), "block_size": int(first["model_config"]["block_size"]), "vocab_size": int(first["model_config"]["vocab_size"]), "tokens_seen": int(first["tokens_seen"]), "seeds": sorted({int(row["seed"]) for row in group_rows}), } ) by_model: dict[str, list[dict]] = defaultdict(list) for row in aggregates: by_model[row["model_name"]].append(row) model_rows = [] for model_name, model_group in by_model.items(): first = model_group[0] seeds = sorted({seed for row in model_group for seed in row["seeds"]}) model_rows.append( { "model_name": model_name, "parameters": first["parameters"], "n_layer": first["n_layer"], "n_head": first["n_head"], "n_embd": first["n_embd"], "block_size": first["block_size"], "vocab_size": first["vocab_size"], "seeds": seeds, } ) lines = [ "# Static Dropout Screen Summary", "", f"Run directory: `{output_dir}`", "", "## Models", "", "| Model | Params | Layers | Heads | Embedding | Block | Vocab | Seeds |", "|---|---:|---:|---:|---:|---:|---:|---|", ] for model in sorted(model_rows, key=lambda item: item["parameters"]): lines.append( "| " f"`{model['model_name']}` | {model['parameters']:,} | " f"{model['n_layer']} | {model['n_head']} | {model['n_embd']} | " f"{model['block_size']} | {model['vocab_size']} | " f"{', '.join(str(seed) for seed in model['seeds'])} |" ) lines.extend( [ "", "## Best Dropout By Model And Prefix", "", "| Model | Prefix tokens | Effective epochs | Best dropout | Mean val loss | Val std | Mean train loss | Mean gap | Plateau/bracket note |", "|---|---:|---:|---:|---:|---:|---:|---:|---|", ] ) for model_name, model_group in sorted(by_model.items()): by_prefix: dict[int, list[dict]] = defaultdict(list) for row in model_group: by_prefix[int(row["token_limit"])].append(row) for prefix, prefix_rows in sorted(by_prefix.items()): best = min(prefix_rows, key=lambda row: row["mean_val_eval_loss"]) rates = [float(row["dropout_initial"]) for row in prefix_rows] eff_epochs = float(best["tokens_seen"]) / prefix if best["dropout_initial"] == max(rates): note = "not bracketed; best at top of tested grid" elif best["dropout_initial"] == min(rates): note = "not bracketed; best at bottom of tested grid" else: note = "bracketed by tested grid" lines.append( "| " f"`{model_name}` | {prefix:,} | {eff_epochs:.2f} | " f"{best['dropout_initial']:.2f} | " f"{best['mean_val_eval_loss']:.4f} | " f"{best['std_val_eval_loss']:.4f} | " f"{best['mean_train_eval_loss']:.4f} | " f"{best['mean_generalization_gap']:.4f} | {note} |" ) for model_name, model_group in sorted(by_model.items()): by_prefix = defaultdict(list) for row in model_group: by_prefix[int(row["token_limit"])].append(row) lines.extend( [ "", f"## Model `{model_name}`", ] ) for prefix, prefix_rows in sorted(by_prefix.items()): eff_epochs = float(prefix_rows[0]["tokens_seen"]) / prefix lines.extend( [ "", f"### Prefix {prefix:,} Tokens ({eff_epochs:.2f} Effective Epochs)", "", "| Dropout | N | Mean val loss | Val std | Mean train loss | Train std | Mean gap | Gap std | Sampled tokens | Params |", "|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|", ] ) for row in sorted(prefix_rows, key=lambda item: item["dropout_initial"]): lines.append( "| " f"{row['dropout_initial']:.2f} | {row['n']} | " f"{row['mean_val_eval_loss']:.4f} | " f"{row['std_val_eval_loss']:.4f} | " f"{row['mean_train_eval_loss']:.4f} | " f"{row['std_train_eval_loss']:.4f} | " f"{row['mean_generalization_gap']:.4f} | " f"{row['std_generalization_gap']:.4f} | " f"{int(row['tokens_seen']):,} | {int(row['parameters']):,} |" ) output = "\n".join(lines) + "\n" (output_dir / "RESULT_SUMMARY.md").write_text(output, encoding="utf-8") def svg_escape(value: object) -> str: return ( str(value) .replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) ) def write_dropout_curve_svg(output_dir: Path, summary: list[dict]) -> None: rows = [ row for row in summary if row["run_mode"] == "screen_static" and row["condition_kind"] == "static" ] if not rows: return grouped: dict[tuple[str, int], list[dict]] = defaultdict(list) model_params: dict[str, int] = {} for row in rows: model_name = row["model_name"] grouped[(model_name, int(row["token_limit"]))].append(row) model_params[model_name] = int(row["parameters"]) models = sorted(model_params, key=lambda name: model_params[name]) prefixes = sorted({int(row["token_limit"]) for row in rows}) panel_w, panel_h = 230, 170 margin_l, margin_t = 58, 34 plot_w, plot_h = 142, 94 gap_x, gap_y = 18, 38 width = margin_l + len(prefixes) * panel_w + gap_x height = 70 + len(models) * (panel_h + gap_y) colors = ["#1f77b4", "#d62728", "#2ca02c", "#9467bd", "#ff7f0e"] parts = [ f'', "", '', 'Static dropout law: validation loss vs dropout', 'Each panel uses its own y-scale. Points are one-seed means unless N > 1.', ] for col, prefix in enumerate(prefixes): x = margin_l + col * panel_w + plot_w / 2 parts.append( f'{prefix:,} prefix tokens' ) for row_idx, model_name in enumerate(models): row_y = 92 + row_idx * (panel_h + gap_y) parts.append( f'' f'{svg_escape(model_name)} ({model_params[model_name] / 1_000_000:.1f}M)' ) for col, prefix in enumerate(prefixes): panel_x = margin_l + col * panel_w panel_y = row_y curve = sorted( grouped.get((model_name, prefix), []), key=lambda item: float(item["dropout_initial"]), ) if not curve: continue losses = [float(item["mean_val_eval_loss"]) for item in curve] min_loss, max_loss = min(losses), max(losses) pad = max(0.02, (max_loss - min_loss) * 0.08) y_min, y_max = min_loss - pad, max_loss + pad best = min(curve, key=lambda item: float(item["mean_val_eval_loss"])) def px(dropout: float) -> float: return panel_x + (dropout / 0.9) * plot_w def py(loss: float) -> float: scale = (loss - y_min) / (y_max - y_min) return panel_y + plot_h - scale * plot_h parts.extend( [ f'', f'', f'', f'{y_max:.2f}', f'{y_min:.2f}', f'0', f'0.9', ] ) points = " ".join( f"{px(float(item['dropout_initial'])):.1f},{py(float(item['mean_val_eval_loss'])):.1f}" for item in curve ) color = colors[row_idx % len(colors)] parts.append(f'') for item in curve: dropout = float(item["dropout_initial"]) loss = float(item["mean_val_eval_loss"]) radius = 4 if item is best else 2.7 fill = "#111827" if item is best else "#ffffff" parts.append( f'' ) parts.append( f'' f'best p={float(best["dropout_initial"]):.2f}' ) parts.append( f'' f'loss={float(best["mean_val_eval_loss"]):.3f}' ) parts.append("") (output_dir / "dropout_curves.svg").write_text("\n".join(parts), encoding="utf-8") def write_stream_markdown_summary(output_dir: Path, rows: list[dict]) -> None: stream_rows = [row for row in rows if row["run_mode"] == "locked_stream"] if not stream_rows: return by_condition_stage: dict[tuple[str, int], list[dict]] = defaultdict(list) by_condition: dict[str, list[dict]] = defaultdict(list) for row in stream_rows: condition = row["condition"] by_condition_stage[(condition, int(row["stage"]))].append(row) by_condition[condition].append(row) first = stream_rows[0] seeds = sorted({int(row["seed"]) for row in stream_rows}) conditions = sorted( by_condition, key=lambda name: ( by_condition[name][0]["condition_kind"] != "anchor_decay", by_condition[name][0]["dropout_initial"], name, ), ) stages = sorted({int(row["stage"]) for row in stream_rows}) lines = [ "# Locked Streaming Dropout Summary", "", f"Run directory: `{output_dir}`", "", ( f"Model: `{first['model_name']}` causal Transformer, " f"{int(first['parameters']):,} parameters, {first['n_layer']} layers, " f"{first['n_head']} heads, {first['n_embd']} embedding dim." ), ( f"Training per stage: {first['steps']:,} steps. " "Sampled tokens are cumulative in each stage row. " f"Seeds present: {', '.join(str(seed) for seed in seeds)}." ), "", "## Condition Ranking", "", "| Condition | Kind | Final dropout | Mean trajectory val loss | Final val loss | Final gap | Dropout path |", "|---|---|---:|---:|---:|---:|---|", ] ranking = [] for condition in conditions: stage_items = [] for stage in stages: group = by_condition_stage.get((condition, stage), []) if not group: continue stage_items.append( { "stage": stage, "token_limit": int(group[0]["token_limit"]), "mean_val": statistics.fmean( float(row["val_eval_loss"]) for row in group ), "mean_gap": statistics.fmean( float(row["generalization_gap"]) for row in group ), "mean_dropout": statistics.fmean( float(row["dropout_active_final"]) for row in group ), "kind": group[0]["condition_kind"], } ) if not stage_items: continue final = max(stage_items, key=lambda item: item["stage"]) ranking.append( { "condition": condition, "kind": stage_items[0]["kind"], "trajectory_val": statistics.fmean(item["mean_val"] for item in stage_items), "final_val": final["mean_val"], "final_gap": final["mean_gap"], "final_dropout": final["mean_dropout"], "dropout_path": " -> ".join( f"{item['mean_dropout']:.2f}" for item in stage_items ), } ) for item in sorted(ranking, key=lambda row: row["trajectory_val"]): lines.append( "| " f"`{item['condition']}` | {item['kind']} | " f"{item['final_dropout']:.2f} | {item['trajectory_val']:.4f} | " f"{item['final_val']:.4f} | {item['final_gap']:.4f} | " f"{item['dropout_path']} |" ) lines.extend(["", "## Stage Trajectory", ""]) for stage in stages: stage_groups = { condition: by_condition_stage[(condition, stage)] for condition in conditions if (condition, stage) in by_condition_stage } if not stage_groups: continue prefix = int(next(iter(stage_groups.values()))[0]["token_limit"]) lines.extend( [ f"### Stage {stage}: {prefix:,} Prefix Tokens", "", "| Condition | Dropout | Mean val loss | Mean train loss | Mean gap | N |", "|---|---:|---:|---:|---:|---:|", ] ) for condition, group in sorted( stage_groups.items(), key=lambda item: statistics.fmean( float(row["val_eval_loss"]) for row in item[1] ), ): val = statistics.fmean(float(row["val_eval_loss"]) for row in group) train = statistics.fmean(float(row["train_eval_loss"]) for row in group) gap = statistics.fmean(float(row["generalization_gap"]) for row in group) dropout = statistics.fmean( float(row["dropout_active_final"]) for row in group ) lines.append( "| " f"`{condition}` | {dropout:.2f} | {val:.4f} | " f"{train:.4f} | {gap:.4f} | {len(group)} |" ) lines.append("") (output_dir / "RESULT_SUMMARY.md").write_text( "\n".join(lines).rstrip() + "\n", encoding="utf-8", ) def static_conditions(dropout_rates: list[float]) -> list[DropoutCondition]: return [ DropoutCondition( name=f"static_dropout_{rate_label(rate)}", kind="static", initial=rate, final=rate, ) for rate in dropout_rates ] def run_fixed_static_sweep( *, args: argparse.Namespace, model_specs: list[ModelSpec], seeds: list[int], train_tokens: np.ndarray, val_tokens: np.ndarray, tokenizer_vocab_size: int, token_limits: list[int], device: torch.device, metrics_file, trace_file, completed_keys: set[tuple] | None = None, ) -> list[dict]: rows: list[dict] = [] completed_keys = completed_keys or set() conditions = static_conditions(sorted(set(args.dropout_rates))) planned = 0 for token_limit in token_limits: for model_spec in model_specs: for condition in conditions: for seed in seeds: key = planned_metric_key( mode=args.mode, condition=condition, model_spec=model_spec, seed=seed, token_limit=token_limit, ) if key not in completed_keys: planned += 1 progress = ProgressMeter(planned) for token_limit in token_limits: for model_spec in model_specs: best_val_loss = float("inf") worse_streak = 0 for condition in conditions: condition_rows: list[dict] = [] for seed in seeds: key = planned_metric_key( mode=args.mode, condition=condition, model_spec=model_spec, seed=seed, token_limit=token_limit, ) if key in completed_keys: write_jsonl_row( trace_file, { "event": "skipped_completed_condition", "run_mode": args.mode, "condition": condition.name, "model_name": model_spec.name, "seed": seed, "stage": None, "token_limit": int(token_limit), }, ) continue config = model_spec.config( tokenizer_vocab_size, args.block_size, condition.initial, ) model, optimizer, _, row = train_segment( run_mode=args.mode, condition=condition, model_spec=model_spec, config=config, train_tokens=train_tokens, val_tokens=val_tokens, token_limit=token_limit, steps=args.steps, seed=seed, args=args, device=device, dropout_fn=condition.make_fn( args.steps * args.batch_size * args.block_size ), metrics_file=metrics_file, trace_file=trace_file, ) rows.append(row) condition_rows.append(row) completed_keys.add(metric_key(row)) progress.mark_done(row) del model, optimizer torch.mps.empty_cache() if not condition_rows: continue mean_val_loss = statistics.fmean( float(row["val_eval_loss"]) for row in condition_rows ) if mean_val_loss < best_val_loss - args.screen_prune_min_delta: best_val_loss = mean_val_loss worse_streak = 0 elif mean_val_loss > best_val_loss + args.screen_prune_min_delta: worse_streak += 1 if ( args.mode == "screen_static" and args.screen_early_stop and worse_streak >= args.screen_prune_patience and condition.initial >= args.target_min_dropout ): write_jsonl_row( trace_file, { "event": "screen_pruned_model", "run_mode": args.mode, "model_name": model_spec.name, "token_limit": int(token_limit), "best_val_loss": best_val_loss, "pruned_after_dropout": condition.initial, "worse_streak": worse_streak, "remaining_dropouts": [ rate for rate in args.dropout_rates if rate > condition.initial ], }, ) break return rows def run_locked_stream( *, args: argparse.Namespace, model_specs: list[ModelSpec], seeds: list[int], train_tokens: np.ndarray, val_tokens: np.ndarray, tokenizer_vocab_size: int, stream_caps: list[int], device: torch.device, metrics_file, trace_file, completed_keys: set[tuple] | None = None, ) -> list[dict]: rows: list[dict] = [] completed_keys = completed_keys or set() conditions = args.anchor_decays + static_conditions(args.dropout_rates) + args.decays fallback_decay_tokens = ( args.decay_tokens or args.stage_steps * args.batch_size * args.block_size * len(stream_caps) ) planned = 0 for model_spec in model_specs: for condition in conditions: for seed in seeds: for stage, token_limit in enumerate(stream_caps): key = planned_metric_key( mode=args.mode, condition=condition, model_spec=model_spec, seed=seed, token_limit=token_limit, stage=stage, ) if key not in completed_keys: planned += 1 progress = ProgressMeter(planned) for model_spec in model_specs: for condition in conditions: for seed in seeds: model = None optimizer = None tokens_seen = 0 for stage, token_limit in enumerate(stream_caps): key = planned_metric_key( mode=args.mode, condition=condition, model_spec=model_spec, seed=seed, token_limit=token_limit, stage=stage, ) if key in completed_keys: write_jsonl_row( trace_file, { "event": "skipped_completed_condition", "run_mode": args.mode, "condition": condition.name, "model_name": model_spec.name, "seed": seed, "stage": stage, "token_limit": int(token_limit), }, ) continue config = model_spec.config( tokenizer_vocab_size, args.block_size, condition.initial, ) model, optimizer, tokens_seen, row = train_segment( run_mode=args.mode, condition=condition, model_spec=model_spec, config=config, train_tokens=train_tokens, val_tokens=val_tokens, token_limit=token_limit, steps=args.stage_steps, seed=seed, args=args, device=device, dropout_fn=condition.make_fn( fallback_decay_tokens, unique_tokens=token_limit, ), metrics_file=metrics_file, trace_file=trace_file, stage=stage, model=model, optimizer=optimizer, tokens_seen_start=tokens_seen, ) rows.append(row) completed_keys.add(metric_key(row)) progress.mark_done(row) del model, optimizer torch.mps.empty_cache() return rows def prepare_data(args: argparse.Namespace, output_dir: Path, required_train_tokens: int): cache_dir = Path(args.cache_dir) if args.cache_dir else output_dir / "cache" cache_dir.mkdir(parents=True, exist_ok=True) if args.use_cached_data: if args.force_retokenize: raise ValueError("--use-cached-data cannot be combined with --force-retokenize") return load_cached_splits( cache_dir=cache_dir, vocab_size=args.vocab_size, max_required_train_tokens=required_train_tokens, val_tokens=args.val_tokens, allow_short_corpus=args.allow_short_corpus, ) paths = resolve_paths(args.corpus, args.corpus_glob) tokenizer = train_or_load_tokenizer( paths=paths, output_dir=cache_dir, vocab_size=args.vocab_size, tokenizer_train_chars=args.tokenizer_train_chars, text_column=args.text_column, force_retrain=args.force_retokenize, ) splits = encode_corpus( paths=paths, tokenizer=tokenizer, output_dir=cache_dir, max_required_train_tokens=required_train_tokens, val_tokens=args.val_tokens, text_column=args.text_column, allow_short_corpus=args.allow_short_corpus, force_reencode=args.force_retokenize, ) return tokenizer, splits def run(args: argparse.Namespace) -> Path: device = assert_mps_only() seeds = default_seeds(args.mode, args.seeds) model_specs = [parse_model_spec(spec) for spec in args.models] if args.mode != "locked_stream" and (args.decays or args.anchor_decays): raise ValueError("--decays and --anchor-decays are only used with --mode locked_stream") if args.resume_from and args.mode == "locked_stream": raise ValueError("--resume-from currently supports fixed static sweeps only") if args.resume_from: output_dir = Path(args.resume_from) if not output_dir.exists(): raise FileNotFoundError(f"resume directory does not exist: {output_dir}") else: run_id = datetime.now().strftime("%Y%m%d-%H%M%S") output_dir = Path(args.output_dir) / args.mode / run_id output_dir.mkdir(parents=True, exist_ok=True) required_train_tokens = max( args.stream_token_caps if args.mode == "locked_stream" else args.token_limits ) tokenizer, splits = prepare_data(args, output_dir, required_train_tokens) token_limits = [min(limit, len(splits.train)) for limit in args.token_limits] stream_caps = [min(limit, len(splits.train)) for limit in args.stream_token_caps] args_payload = vars(args).copy() args_payload["decays"] = [condition.to_dict() for condition in args.decays] args_payload["anchor_decays"] = [ condition.to_dict() for condition in args.anchor_decays ] config_payload = { "args": args_payload, "mode": args.mode, "seeds": seeds, "models": [model.to_dict() for model in model_specs], "device": str(device), "torch": torch.__version__, "python": sys.version, "mps_available": torch.backends.mps.is_available(), "attribution": NANOCHAT_ATTRIBUTION, "tokenizer_path": str(splits.tokenizer_path), "encoded_path": str(splits.encoded_path), "train_tokens": int(len(splits.train)), "val_tokens": int(len(splits.val)), "effective_token_limits": [int(limit) for limit in token_limits], "effective_stream_token_caps": [int(limit) for limit in stream_caps], "resume_from": str(args.resume_from) if args.resume_from else None, } config_name = "config.resume.json" if args.resume_from else "config.json" (output_dir / config_name).write_text( json.dumps(config_payload, indent=2), encoding="utf-8", ) metrics_path = output_dir / "metrics.jsonl" trace_path = output_dir / "trace.jsonl" existing_rows = load_metrics(metrics_path) if args.resume_from else [] completed_keys = {metric_key(row) for row in existing_rows} with ( metrics_path.open("a" if args.resume_from else "w", encoding="utf-8") as metrics_file, trace_path.open("a" if args.resume_from else "w", encoding="utf-8") as trace_file, ): if args.mode in {"screen_static", "confirm_static"}: new_rows = run_fixed_static_sweep( args=args, model_specs=model_specs, seeds=seeds, train_tokens=splits.train, val_tokens=splits.val, tokenizer_vocab_size=tokenizer.vocab_size, token_limits=token_limits, device=device, metrics_file=metrics_file, trace_file=trace_file, completed_keys=completed_keys, ) else: new_rows = run_locked_stream( args=args, model_specs=model_specs, seeds=seeds, train_tokens=splits.train, val_tokens=splits.val, tokenizer_vocab_size=tokenizer.vocab_size, stream_caps=stream_caps, device=device, metrics_file=metrics_file, trace_file=trace_file, ) rows = existing_rows + new_rows summary = summarize(rows) (output_dir / "summary.json").write_text( json.dumps(summary, indent=2), encoding="utf-8", ) write_csv(output_dir / "summary.csv", summary, SUMMARY_FIELDS) if args.mode in {"screen_static", "confirm_static"}: selection = build_model_selection(summary, args) (output_dir / "model_selection.json").write_text( json.dumps(selection, indent=2), encoding="utf-8", ) write_csv(output_dir / "model_selection.csv", selection, SELECTION_FIELDS) write_screen_markdown_summary(output_dir, rows) write_dropout_curve_svg(output_dir, summary) elif args.mode == "locked_stream": write_stream_markdown_summary(output_dir, rows) print( json.dumps( { "output_dir": str(output_dir), "new_rows": len(new_rows), "total_metric_rows": len(rows), "summary_rows": len(summary), }, indent=2, ) ) return output_dir def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="MPS-only dropout/model selection experiments" ) parser.add_argument( "--mode", choices=["screen_static", "confirm_static", "locked_stream"], default="screen_static", ) parser.add_argument("--corpus", default=None, help="Text or parquet corpus path") parser.add_argument("--corpus-glob", default=None, help="Glob of text/parquet corpus paths") parser.add_argument("--text-column", default="text", help="Parquet text column") parser.add_argument( "--use-cached-data", action="store_true", help=( "Load tokenizer-v{vocab}.json and tokens-v{vocab}-*.npy from --cache-dir " "instead of requiring the original text/parquet corpus." ), ) parser.add_argument("--output-dir", default="runs") parser.add_argument( "--resume-from", default=None, help="Existing fixed-static run directory; completed metric rows are skipped", ) parser.add_argument("--cache-dir", default=".cache/dropout_decay") parser.add_argument( "--models", nargs="+", default=["8x8x256"], help="Model specs like 8x8x256 or name=8x8x256", ) parser.add_argument("--seeds", nargs="+", type=int, default=None) parser.add_argument("--token-limits", nargs="+", type=int, default=[5_000_000]) parser.add_argument( "--stream-token-caps", nargs="+", type=int, default=[5_000_000, 10_000_000, 20_000_000, 40_000_000], ) parser.add_argument("--val-tokens", type=int, default=500_000) parser.add_argument("--allow-short-corpus", action="store_true") parser.add_argument("--force-retokenize", action="store_true") parser.add_argument("--vocab-size", type=int, default=4096) parser.add_argument("--tokenizer-train-chars", type=int, default=10_000_000) parser.add_argument("--block-size", type=int, default=128) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--steps", type=int, default=2000) parser.add_argument("--stage-steps", type=int, default=1000) parser.add_argument("--dropout-rates", nargs="*", type=float, default=DEFAULT_DROPOUT_RATES) parser.add_argument("--decays", nargs="*", type=parse_decay_spec, default=[]) parser.add_argument( "--anchor-decays", nargs="*", type=parse_anchor_decay_spec, default=[], help=( "Prefix-token anchor schedules like " "fit:250000=0.60,500000=0.40,1000000=0.30" ), ) parser.add_argument("--decay-tokens", type=int, default=None) parser.add_argument("--eval-batches", type=int, default=64) parser.add_argument("--train-eval-batches", type=int, default=32) parser.add_argument("--trace-eval-batches", type=int, default=8) parser.add_argument("--eval-every", type=int, default=0) parser.add_argument("--log-every", type=int, default=100) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--weight-decay", type=float, default=0.1) parser.add_argument("--grad-clip", type=float, default=1.0) parser.add_argument("--plateau-delta", type=float, default=0.01) parser.add_argument("--target-min-dropout", type=float, default=0.10) parser.add_argument("--min-nonzero-margin", type=float, default=0.01) parser.add_argument("--min-high-dropout-margin", type=float, default=0.03) parser.add_argument("--screen-early-stop", action="store_true") parser.add_argument("--screen-prune-patience", type=int, default=3) parser.add_argument("--screen-prune-min-delta", type=float, default=0.01) return parser def main() -> None: args = build_parser().parse_args() run(args) if __name__ == "__main__": main()