curriculum-cot-code / addition /run_comparison.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import argparse
import json
import math
from pathlib import Path
from statistics import mean, pstdev
from typing import Any
from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args
from addition.plots import plot_method_comparison
from addition.train import run_experiment
def _mean_std(values: list[float]) -> dict[str, float]:
if not values:
return {"mean": 0.0, "std": 0.0}
if len(values) == 1:
return {"mean": float(values[0]), "std": 0.0}
return {"mean": float(mean(values)), "std": float(pstdev(values))}
def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]:
lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int)
metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"]
aggregated: dict[str, Any] = {}
for length in lengths:
aggregated[length] = {}
for metric_name in metric_names:
values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries]
aggregated[length][metric_name] = _mean_std(values)
return aggregated
def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"])
aggregated: dict[str, Any] = {}
for stage in range(1, max_stage + 1):
stage_values = []
stage_exact = []
for summary in run_summaries:
stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage]
if not stage_entries:
continue
stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries))
stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries))
aggregated[str(stage)] = {
"validation_digit_accuracy": _mean_std(stage_values),
"validation_exact_match": _mean_std(stage_exact),
}
return aggregated
def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries]
output: dict[str, Any] = {
"probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]),
}
for attention_key in ("attention_uniform", "attention_carry_heavy"):
attention_values = [diag.get(attention_key, {}) for diag in diagnostics]
metric_names = sorted({metric for diag in attention_values for metric in diag.keys()})
output[attention_key] = {
metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names
}
return output
def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]:
aggregate: dict[str, Any] = {}
for method, run_summaries in results_by_method.items():
aggregate[method] = {
"test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"),
"test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"),
"stage_progression": _aggregate_stage_progression(run_summaries),
"diagnostics": _aggregate_diagnostics(run_summaries),
}
return aggregate
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.")
add_config_arguments(parser)
parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS)
parser.add_argument("--seeds", nargs="*", type=int, default=None)
parser.add_argument("--comparison_output_dir", type=str, default="")
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
base_config = apply_preset(build_config_from_args(args))
seeds = args.seeds or list(range(base_config.comparison_num_seeds))
comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}")
comparison_root.mkdir(parents=True, exist_ok=True)
results_by_method: dict[str, list[dict[str, Any]]] = {}
for method in args.methods:
results_by_method[method] = []
for seed in seeds:
args.model = method
args.seed = seed
args.output_dir = str(comparison_root / f"{method}_seed{seed}")
config = apply_preset(build_config_from_args(args))
config.output_dir = str(comparison_root / f"{method}_seed{seed}")
print(f"[addition comparison] running method={method} seed={seed}", flush=True)
summary = run_experiment(config)
results_by_method[method].append(summary)
aggregate = aggregate_runs(results_by_method)
aggregate_payload = {
"methods": args.methods,
"seeds": seeds,
"aggregate": aggregate,
}
with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle:
json.dump(aggregate_payload, handle, indent=2, sort_keys=True)
plot_method_comparison(aggregate, comparison_root / "plots")
if __name__ == "__main__":
main()