File size: 5,376 Bytes
76de008 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | from __future__ import annotations
import math
from pathlib import Path
from typing import Any
def _load_pyplot():
import matplotlib.pyplot as plt
return plt
def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]:
if not history:
return []
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
steps = [entry["global_step"] for entry in history]
digit_acc = [entry["validation_digit_accuracy"] for entry in history]
carry_acc = [entry["validation_final_carry_accuracy"] for entry in history]
exact_match = [entry["validation_exact_match"] for entry in history]
stages = [entry["stage"] for entry in history]
saved_paths: list[Path] = []
plt.figure(figsize=(8, 4.5))
plt.plot(steps, digit_acc, label="Val digit acc")
plt.plot(steps, carry_acc, label="Val final carry acc")
plt.plot(steps, exact_match, label="Val exact match")
plt.xlabel("Global step")
plt.ylabel("Accuracy")
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
metrics_path = output_dir / "training_curves.png"
plt.savefig(metrics_path, dpi=160)
plt.close()
saved_paths.append(metrics_path)
plt.figure(figsize=(8, 4.5))
plt.step(steps, stages, where="post")
plt.xlabel("Global step")
plt.ylabel("Curriculum stage")
plt.tight_layout()
stage_path = output_dir / "stage_progression.png"
plt.savefig(stage_path, dpi=160)
plt.close()
saved_paths.append(stage_path)
return saved_paths
def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]:
lengths = sorted(int(length) for length in aggregate[method][split].keys())
means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths]
stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths]
return lengths, means, stds
def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]:
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths: list[Path] = []
methods = list(aggregate.keys())
splits = [
("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"),
("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"),
]
for split, filename, title in splits:
plt.figure(figsize=(8, 4.5))
for method in methods:
lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match")
plt.plot(lengths, means, marker="o", label=method)
lower = [max(0.0, mean - std) for mean, std in zip(means, stds)]
upper = [min(1.0, mean + std) for mean, std in zip(means, stds)]
plt.fill_between(lengths, lower, upper, alpha=0.15)
plt.xlabel("Active digits")
plt.ylabel("Exact-match accuracy")
plt.title(title)
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
path = output_dir / filename
plt.savefig(path, dpi=160)
plt.close()
saved_paths.append(path)
plt.figure(figsize=(8, 4.5))
for method in methods:
stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys())
means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages]
stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages]
plt.plot(stages, means, marker="o", label=method)
plt.fill_between(
stages,
[max(0.0, mean - std) for mean, std in zip(means, stds)],
[min(1.0, mean + std) for mean, std in zip(means, stds)],
alpha=0.15,
)
plt.xlabel("Curriculum stage")
plt.ylabel("Best validation digit accuracy")
plt.ylim(0.0, 1.01)
plt.title("Validation digit accuracy vs stage")
plt.legend()
plt.tight_layout()
stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png"
plt.savefig(stage_curve_path, dpi=160)
plt.close()
saved_paths.append(stage_curve_path)
return saved_paths
def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]:
plt = _load_pyplot()
output_dir.mkdir(parents=True, exist_ok=True)
saved_paths = plot_training_history(summary.get("history", []), output_dir)
uniform = summary["final_results"]["test_uniform"]
carry_heavy = summary["final_results"]["test_carry_heavy"]
lengths = sorted(int(length) for length in uniform.keys())
uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths]
carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths]
plt.figure(figsize=(8, 4.5))
plt.plot(lengths, uniform_exact, marker="o", label="Uniform")
plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy")
plt.xlabel("Active digits")
plt.ylabel("Exact-match accuracy")
plt.ylim(0.0, 1.01)
plt.legend()
plt.tight_layout()
final_curve_path = output_dir / "final_exact_match_by_length.png"
plt.savefig(final_curve_path, dpi=160)
plt.close()
saved_paths.append(final_curve_path)
return saved_paths
|