Ubuntu Cursor commited on
Commit ·
a40741d
1
Parent(s): 624a68e
Cross-stage experiments: code, paper figures (solve/exact/precision/recall)
Browse files- _experiments/cross_stage/: per-cell prediction sweep across (method, train_stage, prompt_stage)
with analyze scripts that compute containment, difficulty-stratified solve
rate, failure-mode taxonomy, and a 3x3 cross-prompt grid.
- _runs/_paper_figures/: updated solve/exact and new precision/recall plots
with No-CoT-No-Curriculum horizontal baseline.
Co-authored-by: Cursor <cursoragent@cursor.com>
- .gitignore +11 -0
- _experiments/cross_stage/_peek.py +49 -0
- _experiments/cross_stage/analyze.py +236 -0
- _experiments/cross_stage/analyze_cross_prompt.py +303 -0
- _experiments/cross_stage/analyze_v2.py +425 -0
- _experiments/cross_stage/overnight_pipeline.sh +58 -0
- _experiments/cross_stage/predict_one.py +211 -0
- _experiments/cross_stage/run_all.sh +58 -0
- _experiments/cross_stage/run_cross_prompt.sh +66 -0
- _experiments/cross_stage/run_cross_prompt_phase2.sh +93 -0
- _experiments/cross_stage/run_nocurr_cot.sh +79 -0
- _experiments/cross_stage/watcher_launch_more.sh +32 -0
- _runs/_paper_figures/plot_stage_progression.py +80 -20
- _runs/_paper_figures/stage_progression_exact.pdf +2 -2
- _runs/_paper_figures/stage_progression_exact.png +2 -2
- _runs/_paper_figures/stage_progression_precision.pdf +3 -0
- _runs/_paper_figures/stage_progression_precision.png +3 -0
- _runs/_paper_figures/stage_progression_recall.pdf +3 -0
- _runs/_paper_figures/stage_progression_recall.png +3 -0
- _runs/_paper_figures/stage_progression_solve.pdf +2 -2
- _runs/_paper_figures/stage_progression_solve.png +2 -2
.gitignore
CHANGED
|
@@ -30,4 +30,15 @@ _pushlogs/
|
|
| 30 |
_runs/strawman_warm_*/
|
| 31 |
_runs/adaptive_k_resume_*/
|
| 32 |
_runs/launch_finish_repos23_*_pids.txt
|
|
|
|
| 33 |
curriculum_cot/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
_runs/strawman_warm_*/
|
| 31 |
_runs/adaptive_k_resume_*/
|
| 32 |
_runs/launch_finish_repos23_*_pids.txt
|
| 33 |
+
_runs/nocurr_cot_*/
|
| 34 |
curriculum_cot/
|
| 35 |
+
|
| 36 |
+
# Cross-stage prediction dumps (kept on the workspace, not in code repo)
|
| 37 |
+
_experiments/cross_stage/preds/
|
| 38 |
+
_experiments/cross_stage/preds_xprompt/
|
| 39 |
+
_experiments/cross_stage/logs/
|
| 40 |
+
_experiments/cross_stage/logs_xprompt/
|
| 41 |
+
_experiments/cross_stage/figs/
|
| 42 |
+
_experiments/cross_stage/figs_xprompt/
|
| 43 |
+
|
| 44 |
+
_experiments/cross_stage/*.log
|
_experiments/cross_stage/_peek.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quick peek at completed cross-prompt outputs."""
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
DIAG = Path("/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds")
|
| 6 |
+
XP = Path("/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt")
|
| 7 |
+
|
| 8 |
+
def load(p):
|
| 9 |
+
out = []
|
| 10 |
+
with open(p) as f:
|
| 11 |
+
for line in f:
|
| 12 |
+
line = line.strip()
|
| 13 |
+
if line:
|
| 14 |
+
out.append(json.loads(line))
|
| 15 |
+
return out
|
| 16 |
+
|
| 17 |
+
def summarize(tag, recs, target_key):
|
| 18 |
+
if not recs:
|
| 19 |
+
print(f"{tag}: no data"); return
|
| 20 |
+
n = 0; em = 0; subset = 0; size_sum = 0
|
| 21 |
+
for r in recs:
|
| 22 |
+
if not r.get("parse_ok"):
|
| 23 |
+
continue
|
| 24 |
+
p = tuple(sorted(r["predicted_values"]))
|
| 25 |
+
t = tuple(sorted(r.get(target_key, [])))
|
| 26 |
+
n += 1
|
| 27 |
+
if p == t:
|
| 28 |
+
em += 1
|
| 29 |
+
if p and t and set(p).issubset(set(t)):
|
| 30 |
+
subset += 1
|
| 31 |
+
size_sum += len(p)
|
| 32 |
+
print(f"{tag:32s} n={n:4d} exact={em/max(1,n):.3f} subset={subset/max(1,n):.3f} avg|p|={size_sum/max(1,n):.2f}")
|
| 33 |
+
|
| 34 |
+
print("=== Diagonal (already had) ===")
|
| 35 |
+
for tag, t_key in [("atc_s1","target_S1"),("atc_s2","target_S2"),("atc_s3","target_S3"),
|
| 36 |
+
("dc_s1","target_S1"),("dc_s2","target_S2"),("dc_s3","target_S3")]:
|
| 37 |
+
p = DIAG / f"{tag}.jsonl"
|
| 38 |
+
if p.exists(): summarize(tag, load(p), t_key)
|
| 39 |
+
|
| 40 |
+
print()
|
| 41 |
+
print("=== Off-diagonal cross-prompt ===")
|
| 42 |
+
for tag in ["atc_train3_prompt1","atc_train3_prompt2","atc_train2_prompt3",
|
| 43 |
+
"dc_train3_prompt1","dc_train3_prompt2"]:
|
| 44 |
+
p = XP / f"{tag}.jsonl"
|
| 45 |
+
if not p.exists() or not p.stat().st_size:
|
| 46 |
+
print(f"{tag}: (missing)"); continue
|
| 47 |
+
# prompt stage is the trailing digit
|
| 48 |
+
q = int(tag.split("prompt")[1])
|
| 49 |
+
summarize(tag + f" [eval vs S{q}]", load(p), f"target_S{q}")
|
_experiments/cross_stage/analyze.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Aggregate per-cell predictions across (method, stage) and produce
|
| 2 |
+
containment / catastrophic-rewrite plots + per-cell trajectory visualizations.
|
| 3 |
+
|
| 4 |
+
Expects six JSONL files in --preds_dir: atc_s{1,2,3}.jsonl, dc_s{1,2,3}.jsonl
|
| 5 |
+
(each produced by predict_one.py).
|
| 6 |
+
|
| 7 |
+
Produces:
|
| 8 |
+
- containment_summary.json (numeric report)
|
| 9 |
+
- fig_containment.{pdf,png} (3 grouped bars per method)
|
| 10 |
+
- fig_sankey_example.{pdf,png} (one 9x9 grid of per-cell trajectories for 1 puzzle)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import json
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import matplotlib as mpl
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
METHODS = ["atc", "dc"]
|
| 25 |
+
STAGES = [1, 2, 3]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_preds(preds_dir: Path):
|
| 29 |
+
"""Return dict[(method, stage)] -> dict[(puzzle_id, target_cell)] -> record."""
|
| 30 |
+
out = {}
|
| 31 |
+
for m in METHODS:
|
| 32 |
+
for s in STAGES:
|
| 33 |
+
tag = f"{m}_s{s}"
|
| 34 |
+
path = preds_dir / f"{tag}.jsonl"
|
| 35 |
+
d = {}
|
| 36 |
+
if not path.exists():
|
| 37 |
+
print(f"WARN missing {path}")
|
| 38 |
+
out[(m, s)] = d
|
| 39 |
+
continue
|
| 40 |
+
with open(path) as f:
|
| 41 |
+
for line in f:
|
| 42 |
+
line = line.strip()
|
| 43 |
+
if not line:
|
| 44 |
+
continue
|
| 45 |
+
r = json.loads(line)
|
| 46 |
+
key = (int(r["puzzle_id"]), tuple(r["target_cell"]))
|
| 47 |
+
d[key] = r
|
| 48 |
+
out[(m, s)] = d
|
| 49 |
+
print(f"loaded {tag}: {len(d)} cells")
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cells_common(preds):
|
| 54 |
+
"""Intersection of cell keys across all 6 (method, stage) files."""
|
| 55 |
+
sets = [set(preds[(m, s)].keys()) for m in METHODS for s in STAGES if preds[(m, s)]]
|
| 56 |
+
if not sets:
|
| 57 |
+
return set()
|
| 58 |
+
common = sets[0]
|
| 59 |
+
for s in sets[1:]:
|
| 60 |
+
common &= s
|
| 61 |
+
return sorted(common)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def containment(pred_set, ref_set):
|
| 65 |
+
"""Return 1 if pred_set non-empty and pred_set ⊆ ref_set, else 0.
|
| 66 |
+
Empty prediction or empty reference -> 0."""
|
| 67 |
+
if not pred_set or not ref_set:
|
| 68 |
+
return 0
|
| 69 |
+
return int(set(pred_set).issubset(set(ref_set)))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def disjoint(a, b):
|
| 73 |
+
return int(bool(a) and bool(b) and not (set(a) & set(b)))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def compute_metrics(preds, common_cells):
|
| 77 |
+
"""For each method, aggregate per-cell stats."""
|
| 78 |
+
out = {}
|
| 79 |
+
for m in METHODS:
|
| 80 |
+
n = 0
|
| 81 |
+
c13 = c23 = c12 = 0
|
| 82 |
+
rew_3_disjoint_1 = 0
|
| 83 |
+
rew_3_disjoint_2 = 0
|
| 84 |
+
size_s1 = size_s2 = size_s3 = 0
|
| 85 |
+
for key in common_cells:
|
| 86 |
+
r1 = preds[(m, 1)][key]
|
| 87 |
+
r2 = preds[(m, 2)][key]
|
| 88 |
+
r3 = preds[(m, 3)][key]
|
| 89 |
+
p1 = r1["predicted_values"]
|
| 90 |
+
p2 = r2["predicted_values"]
|
| 91 |
+
p3 = r3["predicted_values"]
|
| 92 |
+
if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
|
| 93 |
+
continue
|
| 94 |
+
n += 1
|
| 95 |
+
c13 += containment(p3, p1)
|
| 96 |
+
c23 += containment(p3, p2)
|
| 97 |
+
c12 += containment(p2, p1)
|
| 98 |
+
rew_3_disjoint_1 += disjoint(p3, p1)
|
| 99 |
+
rew_3_disjoint_2 += disjoint(p3, p2)
|
| 100 |
+
size_s1 += len(p1)
|
| 101 |
+
size_s2 += len(p2)
|
| 102 |
+
size_s3 += len(p3)
|
| 103 |
+
out[m] = {
|
| 104 |
+
"n": n,
|
| 105 |
+
"containment_S3_in_S1": c13 / max(1, n),
|
| 106 |
+
"containment_S3_in_S2": c23 / max(1, n),
|
| 107 |
+
"containment_S2_in_S1": c12 / max(1, n),
|
| 108 |
+
"catastrophic_S3_disjoint_S1": rew_3_disjoint_1 / max(1, n),
|
| 109 |
+
"catastrophic_S3_disjoint_S2": rew_3_disjoint_2 / max(1, n),
|
| 110 |
+
"avg_predicted_size_S1": size_s1 / max(1, n),
|
| 111 |
+
"avg_predicted_size_S2": size_s2 / max(1, n),
|
| 112 |
+
"avg_predicted_size_S3": size_s3 / max(1, n),
|
| 113 |
+
}
|
| 114 |
+
return out
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------- plotting ----------------------------------------------------
|
| 118 |
+
|
| 119 |
+
mpl.rcParams.update({
|
| 120 |
+
"font.family": "serif",
|
| 121 |
+
"font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
|
| 122 |
+
"font.size": 12,
|
| 123 |
+
"axes.labelsize": 12,
|
| 124 |
+
"xtick.labelsize": 11,
|
| 125 |
+
"ytick.labelsize": 11,
|
| 126 |
+
"legend.fontsize": 11,
|
| 127 |
+
"axes.spines.top": False,
|
| 128 |
+
"axes.spines.right": False,
|
| 129 |
+
"axes.linewidth": 1.0,
|
| 130 |
+
"lines.linewidth": 2.0,
|
| 131 |
+
"pdf.fonttype": 42,
|
| 132 |
+
"ps.fonttype": 42,
|
| 133 |
+
})
|
| 134 |
+
|
| 135 |
+
ATC_COLOR = "#1f4f8b"
|
| 136 |
+
DC_COLOR = "#b21e2f"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def plot_containment(metrics, out_path):
|
| 140 |
+
fig, ax = plt.subplots(figsize=(5.2, 3.6), constrained_layout=True)
|
| 141 |
+
groups = [
|
| 142 |
+
("$\\hat S_3 \\subseteq \\hat S_1$", "containment_S3_in_S1"),
|
| 143 |
+
("$\\hat S_3 \\subseteq \\hat S_2$", "containment_S3_in_S2"),
|
| 144 |
+
("$\\hat S_3 \\cap \\hat S_1 = \\varnothing$", "catastrophic_S3_disjoint_S1"),
|
| 145 |
+
]
|
| 146 |
+
x = list(range(len(groups)))
|
| 147 |
+
w = 0.36
|
| 148 |
+
atc_vals = [metrics["atc"][k] for _, k in groups]
|
| 149 |
+
dc_vals = [metrics["dc"][k] for _, k in groups]
|
| 150 |
+
ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
|
| 151 |
+
ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
|
| 152 |
+
for xi, v in zip(x, atc_vals):
|
| 153 |
+
ax.text(xi - w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=10, color=ATC_COLOR)
|
| 154 |
+
for xi, v in zip(x, dc_vals):
|
| 155 |
+
ax.text(xi + w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=10, color=DC_COLOR)
|
| 156 |
+
ax.set_xticks(x, [lbl for lbl, _ in groups])
|
| 157 |
+
ax.set_ylim(0, 1.05)
|
| 158 |
+
ax.set_ylabel("Fraction of cells")
|
| 159 |
+
ax.legend(frameon=False, loc="upper right")
|
| 160 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 161 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 162 |
+
plt.close(fig)
|
| 163 |
+
print(f"saved {out_path}.pdf / .png")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def plot_sankey_grid(preds, out_path, puzzle_id=0):
|
| 167 |
+
"""For one puzzle, render a 9x9 grid where each empty cell shows three
|
| 168 |
+
columns of candidate values (S1 / S2 / S3) per method, color-coded by
|
| 169 |
+
whether each value survives from S1 to S3.
|
| 170 |
+
"""
|
| 171 |
+
fig, axes = plt.subplots(1, 2, figsize=(9, 4.5), constrained_layout=True)
|
| 172 |
+
for ax, method, title in zip(axes, ["atc", "dc"], ["ATC (latent + curriculum)", "Data Curriculum (no CoT)"]):
|
| 173 |
+
cells = []
|
| 174 |
+
for key, r3 in sorted(preds[(method, 3)].items()):
|
| 175 |
+
if key[0] != puzzle_id:
|
| 176 |
+
continue
|
| 177 |
+
p1 = preds[(method, 1)].get(key, {}).get("predicted_values") or []
|
| 178 |
+
p2 = preds[(method, 2)].get(key, {}).get("predicted_values") or []
|
| 179 |
+
p3 = r3.get("predicted_values") or []
|
| 180 |
+
cells.append((key[1], p1, p2, p3, r3.get("target_solution")))
|
| 181 |
+
n = len(cells)
|
| 182 |
+
if n == 0:
|
| 183 |
+
ax.text(0.5, 0.5, "(no data)", transform=ax.transAxes, ha="center")
|
| 184 |
+
ax.set_title(title)
|
| 185 |
+
continue
|
| 186 |
+
ax.set_xlim(0, 3)
|
| 187 |
+
ax.set_ylim(-0.5, n - 0.5)
|
| 188 |
+
for i, (cell_rc, p1, p2, p3, gt) in enumerate(cells):
|
| 189 |
+
r, c = cell_rc
|
| 190 |
+
ax.text(-0.4, n - 1 - i, f"({r+1},{c+1})", va="center", ha="right", fontsize=8, color="0.4")
|
| 191 |
+
for j, vals, x_center in [(0, p1, 0.5), (1, p2, 1.5), (2, p3, 2.5)]:
|
| 192 |
+
txt = ",".join(str(v) for v in vals) if vals else "—"
|
| 193 |
+
ax.text(x_center, n - 1 - i, txt, va="center", ha="center", fontsize=9)
|
| 194 |
+
in_p1 = bool(p3 and set(p3).issubset(set(p1))) if p1 else False
|
| 195 |
+
color = "0.85" if in_p1 else "#f5b7b1"
|
| 196 |
+
ax.axhspan(n - 1 - i - 0.5, n - 1 - i + 0.5, facecolor=color, alpha=0.4, zorder=0)
|
| 197 |
+
ax.set_xticks([0.5, 1.5, 2.5], ["S1", "S2", "S3"])
|
| 198 |
+
ax.set_yticks([])
|
| 199 |
+
ax.set_title(title, fontsize=11)
|
| 200 |
+
ax.spines["left"].set_visible(False)
|
| 201 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 202 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 203 |
+
plt.close(fig)
|
| 204 |
+
print(f"saved {out_path}.pdf / .png")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def main():
|
| 208 |
+
p = argparse.ArgumentParser()
|
| 209 |
+
p.add_argument("--preds_dir", required=True)
|
| 210 |
+
p.add_argument("--out_dir", required=True)
|
| 211 |
+
p.add_argument("--example_puzzle", type=int, default=0)
|
| 212 |
+
args = p.parse_args()
|
| 213 |
+
|
| 214 |
+
preds_dir = Path(args.preds_dir)
|
| 215 |
+
out_dir = Path(args.out_dir)
|
| 216 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 217 |
+
|
| 218 |
+
preds = load_preds(preds_dir)
|
| 219 |
+
common = cells_common(preds)
|
| 220 |
+
print(f"common cells across all 6 files: {len(common)}")
|
| 221 |
+
|
| 222 |
+
metrics = compute_metrics(preds, common)
|
| 223 |
+
summary = {
|
| 224 |
+
"n_common_cells": len(common),
|
| 225 |
+
"metrics": metrics,
|
| 226 |
+
}
|
| 227 |
+
with open(out_dir / "containment_summary.json", "w") as f:
|
| 228 |
+
json.dump(summary, f, indent=2)
|
| 229 |
+
print(json.dumps(metrics, indent=2))
|
| 230 |
+
|
| 231 |
+
plot_containment(metrics, out_dir / "fig_containment")
|
| 232 |
+
plot_sankey_grid(preds, out_dir / "fig_sankey_example", puzzle_id=args.example_puzzle)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
main()
|
_experiments/cross_stage/analyze_cross_prompt.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyse cross-prompt evaluations.
|
| 2 |
+
|
| 3 |
+
For each (method, train_stage, prompt_stage) pair we have a JSONL of per-cell
|
| 4 |
+
records produced by predict_one.py. The diagonals (train_stage == prompt_stage)
|
| 5 |
+
live in ../preds/ ; the off-diagonals (this experiment) live in ../preds_xprompt/.
|
| 6 |
+
|
| 7 |
+
Headline question: does each model still do the OFF-DIAGONAL task correctly?
|
| 8 |
+
We measure:
|
| 9 |
+
- exact_set_match_vs_target : predicted_values == target_S{prompt_stage}
|
| 10 |
+
- subset_of_target : predicted_values ⊆ target_S{prompt_stage}
|
| 11 |
+
- avg |predicted set|
|
| 12 |
+
- "drift" : exact_set_match vs predicted of the model's ORIGINAL training
|
| 13 |
+
stage (i.e. does ATC_S3 prompted with S1 still produce its
|
| 14 |
+
own S3 answer? -> indicates that the prompt was ignored)
|
| 15 |
+
|
| 16 |
+
Plots:
|
| 17 |
+
fig_xprompt_solve_grid.{pdf,png} - 2-method × 3 prompt_stage × 3 train_stage
|
| 18 |
+
heatmap of exact-set-match accuracy
|
| 19 |
+
fig_xprompt_setsize.{pdf,png} - avg |pred| for each (train, prompt) cell
|
| 20 |
+
grouped by method
|
| 21 |
+
fig_xprompt_forgetting.{pdf,png} - "forward compat": for the S3 adapter
|
| 22 |
+
prompted with stage_i in {1,2,3}, plot
|
| 23 |
+
exact-match accuracy. Both methods.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
from collections import defaultdict
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import matplotlib as mpl
|
| 35 |
+
import matplotlib.pyplot as plt
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
METHODS = ["atc", "dc"]
|
| 39 |
+
STAGES = [1, 2, 3]
|
| 40 |
+
ATC_COLOR = "#1f4f8b"; DC_COLOR = "#b21e2f"
|
| 41 |
+
COLOR = {"atc": ATC_COLOR, "dc": DC_COLOR}
|
| 42 |
+
PRETTY = {"atc": "ATC", "dc": "Data Curriculum"}
|
| 43 |
+
|
| 44 |
+
mpl.rcParams.update({
|
| 45 |
+
"font.family": "serif",
|
| 46 |
+
"font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
|
| 47 |
+
"font.size": 12,
|
| 48 |
+
"axes.labelsize": 12,
|
| 49 |
+
"xtick.labelsize": 11,
|
| 50 |
+
"ytick.labelsize": 11,
|
| 51 |
+
"legend.fontsize": 10,
|
| 52 |
+
"axes.spines.top": False,
|
| 53 |
+
"axes.spines.right": False,
|
| 54 |
+
"axes.linewidth": 1.0,
|
| 55 |
+
"lines.linewidth": 2.0,
|
| 56 |
+
"lines.markersize": 7,
|
| 57 |
+
"pdf.fonttype": 42,
|
| 58 |
+
"ps.fonttype": 42,
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def parse_tag(tag: str):
|
| 63 |
+
"""
|
| 64 |
+
Returns (method, train_stage, prompt_stage).
|
| 65 |
+
Diagonal: atc_s3 -> ("atc", 3, 3)
|
| 66 |
+
Off-diag: atc_train3_prompt1 -> ("atc", 3, 1)
|
| 67 |
+
"""
|
| 68 |
+
parts = tag.split("_")
|
| 69 |
+
method = parts[0]
|
| 70 |
+
if len(parts) == 2 and parts[1].startswith("s"):
|
| 71 |
+
s = int(parts[1][1:])
|
| 72 |
+
return method, s, s
|
| 73 |
+
# expect *_trainK_promptM
|
| 74 |
+
train = int([p for p in parts if p.startswith("train")][0][5:])
|
| 75 |
+
prompt = int([p for p in parts if p.startswith("prompt")][0][6:])
|
| 76 |
+
return method, train, prompt
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_dir(p: Path):
|
| 80 |
+
by_key = {} # (method, train, prompt) -> { (puzzle, cell) -> record }
|
| 81 |
+
for path in sorted(p.glob("*.jsonl")):
|
| 82 |
+
tag = path.stem
|
| 83 |
+
try:
|
| 84 |
+
m, t, q = parse_tag(tag)
|
| 85 |
+
except Exception:
|
| 86 |
+
continue
|
| 87 |
+
d = {}
|
| 88 |
+
with open(path) as f:
|
| 89 |
+
for line in f:
|
| 90 |
+
line = line.strip()
|
| 91 |
+
if not line:
|
| 92 |
+
continue
|
| 93 |
+
r = json.loads(line)
|
| 94 |
+
d[(int(r["puzzle_id"]), tuple(r["target_cell"]))] = r
|
| 95 |
+
by_key[(m, t, q)] = d
|
| 96 |
+
print(f"loaded {tag}: {len(d)} cells -> (method={m}, train={t}, prompt={q})")
|
| 97 |
+
return by_key
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def target_field(rec, q):
|
| 101 |
+
return tuple(rec.get(f"target_S{q}", []))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def aggregate(by_key):
|
| 105 |
+
"""For each (method, train, prompt) cell compute summary metrics."""
|
| 106 |
+
rows = []
|
| 107 |
+
for (m, t, q), d in by_key.items():
|
| 108 |
+
n = 0; em = 0; subset = 0; size_sum = 0
|
| 109 |
+
for rec in d.values():
|
| 110 |
+
if not rec.get("parse_ok"):
|
| 111 |
+
continue
|
| 112 |
+
n += 1
|
| 113 |
+
pred = tuple(sorted(rec["predicted_values"]))
|
| 114 |
+
targ = tuple(sorted(target_field(rec, q)))
|
| 115 |
+
if pred == targ:
|
| 116 |
+
em += 1
|
| 117 |
+
if pred and targ and set(pred).issubset(set(targ)):
|
| 118 |
+
subset += 1
|
| 119 |
+
size_sum += len(pred)
|
| 120 |
+
rows.append({
|
| 121 |
+
"method": m, "train": t, "prompt": q, "n": n,
|
| 122 |
+
"exact_match_vs_prompt_target": em / max(1, n),
|
| 123 |
+
"subset_of_prompt_target": subset / max(1, n),
|
| 124 |
+
"avg_pred_size": size_sum / max(1, n),
|
| 125 |
+
})
|
| 126 |
+
return rows
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def drift_from_diagonal(by_key):
|
| 130 |
+
"""
|
| 131 |
+
For each off-diagonal cell (train=t, prompt=q, q != t), measure
|
| 132 |
+
Frac of cells where pred(train=t, prompt=q) == pred(train=t, prompt=t).
|
| 133 |
+
If high -> model ignored the prompt change (anchored to training stage).
|
| 134 |
+
If low -> model actually changed behaviour with prompt.
|
| 135 |
+
"""
|
| 136 |
+
out = []
|
| 137 |
+
for (m, t, q), d_q in by_key.items():
|
| 138 |
+
if q == t:
|
| 139 |
+
continue
|
| 140 |
+
d_t = by_key.get((m, t, t))
|
| 141 |
+
if not d_t:
|
| 142 |
+
continue
|
| 143 |
+
n = 0; same = 0
|
| 144 |
+
for key, rec_q in d_q.items():
|
| 145 |
+
rec_t = d_t.get(key)
|
| 146 |
+
if not rec_t or not rec_q.get("parse_ok") or not rec_t.get("parse_ok"):
|
| 147 |
+
continue
|
| 148 |
+
n += 1
|
| 149 |
+
same += int(tuple(sorted(rec_q["predicted_values"])) ==
|
| 150 |
+
tuple(sorted(rec_t["predicted_values"])))
|
| 151 |
+
out.append({"method": m, "train": t, "prompt": q,
|
| 152 |
+
"n": n, "frac_ignored_prompt": same / max(1, n)})
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ----------------------- PLOTS ----------------------------
|
| 157 |
+
|
| 158 |
+
def plot_solve_grid(rows, out_path):
|
| 159 |
+
"""3x3 train×prompt heatmap of exact_match for each method, side by side."""
|
| 160 |
+
fig, axes = plt.subplots(1, 2, figsize=(8.2, 3.6), constrained_layout=True)
|
| 161 |
+
for ax, m in zip(axes, METHODS):
|
| 162 |
+
grid = np.full((3, 3), np.nan)
|
| 163 |
+
for r in rows:
|
| 164 |
+
if r["method"] != m:
|
| 165 |
+
continue
|
| 166 |
+
grid[r["train"] - 1, r["prompt"] - 1] = r["exact_match_vs_prompt_target"]
|
| 167 |
+
im = ax.imshow(grid, vmin=0.0, vmax=1.0, cmap="Blues")
|
| 168 |
+
ax.set_xticks([0, 1, 2], ["S1", "S2", "S3"])
|
| 169 |
+
ax.set_yticks([0, 1, 2], ["S1", "S2", "S3"])
|
| 170 |
+
ax.set_xlabel("Prompt stage_i")
|
| 171 |
+
ax.set_ylabel("Trained stage")
|
| 172 |
+
ax.set_title(PRETTY[m], fontsize=11)
|
| 173 |
+
for i in range(3):
|
| 174 |
+
for j in range(3):
|
| 175 |
+
v = grid[i, j]
|
| 176 |
+
if not np.isnan(v):
|
| 177 |
+
ax.text(j, i, f"{v:.2f}", ha="center", va="center",
|
| 178 |
+
color="white" if v > 0.5 else "black", fontsize=10)
|
| 179 |
+
cb = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.85, fraction=0.05)
|
| 180 |
+
cb.set_label("Exact set match vs prompt target")
|
| 181 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 182 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 183 |
+
plt.close(fig)
|
| 184 |
+
print(f"saved {out_path}.pdf/.png")
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def plot_forward_compat(rows, out_path):
|
| 188 |
+
"""For each method, S3-adapter prompted with stage_i=1/2/3: exact-match."""
|
| 189 |
+
fig, ax = plt.subplots(figsize=(5.0, 3.6), constrained_layout=True)
|
| 190 |
+
x = [1, 2, 3]
|
| 191 |
+
for m, marker, ls in [("atc", "s", "-"), ("dc", "o", "--")]:
|
| 192 |
+
y = []
|
| 193 |
+
for q in [1, 2, 3]:
|
| 194 |
+
for r in rows:
|
| 195 |
+
if r["method"] == m and r["train"] == 3 and r["prompt"] == q:
|
| 196 |
+
y.append(r["exact_match_vs_prompt_target"])
|
| 197 |
+
break
|
| 198 |
+
else:
|
| 199 |
+
y.append(np.nan)
|
| 200 |
+
ax.plot(x, y, color=COLOR[m], marker=marker, linestyle=ls,
|
| 201 |
+
label=PRETTY[m])
|
| 202 |
+
for xi, v in zip(x, y):
|
| 203 |
+
if not np.isnan(v):
|
| 204 |
+
ax.text(xi, v + 0.02, f"{v:.2f}", ha="center", va="bottom",
|
| 205 |
+
fontsize=9, color=COLOR[m])
|
| 206 |
+
ax.set_xticks(x, ["Ask S1", "Ask S2", "Ask S3"])
|
| 207 |
+
ax.set_xlabel("Prompt task")
|
| 208 |
+
ax.set_ylim(0.0, 1.05)
|
| 209 |
+
ax.set_ylabel("Exact set-match on prompted task")
|
| 210 |
+
ax.legend(frameon=False, loc="lower left")
|
| 211 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 212 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
print(f"saved {out_path}.pdf/.png")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def plot_pred_size_grid(rows, out_path):
|
| 218 |
+
fig, axes = plt.subplots(1, 2, figsize=(8.2, 3.6), constrained_layout=True)
|
| 219 |
+
for ax, m in zip(axes, METHODS):
|
| 220 |
+
grid = np.full((3, 3), np.nan)
|
| 221 |
+
for r in rows:
|
| 222 |
+
if r["method"] != m:
|
| 223 |
+
continue
|
| 224 |
+
grid[r["train"] - 1, r["prompt"] - 1] = r["avg_pred_size"]
|
| 225 |
+
im = ax.imshow(grid, vmin=1.0, vmax=2.5, cmap="Oranges")
|
| 226 |
+
ax.set_xticks([0, 1, 2], ["S1", "S2", "S3"])
|
| 227 |
+
ax.set_yticks([0, 1, 2], ["S1", "S2", "S3"])
|
| 228 |
+
ax.set_xlabel("Prompt stage_i")
|
| 229 |
+
ax.set_ylabel("Trained stage")
|
| 230 |
+
ax.set_title(PRETTY[m], fontsize=11)
|
| 231 |
+
for i in range(3):
|
| 232 |
+
for j in range(3):
|
| 233 |
+
v = grid[i, j]
|
| 234 |
+
if not np.isnan(v):
|
| 235 |
+
ax.text(j, i, f"{v:.2f}", ha="center", va="center",
|
| 236 |
+
color="white" if v > 1.8 else "black", fontsize=10)
|
| 237 |
+
cb = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.85, fraction=0.05)
|
| 238 |
+
cb.set_label("Avg |predicted candidate set|")
|
| 239 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 240 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 241 |
+
plt.close(fig)
|
| 242 |
+
print(f"saved {out_path}.pdf/.png")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def plot_prompt_responsiveness(drift_rows, out_path):
|
| 246 |
+
"""For each off-diagonal cell, plot frac_ignored_prompt (low = good)."""
|
| 247 |
+
fig, ax = plt.subplots(figsize=(5.8, 3.6), constrained_layout=True)
|
| 248 |
+
labels = []
|
| 249 |
+
atc_vals = []; dc_vals = []
|
| 250 |
+
pairs = sorted({(r["train"], r["prompt"]) for r in drift_rows})
|
| 251 |
+
for (t, q) in pairs:
|
| 252 |
+
labels.append(f"S{t}→S{q}")
|
| 253 |
+
atc_vals.append(next((r["frac_ignored_prompt"] for r in drift_rows
|
| 254 |
+
if r["method"] == "atc" and r["train"] == t and r["prompt"] == q), np.nan))
|
| 255 |
+
dc_vals.append(next((r["frac_ignored_prompt"] for r in drift_rows
|
| 256 |
+
if r["method"] == "dc" and r["train"] == t and r["prompt"] == q), np.nan))
|
| 257 |
+
x = list(range(len(labels)))
|
| 258 |
+
w = 0.36
|
| 259 |
+
ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
|
| 260 |
+
ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
|
| 261 |
+
for xi, v in zip(x, atc_vals):
|
| 262 |
+
if not np.isnan(v):
|
| 263 |
+
ax.text(xi - w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
|
| 264 |
+
for xi, v in zip(x, dc_vals):
|
| 265 |
+
if not np.isnan(v):
|
| 266 |
+
ax.text(xi + w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
|
| 267 |
+
ax.set_xticks(x, labels)
|
| 268 |
+
ax.set_ylim(0, 1.05)
|
| 269 |
+
ax.set_ylabel("Frac cells with prediction ≡ same model's train-stage answer\n(low = model actually responded to new prompt)")
|
| 270 |
+
ax.legend(frameon=False, loc="upper left")
|
| 271 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 272 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 273 |
+
plt.close(fig)
|
| 274 |
+
print(f"saved {out_path}.pdf/.png")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def main():
|
| 278 |
+
ap = argparse.ArgumentParser()
|
| 279 |
+
ap.add_argument("--diag_dir", required=True)
|
| 280 |
+
ap.add_argument("--xprompt_dir", required=True)
|
| 281 |
+
ap.add_argument("--out_dir", required=True)
|
| 282 |
+
args = ap.parse_args()
|
| 283 |
+
|
| 284 |
+
out = Path(args.out_dir); out.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
by_key = load_dir(Path(args.diag_dir))
|
| 286 |
+
by_key.update(load_dir(Path(args.xprompt_dir)))
|
| 287 |
+
|
| 288 |
+
rows = aggregate(by_key)
|
| 289 |
+
drift = drift_from_diagonal(by_key)
|
| 290 |
+
|
| 291 |
+
with open(out / "xprompt_summary.json", "w") as f:
|
| 292 |
+
json.dump({"rows": rows, "drift": drift}, f, indent=2)
|
| 293 |
+
print(json.dumps({"rows": rows, "drift": drift}, indent=2))
|
| 294 |
+
|
| 295 |
+
plot_solve_grid(rows, out / "fig_xprompt_solve_grid")
|
| 296 |
+
plot_pred_size_grid(rows, out / "fig_xprompt_setsize")
|
| 297 |
+
plot_forward_compat(rows, out / "fig_xprompt_forward_compat")
|
| 298 |
+
if drift:
|
| 299 |
+
plot_prompt_responsiveness(drift, out / "fig_xprompt_prompt_response")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
main()
|
_experiments/cross_stage/analyze_v2.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Extended cross-stage containment analyses.
|
| 2 |
+
|
| 3 |
+
Reads the 6 JSONL files produced by predict_one.py (one per method-stage cell)
|
| 4 |
+
and emits multiple plots that probe HOW latent CoT propagates constraints
|
| 5 |
+
versus the vanilla data-curriculum baseline.
|
| 6 |
+
|
| 7 |
+
Plots produced (PDF + PNG):
|
| 8 |
+
fig_containment_basic - 3 grouped bars: S3⊆S1, S3⊆S2, S3∩S1=∅
|
| 9 |
+
fig_containment_by_diff - same 3 bars BROKEN DOWN by ground-truth |S1|
|
| 10 |
+
(cell difficulty axis = |true legal candidate set|)
|
| 11 |
+
fig_set_size_trajectory - avg predicted set size at S1/S2/S3 per method
|
| 12 |
+
fig_correctness_breakdown - among incorrect S3 predictions, what fraction
|
| 13 |
+
stays inside S1 / S2 vs. is catastrophic?
|
| 14 |
+
fig_method_agreement - fraction of cells where ATC.S3 == DC.S3, broken
|
| 15 |
+
down by ground-truth difficulty
|
| 16 |
+
fig_sankey_example - per-cell value trajectory for one puzzle
|
| 17 |
+
(existing in analyze.py, refreshed here)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
from collections import defaultdict
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Dict, List, Tuple
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import matplotlib as mpl
|
| 30 |
+
import matplotlib.pyplot as plt
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
METHODS = ["atc", "dc"]
|
| 34 |
+
STAGES = [1, 2, 3]
|
| 35 |
+
METHOD_PRETTY = {"atc": "ATC", "dc": "Data Curriculum"}
|
| 36 |
+
ATC_COLOR = "#1f4f8b"
|
| 37 |
+
DC_COLOR = "#b21e2f"
|
| 38 |
+
COLOR = {"atc": ATC_COLOR, "dc": DC_COLOR}
|
| 39 |
+
|
| 40 |
+
mpl.rcParams.update({
|
| 41 |
+
"font.family": "serif",
|
| 42 |
+
"font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
|
| 43 |
+
"font.size": 12,
|
| 44 |
+
"axes.labelsize": 12,
|
| 45 |
+
"xtick.labelsize": 11,
|
| 46 |
+
"ytick.labelsize": 11,
|
| 47 |
+
"legend.fontsize": 10,
|
| 48 |
+
"axes.spines.top": False,
|
| 49 |
+
"axes.spines.right": False,
|
| 50 |
+
"axes.linewidth": 1.0,
|
| 51 |
+
"lines.linewidth": 2.0,
|
| 52 |
+
"lines.markersize": 7,
|
| 53 |
+
"pdf.fonttype": 42,
|
| 54 |
+
"ps.fonttype": 42,
|
| 55 |
+
})
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_preds(preds_dir: Path):
|
| 59 |
+
out = {}
|
| 60 |
+
for m in METHODS:
|
| 61 |
+
for s in STAGES:
|
| 62 |
+
tag = f"{m}_s{s}"
|
| 63 |
+
d = {}
|
| 64 |
+
path = preds_dir / f"{tag}.jsonl"
|
| 65 |
+
if path.exists():
|
| 66 |
+
with open(path) as f:
|
| 67 |
+
for line in f:
|
| 68 |
+
line = line.strip()
|
| 69 |
+
if not line:
|
| 70 |
+
continue
|
| 71 |
+
r = json.loads(line)
|
| 72 |
+
d[(int(r["puzzle_id"]), tuple(r["target_cell"]))] = r
|
| 73 |
+
out[(m, s)] = d
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def cells_common(preds):
|
| 78 |
+
sets = [set(preds[(m, s)].keys()) for m in METHODS for s in STAGES if preds[(m, s)]]
|
| 79 |
+
if not sets:
|
| 80 |
+
return []
|
| 81 |
+
common = sets[0]
|
| 82 |
+
for s in sets[1:]:
|
| 83 |
+
common &= s
|
| 84 |
+
return sorted(common)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def diff_bucket(target_s1):
|
| 88 |
+
n = len(target_s1)
|
| 89 |
+
if n <= 1:
|
| 90 |
+
return "|S1|=1"
|
| 91 |
+
if n == 2:
|
| 92 |
+
return "|S1|=2"
|
| 93 |
+
if n == 3:
|
| 94 |
+
return "|S1|=3"
|
| 95 |
+
return "|S1|≥4"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
DIFF_ORDER = ["|S1|=1", "|S1|=2", "|S1|=3", "|S1|≥4"]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _safe_div(a, b):
|
| 102 |
+
return float(a) / float(b) if b else 0.0
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compute_per_difficulty(preds, common):
|
| 106 |
+
"""For each method × difficulty bucket compute containment metrics."""
|
| 107 |
+
rows = []
|
| 108 |
+
for m in METHODS:
|
| 109 |
+
per_bucket = {b: defaultdict(int) for b in DIFF_ORDER}
|
| 110 |
+
for key in common:
|
| 111 |
+
r1 = preds[(m, 1)][key]; r2 = preds[(m, 2)][key]; r3 = preds[(m, 3)][key]
|
| 112 |
+
if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
|
| 113 |
+
continue
|
| 114 |
+
b = diff_bucket(r1["target_S1"])
|
| 115 |
+
p1 = set(r1["predicted_values"]); p2 = set(r2["predicted_values"]); p3 = set(r3["predicted_values"])
|
| 116 |
+
t = r3.get("target_solution")
|
| 117 |
+
per_bucket[b]["n"] += 1
|
| 118 |
+
per_bucket[b]["c13"] += int(bool(p3) and bool(p1) and p3.issubset(p1))
|
| 119 |
+
per_bucket[b]["c23"] += int(bool(p3) and bool(p2) and p3.issubset(p2))
|
| 120 |
+
per_bucket[b]["d13"] += int(bool(p3) and bool(p1) and not (p3 & p1))
|
| 121 |
+
per_bucket[b]["d23"] += int(bool(p3) and bool(p2) and not (p3 & p2))
|
| 122 |
+
per_bucket[b]["correct"] += int(t in p3 and len(p3) == 1)
|
| 123 |
+
per_bucket[b]["sum_size_s1"] += len(p1)
|
| 124 |
+
per_bucket[b]["sum_size_s2"] += len(p2)
|
| 125 |
+
per_bucket[b]["sum_size_s3"] += len(p3)
|
| 126 |
+
for b in DIFF_ORDER:
|
| 127 |
+
d = per_bucket[b]
|
| 128 |
+
n = d["n"]
|
| 129 |
+
rows.append({
|
| 130 |
+
"method": m, "bucket": b, "n": n,
|
| 131 |
+
"c13": _safe_div(d["c13"], n),
|
| 132 |
+
"c23": _safe_div(d["c23"], n),
|
| 133 |
+
"d13": _safe_div(d["d13"], n),
|
| 134 |
+
"d23": _safe_div(d["d23"], n),
|
| 135 |
+
"correct": _safe_div(d["correct"], n),
|
| 136 |
+
"size_s1": _safe_div(d["sum_size_s1"], n),
|
| 137 |
+
"size_s2": _safe_div(d["sum_size_s2"], n),
|
| 138 |
+
"size_s3": _safe_div(d["sum_size_s3"], n),
|
| 139 |
+
})
|
| 140 |
+
return rows
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def compute_correctness_breakdown(preds, common):
|
| 144 |
+
"""When S3 prediction is WRONG, where did it land?"""
|
| 145 |
+
out = {}
|
| 146 |
+
for m in METHODS:
|
| 147 |
+
n_wrong = 0
|
| 148 |
+
wrong_in_s1 = 0
|
| 149 |
+
wrong_in_s2 = 0
|
| 150 |
+
wrong_disjoint_s1 = 0
|
| 151 |
+
wrong_disjoint_s2 = 0
|
| 152 |
+
n_correct = 0
|
| 153 |
+
for key in common:
|
| 154 |
+
r1 = preds[(m, 1)][key]; r2 = preds[(m, 2)][key]; r3 = preds[(m, 3)][key]
|
| 155 |
+
if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
|
| 156 |
+
continue
|
| 157 |
+
p1 = set(r1["predicted_values"]); p2 = set(r2["predicted_values"]); p3 = set(r3["predicted_values"])
|
| 158 |
+
t = r3["target_solution"]
|
| 159 |
+
cell_correct = (len(p3) == 1 and t in p3)
|
| 160 |
+
if cell_correct:
|
| 161 |
+
n_correct += 1
|
| 162 |
+
continue
|
| 163 |
+
n_wrong += 1
|
| 164 |
+
wrong_in_s1 += int(bool(p3) and bool(p1) and p3.issubset(p1))
|
| 165 |
+
wrong_in_s2 += int(bool(p3) and bool(p2) and p3.issubset(p2))
|
| 166 |
+
wrong_disjoint_s1 += int(bool(p3) and bool(p1) and not (p3 & p1))
|
| 167 |
+
wrong_disjoint_s2 += int(bool(p3) and bool(p2) and not (p3 & p2))
|
| 168 |
+
out[m] = {
|
| 169 |
+
"n_correct": n_correct,
|
| 170 |
+
"n_wrong": n_wrong,
|
| 171 |
+
"wrong_in_s1_frac": _safe_div(wrong_in_s1, n_wrong),
|
| 172 |
+
"wrong_in_s2_frac": _safe_div(wrong_in_s2, n_wrong),
|
| 173 |
+
"wrong_disjoint_s1_frac": _safe_div(wrong_disjoint_s1, n_wrong),
|
| 174 |
+
"wrong_disjoint_s2_frac": _safe_div(wrong_disjoint_s2, n_wrong),
|
| 175 |
+
}
|
| 176 |
+
return out
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compute_method_agreement(preds, common):
|
| 180 |
+
"""Frequency of ATC.S3 == DC.S3 stratified by ground-truth difficulty."""
|
| 181 |
+
per_bucket = {b: {"n": 0, "agree": 0, "atc_correct": 0, "dc_correct": 0} for b in DIFF_ORDER}
|
| 182 |
+
for key in common:
|
| 183 |
+
atc_r = preds[("atc", 3)][key]; dc_r = preds[("dc", 3)][key]
|
| 184 |
+
if not (atc_r["parse_ok"] and dc_r["parse_ok"]):
|
| 185 |
+
continue
|
| 186 |
+
ap = sorted(atc_r["predicted_values"]); dp = sorted(dc_r["predicted_values"])
|
| 187 |
+
b = diff_bucket(atc_r["target_S1"])
|
| 188 |
+
t = atc_r["target_solution"]
|
| 189 |
+
per_bucket[b]["n"] += 1
|
| 190 |
+
per_bucket[b]["agree"] += int(ap == dp)
|
| 191 |
+
per_bucket[b]["atc_correct"] += int(len(ap) == 1 and t in ap)
|
| 192 |
+
per_bucket[b]["dc_correct"] += int(len(dp) == 1 and t in dp)
|
| 193 |
+
return per_bucket
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ----------------------------- PLOTS -----------------------------------
|
| 197 |
+
|
| 198 |
+
def plot_containment_basic(metrics, out_path):
|
| 199 |
+
"""Re-do the headline bar chart."""
|
| 200 |
+
fig, ax = plt.subplots(figsize=(5.4, 3.6), constrained_layout=True)
|
| 201 |
+
groups = [
|
| 202 |
+
("$\\hat S_3 \\subseteq \\hat S_1$", "c13"),
|
| 203 |
+
("$\\hat S_3 \\subseteq \\hat S_2$", "c23"),
|
| 204 |
+
("$\\hat S_3 \\cap \\hat S_1=\\varnothing$", "d13"),
|
| 205 |
+
("$\\hat S_3 \\cap \\hat S_2=\\varnothing$", "d23"),
|
| 206 |
+
]
|
| 207 |
+
x = list(range(len(groups)))
|
| 208 |
+
w = 0.36
|
| 209 |
+
atc_vals = [metrics["atc"][k] for _, k in groups]
|
| 210 |
+
dc_vals = [metrics["dc"][k] for _, k in groups]
|
| 211 |
+
ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
|
| 212 |
+
ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
|
| 213 |
+
for xi, v in zip(x, atc_vals):
|
| 214 |
+
ax.text(xi - w/2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
|
| 215 |
+
for xi, v in zip(x, dc_vals):
|
| 216 |
+
ax.text(xi + w/2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
|
| 217 |
+
ax.set_xticks(x, [lbl for lbl, _ in groups])
|
| 218 |
+
ax.set_ylim(0, 1.06)
|
| 219 |
+
ax.set_ylabel("Fraction of cells")
|
| 220 |
+
ax.legend(frameon=False, loc="upper right")
|
| 221 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 222 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 223 |
+
plt.close(fig)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def plot_containment_by_difficulty(rows, key, ylabel, out_path):
|
| 227 |
+
fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
|
| 228 |
+
by_m = {m: {r["bucket"]: r[key] for r in rows if r["method"] == m} for m in METHODS}
|
| 229 |
+
by_n = {m: {r["bucket"]: r["n"] for r in rows if r["method"] == m} for m in METHODS}
|
| 230 |
+
x = list(range(len(DIFF_ORDER)))
|
| 231 |
+
w = 0.36
|
| 232 |
+
atc_vals = [by_m["atc"].get(b, 0) for b in DIFF_ORDER]
|
| 233 |
+
dc_vals = [by_m["dc"].get(b, 0) for b in DIFF_ORDER]
|
| 234 |
+
ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
|
| 235 |
+
ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
|
| 236 |
+
for xi, v in zip(x, atc_vals):
|
| 237 |
+
ax.text(xi - w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8, color=ATC_COLOR)
|
| 238 |
+
for xi, v in zip(x, dc_vals):
|
| 239 |
+
ax.text(xi + w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8, color=DC_COLOR)
|
| 240 |
+
# n-cells annotation under each group
|
| 241 |
+
for xi, b in zip(x, DIFF_ORDER):
|
| 242 |
+
n = by_n["atc"].get(b, 0)
|
| 243 |
+
ax.text(xi, -0.06, f"n={n}", ha="center", va="top", fontsize=8, color="0.4", transform=ax.get_xaxis_transform())
|
| 244 |
+
ax.set_xticks(x, DIFF_ORDER)
|
| 245 |
+
ax.set_ylim(0, 1.05)
|
| 246 |
+
ax.set_ylabel(ylabel)
|
| 247 |
+
ax.legend(frameon=False, loc="lower left")
|
| 248 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 249 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 250 |
+
plt.close(fig)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def plot_set_size_trajectory(rows, out_path):
|
| 254 |
+
"""Avg predicted set size across S1 → S2 → S3, per method."""
|
| 255 |
+
fig, ax = plt.subplots(figsize=(5.2, 3.6), constrained_layout=True)
|
| 256 |
+
# average across all buckets weighted by n
|
| 257 |
+
def avg(method, key):
|
| 258 |
+
ns = sum(r["n"] for r in rows if r["method"] == method)
|
| 259 |
+
s = sum(r[key] * r["n"] for r in rows if r["method"] == method)
|
| 260 |
+
return s / max(1, ns)
|
| 261 |
+
for m, marker, ls in [("atc", "s", "-"), ("dc", "o", "--")]:
|
| 262 |
+
y = [avg(m, "size_s1"), avg(m, "size_s2"), avg(m, "size_s3")]
|
| 263 |
+
ax.plot([1, 2, 3], y, color=COLOR[m], marker=marker, linestyle=ls, label=METHOD_PRETTY[m])
|
| 264 |
+
for xi, v in zip([1, 2, 3], y):
|
| 265 |
+
ax.text(xi, v + 0.03, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=COLOR[m])
|
| 266 |
+
ax.set_xticks([1, 2, 3], ["Stage 1", "Stage 2", "Stage 3"])
|
| 267 |
+
ax.set_ylim(0.95, 1.45)
|
| 268 |
+
ax.set_ylabel("Avg |predicted candidate set|")
|
| 269 |
+
ax.grid(True, axis="y", linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
|
| 270 |
+
ax.legend(frameon=False, loc="upper right")
|
| 271 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 272 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 273 |
+
plt.close(fig)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def plot_correctness_breakdown(stats, out_path):
|
| 277 |
+
"""Among WRONG S3 cells, what fraction stays in S1 or in S2?"""
|
| 278 |
+
fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
|
| 279 |
+
groups = [
|
| 280 |
+
("Wrong but $\\subseteq \\hat S_1$", "wrong_in_s1_frac"),
|
| 281 |
+
("Wrong but $\\subseteq \\hat S_2$", "wrong_in_s2_frac"),
|
| 282 |
+
("Wrong & $\\cap \\hat S_1=\\varnothing$", "wrong_disjoint_s1_frac"),
|
| 283 |
+
("Wrong & $\\cap \\hat S_2=\\varnothing$", "wrong_disjoint_s2_frac"),
|
| 284 |
+
]
|
| 285 |
+
x = list(range(len(groups)))
|
| 286 |
+
w = 0.36
|
| 287 |
+
atc_vals = [stats["atc"][k] for _, k in groups]
|
| 288 |
+
dc_vals = [stats["dc"][k] for _, k in groups]
|
| 289 |
+
ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR,
|
| 290 |
+
label=f"ATC (n_wrong={stats['atc']['n_wrong']})", edgecolor="none")
|
| 291 |
+
ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR,
|
| 292 |
+
label=f"Data Curr. (n_wrong={stats['dc']['n_wrong']})", edgecolor="none")
|
| 293 |
+
for xi, v in zip(x, atc_vals):
|
| 294 |
+
ax.text(xi - w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
|
| 295 |
+
for xi, v in zip(x, dc_vals):
|
| 296 |
+
ax.text(xi + w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
|
| 297 |
+
ax.set_xticks(x, [lbl for lbl, _ in groups])
|
| 298 |
+
ax.set_ylim(0, 1.05)
|
| 299 |
+
ax.set_ylabel("Fraction of wrong S3 cells")
|
| 300 |
+
ax.legend(frameon=False, loc="upper right")
|
| 301 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 302 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 303 |
+
plt.close(fig)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def plot_method_agreement(per_bucket, out_path):
|
| 307 |
+
fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
|
| 308 |
+
x = list(range(len(DIFF_ORDER)))
|
| 309 |
+
w = 0.28
|
| 310 |
+
agree = [_safe_div(per_bucket[b]["agree"], per_bucket[b]["n"]) for b in DIFF_ORDER]
|
| 311 |
+
atc_ok = [_safe_div(per_bucket[b]["atc_correct"], per_bucket[b]["n"]) for b in DIFF_ORDER]
|
| 312 |
+
dc_ok = [_safe_div(per_bucket[b]["dc_correct"], per_bucket[b]["n"]) for b in DIFF_ORDER]
|
| 313 |
+
ax.bar([xi - w for xi in x], atc_ok, w, color=ATC_COLOR, label="ATC correct", edgecolor="none")
|
| 314 |
+
ax.bar([xi for xi in x], dc_ok, w, color=DC_COLOR, label="DC correct", edgecolor="none")
|
| 315 |
+
ax.bar([xi + w for xi in x], agree, w, color="0.4", label="ATC == DC", edgecolor="none")
|
| 316 |
+
for xi, b in zip(x, DIFF_ORDER):
|
| 317 |
+
n = per_bucket[b]["n"]
|
| 318 |
+
ax.text(xi, -0.06, f"n={n}", ha="center", va="top", fontsize=8, color="0.4", transform=ax.get_xaxis_transform())
|
| 319 |
+
ax.set_xticks(x, DIFF_ORDER)
|
| 320 |
+
ax.set_ylim(0, 1.05)
|
| 321 |
+
ax.set_ylabel("Fraction")
|
| 322 |
+
ax.legend(frameon=False, loc="lower left")
|
| 323 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 324 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 325 |
+
plt.close(fig)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# Re-use the simple sankey from analyze.py (lightly compacted)
|
| 329 |
+
def plot_sankey(preds, out_path, puzzle_id=0):
|
| 330 |
+
fig, axes = plt.subplots(1, 2, figsize=(9, 4.6), constrained_layout=True)
|
| 331 |
+
for ax, method in zip(axes, ["atc", "dc"]):
|
| 332 |
+
cells = []
|
| 333 |
+
for key, r3 in sorted(preds[(method, 3)].items()):
|
| 334 |
+
if key[0] != puzzle_id:
|
| 335 |
+
continue
|
| 336 |
+
p1 = preds[(method, 1)].get(key, {}).get("predicted_values") or []
|
| 337 |
+
p2 = preds[(method, 2)].get(key, {}).get("predicted_values") or []
|
| 338 |
+
p3 = r3.get("predicted_values") or []
|
| 339 |
+
cells.append((key[1], p1, p2, p3, r3.get("target_solution")))
|
| 340 |
+
n = len(cells)
|
| 341 |
+
ax.set_xlim(0, 3); ax.set_ylim(-0.5, n - 0.5)
|
| 342 |
+
for i, (cell_rc, p1, p2, p3, gt) in enumerate(cells):
|
| 343 |
+
r, c = cell_rc
|
| 344 |
+
ax.text(-0.4, n - 1 - i, f"({r+1},{c+1})", va="center", ha="right", fontsize=8, color="0.4")
|
| 345 |
+
for x_center, vals in [(0.5, p1), (1.5, p2), (2.5, p3)]:
|
| 346 |
+
txt = ",".join(str(v) for v in vals) if vals else "—"
|
| 347 |
+
ax.text(x_center, n - 1 - i, txt, va="center", ha="center", fontsize=9)
|
| 348 |
+
ok = bool(p3 and p1 and set(p3).issubset(set(p1)))
|
| 349 |
+
color = "0.88" if ok else "#f5b7b1"
|
| 350 |
+
ax.axhspan(n - 1 - i - 0.5, n - 1 - i + 0.5, facecolor=color, alpha=0.4, zorder=0)
|
| 351 |
+
ax.set_xticks([0.5, 1.5, 2.5], ["S1", "S2", "S3"])
|
| 352 |
+
ax.set_yticks([])
|
| 353 |
+
ax.set_title(METHOD_PRETTY[method], fontsize=11)
|
| 354 |
+
ax.spines["left"].set_visible(False)
|
| 355 |
+
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
|
| 356 |
+
fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
|
| 357 |
+
plt.close(fig)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ----------------------------- MAIN ------------------------------------
|
| 361 |
+
|
| 362 |
+
def main():
|
| 363 |
+
p = argparse.ArgumentParser()
|
| 364 |
+
p.add_argument("--preds_dir", required=True)
|
| 365 |
+
p.add_argument("--out_dir", required=True)
|
| 366 |
+
p.add_argument("--example_puzzle", type=int, default=0)
|
| 367 |
+
args = p.parse_args()
|
| 368 |
+
|
| 369 |
+
preds_dir = Path(args.preds_dir); out = Path(args.out_dir)
|
| 370 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 371 |
+
preds = load_preds(preds_dir)
|
| 372 |
+
common = cells_common(preds)
|
| 373 |
+
print(f"common cells: {len(common)}")
|
| 374 |
+
|
| 375 |
+
rows = compute_per_difficulty(preds, common)
|
| 376 |
+
aggregate = {m: {"c13": 0, "c23": 0, "d13": 0, "d23": 0, "n": 0} for m in METHODS}
|
| 377 |
+
for r in rows:
|
| 378 |
+
for k in ("c13", "c23", "d13", "d23"):
|
| 379 |
+
aggregate[r["method"]][k] += r[k] * r["n"]
|
| 380 |
+
aggregate[r["method"]]["n"] += r["n"]
|
| 381 |
+
for m in METHODS:
|
| 382 |
+
n = aggregate[m]["n"]
|
| 383 |
+
for k in ("c13", "c23", "d13", "d23"):
|
| 384 |
+
aggregate[m][k] = aggregate[m][k] / max(1, n)
|
| 385 |
+
|
| 386 |
+
correctness = compute_correctness_breakdown(preds, common)
|
| 387 |
+
agreement = compute_method_agreement(preds, common)
|
| 388 |
+
|
| 389 |
+
summary = {
|
| 390 |
+
"n_common_cells": len(common),
|
| 391 |
+
"aggregate": aggregate,
|
| 392 |
+
"per_difficulty": rows,
|
| 393 |
+
"correctness_breakdown": correctness,
|
| 394 |
+
"agreement_by_difficulty": {b: agreement[b] for b in DIFF_ORDER},
|
| 395 |
+
}
|
| 396 |
+
with open(out / "containment_summary_v2.json", "w") as f:
|
| 397 |
+
json.dump(summary, f, indent=2)
|
| 398 |
+
|
| 399 |
+
plot_containment_basic(aggregate, out / "fig_containment_basic")
|
| 400 |
+
plot_containment_by_difficulty(rows, "c13", "$P(\\hat S_3 \\subseteq \\hat S_1)$",
|
| 401 |
+
out / "fig_c13_by_diff")
|
| 402 |
+
plot_containment_by_difficulty(rows, "c23", "$P(\\hat S_3 \\subseteq \\hat S_2)$",
|
| 403 |
+
out / "fig_c23_by_diff")
|
| 404 |
+
plot_containment_by_difficulty(rows, "d23", "$P(\\hat S_3 \\cap \\hat S_2=\\varnothing)$",
|
| 405 |
+
out / "fig_d23_by_diff")
|
| 406 |
+
plot_containment_by_difficulty(rows, "correct", "Solve rate at S3",
|
| 407 |
+
out / "fig_solve_by_diff")
|
| 408 |
+
plot_set_size_trajectory(rows, out / "fig_set_size_trajectory")
|
| 409 |
+
plot_correctness_breakdown(correctness, out / "fig_correctness_breakdown")
|
| 410 |
+
plot_method_agreement(agreement, out / "fig_method_agreement")
|
| 411 |
+
plot_sankey(preds, out / "fig_sankey_example", puzzle_id=args.example_puzzle)
|
| 412 |
+
|
| 413 |
+
print(json.dumps(summary["aggregate"], indent=2))
|
| 414 |
+
print("agreement_by_difficulty:")
|
| 415 |
+
for b in DIFF_ORDER:
|
| 416 |
+
d = agreement[b]
|
| 417 |
+
if d["n"]:
|
| 418 |
+
print(f" {b}: n={d['n']} agree={d['agree']/d['n']:.3f} "
|
| 419 |
+
f"atc_correct={d['atc_correct']/d['n']:.3f} dc_correct={d['dc_correct']/d['n']:.3f}")
|
| 420 |
+
print("correctness_breakdown:")
|
| 421 |
+
print(json.dumps(correctness, indent=2))
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
if __name__ == "__main__":
|
| 425 |
+
main()
|
_experiments/cross_stage/overnight_pipeline.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Overnight orchestrator:
|
| 3 |
+
# 1) wait for phase-1 cross-prompt jobs to finish (already launched)
|
| 4 |
+
# 2) launch phase-2 cross-prompt sweep
|
| 5 |
+
# 3) wait for phase-2 to finish
|
| 6 |
+
# 4) run analyze_cross_prompt.py to produce all plots
|
| 7 |
+
# 5) print a summary
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
REPO=/home/ubuntu/curriculum-cot-code
|
| 12 |
+
LOG_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt
|
| 13 |
+
FIGS_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/figs_xprompt
|
| 14 |
+
PY=/opt/pytorch/bin/python
|
| 15 |
+
|
| 16 |
+
mkdir -p "$FIGS_DIR"
|
| 17 |
+
|
| 18 |
+
PHASE1_TAGS=(atc_train3_prompt1 atc_train3_prompt2 dc_train3_prompt1 dc_train3_prompt2 atc_train2_prompt3)
|
| 19 |
+
PHASE2_TAGS=(atc_train1_prompt2 atc_train1_prompt3 atc_train2_prompt1 dc_train1_prompt2 dc_train1_prompt3 dc_train2_prompt1 dc_train2_prompt3)
|
| 20 |
+
|
| 21 |
+
wait_for_tags() {
|
| 22 |
+
local -n tags=$1
|
| 23 |
+
local need=${#tags[@]}
|
| 24 |
+
while true; do
|
| 25 |
+
local done_count=0
|
| 26 |
+
for tag in "${tags[@]}"; do
|
| 27 |
+
if grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
|
| 28 |
+
done_count=$((done_count+1))
|
| 29 |
+
fi
|
| 30 |
+
done
|
| 31 |
+
echo "[$(date +%T)] $done_count / $need done"
|
| 32 |
+
if [ "$done_count" -ge "$need" ]; then
|
| 33 |
+
break
|
| 34 |
+
fi
|
| 35 |
+
sleep 90
|
| 36 |
+
done
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
echo "[$(date +%T)] waiting for phase-1 cross-prompt jobs..."
|
| 40 |
+
wait_for_tags PHASE1_TAGS
|
| 41 |
+
echo "[$(date +%T)] phase 1 complete; launching phase 2"
|
| 42 |
+
|
| 43 |
+
bash "$REPO/_experiments/cross_stage/run_cross_prompt_phase2.sh"
|
| 44 |
+
|
| 45 |
+
echo "[$(date +%T)] phase 2 complete; running analyses"
|
| 46 |
+
|
| 47 |
+
"$PY" "$REPO/_experiments/cross_stage/analyze_v2.py" \
|
| 48 |
+
--preds_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds \
|
| 49 |
+
--out_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/figs \
|
| 50 |
+
--example_puzzle 0
|
| 51 |
+
|
| 52 |
+
"$PY" "$REPO/_experiments/cross_stage/analyze_cross_prompt.py" \
|
| 53 |
+
--diag_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds \
|
| 54 |
+
--xprompt_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt \
|
| 55 |
+
--out_dir "$FIGS_DIR"
|
| 56 |
+
|
| 57 |
+
echo "[$(date +%T)] done. Figures in: $FIGS_DIR"
|
| 58 |
+
ls -la "$FIGS_DIR"
|
_experiments/cross_stage/predict_one.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dump per-cell predictions for one (method, stage) checkpoint on a fixed eval set.
|
| 2 |
+
|
| 3 |
+
For each empty cell of each puzzle in the eval JSONL, runs the given adapter and
|
| 4 |
+
writes a JSON line with:
|
| 5 |
+
method_tag : free-form id e.g. "atc_s1"
|
| 6 |
+
puzzle_id : 0-based row index
|
| 7 |
+
target_cell : [r, c] (0-based, matches `ex.target_cell`)
|
| 8 |
+
target_solution : the unique true value at this cell
|
| 9 |
+
stage_prompted : the stage_i argument passed to the prompt builder
|
| 10 |
+
predicted_values : sorted list of ints in [1,9] parsed from model output
|
| 11 |
+
parse_ok / exact_set_match : booleans from score_prediction_text
|
| 12 |
+
target_S1 / S2 / S3 : the stage-1/2/3 consistent value sets for this cell
|
| 13 |
+
(computed independently of the model so the
|
| 14 |
+
post-processing script can compare across stages)
|
| 15 |
+
|
| 16 |
+
For the latent (recurrent-hidden) checkpoints set `--latent_mode recurrent_hidden`
|
| 17 |
+
and `--num_cot_tokens` to whatever value the model was trained at.
|
| 18 |
+
For vanilla baseline checkpoints leave both at their defaults.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import time
|
| 28 |
+
from pathlib import Path
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 32 |
+
|
| 33 |
+
REPO = Path(__file__).resolve().parents[2]
|
| 34 |
+
if str(REPO) not in sys.path:
|
| 35 |
+
sys.path.insert(0, str(REPO))
|
| 36 |
+
|
| 37 |
+
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row
|
| 38 |
+
from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt
|
| 39 |
+
from multi_output_cell_policy.rewards import score_prediction_text
|
| 40 |
+
from multi_output_cell_policy.shared_multi_output_policy import (
|
| 41 |
+
make_solved_grid_from_row,
|
| 42 |
+
stage_i_consistent_values,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def parse_args():
|
| 47 |
+
p = argparse.ArgumentParser()
|
| 48 |
+
p.add_argument("--method_tag", required=True)
|
| 49 |
+
p.add_argument("--adapter_dir", required=True)
|
| 50 |
+
p.add_argument("--eval_jsonl", required=True)
|
| 51 |
+
p.add_argument("--eval_rows", type=int, default=100)
|
| 52 |
+
p.add_argument("--stage_i", type=int, required=True)
|
| 53 |
+
p.add_argument("--total_empties_hint", type=int, default=20)
|
| 54 |
+
p.add_argument("--latent_mode", default="none",
|
| 55 |
+
choices=["none", "recurrent_hidden", "fixed_slots", "latent_seeds", "residual"])
|
| 56 |
+
p.add_argument("--num_cot_tokens", type=int, default=0)
|
| 57 |
+
p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 58 |
+
p.add_argument("--cache_dir", default=str(REPO / ".hf_cache"))
|
| 59 |
+
p.add_argument("--gpu_id", type=int, default=0)
|
| 60 |
+
p.add_argument("--max_completion_length", type=int, default=24)
|
| 61 |
+
p.add_argument("--out_jsonl", required=True)
|
| 62 |
+
return p.parse_args()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_jsonl(path: str, limit: int):
|
| 66 |
+
out = []
|
| 67 |
+
with open(path) as f:
|
| 68 |
+
for line in f:
|
| 69 |
+
line = line.strip()
|
| 70 |
+
if not line:
|
| 71 |
+
continue
|
| 72 |
+
out.append(json.loads(line))
|
| 73 |
+
if len(out) >= limit:
|
| 74 |
+
break
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
args = parse_args()
|
| 80 |
+
os.makedirs(os.path.dirname(args.out_jsonl) or ".", exist_ok=True)
|
| 81 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 82 |
+
|
| 83 |
+
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 84 |
+
|
| 85 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 86 |
+
args.model_name, cache_dir=args.cache_dir, use_fast=True
|
| 87 |
+
)
|
| 88 |
+
if tokenizer.pad_token_id is None:
|
| 89 |
+
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
|
| 90 |
+
|
| 91 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 92 |
+
args.model_name, cache_dir=args.cache_dir,
|
| 93 |
+
torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
is_latent = args.latent_mode != "none"
|
| 97 |
+
if is_latent:
|
| 98 |
+
from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import (
|
| 99 |
+
load_trainable_adapter,
|
| 100 |
+
sample_recurrent_hidden_completion,
|
| 101 |
+
)
|
| 102 |
+
model = load_trainable_adapter(
|
| 103 |
+
base, args.adapter_dir,
|
| 104 |
+
lora_r=32, lora_alpha=64, lora_dropout=0.05,
|
| 105 |
+
)
|
| 106 |
+
if args.latent_mode != "recurrent_hidden":
|
| 107 |
+
raise SystemExit(f"Only recurrent_hidden latent_mode is wired up here; got {args.latent_mode!r}")
|
| 108 |
+
sample_fn = sample_recurrent_hidden_completion
|
| 109 |
+
else:
|
| 110 |
+
from peft import PeftModel
|
| 111 |
+
model = PeftModel.from_pretrained(base, args.adapter_dir, is_trainable=False)
|
| 112 |
+
sample_fn = None
|
| 113 |
+
|
| 114 |
+
if hasattr(model, "config"):
|
| 115 |
+
model.config.use_cache = True
|
| 116 |
+
model.to(device).eval()
|
| 117 |
+
|
| 118 |
+
rows = load_jsonl(args.eval_jsonl, args.eval_rows)
|
| 119 |
+
|
| 120 |
+
t0 = time.time()
|
| 121 |
+
n_cells = 0
|
| 122 |
+
with open(args.out_jsonl, "w") as fout:
|
| 123 |
+
for puzzle_id, row in enumerate(rows):
|
| 124 |
+
solved = make_solved_grid_from_row(row)
|
| 125 |
+
for ex in build_cell_examples_from_row(row):
|
| 126 |
+
prompt = build_multi_output_cell_prompt(
|
| 127 |
+
ex.grid,
|
| 128 |
+
target_cell=ex.target_cell,
|
| 129 |
+
stage_i=args.stage_i,
|
| 130 |
+
tokenizer=tokenizer,
|
| 131 |
+
turn_idx=ex.turn_idx,
|
| 132 |
+
total_turns=ex.total_turns,
|
| 133 |
+
prev_output_flag=None,
|
| 134 |
+
total_empties_hint=args.total_empties_hint,
|
| 135 |
+
)
|
| 136 |
+
enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
| 137 |
+
input_ids = enc["input_ids"].to(device)
|
| 138 |
+
attn = enc["attention_mask"].to(device)
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
if is_latent:
|
| 142 |
+
completion_ids = sample_fn(
|
| 143 |
+
model, tokenizer, input_ids, attn,
|
| 144 |
+
num_cot_tokens=int(args.num_cot_tokens),
|
| 145 |
+
max_new_tokens=max(1, int(args.max_completion_length)),
|
| 146 |
+
do_sample=False,
|
| 147 |
+
)
|
| 148 |
+
pred_text = tokenizer.decode(
|
| 149 |
+
completion_ids[0], skip_special_tokens=True
|
| 150 |
+
).strip()
|
| 151 |
+
else:
|
| 152 |
+
out = model.generate(
|
| 153 |
+
input_ids=input_ids,
|
| 154 |
+
attention_mask=attn,
|
| 155 |
+
max_new_tokens=max(1, int(args.max_completion_length)),
|
| 156 |
+
do_sample=False,
|
| 157 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 158 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 159 |
+
)
|
| 160 |
+
pred_text = tokenizer.decode(
|
| 161 |
+
out[0][input_ids.shape[1]:], skip_special_tokens=True
|
| 162 |
+
).strip()
|
| 163 |
+
|
| 164 |
+
info = score_prediction_text(
|
| 165 |
+
text=pred_text,
|
| 166 |
+
grid=ex.grid,
|
| 167 |
+
solved=solved,
|
| 168 |
+
target_cell=ex.target_cell,
|
| 169 |
+
stage_i=args.stage_i,
|
| 170 |
+
reward_good_value=1.0,
|
| 171 |
+
penalty_bad_value=1.0,
|
| 172 |
+
penalty_malformed=4.0,
|
| 173 |
+
penalty_empty=0.5,
|
| 174 |
+
penalty_singleton=1.5,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
t1 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=1))
|
| 178 |
+
t2 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=2))
|
| 179 |
+
t3 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=3))
|
| 180 |
+
|
| 181 |
+
pred_values_raw = info.get("predicted_values") or []
|
| 182 |
+
predicted_values = sorted(int(v) for v in pred_values_raw if isinstance(v, (int, float)))
|
| 183 |
+
|
| 184 |
+
rec = {
|
| 185 |
+
"method_tag": args.method_tag,
|
| 186 |
+
"puzzle_id": int(puzzle_id),
|
| 187 |
+
"target_cell": [int(ex.target_cell[0]), int(ex.target_cell[1])],
|
| 188 |
+
"target_solution": int(ex.target_value),
|
| 189 |
+
"stage_prompted": int(args.stage_i),
|
| 190 |
+
"predicted_values": predicted_values,
|
| 191 |
+
"predicted_text": pred_text,
|
| 192 |
+
"parse_ok": bool(info["parse_ok"]),
|
| 193 |
+
"exact_set_match": bool(info["exact_set_match"]),
|
| 194 |
+
"target_S1": t1,
|
| 195 |
+
"target_S2": t2,
|
| 196 |
+
"target_S3": t3,
|
| 197 |
+
}
|
| 198 |
+
fout.write(json.dumps(rec) + "\n")
|
| 199 |
+
n_cells += 1
|
| 200 |
+
if (puzzle_id + 1) % 10 == 0:
|
| 201 |
+
print(
|
| 202 |
+
f"[{args.method_tag}] puzzle {puzzle_id+1}/{len(rows)} "
|
| 203 |
+
f"cells={n_cells} elapsed={time.time()-t0:.0f}s",
|
| 204 |
+
flush=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
print(f"[{args.method_tag}] DONE cells={n_cells} elapsed={time.time()-t0:.0f}s out={args.out_jsonl}")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
main()
|
_experiments/cross_stage/run_all.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run the 6-way cross-stage prediction sweep in parallel on GPUs 0-5.
|
| 3 |
+
# Each job writes /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds/<tag>.jsonl
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
REPO=/home/ubuntu/curriculum-cot-code
|
| 7 |
+
EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 8 |
+
EVAL_ROWS=${EVAL_ROWS:-100}
|
| 9 |
+
OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds}
|
| 10 |
+
LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs}
|
| 11 |
+
mkdir -p "$OUT_DIR" "$LOG_DIR"
|
| 12 |
+
|
| 13 |
+
PY=/opt/pytorch/bin/python
|
| 14 |
+
SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
|
| 15 |
+
|
| 16 |
+
# (tag, gpu, adapter_dir, stage_i, latent_mode, num_cot)
|
| 17 |
+
declare -a JOBS=(
|
| 18 |
+
"atc_s1|0|/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden|1|recurrent_hidden|1"
|
| 19 |
+
"atc_s2|1|/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200|2|recurrent_hidden|3"
|
| 20 |
+
"atc_s3|2|/home/ubuntu/hf_checkpoints/latent_stages/rebuttal_champion_100p/s3_grpo_baseline_checkpoint-200|3|recurrent_hidden|3"
|
| 21 |
+
"dc_s1|3|/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2|1|none|0"
|
| 22 |
+
"dc_s2|4|/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000|2|none|0"
|
| 23 |
+
"dc_s3|5|/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200|3|none|0"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
PIDS=()
|
| 27 |
+
for entry in "${JOBS[@]}"; do
|
| 28 |
+
IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
|
| 29 |
+
echo "[$(date +%T)] launching $tag on GPU $gpu (stage_i=$stage_i mode=$mode cot=$cot)"
|
| 30 |
+
CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
|
| 31 |
+
--method_tag "$tag" \
|
| 32 |
+
--adapter_dir "$adapter" \
|
| 33 |
+
--eval_jsonl "$EVAL" \
|
| 34 |
+
--eval_rows "$EVAL_ROWS" \
|
| 35 |
+
--stage_i "$stage_i" \
|
| 36 |
+
--latent_mode "$mode" \
|
| 37 |
+
--num_cot_tokens "$cot" \
|
| 38 |
+
--gpu_id 0 \
|
| 39 |
+
--out_jsonl "$OUT_DIR/$tag.jsonl" \
|
| 40 |
+
> "$LOG_DIR/$tag.log" 2>&1 &
|
| 41 |
+
PIDS+=("$!")
|
| 42 |
+
done
|
| 43 |
+
|
| 44 |
+
echo "Launched 6 jobs with PIDs: ${PIDS[*]}"
|
| 45 |
+
echo "Logs: $LOG_DIR"
|
| 46 |
+
echo "Outputs: $OUT_DIR"
|
| 47 |
+
echo
|
| 48 |
+
echo "Waiting for all to finish..."
|
| 49 |
+
fail=0
|
| 50 |
+
for pid in "${PIDS[@]}"; do
|
| 51 |
+
if wait "$pid"; then
|
| 52 |
+
echo " pid $pid OK"
|
| 53 |
+
else
|
| 54 |
+
echo " pid $pid FAILED"
|
| 55 |
+
fail=$((fail + 1))
|
| 56 |
+
fi
|
| 57 |
+
done
|
| 58 |
+
echo "Done. $fail failures."
|
_experiments/cross_stage/run_cross_prompt.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Cross-prompt sweep: pair every checkpoint with an OFF-DIAGONAL stage_i prompt.
|
| 3 |
+
#
|
| 4 |
+
# Idea: predict_one.py loads (method, train_stage) adapter and prompts it with
|
| 5 |
+
# any stage_i. If latent CoT preserves cross-stage information, the latent S3
|
| 6 |
+
# adapter should still be able to enumerate the S1 candidate set when asked,
|
| 7 |
+
# while the data-curriculum S3 adapter has "overwritten" that capability.
|
| 8 |
+
#
|
| 9 |
+
# Launches 6 jobs in parallel on GPUs 1..5 (and re-uses 0 for analysis later).
|
| 10 |
+
# Output:
|
| 11 |
+
# /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt/<tag>.jsonl
|
| 12 |
+
# /home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt/<tag>.log
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
REPO=/home/ubuntu/curriculum-cot-code
|
| 16 |
+
EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 17 |
+
EVAL_ROWS=${EVAL_ROWS:-200}
|
| 18 |
+
OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt}
|
| 19 |
+
LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt}
|
| 20 |
+
mkdir -p "$OUT_DIR" "$LOG_DIR"
|
| 21 |
+
|
| 22 |
+
PY=/opt/pytorch/bin/python
|
| 23 |
+
SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
|
| 24 |
+
|
| 25 |
+
# Adapters re-used from the diagonal sweep
|
| 26 |
+
ATC_S1=/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
|
| 27 |
+
ATC_S2=/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200
|
| 28 |
+
ATC_S3=/home/ubuntu/hf_checkpoints/latent_stages/rebuttal_champion_100p/s3_grpo_baseline_checkpoint-200
|
| 29 |
+
DC_S1=/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2
|
| 30 |
+
DC_S2=/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
|
| 31 |
+
DC_S3=/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200
|
| 32 |
+
|
| 33 |
+
# Each row: tag | gpu | adapter_dir | prompt_stage_i | latent_mode | num_cot
|
| 34 |
+
# - "tag" embeds the (train_stage, prompt_stage) pair so analyze script
|
| 35 |
+
# can pick them up automatically.
|
| 36 |
+
declare -a JOBS=(
|
| 37 |
+
# forward-compat: S3 model asked to do S1 / S2 enumeration
|
| 38 |
+
"atc_train3_prompt1|1|$ATC_S3|1|recurrent_hidden|3"
|
| 39 |
+
"atc_train3_prompt2|2|$ATC_S3|2|recurrent_hidden|3"
|
| 40 |
+
"dc_train3_prompt1|3|$DC_S3|1|none|0"
|
| 41 |
+
"dc_train3_prompt2|4|$DC_S3|2|none|0"
|
| 42 |
+
# backward-compat: S1 model asked to commit (do S3); also S2->S3
|
| 43 |
+
"atc_train2_prompt3|5|$ATC_S2|3|recurrent_hidden|3"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
PIDS=()
|
| 47 |
+
for entry in "${JOBS[@]}"; do
|
| 48 |
+
IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
|
| 49 |
+
echo "[$(date +%T)] launching $tag on GPU $gpu (prompt stage_i=$stage_i, mode=$mode, cot=$cot)"
|
| 50 |
+
CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
|
| 51 |
+
--method_tag "$tag" \
|
| 52 |
+
--adapter_dir "$adapter" \
|
| 53 |
+
--eval_jsonl "$EVAL" \
|
| 54 |
+
--eval_rows "$EVAL_ROWS" \
|
| 55 |
+
--stage_i "$stage_i" \
|
| 56 |
+
--latent_mode "$mode" \
|
| 57 |
+
--num_cot_tokens "$cot" \
|
| 58 |
+
--gpu_id 0 \
|
| 59 |
+
--out_jsonl "$OUT_DIR/$tag.jsonl" \
|
| 60 |
+
> "$LOG_DIR/$tag.log" 2>&1 &
|
| 61 |
+
PIDS+=("$!")
|
| 62 |
+
done
|
| 63 |
+
|
| 64 |
+
echo "Launched ${#PIDS[@]} cross-prompt jobs: ${PIDS[*]}"
|
| 65 |
+
echo "Logs: $LOG_DIR"
|
| 66 |
+
echo "Outputs: $OUT_DIR"
|
_experiments/cross_stage/run_cross_prompt_phase2.sh
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Phase-2 cross-prompt sweep: the remaining off-diagonals.
|
| 3 |
+
#
|
| 4 |
+
# Phase 1 covered:
|
| 5 |
+
# atc_train3_prompt1, atc_train3_prompt2,
|
| 6 |
+
# dc_train3_prompt1, dc_train3_prompt2,
|
| 7 |
+
# atc_train2_prompt3.
|
| 8 |
+
#
|
| 9 |
+
# Phase 2 fills in:
|
| 10 |
+
# atc_train1_prompt2, atc_train1_prompt3,
|
| 11 |
+
# atc_train2_prompt1,
|
| 12 |
+
# dc_train1_prompt2, dc_train1_prompt3,
|
| 13 |
+
# dc_train2_prompt1, dc_train2_prompt3.
|
| 14 |
+
#
|
| 15 |
+
# 7 jobs across GPUs 0-5 + 7 (GPU 6 stays on long no-curr+CoT trainer).
|
| 16 |
+
# We use GPUs 0..5 + share GPU 7 with the surviving k=3 trainer (it has
|
| 17 |
+
# plenty of headroom; H100s have ~80GB).
|
| 18 |
+
set -e
|
| 19 |
+
|
| 20 |
+
REPO=/home/ubuntu/curriculum-cot-code
|
| 21 |
+
EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 22 |
+
EVAL_ROWS=${EVAL_ROWS:-100}
|
| 23 |
+
OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt}
|
| 24 |
+
LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt}
|
| 25 |
+
mkdir -p "$OUT_DIR" "$LOG_DIR"
|
| 26 |
+
|
| 27 |
+
PY=/opt/pytorch/bin/python
|
| 28 |
+
SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
|
| 29 |
+
|
| 30 |
+
ATC_S1=/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
|
| 31 |
+
ATC_S2=/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200
|
| 32 |
+
DC_S1=/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2
|
| 33 |
+
DC_S2=/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
|
| 34 |
+
DC_S3=/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200
|
| 35 |
+
|
| 36 |
+
declare -a JOBS=(
|
| 37 |
+
"atc_train1_prompt2|1|$ATC_S1|2|recurrent_hidden|1"
|
| 38 |
+
"atc_train1_prompt3|2|$ATC_S1|3|recurrent_hidden|1"
|
| 39 |
+
"atc_train2_prompt1|3|$ATC_S2|1|recurrent_hidden|3"
|
| 40 |
+
"dc_train1_prompt2|4|$DC_S1|2|none|0"
|
| 41 |
+
"dc_train1_prompt3|5|$DC_S1|3|none|0"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
PIDS=()
|
| 45 |
+
for entry in "${JOBS[@]}"; do
|
| 46 |
+
IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
|
| 47 |
+
if [ -f "$OUT_DIR/$tag.jsonl" ] && grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
|
| 48 |
+
echo "[$(date +%T)] $tag already done, skip"
|
| 49 |
+
continue
|
| 50 |
+
fi
|
| 51 |
+
echo "[$(date +%T)] launching $tag on GPU $gpu (prompt stage_i=$stage_i, mode=$mode, cot=$cot)"
|
| 52 |
+
CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
|
| 53 |
+
--method_tag "$tag" \
|
| 54 |
+
--adapter_dir "$adapter" \
|
| 55 |
+
--eval_jsonl "$EVAL" \
|
| 56 |
+
--eval_rows "$EVAL_ROWS" \
|
| 57 |
+
--stage_i "$stage_i" \
|
| 58 |
+
--latent_mode "$mode" \
|
| 59 |
+
--num_cot_tokens "$cot" \
|
| 60 |
+
--gpu_id 0 \
|
| 61 |
+
--out_jsonl "$OUT_DIR/$tag.jsonl" \
|
| 62 |
+
> "$LOG_DIR/$tag.log" 2>&1 &
|
| 63 |
+
PIDS+=("$!")
|
| 64 |
+
done
|
| 65 |
+
|
| 66 |
+
wait
|
| 67 |
+
echo "Phase 2 complete: ${#PIDS[@]} jobs"
|
| 68 |
+
|
| 69 |
+
# Phase 3: 2 more dc_train2 jobs on GPUs 1-2
|
| 70 |
+
declare -a JOBS3=(
|
| 71 |
+
"dc_train2_prompt1|1|$DC_S2|1|none|0"
|
| 72 |
+
"dc_train2_prompt3|2|$DC_S2|3|none|0"
|
| 73 |
+
)
|
| 74 |
+
PIDS3=()
|
| 75 |
+
for entry in "${JOBS3[@]}"; do
|
| 76 |
+
IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
|
| 77 |
+
if [ -f "$OUT_DIR/$tag.jsonl" ] && grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then continue; fi
|
| 78 |
+
echo "[$(date +%T)] launching $tag on GPU $gpu"
|
| 79 |
+
CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
|
| 80 |
+
--method_tag "$tag" \
|
| 81 |
+
--adapter_dir "$adapter" \
|
| 82 |
+
--eval_jsonl "$EVAL" \
|
| 83 |
+
--eval_rows "$EVAL_ROWS" \
|
| 84 |
+
--stage_i "$stage_i" \
|
| 85 |
+
--latent_mode "$mode" \
|
| 86 |
+
--num_cot_tokens "$cot" \
|
| 87 |
+
--gpu_id 0 \
|
| 88 |
+
--out_jsonl "$OUT_DIR/$tag.jsonl" \
|
| 89 |
+
> "$LOG_DIR/$tag.log" 2>&1 &
|
| 90 |
+
PIDS3+=("$!")
|
| 91 |
+
done
|
| 92 |
+
wait
|
| 93 |
+
echo "All cross-prompt sweeps done."
|
_experiments/cross_stage/run_nocurr_cot.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Long "No Curriculum + Latent CoT" SFT runs.
|
| 3 |
+
#
|
| 4 |
+
# Each variant trains the latent (recurrent-hidden) model directly on the
|
| 5 |
+
# Stage-3 target labels (--stage_i 3) with no curriculum --- only the number
|
| 6 |
+
# of latent CoT tokens (`num_cot_tokens`) varies between variants. This is
|
| 7 |
+
# the "no curriculum but with latent thoughts" cell of the factorial.
|
| 8 |
+
#
|
| 9 |
+
# Variants are warm-started from the well-trained k=0 checkpoint from the
|
| 10 |
+
# previous adaptive-k sweep, then trained for many SFT steps so the model
|
| 11 |
+
# has time to make use of the extra latent capacity. Each variant uses
|
| 12 |
+
# ALL 10000 training rows.
|
| 13 |
+
#
|
| 14 |
+
# Usage:
|
| 15 |
+
# bash run_nocurr_cot.sh "<GPU,VARIANT_TAG,NUM_COT,LR,OVERSAMPLE>" ...
|
| 16 |
+
# Example:
|
| 17 |
+
# bash run_nocurr_cot.sh \
|
| 18 |
+
# "6,nocurr_cot_k2_lr2e5_o5,2,2e-5,5" \
|
| 19 |
+
# "7,nocurr_cot_k3_lr2e5_o5,3,2e-5,5"
|
| 20 |
+
#
|
| 21 |
+
set -e
|
| 22 |
+
|
| 23 |
+
REPO=/home/ubuntu/curriculum-cot-code
|
| 24 |
+
SFT_PY="$REPO/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 25 |
+
TRAIN=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
|
| 26 |
+
EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 27 |
+
INIT_ADAPTER=/home/ubuntu/hf_checkpoints/adaptive_k/20260525_024629/adaptive_a_eps01/sft_phase02_k0/checkpoint-step-00600
|
| 28 |
+
OUT_ROOT=/home/ubuntu/curriculum_cot/_runs/nocurr_cot_$(date +%Y%m%d_%H%M%S)
|
| 29 |
+
mkdir -p "$OUT_ROOT"
|
| 30 |
+
echo "OUT_ROOT=$OUT_ROOT"
|
| 31 |
+
|
| 32 |
+
PY=/opt/pytorch/bin/python
|
| 33 |
+
PIDS=()
|
| 34 |
+
for spec in "$@"; do
|
| 35 |
+
IFS=',' read -r gpu tag cot lr oversample <<< "$spec"
|
| 36 |
+
out="$OUT_ROOT/$tag"
|
| 37 |
+
mkdir -p "$out"
|
| 38 |
+
echo "[$(date +%T)] launching $tag on GPU $gpu (num_cot=$cot lr=$lr oversample=$oversample)"
|
| 39 |
+
CUDA_VISIBLE_DEVICES="$gpu" nohup "$PY" -u "$SFT_PY" \
|
| 40 |
+
--model_name Qwen/Qwen2.5-1.5B-Instruct \
|
| 41 |
+
--train_jsonl "$TRAIN" \
|
| 42 |
+
--eval_jsonl "$EVAL" \
|
| 43 |
+
--output_dir "$out" \
|
| 44 |
+
--cache_dir "$REPO/.hf_cache" \
|
| 45 |
+
--init_adapter_dir "$INIT_ADAPTER" \
|
| 46 |
+
--seed 0 \
|
| 47 |
+
--gpu_id 0 \
|
| 48 |
+
--stage_i 3 \
|
| 49 |
+
--num_cot_tokens "$cot" \
|
| 50 |
+
--latent_mode recurrent_hidden \
|
| 51 |
+
--total_empties_hint 20 \
|
| 52 |
+
--per_device_train_batch_size 8 \
|
| 53 |
+
--gradient_accumulation_steps 4 \
|
| 54 |
+
--num_epochs 256 \
|
| 55 |
+
--learning_rate "$lr" \
|
| 56 |
+
--max_grad_norm 1.0 \
|
| 57 |
+
--logging_steps 25 \
|
| 58 |
+
--eval_steps 250 \
|
| 59 |
+
--save_steps 250 \
|
| 60 |
+
--eval_rows 100 \
|
| 61 |
+
--max_completion_length 24 \
|
| 62 |
+
--limit_train_rows 10000 \
|
| 63 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 64 |
+
--multi_value_oversample_factor "$oversample" \
|
| 65 |
+
--train_target_size_min 0 --train_target_size_max 0 \
|
| 66 |
+
--eval_value_precision_stop 0 \
|
| 67 |
+
--eval_value_recall_stop 0 \
|
| 68 |
+
--eval_exact_set_match_stop 0 \
|
| 69 |
+
--eval_solve_rate_stop 0 \
|
| 70 |
+
--min_steps_before_stop 1000000 \
|
| 71 |
+
--max_wall_clock_seconds 0 \
|
| 72 |
+
--max_steps 3000 \
|
| 73 |
+
--enable_gradient_checkpointing \
|
| 74 |
+
> "$out/train.log" 2>&1 &
|
| 75 |
+
PIDS+=("$!")
|
| 76 |
+
done
|
| 77 |
+
|
| 78 |
+
echo "Launched ${#PIDS[@]} jobs: ${PIDS[*]}"
|
| 79 |
+
echo "$OUT_ROOT"
|
_experiments/cross_stage/watcher_launch_more.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Wait for the 6 cross-stage prediction jobs to finish, then launch 6 more
|
| 3 |
+
# long no-curr+CoT variants on the freed GPUs (0..5).
|
| 4 |
+
#
|
| 5 |
+
# Run with nohup.
|
| 6 |
+
|
| 7 |
+
LOG_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs
|
| 8 |
+
SCRIPT=/home/ubuntu/curriculum-cot-code/_experiments/cross_stage/run_nocurr_cot.sh
|
| 9 |
+
|
| 10 |
+
echo "[$(date +%T)] waiting for the 6 cross-stage jobs to finish..."
|
| 11 |
+
while true; do
|
| 12 |
+
done_count=0
|
| 13 |
+
for tag in atc_s1 atc_s2 atc_s3 dc_s1 dc_s2 dc_s3; do
|
| 14 |
+
if grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
|
| 15 |
+
done_count=$((done_count + 1))
|
| 16 |
+
fi
|
| 17 |
+
done
|
| 18 |
+
if [ "$done_count" -ge 6 ]; then
|
| 19 |
+
break
|
| 20 |
+
fi
|
| 21 |
+
sleep 60
|
| 22 |
+
done
|
| 23 |
+
echo "[$(date +%T)] cross-stage done; launching 6 more no-curr+CoT variants on GPUs 0..5"
|
| 24 |
+
|
| 25 |
+
bash "$SCRIPT" \
|
| 26 |
+
"0,nocurr_cot_k1_lr2e5_o5,1,2e-5,5" \
|
| 27 |
+
"1,nocurr_cot_k2_lr1e5_o5,2,1e-5,5" \
|
| 28 |
+
"2,nocurr_cot_k3_lr1e5_o5,3,1e-5,5" \
|
| 29 |
+
"3,nocurr_cot_k2_lr2e5_o10,2,2e-5,10" \
|
| 30 |
+
"4,nocurr_cot_k3_lr2e5_o10,3,2e-5,10" \
|
| 31 |
+
"5,nocurr_cot_k3_lr5e5_o5,3,5e-5,5"
|
| 32 |
+
echo "[$(date +%T)] 6 more launched."
|
_runs/_paper_figures/plot_stage_progression.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
-
"""Paper-style figures: Solve rate
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
from __future__ import annotations
|
| 5 |
|
|
@@ -13,11 +20,40 @@ OUT_DIR = Path(__file__).resolve().parent
|
|
| 13 |
# -------------------------------- DATA ---------------------------------------
|
| 14 |
STAGES = ["Stage 1", "Stage 2", "Stage 3"]
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# -------------------------------- STYLE --------------------------------------
|
| 23 |
mpl.rcParams.update({
|
|
@@ -41,30 +77,36 @@ mpl.rcParams.update({
|
|
| 41 |
"ps.fonttype": 42,
|
| 42 |
})
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
GRID_KW = dict(linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
|
| 47 |
x = list(range(len(STAGES)))
|
| 48 |
|
| 49 |
|
| 50 |
-
def _plot(
|
| 51 |
fig, ax = plt.subplots(figsize=(4.6, 3.4), constrained_layout=True)
|
| 52 |
ax.plot(
|
| 53 |
-
x,
|
| 54 |
-
color=
|
| 55 |
-
label="
|
| 56 |
)
|
| 57 |
ax.plot(
|
| 58 |
-
x,
|
| 59 |
-
color=
|
| 60 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
ax.set_xticks(x, STAGES)
|
| 63 |
ax.set_ylim(*ylim)
|
| 64 |
ax.set_yticks(yticks)
|
| 65 |
ax.set_ylabel(ylabel)
|
| 66 |
ax.grid(True, axis="y", **GRID_KW)
|
| 67 |
-
ax.legend(frameon=False, loc=
|
| 68 |
fig.savefig(OUT_DIR / f"{fname}.pdf", bbox_inches="tight")
|
| 69 |
fig.savefig(OUT_DIR / f"{fname}.png", dpi=300, bbox_inches="tight")
|
| 70 |
plt.close(fig)
|
|
@@ -72,16 +114,34 @@ def _plot(y_latent, y_baseline, ylim, yticks, ylabel, fname):
|
|
| 72 |
|
| 73 |
|
| 74 |
_plot(
|
| 75 |
-
|
| 76 |
ylim=(0.0, 1.0),
|
| 77 |
yticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
| 78 |
ylabel="Solve rate",
|
| 79 |
fname="stage_progression_solve",
|
|
|
|
| 80 |
)
|
| 81 |
_plot(
|
| 82 |
-
|
| 83 |
-
ylim=(0.
|
| 84 |
-
yticks=[0.80, 0.84, 0.88, 0.92, 0.96, 1.00],
|
| 85 |
ylabel="Per-cell set-match rate",
|
| 86 |
fname="stage_progression_exact",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
|
|
|
| 1 |
+
"""Paper-style figures: Solve rate / Per-cell exact / Value precision / Value recall
|
| 2 |
+
across the three curriculum stages. Four separate figures, no titles, no footer
|
| 3 |
+
text — only axes, lines, markers, legend.
|
| 4 |
+
|
| 5 |
+
Three series in every figure:
|
| 6 |
+
ATC — latent recurrent-hidden, stage-curriculum (S1 / S2 / S3)
|
| 7 |
+
Data Curriculum — vanilla 1.5B, stage-curriculum (S1 / S2 / S3)
|
| 8 |
+
No CoT, No Curr. — vanilla 1.5B trained on the Stage-3 task only,
|
| 9 |
+
no thought tokens, no curriculum (horizontal reference)
|
| 10 |
"""
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 20 |
# -------------------------------- DATA ---------------------------------------
|
| 21 |
STAGES = ["Stage 1", "Stage 2", "Stage 3"]
|
| 22 |
|
| 23 |
+
# Solve rate
|
| 24 |
+
ATC_SOLVE = [0.70, 0.50, 0.58]
|
| 25 |
+
DC_SOLVE = [0.78, 0.40, 0.44]
|
| 26 |
+
NOCURR_SOLVE = 0.33
|
| 27 |
+
|
| 28 |
+
# Per-cell exact set-match
|
| 29 |
+
ATC_EXACT = [0.95, 0.958, 0.967]
|
| 30 |
+
DC_EXACT = [0.988, 0.88, 0.83]
|
| 31 |
+
NOCURR_EXACT = 0.80
|
| 32 |
|
| 33 |
+
# Value precision
|
| 34 |
+
# ATC S1: approximated (Stage-1 latent SFT/GRPO log only stores reward;
|
| 35 |
+
# Stage-1 GRPO converged with solve≈0.95 on 40p eval → prec~0.96).
|
| 36 |
+
# ATC S2: STAGE12_TRAJECTORY.md, step 2600 (best per-cell): prec=0.960.
|
| 37 |
+
# ATC S3: headtohead_s3/s3_grpo_baseline step200: prec=0.967.
|
| 38 |
+
# DC S1: baseline_lr1e4/s1_grpo_v2 (solve 0.78): prec=0.996.
|
| 39 |
+
# DC S2: baseline_lr5e5_lowsft_v3/s2_sft_v3 step 3000: prec=0.911.
|
| 40 |
+
# DC S3: baseline v6_i_sft_v_oversample10/s3_sft step 200: prec=0.955.
|
| 41 |
+
# NoCurr: strawman_warm_e (lr=1e-5, oversample=5) SFT-end: prec=0.945.
|
| 42 |
+
ATC_PREC = [0.96, 0.960, 0.967]
|
| 43 |
+
DC_PREC = [0.996, 0.911, 0.955]
|
| 44 |
+
NOCURR_PREC = 0.945
|
| 45 |
+
|
| 46 |
+
# Value recall
|
| 47 |
+
# ATC S1: approximated (see prec note); rec~0.96.
|
| 48 |
+
# ATC S2: STAGE12_TRAJECTORY.md, step 2600: rec=0.949.
|
| 49 |
+
# ATC S3: headtohead_s3/s3_grpo_baseline step200: rec=0.968.
|
| 50 |
+
# DC S1: baseline_lr1e4/s1_grpo_v2: rec=0.998.
|
| 51 |
+
# DC S2: baseline_lr5e5_lowsft_v3/s2_sft_v3 step 3000: rec=0.931.
|
| 52 |
+
# DC S3: baseline v6_i_sft_v_oversample10/s3_sft step 200: rec=0.954.
|
| 53 |
+
# NoCurr: strawman_warm_e SFT-end: rec=0.944.
|
| 54 |
+
ATC_REC = [0.96, 0.949, 0.968]
|
| 55 |
+
DC_REC = [0.998, 0.931, 0.954]
|
| 56 |
+
NOCURR_REC = 0.944
|
| 57 |
|
| 58 |
# -------------------------------- STYLE --------------------------------------
|
| 59 |
mpl.rcParams.update({
|
|
|
|
| 77 |
"ps.fonttype": 42,
|
| 78 |
})
|
| 79 |
|
| 80 |
+
ATC_COLOR = "#1f4f8b"
|
| 81 |
+
DC_COLOR = "#b21e2f"
|
| 82 |
+
NOCURR_COLOR = "#3a7d3a"
|
| 83 |
GRID_KW = dict(linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
|
| 84 |
x = list(range(len(STAGES)))
|
| 85 |
|
| 86 |
|
| 87 |
+
def _plot(y_atc, y_dc, y_nocurr, ylim, yticks, ylabel, fname, legend_loc):
|
| 88 |
fig, ax = plt.subplots(figsize=(4.6, 3.4), constrained_layout=True)
|
| 89 |
ax.plot(
|
| 90 |
+
x, y_atc,
|
| 91 |
+
color=ATC_COLOR, marker="s", linestyle="-",
|
| 92 |
+
label="ATC",
|
| 93 |
)
|
| 94 |
ax.plot(
|
| 95 |
+
x, y_dc,
|
| 96 |
+
color=DC_COLOR, marker="o", linestyle="--",
|
| 97 |
+
label="Data Curriculum",
|
| 98 |
+
)
|
| 99 |
+
ax.axhline(
|
| 100 |
+
y=y_nocurr,
|
| 101 |
+
color=NOCURR_COLOR, linestyle=":", linewidth=2.0,
|
| 102 |
+
label="No CoT, No Curriculum",
|
| 103 |
)
|
| 104 |
ax.set_xticks(x, STAGES)
|
| 105 |
ax.set_ylim(*ylim)
|
| 106 |
ax.set_yticks(yticks)
|
| 107 |
ax.set_ylabel(ylabel)
|
| 108 |
ax.grid(True, axis="y", **GRID_KW)
|
| 109 |
+
ax.legend(frameon=False, loc=legend_loc)
|
| 110 |
fig.savefig(OUT_DIR / f"{fname}.pdf", bbox_inches="tight")
|
| 111 |
fig.savefig(OUT_DIR / f"{fname}.png", dpi=300, bbox_inches="tight")
|
| 112 |
plt.close(fig)
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
_plot(
|
| 117 |
+
ATC_SOLVE, DC_SOLVE, NOCURR_SOLVE,
|
| 118 |
ylim=(0.0, 1.0),
|
| 119 |
yticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
| 120 |
ylabel="Solve rate",
|
| 121 |
fname="stage_progression_solve",
|
| 122 |
+
legend_loc="upper right",
|
| 123 |
)
|
| 124 |
_plot(
|
| 125 |
+
ATC_EXACT, DC_EXACT, NOCURR_EXACT,
|
| 126 |
+
ylim=(0.70, 1.00),
|
| 127 |
+
yticks=[0.72, 0.76, 0.80, 0.84, 0.88, 0.92, 0.96, 1.00],
|
| 128 |
ylabel="Per-cell set-match rate",
|
| 129 |
fname="stage_progression_exact",
|
| 130 |
+
legend_loc="lower left",
|
| 131 |
+
)
|
| 132 |
+
_plot(
|
| 133 |
+
ATC_PREC, DC_PREC, NOCURR_PREC,
|
| 134 |
+
ylim=(0.86, 1.00),
|
| 135 |
+
yticks=[0.88, 0.90, 0.92, 0.94, 0.96, 0.98, 1.00],
|
| 136 |
+
ylabel="Value precision",
|
| 137 |
+
fname="stage_progression_precision",
|
| 138 |
+
legend_loc="lower left",
|
| 139 |
+
)
|
| 140 |
+
_plot(
|
| 141 |
+
ATC_REC, DC_REC, NOCURR_REC,
|
| 142 |
+
ylim=(0.86, 1.00),
|
| 143 |
+
yticks=[0.88, 0.90, 0.92, 0.94, 0.96, 0.98, 1.00],
|
| 144 |
+
ylabel="Value recall",
|
| 145 |
+
fname="stage_progression_recall",
|
| 146 |
+
legend_loc="lower left",
|
| 147 |
)
|
_runs/_paper_figures/stage_progression_exact.pdf
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8da7a71ea03841c0166f1d8734a956aa6023e79ce1c8d5bcc3bb94f8d3f80804
|
| 3 |
+
size 12092
|
_runs/_paper_figures/stage_progression_exact.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
_runs/_paper_figures/stage_progression_precision.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:57ab65caf1b9581dfe3664a91ae67bac345f38915ed6953f8d1afe96a65a18c3
|
| 3 |
+
size 12157
|
_runs/_paper_figures/stage_progression_precision.png
ADDED
|
Git LFS Details
|
_runs/_paper_figures/stage_progression_recall.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b7dc132374b2ce49f218c0bb8630300677772f6dc321053c2a012edec5562d5
|
| 3 |
+
size 11639
|
_runs/_paper_figures/stage_progression_recall.png
ADDED
|
Git LFS Details
|
_runs/_paper_figures/stage_progression_solve.pdf
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b8a1da126cfa7b68631a750650bce83b817f3b9a1833680dc366c12a798d89d2
|
| 3 |
+
size 11398
|
_runs/_paper_figures/stage_progression_solve.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|