v121rc_exp1 / analyze_and_plot_hno.py
Linksome's picture
Add files using upload-large-folder tool
ac94d57 verified
import os
import re
import json
import math
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict
import matplotlib.pyplot as plt
# -----------------------------
# Experiment schema (fixed)
# -----------------------------
CONFIGS = list("ABCDEFGHI")
CONFIG_META = {
# config: (dataset_entropy_group, structure)
"A": ("HNO3", "0-shot"),
"B": ("HNO3", "CoT"),
"C": ("HNO3", "Fake CoT"),
"D": ("HNO2", "0-shot"),
"E": ("HNO2", "CoT"),
"F": ("HNO2", "Fake CoT"),
"G": ("HNO1", "0-shot"),
"H": ("HNO1", "CoT"),
"I": ("HNO1", "Fake CoT"),
}
# How to interpret file suffixes from evaluating_construct.py
EVAL_KIND_ORDER = {
"Base": ["Base"], # original distribution eval file (no _P/_R/_A)
"Paraphrase": [f"P{i}" for i in range(1, 6)], # P1..P5
"Reverse": [f"R{i}" for i in range(1, 4)], # R1..R3
"Aggregate": [f"A{i}" for i in range(1, 5)], # A1..A4
}
# For plotting hardness curves, combine into one ordered axis if desired
HARDNESS_AXIS = (
[("Base", "Base")] +
[("Paraphrase", f"P{i}") for i in range(1, 6)] +
[("Reverse", f"R{i}") for i in range(1, 4)] +
[("Aggregate", f"A{i}") for i in range(1, 5)]
)
# -----------------------------
# Paths (edit if needed)
# -----------------------------
DEFAULT_RESULTS_ROOT = "/workspace/v121rc_exp1" # expects /workspace/v121rc_exp1/A, /B, ... each containing *_results.json
DEFAULT_FIG_DIR = "/workspace/v121rc_exp1/plots"
# -----------------------------
# Helpers
# -----------------------------
def safe_mkdir(p: str) -> None:
os.makedirs(p, exist_ok=True)
# def list_result_files(config_dir: str) -> List[str]:
# # Only *_results.json produced by run*.py scripts
# files = []
# if not os.path.isdir(config_dir):
# return files
# for fn in os.listdir(config_dir):
# if fn.endswith("_results.json"):
# files.append(os.path.join(config_dir, fn))
# return sorted(files)
def list_result_files(config_dir: str) -> List[str]:
target = os.path.join(config_dir, "PandaEval12_1_results") # TODO: or PandaEval12_2_results
files: List[str] = []
if not os.path.isdir(target):
return files
for root, _, fns in os.walk(target):
for fn in fns:
if fn.endswith("_results.json"):
files.append(os.path.join(root, fn))
return sorted(files)
def infer_eval_tag_from_filename(fn: str) -> Tuple[str, str]:
"""
Returns: (kind, tag)
kind in {Base, Paraphrase, Reverse, Aggregate}
tag: Base or Pi/Ri/Ai
"""
base = os.path.basename(fn)
stem = base[:-len("_results.json")] if base.endswith("_results.json") else base
m = re.search(r"_(P[1-5]|R[1-3]|A[1-4])$", stem)
if m:
tag = m.group(1)
if tag.startswith("P"):
return "Paraphrase", tag
if tag.startswith("R"):
return "Reverse", tag
if tag.startswith("A"):
return "Aggregate", tag
return "Base", "Base"
def load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def get_ckpts_from_entries(entries: List[dict]) -> List[int]:
ckpts = set()
for e in entries:
for k in e.keys():
if k.startswith("step_"):
try:
ckpts.add(int(k.split("_", 1)[1]))
except Exception:
pass
return sorted(ckpts)
def step_accuracy(entry: dict, ckpt: int) -> Optional[float]:
key = f"step_{ckpt}"
v = entry.get(key)
if not isinstance(v, dict):
return None
acc = v.get("accuracy")
if acc is None:
return None
try:
return float(acc)
except Exception:
return None
def file_accuracy(entries: List[dict], ckpt: int) -> Tuple[float, int]:
"""
Returns (mean_acc, n_used). Skips missing/empty entries.
"""
accs = []
for e in entries:
a = step_accuracy(e, ckpt)
if a is None:
continue
accs.append(a)
if not accs:
return float("nan"), 0
return sum(accs) / len(accs), len(accs)
@dataclass
class FileMetric:
config: str
dataset: str
structure: str
ckpt: int
kind: str
tag: str
mean_acc: float
n: int
path: str
# -----------------------------
# Load + aggregate
# -----------------------------
def collect_metrics(results_root: str) -> List[FileMetric]:
metrics: List[FileMetric] = []
for cfg in CONFIGS:
cfg_dir = os.path.join(results_root, cfg)
files = list_result_files(cfg_dir)
if not files:
continue
dataset, structure = CONFIG_META[cfg]
for fp in files:
kind, tag = infer_eval_tag_from_filename(fp)
try:
entries = load_json(fp)
except Exception:
continue
if not isinstance(entries, list) or not entries:
continue
ckpts = get_ckpts_from_entries(entries)
for ckpt in ckpts:
mean_acc, n = file_accuracy(entries, ckpt)
if n == 0 or math.isnan(mean_acc):
continue
metrics.append(
FileMetric(
config=cfg,
dataset=dataset,
structure=structure,
ckpt=ckpt,
kind=kind,
tag=tag,
mean_acc=mean_acc,
n=n,
path=fp,
)
)
return metrics
def aggregate_accuracy(
metrics: List[FileMetric],
group_keys: Tuple[str, ...],
filter_fn=None,
) -> Dict[Tuple, float]:
"""
Weighted average by n across FileMetric rows.
group_keys uses FileMetric attributes: config, dataset, structure, ckpt, kind, tag.
Returns {group_tuple: weighted_mean_acc}
"""
num = defaultdict(float)
den = defaultdict(float)
for m in metrics:
if filter_fn and not filter_fn(m):
continue
key = tuple(getattr(m, k) for k in group_keys)
num[key] += m.mean_acc * m.n
den[key] += m.n
out = {}
for k in num.keys():
if den[k] > 0:
out[k] = num[k] / den[k]
return out
# -----------------------------
# Plotting (matplotlib only, no custom colors)
# -----------------------------
def plot_entropy_effect(metrics: List[FileMetric], fig_dir: str) -> None:
safe_mkdir(fig_dir)
agg = aggregate_accuracy(metrics, ("dataset", "kind"))
datasets = ["HNO1", "HNO2", "HNO3"]
kinds = ["Base", "Paraphrase", "Reverse", "Aggregate"]
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.ravel()
for ax, kind in zip(axes, kinds):
ys = [agg.get((ds, kind), float("nan")) for ds in datasets]
ax.bar(datasets, ys)
ax.set_ylim(0, 1)
ax.set_title(f"Entropy effect — {kind}")
ax.set_ylabel("Accuracy")
fig.suptitle("Q1: Effect of effective information entropy (HNO1 vs HNO2 vs HNO3)")
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q1_entropy_effect_by_kind.png"), dpi=200)
plt.close(fig)
def plot_hardness_curves(metrics: List[FileMetric], fig_dir: str) -> None:
safe_mkdir(fig_dir)
agg = aggregate_accuracy(metrics, ("dataset", "kind", "tag"))
x_labels = [t for _, t in HARDNESS_AXIS]
x = list(range(len(HARDNESS_AXIS)))
fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(111)
for ds in ["HNO1", "HNO2", "HNO3"]:
y = [agg.get((ds, kind, tag), float("nan")) for kind, tag in HARDNESS_AXIS]
ax.plot(x, y, marker="o", label=ds)
ax.set_xticks(x)
ax.set_xticklabels(x_labels, rotation=45, ha="right")
ax.set_ylim(0, 1)
ax.set_ylabel("Accuracy")
ax.set_title("Q2: Performance across evaluation sets (hardness axis)")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q2_hardness_curves_by_entropy.png"), dpi=200)
plt.close(fig)
def plot_structure_effect(metrics: List[FileMetric], fig_dir: str) -> None:
safe_mkdir(fig_dir)
structures = ["0-shot", "CoT", "Fake CoT"]
datasets = ["HNO1", "HNO2", "HNO3"]
width = 0.25
x = list(range(len(datasets)))
# (a) Base only
base_agg = aggregate_accuracy(
metrics,
("dataset", "structure"),
filter_fn=lambda m: m.kind == "Base" and m.tag == "Base",
)
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
for i, st in enumerate(structures):
y = [base_agg.get((ds, st), float("nan")) for ds in datasets]
ax.bar([xi + (i - 1) * width for xi in x], y, width=width, label=st)
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.set_ylim(0, 1)
ax.set_ylabel("Accuracy")
ax.set_title("Q3a: Structure effect on Base (train-distribution) evaluation")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q3a_structure_effect_base.png"), dpi=200)
plt.close(fig)
# (b) Overall across all eval files/kinds/tags
overall_agg = aggregate_accuracy(metrics, ("dataset", "structure"))
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
for i, st in enumerate(structures):
y = [overall_agg.get((ds, st), float("nan")) for ds in datasets]
ax.bar([xi + (i - 1) * width for xi in x], y, width=width, label=st)
ax.set_xticks(x)
ax.set_xticklabels(datasets)
ax.set_ylim(0, 1)
ax.set_ylabel("Accuracy")
ax.set_title("Q3b: Structure effect on Overall evaluation (all eval sets)")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q3b_structure_effect_overall.png"), dpi=200)
plt.close(fig)
def plot_config_breakdown(metrics: List[FileMetric], fig_dir: str) -> None:
safe_mkdir(fig_dir)
agg = aggregate_accuracy(metrics, ("config", "kind", "tag"))
x_labels = [t for _, t in HARDNESS_AXIS]
x = list(range(len(HARDNESS_AXIS)))
for dataset in ["HNO1", "HNO2", "HNO3"]:
fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(111)
cfgs = [c for c in CONFIGS if CONFIG_META[c][0] == dataset]
for cfg in cfgs:
y = [agg.get((cfg, kind, tag), float("nan")) for kind, tag in HARDNESS_AXIS]
ax.plot(x, y, marker="o", label=cfg)
ax.set_xticks(x)
ax.set_xticklabels(x_labels, rotation=45, ha="right")
ax.set_ylim(0, 1)
ax.set_ylabel("Accuracy")
ax.set_title(f"Per-config hardness curves ({dataset}) — A..I")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, f"extra_per_config_hardness_{dataset}.png"), dpi=200)
plt.close(fig)
# -----------------------------
# NEW: Q4 (step-wise / scaling law) plots
# -----------------------------
def plot_q4_learning_curve_base(metrics: List[FileMetric], fig_dir: str) -> None:
"""
Q4a: Learning curves on Base (train-distribution) eval.
x-axis: checkpoint step
y-axis: accuracy
lines: HNO1/HNO2/HNO3 (averaged over structures/configs)
"""
safe_mkdir(fig_dir)
agg = aggregate_accuracy(
metrics,
("dataset", "ckpt"),
filter_fn=lambda m: m.kind == "Base" and m.tag == "Base",
)
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
for ds in ["HNO1", "HNO2", "HNO3"]:
xs, ys = [], []
for (dataset, ckpt), acc in sorted(agg.items()):
if dataset == ds:
xs.append(ckpt)
ys.append(acc)
if xs:
ax.plot(xs, ys, marker="o", label=ds)
ax.set_xlabel("Training checkpoint")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1)
ax.set_title("Q4a: Learning curve vs training duration (Base eval)")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q4a_learning_curve_base.png"), dpi=200)
plt.close(fig)
def plot_q4_scaling_curves_ood(metrics: List[FileMetric], fig_dir: str) -> None:
"""
Q4b: Generalization scaling curves (OOD evals) vs training duration.
Subplots: Paraphrase / Reverse / Aggregate
Lines: HNO1/HNO2/HNO3 (averaged over structures/configs and tags within that kind)
"""
safe_mkdir(fig_dir)
fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
kinds = ["Paraphrase", "Reverse", "Aggregate"]
for ax, kind in zip(axes, kinds):
agg = aggregate_accuracy(
metrics,
("dataset", "ckpt"),
filter_fn=lambda m: m.kind == kind,
)
for ds in ["HNO1", "HNO2", "HNO3"]:
xs, ys = [], []
for (dataset, ckpt), acc in sorted(agg.items()):
if dataset == ds:
xs.append(ckpt)
ys.append(acc)
if xs:
ax.plot(xs, ys, marker="o", label=ds)
ax.set_title(kind)
ax.set_xlabel("Training checkpoint")
axes[0].set_ylabel("Accuracy")
axes[0].legend()
fig.suptitle("Q4b: Generalization scaling vs training duration (OOD evals)")
fig.tight_layout()
fig.savefig(
os.path.join(fig_dir, "Q4b_scaling_curve_paraphrase_reverse_aggregate.png"),
dpi=200,
)
plt.close(fig)
def plot_q4_scaling_by_structure(metrics: List[FileMetric], fig_dir: str) -> None:
"""
Q4c: Scaling vs training duration by training structure.
x-axis: checkpoint
y-axis: accuracy
lines: 0-shot / CoT / Fake CoT
(averaged over datasets+configs; filter to Base eval by default)
"""
safe_mkdir(fig_dir)
agg = aggregate_accuracy(
metrics,
("structure", "ckpt"),
filter_fn=lambda m: m.kind == "Base" and m.tag == "Base",
)
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
for st in ["0-shot", "CoT", "Fake CoT"]:
xs, ys = [], []
for (structure, ckpt), acc in sorted(agg.items()):
if structure == st:
xs.append(ckpt)
ys.append(acc)
if xs:
ax.plot(xs, ys, marker="o", label=st)
ax.set_xlabel("Training checkpoint")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1)
ax.set_title("Q4c: Scaling vs training duration by structure (Base eval)")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, "Q4c_scaling_by_structure.png"), dpi=200)
plt.close(fig)
def plot_q4_per_config_learning_curves(metrics: List[FileMetric], fig_dir: str) -> None:
"""
Q4d (extra): Per-config (A..I) learning curve on Base eval.
x-axis: checkpoint
y-axis: accuracy
One plot per dataset group (HNO1/HNO2/HNO3), lines are configs within that dataset.
"""
safe_mkdir(fig_dir)
agg = aggregate_accuracy(
metrics,
("config", "ckpt"),
filter_fn=lambda m: m.kind == "Base" and m.tag == "Base",
)
for dataset in ["HNO1", "HNO2", "HNO3"]:
cfgs = [c for c in CONFIGS if CONFIG_META[c][0] == dataset]
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
for cfg in cfgs:
xs, ys = [], []
for (c, ckpt), acc in sorted(agg.items()):
if c == cfg:
xs.append(ckpt)
ys.append(acc)
if xs:
ax.plot(xs, ys, marker="o", label=cfg)
ax.set_xlabel("Training checkpoint")
ax.set_ylabel("Accuracy")
ax.set_ylim(0, 1)
ax.set_title(f"Q4d: Per-config learning curves (Base eval) — {dataset}")
ax.legend()
fig.tight_layout()
fig.savefig(os.path.join(fig_dir, f"Q4d_per_config_learning_curve_{dataset}.png"), dpi=200)
plt.close(fig)
# -----------------------------
# Text summary (optional)
# -----------------------------
def print_topline_tables(metrics: List[FileMetric]) -> None:
overall_by_ds = aggregate_accuracy(metrics, ("dataset",))
print("\n=== Q1 topline: Overall accuracy by dataset entropy group ===")
for ds in ["HNO1", "HNO2", "HNO3"]:
v = overall_by_ds.get((ds,), float("nan"))
print(f"{ds}: {v:.4f}" if not math.isnan(v) else f"{ds}: NaN")
by_ds_kind_tag = aggregate_accuracy(metrics, ("dataset", "kind", "tag"))
print("\n=== Q2 topline: Hardness axis by dataset (Base, P1..P5, R1..R3, A1..A4) ===")
for ds in ["HNO1", "HNO2", "HNO3"]:
row = []
for kind, tag in HARDNESS_AXIS:
v = by_ds_kind_tag.get((ds, kind, tag), float("nan"))
row.append("NaN" if math.isnan(v) else f"{v:.3f}")
print(ds + ":\t" + "\t".join(row))
by_ds_structure = aggregate_accuracy(metrics, ("dataset", "structure"))
print("\n=== Q3 topline: Overall accuracy by structure within each dataset ===")
for ds in ["HNO1", "HNO2", "HNO3"]:
for st in ["0-shot", "CoT", "Fake CoT"]:
v = by_ds_structure.get((ds, st), float("nan"))
print(f"{ds} | {st}: {v:.4f}" if not math.isnan(v) else f"{ds} | {st}: NaN")
# Q4 topline: base curve endpoints
base_by_ds_ckpt = aggregate_accuracy(
metrics, ("dataset", "ckpt"), filter_fn=lambda m: m.kind == "Base" and m.tag == "Base"
)
all_ckpts = sorted({ckpt for (_, ckpt) in base_by_ds_ckpt.keys()})
if all_ckpts:
first_ckpt, last_ckpt = all_ckpts[0], all_ckpts[-1]
print("\n=== Q4 topline: Base eval improvement from first to last checkpoint ===")
for ds in ["HNO1", "HNO2", "HNO3"]:
a0 = base_by_ds_ckpt.get((ds, first_ckpt), float("nan"))
a1 = base_by_ds_ckpt.get((ds, last_ckpt), float("nan"))
if not (math.isnan(a0) or math.isnan(a1)):
print(f"{ds}: {first_ckpt}->{last_ckpt}: {a0:.4f} -> {a1:.4f} (Δ={a1-a0:+.4f})")
else:
print(f"{ds}: insufficient checkpoints to compute Δ")
# -----------------------------
# Main
# -----------------------------
def main(results_root: str = DEFAULT_RESULTS_ROOT, fig_dir: str = DEFAULT_FIG_DIR) -> None:
safe_mkdir(fig_dir)
metrics = collect_metrics(results_root)
if not metrics:
raise RuntimeError(f"No metrics found. Expected *_results.json under {results_root}/A..I")
print(f"Loaded FileMetric rows: {len(metrics)}")
print_topline_tables(metrics)
# Q1–Q3
plot_entropy_effect(metrics, fig_dir)
plot_hardness_curves(metrics, fig_dir)
plot_structure_effect(metrics, fig_dir)
plot_config_breakdown(metrics, fig_dir)
# Q4 (training duration / scaling)
plot_q4_learning_curve_base(metrics, fig_dir)
plot_q4_scaling_curves_ood(metrics, fig_dir)
plot_q4_scaling_by_structure(metrics, fig_dir)
plot_q4_per_config_learning_curves(metrics, fig_dir)
print(f"\nSaved plots to: {fig_dir}")
print("Generated files:")
print(" - Q1_entropy_effect_by_kind.png")
print(" - Q2_hardness_curves_by_entropy.png")
print(" - Q3a_structure_effect_base.png")
print(" - Q3b_structure_effect_overall.png")
print(" - extra_per_config_hardness_HNO1.png / HNO2 / HNO3")
print(" - Q4a_learning_curve_base.png")
print(" - Q4b_scaling_curve_paraphrase_reverse_aggregate.png")
print(" - Q4c_scaling_by_structure.png")
print(" - Q4d_per_config_learning_curve_HNO1.png / HNO2 / HNO3")
if __name__ == "__main__":
results_root = os.environ.get("RESULTS_ROOT", DEFAULT_RESULTS_ROOT)
fig_dir = os.environ.get("FIG_DIR", DEFAULT_FIG_DIR)
main(results_root, fig_dir)