| """ |
| Derived from Andrej Karpathy's nanochat project. |
| |
| MIT License |
| |
| Copyright (c) 2025 Andrej Karpathy |
| |
| Permission is hereby granted, free of charge, to any person obtaining a copy |
| of this software and associated documentation files (the "Software"), to deal |
| in the Software without restriction, including without limitation the rights |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| copies of the Software, and to permit persons to whom the Software is |
| furnished to do so, subject to the following conditions: |
| |
| The above copyright notice and this permission notice shall be included in all |
| copies or substantial portions of the Software. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from datetime import datetime |
| import json |
| import math |
| import os |
| from pathlib import Path |
| import random |
| import statistics |
| import sys |
| import time |
| from typing import Callable |
|
|
| import numpy as np |
| import torch |
|
|
| from dropout_decay.data import ( |
| encode_corpus, |
| load_cached_splits, |
| resolve_paths, |
| train_or_load_tokenizer, |
| ) |
| from dropout_decay.license import NANOCHAT_ATTRIBUTION |
| from dropout_decay.model import DropoutGPT, GPTConfig |
| from dropout_decay.scheduler import DropoutDecayConfig, DropoutDecayScheduler |
|
|
|
|
| DEFAULT_DROPOUT_RATES = [0.0, 0.02, 0.05, 0.08, 0.10, 0.14, 0.20, 0.30, 0.50] |
| SUMMARY_FIELDS = [ |
| "run_mode", |
| "condition", |
| "condition_kind", |
| "stage", |
| "token_limit", |
| "model_name", |
| "n_layer", |
| "n_head", |
| "n_embd", |
| "parameters", |
| "dropout_initial", |
| "dropout_final", |
| "dropout_schedule", |
| "n", |
| "mean_train_eval_loss", |
| "std_train_eval_loss", |
| "mean_val_eval_loss", |
| "std_val_eval_loss", |
| "mean_generalization_gap", |
| "std_generalization_gap", |
| ] |
| SELECTION_FIELDS = [ |
| "run_mode", |
| "token_limit", |
| "model_name", |
| "n_layer", |
| "n_head", |
| "n_embd", |
| "parameters", |
| "n", |
| "best_dropout", |
| "best_val_loss", |
| "best_val_std", |
| "plateau_start_dropout", |
| "plateau_end_dropout", |
| "plateau_delta", |
| "zero_dropout_val_loss", |
| "zero_minus_best", |
| "best_nonzero_dropout", |
| "best_nonzero_val_loss", |
| "zero_minus_best_nonzero", |
| "max_dropout", |
| "max_dropout_val_loss", |
| "max_dropout_minus_best", |
| "has_nonzero_optimum", |
| "meets_target_dropout", |
| "curve_json", |
| ] |
|
|
|
|
| @dataclass(frozen=True) |
| class ModelSpec: |
| name: str |
| n_layer: int |
| n_head: int |
| n_embd: int |
|
|
| def config(self, vocab_size: int, block_size: int, dropout: float) -> GPTConfig: |
| return GPTConfig( |
| block_size=block_size, |
| vocab_size=vocab_size, |
| n_layer=self.n_layer, |
| n_head=self.n_head, |
| n_embd=self.n_embd, |
| dropout=dropout, |
| ) |
|
|
| def to_dict(self) -> dict[str, int | str]: |
| return { |
| "model_name": self.name, |
| "n_layer": self.n_layer, |
| "n_head": self.n_head, |
| "n_embd": self.n_embd, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class DropoutCondition: |
| name: str |
| kind: str |
| initial: float |
| final: float |
| schedule: str = "constant" |
| decay_tokens: int | None = None |
| anchors: tuple[tuple[int, float], ...] = () |
|
|
| def to_dict(self) -> dict: |
| return { |
| "name": self.name, |
| "kind": self.kind, |
| "initial": self.initial, |
| "final": self.final, |
| "schedule": self.schedule, |
| "decay_tokens": self.decay_tokens, |
| "anchors": [list(anchor) for anchor in self.anchors], |
| } |
|
|
| def make_fn( |
| self, |
| fallback_decay_tokens: int, |
| unique_tokens: int | None = None, |
| ) -> Callable[[int], float]: |
| if self.kind == "static": |
| return lambda _tokens_seen, p=self.initial: p |
| if self.kind == "anchor_decay": |
| if unique_tokens is None: |
| raise ValueError("anchor_decay conditions require unique_tokens") |
| p = anchor_dropout(unique_tokens, self.anchors) |
| return lambda _tokens_seen, p=p: p |
| scheduler = DropoutDecayScheduler( |
| DropoutDecayConfig( |
| initial_dropout=self.initial, |
| final_dropout=self.final, |
| decay_tokens=self.decay_tokens or fallback_decay_tokens, |
| schedule=self.schedule, |
| ) |
| ) |
| return scheduler.value |
|
|
|
|
| def anchor_dropout(unique_tokens: int, anchors: tuple[tuple[int, float], ...]) -> float: |
| if not anchors: |
| raise ValueError("anchor dropout schedule requires at least one anchor") |
| ordered = sorted(anchors) |
| if unique_tokens <= ordered[0][0]: |
| return ordered[0][1] |
| if unique_tokens >= ordered[-1][0]: |
| return ordered[-1][1] |
| log_unique = math.log(unique_tokens) |
| for (left_tokens, left_dropout), (right_tokens, right_dropout) in zip( |
| ordered, ordered[1:] |
| ): |
| if left_tokens <= unique_tokens <= right_tokens: |
| left_log = math.log(left_tokens) |
| right_log = math.log(right_tokens) |
| mix = (log_unique - left_log) / (right_log - left_log) |
| return left_dropout + mix * (right_dropout - left_dropout) |
| return ordered[-1][1] |
|
|
|
|
| def assert_mps_only() -> torch.device: |
| if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") == "1": |
| raise SystemExit( |
| "Refusing to run: PYTORCH_ENABLE_MPS_FALLBACK=1 could execute CPU fallbacks." |
| ) |
| if not torch.backends.mps.is_built(): |
| raise SystemExit("PyTorch was not built with MPS. Stopping.") |
| if not torch.backends.mps.is_available(): |
| raise SystemExit("MPS is not available. Stopping before any Torch experiment work.") |
| device = torch.device("mps") |
| torch.set_default_device(device) |
| return device |
|
|
|
|
| def clean_name(value: str) -> str: |
| cleaned = "".join( |
| c if c.isalnum() or c in {"-", "_"} else "_" for c in value.strip() |
| ) |
| return cleaned.strip("_") or "model" |
|
|
|
|
| def rate_label(rate: float) -> str: |
| label = f"{rate:.3f}".rstrip("0").rstrip(".") |
| return label if label else "0" |
|
|
|
|
| def parse_model_spec(raw: str) -> ModelSpec: |
| if "=" in raw: |
| name, dims = raw.split("=", 1) |
| name = clean_name(name) |
| else: |
| dims = raw |
| name = "" |
| parts = dims.lower().replace(",", "x").split("x") |
| if len(parts) != 3: |
| raise argparse.ArgumentTypeError( |
| "model specs must look like 8x8x256 or name=8x8x256" |
| ) |
| try: |
| n_layer, n_head, n_embd = [int(part) for part in parts] |
| except ValueError as exc: |
| raise argparse.ArgumentTypeError("model dimensions must be integers") from exc |
| if n_layer <= 0 or n_head <= 0 or n_embd <= 0: |
| raise argparse.ArgumentTypeError("model dimensions must be positive") |
| if n_embd % n_head != 0: |
| raise argparse.ArgumentTypeError("n_embd must be divisible by n_head") |
| if (n_embd // n_head) % 2 != 0: |
| raise argparse.ArgumentTypeError("n_embd / n_head must be even for rotary attention") |
| return ModelSpec( |
| name or f"L{n_layer}_H{n_head}_D{n_embd}", |
| n_layer, |
| n_head, |
| n_embd, |
| ) |
|
|
|
|
| def parse_decay_spec(raw: str) -> DropoutCondition: |
| parts = raw.split(":") |
| if len(parts) not in {3, 4, 5}: |
| raise argparse.ArgumentTypeError( |
| "decay specs must look like " |
| "name:initial:final[:cosine|smoothstep|linear[:decay_tokens]]" |
| ) |
| name = clean_name(parts[0]) |
| try: |
| initial = float(parts[1]) |
| final = float(parts[2]) |
| decay_tokens = int(parts[4]) if len(parts) == 5 else None |
| except ValueError as exc: |
| raise argparse.ArgumentTypeError( |
| "decay dropout values and decay_tokens must be numeric" |
| ) from exc |
| schedule = parts[3] if len(parts) >= 4 else "cosine" |
| if schedule not in {"cosine", "smoothstep", "linear"}: |
| raise argparse.ArgumentTypeError( |
| "decay schedule must be cosine, smoothstep, or linear" |
| ) |
| return DropoutCondition( |
| name=name, |
| kind="decay", |
| initial=initial, |
| final=final, |
| schedule=schedule, |
| decay_tokens=decay_tokens, |
| ) |
|
|
|
|
| def parse_anchor_decay_spec(raw: str) -> DropoutCondition: |
| if ":" not in raw: |
| raise argparse.ArgumentTypeError( |
| "anchor decay specs must look like name:250000=0.60,500000=0.40" |
| ) |
| name, raw_anchors = raw.split(":", 1) |
| anchors: list[tuple[int, float]] = [] |
| for piece in raw_anchors.split(","): |
| if "=" not in piece: |
| raise argparse.ArgumentTypeError( |
| "anchor decay anchors must look like token_count=dropout" |
| ) |
| raw_tokens, raw_dropout = piece.split("=", 1) |
| try: |
| token_count = int(raw_tokens) |
| dropout = float(raw_dropout) |
| except ValueError as exc: |
| raise argparse.ArgumentTypeError( |
| "anchor token counts must be integers and dropout values numeric" |
| ) from exc |
| if token_count <= 0: |
| raise argparse.ArgumentTypeError("anchor token counts must be positive") |
| if not 0.0 <= dropout < 1.0: |
| raise argparse.ArgumentTypeError("anchor dropout must satisfy 0 <= p < 1") |
| anchors.append((token_count, dropout)) |
| anchors = sorted(set(anchors)) |
| if len(anchors) < 2: |
| raise argparse.ArgumentTypeError("provide at least two dropout anchors") |
| for left, right in zip(anchors, anchors[1:]): |
| if right[1] > left[1]: |
| raise argparse.ArgumentTypeError( |
| "anchor dropout values must be non-increasing as token counts grow" |
| ) |
| return DropoutCondition( |
| name=clean_name(name), |
| kind="anchor_decay", |
| initial=anchors[0][1], |
| final=anchors[-1][1], |
| schedule="log_prefix_anchor", |
| anchors=tuple(anchors), |
| ) |
|
|
|
|
| def default_seeds(mode: str, seeds: list[int] | None) -> list[int]: |
| if seeds: |
| return seeds |
| return [1] if mode == "screen_static" else [1, 2, 3] |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
|
|
| def batch_seed(seed: int, model: ModelSpec, dropout_code: int, stage: int | None) -> int: |
| return ( |
| seed * 1_000_003 |
| + model.n_layer * 100_003 |
| + model.n_head * 10_007 |
| + model.n_embd * 101 |
| + dropout_code * 37 |
| + (stage or 0) * 997 |
| ) |
|
|
|
|
| def make_batch( |
| tokens: np.ndarray, |
| token_limit: int, |
| batch_size: int, |
| block_size: int, |
| rng: np.random.Generator, |
| device: torch.device, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| limit = min(token_limit, len(tokens)) |
| max_start = limit - block_size - 1 |
| if max_start <= 0: |
| raise ValueError("token_limit is too small for the requested block_size") |
| starts = rng.integers(0, max_start, size=batch_size) |
| x_np = np.stack([tokens[start : start + block_size] for start in starts]).astype( |
| np.int64 |
| ) |
| y_np = np.stack( |
| [tokens[start + 1 : start + 1 + block_size] for start in starts] |
| ).astype(np.int64) |
| return ( |
| torch.tensor(x_np, dtype=torch.long, device=device), |
| torch.tensor(y_np, dtype=torch.long, device=device), |
| ) |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss( |
| model: DropoutGPT, |
| tokens: np.ndarray, |
| token_limit: int, |
| batches: int, |
| args: argparse.Namespace, |
| device: torch.device, |
| rng_seed: int, |
| ) -> float: |
| if batches <= 0: |
| return float("nan") |
| model.eval() |
| rng = np.random.default_rng(rng_seed) |
| losses: list[float] = [] |
| for _ in range(batches): |
| x, y = make_batch(tokens, token_limit, args.batch_size, args.block_size, rng, device) |
| _, loss = model(x, y) |
| losses.append(float(loss.item())) |
| model.train() |
| return float(statistics.fmean(losses)) |
|
|
|
|
| def write_jsonl_row(handle, row: dict) -> None: |
| handle.write(json.dumps(row, sort_keys=True) + "\n") |
| handle.flush() |
|
|
|
|
| def metric_key(row: dict) -> tuple: |
| return ( |
| row["run_mode"], |
| row["condition"], |
| row["condition_kind"], |
| row.get("stage"), |
| int(row["token_limit"]), |
| row["model_name"], |
| int(row["n_layer"]), |
| int(row["n_head"]), |
| int(row["n_embd"]), |
| int(row["seed"]), |
| float(row["dropout_initial"]), |
| float(row["dropout_final"]), |
| row["dropout_schedule"], |
| ) |
|
|
|
|
| def planned_metric_key( |
| *, |
| mode: str, |
| condition: DropoutCondition, |
| model_spec: ModelSpec, |
| seed: int, |
| token_limit: int, |
| stage: int | None = None, |
| ) -> tuple: |
| return ( |
| mode, |
| condition.name, |
| condition.kind, |
| stage, |
| int(token_limit), |
| model_spec.name, |
| int(model_spec.n_layer), |
| int(model_spec.n_head), |
| int(model_spec.n_embd), |
| int(seed), |
| float(condition.initial), |
| float(condition.final), |
| condition.schedule, |
| ) |
|
|
|
|
| def load_metrics(path: Path) -> list[dict]: |
| if not path.exists(): |
| return [] |
| rows = [] |
| for line in path.read_text(encoding="utf-8").splitlines(): |
| if line.strip(): |
| rows.append(json.loads(line)) |
| return rows |
|
|
|
|
| def train_segment( |
| *, |
| run_mode: str, |
| condition: DropoutCondition, |
| model_spec: ModelSpec, |
| config: GPTConfig, |
| train_tokens: np.ndarray, |
| val_tokens: np.ndarray, |
| token_limit: int, |
| steps: int, |
| seed: int, |
| args: argparse.Namespace, |
| device: torch.device, |
| dropout_fn: Callable[[int], float], |
| metrics_file, |
| trace_file, |
| stage: int | None = None, |
| model: DropoutGPT | None = None, |
| optimizer: torch.optim.Optimizer | None = None, |
| tokens_seen_start: int = 0, |
| ) -> tuple[DropoutGPT, torch.optim.Optimizer, int, dict]: |
| if model is None: |
| set_seed(seed) |
| model = DropoutGPT(config).to(device) |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.lr, |
| betas=(0.9, 0.95), |
| weight_decay=args.weight_decay, |
| ) |
| else: |
| torch.manual_seed(seed + 10_000 + (stage or 0)) |
| if optimizer is None: |
| raise ValueError("optimizer is required when reusing a model") |
|
|
| model.train() |
| dropout_code = int(round(condition.initial * 10_000)) |
| rng = np.random.default_rng(batch_seed(seed, model_spec, dropout_code, stage)) |
| tokens_seen = tokens_seen_start |
| last_loss = float("nan") |
| active_dropout = condition.initial |
| t0 = time.time() |
|
|
| for step in range(1, steps + 1): |
| active_dropout = dropout_fn(tokens_seen) |
| model.set_dropout(active_dropout) |
| x, y = make_batch(train_tokens, token_limit, args.batch_size, args.block_size, rng, device) |
| _, loss = model(x, y) |
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| if args.grad_clip > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optimizer.step() |
|
|
| tokens_seen += args.batch_size * args.block_size |
| last_loss = float(loss.item()) |
| if args.log_every > 0 and step % args.log_every == 0: |
| write_jsonl_row( |
| trace_file, |
| { |
| "event": "train_step", |
| "run_mode": run_mode, |
| "condition": condition.name, |
| "model_name": model_spec.name, |
| "seed": seed, |
| "stage": stage, |
| "step": step, |
| "steps": steps, |
| "token_limit": int(token_limit), |
| "tokens_seen": int(tokens_seen), |
| "dropout": float(active_dropout), |
| "train_batch_loss": last_loss, |
| }, |
| ) |
| if args.eval_every > 0 and step % args.eval_every == 0: |
| train_eval = estimate_loss( |
| model, |
| train_tokens, |
| token_limit, |
| args.trace_eval_batches, |
| args, |
| device, |
| rng_seed=seed + 20_000 + step, |
| ) |
| val_eval = estimate_loss( |
| model, |
| val_tokens, |
| len(val_tokens), |
| args.trace_eval_batches, |
| args, |
| device, |
| rng_seed=seed + 30_000 + step, |
| ) |
| write_jsonl_row( |
| trace_file, |
| { |
| "event": "eval_step", |
| "run_mode": run_mode, |
| "condition": condition.name, |
| "model_name": model_spec.name, |
| "seed": seed, |
| "stage": stage, |
| "step": step, |
| "steps": steps, |
| "token_limit": int(token_limit), |
| "tokens_seen": int(tokens_seen), |
| "dropout": float(active_dropout), |
| "train_eval_loss": train_eval, |
| "val_eval_loss": val_eval, |
| "generalization_gap": val_eval - train_eval, |
| }, |
| ) |
|
|
| train_eval = estimate_loss( |
| model, |
| train_tokens, |
| token_limit, |
| args.train_eval_batches, |
| args, |
| device, |
| rng_seed=seed + 40_000 + (stage or 0), |
| ) |
| val_eval = estimate_loss( |
| model, |
| val_tokens, |
| len(val_tokens), |
| args.eval_batches, |
| args, |
| device, |
| rng_seed=seed + 50_000 + (stage or 0), |
| ) |
| row = { |
| "run_mode": run_mode, |
| "condition": condition.name, |
| "condition_kind": condition.kind, |
| "seed": seed, |
| "stage": stage, |
| "token_limit": int(token_limit), |
| "steps": int(steps), |
| "tokens_seen": int(tokens_seen), |
| "dropout_initial": float(condition.initial), |
| "dropout_final": float(condition.final), |
| "dropout_schedule": condition.schedule, |
| "dropout_active_final": float(active_dropout), |
| "train_loss_last": last_loss, |
| "train_eval_loss": train_eval, |
| "val_eval_loss": val_eval, |
| "eval_loss": val_eval, |
| "generalization_gap": val_eval - train_eval, |
| "elapsed_sec": time.time() - t0, |
| "parameters": model.num_parameters(), |
| "model_config": config.to_dict(), |
| **model_spec.to_dict(), |
| } |
| write_jsonl_row(metrics_file, row) |
| return model, optimizer, tokens_seen, row |
|
|
|
|
| def summarize(rows: list[dict]) -> list[dict]: |
| groups: dict[tuple, list[dict]] = defaultdict(list) |
| for row in rows: |
| key = ( |
| row["run_mode"], |
| row["condition"], |
| row["condition_kind"], |
| row["stage"], |
| row["token_limit"], |
| row["model_name"], |
| row["n_layer"], |
| row["n_head"], |
| row["n_embd"], |
| row["parameters"], |
| row["dropout_initial"], |
| row["dropout_final"], |
| row["dropout_schedule"], |
| ) |
| groups[key].append(row) |
|
|
| summary: list[dict] = [] |
| for group_rows in groups.values(): |
| first = group_rows[0] |
| item = {field: first[field] for field in SUMMARY_FIELDS if field in first} |
| item["n"] = len(group_rows) |
| for source, mean_key, std_key in [ |
| ("train_eval_loss", "mean_train_eval_loss", "std_train_eval_loss"), |
| ("val_eval_loss", "mean_val_eval_loss", "std_val_eval_loss"), |
| ("generalization_gap", "mean_generalization_gap", "std_generalization_gap"), |
| ]: |
| values = [float(row[source]) for row in group_rows] |
| item[mean_key] = statistics.fmean(values) |
| item[std_key] = statistics.stdev(values) if len(values) > 1 else 0.0 |
| summary.append(item) |
| return sorted( |
| summary, |
| key=lambda row: ( |
| row["run_mode"], |
| row["model_name"], |
| row["token_limit"], |
| row["condition"], |
| row["stage"] or -1, |
| ), |
| ) |
|
|
|
|
| def write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None: |
| with path.open("w", newline="", encoding="utf-8") as handle: |
| writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") |
| writer.writeheader() |
| writer.writerows(rows) |
|
|
|
|
| def format_duration(seconds: float) -> str: |
| seconds = max(0, int(seconds)) |
| hours, remainder = divmod(seconds, 3600) |
| minutes, secs = divmod(remainder, 60) |
| if hours: |
| return f"{hours}h{minutes:02d}m" |
| if minutes: |
| return f"{minutes}m{secs:02d}s" |
| return f"{secs}s" |
|
|
|
|
| class ProgressMeter: |
| def __init__(self, total: int): |
| self.total = max(0, total) |
| self.done = 0 |
| self.started_at = time.time() |
|
|
| def mark_done(self, row: dict) -> None: |
| self.done += 1 |
| elapsed = time.time() - self.started_at |
| mean = elapsed / self.done if self.done else 0.0 |
| remaining = max(0, self.total - self.done) |
| eta = remaining * mean |
| stage = row.get("stage") |
| stage_text = f"stage={stage} " if stage is not None else "" |
| print( |
| "progress " |
| f"{self.done}/{self.total} " |
| f"eta={format_duration(eta)} " |
| f"mode={row['run_mode']} " |
| f"model={row['model_name']} " |
| f"params={int(row['parameters']):,} " |
| f"prefix={int(row['token_limit']):,} " |
| f"{stage_text}" |
| f"seed={row['seed']} " |
| f"condition={row['condition']} " |
| f"dropout={float(row['dropout_active_final']):.3f} " |
| f"val={row['val_eval_loss']:.4f} " |
| f"train={row['train_eval_loss']:.4f} " |
| f"gap={row['generalization_gap']:.4f} " |
| f"elapsed={format_duration(float(row['elapsed_sec']))}", |
| flush=True, |
| ) |
|
|
|
|
| def build_model_selection(summary: list[dict], args: argparse.Namespace) -> list[dict]: |
| groups: dict[tuple, list[dict]] = defaultdict(list) |
| for row in summary: |
| if row["condition_kind"] == "static" and row["stage"] is None: |
| groups[(row["run_mode"], row["token_limit"], row["model_name"])].append(row) |
|
|
| selection: list[dict] = [] |
| for (run_mode, token_limit, model_name), rows in groups.items(): |
| curve = sorted(rows, key=lambda row: row["dropout_initial"]) |
| best = min(curve, key=lambda row: row["mean_val_eval_loss"]) |
| plateau = [ |
| row for row in curve |
| if row["mean_val_eval_loss"] <= best["mean_val_eval_loss"] + args.plateau_delta |
| ] |
| zero = next((row for row in curve if row["dropout_initial"] == 0.0), None) |
| nonzero = [row for row in curve if row["dropout_initial"] > 0.0] |
| best_nonzero = min(nonzero, key=lambda row: row["mean_val_eval_loss"]) if nonzero else None |
| max_dropout = max(curve, key=lambda row: row["dropout_initial"]) |
|
|
| zero_loss = zero["mean_val_eval_loss"] if zero else None |
| zero_minus_best = zero_loss - best["mean_val_eval_loss"] if zero_loss is not None else None |
| zero_minus_best_nonzero = ( |
| zero_loss - best_nonzero["mean_val_eval_loss"] |
| if zero_loss is not None and best_nonzero is not None |
| else None |
| ) |
| max_minus_best = max_dropout["mean_val_eval_loss"] - best["mean_val_eval_loss"] |
| has_nonzero_optimum = ( |
| best["dropout_initial"] > 0.0 |
| and zero_minus_best is not None |
| and zero_minus_best >= args.min_nonzero_margin |
| and max_minus_best >= args.min_high_dropout_margin |
| ) |
| curve_json = json.dumps( |
| [ |
| { |
| "dropout": row["dropout_initial"], |
| "mean_val_loss": row["mean_val_eval_loss"], |
| "std_val_loss": row["std_val_eval_loss"], |
| "mean_train_loss": row["mean_train_eval_loss"], |
| "mean_generalization_gap": row["mean_generalization_gap"], |
| "n": row["n"], |
| } |
| for row in curve |
| ], |
| sort_keys=True, |
| ) |
| selection.append( |
| { |
| "run_mode": run_mode, |
| "token_limit": token_limit, |
| "model_name": model_name, |
| "n_layer": best["n_layer"], |
| "n_head": best["n_head"], |
| "n_embd": best["n_embd"], |
| "parameters": best["parameters"], |
| "n": best["n"], |
| "best_dropout": best["dropout_initial"], |
| "best_val_loss": best["mean_val_eval_loss"], |
| "best_val_std": best["std_val_eval_loss"], |
| "plateau_start_dropout": min(row["dropout_initial"] for row in plateau), |
| "plateau_end_dropout": max(row["dropout_initial"] for row in plateau), |
| "plateau_delta": args.plateau_delta, |
| "zero_dropout_val_loss": zero_loss, |
| "zero_minus_best": zero_minus_best, |
| "best_nonzero_dropout": ( |
| best_nonzero["dropout_initial"] if best_nonzero else None |
| ), |
| "best_nonzero_val_loss": ( |
| best_nonzero["mean_val_eval_loss"] if best_nonzero else None |
| ), |
| "zero_minus_best_nonzero": zero_minus_best_nonzero, |
| "max_dropout": max_dropout["dropout_initial"], |
| "max_dropout_val_loss": max_dropout["mean_val_eval_loss"], |
| "max_dropout_minus_best": max_minus_best, |
| "has_nonzero_optimum": has_nonzero_optimum, |
| "meets_target_dropout": best["dropout_initial"] >= args.target_min_dropout, |
| "curve_json": curve_json, |
| } |
| ) |
| return sorted( |
| selection, |
| key=lambda row: ( |
| not row["has_nonzero_optimum"], |
| not row["meets_target_dropout"], |
| row["best_val_loss"], |
| ), |
| ) |
|
|
|
|
| def write_screen_markdown_summary(output_dir: Path, rows: list[dict]) -> None: |
| if not rows: |
| return |
| static_rows = [ |
| row |
| for row in rows |
| if row["run_mode"] == "screen_static" and row["condition_kind"] == "static" |
| ] |
| if not static_rows: |
| return |
|
|
| by_model_prefix_rate: dict[tuple[str, int, float], list[dict]] = defaultdict(list) |
| for row in rows: |
| if row["run_mode"] == "screen_static" and row["condition_kind"] == "static": |
| by_model_prefix_rate[ |
| ( |
| row["model_name"], |
| int(row["token_limit"]), |
| float(row["dropout_initial"]), |
| ) |
| ].append(row) |
|
|
| aggregates: list[dict] = [] |
| for (model_name, prefix, dropout), group_rows in by_model_prefix_rate.items(): |
| first = group_rows[0] |
| val_losses = [float(row["val_eval_loss"]) for row in group_rows] |
| train_losses = [float(row["train_eval_loss"]) for row in group_rows] |
| gaps = [float(row["generalization_gap"]) for row in group_rows] |
| aggregates.append( |
| { |
| "model_name": model_name, |
| "token_limit": prefix, |
| "dropout_initial": dropout, |
| "n": len(group_rows), |
| "mean_val_eval_loss": statistics.fmean(val_losses), |
| "std_val_eval_loss": statistics.stdev(val_losses) |
| if len(val_losses) > 1 |
| else 0.0, |
| "mean_train_eval_loss": statistics.fmean(train_losses), |
| "std_train_eval_loss": statistics.stdev(train_losses) |
| if len(train_losses) > 1 |
| else 0.0, |
| "mean_generalization_gap": statistics.fmean(gaps), |
| "std_generalization_gap": statistics.stdev(gaps) |
| if len(gaps) > 1 |
| else 0.0, |
| "parameters": int(first["parameters"]), |
| "n_layer": int(first["n_layer"]), |
| "n_head": int(first["n_head"]), |
| "n_embd": int(first["n_embd"]), |
| "block_size": int(first["model_config"]["block_size"]), |
| "vocab_size": int(first["model_config"]["vocab_size"]), |
| "tokens_seen": int(first["tokens_seen"]), |
| "seeds": sorted({int(row["seed"]) for row in group_rows}), |
| } |
| ) |
|
|
| by_model: dict[str, list[dict]] = defaultdict(list) |
| for row in aggregates: |
| by_model[row["model_name"]].append(row) |
|
|
| model_rows = [] |
| for model_name, model_group in by_model.items(): |
| first = model_group[0] |
| seeds = sorted({seed for row in model_group for seed in row["seeds"]}) |
| model_rows.append( |
| { |
| "model_name": model_name, |
| "parameters": first["parameters"], |
| "n_layer": first["n_layer"], |
| "n_head": first["n_head"], |
| "n_embd": first["n_embd"], |
| "block_size": first["block_size"], |
| "vocab_size": first["vocab_size"], |
| "seeds": seeds, |
| } |
| ) |
|
|
| lines = [ |
| "# Static Dropout Screen Summary", |
| "", |
| f"Run directory: `{output_dir}`", |
| "", |
| "## Models", |
| "", |
| "| Model | Params | Layers | Heads | Embedding | Block | Vocab | Seeds |", |
| "|---|---:|---:|---:|---:|---:|---:|---|", |
| ] |
| for model in sorted(model_rows, key=lambda item: item["parameters"]): |
| lines.append( |
| "| " |
| f"`{model['model_name']}` | {model['parameters']:,} | " |
| f"{model['n_layer']} | {model['n_head']} | {model['n_embd']} | " |
| f"{model['block_size']} | {model['vocab_size']} | " |
| f"{', '.join(str(seed) for seed in model['seeds'])} |" |
| ) |
| lines.extend( |
| [ |
| "", |
| "## Best Dropout By Model And Prefix", |
| "", |
| "| Model | Prefix tokens | Effective epochs | Best dropout | Mean val loss | Val std | Mean train loss | Mean gap | Plateau/bracket note |", |
| "|---|---:|---:|---:|---:|---:|---:|---:|---|", |
| ] |
| ) |
| for model_name, model_group in sorted(by_model.items()): |
| by_prefix: dict[int, list[dict]] = defaultdict(list) |
| for row in model_group: |
| by_prefix[int(row["token_limit"])].append(row) |
| for prefix, prefix_rows in sorted(by_prefix.items()): |
| best = min(prefix_rows, key=lambda row: row["mean_val_eval_loss"]) |
| rates = [float(row["dropout_initial"]) for row in prefix_rows] |
| eff_epochs = float(best["tokens_seen"]) / prefix |
| if best["dropout_initial"] == max(rates): |
| note = "not bracketed; best at top of tested grid" |
| elif best["dropout_initial"] == min(rates): |
| note = "not bracketed; best at bottom of tested grid" |
| else: |
| note = "bracketed by tested grid" |
| lines.append( |
| "| " |
| f"`{model_name}` | {prefix:,} | {eff_epochs:.2f} | " |
| f"{best['dropout_initial']:.2f} | " |
| f"{best['mean_val_eval_loss']:.4f} | " |
| f"{best['std_val_eval_loss']:.4f} | " |
| f"{best['mean_train_eval_loss']:.4f} | " |
| f"{best['mean_generalization_gap']:.4f} | {note} |" |
| ) |
|
|
| for model_name, model_group in sorted(by_model.items()): |
| by_prefix = defaultdict(list) |
| for row in model_group: |
| by_prefix[int(row["token_limit"])].append(row) |
| lines.extend( |
| [ |
| "", |
| f"## Model `{model_name}`", |
| ] |
| ) |
| for prefix, prefix_rows in sorted(by_prefix.items()): |
| eff_epochs = float(prefix_rows[0]["tokens_seen"]) / prefix |
| lines.extend( |
| [ |
| "", |
| f"### Prefix {prefix:,} Tokens ({eff_epochs:.2f} Effective Epochs)", |
| "", |
| "| Dropout | N | Mean val loss | Val std | Mean train loss | Train std | Mean gap | Gap std | Sampled tokens | Params |", |
| "|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|", |
| ] |
| ) |
| for row in sorted(prefix_rows, key=lambda item: item["dropout_initial"]): |
| lines.append( |
| "| " |
| f"{row['dropout_initial']:.2f} | {row['n']} | " |
| f"{row['mean_val_eval_loss']:.4f} | " |
| f"{row['std_val_eval_loss']:.4f} | " |
| f"{row['mean_train_eval_loss']:.4f} | " |
| f"{row['std_train_eval_loss']:.4f} | " |
| f"{row['mean_generalization_gap']:.4f} | " |
| f"{row['std_generalization_gap']:.4f} | " |
| f"{int(row['tokens_seen']):,} | {int(row['parameters']):,} |" |
| ) |
|
|
| output = "\n".join(lines) + "\n" |
| (output_dir / "RESULT_SUMMARY.md").write_text(output, encoding="utf-8") |
|
|
|
|
| def svg_escape(value: object) -> str: |
| return ( |
| str(value) |
| .replace("&", "&") |
| .replace("<", "<") |
| .replace(">", ">") |
| .replace('"', """) |
| ) |
|
|
|
|
| def write_dropout_curve_svg(output_dir: Path, summary: list[dict]) -> None: |
| rows = [ |
| row |
| for row in summary |
| if row["run_mode"] == "screen_static" and row["condition_kind"] == "static" |
| ] |
| if not rows: |
| return |
|
|
| grouped: dict[tuple[str, int], list[dict]] = defaultdict(list) |
| model_params: dict[str, int] = {} |
| for row in rows: |
| model_name = row["model_name"] |
| grouped[(model_name, int(row["token_limit"]))].append(row) |
| model_params[model_name] = int(row["parameters"]) |
|
|
| models = sorted(model_params, key=lambda name: model_params[name]) |
| prefixes = sorted({int(row["token_limit"]) for row in rows}) |
| panel_w, panel_h = 230, 170 |
| margin_l, margin_t = 58, 34 |
| plot_w, plot_h = 142, 94 |
| gap_x, gap_y = 18, 38 |
| width = margin_l + len(prefixes) * panel_w + gap_x |
| height = 70 + len(models) * (panel_h + gap_y) |
| colors = ["#1f77b4", "#d62728", "#2ca02c", "#9467bd", "#ff7f0e"] |
|
|
| parts = [ |
| f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">', |
| "<style>", |
| "text{font-family:Arial,Helvetica,sans-serif;fill:#111827}", |
| ".small{font-size:10px}.label{font-size:11px}.title{font-size:15px;font-weight:700}", |
| ".axis{stroke:#374151;stroke-width:1}.grid{stroke:#e5e7eb;stroke-width:1}.line{fill:none;stroke-width:2}", |
| "</style>", |
| '<rect width="100%" height="100%" fill="#ffffff"/>', |
| '<text x="24" y="28" class="title">Static dropout law: validation loss vs dropout</text>', |
| '<text x="24" y="48" class="label">Each panel uses its own y-scale. Points are one-seed means unless N > 1.</text>', |
| ] |
|
|
| for col, prefix in enumerate(prefixes): |
| x = margin_l + col * panel_w + plot_w / 2 |
| parts.append( |
| f'<text x="{x:.1f}" y="70" text-anchor="middle" class="label">{prefix:,} prefix tokens</text>' |
| ) |
|
|
| for row_idx, model_name in enumerate(models): |
| row_y = 92 + row_idx * (panel_h + gap_y) |
| parts.append( |
| f'<text x="24" y="{row_y + 48}" class="label" transform="rotate(-90 24 {row_y + 48})">' |
| f'{svg_escape(model_name)} ({model_params[model_name] / 1_000_000:.1f}M)</text>' |
| ) |
| for col, prefix in enumerate(prefixes): |
| panel_x = margin_l + col * panel_w |
| panel_y = row_y |
| curve = sorted( |
| grouped.get((model_name, prefix), []), |
| key=lambda item: float(item["dropout_initial"]), |
| ) |
| if not curve: |
| continue |
| losses = [float(item["mean_val_eval_loss"]) for item in curve] |
| min_loss, max_loss = min(losses), max(losses) |
| pad = max(0.02, (max_loss - min_loss) * 0.08) |
| y_min, y_max = min_loss - pad, max_loss + pad |
| best = min(curve, key=lambda item: float(item["mean_val_eval_loss"])) |
|
|
| def px(dropout: float) -> float: |
| return panel_x + (dropout / 0.9) * plot_w |
|
|
| def py(loss: float) -> float: |
| scale = (loss - y_min) / (y_max - y_min) |
| return panel_y + plot_h - scale * plot_h |
|
|
| parts.extend( |
| [ |
| f'<line x1="{panel_x:.1f}" y1="{panel_y:.1f}" x2="{panel_x:.1f}" y2="{panel_y + plot_h:.1f}" class="axis"/>', |
| f'<line x1="{panel_x:.1f}" y1="{panel_y + plot_h:.1f}" x2="{panel_x + plot_w:.1f}" y2="{panel_y + plot_h:.1f}" class="axis"/>', |
| f'<line x1="{panel_x:.1f}" y1="{panel_y:.1f}" x2="{panel_x + plot_w:.1f}" y2="{panel_y:.1f}" class="grid"/>', |
| f'<text x="{panel_x:.1f}" y="{panel_y - 6:.1f}" class="small">{y_max:.2f}</text>', |
| f'<text x="{panel_x:.1f}" y="{panel_y + plot_h + 13:.1f}" class="small">{y_min:.2f}</text>', |
| f'<text x="{panel_x:.1f}" y="{panel_y + plot_h + 28:.1f}" class="small">0</text>', |
| f'<text x="{panel_x + plot_w:.1f}" y="{panel_y + plot_h + 28:.1f}" text-anchor="end" class="small">0.9</text>', |
| ] |
| ) |
| points = " ".join( |
| f"{px(float(item['dropout_initial'])):.1f},{py(float(item['mean_val_eval_loss'])):.1f}" |
| for item in curve |
| ) |
| color = colors[row_idx % len(colors)] |
| parts.append(f'<polyline points="{points}" class="line" stroke="{color}"/>') |
| for item in curve: |
| dropout = float(item["dropout_initial"]) |
| loss = float(item["mean_val_eval_loss"]) |
| radius = 4 if item is best else 2.7 |
| fill = "#111827" if item is best else "#ffffff" |
| parts.append( |
| f'<circle cx="{px(dropout):.1f}" cy="{py(loss):.1f}" r="{radius}" fill="{fill}" stroke="{color}" stroke-width="1.5"/>' |
| ) |
| parts.append( |
| f'<text x="{panel_x + plot_w + 8:.1f}" y="{panel_y + 14:.1f}" class="small">' |
| f'best p={float(best["dropout_initial"]):.2f}</text>' |
| ) |
| parts.append( |
| f'<text x="{panel_x + plot_w + 8:.1f}" y="{panel_y + 28:.1f}" class="small">' |
| f'loss={float(best["mean_val_eval_loss"]):.3f}</text>' |
| ) |
|
|
| parts.append("</svg>") |
| (output_dir / "dropout_curves.svg").write_text("\n".join(parts), encoding="utf-8") |
|
|
|
|
| def write_stream_markdown_summary(output_dir: Path, rows: list[dict]) -> None: |
| stream_rows = [row for row in rows if row["run_mode"] == "locked_stream"] |
| if not stream_rows: |
| return |
|
|
| by_condition_stage: dict[tuple[str, int], list[dict]] = defaultdict(list) |
| by_condition: dict[str, list[dict]] = defaultdict(list) |
| for row in stream_rows: |
| condition = row["condition"] |
| by_condition_stage[(condition, int(row["stage"]))].append(row) |
| by_condition[condition].append(row) |
|
|
| first = stream_rows[0] |
| seeds = sorted({int(row["seed"]) for row in stream_rows}) |
| conditions = sorted( |
| by_condition, |
| key=lambda name: ( |
| by_condition[name][0]["condition_kind"] != "anchor_decay", |
| by_condition[name][0]["dropout_initial"], |
| name, |
| ), |
| ) |
| stages = sorted({int(row["stage"]) for row in stream_rows}) |
|
|
| lines = [ |
| "# Locked Streaming Dropout Summary", |
| "", |
| f"Run directory: `{output_dir}`", |
| "", |
| ( |
| f"Model: `{first['model_name']}` causal Transformer, " |
| f"{int(first['parameters']):,} parameters, {first['n_layer']} layers, " |
| f"{first['n_head']} heads, {first['n_embd']} embedding dim." |
| ), |
| ( |
| f"Training per stage: {first['steps']:,} steps. " |
| "Sampled tokens are cumulative in each stage row. " |
| f"Seeds present: {', '.join(str(seed) for seed in seeds)}." |
| ), |
| "", |
| "## Condition Ranking", |
| "", |
| "| Condition | Kind | Final dropout | Mean trajectory val loss | Final val loss | Final gap | Dropout path |", |
| "|---|---|---:|---:|---:|---:|---|", |
| ] |
|
|
| ranking = [] |
| for condition in conditions: |
| stage_items = [] |
| for stage in stages: |
| group = by_condition_stage.get((condition, stage), []) |
| if not group: |
| continue |
| stage_items.append( |
| { |
| "stage": stage, |
| "token_limit": int(group[0]["token_limit"]), |
| "mean_val": statistics.fmean( |
| float(row["val_eval_loss"]) for row in group |
| ), |
| "mean_gap": statistics.fmean( |
| float(row["generalization_gap"]) for row in group |
| ), |
| "mean_dropout": statistics.fmean( |
| float(row["dropout_active_final"]) for row in group |
| ), |
| "kind": group[0]["condition_kind"], |
| } |
| ) |
| if not stage_items: |
| continue |
| final = max(stage_items, key=lambda item: item["stage"]) |
| ranking.append( |
| { |
| "condition": condition, |
| "kind": stage_items[0]["kind"], |
| "trajectory_val": statistics.fmean(item["mean_val"] for item in stage_items), |
| "final_val": final["mean_val"], |
| "final_gap": final["mean_gap"], |
| "final_dropout": final["mean_dropout"], |
| "dropout_path": " -> ".join( |
| f"{item['mean_dropout']:.2f}" for item in stage_items |
| ), |
| } |
| ) |
|
|
| for item in sorted(ranking, key=lambda row: row["trajectory_val"]): |
| lines.append( |
| "| " |
| f"`{item['condition']}` | {item['kind']} | " |
| f"{item['final_dropout']:.2f} | {item['trajectory_val']:.4f} | " |
| f"{item['final_val']:.4f} | {item['final_gap']:.4f} | " |
| f"{item['dropout_path']} |" |
| ) |
|
|
| lines.extend(["", "## Stage Trajectory", ""]) |
| for stage in stages: |
| stage_groups = { |
| condition: by_condition_stage[(condition, stage)] |
| for condition in conditions |
| if (condition, stage) in by_condition_stage |
| } |
| if not stage_groups: |
| continue |
| prefix = int(next(iter(stage_groups.values()))[0]["token_limit"]) |
| lines.extend( |
| [ |
| f"### Stage {stage}: {prefix:,} Prefix Tokens", |
| "", |
| "| Condition | Dropout | Mean val loss | Mean train loss | Mean gap | N |", |
| "|---|---:|---:|---:|---:|---:|", |
| ] |
| ) |
| for condition, group in sorted( |
| stage_groups.items(), |
| key=lambda item: statistics.fmean( |
| float(row["val_eval_loss"]) for row in item[1] |
| ), |
| ): |
| val = statistics.fmean(float(row["val_eval_loss"]) for row in group) |
| train = statistics.fmean(float(row["train_eval_loss"]) for row in group) |
| gap = statistics.fmean(float(row["generalization_gap"]) for row in group) |
| dropout = statistics.fmean( |
| float(row["dropout_active_final"]) for row in group |
| ) |
| lines.append( |
| "| " |
| f"`{condition}` | {dropout:.2f} | {val:.4f} | " |
| f"{train:.4f} | {gap:.4f} | {len(group)} |" |
| ) |
| lines.append("") |
|
|
| (output_dir / "RESULT_SUMMARY.md").write_text( |
| "\n".join(lines).rstrip() + "\n", |
| encoding="utf-8", |
| ) |
|
|
|
|
| def static_conditions(dropout_rates: list[float]) -> list[DropoutCondition]: |
| return [ |
| DropoutCondition( |
| name=f"static_dropout_{rate_label(rate)}", |
| kind="static", |
| initial=rate, |
| final=rate, |
| ) |
| for rate in dropout_rates |
| ] |
|
|
|
|
| def run_fixed_static_sweep( |
| *, |
| args: argparse.Namespace, |
| model_specs: list[ModelSpec], |
| seeds: list[int], |
| train_tokens: np.ndarray, |
| val_tokens: np.ndarray, |
| tokenizer_vocab_size: int, |
| token_limits: list[int], |
| device: torch.device, |
| metrics_file, |
| trace_file, |
| completed_keys: set[tuple] | None = None, |
| ) -> list[dict]: |
| rows: list[dict] = [] |
| completed_keys = completed_keys or set() |
| conditions = static_conditions(sorted(set(args.dropout_rates))) |
| planned = 0 |
| for token_limit in token_limits: |
| for model_spec in model_specs: |
| for condition in conditions: |
| for seed in seeds: |
| key = planned_metric_key( |
| mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| seed=seed, |
| token_limit=token_limit, |
| ) |
| if key not in completed_keys: |
| planned += 1 |
| progress = ProgressMeter(planned) |
| for token_limit in token_limits: |
| for model_spec in model_specs: |
| best_val_loss = float("inf") |
| worse_streak = 0 |
| for condition in conditions: |
| condition_rows: list[dict] = [] |
| for seed in seeds: |
| key = planned_metric_key( |
| mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| seed=seed, |
| token_limit=token_limit, |
| ) |
| if key in completed_keys: |
| write_jsonl_row( |
| trace_file, |
| { |
| "event": "skipped_completed_condition", |
| "run_mode": args.mode, |
| "condition": condition.name, |
| "model_name": model_spec.name, |
| "seed": seed, |
| "stage": None, |
| "token_limit": int(token_limit), |
| }, |
| ) |
| continue |
| config = model_spec.config( |
| tokenizer_vocab_size, |
| args.block_size, |
| condition.initial, |
| ) |
| model, optimizer, _, row = train_segment( |
| run_mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| config=config, |
| train_tokens=train_tokens, |
| val_tokens=val_tokens, |
| token_limit=token_limit, |
| steps=args.steps, |
| seed=seed, |
| args=args, |
| device=device, |
| dropout_fn=condition.make_fn( |
| args.steps * args.batch_size * args.block_size |
| ), |
| metrics_file=metrics_file, |
| trace_file=trace_file, |
| ) |
| rows.append(row) |
| condition_rows.append(row) |
| completed_keys.add(metric_key(row)) |
| progress.mark_done(row) |
| del model, optimizer |
| torch.mps.empty_cache() |
|
|
| if not condition_rows: |
| continue |
| mean_val_loss = statistics.fmean( |
| float(row["val_eval_loss"]) for row in condition_rows |
| ) |
| if mean_val_loss < best_val_loss - args.screen_prune_min_delta: |
| best_val_loss = mean_val_loss |
| worse_streak = 0 |
| elif mean_val_loss > best_val_loss + args.screen_prune_min_delta: |
| worse_streak += 1 |
|
|
| if ( |
| args.mode == "screen_static" |
| and args.screen_early_stop |
| and worse_streak >= args.screen_prune_patience |
| and condition.initial >= args.target_min_dropout |
| ): |
| write_jsonl_row( |
| trace_file, |
| { |
| "event": "screen_pruned_model", |
| "run_mode": args.mode, |
| "model_name": model_spec.name, |
| "token_limit": int(token_limit), |
| "best_val_loss": best_val_loss, |
| "pruned_after_dropout": condition.initial, |
| "worse_streak": worse_streak, |
| "remaining_dropouts": [ |
| rate |
| for rate in args.dropout_rates |
| if rate > condition.initial |
| ], |
| }, |
| ) |
| break |
| return rows |
|
|
|
|
| def run_locked_stream( |
| *, |
| args: argparse.Namespace, |
| model_specs: list[ModelSpec], |
| seeds: list[int], |
| train_tokens: np.ndarray, |
| val_tokens: np.ndarray, |
| tokenizer_vocab_size: int, |
| stream_caps: list[int], |
| device: torch.device, |
| metrics_file, |
| trace_file, |
| completed_keys: set[tuple] | None = None, |
| ) -> list[dict]: |
| rows: list[dict] = [] |
| completed_keys = completed_keys or set() |
| conditions = args.anchor_decays + static_conditions(args.dropout_rates) + args.decays |
| fallback_decay_tokens = ( |
| args.decay_tokens |
| or args.stage_steps * args.batch_size * args.block_size * len(stream_caps) |
| ) |
| planned = 0 |
| for model_spec in model_specs: |
| for condition in conditions: |
| for seed in seeds: |
| for stage, token_limit in enumerate(stream_caps): |
| key = planned_metric_key( |
| mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| seed=seed, |
| token_limit=token_limit, |
| stage=stage, |
| ) |
| if key not in completed_keys: |
| planned += 1 |
| progress = ProgressMeter(planned) |
| for model_spec in model_specs: |
| for condition in conditions: |
| for seed in seeds: |
| model = None |
| optimizer = None |
| tokens_seen = 0 |
| for stage, token_limit in enumerate(stream_caps): |
| key = planned_metric_key( |
| mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| seed=seed, |
| token_limit=token_limit, |
| stage=stage, |
| ) |
| if key in completed_keys: |
| write_jsonl_row( |
| trace_file, |
| { |
| "event": "skipped_completed_condition", |
| "run_mode": args.mode, |
| "condition": condition.name, |
| "model_name": model_spec.name, |
| "seed": seed, |
| "stage": stage, |
| "token_limit": int(token_limit), |
| }, |
| ) |
| continue |
| config = model_spec.config( |
| tokenizer_vocab_size, |
| args.block_size, |
| condition.initial, |
| ) |
| model, optimizer, tokens_seen, row = train_segment( |
| run_mode=args.mode, |
| condition=condition, |
| model_spec=model_spec, |
| config=config, |
| train_tokens=train_tokens, |
| val_tokens=val_tokens, |
| token_limit=token_limit, |
| steps=args.stage_steps, |
| seed=seed, |
| args=args, |
| device=device, |
| dropout_fn=condition.make_fn( |
| fallback_decay_tokens, |
| unique_tokens=token_limit, |
| ), |
| metrics_file=metrics_file, |
| trace_file=trace_file, |
| stage=stage, |
| model=model, |
| optimizer=optimizer, |
| tokens_seen_start=tokens_seen, |
| ) |
| rows.append(row) |
| completed_keys.add(metric_key(row)) |
| progress.mark_done(row) |
| del model, optimizer |
| torch.mps.empty_cache() |
| return rows |
|
|
|
|
| def prepare_data(args: argparse.Namespace, output_dir: Path, required_train_tokens: int): |
| cache_dir = Path(args.cache_dir) if args.cache_dir else output_dir / "cache" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| if args.use_cached_data: |
| if args.force_retokenize: |
| raise ValueError("--use-cached-data cannot be combined with --force-retokenize") |
| return load_cached_splits( |
| cache_dir=cache_dir, |
| vocab_size=args.vocab_size, |
| max_required_train_tokens=required_train_tokens, |
| val_tokens=args.val_tokens, |
| allow_short_corpus=args.allow_short_corpus, |
| ) |
|
|
| paths = resolve_paths(args.corpus, args.corpus_glob) |
| tokenizer = train_or_load_tokenizer( |
| paths=paths, |
| output_dir=cache_dir, |
| vocab_size=args.vocab_size, |
| tokenizer_train_chars=args.tokenizer_train_chars, |
| text_column=args.text_column, |
| force_retrain=args.force_retokenize, |
| ) |
| splits = encode_corpus( |
| paths=paths, |
| tokenizer=tokenizer, |
| output_dir=cache_dir, |
| max_required_train_tokens=required_train_tokens, |
| val_tokens=args.val_tokens, |
| text_column=args.text_column, |
| allow_short_corpus=args.allow_short_corpus, |
| force_reencode=args.force_retokenize, |
| ) |
| return tokenizer, splits |
|
|
|
|
| def run(args: argparse.Namespace) -> Path: |
| device = assert_mps_only() |
| seeds = default_seeds(args.mode, args.seeds) |
| model_specs = [parse_model_spec(spec) for spec in args.models] |
| if args.mode != "locked_stream" and (args.decays or args.anchor_decays): |
| raise ValueError("--decays and --anchor-decays are only used with --mode locked_stream") |
| if args.resume_from and args.mode == "locked_stream": |
| raise ValueError("--resume-from currently supports fixed static sweeps only") |
|
|
| if args.resume_from: |
| output_dir = Path(args.resume_from) |
| if not output_dir.exists(): |
| raise FileNotFoundError(f"resume directory does not exist: {output_dir}") |
| else: |
| run_id = datetime.now().strftime("%Y%m%d-%H%M%S") |
| output_dir = Path(args.output_dir) / args.mode / run_id |
| output_dir.mkdir(parents=True, exist_ok=True) |
| required_train_tokens = max( |
| args.stream_token_caps if args.mode == "locked_stream" else args.token_limits |
| ) |
| tokenizer, splits = prepare_data(args, output_dir, required_train_tokens) |
| token_limits = [min(limit, len(splits.train)) for limit in args.token_limits] |
| stream_caps = [min(limit, len(splits.train)) for limit in args.stream_token_caps] |
|
|
| args_payload = vars(args).copy() |
| args_payload["decays"] = [condition.to_dict() for condition in args.decays] |
| args_payload["anchor_decays"] = [ |
| condition.to_dict() for condition in args.anchor_decays |
| ] |
| config_payload = { |
| "args": args_payload, |
| "mode": args.mode, |
| "seeds": seeds, |
| "models": [model.to_dict() for model in model_specs], |
| "device": str(device), |
| "torch": torch.__version__, |
| "python": sys.version, |
| "mps_available": torch.backends.mps.is_available(), |
| "attribution": NANOCHAT_ATTRIBUTION, |
| "tokenizer_path": str(splits.tokenizer_path), |
| "encoded_path": str(splits.encoded_path), |
| "train_tokens": int(len(splits.train)), |
| "val_tokens": int(len(splits.val)), |
| "effective_token_limits": [int(limit) for limit in token_limits], |
| "effective_stream_token_caps": [int(limit) for limit in stream_caps], |
| "resume_from": str(args.resume_from) if args.resume_from else None, |
| } |
| config_name = "config.resume.json" if args.resume_from else "config.json" |
| (output_dir / config_name).write_text( |
| json.dumps(config_payload, indent=2), |
| encoding="utf-8", |
| ) |
|
|
| metrics_path = output_dir / "metrics.jsonl" |
| trace_path = output_dir / "trace.jsonl" |
| existing_rows = load_metrics(metrics_path) if args.resume_from else [] |
| completed_keys = {metric_key(row) for row in existing_rows} |
| with ( |
| metrics_path.open("a" if args.resume_from else "w", encoding="utf-8") as metrics_file, |
| trace_path.open("a" if args.resume_from else "w", encoding="utf-8") as trace_file, |
| ): |
| if args.mode in {"screen_static", "confirm_static"}: |
| new_rows = run_fixed_static_sweep( |
| args=args, |
| model_specs=model_specs, |
| seeds=seeds, |
| train_tokens=splits.train, |
| val_tokens=splits.val, |
| tokenizer_vocab_size=tokenizer.vocab_size, |
| token_limits=token_limits, |
| device=device, |
| metrics_file=metrics_file, |
| trace_file=trace_file, |
| completed_keys=completed_keys, |
| ) |
| else: |
| new_rows = run_locked_stream( |
| args=args, |
| model_specs=model_specs, |
| seeds=seeds, |
| train_tokens=splits.train, |
| val_tokens=splits.val, |
| tokenizer_vocab_size=tokenizer.vocab_size, |
| stream_caps=stream_caps, |
| device=device, |
| metrics_file=metrics_file, |
| trace_file=trace_file, |
| ) |
| rows = existing_rows + new_rows |
|
|
| summary = summarize(rows) |
| (output_dir / "summary.json").write_text( |
| json.dumps(summary, indent=2), |
| encoding="utf-8", |
| ) |
| write_csv(output_dir / "summary.csv", summary, SUMMARY_FIELDS) |
| if args.mode in {"screen_static", "confirm_static"}: |
| selection = build_model_selection(summary, args) |
| (output_dir / "model_selection.json").write_text( |
| json.dumps(selection, indent=2), |
| encoding="utf-8", |
| ) |
| write_csv(output_dir / "model_selection.csv", selection, SELECTION_FIELDS) |
| write_screen_markdown_summary(output_dir, rows) |
| write_dropout_curve_svg(output_dir, summary) |
| elif args.mode == "locked_stream": |
| write_stream_markdown_summary(output_dir, rows) |
|
|
| print( |
| json.dumps( |
| { |
| "output_dir": str(output_dir), |
| "new_rows": len(new_rows), |
| "total_metric_rows": len(rows), |
| "summary_rows": len(summary), |
| }, |
| indent=2, |
| ) |
| ) |
| return output_dir |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser( |
| description="MPS-only dropout/model selection experiments" |
| ) |
| parser.add_argument( |
| "--mode", |
| choices=["screen_static", "confirm_static", "locked_stream"], |
| default="screen_static", |
| ) |
| parser.add_argument("--corpus", default=None, help="Text or parquet corpus path") |
| parser.add_argument("--corpus-glob", default=None, help="Glob of text/parquet corpus paths") |
| parser.add_argument("--text-column", default="text", help="Parquet text column") |
| parser.add_argument( |
| "--use-cached-data", |
| action="store_true", |
| help=( |
| "Load tokenizer-v{vocab}.json and tokens-v{vocab}-*.npy from --cache-dir " |
| "instead of requiring the original text/parquet corpus." |
| ), |
| ) |
| parser.add_argument("--output-dir", default="runs") |
| parser.add_argument( |
| "--resume-from", |
| default=None, |
| help="Existing fixed-static run directory; completed metric rows are skipped", |
| ) |
| parser.add_argument("--cache-dir", default=".cache/dropout_decay") |
| parser.add_argument( |
| "--models", |
| nargs="+", |
| default=["8x8x256"], |
| help="Model specs like 8x8x256 or name=8x8x256", |
| ) |
| parser.add_argument("--seeds", nargs="+", type=int, default=None) |
| parser.add_argument("--token-limits", nargs="+", type=int, default=[5_000_000]) |
| parser.add_argument( |
| "--stream-token-caps", |
| nargs="+", |
| type=int, |
| default=[5_000_000, 10_000_000, 20_000_000, 40_000_000], |
| ) |
| parser.add_argument("--val-tokens", type=int, default=500_000) |
| parser.add_argument("--allow-short-corpus", action="store_true") |
| parser.add_argument("--force-retokenize", action="store_true") |
| parser.add_argument("--vocab-size", type=int, default=4096) |
| parser.add_argument("--tokenizer-train-chars", type=int, default=10_000_000) |
| parser.add_argument("--block-size", type=int, default=128) |
| parser.add_argument("--batch-size", type=int, default=16) |
| parser.add_argument("--steps", type=int, default=2000) |
| parser.add_argument("--stage-steps", type=int, default=1000) |
| parser.add_argument("--dropout-rates", nargs="*", type=float, default=DEFAULT_DROPOUT_RATES) |
| parser.add_argument("--decays", nargs="*", type=parse_decay_spec, default=[]) |
| parser.add_argument( |
| "--anchor-decays", |
| nargs="*", |
| type=parse_anchor_decay_spec, |
| default=[], |
| help=( |
| "Prefix-token anchor schedules like " |
| "fit:250000=0.60,500000=0.40,1000000=0.30" |
| ), |
| ) |
| parser.add_argument("--decay-tokens", type=int, default=None) |
| parser.add_argument("--eval-batches", type=int, default=64) |
| parser.add_argument("--train-eval-batches", type=int, default=32) |
| parser.add_argument("--trace-eval-batches", type=int, default=8) |
| parser.add_argument("--eval-every", type=int, default=0) |
| parser.add_argument("--log-every", type=int, default=100) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--weight-decay", type=float, default=0.1) |
| parser.add_argument("--grad-clip", type=float, default=1.0) |
| parser.add_argument("--plateau-delta", type=float, default=0.01) |
| parser.add_argument("--target-min-dropout", type=float, default=0.10) |
| parser.add_argument("--min-nonzero-margin", type=float, default=0.01) |
| parser.add_argument("--min-high-dropout-margin", type=float, default=0.03) |
| parser.add_argument("--screen-early-stop", action="store_true") |
| parser.add_argument("--screen-prune-patience", type=int, default=3) |
| parser.add_argument("--screen-prune-min-delta", type=float, default=0.01) |
| return parser |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
| run(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|