File size: 5,376 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import math
from pathlib import Path
from typing import Any


def _load_pyplot():
    import matplotlib.pyplot as plt

    return plt


def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]:
    if not history:
        return []
    plt = _load_pyplot()
    output_dir.mkdir(parents=True, exist_ok=True)
    steps = [entry["global_step"] for entry in history]
    digit_acc = [entry["validation_digit_accuracy"] for entry in history]
    carry_acc = [entry["validation_final_carry_accuracy"] for entry in history]
    exact_match = [entry["validation_exact_match"] for entry in history]
    stages = [entry["stage"] for entry in history]

    saved_paths: list[Path] = []

    plt.figure(figsize=(8, 4.5))
    plt.plot(steps, digit_acc, label="Val digit acc")
    plt.plot(steps, carry_acc, label="Val final carry acc")
    plt.plot(steps, exact_match, label="Val exact match")
    plt.xlabel("Global step")
    plt.ylabel("Accuracy")
    plt.ylim(0.0, 1.01)
    plt.legend()
    plt.tight_layout()
    metrics_path = output_dir / "training_curves.png"
    plt.savefig(metrics_path, dpi=160)
    plt.close()
    saved_paths.append(metrics_path)

    plt.figure(figsize=(8, 4.5))
    plt.step(steps, stages, where="post")
    plt.xlabel("Global step")
    plt.ylabel("Curriculum stage")
    plt.tight_layout()
    stage_path = output_dir / "stage_progression.png"
    plt.savefig(stage_path, dpi=160)
    plt.close()
    saved_paths.append(stage_path)

    return saved_paths


def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]:
    lengths = sorted(int(length) for length in aggregate[method][split].keys())
    means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths]
    stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths]
    return lengths, means, stds


def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]:
    plt = _load_pyplot()
    output_dir.mkdir(parents=True, exist_ok=True)
    saved_paths: list[Path] = []
    methods = list(aggregate.keys())
    splits = [
        ("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"),
        ("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"),
    ]
    for split, filename, title in splits:
        plt.figure(figsize=(8, 4.5))
        for method in methods:
            lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match")
            plt.plot(lengths, means, marker="o", label=method)
            lower = [max(0.0, mean - std) for mean, std in zip(means, stds)]
            upper = [min(1.0, mean + std) for mean, std in zip(means, stds)]
            plt.fill_between(lengths, lower, upper, alpha=0.15)
        plt.xlabel("Active digits")
        plt.ylabel("Exact-match accuracy")
        plt.title(title)
        plt.ylim(0.0, 1.01)
        plt.legend()
        plt.tight_layout()
        path = output_dir / filename
        plt.savefig(path, dpi=160)
        plt.close()
        saved_paths.append(path)

    plt.figure(figsize=(8, 4.5))
    for method in methods:
        stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys())
        means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages]
        stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages]
        plt.plot(stages, means, marker="o", label=method)
        plt.fill_between(
            stages,
            [max(0.0, mean - std) for mean, std in zip(means, stds)],
            [min(1.0, mean + std) for mean, std in zip(means, stds)],
            alpha=0.15,
        )
    plt.xlabel("Curriculum stage")
    plt.ylabel("Best validation digit accuracy")
    plt.ylim(0.0, 1.01)
    plt.title("Validation digit accuracy vs stage")
    plt.legend()
    plt.tight_layout()
    stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png"
    plt.savefig(stage_curve_path, dpi=160)
    plt.close()
    saved_paths.append(stage_curve_path)
    return saved_paths


def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]:
    plt = _load_pyplot()
    output_dir.mkdir(parents=True, exist_ok=True)
    saved_paths = plot_training_history(summary.get("history", []), output_dir)
    uniform = summary["final_results"]["test_uniform"]
    carry_heavy = summary["final_results"]["test_carry_heavy"]
    lengths = sorted(int(length) for length in uniform.keys())
    uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths]
    carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths]
    plt.figure(figsize=(8, 4.5))
    plt.plot(lengths, uniform_exact, marker="o", label="Uniform")
    plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy")
    plt.xlabel("Active digits")
    plt.ylabel("Exact-match accuracy")
    plt.ylim(0.0, 1.01)
    plt.legend()
    plt.tight_layout()
    final_curve_path = output_dir / "final_exact_match_by_length.png"
    plt.savefig(final_curve_path, dpi=160)
    plt.close()
    saved_paths.append(final_curve_path)
    return saved_paths