import os import re import json import math from dataclasses import dataclass from typing import Dict, List, Tuple, Optional, Any from collections import defaultdict import matplotlib.pyplot as plt # ----------------------------- # Experiment schema (fixed) # ----------------------------- CONFIGS = list("ABCDEFGHI") CONFIG_META = { # config: (dataset_entropy_group, structure) "A": ("HNO3", "0-shot"), "B": ("HNO3", "CoT"), "C": ("HNO3", "Fake CoT"), "D": ("HNO2", "0-shot"), "E": ("HNO2", "CoT"), "F": ("HNO2", "Fake CoT"), "G": ("HNO1", "0-shot"), "H": ("HNO1", "CoT"), "I": ("HNO1", "Fake CoT"), } # How to interpret file suffixes from evaluating_construct.py EVAL_KIND_ORDER = { "Base": ["Base"], # original distribution eval file (no _P/_R/_A) "Paraphrase": [f"P{i}" for i in range(1, 6)], # P1..P5 "Reverse": [f"R{i}" for i in range(1, 4)], # R1..R3 "Aggregate": [f"A{i}" for i in range(1, 5)], # A1..A4 } # For plotting hardness curves, combine into one ordered axis if desired HARDNESS_AXIS = ( [("Base", "Base")] + [("Paraphrase", f"P{i}") for i in range(1, 6)] + [("Reverse", f"R{i}") for i in range(1, 4)] + [("Aggregate", f"A{i}") for i in range(1, 5)] ) # ----------------------------- # Paths (edit if needed) # ----------------------------- DEFAULT_RESULTS_ROOT = "/workspace/v121rc_exp1" # expects /workspace/v121rc_exp1/A, /B, ... each containing *_results.json DEFAULT_FIG_DIR = "/workspace/v121rc_exp1/plots" # ----------------------------- # Helpers # ----------------------------- def safe_mkdir(p: str) -> None: os.makedirs(p, exist_ok=True) # def list_result_files(config_dir: str) -> List[str]: # # Only *_results.json produced by run*.py scripts # files = [] # if not os.path.isdir(config_dir): # return files # for fn in os.listdir(config_dir): # if fn.endswith("_results.json"): # files.append(os.path.join(config_dir, fn)) # return sorted(files) def list_result_files(config_dir: str) -> List[str]: target = os.path.join(config_dir, "PandaEval12_1_results") # TODO: or PandaEval12_2_results files: List[str] = [] if not os.path.isdir(target): return files for root, _, fns in os.walk(target): for fn in fns: if fn.endswith("_results.json"): files.append(os.path.join(root, fn)) return sorted(files) def infer_eval_tag_from_filename(fn: str) -> Tuple[str, str]: """ Returns: (kind, tag) kind in {Base, Paraphrase, Reverse, Aggregate} tag: Base or Pi/Ri/Ai """ base = os.path.basename(fn) stem = base[:-len("_results.json")] if base.endswith("_results.json") else base m = re.search(r"_(P[1-5]|R[1-3]|A[1-4])$", stem) if m: tag = m.group(1) if tag.startswith("P"): return "Paraphrase", tag if tag.startswith("R"): return "Reverse", tag if tag.startswith("A"): return "Aggregate", tag return "Base", "Base" def load_json(path: str) -> Any: with open(path, "r", encoding="utf-8") as f: return json.load(f) def get_ckpts_from_entries(entries: List[dict]) -> List[int]: ckpts = set() for e in entries: for k in e.keys(): if k.startswith("step_"): try: ckpts.add(int(k.split("_", 1)[1])) except Exception: pass return sorted(ckpts) def step_accuracy(entry: dict, ckpt: int) -> Optional[float]: key = f"step_{ckpt}" v = entry.get(key) if not isinstance(v, dict): return None acc = v.get("accuracy") if acc is None: return None try: return float(acc) except Exception: return None def file_accuracy(entries: List[dict], ckpt: int) -> Tuple[float, int]: """ Returns (mean_acc, n_used). Skips missing/empty entries. """ accs = [] for e in entries: a = step_accuracy(e, ckpt) if a is None: continue accs.append(a) if not accs: return float("nan"), 0 return sum(accs) / len(accs), len(accs) @dataclass class FileMetric: config: str dataset: str structure: str ckpt: int kind: str tag: str mean_acc: float n: int path: str # ----------------------------- # Load + aggregate # ----------------------------- def collect_metrics(results_root: str) -> List[FileMetric]: metrics: List[FileMetric] = [] for cfg in CONFIGS: cfg_dir = os.path.join(results_root, cfg) files = list_result_files(cfg_dir) if not files: continue dataset, structure = CONFIG_META[cfg] for fp in files: kind, tag = infer_eval_tag_from_filename(fp) try: entries = load_json(fp) except Exception: continue if not isinstance(entries, list) or not entries: continue ckpts = get_ckpts_from_entries(entries) for ckpt in ckpts: mean_acc, n = file_accuracy(entries, ckpt) if n == 0 or math.isnan(mean_acc): continue metrics.append( FileMetric( config=cfg, dataset=dataset, structure=structure, ckpt=ckpt, kind=kind, tag=tag, mean_acc=mean_acc, n=n, path=fp, ) ) return metrics def aggregate_accuracy( metrics: List[FileMetric], group_keys: Tuple[str, ...], filter_fn=None, ) -> Dict[Tuple, float]: """ Weighted average by n across FileMetric rows. group_keys uses FileMetric attributes: config, dataset, structure, ckpt, kind, tag. Returns {group_tuple: weighted_mean_acc} """ num = defaultdict(float) den = defaultdict(float) for m in metrics: if filter_fn and not filter_fn(m): continue key = tuple(getattr(m, k) for k in group_keys) num[key] += m.mean_acc * m.n den[key] += m.n out = {} for k in num.keys(): if den[k] > 0: out[k] = num[k] / den[k] return out # ----------------------------- # Plotting (matplotlib only, no custom colors) # ----------------------------- def plot_entropy_effect(metrics: List[FileMetric], fig_dir: str) -> None: safe_mkdir(fig_dir) agg = aggregate_accuracy(metrics, ("dataset", "kind")) datasets = ["HNO1", "HNO2", "HNO3"] kinds = ["Base", "Paraphrase", "Reverse", "Aggregate"] fig, axes = plt.subplots(2, 2, figsize=(12, 8)) axes = axes.ravel() for ax, kind in zip(axes, kinds): ys = [agg.get((ds, kind), float("nan")) for ds in datasets] ax.bar(datasets, ys) ax.set_ylim(0, 1) ax.set_title(f"Entropy effect — {kind}") ax.set_ylabel("Accuracy") fig.suptitle("Q1: Effect of effective information entropy (HNO1 vs HNO2 vs HNO3)") fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q1_entropy_effect_by_kind.png"), dpi=200) plt.close(fig) def plot_hardness_curves(metrics: List[FileMetric], fig_dir: str) -> None: safe_mkdir(fig_dir) agg = aggregate_accuracy(metrics, ("dataset", "kind", "tag")) x_labels = [t for _, t in HARDNESS_AXIS] x = list(range(len(HARDNESS_AXIS))) fig = plt.figure(figsize=(14, 6)) ax = fig.add_subplot(111) for ds in ["HNO1", "HNO2", "HNO3"]: y = [agg.get((ds, kind, tag), float("nan")) for kind, tag in HARDNESS_AXIS] ax.plot(x, y, marker="o", label=ds) ax.set_xticks(x) ax.set_xticklabels(x_labels, rotation=45, ha="right") ax.set_ylim(0, 1) ax.set_ylabel("Accuracy") ax.set_title("Q2: Performance across evaluation sets (hardness axis)") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q2_hardness_curves_by_entropy.png"), dpi=200) plt.close(fig) def plot_structure_effect(metrics: List[FileMetric], fig_dir: str) -> None: safe_mkdir(fig_dir) structures = ["0-shot", "CoT", "Fake CoT"] datasets = ["HNO1", "HNO2", "HNO3"] width = 0.25 x = list(range(len(datasets))) # (a) Base only base_agg = aggregate_accuracy( metrics, ("dataset", "structure"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base", ) fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) for i, st in enumerate(structures): y = [base_agg.get((ds, st), float("nan")) for ds in datasets] ax.bar([xi + (i - 1) * width for xi in x], y, width=width, label=st) ax.set_xticks(x) ax.set_xticklabels(datasets) ax.set_ylim(0, 1) ax.set_ylabel("Accuracy") ax.set_title("Q3a: Structure effect on Base (train-distribution) evaluation") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q3a_structure_effect_base.png"), dpi=200) plt.close(fig) # (b) Overall across all eval files/kinds/tags overall_agg = aggregate_accuracy(metrics, ("dataset", "structure")) fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) for i, st in enumerate(structures): y = [overall_agg.get((ds, st), float("nan")) for ds in datasets] ax.bar([xi + (i - 1) * width for xi in x], y, width=width, label=st) ax.set_xticks(x) ax.set_xticklabels(datasets) ax.set_ylim(0, 1) ax.set_ylabel("Accuracy") ax.set_title("Q3b: Structure effect on Overall evaluation (all eval sets)") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q3b_structure_effect_overall.png"), dpi=200) plt.close(fig) def plot_config_breakdown(metrics: List[FileMetric], fig_dir: str) -> None: safe_mkdir(fig_dir) agg = aggregate_accuracy(metrics, ("config", "kind", "tag")) x_labels = [t for _, t in HARDNESS_AXIS] x = list(range(len(HARDNESS_AXIS))) for dataset in ["HNO1", "HNO2", "HNO3"]: fig = plt.figure(figsize=(14, 6)) ax = fig.add_subplot(111) cfgs = [c for c in CONFIGS if CONFIG_META[c][0] == dataset] for cfg in cfgs: y = [agg.get((cfg, kind, tag), float("nan")) for kind, tag in HARDNESS_AXIS] ax.plot(x, y, marker="o", label=cfg) ax.set_xticks(x) ax.set_xticklabels(x_labels, rotation=45, ha="right") ax.set_ylim(0, 1) ax.set_ylabel("Accuracy") ax.set_title(f"Per-config hardness curves ({dataset}) — A..I") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, f"extra_per_config_hardness_{dataset}.png"), dpi=200) plt.close(fig) # ----------------------------- # NEW: Q4 (step-wise / scaling law) plots # ----------------------------- def plot_q4_learning_curve_base(metrics: List[FileMetric], fig_dir: str) -> None: """ Q4a: Learning curves on Base (train-distribution) eval. x-axis: checkpoint step y-axis: accuracy lines: HNO1/HNO2/HNO3 (averaged over structures/configs) """ safe_mkdir(fig_dir) agg = aggregate_accuracy( metrics, ("dataset", "ckpt"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base", ) fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) for ds in ["HNO1", "HNO2", "HNO3"]: xs, ys = [], [] for (dataset, ckpt), acc in sorted(agg.items()): if dataset == ds: xs.append(ckpt) ys.append(acc) if xs: ax.plot(xs, ys, marker="o", label=ds) ax.set_xlabel("Training checkpoint") ax.set_ylabel("Accuracy") ax.set_ylim(0, 1) ax.set_title("Q4a: Learning curve vs training duration (Base eval)") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q4a_learning_curve_base.png"), dpi=200) plt.close(fig) def plot_q4_scaling_curves_ood(metrics: List[FileMetric], fig_dir: str) -> None: """ Q4b: Generalization scaling curves (OOD evals) vs training duration. Subplots: Paraphrase / Reverse / Aggregate Lines: HNO1/HNO2/HNO3 (averaged over structures/configs and tags within that kind) """ safe_mkdir(fig_dir) fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True) kinds = ["Paraphrase", "Reverse", "Aggregate"] for ax, kind in zip(axes, kinds): agg = aggregate_accuracy( metrics, ("dataset", "ckpt"), filter_fn=lambda m: m.kind == kind, ) for ds in ["HNO1", "HNO2", "HNO3"]: xs, ys = [], [] for (dataset, ckpt), acc in sorted(agg.items()): if dataset == ds: xs.append(ckpt) ys.append(acc) if xs: ax.plot(xs, ys, marker="o", label=ds) ax.set_title(kind) ax.set_xlabel("Training checkpoint") axes[0].set_ylabel("Accuracy") axes[0].legend() fig.suptitle("Q4b: Generalization scaling vs training duration (OOD evals)") fig.tight_layout() fig.savefig( os.path.join(fig_dir, "Q4b_scaling_curve_paraphrase_reverse_aggregate.png"), dpi=200, ) plt.close(fig) def plot_q4_scaling_by_structure(metrics: List[FileMetric], fig_dir: str) -> None: """ Q4c: Scaling vs training duration by training structure. x-axis: checkpoint y-axis: accuracy lines: 0-shot / CoT / Fake CoT (averaged over datasets+configs; filter to Base eval by default) """ safe_mkdir(fig_dir) agg = aggregate_accuracy( metrics, ("structure", "ckpt"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base", ) fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) for st in ["0-shot", "CoT", "Fake CoT"]: xs, ys = [], [] for (structure, ckpt), acc in sorted(agg.items()): if structure == st: xs.append(ckpt) ys.append(acc) if xs: ax.plot(xs, ys, marker="o", label=st) ax.set_xlabel("Training checkpoint") ax.set_ylabel("Accuracy") ax.set_ylim(0, 1) ax.set_title("Q4c: Scaling vs training duration by structure (Base eval)") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, "Q4c_scaling_by_structure.png"), dpi=200) plt.close(fig) def plot_q4_per_config_learning_curves(metrics: List[FileMetric], fig_dir: str) -> None: """ Q4d (extra): Per-config (A..I) learning curve on Base eval. x-axis: checkpoint y-axis: accuracy One plot per dataset group (HNO1/HNO2/HNO3), lines are configs within that dataset. """ safe_mkdir(fig_dir) agg = aggregate_accuracy( metrics, ("config", "ckpt"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base", ) for dataset in ["HNO1", "HNO2", "HNO3"]: cfgs = [c for c in CONFIGS if CONFIG_META[c][0] == dataset] fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) for cfg in cfgs: xs, ys = [], [] for (c, ckpt), acc in sorted(agg.items()): if c == cfg: xs.append(ckpt) ys.append(acc) if xs: ax.plot(xs, ys, marker="o", label=cfg) ax.set_xlabel("Training checkpoint") ax.set_ylabel("Accuracy") ax.set_ylim(0, 1) ax.set_title(f"Q4d: Per-config learning curves (Base eval) — {dataset}") ax.legend() fig.tight_layout() fig.savefig(os.path.join(fig_dir, f"Q4d_per_config_learning_curve_{dataset}.png"), dpi=200) plt.close(fig) # ----------------------------- # Text summary (optional) # ----------------------------- def print_topline_tables(metrics: List[FileMetric]) -> None: overall_by_ds = aggregate_accuracy(metrics, ("dataset",)) print("\n=== Q1 topline: Overall accuracy by dataset entropy group ===") for ds in ["HNO1", "HNO2", "HNO3"]: v = overall_by_ds.get((ds,), float("nan")) print(f"{ds}: {v:.4f}" if not math.isnan(v) else f"{ds}: NaN") by_ds_kind_tag = aggregate_accuracy(metrics, ("dataset", "kind", "tag")) print("\n=== Q2 topline: Hardness axis by dataset (Base, P1..P5, R1..R3, A1..A4) ===") for ds in ["HNO1", "HNO2", "HNO3"]: row = [] for kind, tag in HARDNESS_AXIS: v = by_ds_kind_tag.get((ds, kind, tag), float("nan")) row.append("NaN" if math.isnan(v) else f"{v:.3f}") print(ds + ":\t" + "\t".join(row)) by_ds_structure = aggregate_accuracy(metrics, ("dataset", "structure")) print("\n=== Q3 topline: Overall accuracy by structure within each dataset ===") for ds in ["HNO1", "HNO2", "HNO3"]: for st in ["0-shot", "CoT", "Fake CoT"]: v = by_ds_structure.get((ds, st), float("nan")) print(f"{ds} | {st}: {v:.4f}" if not math.isnan(v) else f"{ds} | {st}: NaN") # Q4 topline: base curve endpoints base_by_ds_ckpt = aggregate_accuracy( metrics, ("dataset", "ckpt"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base" ) all_ckpts = sorted({ckpt for (_, ckpt) in base_by_ds_ckpt.keys()}) if all_ckpts: first_ckpt, last_ckpt = all_ckpts[0], all_ckpts[-1] print("\n=== Q4 topline: Base eval improvement from first to last checkpoint ===") for ds in ["HNO1", "HNO2", "HNO3"]: a0 = base_by_ds_ckpt.get((ds, first_ckpt), float("nan")) a1 = base_by_ds_ckpt.get((ds, last_ckpt), float("nan")) if not (math.isnan(a0) or math.isnan(a1)): print(f"{ds}: {first_ckpt}->{last_ckpt}: {a0:.4f} -> {a1:.4f} (Δ={a1-a0:+.4f})") else: print(f"{ds}: insufficient checkpoints to compute Δ") # ----------------------------- # Main # ----------------------------- def main(results_root: str = DEFAULT_RESULTS_ROOT, fig_dir: str = DEFAULT_FIG_DIR) -> None: safe_mkdir(fig_dir) metrics = collect_metrics(results_root) if not metrics: raise RuntimeError(f"No metrics found. Expected *_results.json under {results_root}/A..I") print(f"Loaded FileMetric rows: {len(metrics)}") print_topline_tables(metrics) # Q1–Q3 plot_entropy_effect(metrics, fig_dir) plot_hardness_curves(metrics, fig_dir) plot_structure_effect(metrics, fig_dir) plot_config_breakdown(metrics, fig_dir) # Q4 (training duration / scaling) plot_q4_learning_curve_base(metrics, fig_dir) plot_q4_scaling_curves_ood(metrics, fig_dir) plot_q4_scaling_by_structure(metrics, fig_dir) plot_q4_per_config_learning_curves(metrics, fig_dir) print(f"\nSaved plots to: {fig_dir}") print("Generated files:") print(" - Q1_entropy_effect_by_kind.png") print(" - Q2_hardness_curves_by_entropy.png") print(" - Q3a_structure_effect_base.png") print(" - Q3b_structure_effect_overall.png") print(" - extra_per_config_hardness_HNO1.png / HNO2 / HNO3") print(" - Q4a_learning_curve_base.png") print(" - Q4b_scaling_curve_paraphrase_reverse_aggregate.png") print(" - Q4c_scaling_by_structure.png") print(" - Q4d_per_config_learning_curve_HNO1.png / HNO2 / HNO3") if __name__ == "__main__": results_root = os.environ.get("RESULTS_ROOT", DEFAULT_RESULTS_ROOT) fig_dir = os.environ.get("FIG_DIR", DEFAULT_FIG_DIR) main(results_root, fig_dir)