Add data/ JSONLs + _runs/ launch scripts (override .gitignore)
Browse files- _runs/LATENT_PID.txt +1 -0
- _runs/adaptive_k_cellpolicy_pipeline.py +430 -0
- _runs/adaptive_latent_baseline_sudoku_train.py +534 -0
- _runs/add_variants_g_h.sh +57 -0
- _runs/add_variants_i_j_k_l.sh +94 -0
- _runs/baseline_1p5b_pipeline_v4.sh +328 -0
- _runs/eval_strawman_cellpolicy.py +132 -0
- _runs/launch_adaptive_k_cellpolicy.sh +42 -0
- _runs/launch_adaptive_latent_baseline.sh +76 -0
- _runs/launch_baseline_1p5b_v4.sh +82 -0
- _runs/launch_baseline_push_v5.sh +84 -0
- _runs/launch_baseline_push_v6.sh +123 -0
- _runs/launch_latent_reproduction_overnight.sh +82 -0
- _runs/launch_simple_baseline.sh +97 -0
- _runs/launch_strawman_cellpolicy.sh +38 -0
- _runs/simple_baseline_sudoku_train.py +559 -0
- _runs/status.sh +42 -0
- _runs/strawman_cellpolicy_pipeline.sh +186 -0
_runs/LATENT_PID.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
164065 0-7 latent_reproduction_20260524_062728
|
_runs/adaptive_k_cellpolicy_pipeline.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Adaptive-k cell-policy pipeline (no curriculum).
|
| 3 |
+
|
| 4 |
+
Wraps the existing per-cell trainers to implement an "adaptive-k" schedule:
|
| 5 |
+
the model is trained at stage_i=3 only (no curriculum), with the number of
|
| 6 |
+
recurrent-hidden thought tokens k starting at 0 (vanilla SFT) and being
|
| 7 |
+
incremented whenever the eval exact_set_match metric plateaus. Each phase
|
| 8 |
+
runs ``sft_latent_multi_output_train.py`` for ``steps_per_phase`` SFT steps
|
| 9 |
+
at fixed k, initialised from the previous phase's best checkpoint (so the
|
| 10 |
+
recurrent-hidden bank persists). After the final SFT phase, ``grpo_residual_projector_latent_train.py``
|
| 11 |
+
is invoked at the converged k.
|
| 12 |
+
|
| 13 |
+
The trainer scripts, prompt template, and scoring function are the *same*
|
| 14 |
+
ones used by every cell-policy / latent experiment. The only knob this
|
| 15 |
+
orchestrator provides is the k-schedule; per-cell prompt+supervision is
|
| 16 |
+
handled by the existing trainers.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
import shutil
|
| 26 |
+
import subprocess
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import List, Optional
|
| 31 |
+
|
| 32 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 33 |
+
SFT_SCRIPT = ROOT / "latent_multi_output_cell_policy" / "sft_latent_multi_output_train.py"
|
| 34 |
+
GRPO_SCRIPT = ROOT / "latent_multi_output_cell_policy" / "grpo_residual_projector_latent_train.py"
|
| 35 |
+
TRAIN_JSONL = ROOT / "data" / "sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
|
| 36 |
+
EVAL_JSONL = ROOT / "data" / "sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_args() -> argparse.Namespace:
|
| 40 |
+
p = argparse.ArgumentParser()
|
| 41 |
+
p.add_argument("--variant", required=True)
|
| 42 |
+
p.add_argument("--gpu", required=True)
|
| 43 |
+
p.add_argument("--output_root", required=True)
|
| 44 |
+
p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 45 |
+
p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
|
| 46 |
+
p.add_argument("--python_bin", default="/opt/pytorch/bin/python")
|
| 47 |
+
p.add_argument("--latent_mode", default="recurrent_hidden")
|
| 48 |
+
p.add_argument("--start_k", type=int, default=0)
|
| 49 |
+
p.add_argument("--max_k", type=int, default=4)
|
| 50 |
+
p.add_argument("--steps_per_phase", type=int, default=600)
|
| 51 |
+
p.add_argument(
|
| 52 |
+
"--max_phases_per_k",
|
| 53 |
+
type=int,
|
| 54 |
+
default=2,
|
| 55 |
+
help="Hard cap on how many ``steps_per_phase`` chunks to spend at a single k before bumping.",
|
| 56 |
+
)
|
| 57 |
+
p.add_argument(
|
| 58 |
+
"--plateau_eps",
|
| 59 |
+
type=float,
|
| 60 |
+
default=0.01,
|
| 61 |
+
help="If eval exact_set_match_rate improves by less than this between two consecutive phases at the same k, declare a plateau and bump k.",
|
| 62 |
+
)
|
| 63 |
+
p.add_argument("--sft_lr", type=float, default=2e-5)
|
| 64 |
+
p.add_argument("--sft_bs", type=int, default=8)
|
| 65 |
+
p.add_argument("--sft_ga", type=int, default=4)
|
| 66 |
+
p.add_argument("--sft_oversample", type=int, default=3)
|
| 67 |
+
p.add_argument("--grpo_steps", type=int, default=1500)
|
| 68 |
+
p.add_argument("--grpo_lr", type=float, default=5e-6)
|
| 69 |
+
p.add_argument("--grpo_bs", type=int, default=8)
|
| 70 |
+
p.add_argument("--grpo_ga", type=int, default=4)
|
| 71 |
+
p.add_argument("--grpo_ng", type=int, default=8)
|
| 72 |
+
p.add_argument("--grpo_beta", type=float, default=0.0)
|
| 73 |
+
p.add_argument("--grpo_max_prompt", type=int, default=768)
|
| 74 |
+
p.add_argument("--grpo_max_completion", type=int, default=24)
|
| 75 |
+
p.add_argument("--eval_rows", type=int, default=100)
|
| 76 |
+
p.add_argument("--train_rows", type=int, default=10000)
|
| 77 |
+
p.add_argument("--enable_gc", action="store_true", default=True)
|
| 78 |
+
p.add_argument("--seed", type=int, default=0)
|
| 79 |
+
return p.parse_args()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---- log parsing -----------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
EVAL_RE = re.compile(r"exact_set_match_rate.*?([01]\.\d+)")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def latest_eval_metric(log_path: Path) -> Optional[float]:
|
| 88 |
+
"""Return the most recent eval exact_set_match_rate from the SFT train log."""
|
| 89 |
+
if not log_path.exists():
|
| 90 |
+
return None
|
| 91 |
+
last: Optional[float] = None
|
| 92 |
+
with open(log_path) as f:
|
| 93 |
+
for line in f:
|
| 94 |
+
m = EVAL_RE.search(line)
|
| 95 |
+
if m:
|
| 96 |
+
try:
|
| 97 |
+
last = float(m.group(1))
|
| 98 |
+
except ValueError:
|
| 99 |
+
continue
|
| 100 |
+
return last
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def latest_ckpt_dir(out_dir: Path) -> Optional[Path]:
|
| 104 |
+
if not out_dir.exists():
|
| 105 |
+
return None
|
| 106 |
+
cks = sorted(
|
| 107 |
+
[p for p in out_dir.iterdir() if p.is_dir() and p.name.startswith("checkpoint-step-")],
|
| 108 |
+
key=lambda p: int(p.name.split("-")[-1]),
|
| 109 |
+
)
|
| 110 |
+
if cks:
|
| 111 |
+
return cks[-1]
|
| 112 |
+
if (out_dir / "adapter_model.safetensors").exists():
|
| 113 |
+
return out_dir
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def best_grpo_ckpt(out_dir: Path) -> Optional[Path]:
|
| 118 |
+
if not out_dir.exists():
|
| 119 |
+
return None
|
| 120 |
+
cks = sorted(
|
| 121 |
+
[p for p in out_dir.iterdir() if p.is_dir() and p.name.startswith("checkpoint-")],
|
| 122 |
+
key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
|
| 123 |
+
)
|
| 124 |
+
if cks:
|
| 125 |
+
return cks[-1]
|
| 126 |
+
if (out_dir / "adapter_model.safetensors").exists():
|
| 127 |
+
return out_dir
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---- subprocess wrappers ---------------------------------------------------
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def run_sft_phase(
|
| 135 |
+
*,
|
| 136 |
+
args: argparse.Namespace,
|
| 137 |
+
phase_dir: Path,
|
| 138 |
+
init_adapter: str,
|
| 139 |
+
num_cot_tokens: int,
|
| 140 |
+
max_steps: int,
|
| 141 |
+
) -> Path:
|
| 142 |
+
"""Launch one SFT phase at fixed k. Returns latest checkpoint path."""
|
| 143 |
+
phase_dir.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
log_path = phase_dir / "train.log"
|
| 145 |
+
cmd = [
|
| 146 |
+
args.python_bin,
|
| 147 |
+
"-u",
|
| 148 |
+
str(SFT_SCRIPT),
|
| 149 |
+
"--model_name",
|
| 150 |
+
args.model_name,
|
| 151 |
+
"--train_jsonl",
|
| 152 |
+
str(TRAIN_JSONL),
|
| 153 |
+
"--eval_jsonl",
|
| 154 |
+
str(EVAL_JSONL),
|
| 155 |
+
"--output_dir",
|
| 156 |
+
str(phase_dir),
|
| 157 |
+
"--cache_dir",
|
| 158 |
+
args.cache_dir,
|
| 159 |
+
"--init_adapter_dir",
|
| 160 |
+
str(init_adapter),
|
| 161 |
+
"--seed",
|
| 162 |
+
str(args.seed),
|
| 163 |
+
"--gpu_id",
|
| 164 |
+
"0",
|
| 165 |
+
"--stage_i",
|
| 166 |
+
"3",
|
| 167 |
+
"--num_cot_tokens",
|
| 168 |
+
str(int(num_cot_tokens)),
|
| 169 |
+
"--latent_mode",
|
| 170 |
+
args.latent_mode,
|
| 171 |
+
"--total_empties_hint",
|
| 172 |
+
"20",
|
| 173 |
+
"--per_device_train_batch_size",
|
| 174 |
+
str(args.sft_bs),
|
| 175 |
+
"--gradient_accumulation_steps",
|
| 176 |
+
str(args.sft_ga),
|
| 177 |
+
"--num_epochs",
|
| 178 |
+
"256",
|
| 179 |
+
"--learning_rate",
|
| 180 |
+
str(args.sft_lr),
|
| 181 |
+
"--max_grad_norm",
|
| 182 |
+
"1.0",
|
| 183 |
+
"--logging_steps",
|
| 184 |
+
"25",
|
| 185 |
+
"--eval_steps",
|
| 186 |
+
"200",
|
| 187 |
+
"--save_steps",
|
| 188 |
+
"200",
|
| 189 |
+
"--eval_rows",
|
| 190 |
+
str(args.eval_rows),
|
| 191 |
+
"--max_completion_length",
|
| 192 |
+
"24",
|
| 193 |
+
"--limit_train_rows",
|
| 194 |
+
str(args.train_rows),
|
| 195 |
+
"--lora_r",
|
| 196 |
+
"32",
|
| 197 |
+
"--lora_alpha",
|
| 198 |
+
"64",
|
| 199 |
+
"--lora_dropout",
|
| 200 |
+
"0.05",
|
| 201 |
+
"--multi_value_oversample_factor",
|
| 202 |
+
str(args.sft_oversample),
|
| 203 |
+
"--max_steps",
|
| 204 |
+
str(int(max_steps)),
|
| 205 |
+
]
|
| 206 |
+
if args.enable_gc:
|
| 207 |
+
cmd.append("--enable_gradient_checkpointing")
|
| 208 |
+
|
| 209 |
+
print(f"[adaptive-k] >>> SFT phase k={num_cot_tokens} max_steps={max_steps}", flush=True)
|
| 210 |
+
print(f"[adaptive-k] init={init_adapter or '(BASE)'}", flush=True)
|
| 211 |
+
print(f"[adaptive-k] out={phase_dir}", flush=True)
|
| 212 |
+
print(f"[adaptive-k] log={log_path}", flush=True)
|
| 213 |
+
|
| 214 |
+
env = dict(os.environ)
|
| 215 |
+
env["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
| 216 |
+
env["TOKENIZERS_PARALLELISM"] = "false"
|
| 217 |
+
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 218 |
+
env["HF_HOME"] = args.cache_dir
|
| 219 |
+
env["TRANSFORMERS_CACHE"] = args.cache_dir
|
| 220 |
+
with open(log_path, "w") as logf:
|
| 221 |
+
ret = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT, env=env)
|
| 222 |
+
if ret.returncode != 0:
|
| 223 |
+
raise RuntimeError(f"SFT phase k={num_cot_tokens} failed (exit {ret.returncode}); see {log_path}")
|
| 224 |
+
|
| 225 |
+
last = latest_ckpt_dir(phase_dir)
|
| 226 |
+
if last is None:
|
| 227 |
+
raise RuntimeError(f"No checkpoint produced under {phase_dir}")
|
| 228 |
+
return last
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def run_grpo_phase(
|
| 232 |
+
*,
|
| 233 |
+
args: argparse.Namespace,
|
| 234 |
+
phase_dir: Path,
|
| 235 |
+
init_adapter: str,
|
| 236 |
+
num_cot_tokens: int,
|
| 237 |
+
max_steps: int,
|
| 238 |
+
) -> Optional[Path]:
|
| 239 |
+
phase_dir.mkdir(parents=True, exist_ok=True)
|
| 240 |
+
log_path = phase_dir / "train.log"
|
| 241 |
+
cmd = [
|
| 242 |
+
args.python_bin,
|
| 243 |
+
"-u",
|
| 244 |
+
str(GRPO_SCRIPT),
|
| 245 |
+
"--model_name",
|
| 246 |
+
args.model_name,
|
| 247 |
+
"--train_jsonl",
|
| 248 |
+
str(TRAIN_JSONL),
|
| 249 |
+
"--eval_jsonl",
|
| 250 |
+
str(EVAL_JSONL),
|
| 251 |
+
"--output_dir",
|
| 252 |
+
str(phase_dir),
|
| 253 |
+
"--cache_dir",
|
| 254 |
+
args.cache_dir,
|
| 255 |
+
"--init_adapter_dir",
|
| 256 |
+
str(init_adapter),
|
| 257 |
+
"--seed",
|
| 258 |
+
str(args.seed),
|
| 259 |
+
"--gpu_id",
|
| 260 |
+
"0",
|
| 261 |
+
"--stage_i",
|
| 262 |
+
"3",
|
| 263 |
+
"--num_cot_tokens",
|
| 264 |
+
str(int(num_cot_tokens)),
|
| 265 |
+
"--latent_mode",
|
| 266 |
+
args.latent_mode,
|
| 267 |
+
"--total_empties_hint",
|
| 268 |
+
"20",
|
| 269 |
+
"--per_device_train_batch_size",
|
| 270 |
+
str(args.grpo_bs),
|
| 271 |
+
"--gradient_accumulation_steps",
|
| 272 |
+
str(args.grpo_ga),
|
| 273 |
+
"--num_train_epochs",
|
| 274 |
+
"100",
|
| 275 |
+
"--learning_rate",
|
| 276 |
+
str(args.grpo_lr),
|
| 277 |
+
"--logging_steps",
|
| 278 |
+
"10",
|
| 279 |
+
"--save_steps",
|
| 280 |
+
"200",
|
| 281 |
+
"--eval_steps",
|
| 282 |
+
"150",
|
| 283 |
+
"--eval_rows",
|
| 284 |
+
str(args.eval_rows),
|
| 285 |
+
"--num_generations",
|
| 286 |
+
str(args.grpo_ng),
|
| 287 |
+
"--max_prompt_length",
|
| 288 |
+
str(args.grpo_max_prompt),
|
| 289 |
+
"--max_completion_length",
|
| 290 |
+
str(args.grpo_max_completion),
|
| 291 |
+
"--beta",
|
| 292 |
+
str(args.grpo_beta),
|
| 293 |
+
"--limit_train_rows",
|
| 294 |
+
str(args.train_rows),
|
| 295 |
+
"--lora_r",
|
| 296 |
+
"32",
|
| 297 |
+
"--lora_alpha",
|
| 298 |
+
"64",
|
| 299 |
+
"--lora_dropout",
|
| 300 |
+
"0.05",
|
| 301 |
+
"--max_steps",
|
| 302 |
+
str(int(max_steps)),
|
| 303 |
+
]
|
| 304 |
+
if args.enable_gc:
|
| 305 |
+
cmd.append("--enable_gradient_checkpointing")
|
| 306 |
+
print(f"[adaptive-k] >>> GRPO phase k={num_cot_tokens} max_steps={max_steps}", flush=True)
|
| 307 |
+
print(f"[adaptive-k] init={init_adapter}", flush=True)
|
| 308 |
+
print(f"[adaptive-k] out={phase_dir}", flush=True)
|
| 309 |
+
|
| 310 |
+
env = dict(os.environ)
|
| 311 |
+
env["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
| 312 |
+
env["TOKENIZERS_PARALLELISM"] = "false"
|
| 313 |
+
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 314 |
+
env["HF_HOME"] = args.cache_dir
|
| 315 |
+
env["TRANSFORMERS_CACHE"] = args.cache_dir
|
| 316 |
+
with open(log_path, "w") as logf:
|
| 317 |
+
ret = subprocess.run(cmd, stdout=logf, stderr=subprocess.STDOUT, env=env)
|
| 318 |
+
if ret.returncode != 0:
|
| 319 |
+
print(f"[adaptive-k] WARN: GRPO failed exit={ret.returncode}, see {log_path}", flush=True)
|
| 320 |
+
return best_grpo_ckpt(phase_dir)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ---- main loop -------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def main() -> None:
|
| 327 |
+
args = parse_args()
|
| 328 |
+
output_root = Path(args.output_root)
|
| 329 |
+
output_root.mkdir(parents=True, exist_ok=True)
|
| 330 |
+
state_path = output_root / "STATE.json"
|
| 331 |
+
pipeline_log = output_root / "PIPELINE.log"
|
| 332 |
+
|
| 333 |
+
def log(msg: str) -> None:
|
| 334 |
+
line = f"[{time.strftime('%H:%M:%S')}] {msg}"
|
| 335 |
+
print(line, flush=True)
|
| 336 |
+
with open(pipeline_log, "a") as f:
|
| 337 |
+
f.write(line + "\n")
|
| 338 |
+
|
| 339 |
+
log(f"===== ADAPTIVE-K {args.variant} on GPU {args.gpu} =====")
|
| 340 |
+
log(f" start_k={args.start_k} max_k={args.max_k} steps_per_phase={args.steps_per_phase} max_phases_per_k={args.max_phases_per_k}")
|
| 341 |
+
log(f" plateau_eps={args.plateau_eps} sft_lr={args.sft_lr} grpo_lr={args.grpo_lr}")
|
| 342 |
+
log(f" output_root={output_root}")
|
| 343 |
+
|
| 344 |
+
history: List[dict] = []
|
| 345 |
+
cur_k = int(args.start_k)
|
| 346 |
+
cur_init: str = "" # "" -> train from base
|
| 347 |
+
last_metric_at_k: Optional[float] = None
|
| 348 |
+
phases_at_k = 0
|
| 349 |
+
sft_phase_idx = 0
|
| 350 |
+
|
| 351 |
+
while cur_k <= int(args.max_k):
|
| 352 |
+
sft_phase_idx += 1
|
| 353 |
+
phase_dir = output_root / f"sft_phase{sft_phase_idx:02d}_k{cur_k}"
|
| 354 |
+
ckpt = run_sft_phase(
|
| 355 |
+
args=args,
|
| 356 |
+
phase_dir=phase_dir,
|
| 357 |
+
init_adapter=cur_init,
|
| 358 |
+
num_cot_tokens=cur_k,
|
| 359 |
+
max_steps=int(args.steps_per_phase),
|
| 360 |
+
)
|
| 361 |
+
metric = latest_eval_metric(phase_dir / "train.log")
|
| 362 |
+
log(
|
| 363 |
+
f" phase{sft_phase_idx} k={cur_k} ckpt={ckpt.name} eval_exact_set_match_rate={metric}"
|
| 364 |
+
)
|
| 365 |
+
history.append(
|
| 366 |
+
{
|
| 367 |
+
"phase": sft_phase_idx,
|
| 368 |
+
"k": cur_k,
|
| 369 |
+
"phase_dir": str(phase_dir),
|
| 370 |
+
"ckpt": str(ckpt),
|
| 371 |
+
"exact_set_match_rate": metric,
|
| 372 |
+
}
|
| 373 |
+
)
|
| 374 |
+
with open(state_path, "w") as f:
|
| 375 |
+
json.dump({"history": history, "cur_k": cur_k, "cur_ckpt": str(ckpt)}, f, indent=2)
|
| 376 |
+
|
| 377 |
+
cur_init = str(ckpt)
|
| 378 |
+
phases_at_k += 1
|
| 379 |
+
|
| 380 |
+
if cur_k >= int(args.max_k):
|
| 381 |
+
log(f" reached max_k={args.max_k}, stopping SFT loop")
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
if last_metric_at_k is None or metric is None:
|
| 385 |
+
improvement = None
|
| 386 |
+
else:
|
| 387 |
+
improvement = float(metric) - float(last_metric_at_k)
|
| 388 |
+
log(f" improvement_at_k={improvement} phases_at_k={phases_at_k}/{args.max_phases_per_k}")
|
| 389 |
+
|
| 390 |
+
bump = False
|
| 391 |
+
if phases_at_k >= int(args.max_phases_per_k):
|
| 392 |
+
log(" hit max_phases_per_k, bumping k")
|
| 393 |
+
bump = True
|
| 394 |
+
elif improvement is not None and improvement < float(args.plateau_eps):
|
| 395 |
+
log(f" improvement {improvement:.4f} < plateau_eps {args.plateau_eps:.4f}, bumping k")
|
| 396 |
+
bump = True
|
| 397 |
+
|
| 398 |
+
if bump:
|
| 399 |
+
cur_k += 1
|
| 400 |
+
last_metric_at_k = None
|
| 401 |
+
phases_at_k = 0
|
| 402 |
+
else:
|
| 403 |
+
last_metric_at_k = metric
|
| 404 |
+
|
| 405 |
+
log(f"===== final SFT k={cur_k} ckpt={cur_init} =====")
|
| 406 |
+
grpo_dir = output_root / f"grpo_k{cur_k}"
|
| 407 |
+
grpo_ckpt = run_grpo_phase(
|
| 408 |
+
args=args,
|
| 409 |
+
phase_dir=grpo_dir,
|
| 410 |
+
init_adapter=cur_init,
|
| 411 |
+
num_cot_tokens=cur_k,
|
| 412 |
+
max_steps=int(args.grpo_steps),
|
| 413 |
+
)
|
| 414 |
+
log(f"===== GRPO done ckpt={grpo_ckpt} =====")
|
| 415 |
+
with open(state_path, "w") as f:
|
| 416 |
+
json.dump(
|
| 417 |
+
{
|
| 418 |
+
"history": history,
|
| 419 |
+
"final_k": cur_k,
|
| 420 |
+
"final_sft_ckpt": cur_init,
|
| 421 |
+
"grpo_ckpt": str(grpo_ckpt) if grpo_ckpt else None,
|
| 422 |
+
},
|
| 423 |
+
f,
|
| 424 |
+
indent=2,
|
| 425 |
+
)
|
| 426 |
+
log(f"===== ADAPTIVE-K {args.variant} done =====")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
if __name__ == "__main__":
|
| 430 |
+
main()
|
_runs/adaptive_latent_baseline_sudoku_train.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Adaptive-k thought-token baseline (experiment D in the 2x2 ablation).
|
| 3 |
+
|
| 4 |
+
Same single-stage, whole-puzzle setup as `simple_baseline_sudoku_train.py`
|
| 5 |
+
(experiment C, the "strawman"). Same JSONL data, same chat template, same
|
| 6 |
+
model, same LoRA. The ONLY difference is that this run inserts k recurrent
|
| 7 |
+
thought tokens between the prompt and the next-token logits, and grows k
|
| 8 |
+
on demand whenever the SFT loss plateaus.
|
| 9 |
+
|
| 10 |
+
Algorithm:
|
| 11 |
+
k = 0 (start as the vanilla baseline)
|
| 12 |
+
repeat:
|
| 13 |
+
train SFT for `min_steps_per_k` steps with current k
|
| 14 |
+
if rolling_avg(loss[-w:]) - rolling_avg(loss[-2w:-w]) > -plateau_eps:
|
| 15 |
+
k += 1 # grow capacity
|
| 16 |
+
if k > max_k: break
|
| 17 |
+
if loss has been steadily decreasing past `min_steps_per_k * 3`:
|
| 18 |
+
break # converged
|
| 19 |
+
save final adapter
|
| 20 |
+
|
| 21 |
+
The recurrent_hidden mechanism is imported verbatim from
|
| 22 |
+
`latent_multi_output_cell_policy.grpo_residual_projector_latent_train`
|
| 23 |
+
(via `latent_batched_completion_ce_loss`). For k=0 the loss reduces to
|
| 24 |
+
vanilla next-token CE, so the trajectory smoothly continues from the
|
| 25 |
+
strawman.
|
| 26 |
+
|
| 27 |
+
Reward / loss contract (see `simple_baseline_sudoku_train.py` for details):
|
| 28 |
+
- supervision is token-level CE against the JSONL `completion` field
|
| 29 |
+
(the 20 ground-truth digits at the 20 empty cells, row-major).
|
| 30 |
+
- this script is SFT-only; you can chain GRPO afterwards by passing the
|
| 31 |
+
saved adapter to `simple_baseline_sudoku_train.py --phase grpo`.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
from __future__ import annotations
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import json
|
| 38 |
+
import math
|
| 39 |
+
import os
|
| 40 |
+
import sys
|
| 41 |
+
import time
|
| 42 |
+
from collections import deque
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
from typing import Any, Dict, List, Tuple
|
| 45 |
+
|
| 46 |
+
import torch
|
| 47 |
+
import torch.nn.functional as F
|
| 48 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 49 |
+
from torch.optim import AdamW
|
| 50 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 51 |
+
|
| 52 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 53 |
+
if str(ROOT) not in sys.path:
|
| 54 |
+
sys.path.insert(0, str(ROOT))
|
| 55 |
+
|
| 56 |
+
# Reuse helpers and the latent loss from the curriculum codebase. NO
|
| 57 |
+
# re-implementation of the recurrent_hidden mechanism here.
|
| 58 |
+
from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore
|
| 59 |
+
load_jsonl_rows,
|
| 60 |
+
pick_dtype,
|
| 61 |
+
)
|
| 62 |
+
from latent_multi_output_cell_policy.sft_latent_multi_output_train import ( # type: ignore
|
| 63 |
+
latent_batched_completion_ce_loss,
|
| 64 |
+
)
|
| 65 |
+
from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import ( # type: ignore
|
| 66 |
+
recurrent_hidden_next_token_logits_from_ids,
|
| 67 |
+
)
|
| 68 |
+
from _runs.simple_baseline_sudoku_train import ( # type: ignore
|
| 69 |
+
SYSTEM_PROMPT_STRAWMAN,
|
| 70 |
+
build_chat_prompt,
|
| 71 |
+
parse_int_list,
|
| 72 |
+
)
|
| 73 |
+
from multi_output_cell_policy.rewards import score_prediction_text # type: ignore
|
| 74 |
+
from multi_output_cell_policy.shared_multi_output_policy import ( # type: ignore
|
| 75 |
+
make_solved_grid_from_row,
|
| 76 |
+
stage_i_consistent_values,
|
| 77 |
+
)
|
| 78 |
+
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row # type: ignore
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---- Tokenization (mirror what latent_batched_completion_ce_loss expects) ---
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def tokenize_example(
|
| 85 |
+
tokenizer: Any,
|
| 86 |
+
raw_prompt: str,
|
| 87 |
+
completion_text: str,
|
| 88 |
+
max_prompt_length: int,
|
| 89 |
+
max_completion_length: int,
|
| 90 |
+
) -> Dict[str, List[int]]:
|
| 91 |
+
prompt_text = build_chat_prompt(tokenizer, raw_prompt)
|
| 92 |
+
prompt_ids = tokenizer(
|
| 93 |
+
prompt_text, add_special_tokens=False
|
| 94 |
+
).input_ids[-max_prompt_length:]
|
| 95 |
+
eos = tokenizer.eos_token or "<|endoftext|>"
|
| 96 |
+
completion_ids = tokenizer(
|
| 97 |
+
completion_text + eos, add_special_tokens=False
|
| 98 |
+
).input_ids[:max_completion_length]
|
| 99 |
+
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---- Eval (autoregressive greedy decode WITH k recurrent thought tokens) ---
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def latent_greedy_generate(
|
| 107 |
+
model: torch.nn.Module,
|
| 108 |
+
tokenizer: Any,
|
| 109 |
+
prompt_text: str,
|
| 110 |
+
device: torch.device,
|
| 111 |
+
*,
|
| 112 |
+
num_cot_tokens: int,
|
| 113 |
+
max_new_tokens: int,
|
| 114 |
+
) -> str:
|
| 115 |
+
enc = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
|
| 116 |
+
input_ids = enc["input_ids"].to(device)
|
| 117 |
+
attention_mask = enc["attention_mask"].to(device)
|
| 118 |
+
prompt_len = int(input_ids.shape[1])
|
| 119 |
+
eos_token_id = tokenizer.eos_token_id
|
| 120 |
+
for _ in range(int(max_new_tokens)):
|
| 121 |
+
logits = recurrent_hidden_next_token_logits_from_ids(
|
| 122 |
+
model, input_ids, attention_mask, int(max(0, num_cot_tokens))
|
| 123 |
+
)
|
| 124 |
+
next_id = int(torch.argmax(logits, dim=-1).item())
|
| 125 |
+
input_ids = torch.cat(
|
| 126 |
+
[input_ids, torch.tensor([[next_id]], device=device, dtype=input_ids.dtype)], dim=1
|
| 127 |
+
)
|
| 128 |
+
attention_mask = torch.cat(
|
| 129 |
+
[attention_mask, torch.ones((1, 1), device=device, dtype=attention_mask.dtype)], dim=1
|
| 130 |
+
)
|
| 131 |
+
if eos_token_id is not None and next_id == int(eos_token_id):
|
| 132 |
+
break
|
| 133 |
+
new_ids = input_ids[0, prompt_len:]
|
| 134 |
+
return tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@torch.no_grad()
|
| 138 |
+
def run_eval(
|
| 139 |
+
model: torch.nn.Module,
|
| 140 |
+
tokenizer: Any,
|
| 141 |
+
eval_rows: List[Dict[str, Any]],
|
| 142 |
+
device: torch.device,
|
| 143 |
+
*,
|
| 144 |
+
num_cot_tokens: int,
|
| 145 |
+
max_new_tokens: int,
|
| 146 |
+
print_n: int = 3,
|
| 147 |
+
stage_i: int = 3,
|
| 148 |
+
) -> Dict[str, float]:
|
| 149 |
+
"""Apples-to-apples eval with the cell-policy framework (see strawman script).
|
| 150 |
+
|
| 151 |
+
The model emits the WHOLE puzzle (JSON list of integers) in one greedy
|
| 152 |
+
rollout with `num_cot_tokens` recurrent thought tokens prepended at each
|
| 153 |
+
step. We split that list into per-cell SINGLETON predictions and score
|
| 154 |
+
each cell with ``score_prediction_text`` against the i-consistent target
|
| 155 |
+
set at ``stage_i`` (default 3 — matches the S3 eval used for the v6
|
| 156 |
+
baseline and the latent champion).
|
| 157 |
+
"""
|
| 158 |
+
model.eval()
|
| 159 |
+
total_cells = 0
|
| 160 |
+
parse_ok = 0.0
|
| 161 |
+
canonical_ok = 0.0
|
| 162 |
+
exact_set_match = 0.0
|
| 163 |
+
includes_gt = 0.0
|
| 164 |
+
precision_sum = 0.0
|
| 165 |
+
recall_sum = 0.0
|
| 166 |
+
cardinality_match_sum = 0.0
|
| 167 |
+
n_solve = 0
|
| 168 |
+
n_total_puzzles = 0
|
| 169 |
+
n_parse_fail_puzzles = 0
|
| 170 |
+
printed = 0
|
| 171 |
+
for row in eval_rows:
|
| 172 |
+
target_completion = parse_int_list(str(row["completion"]))
|
| 173 |
+
if target_completion is None:
|
| 174 |
+
continue
|
| 175 |
+
n_total_puzzles += 1
|
| 176 |
+
prompt_text = build_chat_prompt(tokenizer, str(row["prompt"]).strip())
|
| 177 |
+
gen = latent_greedy_generate(
|
| 178 |
+
model, tokenizer, prompt_text, device,
|
| 179 |
+
num_cot_tokens=num_cot_tokens, max_new_tokens=max_new_tokens,
|
| 180 |
+
)
|
| 181 |
+
pred_list = parse_int_list(gen)
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
cells = build_cell_examples_from_row(row)
|
| 185 |
+
solved = make_solved_grid_from_row(row)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
if printed < print_n:
|
| 188 |
+
print(f"[adaptive_k k={num_cot_tokens} eval] row skipped (no metadata): {e}", flush=True)
|
| 189 |
+
printed += 1
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
row_all_exact = True
|
| 193 |
+
row_has_eval_cell = False
|
| 194 |
+
for idx, ex in enumerate(cells):
|
| 195 |
+
target_values = stage_i_consistent_values(
|
| 196 |
+
ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i)
|
| 197 |
+
)
|
| 198 |
+
row_has_eval_cell = True
|
| 199 |
+
if pred_list is not None and idx < len(pred_list):
|
| 200 |
+
pred_text = json.dumps({"values": [int(pred_list[idx])]})
|
| 201 |
+
else:
|
| 202 |
+
pred_text = ""
|
| 203 |
+
info = score_prediction_text(
|
| 204 |
+
text=pred_text,
|
| 205 |
+
grid=ex.grid,
|
| 206 |
+
solved=solved,
|
| 207 |
+
target_cell=ex.target_cell,
|
| 208 |
+
stage_i=int(stage_i),
|
| 209 |
+
reward_good_value=1.0,
|
| 210 |
+
penalty_bad_value=1.75,
|
| 211 |
+
penalty_malformed=4.0,
|
| 212 |
+
penalty_empty=0.5,
|
| 213 |
+
penalty_singleton=1.5,
|
| 214 |
+
)
|
| 215 |
+
total_cells += 1
|
| 216 |
+
parse_ok += float(info["parse_ok"])
|
| 217 |
+
canonical_ok += float(info["strict_canonical"])
|
| 218 |
+
exact_set_match += float(info["exact_set_match"])
|
| 219 |
+
includes_gt += float(info["includes_ground_truth"])
|
| 220 |
+
precision_sum += float(info["value_precision"])
|
| 221 |
+
recall_sum += float(info["value_recall"])
|
| 222 |
+
if int(info["num_predicted_values"]) == int(len(target_values)):
|
| 223 |
+
cardinality_match_sum += 1.0
|
| 224 |
+
if float(info["exact_set_match"]) < 0.5:
|
| 225 |
+
row_all_exact = False
|
| 226 |
+
if row_has_eval_cell and row_all_exact:
|
| 227 |
+
n_solve += 1
|
| 228 |
+
if pred_list is None:
|
| 229 |
+
n_parse_fail_puzzles += 1
|
| 230 |
+
if printed < print_n:
|
| 231 |
+
head_pred = pred_list if pred_list is not None else "PARSE_FAIL"
|
| 232 |
+
print(
|
| 233 |
+
f"[adaptive_k k={num_cot_tokens} eval] target={target_completion} pred={head_pred} "
|
| 234 |
+
f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}",
|
| 235 |
+
flush=True,
|
| 236 |
+
)
|
| 237 |
+
printed += 1
|
| 238 |
+
model.train()
|
| 239 |
+
return {
|
| 240 |
+
"n_total_cells": float(total_cells),
|
| 241 |
+
"n_total_puzzles": float(n_total_puzzles),
|
| 242 |
+
"parse_rate": float(parse_ok / max(1, total_cells)),
|
| 243 |
+
"strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
|
| 244 |
+
"exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
|
| 245 |
+
"includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
|
| 246 |
+
"value_precision": float(precision_sum / max(1, total_cells)),
|
| 247 |
+
"value_recall": float(recall_sum / max(1, total_cells)),
|
| 248 |
+
"cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)),
|
| 249 |
+
"puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)),
|
| 250 |
+
"solve_rate": float(n_solve) / max(1, n_total_puzzles),
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---- Main loop --------------------------------------------------------------
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def parse_args() -> argparse.Namespace:
|
| 258 |
+
p = argparse.ArgumentParser()
|
| 259 |
+
p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 260 |
+
p.add_argument("--train_jsonl", type=str, required=True)
|
| 261 |
+
p.add_argument("--eval_jsonl", type=str, required=True)
|
| 262 |
+
p.add_argument("--output_dir", type=str, required=True)
|
| 263 |
+
p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache"))
|
| 264 |
+
p.add_argument("--init_adapter_dir", type=str, default="")
|
| 265 |
+
p.add_argument("--seed", type=int, default=0)
|
| 266 |
+
|
| 267 |
+
# Data
|
| 268 |
+
p.add_argument("--limit_train_rows", type=int, default=10000)
|
| 269 |
+
p.add_argument("--eval_rows", type=int, default=50)
|
| 270 |
+
|
| 271 |
+
# Train hyperparameters
|
| 272 |
+
p.add_argument("--per_device_train_batch_size", type=int, default=4)
|
| 273 |
+
p.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
| 274 |
+
p.add_argument("--learning_rate", type=float, default=5e-5)
|
| 275 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 276 |
+
p.add_argument("--max_steps", type=int, default=4000)
|
| 277 |
+
p.add_argument("--logging_steps", type=int, default=25)
|
| 278 |
+
p.add_argument("--save_steps", type=int, default=500)
|
| 279 |
+
p.add_argument("--eval_every_steps", type=int, default=500)
|
| 280 |
+
p.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 281 |
+
p.add_argument("--max_completion_length", type=int, default=96)
|
| 282 |
+
p.add_argument("--max_prompt_length", type=int, default=1024)
|
| 283 |
+
|
| 284 |
+
# LoRA
|
| 285 |
+
p.add_argument("--lora_r", type=int, default=32)
|
| 286 |
+
p.add_argument("--lora_alpha", type=int, default=64)
|
| 287 |
+
p.add_argument("--lora_dropout", type=float, default=0.05)
|
| 288 |
+
p.add_argument("--enable_gradient_checkpointing", action="store_true")
|
| 289 |
+
|
| 290 |
+
# Adaptive-k schedule
|
| 291 |
+
p.add_argument("--start_k", type=int, default=0)
|
| 292 |
+
p.add_argument("--max_k", type=int, default=4)
|
| 293 |
+
p.add_argument(
|
| 294 |
+
"--min_steps_per_k",
|
| 295 |
+
type=int,
|
| 296 |
+
default=400,
|
| 297 |
+
help="Minimum SFT steps to spend at each k before considering an increment.",
|
| 298 |
+
)
|
| 299 |
+
p.add_argument(
|
| 300 |
+
"--plateau_window",
|
| 301 |
+
type=int,
|
| 302 |
+
default=100,
|
| 303 |
+
help="Sliding window (in steps) used to compute rolling-mean loss for plateau detection.",
|
| 304 |
+
)
|
| 305 |
+
p.add_argument(
|
| 306 |
+
"--plateau_eps",
|
| 307 |
+
type=float,
|
| 308 |
+
default=0.005,
|
| 309 |
+
help="If rolling_mean(loss[-w:]) - rolling_mean(loss[-2w:-w]) > -plateau_eps -> plateau detected.",
|
| 310 |
+
)
|
| 311 |
+
p.add_argument(
|
| 312 |
+
"--converged_eps",
|
| 313 |
+
type=float,
|
| 314 |
+
default=0.001,
|
| 315 |
+
help="If two consecutive plateau windows pass with delta within this band, we declare convergence and stop.",
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return p.parse_args()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device):
|
| 322 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir, use_fast=True)
|
| 323 |
+
if tokenizer.pad_token_id is None:
|
| 324 |
+
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
|
| 325 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 326 |
+
args.model_name, cache_dir=args.cache_dir,
|
| 327 |
+
torch_dtype=pick_dtype(), low_cpu_mem_usage=True,
|
| 328 |
+
)
|
| 329 |
+
if str(args.init_adapter_dir).strip():
|
| 330 |
+
model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
|
| 331 |
+
else:
|
| 332 |
+
lora = LoraConfig(
|
| 333 |
+
r=args.lora_r,
|
| 334 |
+
lora_alpha=args.lora_alpha,
|
| 335 |
+
lora_dropout=args.lora_dropout,
|
| 336 |
+
bias="none",
|
| 337 |
+
task_type="CAUSAL_LM",
|
| 338 |
+
target_modules=[
|
| 339 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 340 |
+
"gate_proj", "up_proj", "down_proj",
|
| 341 |
+
],
|
| 342 |
+
)
|
| 343 |
+
model = get_peft_model(model, lora)
|
| 344 |
+
if args.enable_gradient_checkpointing:
|
| 345 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 346 |
+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 347 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 348 |
+
model.enable_input_require_grads()
|
| 349 |
+
if hasattr(model, "config"):
|
| 350 |
+
model.config.use_cache = False
|
| 351 |
+
model.to(device)
|
| 352 |
+
return model, tokenizer
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def detect_plateau(losses: deque, window: int, plateau_eps: float) -> Tuple[bool, float]:
|
| 356 |
+
if len(losses) < 2 * window:
|
| 357 |
+
return False, 0.0
|
| 358 |
+
arr = list(losses)
|
| 359 |
+
recent = arr[-window:]
|
| 360 |
+
prior = arr[-2 * window : -window]
|
| 361 |
+
delta = (sum(recent) / len(recent)) - (sum(prior) / len(prior))
|
| 362 |
+
# If delta > -plateau_eps, loss hasn't decreased fast enough -> plateau.
|
| 363 |
+
return (delta > -float(plateau_eps)), float(delta)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def save_adapter(model: torch.nn.Module, tokenizer: Any, out: str) -> None:
|
| 367 |
+
os.makedirs(out, exist_ok=True)
|
| 368 |
+
if hasattr(model, "save_pretrained"):
|
| 369 |
+
model.save_pretrained(out)
|
| 370 |
+
if hasattr(tokenizer, "save_pretrained"):
|
| 371 |
+
tokenizer.save_pretrained(out)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def main() -> None:
|
| 375 |
+
args = parse_args()
|
| 376 |
+
set_seed(int(args.seed))
|
| 377 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 378 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 379 |
+
|
| 380 |
+
train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
|
| 381 |
+
eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
|
| 382 |
+
|
| 383 |
+
print(f"[adaptive_k] loaded {len(train_rows)} train rows, {len(eval_rows)} eval rows", flush=True)
|
| 384 |
+
|
| 385 |
+
model, tokenizer = setup_model_and_tokenizer(args, device)
|
| 386 |
+
pad_id = int(tokenizer.pad_token_id)
|
| 387 |
+
|
| 388 |
+
# Pre-tokenize the train set once.
|
| 389 |
+
train_examples: List[Dict[str, Any]] = []
|
| 390 |
+
for row in train_rows:
|
| 391 |
+
try:
|
| 392 |
+
ex = tokenize_example(
|
| 393 |
+
tokenizer,
|
| 394 |
+
str(row["prompt"]).strip(),
|
| 395 |
+
str(row["completion"]).strip(),
|
| 396 |
+
int(args.max_prompt_length),
|
| 397 |
+
int(args.max_completion_length),
|
| 398 |
+
)
|
| 399 |
+
if ex["completion_ids"]:
|
| 400 |
+
train_examples.append(ex)
|
| 401 |
+
except Exception as e: # noqa: BLE001
|
| 402 |
+
print(f"[adaptive_k] tokenize skip: {e}", flush=True)
|
| 403 |
+
print(f"[adaptive_k] tokenized {len(train_examples)} train examples", flush=True)
|
| 404 |
+
|
| 405 |
+
optimizer = AdamW(
|
| 406 |
+
(p for p in model.parameters() if p.requires_grad),
|
| 407 |
+
lr=float(args.learning_rate),
|
| 408 |
+
weight_decay=float(args.weight_decay),
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
bs = int(args.per_device_train_batch_size)
|
| 412 |
+
ga = int(args.gradient_accumulation_steps)
|
| 413 |
+
steps = 0
|
| 414 |
+
losses_per_step: List[float] = []
|
| 415 |
+
rolling: deque = deque(maxlen=2 * int(args.plateau_window) + 16)
|
| 416 |
+
k = int(args.start_k)
|
| 417 |
+
max_k = int(args.max_k)
|
| 418 |
+
steps_at_current_k = 0
|
| 419 |
+
grew_at: List[Tuple[int, int]] = [] # (step, new_k)
|
| 420 |
+
|
| 421 |
+
print(f"[adaptive_k] starting at k={k}", flush=True)
|
| 422 |
+
init_eval = run_eval(
|
| 423 |
+
model, tokenizer, eval_rows, device,
|
| 424 |
+
num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
|
| 425 |
+
)
|
| 426 |
+
print(f"[adaptive_k] init eval k={k}: {init_eval}", flush=True)
|
| 427 |
+
|
| 428 |
+
t0 = time.time()
|
| 429 |
+
rng_state = torch.Generator(device="cpu").manual_seed(int(args.seed))
|
| 430 |
+
perm = torch.randperm(len(train_examples), generator=rng_state).tolist()
|
| 431 |
+
cursor = 0
|
| 432 |
+
|
| 433 |
+
optimizer.zero_grad(set_to_none=True)
|
| 434 |
+
micro_in_step = 0
|
| 435 |
+
micro_loss_accum = 0.0
|
| 436 |
+
|
| 437 |
+
while steps < int(args.max_steps):
|
| 438 |
+
if cursor + bs > len(perm):
|
| 439 |
+
perm = torch.randperm(len(train_examples), generator=rng_state).tolist()
|
| 440 |
+
cursor = 0
|
| 441 |
+
batch_indices = perm[cursor : cursor + bs]
|
| 442 |
+
cursor += bs
|
| 443 |
+
batch = [train_examples[i] for i in batch_indices]
|
| 444 |
+
|
| 445 |
+
loss = latent_batched_completion_ce_loss(
|
| 446 |
+
model,
|
| 447 |
+
batch,
|
| 448 |
+
device,
|
| 449 |
+
num_cot_tokens=int(max(0, k)),
|
| 450 |
+
latent_mode="recurrent_hidden",
|
| 451 |
+
pad_token_id=pad_id,
|
| 452 |
+
) / float(ga)
|
| 453 |
+
loss.backward()
|
| 454 |
+
micro_loss_accum += float(loss.detach().item()) * float(ga)
|
| 455 |
+
micro_in_step += 1
|
| 456 |
+
|
| 457 |
+
if micro_in_step >= ga:
|
| 458 |
+
torch.nn.utils.clip_grad_norm_(
|
| 459 |
+
(p for p in model.parameters() if p.requires_grad),
|
| 460 |
+
float(args.max_grad_norm),
|
| 461 |
+
)
|
| 462 |
+
optimizer.step()
|
| 463 |
+
optimizer.zero_grad(set_to_none=True)
|
| 464 |
+
steps += 1
|
| 465 |
+
steps_at_current_k += 1
|
| 466 |
+
avg_micro_loss = micro_loss_accum / float(ga)
|
| 467 |
+
losses_per_step.append(avg_micro_loss)
|
| 468 |
+
rolling.append(avg_micro_loss)
|
| 469 |
+
micro_in_step = 0
|
| 470 |
+
micro_loss_accum = 0.0
|
| 471 |
+
|
| 472 |
+
if steps % int(args.logging_steps) == 0:
|
| 473 |
+
w = int(args.plateau_window)
|
| 474 |
+
recent = list(rolling)[-w:] if len(rolling) >= w else list(rolling)
|
| 475 |
+
rec_mean = sum(recent) / max(1, len(recent))
|
| 476 |
+
elapsed = time.time() - t0
|
| 477 |
+
print(
|
| 478 |
+
f"[adaptive_k] step={steps} k={k} loss={avg_micro_loss:.4f} "
|
| 479 |
+
f"rolling_mean({len(recent)})={rec_mean:.4f} elapsed={elapsed:.0f}s",
|
| 480 |
+
flush=True,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if steps % int(args.eval_every_steps) == 0:
|
| 484 |
+
ev = run_eval(
|
| 485 |
+
model, tokenizer, eval_rows, device,
|
| 486 |
+
num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
|
| 487 |
+
)
|
| 488 |
+
print(f"[adaptive_k] EVAL step={steps} k={k}: {ev}", flush=True)
|
| 489 |
+
|
| 490 |
+
if steps % int(args.save_steps) == 0:
|
| 491 |
+
save_adapter(model, tokenizer, os.path.join(args.output_dir, f"checkpoint-step-{steps:05d}"))
|
| 492 |
+
|
| 493 |
+
# Plateau check (only after `min_steps_per_k` at current k, and we
|
| 494 |
+
# have at least 2*plateau_window losses in the rolling buffer).
|
| 495 |
+
if steps_at_current_k >= int(args.min_steps_per_k):
|
| 496 |
+
plateau, delta = detect_plateau(rolling, int(args.plateau_window), float(args.plateau_eps))
|
| 497 |
+
if plateau and k < max_k:
|
| 498 |
+
print(
|
| 499 |
+
f"[adaptive_k] plateau detected at step={steps} k={k} delta={delta:+.4f} -> growing k -> {k+1}",
|
| 500 |
+
flush=True,
|
| 501 |
+
)
|
| 502 |
+
k += 1
|
| 503 |
+
steps_at_current_k = 0
|
| 504 |
+
grew_at.append((steps, k))
|
| 505 |
+
rolling.clear() # restart plateau tracking after capacity bump
|
| 506 |
+
save_adapter(model, tokenizer, os.path.join(args.output_dir, f"checkpoint-step-{steps:05d}-grow-k{k}"))
|
| 507 |
+
elif plateau and k >= max_k and abs(delta) < float(args.converged_eps):
|
| 508 |
+
print(
|
| 509 |
+
f"[adaptive_k] convergence at step={steps} k={k} delta={delta:+.4f} (max_k reached) -> stopping",
|
| 510 |
+
flush=True,
|
| 511 |
+
)
|
| 512 |
+
break
|
| 513 |
+
|
| 514 |
+
final_dir = os.path.join(args.output_dir, "final")
|
| 515 |
+
save_adapter(model, tokenizer, final_dir)
|
| 516 |
+
final_eval = run_eval(
|
| 517 |
+
model, tokenizer, eval_rows, device,
|
| 518 |
+
num_cot_tokens=k, max_new_tokens=int(args.max_completion_length),
|
| 519 |
+
)
|
| 520 |
+
summary = {
|
| 521 |
+
"final_k": k,
|
| 522 |
+
"total_steps": steps,
|
| 523 |
+
"max_k": max_k,
|
| 524 |
+
"grew_at_steps": grew_at,
|
| 525 |
+
"final_eval": final_eval,
|
| 526 |
+
"training_seconds": time.time() - t0,
|
| 527 |
+
}
|
| 528 |
+
with open(os.path.join(args.output_dir, "summary.json"), "w", encoding="utf-8") as f:
|
| 529 |
+
json.dump(summary, f, indent=2)
|
| 530 |
+
print(f"[adaptive_k] DONE summary={json.dumps(summary)}", flush=True)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
if __name__ == "__main__":
|
| 534 |
+
main()
|
_runs/add_variants_g_h.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Add 2 more variants on GPUs 6 and 7 to the active sweep.
|
| 3 |
+
# Both seed from the lr5e5 lowsft S2 SFT step-3000 (the winning lineage at step 150).
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 7 |
+
SWEEP_ROOT="${SWEEP_ROOT:-$(ls -dt ${ROOT}/_runs/baseline_1p5b_v4_*/ 2>/dev/null | head -1 | sed 's:/$::')}"
|
| 8 |
+
PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
|
| 9 |
+
|
| 10 |
+
[[ -d "${SWEEP_ROOT}" ]] || { echo "sweep root missing"; exit 1; }
|
| 11 |
+
echo "Sweep: ${SWEEP_ROOT}"
|
| 12 |
+
|
| 13 |
+
CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
|
| 14 |
+
[[ -d "${CKPT_LR5E5}" ]] || { echo "missing init"; exit 1; }
|
| 15 |
+
|
| 16 |
+
launch_variant() {
|
| 17 |
+
local gpu="$1" variant="$2" init="$3"
|
| 18 |
+
shift 3
|
| 19 |
+
local out="${SWEEP_ROOT}/${variant}"
|
| 20 |
+
mkdir -p "${out}"
|
| 21 |
+
local nohup_log="${out}/nohup.log"
|
| 22 |
+
printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
|
| 23 |
+
nohup env \
|
| 24 |
+
ROOT="${ROOT}" \
|
| 25 |
+
VARIANT="${variant}" \
|
| 26 |
+
GPU="${gpu}" \
|
| 27 |
+
S2_SFT_CKPT="${init}" \
|
| 28 |
+
OUTPUT_ROOT="${out}" \
|
| 29 |
+
USE_WANDB=0 \
|
| 30 |
+
WANDB_MODE=offline \
|
| 31 |
+
"$@" \
|
| 32 |
+
bash "${PIPELINE}" \
|
| 33 |
+
</dev/null >"${nohup_log}" 2>&1 &
|
| 34 |
+
local pid=$!
|
| 35 |
+
printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
|
| 36 |
+
echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 37 |
+
disown "${pid}" 2>/dev/null || true
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# pipe_g: lr5e5 lineage, faster GRPO LR (1e-5) to push convergence
|
| 41 |
+
launch_variant 6 pipe_g_lr5e5_grpo1e5 "${CKPT_LR5E5}" GRPO_LR=1e-5 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
|
| 42 |
+
|
| 43 |
+
# pipe_h: lr5e5 lineage, lower singleton penalty (1.0) to test if 1.5 hurts
|
| 44 |
+
launch_variant 7 pipe_h_lr5e5_grpo5e6_sngl10 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.0
|
| 45 |
+
|
| 46 |
+
# Update sweep README
|
| 47 |
+
cat >>"${SWEEP_ROOT}/SWEEP_README.md" <<EOF
|
| 48 |
+
|
| 49 |
+
## Added at $(date '+%H:%M:%S')
|
| 50 |
+
|
| 51 |
+
| GPU | variant | S2 init | GRPO LR | S3 SFT LR | penalty_singleton |
|
| 52 |
+
| ---: | --- | --- | ---: | ---: | ---: |
|
| 53 |
+
| 6 | pipe_g_lr5e5_grpo1e5 | lr5e5_lowsft step-3000 | 1e-5 | 2e-5 | 1.5 |
|
| 54 |
+
| 7 | pipe_h_lr5e5_grpo5e6_sngl10 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 1.0 |
|
| 55 |
+
EOF
|
| 56 |
+
|
| 57 |
+
echo "Done. Now running 8 variants on GPUs 0..7."
|
_runs/add_variants_i_j_k_l.sh
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Add 4 high-throughput variants on freed GPUs 0,2,3,4.
|
| 3 |
+
# 3 variants fast-forward to S3 SFT (since S2 GRPO is plateau-stuck on baseline).
|
| 4 |
+
# 1 variant tries an aggressive 10x GRPO LR to break the S2 plateau.
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 8 |
+
SWEEP_ROOT="${SWEEP_ROOT:-$(ls -dt ${ROOT}/_runs/baseline_1p5b_v4_*/ 2>/dev/null | head -1 | sed 's:/$::')}"
|
| 9 |
+
PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
|
| 10 |
+
|
| 11 |
+
[[ -d "${SWEEP_ROOT}" ]] || { echo "sweep root missing"; exit 1; }
|
| 12 |
+
echo "Sweep: ${SWEEP_ROOT}"
|
| 13 |
+
|
| 14 |
+
CKPT_LR1E4="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr1e4_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
|
| 15 |
+
CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
|
| 16 |
+
[[ -d "${CKPT_LR5E5}" ]] || { echo "missing init lr5e5"; exit 1; }
|
| 17 |
+
[[ -d "${CKPT_LR1E4}" ]] || { echo "missing init lr1e4"; exit 1; }
|
| 18 |
+
|
| 19 |
+
launch_variant() {
|
| 20 |
+
local gpu="$1" variant="$2" init="$3"
|
| 21 |
+
shift 3
|
| 22 |
+
local out="${SWEEP_ROOT}/${variant}"
|
| 23 |
+
mkdir -p "${out}"
|
| 24 |
+
local nohup_log="${out}/nohup.log"
|
| 25 |
+
printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
|
| 26 |
+
nohup env \
|
| 27 |
+
ROOT="${ROOT}" \
|
| 28 |
+
VARIANT="${variant}" \
|
| 29 |
+
GPU="${gpu}" \
|
| 30 |
+
S2_SFT_CKPT="${init}" \
|
| 31 |
+
OUTPUT_ROOT="${out}" \
|
| 32 |
+
USE_WANDB=0 \
|
| 33 |
+
WANDB_MODE=offline \
|
| 34 |
+
"$@" \
|
| 35 |
+
bash "${PIPELINE}" \
|
| 36 |
+
</dev/null >"${nohup_log}" 2>&1 &
|
| 37 |
+
local pid=$!
|
| 38 |
+
printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
|
| 39 |
+
echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 40 |
+
disown "${pid}" 2>/dev/null || true
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# pipe_i (GPU 0): fast-forward to S3 SFT from lr5e5 lowsft step-3000.
|
| 44 |
+
# high-throughput: no GC, bs=32x1, larger eval batches.
|
| 45 |
+
launch_variant 0 pipe_i_s3sft_lr5e5_fast "${CKPT_LR5E5}" \
|
| 46 |
+
START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR5E5}" \
|
| 47 |
+
SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
|
| 48 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 49 |
+
USE_GC=0
|
| 50 |
+
|
| 51 |
+
# pipe_j (GPU 2): fast-forward to S3 SFT from lr5e5 with lower LR for stability.
|
| 52 |
+
launch_variant 2 pipe_j_s3sft_lr5e5_lr1e5 "${CKPT_LR5E5}" \
|
| 53 |
+
START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR5E5}" \
|
| 54 |
+
SFT_LR_S3=1e-5 SFT_BS=32 SFT_GA=1 \
|
| 55 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 56 |
+
USE_GC=0
|
| 57 |
+
|
| 58 |
+
# pipe_k (GPU 3): fast-forward to S3 SFT from lr1e4 lineage (mirror of i but other init).
|
| 59 |
+
launch_variant 3 pipe_k_s3sft_lr1e4_fast "${CKPT_LR1E4}" \
|
| 60 |
+
START_PHASE=s3_sft S3_SFT_INIT="${CKPT_LR1E4}" \
|
| 61 |
+
SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
|
| 62 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 63 |
+
USE_GC=0
|
| 64 |
+
|
| 65 |
+
# pipe_l (GPU 4): aggressive 10x GRPO LR + 16 generations, push past S2 plateau.
|
| 66 |
+
launch_variant 4 pipe_l_lr5e5_grpo5e5_ng16 "${CKPT_LR5E5}" \
|
| 67 |
+
START_PHASE=s2_grpo \
|
| 68 |
+
GRPO_LR=5e-5 GRPO_BS=16 GRPO_GA=1 GRPO_NG=16 \
|
| 69 |
+
PENALTY_SINGLETON=1.5 \
|
| 70 |
+
SFT_LR_S3=2e-5 SFT_BS=32 SFT_GA=1 \
|
| 71 |
+
USE_GC=0
|
| 72 |
+
|
| 73 |
+
cat >>"${SWEEP_ROOT}/SWEEP_README.md" <<EOF
|
| 74 |
+
|
| 75 |
+
## Added at $(date '+%H:%M:%S') — high-throughput / S3 fast-forward
|
| 76 |
+
|
| 77 |
+
S2 GRPO plateaued at solve=0.14 (lr5e5 lineage) or 0.05 (lr1e4 lineage) for all
|
| 78 |
+
of pipe_a/b/c/d/e — bit-identical evals from step 150 to 450. The per-cell
|
| 79 |
+
exact ceiling (~0.91) caps puzzle solve at ~0.91^20 ~= 0.14 regardless of
|
| 80 |
+
GRPO. Real lever is S3 SFT on harder cells (multi-value).
|
| 81 |
+
|
| 82 |
+
Killed pipe_a, pipe_c, pipe_d, pipe_e (flat). Launched 4 replacements with
|
| 83 |
+
USE_GC=0 (gradient checkpointing OFF — we have 80 GB headroom) and bs=32x1
|
| 84 |
+
for ~2-3x throughput per GPU.
|
| 85 |
+
|
| 86 |
+
| GPU | variant | start phase | init | SFT LR (S3) | GRPO LR | bs | ng |
|
| 87 |
+
| ---: | --- | --- | --- | ---: | ---: | ---: | ---: |
|
| 88 |
+
| 0 | pipe_i_s3sft_lr5e5_fast | s3_sft | lr5e5 step-3000 | 2e-5 | 5e-6 | 32 | 8 |
|
| 89 |
+
| 2 | pipe_j_s3sft_lr5e5_lr1e5 | s3_sft | lr5e5 step-3000 | 1e-5 | 5e-6 | 32 | 8 |
|
| 90 |
+
| 3 | pipe_k_s3sft_lr1e4_fast | s3_sft | lr1e4 step-3000 | 2e-5 | 5e-6 | 32 | 8 |
|
| 91 |
+
| 4 | pipe_l_lr5e5_grpo5e5_ng16 | s2_grpo | lr5e5 step-3000 | - | 5e-5 | 16 | 16 |
|
| 92 |
+
EOF
|
| 93 |
+
|
| 94 |
+
echo "Done. Now running 8 variants on GPUs 0..7."
|
_runs/baseline_1p5b_pipeline_v4.sh
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 1.5B vanilla baseline: S2 GRPO -> S3 SFT -> S3 GRPO, single GPU.
|
| 3 |
+
# Optionally pre-pends extra S2 SFT steps if EXTRA_S2_SFT_STEPS>0.
|
| 4 |
+
#
|
| 5 |
+
# Required env:
|
| 6 |
+
# VARIANT variant name (used in dirs / wandb)
|
| 7 |
+
# GPU CUDA index for this variant (0..7)
|
| 8 |
+
# S2_SFT_CKPT path to S2 SFT LoRA adapter (uses this as S2 GRPO init)
|
| 9 |
+
#
|
| 10 |
+
# Optional env:
|
| 11 |
+
# ROOT default /home/ubuntu/curriculum_cot
|
| 12 |
+
# PYTHON_BIN default /opt/pytorch/bin/python
|
| 13 |
+
# OUTPUT_ROOT default $ROOT/_runs/baseline_1p5b_v4_$(date)/$VARIANT
|
| 14 |
+
# MODEL_NAME default Qwen/Qwen2.5-1.5B-Instruct
|
| 15 |
+
# GRPO_LR default 5e-6
|
| 16 |
+
# GRPO_BETA default 0.0
|
| 17 |
+
# GRPO_NG default 8
|
| 18 |
+
# GRPO_BS default 16
|
| 19 |
+
# GRPO_GA default 2
|
| 20 |
+
# GRPO_PROMPT default 768
|
| 21 |
+
# GRPO_COMPL default 24
|
| 22 |
+
# PENALTY_SINGLETON default 1.5
|
| 23 |
+
# PENALTY_BAD default 1.0
|
| 24 |
+
# REWARD_GOOD default 1.25
|
| 25 |
+
# PENALTY_MAL default 4.0
|
| 26 |
+
# PENALTY_EMPTY default 0.5
|
| 27 |
+
# SFT_LR_S3 default 2e-5
|
| 28 |
+
# SFT_BS default 16
|
| 29 |
+
# SFT_GA default 2
|
| 30 |
+
# VALUE_TARGET default 0.98
|
| 31 |
+
# S2_GRPO_MAX_STEPS default 1200 (pipeline budget)
|
| 32 |
+
# S3_SFT_MAX_STEPS default 2400
|
| 33 |
+
# S3_GRPO_MAX_STEPS default 1500
|
| 34 |
+
# EXTRA_S2_SFT_STEPS default 0 (extra S2 SFT steps before S2 GRPO)
|
| 35 |
+
# EXTRA_S2_SFT_LR default 1e-5
|
| 36 |
+
# EVAL_ROWS default 100
|
| 37 |
+
# TRAIN_ROWS default 10000
|
| 38 |
+
# USE_WANDB default 0
|
| 39 |
+
# WANDB_PROJECT default sudoku-baseline-1p5b-v4
|
| 40 |
+
# WANDB_MODE default offline
|
| 41 |
+
# PHASE_WALL_SECS default 0 (no phase wallclock cap)
|
| 42 |
+
# START_PHASE default s2_grpo (one of: s2_sft_extra,s2_grpo,s3_sft,s3_grpo)
|
| 43 |
+
# S3_SFT_INIT if START_PHASE=s3_sft, S3-SFT init adapter (overrides S2 GRPO output)
|
| 44 |
+
# S3_GRPO_INIT if START_PHASE=s3_grpo, S3-GRPO init adapter
|
| 45 |
+
# USE_GC default 0 (1 to enable gradient checkpointing; we usually have memory)
|
| 46 |
+
|
| 47 |
+
set -euo pipefail
|
| 48 |
+
|
| 49 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 50 |
+
PYTHON_BIN="${PYTHON_BIN:-/opt/pytorch/bin/python}"
|
| 51 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 52 |
+
GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 53 |
+
|
| 54 |
+
: "${VARIANT:?VARIANT required}"
|
| 55 |
+
: "${GPU:?GPU required}"
|
| 56 |
+
: "${S2_SFT_CKPT:?S2_SFT_CKPT required}"
|
| 57 |
+
|
| 58 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/_runs/baseline_1p5b_v4_$(date +%Y%m%d_%H%M%S)/${VARIANT}}"
|
| 59 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
|
| 60 |
+
|
| 61 |
+
GRPO_LR="${GRPO_LR:-5e-6}"
|
| 62 |
+
GRPO_BETA="${GRPO_BETA:-0.0}"
|
| 63 |
+
GRPO_NG="${GRPO_NG:-8}"
|
| 64 |
+
GRPO_BS="${GRPO_BS:-16}"
|
| 65 |
+
GRPO_GA="${GRPO_GA:-2}"
|
| 66 |
+
GRPO_PROMPT="${GRPO_PROMPT:-768}"
|
| 67 |
+
GRPO_COMPL="${GRPO_COMPL:-24}"
|
| 68 |
+
PENALTY_SINGLETON="${PENALTY_SINGLETON:-1.5}"
|
| 69 |
+
PENALTY_BAD="${PENALTY_BAD:-1.0}"
|
| 70 |
+
PENALTY_MAL="${PENALTY_MAL:-4.0}"
|
| 71 |
+
PENALTY_EMPTY="${PENALTY_EMPTY:-0.5}"
|
| 72 |
+
REWARD_GOOD="${REWARD_GOOD:-1.25}"
|
| 73 |
+
PENALTY_MISSING="${PENALTY_MISSING:-0.0}"
|
| 74 |
+
EXACT_MATCH_BONUS="${EXACT_MATCH_BONUS:-0.0}"
|
| 75 |
+
CARD_MISMATCH_PEN="${CARD_MISMATCH_PEN:-0.0}"
|
| 76 |
+
SFT_OVERSAMPLE="${SFT_OVERSAMPLE:-1}"
|
| 77 |
+
SFT_TGT_MIN="${SFT_TGT_MIN:-0}"
|
| 78 |
+
SFT_TGT_MAX="${SFT_TGT_MAX:-0}"
|
| 79 |
+
SFT_LR_S3="${SFT_LR_S3:-2e-5}"
|
| 80 |
+
SFT_BS="${SFT_BS:-16}"
|
| 81 |
+
SFT_GA="${SFT_GA:-2}"
|
| 82 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 83 |
+
S2_GRPO_MAX_STEPS="${S2_GRPO_MAX_STEPS:-1200}"
|
| 84 |
+
S3_SFT_MAX_STEPS="${S3_SFT_MAX_STEPS:-2400}"
|
| 85 |
+
S3_GRPO_MAX_STEPS="${S3_GRPO_MAX_STEPS:-1500}"
|
| 86 |
+
EXTRA_S2_SFT_STEPS="${EXTRA_S2_SFT_STEPS:-0}"
|
| 87 |
+
EXTRA_S2_SFT_LR="${EXTRA_S2_SFT_LR:-1e-5}"
|
| 88 |
+
EVAL_ROWS="${EVAL_ROWS:-100}"
|
| 89 |
+
TRAIN_ROWS="${TRAIN_ROWS:-10000}"
|
| 90 |
+
USE_WANDB="${USE_WANDB:-0}"
|
| 91 |
+
WANDB_PROJECT="${WANDB_PROJECT:-sudoku-baseline-1p5b-v4}"
|
| 92 |
+
WANDB_MODE="${WANDB_MODE:-offline}"
|
| 93 |
+
PHASE_WALL_SECS="${PHASE_WALL_SECS:-0}"
|
| 94 |
+
START_PHASE="${START_PHASE:-s2_grpo}"
|
| 95 |
+
S3_SFT_INIT="${S3_SFT_INIT:-}"
|
| 96 |
+
S3_GRPO_INIT="${S3_GRPO_INIT:-}"
|
| 97 |
+
USE_GC="${USE_GC:-0}"
|
| 98 |
+
|
| 99 |
+
TRAIN_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
|
| 100 |
+
EVAL_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
|
| 101 |
+
|
| 102 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 103 |
+
PIPELINE_LOG="${OUTPUT_ROOT}/PIPELINE.log"
|
| 104 |
+
|
| 105 |
+
ts() { date +'%H:%M:%S'; }
|
| 106 |
+
log() { printf '[%s] %s\n' "$(ts)" "$*" | tee -a "${PIPELINE_LOG}" >&2; }
|
| 107 |
+
|
| 108 |
+
latest_ckpt_step() {
|
| 109 |
+
local d="$1"
|
| 110 |
+
shopt -s nullglob
|
| 111 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 112 |
+
shopt -u nullglob
|
| 113 |
+
(( ${#cks[@]} == 0 )) && return 1
|
| 114 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
best_grpo_adapter() {
|
| 118 |
+
local d="$1"
|
| 119 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 120 |
+
printf '%s\n' "${d}"; return 0
|
| 121 |
+
fi
|
| 122 |
+
local best="" step=-1
|
| 123 |
+
shopt -s nullglob
|
| 124 |
+
for c in "${d}"/checkpoint-*; do
|
| 125 |
+
[[ -d "$c" ]] || continue
|
| 126 |
+
[[ -f "$c/adapter_model.safetensors" ]] || continue
|
| 127 |
+
local n="${c##*checkpoint-}"
|
| 128 |
+
if [[ "$n" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 129 |
+
step=$((10#${n})); best="$c"
|
| 130 |
+
fi
|
| 131 |
+
done
|
| 132 |
+
shopt -u nullglob
|
| 133 |
+
[[ -n "$best" ]] || return 1
|
| 134 |
+
printf '%s\n' "$best"
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if [[ ! -f "${TRAIN_JSONL}" || ! -f "${EVAL_JSONL}" ]]; then
|
| 138 |
+
log "ERROR: missing dataset jsonls (${TRAIN_JSONL} / ${EVAL_JSONL})."
|
| 139 |
+
exit 1
|
| 140 |
+
fi
|
| 141 |
+
|
| 142 |
+
export CUDA_VISIBLE_DEVICES="${GPU}"
|
| 143 |
+
export TOKENIZERS_PARALLELISM=false
|
| 144 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 145 |
+
export HF_HOME="${ROOT}/.hf_cache"
|
| 146 |
+
export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
|
| 147 |
+
export WANDB_MODE="${WANDB_MODE}"
|
| 148 |
+
|
| 149 |
+
run_sft() {
|
| 150 |
+
local stage="$1" init_adapter="$2" out_dir="$3" lr="$4" max_steps="$5" tag="$6"
|
| 151 |
+
mkdir -p "${out_dir}"
|
| 152 |
+
log "=== Stage ${stage} SFT (${tag}) lr=${lr} max_steps=${max_steps} bs=${SFT_BS}x${SFT_GA} GC=${USE_GC} init=${init_adapter} ==="
|
| 153 |
+
log " out=${out_dir}"
|
| 154 |
+
local extra=()
|
| 155 |
+
if [[ "${USE_WANDB}" == "1" ]]; then
|
| 156 |
+
extra+=(--use_wandb --wandb_project "${WANDB_PROJECT}" \
|
| 157 |
+
--wandb_run_name "${VARIANT}_${tag}" --wandb_mode "${WANDB_MODE}")
|
| 158 |
+
fi
|
| 159 |
+
if [[ "${USE_GC}" == "1" ]]; then
|
| 160 |
+
extra+=(--enable_gradient_checkpointing)
|
| 161 |
+
fi
|
| 162 |
+
"${PYTHON_BIN}" -u "${SFT_SCRIPT}" \
|
| 163 |
+
--model_name "${MODEL_NAME}" \
|
| 164 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 165 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 166 |
+
--output_dir "${out_dir}" \
|
| 167 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 168 |
+
--init_adapter_dir "${init_adapter}" \
|
| 169 |
+
--seed 0 \
|
| 170 |
+
--gpu_id 0 \
|
| 171 |
+
--stage_i "${stage}" \
|
| 172 |
+
--total_empties_hint 20 \
|
| 173 |
+
--per_device_train_batch_size "${SFT_BS}" \
|
| 174 |
+
--gradient_accumulation_steps "${SFT_GA}" \
|
| 175 |
+
--num_epochs 256 \
|
| 176 |
+
--learning_rate "${lr}" \
|
| 177 |
+
--max_grad_norm 1.0 \
|
| 178 |
+
--logging_steps 25 \
|
| 179 |
+
--eval_steps 150 \
|
| 180 |
+
--save_steps 200 \
|
| 181 |
+
--eval_rows "${EVAL_ROWS}" \
|
| 182 |
+
--max_completion_length 24 \
|
| 183 |
+
--limit_train_rows "${TRAIN_ROWS}" \
|
| 184 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 185 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 186 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 187 |
+
--eval_exact_set_match_stop 0 \
|
| 188 |
+
--eval_solve_rate_stop 0 \
|
| 189 |
+
--min_steps_before_stop 100 \
|
| 190 |
+
--max_wall_clock_seconds "${PHASE_WALL_SECS}" \
|
| 191 |
+
--max_steps "${max_steps}" \
|
| 192 |
+
--multi_value_oversample_factor "${SFT_OVERSAMPLE}" \
|
| 193 |
+
--train_target_size_min "${SFT_TGT_MIN}" \
|
| 194 |
+
--train_target_size_max "${SFT_TGT_MAX}" \
|
| 195 |
+
"${extra[@]}" 2>&1 | tee "${out_dir}/train.log"
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
run_grpo() {
|
| 199 |
+
local stage="$1" init_adapter="$2" out_dir="$3" max_steps="$4" tag="$5"
|
| 200 |
+
mkdir -p "${out_dir}"
|
| 201 |
+
log "=== Stage ${stage} GRPO (${tag}) lr=${GRPO_LR} ng=${GRPO_NG} bs=${GRPO_BS}x${GRPO_GA} prompt=${GRPO_PROMPT} GC=${USE_GC} max_steps=${max_steps} init=${init_adapter} ==="
|
| 202 |
+
log " rewards: good=${REWARD_GOOD} bad=${PENALTY_BAD} mal=${PENALTY_MAL} empty=${PENALTY_EMPTY} sngl=${PENALTY_SINGLETON} missing=${PENALTY_MISSING} exact_b=${EXACT_MATCH_BONUS} card_pen=${CARD_MISMATCH_PEN}"
|
| 203 |
+
log " out=${out_dir}"
|
| 204 |
+
local extra=()
|
| 205 |
+
if [[ "${USE_WANDB}" == "1" ]]; then
|
| 206 |
+
extra+=(--use_wandb --wandb_project "${WANDB_PROJECT}" \
|
| 207 |
+
--wandb_run_name "${VARIANT}_${tag}" --wandb_mode "${WANDB_MODE}")
|
| 208 |
+
fi
|
| 209 |
+
if [[ "${USE_GC}" == "1" ]]; then
|
| 210 |
+
extra+=(--enable_gradient_checkpointing)
|
| 211 |
+
fi
|
| 212 |
+
"${PYTHON_BIN}" -u "${GRPO_SCRIPT}" \
|
| 213 |
+
--model_name "${MODEL_NAME}" \
|
| 214 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 215 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 216 |
+
--output_dir "${out_dir}" \
|
| 217 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 218 |
+
--init_adapter_dir "${init_adapter}" \
|
| 219 |
+
--seed 0 \
|
| 220 |
+
--gpu_id 0 \
|
| 221 |
+
--stage_i "${stage}" \
|
| 222 |
+
--total_empties_hint 20 \
|
| 223 |
+
--per_device_train_batch_size "${GRPO_BS}" \
|
| 224 |
+
--gradient_accumulation_steps "${GRPO_GA}" \
|
| 225 |
+
--num_train_epochs 100 \
|
| 226 |
+
--learning_rate "${GRPO_LR}" \
|
| 227 |
+
--logging_steps 10 \
|
| 228 |
+
--save_steps 200 \
|
| 229 |
+
--eval_steps 150 \
|
| 230 |
+
--eval_rows "${EVAL_ROWS}" \
|
| 231 |
+
--num_generations "${GRPO_NG}" \
|
| 232 |
+
--max_prompt_length "${GRPO_PROMPT}" \
|
| 233 |
+
--max_completion_length "${GRPO_COMPL}" \
|
| 234 |
+
--beta "${GRPO_BETA}" \
|
| 235 |
+
--limit_train_rows "${TRAIN_ROWS}" \
|
| 236 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 237 |
+
--reward_good_value "${REWARD_GOOD}" \
|
| 238 |
+
--penalty_bad_value "${PENALTY_BAD}" \
|
| 239 |
+
--penalty_malformed "${PENALTY_MAL}" \
|
| 240 |
+
--penalty_empty "${PENALTY_EMPTY}" \
|
| 241 |
+
--penalty_singleton "${PENALTY_SINGLETON}" \
|
| 242 |
+
--penalty_missing "${PENALTY_MISSING}" \
|
| 243 |
+
--exact_match_bonus "${EXACT_MATCH_BONUS}" \
|
| 244 |
+
--cardinality_mismatch_penalty "${CARD_MISMATCH_PEN}" \
|
| 245 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 246 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 247 |
+
--eval_solve_rate_stop 0 \
|
| 248 |
+
--min_steps_before_stop 100 \
|
| 249 |
+
--max_wall_clock_seconds "${PHASE_WALL_SECS}" \
|
| 250 |
+
--max_steps "${max_steps}" \
|
| 251 |
+
"${extra[@]}" 2>&1 | tee "${out_dir}/train.log"
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
log "===== ${VARIANT} on GPU ${GPU} ====="
|
| 255 |
+
log "S2 SFT init: ${S2_SFT_CKPT}"
|
| 256 |
+
log "START_PHASE=${START_PHASE} GRPO_LR=${GRPO_LR} SFT_LR_S3=${SFT_LR_S3} PENALTY_SINGLETON=${PENALTY_SINGLETON} USE_GC=${USE_GC}"
|
| 257 |
+
log " EXTRA_S2_SFT_STEPS=${EXTRA_S2_SFT_STEPS} GRPO_BS=${GRPO_BS}x${GRPO_GA} SFT_BS=${SFT_BS}x${SFT_GA} GRPO_NG=${GRPO_NG}"
|
| 258 |
+
|
| 259 |
+
S2_SFT_DIR_FOR_GRPO="${S2_SFT_CKPT}"
|
| 260 |
+
S2_GRPO_ADAPTER=""
|
| 261 |
+
S3_SFT_INIT_RESOLVED=""
|
| 262 |
+
S3_GRPO_INIT_RESOLVED=""
|
| 263 |
+
|
| 264 |
+
phase_idx() {
|
| 265 |
+
case "$1" in
|
| 266 |
+
s2_sft_extra) echo 1 ;;
|
| 267 |
+
s2_grpo) echo 2 ;;
|
| 268 |
+
s3_sft) echo 3 ;;
|
| 269 |
+
s3_grpo) echo 4 ;;
|
| 270 |
+
*) echo 2 ;;
|
| 271 |
+
esac
|
| 272 |
+
}
|
| 273 |
+
START_IDX="$(phase_idx "${START_PHASE}")"
|
| 274 |
+
|
| 275 |
+
if (( START_IDX <= 1 )) && (( EXTRA_S2_SFT_STEPS > 0 )); then
|
| 276 |
+
S2_SFT_EXTRA_DIR="${OUTPUT_ROOT}/s2_sft_extra"
|
| 277 |
+
run_sft 2 "${S2_SFT_CKPT}" "${S2_SFT_EXTRA_DIR}" "${EXTRA_S2_SFT_LR}" "${EXTRA_S2_SFT_STEPS}" "s2sft_extra"
|
| 278 |
+
if NEW_CKPT="$(latest_ckpt_step "${S2_SFT_EXTRA_DIR}")"; then
|
| 279 |
+
log ">>> Extra S2 SFT ckpt: ${NEW_CKPT}"
|
| 280 |
+
S2_SFT_DIR_FOR_GRPO="${NEW_CKPT}"
|
| 281 |
+
else
|
| 282 |
+
log "WARN: no new S2 SFT ckpt produced; falling back to ${S2_SFT_CKPT}"
|
| 283 |
+
fi
|
| 284 |
+
fi
|
| 285 |
+
|
| 286 |
+
if (( START_IDX <= 2 )); then
|
| 287 |
+
S2_GRPO_DIR="${OUTPUT_ROOT}/s2_grpo"
|
| 288 |
+
run_grpo 2 "${S2_SFT_DIR_FOR_GRPO}" "${S2_GRPO_DIR}" "${S2_GRPO_MAX_STEPS}" "s2grpo"
|
| 289 |
+
S2_GRPO_ADAPTER="$(best_grpo_adapter "${S2_GRPO_DIR}")"
|
| 290 |
+
if [[ -z "${S2_GRPO_ADAPTER}" ]]; then
|
| 291 |
+
log "ERROR: no S2 GRPO adapter under ${S2_GRPO_DIR}"; exit 1
|
| 292 |
+
fi
|
| 293 |
+
log ">>> S2 GRPO adapter: ${S2_GRPO_ADAPTER}"
|
| 294 |
+
S3_SFT_INIT_RESOLVED="${S2_GRPO_ADAPTER}"
|
| 295 |
+
elif (( START_IDX == 3 )); then
|
| 296 |
+
if [[ -z "${S3_SFT_INIT}" ]]; then
|
| 297 |
+
log "ERROR: START_PHASE=s3_sft but S3_SFT_INIT is empty"; exit 1
|
| 298 |
+
fi
|
| 299 |
+
S3_SFT_INIT_RESOLVED="${S3_SFT_INIT}"
|
| 300 |
+
log ">>> Skipping to S3 SFT, init=${S3_SFT_INIT_RESOLVED}"
|
| 301 |
+
fi
|
| 302 |
+
|
| 303 |
+
if (( START_IDX <= 3 )); then
|
| 304 |
+
S3_SFT_DIR="${OUTPUT_ROOT}/s3_sft"
|
| 305 |
+
run_sft 3 "${S3_SFT_INIT_RESOLVED}" "${S3_SFT_DIR}" "${SFT_LR_S3}" "${S3_SFT_MAX_STEPS}" "s3sft"
|
| 306 |
+
S3_SFT_CKPT="$(latest_ckpt_step "${S3_SFT_DIR}")"
|
| 307 |
+
if [[ -z "${S3_SFT_CKPT}" ]]; then
|
| 308 |
+
log "ERROR: no S3 SFT ckpt under ${S3_SFT_DIR}"; exit 1
|
| 309 |
+
fi
|
| 310 |
+
log ">>> S3 SFT ckpt: ${S3_SFT_CKPT}"
|
| 311 |
+
S3_GRPO_INIT_RESOLVED="${S3_SFT_CKPT}"
|
| 312 |
+
elif (( START_IDX == 4 )); then
|
| 313 |
+
if [[ -z "${S3_GRPO_INIT}" ]]; then
|
| 314 |
+
log "ERROR: START_PHASE=s3_grpo but S3_GRPO_INIT is empty"; exit 1
|
| 315 |
+
fi
|
| 316 |
+
S3_GRPO_INIT_RESOLVED="${S3_GRPO_INIT}"
|
| 317 |
+
log ">>> Skipping to S3 GRPO, init=${S3_GRPO_INIT_RESOLVED}"
|
| 318 |
+
fi
|
| 319 |
+
|
| 320 |
+
S3_GRPO_DIR="${OUTPUT_ROOT}/s3_grpo"
|
| 321 |
+
run_grpo 3 "${S3_GRPO_INIT_RESOLVED}" "${S3_GRPO_DIR}" "${S3_GRPO_MAX_STEPS}" "s3grpo"
|
| 322 |
+
S3_GRPO_ADAPTER="$(best_grpo_adapter "${S3_GRPO_DIR}")"
|
| 323 |
+
if [[ -z "${S3_GRPO_ADAPTER}" ]]; then
|
| 324 |
+
log "ERROR: no S3 GRPO adapter under ${S3_GRPO_DIR}"; exit 1
|
| 325 |
+
fi
|
| 326 |
+
log ">>> S3 GRPO adapter: ${S3_GRPO_ADAPTER}"
|
| 327 |
+
|
| 328 |
+
log "===== ${VARIANT} DONE — final S3 GRPO adapter at ${S3_GRPO_ADAPTER} ====="
|
_runs/eval_strawman_cellpolicy.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric.
|
| 3 |
+
|
| 4 |
+
This is a thin CLI wrapper that:
|
| 5 |
+
|
| 6 |
+
1. Loads a base model + LoRA adapter.
|
| 7 |
+
2. Runs the same scoring procedure as
|
| 8 |
+
``multi_output_cell_policy/sft_multi_output_train.py::run_eval``,
|
| 9 |
+
i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate
|
| 10 |
+
over empty cells in row-major order and scores each predicted value
|
| 11 |
+
with ``score_prediction_text`` against the i-consistent target set at
|
| 12 |
+
``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal).
|
| 13 |
+
3. The only difference vs the cell-policy is that the model emits the whole
|
| 14 |
+
puzzle in ONE forward pass, then the predicted list is split into
|
| 15 |
+
per-cell singletons.
|
| 16 |
+
|
| 17 |
+
Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``)
|
| 18 |
+
and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k
|
| 19 |
+
models (``adaptive_latent_baseline_sudoku_train.py``).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import sys
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from typing import Any, Dict, List
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
from peft import PeftModel
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 33 |
+
|
| 34 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 35 |
+
if str(ROOT) not in sys.path:
|
| 36 |
+
sys.path.insert(0, str(ROOT))
|
| 37 |
+
|
| 38 |
+
from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore # noqa: E402
|
| 39 |
+
load_jsonl_rows,
|
| 40 |
+
pick_dtype,
|
| 41 |
+
)
|
| 42 |
+
from _runs.simple_baseline_sudoku_train import ( # type: ignore # noqa: E402
|
| 43 |
+
run_eval as run_eval_strawman,
|
| 44 |
+
)
|
| 45 |
+
from _runs.adaptive_latent_baseline_sudoku_train import ( # type: ignore # noqa: E402
|
| 46 |
+
run_eval as run_eval_adaptive_k,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def parse_args() -> argparse.Namespace:
|
| 51 |
+
p = argparse.ArgumentParser()
|
| 52 |
+
p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True)
|
| 53 |
+
p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 54 |
+
p.add_argument("--adapter_dir", required=True)
|
| 55 |
+
p.add_argument("--eval_jsonl", required=True)
|
| 56 |
+
p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
|
| 57 |
+
p.add_argument("--eval_rows", type=int, default=100)
|
| 58 |
+
p.add_argument("--max_completion_length", type=int, default=96)
|
| 59 |
+
p.add_argument("--stage_i", type=int, default=3)
|
| 60 |
+
p.add_argument(
|
| 61 |
+
"--num_cot_tokens",
|
| 62 |
+
type=int,
|
| 63 |
+
default=0,
|
| 64 |
+
help="Only used when --kind adaptive_k.",
|
| 65 |
+
)
|
| 66 |
+
p.add_argument("--seed", type=int, default=0)
|
| 67 |
+
p.add_argument("--out_json", default="")
|
| 68 |
+
return p.parse_args()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def main() -> None:
|
| 72 |
+
args = parse_args()
|
| 73 |
+
set_seed(int(args.seed))
|
| 74 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 75 |
+
dtype = pick_dtype()
|
| 76 |
+
|
| 77 |
+
print(f"[eval-cellpolicy] kind={args.kind} adapter={args.adapter_dir}", flush=True)
|
| 78 |
+
print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl} stage_i={args.stage_i}", flush=True)
|
| 79 |
+
|
| 80 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 81 |
+
args.model_name, cache_dir=args.cache_dir, use_fast=True
|
| 82 |
+
)
|
| 83 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 85 |
+
|
| 86 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 87 |
+
args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype
|
| 88 |
+
)
|
| 89 |
+
model = PeftModel.from_pretrained(base, args.adapter_dir)
|
| 90 |
+
model.to(device)
|
| 91 |
+
model.eval()
|
| 92 |
+
|
| 93 |
+
rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
|
| 94 |
+
print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True)
|
| 95 |
+
|
| 96 |
+
if args.kind == "strawman":
|
| 97 |
+
metrics = run_eval_strawman(
|
| 98 |
+
model, tokenizer, rows, device,
|
| 99 |
+
max_new_tokens=int(args.max_completion_length),
|
| 100 |
+
print_n=3,
|
| 101 |
+
stage_i=int(args.stage_i),
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
metrics = run_eval_adaptive_k(
|
| 105 |
+
model, tokenizer, rows, device,
|
| 106 |
+
num_cot_tokens=int(args.num_cot_tokens),
|
| 107 |
+
max_new_tokens=int(args.max_completion_length),
|
| 108 |
+
print_n=3,
|
| 109 |
+
stage_i=int(args.stage_i),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True)
|
| 113 |
+
if args.out_json:
|
| 114 |
+
Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
with open(args.out_json, "w") as f:
|
| 116 |
+
json.dump(
|
| 117 |
+
{
|
| 118 |
+
"kind": args.kind,
|
| 119 |
+
"adapter_dir": args.adapter_dir,
|
| 120 |
+
"eval_jsonl": args.eval_jsonl,
|
| 121 |
+
"stage_i": int(args.stage_i),
|
| 122 |
+
"num_cot_tokens": int(args.num_cot_tokens),
|
| 123 |
+
"metrics": metrics,
|
| 124 |
+
},
|
| 125 |
+
f,
|
| 126 |
+
indent=2,
|
| 127 |
+
)
|
| 128 |
+
print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
_runs/launch_adaptive_k_cellpolicy.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Launch two adaptive-k variants (single-stage cell-policy at stage_i=3,
|
| 3 |
+
# no curriculum, but with growing recurrent-hidden thought tokens k).
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 6 |
+
TS="$(date +%Y%m%d_%H%M%S)"
|
| 7 |
+
SWEEP_ROOT="${ROOT}/_runs/adaptive_k_cellpolicy_${TS}"
|
| 8 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 9 |
+
PY="${ROOT}/_runs/adaptive_k_cellpolicy_pipeline.py"
|
| 10 |
+
|
| 11 |
+
launch() {
|
| 12 |
+
# Usage: launch <variant> <gpu> <KEY=VALUE>... (ignored, args passed via positional CLI args)
|
| 13 |
+
local variant="$1" gpu="$2"
|
| 14 |
+
shift 2
|
| 15 |
+
local out="${SWEEP_ROOT}/${variant}"
|
| 16 |
+
mkdir -p "${out}"
|
| 17 |
+
echo "[launch] ${variant} on GPU ${gpu} out=${out}"
|
| 18 |
+
nohup /opt/pytorch/bin/python -u "${PY}" \
|
| 19 |
+
--variant "${variant}" \
|
| 20 |
+
--gpu "${gpu}" \
|
| 21 |
+
--output_root "${out}" \
|
| 22 |
+
"$@" > "${out}/console.log" 2>&1 &
|
| 23 |
+
local pid=$!
|
| 24 |
+
disown "${pid}" || true
|
| 25 |
+
echo "${variant}=${pid}" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
# adaptive_a: classic schedule (start at k=0, plateau-bumps with eps=0.01).
|
| 29 |
+
launch adaptive_a_eps01 2 \
|
| 30 |
+
--start_k 0 --max_k 4 --steps_per_phase 600 --max_phases_per_k 2 \
|
| 31 |
+
--plateau_eps 0.01 --sft_lr 2e-5 --sft_bs 8 --sft_ga 4 \
|
| 32 |
+
--grpo_steps 1500 --grpo_lr 5e-6 --grpo_bs 8 --grpo_ga 4 --grpo_ng 8
|
| 33 |
+
|
| 34 |
+
# adaptive_b: faster k-growth (max_phases_per_k=1, force bump every phase).
|
| 35 |
+
launch adaptive_b_fastgrow 3 \
|
| 36 |
+
--start_k 0 --max_k 4 --steps_per_phase 800 --max_phases_per_k 1 \
|
| 37 |
+
--plateau_eps 1.0 --sft_lr 2e-5 --sft_bs 8 --sft_ga 4 \
|
| 38 |
+
--grpo_steps 1500 --grpo_lr 5e-6 --grpo_bs 8 --grpo_ga 4 --grpo_ng 8
|
| 39 |
+
|
| 40 |
+
echo "[launch] sweep root: ${SWEEP_ROOT}"
|
| 41 |
+
echo "[launch] PIDs:"
|
| 42 |
+
cat "${SWEEP_ROOT}/PIDS.txt"
|
_runs/launch_adaptive_latent_baseline.sh
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Adaptive-k thought-token baseline (experiment D in the 2x2 ablation).
|
| 3 |
+
#
|
| 4 |
+
# Same single-stage, whole-puzzle setup as launch_simple_baseline.sh
|
| 5 |
+
# (experiment C, the "strawman"); same model, LoRA, JSONL, chat template.
|
| 6 |
+
# The ONLY change is that the SFT loss uses the recurrent_hidden mechanism
|
| 7 |
+
# with k thought tokens, and k grows automatically when the rolling-mean
|
| 8 |
+
# loss plateaus.
|
| 9 |
+
set -euo pipefail
|
| 10 |
+
|
| 11 |
+
ROOT=/home/ubuntu/curriculum_cot
|
| 12 |
+
SCRIPT=${ROOT}/_runs/adaptive_latent_baseline_sudoku_train.py
|
| 13 |
+
PYTHON_BIN=/opt/pytorch/bin/python
|
| 14 |
+
|
| 15 |
+
TRAIN_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
|
| 16 |
+
EVAL_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 17 |
+
|
| 18 |
+
SWEEP_ROOT=${ROOT}/_runs/adaptive_latent_$(date +%Y%m%d_%H%M%S)
|
| 19 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 20 |
+
echo "${SWEEP_ROOT}" > "${ROOT}/_runs/current_adaptive_latent_sweep_dir"
|
| 21 |
+
echo "SWEEP_ROOT=${SWEEP_ROOT}"
|
| 22 |
+
|
| 23 |
+
export TOKENIZERS_PARALLELISM=false
|
| 24 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export HF_HOME="${ROOT}/.hf_cache"
|
| 26 |
+
export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
|
| 27 |
+
export WANDB_MODE=offline
|
| 28 |
+
|
| 29 |
+
run_variant() {
|
| 30 |
+
local gpu="$1" tag="$2" lr="$3" max_k="$4" min_steps_per_k="$5"
|
| 31 |
+
local out=${SWEEP_ROOT}/${tag}
|
| 32 |
+
mkdir -p "${out}"
|
| 33 |
+
local log=${out}/train.log
|
| 34 |
+
: > "${log}"
|
| 35 |
+
(
|
| 36 |
+
export CUDA_VISIBLE_DEVICES="${gpu}"
|
| 37 |
+
"${PYTHON_BIN}" -u "${SCRIPT}" \
|
| 38 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 39 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 40 |
+
--output_dir "${out}" \
|
| 41 |
+
--learning_rate "${lr}" \
|
| 42 |
+
--max_steps 4000 \
|
| 43 |
+
--per_device_train_batch_size 4 \
|
| 44 |
+
--gradient_accumulation_steps 2 \
|
| 45 |
+
--logging_steps 25 \
|
| 46 |
+
--save_steps 500 \
|
| 47 |
+
--eval_every_steps 500 \
|
| 48 |
+
--eval_rows 50 \
|
| 49 |
+
--max_completion_length 96 \
|
| 50 |
+
--max_prompt_length 1024 \
|
| 51 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 52 |
+
--enable_gradient_checkpointing \
|
| 53 |
+
--start_k 0 \
|
| 54 |
+
--max_k "${max_k}" \
|
| 55 |
+
--min_steps_per_k "${min_steps_per_k}" \
|
| 56 |
+
--plateau_window 100 \
|
| 57 |
+
--plateau_eps 0.005 \
|
| 58 |
+
--converged_eps 0.001 \
|
| 59 |
+
--seed 0 \
|
| 60 |
+
>> "${log}" 2>&1
|
| 61 |
+
) >/dev/null 2>&1 &
|
| 62 |
+
local pid=$!
|
| 63 |
+
echo "$pid $gpu $tag" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 64 |
+
disown $pid 2>/dev/null || true
|
| 65 |
+
printf 'GPU %s -> %s pid=%s log=%s\n' "$gpu" "$tag" "$pid" "$log"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# 2 variants on idle GPUs 2,3:
|
| 69 |
+
# - adaptive_a: same LR (5e-5) as strawman variant a, max_k=4, min_steps_per_k=400
|
| 70 |
+
# - adaptive_b: smaller min_steps_per_k=250 to grow k more aggressively
|
| 71 |
+
run_variant 2 adaptive_a_lr5e5_maxk4 5e-5 4 400
|
| 72 |
+
run_variant 3 adaptive_b_lr5e5_fastgrow 5e-5 4 250
|
| 73 |
+
|
| 74 |
+
echo
|
| 75 |
+
echo "=== launched ==="
|
| 76 |
+
cat "${SWEEP_ROOT}/PIDS.txt"
|
_runs/launch_baseline_1p5b_v4.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Launch 6 baseline 1.5B variants in parallel, one per GPU (0..5).
|
| 3 |
+
# Each runs S2 GRPO -> S3 SFT -> S3 GRPO from a v3 lowsft S2 SFT checkpoint.
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 7 |
+
SWEEP_ID="${SWEEP_ID:-$(date +%Y%m%d_%H%M%S)}"
|
| 8 |
+
SWEEP_ROOT="${ROOT}/_runs/baseline_1p5b_v4_${SWEEP_ID}"
|
| 9 |
+
PIPELINE="${ROOT}/_runs/baseline_1p5b_pipeline_v4.sh"
|
| 10 |
+
|
| 11 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 12 |
+
SUMMARY="${SWEEP_ROOT}/SWEEP_README.md"
|
| 13 |
+
|
| 14 |
+
CKPT_LR1E4="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr1e4_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
|
| 15 |
+
CKPT_LR5E5="${ROOT}/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000"
|
| 16 |
+
|
| 17 |
+
if [[ ! -d "${CKPT_LR1E4}" || ! -d "${CKPT_LR5E5}" ]]; then
|
| 18 |
+
echo "ERROR: missing init checkpoints" >&2
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
cat >"${SUMMARY}" <<EOF
|
| 23 |
+
# Baseline 1.5B v4 sweep — ${SWEEP_ID}
|
| 24 |
+
|
| 25 |
+
Single GPU per variant. All 6 variants resume from the v3 lowsft S2 SFT
|
| 26 |
+
checkpoints (the only ones with positive trend), then run S2 GRPO -> S3 SFT
|
| 27 |
+
-> S3 GRPO with various GRPO LR / penalty / extra-S2-SFT settings.
|
| 28 |
+
|
| 29 |
+
| GPU | variant | S2 init | GRPO LR | S3 SFT LR | penalty_singleton | extra S2 SFT (steps @ LR) |
|
| 30 |
+
| ---: | --- | --- | ---: | ---: | ---: | --- |
|
| 31 |
+
| 0 | pipe_a_lr1e4_grpo5e6 | lr1e4_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 0 |
|
| 32 |
+
| 1 | pipe_b_lr5e5_grpo5e6 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 0 |
|
| 33 |
+
| 2 | pipe_c_lr1e4_grpo2e6 | lr1e4_lowsft step-3000 | 2e-6 | 2e-5 | 1.5 | 0 |
|
| 34 |
+
| 3 | pipe_d_lr5e5_grpo2e6 | lr5e5_lowsft step-3000 | 2e-6 | 2e-5 | 1.5 | 0 |
|
| 35 |
+
| 4 | pipe_e_lr5e5_grpo5e6_sngl25 | lr5e5_lowsft step-3000 | 5e-6 | 2e-5 | 2.5 | 0 |
|
| 36 |
+
| 5 | pipe_f_lr1e4_extraS2sft | lr1e4_lowsft step-3000 | 5e-6 | 2e-5 | 1.5 | 1500 @ 1e-5 |
|
| 37 |
+
|
| 38 |
+
Pipeline budget per variant:
|
| 39 |
+
- S2 GRPO max 1200 steps (early stop on prec AND recall >= 0.98)
|
| 40 |
+
- S3 SFT max 2400 steps (same early stop)
|
| 41 |
+
- S3 GRPO max 1500 steps (same early stop)
|
| 42 |
+
|
| 43 |
+
Logs: \`<variant>/PIPELINE.log\`, per-phase: \`<variant>/{s2_grpo,s3_sft,s3_grpo}/train.log\`
|
| 44 |
+
EOF
|
| 45 |
+
|
| 46 |
+
launch_variant() {
|
| 47 |
+
local gpu="$1" variant="$2" init="$3"
|
| 48 |
+
shift 3
|
| 49 |
+
local out="${SWEEP_ROOT}/${variant}"
|
| 50 |
+
mkdir -p "${out}"
|
| 51 |
+
local nohup_log="${out}/nohup.log"
|
| 52 |
+
printf 'GPU %s -> %s -> %s\n' "${gpu}" "${variant}" "${init}"
|
| 53 |
+
nohup env \
|
| 54 |
+
ROOT="${ROOT}" \
|
| 55 |
+
VARIANT="${variant}" \
|
| 56 |
+
GPU="${gpu}" \
|
| 57 |
+
S2_SFT_CKPT="${init}" \
|
| 58 |
+
OUTPUT_ROOT="${out}" \
|
| 59 |
+
USE_WANDB=0 \
|
| 60 |
+
WANDB_MODE=offline \
|
| 61 |
+
"$@" \
|
| 62 |
+
bash "${PIPELINE}" \
|
| 63 |
+
</dev/null >"${nohup_log}" 2>&1 &
|
| 64 |
+
local pid=$!
|
| 65 |
+
printf ' pid=%s log=%s\n' "${pid}" "${nohup_log}"
|
| 66 |
+
echo "${pid} ${gpu} ${variant}" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 67 |
+
disown "${pid}" 2>/dev/null || true
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
: > "${SWEEP_ROOT}/PIDS.txt"
|
| 71 |
+
|
| 72 |
+
launch_variant 0 pipe_a_lr1e4_grpo5e6 "${CKPT_LR1E4}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
|
| 73 |
+
launch_variant 1 pipe_b_lr5e5_grpo5e6 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
|
| 74 |
+
launch_variant 2 pipe_c_lr1e4_grpo2e6 "${CKPT_LR1E4}" GRPO_LR=2e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
|
| 75 |
+
launch_variant 3 pipe_d_lr5e5_grpo2e6 "${CKPT_LR5E5}" GRPO_LR=2e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5
|
| 76 |
+
launch_variant 4 pipe_e_lr5e5_grpo5e6_sngl25 "${CKPT_LR5E5}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=2.5
|
| 77 |
+
launch_variant 5 pipe_f_lr1e4_extraS2sft "${CKPT_LR1E4}" GRPO_LR=5e-6 SFT_LR_S3=2e-5 PENALTY_SINGLETON=1.5 EXTRA_S2_SFT_STEPS=1500 EXTRA_S2_SFT_LR=1e-5
|
| 78 |
+
|
| 79 |
+
echo
|
| 80 |
+
echo "Sweep root: ${SWEEP_ROOT}"
|
| 81 |
+
echo "Tail PIDS:"
|
| 82 |
+
cat "${SWEEP_ROOT}/PIDS.txt"
|
_runs/launch_baseline_push_v5.sh
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Wave-5: push baseline 1.5B past solve=0.35.
|
| 3 |
+
#
|
| 4 |
+
# Idea: best ckpts so far cap at per-cell-exact ~0.943 (solve 0.35 = 0.943^20).
|
| 5 |
+
# To reach solve=0.5 we need exact ~= 0.965. That's +2.2pp of per-cell exact.
|
| 6 |
+
#
|
| 7 |
+
# 4 variants, single-GPU each, on GPUs 4..7.
|
| 8 |
+
# All start from the leader (pipe_m post-S3-GRPO at solve=0.35) or its S3 SFT
|
| 9 |
+
# ckpt, then push S3 GRPO further with different levers:
|
| 10 |
+
# - lower LR (escape / fine refine)
|
| 11 |
+
# - longer steps (3000 instead of 1500)
|
| 12 |
+
# - KL anchor (beta>0) to prevent regression
|
| 13 |
+
# - sharper rewards (mirror what worked for the latent's `s3_grpo_sharp_rwd`)
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
|
| 16 |
+
ROOT=/home/ubuntu/curriculum_cot
|
| 17 |
+
SWEEP_ROOT=/home/ubuntu/curriculum_cot/_runs/baseline_1p5b_v4_20260523_184952
|
| 18 |
+
PIPELINE=$ROOT/_runs/baseline_1p5b_pipeline_v4.sh
|
| 19 |
+
|
| 20 |
+
# best wave-2 anchors
|
| 21 |
+
PIPE_M_S3GRPO_LATEST=$(ls -dt $SWEEP_ROOT/pipe_m_s3sft_from_b/s3_grpo/checkpoint-* 2>/dev/null | head -1)
|
| 22 |
+
PIPE_M_S3SFT_LATEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_sft/checkpoint-step-02400
|
| 23 |
+
PIPE_O_S3SFT_LATEST=$SWEEP_ROOT/pipe_o_s3sft_lr5e6/s3_sft/checkpoint-step-02400
|
| 24 |
+
PIPE_J_S3GRPO_LATEST=$(ls -dt $SWEEP_ROOT/pipe_j_s3sft_lr5e5_lr1e5/s3_grpo/checkpoint-* 2>/dev/null | head -1)
|
| 25 |
+
|
| 26 |
+
# Sanity
|
| 27 |
+
for c in "$PIPE_M_S3GRPO_LATEST" "$PIPE_M_S3SFT_LATEST" "$PIPE_O_S3SFT_LATEST" "$PIPE_J_S3GRPO_LATEST"; do
|
| 28 |
+
[[ -d "$c" ]] || { echo "MISSING: $c"; exit 1; }
|
| 29 |
+
done
|
| 30 |
+
|
| 31 |
+
CKPT_LR5E5=$ROOT/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
|
| 32 |
+
|
| 33 |
+
launch() {
|
| 34 |
+
local gpu="$1" variant="$2"; shift 2
|
| 35 |
+
local out=$SWEEP_ROOT/$variant; mkdir -p "$out"
|
| 36 |
+
nohup env ROOT="$ROOT" VARIANT="$variant" GPU="$gpu" S2_SFT_CKPT="$CKPT_LR5E5" \
|
| 37 |
+
OUTPUT_ROOT="$out" USE_WANDB=0 WANDB_MODE=offline "$@" \
|
| 38 |
+
bash "$PIPELINE" </dev/null >"$out/nohup.log" 2>&1 &
|
| 39 |
+
local pid=$!
|
| 40 |
+
echo "$pid $gpu $variant" >> "$SWEEP_ROOT/PIDS.txt"
|
| 41 |
+
disown $pid 2>/dev/null || true
|
| 42 |
+
printf 'GPU %s -> %s pid=%s\n' "$gpu" "$variant" "$pid"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# pipe_t (GPU 4): continue pipe_m's S3 GRPO with lower LR + KL anchor + longer steps.
|
| 46 |
+
# Keep the policy near the SFT reference to avoid the regression we saw earlier.
|
| 47 |
+
launch 4 pipe_t_grpo_low_kl \
|
| 48 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3GRPO_LATEST" \
|
| 49 |
+
GRPO_LR=1e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 50 |
+
GRPO_BETA=0.04 \
|
| 51 |
+
S3_GRPO_MAX_STEPS=3000 \
|
| 52 |
+
USE_GC=0
|
| 53 |
+
|
| 54 |
+
# pipe_u (GPU 5): re-run S3 GRPO from pipe_m's S3-SFT ckpt with sharper rewards
|
| 55 |
+
# (mirror latent `s3_grpo_sharp_rwd` recipe: bigger penalty for bad).
|
| 56 |
+
launch 5 pipe_u_grpo_sharp_rwd \
|
| 57 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3SFT_LATEST" \
|
| 58 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 59 |
+
REWARD_GOOD=1.5 PENALTY_BAD=2.0 PENALTY_MAL=4.0 \
|
| 60 |
+
S3_GRPO_MAX_STEPS=3000 \
|
| 61 |
+
USE_GC=0
|
| 62 |
+
|
| 63 |
+
# pipe_v (GPU 6): extend pipe_o's S3 SFT (the strongest pure-SFT path) with very
|
| 64 |
+
# low LR for 4000 more steps. Then S3 GRPO at LR=1e-6.
|
| 65 |
+
launch 6 pipe_v_sft_extend \
|
| 66 |
+
START_PHASE=s3_sft S3_SFT_INIT="$PIPE_O_S3SFT_LATEST" \
|
| 67 |
+
SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
|
| 68 |
+
S3_SFT_MAX_STEPS=4000 \
|
| 69 |
+
GRPO_LR=1e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 70 |
+
S3_GRPO_MAX_STEPS=2000 \
|
| 71 |
+
USE_GC=0
|
| 72 |
+
|
| 73 |
+
# pipe_w (GPU 7): continue pipe_j's S3 GRPO with very low LR + KL anchor.
|
| 74 |
+
# Different lineage from pipe_m, so this gives an independent push.
|
| 75 |
+
launch 7 pipe_w_j_low_kl \
|
| 76 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_J_S3GRPO_LATEST" \
|
| 77 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 78 |
+
GRPO_BETA=0.02 \
|
| 79 |
+
S3_GRPO_MAX_STEPS=3000 \
|
| 80 |
+
USE_GC=0
|
| 81 |
+
|
| 82 |
+
echo
|
| 83 |
+
echo "=== launched ==="
|
| 84 |
+
cat "$SWEEP_ROOT/PIDS.txt" | tail -4
|
_runs/launch_baseline_push_v6.sh
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Wave-6: push baseline 1.5B past solve=0.40 by porting the latent's winning
|
| 3 |
+
# reward shaping + multi-value oversampling into the vanilla baseline pipeline.
|
| 4 |
+
#
|
| 5 |
+
# Diagnosis from v4/v5 logs:
|
| 6 |
+
# At plateau, eval shows `avg_set_size=1.000` for every step. The model
|
| 7 |
+
# is predicting only ONE value per cell even when the target is multi-valued.
|
| 8 |
+
# Per-cell exact pinned at 0.95 → solve = 0.95^20 ≈ 0.36.
|
| 9 |
+
# Same failure mode the latent's `s3_grpo_sharp_rwd` recipe fixed:
|
| 10 |
+
# exact_match_bonus + cardinality_mismatch_penalty + penalty_missing
|
| 11 |
+
# plus SFT-side multi_value_oversample_factor=5 (and target_size_min=2 for
|
| 12 |
+
# the most aggressive variant).
|
| 13 |
+
#
|
| 14 |
+
# 8 variants on GPUs 0..7. All seed from existing v4 best ckpts so we don't
|
| 15 |
+
# burn cycles redoing S2.
|
| 16 |
+
set -euo pipefail
|
| 17 |
+
|
| 18 |
+
ROOT=/home/ubuntu/curriculum_cot
|
| 19 |
+
SWEEP_ROOT=$ROOT/_runs/baseline_1p5b_v4_20260523_184952
|
| 20 |
+
PIPELINE=$ROOT/_runs/baseline_1p5b_pipeline_v4.sh
|
| 21 |
+
|
| 22 |
+
# --- v4 anchors ----
|
| 23 |
+
PIPE_V_S3SFT_LATEST=$SWEEP_ROOT/pipe_v_sft_extend/s3_sft/checkpoint-step-04000
|
| 24 |
+
PIPE_M_S3SFT_LATEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_sft/checkpoint-step-02400
|
| 25 |
+
PIPE_V_S3GRPO_BEST=$SWEEP_ROOT/pipe_v_sft_extend/s3_grpo/checkpoint-1000 # step 1050 was 0.40 peak; 1000 is closest saved
|
| 26 |
+
PIPE_M_S3GRPO_BEST=$SWEEP_ROOT/pipe_m_s3sft_from_b/s3_grpo/checkpoint-200 # peak per pipe_m logs
|
| 27 |
+
PIPE_O_S3SFT_LATEST=$SWEEP_ROOT/pipe_o_s3sft_lr5e6/s3_sft/checkpoint-step-02400
|
| 28 |
+
CKPT_LR5E5=$ROOT/checkpoints/sudoku-9x9-20empty-baseline-1p5b-sweep/baseline_lr5e5_lowsft_v3/s2_sft_v3/checkpoint-step-03000
|
| 29 |
+
|
| 30 |
+
for c in "$PIPE_V_S3SFT_LATEST" "$PIPE_M_S3SFT_LATEST" "$PIPE_V_S3GRPO_BEST" "$PIPE_M_S3GRPO_BEST" "$PIPE_O_S3SFT_LATEST"; do
|
| 31 |
+
[[ -d "$c" ]] || { echo "MISSING: $c"; exit 1; }
|
| 32 |
+
done
|
| 33 |
+
|
| 34 |
+
launch() {
|
| 35 |
+
local gpu="$1" variant="$2"; shift 2
|
| 36 |
+
local out=$SWEEP_ROOT/$variant; mkdir -p "$out"
|
| 37 |
+
nohup env ROOT="$ROOT" VARIANT="$variant" GPU="$gpu" S2_SFT_CKPT="$CKPT_LR5E5" \
|
| 38 |
+
OUTPUT_ROOT="$out" USE_WANDB=0 WANDB_MODE=offline "$@" \
|
| 39 |
+
bash "$PIPELINE" </dev/null >"$out/nohup.log" 2>&1 &
|
| 40 |
+
local pid=$!
|
| 41 |
+
echo "$pid $gpu $variant" >> "$SWEEP_ROOT/PIDS.txt"
|
| 42 |
+
disown $pid 2>/dev/null || true
|
| 43 |
+
printf 'GPU %s -> %s pid=%s\n' "$gpu" "$variant" "$pid"
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# === GRPO continuations (the high-leverage knob) ===
|
| 47 |
+
|
| 48 |
+
# v6_a (GPU 0): continue best v4 GRPO with the FULL latent recipe.
|
| 49 |
+
# card_pen=1.0 + missing=0.75 + exact_b=2.0; LR slightly lower than v4 to be safe.
|
| 50 |
+
launch 0 v6_a_grpo_v_card \
|
| 51 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
|
| 52 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 53 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 54 |
+
S3_GRPO_MAX_STEPS=2000
|
| 55 |
+
|
| 56 |
+
# v6_b (GPU 1): "sharp" version — mirror s3_grpo_sharp_rwd's stronger weights.
|
| 57 |
+
launch 1 v6_b_grpo_v_sharp \
|
| 58 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
|
| 59 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 60 |
+
PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=4.0 CARD_MISMATCH_PEN=3.0 \
|
| 61 |
+
S3_GRPO_MAX_STEPS=2000
|
| 62 |
+
|
| 63 |
+
# v6_c (GPU 2): full recipe but from pipe_v's S3 SFT (fresh GRPO, not continuation).
|
| 64 |
+
launch 2 v6_c_grpo_vsft_card \
|
| 65 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3SFT_LATEST" \
|
| 66 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 67 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 68 |
+
S3_GRPO_MAX_STEPS=2000
|
| 69 |
+
|
| 70 |
+
# v6_d (GPU 3): same recipe but from pipe_m's S3 SFT (different lineage; champion).
|
| 71 |
+
launch 3 v6_d_grpo_msft_card \
|
| 72 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_M_S3SFT_LATEST" \
|
| 73 |
+
GRPO_LR=5e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 74 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 75 |
+
S3_GRPO_MAX_STEPS=2000
|
| 76 |
+
|
| 77 |
+
# === SFT push w/ oversample (the data-side knob) ===
|
| 78 |
+
|
| 79 |
+
# v6_e (GPU 4): continue pipe_v S3 SFT with oversample=5. Mirrors r1_sft_c_oversample5.
|
| 80 |
+
launch 4 v6_e_sft_v_oversample5 \
|
| 81 |
+
START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
|
| 82 |
+
SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
|
| 83 |
+
SFT_OVERSAMPLE=5 \
|
| 84 |
+
S3_SFT_MAX_STEPS=2500 \
|
| 85 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 86 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 87 |
+
S3_GRPO_MAX_STEPS=1500
|
| 88 |
+
|
| 89 |
+
# v6_f (GPU 5): same but oversample=8 (more aggressive).
|
| 90 |
+
launch 5 v6_f_sft_v_oversample8 \
|
| 91 |
+
START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
|
| 92 |
+
SFT_LR_S3=2e-6 SFT_BS=16 SFT_GA=1 \
|
| 93 |
+
SFT_OVERSAMPLE=8 \
|
| 94 |
+
S3_SFT_MAX_STEPS=2500 \
|
| 95 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 96 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 97 |
+
S3_GRPO_MAX_STEPS=1500
|
| 98 |
+
|
| 99 |
+
# v6_g (GPU 6): oversample=5 + train_target_size_min=2 (only multi-value cells).
|
| 100 |
+
# This is the most surgical variant — focus all training mass on the failing cells.
|
| 101 |
+
launch 6 v6_g_sft_v_mv_only \
|
| 102 |
+
START_PHASE=s3_sft S3_SFT_INIT="$PIPE_V_S3SFT_LATEST" \
|
| 103 |
+
SFT_LR_S3=1e-6 SFT_BS=16 SFT_GA=1 \
|
| 104 |
+
SFT_OVERSAMPLE=5 SFT_TGT_MIN=2 \
|
| 105 |
+
S3_SFT_MAX_STEPS=2000 \
|
| 106 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 107 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 108 |
+
S3_GRPO_MAX_STEPS=1500
|
| 109 |
+
|
| 110 |
+
# v6_h (GPU 7): same as v6_a but with even more steps + KL anchor for stability.
|
| 111 |
+
# The latent best (s3_grpo_baseline) ran with beta=0.0; we know KL>0 hurts long term.
|
| 112 |
+
# But here we want to see whether the new shape rewards survive more steps without
|
| 113 |
+
# regression. Use a small beta (0.01) for gentle anchoring.
|
| 114 |
+
launch 7 v6_h_grpo_v_card_long \
|
| 115 |
+
START_PHASE=s3_grpo S3_GRPO_INIT="$PIPE_V_S3GRPO_BEST" \
|
| 116 |
+
GRPO_LR=2e-6 GRPO_BS=32 GRPO_GA=1 GRPO_NG=8 \
|
| 117 |
+
GRPO_BETA=0.01 \
|
| 118 |
+
PENALTY_MISSING=0.75 EXACT_MATCH_BONUS=2.0 CARD_MISMATCH_PEN=1.0 \
|
| 119 |
+
S3_GRPO_MAX_STEPS=3000
|
| 120 |
+
|
| 121 |
+
echo
|
| 122 |
+
echo "=== launched ==="
|
| 123 |
+
tail -8 "$SWEEP_ROOT/PIDS.txt"
|
_runs/launch_latent_reproduction_overnight.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Overnight reproduction of the latent recurrent-hidden 3-stage champion.
|
| 3 |
+
# Mirrors the recipe that produced solve=0.60 (100p) / 0.675 (40p) on 2026-05-22.
|
| 4 |
+
#
|
| 5 |
+
# Single distributed job across all 8 H100s. End-to-end runtime: ~6-7 hrs.
|
| 6 |
+
#
|
| 7 |
+
# Stages: S1 SFT (cot=1) -> S1 GRPO (cot=1)
|
| 8 |
+
# -> S2 SFT (cot=2) -> S2 GRPO (cot=2)
|
| 9 |
+
# -> S3 SFT (cot=3) -> S3 GRPO (cot=3)
|
| 10 |
+
#
|
| 11 |
+
# Hyperparameters (defaults, faithful to original):
|
| 12 |
+
# model Qwen/Qwen2.5-0.5B-Instruct
|
| 13 |
+
# num_cot_tokens 1->2->3 across stages
|
| 14 |
+
# latent_mode recurrent_hidden
|
| 15 |
+
# bs=8/device, grad_accum=2, gradient checkpointing ON
|
| 16 |
+
# stage1_sft_lr=2e-4, stage2/3_sft_lr=5e-5, grpo_lr=1e-6 (hardcoded)
|
| 17 |
+
# value_target=0.98 (precision AND recall)
|
| 18 |
+
# train_puzzles=10000 eval_puzzles=100
|
| 19 |
+
# num_generations=4 max_completion_length=24
|
| 20 |
+
|
| 21 |
+
set -euo pipefail
|
| 22 |
+
|
| 23 |
+
ROOT=/home/ubuntu/curriculum_cot
|
| 24 |
+
SCRIPT="${ROOT}/hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh"
|
| 25 |
+
|
| 26 |
+
RUN_TAG="latent_reproduction_overnight_$(date +%Y%m%d_%H%M%S)"
|
| 27 |
+
OUTPUT_ROOT="${ROOT}/_runs/${RUN_TAG}"
|
| 28 |
+
LOG="${OUTPUT_ROOT}/PIPELINE.log"
|
| 29 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 30 |
+
|
| 31 |
+
# Free the HF caches and ensure our pre-downloaded Qwen 0.5B is found
|
| 32 |
+
export HF_HOME="${ROOT}/.hf_cache"
|
| 33 |
+
export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
|
| 34 |
+
export HF_HUB_OFFLINE=0
|
| 35 |
+
export TOKENIZERS_PARALLELISM=false
|
| 36 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 37 |
+
# wandb is not authenticated on this machine — keep offline so jobs don't hang
|
| 38 |
+
export WANDB_MODE=offline
|
| 39 |
+
# avoid the upstream script trying to pull from a wandb entity we don't own
|
| 40 |
+
export WANDB_ENTITY="local"
|
| 41 |
+
|
| 42 |
+
# Use our preinstalled pytorch venv
|
| 43 |
+
export PYTHON_BIN=/opt/pytorch/bin/python
|
| 44 |
+
|
| 45 |
+
# 4-GPU distributed run with doubled grad accum to preserve the original
|
| 46 |
+
# effective batch size (8*2*8 = 128 -> 8*4*4 = 128). Takes ~2x wall-clock
|
| 47 |
+
# but is faithful to the original convergence dynamics.
|
| 48 |
+
export GPU_IDS=0,1,2,3
|
| 49 |
+
export NUM_PROCESSES=4
|
| 50 |
+
export SFT_GRAD_ACCUM=4
|
| 51 |
+
export GRPO_GRAD_ACCUM=4
|
| 52 |
+
|
| 53 |
+
# Match original
|
| 54 |
+
export MODEL_NAME="Qwen/Qwen2.5-0.5B-Instruct"
|
| 55 |
+
export VALUE_TARGET=0.98
|
| 56 |
+
export SFT_VALUE_TARGET=0.95
|
| 57 |
+
export GRPO_VALUE_TARGET=0.98
|
| 58 |
+
export TRAIN_PUZZLES=10000
|
| 59 |
+
export EVAL_PUZZLES=100
|
| 60 |
+
export MIN_STEPS_BEFORE_STOP=50
|
| 61 |
+
|
| 62 |
+
# Cap per-phase wallclock to keep us safely under one overnight session.
|
| 63 |
+
# The original took ~6-7 hours; we cap each phase at 75 min to let all 6 phases
|
| 64 |
+
# finish within ~7.5 hrs even if one phase slow-runs.
|
| 65 |
+
export PHASE_WALL_CLOCK_SECONDS=4500
|
| 66 |
+
|
| 67 |
+
# Hard step caps (in addition to early stop on prec+recall)
|
| 68 |
+
export SFT_MAX_STEPS=4000
|
| 69 |
+
export GRPO_MAX_STEPS=2000
|
| 70 |
+
|
| 71 |
+
export RUN_TAG
|
| 72 |
+
export OUTPUT_ROOT
|
| 73 |
+
export CHECKPOINT_ROOT="${OUTPUT_ROOT}"
|
| 74 |
+
|
| 75 |
+
printf '[launch_latent_reproduction] %s\n' "$(date -Is)" | tee -a "${LOG}"
|
| 76 |
+
printf ' RUN_TAG=%s\n' "${RUN_TAG}" | tee -a "${LOG}"
|
| 77 |
+
printf ' OUTPUT_ROOT=%s\n' "${OUTPUT_ROOT}" | tee -a "${LOG}"
|
| 78 |
+
printf ' GPUs=%s nproc=%s model=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" "${MODEL_NAME}" | tee -a "${LOG}"
|
| 79 |
+
printf ' VALUE_TARGET=%s SFT_VALUE_TARGET=%s GRPO_VALUE_TARGET=%s\n' "${VALUE_TARGET}" "${SFT_VALUE_TARGET}" "${GRPO_VALUE_TARGET}" | tee -a "${LOG}"
|
| 80 |
+
printf ' PHASE_WALL_CLOCK=%ss SFT_MAX_STEPS=%s GRPO_MAX_STEPS=%s\n' "${PHASE_WALL_CLOCK_SECONDS}" "${SFT_MAX_STEPS}" "${GRPO_MAX_STEPS}" | tee -a "${LOG}"
|
| 81 |
+
|
| 82 |
+
bash "${SCRIPT}" 2>&1 | tee -a "${LOG}"
|
_runs/launch_simple_baseline.sh
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Strawman baseline for the rebuttal: vanilla LoRA, no curriculum, no thought
|
| 3 |
+
# tokens, single-shot whole-puzzle prediction. SFT followed by GRPO.
|
| 4 |
+
#
|
| 5 |
+
# Same model (Qwen2.5-1.5B-Instruct), same LoRA (r=32, α=64, dropout=0.05),
|
| 6 |
+
# same JSONL data file, same Qwen chat template wrapping as the cell-policy
|
| 7 |
+
# experiments. The ONLY differences from the cell-policy baseline are:
|
| 8 |
+
# - no per-cell expansion (one example per puzzle)
|
| 9 |
+
# - no stage_i / curriculum
|
| 10 |
+
# - no multi_value_oversample, no exact_match_bonus / cardinality penalties
|
| 11 |
+
# - reward = number of correct values out of 20 + whole-solve bonus
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
|
| 14 |
+
ROOT=/home/ubuntu/curriculum_cot
|
| 15 |
+
SCRIPT=${ROOT}/_runs/simple_baseline_sudoku_train.py
|
| 16 |
+
PYTHON_BIN=/opt/pytorch/bin/python
|
| 17 |
+
|
| 18 |
+
TRAIN_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl
|
| 19 |
+
EVAL_JSONL=${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl
|
| 20 |
+
|
| 21 |
+
SWEEP_ROOT=${ROOT}/_runs/strawman_baseline_$(date +%Y%m%d_%H%M%S)
|
| 22 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 23 |
+
echo "${SWEEP_ROOT}" > "${ROOT}/_runs/current_strawman_sweep_dir"
|
| 24 |
+
echo "SWEEP_ROOT=${SWEEP_ROOT}"
|
| 25 |
+
|
| 26 |
+
export TOKENIZERS_PARALLELISM=false
|
| 27 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 28 |
+
export HF_HOME="${ROOT}/.hf_cache"
|
| 29 |
+
export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
|
| 30 |
+
export WANDB_MODE=offline
|
| 31 |
+
|
| 32 |
+
run_pipeline() {
|
| 33 |
+
local gpu="$1" tag="$2" sft_lr="$3" grpo_lr="$4" sft_max="$5" grpo_max="$6"
|
| 34 |
+
local out=${SWEEP_ROOT}/${tag}
|
| 35 |
+
mkdir -p "${out}"
|
| 36 |
+
local log=${out}/pipeline.log
|
| 37 |
+
: > "${log}"
|
| 38 |
+
(
|
| 39 |
+
export CUDA_VISIBLE_DEVICES="${gpu}"
|
| 40 |
+
echo "[$(date +%H:%M:%S)] === ${tag} on GPU ${gpu}: SFT lr=${sft_lr} max_steps=${sft_max} ===" >> "${log}"
|
| 41 |
+
"${PYTHON_BIN}" -u "${SCRIPT}" \
|
| 42 |
+
--phase sft \
|
| 43 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 44 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 45 |
+
--output_dir "${out}/sft" \
|
| 46 |
+
--learning_rate "${sft_lr}" \
|
| 47 |
+
--max_steps "${sft_max}" \
|
| 48 |
+
--per_device_train_batch_size 8 \
|
| 49 |
+
--gradient_accumulation_steps 2 \
|
| 50 |
+
--num_epochs 8 \
|
| 51 |
+
--logging_steps 25 \
|
| 52 |
+
--save_steps 200 \
|
| 53 |
+
--eval_rows 100 \
|
| 54 |
+
--max_completion_length 96 \
|
| 55 |
+
--max_prompt_length 1024 \
|
| 56 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 57 |
+
--seed 0 \
|
| 58 |
+
>> "${log}" 2>&1
|
| 59 |
+
|
| 60 |
+
echo "[$(date +%H:%M:%S)] === ${tag} on GPU ${gpu}: GRPO lr=${grpo_lr} max_steps=${grpo_max} ===" >> "${log}"
|
| 61 |
+
"${PYTHON_BIN}" -u "${SCRIPT}" \
|
| 62 |
+
--phase grpo \
|
| 63 |
+
--init_adapter_dir "${out}/sft/final" \
|
| 64 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 65 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 66 |
+
--output_dir "${out}/grpo" \
|
| 67 |
+
--learning_rate "${grpo_lr}" \
|
| 68 |
+
--max_steps "${grpo_max}" \
|
| 69 |
+
--per_device_train_batch_size 4 \
|
| 70 |
+
--gradient_accumulation_steps 2 \
|
| 71 |
+
--num_generations 8 \
|
| 72 |
+
--beta 0.0 \
|
| 73 |
+
--temperature 1.0 \
|
| 74 |
+
--num_epochs 50 \
|
| 75 |
+
--logging_steps 25 \
|
| 76 |
+
--save_steps 200 \
|
| 77 |
+
--eval_rows 100 \
|
| 78 |
+
--max_completion_length 96 \
|
| 79 |
+
--max_prompt_length 1024 \
|
| 80 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 81 |
+
--seed 0 \
|
| 82 |
+
>> "${log}" 2>&1
|
| 83 |
+
echo "[$(date +%H:%M:%S)] === ${tag} DONE ===" >> "${log}"
|
| 84 |
+
) >/dev/null 2>&1 &
|
| 85 |
+
local pid=$!
|
| 86 |
+
echo "$pid $gpu $tag" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 87 |
+
disown $pid 2>/dev/null || true
|
| 88 |
+
printf 'GPU %s -> %s pid=%s log=%s\n' "$gpu" "$tag" "$pid" "$log"
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# 2 variants on GPUs 0,1: explore SFT LR (5e-5 and 1e-4) — same GRPO LR (5e-6).
|
| 92 |
+
run_pipeline 0 strawman_a_sft5e5_grpo5e6 5e-5 5e-6 2000 1500
|
| 93 |
+
run_pipeline 1 strawman_b_sft1e4_grpo5e6 1e-4 5e-6 2000 1500
|
| 94 |
+
|
| 95 |
+
echo
|
| 96 |
+
echo "=== launched ==="
|
| 97 |
+
cat "${SWEEP_ROOT}/PIDS.txt"
|
_runs/launch_strawman_cellpolicy.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Launch two strawman variants (single-stage cell-policy at stage_i=3, no
|
| 3 |
+
# curriculum, no thought tokens) on GPUs 0 and 1.
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 6 |
+
TS="$(date +%Y%m%d_%H%M%S)"
|
| 7 |
+
SWEEP_ROOT="${ROOT}/_runs/strawman_cellpolicy_${TS}"
|
| 8 |
+
mkdir -p "${SWEEP_ROOT}"
|
| 9 |
+
PIPE="${ROOT}/_runs/strawman_cellpolicy_pipeline.sh"
|
| 10 |
+
chmod +x "${PIPE}"
|
| 11 |
+
|
| 12 |
+
launch() {
|
| 13 |
+
# Usage: launch <variant> <gpu> <KEY=VALUE>...
|
| 14 |
+
local variant="$1" gpu="$2"
|
| 15 |
+
shift 2
|
| 16 |
+
local out="${SWEEP_ROOT}/${variant}"
|
| 17 |
+
mkdir -p "${out}"
|
| 18 |
+
echo "[launch] ${variant} on GPU ${gpu} out=${out}"
|
| 19 |
+
nohup env VARIANT="${variant}" GPU="${gpu}" OUTPUT_ROOT="${out}" "$@" \
|
| 20 |
+
bash "${PIPE}" > "${out}/console.log" 2>&1 &
|
| 21 |
+
local pid=$!
|
| 22 |
+
disown "${pid}" || true
|
| 23 |
+
echo "${variant}=${pid}" >> "${SWEEP_ROOT}/PIDS.txt"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
launch strawman_a_lr2e5 0 \
|
| 27 |
+
SFT_LR=2e-5 GRPO_LR=5e-6 SFT_MAX_STEPS=3000 GRPO_MAX_STEPS=1500 \
|
| 28 |
+
PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=1.0 CARD_MISMATCH_PEN=1.5 \
|
| 29 |
+
SFT_OVERSAMPLE=3
|
| 30 |
+
|
| 31 |
+
launch strawman_b_lr5e5 1 \
|
| 32 |
+
SFT_LR=5e-5 GRPO_LR=5e-6 SFT_MAX_STEPS=4000 GRPO_MAX_STEPS=1500 \
|
| 33 |
+
PENALTY_MISSING=1.0 EXACT_MATCH_BONUS=1.0 CARD_MISMATCH_PEN=1.5 \
|
| 34 |
+
SFT_OVERSAMPLE=3
|
| 35 |
+
|
| 36 |
+
echo "[launch] sweep root: ${SWEEP_ROOT}"
|
| 37 |
+
echo "[launch] PIDs:"
|
| 38 |
+
cat "${SWEEP_ROOT}/PIDS.txt"
|
_runs/simple_baseline_sudoku_train.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Strawman baseline for the rebuttal.
|
| 3 |
+
|
| 4 |
+
Vanilla Qwen2.5-1.5B-Instruct + LoRA on top of the *existing* JSONL data
|
| 5 |
+
(`data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`).
|
| 6 |
+
|
| 7 |
+
Compared to the cell-policy / latent recipes, this strawman intentionally
|
| 8 |
+
removes everything that helped:
|
| 9 |
+
|
| 10 |
+
- NO curriculum (single stage; we don't even read `stage_i`).
|
| 11 |
+
- NO chain-of-thought / latent thought tokens.
|
| 12 |
+
- NO per-cell expansion (one example == one whole puzzle).
|
| 13 |
+
- NO multi-value oversampling, no special reward shaping (just matches/N).
|
| 14 |
+
|
| 15 |
+
It uses the *same* model, *same* LoRA config, *same* tokenizer + chat
|
| 16 |
+
template wrapping that every cell-policy experiment used, so any solve
|
| 17 |
+
gap vs the cell-policy / latent runs is purely due to task framing,
|
| 18 |
+
not data, prompt, model, or PEFT differences.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
python simple_baseline_sudoku_train.py --phase sft --output_dir <out>/sft --learning_rate 5e-5
|
| 22 |
+
python simple_baseline_sudoku_train.py --phase grpo --init_adapter_dir <out>/sft/final \
|
| 23 |
+
--output_dir <out>/grpo --learning_rate 5e-6
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import math
|
| 31 |
+
import os
|
| 32 |
+
import re
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
from datasets import Dataset
|
| 40 |
+
from peft import LoraConfig, PeftModel, get_peft_model
|
| 41 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 42 |
+
|
| 43 |
+
# Reuse existing helpers (these are the canonical ones used by every cell-policy run).
|
| 44 |
+
ROOT = Path(__file__).resolve().parent.parent
|
| 45 |
+
if str(ROOT) not in sys.path:
|
| 46 |
+
sys.path.insert(0, str(ROOT))
|
| 47 |
+
|
| 48 |
+
from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore
|
| 49 |
+
load_jsonl_rows,
|
| 50 |
+
pick_dtype,
|
| 51 |
+
)
|
| 52 |
+
from multi_output_cell_policy.rewards import score_prediction_text # type: ignore
|
| 53 |
+
from multi_output_cell_policy.shared_multi_output_policy import ( # type: ignore
|
| 54 |
+
make_solved_grid_from_row,
|
| 55 |
+
stage_i_consistent_values,
|
| 56 |
+
)
|
| 57 |
+
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row # type: ignore
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---- Strawman task definition -----------------------------------------------
|
| 61 |
+
# This is the ONLY new piece relative to the cell-policy experiments. The
|
| 62 |
+
# system prompt asks the model to emit the missing values for ALL empty cells
|
| 63 |
+
# in one shot, in the row-major order that the existing JSONL `completion`
|
| 64 |
+
# field already uses. The user message is the raw `prompt` field from the
|
| 65 |
+
# JSONL (puzzle as (row,col,value) tuples), which is byte-identical to what
|
| 66 |
+
# `prompt_builder.py` consumes in cell-policy runs.
|
| 67 |
+
|
| 68 |
+
SYSTEM_PROMPT_STRAWMAN = (
|
| 69 |
+
"You are a Sudoku solver.\n"
|
| 70 |
+
"You will be given a 9x9 Sudoku grid encoded as (row,col,value) tuples in "
|
| 71 |
+
"row-major order, where value 0 marks an empty cell.\n"
|
| 72 |
+
"Predict the missing values for ALL empty cells in row-major order.\n"
|
| 73 |
+
"Return ONLY a JSON list of integers like [v1,v2,...,vK], where K is the "
|
| 74 |
+
"number of empty cells (typically 20). Each value must be an integer in "
|
| 75 |
+
"[1,9].\n"
|
| 76 |
+
"Do not include any explanation, markdown, or text outside the JSON list."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_chat_prompt(tokenizer: Any, raw_prompt: str) -> str:
|
| 81 |
+
"""Same chat template wrapping every other experiment uses (Qwen, system+user)."""
|
| 82 |
+
messages = [
|
| 83 |
+
{"role": "system", "content": SYSTEM_PROMPT_STRAWMAN.strip()},
|
| 84 |
+
{"role": "user", "content": raw_prompt},
|
| 85 |
+
]
|
| 86 |
+
chat_template = getattr(tokenizer, "chat_template", None)
|
| 87 |
+
if chat_template:
|
| 88 |
+
return tokenizer.apply_chat_template(
|
| 89 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 90 |
+
)
|
| 91 |
+
return SYSTEM_PROMPT_STRAWMAN.strip() + "\n\n" + raw_prompt + "\n"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---- Reward -----------------------------------------------------------------
|
| 95 |
+
|
| 96 |
+
LIST_RE = re.compile(r"\[[^\[\]]*\]")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def parse_int_list(text: str) -> Optional[List[int]]:
|
| 100 |
+
"""Parse the model's emission as a JSON int list with values in [1,9].
|
| 101 |
+
|
| 102 |
+
Tolerant: tries the whole completion first, then falls back to the first
|
| 103 |
+
well-formed JSON list match. Returns None on failure.
|
| 104 |
+
"""
|
| 105 |
+
s = str(text).strip()
|
| 106 |
+
if not s:
|
| 107 |
+
return None
|
| 108 |
+
candidates: List[str] = []
|
| 109 |
+
candidates.append(s)
|
| 110 |
+
m = LIST_RE.search(s)
|
| 111 |
+
if m is not None:
|
| 112 |
+
candidates.append(m.group(0))
|
| 113 |
+
for cand in candidates:
|
| 114 |
+
try:
|
| 115 |
+
obj = json.loads(cand)
|
| 116 |
+
except Exception:
|
| 117 |
+
continue
|
| 118 |
+
if not isinstance(obj, list):
|
| 119 |
+
continue
|
| 120 |
+
out: List[int] = []
|
| 121 |
+
ok = True
|
| 122 |
+
for v in obj:
|
| 123 |
+
if isinstance(v, bool) or not isinstance(v, int):
|
| 124 |
+
ok = False
|
| 125 |
+
break
|
| 126 |
+
if v < 1 or v > 9:
|
| 127 |
+
ok = False
|
| 128 |
+
break
|
| 129 |
+
out.append(int(v))
|
| 130 |
+
if ok:
|
| 131 |
+
return out
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def whole_puzzle_reward(
|
| 136 |
+
*,
|
| 137 |
+
pred_list: Optional[List[int]],
|
| 138 |
+
target_list: List[int],
|
| 139 |
+
parse_penalty: float = 4.0,
|
| 140 |
+
length_mismatch_penalty: float = 0.5,
|
| 141 |
+
full_solve_bonus: float = 5.0,
|
| 142 |
+
) -> float:
|
| 143 |
+
"""Simple reward: matches per cell + bonus for full solve, penalty if parse fails."""
|
| 144 |
+
if pred_list is None:
|
| 145 |
+
return -float(parse_penalty)
|
| 146 |
+
n = len(target_list)
|
| 147 |
+
matches = 0
|
| 148 |
+
for i in range(min(len(pred_list), n)):
|
| 149 |
+
if int(pred_list[i]) == int(target_list[i]):
|
| 150 |
+
matches += 1
|
| 151 |
+
reward = float(matches)
|
| 152 |
+
if len(pred_list) != n:
|
| 153 |
+
reward -= float(length_mismatch_penalty) * abs(len(pred_list) - n)
|
| 154 |
+
if len(pred_list) == n and matches == n:
|
| 155 |
+
reward += float(full_solve_bonus)
|
| 156 |
+
return reward
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ---- Dataset construction ---------------------------------------------------
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def build_dataset(rows: List[Dict[str, Any]], tokenizer: Any) -> Dataset:
|
| 163 |
+
prompts, completions, targets = [], [], []
|
| 164 |
+
for row in rows:
|
| 165 |
+
raw_prompt = str(row["prompt"]).strip()
|
| 166 |
+
completion_str = str(row["completion"]).strip()
|
| 167 |
+
target = parse_int_list(completion_str)
|
| 168 |
+
if target is None:
|
| 169 |
+
continue
|
| 170 |
+
prompts.append(build_chat_prompt(tokenizer, raw_prompt))
|
| 171 |
+
completions.append(completion_str)
|
| 172 |
+
targets.append(json.dumps(target, separators=(",", ":")))
|
| 173 |
+
return Dataset.from_dict(
|
| 174 |
+
{"prompt": prompts, "completion": completions, "target": targets}
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ---- Eval (deterministic, greedy, single-shot) ------------------------------
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
@torch.no_grad()
|
| 183 |
+
def run_eval(
|
| 184 |
+
model: torch.nn.Module,
|
| 185 |
+
tokenizer: Any,
|
| 186 |
+
eval_rows: List[Dict[str, Any]],
|
| 187 |
+
device: torch.device,
|
| 188 |
+
max_new_tokens: int = 96,
|
| 189 |
+
print_n: int = 3,
|
| 190 |
+
stage_i: int = 3,
|
| 191 |
+
) -> Dict[str, float]:
|
| 192 |
+
"""Apples-to-apples eval with the cell-policy framework.
|
| 193 |
+
|
| 194 |
+
The strawman model emits the WHOLE puzzle (a JSON list of integers) in
|
| 195 |
+
one forward pass. We then split that list into per-cell SINGLETON
|
| 196 |
+
predictions and score each cell with the same ``score_prediction_text``
|
| 197 |
+
function the cell-policy / latent baselines use, against the i-consistent
|
| 198 |
+
target set at ``stage_i`` (default 3 — matching the S3 eval used for the
|
| 199 |
+
rebuttal v6 baseline and the latent champion).
|
| 200 |
+
|
| 201 |
+
Reported metrics mirror ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``
|
| 202 |
+
so numbers are directly comparable across all four 2x2 ablation cells.
|
| 203 |
+
"""
|
| 204 |
+
model.eval()
|
| 205 |
+
total_cells = 0
|
| 206 |
+
parse_ok = 0.0
|
| 207 |
+
canonical_ok = 0.0
|
| 208 |
+
exact_set_match = 0.0
|
| 209 |
+
includes_gt = 0.0
|
| 210 |
+
precision_sum = 0.0
|
| 211 |
+
recall_sum = 0.0
|
| 212 |
+
cardinality_match_sum = 0.0
|
| 213 |
+
n_solve = 0
|
| 214 |
+
n_total_puzzles = 0
|
| 215 |
+
n_parse_fail_puzzles = 0
|
| 216 |
+
printed = 0
|
| 217 |
+
for row in eval_rows:
|
| 218 |
+
target_completion = parse_int_list(str(row["completion"]))
|
| 219 |
+
if target_completion is None:
|
| 220 |
+
continue
|
| 221 |
+
n_total_puzzles += 1
|
| 222 |
+
prompt = build_chat_prompt(tokenizer, str(row["prompt"]).strip())
|
| 223 |
+
enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
|
| 224 |
+
enc = {k: v.to(device) for k, v in enc.items()}
|
| 225 |
+
out = model.generate(
|
| 226 |
+
**enc,
|
| 227 |
+
max_new_tokens=int(max_new_tokens),
|
| 228 |
+
do_sample=False,
|
| 229 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 230 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 231 |
+
)
|
| 232 |
+
gen = tokenizer.decode(
|
| 233 |
+
out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True
|
| 234 |
+
).strip()
|
| 235 |
+
pred_list = parse_int_list(gen)
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
cells = build_cell_examples_from_row(row)
|
| 239 |
+
solved = make_solved_grid_from_row(row)
|
| 240 |
+
except Exception as e:
|
| 241 |
+
if printed < print_n:
|
| 242 |
+
print(f"[strawman eval debug] row skipped (no metadata): {e}", flush=True)
|
| 243 |
+
printed += 1
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
row_all_exact = True
|
| 247 |
+
row_has_eval_cell = False
|
| 248 |
+
for idx, ex in enumerate(cells):
|
| 249 |
+
target_values = stage_i_consistent_values(
|
| 250 |
+
ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i)
|
| 251 |
+
)
|
| 252 |
+
row_has_eval_cell = True
|
| 253 |
+
if pred_list is not None and idx < len(pred_list):
|
| 254 |
+
pred_text = json.dumps({"values": [int(pred_list[idx])]})
|
| 255 |
+
else:
|
| 256 |
+
pred_text = ""
|
| 257 |
+
info = score_prediction_text(
|
| 258 |
+
text=pred_text,
|
| 259 |
+
grid=ex.grid,
|
| 260 |
+
solved=solved,
|
| 261 |
+
target_cell=ex.target_cell,
|
| 262 |
+
stage_i=int(stage_i),
|
| 263 |
+
reward_good_value=1.0,
|
| 264 |
+
penalty_bad_value=1.75,
|
| 265 |
+
penalty_malformed=4.0,
|
| 266 |
+
penalty_empty=0.5,
|
| 267 |
+
penalty_singleton=1.5,
|
| 268 |
+
)
|
| 269 |
+
total_cells += 1
|
| 270 |
+
parse_ok += float(info["parse_ok"])
|
| 271 |
+
canonical_ok += float(info["strict_canonical"])
|
| 272 |
+
exact_set_match += float(info["exact_set_match"])
|
| 273 |
+
includes_gt += float(info["includes_ground_truth"])
|
| 274 |
+
precision_sum += float(info["value_precision"])
|
| 275 |
+
recall_sum += float(info["value_recall"])
|
| 276 |
+
if int(info["num_predicted_values"]) == int(len(target_values)):
|
| 277 |
+
cardinality_match_sum += 1.0
|
| 278 |
+
if float(info["exact_set_match"]) < 0.5:
|
| 279 |
+
row_all_exact = False
|
| 280 |
+
if row_has_eval_cell and row_all_exact:
|
| 281 |
+
n_solve += 1
|
| 282 |
+
if pred_list is None:
|
| 283 |
+
n_parse_fail_puzzles += 1
|
| 284 |
+
if printed < print_n:
|
| 285 |
+
head_pred = pred_list if pred_list is not None else "PARSE_FAIL"
|
| 286 |
+
print(
|
| 287 |
+
f"[strawman eval debug] target={target_completion} pred={head_pred} "
|
| 288 |
+
f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}",
|
| 289 |
+
flush=True,
|
| 290 |
+
)
|
| 291 |
+
printed += 1
|
| 292 |
+
return {
|
| 293 |
+
"n_total_cells": float(total_cells),
|
| 294 |
+
"n_total_puzzles": float(n_total_puzzles),
|
| 295 |
+
"parse_rate": float(parse_ok / max(1, total_cells)),
|
| 296 |
+
"strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
|
| 297 |
+
"exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
|
| 298 |
+
"includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
|
| 299 |
+
"value_precision": float(precision_sum / max(1, total_cells)),
|
| 300 |
+
"value_recall": float(recall_sum / max(1, total_cells)),
|
| 301 |
+
"cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)),
|
| 302 |
+
"puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)),
|
| 303 |
+
"solve_rate": float(n_solve) / max(1, n_total_puzzles),
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ---- Main -------------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def parse_args() -> argparse.Namespace:
|
| 311 |
+
p = argparse.ArgumentParser()
|
| 312 |
+
p.add_argument("--phase", choices=["sft", "grpo"], required=True)
|
| 313 |
+
p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 314 |
+
p.add_argument("--train_jsonl", type=str, required=True)
|
| 315 |
+
p.add_argument("--eval_jsonl", type=str, required=True)
|
| 316 |
+
p.add_argument("--output_dir", type=str, required=True)
|
| 317 |
+
p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache"))
|
| 318 |
+
p.add_argument("--init_adapter_dir", type=str, default="")
|
| 319 |
+
p.add_argument("--seed", type=int, default=0)
|
| 320 |
+
|
| 321 |
+
# Data
|
| 322 |
+
p.add_argument("--limit_train_rows", type=int, default=10000)
|
| 323 |
+
p.add_argument("--eval_rows", type=int, default=100)
|
| 324 |
+
|
| 325 |
+
# Train hyperparameters
|
| 326 |
+
p.add_argument("--per_device_train_batch_size", type=int, default=8)
|
| 327 |
+
p.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
| 328 |
+
p.add_argument("--learning_rate", type=float, default=5e-5)
|
| 329 |
+
p.add_argument("--weight_decay", type=float, default=0.0)
|
| 330 |
+
p.add_argument("--num_epochs", type=float, default=8.0)
|
| 331 |
+
p.add_argument("--max_steps", type=int, default=2000)
|
| 332 |
+
p.add_argument("--logging_steps", type=int, default=25)
|
| 333 |
+
p.add_argument("--save_steps", type=int, default=200)
|
| 334 |
+
p.add_argument("--eval_steps", type=int, default=150)
|
| 335 |
+
p.add_argument("--max_grad_norm", type=float, default=1.0)
|
| 336 |
+
p.add_argument("--max_completion_length", type=int, default=96)
|
| 337 |
+
p.add_argument("--max_prompt_length", type=int, default=1024)
|
| 338 |
+
|
| 339 |
+
# LoRA
|
| 340 |
+
p.add_argument("--lora_r", type=int, default=32)
|
| 341 |
+
p.add_argument("--lora_alpha", type=int, default=64)
|
| 342 |
+
p.add_argument("--lora_dropout", type=float, default=0.05)
|
| 343 |
+
p.add_argument("--enable_gradient_checkpointing", action="store_true")
|
| 344 |
+
|
| 345 |
+
# GRPO-only
|
| 346 |
+
p.add_argument("--num_generations", type=int, default=8)
|
| 347 |
+
p.add_argument("--beta", type=float, default=0.0)
|
| 348 |
+
p.add_argument("--temperature", type=float, default=1.0)
|
| 349 |
+
p.add_argument("--full_solve_bonus", type=float, default=5.0)
|
| 350 |
+
p.add_argument("--length_mismatch_penalty", type=float, default=0.5)
|
| 351 |
+
p.add_argument("--parse_penalty", type=float, default=4.0)
|
| 352 |
+
|
| 353 |
+
# W&B
|
| 354 |
+
p.add_argument("--use_wandb", action="store_true")
|
| 355 |
+
p.add_argument("--wandb_project", type=str, default="sudoku-strawman-baseline")
|
| 356 |
+
p.add_argument("--wandb_run_name", type=str, default="")
|
| 357 |
+
p.add_argument("--wandb_mode", type=str, default="offline")
|
| 358 |
+
|
| 359 |
+
return p.parse_args()
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device):
|
| 363 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 364 |
+
args.model_name, cache_dir=args.cache_dir, use_fast=True
|
| 365 |
+
)
|
| 366 |
+
if tokenizer.pad_token_id is None:
|
| 367 |
+
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
|
| 368 |
+
if tokenizer.padding_side != "left":
|
| 369 |
+
tokenizer.padding_side = "left"
|
| 370 |
+
|
| 371 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 372 |
+
args.model_name,
|
| 373 |
+
cache_dir=args.cache_dir,
|
| 374 |
+
torch_dtype=pick_dtype(),
|
| 375 |
+
low_cpu_mem_usage=True,
|
| 376 |
+
)
|
| 377 |
+
if str(args.init_adapter_dir).strip():
|
| 378 |
+
model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
|
| 379 |
+
else:
|
| 380 |
+
lora = LoraConfig(
|
| 381 |
+
r=args.lora_r,
|
| 382 |
+
lora_alpha=args.lora_alpha,
|
| 383 |
+
lora_dropout=args.lora_dropout,
|
| 384 |
+
bias="none",
|
| 385 |
+
task_type="CAUSAL_LM",
|
| 386 |
+
target_modules=[
|
| 387 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 388 |
+
"gate_proj", "up_proj", "down_proj",
|
| 389 |
+
],
|
| 390 |
+
)
|
| 391 |
+
model = get_peft_model(model, lora)
|
| 392 |
+
|
| 393 |
+
if args.enable_gradient_checkpointing:
|
| 394 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 395 |
+
model.gradient_checkpointing_enable(
|
| 396 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 397 |
+
)
|
| 398 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 399 |
+
model.enable_input_require_grads()
|
| 400 |
+
if hasattr(model, "config"):
|
| 401 |
+
model.config.use_cache = False
|
| 402 |
+
model.to(device)
|
| 403 |
+
return model, tokenizer
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def run_sft(args: argparse.Namespace) -> None:
|
| 407 |
+
from trl import SFTConfig, SFTTrainer # type: ignore
|
| 408 |
+
|
| 409 |
+
set_seed(int(args.seed))
|
| 410 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 411 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 412 |
+
|
| 413 |
+
train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
|
| 414 |
+
eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
|
| 415 |
+
|
| 416 |
+
model, tokenizer = setup_model_and_tokenizer(args, device)
|
| 417 |
+
|
| 418 |
+
# Build dataset of {prompt, completion} where prompt is chat-templated.
|
| 419 |
+
train_ds = build_dataset(train_rows, tokenizer)
|
| 420 |
+
|
| 421 |
+
cfg = SFTConfig(
|
| 422 |
+
output_dir=args.output_dir,
|
| 423 |
+
per_device_train_batch_size=int(args.per_device_train_batch_size),
|
| 424 |
+
gradient_accumulation_steps=int(args.gradient_accumulation_steps),
|
| 425 |
+
learning_rate=float(args.learning_rate),
|
| 426 |
+
weight_decay=float(args.weight_decay),
|
| 427 |
+
num_train_epochs=float(args.num_epochs),
|
| 428 |
+
max_steps=int(args.max_steps),
|
| 429 |
+
logging_steps=int(args.logging_steps),
|
| 430 |
+
save_steps=int(args.save_steps),
|
| 431 |
+
save_strategy="steps",
|
| 432 |
+
save_total_limit=4,
|
| 433 |
+
eval_strategy="no",
|
| 434 |
+
bf16=(pick_dtype() == torch.bfloat16),
|
| 435 |
+
fp16=(pick_dtype() == torch.float16),
|
| 436 |
+
max_grad_norm=float(args.max_grad_norm),
|
| 437 |
+
gradient_checkpointing=bool(args.enable_gradient_checkpointing),
|
| 438 |
+
report_to=("wandb" if args.use_wandb else "none"),
|
| 439 |
+
run_name=(args.wandb_run_name or None),
|
| 440 |
+
max_length=int(args.max_prompt_length + args.max_completion_length + 8),
|
| 441 |
+
completion_only_loss=True,
|
| 442 |
+
seed=int(args.seed),
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
trainer = SFTTrainer(
|
| 446 |
+
model=model,
|
| 447 |
+
args=cfg,
|
| 448 |
+
train_dataset=train_ds,
|
| 449 |
+
processing_class=tokenizer,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Periodic eval hook (TRL doesn't natively give us a custom eval loop hook,
|
| 453 |
+
# so we run eval before training and after the final step here).
|
| 454 |
+
print("[strawman sft] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
|
| 455 |
+
|
| 456 |
+
t0 = time.time()
|
| 457 |
+
trainer.train()
|
| 458 |
+
print(f"[strawman sft] training time = {time.time() - t0:.1f}s", flush=True)
|
| 459 |
+
|
| 460 |
+
final_dir = os.path.join(args.output_dir, "final")
|
| 461 |
+
trainer.save_model(final_dir)
|
| 462 |
+
print(f"[strawman sft] saved final adapter to {final_dir}", flush=True)
|
| 463 |
+
|
| 464 |
+
print("[strawman sft] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def run_grpo(args: argparse.Namespace) -> None:
|
| 468 |
+
from trl import GRPOConfig, GRPOTrainer # type: ignore
|
| 469 |
+
|
| 470 |
+
set_seed(int(args.seed))
|
| 471 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 472 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 473 |
+
|
| 474 |
+
train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
|
| 475 |
+
eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
|
| 476 |
+
|
| 477 |
+
model, tokenizer = setup_model_and_tokenizer(args, device)
|
| 478 |
+
train_ds = build_dataset(train_rows, tokenizer)
|
| 479 |
+
|
| 480 |
+
parse_penalty = float(args.parse_penalty)
|
| 481 |
+
length_mismatch_penalty = float(args.length_mismatch_penalty)
|
| 482 |
+
full_solve_bonus = float(args.full_solve_bonus)
|
| 483 |
+
|
| 484 |
+
def reward_fn(completions, target, **kwargs):
|
| 485 |
+
rewards: List[float] = []
|
| 486 |
+
for c, tgt in zip(completions, target):
|
| 487 |
+
tgt_list = json.loads(tgt) if isinstance(tgt, str) else list(tgt)
|
| 488 |
+
pred = parse_int_list(str(c))
|
| 489 |
+
rewards.append(
|
| 490 |
+
whole_puzzle_reward(
|
| 491 |
+
pred_list=pred,
|
| 492 |
+
target_list=tgt_list,
|
| 493 |
+
parse_penalty=parse_penalty,
|
| 494 |
+
length_mismatch_penalty=length_mismatch_penalty,
|
| 495 |
+
full_solve_bonus=full_solve_bonus,
|
| 496 |
+
)
|
| 497 |
+
)
|
| 498 |
+
return rewards
|
| 499 |
+
|
| 500 |
+
cfg = GRPOConfig(
|
| 501 |
+
output_dir=args.output_dir,
|
| 502 |
+
per_device_train_batch_size=int(args.per_device_train_batch_size),
|
| 503 |
+
gradient_accumulation_steps=int(args.gradient_accumulation_steps),
|
| 504 |
+
learning_rate=float(args.learning_rate),
|
| 505 |
+
weight_decay=float(args.weight_decay),
|
| 506 |
+
num_train_epochs=float(args.num_epochs),
|
| 507 |
+
max_steps=int(args.max_steps),
|
| 508 |
+
logging_steps=int(args.logging_steps),
|
| 509 |
+
save_steps=int(args.save_steps),
|
| 510 |
+
save_strategy="steps",
|
| 511 |
+
save_total_limit=6,
|
| 512 |
+
bf16=(pick_dtype() == torch.bfloat16),
|
| 513 |
+
fp16=(pick_dtype() == torch.float16),
|
| 514 |
+
max_grad_norm=float(args.max_grad_norm),
|
| 515 |
+
gradient_checkpointing=bool(args.enable_gradient_checkpointing),
|
| 516 |
+
report_to=("wandb" if args.use_wandb else "none"),
|
| 517 |
+
run_name=(args.wandb_run_name or None),
|
| 518 |
+
max_prompt_length=int(args.max_prompt_length),
|
| 519 |
+
max_completion_length=int(args.max_completion_length),
|
| 520 |
+
num_generations=int(args.num_generations),
|
| 521 |
+
beta=float(args.beta),
|
| 522 |
+
temperature=float(args.temperature),
|
| 523 |
+
seed=int(args.seed),
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
trainer = GRPOTrainer(
|
| 527 |
+
model=model,
|
| 528 |
+
reward_funcs=[reward_fn],
|
| 529 |
+
args=cfg,
|
| 530 |
+
train_dataset=train_ds,
|
| 531 |
+
processing_class=tokenizer,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
print("[strawman grpo] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
|
| 535 |
+
|
| 536 |
+
t0 = time.time()
|
| 537 |
+
trainer.train()
|
| 538 |
+
print(f"[strawman grpo] training time = {time.time() - t0:.1f}s", flush=True)
|
| 539 |
+
|
| 540 |
+
final_dir = os.path.join(args.output_dir, "final")
|
| 541 |
+
trainer.save_model(final_dir)
|
| 542 |
+
print(f"[strawman grpo] saved final adapter to {final_dir}", flush=True)
|
| 543 |
+
|
| 544 |
+
print("[strawman grpo] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def main() -> None:
|
| 548 |
+
args = parse_args()
|
| 549 |
+
if args.use_wandb:
|
| 550 |
+
os.environ.setdefault("WANDB_MODE", str(args.wandb_mode))
|
| 551 |
+
os.environ["WANDB_PROJECT"] = args.wandb_project
|
| 552 |
+
if args.phase == "sft":
|
| 553 |
+
run_sft(args)
|
| 554 |
+
else:
|
| 555 |
+
run_grpo(args)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
if __name__ == "__main__":
|
| 559 |
+
main()
|
_runs/status.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# One-shot snapshot of the active sweep.
|
| 3 |
+
SWEEP="${1:-$(ls -dt /home/ubuntu/curriculum_cot/_runs/baseline_1p5b_v4_* 2>/dev/null | head -1)}"
|
| 4 |
+
[[ -z "${SWEEP}" || ! -d "${SWEEP}" ]] && { echo "no sweep"; exit 1; }
|
| 5 |
+
echo "=== sweep: ${SWEEP} ==="
|
| 6 |
+
echo "=== nvidia-smi ==="
|
| 7 |
+
nvidia-smi --query-gpu=index,utilization.gpu,memory.used,memory.total,power.draw --format=csv,noheader
|
| 8 |
+
echo
|
| 9 |
+
echo "=== pids ==="
|
| 10 |
+
while read -r pid gpu name; do
|
| 11 |
+
if kill -0 "$pid" 2>/dev/null; then alive=ALIVE; else alive=DEAD; fi
|
| 12 |
+
printf ' pid=%-6s gpu=%s %-30s %s\n' "$pid" "$gpu" "$name" "$alive"
|
| 13 |
+
done < "${SWEEP}/PIDS.txt"
|
| 14 |
+
echo
|
| 15 |
+
echo "=== per-variant phase + best/last eval ==="
|
| 16 |
+
for v in "${SWEEP}"/pipe_*; do
|
| 17 |
+
vn="$(basename "$v")"
|
| 18 |
+
current_phase="(starting)"
|
| 19 |
+
for ph in s2_sft_extra s2_grpo s3_sft s3_grpo; do
|
| 20 |
+
[[ -d "$v/$ph" ]] && current_phase="$ph"
|
| 21 |
+
done
|
| 22 |
+
printf '\n--- %s (phase=%s) ---\n' "$vn" "${current_phase}"
|
| 23 |
+
# Pipeline log tail
|
| 24 |
+
if [[ -f "$v/PIPELINE.log" ]]; then
|
| 25 |
+
tail -3 "$v/PIPELINE.log" | sed 's/^/ PL: /'
|
| 26 |
+
fi
|
| 27 |
+
# Phase-specific evals
|
| 28 |
+
for ph in s2_sft_extra s2_grpo s3_sft s3_grpo; do
|
| 29 |
+
log="$v/$ph/train.log"
|
| 30 |
+
[[ -f "$log" ]] || continue
|
| 31 |
+
# SFT eval lines
|
| 32 |
+
last_sft="$(grep -E "\[baseline sft eval\] " "$log" 2>/dev/null | tail -3)"
|
| 33 |
+
last_grpo="$(grep -E "\[baseline grpo (custom )?eval" "$log" 2>/dev/null | tail -3)"
|
| 34 |
+
last_train="$(grep -E "\[baseline (sft|grpo) (train|final)" "$log" 2>/dev/null | tail -1)"
|
| 35 |
+
if [[ -n "$last_sft$last_grpo$last_train" ]]; then
|
| 36 |
+
printf ' [%s]\n' "$ph"
|
| 37 |
+
[[ -n "$last_train" ]] && echo "$last_train" | sed 's/^/ tr: /'
|
| 38 |
+
[[ -n "$last_sft" ]] && echo "$last_sft" | sed 's/^/ ev: /'
|
| 39 |
+
[[ -n "$last_grpo" ]] && echo "$last_grpo" | sed 's/^/ ev: /'
|
| 40 |
+
fi
|
| 41 |
+
done
|
| 42 |
+
done
|
_runs/strawman_cellpolicy_pipeline.sh
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Strawman = single-stage cell-policy at stage_i=3 from BASE (no curriculum,
|
| 3 |
+
# no thought tokens). Same per-cell prompt, same trainer scripts, same scoring
|
| 4 |
+
# function as the v6 baseline and the latent champion. The ONLY differences
|
| 5 |
+
# vs the v6 baseline are:
|
| 6 |
+
# - No prior SFT/GRPO at stage_i=1 or stage_i=2 (start fresh from base Qwen).
|
| 7 |
+
# - Single SFT phase + single GRPO phase, both at stage_i=3.
|
| 8 |
+
# - No latent recurrent-hidden tokens (vanilla LoRA on base model).
|
| 9 |
+
# Required env vars: VARIANT, GPU, OUTPUT_ROOT.
|
| 10 |
+
set -euo pipefail
|
| 11 |
+
|
| 12 |
+
ROOT="${ROOT:-/home/ubuntu/curriculum_cot}"
|
| 13 |
+
PYTHON_BIN="${PYTHON_BIN:-/opt/pytorch/bin/python}"
|
| 14 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 15 |
+
GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 16 |
+
|
| 17 |
+
: "${VARIANT:?VARIANT required}"
|
| 18 |
+
: "${GPU:?GPU required}"
|
| 19 |
+
|
| 20 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/_runs/strawman_cellpolicy_$(date +%Y%m%d_%H%M%S)/${VARIANT}}"
|
| 21 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
|
| 22 |
+
|
| 23 |
+
# Use the same S3 hyperparameters as the v6 baseline so the only knob is
|
| 24 |
+
# "did we do the curriculum or not".
|
| 25 |
+
SFT_LR="${SFT_LR:-2e-5}"
|
| 26 |
+
SFT_BS="${SFT_BS:-16}"
|
| 27 |
+
SFT_GA="${SFT_GA:-2}"
|
| 28 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-3000}"
|
| 29 |
+
|
| 30 |
+
GRPO_LR="${GRPO_LR:-5e-6}"
|
| 31 |
+
GRPO_BETA="${GRPO_BETA:-0.0}"
|
| 32 |
+
GRPO_NG="${GRPO_NG:-8}"
|
| 33 |
+
GRPO_BS="${GRPO_BS:-16}"
|
| 34 |
+
GRPO_GA="${GRPO_GA:-2}"
|
| 35 |
+
GRPO_PROMPT="${GRPO_PROMPT:-768}"
|
| 36 |
+
GRPO_COMPL="${GRPO_COMPL:-24}"
|
| 37 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-1500}"
|
| 38 |
+
|
| 39 |
+
# v6-style reward shaping (same as the v6 sweep that hit solve=0.44).
|
| 40 |
+
REWARD_GOOD="${REWARD_GOOD:-1.25}"
|
| 41 |
+
PENALTY_BAD="${PENALTY_BAD:-1.0}"
|
| 42 |
+
PENALTY_MAL="${PENALTY_MAL:-4.0}"
|
| 43 |
+
PENALTY_EMPTY="${PENALTY_EMPTY:-0.5}"
|
| 44 |
+
PENALTY_SINGLETON="${PENALTY_SINGLETON:-1.5}"
|
| 45 |
+
PENALTY_MISSING="${PENALTY_MISSING:-1.0}"
|
| 46 |
+
EXACT_MATCH_BONUS="${EXACT_MATCH_BONUS:-1.0}"
|
| 47 |
+
CARD_MISMATCH_PEN="${CARD_MISMATCH_PEN:-1.5}"
|
| 48 |
+
SFT_OVERSAMPLE="${SFT_OVERSAMPLE:-3}"
|
| 49 |
+
SFT_TGT_MIN="${SFT_TGT_MIN:-0}"
|
| 50 |
+
SFT_TGT_MAX="${SFT_TGT_MAX:-0}"
|
| 51 |
+
|
| 52 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 53 |
+
EVAL_ROWS="${EVAL_ROWS:-100}"
|
| 54 |
+
TRAIN_ROWS="${TRAIN_ROWS:-10000}"
|
| 55 |
+
USE_GC="${USE_GC:-1}" # GC=1 to allow bs 16 on a single 80G GPU
|
| 56 |
+
PHASE_WALL_SECS="${PHASE_WALL_SECS:-0}"
|
| 57 |
+
|
| 58 |
+
TRAIN_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl"
|
| 59 |
+
EVAL_JSONL="${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_eval.jsonl"
|
| 60 |
+
|
| 61 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 62 |
+
PIPELINE_LOG="${OUTPUT_ROOT}/PIPELINE.log"
|
| 63 |
+
ts() { date +'%H:%M:%S'; }
|
| 64 |
+
log() { printf '[%s] %s\n' "$(ts)" "$*" | tee -a "${PIPELINE_LOG}" >&2; }
|
| 65 |
+
|
| 66 |
+
best_ckpt() {
|
| 67 |
+
local d="$1"
|
| 68 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 69 |
+
printf '%s\n' "${d}"; return 0
|
| 70 |
+
fi
|
| 71 |
+
shopt -s nullglob
|
| 72 |
+
local cks=("${d}"/checkpoint-step-* "${d}"/checkpoint-*)
|
| 73 |
+
shopt -u nullglob
|
| 74 |
+
(( ${#cks[@]} == 0 )) && return 1
|
| 75 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
if [[ ! -f "${TRAIN_JSONL}" || ! -f "${EVAL_JSONL}" ]]; then
|
| 79 |
+
log "ERROR: missing dataset jsonls"; exit 1
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
export CUDA_VISIBLE_DEVICES="${GPU}"
|
| 83 |
+
export TOKENIZERS_PARALLELISM=false
|
| 84 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 85 |
+
export HF_HOME="${ROOT}/.hf_cache"
|
| 86 |
+
export TRANSFORMERS_CACHE="${ROOT}/.hf_cache"
|
| 87 |
+
|
| 88 |
+
GC_FLAG=()
|
| 89 |
+
if [[ "${USE_GC}" == "1" ]]; then GC_FLAG=(--enable_gradient_checkpointing); fi
|
| 90 |
+
|
| 91 |
+
log "===== STRAWMAN ${VARIANT} on GPU ${GPU} ====="
|
| 92 |
+
log " SFT lr=${SFT_LR} max_steps=${SFT_MAX_STEPS} bs=${SFT_BS}x${SFT_GA} GC=${USE_GC}"
|
| 93 |
+
log " GRPO lr=${GRPO_LR} max_steps=${GRPO_MAX_STEPS} ng=${GRPO_NG} bs=${GRPO_BS}x${GRPO_GA}"
|
| 94 |
+
log " rewards good=${REWARD_GOOD} bad=${PENALTY_BAD} mal=${PENALTY_MAL} empty=${PENALTY_EMPTY} sng=${PENALTY_SINGLETON} miss=${PENALTY_MISSING} bonus=${EXACT_MATCH_BONUS} card=${CARD_MISMATCH_PEN}"
|
| 95 |
+
log " out=${OUTPUT_ROOT}"
|
| 96 |
+
|
| 97 |
+
# ----- Phase 1: SFT at stage_i=3 from BASE (no init adapter) -----
|
| 98 |
+
SFT_DIR="${OUTPUT_ROOT}/sft"
|
| 99 |
+
mkdir -p "${SFT_DIR}"
|
| 100 |
+
log "=== PHASE SFT (stage_i=3, init=BASE) ==="
|
| 101 |
+
"${PYTHON_BIN}" -u "${SFT_SCRIPT}" \
|
| 102 |
+
--model_name "${MODEL_NAME}" \
|
| 103 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 104 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 105 |
+
--output_dir "${SFT_DIR}" \
|
| 106 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 107 |
+
--init_adapter_dir "" \
|
| 108 |
+
--seed 0 \
|
| 109 |
+
--gpu_id 0 \
|
| 110 |
+
--stage_i 3 \
|
| 111 |
+
--total_empties_hint 20 \
|
| 112 |
+
--per_device_train_batch_size "${SFT_BS}" \
|
| 113 |
+
--gradient_accumulation_steps "${SFT_GA}" \
|
| 114 |
+
--num_epochs 256 \
|
| 115 |
+
--learning_rate "${SFT_LR}" \
|
| 116 |
+
--max_grad_norm 1.0 \
|
| 117 |
+
--logging_steps 25 \
|
| 118 |
+
--eval_steps 200 \
|
| 119 |
+
--save_steps 200 \
|
| 120 |
+
--eval_rows "${EVAL_ROWS}" \
|
| 121 |
+
--max_completion_length 24 \
|
| 122 |
+
--limit_train_rows "${TRAIN_ROWS}" \
|
| 123 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 124 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 125 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 126 |
+
--eval_exact_set_match_stop 0 \
|
| 127 |
+
--eval_solve_rate_stop 0 \
|
| 128 |
+
--min_steps_before_stop 200 \
|
| 129 |
+
--max_wall_clock_seconds "${PHASE_WALL_SECS}" \
|
| 130 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 131 |
+
--multi_value_oversample_factor "${SFT_OVERSAMPLE}" \
|
| 132 |
+
--train_target_size_min "${SFT_TGT_MIN}" \
|
| 133 |
+
--train_target_size_max "${SFT_TGT_MAX}" \
|
| 134 |
+
"${GC_FLAG[@]}" 2>&1 | tee "${SFT_DIR}/train.log"
|
| 135 |
+
|
| 136 |
+
SFT_CKPT="$(best_ckpt "${SFT_DIR}")" || { log "ERROR: no SFT ckpt"; exit 1; }
|
| 137 |
+
log ">>> SFT ckpt: ${SFT_CKPT}"
|
| 138 |
+
|
| 139 |
+
# ----- Phase 2: GRPO at stage_i=3 from SFT output -----
|
| 140 |
+
GRPO_DIR="${OUTPUT_ROOT}/grpo"
|
| 141 |
+
mkdir -p "${GRPO_DIR}"
|
| 142 |
+
log "=== PHASE GRPO (stage_i=3, init=${SFT_CKPT}) ==="
|
| 143 |
+
"${PYTHON_BIN}" -u "${GRPO_SCRIPT}" \
|
| 144 |
+
--model_name "${MODEL_NAME}" \
|
| 145 |
+
--train_jsonl "${TRAIN_JSONL}" \
|
| 146 |
+
--eval_jsonl "${EVAL_JSONL}" \
|
| 147 |
+
--output_dir "${GRPO_DIR}" \
|
| 148 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 149 |
+
--init_adapter_dir "${SFT_CKPT}" \
|
| 150 |
+
--seed 0 \
|
| 151 |
+
--gpu_id 0 \
|
| 152 |
+
--stage_i 3 \
|
| 153 |
+
--total_empties_hint 20 \
|
| 154 |
+
--per_device_train_batch_size "${GRPO_BS}" \
|
| 155 |
+
--gradient_accumulation_steps "${GRPO_GA}" \
|
| 156 |
+
--num_train_epochs 100 \
|
| 157 |
+
--learning_rate "${GRPO_LR}" \
|
| 158 |
+
--logging_steps 10 \
|
| 159 |
+
--save_steps 200 \
|
| 160 |
+
--eval_steps 150 \
|
| 161 |
+
--eval_rows "${EVAL_ROWS}" \
|
| 162 |
+
--num_generations "${GRPO_NG}" \
|
| 163 |
+
--max_prompt_length "${GRPO_PROMPT}" \
|
| 164 |
+
--max_completion_length "${GRPO_COMPL}" \
|
| 165 |
+
--beta "${GRPO_BETA}" \
|
| 166 |
+
--limit_train_rows "${TRAIN_ROWS}" \
|
| 167 |
+
--lora_r 32 --lora_alpha 64 --lora_dropout 0.05 \
|
| 168 |
+
--reward_good_value "${REWARD_GOOD}" \
|
| 169 |
+
--penalty_bad_value "${PENALTY_BAD}" \
|
| 170 |
+
--penalty_malformed "${PENALTY_MAL}" \
|
| 171 |
+
--penalty_empty "${PENALTY_EMPTY}" \
|
| 172 |
+
--penalty_singleton "${PENALTY_SINGLETON}" \
|
| 173 |
+
--penalty_missing "${PENALTY_MISSING}" \
|
| 174 |
+
--exact_match_bonus "${EXACT_MATCH_BONUS}" \
|
| 175 |
+
--cardinality_mismatch_penalty "${CARD_MISMATCH_PEN}" \
|
| 176 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 177 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 178 |
+
--eval_solve_rate_stop 0 \
|
| 179 |
+
--min_steps_before_stop 100 \
|
| 180 |
+
--max_wall_clock_seconds "${PHASE_WALL_SECS}" \
|
| 181 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 182 |
+
"${GC_FLAG[@]}" 2>&1 | tee "${GRPO_DIR}/train.log"
|
| 183 |
+
|
| 184 |
+
GRPO_CKPT="$(best_ckpt "${GRPO_DIR}")" || { log "WARN: no GRPO ckpt found"; exit 0; }
|
| 185 |
+
log ">>> GRPO ckpt: ${GRPO_CKPT}"
|
| 186 |
+
log "===== STRAWMAN ${VARIANT} done ====="
|