from __future__ import annotations import math from pathlib import Path from typing import Any def _load_pyplot(): import matplotlib.pyplot as plt return plt def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]: if not history: return [] plt = _load_pyplot() output_dir.mkdir(parents=True, exist_ok=True) steps = [entry["global_step"] for entry in history] digit_acc = [entry["validation_digit_accuracy"] for entry in history] carry_acc = [entry["validation_final_carry_accuracy"] for entry in history] exact_match = [entry["validation_exact_match"] for entry in history] stages = [entry["stage"] for entry in history] saved_paths: list[Path] = [] plt.figure(figsize=(8, 4.5)) plt.plot(steps, digit_acc, label="Val digit acc") plt.plot(steps, carry_acc, label="Val final carry acc") plt.plot(steps, exact_match, label="Val exact match") plt.xlabel("Global step") plt.ylabel("Accuracy") plt.ylim(0.0, 1.01) plt.legend() plt.tight_layout() metrics_path = output_dir / "training_curves.png" plt.savefig(metrics_path, dpi=160) plt.close() saved_paths.append(metrics_path) plt.figure(figsize=(8, 4.5)) plt.step(steps, stages, where="post") plt.xlabel("Global step") plt.ylabel("Curriculum stage") plt.tight_layout() stage_path = output_dir / "stage_progression.png" plt.savefig(stage_path, dpi=160) plt.close() saved_paths.append(stage_path) return saved_paths def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]: lengths = sorted(int(length) for length in aggregate[method][split].keys()) means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths] stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths] return lengths, means, stds def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]: plt = _load_pyplot() output_dir.mkdir(parents=True, exist_ok=True) saved_paths: list[Path] = [] methods = list(aggregate.keys()) splits = [ ("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"), ("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"), ] for split, filename, title in splits: plt.figure(figsize=(8, 4.5)) for method in methods: lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match") plt.plot(lengths, means, marker="o", label=method) lower = [max(0.0, mean - std) for mean, std in zip(means, stds)] upper = [min(1.0, mean + std) for mean, std in zip(means, stds)] plt.fill_between(lengths, lower, upper, alpha=0.15) plt.xlabel("Active digits") plt.ylabel("Exact-match accuracy") plt.title(title) plt.ylim(0.0, 1.01) plt.legend() plt.tight_layout() path = output_dir / filename plt.savefig(path, dpi=160) plt.close() saved_paths.append(path) plt.figure(figsize=(8, 4.5)) for method in methods: stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys()) means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages] stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages] plt.plot(stages, means, marker="o", label=method) plt.fill_between( stages, [max(0.0, mean - std) for mean, std in zip(means, stds)], [min(1.0, mean + std) for mean, std in zip(means, stds)], alpha=0.15, ) plt.xlabel("Curriculum stage") plt.ylabel("Best validation digit accuracy") plt.ylim(0.0, 1.01) plt.title("Validation digit accuracy vs stage") plt.legend() plt.tight_layout() stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png" plt.savefig(stage_curve_path, dpi=160) plt.close() saved_paths.append(stage_curve_path) return saved_paths def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]: plt = _load_pyplot() output_dir.mkdir(parents=True, exist_ok=True) saved_paths = plot_training_history(summary.get("history", []), output_dir) uniform = summary["final_results"]["test_uniform"] carry_heavy = summary["final_results"]["test_carry_heavy"] lengths = sorted(int(length) for length in uniform.keys()) uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths] carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths] plt.figure(figsize=(8, 4.5)) plt.plot(lengths, uniform_exact, marker="o", label="Uniform") plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy") plt.xlabel("Active digits") plt.ylabel("Exact-match accuracy") plt.ylim(0.0, 1.01) plt.legend() plt.tight_layout() final_curve_path = output_dir / "final_exact_match_by_length.png" plt.savefig(final_curve_path, dpi=160) plt.close() saved_paths.append(final_curve_path) return saved_paths