Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import math
from pathlib import Path
from typing import Any
def _load_pyplot():
import matplotlib.pyplot as plt
return plt
def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]:
if not history:
return []
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
steps = [entry["global_step"] for entry in history]
digit_acc = [entry["validation_digit_accuracy"] for entry in history]
carry_acc = [entry["validation_final_carry_accuracy"] for entry in history]
exact_match = [entry["validation_exact_match"] for entry in history]
stages = [entry["stage"] for entry in history]
saved_paths: list[Path] = []
plt.figure(figsize=(8, 4.5))
plt.plot(steps, digit_acc, label="Val digit acc")
plt.plot(steps, carry_acc, label="Val final carry acc")
plt.plot(steps, exact_match, label="Val exact match")
plt.xlabel("Global step")
plt.ylabel("Accuracy")
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
metrics_path = output_dir / "training_curves.png"
plt.savefig(metrics_path, dpi=160)
plt.close()
saved_paths.append(metrics_path)
plt.figure(figsize=(8, 4.5))
plt.step(steps, stages, where="post")
plt.xlabel("Global step")
plt.ylabel("Curriculum stage")
plt.tight_layout()
stage_path = output_dir / "stage_progression.png"
plt.savefig(stage_path, dpi=160)
plt.close()
saved_paths.append(stage_path)
return saved_paths
def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]:
lengths = sorted(int(length) for length in aggregate[method][split].keys())
means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths]
stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths]
return lengths, means, stds
def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]:
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths: list[Path] = []
methods = list(aggregate.keys())
splits = [
("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"),
("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"),
]
for split, filename, title in splits:
plt.figure(figsize=(8, 4.5))
for method in methods:
lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match")
plt.plot(lengths, means, marker="o", label=method)
lower = [max(0.0, mean - std) for mean, std in zip(means, stds)]
upper = [min(1.0, mean + std) for mean, std in zip(means, stds)]
plt.fill_between(lengths, lower, upper, alpha=0.15)
plt.xlabel("Active digits")
plt.ylabel("Exact-match accuracy")
plt.title(title)
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
path = output_dir / filename
plt.savefig(path, dpi=160)
plt.close()
saved_paths.append(path)
plt.figure(figsize=(8, 4.5))
for method in methods:
stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys())
means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages]
stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages]
plt.plot(stages, means, marker="o", label=method)
plt.fill_between(
stages,
[max(0.0, mean - std) for mean, std in zip(means, stds)],
[min(1.0, mean + std) for mean, std in zip(means, stds)],
alpha=0.15,
)
plt.xlabel("Curriculum stage")
plt.ylabel("Best validation digit accuracy")
plt.ylim(0.0, 1.01)
plt.title("Validation digit accuracy vs stage")
plt.legend()
plt.tight_layout()
stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png"
plt.savefig(stage_curve_path, dpi=160)
plt.close()
saved_paths.append(stage_curve_path)
return saved_paths
def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]:
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths = plot_training_history(summary.get("history", []), output_dir)
uniform = summary["final_results"]["test_uniform"]
carry_heavy = summary["final_results"]["test_carry_heavy"]
lengths = sorted(int(length) for length in uniform.keys())
uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths]
carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths]
plt.figure(figsize=(8, 4.5))
plt.plot(lengths, uniform_exact, marker="o", label="Uniform")
plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy")
plt.xlabel("Active digits")
plt.ylabel("Exact-match accuracy")
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
final_curve_path = output_dir / "final_exact_match_by_length.png"
plt.savefig(final_curve_path, dpi=160)
plt.close()
saved_paths.append(final_curve_path)
return saved_paths