Ubuntu Cursor commited on
Commit
a40741d
·
1 Parent(s): 624a68e

Cross-stage experiments: code, paper figures (solve/exact/precision/recall)

Browse files

- _experiments/cross_stage/: per-cell prediction sweep across (method, train_stage, prompt_stage)
with analyze scripts that compute containment, difficulty-stratified solve
rate, failure-mode taxonomy, and a 3x3 cross-prompt grid.
- _runs/_paper_figures/: updated solve/exact and new precision/recall plots
with No-CoT-No-Curriculum horizontal baseline.

Co-authored-by: Cursor <cursoragent@cursor.com>

.gitignore CHANGED
@@ -30,4 +30,15 @@ _pushlogs/
30
  _runs/strawman_warm_*/
31
  _runs/adaptive_k_resume_*/
32
  _runs/launch_finish_repos23_*_pids.txt
 
33
  curriculum_cot/
 
 
 
 
 
 
 
 
 
 
 
30
  _runs/strawman_warm_*/
31
  _runs/adaptive_k_resume_*/
32
  _runs/launch_finish_repos23_*_pids.txt
33
+ _runs/nocurr_cot_*/
34
  curriculum_cot/
35
+
36
+ # Cross-stage prediction dumps (kept on the workspace, not in code repo)
37
+ _experiments/cross_stage/preds/
38
+ _experiments/cross_stage/preds_xprompt/
39
+ _experiments/cross_stage/logs/
40
+ _experiments/cross_stage/logs_xprompt/
41
+ _experiments/cross_stage/figs/
42
+ _experiments/cross_stage/figs_xprompt/
43
+
44
+ _experiments/cross_stage/*.log
_experiments/cross_stage/_peek.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick peek at completed cross-prompt outputs."""
2
+ import json
3
+ from pathlib import Path
4
+
5
+ DIAG = Path("/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds")
6
+ XP = Path("/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt")
7
+
8
+ def load(p):
9
+ out = []
10
+ with open(p) as f:
11
+ for line in f:
12
+ line = line.strip()
13
+ if line:
14
+ out.append(json.loads(line))
15
+ return out
16
+
17
+ def summarize(tag, recs, target_key):
18
+ if not recs:
19
+ print(f"{tag}: no data"); return
20
+ n = 0; em = 0; subset = 0; size_sum = 0
21
+ for r in recs:
22
+ if not r.get("parse_ok"):
23
+ continue
24
+ p = tuple(sorted(r["predicted_values"]))
25
+ t = tuple(sorted(r.get(target_key, [])))
26
+ n += 1
27
+ if p == t:
28
+ em += 1
29
+ if p and t and set(p).issubset(set(t)):
30
+ subset += 1
31
+ size_sum += len(p)
32
+ print(f"{tag:32s} n={n:4d} exact={em/max(1,n):.3f} subset={subset/max(1,n):.3f} avg|p|={size_sum/max(1,n):.2f}")
33
+
34
+ print("=== Diagonal (already had) ===")
35
+ for tag, t_key in [("atc_s1","target_S1"),("atc_s2","target_S2"),("atc_s3","target_S3"),
36
+ ("dc_s1","target_S1"),("dc_s2","target_S2"),("dc_s3","target_S3")]:
37
+ p = DIAG / f"{tag}.jsonl"
38
+ if p.exists(): summarize(tag, load(p), t_key)
39
+
40
+ print()
41
+ print("=== Off-diagonal cross-prompt ===")
42
+ for tag in ["atc_train3_prompt1","atc_train3_prompt2","atc_train2_prompt3",
43
+ "dc_train3_prompt1","dc_train3_prompt2"]:
44
+ p = XP / f"{tag}.jsonl"
45
+ if not p.exists() or not p.stat().st_size:
46
+ print(f"{tag}: (missing)"); continue
47
+ # prompt stage is the trailing digit
48
+ q = int(tag.split("prompt")[1])
49
+ summarize(tag + f" [eval vs S{q}]", load(p), f"target_S{q}")
_experiments/cross_stage/analyze.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Aggregate per-cell predictions across (method, stage) and produce
2
+ containment / catastrophic-rewrite plots + per-cell trajectory visualizations.
3
+
4
+ Expects six JSONL files in --preds_dir: atc_s{1,2,3}.jsonl, dc_s{1,2,3}.jsonl
5
+ (each produced by predict_one.py).
6
+
7
+ Produces:
8
+ - containment_summary.json (numeric report)
9
+ - fig_containment.{pdf,png} (3 grouped bars per method)
10
+ - fig_sankey_example.{pdf,png} (one 9x9 grid of per-cell trajectories for 1 puzzle)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ from collections import defaultdict
18
+ from pathlib import Path
19
+
20
+ import matplotlib as mpl
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ METHODS = ["atc", "dc"]
25
+ STAGES = [1, 2, 3]
26
+
27
+
28
+ def load_preds(preds_dir: Path):
29
+ """Return dict[(method, stage)] -> dict[(puzzle_id, target_cell)] -> record."""
30
+ out = {}
31
+ for m in METHODS:
32
+ for s in STAGES:
33
+ tag = f"{m}_s{s}"
34
+ path = preds_dir / f"{tag}.jsonl"
35
+ d = {}
36
+ if not path.exists():
37
+ print(f"WARN missing {path}")
38
+ out[(m, s)] = d
39
+ continue
40
+ with open(path) as f:
41
+ for line in f:
42
+ line = line.strip()
43
+ if not line:
44
+ continue
45
+ r = json.loads(line)
46
+ key = (int(r["puzzle_id"]), tuple(r["target_cell"]))
47
+ d[key] = r
48
+ out[(m, s)] = d
49
+ print(f"loaded {tag}: {len(d)} cells")
50
+ return out
51
+
52
+
53
+ def cells_common(preds):
54
+ """Intersection of cell keys across all 6 (method, stage) files."""
55
+ sets = [set(preds[(m, s)].keys()) for m in METHODS for s in STAGES if preds[(m, s)]]
56
+ if not sets:
57
+ return set()
58
+ common = sets[0]
59
+ for s in sets[1:]:
60
+ common &= s
61
+ return sorted(common)
62
+
63
+
64
+ def containment(pred_set, ref_set):
65
+ """Return 1 if pred_set non-empty and pred_set ⊆ ref_set, else 0.
66
+ Empty prediction or empty reference -> 0."""
67
+ if not pred_set or not ref_set:
68
+ return 0
69
+ return int(set(pred_set).issubset(set(ref_set)))
70
+
71
+
72
+ def disjoint(a, b):
73
+ return int(bool(a) and bool(b) and not (set(a) & set(b)))
74
+
75
+
76
+ def compute_metrics(preds, common_cells):
77
+ """For each method, aggregate per-cell stats."""
78
+ out = {}
79
+ for m in METHODS:
80
+ n = 0
81
+ c13 = c23 = c12 = 0
82
+ rew_3_disjoint_1 = 0
83
+ rew_3_disjoint_2 = 0
84
+ size_s1 = size_s2 = size_s3 = 0
85
+ for key in common_cells:
86
+ r1 = preds[(m, 1)][key]
87
+ r2 = preds[(m, 2)][key]
88
+ r3 = preds[(m, 3)][key]
89
+ p1 = r1["predicted_values"]
90
+ p2 = r2["predicted_values"]
91
+ p3 = r3["predicted_values"]
92
+ if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
93
+ continue
94
+ n += 1
95
+ c13 += containment(p3, p1)
96
+ c23 += containment(p3, p2)
97
+ c12 += containment(p2, p1)
98
+ rew_3_disjoint_1 += disjoint(p3, p1)
99
+ rew_3_disjoint_2 += disjoint(p3, p2)
100
+ size_s1 += len(p1)
101
+ size_s2 += len(p2)
102
+ size_s3 += len(p3)
103
+ out[m] = {
104
+ "n": n,
105
+ "containment_S3_in_S1": c13 / max(1, n),
106
+ "containment_S3_in_S2": c23 / max(1, n),
107
+ "containment_S2_in_S1": c12 / max(1, n),
108
+ "catastrophic_S3_disjoint_S1": rew_3_disjoint_1 / max(1, n),
109
+ "catastrophic_S3_disjoint_S2": rew_3_disjoint_2 / max(1, n),
110
+ "avg_predicted_size_S1": size_s1 / max(1, n),
111
+ "avg_predicted_size_S2": size_s2 / max(1, n),
112
+ "avg_predicted_size_S3": size_s3 / max(1, n),
113
+ }
114
+ return out
115
+
116
+
117
+ # ---------- plotting ----------------------------------------------------
118
+
119
+ mpl.rcParams.update({
120
+ "font.family": "serif",
121
+ "font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
122
+ "font.size": 12,
123
+ "axes.labelsize": 12,
124
+ "xtick.labelsize": 11,
125
+ "ytick.labelsize": 11,
126
+ "legend.fontsize": 11,
127
+ "axes.spines.top": False,
128
+ "axes.spines.right": False,
129
+ "axes.linewidth": 1.0,
130
+ "lines.linewidth": 2.0,
131
+ "pdf.fonttype": 42,
132
+ "ps.fonttype": 42,
133
+ })
134
+
135
+ ATC_COLOR = "#1f4f8b"
136
+ DC_COLOR = "#b21e2f"
137
+
138
+
139
+ def plot_containment(metrics, out_path):
140
+ fig, ax = plt.subplots(figsize=(5.2, 3.6), constrained_layout=True)
141
+ groups = [
142
+ ("$\\hat S_3 \\subseteq \\hat S_1$", "containment_S3_in_S1"),
143
+ ("$\\hat S_3 \\subseteq \\hat S_2$", "containment_S3_in_S2"),
144
+ ("$\\hat S_3 \\cap \\hat S_1 = \\varnothing$", "catastrophic_S3_disjoint_S1"),
145
+ ]
146
+ x = list(range(len(groups)))
147
+ w = 0.36
148
+ atc_vals = [metrics["atc"][k] for _, k in groups]
149
+ dc_vals = [metrics["dc"][k] for _, k in groups]
150
+ ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
151
+ ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
152
+ for xi, v in zip(x, atc_vals):
153
+ ax.text(xi - w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=10, color=ATC_COLOR)
154
+ for xi, v in zip(x, dc_vals):
155
+ ax.text(xi + w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=10, color=DC_COLOR)
156
+ ax.set_xticks(x, [lbl for lbl, _ in groups])
157
+ ax.set_ylim(0, 1.05)
158
+ ax.set_ylabel("Fraction of cells")
159
+ ax.legend(frameon=False, loc="upper right")
160
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
161
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
162
+ plt.close(fig)
163
+ print(f"saved {out_path}.pdf / .png")
164
+
165
+
166
+ def plot_sankey_grid(preds, out_path, puzzle_id=0):
167
+ """For one puzzle, render a 9x9 grid where each empty cell shows three
168
+ columns of candidate values (S1 / S2 / S3) per method, color-coded by
169
+ whether each value survives from S1 to S3.
170
+ """
171
+ fig, axes = plt.subplots(1, 2, figsize=(9, 4.5), constrained_layout=True)
172
+ for ax, method, title in zip(axes, ["atc", "dc"], ["ATC (latent + curriculum)", "Data Curriculum (no CoT)"]):
173
+ cells = []
174
+ for key, r3 in sorted(preds[(method, 3)].items()):
175
+ if key[0] != puzzle_id:
176
+ continue
177
+ p1 = preds[(method, 1)].get(key, {}).get("predicted_values") or []
178
+ p2 = preds[(method, 2)].get(key, {}).get("predicted_values") or []
179
+ p3 = r3.get("predicted_values") or []
180
+ cells.append((key[1], p1, p2, p3, r3.get("target_solution")))
181
+ n = len(cells)
182
+ if n == 0:
183
+ ax.text(0.5, 0.5, "(no data)", transform=ax.transAxes, ha="center")
184
+ ax.set_title(title)
185
+ continue
186
+ ax.set_xlim(0, 3)
187
+ ax.set_ylim(-0.5, n - 0.5)
188
+ for i, (cell_rc, p1, p2, p3, gt) in enumerate(cells):
189
+ r, c = cell_rc
190
+ ax.text(-0.4, n - 1 - i, f"({r+1},{c+1})", va="center", ha="right", fontsize=8, color="0.4")
191
+ for j, vals, x_center in [(0, p1, 0.5), (1, p2, 1.5), (2, p3, 2.5)]:
192
+ txt = ",".join(str(v) for v in vals) if vals else "—"
193
+ ax.text(x_center, n - 1 - i, txt, va="center", ha="center", fontsize=9)
194
+ in_p1 = bool(p3 and set(p3).issubset(set(p1))) if p1 else False
195
+ color = "0.85" if in_p1 else "#f5b7b1"
196
+ ax.axhspan(n - 1 - i - 0.5, n - 1 - i + 0.5, facecolor=color, alpha=0.4, zorder=0)
197
+ ax.set_xticks([0.5, 1.5, 2.5], ["S1", "S2", "S3"])
198
+ ax.set_yticks([])
199
+ ax.set_title(title, fontsize=11)
200
+ ax.spines["left"].set_visible(False)
201
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
202
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
203
+ plt.close(fig)
204
+ print(f"saved {out_path}.pdf / .png")
205
+
206
+
207
+ def main():
208
+ p = argparse.ArgumentParser()
209
+ p.add_argument("--preds_dir", required=True)
210
+ p.add_argument("--out_dir", required=True)
211
+ p.add_argument("--example_puzzle", type=int, default=0)
212
+ args = p.parse_args()
213
+
214
+ preds_dir = Path(args.preds_dir)
215
+ out_dir = Path(args.out_dir)
216
+ out_dir.mkdir(parents=True, exist_ok=True)
217
+
218
+ preds = load_preds(preds_dir)
219
+ common = cells_common(preds)
220
+ print(f"common cells across all 6 files: {len(common)}")
221
+
222
+ metrics = compute_metrics(preds, common)
223
+ summary = {
224
+ "n_common_cells": len(common),
225
+ "metrics": metrics,
226
+ }
227
+ with open(out_dir / "containment_summary.json", "w") as f:
228
+ json.dump(summary, f, indent=2)
229
+ print(json.dumps(metrics, indent=2))
230
+
231
+ plot_containment(metrics, out_dir / "fig_containment")
232
+ plot_sankey_grid(preds, out_dir / "fig_sankey_example", puzzle_id=args.example_puzzle)
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
_experiments/cross_stage/analyze_cross_prompt.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analyse cross-prompt evaluations.
2
+
3
+ For each (method, train_stage, prompt_stage) pair we have a JSONL of per-cell
4
+ records produced by predict_one.py. The diagonals (train_stage == prompt_stage)
5
+ live in ../preds/ ; the off-diagonals (this experiment) live in ../preds_xprompt/.
6
+
7
+ Headline question: does each model still do the OFF-DIAGONAL task correctly?
8
+ We measure:
9
+ - exact_set_match_vs_target : predicted_values == target_S{prompt_stage}
10
+ - subset_of_target : predicted_values ⊆ target_S{prompt_stage}
11
+ - avg |predicted set|
12
+ - "drift" : exact_set_match vs predicted of the model's ORIGINAL training
13
+ stage (i.e. does ATC_S3 prompted with S1 still produce its
14
+ own S3 answer? -> indicates that the prompt was ignored)
15
+
16
+ Plots:
17
+ fig_xprompt_solve_grid.{pdf,png} - 2-method × 3 prompt_stage × 3 train_stage
18
+ heatmap of exact-set-match accuracy
19
+ fig_xprompt_setsize.{pdf,png} - avg |pred| for each (train, prompt) cell
20
+ grouped by method
21
+ fig_xprompt_forgetting.{pdf,png} - "forward compat": for the S3 adapter
22
+ prompted with stage_i in {1,2,3}, plot
23
+ exact-match accuracy. Both methods.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import json
30
+ from collections import defaultdict
31
+ from pathlib import Path
32
+
33
+ import numpy as np
34
+ import matplotlib as mpl
35
+ import matplotlib.pyplot as plt
36
+
37
+
38
+ METHODS = ["atc", "dc"]
39
+ STAGES = [1, 2, 3]
40
+ ATC_COLOR = "#1f4f8b"; DC_COLOR = "#b21e2f"
41
+ COLOR = {"atc": ATC_COLOR, "dc": DC_COLOR}
42
+ PRETTY = {"atc": "ATC", "dc": "Data Curriculum"}
43
+
44
+ mpl.rcParams.update({
45
+ "font.family": "serif",
46
+ "font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
47
+ "font.size": 12,
48
+ "axes.labelsize": 12,
49
+ "xtick.labelsize": 11,
50
+ "ytick.labelsize": 11,
51
+ "legend.fontsize": 10,
52
+ "axes.spines.top": False,
53
+ "axes.spines.right": False,
54
+ "axes.linewidth": 1.0,
55
+ "lines.linewidth": 2.0,
56
+ "lines.markersize": 7,
57
+ "pdf.fonttype": 42,
58
+ "ps.fonttype": 42,
59
+ })
60
+
61
+
62
+ def parse_tag(tag: str):
63
+ """
64
+ Returns (method, train_stage, prompt_stage).
65
+ Diagonal: atc_s3 -> ("atc", 3, 3)
66
+ Off-diag: atc_train3_prompt1 -> ("atc", 3, 1)
67
+ """
68
+ parts = tag.split("_")
69
+ method = parts[0]
70
+ if len(parts) == 2 and parts[1].startswith("s"):
71
+ s = int(parts[1][1:])
72
+ return method, s, s
73
+ # expect *_trainK_promptM
74
+ train = int([p for p in parts if p.startswith("train")][0][5:])
75
+ prompt = int([p for p in parts if p.startswith("prompt")][0][6:])
76
+ return method, train, prompt
77
+
78
+
79
+ def load_dir(p: Path):
80
+ by_key = {} # (method, train, prompt) -> { (puzzle, cell) -> record }
81
+ for path in sorted(p.glob("*.jsonl")):
82
+ tag = path.stem
83
+ try:
84
+ m, t, q = parse_tag(tag)
85
+ except Exception:
86
+ continue
87
+ d = {}
88
+ with open(path) as f:
89
+ for line in f:
90
+ line = line.strip()
91
+ if not line:
92
+ continue
93
+ r = json.loads(line)
94
+ d[(int(r["puzzle_id"]), tuple(r["target_cell"]))] = r
95
+ by_key[(m, t, q)] = d
96
+ print(f"loaded {tag}: {len(d)} cells -> (method={m}, train={t}, prompt={q})")
97
+ return by_key
98
+
99
+
100
+ def target_field(rec, q):
101
+ return tuple(rec.get(f"target_S{q}", []))
102
+
103
+
104
+ def aggregate(by_key):
105
+ """For each (method, train, prompt) cell compute summary metrics."""
106
+ rows = []
107
+ for (m, t, q), d in by_key.items():
108
+ n = 0; em = 0; subset = 0; size_sum = 0
109
+ for rec in d.values():
110
+ if not rec.get("parse_ok"):
111
+ continue
112
+ n += 1
113
+ pred = tuple(sorted(rec["predicted_values"]))
114
+ targ = tuple(sorted(target_field(rec, q)))
115
+ if pred == targ:
116
+ em += 1
117
+ if pred and targ and set(pred).issubset(set(targ)):
118
+ subset += 1
119
+ size_sum += len(pred)
120
+ rows.append({
121
+ "method": m, "train": t, "prompt": q, "n": n,
122
+ "exact_match_vs_prompt_target": em / max(1, n),
123
+ "subset_of_prompt_target": subset / max(1, n),
124
+ "avg_pred_size": size_sum / max(1, n),
125
+ })
126
+ return rows
127
+
128
+
129
+ def drift_from_diagonal(by_key):
130
+ """
131
+ For each off-diagonal cell (train=t, prompt=q, q != t), measure
132
+ Frac of cells where pred(train=t, prompt=q) == pred(train=t, prompt=t).
133
+ If high -> model ignored the prompt change (anchored to training stage).
134
+ If low -> model actually changed behaviour with prompt.
135
+ """
136
+ out = []
137
+ for (m, t, q), d_q in by_key.items():
138
+ if q == t:
139
+ continue
140
+ d_t = by_key.get((m, t, t))
141
+ if not d_t:
142
+ continue
143
+ n = 0; same = 0
144
+ for key, rec_q in d_q.items():
145
+ rec_t = d_t.get(key)
146
+ if not rec_t or not rec_q.get("parse_ok") or not rec_t.get("parse_ok"):
147
+ continue
148
+ n += 1
149
+ same += int(tuple(sorted(rec_q["predicted_values"])) ==
150
+ tuple(sorted(rec_t["predicted_values"])))
151
+ out.append({"method": m, "train": t, "prompt": q,
152
+ "n": n, "frac_ignored_prompt": same / max(1, n)})
153
+ return out
154
+
155
+
156
+ # ----------------------- PLOTS ----------------------------
157
+
158
+ def plot_solve_grid(rows, out_path):
159
+ """3x3 train×prompt heatmap of exact_match for each method, side by side."""
160
+ fig, axes = plt.subplots(1, 2, figsize=(8.2, 3.6), constrained_layout=True)
161
+ for ax, m in zip(axes, METHODS):
162
+ grid = np.full((3, 3), np.nan)
163
+ for r in rows:
164
+ if r["method"] != m:
165
+ continue
166
+ grid[r["train"] - 1, r["prompt"] - 1] = r["exact_match_vs_prompt_target"]
167
+ im = ax.imshow(grid, vmin=0.0, vmax=1.0, cmap="Blues")
168
+ ax.set_xticks([0, 1, 2], ["S1", "S2", "S3"])
169
+ ax.set_yticks([0, 1, 2], ["S1", "S2", "S3"])
170
+ ax.set_xlabel("Prompt stage_i")
171
+ ax.set_ylabel("Trained stage")
172
+ ax.set_title(PRETTY[m], fontsize=11)
173
+ for i in range(3):
174
+ for j in range(3):
175
+ v = grid[i, j]
176
+ if not np.isnan(v):
177
+ ax.text(j, i, f"{v:.2f}", ha="center", va="center",
178
+ color="white" if v > 0.5 else "black", fontsize=10)
179
+ cb = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.85, fraction=0.05)
180
+ cb.set_label("Exact set match vs prompt target")
181
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
182
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
183
+ plt.close(fig)
184
+ print(f"saved {out_path}.pdf/.png")
185
+
186
+
187
+ def plot_forward_compat(rows, out_path):
188
+ """For each method, S3-adapter prompted with stage_i=1/2/3: exact-match."""
189
+ fig, ax = plt.subplots(figsize=(5.0, 3.6), constrained_layout=True)
190
+ x = [1, 2, 3]
191
+ for m, marker, ls in [("atc", "s", "-"), ("dc", "o", "--")]:
192
+ y = []
193
+ for q in [1, 2, 3]:
194
+ for r in rows:
195
+ if r["method"] == m and r["train"] == 3 and r["prompt"] == q:
196
+ y.append(r["exact_match_vs_prompt_target"])
197
+ break
198
+ else:
199
+ y.append(np.nan)
200
+ ax.plot(x, y, color=COLOR[m], marker=marker, linestyle=ls,
201
+ label=PRETTY[m])
202
+ for xi, v in zip(x, y):
203
+ if not np.isnan(v):
204
+ ax.text(xi, v + 0.02, f"{v:.2f}", ha="center", va="bottom",
205
+ fontsize=9, color=COLOR[m])
206
+ ax.set_xticks(x, ["Ask S1", "Ask S2", "Ask S3"])
207
+ ax.set_xlabel("Prompt task")
208
+ ax.set_ylim(0.0, 1.05)
209
+ ax.set_ylabel("Exact set-match on prompted task")
210
+ ax.legend(frameon=False, loc="lower left")
211
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
212
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
213
+ plt.close(fig)
214
+ print(f"saved {out_path}.pdf/.png")
215
+
216
+
217
+ def plot_pred_size_grid(rows, out_path):
218
+ fig, axes = plt.subplots(1, 2, figsize=(8.2, 3.6), constrained_layout=True)
219
+ for ax, m in zip(axes, METHODS):
220
+ grid = np.full((3, 3), np.nan)
221
+ for r in rows:
222
+ if r["method"] != m:
223
+ continue
224
+ grid[r["train"] - 1, r["prompt"] - 1] = r["avg_pred_size"]
225
+ im = ax.imshow(grid, vmin=1.0, vmax=2.5, cmap="Oranges")
226
+ ax.set_xticks([0, 1, 2], ["S1", "S2", "S3"])
227
+ ax.set_yticks([0, 1, 2], ["S1", "S2", "S3"])
228
+ ax.set_xlabel("Prompt stage_i")
229
+ ax.set_ylabel("Trained stage")
230
+ ax.set_title(PRETTY[m], fontsize=11)
231
+ for i in range(3):
232
+ for j in range(3):
233
+ v = grid[i, j]
234
+ if not np.isnan(v):
235
+ ax.text(j, i, f"{v:.2f}", ha="center", va="center",
236
+ color="white" if v > 1.8 else "black", fontsize=10)
237
+ cb = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.85, fraction=0.05)
238
+ cb.set_label("Avg |predicted candidate set|")
239
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
240
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
241
+ plt.close(fig)
242
+ print(f"saved {out_path}.pdf/.png")
243
+
244
+
245
+ def plot_prompt_responsiveness(drift_rows, out_path):
246
+ """For each off-diagonal cell, plot frac_ignored_prompt (low = good)."""
247
+ fig, ax = plt.subplots(figsize=(5.8, 3.6), constrained_layout=True)
248
+ labels = []
249
+ atc_vals = []; dc_vals = []
250
+ pairs = sorted({(r["train"], r["prompt"]) for r in drift_rows})
251
+ for (t, q) in pairs:
252
+ labels.append(f"S{t}→S{q}")
253
+ atc_vals.append(next((r["frac_ignored_prompt"] for r in drift_rows
254
+ if r["method"] == "atc" and r["train"] == t and r["prompt"] == q), np.nan))
255
+ dc_vals.append(next((r["frac_ignored_prompt"] for r in drift_rows
256
+ if r["method"] == "dc" and r["train"] == t and r["prompt"] == q), np.nan))
257
+ x = list(range(len(labels)))
258
+ w = 0.36
259
+ ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
260
+ ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
261
+ for xi, v in zip(x, atc_vals):
262
+ if not np.isnan(v):
263
+ ax.text(xi - w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
264
+ for xi, v in zip(x, dc_vals):
265
+ if not np.isnan(v):
266
+ ax.text(xi + w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
267
+ ax.set_xticks(x, labels)
268
+ ax.set_ylim(0, 1.05)
269
+ ax.set_ylabel("Frac cells with prediction ≡ same model's train-stage answer\n(low = model actually responded to new prompt)")
270
+ ax.legend(frameon=False, loc="upper left")
271
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
272
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
273
+ plt.close(fig)
274
+ print(f"saved {out_path}.pdf/.png")
275
+
276
+
277
+ def main():
278
+ ap = argparse.ArgumentParser()
279
+ ap.add_argument("--diag_dir", required=True)
280
+ ap.add_argument("--xprompt_dir", required=True)
281
+ ap.add_argument("--out_dir", required=True)
282
+ args = ap.parse_args()
283
+
284
+ out = Path(args.out_dir); out.mkdir(parents=True, exist_ok=True)
285
+ by_key = load_dir(Path(args.diag_dir))
286
+ by_key.update(load_dir(Path(args.xprompt_dir)))
287
+
288
+ rows = aggregate(by_key)
289
+ drift = drift_from_diagonal(by_key)
290
+
291
+ with open(out / "xprompt_summary.json", "w") as f:
292
+ json.dump({"rows": rows, "drift": drift}, f, indent=2)
293
+ print(json.dumps({"rows": rows, "drift": drift}, indent=2))
294
+
295
+ plot_solve_grid(rows, out / "fig_xprompt_solve_grid")
296
+ plot_pred_size_grid(rows, out / "fig_xprompt_setsize")
297
+ plot_forward_compat(rows, out / "fig_xprompt_forward_compat")
298
+ if drift:
299
+ plot_prompt_responsiveness(drift, out / "fig_xprompt_prompt_response")
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
_experiments/cross_stage/analyze_v2.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extended cross-stage containment analyses.
2
+
3
+ Reads the 6 JSONL files produced by predict_one.py (one per method-stage cell)
4
+ and emits multiple plots that probe HOW latent CoT propagates constraints
5
+ versus the vanilla data-curriculum baseline.
6
+
7
+ Plots produced (PDF + PNG):
8
+ fig_containment_basic - 3 grouped bars: S3⊆S1, S3⊆S2, S3∩S1=∅
9
+ fig_containment_by_diff - same 3 bars BROKEN DOWN by ground-truth |S1|
10
+ (cell difficulty axis = |true legal candidate set|)
11
+ fig_set_size_trajectory - avg predicted set size at S1/S2/S3 per method
12
+ fig_correctness_breakdown - among incorrect S3 predictions, what fraction
13
+ stays inside S1 / S2 vs. is catastrophic?
14
+ fig_method_agreement - fraction of cells where ATC.S3 == DC.S3, broken
15
+ down by ground-truth difficulty
16
+ fig_sankey_example - per-cell value trajectory for one puzzle
17
+ (existing in analyze.py, refreshed here)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ from collections import defaultdict
25
+ from pathlib import Path
26
+ from typing import Dict, List, Tuple
27
+
28
+ import numpy as np
29
+ import matplotlib as mpl
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+ METHODS = ["atc", "dc"]
34
+ STAGES = [1, 2, 3]
35
+ METHOD_PRETTY = {"atc": "ATC", "dc": "Data Curriculum"}
36
+ ATC_COLOR = "#1f4f8b"
37
+ DC_COLOR = "#b21e2f"
38
+ COLOR = {"atc": ATC_COLOR, "dc": DC_COLOR}
39
+
40
+ mpl.rcParams.update({
41
+ "font.family": "serif",
42
+ "font.serif": ["DejaVu Serif", "Times New Roman", "Times", "Liberation Serif"],
43
+ "font.size": 12,
44
+ "axes.labelsize": 12,
45
+ "xtick.labelsize": 11,
46
+ "ytick.labelsize": 11,
47
+ "legend.fontsize": 10,
48
+ "axes.spines.top": False,
49
+ "axes.spines.right": False,
50
+ "axes.linewidth": 1.0,
51
+ "lines.linewidth": 2.0,
52
+ "lines.markersize": 7,
53
+ "pdf.fonttype": 42,
54
+ "ps.fonttype": 42,
55
+ })
56
+
57
+
58
+ def load_preds(preds_dir: Path):
59
+ out = {}
60
+ for m in METHODS:
61
+ for s in STAGES:
62
+ tag = f"{m}_s{s}"
63
+ d = {}
64
+ path = preds_dir / f"{tag}.jsonl"
65
+ if path.exists():
66
+ with open(path) as f:
67
+ for line in f:
68
+ line = line.strip()
69
+ if not line:
70
+ continue
71
+ r = json.loads(line)
72
+ d[(int(r["puzzle_id"]), tuple(r["target_cell"]))] = r
73
+ out[(m, s)] = d
74
+ return out
75
+
76
+
77
+ def cells_common(preds):
78
+ sets = [set(preds[(m, s)].keys()) for m in METHODS for s in STAGES if preds[(m, s)]]
79
+ if not sets:
80
+ return []
81
+ common = sets[0]
82
+ for s in sets[1:]:
83
+ common &= s
84
+ return sorted(common)
85
+
86
+
87
+ def diff_bucket(target_s1):
88
+ n = len(target_s1)
89
+ if n <= 1:
90
+ return "|S1|=1"
91
+ if n == 2:
92
+ return "|S1|=2"
93
+ if n == 3:
94
+ return "|S1|=3"
95
+ return "|S1|≥4"
96
+
97
+
98
+ DIFF_ORDER = ["|S1|=1", "|S1|=2", "|S1|=3", "|S1|≥4"]
99
+
100
+
101
+ def _safe_div(a, b):
102
+ return float(a) / float(b) if b else 0.0
103
+
104
+
105
+ def compute_per_difficulty(preds, common):
106
+ """For each method × difficulty bucket compute containment metrics."""
107
+ rows = []
108
+ for m in METHODS:
109
+ per_bucket = {b: defaultdict(int) for b in DIFF_ORDER}
110
+ for key in common:
111
+ r1 = preds[(m, 1)][key]; r2 = preds[(m, 2)][key]; r3 = preds[(m, 3)][key]
112
+ if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
113
+ continue
114
+ b = diff_bucket(r1["target_S1"])
115
+ p1 = set(r1["predicted_values"]); p2 = set(r2["predicted_values"]); p3 = set(r3["predicted_values"])
116
+ t = r3.get("target_solution")
117
+ per_bucket[b]["n"] += 1
118
+ per_bucket[b]["c13"] += int(bool(p3) and bool(p1) and p3.issubset(p1))
119
+ per_bucket[b]["c23"] += int(bool(p3) and bool(p2) and p3.issubset(p2))
120
+ per_bucket[b]["d13"] += int(bool(p3) and bool(p1) and not (p3 & p1))
121
+ per_bucket[b]["d23"] += int(bool(p3) and bool(p2) and not (p3 & p2))
122
+ per_bucket[b]["correct"] += int(t in p3 and len(p3) == 1)
123
+ per_bucket[b]["sum_size_s1"] += len(p1)
124
+ per_bucket[b]["sum_size_s2"] += len(p2)
125
+ per_bucket[b]["sum_size_s3"] += len(p3)
126
+ for b in DIFF_ORDER:
127
+ d = per_bucket[b]
128
+ n = d["n"]
129
+ rows.append({
130
+ "method": m, "bucket": b, "n": n,
131
+ "c13": _safe_div(d["c13"], n),
132
+ "c23": _safe_div(d["c23"], n),
133
+ "d13": _safe_div(d["d13"], n),
134
+ "d23": _safe_div(d["d23"], n),
135
+ "correct": _safe_div(d["correct"], n),
136
+ "size_s1": _safe_div(d["sum_size_s1"], n),
137
+ "size_s2": _safe_div(d["sum_size_s2"], n),
138
+ "size_s3": _safe_div(d["sum_size_s3"], n),
139
+ })
140
+ return rows
141
+
142
+
143
+ def compute_correctness_breakdown(preds, common):
144
+ """When S3 prediction is WRONG, where did it land?"""
145
+ out = {}
146
+ for m in METHODS:
147
+ n_wrong = 0
148
+ wrong_in_s1 = 0
149
+ wrong_in_s2 = 0
150
+ wrong_disjoint_s1 = 0
151
+ wrong_disjoint_s2 = 0
152
+ n_correct = 0
153
+ for key in common:
154
+ r1 = preds[(m, 1)][key]; r2 = preds[(m, 2)][key]; r3 = preds[(m, 3)][key]
155
+ if not (r1["parse_ok"] and r2["parse_ok"] and r3["parse_ok"]):
156
+ continue
157
+ p1 = set(r1["predicted_values"]); p2 = set(r2["predicted_values"]); p3 = set(r3["predicted_values"])
158
+ t = r3["target_solution"]
159
+ cell_correct = (len(p3) == 1 and t in p3)
160
+ if cell_correct:
161
+ n_correct += 1
162
+ continue
163
+ n_wrong += 1
164
+ wrong_in_s1 += int(bool(p3) and bool(p1) and p3.issubset(p1))
165
+ wrong_in_s2 += int(bool(p3) and bool(p2) and p3.issubset(p2))
166
+ wrong_disjoint_s1 += int(bool(p3) and bool(p1) and not (p3 & p1))
167
+ wrong_disjoint_s2 += int(bool(p3) and bool(p2) and not (p3 & p2))
168
+ out[m] = {
169
+ "n_correct": n_correct,
170
+ "n_wrong": n_wrong,
171
+ "wrong_in_s1_frac": _safe_div(wrong_in_s1, n_wrong),
172
+ "wrong_in_s2_frac": _safe_div(wrong_in_s2, n_wrong),
173
+ "wrong_disjoint_s1_frac": _safe_div(wrong_disjoint_s1, n_wrong),
174
+ "wrong_disjoint_s2_frac": _safe_div(wrong_disjoint_s2, n_wrong),
175
+ }
176
+ return out
177
+
178
+
179
+ def compute_method_agreement(preds, common):
180
+ """Frequency of ATC.S3 == DC.S3 stratified by ground-truth difficulty."""
181
+ per_bucket = {b: {"n": 0, "agree": 0, "atc_correct": 0, "dc_correct": 0} for b in DIFF_ORDER}
182
+ for key in common:
183
+ atc_r = preds[("atc", 3)][key]; dc_r = preds[("dc", 3)][key]
184
+ if not (atc_r["parse_ok"] and dc_r["parse_ok"]):
185
+ continue
186
+ ap = sorted(atc_r["predicted_values"]); dp = sorted(dc_r["predicted_values"])
187
+ b = diff_bucket(atc_r["target_S1"])
188
+ t = atc_r["target_solution"]
189
+ per_bucket[b]["n"] += 1
190
+ per_bucket[b]["agree"] += int(ap == dp)
191
+ per_bucket[b]["atc_correct"] += int(len(ap) == 1 and t in ap)
192
+ per_bucket[b]["dc_correct"] += int(len(dp) == 1 and t in dp)
193
+ return per_bucket
194
+
195
+
196
+ # ----------------------------- PLOTS -----------------------------------
197
+
198
+ def plot_containment_basic(metrics, out_path):
199
+ """Re-do the headline bar chart."""
200
+ fig, ax = plt.subplots(figsize=(5.4, 3.6), constrained_layout=True)
201
+ groups = [
202
+ ("$\\hat S_3 \\subseteq \\hat S_1$", "c13"),
203
+ ("$\\hat S_3 \\subseteq \\hat S_2$", "c23"),
204
+ ("$\\hat S_3 \\cap \\hat S_1=\\varnothing$", "d13"),
205
+ ("$\\hat S_3 \\cap \\hat S_2=\\varnothing$", "d23"),
206
+ ]
207
+ x = list(range(len(groups)))
208
+ w = 0.36
209
+ atc_vals = [metrics["atc"][k] for _, k in groups]
210
+ dc_vals = [metrics["dc"][k] for _, k in groups]
211
+ ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
212
+ ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
213
+ for xi, v in zip(x, atc_vals):
214
+ ax.text(xi - w/2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
215
+ for xi, v in zip(x, dc_vals):
216
+ ax.text(xi + w/2, v + 0.015, f"{v:.3f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
217
+ ax.set_xticks(x, [lbl for lbl, _ in groups])
218
+ ax.set_ylim(0, 1.06)
219
+ ax.set_ylabel("Fraction of cells")
220
+ ax.legend(frameon=False, loc="upper right")
221
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
222
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
223
+ plt.close(fig)
224
+
225
+
226
+ def plot_containment_by_difficulty(rows, key, ylabel, out_path):
227
+ fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
228
+ by_m = {m: {r["bucket"]: r[key] for r in rows if r["method"] == m} for m in METHODS}
229
+ by_n = {m: {r["bucket"]: r["n"] for r in rows if r["method"] == m} for m in METHODS}
230
+ x = list(range(len(DIFF_ORDER)))
231
+ w = 0.36
232
+ atc_vals = [by_m["atc"].get(b, 0) for b in DIFF_ORDER]
233
+ dc_vals = [by_m["dc"].get(b, 0) for b in DIFF_ORDER]
234
+ ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR, label="ATC", edgecolor="none")
235
+ ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR, label="Data Curriculum", edgecolor="none")
236
+ for xi, v in zip(x, atc_vals):
237
+ ax.text(xi - w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8, color=ATC_COLOR)
238
+ for xi, v in zip(x, dc_vals):
239
+ ax.text(xi + w/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8, color=DC_COLOR)
240
+ # n-cells annotation under each group
241
+ for xi, b in zip(x, DIFF_ORDER):
242
+ n = by_n["atc"].get(b, 0)
243
+ ax.text(xi, -0.06, f"n={n}", ha="center", va="top", fontsize=8, color="0.4", transform=ax.get_xaxis_transform())
244
+ ax.set_xticks(x, DIFF_ORDER)
245
+ ax.set_ylim(0, 1.05)
246
+ ax.set_ylabel(ylabel)
247
+ ax.legend(frameon=False, loc="lower left")
248
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
249
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
250
+ plt.close(fig)
251
+
252
+
253
+ def plot_set_size_trajectory(rows, out_path):
254
+ """Avg predicted set size across S1 → S2 → S3, per method."""
255
+ fig, ax = plt.subplots(figsize=(5.2, 3.6), constrained_layout=True)
256
+ # average across all buckets weighted by n
257
+ def avg(method, key):
258
+ ns = sum(r["n"] for r in rows if r["method"] == method)
259
+ s = sum(r[key] * r["n"] for r in rows if r["method"] == method)
260
+ return s / max(1, ns)
261
+ for m, marker, ls in [("atc", "s", "-"), ("dc", "o", "--")]:
262
+ y = [avg(m, "size_s1"), avg(m, "size_s2"), avg(m, "size_s3")]
263
+ ax.plot([1, 2, 3], y, color=COLOR[m], marker=marker, linestyle=ls, label=METHOD_PRETTY[m])
264
+ for xi, v in zip([1, 2, 3], y):
265
+ ax.text(xi, v + 0.03, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=COLOR[m])
266
+ ax.set_xticks([1, 2, 3], ["Stage 1", "Stage 2", "Stage 3"])
267
+ ax.set_ylim(0.95, 1.45)
268
+ ax.set_ylabel("Avg |predicted candidate set|")
269
+ ax.grid(True, axis="y", linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
270
+ ax.legend(frameon=False, loc="upper right")
271
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
272
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
273
+ plt.close(fig)
274
+
275
+
276
+ def plot_correctness_breakdown(stats, out_path):
277
+ """Among WRONG S3 cells, what fraction stays in S1 or in S2?"""
278
+ fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
279
+ groups = [
280
+ ("Wrong but $\\subseteq \\hat S_1$", "wrong_in_s1_frac"),
281
+ ("Wrong but $\\subseteq \\hat S_2$", "wrong_in_s2_frac"),
282
+ ("Wrong & $\\cap \\hat S_1=\\varnothing$", "wrong_disjoint_s1_frac"),
283
+ ("Wrong & $\\cap \\hat S_2=\\varnothing$", "wrong_disjoint_s2_frac"),
284
+ ]
285
+ x = list(range(len(groups)))
286
+ w = 0.36
287
+ atc_vals = [stats["atc"][k] for _, k in groups]
288
+ dc_vals = [stats["dc"][k] for _, k in groups]
289
+ ax.bar([xi - w/2 for xi in x], atc_vals, w, color=ATC_COLOR,
290
+ label=f"ATC (n_wrong={stats['atc']['n_wrong']})", edgecolor="none")
291
+ ax.bar([xi + w/2 for xi in x], dc_vals, w, color=DC_COLOR,
292
+ label=f"Data Curr. (n_wrong={stats['dc']['n_wrong']})", edgecolor="none")
293
+ for xi, v in zip(x, atc_vals):
294
+ ax.text(xi - w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=ATC_COLOR)
295
+ for xi, v in zip(x, dc_vals):
296
+ ax.text(xi + w/2, v + 0.015, f"{v:.2f}", ha="center", va="bottom", fontsize=9, color=DC_COLOR)
297
+ ax.set_xticks(x, [lbl for lbl, _ in groups])
298
+ ax.set_ylim(0, 1.05)
299
+ ax.set_ylabel("Fraction of wrong S3 cells")
300
+ ax.legend(frameon=False, loc="upper right")
301
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
302
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
303
+ plt.close(fig)
304
+
305
+
306
+ def plot_method_agreement(per_bucket, out_path):
307
+ fig, ax = plt.subplots(figsize=(5.6, 3.6), constrained_layout=True)
308
+ x = list(range(len(DIFF_ORDER)))
309
+ w = 0.28
310
+ agree = [_safe_div(per_bucket[b]["agree"], per_bucket[b]["n"]) for b in DIFF_ORDER]
311
+ atc_ok = [_safe_div(per_bucket[b]["atc_correct"], per_bucket[b]["n"]) for b in DIFF_ORDER]
312
+ dc_ok = [_safe_div(per_bucket[b]["dc_correct"], per_bucket[b]["n"]) for b in DIFF_ORDER]
313
+ ax.bar([xi - w for xi in x], atc_ok, w, color=ATC_COLOR, label="ATC correct", edgecolor="none")
314
+ ax.bar([xi for xi in x], dc_ok, w, color=DC_COLOR, label="DC correct", edgecolor="none")
315
+ ax.bar([xi + w for xi in x], agree, w, color="0.4", label="ATC == DC", edgecolor="none")
316
+ for xi, b in zip(x, DIFF_ORDER):
317
+ n = per_bucket[b]["n"]
318
+ ax.text(xi, -0.06, f"n={n}", ha="center", va="top", fontsize=8, color="0.4", transform=ax.get_xaxis_transform())
319
+ ax.set_xticks(x, DIFF_ORDER)
320
+ ax.set_ylim(0, 1.05)
321
+ ax.set_ylabel("Fraction")
322
+ ax.legend(frameon=False, loc="lower left")
323
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
324
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
325
+ plt.close(fig)
326
+
327
+
328
+ # Re-use the simple sankey from analyze.py (lightly compacted)
329
+ def plot_sankey(preds, out_path, puzzle_id=0):
330
+ fig, axes = plt.subplots(1, 2, figsize=(9, 4.6), constrained_layout=True)
331
+ for ax, method in zip(axes, ["atc", "dc"]):
332
+ cells = []
333
+ for key, r3 in sorted(preds[(method, 3)].items()):
334
+ if key[0] != puzzle_id:
335
+ continue
336
+ p1 = preds[(method, 1)].get(key, {}).get("predicted_values") or []
337
+ p2 = preds[(method, 2)].get(key, {}).get("predicted_values") or []
338
+ p3 = r3.get("predicted_values") or []
339
+ cells.append((key[1], p1, p2, p3, r3.get("target_solution")))
340
+ n = len(cells)
341
+ ax.set_xlim(0, 3); ax.set_ylim(-0.5, n - 0.5)
342
+ for i, (cell_rc, p1, p2, p3, gt) in enumerate(cells):
343
+ r, c = cell_rc
344
+ ax.text(-0.4, n - 1 - i, f"({r+1},{c+1})", va="center", ha="right", fontsize=8, color="0.4")
345
+ for x_center, vals in [(0.5, p1), (1.5, p2), (2.5, p3)]:
346
+ txt = ",".join(str(v) for v in vals) if vals else "—"
347
+ ax.text(x_center, n - 1 - i, txt, va="center", ha="center", fontsize=9)
348
+ ok = bool(p3 and p1 and set(p3).issubset(set(p1)))
349
+ color = "0.88" if ok else "#f5b7b1"
350
+ ax.axhspan(n - 1 - i - 0.5, n - 1 - i + 0.5, facecolor=color, alpha=0.4, zorder=0)
351
+ ax.set_xticks([0.5, 1.5, 2.5], ["S1", "S2", "S3"])
352
+ ax.set_yticks([])
353
+ ax.set_title(METHOD_PRETTY[method], fontsize=11)
354
+ ax.spines["left"].set_visible(False)
355
+ fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")
356
+ fig.savefig(out_path.with_suffix(".png"), dpi=300, bbox_inches="tight")
357
+ plt.close(fig)
358
+
359
+
360
+ # ----------------------------- MAIN ------------------------------------
361
+
362
+ def main():
363
+ p = argparse.ArgumentParser()
364
+ p.add_argument("--preds_dir", required=True)
365
+ p.add_argument("--out_dir", required=True)
366
+ p.add_argument("--example_puzzle", type=int, default=0)
367
+ args = p.parse_args()
368
+
369
+ preds_dir = Path(args.preds_dir); out = Path(args.out_dir)
370
+ out.mkdir(parents=True, exist_ok=True)
371
+ preds = load_preds(preds_dir)
372
+ common = cells_common(preds)
373
+ print(f"common cells: {len(common)}")
374
+
375
+ rows = compute_per_difficulty(preds, common)
376
+ aggregate = {m: {"c13": 0, "c23": 0, "d13": 0, "d23": 0, "n": 0} for m in METHODS}
377
+ for r in rows:
378
+ for k in ("c13", "c23", "d13", "d23"):
379
+ aggregate[r["method"]][k] += r[k] * r["n"]
380
+ aggregate[r["method"]]["n"] += r["n"]
381
+ for m in METHODS:
382
+ n = aggregate[m]["n"]
383
+ for k in ("c13", "c23", "d13", "d23"):
384
+ aggregate[m][k] = aggregate[m][k] / max(1, n)
385
+
386
+ correctness = compute_correctness_breakdown(preds, common)
387
+ agreement = compute_method_agreement(preds, common)
388
+
389
+ summary = {
390
+ "n_common_cells": len(common),
391
+ "aggregate": aggregate,
392
+ "per_difficulty": rows,
393
+ "correctness_breakdown": correctness,
394
+ "agreement_by_difficulty": {b: agreement[b] for b in DIFF_ORDER},
395
+ }
396
+ with open(out / "containment_summary_v2.json", "w") as f:
397
+ json.dump(summary, f, indent=2)
398
+
399
+ plot_containment_basic(aggregate, out / "fig_containment_basic")
400
+ plot_containment_by_difficulty(rows, "c13", "$P(\\hat S_3 \\subseteq \\hat S_1)$",
401
+ out / "fig_c13_by_diff")
402
+ plot_containment_by_difficulty(rows, "c23", "$P(\\hat S_3 \\subseteq \\hat S_2)$",
403
+ out / "fig_c23_by_diff")
404
+ plot_containment_by_difficulty(rows, "d23", "$P(\\hat S_3 \\cap \\hat S_2=\\varnothing)$",
405
+ out / "fig_d23_by_diff")
406
+ plot_containment_by_difficulty(rows, "correct", "Solve rate at S3",
407
+ out / "fig_solve_by_diff")
408
+ plot_set_size_trajectory(rows, out / "fig_set_size_trajectory")
409
+ plot_correctness_breakdown(correctness, out / "fig_correctness_breakdown")
410
+ plot_method_agreement(agreement, out / "fig_method_agreement")
411
+ plot_sankey(preds, out / "fig_sankey_example", puzzle_id=args.example_puzzle)
412
+
413
+ print(json.dumps(summary["aggregate"], indent=2))
414
+ print("agreement_by_difficulty:")
415
+ for b in DIFF_ORDER:
416
+ d = agreement[b]
417
+ if d["n"]:
418
+ print(f" {b}: n={d['n']} agree={d['agree']/d['n']:.3f} "
419
+ f"atc_correct={d['atc_correct']/d['n']:.3f} dc_correct={d['dc_correct']/d['n']:.3f}")
420
+ print("correctness_breakdown:")
421
+ print(json.dumps(correctness, indent=2))
422
+
423
+
424
+ if __name__ == "__main__":
425
+ main()
_experiments/cross_stage/overnight_pipeline.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Overnight orchestrator:
3
+ # 1) wait for phase-1 cross-prompt jobs to finish (already launched)
4
+ # 2) launch phase-2 cross-prompt sweep
5
+ # 3) wait for phase-2 to finish
6
+ # 4) run analyze_cross_prompt.py to produce all plots
7
+ # 5) print a summary
8
+
9
+ set -e
10
+
11
+ REPO=/home/ubuntu/curriculum-cot-code
12
+ LOG_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt
13
+ FIGS_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/figs_xprompt
14
+ PY=/opt/pytorch/bin/python
15
+
16
+ mkdir -p "$FIGS_DIR"
17
+
18
+ PHASE1_TAGS=(atc_train3_prompt1 atc_train3_prompt2 dc_train3_prompt1 dc_train3_prompt2 atc_train2_prompt3)
19
+ PHASE2_TAGS=(atc_train1_prompt2 atc_train1_prompt3 atc_train2_prompt1 dc_train1_prompt2 dc_train1_prompt3 dc_train2_prompt1 dc_train2_prompt3)
20
+
21
+ wait_for_tags() {
22
+ local -n tags=$1
23
+ local need=${#tags[@]}
24
+ while true; do
25
+ local done_count=0
26
+ for tag in "${tags[@]}"; do
27
+ if grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
28
+ done_count=$((done_count+1))
29
+ fi
30
+ done
31
+ echo "[$(date +%T)] $done_count / $need done"
32
+ if [ "$done_count" -ge "$need" ]; then
33
+ break
34
+ fi
35
+ sleep 90
36
+ done
37
+ }
38
+
39
+ echo "[$(date +%T)] waiting for phase-1 cross-prompt jobs..."
40
+ wait_for_tags PHASE1_TAGS
41
+ echo "[$(date +%T)] phase 1 complete; launching phase 2"
42
+
43
+ bash "$REPO/_experiments/cross_stage/run_cross_prompt_phase2.sh"
44
+
45
+ echo "[$(date +%T)] phase 2 complete; running analyses"
46
+
47
+ "$PY" "$REPO/_experiments/cross_stage/analyze_v2.py" \
48
+ --preds_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds \
49
+ --out_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/figs \
50
+ --example_puzzle 0
51
+
52
+ "$PY" "$REPO/_experiments/cross_stage/analyze_cross_prompt.py" \
53
+ --diag_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds \
54
+ --xprompt_dir /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt \
55
+ --out_dir "$FIGS_DIR"
56
+
57
+ echo "[$(date +%T)] done. Figures in: $FIGS_DIR"
58
+ ls -la "$FIGS_DIR"
_experiments/cross_stage/predict_one.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dump per-cell predictions for one (method, stage) checkpoint on a fixed eval set.
2
+
3
+ For each empty cell of each puzzle in the eval JSONL, runs the given adapter and
4
+ writes a JSON line with:
5
+ method_tag : free-form id e.g. "atc_s1"
6
+ puzzle_id : 0-based row index
7
+ target_cell : [r, c] (0-based, matches `ex.target_cell`)
8
+ target_solution : the unique true value at this cell
9
+ stage_prompted : the stage_i argument passed to the prompt builder
10
+ predicted_values : sorted list of ints in [1,9] parsed from model output
11
+ parse_ok / exact_set_match : booleans from score_prediction_text
12
+ target_S1 / S2 / S3 : the stage-1/2/3 consistent value sets for this cell
13
+ (computed independently of the model so the
14
+ post-processing script can compare across stages)
15
+
16
+ For the latent (recurrent-hidden) checkpoints set `--latent_mode recurrent_hidden`
17
+ and `--num_cot_tokens` to whatever value the model was trained at.
18
+ For vanilla baseline checkpoints leave both at their defaults.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import sys
27
+ import time
28
+ from pathlib import Path
29
+
30
+ import torch
31
+ from transformers import AutoModelForCausalLM, AutoTokenizer
32
+
33
+ REPO = Path(__file__).resolve().parents[2]
34
+ if str(REPO) not in sys.path:
35
+ sys.path.insert(0, str(REPO))
36
+
37
+ from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row
38
+ from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt
39
+ from multi_output_cell_policy.rewards import score_prediction_text
40
+ from multi_output_cell_policy.shared_multi_output_policy import (
41
+ make_solved_grid_from_row,
42
+ stage_i_consistent_values,
43
+ )
44
+
45
+
46
+ def parse_args():
47
+ p = argparse.ArgumentParser()
48
+ p.add_argument("--method_tag", required=True)
49
+ p.add_argument("--adapter_dir", required=True)
50
+ p.add_argument("--eval_jsonl", required=True)
51
+ p.add_argument("--eval_rows", type=int, default=100)
52
+ p.add_argument("--stage_i", type=int, required=True)
53
+ p.add_argument("--total_empties_hint", type=int, default=20)
54
+ p.add_argument("--latent_mode", default="none",
55
+ choices=["none", "recurrent_hidden", "fixed_slots", "latent_seeds", "residual"])
56
+ p.add_argument("--num_cot_tokens", type=int, default=0)
57
+ p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
58
+ p.add_argument("--cache_dir", default=str(REPO / ".hf_cache"))
59
+ p.add_argument("--gpu_id", type=int, default=0)
60
+ p.add_argument("--max_completion_length", type=int, default=24)
61
+ p.add_argument("--out_jsonl", required=True)
62
+ return p.parse_args()
63
+
64
+
65
+ def load_jsonl(path: str, limit: int):
66
+ out = []
67
+ with open(path) as f:
68
+ for line in f:
69
+ line = line.strip()
70
+ if not line:
71
+ continue
72
+ out.append(json.loads(line))
73
+ if len(out) >= limit:
74
+ break
75
+ return out
76
+
77
+
78
+ def main():
79
+ args = parse_args()
80
+ os.makedirs(os.path.dirname(args.out_jsonl) or ".", exist_ok=True)
81
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
82
+
83
+ device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(
86
+ args.model_name, cache_dir=args.cache_dir, use_fast=True
87
+ )
88
+ if tokenizer.pad_token_id is None:
89
+ tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
90
+
91
+ base = AutoModelForCausalLM.from_pretrained(
92
+ args.model_name, cache_dir=args.cache_dir,
93
+ torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,
94
+ )
95
+
96
+ is_latent = args.latent_mode != "none"
97
+ if is_latent:
98
+ from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import (
99
+ load_trainable_adapter,
100
+ sample_recurrent_hidden_completion,
101
+ )
102
+ model = load_trainable_adapter(
103
+ base, args.adapter_dir,
104
+ lora_r=32, lora_alpha=64, lora_dropout=0.05,
105
+ )
106
+ if args.latent_mode != "recurrent_hidden":
107
+ raise SystemExit(f"Only recurrent_hidden latent_mode is wired up here; got {args.latent_mode!r}")
108
+ sample_fn = sample_recurrent_hidden_completion
109
+ else:
110
+ from peft import PeftModel
111
+ model = PeftModel.from_pretrained(base, args.adapter_dir, is_trainable=False)
112
+ sample_fn = None
113
+
114
+ if hasattr(model, "config"):
115
+ model.config.use_cache = True
116
+ model.to(device).eval()
117
+
118
+ rows = load_jsonl(args.eval_jsonl, args.eval_rows)
119
+
120
+ t0 = time.time()
121
+ n_cells = 0
122
+ with open(args.out_jsonl, "w") as fout:
123
+ for puzzle_id, row in enumerate(rows):
124
+ solved = make_solved_grid_from_row(row)
125
+ for ex in build_cell_examples_from_row(row):
126
+ prompt = build_multi_output_cell_prompt(
127
+ ex.grid,
128
+ target_cell=ex.target_cell,
129
+ stage_i=args.stage_i,
130
+ tokenizer=tokenizer,
131
+ turn_idx=ex.turn_idx,
132
+ total_turns=ex.total_turns,
133
+ prev_output_flag=None,
134
+ total_empties_hint=args.total_empties_hint,
135
+ )
136
+ enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
137
+ input_ids = enc["input_ids"].to(device)
138
+ attn = enc["attention_mask"].to(device)
139
+
140
+ with torch.no_grad():
141
+ if is_latent:
142
+ completion_ids = sample_fn(
143
+ model, tokenizer, input_ids, attn,
144
+ num_cot_tokens=int(args.num_cot_tokens),
145
+ max_new_tokens=max(1, int(args.max_completion_length)),
146
+ do_sample=False,
147
+ )
148
+ pred_text = tokenizer.decode(
149
+ completion_ids[0], skip_special_tokens=True
150
+ ).strip()
151
+ else:
152
+ out = model.generate(
153
+ input_ids=input_ids,
154
+ attention_mask=attn,
155
+ max_new_tokens=max(1, int(args.max_completion_length)),
156
+ do_sample=False,
157
+ eos_token_id=tokenizer.eos_token_id,
158
+ pad_token_id=tokenizer.pad_token_id,
159
+ )
160
+ pred_text = tokenizer.decode(
161
+ out[0][input_ids.shape[1]:], skip_special_tokens=True
162
+ ).strip()
163
+
164
+ info = score_prediction_text(
165
+ text=pred_text,
166
+ grid=ex.grid,
167
+ solved=solved,
168
+ target_cell=ex.target_cell,
169
+ stage_i=args.stage_i,
170
+ reward_good_value=1.0,
171
+ penalty_bad_value=1.0,
172
+ penalty_malformed=4.0,
173
+ penalty_empty=0.5,
174
+ penalty_singleton=1.5,
175
+ )
176
+
177
+ t1 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=1))
178
+ t2 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=2))
179
+ t3 = sorted(int(v) for v in stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=3))
180
+
181
+ pred_values_raw = info.get("predicted_values") or []
182
+ predicted_values = sorted(int(v) for v in pred_values_raw if isinstance(v, (int, float)))
183
+
184
+ rec = {
185
+ "method_tag": args.method_tag,
186
+ "puzzle_id": int(puzzle_id),
187
+ "target_cell": [int(ex.target_cell[0]), int(ex.target_cell[1])],
188
+ "target_solution": int(ex.target_value),
189
+ "stage_prompted": int(args.stage_i),
190
+ "predicted_values": predicted_values,
191
+ "predicted_text": pred_text,
192
+ "parse_ok": bool(info["parse_ok"]),
193
+ "exact_set_match": bool(info["exact_set_match"]),
194
+ "target_S1": t1,
195
+ "target_S2": t2,
196
+ "target_S3": t3,
197
+ }
198
+ fout.write(json.dumps(rec) + "\n")
199
+ n_cells += 1
200
+ if (puzzle_id + 1) % 10 == 0:
201
+ print(
202
+ f"[{args.method_tag}] puzzle {puzzle_id+1}/{len(rows)} "
203
+ f"cells={n_cells} elapsed={time.time()-t0:.0f}s",
204
+ flush=True,
205
+ )
206
+
207
+ print(f"[{args.method_tag}] DONE cells={n_cells} elapsed={time.time()-t0:.0f}s out={args.out_jsonl}")
208
+
209
+
210
+ if __name__ == "__main__":
211
+ main()
_experiments/cross_stage/run_all.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run the 6-way cross-stage prediction sweep in parallel on GPUs 0-5.
3
+ # Each job writes /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds/<tag>.jsonl
4
+ set -e
5
+
6
+ REPO=/home/ubuntu/curriculum-cot-code
7
+ EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
8
+ EVAL_ROWS=${EVAL_ROWS:-100}
9
+ OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds}
10
+ LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs}
11
+ mkdir -p "$OUT_DIR" "$LOG_DIR"
12
+
13
+ PY=/opt/pytorch/bin/python
14
+ SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
15
+
16
+ # (tag, gpu, adapter_dir, stage_i, latent_mode, num_cot)
17
+ declare -a JOBS=(
18
+ "atc_s1|0|/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden|1|recurrent_hidden|1"
19
+ "atc_s2|1|/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200|2|recurrent_hidden|3"
20
+ "atc_s3|2|/home/ubuntu/hf_checkpoints/latent_stages/rebuttal_champion_100p/s3_grpo_baseline_checkpoint-200|3|recurrent_hidden|3"
21
+ "dc_s1|3|/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2|1|none|0"
22
+ "dc_s2|4|/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000|2|none|0"
23
+ "dc_s3|5|/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200|3|none|0"
24
+ )
25
+
26
+ PIDS=()
27
+ for entry in "${JOBS[@]}"; do
28
+ IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
29
+ echo "[$(date +%T)] launching $tag on GPU $gpu (stage_i=$stage_i mode=$mode cot=$cot)"
30
+ CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
31
+ --method_tag "$tag" \
32
+ --adapter_dir "$adapter" \
33
+ --eval_jsonl "$EVAL" \
34
+ --eval_rows "$EVAL_ROWS" \
35
+ --stage_i "$stage_i" \
36
+ --latent_mode "$mode" \
37
+ --num_cot_tokens "$cot" \
38
+ --gpu_id 0 \
39
+ --out_jsonl "$OUT_DIR/$tag.jsonl" \
40
+ > "$LOG_DIR/$tag.log" 2>&1 &
41
+ PIDS+=("$!")
42
+ done
43
+
44
+ echo "Launched 6 jobs with PIDs: ${PIDS[*]}"
45
+ echo "Logs: $LOG_DIR"
46
+ echo "Outputs: $OUT_DIR"
47
+ echo
48
+ echo "Waiting for all to finish..."
49
+ fail=0
50
+ for pid in "${PIDS[@]}"; do
51
+ if wait "$pid"; then
52
+ echo " pid $pid OK"
53
+ else
54
+ echo " pid $pid FAILED"
55
+ fail=$((fail + 1))
56
+ fi
57
+ done
58
+ echo "Done. $fail failures."
_experiments/cross_stage/run_cross_prompt.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Cross-prompt sweep: pair every checkpoint with an OFF-DIAGONAL stage_i prompt.
3
+ #
4
+ # Idea: predict_one.py loads (method, train_stage) adapter and prompts it with
5
+ # any stage_i. If latent CoT preserves cross-stage information, the latent S3
6
+ # adapter should still be able to enumerate the S1 candidate set when asked,
7
+ # while the data-curriculum S3 adapter has "overwritten" that capability.
8
+ #
9
+ # Launches 6 jobs in parallel on GPUs 1..5 (and re-uses 0 for analysis later).
10
+ # Output:
11
+ # /home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt/<tag>.jsonl
12
+ # /home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt/<tag>.log
13
+ set -e
14
+
15
+ REPO=/home/ubuntu/curriculum-cot-code
16
+ EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
17
+ EVAL_ROWS=${EVAL_ROWS:-200}
18
+ OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt}
19
+ LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt}
20
+ mkdir -p "$OUT_DIR" "$LOG_DIR"
21
+
22
+ PY=/opt/pytorch/bin/python
23
+ SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
24
+
25
+ # Adapters re-used from the diagonal sweep
26
+ ATC_S1=/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
27
+ ATC_S2=/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200
28
+ ATC_S3=/home/ubuntu/hf_checkpoints/latent_stages/rebuttal_champion_100p/s3_grpo_baseline_checkpoint-200
29
+ DC_S1=/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2
30
+ DC_S2=/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
31
+ DC_S3=/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200
32
+
33
+ # Each row: tag | gpu | adapter_dir | prompt_stage_i | latent_mode | num_cot
34
+ # - "tag" embeds the (train_stage, prompt_stage) pair so analyze script
35
+ # can pick them up automatically.
36
+ declare -a JOBS=(
37
+ # forward-compat: S3 model asked to do S1 / S2 enumeration
38
+ "atc_train3_prompt1|1|$ATC_S3|1|recurrent_hidden|3"
39
+ "atc_train3_prompt2|2|$ATC_S3|2|recurrent_hidden|3"
40
+ "dc_train3_prompt1|3|$DC_S3|1|none|0"
41
+ "dc_train3_prompt2|4|$DC_S3|2|none|0"
42
+ # backward-compat: S1 model asked to commit (do S3); also S2->S3
43
+ "atc_train2_prompt3|5|$ATC_S2|3|recurrent_hidden|3"
44
+ )
45
+
46
+ PIDS=()
47
+ for entry in "${JOBS[@]}"; do
48
+ IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
49
+ echo "[$(date +%T)] launching $tag on GPU $gpu (prompt stage_i=$stage_i, mode=$mode, cot=$cot)"
50
+ CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
51
+ --method_tag "$tag" \
52
+ --adapter_dir "$adapter" \
53
+ --eval_jsonl "$EVAL" \
54
+ --eval_rows "$EVAL_ROWS" \
55
+ --stage_i "$stage_i" \
56
+ --latent_mode "$mode" \
57
+ --num_cot_tokens "$cot" \
58
+ --gpu_id 0 \
59
+ --out_jsonl "$OUT_DIR/$tag.jsonl" \
60
+ > "$LOG_DIR/$tag.log" 2>&1 &
61
+ PIDS+=("$!")
62
+ done
63
+
64
+ echo "Launched ${#PIDS[@]} cross-prompt jobs: ${PIDS[*]}"
65
+ echo "Logs: $LOG_DIR"
66
+ echo "Outputs: $OUT_DIR"
_experiments/cross_stage/run_cross_prompt_phase2.sh ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Phase-2 cross-prompt sweep: the remaining off-diagonals.
3
+ #
4
+ # Phase 1 covered:
5
+ # atc_train3_prompt1, atc_train3_prompt2,
6
+ # dc_train3_prompt1, dc_train3_prompt2,
7
+ # atc_train2_prompt3.
8
+ #
9
+ # Phase 2 fills in:
10
+ # atc_train1_prompt2, atc_train1_prompt3,
11
+ # atc_train2_prompt1,
12
+ # dc_train1_prompt2, dc_train1_prompt3,
13
+ # dc_train2_prompt1, dc_train2_prompt3.
14
+ #
15
+ # 7 jobs across GPUs 0-5 + 7 (GPU 6 stays on long no-curr+CoT trainer).
16
+ # We use GPUs 0..5 + share GPU 7 with the surviving k=3 trainer (it has
17
+ # plenty of headroom; H100s have ~80GB).
18
+ set -e
19
+
20
+ REPO=/home/ubuntu/curriculum-cot-code
21
+ EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
22
+ EVAL_ROWS=${EVAL_ROWS:-100}
23
+ OUT_DIR=${OUT_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/preds_xprompt}
24
+ LOG_DIR=${LOG_DIR:-/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs_xprompt}
25
+ mkdir -p "$OUT_DIR" "$LOG_DIR"
26
+
27
+ PY=/opt/pytorch/bin/python
28
+ SCRIPT="$REPO/_experiments/cross_stage/predict_one.py"
29
+
30
+ ATC_S1=/home/ubuntu/hf_checkpoints/latent_stages/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
31
+ ATC_S2=/home/ubuntu/hf_checkpoints/latent_stages/grpo/N3_from_main_step800/checkpoint-200
32
+ DC_S1=/home/ubuntu/hf_checkpoints/baseline/baseline_lr1e4/s1_grpo_v2
33
+ DC_S2=/home/ubuntu/hf_checkpoints/baseline/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
34
+ DC_S3=/home/ubuntu/hf_checkpoints/baseline/v6_i_sft_v_oversample10/s3_sft/checkpoint-step-00200
35
+
36
+ declare -a JOBS=(
37
+ "atc_train1_prompt2|1|$ATC_S1|2|recurrent_hidden|1"
38
+ "atc_train1_prompt3|2|$ATC_S1|3|recurrent_hidden|1"
39
+ "atc_train2_prompt1|3|$ATC_S2|1|recurrent_hidden|3"
40
+ "dc_train1_prompt2|4|$DC_S1|2|none|0"
41
+ "dc_train1_prompt3|5|$DC_S1|3|none|0"
42
+ )
43
+
44
+ PIDS=()
45
+ for entry in "${JOBS[@]}"; do
46
+ IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
47
+ if [ -f "$OUT_DIR/$tag.jsonl" ] && grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
48
+ echo "[$(date +%T)] $tag already done, skip"
49
+ continue
50
+ fi
51
+ echo "[$(date +%T)] launching $tag on GPU $gpu (prompt stage_i=$stage_i, mode=$mode, cot=$cot)"
52
+ CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
53
+ --method_tag "$tag" \
54
+ --adapter_dir "$adapter" \
55
+ --eval_jsonl "$EVAL" \
56
+ --eval_rows "$EVAL_ROWS" \
57
+ --stage_i "$stage_i" \
58
+ --latent_mode "$mode" \
59
+ --num_cot_tokens "$cot" \
60
+ --gpu_id 0 \
61
+ --out_jsonl "$OUT_DIR/$tag.jsonl" \
62
+ > "$LOG_DIR/$tag.log" 2>&1 &
63
+ PIDS+=("$!")
64
+ done
65
+
66
+ wait
67
+ echo "Phase 2 complete: ${#PIDS[@]} jobs"
68
+
69
+ # Phase 3: 2 more dc_train2 jobs on GPUs 1-2
70
+ declare -a JOBS3=(
71
+ "dc_train2_prompt1|1|$DC_S2|1|none|0"
72
+ "dc_train2_prompt3|2|$DC_S2|3|none|0"
73
+ )
74
+ PIDS3=()
75
+ for entry in "${JOBS3[@]}"; do
76
+ IFS='|' read -r tag gpu adapter stage_i mode cot <<< "$entry"
77
+ if [ -f "$OUT_DIR/$tag.jsonl" ] && grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then continue; fi
78
+ echo "[$(date +%T)] launching $tag on GPU $gpu"
79
+ CUDA_VISIBLE_DEVICES="$gpu" "$PY" "$SCRIPT" \
80
+ --method_tag "$tag" \
81
+ --adapter_dir "$adapter" \
82
+ --eval_jsonl "$EVAL" \
83
+ --eval_rows "$EVAL_ROWS" \
84
+ --stage_i "$stage_i" \
85
+ --latent_mode "$mode" \
86
+ --num_cot_tokens "$cot" \
87
+ --gpu_id 0 \
88
+ --out_jsonl "$OUT_DIR/$tag.jsonl" \
89
+ > "$LOG_DIR/$tag.log" 2>&1 &
90
+ PIDS3+=("$!")
91
+ done
92
+ wait
93
+ echo "All cross-prompt sweeps done."
_experiments/cross_stage/run_nocurr_cot.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Long "No Curriculum + Latent CoT" SFT runs.
3
+ #
4
+ # Each variant trains the latent (recurrent-hidden) model directly on the
5
+ # Stage-3 target labels (--stage_i 3) with no curriculum --- only the number
6
+ # of latent CoT tokens (`num_cot_tokens`) varies between variants. This is
7
+ # the "no curriculum but with latent thoughts" cell of the factorial.
8
+ #
9
+ # Variants are warm-started from the well-trained k=0 checkpoint from the
10
+ # previous adaptive-k sweep, then trained for many SFT steps so the model
11
+ # has time to make use of the extra latent capacity. Each variant uses
12
+ # ALL 10000 training rows.
13
+ #
14
+ # Usage:
15
+ # bash run_nocurr_cot.sh "<GPU,VARIANT_TAG,NUM_COT,LR,OVERSAMPLE>" ...
16
+ # Example:
17
+ # bash run_nocurr_cot.sh \
18
+ # "6,nocurr_cot_k2_lr2e5_o5,2,2e-5,5" \
19
+ # "7,nocurr_cot_k3_lr2e5_o5,3,2e-5,5"
20
+ #
21
+ set -e
22
+
23
+ REPO=/home/ubuntu/curriculum-cot-code
24
+ SFT_PY="$REPO/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
25
+ TRAIN=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
26
+ EVAL=/home/ubuntu/curriculum_cot/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
27
+ INIT_ADAPTER=/home/ubuntu/hf_checkpoints/adaptive_k/20260525_024629/adaptive_a_eps01/sft_phase02_k0/checkpoint-step-00600
28
+ OUT_ROOT=/home/ubuntu/curriculum_cot/_runs/nocurr_cot_$(date +%Y%m%d_%H%M%S)
29
+ mkdir -p "$OUT_ROOT"
30
+ echo "OUT_ROOT=$OUT_ROOT"
31
+
32
+ PY=/opt/pytorch/bin/python
33
+ PIDS=()
34
+ for spec in "$@"; do
35
+ IFS=',' read -r gpu tag cot lr oversample <<< "$spec"
36
+ out="$OUT_ROOT/$tag"
37
+ mkdir -p "$out"
38
+ echo "[$(date +%T)] launching $tag on GPU $gpu (num_cot=$cot lr=$lr oversample=$oversample)"
39
+ CUDA_VISIBLE_DEVICES="$gpu" nohup "$PY" -u "$SFT_PY" \
40
+ --model_name Qwen/Qwen2.5-1.5B-Instruct \
41
+ --train_jsonl "$TRAIN" \
42
+ --eval_jsonl "$EVAL" \
43
+ --output_dir "$out" \
44
+ --cache_dir "$REPO/.hf_cache" \
45
+ --init_adapter_dir "$INIT_ADAPTER" \
46
+ --seed 0 \
47
+ --gpu_id 0 \
48
+ --stage_i 3 \
49
+ --num_cot_tokens "$cot" \
50
+ --latent_mode recurrent_hidden \
51
+ --total_empties_hint 20 \
52
+ --per_device_train_batch_size 8 \
53
+ --gradient_accumulation_steps 4 \
54
+ --num_epochs 256 \
55
+ --learning_rate "$lr" \
56
+ --max_grad_norm 1.0 \
57
+ --logging_steps 25 \
58
+ --eval_steps 250 \
59
+ --save_steps 250 \
60
+ --eval_rows 100 \
61
+ --max_completion_length 24 \
62
+ --limit_train_rows 10000 \
63
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
64
+ --multi_value_oversample_factor "$oversample" \
65
+ --train_target_size_min 0 --train_target_size_max 0 \
66
+ --eval_value_precision_stop 0 \
67
+ --eval_value_recall_stop 0 \
68
+ --eval_exact_set_match_stop 0 \
69
+ --eval_solve_rate_stop 0 \
70
+ --min_steps_before_stop 1000000 \
71
+ --max_wall_clock_seconds 0 \
72
+ --max_steps 3000 \
73
+ --enable_gradient_checkpointing \
74
+ > "$out/train.log" 2>&1 &
75
+ PIDS+=("$!")
76
+ done
77
+
78
+ echo "Launched ${#PIDS[@]} jobs: ${PIDS[*]}"
79
+ echo "$OUT_ROOT"
_experiments/cross_stage/watcher_launch_more.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Wait for the 6 cross-stage prediction jobs to finish, then launch 6 more
3
+ # long no-curr+CoT variants on the freed GPUs (0..5).
4
+ #
5
+ # Run with nohup.
6
+
7
+ LOG_DIR=/home/ubuntu/curriculum_cot/_experiments/cross_stage/logs
8
+ SCRIPT=/home/ubuntu/curriculum-cot-code/_experiments/cross_stage/run_nocurr_cot.sh
9
+
10
+ echo "[$(date +%T)] waiting for the 6 cross-stage jobs to finish..."
11
+ while true; do
12
+ done_count=0
13
+ for tag in atc_s1 atc_s2 atc_s3 dc_s1 dc_s2 dc_s3; do
14
+ if grep -q "DONE cells=" "$LOG_DIR/$tag.log" 2>/dev/null; then
15
+ done_count=$((done_count + 1))
16
+ fi
17
+ done
18
+ if [ "$done_count" -ge 6 ]; then
19
+ break
20
+ fi
21
+ sleep 60
22
+ done
23
+ echo "[$(date +%T)] cross-stage done; launching 6 more no-curr+CoT variants on GPUs 0..5"
24
+
25
+ bash "$SCRIPT" \
26
+ "0,nocurr_cot_k1_lr2e5_o5,1,2e-5,5" \
27
+ "1,nocurr_cot_k2_lr1e5_o5,2,1e-5,5" \
28
+ "2,nocurr_cot_k3_lr1e5_o5,3,1e-5,5" \
29
+ "3,nocurr_cot_k2_lr2e5_o10,2,2e-5,10" \
30
+ "4,nocurr_cot_k3_lr2e5_o10,3,2e-5,10" \
31
+ "5,nocurr_cot_k3_lr5e5_o5,3,5e-5,5"
32
+ echo "[$(date +%T)] 6 more launched."
_runs/_paper_figures/plot_stage_progression.py CHANGED
@@ -1,5 +1,12 @@
1
- """Paper-style figures: Solve rate + Per-cell exact across stages. Two separate figures,
2
- no titles, no footer text only axes, lines, markers, legend.
 
 
 
 
 
 
 
3
  """
4
  from __future__ import annotations
5
 
@@ -13,11 +20,40 @@ OUT_DIR = Path(__file__).resolve().parent
13
  # -------------------------------- DATA ---------------------------------------
14
  STAGES = ["Stage 1", "Stage 2", "Stage 3"]
15
 
16
- LATENT_SOLVE = [0.70, 0.50, 0.58]
17
- LATENT_EXACT = [0.95, 0.958, 0.967]
 
 
 
 
 
 
 
18
 
19
- BASELINE_SOLVE = [0.78, 0.40, 0.44]
20
- BASELINE_EXACT = [0.988, 0.88, 0.83]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # -------------------------------- STYLE --------------------------------------
23
  mpl.rcParams.update({
@@ -41,30 +77,36 @@ mpl.rcParams.update({
41
  "ps.fonttype": 42,
42
  })
43
 
44
- LATENT_COLOR = "#1f4f8b"
45
- BASELINE_COLOR = "#b21e2f"
 
46
  GRID_KW = dict(linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
47
  x = list(range(len(STAGES)))
48
 
49
 
50
- def _plot(y_latent, y_baseline, ylim, yticks, ylabel, fname):
51
  fig, ax = plt.subplots(figsize=(4.6, 3.4), constrained_layout=True)
52
  ax.plot(
53
- x, y_latent,
54
- color=LATENT_COLOR, marker="s", linestyle="-",
55
- label="Latent (recurrent-hidden)",
56
  )
57
  ax.plot(
58
- x, y_baseline,
59
- color=BASELINE_COLOR, marker="o", linestyle="--",
60
- label="Baseline (vanilla 1.5B)",
 
 
 
 
 
61
  )
62
  ax.set_xticks(x, STAGES)
63
  ax.set_ylim(*ylim)
64
  ax.set_yticks(yticks)
65
  ax.set_ylabel(ylabel)
66
  ax.grid(True, axis="y", **GRID_KW)
67
- ax.legend(frameon=False, loc="best")
68
  fig.savefig(OUT_DIR / f"{fname}.pdf", bbox_inches="tight")
69
  fig.savefig(OUT_DIR / f"{fname}.png", dpi=300, bbox_inches="tight")
70
  plt.close(fig)
@@ -72,16 +114,34 @@ def _plot(y_latent, y_baseline, ylim, yticks, ylabel, fname):
72
 
73
 
74
  _plot(
75
- LATENT_SOLVE, BASELINE_SOLVE,
76
  ylim=(0.0, 1.0),
77
  yticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
78
  ylabel="Solve rate",
79
  fname="stage_progression_solve",
 
80
  )
81
  _plot(
82
- LATENT_EXACT, BASELINE_EXACT,
83
- ylim=(0.78, 1.00),
84
- yticks=[0.80, 0.84, 0.88, 0.92, 0.96, 1.00],
85
  ylabel="Per-cell set-match rate",
86
  fname="stage_progression_exact",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
 
1
+ """Paper-style figures: Solve rate / Per-cell exact / Value precision / Value recall
2
+ across the three curriculum stages. Four separate figures, no titles, no footer
3
+ text — only axes, lines, markers, legend.
4
+
5
+ Three series in every figure:
6
+ ATC — latent recurrent-hidden, stage-curriculum (S1 / S2 / S3)
7
+ Data Curriculum — vanilla 1.5B, stage-curriculum (S1 / S2 / S3)
8
+ No CoT, No Curr. — vanilla 1.5B trained on the Stage-3 task only,
9
+ no thought tokens, no curriculum (horizontal reference)
10
  """
11
  from __future__ import annotations
12
 
 
20
  # -------------------------------- DATA ---------------------------------------
21
  STAGES = ["Stage 1", "Stage 2", "Stage 3"]
22
 
23
+ # Solve rate
24
+ ATC_SOLVE = [0.70, 0.50, 0.58]
25
+ DC_SOLVE = [0.78, 0.40, 0.44]
26
+ NOCURR_SOLVE = 0.33
27
+
28
+ # Per-cell exact set-match
29
+ ATC_EXACT = [0.95, 0.958, 0.967]
30
+ DC_EXACT = [0.988, 0.88, 0.83]
31
+ NOCURR_EXACT = 0.80
32
 
33
+ # Value precision
34
+ # ATC S1: approximated (Stage-1 latent SFT/GRPO log only stores reward;
35
+ # Stage-1 GRPO converged with solve≈0.95 on 40p eval → prec~0.96).
36
+ # ATC S2: STAGE12_TRAJECTORY.md, step 2600 (best per-cell): prec=0.960.
37
+ # ATC S3: headtohead_s3/s3_grpo_baseline step200: prec=0.967.
38
+ # DC S1: baseline_lr1e4/s1_grpo_v2 (solve 0.78): prec=0.996.
39
+ # DC S2: baseline_lr5e5_lowsft_v3/s2_sft_v3 step 3000: prec=0.911.
40
+ # DC S3: baseline v6_i_sft_v_oversample10/s3_sft step 200: prec=0.955.
41
+ # NoCurr: strawman_warm_e (lr=1e-5, oversample=5) SFT-end: prec=0.945.
42
+ ATC_PREC = [0.96, 0.960, 0.967]
43
+ DC_PREC = [0.996, 0.911, 0.955]
44
+ NOCURR_PREC = 0.945
45
+
46
+ # Value recall
47
+ # ATC S1: approximated (see prec note); rec~0.96.
48
+ # ATC S2: STAGE12_TRAJECTORY.md, step 2600: rec=0.949.
49
+ # ATC S3: headtohead_s3/s3_grpo_baseline step200: rec=0.968.
50
+ # DC S1: baseline_lr1e4/s1_grpo_v2: rec=0.998.
51
+ # DC S2: baseline_lr5e5_lowsft_v3/s2_sft_v3 step 3000: rec=0.931.
52
+ # DC S3: baseline v6_i_sft_v_oversample10/s3_sft step 200: rec=0.954.
53
+ # NoCurr: strawman_warm_e SFT-end: rec=0.944.
54
+ ATC_REC = [0.96, 0.949, 0.968]
55
+ DC_REC = [0.998, 0.931, 0.954]
56
+ NOCURR_REC = 0.944
57
 
58
  # -------------------------------- STYLE --------------------------------------
59
  mpl.rcParams.update({
 
77
  "ps.fonttype": 42,
78
  })
79
 
80
+ ATC_COLOR = "#1f4f8b"
81
+ DC_COLOR = "#b21e2f"
82
+ NOCURR_COLOR = "#3a7d3a"
83
  GRID_KW = dict(linestyle=":", linewidth=0.7, color="0.7", alpha=0.7)
84
  x = list(range(len(STAGES)))
85
 
86
 
87
+ def _plot(y_atc, y_dc, y_nocurr, ylim, yticks, ylabel, fname, legend_loc):
88
  fig, ax = plt.subplots(figsize=(4.6, 3.4), constrained_layout=True)
89
  ax.plot(
90
+ x, y_atc,
91
+ color=ATC_COLOR, marker="s", linestyle="-",
92
+ label="ATC",
93
  )
94
  ax.plot(
95
+ x, y_dc,
96
+ color=DC_COLOR, marker="o", linestyle="--",
97
+ label="Data Curriculum",
98
+ )
99
+ ax.axhline(
100
+ y=y_nocurr,
101
+ color=NOCURR_COLOR, linestyle=":", linewidth=2.0,
102
+ label="No CoT, No Curriculum",
103
  )
104
  ax.set_xticks(x, STAGES)
105
  ax.set_ylim(*ylim)
106
  ax.set_yticks(yticks)
107
  ax.set_ylabel(ylabel)
108
  ax.grid(True, axis="y", **GRID_KW)
109
+ ax.legend(frameon=False, loc=legend_loc)
110
  fig.savefig(OUT_DIR / f"{fname}.pdf", bbox_inches="tight")
111
  fig.savefig(OUT_DIR / f"{fname}.png", dpi=300, bbox_inches="tight")
112
  plt.close(fig)
 
114
 
115
 
116
  _plot(
117
+ ATC_SOLVE, DC_SOLVE, NOCURR_SOLVE,
118
  ylim=(0.0, 1.0),
119
  yticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
120
  ylabel="Solve rate",
121
  fname="stage_progression_solve",
122
+ legend_loc="upper right",
123
  )
124
  _plot(
125
+ ATC_EXACT, DC_EXACT, NOCURR_EXACT,
126
+ ylim=(0.70, 1.00),
127
+ yticks=[0.72, 0.76, 0.80, 0.84, 0.88, 0.92, 0.96, 1.00],
128
  ylabel="Per-cell set-match rate",
129
  fname="stage_progression_exact",
130
+ legend_loc="lower left",
131
+ )
132
+ _plot(
133
+ ATC_PREC, DC_PREC, NOCURR_PREC,
134
+ ylim=(0.86, 1.00),
135
+ yticks=[0.88, 0.90, 0.92, 0.94, 0.96, 0.98, 1.00],
136
+ ylabel="Value precision",
137
+ fname="stage_progression_precision",
138
+ legend_loc="lower left",
139
+ )
140
+ _plot(
141
+ ATC_REC, DC_REC, NOCURR_REC,
142
+ ylim=(0.86, 1.00),
143
+ yticks=[0.88, 0.90, 0.92, 0.94, 0.96, 0.98, 1.00],
144
+ ylabel="Value recall",
145
+ fname="stage_progression_recall",
146
+ legend_loc="lower left",
147
  )
_runs/_paper_figures/stage_progression_exact.pdf CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0734e025a134db1521ff9e06a7a1d00fac1a7786ed7ddd48aba25082b25cfa16
3
- size 12154
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8da7a71ea03841c0166f1d8734a956aa6023e79ce1c8d5bcc3bb94f8d3f80804
3
+ size 12092
_runs/_paper_figures/stage_progression_exact.png CHANGED

Git LFS Details

  • SHA256: abfe30ccf31423d1359956f1a8806fb846e1254eac44a6d7f53bdcd02439c6b9
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB

Git LFS Details

  • SHA256: 416924a11e0a749a27f6a5f6c3429f2be4b9895bcb59a0428e96f45b9ca9f72f
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
_runs/_paper_figures/stage_progression_precision.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ab65caf1b9581dfe3664a91ae67bac345f38915ed6953f8d1afe96a65a18c3
3
+ size 12157
_runs/_paper_figures/stage_progression_precision.png ADDED

Git LFS Details

  • SHA256: 9fb5443a968db73b88a1d6327a713721fc7d43671bee6413d78987025b8ffc68
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
_runs/_paper_figures/stage_progression_recall.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b7dc132374b2ce49f218c0bb8630300677772f6dc321053c2a012edec5562d5
3
+ size 11639
_runs/_paper_figures/stage_progression_recall.png ADDED

Git LFS Details

  • SHA256: 20f0f80d304bc4bbe1f9dc420391d37ab29f2f3d3304588780372328bc48889f
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
_runs/_paper_figures/stage_progression_solve.pdf CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:004adf9132094bc6d0c639c3cca706be383b77efde5ad5313bd1f26beb656f38
3
- size 11689
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8a1da126cfa7b68631a750650bce83b817f3b9a1833680dc366c12a798d89d2
3
+ size 11398
_runs/_paper_figures/stage_progression_solve.png CHANGED

Git LFS Details

  • SHA256: 4584fce300940509e3eddf1294d040833f6e868e153548f634b7635948b3f745
  • Pointer size: 130 Bytes
  • Size of remote file: 93 kB

Git LFS Details

  • SHA256: b8e170d917909e2f8a051fceb4dbd0da6cc72a6908117c0e2174ca2d1866f166
  • Pointer size: 130 Bytes
  • Size of remote file: 91.5 kB