Avra98 commited on
Commit
48c96cf
·
verified ·
1 Parent(s): 76de008

Add data/ JSONLs + _runs/ launch scripts (override .gitignore)

Browse files
_runs/LATENT_PID.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 164065 0-7 latent_reproduction_20260524_062728
_runs/adaptive_k_cellpolicy_pipeline.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Adaptive-k cell-policy pipeline (no curriculum).
3
+
4
+ Wraps the existing per-cell trainers to implement an "adaptive-k" schedule:
5
+ the model is trained at stage_i=3 only (no curriculum), with the number of
6
+ recurrent-hidden thought tokens k starting at 0 (vanilla SFT) and being
7
+ incremented whenever the eval exact_set_match metric plateaus. Each phase
8
+ runs ``sft_latent_multi_output_train.py`` for ``steps_per_phase`` SFT steps
9
+ at fixed k, initialised from the previous phase's best checkpoint (so the
10
+ recurrent-hidden bank persists). After the final SFT phase, ``grpo_residual_projector_latent_train.py``
11
+ is invoked at the converged k.
12
+
13
+ The trainer scripts, prompt template, and scoring function are the *same*
14
+ ones used by every cell-policy / latent experiment. The only knob this
15
+ orchestrator provides is the k-schedule; per-cell prompt+supervision is
16
+ handled by the existing trainers.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import os
24
+ import re
25
+ import shutil
26
+ import subprocess
27
+ import sys
28
+ import time
29
+ from pathlib import Path
30
+ from typing import List, Optional
31
+
32
+ ROOT = Path(__file__).resolve().parent.parent
33
+ SFT_SCRIPT = ROOT / "latent_multi_output_cell_policy" / "sft_latent_multi_output_train.py"
34
+ GRPO_SCRIPT = ROOT / "latent_multi_output_cell_policy" / "grpo_residual_projector_latent_train.py"
35
+ TRAIN_JSONL = ROOT / "data" / "sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
36
+ EVAL_JSONL = ROOT / "data" / "sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
37
+
38
+
39
+ def parse_args() -> argparse.Namespace:
40
+ p = argparse.ArgumentParser()
41
+ p.add_argument("--variant", required=True)
42
+ p.add_argument("--gpu", required=True)
43
+ p.add_argument("--output_root", required=True)
44
+ p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
45
+ p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
46
+ p.add_argument("--python_bin", default="/opt/pytorch/bin/python")
47
+ p.add_argument("--latent_mode", default="recurrent_hidden")
48
+ p.add_argument("--start_k", type=int, default=0)
49
+ p.add_argument("--max_k", type=int, default=4)
50
+ p.add_argument("--steps_per_phase", type=int, default=600)
51
+ p.add_argument(
52
+ "--max_phases_per_k",
53
+ type=int,
54
+ default=2,
55
+ help="Hard cap on how many ``steps_per_phase`` chunks to spend at a single k before bumping.",
56
+ )
57
+ p.add_argument(
58
+ "--plateau_eps",
59
+ type=float,
60
+ default=0.01,
61
+ help="If eval exact_set_match_rate improves by less than this between two consecutive phases at the same k, declare a plateau and bump k.",
62
+ )
63
+ p.add_argument("--sft_lr", type=float, default=2e-5)
64
+ p.add_argument("--sft_bs", type=int, default=8)
65
+ p.add_argument("--sft_ga", type=int, default=4)
66
+ p.add_argument("--sft_oversample", type=int, default=3)
67
+ p.add_argument("--grpo_steps", type=int, default=1500)
68
+ p.add_argument("--grpo_lr", type=float, default=5e-6)
69
+ p.add_argument("--grpo_bs", type=int, default=8)
70
+ p.add_argument("--grpo_ga", type=int, default=4)
71
+ p.add_argument("--grpo_ng", type=int, default=8)
72
+ p.add_argument("--grpo_beta", type=float, default=0.0)
73
+ p.add_argument("--grpo_max_prompt", type=int, default=768)
74
+ p.add_argument("--grpo_max_completion", type=int, default=24)
75
+ p.add_argument("--eval_rows", type=int, default=100)
76
+ p.add_argument("--train_rows", type=int, default=10000)
77
+ p.add_argument("--enable_gc", action="store_true", default=True)
78
+ p.add_argument("--seed", type=int, default=0)
79
+ return p.parse_args()
80
+
81
+
82
+ # ---- log parsing -----------------------------------------------------------
83
+
84
+ EVAL_RE = re.compile(r"exact_set_match_rate.*?([01]\.\d+)")
85
+
86
+
87
+ def latest_eval_metric(log_path: Path) -> Optional[float]:
88
+ """Return the most recent eval exact_set_match_rate from the SFT train log."""
89
+ if not log_path.exists():
90
+ return None
91
+ last: Optional[float] = None
92
+ with open(log_path) as f:
93
+ for line in f:
94
+ m = EVAL_RE.search(line)
95
+ if m:
96
+ try:
97
+ last = float(m.group(1))
98
+ except ValueError:
99
+ continue
100
+ return last
101
+
102
+
103
+ def latest_ckpt_dir(out_dir: Path) -> Optional[Path]:
104
+ if not out_dir.exists():
105
+ return None
106
+ cks = sorted(
107
+ [p for p in out_dir.iterdir() if p.is_dir() and p.name.startswith("checkpoint-step-")],
108
+ key=lambda p: int(p.name.split("-")[-1]),
109
+ )
110
+ if cks:
111
+ return cks[-1]
112
+ if (out_dir / "adapter_model.safetensors").exists():
113
+ return out_dir
114
+ return None
115
+
116
+
117
+ def best_grpo_ckpt(out_dir: Path) -> Optional[Path]:
118
+ if not out_dir.exists():
119
+ return None
120
+ cks = sorted(
121
+ [p for p in out_dir.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")],
122
+ key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
123
+ )
124
+ if cks:
125
+ return cks[-1]
126
+ if (out_dir / "adapter_model.safetensors").exists():
127
+ return out_dir
128
+ return None
129
+
130
+
131
+ # ---- subprocess wrappers ---------------------------------------------------
132
+
133
+
134
+ def run_sft_phase(
135
+ *,
136
+ args: argparse.Namespace,
137
+ phase_dir: Path,
138
+ init_adapter: str,
139
+ num_cot_tokens: int,
140
+ max_steps: int,
141
+ ) -> Path:
142
+ """Launch one SFT phase at fixed k. Returns latest checkpoint path."""
143
+ phase_dir.mkdir(parents=True, exist_ok=True)
144
+ log_path = phase_dir / "train.log"
145
+ cmd = [
146
+ args.python_bin,
147
+ "-u",
148
+ str(SFT_SCRIPT),
149
+ "--model_name",
150
+ args.model_name,
151
+ "--train_jsonl",
152
+ str(TRAIN_JSONL),
153
+ "--eval_jsonl",
154
+ str(EVAL_JSONL),
155
+ "--output_dir",
156
+ str(phase_dir),
157
+ "--cache_dir",
158
+ args.cache_dir,
159
+ "--init_adapter_dir",
160
+ str(init_adapter),
161
+ "--seed",
162
+ str(args.seed),
163
+ "--gpu_id",
164
+ "0",
165
+ "--stage_i",
166
+ "3",
167
+ "--num_cot_tokens",
168
+ str(int(num_cot_tokens)),
169
+ "--latent_mode",
170
+ args.latent_mode,
171
+ "--total_empties_hint",
172
+ "20",
173
+ "--per_device_train_batch_size",
174
+ str(args.sft_bs),
175
+ "--gradient_accumulation_steps",
176
+ str(args.sft_ga),
177
+ "--num_epochs",
178
+ "256",
179
+ "--learning_rate",
180
+ str(args.sft_lr),
181
+ "--max_grad_norm",
182
+ "1.0",
183
+ "--logging_steps",
184
+ "25",
185
+ "--eval_steps",
186
+ "200",
187
+ "--save_steps",
188
+ "200",
189
+ "--eval_rows",
190
+ str(args.eval_rows),
191
+ "--max_completion_length",
192
+ "24",
193
+ "--limit_train_rows",
194
+ str(args.train_rows),
195
+ "--lora_r",
196
+ "32",
197
+ "--lora_alpha",
198
+ "64",
199
+ "--lora_dropout",
200
+ "0.05",
201
+ "--multi_value_oversample_factor",
202
+ str(args.sft_oversample),
203
+ "--max_steps",
204
+ str(int(max_steps)),
205
+ ]
206
+ if args.enable_gc:
207
+ cmd.append("--enable_gradient_checkpointing")
208
+
209
+ print(f"[adaptive-k] >>> SFT phase k={num_cot_tokens} max_steps={max_steps}", flush=True)
210
+ print(f"[adaptive-k] init={init_adapter or '(BASE)'}", flush=True)
211
+ print(f"[adaptive-k] out={phase_dir}", flush=True)
212
+ print(f"[adaptive-k] log={log_path}", flush=True)
213
+
214
+ env = dict(os.environ)
215
+ env["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
216
+ env["TOKENIZERS_PARALLELISM"] = "false"
217
+ env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
218
+ env["HF_HOME"] = args.cache_dir
219
+ env["TRANSFORMERS_CACHE"] = args.cache_dir
220
+ with open(log_path, "w") as logf:
221
+ ret = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT, env=env)
222
+ if ret.returncode != 0:
223
+ raise RuntimeError(f"SFT phase k={num_cot_tokens} failed (exit {ret.returncode}); see {log_path}")
224
+
225
+ last = latest_ckpt_dir(phase_dir)
226
+ if last is None:
227
+ raise RuntimeError(f"No checkpoint produced under {phase_dir}")
228
+ return last
229
+
230
+
231
+ def run_grpo_phase(
232
+ *,
233
+ args: argparse.Namespace,
234
+ phase_dir: Path,
235
+ init_adapter: str,
236
+ num_cot_tokens: int,
237
+ max_steps: int,
238
+ ) -> Optional[Path]:
239
+ phase_dir.mkdir(parents=True, exist_ok=True)
240
+ log_path = phase_dir / "train.log"
241
+ cmd = [
242
+ args.python_bin,
243
+ "-u",
244
+ str(GRPO_SCRIPT),
245
+ "--model_name",
246
+ args.model_name,
247
+ "--train_jsonl",
248
+ str(TRAIN_JSONL),
249
+ "--eval_jsonl",
250
+ str(EVAL_JSONL),
251
+ "--output_dir",
252
+ str(phase_dir),
253
+ "--cache_dir",
254
+ args.cache_dir,
255
+ "--init_adapter_dir",
256
+ str(init_adapter),
257
+ "--seed",
258
+ str(args.seed),
259
+ "--gpu_id",
260
+ "0",
261
+ "--stage_i",
262
+ "3",
263
+ "--num_cot_tokens",
264
+ str(int(num_cot_tokens)),
265
+ "--latent_mode",
266
+ args.latent_mode,
267
+ "--total_empties_hint",
268
+ "20",
269
+ "--per_device_train_batch_size",
270
+ str(args.grpo_bs),
271
+ "--gradient_accumulation_steps",
272
+ str(args.grpo_ga),
273
+ "--num_train_epochs",
274
+ "100",
275
+ "--learning_rate",
276
+ str(args.grpo_lr),
277
+ "--logging_steps",
278
+ "10",
279
+ "--save_steps",
280
+ "200",
281
+ "--eval_steps",
282
+ "150",
283
+ "--eval_rows",
284
+ str(args.eval_rows),
285
+ "--num_generations",
286
+ str(args.grpo_ng),
287
+ "--max_prompt_length",
288
+ str(args.grpo_max_prompt),
289
+ "--max_completion_length",
290
+ str(args.grpo_max_completion),
291
+ "--beta",
292
+ str(args.grpo_beta),
293
+ "--limit_train_rows",
294
+ str(args.train_rows),
295
+ "--lora_r",
296
+ "32",
297
+ "--lora_alpha",
298
+ "64",
299
+ "--lora_dropout",
300
+ "0.05",
301
+ "--max_steps",
302
+ str(int(max_steps)),
303
+ ]
304
+ if args.enable_gc:
305
+ cmd.append("--enable_gradient_checkpointing")
306
+ print(f"[adaptive-k] >>> GRPO phase k={num_cot_tokens} max_steps={max_steps}", flush=True)
307
+ print(f"[adaptive-k] init={init_adapter}", flush=True)
308
+ print(f"[adaptive-k] out={phase_dir}", flush=True)
309
+
310
+ env = dict(os.environ)
311
+ env["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
312
+ env["TOKENIZERS_PARALLELISM"] = "false"
313
+ env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
314
+ env["HF_HOME"] = args.cache_dir
315
+ env["TRANSFORMERS_CACHE"] = args.cache_dir
316
+ with open(log_path, "w") as logf:
317
+ ret = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT, env=env)
318
+ if ret.returncode != 0:
319
+ print(f"[adaptive-k] WARN: GRPO failed exit={ret.returncode}, see {log_path}", flush=True)
320
+ return best_grpo_ckpt(phase_dir)
321
+
322
+
323
+ # ---- main loop -------------------------------------------------------------
324
+
325
+
326
+ def main() -> None:
327
+ args = parse_args()
328
+ output_root = Path(args.output_root)
329
+ output_root.mkdir(parents=True, exist_ok=True)
330
+ state_path = output_root / "STATE.json"
331
+ pipeline_log = output_root / "PIPELINE.log"
332
+
333
+ def log(msg: str) -> None:
334
+ line = f"[{time.strftime('%H:%M:%S')}] {msg}"
335
+ print(line, flush=True)
336
+ with open(pipeline_log, "a") as f:
337
+ f.write(line + "\n")
338
+
339
+ log(f"===== ADAPTIVE-K {args.variant} on GPU {args.gpu} =====")
340
+ log(f" start_k={args.start_k} max_k={args.max_k} steps_per_phase={args.steps_per_phase} max_phases_per_k={args.max_phases_per_k}")
341
+ log(f" plateau_eps={args.plateau_eps} sft_lr={args.sft_lr} grpo_lr={args.grpo_lr}")
342
+ log(f" output_root={output_root}")
343
+
344
+ history: List[dict] = []
345
+ cur_k = int(args.start_k)
346
+ cur_init: str = "" # "" -> train from base
347
+ last_metric_at_k: Optional[float] = None
348
+ phases_at_k = 0
349
+ sft_phase_idx = 0
350
+
351
+ while cur_k <= int(args.max_k):
352
+ sft_phase_idx += 1
353
+ phase_dir = output_root / f"sft_phase{sft_phase_idx:02d}_k{cur_k}"
354
+ ckpt = run_sft_phase(
355
+ args=args,
356
+ phase_dir=phase_dir,
357
+ init_adapter=cur_init,
358
+ num_cot_tokens=cur_k,
359
+ max_steps=int(args.steps_per_phase),
360
+ )
361
+ metric = latest_eval_metric(phase_dir / "train.log")
362
+ log(
363
+ f" phase{sft_phase_idx} k={cur_k} ckpt={ckpt.name} eval_exact_set_match_rate={metric}"
364
+ )
365
+ history.append(
366
+ {
367
+ "phase": sft_phase_idx,
368
+ "k": cur_k,
369
+ "phase_dir": str(phase_dir),
370
+ "ckpt": str(ckpt),
371
+ "exact_set_match_rate": metric,
372
+ }
373
+ )
374
+ with open(state_path, "w") as f:
375
+ json.dump({"history": history, "cur_k": cur_k, "cur_ckpt": str(ckpt)}, f, indent=2)
376
+
377
+ cur_init = str(ckpt)
378
+ phases_at_k += 1
379
+
380
+ if cur_k >= int(args.max_k):
381
+ log(f" reached max_k={args.max_k}, stopping SFT loop")
382
+ break
383
+
384
+ if last_metric_at_k is None or metric is None:
385
+ improvement = None
386
+ else:
387
+ improvement = float(metric) - float(last_metric_at_k)
388
+ log(f" improvement_at_k={improvement} phases_at_k={phases_at_k}/{args.max_phases_per_k}")
389
+
390
+ bump = False
391
+ if phases_at_k >= int(args.max_phases_per_k):
392
+ log(" hit max_phases_per_k, bumping k")
393
+ bump = True
394
+ elif improvement is not None and improvement < float(args.plateau_eps):
395
+ log(f" improvement {improvement:.4f} < plateau_eps {args.plateau_eps:.4f}, bumping k")
396
+ bump = True
397
+
398
+ if bump:
399
+ cur_k += 1
400
+ last_metric_at_k = None
401
+ phases_at_k = 0
402
+ else:
403
+ last_metric_at_k = metric
404
+
405
+ log(f"===== final SFT k={cur_k} ckpt={cur_init} =====")
406
+ grpo_dir = output_root / f"grpo_k{cur_k}"
407
+ grpo_ckpt = run_grpo_phase(
408
+ args=args,
409
+ phase_dir=grpo_dir,
410
+ init_adapter=cur_init,
411
+ num_cot_tokens=cur_k,
412
+ max_steps=int(args.grpo_steps),
413
+ )
414
+ log(f"===== GRPO done ckpt={grpo_ckpt} =====")
415
+ with open(state_path, "w") as f:
416
+ json.dump(
417
+ {
418
+ "history": history,
419
+ "final_k": cur_k,
420
+ "final_sft_ckpt": cur_init,
421
+ "grpo_ckpt": str(grpo_ckpt) if grpo_ckpt else None,
422
+ },
423
+ f,
424
+ indent=2,
425
+ )
426
+ log(f"===== ADAPTIVE-K {args.variant} done =====")
427
+
428
+
429
+ if __name__ == "__main__":
430
+ main()
_runs/adaptive_latent_baseline_sudoku_train.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Adaptive-k thought-token baseline (experiment D in the 2x2 ablation).
3
+
4
+ Same single-stage, whole-puzzle setup as `simple_baseline_sudoku_train.py`
5
+ (experiment C, the "strawman"). Same JSONL data, same chat template, same
6
+ model, same LoRA. The ONLY difference is that this run inserts k recurrent
7
+ thought tokens between the prompt and the next-token logits, and grows k
8
+ on demand whenever the SFT loss plateaus.
9
+
10
+ Algorithm:
11
+ k = 0 (start as the vanilla baseline)
12
+ repeat:
13
+ train SFT for `min_steps_per_k` steps with current k
14
+ if rolling_avg(loss[-w:]) - rolling_avg(loss[-2w:-w]) > -plateau_eps:
15
+ k += 1 # grow capacity
16
+ if k > max_k: break
17
+ if loss has been steadily decreasing past `min_steps_per_k * 3`:
18
+ break # converged
19
+ save final adapter
20
+
21
+ The recurrent_hidden mechanism is imported verbatim from
22
+ `latent_multi_output_cell_policy.grpo_residual_projector_latent_train`
23
+ (via `latent_batched_completion_ce_loss`). For k=0 the loss reduces to
24
+ vanilla next-token CE, so the trajectory smoothly continues from the
25
+ strawman.
26
+
27
+ Reward / loss contract (see `simple_baseline_sudoku_train.py` for details):
28
+ - supervision is token-level CE against the JSONL `completion` field
29
+ (the 20 ground-truth digits at the 20 empty cells, row-major).
30
+ - this script is SFT-only; you can chain GRPO afterwards by passing the
31
+ saved adapter to `simple_baseline_sudoku_train.py --phase grpo`.
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import argparse
37
+ import json
38
+ import math
39
+ import os
40
+ import sys
41
+ import time
42
+ from collections import deque
43
+ from pathlib import Path
44
+ from typing import Any, Dict, List, Tuple
45
+
46
+ import torch
47
+ import torch.nn.functional as F
48
+ from peft import LoraConfig, PeftModel, get_peft_model
49
+ from torch.optim import AdamW
50
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
51
+
52
+ ROOT = Path(__file__).resolve().parent.parent
53
+ if str(ROOT) not in sys.path:
54
+ sys.path.insert(0, str(ROOT))
55
+
56
+ # Reuse helpers and the latent loss from the curriculum codebase. NO
57
+ # re-implementation of the recurrent_hidden mechanism here.
58
+ from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore
59
+ load_jsonl_rows,
60
+ pick_dtype,
61
+ )
62
+ from latent_multi_output_cell_policy.sft_latent_multi_output_train import ( # type: ignore
63
+ latent_batched_completion_ce_loss,
64
+ )
65
+ from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import ( # type: ignore
66
+ recurrent_hidden_next_token_logits_from_ids,
67
+ )
68
+ from _runs.simple_baseline_sudoku_train import ( # type: ignore
69
+ SYSTEM_PROMPT_STRAWMAN,
70
+ build_chat_prompt,
71
+ parse_int_list,
72
+ )
73
+ from multi_output_cell_policy.rewards import score_prediction_text # type: ignore
74
+ from multi_output_cell_policy.shared_multi_output_policy import ( # type: ignore
75
+ make_solved_grid_from_row,
76
+ stage_i_consistent_values,
77
+ )
78
+ from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row # type: ignore
79
+
80
+
81
+ # ---- Tokenization (mirror what latent_batched_completion_ce_loss expects) ---
82
+
83
+
84
+ def tokenize_example(
85
+ tokenizer: Any,
86
+ raw_prompt: str,
87
+ completion_text: str,
88
+ max_prompt_length: int,
89
+ max_completion_length: int,
90
+ ) -> Dict[str, List[int]]:
91
+ prompt_text = build_chat_prompt(tokenizer, raw_prompt)
92
+ prompt_ids = tokenizer(
93
+ prompt_text, add_special_tokens=False
94
+ ).input_ids[-max_prompt_length:]
95
+ eos = tokenizer.eos_token or "<|endoftext|>"
96
+ completion_ids = tokenizer(
97
+ completion_text + eos, add_special_tokens=False
98
+ ).input_ids[:max_completion_length]
99
+ return {"prompt_ids": prompt_ids, "completion_ids": completion_ids}
100
+
101
+
102
+ # ---- Eval (autoregressive greedy decode WITH k recurrent thought tokens) ---
103
+
104
+
105
+ @torch.no_grad()
106
+ def latent_greedy_generate(
107
+ model: torch.nn.Module,
108
+ tokenizer: Any,
109
+ prompt_text: str,
110
+ device: torch.device,
111
+ *,
112
+ num_cot_tokens: int,
113
+ max_new_tokens: int,
114
+ ) -> str:
115
+ enc = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
116
+ input_ids = enc["input_ids"].to(device)
117
+ attention_mask = enc["attention_mask"].to(device)
118
+ prompt_len = int(input_ids.shape[1])
119
+ eos_token_id = tokenizer.eos_token_id
120
+ for _ in range(int(max_new_tokens)):
121
+ logits = recurrent_hidden_next_token_logits_from_ids(
122
+ model, input_ids, attention_mask, int(max(0, num_cot_tokens))
123
+ )
124
+ next_id = int(torch.argmax(logits, dim=-1).item())
125
+ input_ids = torch.cat(
126
+ [input_ids, torch.tensor([[next_id]], device=device, dtype=input_ids.dtype)], dim=1
127
+ )
128
+ attention_mask = torch.cat(
129
+ [attention_mask, torch.ones((1, 1), device=device, dtype=attention_mask.dtype)], dim=1
130
+ )
131
+ if eos_token_id is not None and next_id == int(eos_token_id):
132
+ break
133
+ new_ids = input_ids[0, prompt_len:]
134
+ return tokenizer.decode(new_ids, skip_special_tokens=True).strip()
135
+
136
+
137
+ @torch.no_grad()
138
+ def run_eval(
139
+ model: torch.nn.Module,
140
+ tokenizer: Any,
141
+ eval_rows: List[Dict[str, Any]],
142
+ device: torch.device,
143
+ *,
144
+ num_cot_tokens: int,
145
+ max_new_tokens: int,
146
+ print_n: int = 3,
147
+ stage_i: int = 3,
148
+ ) -> Dict[str, float]:
149
+ """Apples-to-apples eval with the cell-policy framework (see strawman script).
150
+
151
+ The model emits the WHOLE puzzle (JSON list of integers) in one greedy
152
+ rollout with `num_cot_tokens` recurrent thought tokens prepended at each
153
+ step. We split that list into per-cell SINGLETON predictions and score
154
+ each cell with ``score_prediction_text`` against the i-consistent target
155
+ set at ``stage_i`` (default 3 — matches the S3 eval used for the v6
156
+ baseline and the latent champion).
157
+ """
158
+ model.eval()
159
+ total_cells = 0
160
+ parse_ok = 0.0
161
+ canonical_ok = 0.0
162
+ exact_set_match = 0.0
163
+ includes_gt = 0.0
164
+ precision_sum = 0.0
165
+ recall_sum = 0.0
166
+ cardinality_match_sum = 0.0
167
+ n_solve = 0
168
+ n_total_puzzles = 0
169
+ n_parse_fail_puzzles = 0
170
+ printed = 0
171
+ for row in eval_rows:
172
+ target_completion = parse_int_list(str(row["completion"]))
173
+ if target_completion is None:
174
+ continue
175
+ n_total_puzzles += 1
176
+ prompt_text = build_chat_prompt(tokenizer, str(row["prompt"]).strip())
177
+ gen = latent_greedy_generate(
178
+ model, tokenizer, prompt_text, device,
179
+ num_cot_tokens=num_cot_tokens, max_new_tokens=max_new_tokens,
180
+ )
181
+ pred_list = parse_int_list(gen)
182
+
183
+ try:
184
+ cells = build_cell_examples_from_row(row)
185
+ solved = make_solved_grid_from_row(row)
186
+ except Exception as e:
187
+ if printed < print_n:
188
+ print(f"[adaptive_k k={num_cot_tokens} eval] row skipped (no metadata): {e}", flush=True)
189
+ printed += 1
190
+ continue
191
+
192
+ row_all_exact = True
193
+ row_has_eval_cell = False
194
+ for idx, ex in enumerate(cells):
195
+ target_values = stage_i_consistent_values(
196
+ ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i)
197
+ )
198
+ row_has_eval_cell = True
199
+ if pred_list is not None and idx < len(pred_list):
200
+ pred_text = json.dumps({"values": [int(pred_list[idx])]})
201
+ else:
202
+ pred_text = ""
203
+ info = score_prediction_text(
204
+ text=pred_text,
205
+ grid=ex.grid,
206
+ solved=solved,
207
+ target_cell=ex.target_cell,
208
+ stage_i=int(stage_i),
209
+ reward_good_value=1.0,
210
+ penalty_bad_value=1.75,
211
+ penalty_malformed=4.0,
212
+ penalty_empty=0.5,
213
+ penalty_singleton=1.5,
214
+ )
215
+ total_cells += 1
216
+ parse_ok += float(info["parse_ok"])
217
+ canonical_ok += float(info["strict_canonical"])
218
+ exact_set_match += float(info["exact_set_match"])
219
+ includes_gt += float(info["includes_ground_truth"])
220
+ precision_sum += float(info["value_precision"])
221
+ recall_sum += float(info["value_recall"])
222
+ if int(info["num_predicted_values"]) == int(len(target_values)):
223
+ cardinality_match_sum += 1.0
224
+ if float(info["exact_set_match"]) < 0.5:
225
+ row_all_exact = False
226
+ if row_has_eval_cell and row_all_exact:
227
+ n_solve += 1
228
+ if pred_list is None:
229
+ n_parse_fail_puzzles += 1
230
+ if printed < print_n:
231
+ head_pred = pred_list if pred_list is not None else "PARSE_FAIL"
232
+ print(
233
+ f"[adaptive_k k={num_cot_tokens} eval] target={target_completion} pred={head_pred} "
234
+ f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}",
235
+ flush=True,
236
+ )
237
+ printed += 1
238
+ model.train()
239
+ return {
240
+ "n_total_cells": float(total_cells),
241
+ "n_total_puzzles": float(n_total_puzzles),
242
+ "parse_rate": float(parse_ok / max(1, total_cells)),
243
+ "strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
244
+ "exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
245
+ "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
246
+ "value_precision": float(precision_sum / max(1, total_cells)),
247
+ "value_recall": float(recall_sum / max(1, total_cells)),
248
+ "cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)),
249
+ "puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)),
250
+ "solve_rate": float(n_solve) / max(1, n_total_puzzles),
251
+ }
252
+
253
+
254
+ # ---- Main loop --------------------------------------------------------------
255
+
256
+
257
+ def parse_args() -> argparse.Namespace:
258
+ p = argparse.ArgumentParser()
259
+ p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
260
+ p.add_argument("--train_jsonl", type=str, required=True)
261
+ p.add_argument("--eval_jsonl", type=str, required=True)
262
+ p.add_argument("--output_dir", type=str, required=True)
263
+ p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache"))
264
+ p.add_argument("--init_adapter_dir", type=str, default="")
265
+ p.add_argument("--seed", type=int, default=0)
266
+
267
+ # Data
268
+ p.add_argument("--limit_train_rows", type=int, default=10000)
269
+ p.add_argument("--eval_rows", type=int, default=50)
270
+
271
+ # Train hyperparameters
272
+ p.add_argument("--per_device_train_batch_size", type=int, default=4)
273
+ p.add_argument("--gradient_accumulation_steps", type=int, default=2)
274
+ p.add_argument("--learning_rate", type=float, default=5e-5)
275
+ p.add_argument("--weight_decay", type=float, default=0.0)
276
+ p.add_argument("--max_steps", type=int, default=4000)
277
+ p.add_argument("--logging_steps", type=int, default=25)
278
+ p.add_argument("--save_steps", type=int, default=500)
279
+ p.add_argument("--eval_every_steps", type=int, default=500)
280
+ p.add_argument("--max_grad_norm", type=float, default=1.0)
281
+ p.add_argument("--max_completion_length", type=int, default=96)
282
+ p.add_argument("--max_prompt_length", type=int, default=1024)
283
+
284
+ # LoRA
285
+ p.add_argument("--lora_r", type=int, default=32)
286
+ p.add_argument("--lora_alpha", type=int, default=64)
287
+ p.add_argument("--lora_dropout", type=float, default=0.05)
288
+ p.add_argument("--enable_gradient_checkpointing", action="store_true")
289
+
290
+ # Adaptive-k schedule
291
+ p.add_argument("--start_k", type=int, default=0)
292
+ p.add_argument("--max_k", type=int, default=4)
293
+ p.add_argument(
294
+ "--min_steps_per_k",
295
+ type=int,
296
+ default=400,
297
+ help="Minimum SFT steps to spend at each k before considering an increment.",
298
+ )
299
+ p.add_argument(
300
+ "--plateau_window",
301
+ type=int,
302
+ default=100,
303
+ help="Sliding window (in steps) used to compute rolling-mean loss for plateau detection.",
304
+ )
305
+ p.add_argument(
306
+ "--plateau_eps",
307
+ type=float,
308
+ default=0.005,
309
+ help="If rolling_mean(loss[-w:]) - rolling_mean(loss[-2w:-w]) > -plateau_eps -> plateau detected.",
310
+ )
311
+ p.add_argument(
312
+ "--converged_eps",
313
+ type=float,
314
+ default=0.001,
315
+ help="If two consecutive plateau windows pass with delta within this band, we declare convergence and stop.",
316
+ )
317
+
318
+ return p.parse_args()
319
+
320
+
321
+ def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device):
322
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir, use_fast=True)
323
+ if tokenizer.pad_token_id is None:
324
+ tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
325
+ model = AutoModelForCausalLM.from_pretrained(
326
+ args.model_name, cache_dir=args.cache_dir,
327
+ torch_dtype=pick_dtype(), low_cpu_mem_usage=True,
328
+ )
329
+ if str(args.init_adapter_dir).strip():
330
+ model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
331
+ else:
332
+ lora = LoraConfig(
333
+ r=args.lora_r,
334
+ lora_alpha=args.lora_alpha,
335
+ lora_dropout=args.lora_dropout,
336
+ bias="none",
337
+ task_type="CAUSAL_LM",
338
+ target_modules=[
339
+ "q_proj", "k_proj", "v_proj", "o_proj",
340
+ "gate_proj", "up_proj", "down_proj",
341
+ ],
342
+ )
343
+ model = get_peft_model(model, lora)
344
+ if args.enable_gradient_checkpointing:
345
+ if hasattr(model, "gradient_checkpointing_enable"):
346
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
347
+ if hasattr(model, "enable_input_require_grads"):
348
+ model.enable_input_require_grads()
349
+ if hasattr(model, "config"):
350
+ model.config.use_cache = False
351
+ model.to(device)
352
+ return model, tokenizer
353
+
354
+
355
+ def detect_plateau(losses: deque, window: int, plateau_eps: float) -> Tuple[bool, float]:
356
+ if len(losses) < 2 * window:
357
+ return False, 0.0
358
+ arr = list(losses)
359
+ recent = arr[-window:]
360
+ prior = arr[-2 * window : -window]
361
+ delta = (sum(recent) / len(recent)) - (sum(prior) / len(prior))
362
+ # If delta > -plateau_eps, loss hasn't decreased fast enough -> plateau.
363
+ return (delta > -float(plateau_eps)), float(delta)
364
+
365
+
366
+ def save_adapter(model: torch.nn.Module, tokenizer: Any, out: str) -> None:
367
+ os.makedirs(out, exist_ok=True)
368
+ if hasattr(model, "save_pretrained"):
369
+ model.save_pretrained(out)
370
+ if hasattr(tokenizer, "save_pretrained"):
371
+ tokenizer.save_pretrained(out)
372
+
373
+
374
+ def main() -> None:
375
+ args = parse_args()
376
+ set_seed(int(args.seed))
377
+ os.makedirs(args.output_dir, exist_ok=True)
378
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
379
+
380
+ train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
381
+ eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
382
+
383
+ print(f"[adaptive_k] loaded {len(train_rows)} train rows, {len(eval_rows)} eval rows", flush=True)
384
+
385
+ model, tokenizer = setup_model_and_tokenizer(args, device)
386
+ pad_id = int(tokenizer.pad_token_id)
387
+
388
+ # Pre-tokenize the train set once.
389
+ train_examples: List[Dict[str, Any]] = []
390
+ for row in train_rows:
391
+ try:
392
+ ex = tokenize_example(
393
+ tokenizer,
394
+ str(row["prompt"]).strip(),
395
+ str(row["completion"]).strip(),
396
+ int(args.max_prompt_length),
397
+ int(args.max_completion_length),
398
+ )
399
+ if ex["completion_ids"]:
400
+ train_examples.append(ex)
401
+ except Exception as e: # noqa: BLE001
402
+ print(f"[adaptive_k] tokenize skip: {e}", flush=True)
403
+ print(f"[adaptive_k] tokenized {len(train_examples)} train examples", flush=True)
404
+
405
+ optimizer = AdamW(
406
+ (p for p in model.parameters() if p.requires_grad),
407
+ lr=float(args.learning_rate),
408
+ weight_decay=float(args.weight_decay),
409
+ )
410
+
411
+ bs = int(args.per_device_train_batch_size)
412
+ ga = int(args.gradient_accumulation_steps)
413
+ steps = 0
414
+ losses_per_step: List[float] = []
415
+ rolling: deque = deque(maxlen=2 * int(args.plateau_window) + 16)
416
+ k = int(args.start_k)
417
+ max_k = int(args.max_k)
418
+ steps_at_current_k = 0
419
+ grew_at: List[Tuple[int, int]] = [] # (step, new_k)
420
+
421
+ print(f"[adaptive_k] starting at k={k}", flush=True)
422
+ init_eval = run_eval(
423
+ model, tokenizer, eval_rows, device,
424
+ num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
425
+ )
426
+ print(f"[adaptive_k] init eval k={k}: {init_eval}", flush=True)
427
+
428
+ t0 = time.time()
429
+ rng_state = torch.Generator(device="cpu").manual_seed(int(args.seed))
430
+ perm = torch.randperm(len(train_examples), generator=rng_state).tolist()
431
+ cursor = 0
432
+
433
+ optimizer.zero_grad(set_to_none=True)
434
+ micro_in_step = 0
435
+ micro_loss_accum = 0.0
436
+
437
+ while steps < int(args.max_steps):
438
+ if cursor + bs > len(perm):
439
+ perm = torch.randperm(len(train_examples), generator=rng_state).tolist()
440
+ cursor = 0
441
+ batch_indices = perm[cursor : cursor + bs]
442
+ cursor += bs
443
+ batch = [train_examples[i] for i in batch_indices]
444
+
445
+ loss = latent_batched_completion_ce_loss(
446
+ model,
447
+ batch,
448
+ device,
449
+ num_cot_tokens=int(max(0, k)),
450
+ latent_mode="recurrent_hidden",
451
+ pad_token_id=pad_id,
452
+ ) / float(ga)
453
+ loss.backward()
454
+ micro_loss_accum += float(loss.detach().item()) * float(ga)
455
+ micro_in_step += 1
456
+
457
+ if micro_in_step >= ga:
458
+ torch.nn.utils.clip_grad_norm_(
459
+ (p for p in model.parameters() if p.requires_grad),
460
+ float(args.max_grad_norm),
461
+ )
462
+ optimizer.step()
463
+ optimizer.zero_grad(set_to_none=True)
464
+ steps += 1
465
+ steps_at_current_k += 1
466
+ avg_micro_loss = micro_loss_accum / float(ga)
467
+ losses_per_step.append(avg_micro_loss)
468
+ rolling.append(avg_micro_loss)
469
+ micro_in_step = 0
470
+ micro_loss_accum = 0.0
471
+
472
+ if steps % int(args.logging_steps) == 0:
473
+ w = int(args.plateau_window)
474
+ recent = list(rolling)[-w:] if len(rolling) >= w else list(rolling)
475
+ rec_mean = sum(recent) / max(1, len(recent))
476
+ elapsed = time.time() - t0
477
+ print(
478
+ f"[adaptive_k] step={steps} k={k} loss={avg_micro_loss:.4f} "
479
+ f"rolling_mean({len(recent)})={rec_mean:.4f} elapsed={elapsed:.0f}s",
480
+ flush=True,
481
+ )
482
+
483
+ if steps % int(args.eval_every_steps) == 0:
484
+ ev = run_eval(
485
+ model, tokenizer, eval_rows, device,
486
+ num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
487
+ )
488
+ print(f"[adaptive_k] EVAL step={steps} k={k}: {ev}", flush=True)
489
+
490
+ if steps % int(args.save_steps) == 0:
491
+ save_adapter(model, tokenizer, os.path.join(args.output_dir, f"checkpoint-step-{steps:05d}"))
492
+
493
+ # Plateau check (only after `min_steps_per_k` at current k, and we
494
+ # have at least 2*plateau_window losses in the rolling buffer).
495
+ if steps_at_current_k >= int(args.min_steps_per_k):
496
+ plateau, delta = detect_plateau(rolling, int(args.plateau_window), float(args.plateau_eps))
497
+ if plateau and k < max_k:
498
+ print(
499
+ f"[adaptive_k] plateau detected at step={steps} k={k} delta={delta:+.4f} -> growing k -> {k+1}",
500
+ flush=True,
501
+ )
502
+ k += 1
503
+ steps_at_current_k = 0
504
+ grew_at.append((steps, k))
505
+ rolling.clear() # restart plateau tracking after capacity bump
506
+ save_adapter(model, tokenizer, os.path.join(args.output_dir, f"checkpoint-step-{steps:05d}-grow-k{k}"))
507
+ elif plateau and k >= max_k and abs(delta) < float(args.converged_eps):
508
+ print(
509
+ f"[adaptive_k] convergence at step={steps} k={k} delta={delta:+.4f} (max_k reached) -> stopping",
510
+ flush=True,
511
+ )
512
+ break
513
+
514
+ final_dir = os.path.join(args.output_dir, "final")
515
+ save_adapter(model, tokenizer, final_dir)
516
+ final_eval = run_eval(
517
+ model, tokenizer, eval_rows, device,
518
+ num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
519
+ )
520
+ summary = {
521
+ "final_k": k,
522
+ "total_steps": steps,
523
+ "max_k": max_k,
524
+ "grew_at_steps": grew_at,
525
+ "final_eval": final_eval,
526
+ "training_seconds": time.time() - t0,
527
+ }
528
+ with open(os.path.join(args.output_dir, "summary.json"), "w", encoding="utf-8") as f:
529
+ json.dump(summary, f, indent=2)
530
+ print(f"[adaptive_k] DONE summary={json.dumps(summary)}", flush=True)
531
+
532
+
533
+ if __name__ == "__main__":
534
+ main()
_runs/add_variants_g_h.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Add 2 more variants on GPUs 6 and 7 to the active sweep.
3
+ # Both seed from the lr5e5 lowsft S2 SFT step-3000 (the winning lineage at step 150).
4
+ set -euo pipefail
5
+
6
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
7
+ SWEEP_ROOT="${SWEEP_ROOT:-$(ls -dt ${ROOT}/_runs/baseline_1p5b_v4_*/ 2>/dev/null | head -1 | sed 's:/$::')}"
8
+ PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
9
+
10
+ [[ -d "${SWEEP_ROOT}" ]] || { echo "sweep root missing"; exit 1; }
11
+ echo "Sweep: ${SWEEP_ROOT}"
12
+
13
+ CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
14
+ [[ -d "${CKPT_LR5E5}" ]] || { echo "missing init"; exit 1; }
15
+
16
+ launch_variant() {
17
+ local gpu="$1" variant="$2" init="$3"
18
+ shift 3
19
+ local out="${SWEEP_ROOT}/${variant}"
20
+ mkdir -p "${out}"
21
+ local nohup_log="${out}/nohup.log"
22
+ printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
23
+ nohup env \
24
+ ROOT="${ROOT}" \
25
+ VARIANT="${variant}" \
26
+ GPU="${gpu}" \
27
+ S2_SFT_CKPT="${init}" \
28
+ OUTPUT_ROOT="${out}" \
29
+ USE_WANDB=0 \
30
+ WANDB_MODE=offline \
31
+ "$@" \
32
+ bash "${PIPELINE}" \
33
+ </dev/null >"${nohup_log}" 2>&1 &
34
+ local pid=$!
35
+ printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
36
+ echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
37
+ disown "${pid}" 2>/dev/null || true
38
+ }
39
+
40
+ # pipe_g: lr5e5 lineage, faster GRPO LR (1e-5) to push convergence
41
+ launch_variant 6 pipe_g_lr5e5_grpo1e5 "${CKPT_LR5E5}" GRPO_LR=1e-5 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
42
+
43
+ # pipe_h: lr5e5 lineage, lower singleton penalty (1.0) to test if 1.5 hurts
44
+ launch_variant 7 pipe_h_lr5e5_grpo5e6_sngl10 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.0
45
+
46
+ # Update sweep README
47
+ cat >>"${SWEEP_ROOT}/SWEEP_README.md" <<EOF
48
+
49
+ ## Added at $(date '+%H:%M:%S')
50
+
51
+ | GPU | variant | S2 init | GRPO LR | S3 SFT LR | penalty_singleton |
52
+ | ---: | --- | --- | ---: | ---: | ---: |
53
+ | 6 | pipe_g_lr5e5_grpo1e5 | lr5e5_lowsft step-3000 | 1e-5 | 2e-5 | 1.5 |
54
+ | 7 | pipe_h_lr5e5_grpo5e6_sngl10 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 1.0 |
55
+ EOF
56
+
57
+ echo "Done. Now running 8 variants on GPUs 0..7."
_runs/add_variants_i_j_k_l.sh ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Add 4 high-throughput variants on freed GPUs 0,2,3,4.
3
+ # 3 variants fast-forward to S3 SFT (since S2 GRPO is plateau-stuck on baseline).
4
+ # 1 variant tries an aggressive 10x GRPO LR to break the S2 plateau.
5
+ set -euo pipefail
6
+
7
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
8
+ SWEEP_ROOT="${SWEEP_ROOT:-$(ls -dt ${ROOT}/_runs/baseline_1p5b_v4_*/ 2>/dev/null | head -1 | sed 's:/$::')}"
9
+ PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
10
+
11
+ [[ -d "${SWEEP_ROOT}" ]] || { echo "sweep root missing"; exit 1; }
12
+ echo "Sweep: ${SWEEP_ROOT}"
13
+
14
+ CKPT_LR1E4="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr1e4_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
15
+ CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
16
+ [[ -d "${CKPT_LR5E5}" ]] || { echo "missing init lr5e5"; exit 1; }
17
+ [[ -d "${CKPT_LR1E4}" ]] || { echo "missing init lr1e4"; exit 1; }
18
+
19
+ launch_variant() {
20
+ local gpu="$1" variant="$2" init="$3"
21
+ shift 3
22
+ local out="${SWEEP_ROOT}/${variant}"
23
+ mkdir -p "${out}"
24
+ local nohup_log="${out}/nohup.log"
25
+ printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
26
+ nohup env \
27
+ ROOT="${ROOT}" \
28
+ VARIANT="${variant}" \
29
+ GPU="${gpu}" \
30
+ S2_SFT_CKPT="${init}" \
31
+ OUTPUT_ROOT="${out}" \
32
+ USE_WANDB=0 \
33
+ WANDB_MODE=offline \
34
+ "$@" \
35
+ bash "${PIPELINE}" \
36
+ </dev/null >"${nohup_log}" 2>&1 &
37
+ local pid=$!
38
+ printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
39
+ echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
40
+ disown "${pid}" 2>/dev/null || true
41
+ }
42
+
43
+ # pipe_i (GPU 0): fast-forward to S3 SFT from lr5e5 lowsft step-3000.
44
+ # high-throughput: no GC, bs=32x1, larger eval batches.
45
+ launch_variant 0 pipe_i_s3sft_lr5e5_fast "${CKPT_LR5E5}" \
46
+ START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR5E5}" \
47
+ SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
48
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
49
+ USE_GC=0
50
+
51
+ # pipe_j (GPU 2): fast-forward to S3 SFT from lr5e5 with lower LR for stability.
52
+ launch_variant 2 pipe_j_s3sft_lr5e5_lr1e5 "${CKPT_LR5E5}" \
53
+ START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR5E5}" \
54
+ SFT_LR_S3=1e-5 SFT_BS=32 SFT_GA=1 \
55
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
56
+ USE_GC=0
57
+
58
+ # pipe_k (GPU 3): fast-forward to S3 SFT from lr1e4 lineage (mirror of i but other init).
59
+ launch_variant 3 pipe_k_s3sft_lr1e4_fast "${CKPT_LR1E4}" \
60
+ START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR1E4}" \
61
+ SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
62
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
63
+ USE_GC=0
64
+
65
+ # pipe_l (GPU 4): aggressive 10x GRPO LR + 16 generations, push past S2 plateau.
66
+ launch_variant 4 pipe_l_lr5e5_grpo5e5_ng16 "${CKPT_LR5E5}" \
67
+ START_PHASE=s2_grpo \
68
+ GRPO_LR=5e-5 GRPO_BS=16 GRPO_GA=1 GRPO_NG=16 \
69
+ PENALTY_SINGLETON=1.5 \
70
+ SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
71
+ USE_GC=0
72
+
73
+ cat >>"${SWEEP_ROOT}/SWEEP_README.md" <<EOF
74
+
75
+ ## Added at $(date '+%H:%M:%S') — high-throughput / S3 fast-forward
76
+
77
+ S2 GRPO plateaued at solve=0.14 (lr5e5 lineage) or 0.05 (lr1e4 lineage) for all
78
+ of pipe_a/b/c/d/e — bit-identical evals from step 150 to 450. The per-cell
79
+ exact ceiling (~0.91) caps puzzle solve at ~0.91^20 ~= 0.14 regardless of
80
+ GRPO. Real lever is S3 SFT on harder cells (multi-value).
81
+
82
+ Killed pipe_a, pipe_c, pipe_d, pipe_e (flat). Launched 4 replacements with
83
+ USE_GC=0 (gradient checkpointing OFF — we have 80 GB headroom) and bs=32x1
84
+ for ~2-3x throughput per GPU.
85
+
86
+ | GPU | variant | start phase | init | SFT LR (S3) | GRPO LR | bs | ng |
87
+ | ---: | --- | --- | --- | ---: | ---: | ---: | ---: |
88
+ | 0 | pipe_i_s3sft_lr5e5_fast | s3_sft | lr5e5 step-3000 | 2e-5 | 5e-6 | 32 | 8 |
89
+ | 2 | pipe_j_s3sft_lr5e5_lr1e5 | s3_sft | lr5e5 step-3000 | 1e-5 | 5e-6 | 32 | 8 |
90
+ | 3 | pipe_k_s3sft_lr1e4_fast | s3_sft | lr1e4 step-3000 | 2e-5 | 5e-6 | 32 | 8 |
91
+ | 4 | pipe_l_lr5e5_grpo5e5_ng16 | s2_grpo | lr5e5 step-3000 | - | 5e-5 | 16 | 16 |
92
+ EOF
93
+
94
+ echo "Done. Now running 8 variants on GPUs 0..7."
_runs/baseline_1p5b_pipeline_v4.sh ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 1.5B vanilla baseline: S2 GRPO -> S3 SFT -> S3 GRPO, single GPU.
3
+ # Optionally pre-pends extra S2 SFT steps if EXTRA_S2_SFT_STEPS>0.
4
+ #
5
+ # Required env:
6
+ # VARIANT variant name (used in dirs / wandb)
7
+ # GPU CUDA index for this variant (0..7)
8
+ # S2_SFT_CKPT path to S2 SFT LoRA adapter (uses this as S2 GRPO init)
9
+ #
10
+ # Optional env:
11
+ # ROOT default /home/ubuntu/curriculum_cot
12
+ # PYTHON_BIN default /opt/pytorch/bin/python
13
+ # OUTPUT_ROOT default $ROOT/_runs/baseline_1p5b_v4_$(date)/$VARIANT
14
+ # MODEL_NAME default Qwen/Qwen2.5-1.5B-Instruct
15
+ # GRPO_LR default 5e-6
16
+ # GRPO_BETA default 0.0
17
+ # GRPO_NG default 8
18
+ # GRPO_BS default 16
19
+ # GRPO_GA default 2
20
+ # GRPO_PROMPT default 768
21
+ # GRPO_COMPL default 24
22
+ # PENALTY_SINGLETON default 1.5
23
+ # PENALTY_BAD default 1.0
24
+ # REWARD_GOOD default 1.25
25
+ # PENALTY_MAL default 4.0
26
+ # PENALTY_EMPTY default 0.5
27
+ # SFT_LR_S3 default 2e-5
28
+ # SFT_BS default 16
29
+ # SFT_GA default 2
30
+ # VALUE_TARGET default 0.98
31
+ # S2_GRPO_MAX_STEPS default 1200 (pipeline budget)
32
+ # S3_SFT_MAX_STEPS default 2400
33
+ # S3_GRPO_MAX_STEPS default 1500
34
+ # EXTRA_S2_SFT_STEPS default 0 (extra S2 SFT steps before S2 GRPO)
35
+ # EXTRA_S2_SFT_LR default 1e-5
36
+ # EVAL_ROWS default 100
37
+ # TRAIN_ROWS default 10000
38
+ # USE_WANDB default 0
39
+ # WANDB_PROJECT default sudoku-baseline-1p5b-v4
40
+ # WANDB_MODE default offline
41
+ # PHASE_WALL_SECS default 0 (no phase wallclock cap)
42
+ # START_PHASE default s2_grpo (one of: s2_sft_extra,s2_grpo,s3_sft,s3_grpo)
43
+ # S3_SFT_INIT if START_PHASE=s3_sft, S3-SFT init adapter (overrides S2 GRPO output)
44
+ # S3_GRPO_INIT if START_PHASE=s3_grpo, S3-GRPO init adapter
45
+ # USE_GC default 0 (1 to enable gradient checkpointing; we usually have memory)
46
+
47
+ set -euo pipefail
48
+
49
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
50
+ PYTHON_BIN="${PYTHON_BIN:-/opt/pytorch/bin/python}"
51
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
52
+ GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
53
+
54
+ : "${VARIANT:?VARIANT required}"
55
+ : "${GPU:?GPU required}"
56
+ : "${S2_SFT_CKPT:?S2_SFT_CKPT required}"
57
+
58
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/_runs/baseline_1p5b_v4_$(date +%Y%m%d_%H%M%S)/${VARIANT}}"
59
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
60
+
61
+ GRPO_LR="${GRPO_LR:-5e-6}"
62
+ GRPO_BETA="${GRPO_BETA:-0.0}"
63
+ GRPO_NG="${GRPO_NG:-8}"
64
+ GRPO_BS="${GRPO_BS:-16}"
65
+ GRPO_GA="${GRPO_GA:-2}"
66
+ GRPO_PROMPT="${GRPO_PROMPT:-768}"
67
+ GRPO_COMPL="${GRPO_COMPL:-24}"
68
+ PENALTY_SINGLETON="${PENALTY_SINGLETON:-1.5}"
69
+ PENALTY_BAD="${PENALTY_BAD:-1.0}"
70
+ PENALTY_MAL="${PENALTY_MAL:-4.0}"
71
+ PENALTY_EMPTY="${PENALTY_EMPTY:-0.5}"
72
+ REWARD_GOOD="${REWARD_GOOD:-1.25}"
73
+ PENALTY_MISSING="${PENALTY_MISSING:-0.0}"
74
+ EXACT_MATCH_BONUS="${EXACT_MATCH_BONUS:-0.0}"
75
+ CARD_MISMATCH_PEN="${CARD_MISMATCH_PEN:-0.0}"
76
+ SFT_OVERSAMPLE="${SFT_OVERSAMPLE:-1}"
77
+ SFT_TGT_MIN="${SFT_TGT_MIN:-0}"
78
+ SFT_TGT_MAX="${SFT_TGT_MAX:-0}"
79
+ SFT_LR_S3="${SFT_LR_S3:-2e-5}"
80
+ SFT_BS="${SFT_BS:-16}"
81
+ SFT_GA="${SFT_GA:-2}"
82
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
83
+ S2_GRPO_MAX_STEPS="${S2_GRPO_MAX_STEPS:-1200}"
84
+ S3_SFT_MAX_STEPS="${S3_SFT_MAX_STEPS:-2400}"
85
+ S3_GRPO_MAX_STEPS="${S3_GRPO_MAX_STEPS:-1500}"
86
+ EXTRA_S2_SFT_STEPS="${EXTRA_S2_SFT_STEPS:-0}"
87
+ EXTRA_S2_SFT_LR="${EXTRA_S2_SFT_LR:-1e-5}"
88
+ EVAL_ROWS="${EVAL_ROWS:-100}"
89
+ TRAIN_ROWS="${TRAIN_ROWS:-10000}"
90
+ USE_WANDB="${USE_WANDB:-0}"
91
+ WANDB_PROJECT="${WANDB_PROJECT:-sudoku-baseline-1p5b-v4}"
92
+ WANDB_MODE="${WANDB_MODE:-offline}"
93
+ PHASE_WALL_SECS="${PHASE_WALL_SECS:-0}"
94
+ START_PHASE="${START_PHASE:-s2_grpo}"
95
+ S3_SFT_INIT="${S3_SFT_INIT:-}"
96
+ S3_GRPO_INIT="${S3_GRPO_INIT:-}"
97
+ USE_GC="${USE_GC:-0}"
98
+
99
+ TRAIN_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
100
+ EVAL_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
101
+
102
+ mkdir -p "${OUTPUT_ROOT}"
103
+ PIPELINE_LOG="${OUTPUT_ROOT}/PIPELINE.log"
104
+
105
+ ts() { date +'%H:%M:%S'; }
106
+ log() { printf '[%s] %s\n' "$(ts)" "$*" | tee -a "${PIPELINE_LOG}" >&2; }
107
+
108
+ latest_ckpt_step() {
109
+ local d="$1"
110
+ shopt -s nullglob
111
+ local cks=("${d}"/checkpoint-step-*)
112
+ shopt -u nullglob
113
+ (( ${#cks[@]} == 0 )) && return 1
114
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
115
+ }
116
+
117
+ best_grpo_adapter() {
118
+ local d="$1"
119
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
120
+ printf '%s\n' "${d}"; return 0
121
+ fi
122
+ local best="" step=-1
123
+ shopt -s nullglob
124
+ for c in "${d}"/checkpoint-*; do
125
+ [[ -d "$c" ]] || continue
126
+ [[ -f "$c/adapter_model.safetensors" ]] || continue
127
+ local n="${c##*checkpoint-}"
128
+ if [[ "$n" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
129
+ step=$((10#${n})); best="$c"
130
+ fi
131
+ done
132
+ shopt -u nullglob
133
+ [[ -n "$best" ]] || return 1
134
+ printf '%s\n' "$best"
135
+ }
136
+
137
+ if [[ ! -f "${TRAIN_JSONL}" || ! -f "${EVAL_JSONL}" ]]; then
138
+ log "ERROR: missing dataset jsonls (${TRAIN_JSONL} / ${EVAL_JSONL})."
139
+ exit 1
140
+ fi
141
+
142
+ export CUDA_VISIBLE_DEVICES="${GPU}"
143
+ export TOKENIZERS_PARALLELISM=false
144
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
145
+ export HF_HOME="${ROOT}/.hf_cache"
146
+ export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
147
+ export WANDB_MODE="${WANDB_MODE}"
148
+
149
+ run_sft() {
150
+ local stage="$1" init_adapter="$2" out_dir="$3" lr="$4" max_steps="$5" tag="$6"
151
+ mkdir -p "${out_dir}"
152
+ log "=== Stage ${stage} SFT (${tag}) lr=${lr} max_steps=${max_steps} bs=${SFT_BS}x${SFT_GA} GC=${USE_GC} init=${init_adapter} ==="
153
+ log " out=${out_dir}"
154
+ local extra=()
155
+ if [[ "${USE_WANDB}" == "1" ]]; then
156
+ extra+=(--use_wandb --wandb_project "${WANDB_PROJECT}" \
157
+ --wandb_run_name "${VARIANT}_${tag}" --wandb_mode "${WANDB_MODE}")
158
+ fi
159
+ if [[ "${USE_GC}" == "1" ]]; then
160
+ extra+=(--enable_gradient_checkpointing)
161
+ fi
162
+ "${PYTHON_BIN}" -u "${SFT_SCRIPT}" \
163
+ --model_name "${MODEL_NAME}" \
164
+ --train_jsonl "${TRAIN_JSONL}" \
165
+ --eval_jsonl "${EVAL_JSONL}" \
166
+ --output_dir "${out_dir}" \
167
+ --cache_dir "${ROOT}/.hf_cache" \
168
+ --init_adapter_dir "${init_adapter}" \
169
+ --seed 0 \
170
+ --gpu_id 0 \
171
+ --stage_i "${stage}" \
172
+ --total_empties_hint 20 \
173
+ --per_device_train_batch_size "${SFT_BS}" \
174
+ --gradient_accumulation_steps "${SFT_GA}" \
175
+ --num_epochs 256 \
176
+ --learning_rate "${lr}" \
177
+ --max_grad_norm 1.0 \
178
+ --logging_steps 25 \
179
+ --eval_steps 150 \
180
+ --save_steps 200 \
181
+ --eval_rows "${EVAL_ROWS}" \
182
+ --max_completion_length 24 \
183
+ --limit_train_rows "${TRAIN_ROWS}" \
184
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
185
+ --eval_value_precision_stop "${VALUE_TARGET}" \
186
+ --eval_value_recall_stop "${VALUE_TARGET}" \
187
+ --eval_exact_set_match_stop 0 \
188
+ --eval_solve_rate_stop 0 \
189
+ --min_steps_before_stop 100 \
190
+ --max_wall_clock_seconds "${PHASE_WALL_SECS}" \
191
+ --max_steps "${max_steps}" \
192
+ --multi_value_oversample_factor "${SFT_OVERSAMPLE}" \
193
+ --train_target_size_min "${SFT_TGT_MIN}" \
194
+ --train_target_size_max "${SFT_TGT_MAX}" \
195
+ "${extra[@]}" 2>&1 | tee "${out_dir}/train.log"
196
+ }
197
+
198
+ run_grpo() {
199
+ local stage="$1" init_adapter="$2" out_dir="$3" max_steps="$4" tag="$5"
200
+ mkdir -p "${out_dir}"
201
+ log "=== Stage ${stage} GRPO (${tag}) lr=${GRPO_LR} ng=${GRPO_NG} bs=${GRPO_BS}x${GRPO_GA} prompt=${GRPO_PROMPT} GC=${USE_GC} max_steps=${max_steps} init=${init_adapter} ==="
202
+ log " rewards: good=${REWARD_GOOD} bad=${PENALTY_BAD} mal=${PENALTY_MAL} empty=${PENALTY_EMPTY} sngl=${PENALTY_SINGLETON} missing=${PENALTY_MISSING} exact_b=${EXACT_MATCH_BONUS} card_pen=${CARD_MISMATCH_PEN}"
203
+ log " out=${out_dir}"
204
+ local extra=()
205
+ if [[ "${USE_WANDB}" == "1" ]]; then
206
+ extra+=(--use_wandb --wandb_project "${WANDB_PROJECT}" \
207
+ --wandb_run_name "${VARIANT}_${tag}" --wandb_mode "${WANDB_MODE}")
208
+ fi
209
+ if [[ "${USE_GC}" == "1" ]]; then
210
+ extra+=(--enable_gradient_checkpointing)
211
+ fi
212
+ "${PYTHON_BIN}" -u "${GRPO_SCRIPT}" \
213
+ --model_name "${MODEL_NAME}" \
214
+ --train_jsonl "${TRAIN_JSONL}" \
215
+ --eval_jsonl "${EVAL_JSONL}" \
216
+ --output_dir "${out_dir}" \
217
+ --cache_dir "${ROOT}/.hf_cache" \
218
+ --init_adapter_dir "${init_adapter}" \
219
+ --seed 0 \
220
+ --gpu_id 0 \
221
+ --stage_i "${stage}" \
222
+ --total_empties_hint 20 \
223
+ --per_device_train_batch_size "${GRPO_BS}" \
224
+ --gradient_accumulation_steps "${GRPO_GA}" \
225
+ --num_train_epochs 100 \
226
+ --learning_rate "${GRPO_LR}" \
227
+ --logging_steps 10 \
228
+ --save_steps 200 \
229
+ --eval_steps 150 \
230
+ --eval_rows "${EVAL_ROWS}" \
231
+ --num_generations "${GRPO_NG}" \
232
+ --max_prompt_length "${GRPO_PROMPT}" \
233
+ --max_completion_length "${GRPO_COMPL}" \
234
+ --beta "${GRPO_BETA}" \
235
+ --limit_train_rows "${TRAIN_ROWS}" \
236
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
237
+ --reward_good_value "${REWARD_GOOD}" \
238
+ --penalty_bad_value "${PENALTY_BAD}" \
239
+ --penalty_malformed "${PENALTY_MAL}" \
240
+ --penalty_empty "${PENALTY_EMPTY}" \
241
+ --penalty_singleton "${PENALTY_SINGLETON}" \
242
+ --penalty_missing "${PENALTY_MISSING}" \
243
+ --exact_match_bonus "${EXACT_MATCH_BONUS}" \
244
+ --cardinality_mismatch_penalty "${CARD_MISMATCH_PEN}" \
245
+ --eval_value_precision_stop "${VALUE_TARGET}" \
246
+ --eval_value_recall_stop "${VALUE_TARGET}" \
247
+ --eval_solve_rate_stop 0 \
248
+ --min_steps_before_stop 100 \
249
+ --max_wall_clock_seconds "${PHASE_WALL_SECS}" \
250
+ --max_steps "${max_steps}" \
251
+ "${extra[@]}" 2>&1 | tee "${out_dir}/train.log"
252
+ }
253
+
254
+ log "===== ${VARIANT} on GPU ${GPU} ====="
255
+ log "S2 SFT init: ${S2_SFT_CKPT}"
256
+ log "START_PHASE=${START_PHASE} GRPO_LR=${GRPO_LR} SFT_LR_S3=${SFT_LR_S3} PENALTY_SINGLETON=${PENALTY_SINGLETON} USE_GC=${USE_GC}"
257
+ log " EXTRA_S2_SFT_STEPS=${EXTRA_S2_SFT_STEPS} GRPO_BS=${GRPO_BS}x${GRPO_GA} SFT_BS=${SFT_BS}x${SFT_GA} GRPO_NG=${GRPO_NG}"
258
+
259
+ S2_SFT_DIR_FOR_GRPO="${S2_SFT_CKPT}"
260
+ S2_GRPO_ADAPTER=""
261
+ S3_SFT_INIT_RESOLVED=""
262
+ S3_GRPO_INIT_RESOLVED=""
263
+
264
+ phase_idx() {
265
+ case "$1" in
266
+ s2_sft_extra) echo 1 ;;
267
+ s2_grpo) echo 2 ;;
268
+ s3_sft) echo 3 ;;
269
+ s3_grpo) echo 4 ;;
270
+ *) echo 2 ;;
271
+ esac
272
+ }
273
+ START_IDX="$(phase_idx "${START_PHASE}")"
274
+
275
+ if (( START_IDX <= 1 )) && (( EXTRA_S2_SFT_STEPS > 0 )); then
276
+ S2_SFT_EXTRA_DIR="${OUTPUT_ROOT}/s2_sft_extra"
277
+ run_sft 2 "${S2_SFT_CKPT}" "${S2_SFT_EXTRA_DIR}" "${EXTRA_S2_SFT_LR}" "${EXTRA_S2_SFT_STEPS}" "s2sft_extra"
278
+ if NEW_CKPT="$(latest_ckpt_step "${S2_SFT_EXTRA_DIR}")"; then
279
+ log ">>> Extra S2 SFT ckpt: ${NEW_CKPT}"
280
+ S2_SFT_DIR_FOR_GRPO="${NEW_CKPT}"
281
+ else
282
+ log "WARN: no new S2 SFT ckpt produced; falling back to ${S2_SFT_CKPT}"
283
+ fi
284
+ fi
285
+
286
+ if (( START_IDX <= 2 )); then
287
+ S2_GRPO_DIR="${OUTPUT_ROOT}/s2_grpo"
288
+ run_grpo 2 "${S2_SFT_DIR_FOR_GRPO}" "${S2_GRPO_DIR}" "${S2_GRPO_MAX_STEPS}" "s2grpo"
289
+ S2_GRPO_ADAPTER="$(best_grpo_adapter "${S2_GRPO_DIR}")"
290
+ if [[ -z "${S2_GRPO_ADAPTER}" ]]; then
291
+ log "ERROR: no S2 GRPO adapter under ${S2_GRPO_DIR}"; exit 1
292
+ fi
293
+ log ">>> S2 GRPO adapter: ${S2_GRPO_ADAPTER}"
294
+ S3_SFT_INIT_RESOLVED="${S2_GRPO_ADAPTER}"
295
+ elif (( START_IDX == 3 )); then
296
+ if [[ -z "${S3_SFT_INIT}" ]]; then
297
+ log "ERROR: START_PHASE=s3_sft but S3_SFT_INIT is empty"; exit 1
298
+ fi
299
+ S3_SFT_INIT_RESOLVED="${S3_SFT_INIT}"
300
+ log ">>> Skipping to S3 SFT, init=${S3_SFT_INIT_RESOLVED}"
301
+ fi
302
+
303
+ if (( START_IDX <= 3 )); then
304
+ S3_SFT_DIR="${OUTPUT_ROOT}/s3_sft"
305
+ run_sft 3 "${S3_SFT_INIT_RESOLVED}" "${S3_SFT_DIR}" "${SFT_LR_S3}" "${S3_SFT_MAX_STEPS}" "s3sft"
306
+ S3_SFT_CKPT="$(latest_ckpt_step "${S3_SFT_DIR}")"
307
+ if [[ -z "${S3_SFT_CKPT}" ]]; then
308
+ log "ERROR: no S3 SFT ckpt under ${S3_SFT_DIR}"; exit 1
309
+ fi
310
+ log ">>> S3 SFT ckpt: ${S3_SFT_CKPT}"
311
+ S3_GRPO_INIT_RESOLVED="${S3_SFT_CKPT}"
312
+ elif (( START_IDX == 4 )); then
313
+ if [[ -z "${S3_GRPO_INIT}" ]]; then
314
+ log "ERROR: START_PHASE=s3_grpo but S3_GRPO_INIT is empty"; exit 1
315
+ fi
316
+ S3_GRPO_INIT_RESOLVED="${S3_GRPO_INIT}"
317
+ log ">>> Skipping to S3 GRPO, init=${S3_GRPO_INIT_RESOLVED}"
318
+ fi
319
+
320
+ S3_GRPO_DIR="${OUTPUT_ROOT}/s3_grpo"
321
+ run_grpo 3 "${S3_GRPO_INIT_RESOLVED}" "${S3_GRPO_DIR}" "${S3_GRPO_MAX_STEPS}" "s3grpo"
322
+ S3_GRPO_ADAPTER="$(best_grpo_adapter "${S3_GRPO_DIR}")"
323
+ if [[ -z "${S3_GRPO_ADAPTER}" ]]; then
324
+ log "ERROR: no S3 GRPO adapter under ${S3_GRPO_DIR}"; exit 1
325
+ fi
326
+ log ">>> S3 GRPO adapter: ${S3_GRPO_ADAPTER}"
327
+
328
+ log "===== ${VARIANT} DONE — final S3 GRPO adapter at ${S3_GRPO_ADAPTER} ====="
_runs/eval_strawman_cellpolicy.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric.
3
+
4
+ This is a thin CLI wrapper that:
5
+
6
+ 1. Loads a base model + LoRA adapter.
7
+ 2. Runs the same scoring procedure as
8
+ ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``,
9
+ i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate
10
+ over empty cells in row-major order and scores each predicted value
11
+ with ``score_prediction_text`` against the i-consistent target set at
12
+ ``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal).
13
+ 3. The only difference vs the cell-policy is that the model emits the whole
14
+ puzzle in ONE forward pass, then the predicted list is split into
15
+ per-cell singletons.
16
+
17
+ Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``)
18
+ and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k
19
+ models (``adaptive_latent_baseline_sudoku_train.py``).
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import sys
27
+ from pathlib import Path
28
+ from typing import Any, Dict, List
29
+
30
+ import torch
31
+ from peft import PeftModel
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
33
+
34
+ ROOT = Path(__file__).resolve().parent.parent
35
+ if str(ROOT) not in sys.path:
36
+ sys.path.insert(0, str(ROOT))
37
+
38
+ from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore # noqa: E402
39
+ load_jsonl_rows,
40
+ pick_dtype,
41
+ )
42
+ from _runs.simple_baseline_sudoku_train import ( # type: ignore # noqa: E402
43
+ run_eval as run_eval_strawman,
44
+ )
45
+ from _runs.adaptive_latent_baseline_sudoku_train import ( # type: ignore # noqa: E402
46
+ run_eval as run_eval_adaptive_k,
47
+ )
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ p = argparse.ArgumentParser()
52
+ p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True)
53
+ p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
54
+ p.add_argument("--adapter_dir", required=True)
55
+ p.add_argument("--eval_jsonl", required=True)
56
+ p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
57
+ p.add_argument("--eval_rows", type=int, default=100)
58
+ p.add_argument("--max_completion_length", type=int, default=96)
59
+ p.add_argument("--stage_i", type=int, default=3)
60
+ p.add_argument(
61
+ "--num_cot_tokens",
62
+ type=int,
63
+ default=0,
64
+ help="Only used when --kind adaptive_k.",
65
+ )
66
+ p.add_argument("--seed", type=int, default=0)
67
+ p.add_argument("--out_json", default="")
68
+ return p.parse_args()
69
+
70
+
71
+ def main() -> None:
72
+ args = parse_args()
73
+ set_seed(int(args.seed))
74
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ dtype = pick_dtype()
76
+
77
+ print(f"[eval-cellpolicy] kind={args.kind} adapter={args.adapter_dir}", flush=True)
78
+ print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl} stage_i={args.stage_i}", flush=True)
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ args.model_name, cache_dir=args.cache_dir, use_fast=True
82
+ )
83
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ base = AutoModelForCausalLM.from_pretrained(
87
+ args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype
88
+ )
89
+ model = PeftModel.from_pretrained(base, args.adapter_dir)
90
+ model.to(device)
91
+ model.eval()
92
+
93
+ rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
94
+ print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True)
95
+
96
+ if args.kind == "strawman":
97
+ metrics = run_eval_strawman(
98
+ model, tokenizer, rows, device,
99
+ max_new_tokens=int(args.max_completion_length),
100
+ print_n=3,
101
+ stage_i=int(args.stage_i),
102
+ )
103
+ else:
104
+ metrics = run_eval_adaptive_k(
105
+ model, tokenizer, rows, device,
106
+ num_cot_tokens=int(args.num_cot_tokens),
107
+ max_new_tokens=int(args.max_completion_length),
108
+ print_n=3,
109
+ stage_i=int(args.stage_i),
110
+ )
111
+
112
+ print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True)
113
+ if args.out_json:
114
+ Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
115
+ with open(args.out_json, "w") as f:
116
+ json.dump(
117
+ {
118
+ "kind": args.kind,
119
+ "adapter_dir": args.adapter_dir,
120
+ "eval_jsonl": args.eval_jsonl,
121
+ "stage_i": int(args.stage_i),
122
+ "num_cot_tokens": int(args.num_cot_tokens),
123
+ "metrics": metrics,
124
+ },
125
+ f,
126
+ indent=2,
127
+ )
128
+ print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
_runs/launch_adaptive_k_cellpolicy.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Launch two adaptive-k variants (single-stage cell-policy at stage_i=3,
3
+ # no curriculum, but with growing recurrent-hidden thought tokens k).
4
+ set -euo pipefail
5
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
6
+ TS="$(date +%Y%m%d_%H%M%S)"
7
+ SWEEP_ROOT="${ROOT}/_runs/adaptive_k_cellpolicy_${TS}"
8
+ mkdir -p "${SWEEP_ROOT}"
9
+ PY="${ROOT}/_runs/adaptive_k_cellpolicy_pipeline.py"
10
+
11
+ launch() {
12
+ # Usage: launch <variant> <gpu> <KEY=VALUE>... (ignored, args passed via positional CLI args)
13
+ local variant="$1" gpu="$2"
14
+ shift 2
15
+ local out="${SWEEP_ROOT}/${variant}"
16
+ mkdir -p "${out}"
17
+ echo "[launch] ${variant} on GPU ${gpu} out=${out}"
18
+ nohup /opt/pytorch/bin/python -u "${PY}" \
19
+ --variant "${variant}" \
20
+ --gpu "${gpu}" \
21
+ --output_root "${out}" \
22
+ "$@" > "${out}/console.log" 2>&1 &
23
+ local pid=$!
24
+ disown "${pid}" || true
25
+ echo "${variant}=${pid}" >> "${SWEEP_ROOT}/PIDS.txt"
26
+ }
27
+
28
+ # adaptive_a: classic schedule (start at k=0, plateau-bumps with eps=0.01).
29
+ launch adaptive_a_eps01 2 \
30
+ --start_k 0 --max_k 4 --steps_per_phase 600 --max_phases_per_k 2 \
31
+ --plateau_eps 0.01 --sft_lr 2e-5 --sft_bs 8 --sft_ga 4 \
32
+ --grpo_steps 1500 --grpo_lr 5e-6 --grpo_bs 8 --grpo_ga 4 --grpo_ng 8
33
+
34
+ # adaptive_b: faster k-growth (max_phases_per_k=1, force bump every phase).
35
+ launch adaptive_b_fastgrow 3 \
36
+ --start_k 0 --max_k 4 --steps_per_phase 800 --max_phases_per_k 1 \
37
+ --plateau_eps 1.0 --sft_lr 2e-5 --sft_bs 8 --sft_ga 4 \
38
+ --grpo_steps 1500 --grpo_lr 5e-6 --grpo_bs 8 --grpo_ga 4 --grpo_ng 8
39
+
40
+ echo "[launch] sweep root: ${SWEEP_ROOT}"
41
+ echo "[launch] PIDs:"
42
+ cat "${SWEEP_ROOT}/PIDS.txt"
_runs/launch_adaptive_latent_baseline.sh ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Adaptive-k thought-token baseline (experiment D in the 2x2 ablation).
3
+ #
4
+ # Same single-stage, whole-puzzle setup as launch_simple_baseline.sh
5
+ # (experiment C, the "strawman"); same model, LoRA, JSONL, chat template.
6
+ # The ONLY change is that the SFT loss uses the recurrent_hidden mechanism
7
+ # with k thought tokens, and k grows automatically when the rolling-mean
8
+ # loss plateaus.
9
+ set -euo pipefail
10
+
11
+ ROOT=/home/ubuntu/curriculum_cot
12
+ SCRIPT=${ROOT}/_runs/adaptive_latent_baseline_sudoku_train.py
13
+ PYTHON_BIN=/opt/pytorch/bin/python
14
+
15
+ TRAIN_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
16
+ EVAL_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
17
+
18
+ SWEEP_ROOT=${ROOT}/_runs/adaptive_latent_$(date +%Y%m%d_%H%M%S)
19
+ mkdir -p "${SWEEP_ROOT}"
20
+ echo "${SWEEP_ROOT}" > "${ROOT}/_runs/current_adaptive_latent_sweep_dir"
21
+ echo "SWEEP_ROOT=${SWEEP_ROOT}"
22
+
23
+ export TOKENIZERS_PARALLELISM=false
24
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
25
+ export HF_HOME="${ROOT}/.hf_cache"
26
+ export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
27
+ export WANDB_MODE=offline
28
+
29
+ run_variant() {
30
+ local gpu="$1" tag="$2" lr="$3" max_k="$4" min_steps_per_k="$5"
31
+ local out=${SWEEP_ROOT}/${tag}
32
+ mkdir -p "${out}"
33
+ local log=${out}/train.log
34
+ : > "${log}"
35
+ (
36
+ export CUDA_VISIBLE_DEVICES="${gpu}"
37
+ "${PYTHON_BIN}" -u "${SCRIPT}" \
38
+ --train_jsonl "${TRAIN_JSONL}" \
39
+ --eval_jsonl "${EVAL_JSONL}" \
40
+ --output_dir "${out}" \
41
+ --learning_rate "${lr}" \
42
+ --max_steps 4000 \
43
+ --per_device_train_batch_size 4 \
44
+ --gradient_accumulation_steps 2 \
45
+ --logging_steps 25 \
46
+ --save_steps 500 \
47
+ --eval_every_steps 500 \
48
+ --eval_rows 50 \
49
+ --max_completion_length 96 \
50
+ --max_prompt_length 1024 \
51
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
52
+ --enable_gradient_checkpointing \
53
+ --start_k 0 \
54
+ --max_k "${max_k}" \
55
+ --min_steps_per_k "${min_steps_per_k}" \
56
+ --plateau_window 100 \
57
+ --plateau_eps 0.005 \
58
+ --converged_eps 0.001 \
59
+ --seed 0 \
60
+ >> "${log}" 2>&1
61
+ ) >/dev/null 2>&1 &
62
+ local pid=$!
63
+ echo "$pid $gpu $tag" >> "${SWEEP_ROOT}/PIDS.txt"
64
+ disown $pid 2>/dev/null || true
65
+ printf 'GPU %s -> %s pid=%s log=%s\n' "$gpu" "$tag" "$pid" "$log"
66
+ }
67
+
68
+ # 2 variants on idle GPUs 2,3:
69
+ # - adaptive_a: same LR (5e-5) as strawman variant a, max_k=4, min_steps_per_k=400
70
+ # - adaptive_b: smaller min_steps_per_k=250 to grow k more aggressively
71
+ run_variant 2 adaptive_a_lr5e5_maxk4 5e-5 4 400
72
+ run_variant 3 adaptive_b_lr5e5_fastgrow 5e-5 4 250
73
+
74
+ echo
75
+ echo "=== launched ==="
76
+ cat "${SWEEP_ROOT}/PIDS.txt"
_runs/launch_baseline_1p5b_v4.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Launch 6 baseline 1.5B variants in parallel, one per GPU (0..5).
3
+ # Each runs S2 GRPO -> S3 SFT -> S3 GRPO from a v3 lowsft S2 SFT checkpoint.
4
+ set -euo pipefail
5
+
6
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
7
+ SWEEP_ID="${SWEEP_ID:-$(date +%Y%m%d_%H%M%S)}"
8
+ SWEEP_ROOT="${ROOT}/_runs/baseline_1p5b_v4_${SWEEP_ID}"
9
+ PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
10
+
11
+ mkdir -p "${SWEEP_ROOT}"
12
+ SUMMARY="${SWEEP_ROOT}/SWEEP_README.md"
13
+
14
+ CKPT_LR1E4="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr1e4_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
15
+ CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
16
+
17
+ if [[ ! -d "${CKPT_LR1E4}" || ! -d "${CKPT_LR5E5}" ]]; then
18
+ echo "ERROR: missing init checkpoints" >&2
19
+ exit 1
20
+ fi
21
+
22
+ cat >"${SUMMARY}" <<EOF
23
+ # Baseline 1.5B v4 sweep — ${SWEEP_ID}
24
+
25
+ Single GPU per variant. All 6 variants resume from the v3 lowsft S2 SFT
26
+ checkpoints (the only ones with positive trend), then run S2 GRPO -> S3 SFT
27
+ -> S3 GRPO with various GRPO LR / penalty / extra-S2-SFT settings.
28
+
29
+ | GPU | variant | S2 init | GRPO LR | S3 SFT LR | penalty_singleton | extra S2 SFT (steps @ LR) |
30
+ | ---: | --- | --- | ---: | ---: | ---: | --- |
31
+ | 0 | pipe_a_lr1e4_grpo5e6 | lr1e4_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 0 |
32
+ | 1 | pipe_b_lr5e5_grpo5e6 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 0 |
33
+ | 2 | pipe_c_lr1e4_grpo2e6 | lr1e4_lowsft step-3000 | 2e-6 | 2e-5 | 1.5 | 0 |
34
+ | 3 | pipe_d_lr5e5_grpo2e6 | lr5e5_lowsft step-3000 | 2e-6 | 2e-5 | 1.5 | 0 |
35
+ | 4 | pipe_e_lr5e5_grpo5e6_sngl25 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 2.5 | 0 |
36
+ | 5 | pipe_f_lr1e4_extraS2sft | lr1e4_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 1500 @ 1e-5 |
37
+
38
+ Pipeline budget per variant:
39
+ - S2 GRPO max 1200 steps (early stop on prec AND recall >= 0.98)
40
+ - S3 SFT max 2400 steps (same early stop)
41
+ - S3 GRPO max 1500 steps (same early stop)
42
+
43
+ Logs: \`<variant>/PIPELINE.log\`, per-phase: \`<variant>/{s2_grpo,s3_sft,s3_grpo}/train.log\`
44
+ EOF
45
+
46
+ launch_variant() {
47
+ local gpu="$1" variant="$2" init="$3"
48
+ shift 3
49
+ local out="${SWEEP_ROOT}/${variant}"
50
+ mkdir -p "${out}"
51
+ local nohup_log="${out}/nohup.log"
52
+ printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
53
+ nohup env \
54
+ ROOT="${ROOT}" \
55
+ VARIANT="${variant}" \
56
+ GPU="${gpu}" \
57
+ S2_SFT_CKPT="${init}" \
58
+ OUTPUT_ROOT="${out}" \
59
+ USE_WANDB=0 \
60
+ WANDB_MODE=offline \
61
+ "$@" \
62
+ bash "${PIPELINE}" \
63
+ </dev/null >"${nohup_log}" 2>&1 &
64
+ local pid=$!
65
+ printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
66
+ echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
67
+ disown "${pid}" 2>/dev/null || true
68
+ }
69
+
70
+ : > "${SWEEP_ROOT}/PIDS.txt"
71
+
72
+ launch_variant 0 pipe_a_lr1e4_grpo5e6 "${CKPT_LR1E4}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
73
+ launch_variant 1 pipe_b_lr5e5_grpo5e6 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
74
+ launch_variant 2 pipe_c_lr1e4_grpo2e6 "${CKPT_LR1E4}" GRPO_LR=2e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
75
+ launch_variant 3 pipe_d_lr5e5_grpo2e6 "${CKPT_LR5E5}" GRPO_LR=2e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
76
+ launch_variant 4 pipe_e_lr5e5_grpo5e6_sngl25 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=2.5
77
+ launch_variant 5 pipe_f_lr1e4_extraS2sft "${CKPT_LR1E4}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5 EXTRA_S2_SFT_STEPS=1500 EXTRA_S2_SFT_LR=1e-5
78
+
79
+ echo
80
+ echo "Sweep root: ${SWEEP_ROOT}"
81
+ echo "Tail PIDS:"
82
+ cat "${SWEEP_ROOT}/PIDS.txt"
_runs/launch_baseline_push_v5.sh ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Wave-5: push baseline 1.5B past solve=0.35.
3
+ #
4
+ # Idea: best ckpts so far cap at per-cell-exact ~0.943 (solve 0.35 = 0.943^20).
5
+ # To reach solve=0.5 we need exact ~= 0.965. That's +2.2pp of per-cell exact.
6
+ #
7
+ # 4 variants, single-GPU each, on GPUs 4..7.
8
+ # All start from the leader (pipe_m post-S3-GRPO at solve=0.35) or its S3 SFT
9
+ # ckpt, then push S3 GRPO further with different levers:
10
+ # - lower LR (escape / fine refine)
11
+ # - longer steps (3000 instead of 1500)
12
+ # - KL anchor (beta>0) to prevent regression
13
+ # - sharper rewards (mirror what worked for the latent's `s3_grpo_sharp_rwd`)
14
+ set -euo pipefail
15
+
16
+ ROOT=/home/ubuntu/curriculum_cot
17
+ SWEEP_ROOT=/home/ubuntu/curriculum_cot/_runs/baseline_1p5b_v4_20260523_184952
18
+ PIPELINE=$ROOT/_runs/baseline_1p5b_pipeline_v4.sh
19
+
20
+ # best wave-2 anchors
21
+ PIPE_M_S3GRPO_LATEST=$(ls -dt $SWEEP_ROOT/pipe_m_s3sft_from_b/s3_grpo/checkpoint-* 2>/dev/null | head -1)
22
+ PIPE_M_S3SFT_LATEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_sft/checkpoint-step-02400
23
+ PIPE_O_S3SFT_LATEST=$SWEEP_ROOT/pipe_o_s3sft_lr5e6/s3_sft/checkpoint-step-02400
24
+ PIPE_J_S3GRPO_LATEST=$(ls -dt $SWEEP_ROOT/pipe_j_s3sft_lr5e5_lr1e5/s3_grpo/checkpoint-* 2>/dev/null | head -1)
25
+
26
+ # Sanity
27
+ for c in "$PIPE_M_S3GRPO_LATEST" "$PIPE_M_S3SFT_LATEST" "$PIPE_O_S3SFT_LATEST" "$PIPE_J_S3GRPO_LATEST"; do
28
+ [[ -d "$c" ]] || { echo "MISSING: $c"; exit 1; }
29
+ done
30
+
31
+ CKPT_LR5E5=$ROOT/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
32
+
33
+ launch() {
34
+ local gpu="$1" variant="$2"; shift 2
35
+ local out=$SWEEP_ROOT/$variant; mkdir -p "$out"
36
+ nohup env ROOT="$ROOT" VARIANT="$variant" GPU="$gpu" S2_SFT_CKPT="$CKPT_LR5E5" \
37
+ OUTPUT_ROOT="$out" USE_WANDB=0 WANDB_MODE=offline "$@" \
38
+ bash "$PIPELINE" </dev/null >"$out/nohup.log" 2>&1 &
39
+ local pid=$!
40
+ echo "$pid $gpu $variant" >> "$SWEEP_ROOT/PIDS.txt"
41
+ disown $pid 2>/dev/null || true
42
+ printf 'GPU %s -> %s pid=%s\n' "$gpu" "$variant" "$pid"
43
+ }
44
+
45
+ # pipe_t (GPU 4): continue pipe_m's S3 GRPO with lower LR + KL anchor + longer steps.
46
+ # Keep the policy near the SFT reference to avoid the regression we saw earlier.
47
+ launch 4 pipe_t_grpo_low_kl \
48
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3GRPO_LATEST" \
49
+ GRPO_LR=1e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
50
+ GRPO_BETA=0.04 \
51
+ S3_GRPO_MAX_STEPS=3000 \
52
+ USE_GC=0
53
+
54
+ # pipe_u (GPU 5): re-run S3 GRPO from pipe_m's S3-SFT ckpt with sharper rewards
55
+ # (mirror latent `s3_grpo_sharp_rwd` recipe: bigger penalty for bad).
56
+ launch 5 pipe_u_grpo_sharp_rwd \
57
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3SFT_LATEST" \
58
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
59
+ REWARD_GOOD=1.5 PENALTY_BAD=2.0 PENALTY_MAL=4.0 \
60
+ S3_GRPO_MAX_STEPS=3000 \
61
+ USE_GC=0
62
+
63
+ # pipe_v (GPU 6): extend pipe_o's S3 SFT (the strongest pure-SFT path) with very
64
+ # low LR for 4000 more steps. Then S3 GRPO at LR=1e-6.
65
+ launch 6 pipe_v_sft_extend \
66
+ START_PHASE=s3_sft S3_SFT_INIT="$PIPE_O_S3SFT_LATEST" \
67
+ SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
68
+ S3_SFT_MAX_STEPS=4000 \
69
+ GRPO_LR=1e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
70
+ S3_GRPO_MAX_STEPS=2000 \
71
+ USE_GC=0
72
+
73
+ # pipe_w (GPU 7): continue pipe_j's S3 GRPO with very low LR + KL anchor.
74
+ # Different lineage from pipe_m, so this gives an independent push.
75
+ launch 7 pipe_w_j_low_kl \
76
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_J_S3GRPO_LATEST" \
77
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
78
+ GRPO_BETA=0.02 \
79
+ S3_GRPO_MAX_STEPS=3000 \
80
+ USE_GC=0
81
+
82
+ echo
83
+ echo "=== launched ==="
84
+ cat "$SWEEP_ROOT/PIDS.txt" | tail -4
_runs/launch_baseline_push_v6.sh ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Wave-6: push baseline 1.5B past solve=0.40 by porting the latent's winning
3
+ # reward shaping + multi-value oversampling into the vanilla baseline pipeline.
4
+ #
5
+ # Diagnosis from v4/v5 logs:
6
+ # At plateau, eval shows `avg_set_size=1.000` for every step. The model
7
+ # is predicting only ONE value per cell even when the target is multi-valued.
8
+ # Per-cell exact pinned at 0.95 → solve = 0.95^20 ≈ 0.36.
9
+ # Same failure mode the latent's `s3_grpo_sharp_rwd` recipe fixed:
10
+ # exact_match_bonus + cardinality_mismatch_penalty + penalty_missing
11
+ # plus SFT-side multi_value_oversample_factor=5 (and target_size_min=2 for
12
+ # the most aggressive variant).
13
+ #
14
+ # 8 variants on GPUs 0..7. All seed from existing v4 best ckpts so we don't
15
+ # burn cycles redoing S2.
16
+ set -euo pipefail
17
+
18
+ ROOT=/home/ubuntu/curriculum_cot
19
+ SWEEP_ROOT=$ROOT/_runs/baseline_1p5b_v4_20260523_184952
20
+ PIPELINE=$ROOT/_runs/baseline_1p5b_pipeline_v4.sh
21
+
22
+ # --- v4 anchors ----
23
+ PIPE_V_S3SFT_LATEST=$SWEEP_ROOT/pipe_v_sft_extend/s3_sft/checkpoint-step-04000
24
+ PIPE_M_S3SFT_LATEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_sft/checkpoint-step-02400
25
+ PIPE_V_S3GRPO_BEST=$SWEEP_ROOT/pipe_v_sft_extend/s3_grpo/checkpoint-1000 # step 1050 was 0.40 peak; 1000 is closest saved
26
+ PIPE_M_S3GRPO_BEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_grpo/checkpoint-200 # peak per pipe_m logs
27
+ PIPE_O_S3SFT_LATEST=$SWEEP_ROOT/pipe_o_s3sft_lr5e6/s3_sft/checkpoint-step-02400
28
+ CKPT_LR5E5=$ROOT/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
29
+
30
+ for c in "$PIPE_V_S3SFT_LATEST" "$PIPE_M_S3SFT_LATEST" "$PIPE_V_S3GRPO_BEST" "$PIPE_M_S3GRPO_BEST" "$PIPE_O_S3SFT_LATEST"; do
31
+ [[ -d "$c" ]] || { echo "MISSING: $c"; exit 1; }
32
+ done
33
+
34
+ launch() {
35
+ local gpu="$1" variant="$2"; shift 2
36
+ local out=$SWEEP_ROOT/$variant; mkdir -p "$out"
37
+ nohup env ROOT="$ROOT" VARIANT="$variant" GPU="$gpu" S2_SFT_CKPT="$CKPT_LR5E5" \
38
+ OUTPUT_ROOT="$out" USE_WANDB=0 WANDB_MODE=offline "$@" \
39
+ bash "$PIPELINE" </dev/null >"$out/nohup.log" 2>&1 &
40
+ local pid=$!
41
+ echo "$pid $gpu $variant" >> "$SWEEP_ROOT/PIDS.txt"
42
+ disown $pid 2>/dev/null || true
43
+ printf 'GPU %s -> %s pid=%s\n' "$gpu" "$variant" "$pid"
44
+ }
45
+
46
+ # === GRPO continuations (the high-leverage knob) ===
47
+
48
+ # v6_a (GPU 0): continue best v4 GRPO with the FULL latent recipe.
49
+ # card_pen=1.0 + missing=0.75 + exact_b=2.0; LR slightly lower than v4 to be safe.
50
+ launch 0 v6_a_grpo_v_card \
51
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
52
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
53
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
54
+ S3_GRPO_MAX_STEPS=2000
55
+
56
+ # v6_b (GPU 1): "sharp" version — mirror s3_grpo_sharp_rwd's stronger weights.
57
+ launch 1 v6_b_grpo_v_sharp \
58
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
59
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
60
+ PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=4.0 CARD_MISMATCH_PEN=3.0 \
61
+ S3_GRPO_MAX_STEPS=2000
62
+
63
+ # v6_c (GPU 2): full recipe but from pipe_v's S3 SFT (fresh GRPO, not continuation).
64
+ launch 2 v6_c_grpo_vsft_card \
65
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3SFT_LATEST" \
66
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
67
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
68
+ S3_GRPO_MAX_STEPS=2000
69
+
70
+ # v6_d (GPU 3): same recipe but from pipe_m's S3 SFT (different lineage; champion).
71
+ launch 3 v6_d_grpo_msft_card \
72
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3SFT_LATEST" \
73
+ GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
74
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
75
+ S3_GRPO_MAX_STEPS=2000
76
+
77
+ # === SFT push w/ oversample (the data-side knob) ===
78
+
79
+ # v6_e (GPU 4): continue pipe_v S3 SFT with oversample=5. Mirrors r1_sft_c_oversample5.
80
+ launch 4 v6_e_sft_v_oversample5 \
81
+ START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
82
+ SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
83
+ SFT_OVERSAMPLE=5 \
84
+ S3_SFT_MAX_STEPS=2500 \
85
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
86
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
87
+ S3_GRPO_MAX_STEPS=1500
88
+
89
+ # v6_f (GPU 5): same but oversample=8 (more aggressive).
90
+ launch 5 v6_f_sft_v_oversample8 \
91
+ START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
92
+ SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
93
+ SFT_OVERSAMPLE=8 \
94
+ S3_SFT_MAX_STEPS=2500 \
95
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
96
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
97
+ S3_GRPO_MAX_STEPS=1500
98
+
99
+ # v6_g (GPU 6): oversample=5 + train_target_size_min=2 (only multi-value cells).
100
+ # This is the most surgical variant — focus all training mass on the failing cells.
101
+ launch 6 v6_g_sft_v_mv_only \
102
+ START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
103
+ SFT_LR_S3=1e-6 SFT_BS=16 SFT_GA=1 \
104
+ SFT_OVERSAMPLE=5 SFT_TGT_MIN=2 \
105
+ S3_SFT_MAX_STEPS=2000 \
106
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
107
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
108
+ S3_GRPO_MAX_STEPS=1500
109
+
110
+ # v6_h (GPU 7): same as v6_a but with even more steps + KL anchor for stability.
111
+ # The latent best (s3_grpo_baseline) ran with beta=0.0; we know KL>0 hurts long term.
112
+ # But here we want to see whether the new shape rewards survive more steps without
113
+ # regression. Use a small beta (0.01) for gentle anchoring.
114
+ launch 7 v6_h_grpo_v_card_long \
115
+ START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
116
+ GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
117
+ GRPO_BETA=0.01 \
118
+ PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
119
+ S3_GRPO_MAX_STEPS=3000
120
+
121
+ echo
122
+ echo "=== launched ==="
123
+ tail -8 "$SWEEP_ROOT/PIDS.txt"
_runs/launch_latent_reproduction_overnight.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Overnight reproduction of the latent recurrent-hidden 3-stage champion.
3
+ # Mirrors the recipe that produced solve=0.60 (100p) / 0.675 (40p) on 2026-05-22.
4
+ #
5
+ # Single distributed job across all 8 H100s. End-to-end runtime: ~6-7 hrs.
6
+ #
7
+ # Stages: S1 SFT (cot=1) -> S1 GRPO (cot=1)
8
+ # -> S2 SFT (cot=2) -> S2 GRPO (cot=2)
9
+ # -> S3 SFT (cot=3) -> S3 GRPO (cot=3)
10
+ #
11
+ # Hyperparameters (defaults, faithful to original):
12
+ # model Qwen/Qwen2.5-0.5B-Instruct
13
+ # num_cot_tokens 1->2->3 across stages
14
+ # latent_mode recurrent_hidden
15
+ # bs=8/device, grad_accum=2, gradient checkpointing ON
16
+ # stage1_sft_lr=2e-4, stage2/3_sft_lr=5e-5, grpo_lr=1e-6 (hardcoded)
17
+ # value_target=0.98 (precision AND recall)
18
+ # train_puzzles=10000 eval_puzzles=100
19
+ # num_generations=4 max_completion_length=24
20
+
21
+ set -euo pipefail
22
+
23
+ ROOT=/home/ubuntu/curriculum_cot
24
+ SCRIPT="${ROOT}/hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh"
25
+
26
+ RUN_TAG="latent_reproduction_overnight_$(date +%Y%m%d_%H%M%S)"
27
+ OUTPUT_ROOT="${ROOT}/_runs/${RUN_TAG}"
28
+ LOG="${OUTPUT_ROOT}/PIPELINE.log"
29
+ mkdir -p "${OUTPUT_ROOT}"
30
+
31
+ # Free the HF caches and ensure our pre-downloaded Qwen 0.5B is found
32
+ export HF_HOME="${ROOT}/.hf_cache"
33
+ export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
34
+ export HF_HUB_OFFLINE=0
35
+ export TOKENIZERS_PARALLELISM=false
36
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
37
+ # wandb is not authenticated on this machine — keep offline so jobs don't hang
38
+ export WANDB_MODE=offline
39
+ # avoid the upstream script trying to pull from a wandb entity we don't own
40
+ export WANDB_ENTITY="local"
41
+
42
+ # Use our preinstalled pytorch venv
43
+ export PYTHON_BIN=/opt/pytorch/bin/python
44
+
45
+ # 4-GPU distributed run with doubled grad accum to preserve the original
46
+ # effective batch size (8*2*8 = 128 -> 8*4*4 = 128). Takes ~2x wall-clock
47
+ # but is faithful to the original convergence dynamics.
48
+ export GPU_IDS=0,1,2,3
49
+ export NUM_PROCESSES=4
50
+ export SFT_GRAD_ACCUM=4
51
+ export GRPO_GRAD_ACCUM=4
52
+
53
+ # Match original
54
+ export MODEL_NAME="Qwen/Qwen2.5-0.5B-Instruct"
55
+ export VALUE_TARGET=0.98
56
+ export SFT_VALUE_TARGET=0.95
57
+ export GRPO_VALUE_TARGET=0.98
58
+ export TRAIN_PUZZLES=10000
59
+ export EVAL_PUZZLES=100
60
+ export MIN_STEPS_BEFORE_STOP=50
61
+
62
+ # Cap per-phase wallclock to keep us safely under one overnight session.
63
+ # The original took ~6-7 hours; we cap each phase at 75 min to let all 6 phases
64
+ # finish within ~7.5 hrs even if one phase slow-runs.
65
+ export PHASE_WALL_CLOCK_SECONDS=4500
66
+
67
+ # Hard step caps (in addition to early stop on prec+recall)
68
+ export SFT_MAX_STEPS=4000
69
+ export GRPO_MAX_STEPS=2000
70
+
71
+ export RUN_TAG
72
+ export OUTPUT_ROOT
73
+ export CHECKPOINT_ROOT="${OUTPUT_ROOT}"
74
+
75
+ printf '[launch_latent_reproduction] %s\n' "$(date -Is)" | tee -a "${LOG}"
76
+ printf ' RUN_TAG=%s\n' "${RUN_TAG}" | tee -a "${LOG}"
77
+ printf ' OUTPUT_ROOT=%s\n' "${OUTPUT_ROOT}" | tee -a "${LOG}"
78
+ printf ' GPUs=%s nproc=%s model=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" "${MODEL_NAME}" | tee -a "${LOG}"
79
+ printf ' VALUE_TARGET=%s SFT_VALUE_TARGET=%s GRPO_VALUE_TARGET=%s\n' "${VALUE_TARGET}" "${SFT_VALUE_TARGET}" "${GRPO_VALUE_TARGET}" | tee -a "${LOG}"
80
+ printf ' PHASE_WALL_CLOCK=%ss SFT_MAX_STEPS=%s GRPO_MAX_STEPS=%s\n' "${PHASE_WALL_CLOCK_SECONDS}" "${SFT_MAX_STEPS}" "${GRPO_MAX_STEPS}" | tee -a "${LOG}"
81
+
82
+ bash "${SCRIPT}" 2>&1 | tee -a "${LOG}"
_runs/launch_simple_baseline.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Strawman baseline for the rebuttal: vanilla LoRA, no curriculum, no thought
3
+ # tokens, single-shot whole-puzzle prediction. SFT followed by GRPO.
4
+ #
5
+ # Same model (Qwen2.5-1.5B-Instruct), same LoRA (r=32, α=64, dropout=0.05),
6
+ # same JSONL data file, same Qwen chat template wrapping as the cell-policy
7
+ # experiments. The ONLY differences from the cell-policy baseline are:
8
+ # - no per-cell expansion (one example per puzzle)
9
+ # - no stage_i / curriculum
10
+ # - no multi_value_oversample, no exact_match_bonus / cardinality penalties
11
+ # - reward = number of correct values out of 20 + whole-solve bonus
12
+ set -euo pipefail
13
+
14
+ ROOT=/home/ubuntu/curriculum_cot
15
+ SCRIPT=${ROOT}/_runs/simple_baseline_sudoku_train.py
16
+ PYTHON_BIN=/opt/pytorch/bin/python
17
+
18
+ TRAIN_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
19
+ EVAL_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
20
+
21
+ SWEEP_ROOT=${ROOT}/_runs/strawman_baseline_$(date +%Y%m%d_%H%M%S)
22
+ mkdir -p "${SWEEP_ROOT}"
23
+ echo "${SWEEP_ROOT}" > "${ROOT}/_runs/current_strawman_sweep_dir"
24
+ echo "SWEEP_ROOT=${SWEEP_ROOT}"
25
+
26
+ export TOKENIZERS_PARALLELISM=false
27
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
28
+ export HF_HOME="${ROOT}/.hf_cache"
29
+ export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
30
+ export WANDB_MODE=offline
31
+
32
+ run_pipeline() {
33
+ local gpu="$1" tag="$2" sft_lr="$3" grpo_lr="$4" sft_max="$5" grpo_max="$6"
34
+ local out=${SWEEP_ROOT}/${tag}
35
+ mkdir -p "${out}"
36
+ local log=${out}/pipeline.log
37
+ : > "${log}"
38
+ (
39
+ export CUDA_VISIBLE_DEVICES="${gpu}"
40
+ echo "[$(date +%H:%M:%S)] === ${tag} on GPU ${gpu}: SFT lr=${sft_lr} max_steps=${sft_max} ===" >> "${log}"
41
+ "${PYTHON_BIN}" -u "${SCRIPT}" \
42
+ --phase sft \
43
+ --train_jsonl "${TRAIN_JSONL}" \
44
+ --eval_jsonl "${EVAL_JSONL}" \
45
+ --output_dir "${out}/sft" \
46
+ --learning_rate "${sft_lr}" \
47
+ --max_steps "${sft_max}" \
48
+ --per_device_train_batch_size 8 \
49
+ --gradient_accumulation_steps 2 \
50
+ --num_epochs 8 \
51
+ --logging_steps 25 \
52
+ --save_steps 200 \
53
+ --eval_rows 100 \
54
+ --max_completion_length 96 \
55
+ --max_prompt_length 1024 \
56
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
57
+ --seed 0 \
58
+ >> "${log}" 2>&1
59
+
60
+ echo "[$(date +%H:%M:%S)] === ${tag} on GPU ${gpu}: GRPO lr=${grpo_lr} max_steps=${grpo_max} ===" >> "${log}"
61
+ "${PYTHON_BIN}" -u "${SCRIPT}" \
62
+ --phase grpo \
63
+ --init_adapter_dir "${out}/sft/final" \
64
+ --train_jsonl "${TRAIN_JSONL}" \
65
+ --eval_jsonl "${EVAL_JSONL}" \
66
+ --output_dir "${out}/grpo" \
67
+ --learning_rate "${grpo_lr}" \
68
+ --max_steps "${grpo_max}" \
69
+ --per_device_train_batch_size 4 \
70
+ --gradient_accumulation_steps 2 \
71
+ --num_generations 8 \
72
+ --beta 0.0 \
73
+ --temperature 1.0 \
74
+ --num_epochs 50 \
75
+ --logging_steps 25 \
76
+ --save_steps 200 \
77
+ --eval_rows 100 \
78
+ --max_completion_length 96 \
79
+ --max_prompt_length 1024 \
80
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
81
+ --seed 0 \
82
+ >> "${log}" 2>&1
83
+ echo "[$(date +%H:%M:%S)] === ${tag} DONE ===" >> "${log}"
84
+ ) >/dev/null 2>&1 &
85
+ local pid=$!
86
+ echo "$pid $gpu $tag" >> "${SWEEP_ROOT}/PIDS.txt"
87
+ disown $pid 2>/dev/null || true
88
+ printf 'GPU %s -> %s pid=%s log=%s\n' "$gpu" "$tag" "$pid" "$log"
89
+ }
90
+
91
+ # 2 variants on GPUs 0,1: explore SFT LR (5e-5 and 1e-4) — same GRPO LR (5e-6).
92
+ run_pipeline 0 strawman_a_sft5e5_grpo5e6 5e-5 5e-6 2000 1500
93
+ run_pipeline 1 strawman_b_sft1e4_grpo5e6 1e-4 5e-6 2000 1500
94
+
95
+ echo
96
+ echo "=== launched ==="
97
+ cat "${SWEEP_ROOT}/PIDS.txt"
_runs/launch_strawman_cellpolicy.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Launch two strawman variants (single-stage cell-policy at stage_i=3, no
3
+ # curriculum, no thought tokens) on GPUs 0 and 1.
4
+ set -euo pipefail
5
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
6
+ TS="$(date +%Y%m%d_%H%M%S)"
7
+ SWEEP_ROOT="${ROOT}/_runs/strawman_cellpolicy_${TS}"
8
+ mkdir -p "${SWEEP_ROOT}"
9
+ PIPE="${ROOT}/_runs/strawman_cellpolicy_pipeline.sh"
10
+ chmod +x "${PIPE}"
11
+
12
+ launch() {
13
+ # Usage: launch <variant> <gpu> <KEY=VALUE>...
14
+ local variant="$1" gpu="$2"
15
+ shift 2
16
+ local out="${SWEEP_ROOT}/${variant}"
17
+ mkdir -p "${out}"
18
+ echo "[launch] ${variant} on GPU ${gpu} out=${out}"
19
+ nohup env VARIANT="${variant}" GPU="${gpu}" OUTPUT_ROOT="${out}" "$@" \
20
+ bash "${PIPE}" > "${out}/console.log" 2>&1 &
21
+ local pid=$!
22
+ disown "${pid}" || true
23
+ echo "${variant}=${pid}" >> "${SWEEP_ROOT}/PIDS.txt"
24
+ }
25
+
26
+ launch strawman_a_lr2e5 0 \
27
+ SFT_LR=2e-5 GRPO_LR=5e-6 SFT_MAX_STEPS=3000 GRPO_MAX_STEPS=1500 \
28
+ PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=1.0 CARD_MISMATCH_PEN=1.5 \
29
+ SFT_OVERSAMPLE=3
30
+
31
+ launch strawman_b_lr5e5 1 \
32
+ SFT_LR=5e-5 GRPO_LR=5e-6 SFT_MAX_STEPS=4000 GRPO_MAX_STEPS=1500 \
33
+ PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=1.0 CARD_MISMATCH_PEN=1.5 \
34
+ SFT_OVERSAMPLE=3
35
+
36
+ echo "[launch] sweep root: ${SWEEP_ROOT}"
37
+ echo "[launch] PIDs:"
38
+ cat "${SWEEP_ROOT}/PIDS.txt"
_runs/simple_baseline_sudoku_train.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Strawman baseline for the rebuttal.
3
+
4
+ Vanilla Qwen2.5-1.5B-Instruct + LoRA on top of the *existing* JSONL data
5
+ (`data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`).
6
+
7
+ Compared to the cell-policy / latent recipes, this strawman intentionally
8
+ removes everything that helped:
9
+
10
+ - NO curriculum (single stage; we don't even read `stage_i`).
11
+ - NO chain-of-thought / latent thought tokens.
12
+ - NO per-cell expansion (one example == one whole puzzle).
13
+ - NO multi-value oversampling, no special reward shaping (just matches/N).
14
+
15
+ It uses the *same* model, *same* LoRA config, *same* tokenizer + chat
16
+ template wrapping that every cell-policy experiment used, so any solve
17
+ gap vs the cell-policy / latent runs is purely due to task framing,
18
+ not data, prompt, model, or PEFT differences.
19
+
20
+ Usage:
21
+ python simple_baseline_sudoku_train.py --phase sft --output_dir <out>/sft --learning_rate 5e-5
22
+ python simple_baseline_sudoku_train.py --phase grpo --init_adapter_dir <out>/sft/final \
23
+ --output_dir <out>/grpo --learning_rate 5e-6
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import json
30
+ import math
31
+ import os
32
+ import re
33
+ import sys
34
+ import time
35
+ from pathlib import Path
36
+ from typing import Any, Callable, Dict, List, Optional
37
+
38
+ import torch
39
+ from datasets import Dataset
40
+ from peft import LoraConfig, PeftModel, get_peft_model
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
42
+
43
+ # Reuse existing helpers (these are the canonical ones used by every cell-policy run).
44
+ ROOT = Path(__file__).resolve().parent.parent
45
+ if str(ROOT) not in sys.path:
46
+ sys.path.insert(0, str(ROOT))
47
+
48
+ from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore
49
+ load_jsonl_rows,
50
+ pick_dtype,
51
+ )
52
+ from multi_output_cell_policy.rewards import score_prediction_text # type: ignore
53
+ from multi_output_cell_policy.shared_multi_output_policy import ( # type: ignore
54
+ make_solved_grid_from_row,
55
+ stage_i_consistent_values,
56
+ )
57
+ from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row # type: ignore
58
+
59
+
60
+ # ---- Strawman task definition -----------------------------------------------
61
+ # This is the ONLY new piece relative to the cell-policy experiments. The
62
+ # system prompt asks the model to emit the missing values for ALL empty cells
63
+ # in one shot, in the row-major order that the existing JSONL `completion`
64
+ # field already uses. The user message is the raw `prompt` field from the
65
+ # JSONL (puzzle as (row,col,value) tuples), which is byte-identical to what
66
+ # `prompt_builder.py` consumes in cell-policy runs.
67
+
68
+ SYSTEM_PROMPT_STRAWMAN = (
69
+ "You are a Sudoku solver.\n"
70
+ "You will be given a 9x9 Sudoku grid encoded as (row,col,value) tuples in "
71
+ "row-major order, where value 0 marks an empty cell.\n"
72
+ "Predict the missing values for ALL empty cells in row-major order.\n"
73
+ "Return ONLY a JSON list of integers like [v1,v2,...,vK], where K is the "
74
+ "number of empty cells (typically 20). Each value must be an integer in "
75
+ "[1,9].\n"
76
+ "Do not include any explanation, markdown, or text outside the JSON list."
77
+ )
78
+
79
+
80
+ def build_chat_prompt(tokenizer: Any, raw_prompt: str) -> str:
81
+ """Same chat template wrapping every other experiment uses (Qwen, system+user)."""
82
+ messages = [
83
+ {"role": "system", "content": SYSTEM_PROMPT_STRAWMAN.strip()},
84
+ {"role": "user", "content": raw_prompt},
85
+ ]
86
+ chat_template = getattr(tokenizer, "chat_template", None)
87
+ if chat_template:
88
+ return tokenizer.apply_chat_template(
89
+ messages, tokenize=False, add_generation_prompt=True
90
+ )
91
+ return SYSTEM_PROMPT_STRAWMAN.strip() + "\n\n" + raw_prompt + "\n"
92
+
93
+
94
+ # ---- Reward -----------------------------------------------------------------
95
+
96
+ LIST_RE = re.compile(r"\[[^\[\]]*\]")
97
+
98
+
99
+ def parse_int_list(text: str) -> Optional[List[int]]:
100
+ """Parse the model's emission as a JSON int list with values in [1,9].
101
+
102
+ Tolerant: tries the whole completion first, then falls back to the first
103
+ well-formed JSON list match. Returns None on failure.
104
+ """
105
+ s = str(text).strip()
106
+ if not s:
107
+ return None
108
+ candidates: List[str] = []
109
+ candidates.append(s)
110
+ m = LIST_RE.search(s)
111
+ if m is not None:
112
+ candidates.append(m.group(0))
113
+ for cand in candidates:
114
+ try:
115
+ obj = json.loads(cand)
116
+ except Exception:
117
+ continue
118
+ if not isinstance(obj, list):
119
+ continue
120
+ out: List[int] = []
121
+ ok = True
122
+ for v in obj:
123
+ if isinstance(v, bool) or not isinstance(v, int):
124
+ ok = False
125
+ break
126
+ if v < 1 or v > 9:
127
+ ok = False
128
+ break
129
+ out.append(int(v))
130
+ if ok:
131
+ return out
132
+ return None
133
+
134
+
135
+ def whole_puzzle_reward(
136
+ *,
137
+ pred_list: Optional[List[int]],
138
+ target_list: List[int],
139
+ parse_penalty: float = 4.0,
140
+ length_mismatch_penalty: float = 0.5,
141
+ full_solve_bonus: float = 5.0,
142
+ ) -> float:
143
+ """Simple reward: matches per cell + bonus for full solve, penalty if parse fails."""
144
+ if pred_list is None:
145
+ return -float(parse_penalty)
146
+ n = len(target_list)
147
+ matches = 0
148
+ for i in range(min(len(pred_list), n)):
149
+ if int(pred_list[i]) == int(target_list[i]):
150
+ matches += 1
151
+ reward = float(matches)
152
+ if len(pred_list) != n:
153
+ reward -= float(length_mismatch_penalty) * abs(len(pred_list) - n)
154
+ if len(pred_list) == n and matches == n:
155
+ reward += float(full_solve_bonus)
156
+ return reward
157
+
158
+
159
+ # ---- Dataset construction ---------------------------------------------------
160
+
161
+
162
+ def build_dataset(rows: List[Dict[str, Any]], tokenizer: Any) -> Dataset:
163
+ prompts, completions, targets = [], [], []
164
+ for row in rows:
165
+ raw_prompt = str(row["prompt"]).strip()
166
+ completion_str = str(row["completion"]).strip()
167
+ target = parse_int_list(completion_str)
168
+ if target is None:
169
+ continue
170
+ prompts.append(build_chat_prompt(tokenizer, raw_prompt))
171
+ completions.append(completion_str)
172
+ targets.append(json.dumps(target, separators=(",", ":")))
173
+ return Dataset.from_dict(
174
+ {"prompt": prompts, "completion": completions, "target": targets}
175
+ )
176
+
177
+
178
+ # ---- Eval (deterministic, greedy, single-shot) ------------------------------
179
+
180
+
181
+ @torch.no_grad()
182
+ @torch.no_grad()
183
+ def run_eval(
184
+ model: torch.nn.Module,
185
+ tokenizer: Any,
186
+ eval_rows: List[Dict[str, Any]],
187
+ device: torch.device,
188
+ max_new_tokens: int = 96,
189
+ print_n: int = 3,
190
+ stage_i: int = 3,
191
+ ) -> Dict[str, float]:
192
+ """Apples-to-apples eval with the cell-policy framework.
193
+
194
+ The strawman model emits the WHOLE puzzle (a JSON list of integers) in
195
+ one forward pass. We then split that list into per-cell SINGLETON
196
+ predictions and score each cell with the same ``score_prediction_text``
197
+ function the cell-policy / latent baselines use, against the i-consistent
198
+ target set at ``stage_i`` (default 3 — matching the S3 eval used for the
199
+ rebuttal v6 baseline and the latent champion).
200
+
201
+ Reported metrics mirror ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``
202
+ so numbers are directly comparable across all four 2x2 ablation cells.
203
+ """
204
+ model.eval()
205
+ total_cells = 0
206
+ parse_ok = 0.0
207
+ canonical_ok = 0.0
208
+ exact_set_match = 0.0
209
+ includes_gt = 0.0
210
+ precision_sum = 0.0
211
+ recall_sum = 0.0
212
+ cardinality_match_sum = 0.0
213
+ n_solve = 0
214
+ n_total_puzzles = 0
215
+ n_parse_fail_puzzles = 0
216
+ printed = 0
217
+ for row in eval_rows:
218
+ target_completion = parse_int_list(str(row["completion"]))
219
+ if target_completion is None:
220
+ continue
221
+ n_total_puzzles += 1
222
+ prompt = build_chat_prompt(tokenizer, str(row["prompt"]).strip())
223
+ enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
224
+ enc = {k: v.to(device) for k, v in enc.items()}
225
+ out = model.generate(
226
+ **enc,
227
+ max_new_tokens=int(max_new_tokens),
228
+ do_sample=False,
229
+ eos_token_id=tokenizer.eos_token_id,
230
+ pad_token_id=tokenizer.pad_token_id,
231
+ )
232
+ gen = tokenizer.decode(
233
+ out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True
234
+ ).strip()
235
+ pred_list = parse_int_list(gen)
236
+
237
+ try:
238
+ cells = build_cell_examples_from_row(row)
239
+ solved = make_solved_grid_from_row(row)
240
+ except Exception as e:
241
+ if printed < print_n:
242
+ print(f"[strawman eval debug] row skipped (no metadata): {e}", flush=True)
243
+ printed += 1
244
+ continue
245
+
246
+ row_all_exact = True
247
+ row_has_eval_cell = False
248
+ for idx, ex in enumerate(cells):
249
+ target_values = stage_i_consistent_values(
250
+ ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i)
251
+ )
252
+ row_has_eval_cell = True
253
+ if pred_list is not None and idx < len(pred_list):
254
+ pred_text = json.dumps({"values": [int(pred_list[idx])]})
255
+ else:
256
+ pred_text = ""
257
+ info = score_prediction_text(
258
+ text=pred_text,
259
+ grid=ex.grid,
260
+ solved=solved,
261
+ target_cell=ex.target_cell,
262
+ stage_i=int(stage_i),
263
+ reward_good_value=1.0,
264
+ penalty_bad_value=1.75,
265
+ penalty_malformed=4.0,
266
+ penalty_empty=0.5,
267
+ penalty_singleton=1.5,
268
+ )
269
+ total_cells += 1
270
+ parse_ok += float(info["parse_ok"])
271
+ canonical_ok += float(info["strict_canonical"])
272
+ exact_set_match += float(info["exact_set_match"])
273
+ includes_gt += float(info["includes_ground_truth"])
274
+ precision_sum += float(info["value_precision"])
275
+ recall_sum += float(info["value_recall"])
276
+ if int(info["num_predicted_values"]) == int(len(target_values)):
277
+ cardinality_match_sum += 1.0
278
+ if float(info["exact_set_match"]) < 0.5:
279
+ row_all_exact = False
280
+ if row_has_eval_cell and row_all_exact:
281
+ n_solve += 1
282
+ if pred_list is None:
283
+ n_parse_fail_puzzles += 1
284
+ if printed < print_n:
285
+ head_pred = pred_list if pred_list is not None else "PARSE_FAIL"
286
+ print(
287
+ f"[strawman eval debug] target={target_completion} pred={head_pred} "
288
+ f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}",
289
+ flush=True,
290
+ )
291
+ printed += 1
292
+ return {
293
+ "n_total_cells": float(total_cells),
294
+ "n_total_puzzles": float(n_total_puzzles),
295
+ "parse_rate": float(parse_ok / max(1, total_cells)),
296
+ "strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
297
+ "exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
298
+ "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
299
+ "value_precision": float(precision_sum / max(1, total_cells)),
300
+ "value_recall": float(recall_sum / max(1, total_cells)),
301
+ "cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)),
302
+ "puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)),
303
+ "solve_rate": float(n_solve) / max(1, n_total_puzzles),
304
+ }
305
+
306
+
307
+ # ---- Main -------------------------------------------------------------------
308
+
309
+
310
+ def parse_args() -> argparse.Namespace:
311
+ p = argparse.ArgumentParser()
312
+ p.add_argument("--phase", choices=["sft", "grpo"], required=True)
313
+ p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
314
+ p.add_argument("--train_jsonl", type=str, required=True)
315
+ p.add_argument("--eval_jsonl", type=str, required=True)
316
+ p.add_argument("--output_dir", type=str, required=True)
317
+ p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache"))
318
+ p.add_argument("--init_adapter_dir", type=str, default="")
319
+ p.add_argument("--seed", type=int, default=0)
320
+
321
+ # Data
322
+ p.add_argument("--limit_train_rows", type=int, default=10000)
323
+ p.add_argument("--eval_rows", type=int, default=100)
324
+
325
+ # Train hyperparameters
326
+ p.add_argument("--per_device_train_batch_size", type=int, default=8)
327
+ p.add_argument("--gradient_accumulation_steps", type=int, default=2)
328
+ p.add_argument("--learning_rate", type=float, default=5e-5)
329
+ p.add_argument("--weight_decay", type=float, default=0.0)
330
+ p.add_argument("--num_epochs", type=float, default=8.0)
331
+ p.add_argument("--max_steps", type=int, default=2000)
332
+ p.add_argument("--logging_steps", type=int, default=25)
333
+ p.add_argument("--save_steps", type=int, default=200)
334
+ p.add_argument("--eval_steps", type=int, default=150)
335
+ p.add_argument("--max_grad_norm", type=float, default=1.0)
336
+ p.add_argument("--max_completion_length", type=int, default=96)
337
+ p.add_argument("--max_prompt_length", type=int, default=1024)
338
+
339
+ # LoRA
340
+ p.add_argument("--lora_r", type=int, default=32)
341
+ p.add_argument("--lora_alpha", type=int, default=64)
342
+ p.add_argument("--lora_dropout", type=float, default=0.05)
343
+ p.add_argument("--enable_gradient_checkpointing", action="store_true")
344
+
345
+ # GRPO-only
346
+ p.add_argument("--num_generations", type=int, default=8)
347
+ p.add_argument("--beta", type=float, default=0.0)
348
+ p.add_argument("--temperature", type=float, default=1.0)
349
+ p.add_argument("--full_solve_bonus", type=float, default=5.0)
350
+ p.add_argument("--length_mismatch_penalty", type=float, default=0.5)
351
+ p.add_argument("--parse_penalty", type=float, default=4.0)
352
+
353
+ # W&B
354
+ p.add_argument("--use_wandb", action="store_true")
355
+ p.add_argument("--wandb_project", type=str, default="sudoku-strawman-baseline")
356
+ p.add_argument("--wandb_run_name", type=str, default="")
357
+ p.add_argument("--wandb_mode", type=str, default="offline")
358
+
359
+ return p.parse_args()
360
+
361
+
362
+ def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device):
363
+ tokenizer = AutoTokenizer.from_pretrained(
364
+ args.model_name, cache_dir=args.cache_dir, use_fast=True
365
+ )
366
+ if tokenizer.pad_token_id is None:
367
+ tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
368
+ if tokenizer.padding_side != "left":
369
+ tokenizer.padding_side = "left"
370
+
371
+ model = AutoModelForCausalLM.from_pretrained(
372
+ args.model_name,
373
+ cache_dir=args.cache_dir,
374
+ torch_dtype=pick_dtype(),
375
+ low_cpu_mem_usage=True,
376
+ )
377
+ if str(args.init_adapter_dir).strip():
378
+ model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
379
+ else:
380
+ lora = LoraConfig(
381
+ r=args.lora_r,
382
+ lora_alpha=args.lora_alpha,
383
+ lora_dropout=args.lora_dropout,
384
+ bias="none",
385
+ task_type="CAUSAL_LM",
386
+ target_modules=[
387
+ "q_proj", "k_proj", "v_proj", "o_proj",
388
+ "gate_proj", "up_proj", "down_proj",
389
+ ],
390
+ )
391
+ model = get_peft_model(model, lora)
392
+
393
+ if args.enable_gradient_checkpointing:
394
+ if hasattr(model, "gradient_checkpointing_enable"):
395
+ model.gradient_checkpointing_enable(
396
+ gradient_checkpointing_kwargs={"use_reentrant": False}
397
+ )
398
+ if hasattr(model, "enable_input_require_grads"):
399
+ model.enable_input_require_grads()
400
+ if hasattr(model, "config"):
401
+ model.config.use_cache = False
402
+ model.to(device)
403
+ return model, tokenizer
404
+
405
+
406
+ def run_sft(args: argparse.Namespace) -> None:
407
+ from trl import SFTConfig, SFTTrainer # type: ignore
408
+
409
+ set_seed(int(args.seed))
410
+ os.makedirs(args.output_dir, exist_ok=True)
411
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
412
+
413
+ train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
414
+ eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
415
+
416
+ model, tokenizer = setup_model_and_tokenizer(args, device)
417
+
418
+ # Build dataset of {prompt, completion} where prompt is chat-templated.
419
+ train_ds = build_dataset(train_rows, tokenizer)
420
+
421
+ cfg = SFTConfig(
422
+ output_dir=args.output_dir,
423
+ per_device_train_batch_size=int(args.per_device_train_batch_size),
424
+ gradient_accumulation_steps=int(args.gradient_accumulation_steps),
425
+ learning_rate=float(args.learning_rate),
426
+ weight_decay=float(args.weight_decay),
427
+ num_train_epochs=float(args.num_epochs),
428
+ max_steps=int(args.max_steps),
429
+ logging_steps=int(args.logging_steps),
430
+ save_steps=int(args.save_steps),
431
+ save_strategy="steps",
432
+ save_total_limit=4,
433
+ eval_strategy="no",
434
+ bf16=(pick_dtype() == torch.bfloat16),
435
+ fp16=(pick_dtype() == torch.float16),
436
+ max_grad_norm=float(args.max_grad_norm),
437
+ gradient_checkpointing=bool(args.enable_gradient_checkpointing),
438
+ report_to=("wandb" if args.use_wandb else "none"),
439
+ run_name=(args.wandb_run_name or None),
440
+ max_length=int(args.max_prompt_length + args.max_completion_length + 8),
441
+ completion_only_loss=True,
442
+ seed=int(args.seed),
443
+ )
444
+
445
+ trainer = SFTTrainer(
446
+ model=model,
447
+ args=cfg,
448
+ train_dataset=train_ds,
449
+ processing_class=tokenizer,
450
+ )
451
+
452
+ # Periodic eval hook (TRL doesn't natively give us a custom eval loop hook,
453
+ # so we run eval before training and after the final step here).
454
+ print("[strawman sft] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
455
+
456
+ t0 = time.time()
457
+ trainer.train()
458
+ print(f"[strawman sft] training time = {time.time() - t0:.1f}s", flush=True)
459
+
460
+ final_dir = os.path.join(args.output_dir, "final")
461
+ trainer.save_model(final_dir)
462
+ print(f"[strawman sft] saved final adapter to {final_dir}", flush=True)
463
+
464
+ print("[strawman sft] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
465
+
466
+
467
+ def run_grpo(args: argparse.Namespace) -> None:
468
+ from trl import GRPOConfig, GRPOTrainer # type: ignore
469
+
470
+ set_seed(int(args.seed))
471
+ os.makedirs(args.output_dir, exist_ok=True)
472
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
473
+
474
+ train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
475
+ eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
476
+
477
+ model, tokenizer = setup_model_and_tokenizer(args, device)
478
+ train_ds = build_dataset(train_rows, tokenizer)
479
+
480
+ parse_penalty = float(args.parse_penalty)
481
+ length_mismatch_penalty = float(args.length_mismatch_penalty)
482
+ full_solve_bonus = float(args.full_solve_bonus)
483
+
484
+ def reward_fn(completions, target, **kwargs):
485
+ rewards: List[float] = []
486
+ for c, tgt in zip(completions, target):
487
+ tgt_list = json.loads(tgt) if isinstance(tgt, str) else list(tgt)
488
+ pred = parse_int_list(str(c))
489
+ rewards.append(
490
+ whole_puzzle_reward(
491
+ pred_list=pred,
492
+ target_list=tgt_list,
493
+ parse_penalty=parse_penalty,
494
+ length_mismatch_penalty=length_mismatch_penalty,
495
+ full_solve_bonus=full_solve_bonus,
496
+ )
497
+ )
498
+ return rewards
499
+
500
+ cfg = GRPOConfig(
501
+ output_dir=args.output_dir,
502
+ per_device_train_batch_size=int(args.per_device_train_batch_size),
503
+ gradient_accumulation_steps=int(args.gradient_accumulation_steps),
504
+ learning_rate=float(args.learning_rate),
505
+ weight_decay=float(args.weight_decay),
506
+ num_train_epochs=float(args.num_epochs),
507
+ max_steps=int(args.max_steps),
508
+ logging_steps=int(args.logging_steps),
509
+ save_steps=int(args.save_steps),
510
+ save_strategy="steps",
511
+ save_total_limit=6,
512
+ bf16=(pick_dtype() == torch.bfloat16),
513
+ fp16=(pick_dtype() == torch.float16),
514
+ max_grad_norm=float(args.max_grad_norm),
515
+ gradient_checkpointing=bool(args.enable_gradient_checkpointing),
516
+ report_to=("wandb" if args.use_wandb else "none"),
517
+ run_name=(args.wandb_run_name or None),
518
+ max_prompt_length=int(args.max_prompt_length),
519
+ max_completion_length=int(args.max_completion_length),
520
+ num_generations=int(args.num_generations),
521
+ beta=float(args.beta),
522
+ temperature=float(args.temperature),
523
+ seed=int(args.seed),
524
+ )
525
+
526
+ trainer = GRPOTrainer(
527
+ model=model,
528
+ reward_funcs=[reward_fn],
529
+ args=cfg,
530
+ train_dataset=train_ds,
531
+ processing_class=tokenizer,
532
+ )
533
+
534
+ print("[strawman grpo] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
535
+
536
+ t0 = time.time()
537
+ trainer.train()
538
+ print(f"[strawman grpo] training time = {time.time() - t0:.1f}s", flush=True)
539
+
540
+ final_dir = os.path.join(args.output_dir, "final")
541
+ trainer.save_model(final_dir)
542
+ print(f"[strawman grpo] saved final adapter to {final_dir}", flush=True)
543
+
544
+ print("[strawman grpo] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
545
+
546
+
547
+ def main() -> None:
548
+ args = parse_args()
549
+ if args.use_wandb:
550
+ os.environ.setdefault("WANDB_MODE", str(args.wandb_mode))
551
+ os.environ["WANDB_PROJECT"] = args.wandb_project
552
+ if args.phase == "sft":
553
+ run_sft(args)
554
+ else:
555
+ run_grpo(args)
556
+
557
+
558
+ if __name__ == "__main__":
559
+ main()
_runs/status.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # One-shot snapshot of the active sweep.
3
+ SWEEP="${1:-$(ls -dt /home/ubuntu/curriculum_cot/_runs/baseline_1p5b_v4_* 2>/dev/null | head -1)}"
4
+ [[ -z "${SWEEP}" || ! -d "${SWEEP}" ]] && { echo "no sweep"; exit 1; }
5
+ echo "=== sweep: ${SWEEP} ==="
6
+ echo "=== nvidia-smi ==="
7
+ nvidia-smi --query-gpu=index,utilization.gpu,memory.used,memory.total,power.draw --format=csv,noheader
8
+ echo
9
+ echo "=== pids ==="
10
+ while read -r pid gpu name; do
11
+ if kill -0 "$pid" 2>/dev/null; then alive=ALIVE; else alive=DEAD; fi
12
+ printf ' pid=%-6s gpu=%s %-30s %s\n' "$pid" "$gpu" "$name" "$alive"
13
+ done < "${SWEEP}/PIDS.txt"
14
+ echo
15
+ echo "=== per-variant phase + best/last eval ==="
16
+ for v in "${SWEEP}"/pipe_*; do
17
+ vn="$(basename "$v")"
18
+ current_phase="(starting)"
19
+ for ph in s2_sft_extra s2_grpo s3_sft s3_grpo; do
20
+ [[ -d "$v/$ph" ]] && current_phase="$ph"
21
+ done
22
+ printf '\n--- %s (phase=%s) ---\n' "$vn" "${current_phase}"
23
+ # Pipeline log tail
24
+ if [[ -f "$v/PIPELINE.log" ]]; then
25
+ tail -3 "$v/PIPELINE.log" | sed 's/^/ PL: /'
26
+ fi
27
+ # Phase-specific evals
28
+ for ph in s2_sft_extra s2_grpo s3_sft s3_grpo; do
29
+ log="$v/$ph/train.log"
30
+ [[ -f "$log" ]] || continue
31
+ # SFT eval lines
32
+ last_sft="$(grep -E "\[baseline sft eval\] " "$log" 2>/dev/null | tail -3)"
33
+ last_grpo="$(grep -E "\[baseline grpo (custom )?eval" "$log" 2>/dev/null | tail -3)"
34
+ last_train="$(grep -E "\[baseline (sft|grpo) (train|final)" "$log" 2>/dev/null | tail -1)"
35
+ if [[ -n "$last_sft$last_grpo$last_train" ]]; then
36
+ printf ' [%s]\n' "$ph"
37
+ [[ -n "$last_train" ]] && echo "$last_train" | sed 's/^/ tr: /'
38
+ [[ -n "$last_sft" ]] && echo "$last_sft" | sed 's/^/ ev: /'
39
+ [[ -n "$last_grpo" ]] && echo "$last_grpo" | sed 's/^/ ev: /'
40
+ fi
41
+ done
42
+ done
_runs/strawman_cellpolicy_pipeline.sh ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Strawman = single-stage cell-policy at stage_i=3 from BASE (no curriculum,
3
+ # no thought tokens). Same per-cell prompt, same trainer scripts, same scoring
4
+ # function as the v6 baseline and the latent champion. The ONLY differences
5
+ # vs the v6 baseline are:
6
+ # - No prior SFT/GRPO at stage_i=1 or stage_i=2 (start fresh from base Qwen).
7
+ # - Single SFT phase + single GRPO phase, both at stage_i=3.
8
+ # - No latent recurrent-hidden tokens (vanilla LoRA on base model).
9
+ # Required env vars: VARIANT, GPU, OUTPUT_ROOT.
10
+ set -euo pipefail
11
+
12
+ ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
13
+ PYTHON_BIN="${PYTHON_BIN:-/opt/pytorch/bin/python}"
14
+ SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
15
+ GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
16
+
17
+ : "${VARIANT:?VARIANT required}"
18
+ : "${GPU:?GPU required}"
19
+
20
+ OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/_runs/strawman_cellpolicy_$(date +%Y%m%d_%H%M%S)/${VARIANT}}"
21
+ MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
22
+
23
+ # Use the same S3 hyperparameters as the v6 baseline so the only knob is
24
+ # "did we do the curriculum or not".
25
+ SFT_LR="${SFT_LR:-2e-5}"
26
+ SFT_BS="${SFT_BS:-16}"
27
+ SFT_GA="${SFT_GA:-2}"
28
+ SFT_MAX_STEPS="${SFT_MAX_STEPS:-3000}"
29
+
30
+ GRPO_LR="${GRPO_LR:-5e-6}"
31
+ GRPO_BETA="${GRPO_BETA:-0.0}"
32
+ GRPO_NG="${GRPO_NG:-8}"
33
+ GRPO_BS="${GRPO_BS:-16}"
34
+ GRPO_GA="${GRPO_GA:-2}"
35
+ GRPO_PROMPT="${GRPO_PROMPT:-768}"
36
+ GRPO_COMPL="${GRPO_COMPL:-24}"
37
+ GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-1500}"
38
+
39
+ # v6-style reward shaping (same as the v6 sweep that hit solve=0.44).
40
+ REWARD_GOOD="${REWARD_GOOD:-1.25}"
41
+ PENALTY_BAD="${PENALTY_BAD:-1.0}"
42
+ PENALTY_MAL="${PENALTY_MAL:-4.0}"
43
+ PENALTY_EMPTY="${PENALTY_EMPTY:-0.5}"
44
+ PENALTY_SINGLETON="${PENALTY_SINGLETON:-1.5}"
45
+ PENALTY_MISSING="${PENALTY_MISSING:-1.0}"
46
+ EXACT_MATCH_BONUS="${EXACT_MATCH_BONUS:-1.0}"
47
+ CARD_MISMATCH_PEN="${CARD_MISMATCH_PEN:-1.5}"
48
+ SFT_OVERSAMPLE="${SFT_OVERSAMPLE:-3}"
49
+ SFT_TGT_MIN="${SFT_TGT_MIN:-0}"
50
+ SFT_TGT_MAX="${SFT_TGT_MAX:-0}"
51
+
52
+ VALUE_TARGET="${VALUE_TARGET:-0.98}"
53
+ EVAL_ROWS="${EVAL_ROWS:-100}"
54
+ TRAIN_ROWS="${TRAIN_ROWS:-10000}"
55
+ USE_GC="${USE_GC:-1}" # GC=1 to allow bs 16 on a single 80G GPU
56
+ PHASE_WALL_SECS="${PHASE_WALL_SECS:-0}"
57
+
58
+ TRAIN_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
59
+ EVAL_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
60
+
61
+ mkdir -p "${OUTPUT_ROOT}"
62
+ PIPELINE_LOG="${OUTPUT_ROOT}/PIPELINE.log"
63
+ ts() { date +'%H:%M:%S'; }
64
+ log() { printf '[%s] %s\n' "$(ts)" "$*" | tee -a "${PIPELINE_LOG}" >&2; }
65
+
66
+ best_ckpt() {
67
+ local d="$1"
68
+ if [[ -f "${d}/adapter_model.safetensors" ]]; then
69
+ printf '%s\n' "${d}"; return 0
70
+ fi
71
+ shopt -s nullglob
72
+ local cks=("${d}"/checkpoint-step-* "${d}"/checkpoint-*)
73
+ shopt -u nullglob
74
+ (( ${#cks[@]} == 0 )) && return 1
75
+ printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
76
+ }
77
+
78
+ if [[ ! -f "${TRAIN_JSONL}" || ! -f "${EVAL_JSONL}" ]]; then
79
+ log "ERROR: missing dataset jsonls"; exit 1
80
+ fi
81
+
82
+ export CUDA_VISIBLE_DEVICES="${GPU}"
83
+ export TOKENIZERS_PARALLELISM=false
84
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
85
+ export HF_HOME="${ROOT}/.hf_cache"
86
+ export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
87
+
88
+ GC_FLAG=()
89
+ if [[ "${USE_GC}" == "1" ]]; then GC_FLAG=(--enable_gradient_checkpointing); fi
90
+
91
+ log "===== STRAWMAN ${VARIANT} on GPU ${GPU} ====="
92
+ log " SFT lr=${SFT_LR} max_steps=${SFT_MAX_STEPS} bs=${SFT_BS}x${SFT_GA} GC=${USE_GC}"
93
+ log " GRPO lr=${GRPO_LR} max_steps=${GRPO_MAX_STEPS} ng=${GRPO_NG} bs=${GRPO_BS}x${GRPO_GA}"
94
+ log " rewards good=${REWARD_GOOD} bad=${PENALTY_BAD} mal=${PENALTY_MAL} empty=${PENALTY_EMPTY} sng=${PENALTY_SINGLETON} miss=${PENALTY_MISSING} bonus=${EXACT_MATCH_BONUS} card=${CARD_MISMATCH_PEN}"
95
+ log " out=${OUTPUT_ROOT}"
96
+
97
+ # ----- Phase 1: SFT at stage_i=3 from BASE (no init adapter) -----
98
+ SFT_DIR="${OUTPUT_ROOT}/sft"
99
+ mkdir -p "${SFT_DIR}"
100
+ log "=== PHASE SFT (stage_i=3, init=BASE) ==="
101
+ "${PYTHON_BIN}" -u "${SFT_SCRIPT}" \
102
+ --model_name "${MODEL_NAME}" \
103
+ --train_jsonl "${TRAIN_JSONL}" \
104
+ --eval_jsonl "${EVAL_JSONL}" \
105
+ --output_dir "${SFT_DIR}" \
106
+ --cache_dir "${ROOT}/.hf_cache" \
107
+ --init_adapter_dir "" \
108
+ --seed 0 \
109
+ --gpu_id 0 \
110
+ --stage_i 3 \
111
+ --total_empties_hint 20 \
112
+ --per_device_train_batch_size "${SFT_BS}" \
113
+ --gradient_accumulation_steps "${SFT_GA}" \
114
+ --num_epochs 256 \
115
+ --learning_rate "${SFT_LR}" \
116
+ --max_grad_norm 1.0 \
117
+ --logging_steps 25 \
118
+ --eval_steps 200 \
119
+ --save_steps 200 \
120
+ --eval_rows "${EVAL_ROWS}" \
121
+ --max_completion_length 24 \
122
+ --limit_train_rows "${TRAIN_ROWS}" \
123
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
124
+ --eval_value_precision_stop "${VALUE_TARGET}" \
125
+ --eval_value_recall_stop "${VALUE_TARGET}" \
126
+ --eval_exact_set_match_stop 0 \
127
+ --eval_solve_rate_stop 0 \
128
+ --min_steps_before_stop 200 \
129
+ --max_wall_clock_seconds "${PHASE_WALL_SECS}" \
130
+ --max_steps "${SFT_MAX_STEPS}" \
131
+ --multi_value_oversample_factor "${SFT_OVERSAMPLE}" \
132
+ --train_target_size_min "${SFT_TGT_MIN}" \
133
+ --train_target_size_max "${SFT_TGT_MAX}" \
134
+ "${GC_FLAG[@]}" 2>&1 | tee "${SFT_DIR}/train.log"
135
+
136
+ SFT_CKPT="$(best_ckpt "${SFT_DIR}")" || { log "ERROR: no SFT ckpt"; exit 1; }
137
+ log ">>> SFT ckpt: ${SFT_CKPT}"
138
+
139
+ # ----- Phase 2: GRPO at stage_i=3 from SFT output -----
140
+ GRPO_DIR="${OUTPUT_ROOT}/grpo"
141
+ mkdir -p "${GRPO_DIR}"
142
+ log "=== PHASE GRPO (stage_i=3, init=${SFT_CKPT}) ==="
143
+ "${PYTHON_BIN}" -u "${GRPO_SCRIPT}" \
144
+ --model_name "${MODEL_NAME}" \
145
+ --train_jsonl "${TRAIN_JSONL}" \
146
+ --eval_jsonl "${EVAL_JSONL}" \
147
+ --output_dir "${GRPO_DIR}" \
148
+ --cache_dir "${ROOT}/.hf_cache" \
149
+ --init_adapter_dir "${SFT_CKPT}" \
150
+ --seed 0 \
151
+ --gpu_id 0 \
152
+ --stage_i 3 \
153
+ --total_empties_hint 20 \
154
+ --per_device_train_batch_size "${GRPO_BS}" \
155
+ --gradient_accumulation_steps "${GRPO_GA}" \
156
+ --num_train_epochs 100 \
157
+ --learning_rate "${GRPO_LR}" \
158
+ --logging_steps 10 \
159
+ --save_steps 200 \
160
+ --eval_steps 150 \
161
+ --eval_rows "${EVAL_ROWS}" \
162
+ --num_generations "${GRPO_NG}" \
163
+ --max_prompt_length "${GRPO_PROMPT}" \
164
+ --max_completion_length "${GRPO_COMPL}" \
165
+ --beta "${GRPO_BETA}" \
166
+ --limit_train_rows "${TRAIN_ROWS}" \
167
+ --lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
168
+ --reward_good_value "${REWARD_GOOD}" \
169
+ --penalty_bad_value "${PENALTY_BAD}" \
170
+ --penalty_malformed "${PENALTY_MAL}" \
171
+ --penalty_empty "${PENALTY_EMPTY}" \
172
+ --penalty_singleton "${PENALTY_SINGLETON}" \
173
+ --penalty_missing "${PENALTY_MISSING}" \
174
+ --exact_match_bonus "${EXACT_MATCH_BONUS}" \
175
+ --cardinality_mismatch_penalty "${CARD_MISMATCH_PEN}" \
176
+ --eval_value_precision_stop "${VALUE_TARGET}" \
177
+ --eval_value_recall_stop "${VALUE_TARGET}" \
178
+ --eval_solve_rate_stop 0 \
179
+ --min_steps_before_stop 100 \
180
+ --max_wall_clock_seconds "${PHASE_WALL_SECS}" \
181
+ --max_steps "${GRPO_MAX_STEPS}" \
182
+ "${GC_FLAG[@]}" 2>&1 | tee "${GRPO_DIR}/train.log"
183
+
184
+ GRPO_CKPT="$(best_ckpt "${GRPO_DIR}")" || { log "WARN: no GRPO ckpt found"; exit 0; }
185
+ log ">>> GRPO ckpt: ${GRPO_CKPT}"
186
+ log "===== STRAWMAN ${VARIANT} done ====="