dropout-decay / src /dropout_decay /experiment.py
Mandeep Sidhu
Add dropout pressure validation artifacts
3550904
"""
Derived from Andrej Karpathy's nanochat project.
MIT License
Copyright (c) 2025 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"""
from __future__ import annotations
import argparse
import csv
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
import json
import math
import os
from pathlib import Path
import random
import statistics
import sys
import time
from typing import Callable
import numpy as np
import torch
from dropout_decay.data import (
encode_corpus,
load_cached_splits,
resolve_paths,
train_or_load_tokenizer,
)
from dropout_decay.license import NANOCHAT_ATTRIBUTION
from dropout_decay.model import DropoutGPT, GPTConfig
from dropout_decay.scheduler import DropoutDecayConfig, DropoutDecayScheduler
DEFAULT_DROPOUT_RATES = [0.0, 0.02, 0.05, 0.08, 0.10, 0.14, 0.20, 0.30, 0.50]
SUMMARY_FIELDS = [
"run_mode",
"condition",
"condition_kind",
"stage",
"token_limit",
"model_name",
"n_layer",
"n_head",
"n_embd",
"parameters",
"dropout_initial",
"dropout_final",
"dropout_schedule",
"n",
"mean_train_eval_loss",
"std_train_eval_loss",
"mean_val_eval_loss",
"std_val_eval_loss",
"mean_generalization_gap",
"std_generalization_gap",
]
SELECTION_FIELDS = [
"run_mode",
"token_limit",
"model_name",
"n_layer",
"n_head",
"n_embd",
"parameters",
"n",
"best_dropout",
"best_val_loss",
"best_val_std",
"plateau_start_dropout",
"plateau_end_dropout",
"plateau_delta",
"zero_dropout_val_loss",
"zero_minus_best",
"best_nonzero_dropout",
"best_nonzero_val_loss",
"zero_minus_best_nonzero",
"max_dropout",
"max_dropout_val_loss",
"max_dropout_minus_best",
"has_nonzero_optimum",
"meets_target_dropout",
"curve_json",
]
@dataclass(frozen=True)
class ModelSpec:
name: str
n_layer: int
n_head: int
n_embd: int
def config(self, vocab_size: int, block_size: int, dropout: float) -> GPTConfig:
return GPTConfig(
block_size=block_size,
vocab_size=vocab_size,
n_layer=self.n_layer,
n_head=self.n_head,
n_embd=self.n_embd,
dropout=dropout,
)
def to_dict(self) -> dict[str, int | str]:
return {
"model_name": self.name,
"n_layer": self.n_layer,
"n_head": self.n_head,
"n_embd": self.n_embd,
}
@dataclass(frozen=True)
class DropoutCondition:
name: str
kind: str
initial: float
final: float
schedule: str = "constant"
decay_tokens: int | None = None
anchors: tuple[tuple[int, float], ...] = ()
def to_dict(self) -> dict:
return {
"name": self.name,
"kind": self.kind,
"initial": self.initial,
"final": self.final,
"schedule": self.schedule,
"decay_tokens": self.decay_tokens,
"anchors": [list(anchor) for anchor in self.anchors],
}
def make_fn(
self,
fallback_decay_tokens: int,
unique_tokens: int | None = None,
) -> Callable[[int], float]:
if self.kind == "static":
return lambda _tokens_seen, p=self.initial: p
if self.kind == "anchor_decay":
if unique_tokens is None:
raise ValueError("anchor_decay conditions require unique_tokens")
p = anchor_dropout(unique_tokens, self.anchors)
return lambda _tokens_seen, p=p: p
scheduler = DropoutDecayScheduler(
DropoutDecayConfig(
initial_dropout=self.initial,
final_dropout=self.final,
decay_tokens=self.decay_tokens or fallback_decay_tokens,
schedule=self.schedule,
)
)
return scheduler.value
def anchor_dropout(unique_tokens: int, anchors: tuple[tuple[int, float], ...]) -> float:
if not anchors:
raise ValueError("anchor dropout schedule requires at least one anchor")
ordered = sorted(anchors)
if unique_tokens <= ordered[0][0]:
return ordered[0][1]
if unique_tokens >= ordered[-1][0]:
return ordered[-1][1]
log_unique = math.log(unique_tokens)
for (left_tokens, left_dropout), (right_tokens, right_dropout) in zip(
ordered, ordered[1:]
):
if left_tokens <= unique_tokens <= right_tokens:
left_log = math.log(left_tokens)
right_log = math.log(right_tokens)
mix = (log_unique - left_log) / (right_log - left_log)
return left_dropout + mix * (right_dropout - left_dropout)
return ordered[-1][1]
def assert_mps_only() -> torch.device:
if os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK") == "1":
raise SystemExit(
"Refusing to run: PYTORCH_ENABLE_MPS_FALLBACK=1 could execute CPU fallbacks."
)
if not torch.backends.mps.is_built():
raise SystemExit("PyTorch was not built with MPS. Stopping.")
if not torch.backends.mps.is_available():
raise SystemExit("MPS is not available. Stopping before any Torch experiment work.")
device = torch.device("mps")
torch.set_default_device(device)
return device
def clean_name(value: str) -> str:
cleaned = "".join(
c if c.isalnum() or c in {"-", "_"} else "_" for c in value.strip()
)
return cleaned.strip("_") or "model"
def rate_label(rate: float) -> str:
label = f"{rate:.3f}".rstrip("0").rstrip(".")
return label if label else "0"
def parse_model_spec(raw: str) -> ModelSpec:
if "=" in raw:
name, dims = raw.split("=", 1)
name = clean_name(name)
else:
dims = raw
name = ""
parts = dims.lower().replace(",", "x").split("x")
if len(parts) != 3:
raise argparse.ArgumentTypeError(
"model specs must look like 8x8x256 or name=8x8x256"
)
try:
n_layer, n_head, n_embd = [int(part) for part in parts]
except ValueError as exc:
raise argparse.ArgumentTypeError("model dimensions must be integers") from exc
if n_layer <= 0 or n_head <= 0 or n_embd <= 0:
raise argparse.ArgumentTypeError("model dimensions must be positive")
if n_embd % n_head != 0:
raise argparse.ArgumentTypeError("n_embd must be divisible by n_head")
if (n_embd // n_head) % 2 != 0:
raise argparse.ArgumentTypeError("n_embd / n_head must be even for rotary attention")
return ModelSpec(
name or f"L{n_layer}_H{n_head}_D{n_embd}",
n_layer,
n_head,
n_embd,
)
def parse_decay_spec(raw: str) -> DropoutCondition:
parts = raw.split(":")
if len(parts) not in {3, 4, 5}:
raise argparse.ArgumentTypeError(
"decay specs must look like "
"name:initial:final[:cosine|smoothstep|linear[:decay_tokens]]"
)
name = clean_name(parts[0])
try:
initial = float(parts[1])
final = float(parts[2])
decay_tokens = int(parts[4]) if len(parts) == 5 else None
except ValueError as exc:
raise argparse.ArgumentTypeError(
"decay dropout values and decay_tokens must be numeric"
) from exc
schedule = parts[3] if len(parts) >= 4 else "cosine"
if schedule not in {"cosine", "smoothstep", "linear"}:
raise argparse.ArgumentTypeError(
"decay schedule must be cosine, smoothstep, or linear"
)
return DropoutCondition(
name=name,
kind="decay",
initial=initial,
final=final,
schedule=schedule,
decay_tokens=decay_tokens,
)
def parse_anchor_decay_spec(raw: str) -> DropoutCondition:
if ":" not in raw:
raise argparse.ArgumentTypeError(
"anchor decay specs must look like name:250000=0.60,500000=0.40"
)
name, raw_anchors = raw.split(":", 1)
anchors: list[tuple[int, float]] = []
for piece in raw_anchors.split(","):
if "=" not in piece:
raise argparse.ArgumentTypeError(
"anchor decay anchors must look like token_count=dropout"
)
raw_tokens, raw_dropout = piece.split("=", 1)
try:
token_count = int(raw_tokens)
dropout = float(raw_dropout)
except ValueError as exc:
raise argparse.ArgumentTypeError(
"anchor token counts must be integers and dropout values numeric"
) from exc
if token_count <= 0:
raise argparse.ArgumentTypeError("anchor token counts must be positive")
if not 0.0 <= dropout < 1.0:
raise argparse.ArgumentTypeError("anchor dropout must satisfy 0 <= p < 1")
anchors.append((token_count, dropout))
anchors = sorted(set(anchors))
if len(anchors) < 2:
raise argparse.ArgumentTypeError("provide at least two dropout anchors")
for left, right in zip(anchors, anchors[1:]):
if right[1] > left[1]:
raise argparse.ArgumentTypeError(
"anchor dropout values must be non-increasing as token counts grow"
)
return DropoutCondition(
name=clean_name(name),
kind="anchor_decay",
initial=anchors[0][1],
final=anchors[-1][1],
schedule="log_prefix_anchor",
anchors=tuple(anchors),
)
def default_seeds(mode: str, seeds: list[int] | None) -> list[int]:
if seeds:
return seeds
return [1] if mode == "screen_static" else [1, 2, 3]
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def batch_seed(seed: int, model: ModelSpec, dropout_code: int, stage: int | None) -> int:
return (
seed * 1_000_003
+ model.n_layer * 100_003
+ model.n_head * 10_007
+ model.n_embd * 101
+ dropout_code * 37
+ (stage or 0) * 997
)
def make_batch(
tokens: np.ndarray,
token_limit: int,
batch_size: int,
block_size: int,
rng: np.random.Generator,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
limit = min(token_limit, len(tokens))
max_start = limit - block_size - 1
if max_start <= 0:
raise ValueError("token_limit is too small for the requested block_size")
starts = rng.integers(0, max_start, size=batch_size)
x_np = np.stack([tokens[start : start + block_size] for start in starts]).astype(
np.int64
)
y_np = np.stack(
[tokens[start + 1 : start + 1 + block_size] for start in starts]
).astype(np.int64)
return (
torch.tensor(x_np, dtype=torch.long, device=device),
torch.tensor(y_np, dtype=torch.long, device=device),
)
@torch.no_grad()
def estimate_loss(
model: DropoutGPT,
tokens: np.ndarray,
token_limit: int,
batches: int,
args: argparse.Namespace,
device: torch.device,
rng_seed: int,
) -> float:
if batches <= 0:
return float("nan")
model.eval()
rng = np.random.default_rng(rng_seed)
losses: list[float] = []
for _ in range(batches):
x, y = make_batch(tokens, token_limit, args.batch_size, args.block_size, rng, device)
_, loss = model(x, y)
losses.append(float(loss.item()))
model.train()
return float(statistics.fmean(losses))
def write_jsonl_row(handle, row: dict) -> None:
handle.write(json.dumps(row, sort_keys=True) + "\n")
handle.flush()
def metric_key(row: dict) -> tuple:
return (
row["run_mode"],
row["condition"],
row["condition_kind"],
row.get("stage"),
int(row["token_limit"]),
row["model_name"],
int(row["n_layer"]),
int(row["n_head"]),
int(row["n_embd"]),
int(row["seed"]),
float(row["dropout_initial"]),
float(row["dropout_final"]),
row["dropout_schedule"],
)
def planned_metric_key(
*,
mode: str,
condition: DropoutCondition,
model_spec: ModelSpec,
seed: int,
token_limit: int,
stage: int | None = None,
) -> tuple:
return (
mode,
condition.name,
condition.kind,
stage,
int(token_limit),
model_spec.name,
int(model_spec.n_layer),
int(model_spec.n_head),
int(model_spec.n_embd),
int(seed),
float(condition.initial),
float(condition.final),
condition.schedule,
)
def load_metrics(path: Path) -> list[dict]:
if not path.exists():
return []
rows = []
for line in path.read_text(encoding="utf-8").splitlines():
if line.strip():
rows.append(json.loads(line))
return rows
def train_segment(
*,
run_mode: str,
condition: DropoutCondition,
model_spec: ModelSpec,
config: GPTConfig,
train_tokens: np.ndarray,
val_tokens: np.ndarray,
token_limit: int,
steps: int,
seed: int,
args: argparse.Namespace,
device: torch.device,
dropout_fn: Callable[[int], float],
metrics_file,
trace_file,
stage: int | None = None,
model: DropoutGPT | None = None,
optimizer: torch.optim.Optimizer | None = None,
tokens_seen_start: int = 0,
) -> tuple[DropoutGPT, torch.optim.Optimizer, int, dict]:
if model is None:
set_seed(seed)
model = DropoutGPT(config).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
)
else:
torch.manual_seed(seed + 10_000 + (stage or 0))
if optimizer is None:
raise ValueError("optimizer is required when reusing a model")
model.train()
dropout_code = int(round(condition.initial * 10_000))
rng = np.random.default_rng(batch_seed(seed, model_spec, dropout_code, stage))
tokens_seen = tokens_seen_start
last_loss = float("nan")
active_dropout = condition.initial
t0 = time.time()
for step in range(1, steps + 1):
active_dropout = dropout_fn(tokens_seen)
model.set_dropout(active_dropout)
x, y = make_batch(train_tokens, token_limit, args.batch_size, args.block_size, rng, device)
_, loss = model(x, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
tokens_seen += args.batch_size * args.block_size
last_loss = float(loss.item())
if args.log_every > 0 and step % args.log_every == 0:
write_jsonl_row(
trace_file,
{
"event": "train_step",
"run_mode": run_mode,
"condition": condition.name,
"model_name": model_spec.name,
"seed": seed,
"stage": stage,
"step": step,
"steps": steps,
"token_limit": int(token_limit),
"tokens_seen": int(tokens_seen),
"dropout": float(active_dropout),
"train_batch_loss": last_loss,
},
)
if args.eval_every > 0 and step % args.eval_every == 0:
train_eval = estimate_loss(
model,
train_tokens,
token_limit,
args.trace_eval_batches,
args,
device,
rng_seed=seed + 20_000 + step,
)
val_eval = estimate_loss(
model,
val_tokens,
len(val_tokens),
args.trace_eval_batches,
args,
device,
rng_seed=seed + 30_000 + step,
)
write_jsonl_row(
trace_file,
{
"event": "eval_step",
"run_mode": run_mode,
"condition": condition.name,
"model_name": model_spec.name,
"seed": seed,
"stage": stage,
"step": step,
"steps": steps,
"token_limit": int(token_limit),
"tokens_seen": int(tokens_seen),
"dropout": float(active_dropout),
"train_eval_loss": train_eval,
"val_eval_loss": val_eval,
"generalization_gap": val_eval - train_eval,
},
)
train_eval = estimate_loss(
model,
train_tokens,
token_limit,
args.train_eval_batches,
args,
device,
rng_seed=seed + 40_000 + (stage or 0),
)
val_eval = estimate_loss(
model,
val_tokens,
len(val_tokens),
args.eval_batches,
args,
device,
rng_seed=seed + 50_000 + (stage or 0),
)
row = {
"run_mode": run_mode,
"condition": condition.name,
"condition_kind": condition.kind,
"seed": seed,
"stage": stage,
"token_limit": int(token_limit),
"steps": int(steps),
"tokens_seen": int(tokens_seen),
"dropout_initial": float(condition.initial),
"dropout_final": float(condition.final),
"dropout_schedule": condition.schedule,
"dropout_active_final": float(active_dropout),
"train_loss_last": last_loss,
"train_eval_loss": train_eval,
"val_eval_loss": val_eval,
"eval_loss": val_eval,
"generalization_gap": val_eval - train_eval,
"elapsed_sec": time.time() - t0,
"parameters": model.num_parameters(),
"model_config": config.to_dict(),
**model_spec.to_dict(),
}
write_jsonl_row(metrics_file, row)
return model, optimizer, tokens_seen, row
def summarize(rows: list[dict]) -> list[dict]:
groups: dict[tuple, list[dict]] = defaultdict(list)
for row in rows:
key = (
row["run_mode"],
row["condition"],
row["condition_kind"],
row["stage"],
row["token_limit"],
row["model_name"],
row["n_layer"],
row["n_head"],
row["n_embd"],
row["parameters"],
row["dropout_initial"],
row["dropout_final"],
row["dropout_schedule"],
)
groups[key].append(row)
summary: list[dict] = []
for group_rows in groups.values():
first = group_rows[0]
item = {field: first[field] for field in SUMMARY_FIELDS if field in first}
item["n"] = len(group_rows)
for source, mean_key, std_key in [
("train_eval_loss", "mean_train_eval_loss", "std_train_eval_loss"),
("val_eval_loss", "mean_val_eval_loss", "std_val_eval_loss"),
("generalization_gap", "mean_generalization_gap", "std_generalization_gap"),
]:
values = [float(row[source]) for row in group_rows]
item[mean_key] = statistics.fmean(values)
item[std_key] = statistics.stdev(values) if len(values) > 1 else 0.0
summary.append(item)
return sorted(
summary,
key=lambda row: (
row["run_mode"],
row["model_name"],
row["token_limit"],
row["condition"],
row["stage"] or -1,
),
)
def write_csv(path: Path, rows: list[dict], fieldnames: list[str]) -> None:
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
def format_duration(seconds: float) -> str:
seconds = max(0, int(seconds))
hours, remainder = divmod(seconds, 3600)
minutes, secs = divmod(remainder, 60)
if hours:
return f"{hours}h{minutes:02d}m"
if minutes:
return f"{minutes}m{secs:02d}s"
return f"{secs}s"
class ProgressMeter:
def __init__(self, total: int):
self.total = max(0, total)
self.done = 0
self.started_at = time.time()
def mark_done(self, row: dict) -> None:
self.done += 1
elapsed = time.time() - self.started_at
mean = elapsed / self.done if self.done else 0.0
remaining = max(0, self.total - self.done)
eta = remaining * mean
stage = row.get("stage")
stage_text = f"stage={stage} " if stage is not None else ""
print(
"progress "
f"{self.done}/{self.total} "
f"eta={format_duration(eta)} "
f"mode={row['run_mode']} "
f"model={row['model_name']} "
f"params={int(row['parameters']):,} "
f"prefix={int(row['token_limit']):,} "
f"{stage_text}"
f"seed={row['seed']} "
f"condition={row['condition']} "
f"dropout={float(row['dropout_active_final']):.3f} "
f"val={row['val_eval_loss']:.4f} "
f"train={row['train_eval_loss']:.4f} "
f"gap={row['generalization_gap']:.4f} "
f"elapsed={format_duration(float(row['elapsed_sec']))}",
flush=True,
)
def build_model_selection(summary: list[dict], args: argparse.Namespace) -> list[dict]:
groups: dict[tuple, list[dict]] = defaultdict(list)
for row in summary:
if row["condition_kind"] == "static" and row["stage"] is None:
groups[(row["run_mode"], row["token_limit"], row["model_name"])].append(row)
selection: list[dict] = []
for (run_mode, token_limit, model_name), rows in groups.items():
curve = sorted(rows, key=lambda row: row["dropout_initial"])
best = min(curve, key=lambda row: row["mean_val_eval_loss"])
plateau = [
row for row in curve
if row["mean_val_eval_loss"] <= best["mean_val_eval_loss"] + args.plateau_delta
]
zero = next((row for row in curve if row["dropout_initial"] == 0.0), None)
nonzero = [row for row in curve if row["dropout_initial"] > 0.0]
best_nonzero = min(nonzero, key=lambda row: row["mean_val_eval_loss"]) if nonzero else None
max_dropout = max(curve, key=lambda row: row["dropout_initial"])
zero_loss = zero["mean_val_eval_loss"] if zero else None
zero_minus_best = zero_loss - best["mean_val_eval_loss"] if zero_loss is not None else None
zero_minus_best_nonzero = (
zero_loss - best_nonzero["mean_val_eval_loss"]
if zero_loss is not None and best_nonzero is not None
else None
)
max_minus_best = max_dropout["mean_val_eval_loss"] - best["mean_val_eval_loss"]
has_nonzero_optimum = (
best["dropout_initial"] > 0.0
and zero_minus_best is not None
and zero_minus_best >= args.min_nonzero_margin
and max_minus_best >= args.min_high_dropout_margin
)
curve_json = json.dumps(
[
{
"dropout": row["dropout_initial"],
"mean_val_loss": row["mean_val_eval_loss"],
"std_val_loss": row["std_val_eval_loss"],
"mean_train_loss": row["mean_train_eval_loss"],
"mean_generalization_gap": row["mean_generalization_gap"],
"n": row["n"],
}
for row in curve
],
sort_keys=True,
)
selection.append(
{
"run_mode": run_mode,
"token_limit": token_limit,
"model_name": model_name,
"n_layer": best["n_layer"],
"n_head": best["n_head"],
"n_embd": best["n_embd"],
"parameters": best["parameters"],
"n": best["n"],
"best_dropout": best["dropout_initial"],
"best_val_loss": best["mean_val_eval_loss"],
"best_val_std": best["std_val_eval_loss"],
"plateau_start_dropout": min(row["dropout_initial"] for row in plateau),
"plateau_end_dropout": max(row["dropout_initial"] for row in plateau),
"plateau_delta": args.plateau_delta,
"zero_dropout_val_loss": zero_loss,
"zero_minus_best": zero_minus_best,
"best_nonzero_dropout": (
best_nonzero["dropout_initial"] if best_nonzero else None
),
"best_nonzero_val_loss": (
best_nonzero["mean_val_eval_loss"] if best_nonzero else None
),
"zero_minus_best_nonzero": zero_minus_best_nonzero,
"max_dropout": max_dropout["dropout_initial"],
"max_dropout_val_loss": max_dropout["mean_val_eval_loss"],
"max_dropout_minus_best": max_minus_best,
"has_nonzero_optimum": has_nonzero_optimum,
"meets_target_dropout": best["dropout_initial"] >= args.target_min_dropout,
"curve_json": curve_json,
}
)
return sorted(
selection,
key=lambda row: (
not row["has_nonzero_optimum"],
not row["meets_target_dropout"],
row["best_val_loss"],
),
)
def write_screen_markdown_summary(output_dir: Path, rows: list[dict]) -> None:
if not rows:
return
static_rows = [
row
for row in rows
if row["run_mode"] == "screen_static" and row["condition_kind"] == "static"
]
if not static_rows:
return
by_model_prefix_rate: dict[tuple[str, int, float], list[dict]] = defaultdict(list)
for row in rows:
if row["run_mode"] == "screen_static" and row["condition_kind"] == "static":
by_model_prefix_rate[
(
row["model_name"],
int(row["token_limit"]),
float(row["dropout_initial"]),
)
].append(row)
aggregates: list[dict] = []
for (model_name, prefix, dropout), group_rows in by_model_prefix_rate.items():
first = group_rows[0]
val_losses = [float(row["val_eval_loss"]) for row in group_rows]
train_losses = [float(row["train_eval_loss"]) for row in group_rows]
gaps = [float(row["generalization_gap"]) for row in group_rows]
aggregates.append(
{
"model_name": model_name,
"token_limit": prefix,
"dropout_initial": dropout,
"n": len(group_rows),
"mean_val_eval_loss": statistics.fmean(val_losses),
"std_val_eval_loss": statistics.stdev(val_losses)
if len(val_losses) > 1
else 0.0,
"mean_train_eval_loss": statistics.fmean(train_losses),
"std_train_eval_loss": statistics.stdev(train_losses)
if len(train_losses) > 1
else 0.0,
"mean_generalization_gap": statistics.fmean(gaps),
"std_generalization_gap": statistics.stdev(gaps)
if len(gaps) > 1
else 0.0,
"parameters": int(first["parameters"]),
"n_layer": int(first["n_layer"]),
"n_head": int(first["n_head"]),
"n_embd": int(first["n_embd"]),
"block_size": int(first["model_config"]["block_size"]),
"vocab_size": int(first["model_config"]["vocab_size"]),
"tokens_seen": int(first["tokens_seen"]),
"seeds": sorted({int(row["seed"]) for row in group_rows}),
}
)
by_model: dict[str, list[dict]] = defaultdict(list)
for row in aggregates:
by_model[row["model_name"]].append(row)
model_rows = []
for model_name, model_group in by_model.items():
first = model_group[0]
seeds = sorted({seed for row in model_group for seed in row["seeds"]})
model_rows.append(
{
"model_name": model_name,
"parameters": first["parameters"],
"n_layer": first["n_layer"],
"n_head": first["n_head"],
"n_embd": first["n_embd"],
"block_size": first["block_size"],
"vocab_size": first["vocab_size"],
"seeds": seeds,
}
)
lines = [
"# Static Dropout Screen Summary",
"",
f"Run directory: `{output_dir}`",
"",
"## Models",
"",
"| Model | Params | Layers | Heads | Embedding | Block | Vocab | Seeds |",
"|---|---:|---:|---:|---:|---:|---:|---|",
]
for model in sorted(model_rows, key=lambda item: item["parameters"]):
lines.append(
"| "
f"`{model['model_name']}` | {model['parameters']:,} | "
f"{model['n_layer']} | {model['n_head']} | {model['n_embd']} | "
f"{model['block_size']} | {model['vocab_size']} | "
f"{', '.join(str(seed) for seed in model['seeds'])} |"
)
lines.extend(
[
"",
"## Best Dropout By Model And Prefix",
"",
"| Model | Prefix tokens | Effective epochs | Best dropout | Mean val loss | Val std | Mean train loss | Mean gap | Plateau/bracket note |",
"|---|---:|---:|---:|---:|---:|---:|---:|---|",
]
)
for model_name, model_group in sorted(by_model.items()):
by_prefix: dict[int, list[dict]] = defaultdict(list)
for row in model_group:
by_prefix[int(row["token_limit"])].append(row)
for prefix, prefix_rows in sorted(by_prefix.items()):
best = min(prefix_rows, key=lambda row: row["mean_val_eval_loss"])
rates = [float(row["dropout_initial"]) for row in prefix_rows]
eff_epochs = float(best["tokens_seen"]) / prefix
if best["dropout_initial"] == max(rates):
note = "not bracketed; best at top of tested grid"
elif best["dropout_initial"] == min(rates):
note = "not bracketed; best at bottom of tested grid"
else:
note = "bracketed by tested grid"
lines.append(
"| "
f"`{model_name}` | {prefix:,} | {eff_epochs:.2f} | "
f"{best['dropout_initial']:.2f} | "
f"{best['mean_val_eval_loss']:.4f} | "
f"{best['std_val_eval_loss']:.4f} | "
f"{best['mean_train_eval_loss']:.4f} | "
f"{best['mean_generalization_gap']:.4f} | {note} |"
)
for model_name, model_group in sorted(by_model.items()):
by_prefix = defaultdict(list)
for row in model_group:
by_prefix[int(row["token_limit"])].append(row)
lines.extend(
[
"",
f"## Model `{model_name}`",
]
)
for prefix, prefix_rows in sorted(by_prefix.items()):
eff_epochs = float(prefix_rows[0]["tokens_seen"]) / prefix
lines.extend(
[
"",
f"### Prefix {prefix:,} Tokens ({eff_epochs:.2f} Effective Epochs)",
"",
"| Dropout | N | Mean val loss | Val std | Mean train loss | Train std | Mean gap | Gap std | Sampled tokens | Params |",
"|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|",
]
)
for row in sorted(prefix_rows, key=lambda item: item["dropout_initial"]):
lines.append(
"| "
f"{row['dropout_initial']:.2f} | {row['n']} | "
f"{row['mean_val_eval_loss']:.4f} | "
f"{row['std_val_eval_loss']:.4f} | "
f"{row['mean_train_eval_loss']:.4f} | "
f"{row['std_train_eval_loss']:.4f} | "
f"{row['mean_generalization_gap']:.4f} | "
f"{row['std_generalization_gap']:.4f} | "
f"{int(row['tokens_seen']):,} | {int(row['parameters']):,} |"
)
output = "\n".join(lines) + "\n"
(output_dir / "RESULT_SUMMARY.md").write_text(output, encoding="utf-8")
def svg_escape(value: object) -> str:
return (
str(value)
.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
)
def write_dropout_curve_svg(output_dir: Path, summary: list[dict]) -> None:
rows = [
row
for row in summary
if row["run_mode"] == "screen_static" and row["condition_kind"] == "static"
]
if not rows:
return
grouped: dict[tuple[str, int], list[dict]] = defaultdict(list)
model_params: dict[str, int] = {}
for row in rows:
model_name = row["model_name"]
grouped[(model_name, int(row["token_limit"]))].append(row)
model_params[model_name] = int(row["parameters"])
models = sorted(model_params, key=lambda name: model_params[name])
prefixes = sorted({int(row["token_limit"]) for row in rows})
panel_w, panel_h = 230, 170
margin_l, margin_t = 58, 34
plot_w, plot_h = 142, 94
gap_x, gap_y = 18, 38
width = margin_l + len(prefixes) * panel_w + gap_x
height = 70 + len(models) * (panel_h + gap_y)
colors = ["#1f77b4", "#d62728", "#2ca02c", "#9467bd", "#ff7f0e"]
parts = [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
"<style>",
"text{font-family:Arial,Helvetica,sans-serif;fill:#111827}",
".small{font-size:10px}.label{font-size:11px}.title{font-size:15px;font-weight:700}",
".axis{stroke:#374151;stroke-width:1}.grid{stroke:#e5e7eb;stroke-width:1}.line{fill:none;stroke-width:2}",
"</style>",
'<rect width="100%" height="100%" fill="#ffffff"/>',
'<text x="24" y="28" class="title">Static dropout law: validation loss vs dropout</text>',
'<text x="24" y="48" class="label">Each panel uses its own y-scale. Points are one-seed means unless N &gt; 1.</text>',
]
for col, prefix in enumerate(prefixes):
x = margin_l + col * panel_w + plot_w / 2
parts.append(
f'<text x="{x:.1f}" y="70" text-anchor="middle" class="label">{prefix:,} prefix tokens</text>'
)
for row_idx, model_name in enumerate(models):
row_y = 92 + row_idx * (panel_h + gap_y)
parts.append(
f'<text x="24" y="{row_y + 48}" class="label" transform="rotate(-90 24 {row_y + 48})">'
f'{svg_escape(model_name)} ({model_params[model_name] / 1_000_000:.1f}M)</text>'
)
for col, prefix in enumerate(prefixes):
panel_x = margin_l + col * panel_w
panel_y = row_y
curve = sorted(
grouped.get((model_name, prefix), []),
key=lambda item: float(item["dropout_initial"]),
)
if not curve:
continue
losses = [float(item["mean_val_eval_loss"]) for item in curve]
min_loss, max_loss = min(losses), max(losses)
pad = max(0.02, (max_loss - min_loss) * 0.08)
y_min, y_max = min_loss - pad, max_loss + pad
best = min(curve, key=lambda item: float(item["mean_val_eval_loss"]))
def px(dropout: float) -> float:
return panel_x + (dropout / 0.9) * plot_w
def py(loss: float) -> float:
scale = (loss - y_min) / (y_max - y_min)
return panel_y + plot_h - scale * plot_h
parts.extend(
[
f'<line x1="{panel_x:.1f}" y1="{panel_y:.1f}" x2="{panel_x:.1f}" y2="{panel_y + plot_h:.1f}" class="axis"/>',
f'<line x1="{panel_x:.1f}" y1="{panel_y + plot_h:.1f}" x2="{panel_x + plot_w:.1f}" y2="{panel_y + plot_h:.1f}" class="axis"/>',
f'<line x1="{panel_x:.1f}" y1="{panel_y:.1f}" x2="{panel_x + plot_w:.1f}" y2="{panel_y:.1f}" class="grid"/>',
f'<text x="{panel_x:.1f}" y="{panel_y - 6:.1f}" class="small">{y_max:.2f}</text>',
f'<text x="{panel_x:.1f}" y="{panel_y + plot_h + 13:.1f}" class="small">{y_min:.2f}</text>',
f'<text x="{panel_x:.1f}" y="{panel_y + plot_h + 28:.1f}" class="small">0</text>',
f'<text x="{panel_x + plot_w:.1f}" y="{panel_y + plot_h + 28:.1f}" text-anchor="end" class="small">0.9</text>',
]
)
points = " ".join(
f"{px(float(item['dropout_initial'])):.1f},{py(float(item['mean_val_eval_loss'])):.1f}"
for item in curve
)
color = colors[row_idx % len(colors)]
parts.append(f'<polyline points="{points}" class="line" stroke="{color}"/>')
for item in curve:
dropout = float(item["dropout_initial"])
loss = float(item["mean_val_eval_loss"])
radius = 4 if item is best else 2.7
fill = "#111827" if item is best else "#ffffff"
parts.append(
f'<circle cx="{px(dropout):.1f}" cy="{py(loss):.1f}" r="{radius}" fill="{fill}" stroke="{color}" stroke-width="1.5"/>'
)
parts.append(
f'<text x="{panel_x + plot_w + 8:.1f}" y="{panel_y + 14:.1f}" class="small">'
f'best p={float(best["dropout_initial"]):.2f}</text>'
)
parts.append(
f'<text x="{panel_x + plot_w + 8:.1f}" y="{panel_y + 28:.1f}" class="small">'
f'loss={float(best["mean_val_eval_loss"]):.3f}</text>'
)
parts.append("</svg>")
(output_dir / "dropout_curves.svg").write_text("\n".join(parts), encoding="utf-8")
def write_stream_markdown_summary(output_dir: Path, rows: list[dict]) -> None:
stream_rows = [row for row in rows if row["run_mode"] == "locked_stream"]
if not stream_rows:
return
by_condition_stage: dict[tuple[str, int], list[dict]] = defaultdict(list)
by_condition: dict[str, list[dict]] = defaultdict(list)
for row in stream_rows:
condition = row["condition"]
by_condition_stage[(condition, int(row["stage"]))].append(row)
by_condition[condition].append(row)
first = stream_rows[0]
seeds = sorted({int(row["seed"]) for row in stream_rows})
conditions = sorted(
by_condition,
key=lambda name: (
by_condition[name][0]["condition_kind"] != "anchor_decay",
by_condition[name][0]["dropout_initial"],
name,
),
)
stages = sorted({int(row["stage"]) for row in stream_rows})
lines = [
"# Locked Streaming Dropout Summary",
"",
f"Run directory: `{output_dir}`",
"",
(
f"Model: `{first['model_name']}` causal Transformer, "
f"{int(first['parameters']):,} parameters, {first['n_layer']} layers, "
f"{first['n_head']} heads, {first['n_embd']} embedding dim."
),
(
f"Training per stage: {first['steps']:,} steps. "
"Sampled tokens are cumulative in each stage row. "
f"Seeds present: {', '.join(str(seed) for seed in seeds)}."
),
"",
"## Condition Ranking",
"",
"| Condition | Kind | Final dropout | Mean trajectory val loss | Final val loss | Final gap | Dropout path |",
"|---|---|---:|---:|---:|---:|---|",
]
ranking = []
for condition in conditions:
stage_items = []
for stage in stages:
group = by_condition_stage.get((condition, stage), [])
if not group:
continue
stage_items.append(
{
"stage": stage,
"token_limit": int(group[0]["token_limit"]),
"mean_val": statistics.fmean(
float(row["val_eval_loss"]) for row in group
),
"mean_gap": statistics.fmean(
float(row["generalization_gap"]) for row in group
),
"mean_dropout": statistics.fmean(
float(row["dropout_active_final"]) for row in group
),
"kind": group[0]["condition_kind"],
}
)
if not stage_items:
continue
final = max(stage_items, key=lambda item: item["stage"])
ranking.append(
{
"condition": condition,
"kind": stage_items[0]["kind"],
"trajectory_val": statistics.fmean(item["mean_val"] for item in stage_items),
"final_val": final["mean_val"],
"final_gap": final["mean_gap"],
"final_dropout": final["mean_dropout"],
"dropout_path": " -> ".join(
f"{item['mean_dropout']:.2f}" for item in stage_items
),
}
)
for item in sorted(ranking, key=lambda row: row["trajectory_val"]):
lines.append(
"| "
f"`{item['condition']}` | {item['kind']} | "
f"{item['final_dropout']:.2f} | {item['trajectory_val']:.4f} | "
f"{item['final_val']:.4f} | {item['final_gap']:.4f} | "
f"{item['dropout_path']} |"
)
lines.extend(["", "## Stage Trajectory", ""])
for stage in stages:
stage_groups = {
condition: by_condition_stage[(condition, stage)]
for condition in conditions
if (condition, stage) in by_condition_stage
}
if not stage_groups:
continue
prefix = int(next(iter(stage_groups.values()))[0]["token_limit"])
lines.extend(
[
f"### Stage {stage}: {prefix:,} Prefix Tokens",
"",
"| Condition | Dropout | Mean val loss | Mean train loss | Mean gap | N |",
"|---|---:|---:|---:|---:|---:|",
]
)
for condition, group in sorted(
stage_groups.items(),
key=lambda item: statistics.fmean(
float(row["val_eval_loss"]) for row in item[1]
),
):
val = statistics.fmean(float(row["val_eval_loss"]) for row in group)
train = statistics.fmean(float(row["train_eval_loss"]) for row in group)
gap = statistics.fmean(float(row["generalization_gap"]) for row in group)
dropout = statistics.fmean(
float(row["dropout_active_final"]) for row in group
)
lines.append(
"| "
f"`{condition}` | {dropout:.2f} | {val:.4f} | "
f"{train:.4f} | {gap:.4f} | {len(group)} |"
)
lines.append("")
(output_dir / "RESULT_SUMMARY.md").write_text(
"\n".join(lines).rstrip() + "\n",
encoding="utf-8",
)
def static_conditions(dropout_rates: list[float]) -> list[DropoutCondition]:
return [
DropoutCondition(
name=f"static_dropout_{rate_label(rate)}",
kind="static",
initial=rate,
final=rate,
)
for rate in dropout_rates
]
def run_fixed_static_sweep(
*,
args: argparse.Namespace,
model_specs: list[ModelSpec],
seeds: list[int],
train_tokens: np.ndarray,
val_tokens: np.ndarray,
tokenizer_vocab_size: int,
token_limits: list[int],
device: torch.device,
metrics_file,
trace_file,
completed_keys: set[tuple] | None = None,
) -> list[dict]:
rows: list[dict] = []
completed_keys = completed_keys or set()
conditions = static_conditions(sorted(set(args.dropout_rates)))
planned = 0
for token_limit in token_limits:
for model_spec in model_specs:
for condition in conditions:
for seed in seeds:
key = planned_metric_key(
mode=args.mode,
condition=condition,
model_spec=model_spec,
seed=seed,
token_limit=token_limit,
)
if key not in completed_keys:
planned += 1
progress = ProgressMeter(planned)
for token_limit in token_limits:
for model_spec in model_specs:
best_val_loss = float("inf")
worse_streak = 0
for condition in conditions:
condition_rows: list[dict] = []
for seed in seeds:
key = planned_metric_key(
mode=args.mode,
condition=condition,
model_spec=model_spec,
seed=seed,
token_limit=token_limit,
)
if key in completed_keys:
write_jsonl_row(
trace_file,
{
"event": "skipped_completed_condition",
"run_mode": args.mode,
"condition": condition.name,
"model_name": model_spec.name,
"seed": seed,
"stage": None,
"token_limit": int(token_limit),
},
)
continue
config = model_spec.config(
tokenizer_vocab_size,
args.block_size,
condition.initial,
)
model, optimizer, _, row = train_segment(
run_mode=args.mode,
condition=condition,
model_spec=model_spec,
config=config,
train_tokens=train_tokens,
val_tokens=val_tokens,
token_limit=token_limit,
steps=args.steps,
seed=seed,
args=args,
device=device,
dropout_fn=condition.make_fn(
args.steps * args.batch_size * args.block_size
),
metrics_file=metrics_file,
trace_file=trace_file,
)
rows.append(row)
condition_rows.append(row)
completed_keys.add(metric_key(row))
progress.mark_done(row)
del model, optimizer
torch.mps.empty_cache()
if not condition_rows:
continue
mean_val_loss = statistics.fmean(
float(row["val_eval_loss"]) for row in condition_rows
)
if mean_val_loss < best_val_loss - args.screen_prune_min_delta:
best_val_loss = mean_val_loss
worse_streak = 0
elif mean_val_loss > best_val_loss + args.screen_prune_min_delta:
worse_streak += 1
if (
args.mode == "screen_static"
and args.screen_early_stop
and worse_streak >= args.screen_prune_patience
and condition.initial >= args.target_min_dropout
):
write_jsonl_row(
trace_file,
{
"event": "screen_pruned_model",
"run_mode": args.mode,
"model_name": model_spec.name,
"token_limit": int(token_limit),
"best_val_loss": best_val_loss,
"pruned_after_dropout": condition.initial,
"worse_streak": worse_streak,
"remaining_dropouts": [
rate
for rate in args.dropout_rates
if rate > condition.initial
],
},
)
break
return rows
def run_locked_stream(
*,
args: argparse.Namespace,
model_specs: list[ModelSpec],
seeds: list[int],
train_tokens: np.ndarray,
val_tokens: np.ndarray,
tokenizer_vocab_size: int,
stream_caps: list[int],
device: torch.device,
metrics_file,
trace_file,
completed_keys: set[tuple] | None = None,
) -> list[dict]:
rows: list[dict] = []
completed_keys = completed_keys or set()
conditions = args.anchor_decays + static_conditions(args.dropout_rates) + args.decays
fallback_decay_tokens = (
args.decay_tokens
or args.stage_steps * args.batch_size * args.block_size * len(stream_caps)
)
planned = 0
for model_spec in model_specs:
for condition in conditions:
for seed in seeds:
for stage, token_limit in enumerate(stream_caps):
key = planned_metric_key(
mode=args.mode,
condition=condition,
model_spec=model_spec,
seed=seed,
token_limit=token_limit,
stage=stage,
)
if key not in completed_keys:
planned += 1
progress = ProgressMeter(planned)
for model_spec in model_specs:
for condition in conditions:
for seed in seeds:
model = None
optimizer = None
tokens_seen = 0
for stage, token_limit in enumerate(stream_caps):
key = planned_metric_key(
mode=args.mode,
condition=condition,
model_spec=model_spec,
seed=seed,
token_limit=token_limit,
stage=stage,
)
if key in completed_keys:
write_jsonl_row(
trace_file,
{
"event": "skipped_completed_condition",
"run_mode": args.mode,
"condition": condition.name,
"model_name": model_spec.name,
"seed": seed,
"stage": stage,
"token_limit": int(token_limit),
},
)
continue
config = model_spec.config(
tokenizer_vocab_size,
args.block_size,
condition.initial,
)
model, optimizer, tokens_seen, row = train_segment(
run_mode=args.mode,
condition=condition,
model_spec=model_spec,
config=config,
train_tokens=train_tokens,
val_tokens=val_tokens,
token_limit=token_limit,
steps=args.stage_steps,
seed=seed,
args=args,
device=device,
dropout_fn=condition.make_fn(
fallback_decay_tokens,
unique_tokens=token_limit,
),
metrics_file=metrics_file,
trace_file=trace_file,
stage=stage,
model=model,
optimizer=optimizer,
tokens_seen_start=tokens_seen,
)
rows.append(row)
completed_keys.add(metric_key(row))
progress.mark_done(row)
del model, optimizer
torch.mps.empty_cache()
return rows
def prepare_data(args: argparse.Namespace, output_dir: Path, required_train_tokens: int):
cache_dir = Path(args.cache_dir) if args.cache_dir else output_dir / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
if args.use_cached_data:
if args.force_retokenize:
raise ValueError("--use-cached-data cannot be combined with --force-retokenize")
return load_cached_splits(
cache_dir=cache_dir,
vocab_size=args.vocab_size,
max_required_train_tokens=required_train_tokens,
val_tokens=args.val_tokens,
allow_short_corpus=args.allow_short_corpus,
)
paths = resolve_paths(args.corpus, args.corpus_glob)
tokenizer = train_or_load_tokenizer(
paths=paths,
output_dir=cache_dir,
vocab_size=args.vocab_size,
tokenizer_train_chars=args.tokenizer_train_chars,
text_column=args.text_column,
force_retrain=args.force_retokenize,
)
splits = encode_corpus(
paths=paths,
tokenizer=tokenizer,
output_dir=cache_dir,
max_required_train_tokens=required_train_tokens,
val_tokens=args.val_tokens,
text_column=args.text_column,
allow_short_corpus=args.allow_short_corpus,
force_reencode=args.force_retokenize,
)
return tokenizer, splits
def run(args: argparse.Namespace) -> Path:
device = assert_mps_only()
seeds = default_seeds(args.mode, args.seeds)
model_specs = [parse_model_spec(spec) for spec in args.models]
if args.mode != "locked_stream" and (args.decays or args.anchor_decays):
raise ValueError("--decays and --anchor-decays are only used with --mode locked_stream")
if args.resume_from and args.mode == "locked_stream":
raise ValueError("--resume-from currently supports fixed static sweeps only")
if args.resume_from:
output_dir = Path(args.resume_from)
if not output_dir.exists():
raise FileNotFoundError(f"resume directory does not exist: {output_dir}")
else:
run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
output_dir = Path(args.output_dir) / args.mode / run_id
output_dir.mkdir(parents=True, exist_ok=True)
required_train_tokens = max(
args.stream_token_caps if args.mode == "locked_stream" else args.token_limits
)
tokenizer, splits = prepare_data(args, output_dir, required_train_tokens)
token_limits = [min(limit, len(splits.train)) for limit in args.token_limits]
stream_caps = [min(limit, len(splits.train)) for limit in args.stream_token_caps]
args_payload = vars(args).copy()
args_payload["decays"] = [condition.to_dict() for condition in args.decays]
args_payload["anchor_decays"] = [
condition.to_dict() for condition in args.anchor_decays
]
config_payload = {
"args": args_payload,
"mode": args.mode,
"seeds": seeds,
"models": [model.to_dict() for model in model_specs],
"device": str(device),
"torch": torch.__version__,
"python": sys.version,
"mps_available": torch.backends.mps.is_available(),
"attribution": NANOCHAT_ATTRIBUTION,
"tokenizer_path": str(splits.tokenizer_path),
"encoded_path": str(splits.encoded_path),
"train_tokens": int(len(splits.train)),
"val_tokens": int(len(splits.val)),
"effective_token_limits": [int(limit) for limit in token_limits],
"effective_stream_token_caps": [int(limit) for limit in stream_caps],
"resume_from": str(args.resume_from) if args.resume_from else None,
}
config_name = "config.resume.json" if args.resume_from else "config.json"
(output_dir / config_name).write_text(
json.dumps(config_payload, indent=2),
encoding="utf-8",
)
metrics_path = output_dir / "metrics.jsonl"
trace_path = output_dir / "trace.jsonl"
existing_rows = load_metrics(metrics_path) if args.resume_from else []
completed_keys = {metric_key(row) for row in existing_rows}
with (
metrics_path.open("a" if args.resume_from else "w", encoding="utf-8") as metrics_file,
trace_path.open("a" if args.resume_from else "w", encoding="utf-8") as trace_file,
):
if args.mode in {"screen_static", "confirm_static"}:
new_rows = run_fixed_static_sweep(
args=args,
model_specs=model_specs,
seeds=seeds,
train_tokens=splits.train,
val_tokens=splits.val,
tokenizer_vocab_size=tokenizer.vocab_size,
token_limits=token_limits,
device=device,
metrics_file=metrics_file,
trace_file=trace_file,
completed_keys=completed_keys,
)
else:
new_rows = run_locked_stream(
args=args,
model_specs=model_specs,
seeds=seeds,
train_tokens=splits.train,
val_tokens=splits.val,
tokenizer_vocab_size=tokenizer.vocab_size,
stream_caps=stream_caps,
device=device,
metrics_file=metrics_file,
trace_file=trace_file,
)
rows = existing_rows + new_rows
summary = summarize(rows)
(output_dir / "summary.json").write_text(
json.dumps(summary, indent=2),
encoding="utf-8",
)
write_csv(output_dir / "summary.csv", summary, SUMMARY_FIELDS)
if args.mode in {"screen_static", "confirm_static"}:
selection = build_model_selection(summary, args)
(output_dir / "model_selection.json").write_text(
json.dumps(selection, indent=2),
encoding="utf-8",
)
write_csv(output_dir / "model_selection.csv", selection, SELECTION_FIELDS)
write_screen_markdown_summary(output_dir, rows)
write_dropout_curve_svg(output_dir, summary)
elif args.mode == "locked_stream":
write_stream_markdown_summary(output_dir, rows)
print(
json.dumps(
{
"output_dir": str(output_dir),
"new_rows": len(new_rows),
"total_metric_rows": len(rows),
"summary_rows": len(summary),
},
indent=2,
)
)
return output_dir
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="MPS-only dropout/model selection experiments"
)
parser.add_argument(
"--mode",
choices=["screen_static", "confirm_static", "locked_stream"],
default="screen_static",
)
parser.add_argument("--corpus", default=None, help="Text or parquet corpus path")
parser.add_argument("--corpus-glob", default=None, help="Glob of text/parquet corpus paths")
parser.add_argument("--text-column", default="text", help="Parquet text column")
parser.add_argument(
"--use-cached-data",
action="store_true",
help=(
"Load tokenizer-v{vocab}.json and tokens-v{vocab}-*.npy from --cache-dir "
"instead of requiring the original text/parquet corpus."
),
)
parser.add_argument("--output-dir", default="runs")
parser.add_argument(
"--resume-from",
default=None,
help="Existing fixed-static run directory; completed metric rows are skipped",
)
parser.add_argument("--cache-dir", default=".cache/dropout_decay")
parser.add_argument(
"--models",
nargs="+",
default=["8x8x256"],
help="Model specs like 8x8x256 or name=8x8x256",
)
parser.add_argument("--seeds", nargs="+", type=int, default=None)
parser.add_argument("--token-limits", nargs="+", type=int, default=[5_000_000])
parser.add_argument(
"--stream-token-caps",
nargs="+",
type=int,
default=[5_000_000, 10_000_000, 20_000_000, 40_000_000],
)
parser.add_argument("--val-tokens", type=int, default=500_000)
parser.add_argument("--allow-short-corpus", action="store_true")
parser.add_argument("--force-retokenize", action="store_true")
parser.add_argument("--vocab-size", type=int, default=4096)
parser.add_argument("--tokenizer-train-chars", type=int, default=10_000_000)
parser.add_argument("--block-size", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--stage-steps", type=int, default=1000)
parser.add_argument("--dropout-rates", nargs="*", type=float, default=DEFAULT_DROPOUT_RATES)
parser.add_argument("--decays", nargs="*", type=parse_decay_spec, default=[])
parser.add_argument(
"--anchor-decays",
nargs="*",
type=parse_anchor_decay_spec,
default=[],
help=(
"Prefix-token anchor schedules like "
"fit:250000=0.60,500000=0.40,1000000=0.30"
),
)
parser.add_argument("--decay-tokens", type=int, default=None)
parser.add_argument("--eval-batches", type=int, default=64)
parser.add_argument("--train-eval-batches", type=int, default=32)
parser.add_argument("--trace-eval-batches", type=int, default=8)
parser.add_argument("--eval-every", type=int, default=0)
parser.add_argument("--log-every", type=int, default=100)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight-decay", type=float, default=0.1)
parser.add_argument("--grad-clip", type=float, default=1.0)
parser.add_argument("--plateau-delta", type=float, default=0.01)
parser.add_argument("--target-min-dropout", type=float, default=0.10)
parser.add_argument("--min-nonzero-margin", type=float, default=0.01)
parser.add_argument("--min-high-dropout-margin", type=float, default=0.03)
parser.add_argument("--screen-early-stop", action="store_true")
parser.add_argument("--screen-prune-patience", type=int, default=3)
parser.add_argument("--screen-prune-min-delta", type=float, default=0.01)
return parser
def main() -> None:
args = build_parser().parse_args()
run(args)
if __name__ == "__main__":
main()