File size: 5,407 Bytes
76de008
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from statistics import mean, pstdev
from typing import Any

from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args
from addition.plots import plot_method_comparison
from addition.train import run_experiment


def _mean_std(values: list[float]) -> dict[str, float]:
    if not values:
        return {"mean": 0.0, "std": 0.0}
    if len(values) == 1:
        return {"mean": float(values[0]), "std": 0.0}
    return {"mean": float(mean(values)), "std": float(pstdev(values))}


def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]:
    lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int)
    metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"]
    aggregated: dict[str, Any] = {}
    for length in lengths:
        aggregated[length] = {}
        for metric_name in metric_names:
            values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries]
            aggregated[length][metric_name] = _mean_std(values)
    return aggregated


def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
    max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"])
    aggregated: dict[str, Any] = {}
    for stage in range(1, max_stage + 1):
        stage_values = []
        stage_exact = []
        for summary in run_summaries:
            stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage]
            if not stage_entries:
                continue
            stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries))
            stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries))
        aggregated[str(stage)] = {
            "validation_digit_accuracy": _mean_std(stage_values),
            "validation_exact_match": _mean_std(stage_exact),
        }
    return aggregated


def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
    diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries]
    output: dict[str, Any] = {
        "probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]),
    }
    for attention_key in ("attention_uniform", "attention_carry_heavy"):
        attention_values = [diag.get(attention_key, {}) for diag in diagnostics]
        metric_names = sorted({metric for diag in attention_values for metric in diag.keys()})
        output[attention_key] = {
            metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names
        }
    return output


def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]:
    aggregate: dict[str, Any] = {}
    for method, run_summaries in results_by_method.items():
        aggregate[method] = {
            "test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"),
            "test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"),
            "stage_progression": _aggregate_stage_progression(run_summaries),
            "diagnostics": _aggregate_diagnostics(run_summaries),
        }
    return aggregate


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.")
    add_config_arguments(parser)
    parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS)
    parser.add_argument("--seeds", nargs="*", type=int, default=None)
    parser.add_argument("--comparison_output_dir", type=str, default="")
    return parser


def main() -> None:
    parser = build_parser()
    args = parser.parse_args()
    base_config = apply_preset(build_config_from_args(args))
    seeds = args.seeds or list(range(base_config.comparison_num_seeds))
    comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}")
    comparison_root.mkdir(parents=True, exist_ok=True)

    results_by_method: dict[str, list[dict[str, Any]]] = {}
    for method in args.methods:
        results_by_method[method] = []
        for seed in seeds:
            args.model = method
            args.seed = seed
            args.output_dir = str(comparison_root / f"{method}_seed{seed}")
            config = apply_preset(build_config_from_args(args))
            config.output_dir = str(comparison_root / f"{method}_seed{seed}")
            print(f"[addition comparison] running method={method} seed={seed}", flush=True)
            summary = run_experiment(config)
            results_by_method[method].append(summary)

    aggregate = aggregate_runs(results_by_method)
    aggregate_payload = {
        "methods": args.methods,
        "seeds": seeds,
        "aggregate": aggregate,
    }
    with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle:
        json.dump(aggregate_payload, handle, indent=2, sort_keys=True)
    plot_method_comparison(aggregate, comparison_root / "plots")


if __name__ == "__main__":
    main()