dropout-decay / scripts /fit_dropout_coefficients.py
Mandeep Sidhu
Add dropout pressure validation artifacts
3550904
#!/usr/bin/env python3
"""
Derived from Andrej Karpathy's nanochat project.
MIT License
Copyright (c) 2025 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"""
from __future__ import annotations
import argparse
import csv
from dataclasses import dataclass
import json
import math
from pathlib import Path
import statistics
import numpy as np
@dataclass(frozen=True)
class CurvePoint:
dropout: float
mean_val_loss: float
std_val_loss: float
n: int
@dataclass(frozen=True)
class CalibrationCell:
source: str
run_dir: Path
run_mode: str
model_name: str
n_layer: int
n_head: int
n_embd: int
parameters: int
unique_tokens: int
sampled_tokens: int
best_grid_dropout: float
best_quad_dropout: float
target_dropout: float
best_val_loss: float
best_val_std: float
boundary_optimum: bool
bracketed_optimum: bool
weight: float
curve: tuple[CurvePoint, ...]
@property
def x_model_pressure(self) -> float:
return math.log10(self.parameters / self.unique_tokens)
@property
def x_sample_pressure(self) -> float:
return math.log10(self.sampled_tokens / self.unique_tokens)
FEATURE_SETS = {
"base": (
("A", "log10(P / U)"),
("B", "log10(C / U)"),
("C0", "1"),
),
"interaction": (
("A", "log10(P / U)"),
("B", "log10(C / U)"),
("D", "log10(P / U) * log10(C / U)"),
("C0", "1"),
),
"quadratic": (
("A", "log10(P / U)"),
("B", "log10(C / U)"),
("Qp", "log10(P / U)^2"),
("Qc", "log10(C / U)^2"),
("C0", "1"),
),
"full_quadratic": (
("A", "log10(P / U)"),
("B", "log10(C / U)"),
("D", "log10(P / U) * log10(C / U)"),
("Qp", "log10(P / U)^2"),
("Qc", "log10(C / U)^2"),
("C0", "1"),
),
}
def feature_vector(cell: CalibrationCell, feature_set: str) -> list[float]:
x_model = cell.x_model_pressure
x_sample = cell.x_sample_pressure
if feature_set == "base":
return [x_model, x_sample, 1.0]
if feature_set == "interaction":
return [x_model, x_sample, x_model * x_sample, 1.0]
if feature_set == "quadratic":
return [x_model, x_sample, x_model * x_model, x_sample * x_sample, 1.0]
if feature_set == "full_quadratic":
return [
x_model,
x_sample,
x_model * x_sample,
x_model * x_model,
x_sample * x_sample,
1.0,
]
raise ValueError(f"unknown feature set: {feature_set}")
def coefficient_names(feature_set: str) -> list[str]:
return [name for name, _ in FEATURE_SETS[feature_set]]
def formula_terms(feature_set: str, coef: np.ndarray) -> list[str]:
terms: list[str] = []
for value, (_, label) in zip(coef, FEATURE_SETS[feature_set]):
if label == "1":
terms.append(f"{value:+.6f}")
else:
terms.append(f"{value:+.6f} * {label}")
return terms
def predict_dropout(cell: CalibrationCell, coef: np.ndarray, feature_set: str) -> float:
return float(np.array(feature_vector(cell, feature_set), dtype=np.float64) @ coef)
def parse_curve(raw: str) -> list[CurvePoint]:
data = json.loads(raw)
points: list[CurvePoint] = []
if isinstance(data, dict):
for dropout, value in data.items():
points.append(
CurvePoint(
dropout=float(dropout),
mean_val_loss=float(value),
std_val_loss=0.0,
n=1,
)
)
else:
for item in data:
mean_val_loss = item.get("mean_val_loss")
if mean_val_loss is None:
mean_val_loss = item.get("val_loss")
if mean_val_loss is None:
mean_val_loss = item.get("eval_loss")
if mean_val_loss is None:
raise ValueError(f"curve point has no validation-loss field: {item}")
points.append(
CurvePoint(
dropout=float(item["dropout"]),
mean_val_loss=float(mean_val_loss),
std_val_loss=float(item.get("std_val_loss", 0.0)),
n=int(item.get("n", 1)),
)
)
return sorted(points, key=lambda point: point.dropout)
def quadratic_minimum(points: list[CurvePoint]) -> tuple[float, bool]:
if not points:
raise ValueError("cannot estimate optimum from an empty curve")
best_index = min(range(len(points)), key=lambda index: points[index].mean_val_loss)
if best_index == 0 or best_index == len(points) - 1:
return points[best_index].dropout, False
left, mid, right = points[best_index - 1 : best_index + 2]
x1, y1 = left.dropout, left.mean_val_loss
x2, y2 = mid.dropout, mid.mean_val_loss
x3, y3 = right.dropout, right.mean_val_loss
denominator = (x1 - x2) * (x1 - x3) * (x2 - x3)
if abs(denominator) < 1e-12:
return mid.dropout, False
a = (x3 * (y2 - y1) + x2 * (y1 - y3) + x1 * (y3 - y2)) / denominator
b = (
x3 * x3 * (y1 - y2)
+ x2 * x2 * (y3 - y1)
+ x1 * x1 * (y2 - y3)
) / denominator
if a <= 0.0:
return mid.dropout, False
optimum = -b / (2.0 * a)
low = min(x1, x3)
high = max(x1, x3)
return max(low, min(high, optimum)), True
def load_metrics_by_cell(run_dir: Path) -> dict[tuple[str, int], int]:
metrics_path = run_dir / "metrics.jsonl"
if not metrics_path.exists():
return {}
sampled_tokens: dict[tuple[str, int], int] = {}
for line in metrics_path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
row = json.loads(line)
model = str(row["model_name"])
token_limit = int(row["token_limit"])
key = (model, token_limit)
sampled_tokens[key] = max(
int(row.get("tokens_seen", 0)),
sampled_tokens.get(key, 0),
)
return sampled_tokens
def source_label(run_dir: Path) -> str:
parts = run_dir.parts
if len(parts) >= 3:
return "/".join(parts[-3:])
return str(run_dir)
def load_cells(run_dirs: list[Path], target: str, weighting: str) -> list[CalibrationCell]:
cells: list[CalibrationCell] = []
for run_dir in run_dirs:
selection_path = run_dir / "model_selection.csv"
if not selection_path.exists():
raise FileNotFoundError(f"missing model_selection.csv under {run_dir}")
sampled_by_cell = load_metrics_by_cell(run_dir)
with selection_path.open(newline="", encoding="utf-8") as handle:
for row in csv.DictReader(handle):
curve = parse_curve(row["curve_json"])
if not curve:
continue
best_grid = float(row["best_dropout"])
best_quad, bracketed = quadratic_minimum(curve)
target_dropout = best_quad if target == "quad" else best_grid
model_name = row["model_name"]
unique_tokens = int(float(row["token_limit"]))
sampled_tokens = sampled_by_cell.get((model_name, unique_tokens), 0)
if not sampled_tokens:
sampled_tokens = max(
unique_tokens,
int(float(row.get("tokens_seen", unique_tokens))),
)
rates = [point.dropout for point in curve]
boundary = best_grid in {min(rates), max(rates)}
best = min(curve, key=lambda point: point.mean_val_loss)
loss_span = max(point.mean_val_loss for point in curve) - best.mean_val_loss
weight = 1.0
if weighting == "heuristic":
if boundary:
weight *= 0.30
if not bracketed:
weight *= 0.50
if loss_span < 0.02:
weight *= 0.50
if best.std_val_loss > 0.0:
weight *= 1.0 / (1.0 + 20.0 * best.std_val_loss)
cells.append(
CalibrationCell(
source=source_label(run_dir),
run_dir=run_dir,
run_mode=row["run_mode"],
model_name=model_name,
n_layer=int(float(row["n_layer"])),
n_head=int(float(row["n_head"])),
n_embd=int(float(row["n_embd"])),
parameters=int(float(row["parameters"])),
unique_tokens=unique_tokens,
sampled_tokens=sampled_tokens,
best_grid_dropout=best_grid,
best_quad_dropout=best_quad,
target_dropout=target_dropout,
best_val_loss=float(row["best_val_loss"]),
best_val_std=float(row["best_val_std"]),
boundary_optimum=boundary,
bracketed_optimum=bracketed,
weight=max(weight, 0.05),
curve=tuple(curve),
)
)
return cells
def fit_coefficients(
cells: list[CalibrationCell],
feature_set: str,
) -> tuple[np.ndarray, dict[str, float]]:
feature_count = len(FEATURE_SETS[feature_set])
if len(cells) < feature_count:
raise ValueError(
f"need at least {feature_count} cells to fit feature set {feature_set}"
)
x = np.array(
[feature_vector(cell, feature_set) for cell in cells],
dtype=np.float64,
)
y = np.array([cell.target_dropout for cell in cells], dtype=np.float64)
weights = np.array([cell.weight for cell in cells], dtype=np.float64)
sqrt_w = np.sqrt(weights)
coef, *_ = np.linalg.lstsq(x * sqrt_w[:, None], y * sqrt_w, rcond=None)
pred = x @ coef
errors = pred - y
metrics = {
"n": float(len(cells)),
"rmse": float(np.sqrt(np.mean(errors * errors))),
"mae": float(np.mean(np.abs(errors))),
"bias": float(np.mean(errors)),
"weighted_rmse": float(np.sqrt(np.average(errors * errors, weights=weights))),
"weighted_mae": float(np.average(np.abs(errors), weights=weights)),
}
return coef, metrics
def grouped_cv(
cells: list[CalibrationCell],
key_name: str,
feature_set: str,
) -> dict[str, float]:
groups: dict[str, list[CalibrationCell]] = {}
for cell in cells:
if key_name == "model":
key = cell.model_name
elif key_name == "prefix":
key = str(cell.unique_tokens)
elif key_name == "source":
key = cell.source
else:
raise ValueError(f"unknown cv key: {key_name}")
groups.setdefault(key, []).append(cell)
errors: list[float] = []
for key, held_out in groups.items():
train = [cell for cell in cells if cell not in held_out]
if len(train) < len(FEATURE_SETS[feature_set]):
continue
coef, _ = fit_coefficients(train, feature_set)
for cell in held_out:
prediction = predict_dropout(cell, coef, feature_set)
errors.append(prediction - cell.target_dropout)
if not errors:
return {"n": 0.0, "rmse": float("nan"), "mae": float("nan"), "bias": float("nan")}
return {
"n": float(len(errors)),
"rmse": math.sqrt(statistics.fmean(error * error for error in errors)),
"mae": statistics.fmean(abs(error) for error in errors),
"bias": statistics.fmean(errors),
}
def suggest_fine_rates(
cell: CalibrationCell,
min_rate: float,
max_rate: float,
spacing: float,
) -> list[float]:
center = cell.target_dropout
candidates = {
round(max(min_rate, min(max_rate, center + delta)), 3)
for delta in (-2 * spacing, -spacing, 0.0, spacing, 2 * spacing)
}
existing = {round(point.dropout, 3) for point in cell.curve}
return sorted(rate for rate in candidates if rate not in existing)
def write_cells(
path: Path,
cells: list[CalibrationCell],
coef: np.ndarray,
feature_set: str,
) -> None:
fieldnames = [
"source",
"run_dir",
"run_mode",
"model_name",
"n_layer",
"n_head",
"n_embd",
"parameters",
"unique_tokens",
"sampled_tokens",
"x_model_pressure",
"x_sample_pressure",
"best_grid_dropout",
"best_quad_dropout",
"target_dropout",
"predicted_dropout",
"residual",
"best_val_loss",
"best_val_std",
"boundary_optimum",
"bracketed_optimum",
"weight",
]
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for cell in cells:
prediction = predict_dropout(cell, coef, feature_set)
writer.writerow(
{
"source": cell.source,
"run_dir": str(cell.run_dir),
"run_mode": cell.run_mode,
"model_name": cell.model_name,
"n_layer": cell.n_layer,
"n_head": cell.n_head,
"n_embd": cell.n_embd,
"parameters": cell.parameters,
"unique_tokens": cell.unique_tokens,
"sampled_tokens": cell.sampled_tokens,
"x_model_pressure": cell.x_model_pressure,
"x_sample_pressure": cell.x_sample_pressure,
"best_grid_dropout": cell.best_grid_dropout,
"best_quad_dropout": cell.best_quad_dropout,
"target_dropout": cell.target_dropout,
"predicted_dropout": prediction,
"residual": prediction - cell.target_dropout,
"best_val_loss": cell.best_val_loss,
"best_val_std": cell.best_val_std,
"boundary_optimum": cell.boundary_optimum,
"bracketed_optimum": cell.bracketed_optimum,
"weight": cell.weight,
}
)
def write_suggestions(
path: Path,
cells: list[CalibrationCell],
min_rate: float,
max_rate: float,
spacing: float,
) -> None:
fieldnames = [
"source",
"model_name",
"unique_tokens",
"target_dropout",
"boundary_optimum",
"bracketed_optimum",
"suggested_rates",
]
with path.open("w", newline="", encoding="utf-8") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for cell in cells:
writer.writerow(
{
"source": cell.source,
"model_name": cell.model_name,
"unique_tokens": cell.unique_tokens,
"target_dropout": f"{cell.target_dropout:.4f}",
"boundary_optimum": cell.boundary_optimum,
"bracketed_optimum": cell.bracketed_optimum,
"suggested_rates": " ".join(
f"{rate:.3f}"
for rate in suggest_fine_rates(cell, min_rate, max_rate, spacing)
),
}
)
def write_report(
path: Path,
cells: list[CalibrationCell],
coef: np.ndarray,
metrics: dict[str, float],
cv: dict[str, dict[str, float]],
target: str,
feature_set: str,
) -> None:
lines = [
"# Dropout Coefficient Fit Diagnostics",
"",
f"Target: `{target}`",
f"Feature set: `{feature_set}`",
"",
"## Coefficients",
"",
"| Coefficient | Term | Value |",
"|---|---|---:|",
]
for value, (name, label) in zip(coef, FEATURE_SETS[feature_set]):
lines.append(f"| `{name}` | `{label}` | {value:.6f} |")
terms = formula_terms(feature_set, coef)
lines.extend(
[
"",
"Formula:",
"",
"```text",
"p = clamp(p_min, p_max,",
" " + "\n ".join(terms) + ")",
"```",
"",
"## Fit Metrics",
"",
"| Metric | Value |",
"|---|---:|",
]
)
for key in ["n", "rmse", "mae", "bias", "weighted_rmse", "weighted_mae"]:
lines.append(f"| `{key}` | {metrics[key]:.6f} |")
lines.extend(["", "## Cross-Validation", "", "| Holdout | n | RMSE | MAE | Bias |", "|---|---:|---:|---:|---:|"])
for key, item in cv.items():
lines.append(
f"| `{key}` | {item['n']:.0f} | {item['rmse']:.6f} | "
f"{item['mae']:.6f} | {item['bias']:.6f} |"
)
lines.extend(
[
"",
"## Calibration Cells",
"",
"| Source | Model | Params | Unique | Sampled | Grid p | Quad p | Target p | Pred p | Residual | Weight | Bracketed | Boundary |",
"|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---|---|",
]
)
for cell in sorted(cells, key=lambda c: (c.source, c.parameters, c.unique_tokens)):
prediction = predict_dropout(cell, coef, feature_set)
lines.append(
"| "
f"`{cell.source}` | `{cell.model_name}` | {cell.parameters:,} | "
f"{cell.unique_tokens:,} | {cell.sampled_tokens:,} | "
f"{cell.best_grid_dropout:.3f} | {cell.best_quad_dropout:.3f} | "
f"{cell.target_dropout:.3f} | {prediction:.3f} | "
f"{prediction - cell.target_dropout:+.3f} | {cell.weight:.3f} | "
f"{cell.bracketed_optimum} | {cell.boundary_optimum} |"
)
path.write_text("\n".join(lines) + "\n", encoding="utf-8")
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Fit dropout pressure coefficients from saved static sweep outputs."
)
parser.add_argument("--run-dirs", nargs="+", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--target", choices=["grid", "quad"], default="quad")
parser.add_argument(
"--weighting",
choices=["none", "heuristic"],
default="heuristic",
help="Use equal weights or downweight boundary/flat/noisy optima.",
)
parser.add_argument("--min-rate", type=float, default=0.0)
parser.add_argument("--max-rate", type=float, default=0.65)
parser.add_argument("--fine-spacing", type=float, default=0.02)
parser.add_argument(
"--feature-set",
choices=sorted(FEATURE_SETS),
default="base",
help="Feature family used to map pressure variables to dropout.",
)
return parser
def main() -> None:
args = build_parser().parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
cells = load_cells(args.run_dirs, args.target, args.weighting)
coef, metrics = fit_coefficients(cells, args.feature_set)
cv = {
"leave_model": grouped_cv(cells, "model", args.feature_set),
"leave_prefix": grouped_cv(cells, "prefix", args.feature_set),
}
if len({cell.source for cell in cells}) > 1:
cv["leave_source"] = grouped_cv(cells, "source", args.feature_set)
coefficient_values = {
name: float(value)
for name, value in zip(coefficient_names(args.feature_set), coef)
}
coefficients = {
"target": args.target,
"feature_set": args.feature_set,
"formula": "p = " + " ".join(formula_terms(args.feature_set, coef)),
"weighting": args.weighting,
"coefficients": coefficient_values,
"metrics": metrics,
"cv": cv,
"run_dirs": [str(path) for path in args.run_dirs],
}
coefficients.update(coefficient_values)
(args.output_dir / "coefficients.json").write_text(
json.dumps(coefficients, indent=2),
encoding="utf-8",
)
write_cells(args.output_dir / "calibration_cells.csv", cells, coef, args.feature_set)
write_suggestions(
args.output_dir / "next_dropout_suggestions.csv",
cells,
args.min_rate,
args.max_rate,
args.fine_spacing,
)
write_report(
args.output_dir / "fit_diagnostics.md",
cells,
coef,
metrics,
cv,
args.target,
args.feature_set,
)
print(
json.dumps(
{
"output_dir": str(args.output_dir),
"cells": len(cells),
"feature_set": args.feature_set,
"coefficients": coefficient_values,
"rmse": metrics["rmse"],
"mae": metrics["mae"],
},
indent=2,
)
)
if __name__ == "__main__":
main()