| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| from pathlib import Path |
| from statistics import mean, pstdev |
| from typing import Any |
|
|
| from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args |
| from addition.plots import plot_method_comparison |
| from addition.train import run_experiment |
|
|
|
|
| def _mean_std(values: list[float]) -> dict[str, float]: |
| if not values: |
| return {"mean": 0.0, "std": 0.0} |
| if len(values) == 1: |
| return {"mean": float(values[0]), "std": 0.0} |
| return {"mean": float(mean(values)), "std": float(pstdev(values))} |
|
|
|
|
| def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]: |
| lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int) |
| metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"] |
| aggregated: dict[str, Any] = {} |
| for length in lengths: |
| aggregated[length] = {} |
| for metric_name in metric_names: |
| values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries] |
| aggregated[length][metric_name] = _mean_std(values) |
| return aggregated |
|
|
|
|
| def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]: |
| max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"]) |
| aggregated: dict[str, Any] = {} |
| for stage in range(1, max_stage + 1): |
| stage_values = [] |
| stage_exact = [] |
| for summary in run_summaries: |
| stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage] |
| if not stage_entries: |
| continue |
| stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries)) |
| stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries)) |
| aggregated[str(stage)] = { |
| "validation_digit_accuracy": _mean_std(stage_values), |
| "validation_exact_match": _mean_std(stage_exact), |
| } |
| return aggregated |
|
|
|
|
| def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]: |
| diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries] |
| output: dict[str, Any] = { |
| "probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]), |
| } |
| for attention_key in ("attention_uniform", "attention_carry_heavy"): |
| attention_values = [diag.get(attention_key, {}) for diag in diagnostics] |
| metric_names = sorted({metric for diag in attention_values for metric in diag.keys()}) |
| output[attention_key] = { |
| metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names |
| } |
| return output |
|
|
|
|
| def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]: |
| aggregate: dict[str, Any] = {} |
| for method, run_summaries in results_by_method.items(): |
| aggregate[method] = { |
| "test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"), |
| "test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"), |
| "stage_progression": _aggregate_stage_progression(run_summaries), |
| "diagnostics": _aggregate_diagnostics(run_summaries), |
| } |
| return aggregate |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.") |
| add_config_arguments(parser) |
| parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS) |
| parser.add_argument("--seeds", nargs="*", type=int, default=None) |
| parser.add_argument("--comparison_output_dir", type=str, default="") |
| return parser |
|
|
|
|
| def main() -> None: |
| parser = build_parser() |
| args = parser.parse_args() |
| base_config = apply_preset(build_config_from_args(args)) |
| seeds = args.seeds or list(range(base_config.comparison_num_seeds)) |
| comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}") |
| comparison_root.mkdir(parents=True, exist_ok=True) |
|
|
| results_by_method: dict[str, list[dict[str, Any]]] = {} |
| for method in args.methods: |
| results_by_method[method] = [] |
| for seed in seeds: |
| args.model = method |
| args.seed = seed |
| args.output_dir = str(comparison_root / f"{method}_seed{seed}") |
| config = apply_preset(build_config_from_args(args)) |
| config.output_dir = str(comparison_root / f"{method}_seed{seed}") |
| print(f"[addition comparison] running method={method} seed={seed}", flush=True) |
| summary = run_experiment(config) |
| results_by_method[method].append(summary) |
|
|
| aggregate = aggregate_runs(results_by_method) |
| aggregate_payload = { |
| "methods": args.methods, |
| "seeds": seeds, |
| "aggregate": aggregate, |
| } |
| with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle: |
| json.dump(aggregate_payload, handle, indent=2, sort_keys=True) |
| plot_method_comparison(aggregate, comparison_root / "plots") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|