diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..ee7780b906bc33a1cfaa8a69fca67cfda8aa4037 --- /dev/null +++ b/.env.example @@ -0,0 +1,70 @@ +# Copy to .env for local runs. Do not put real secrets in this example file. + +# Hugging Face token, if needed for private/gated repos or higher rate limits. +# HF_TOKEN=hf_xxx + +# Weights & Biases configuration. +WANDB_MODE=online +WANDB_ENTITY=training-dynamics +WANDB_PROJECT= +WANDB_RUN_NAME= +WANDB__SERVICE_WAIT=300 + +# GPU/distributed launch defaults. +GPU_IDS=0,1,2,3,4,5,6,7 +NUM_PROCESSES=8 +NPROC_PER_JOB=2 +MASTER_PORT=29501 +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Model/cache/output defaults used by launch scripts. +MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct +CACHE_DIR=.hf_cache +RUN_TAG= +CHECKPOINT_ROOT= +OUTPUT_ROOT= + +# Warm-baseline all-latent stage pipeline defaults. +EMPTIES=20 +MODES_SPEC=recurrent_hidden +GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7 +TRAIN_PUZZLES=10000 +EVAL_PUZZLES=100 +SOLVE_TARGET=0.95 +VALUE_TARGET=0 +MIN_STEPS_BEFORE_STOP=50 +BASELINE_WARM_MAX_STEPS=1000 +LATENT_SFT_MAX_STEPS=1000 +LATENT_GRPO_MAX_STEPS=500 +SFT_NUM_EPOCHS=64 +GRPO_NUM_TRAIN_EPOCHS=50 + +# SFT/GRPO batch and LoRA defaults. +SFT_PER_DEVICE_BS=8 +SFT_GRAD_ACCUM=2 +BASELINE_PER_DEVICE_BS=16 +BASELINE_GRAD_ACCUM=2 +GRPO_PER_DEVICE_BS=4 +GRPO_GRAD_ACCUM=2 +LORA_R=32 +LORA_ALPHA=64 +LORA_DROPOUT=0.05 +GRPO_BETA=0.0 + +# Optional resume adapters. +STAGE1_BASELINE_ADAPTER_DIR= +STAGE1_LATENT_SFT_ADAPTER_DIR= +STAGE1_LATENT_GRPO_ADAPTER_DIR= +STAGE2_BASELINE_WARM_ADAPTER_DIR= +STAGE2_LATENT_SFT_ADAPTER_DIR= +STAGE2_LATENT_GRPO_ADAPTER_DIR= +STAGE3_BASELINE_WARM_ADAPTER_DIR= +STAGE3_LATENT_SFT_ADAPTER_DIR= + +# Optional debug knobs. +FIXED_SLOT_DEBUG_LIMIT=0 +FIXED_SLOT_DECODE_DEBUG_LIMIT=0 +LATENT_VOCAB_DEBUG_TOPK=1 +ATTN_DENSITY_DEBUG_LIMIT=0 +ATTN_DENSITY_THRESHOLD_MULT=1.0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..622cc9cd1cbc743f69e8833710c805150b4cdb78 --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +__pycache__/ +*.pyc + +.env +.env.* +!.env.example + +_prepared_data/ +data/ +checkpoints/ +final_checkpoint/ +location_learner/ + +*.jsonl + +.wandb/ +wandb/ +**/.wandb/ +**/wandb/ +**/wandb_runtime/ + +.venv/ +.hf_cache/ +addition_runs/ +tmp_latent_debug/ +logs/ diff --git a/REBUTTAL_REPORT.md b/REBUTTAL_REPORT.md new file mode 100644 index 0000000000000000000000000000000000000000..9bc7723267de1059c8b191e653d4b13dcade141b --- /dev/null +++ b/REBUTTAL_REPORT.md @@ -0,0 +1,589 @@ +# Curriculum CoT for 9x9 Sudoku — Rebuttal/Paper-Section Material + +_Last updated: 2026-05-24_ + +This document is a comprehensive, paper-ready reference of (a) the data pipeline, +(b) the instruction-tuning prompt format, (c) the curriculum and reward design, +(d) the latent thought-token architecture, (e) the multi-stage SFT-then-GRPO +training recipe, and (f) the headline numerical results — so a rebuttal section +can be assembled directly from this document. + +--- + +## 1. Task + +We use the model as a **per-cell value policy** for 9×9 Sudoku. For a fixed +target empty cell, the model emits a JSON set of candidate digits that are +"i-consistent" with the current grid (definition in §4). We evaluate two +metrics: + +- **per-cell exact set match** (`exact_set_match`) — predicted set equals the + ground-truth i-consistent set; +- **whole-puzzle solve rate** (`solve`) — every empty cell on a 20-empty puzzle + produces an exact set match. + +Because solve = ∏ exact_set_match across the ~20 empty cells of a puzzle, +the two metrics are non-linearly coupled: + +$$ \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} $$ + +so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ — every percentage +point of per-cell exact maps to a much larger swing in solve. + +--- + +## 2. Data pipeline + +### 2.1 Puzzle generation + +Generated by `simple_9x9_curriculum/build_dataset.py`: + +- Start from a base Latin-square grid; randomly relabel digits, permute + rows and columns within bands, and transpose. +- Sample `empties=20` cell positions uniformly at random and erase them. +- Save 10 000 train + 1 000 eval puzzles (seed 0, seed 1). +- Output JSONL files + `data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`. + +A single record contains: + +```json +{ + "prompt": "", + "completion": "[7,3,8,2,6,9,4,5,...]", + "metadata": { + "grid_size": 9, "box_size": 3, "empties": 20, + "empty_locs_1based": [[1,4],[1,9],...], + "target_triples_1based": [[1,4,7],[1,9,3],...] + } +} +``` + +The 20 `target_triples` give the **solved** value at each of the 20 empty +positions, so per-cell training targets are always available. At training +time we expand each puzzle into 20 (puzzle, target_cell) examples. + +### 2.2 Cell-policy framing + +The model is never asked to solve a whole puzzle in one shot. Each example +is one (current_grid, target_cell) pair, and the supervised target is the +set of digits that are "i-consistent" with the current grid (see §4). This +turns Sudoku into a **classification-into-a-set** problem and lets us share +parameters across cells, stages, and puzzle sizes. + +### 2.3 Multi-value oversampling (data-side trick) + +Implemented in `multi_output_cell_policy/sft_multi_output_train.py` via +`tokenizer._multi_value_oversample_factor` and the CLI flags + +``` +--multi_value_oversample_factor INT (default 1) +--train_target_size_min INT (default 0) +--train_target_size_max INT (default 0) +``` + +Inside the dataset builder, examples whose target set has more than one +digit are repeated `multi_value_oversample_factor` times in the training +mix. This biases gradient steps toward exactly the cells the model gets +wrong (multi-value cells). Empirically, this is the single biggest data-side +lever — see §10. + +### 2.4 Where the bottleneck lives + +For 20-empty puzzles in stage 3, only ~25 % of empty cells have a +multi-value target set (the rest collapse to one i-consistent value). Yet +those multi-value cells are responsible for the entire solve-rate gap: +they are the cells where the model under-predicts (returns a singleton +when the target is a 2- or 3-element set), and a single failed cell kills +the whole-puzzle solve. The reward shaping in §6 and the oversample in 2.3 +both attack this single failure mode. + +--- + +## 3. Instruction format + +### 3.1 System prompt + +(verbatim from `multi_output_cell_policy/prompt_builder.py`) + +```text +You are a Sudoku value policy. +This setup uses puzzles with about 20 empty cells. +You will be given one target empty cell. +Return ONLY one JSON object of the form {"values":[...]}. +The JSON object must contain exactly one key named "values". +The "values" field must be a JSON array of unique integers in [1,9]. +You may return as many candidate values as you want, including one, several, +or many values. +Choose the number of returned values yourself based on which values seem +i-consistent. +The order of the values does not matter. +Do not output any explanation, markdown, punctuation outside JSON, or extra text. +Current stage objective: i={i} consistency. +``` + +### 3.2 User message + +```text +Sudoku grid (0 means empty): + +Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), … +Target cell to fill now: (R,C). +Turn: t/T. +Return only JSON with candidate values for this target cell: {"values":[...]} +``` + +We use the Qwen2.5-Instruct chat template (`tokenizer.apply_chat_template`, +`add_generation_prompt=True`) to wrap system + user into the actual prompt +ids. `max_prompt_length = 768`. + +### 3.3 Output format + +```json +{"values":[3,7]} +``` + +Strictly canonical JSON (single key `values`, sorted unique digit list, +no whitespace). Outputs are scored by `parse_values_json` +(`shared_multi_output_policy.py`); any deviation collapses the whole +prediction to `parse_ok=0` and a hard-coded malformed penalty. + +`max_completion_length = 24` tokens — enough to emit any 9-digit set. + +--- + +## 4. Curriculum: stage-i consistency + +The curriculum lives in `_stage_i_consistent_values_for_grid`: + +- **Stage 1 — i=1 (legal moves).** A value v is i=1 consistent at cell c + iff placing v at c violates no Sudoku constraint (row, column, 3×3 box). + This is just "legal candidates". + +- **Stage 2 — i=2.** v is i=2 consistent at c iff (a) it is i=1 consistent + AND (b) after placing v, every other empty cell in the grid still has at + least one i=1-consistent value (i.e. placing v does not immediately make + the puzzle unsolvable by 1-step propagation). + +- **Stage 3 — i=3.** Same recursion one more level deep: v is i=3 consistent + iff after placing v, every other empty cell still has at least one i=2 + consistent value. + +This is bounded look-ahead constraint propagation. Stage-3 sets are tighter +than stage-2 sets which are tighter than stage-1 sets. The curriculum +goal at deployment time is stage-3. + +In data, we use the same source records and just change `--stage_i`; the +target set is regenerated on the fly by `stage_i_consistent_values`. + +--- + +## 5. Latent thought-token architecture + +Base model: **Qwen/Qwen2.5-1.5B-Instruct** + LoRA (r=32, α=64, dropout=0.05) +on `q,k,v,o,gate,up,down`. The latent variant adds **k thought-token slots** +between the prompt and the next-token logits. + +Four modes are implemented (`latent_multi_output_cell_policy/`); the winning +mode for the final number is **`recurrent_hidden`**: + +> `build_recurrent_hidden_latent_hidden(model, ids, mask, k)` +> +> 1. Run the backbone once on the prompt. Keep `base_hidden = h[:,-1,:]`. +> 2. Set `latent_token = base_hidden`. +> 3. Repeat k times: append `latent_token` (as an embedding) to the running +> sequence, run the backbone again on the extended sequence, and replace +> `latent_token` with the new last hidden state. +> 4. After k recursions, `latent_hidden` is fed through the LM head to +> produce the next-token distribution. + +In equations, with E the input embedding lookup, f_θ the LoRA-decorated +backbone, U the LM head: + +$$ z_0 = f_\theta(E([x_1,\dots,x_T]))_{T} $$ +$$ z_{j+1} = f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1,\dots,z_j]\bigr)_{T+j+1},\ j=0,\dots,k-1 $$ +$$ p(\cdot \mid x_{1:T}) = \mathrm{softmax}(U z_k) $$ + +The model can therefore "iterate" k extra forward passes on the same prompt +before committing to a token, with the k extra hidden states carrying +intermediate computation. Setting k=0 recovers the vanilla baseline. + +The other three latent modes are alternatives that we ablated: +`fixed_slots` (concatenate k trainable seed embeddings — Option-2), +`latent_seeds` (similar to fixed_slots), and `residual` (project k extra +hidden states back onto the base hidden state via a learned residual). All +modes share the SFT and GRPO trainers; only the next-token logit function +changes. + +For the curriculum, we grow k stage by stage: + +| stage | num_cot_tokens | comment | +| --- | ---: | --- | +| 1 | 1 | one extra recursion as soon as the model has the surface form | +| 2 | 2 | two — needed for 1-step propagation reasoning | +| 3 | 3 | three — needed for 2-step propagation reasoning | + +--- + +## 6. The reward function + +Defined in `multi_output_cell_policy/rewards.py`. + +Given target set T, predicted set P (after JSON parse), let + +- `num_good = |P ∩ T|` +- `num_bad = |P \ T|` +- `num_missing = max(0, |T| − num_good)` +- `is_exact = (P ≠ ∅) ∧ (P = T)` +- `tri(n) = n(n+1)/2` (rewards larger correct sets superlinearly) + +Then + +$$ +r = \mathrm{tri}(\mathrm{num\_good}) \cdot R_g \;-\; \mathrm{num\_bad} \cdot P_b \\ +- \mathbb{1}[P=\varnothing]\, P_e \;-\; \mathbb{1}[|P|=1, |T|>1, i<2]\, P_s \\ +- \mathrm{num\_missing}\cdot P_m \;+\; \mathbb{1}[\text{is\_exact}]\, B_x \\ +- \mathbb{1}[|P|<|T|, |T|>1]\, P_c +$$ + +with parameters (this is the recipe that produced the 0.58/0.68 latent solve): + +| symbol | flag | value | role | +|---|---|---:|---| +| $R_g$ | `--reward_good_value` | 1.25 | per-correct-value reward (with triangular shape) | +| $P_b$ | `--penalty_bad_value` | 1.0 | per-extra-wrong-value penalty | +| $P_{\!\text{mal}}$ | `--penalty_malformed` | 4.0 | flat penalty if JSON parse fails | +| $P_e$ | `--penalty_empty` | 0.5 | flat penalty if predicted set is empty | +| $P_s$ | `--penalty_singleton` | 1.5 | only at stage<2: punishes singleton on multi-value targets | +| $P_m$ | `--penalty_missing` | **0.75** | per-missing-value (recall pressure) — **NEW** | +| $B_x$ | `--exact_match_bonus` | **2.0** | only when P = T — **NEW** | +| $P_c$ | `--cardinality_mismatch_penalty` | **1.0** | when |P| < |T| and |T|>1 — **NEW** | + +Parse failures short-circuit to `r = -P_mal` and zero per-cell metrics. + +### 6.1 Why those three new terms exist (the breakthrough) + +Diagnosis: at the v3/v4 plateau, eval reported + +``` +exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000 +``` + +across all checkpoints. Per-cell exact and precision/recall were all near +0.95 but the model **always predicted a single digit** (`avg_set_size=1.000`). +On a multi-value target $T=\{8,9\}$, predicting $\{8\}$ keeps precision=1.0, +recall=0.5 and yet `exact_set_match=0`. Solve = exact_set_match^N is +catastrophic in $N$ (=20), so even a small fraction of multi-value cells +killed it. + +Without any of the new terms the optimum of $r$ on a multi-value cell is +trivially "predict the singleton you are most confident about" — there is +no upside to enumerate the second value. The three new terms close exactly +that hole: + +- $P_m$ (`penalty_missing`) directly penalises recall; +- $B_x$ (`exact_match_bonus`) makes $P=T$ strictly dominate any singleton; +- $P_c$ (`cardinality_mismatch_penalty`) is a flat hammer whenever $|P|<|T|$. + +After these terms were added, GRPO on the latent variant moved solve from +~0.30 to ~0.58 (100-puzzle eval) over ~200 steps. The same fix is what we +ported back into the baseline pipeline this evening (see §10). + +--- + +## 7. Multi-stage warm-baseline pipeline (the recipe that worked) + +Master script: +`hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`. + +For each curriculum stage we run **three sub-phases in order**: + +``` +[stage i] + (1) baseline warm SFT (no latent tokens, k=0, vanilla LM) + (2) latent SFT (k = i, latent mode = recurrent_hidden) + (3) latent GRPO (k = i) +``` + +**The warm baseline phase (1) is the trick that makes the curriculum work.** +At every stage transition the data distribution changes (i increases → +target sets shrink) and a new latent slot appears. Doing a vanilla SFT on +the new distribution first lets the LM relearn the surface form on familiar +parameters; THEN the latent SFT adds the extra thought slot on top of an +already-good policy. When we tried to add a new latent slot directly on +top of the previous stage's GRPO checkpoint, training loss did NOT +decrease. + +Concrete LR schedule used for the champion run: + +| phase | init from | LR | k | +|---|---|---:|---:| +| S1 baseline SFT | base Qwen | 2e-4 | 0 | +| S1 latent SFT | S1 baseline | 2e-4 | 1 | +| S1 latent GRPO | S1 latent SFT | 1e-6 | 1 | +| S2 baseline warm SFT | S1 GRPO | 5e-5 | 0 | +| S2 latent SFT | S2 baseline | 5e-5 | 2 | +| S2 latent GRPO | S2 latent SFT | 1e-6 | 2 | +| S3 baseline warm SFT | S2 GRPO | 5e-5 | 0 | +| S3 latent SFT | S3 baseline | 5e-5 → 1e-5 (champion) | 3 | +| S3 latent GRPO | S3 latent SFT | 5e-6 (β=0) | 3 | + +Other shared knobs: + +``` +LoRA: r=32 α=64 dropout=0.05 on q,k,v,o,gate,up,down +SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128 +GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64 + num_generations=4 beta=0.0 max_prompt_length=1024 + max_completion_length=24 +multi_value_oversample_factor=5, exact_match_bonus=2.0, +penalty_missing=0.75, cardinality_mismatch_penalty=1.0 +``` + +--- + +## 8. GRPO settings that mattered + +- **β = 0.** The KL anchor was harmful in every sweep where we tried β>0. + `s3_grpo_kl04` (β=0.04) peaked at solve=0.625 (40p) at step 100 and + regressed to 0.525 by step 500. + +- **num_generations = 4.** With num_generations=2 we routinely saw + `reward_std = 0` (all sampled completions identical → no gradient). + Bumping to 4 fixed it. + +- **Low LR.** `lr=5e-6` was the steadiest. `lr=1e-5` peaked at step 200 + (solve 0.65) then collapsed back to 0.54 — classic mode collapse. + +- **Effective bs ≥ 64.** TRL's GRPOConfig requires + `eff_bs * grad_accum % num_generations == 0`; with 8 GPUs we hit this + trivially, but we caution single-GPU rerunners to set + `per_device_bs=4 grad_accum=2 num_generations=4`. + +- **`enable_input_require_grads()` on the wrapped backbone.** Required for + TRL 0.15.x + PEFT LoRA + gradient checkpointing — otherwise the loss + tensor produced by GRPOTrainer has `requires_grad=False` and `.backward()` + raises. Also `unwrapped.config.use_cache = False`. + +--- + +## 9. Final hyperparameters table — champion latent run + +| group | hyperparameter | value | +|---|---|---| +| Backbone | model | Qwen/Qwen2.5-1.5B-Instruct | +| Backbone | dtype | bf16 | +| Backbone | LoRA target modules | q,k,v,o,gate,up,down | +| Backbone | LoRA r / α / dropout | 32 / 64 / 0.05 | +| Latent | mode | recurrent_hidden | +| Latent | num_cot_tokens (S1/S2/S3) | 1 / 2 / 3 | +| Latent | max_latent_slots / seeds | 8 / 8 | +| Data | total empties | 20 | +| Data | train rows / eval rows | 10 000 / 100 | +| Data | multi_value_oversample_factor | 5 | +| Data | mixed_stage1_ratio (S1) | 1 | +| Data | mixed_stage2_ratio (S≥2) | 1 | +| SFT | per_device_bs / grad_accum | 8 / 2 | +| SFT | num_epochs (cap) | 64 | +| SFT | LR (S1 latent) | 2e-4 | +| SFT | LR (S2/S3 baseline warm + latent) | 5e-5 | +| SFT | LR (S3 latent champion `s3b_lr1e5_o5`) | 1e-5 | +| SFT | weight_decay | 0.0 | +| SFT | gradient checkpointing | on | +| GRPO | per_device_bs / grad_accum | 4 / 2 | +| GRPO | num_generations | 4 | +| GRPO | LR | 5e-6 (S3); 1e-6 (S1, S2) | +| GRPO | β (KL) | 0.0 | +| GRPO | max_prompt_length | 1024 | +| GRPO | max_completion_length | 24 | +| Reward | reward_good_value | 1.25 | +| Reward | penalty_bad_value | 1.0 | +| Reward | penalty_malformed | 4.0 | +| Reward | penalty_empty | 0.5 | +| Reward | penalty_singleton | 1.5 | +| Reward | penalty_missing | 0.75 | +| Reward | exact_match_bonus | 2.0 | +| Reward | cardinality_mismatch_penalty | 1.0 | +| Eval | early-stop on prec/recall | 0.98 | + +--- + +## 10. Headline results + +### 10.1 Latent (with thought tokens, recurrent_hidden) + +| eval | model / phase | step | exact | prec | recall | **solve** | +|---|---|---:|---:|---:|---:|---:| +| **100p (auth.)** | `s3_grpo_baseline` (S3 GRPO, β=0, lr=5e-6) | 200 | 0.9665 | 0.9673 | 0.9680 | **0.580 (58/100)** | +| 40p | `s3_grpo_sharp_rwd` (exact_b=4, card_pen=3) | 300 | — | — | — | **0.675 (27/40)** | +| 40p | `s3_grpo_lr1e5` | 200 | 0.978 | 0.978 | 0.979 | 0.650 | +| 40p | `s3b_lr1e5_o5` (S3 SFT champion) | 2400 | 0.974 | 0.974 | 0.975 | 0.600 | + +### 10.2 Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA) + +| sweep | best variant | best step | exact | **solve (100p)** | +|---|---|---:|---:|---:| +| v3 (single-GPU LR=2e-5, no oversample, no new reward terms) | `baseline_3stage_20260522` | — | 0.730 | **0.000** | +| v4 (LR sweep, multi-GPU, original reward) | `pipe_v_sft_extend` (S3 SFT extended) | 4000 | 0.948 | **0.400** | +| **v6 (this evening, ports latent reward + oversample)** | `v6_i_sft_v_oversample10` (oversample=10) | running | 0.952+ | **0.440 (best so far)** | + +The v6 sweep is still running — `v6_e/f/i` are in S3 SFT continuation, +GRPO follow-on phases queued. The v6_i variant has hit **solve=0.44** at +SFT eval (new baseline best, +0.04 over v4) and is still climbing. + +### 10.3 Stage-by-stage trajectory (latent, 40-puzzle eval) + +``` +S1 SFT : exact ≈ 0.85, solve ≈ 0.20 +S1 GRPO : exact ≈ 0.90, solve ≈ 0.20 +S2 SFT (no oversample) : exact ≈ 0.94, solve ≈ 0.20-0.25 <- the wall +S2 SFT + multi_value_oversample=5 : exact ≈ 0.96, solve ≈ 0.30-0.35 +S2 GRPO + new reward terms : exact ≈ 0.96, solve ≈ 0.35-0.40 +S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion +S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion +S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak +``` + +### 10.4 Latent vs baseline gap (head-to-head, same 100p eval, same prompts) + +| model | exact | prec | recall | **solve** | solved/100 | +|---|---:|---:|---:|---:|---:| +| Latent recurrent_hidden, S3 GRPO | 0.9665 | 0.9673 | 0.9680 | **0.580** | 58 | +| Vanilla baseline, v6_i (best at time of writing) | 0.952 | 0.952 | 0.952 | **0.440** | 44 | + +Gap on 100-puzzle solve: ≈ **+0.14 absolute / +32 % relative** for latent +over the strongest baseline we have. + +--- + +## 11. Why the latent works (interpretation hypotheses) + +These are the working hypotheses the experiments are consistent with; +none is fully proven and ablations are still WIP. + +1. **Constraint-propagation depth.** Stage-3 i-consistency is essentially + 2-ply lookahead. With $k=3$ recurrent hidden tokens the model gets + exactly three extra forward passes between prompt and output — one for + the legality check, one for 1-step propagation, one for the second + step of propagation. Empirically the gap to the no-thought-token + baseline appears at stages where multi-step propagation matters + (stage 2 onward; stage 1 numbers are essentially identical). + +2. **Multi-value cells require enumeration, which a singleton softmax can't + do in one forward pass.** A vanilla LM at 1.5B parameters predicts + essentially deterministically once temperature is low; for a target set + {8, 9} the LM picks one of the two and stops. The latent model can use + one of the recurrent hidden steps to "consider" each option without + committing yet, which is exactly the failure mode in the data + (`avg_set_size = 1.000` for the baseline, `≈ 1.05` for the latent S3 + model on the same eval). + +3. **Stable curriculum capacity growth.** Adding a new latent slot at every + stage gives the model a "fresh slate" of representational capacity at the + exact transition where the task gets harder. The warm-baseline SFT + between stages prevents the new slot from corrupting the previously + learned policy. Without warm baseline, training loss did not decrease + at all (we observed this directly when we tried to skip the warm + baseline yesterday). + +4. **GRPO without latent slots is starved of variance.** With max_completion + length 24 and the model essentially deterministic, GRPO's 4 sampled + completions per prompt collapse to a single answer — `reward_std = 0`, + no gradient. With latent recurrence + the new exact_match_bonus reward, + the model occasionally samples a 2-element set, gets a much higher + reward, and that prompt gets a real gradient signal. + +--- + +## 12. Reproducibility + +Code repository: `https://github.com/Avra98/curriculum_cot` +Latent checkpoints: `https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages` +Baseline checkpoints: `https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep` + +Key scripts: + +- Master orchestrator (latent, 9-phase warm-baseline pipeline): + `hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh` +- Vanilla baseline pipeline: + `_runs/baseline_1p5b_pipeline_v4.sh` (with v6 launchers + `_runs/launch_baseline_push_v6.sh`) +- SFT trainer (vanilla): `multi_output_cell_policy/sft_multi_output_train.py` +- GRPO trainer (vanilla): `multi_output_cell_policy/grpo_multi_output_train.py` +- SFT trainer (latent): `latent_multi_output_cell_policy/sft_latent_multi_output_train.py` +- GRPO trainer (latent): `latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py` +- Reward function: `multi_output_cell_policy/rewards.py` +- Prompt builder: `multi_output_cell_policy/prompt_builder.py` +- Stage-i consistency: `multi_output_cell_policy/shared_multi_output_policy.py` +- 100-puzzle evaluator: `analysis/eval_stage2_checkpoint.py` + +To reproduce the latent champion (1.5B, 9-phase, ~16 GPU·h on 8×H100 80GB): + +```bash +export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter +bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh +``` + +To reproduce the v6 baseline push (single-GPU per variant, ~6 GPU·h): + +```bash +bash _runs/launch_baseline_push_v6.sh +``` + +--- + +## Appendix A. The reward fix as a one-line patch + +The single most consequential code change in this whole project, as a +self-contained patch on `multi_output_cell_policy/rewards.py`: + +```python +# new args (default 0 preserves legacy behaviour) +penalty_missing: float = 0.0 +exact_match_bonus: float = 0.0 +cardinality_mismatch_penalty: float = 0.0 + +num_missing = max(0, len(target_set) - num_good) +is_exact = bool(predicted_values) and (set(predicted_values) == target_set) + +# ... base reward (triangular_number(num_good)*reward_good_value - num_bad*penalty_bad_value) + +if num_missing > 0: + reward -= num_missing * penalty_missing +if is_exact: + reward += exact_match_bonus +if len(predicted_values) < len(target_values) and len(target_values) > 1: + reward -= cardinality_mismatch_penalty +``` + +Defaults are zero so old runs are unaffected; the recipe sets +`(P_m, B_x, P_c) = (0.75, 2.0, 1.0)` for the vanilla recipe and +`(1.0, 4.0, 3.0)` for the "sharp_rwd" variant. + +## Appendix B. The warm-baseline trick as a sequence diagram + +``` +Stage 1 Stage 2 Stage 3 +───────── ───────── ───────── +[base Qwen] ↓ ↓ + ↓ ↓ ↓ +S1 baseline SFT → S2 baseline SFT → S3 baseline SFT +(no latent, k=0) (no latent, k=0) (no latent, k=0) + ↓ ↓ ↓ +S1 latent SFT → S2 latent SFT → S3 latent SFT +(k=1) (k=2) (k=3) + ↓ ↓ ↓ +S1 latent GRPO → S2 latent GRPO → S3 latent GRPO +(k=1, β=0, lr 1e-6) (k=2, β=0) (k=3, β=0, lr 5e-6) + ↓ ↓ ↓ + [final policy] +``` + +Every arrow is `init_adapter_dir = `. Each row is a +"slot in the curriculum"; the column adds reasoning capacity (k+=1) and +moves to a harder target distribution (i+=1). The diagonal across the +diagram is the actual training trajectory. + +--- + +_End of report._ diff --git a/REBUTTAL_REPORT.tex b/REBUTTAL_REPORT.tex new file mode 100644 index 0000000000000000000000000000000000000000..194ac7d6681b107d55ec25ef72e7b12a9ccaba77 --- /dev/null +++ b/REBUTTAL_REPORT.tex @@ -0,0 +1,711 @@ +\documentclass[11pt]{article} +\usepackage[a4paper,margin=1in]{geometry} +\usepackage[utf8]{inputenc} +\usepackage[T1]{fontenc} +\usepackage{lmodern} +\usepackage{microtype} +\usepackage{amsmath,amssymb} +\usepackage{booktabs} +\usepackage{array} +\usepackage{longtable} +\usepackage{tabularx} +\usepackage{xcolor} +\usepackage{listings} +\usepackage{hyperref} +\usepackage{enumitem} +\usepackage{parskip} + +\hypersetup{colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue} + +\lstset{ + basicstyle=\ttfamily\small, + breaklines=true, + columns=fullflexible, + keepspaces=true, + frame=single, + framerule=0.4pt, + xleftmargin=0.5em, + xrightmargin=0.5em, + showstringspaces=false, +} + +\newcommand{\code}[1]{\texttt{#1}} +\newcommand{\indic}{\mathbf{1}} + +\title{Curriculum CoT for $9{\times}9$ Sudoku\\[2pt] +\large Rebuttal / Paper-Section Material} +\author{} +\date{Last updated: 2026--05--24} + +\begin{document} +\maketitle + +\noindent +This document is a comprehensive, paper-ready reference of (a) the data +pipeline, (b) the instruction-tuning prompt format, (c) the curriculum +and reward design, (d) the latent thought-token architecture, (e) the +multi-stage SFT-then-GRPO training recipe, and (f) the headline numerical +results --- so a rebuttal section can be assembled directly from this +document. + +\bigskip +\hrule +\bigskip + +\section{Task} + +We use the model as a \textbf{per-cell value policy} for $9\times 9$ +Sudoku. For a fixed target empty cell, the model emits a JSON set of +candidate digits that are ``i-consistent'' with the current grid +(definition in \S 4). We evaluate two metrics: + +\begin{itemize}[leftmargin=*] +\item \textbf{per-cell exact set match} (\code{exact\_set\_match}) --- +predicted set equals the ground-truth i-consistent set; +\item \textbf{whole-puzzle solve rate} (\code{solve}) --- every empty +cell on a 20-empty puzzle produces an exact set match. +\end{itemize} + +Because $\text{solve} = \prod \text{exact\_set\_match}$ across the $\sim 20$ +empty cells of a puzzle, the two metrics are non-linearly coupled: +\[ +\text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} +\] +so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ --- every +percentage point of per-cell exact maps to a much larger swing in solve. + +\section{Data pipeline} + +\subsection{Puzzle generation} + +Generated by \code{simple\_9x9\_curriculum/build\_dataset.py}: + +\begin{itemize}[leftmargin=*] +\item Start from a base Latin-square grid; randomly relabel digits, +permute rows and columns within bands, and transpose. +\item Sample \code{empties=20} cell positions uniformly at random and +erase them. +\item Save 10\,000 train + 1\,000 eval puzzles (seed 0, seed 1). +\item Output JSONL files \code{data/sudoku\_t3\_20empty\_value\_qwen\_text\_stage1\_\{train,eval\}.jsonl}. +\end{itemize} + +A single record contains: + +\begin{lstlisting} +{ + "prompt": "", + "completion": "[7,3,8,2,6,9,4,5,...]", + "metadata": { + "grid_size": 9, "box_size": 3, "empties": 20, + "empty_locs_1based": [[1,4],[1,9],...], + "target_triples_1based": [[1,4,7],[1,9,3],...] + } +} +\end{lstlisting} + +The 20 \code{target\_triples} give the \textbf{solved} value at each of +the 20 empty positions, so per-cell training targets are always +available. At training time we expand each puzzle into 20 (puzzle, +target\_cell) examples. + +\subsection{Cell-policy framing} + +The model is never asked to solve a whole puzzle in one shot. Each +example is one (current\_grid, target\_cell) pair, and the supervised +target is the set of digits that are ``i-consistent'' with the current +grid (see \S 4). This turns Sudoku into a +\textbf{classification-into-a-set} problem and lets us share parameters +across cells, stages, and puzzle sizes. + +\subsection{Multi-value oversampling (data-side trick)} + +Implemented in \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py} +via \code{tokenizer.\_multi\_value\_oversample\_factor} and the CLI flags + +\begin{lstlisting} +--multi_value_oversample_factor INT (default 1) +--train_target_size_min INT (default 0) +--train_target_size_max INT (default 0) +\end{lstlisting} + +Inside the dataset builder, examples whose target set has more than one +digit are repeated \code{multi\_value\_oversample\_factor} times in the +training mix. This biases gradient steps toward exactly the cells the +model gets wrong (multi-value cells). Empirically, this is the single +biggest data-side lever --- see \S 10. + +\subsection{Where the bottleneck lives} + +For 20-empty puzzles in stage 3, only $\sim 25\%$ of empty cells have a +multi-value target set (the rest collapse to one i-consistent value). +Yet those multi-value cells are responsible for the entire solve-rate +gap: they are the cells where the model under-predicts (returns a +singleton when the target is a 2- or 3-element set), and a single +failed cell kills the whole-puzzle solve. The reward shaping in \S 6 +and the oversample in 2.3 both attack this single failure mode. + +\section{Instruction format} + +\subsection{System prompt} + +(verbatim from \code{multi\_output\_cell\_policy/prompt\_builder.py}) + +\begin{lstlisting} +You are a Sudoku value policy. +This setup uses puzzles with about 20 empty cells. +You will be given one target empty cell. +Return ONLY one JSON object of the form {"values":[...]}. +The JSON object must contain exactly one key named "values". +The "values" field must be a JSON array of unique integers in [1,9]. +You may return as many candidate values as you want, including one, +several, or many values. +Choose the number of returned values yourself based on which values seem +i-consistent. +The order of the values does not matter. +Do not output any explanation, markdown, punctuation outside JSON, or +extra text. +Current stage objective: i={i} consistency. +\end{lstlisting} + +\subsection{User message} + +\begin{lstlisting} +Sudoku grid (0 means empty): + +Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), ... +Target cell to fill now: (R,C). +Turn: t/T. +Return only JSON with candidate values for this target cell: {"values":[...]} +\end{lstlisting} + +We use the Qwen2.5-Instruct chat template +(\code{tokenizer.apply\_chat\_template}, \code{add\_generation\_prompt=True}) +to wrap system + user into the actual prompt ids. +\code{max\_prompt\_length = 768}. + +\subsection{Output format} + +\begin{lstlisting} +{"values":[3,7]} +\end{lstlisting} + +Strictly canonical JSON (single key \code{values}, sorted unique digit +list, no whitespace). Outputs are scored by \code{parse\_values\_json} +(\code{shared\_multi\_output\_policy.py}); any deviation collapses the +whole prediction to \code{parse\_ok=0} and a hard-coded malformed +penalty. + +\code{max\_completion\_length = 24} tokens --- enough to emit any +9-digit set. + +\section{Curriculum: stage-i consistency} + +The curriculum lives in \code{\_stage\_i\_consistent\_values\_for\_grid}: + +\begin{itemize}[leftmargin=*] +\item \textbf{Stage 1 --- $i=1$ (legal moves).} A value $v$ is $i=1$ +consistent at cell $c$ iff placing $v$ at $c$ violates no Sudoku +constraint (row, column, $3\times 3$ box). This is just ``legal +candidates''. + +\item \textbf{Stage 2 --- $i=2$.} $v$ is $i=2$ consistent at $c$ iff +(a) it is $i=1$ consistent AND (b) after placing $v$, every other +empty cell in the grid still has at least one $i=1$-consistent value +(i.e.\ placing $v$ does not immediately make the puzzle unsolvable +by 1-step propagation). + +\item \textbf{Stage 3 --- $i=3$.} Same recursion one more level deep: +$v$ is $i=3$ consistent iff after placing $v$, every other empty cell +still has at least one $i=2$ consistent value. +\end{itemize} + +This is bounded look-ahead constraint propagation. Stage-3 sets are +tighter than stage-2 sets which are tighter than stage-1 sets. The +curriculum goal at deployment time is stage-3. + +In data, we use the same source records and just change \code{--stage\_i}; +the target set is regenerated on the fly by +\code{stage\_i\_consistent\_values}. + +\section{Latent thought-token architecture} + +Base model: \textbf{Qwen/Qwen2.5-1.5B-Instruct} + LoRA +($r=32$, $\alpha=64$, dropout $=0.05$) on +\code{q,k,v,o,gate,up,down}. The latent variant adds \textbf{$k$ +thought-token slots} between the prompt and the next-token logits. + +Four modes are implemented (\code{latent\_multi\_output\_cell\_policy/}); +the winning mode for the final number is \textbf{\code{recurrent\_hidden}}: + +\begin{quote} +\code{build\_recurrent\_hidden\_latent\_hidden(model, ids, mask, k)} +\begin{enumerate}[leftmargin=*,nosep] +\item Run the backbone once on the prompt. Keep +\code{base\_hidden = h[:,-1,:]}. +\item Set \code{latent\_token = base\_hidden}. +\item Repeat $k$ times: append \code{latent\_token} (as an embedding) +to the running sequence, run the backbone again on the extended +sequence, and replace \code{latent\_token} with the new last hidden +state. +\item After $k$ recursions, \code{latent\_hidden} is fed through the LM +head to produce the next-token distribution. +\end{enumerate} +\end{quote} + +In equations, with $E$ the input embedding lookup, $f_\theta$ the +LoRA-decorated backbone, $U$ the LM head: +\begin{align*} +z_0 &= f_\theta\bigl(E([x_1,\dots,x_T])\bigr)_T \\ +z_{j+1} &= f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1, \dots, z_j]\bigr)_{T+j+1},\quad j=0,\dots,k-1 \\ +p(\cdot \mid x_{1:T}) &= \mathrm{softmax}(U z_k) +\end{align*} + +The model can therefore ``iterate'' $k$ extra forward passes on the +same prompt before committing to a token, with the $k$ extra hidden +states carrying intermediate computation. Setting $k=0$ recovers the +vanilla baseline. + +The other three latent modes are alternatives that we ablated: +\code{fixed\_slots} (concatenate $k$ trainable seed embeddings --- +Option-2), \code{latent\_seeds} (similar to \code{fixed\_slots}), and +\code{residual} (project $k$ extra hidden states back onto the base +hidden state via a learned residual). All modes share the SFT and GRPO +trainers; only the next-token logit function changes. + +For the curriculum, we grow $k$ stage by stage: + +\begin{center} +\begin{tabular}{ccl} +\toprule +\textbf{stage} & \textbf{num\_cot\_tokens} & \textbf{comment} \\ +\midrule +1 & 1 & one extra recursion as soon as the model has the surface form \\ +2 & 2 & two --- needed for 1-step propagation reasoning \\ +3 & 3 & three --- needed for 2-step propagation reasoning \\ +\bottomrule +\end{tabular} +\end{center} + +\section{The reward function} + +Defined in \code{multi\_output\_cell\_policy/rewards.py}. + +Given target set $T$, predicted set $P$ (after JSON parse), let +\begin{itemize}[leftmargin=*,nosep] +\item \code{num\_good} $= |P \cap T|$ +\item \code{num\_bad} $= |P \setminus T|$ +\item \code{num\_missing} $= \max(0, |T| - \text{num\_good})$ +\item \code{is\_exact} $= (P \neq \varnothing) \land (P = T)$ +\item $\mathrm{tri}(n) = n(n+1)/2$ (rewards larger correct sets superlinearly) +\end{itemize} + +Then +\begin{align*} +r &= \mathrm{tri}(\text{num\_good}) \cdot R_g \;-\; \text{num\_bad} \cdot P_b \\ + &\quad - \indic[P=\varnothing]\, P_e \;-\; \indic[|P|=1, |T|>1, i<2]\, P_s \\ + &\quad - \text{num\_missing}\cdot P_m \;+\; \indic[\text{is\_exact}]\, B_x \\ + &\quad - \indic[|P|<|T|, |T|>1]\, P_c +\end{align*} + +with parameters (this is the recipe that produced the 0.58/0.68 latent +solve): + +\begin{center} +\begin{tabular}{cllr} +\toprule +\textbf{symbol} & \textbf{flag} & \textbf{role} & \textbf{value} \\ +\midrule +$R_g$ & \code{--reward\_good\_value} & per-correct-value reward (triangular shape) & 1.25 \\ +$P_b$ & \code{--penalty\_bad\_value} & per-extra-wrong-value penalty & 1.0 \\ +$P_{\text{mal}}$ & \code{--penalty\_malformed} & flat penalty if JSON parse fails & 4.0 \\ +$P_e$ & \code{--penalty\_empty} & flat penalty if predicted set is empty & 0.5 \\ +$P_s$ & \code{--penalty\_singleton} & only at stage$<$2: punishes singleton on multi-value targets & 1.5 \\ +$P_m$ & \code{--penalty\_missing} & per-missing-value (recall pressure) --- \textbf{NEW} & \textbf{0.75} \\ +$B_x$ & \code{--exact\_match\_bonus} & only when $P = T$ --- \textbf{NEW} & \textbf{2.0} \\ +$P_c$ & \code{--cardinality\_mismatch\_penalty} & when $|P| < |T|$ and $|T|>1$ --- \textbf{NEW} & \textbf{1.0} \\ +\bottomrule +\end{tabular} +\end{center} + +Parse failures short-circuit to $r = -P_{\text{mal}}$ and zero per-cell +metrics. + +\subsection{Why those three new terms exist (the breakthrough)} + +Diagnosis: at the v3/v4 plateau, eval reported + +\begin{lstlisting} +exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000 +\end{lstlisting} + +across all checkpoints. Per-cell exact and precision/recall were all +near 0.95 but the model \textbf{always predicted a single digit} +(\code{avg\_set\_size=1.000}). On a multi-value target $T=\{8,9\}$, +predicting $\{8\}$ keeps precision $=1.0$, recall $=0.5$ and yet +\code{exact\_set\_match}$=0$. Solve $= \text{exact\_set\_match}^N$ is +catastrophic in $N$ ($=20$), so even a small fraction of multi-value +cells killed it. + +Without any of the new terms the optimum of $r$ on a multi-value cell +is trivially ``predict the singleton you are most confident about'' --- +there is no upside to enumerate the second value. The three new terms +close exactly that hole: + +\begin{itemize}[leftmargin=*,nosep] +\item $P_m$ (\code{penalty\_missing}) directly penalises recall; +\item $B_x$ (\code{exact\_match\_bonus}) makes $P=T$ strictly dominate any singleton; +\item $P_c$ (\code{cardinality\_mismatch\_penalty}) is a flat hammer whenever $|P|<|T|$. +\end{itemize} + +After these terms were added, GRPO on the latent variant moved solve +from $\sim 0.30$ to $\sim 0.58$ (100-puzzle eval) over $\sim 200$ +steps. The same fix is what we ported back into the baseline pipeline +this evening (see \S 10). + +\section{Multi-stage warm-baseline pipeline (the recipe that worked)} + +Master script: +\code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}. + +For each curriculum stage we run \textbf{three sub-phases in order}: + +\begin{lstlisting} +[stage i] + (1) baseline warm SFT (no latent tokens, k=0, vanilla LM) + (2) latent SFT (k = i, latent mode = recurrent_hidden) + (3) latent GRPO (k = i) +\end{lstlisting} + +\textbf{The warm baseline phase (1) is the trick that makes the +curriculum work.} At every stage transition the data distribution +changes ($i$ increases $\Rightarrow$ target sets shrink) and a new +latent slot appears. Doing a vanilla SFT on the new distribution first +lets the LM relearn the surface form on familiar parameters; THEN the +latent SFT adds the extra thought slot on top of an already-good policy. +When we tried to add a new latent slot directly on top of the previous +stage's GRPO checkpoint, training loss did NOT decrease. + +Concrete LR schedule used for the champion run: + +\begin{center} +\begin{tabular}{lllc} +\toprule +\textbf{phase} & \textbf{init from} & \textbf{LR} & \textbf{k} \\ +\midrule +S1 baseline SFT & base Qwen & 2e-4 & 0 \\ +S1 latent SFT & S1 baseline & 2e-4 & 1 \\ +S1 latent GRPO & S1 latent SFT & 1e-6 & 1 \\ +S2 baseline warm SFT & S1 GRPO & 5e-5 & 0 \\ +S2 latent SFT & S2 baseline & 5e-5 & 2 \\ +S2 latent GRPO & S2 latent SFT & 1e-6 & 2 \\ +S3 baseline warm SFT & S2 GRPO & 5e-5 & 0 \\ +S3 latent SFT & S3 baseline & 5e-5 $\rightarrow$ 1e-5 (champion) & 3 \\ +S3 latent GRPO & S3 latent SFT & 5e-6 ($\beta=0$) & 3 \\ +\bottomrule +\end{tabular} +\end{center} + +Other shared knobs: + +\begin{lstlisting} +LoRA: r=32 a=64 dropout=0.05 on q,k,v,o,gate,up,down +SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128 +GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64 + num_generations=4 beta=0.0 max_prompt_length=1024 + max_completion_length=24 +multi_value_oversample_factor=5, exact_match_bonus=2.0, +penalty_missing=0.75, cardinality_mismatch_penalty=1.0 +\end{lstlisting} + +\section{GRPO settings that mattered} + +\begin{itemize}[leftmargin=*] +\item \textbf{$\beta = 0$.} The KL anchor was harmful in every sweep +where we tried $\beta>0$. \code{s3\_grpo\_kl04} ($\beta=0.04$) peaked +at solve $=0.625$ (40p) at step 100 and regressed to $0.525$ by step +500. + +\item \textbf{\code{num\_generations} $= 4$.} With \code{num\_generations}$=2$ +we routinely saw \code{reward\_std}$=0$ (all sampled completions +identical $\Rightarrow$ no gradient). Bumping to 4 fixed it. + +\item \textbf{Low LR.} \code{lr=5e-6} was the steadiest. \code{lr=1e-5} +peaked at step 200 (solve $0.65$) then collapsed back to $0.54$ --- +classic mode collapse. + +\item \textbf{Effective bs $\geq 64$.} TRL's GRPOConfig requires +\code{eff\_bs * grad\_accum \% num\_generations == 0}; with 8 GPUs we +hit this trivially, but we caution single-GPU rerunners to set +\code{per\_device\_bs=4 grad\_accum=2 num\_generations=4}. + +\item \textbf{\code{enable\_input\_require\_grads()} on the wrapped backbone.} +Required for TRL 0.15.x + PEFT LoRA + gradient checkpointing --- +otherwise the loss tensor produced by GRPOTrainer has +\code{requires\_grad=False} and \code{.backward()} raises. Also +\code{unwrapped.config.use\_cache = False}. +\end{itemize} + +\section{Final hyperparameters table --- champion latent run} + +\begin{center} +\begin{longtable}{lll} +\toprule +\textbf{group} & \textbf{hyperparameter} & \textbf{value} \\ +\midrule +\endfirsthead +\toprule +\textbf{group} & \textbf{hyperparameter} & \textbf{value} \\ +\midrule +\endhead +Backbone & model & Qwen/Qwen2.5-1.5B-Instruct \\ +Backbone & dtype & bf16 \\ +Backbone & LoRA target modules & q,k,v,o,gate,up,down \\ +Backbone & LoRA $r$ / $\alpha$ / dropout & 32 / 64 / 0.05 \\ +Latent & mode & \code{recurrent\_hidden} \\ +Latent & \code{num\_cot\_tokens} (S1/S2/S3) & 1 / 2 / 3 \\ +Latent & \code{max\_latent\_slots} / seeds & 8 / 8 \\ +Data & total empties & 20 \\ +Data & train rows / eval rows & 10\,000 / 100 \\ +Data & \code{multi\_value\_oversample\_factor} & 5 \\ +Data & \code{mixed\_stage1\_ratio} (S1) & 1 \\ +Data & \code{mixed\_stage2\_ratio} (S$\geq 2$) & 1 \\ +SFT & per\_device\_bs / grad\_accum & 8 / 2 \\ +SFT & \code{num\_epochs} (cap) & 64 \\ +SFT & LR (S1 latent) & 2e-4 \\ +SFT & LR (S2/S3 baseline warm + latent) & 5e-5 \\ +SFT & LR (S3 latent champion \code{s3b\_lr1e5\_o5}) & 1e-5 \\ +SFT & weight\_decay & 0.0 \\ +SFT & gradient checkpointing & on \\ +GRPO & per\_device\_bs / grad\_accum & 4 / 2 \\ +GRPO & \code{num\_generations} & 4 \\ +GRPO & LR & 5e-6 (S3); 1e-6 (S1, S2) \\ +GRPO & $\beta$ (KL) & 0.0 \\ +GRPO & \code{max\_prompt\_length} & 1024 \\ +GRPO & \code{max\_completion\_length} & 24 \\ +Reward & \code{reward\_good\_value} & 1.25 \\ +Reward & \code{penalty\_bad\_value} & 1.0 \\ +Reward & \code{penalty\_malformed} & 4.0 \\ +Reward & \code{penalty\_empty} & 0.5 \\ +Reward & \code{penalty\_singleton} & 1.5 \\ +Reward & \code{penalty\_missing} & 0.75 \\ +Reward & \code{exact\_match\_bonus} & 2.0 \\ +Reward & \code{cardinality\_mismatch\_penalty} & 1.0 \\ +Eval & early-stop on prec/recall & 0.98 \\ +\bottomrule +\end{longtable} +\end{center} + +\section{Headline results} + +\subsection{Latent (with thought tokens, \code{recurrent\_hidden})} + +\begin{center} +\begin{tabular}{llrrrrr} +\toprule +\textbf{eval} & \textbf{model / phase} & \textbf{step} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} \\ +\midrule +\textbf{100p (auth.)} & \code{s3\_grpo\_baseline} (S3 GRPO, $\beta=0$, lr=5e-6) & 200 & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580 (58/100)} \\ +40p & \code{s3\_grpo\_sharp\_rwd} ($B_x{=}4$, $P_c{=}3$) & 300 & --- & --- & --- & \textbf{0.675 (27/40)} \\ +40p & \code{s3\_grpo\_lr1e5} & 200 & 0.978 & 0.978 & 0.979 & 0.650 \\ +40p & \code{s3b\_lr1e5\_o5} (S3 SFT champion) & 2400 & 0.974 & 0.974 & 0.975 & 0.600 \\ +\bottomrule +\end{tabular} +\end{center} + +\subsection{Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)} + +\begin{center} +\begin{tabular}{llrrr} +\toprule +\textbf{sweep} & \textbf{best variant} & \textbf{best step} & \textbf{exact} & \textbf{solve (100p)} \\ +\midrule +v3 (single-GPU, no oversample, no new reward) & \code{baseline\_3stage\_20260522} & --- & 0.730 & \textbf{0.000} \\ +v4 (LR sweep, multi-GPU, original reward) & \code{pipe\_v\_sft\_extend} (S3 SFT extended) & 4000 & 0.948 & \textbf{0.400} \\ +\textbf{v6 (this evening; ports latent reward + oversample)} & \code{v6\_i\_sft\_v\_oversample10} & running & 0.952$+$ & \textbf{0.440 (best so far)} \\ +\bottomrule +\end{tabular} +\end{center} + +The v6 sweep is still running --- \code{v6\_e/f/i} are in S3 SFT +continuation, GRPO follow-on phases queued. The \code{v6\_i} variant +has hit \textbf{solve $=0.44$} at SFT eval (new baseline best, +$+0.04$ over v4) and is still climbing. + +\subsection{Stage-by-stage trajectory (latent, 40-puzzle eval)} + +\begin{lstlisting} +S1 SFT : exact ~ 0.85, solve ~ 0.20 +S1 GRPO : exact ~ 0.90, solve ~ 0.20 +S2 SFT (no oversample) : exact ~ 0.94, solve ~ 0.20-0.25 <- the wall +S2 SFT + multi_value_oversample=5 : exact ~ 0.96, solve ~ 0.30-0.35 +S2 GRPO + new reward terms : exact ~ 0.96, solve ~ 0.35-0.40 +S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion +S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion +S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak +\end{lstlisting} + +\subsection{Latent vs baseline gap (head-to-head, same 100p eval, same prompts)} + +\begin{center} +\begin{tabular}{lrrrrr} +\toprule +\textbf{model} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} & \textbf{solved/100} \\ +\midrule +Latent \code{recurrent\_hidden}, S3 GRPO & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580} & 58 \\ +Vanilla baseline, \code{v6\_i} (best at time of writing) & 0.952 & 0.952 & 0.952 & \textbf{0.440} & 44 \\ +\bottomrule +\end{tabular} +\end{center} + +Gap on 100-puzzle solve: $\approx$ \textbf{$+0.14$ absolute / $+32\%$ +relative} for latent over the strongest baseline we have. + +\section{Why the latent works (interpretation hypotheses)} + +These are the working hypotheses the experiments are consistent with; +none is fully proven and ablations are still WIP. + +\begin{enumerate}[leftmargin=*] +\item \textbf{Constraint-propagation depth.} Stage-3 i-consistency is +essentially 2-ply lookahead. With $k=3$ recurrent hidden tokens the +model gets exactly three extra forward passes between prompt and +output --- one for the legality check, one for 1-step propagation, +one for the second step of propagation. Empirically the gap to the +no-thought-token baseline appears at stages where multi-step +propagation matters (stage 2 onward; stage 1 numbers are essentially +identical). + +\item \textbf{Multi-value cells require enumeration, which a singleton +softmax can't do in one forward pass.} A vanilla LM at 1.5B +parameters predicts essentially deterministically once temperature is +low; for a target set $\{8, 9\}$ the LM picks one of the two and +stops. The latent model can use one of the recurrent hidden steps to +``consider'' each option without committing yet, which is exactly +the failure mode in the data (\code{avg\_set\_size} $= 1.000$ for the +baseline, $\approx 1.05$ for the latent S3 model on the same eval). + +\item \textbf{Stable curriculum capacity growth.} Adding a new latent +slot at every stage gives the model a ``fresh slate'' of +representational capacity at the exact transition where the task +gets harder. The warm-baseline SFT between stages prevents the new +slot from corrupting the previously learned policy. Without warm +baseline, training loss did not decrease at all (we observed this +directly when we tried to skip the warm baseline). + +\item \textbf{GRPO without latent slots is starved of variance.} With +\code{max\_completion\_length} 24 and the model essentially +deterministic, GRPO's 4 sampled completions per prompt collapse to a +single answer --- \code{reward\_std}$=0$, no gradient. With latent +recurrence + the new \code{exact\_match\_bonus} reward, the model +occasionally samples a 2-element set, gets a much higher reward, and +that prompt gets a real gradient signal. +\end{enumerate} + +\section{Reproducibility} + +\noindent +Code repository: \url{https://github.com/Avra98/curriculum_cot} \\ +Latent checkpoints: \url{https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages} \\ +Baseline checkpoints: \url{https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep} + +Key scripts: + +\begin{itemize}[leftmargin=*,nosep] +\item Master orchestrator (latent, 9-phase warm-baseline pipeline): +\code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh} +\item Vanilla baseline pipeline: +\code{\_runs/baseline\_1p5b\_pipeline\_v4.sh} (with v6 launchers +\code{\_runs/launch\_baseline\_push\_v6.sh}) +\item SFT trainer (vanilla): +\code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py} +\item GRPO trainer (vanilla): +\code{multi\_output\_cell\_policy/grpo\_multi\_output\_train.py} +\item SFT trainer (latent): +\code{latent\_multi\_output\_cell\_policy/sft\_latent\_multi\_output\_train.py} +\item GRPO trainer (latent): +\code{latent\_multi\_output\_cell\_policy/grpo\_residual\_projector\_latent\_train.py} +\item Reward function: \code{multi\_output\_cell\_policy/rewards.py} +\item Prompt builder: \code{multi\_output\_cell\_policy/prompt\_builder.py} +\item Stage-i consistency: +\code{multi\_output\_cell\_policy/shared\_multi\_output\_policy.py} +\item 100-puzzle evaluator: \code{analysis/eval\_stage2\_checkpoint.py} +\end{itemize} + +To reproduce the latent champion (1.5B, 9-phase, $\sim 16$ GPU$\cdot$h +on $8\times$H100 80GB): + +\begin{lstlisting} +export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter +bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh +\end{lstlisting} + +To reproduce the v6 baseline push (single-GPU per variant, $\sim 6$ +GPU$\cdot$h): + +\begin{lstlisting} +bash _runs/launch_baseline_push_v6.sh +\end{lstlisting} + +\appendix + +\section{The reward fix as a one-line patch} + +The single most consequential code change in this whole project, as a +self-contained patch on \code{multi\_output\_cell\_policy/rewards.py}: + +\begin{lstlisting}[language=Python] +# new args (default 0 preserves legacy behaviour) +penalty_missing: float = 0.0 +exact_match_bonus: float = 0.0 +cardinality_mismatch_penalty: float = 0.0 + +num_missing = max(0, len(target_set) - num_good) +is_exact = bool(predicted_values) and (set(predicted_values) == target_set) + +# ... base reward (triangular_number(num_good)*reward_good_value +# - num_bad*penalty_bad_value) + +if num_missing > 0: + reward -= num_missing * penalty_missing +if is_exact: + reward += exact_match_bonus +if len(predicted_values) < len(target_values) and len(target_values) > 1: + reward -= cardinality_mismatch_penalty +\end{lstlisting} + +Defaults are zero so old runs are unaffected; the recipe sets +$(P_m, B_x, P_c) = (0.75, 2.0, 1.0)$ for the vanilla recipe and +$(1.0, 4.0, 3.0)$ for the ``sharp\_rwd'' variant. + +\section{The warm-baseline trick as a sequence diagram} + +\begin{lstlisting} +Stage 1 Stage 2 Stage 3 +--------- --------- --------- +[base Qwen] | | + | | | + v v v +S1 baseline SFT -> S2 baseline SFT -> S3 baseline SFT +(no latent, k=0) (no latent, k=0) (no latent, k=0) + | | | + v v v +S1 latent SFT -> S2 latent SFT -> S3 latent SFT +(k=1) (k=2) (k=3) + | | | + v v v +S1 latent GRPO -> S2 latent GRPO -> S3 latent GRPO +(k=1, b=0, lr 1e-6) (k=2, b=0) (k=3, b=0, lr 5e-6) + | | | + [final policy] +\end{lstlisting} + +Every arrow is \code{init\_adapter\_dir = }. Each row +is a ``slot in the curriculum''; the column adds reasoning capacity +($k\mathrel{+}=1$) and moves to a harder target distribution +($i\mathrel{+}=1$). The diagonal across the diagram is the actual +training trajectory. + +\bigskip +\noindent\emph{End of report.} + +\end{document} diff --git a/addition/README.md b/addition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..643f959d04ad2dffee1039068c56fc4c471ccae2 --- /dev/null +++ b/addition/README.md @@ -0,0 +1,114 @@ +# Addition Carry Experiment + +This folder contains a standalone PyTorch experiment for algorithmic addition with carry on a one-layer decoder-only Transformer. + +The comparison includes exactly three methods: + +- `nocurr_nocot`: no curriculum, no latent chain-of-thought +- `curr_nocot`: digit-length curriculum, no latent chain-of-thought +- `curr_cot`: same one-layer backbone plus recurrent latent scratchpad tokens + +## Task + +Each example adds two reversed digit sequences in a configurable radix. Stage `k` means only the first `k` least-significant positions vary and the rest are zero. Every method now trains on the full example in one forward pass: + +- predict all `k` active sum digits +- predict the final carry bit as an additional output slot +- compute masked loss over the active digits plus the final carry + +This means the baseline and both curriculum variants learn whole-example addition rather than a single queried digit at a time. Internal carry targets are still kept for diagnostics and linear probing, but not as an auxiliary training loss. + +The latent method reuses the same one-layer Transformer recurrently. After an initial pass over the inputs and output slots, the model appends continuous latent scratchpad tokens before the output slots and reruns the same layer, giving later curriculum stages more internal workspace for carry-like computation. + +## Files + +- `config.py`: experiment config and CLI handling +- `data.py`: synthetic data generation, curriculum stages, carry-heavy subsets +- `model.py`: one-layer decoder-only Transformer and latent recurrence +- `train.py`: single-run training entrypoint +- `eval.py`: evaluation and diagnostics +- `plots.py`: local plotting +- `run_comparison.py`: multi-seed comparison across all three methods + +## Outputs + +Each run writes: + +- `config.json` +- `artifacts/history.jsonl` +- `artifacts/summary.json` +- `checkpoints/best.pt` +- `checkpoints/last.pt` +- local plots under `plots/` + +If W&B is enabled, the same run also logs metrics there. + +## Run A Single Method + +Default settings: + +```bash +python addition/train.py --model nocurr_nocot --use_wandb +python addition/train.py --model curr_nocot --use_wandb +python addition/train.py --model curr_cot --use_wandb +``` + +The default backbone now uses a single attention head. To run a harder hexadecimal setting: + +```bash +python addition/train.py --model curr_cot --radix 16 --use_wandb --output_dir addition_runs/hex_curr_cot +``` + +Run offline or local-only: + +```bash +python addition/train.py --model curr_cot --wandb_mode offline +python addition/train.py --model curr_cot --no_wandb +``` + +## Smoke Test + +Use the smoke preset to verify the whole pipeline quickly: + +```bash +python addition/train.py --model curr_cot --preset smoke --no_wandb --output_dir addition_runs/smoke_curr_cot +``` + +## Run The Full Comparison + +This runs all three methods across multiple seeds and saves aggregate plots and JSON: + +```bash +python addition/run_comparison.py --preset default --use_wandb --comparison_output_dir addition_runs/comparison_default +``` + +Small fast comparison: + +```bash +python addition/run_comparison.py --preset smoke --no_wandb --comparison_output_dir addition_runs/comparison_smoke +``` + +## Main Metrics + +The experiment reports: + +- digit accuracy by output position +- final-carry accuracy +- exact whole-sum accuracy by active length +- average digit accuracy by length +- in-distribution results up to `train_max_digits` +- OOD results on longer lengths +- separate uniform and carry-heavy evaluations + +## Diagnostics + +The evaluation also includes: + +- a linear probe on output-slot hidden states for carry prediction +- attention summaries showing how strongly the final carry readout attends to operand digits, previous output slots, and latent tokens + +## Notes + +- The first version is intentionally small enough to iterate locally. +- The backbone depth stays fixed at one layer in all methods. +- The latent method gets more recurrent compute, not more layers. diff --git a/addition/__init__.py b/addition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75a72cdafe45f051e00df4b98a67613c58e8f22f --- /dev/null +++ b/addition/__init__.py @@ -0,0 +1 @@ +"""Standalone addition-with-carry experiment package.""" diff --git a/addition/config.py b/addition/config.py new file mode 100644 index 0000000000000000000000000000000000000000..258f55be0a7b0c33fe3ba6fef8b5c33b541da16d --- /dev/null +++ b/addition/config.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import argparse +import dataclasses +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +import torch + + +VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot") +VALID_PRESETS = ("default", "smoke") + + +@dataclass +class ExperimentConfig: + model: str = "nocurr_nocot" + output_dir: str = "addition_runs/default" + seed: int = 0 + device: str = "cuda" if torch.cuda.is_available() else "cpu" + preset: str = "default" + run_name: str = "" + notes: str = "" + use_wandb: bool = True + wandb_project: str = "addition-carry" + wandb_entity: str = "" + wandb_mode: str = "online" + radix: int = 10 + train_max_digits: int = 12 + eval_max_digits: int = 20 + ood_lengths: tuple[int, ...] = (14, 16, 20) + train_batch_size: int = 256 + eval_batch_size: int = 512 + learning_rate: float = 3e-4 + weight_decay: float = 1e-2 + grad_clip_norm: float = 1.0 + carry_loss_weight: float = 0.0 + train_steps: int = 3600 + max_steps_per_stage: int = 300 + validation_interval: int = 100 + stage_accuracy_threshold: float = 0.99 + initial_stage: int = 1 + eval_examples_per_length: int = 256 + carry_heavy_examples_per_length: int = 256 + train_carry_heavy_prob: float = 0.15 + d_model: int = 512 + n_heads: int = 1 + ff_dim: int = 2048 + dropout: float = 0.0 + max_latent_steps: int = 12 + attention_probe_examples: int = 256 + linear_probe_epochs: int = 150 + linear_probe_lr: float = 1e-2 + comparison_num_seeds: int = 5 + + def __post_init__(self) -> None: + if self.model not in VALID_MODELS: + raise ValueError(f"Unsupported model: {self.model}") + if self.preset not in VALID_PRESETS: + raise ValueError(f"Unsupported preset: {self.preset}") + if self.train_max_digits > self.eval_max_digits: + raise ValueError("train_max_digits must be <= eval_max_digits") + if self.max_latent_steps < 0: + raise ValueError("max_latent_steps must be non-negative") + if self.radix < 2 or self.radix > 16: + raise ValueError("radix must be between 2 and 16") + if self.initial_stage < 1 or self.initial_stage > self.train_max_digits: + raise ValueError("initial_stage must be between 1 and train_max_digits") + self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits) + if not self.ood_lengths: + self.ood_lengths = (self.eval_max_digits,) + + @property + def uses_curriculum(self) -> bool: + return self.model in {"curr_nocot", "curr_cot"} + + @property + def uses_latent_cot(self) -> bool: + return self.model == "curr_cot" + + @property + def discrete_vocab_size(self) -> int: + return self.radix + 2 + + @property + def digit_vocab_size(self) -> int: + return self.radix + + @property + def input_sequence_length(self) -> int: + return self.input_sequence_length_for_digits(self.eval_max_digits) + + @property + def output_sequence_length(self) -> int: + return self.output_sequence_length_for_digits(self.eval_max_digits) + + @property + def base_sequence_length(self) -> int: + return self.base_sequence_length_for_digits(self.eval_max_digits) + + @property + def max_sequence_length(self) -> int: + return self.base_sequence_length + self.max_latent_steps + + @property + def effective_run_name(self) -> str: + if self.run_name: + return self.run_name + return f"{self.model}_base{self.radix}_seed{self.seed}" + + def input_sequence_length_for_digits(self, active_digits: int) -> int: + return (int(active_digits) * 2) + 2 + + def output_sequence_length_for_digits(self, active_digits: int) -> int: + return int(active_digits) + 1 + + def base_sequence_length_for_digits(self, active_digits: int) -> int: + return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits) + + def latent_steps_for_stage(self, stage: int) -> int: + if not self.uses_latent_cot: + return 0 + return max(0, min(int(stage), int(self.max_latent_steps))) + + +def default_output_root() -> Path: + return Path("addition_runs") + + +def apply_preset(config: ExperimentConfig) -> ExperimentConfig: + config = dataclasses.replace(config) + if config.preset == "smoke": + config.output_dir = config.output_dir or str(default_output_root() / "smoke") + config.train_batch_size = 64 + config.eval_batch_size = 128 + config.d_model = 128 + config.n_heads = 1 + config.ff_dim = 512 + config.train_steps = 180 + config.max_steps_per_stage = 40 + config.validation_interval = 20 + config.eval_examples_per_length = 64 + config.carry_heavy_examples_per_length = 64 + config.attention_probe_examples = 64 + config.linear_probe_epochs = 60 + config.comparison_num_seeds = 2 + return config + + +def config_to_dict(config: ExperimentConfig) -> dict: + data = dataclasses.asdict(config) + data["ood_lengths"] = list(config.ood_lengths) + data["uses_curriculum"] = config.uses_curriculum + data["uses_latent_cot"] = config.uses_latent_cot + data["discrete_vocab_size"] = config.discrete_vocab_size + data["input_sequence_length"] = config.input_sequence_length + data["output_sequence_length"] = config.output_sequence_length + data["base_sequence_length"] = config.base_sequence_length + data["max_sequence_length"] = config.max_sequence_length + data["effective_run_name"] = config.effective_run_name + return data + + +def save_config(config: ExperimentConfig, output_dir: Path) -> None: + output_dir.mkdir(parents=True, exist_ok=True) + with (output_dir / "config.json").open("w", encoding="utf-8") as handle: + json.dump(config_to_dict(config), handle, indent=2, sort_keys=True) + + +def add_config_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot") + parser.add_argument("--output_dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--preset", choices=VALID_PRESETS, default="default") + parser.add_argument("--run_name", type=str, default="") + parser.add_argument("--notes", type=str, default="") + parser.add_argument("--use_wandb", action="store_true") + parser.add_argument("--no_wandb", action="store_true") + parser.add_argument("--wandb_project", type=str, default="addition-carry") + parser.add_argument("--wandb_entity", type=str, default="") + parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled")) + parser.add_argument("--radix", type=int, default=10) + parser.add_argument("--train_max_digits", type=int, default=12) + parser.add_argument("--eval_max_digits", type=int, default=20) + parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20]) + parser.add_argument("--train_batch_size", type=int, default=256) + parser.add_argument("--eval_batch_size", type=int, default=512) + parser.add_argument("--learning_rate", type=float, default=3e-4) + parser.add_argument("--weight_decay", type=float, default=1e-2) + parser.add_argument("--grad_clip_norm", type=float, default=1.0) + parser.add_argument("--carry_loss_weight", type=float, default=0.0) + parser.add_argument("--train_steps", type=int, default=3600) + parser.add_argument("--max_steps_per_stage", type=int, default=300) + parser.add_argument("--validation_interval", type=int, default=100) + parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99) + parser.add_argument("--initial_stage", type=int, default=1) + parser.add_argument("--eval_examples_per_length", type=int, default=256) + parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256) + parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15) + parser.add_argument("--d_model", type=int, default=512) + parser.add_argument("--n_heads", type=int, default=1) + parser.add_argument("--ff_dim", type=int, default=2048) + parser.add_argument("--dropout", type=float, default=0.0) + parser.add_argument("--max_latent_steps", type=int, default=12) + parser.add_argument("--attention_probe_examples", type=int, default=256) + parser.add_argument("--linear_probe_epochs", type=int, default=150) + parser.add_argument("--linear_probe_lr", type=float, default=1e-2) + parser.add_argument("--comparison_num_seeds", type=int, default=5) + + +def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig: + use_wandb = bool(args.use_wandb or not args.no_wandb) + if args.wandb_mode == "disabled": + use_wandb = False + output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}") + config = ExperimentConfig( + model=args.model, + output_dir=output_dir, + seed=args.seed, + device=args.device, + preset=args.preset, + run_name=args.run_name, + notes=args.notes, + use_wandb=use_wandb, + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + wandb_mode=args.wandb_mode, + radix=args.radix, + train_max_digits=args.train_max_digits, + eval_max_digits=args.eval_max_digits, + ood_lengths=tuple(args.ood_lengths), + train_batch_size=args.train_batch_size, + eval_batch_size=args.eval_batch_size, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + grad_clip_norm=args.grad_clip_norm, + carry_loss_weight=args.carry_loss_weight, + train_steps=args.train_steps, + max_steps_per_stage=args.max_steps_per_stage, + validation_interval=args.validation_interval, + stage_accuracy_threshold=args.stage_accuracy_threshold, + initial_stage=args.initial_stage, + eval_examples_per_length=args.eval_examples_per_length, + carry_heavy_examples_per_length=args.carry_heavy_examples_per_length, + train_carry_heavy_prob=args.train_carry_heavy_prob, + d_model=args.d_model, + n_heads=args.n_heads, + ff_dim=args.ff_dim, + dropout=args.dropout, + max_latent_steps=args.max_latent_steps, + attention_probe_examples=args.attention_probe_examples, + linear_probe_epochs=args.linear_probe_epochs, + linear_probe_lr=args.linear_probe_lr, + comparison_num_seeds=args.comparison_num_seeds, + ) + return apply_preset(config) + + +def build_arg_parser(description: str) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=description) + add_config_arguments(parser) + return parser + + +def parse_config(description: str) -> ExperimentConfig: + parser = build_arg_parser(description) + args = parser.parse_args() + return build_config_from_args(args) + + +def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]: + root = Path(config.output_dir) + directories = { + "root": root, + "checkpoints": root / "checkpoints", + "stage_checkpoints": root / "checkpoints" / "stages", + "plots": root / "plots", + "artifacts": root / "artifacts", + } + for directory in directories.values(): + directory.mkdir(parents=True, exist_ok=True) + return directories + + +def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]: + return {f"{prefix}{key}": value for key, value in metrics.items()} + + +def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]: + for stage in range(1, config.train_max_digits + 1): + yield stage diff --git a/addition/data.py b/addition/data.py new file mode 100644 index 0000000000000000000000000000000000000000..61becdfe3d93cb9197e9fc1435505d69180dd383 --- /dev/null +++ b/addition/data.py @@ -0,0 +1,390 @@ +from __future__ import annotations + +import dataclasses +import math +import random +from dataclasses import dataclass +from typing import Iterable + +import torch + +from addition.config import ExperimentConfig + + +DIGIT_OFFSET = 0 +DEFAULT_SYMBOLS = "0123456789ABCDEF" + + +@dataclass +class AdditionProblem: + a_digits: list[int] + b_digits: list[int] + sum_digits: list[int] + carry_out: list[int] + active_digits: int + is_carry_heavy: bool + + +@dataclass +class Batch: + input_ids: torch.Tensor + target_digits: torch.Tensor + target_digit_mask: torch.Tensor + target_carry: torch.Tensor + target_final_carry: torch.Tensor + active_digits: torch.Tensor + is_carry_heavy: torch.Tensor + + +@dataclass +class EvaluationSuite: + validation_uniform: dict[int, list[AdditionProblem]] + test_uniform: dict[int, list[AdditionProblem]] + test_carry_heavy: dict[int, list[AdditionProblem]] + + +def a_token_id(radix: int) -> int: + return radix + + +def b_token_id(radix: int) -> int: + return radix + 1 + + +def seed_everything(seed: int) -> None: + random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def compute_sum_and_carry(a_digits: list[int], b_digits: list[int], radix: int) -> tuple[list[int], list[int]]: + sum_digits: list[int] = [] + carry_out: list[int] = [] + carry = 0 + for a_digit, b_digit in zip(a_digits, b_digits): + total = int(a_digit) + int(b_digit) + carry + sum_digits.append(total % radix) + carry = total // radix + carry_out.append(carry) + return sum_digits, carry_out + + +def sample_uniform_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem: + a_digits = [0] * max_digits + b_digits = [0] * max_digits + for index in range(active_digits): + a_digits[index] = rng.randint(0, radix - 1) + b_digits[index] = rng.randint(0, radix - 1) + sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix) + return AdditionProblem( + a_digits=a_digits, + b_digits=b_digits, + sum_digits=sum_digits, + carry_out=carry_out, + active_digits=active_digits, + is_carry_heavy=False, + ) + + +def sample_carry_heavy_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem: + a_digits = [0] * max_digits + b_digits = [0] * max_digits + carry = 0 + for index in range(active_digits): + high_floor = max(0, radix // 2) + a_digit = rng.randint(high_floor, radix - 1) + if carry == 0: + min_b = max(0, radix - a_digit) + else: + min_b = max(0, (radix - 1) - a_digit) + b_digit = rng.randint(min_b, radix - 1) + a_digits[index] = a_digit + b_digits[index] = b_digit + total = a_digit + b_digit + carry + carry = total // radix + sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix) + return AdditionProblem( + a_digits=a_digits, + b_digits=b_digits, + sum_digits=sum_digits, + carry_out=carry_out, + active_digits=active_digits, + is_carry_heavy=True, + ) + + +def sample_problem( + max_digits: int, + active_digits: int, + radix: int, + rng: random.Random, + carry_heavy: bool = False, +) -> AdditionProblem: + if carry_heavy: + return sample_carry_heavy_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng) + return sample_uniform_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng) + + +def encode_problem_tokens(problem: AdditionProblem, radix: int) -> list[int]: + return ( + [a_token_id(radix)] + + [DIGIT_OFFSET + digit for digit in problem.a_digits[: problem.active_digits]] + + [b_token_id(radix)] + + [DIGIT_OFFSET + digit for digit in problem.b_digits[: problem.active_digits]] + ) + + +def build_batch( + problems: list[AdditionProblem], + radix: int, + device: str, +) -> Batch: + active_digits = problems[0].active_digits if problems else 0 + input_ids = torch.tensor( + [ + encode_problem_tokens(problem=problem, radix=radix) + for problem in problems + ], + dtype=torch.long, + device=device, + ) + target_digits = torch.tensor( + [problem.sum_digits[:active_digits] for problem in problems], + dtype=torch.long, + device=device, + ) + target_digit_mask = torch.tensor( + [[1] * active_digits for _ in problems], + dtype=torch.bool, + device=device, + ) + target_carry = torch.tensor( + [problem.carry_out[:active_digits] for problem in problems], + dtype=torch.long, + device=device, + ) + target_final_carry = torch.tensor( + [problem.carry_out[problem.active_digits - 1] for problem in problems], + dtype=torch.long, + device=device, + ) + return Batch( + input_ids=input_ids, + target_digits=target_digits, + target_digit_mask=target_digit_mask, + target_carry=target_carry, + target_final_carry=target_final_carry, + active_digits=torch.tensor([problem.active_digits for problem in problems], dtype=torch.long, device=device), + is_carry_heavy=torch.tensor([int(problem.is_carry_heavy) for problem in problems], dtype=torch.bool, device=device), + ) + + +def sample_training_batch( + config: ExperimentConfig, + stage: int, + rng: random.Random, + device: str, +) -> Batch: + problems: list[AdditionProblem] = [] + for _ in range(config.train_batch_size): + carry_heavy = rng.random() < config.train_carry_heavy_prob + problem = sample_problem( + max_digits=stage, + active_digits=stage, + radix=config.radix, + rng=rng, + carry_heavy=carry_heavy, + ) + problems.append(problem) + return build_batch( + problems=problems, + radix=config.radix, + device=device, + ) + + +def build_problem_set( + *, + max_digits: int, + active_digits: int, + radix: int, + count: int, + seed: int, + carry_heavy: bool, +) -> list[AdditionProblem]: + rng = random.Random(seed) + return [ + sample_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng, carry_heavy=carry_heavy) + for _ in range(count) + ] + + +def build_evaluation_suite(config: ExperimentConfig) -> EvaluationSuite: + validation_uniform: dict[int, list[AdditionProblem]] = {} + test_uniform: dict[int, list[AdditionProblem]] = {} + test_carry_heavy: dict[int, list[AdditionProblem]] = {} + all_lengths = sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths)) + for length in all_lengths: + validation_uniform[length] = build_problem_set( + max_digits=length, + active_digits=length, + radix=config.radix, + count=config.eval_examples_per_length, + seed=10_000 + length, + carry_heavy=False, + ) + test_uniform[length] = build_problem_set( + max_digits=length, + active_digits=length, + radix=config.radix, + count=config.eval_examples_per_length, + seed=20_000 + length, + carry_heavy=False, + ) + test_carry_heavy[length] = build_problem_set( + max_digits=length, + active_digits=length, + radix=config.radix, + count=config.carry_heavy_examples_per_length, + seed=30_000 + length, + carry_heavy=True, + ) + return EvaluationSuite( + validation_uniform=validation_uniform, + test_uniform=test_uniform, + test_carry_heavy=test_carry_heavy, + ) + + +def digits_to_string(digits: Iterable[int], final_carry: int, radix: int) -> str: + digits = list(digits) + significant_digits = list(digits) + if final_carry: + significant_digits.append(final_carry) + while len(significant_digits) > 1 and significant_digits[-1] == 0: + significant_digits.pop() + symbols = DEFAULT_SYMBOLS[:radix] + return "".join(symbols[digit] for digit in reversed(significant_digits)) + + +def value_from_digits(digits: Iterable[int], final_carry: int, radix: int) -> int: + value = 0 + place = 1 + for digit in digits: + value += int(digit) * place + place *= radix + if final_carry: + value += int(final_carry) * place + return value + + +def exact_sum_matches( + predicted_digits: list[int], + predicted_final_carry: int, + truth_digits: list[int], + truth_final_carry: int, +) -> bool: + return predicted_digits == truth_digits and int(predicted_final_carry) == int(truth_final_carry) + + +def summarize_problem(problem: AdditionProblem, radix: int) -> dict[str, int | str]: + final_carry = problem.carry_out[problem.active_digits - 1] + return { + "a": digits_to_string(problem.a_digits[: problem.active_digits], final_carry=0, radix=radix), + "b": digits_to_string(problem.b_digits[: problem.active_digits], final_carry=0, radix=radix), + "sum": digits_to_string(problem.sum_digits[: problem.active_digits], final_carry=final_carry, radix=radix), + "radix": radix, + "active_digits": problem.active_digits, + "carry_heavy": int(problem.is_carry_heavy), + } + + +def count_carry_chain(problem: AdditionProblem) -> int: + longest = 0 + current = 0 + for index in range(problem.active_digits): + if problem.carry_out[index]: + current += 1 + longest = max(longest, current) + else: + current = 0 + return longest + + +def carry_density(problem: AdditionProblem) -> float: + if problem.active_digits <= 0: + return 0.0 + return float(sum(problem.carry_out[: problem.active_digits])) / float(problem.active_digits) + + +def curriculum_stage_lengths(config: ExperimentConfig) -> list[int]: + if config.uses_curriculum: + return list(range(1, config.train_max_digits + 1)) + return [config.train_max_digits] + + +def infer_eval_lengths(config: ExperimentConfig) -> list[int]: + return sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths)) + + +def estimate_train_tokens_per_step(config: ExperimentConfig, stage: int) -> int: + latent_steps = config.latent_steps_for_stage(stage) + return config.train_batch_size * (config.base_sequence_length_for_digits(stage) + latent_steps) + + +def stage_fraction(stage: int, max_stage: int) -> float: + if max_stage <= 1: + return 1.0 + return float(stage - 1) / float(max_stage - 1) + + +def maybe_trim_examples(problems: list[AdditionProblem], limit: int) -> list[AdditionProblem]: + if limit <= 0 or len(problems) <= limit: + return list(problems) + return list(problems[:limit]) + + +def stage_display_name(stage: int) -> str: + suffix = "th" + if stage % 10 == 1 and stage % 100 != 11: + suffix = "st" + elif stage % 10 == 2 and stage % 100 != 12: + suffix = "nd" + elif stage % 10 == 3 and stage % 100 != 13: + suffix = "rd" + return f"{stage}{suffix}-digit" + + +def ideal_carry_chain_examples(config: ExperimentConfig, active_digits: int) -> list[AdditionProblem]: + examples: list[AdditionProblem] = [] + for base_digit in (max(0, config.radix - 2), config.radix - 1): + a_digits = [base_digit] * active_digits + b_digits = [1] * active_digits + sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=config.radix) + examples.append( + AdditionProblem( + a_digits=a_digits, + b_digits=b_digits, + sum_digits=sum_digits, + carry_out=carry_out, + active_digits=active_digits, + is_carry_heavy=True, + ) + ) + return examples + + +def expected_sum_length(problem: AdditionProblem) -> int: + final_carry = problem.carry_out[problem.active_digits - 1] + return problem.active_digits + int(final_carry > 0) + + +def average_query_count(config: ExperimentConfig) -> float: + lengths = curriculum_stage_lengths(config) + return sum(lengths) / float(len(lengths)) + + +def token_budget(config: ExperimentConfig) -> int: + avg_stage = int(math.ceil(average_query_count(config))) + return config.base_sequence_length_for_digits(avg_stage) + config.latent_steps_for_stage(avg_stage) diff --git a/addition/eval.py b/addition/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..afdca22b1fc96509c444db1f88bd97e05bc75d3e --- /dev/null +++ b/addition/eval.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Iterable + +import torch +from torch import nn + +from addition.config import ExperimentConfig +from addition.data import ( + AdditionProblem, + EvaluationSuite, + build_batch, + carry_density, + count_carry_chain, + exact_sum_matches, + maybe_trim_examples, +) +from addition.model import AdditionTransformer + + +@dataclass +class LengthMetrics: + digit_accuracy: float + final_carry_accuracy: float + exact_match: float + avg_carry_chain: float + avg_carry_density: float + example_count: int + per_position_digit_accuracy: list[float] + + +def _chunked(sequence: list[AdditionProblem], chunk_size: int) -> Iterable[list[AdditionProblem]]: + for start in range(0, len(sequence), chunk_size): + yield sequence[start : start + chunk_size] + + +@torch.no_grad() +def evaluate_problem_set( + model: AdditionTransformer, + config: ExperimentConfig, + problems: list[AdditionProblem], + active_digits: int, + *, + device: str, + return_attention: bool = False, +) -> tuple[LengthMetrics, dict[str, float] | None]: + model.eval() + latent_steps = config.latent_steps_for_stage(active_digits) + num_examples = len(problems) + if num_examples == 0: + empty = LengthMetrics( + digit_accuracy=0.0, + final_carry_accuracy=0.0, + exact_match=0.0, + avg_carry_chain=0.0, + avg_carry_density=0.0, + example_count=0, + per_position_digit_accuracy=[0.0] * active_digits, + ) + return empty, None + + predicted_digits = torch.zeros(num_examples, active_digits, dtype=torch.long) + predicted_final_carry = torch.zeros(num_examples, dtype=torch.long) + truth_digits = torch.tensor([[problem.sum_digits[position] for position in range(active_digits)] for problem in problems], dtype=torch.long) + truth_final_carry = torch.tensor([problem.carry_out[active_digits - 1] for problem in problems], dtype=torch.long) + attention_stats: dict[str, float] | None = None + + offset = 0 + for problem_chunk in _chunked(problems, config.eval_batch_size): + batch = build_batch( + problems=problem_chunk, + radix=config.radix, + device=device, + ) + outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=return_attention) + chunk_size = len(problem_chunk) + predicted_digits[offset : offset + chunk_size] = outputs.digit_logits.argmax(dim=-1)[:, :active_digits].cpu() + predicted_final_carry[offset : offset + chunk_size] = outputs.final_carry_logits.argmax(dim=-1).cpu() + if return_attention and attention_stats is None: + attention_stats = summarize_attention( + attention_weights=outputs.attention_weights, + active_digits=active_digits, + input_sequence_length=batch.input_ids.shape[1], + output_sequence_length=outputs.output_hidden.shape[1], + ) + offset += chunk_size + + exact_matches = [] + for example_index, problem in enumerate(problems): + exact_matches.append( + exact_sum_matches( + predicted_digits=predicted_digits[example_index].tolist(), + predicted_final_carry=int(predicted_final_carry[example_index].item()), + truth_digits=problem.sum_digits[:active_digits], + truth_final_carry=problem.carry_out[active_digits - 1], + ) + ) + + per_position_digit = (predicted_digits == truth_digits).float().mean(dim=0).tolist() + metrics = LengthMetrics( + digit_accuracy=float((predicted_digits == truth_digits).float().mean().item()), + final_carry_accuracy=float((predicted_final_carry == truth_final_carry).float().mean().item()), + exact_match=float(torch.tensor(exact_matches, dtype=torch.float32).mean().item()), + avg_carry_chain=float(sum(count_carry_chain(problem) for problem in problems) / len(problems)), + avg_carry_density=float(sum(carry_density(problem) for problem in problems) / len(problems)), + example_count=len(problems), + per_position_digit_accuracy=[float(value) for value in per_position_digit], + ) + return metrics, attention_stats + + +def summarize_attention( + attention_weights: torch.Tensor | None, + *, + active_digits: int, + input_sequence_length: int, + output_sequence_length: int, +) -> dict[str, float]: + if attention_weights is None: + return {} + # Shape: [batch, heads, target_len, source_len] + final_attention = attention_weights[:, :, -1, :] + attention_mean = final_attention.mean(dim=(0, 1)) + active_last_a_index = active_digits + active_last_b_index = input_sequence_length // 2 + active_digits + latent_slice = attention_mean[input_sequence_length : -output_sequence_length] + output_slice = attention_mean[-output_sequence_length:-1] + entropy = -torch.sum(attention_mean * torch.log(attention_mean.clamp_min(1e-9))).item() + summary = { + "lsd_a_attention": float(attention_mean[1].item()), + "msd_a_attention": float(attention_mean[active_last_a_index].item()), + "lsd_b_attention": float(attention_mean[(input_sequence_length // 2) + 1].item()), + "msd_b_attention": float(attention_mean[active_last_b_index].item()), + "attention_entropy": float(entropy), + "all_latent_attention": float(latent_slice.sum().item()) if latent_slice.numel() else 0.0, + "previous_output_attention": float(output_slice.sum().item()) if output_slice.numel() else 0.0, + } + return summary + + +@torch.no_grad() +def evaluate_length_dict( + model: AdditionTransformer, + config: ExperimentConfig, + problems_by_length: dict[int, list[AdditionProblem]], + *, + device: str, + attention_length: int | None = None, +) -> dict[str, dict]: + structured: dict[str, dict] = {} + for length, problems in sorted(problems_by_length.items()): + length_metrics, attention = evaluate_problem_set( + model=model, + config=config, + problems=problems, + active_digits=length, + device=device, + return_attention=attention_length is not None and attention_length == length, + ) + structured[str(length)] = { + "digit_accuracy": length_metrics.digit_accuracy, + "final_carry_accuracy": length_metrics.final_carry_accuracy, + "exact_match": length_metrics.exact_match, + "avg_carry_chain": length_metrics.avg_carry_chain, + "avg_carry_density": length_metrics.avg_carry_density, + "example_count": length_metrics.example_count, + "per_position_digit_accuracy": length_metrics.per_position_digit_accuracy, + } + if attention is not None: + structured[str(length)]["attention_summary"] = attention + return structured + + +def collect_hidden_dataset( + model: AdditionTransformer, + config: ExperimentConfig, + problems: list[AdditionProblem], + *, + active_digits: int, + device: str, + limit_examples: int, +) -> tuple[torch.Tensor, torch.Tensor]: + model.eval() + latent_steps = config.latent_steps_for_stage(active_digits) + selected = maybe_trim_examples(problems, limit_examples) + hidden_states: list[torch.Tensor] = [] + carry_targets: list[torch.Tensor] = [] + with torch.no_grad(): + for problem_chunk in _chunked(selected, config.eval_batch_size): + batch = build_batch( + problems=problem_chunk, + radix=config.radix, + device=device, + ) + outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=False) + slot_hidden = outputs.output_hidden[:, :active_digits, :] + slot_mask = batch.target_digit_mask + hidden_states.append(slot_hidden[slot_mask].detach().cpu()) + carry_targets.append(batch.target_carry[slot_mask].detach().cpu()) + return torch.cat(hidden_states, dim=0), torch.cat(carry_targets, dim=0) + + +def fit_linear_probe( + hidden_states: torch.Tensor, + carry_targets: torch.Tensor, + *, + epochs: int, + learning_rate: float, +) -> dict[str, float]: + if hidden_states.numel() == 0: + return {"probe_accuracy": 0.0} + indices = torch.randperm(hidden_states.shape[0]) + hidden_states = hidden_states[indices] + carry_targets = carry_targets[indices] + split_index = max(1, int(0.8 * hidden_states.shape[0])) + train_hidden = hidden_states[:split_index] + train_targets = carry_targets[:split_index] + test_hidden = hidden_states[split_index:] + test_targets = carry_targets[split_index:] + if test_hidden.numel() == 0: + test_hidden = train_hidden + test_targets = train_targets + + probe = nn.Linear(hidden_states.shape[-1], 2) + optimizer = torch.optim.AdamW(probe.parameters(), lr=learning_rate) + loss_fn = nn.CrossEntropyLoss() + for _ in range(epochs): + logits = probe(train_hidden) + loss = loss_fn(logits, train_targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + predictions = probe(test_hidden).argmax(dim=-1) + accuracy = float((predictions == test_targets).float().mean().item()) + return {"probe_accuracy": accuracy} + + +def evaluate_suite( + model: AdditionTransformer, + config: ExperimentConfig, + suite: EvaluationSuite, + *, + device: str, +) -> dict[str, dict]: + id_lengths = list(range(1, config.train_max_digits + 1)) + ood_lengths = list(config.ood_lengths) + max_attention_length = max(ood_lengths) if ood_lengths else config.train_max_digits + + validation = evaluate_length_dict( + model=model, + config=config, + problems_by_length={length: suite.validation_uniform[length] for length in id_lengths}, + device=device, + ) + uniform_all = evaluate_length_dict( + model=model, + config=config, + problems_by_length={length: suite.test_uniform[length] for length in sorted(set(id_lengths + ood_lengths))}, + device=device, + attention_length=max_attention_length, + ) + carry_heavy_all = evaluate_length_dict( + model=model, + config=config, + problems_by_length={length: suite.test_carry_heavy[length] for length in sorted(set(id_lengths + ood_lengths))}, + device=device, + attention_length=max_attention_length, + ) + probe_hidden, probe_targets = collect_hidden_dataset( + model=model, + config=config, + problems=suite.test_carry_heavy[max_attention_length], + active_digits=max_attention_length, + device=device, + limit_examples=config.attention_probe_examples, + ) + diagnostics = fit_linear_probe( + hidden_states=probe_hidden, + carry_targets=probe_targets, + epochs=config.linear_probe_epochs, + learning_rate=config.linear_probe_lr, + ) + diagnostics["attention_uniform"] = uniform_all[str(max_attention_length)].get("attention_summary", {}) + diagnostics["attention_carry_heavy"] = carry_heavy_all[str(max_attention_length)].get("attention_summary", {}) + return { + "validation_uniform": validation, + "test_uniform": uniform_all, + "test_carry_heavy": carry_heavy_all, + "diagnostics": diagnostics, + } + + +def stage_validation_metric(results: dict[str, dict], stage: int) -> float: + stage_metrics = results["validation_uniform"][str(stage)] + return float(stage_metrics["digit_accuracy"]) + + +def flatten_nested_metrics(prefix: str, nested: dict[str, dict]) -> dict[str, float]: + flat: dict[str, float] = {} + for split_name, split_metrics in nested.items(): + if split_name == "diagnostics": + for key, value in split_metrics.items(): + if isinstance(value, dict): + for inner_key, inner_value in value.items(): + flat[f"{prefix}{split_name}/{key}/{inner_key}"] = float(inner_value) + else: + flat[f"{prefix}{split_name}/{key}"] = float(value) + continue + for length, length_metrics in split_metrics.items(): + if not isinstance(length_metrics, dict): + continue + for metric_name, metric_value in length_metrics.items(): + if isinstance(metric_value, list): + if metric_value: + flat[f"{prefix}{split_name}/length_{length}/{metric_name}_mean"] = float(sum(metric_value) / len(metric_value)) + continue + if isinstance(metric_value, dict): + for inner_key, inner_value in metric_value.items(): + flat[f"{prefix}{split_name}/length_{length}/{metric_name}/{inner_key}"] = float(inner_value) + continue + flat[f"{prefix}{split_name}/length_{length}/{metric_name}"] = float(metric_value) + return flat diff --git a/addition/model.py b/addition/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5b15447221c4e0e64bbc5e7e5ce12cf0c739bb --- /dev/null +++ b/addition/model.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from torch import nn + +from addition.config import ExperimentConfig + + +@dataclass +class ModelOutput: + digit_logits: torch.Tensor + final_carry_logits: torch.Tensor + output_hidden: torch.Tensor + latent_history: list[torch.Tensor] + attention_weights: torch.Tensor | None + + +class TransformerBlock(nn.Module): + def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None: + super().__init__() + self.ln_1 = nn.LayerNorm(d_model) + self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) + self.dropout = nn.Dropout(dropout) + self.ln_2 = nn.LayerNorm(d_model) + self.mlp = nn.Sequential( + nn.Linear(d_model, ff_dim), + nn.GELU(), + nn.Linear(ff_dim, d_model), + nn.Dropout(dropout), + ) + + def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]: + seq_len = hidden_states.shape[1] + causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1) + normed = self.ln_1(hidden_states) + attn_output, attn_weights = self.attn( + normed, + normed, + normed, + need_weights=need_weights, + average_attn_weights=False, + attn_mask=causal_mask, + ) + hidden_states = hidden_states + self.dropout(attn_output) + hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states)) + return hidden_states, attn_weights if need_weights else None + + +class AdditionTransformer(nn.Module): + def __init__(self, config: ExperimentConfig) -> None: + super().__init__() + self.config = config + self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model) + self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model) + self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model)) + self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model)) + self.block = TransformerBlock( + d_model=config.d_model, + n_heads=config.n_heads, + ff_dim=config.ff_dim, + dropout=config.dropout, + ) + self.final_ln = nn.LayerNorm(config.d_model) + self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size) + self.final_carry_head = nn.Linear(config.d_model, 2) + self.reset_parameters() + + def reset_parameters(self) -> None: + nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) + nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02) + nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02) + nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02) + nn.init.xavier_uniform_(self.digit_head.weight) + nn.init.zeros_(self.digit_head.bias) + nn.init.xavier_uniform_(self.final_carry_head.weight) + nn.init.zeros_(self.final_carry_head.bias) + + def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor: + seq_len = input_ids.shape[1] + positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) + return self.token_embedding(input_ids) + self.position_embedding(positions) + + def embed_output_slots( + self, + batch_size: int, + output_length: int, + latent_count: int, + input_length: int, + device: torch.device, + ) -> torch.Tensor: + positions = torch.arange(output_length, device=device) + input_length + latent_count + positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions) + return positioned.unsqueeze(0).expand(batch_size, -1, -1) + + def _run_block( + self, + embeddings: torch.Tensor, + *, + need_attention: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention) + hidden_states = self.final_ln(hidden_states) + return hidden_states, attention_weights + + def forward( + self, + input_ids: torch.Tensor, + *, + latent_steps: int = 0, + return_attention: bool = False, + ) -> ModelOutput: + base_embeddings = self.embed_discrete_tokens(input_ids) + latent_history: list[torch.Tensor] = [] + attention_weights: torch.Tensor | None = None + batch_size = input_ids.shape[0] + input_length = input_ids.shape[1] + active_digits = max(1, (input_length - 2) // 2) + output_length = active_digits + 1 + output_embeddings = self.embed_output_slots( + batch_size=batch_size, + output_length=output_length, + latent_count=0, + input_length=input_length, + device=input_ids.device, + ) + hidden_states, attention_weights = self._run_block( + torch.cat([base_embeddings, output_embeddings], dim=1), + need_attention=return_attention, + ) + output_hidden = hidden_states[:, -output_length:, :] + summary_hidden = output_hidden[:, -1, :] + latent_history.append(summary_hidden) + + latent_embeddings: list[torch.Tensor] = [] + for step_index in range(int(latent_steps)): + latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1) + latent_position_index = input_length + step_index + latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1) + latent_embeddings.append(latent_token) + output_embeddings = self.embed_output_slots( + batch_size=batch_size, + output_length=output_length, + latent_count=len(latent_embeddings), + input_length=input_length, + device=input_ids.device, + ) + hidden_states, attention_weights = self._run_block( + torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1), + need_attention=return_attention, + ) + latent_index = input_length + step_index + summary_hidden = hidden_states[:, latent_index, :] + output_hidden = hidden_states[:, -output_length:, :] + latent_history.append(summary_hidden) + + digit_logits = self.digit_head(output_hidden[:, :active_digits, :]) + final_carry_logits = self.final_carry_head(output_hidden[:, -1, :]) + return ModelOutput( + digit_logits=digit_logits, + final_carry_logits=final_carry_logits, + output_hidden=output_hidden, + latent_history=latent_history, + attention_weights=attention_weights, + ) + + def parameter_count(self) -> int: + return sum(parameter.numel() for parameter in self.parameters()) + + +def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer: + model = AdditionTransformer(config) + if device is not None: + model = model.to(device) + return model + + +@torch.no_grad() +def describe_model(config: ExperimentConfig) -> dict[str, int]: + model = build_model(config) + total_params = model.parameter_count() + head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name) + embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name) + return { + "total_params": int(total_params), + "embedding_params": int(embedding_params), + "head_params": int(head_params), + "backbone_params": int(total_params - head_params), + } diff --git a/addition/plots.py b/addition/plots.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd6d0aef566dae27b8bed59ee8fd148c8e8d712 --- /dev/null +++ b/addition/plots.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import math +from pathlib import Path +from typing import Any + + +def _load_pyplot(): + import matplotlib.pyplot as plt + + return plt + + +def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]: + if not history: + return [] + plt = _load_pyplot() + output_dir.mkdir(parents=True, exist_ok=True) + steps = [entry["global_step"] for entry in history] + digit_acc = [entry["validation_digit_accuracy"] for entry in history] + carry_acc = [entry["validation_final_carry_accuracy"] for entry in history] + exact_match = [entry["validation_exact_match"] for entry in history] + stages = [entry["stage"] for entry in history] + + saved_paths: list[Path] = [] + + plt.figure(figsize=(8, 4.5)) + plt.plot(steps, digit_acc, label="Val digit acc") + plt.plot(steps, carry_acc, label="Val final carry acc") + plt.plot(steps, exact_match, label="Val exact match") + plt.xlabel("Global step") + plt.ylabel("Accuracy") + plt.ylim(0.0, 1.01) + plt.legend() + plt.tight_layout() + metrics_path = output_dir / "training_curves.png" + plt.savefig(metrics_path, dpi=160) + plt.close() + saved_paths.append(metrics_path) + + plt.figure(figsize=(8, 4.5)) + plt.step(steps, stages, where="post") + plt.xlabel("Global step") + plt.ylabel("Curriculum stage") + plt.tight_layout() + stage_path = output_dir / "stage_progression.png" + plt.savefig(stage_path, dpi=160) + plt.close() + saved_paths.append(stage_path) + + return saved_paths + + +def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]: + lengths = sorted(int(length) for length in aggregate[method][split].keys()) + means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths] + stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths] + return lengths, means, stds + + +def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]: + plt = _load_pyplot() + output_dir.mkdir(parents=True, exist_ok=True) + saved_paths: list[Path] = [] + methods = list(aggregate.keys()) + splits = [ + ("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"), + ("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"), + ] + for split, filename, title in splits: + plt.figure(figsize=(8, 4.5)) + for method in methods: + lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match") + plt.plot(lengths, means, marker="o", label=method) + lower = [max(0.0, mean - std) for mean, std in zip(means, stds)] + upper = [min(1.0, mean + std) for mean, std in zip(means, stds)] + plt.fill_between(lengths, lower, upper, alpha=0.15) + plt.xlabel("Active digits") + plt.ylabel("Exact-match accuracy") + plt.title(title) + plt.ylim(0.0, 1.01) + plt.legend() + plt.tight_layout() + path = output_dir / filename + plt.savefig(path, dpi=160) + plt.close() + saved_paths.append(path) + + plt.figure(figsize=(8, 4.5)) + for method in methods: + stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys()) + means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages] + stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages] + plt.plot(stages, means, marker="o", label=method) + plt.fill_between( + stages, + [max(0.0, mean - std) for mean, std in zip(means, stds)], + [min(1.0, mean + std) for mean, std in zip(means, stds)], + alpha=0.15, + ) + plt.xlabel("Curriculum stage") + plt.ylabel("Best validation digit accuracy") + plt.ylim(0.0, 1.01) + plt.title("Validation digit accuracy vs stage") + plt.legend() + plt.tight_layout() + stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png" + plt.savefig(stage_curve_path, dpi=160) + plt.close() + saved_paths.append(stage_curve_path) + return saved_paths + + +def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]: + plt = _load_pyplot() + output_dir.mkdir(parents=True, exist_ok=True) + saved_paths = plot_training_history(summary.get("history", []), output_dir) + uniform = summary["final_results"]["test_uniform"] + carry_heavy = summary["final_results"]["test_carry_heavy"] + lengths = sorted(int(length) for length in uniform.keys()) + uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths] + carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths] + plt.figure(figsize=(8, 4.5)) + plt.plot(lengths, uniform_exact, marker="o", label="Uniform") + plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy") + plt.xlabel("Active digits") + plt.ylabel("Exact-match accuracy") + plt.ylim(0.0, 1.01) + plt.legend() + plt.tight_layout() + final_curve_path = output_dir / "final_exact_match_by_length.png" + plt.savefig(final_curve_path, dpi=160) + plt.close() + saved_paths.append(final_curve_path) + return saved_paths diff --git a/addition/run_comparison.py b/addition/run_comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..48b62f58bf617807a1a5f9f1c56dc10c800b61e0 --- /dev/null +++ b/addition/run_comparison.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from statistics import mean, pstdev +from typing import Any + +from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args +from addition.plots import plot_method_comparison +from addition.train import run_experiment + + +def _mean_std(values: list[float]) -> dict[str, float]: + if not values: + return {"mean": 0.0, "std": 0.0} + if len(values) == 1: + return {"mean": float(values[0]), "std": 0.0} + return {"mean": float(mean(values)), "std": float(pstdev(values))} + + +def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]: + lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int) + metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"] + aggregated: dict[str, Any] = {} + for length in lengths: + aggregated[length] = {} + for metric_name in metric_names: + values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries] + aggregated[length][metric_name] = _mean_std(values) + return aggregated + + +def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]: + max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"]) + aggregated: dict[str, Any] = {} + for stage in range(1, max_stage + 1): + stage_values = [] + stage_exact = [] + for summary in run_summaries: + stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage] + if not stage_entries: + continue + stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries)) + stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries)) + aggregated[str(stage)] = { + "validation_digit_accuracy": _mean_std(stage_values), + "validation_exact_match": _mean_std(stage_exact), + } + return aggregated + + +def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]: + diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries] + output: dict[str, Any] = { + "probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]), + } + for attention_key in ("attention_uniform", "attention_carry_heavy"): + attention_values = [diag.get(attention_key, {}) for diag in diagnostics] + metric_names = sorted({metric for diag in attention_values for metric in diag.keys()}) + output[attention_key] = { + metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names + } + return output + + +def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]: + aggregate: dict[str, Any] = {} + for method, run_summaries in results_by_method.items(): + aggregate[method] = { + "test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"), + "test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"), + "stage_progression": _aggregate_stage_progression(run_summaries), + "diagnostics": _aggregate_diagnostics(run_summaries), + } + return aggregate + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.") + add_config_arguments(parser) + parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS) + parser.add_argument("--seeds", nargs="*", type=int, default=None) + parser.add_argument("--comparison_output_dir", type=str, default="") + return parser + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + base_config = apply_preset(build_config_from_args(args)) + seeds = args.seeds or list(range(base_config.comparison_num_seeds)) + comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}") + comparison_root.mkdir(parents=True, exist_ok=True) + + results_by_method: dict[str, list[dict[str, Any]]] = {} + for method in args.methods: + results_by_method[method] = [] + for seed in seeds: + args.model = method + args.seed = seed + args.output_dir = str(comparison_root / f"{method}_seed{seed}") + config = apply_preset(build_config_from_args(args)) + config.output_dir = str(comparison_root / f"{method}_seed{seed}") + print(f"[addition comparison] running method={method} seed={seed}", flush=True) + summary = run_experiment(config) + results_by_method[method].append(summary) + + aggregate = aggregate_runs(results_by_method) + aggregate_payload = { + "methods": args.methods, + "seeds": seeds, + "aggregate": aggregate, + } + with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle: + json.dump(aggregate_payload, handle, indent=2, sort_keys=True) + plot_method_comparison(aggregate, comparison_root / "plots") + + +if __name__ == "__main__": + main() diff --git a/addition/train.py b/addition/train.py new file mode 100644 index 0000000000000000000000000000000000000000..90dc0937d6ddbb46c6a4801e7a4b37e041ffe570 --- /dev/null +++ b/addition/train.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Any + +import torch +from torch import nn + +from addition.config import ExperimentConfig, ensure_output_dirs, parse_config, save_config +from addition.data import build_batch, build_evaluation_suite, digits_to_string, exact_sum_matches, sample_training_batch, seed_everything +from addition.eval import evaluate_problem_set, evaluate_suite, flatten_nested_metrics +from addition.model import build_model, describe_model +from addition.plots import plot_single_run_results + + +def _maybe_init_wandb(config: ExperimentConfig, output_dir: Path): + if not config.use_wandb or config.wandb_mode == "disabled": + return None + try: + import wandb + except ImportError: + print("wandb is not installed; continuing with local logging only.") + return None + run = wandb.init( + project=config.wandb_project, + entity=config.wandb_entity or None, + name=config.effective_run_name, + mode=config.wandb_mode, + config={"experiment": config.__dict__}, + dir=str(output_dir), + reinit=True, + ) + return run + + +def _save_json(path: Path, payload: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + + +def _save_checkpoint(path: Path, model: nn.Module, optimizer: torch.optim.Optimizer, metadata: dict[str, Any]) -> None: + torch.save( + { + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + "metadata": metadata, + }, + path, + ) + + +def _stage_checkpoint_path(stage_directory: Path, stage: int) -> Path: + return stage_directory / f"stage_{stage:02d}_passed.pt" + + +def _evaluate_current_stage( + model: nn.Module, + config: ExperimentConfig, + suite, + stage: int, + device: str, +) -> dict[str, float]: + stage_metrics, _ = evaluate_problem_set( + model=model, + config=config, + problems=suite.validation_uniform[stage], + active_digits=stage, + device=device, + return_attention=False, + ) + return { + "digit_accuracy": stage_metrics.digit_accuracy, + "final_carry_accuracy": stage_metrics.final_carry_accuracy, + "exact_match": stage_metrics.exact_match, + } + + +def _masked_digit_loss( + logits: torch.Tensor, + targets: torch.Tensor, + mask: torch.Tensor, + loss_fn: nn.Module, +) -> torch.Tensor: + masked_logits = logits[mask] + masked_targets = targets[mask] + if masked_logits.numel() == 0: + return logits.new_zeros(()) + return loss_fn(masked_logits, masked_targets) + + +@torch.no_grad() +def _print_model_debug_format( + model: nn.Module, + config: ExperimentConfig, + *, + stage: int, + rng, + device: str, +) -> None: + debug_batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) + outputs = model(debug_batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) + print("[addition debug] model_architecture", flush=True) + print(model, flush=True) + print( + "[addition debug] batch_format " + f"stage={stage} input_shape={tuple(debug_batch.input_ids.shape)} " + f"target_digits_shape={tuple(debug_batch.target_digits.shape)} " + f"target_mask_shape={tuple(debug_batch.target_digit_mask.shape)} " + f"target_final_carry_shape={tuple(debug_batch.target_final_carry.shape)} " + f"digit_logits_shape={tuple(outputs.digit_logits.shape)} " + f"final_carry_logits_shape={tuple(outputs.final_carry_logits.shape)} " + f"output_hidden_shape={tuple(outputs.output_hidden.shape)}", + flush=True, + ) + + +@torch.no_grad() +def _print_validation_samples( + model: nn.Module, + config: ExperimentConfig, + problems, + *, + stage: int, + device: str, + limit: int = 3, +) -> None: + sample_problems = list(problems[:limit]) + if not sample_problems: + return + batch = build_batch(problems=sample_problems, radix=config.radix, device=device) + outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) + predicted_digits = outputs.digit_logits.argmax(dim=-1).cpu().tolist() + predicted_final_carry = outputs.final_carry_logits.argmax(dim=-1).cpu().tolist() + + for example_index, problem in enumerate(sample_problems): + truth_digits = problem.sum_digits[:stage] + truth_final_carry = problem.carry_out[stage - 1] + pred_digits = predicted_digits[example_index][:stage] + pred_final_carry = int(predicted_final_carry[example_index]) + exact = exact_sum_matches( + predicted_digits=pred_digits, + predicted_final_carry=pred_final_carry, + truth_digits=truth_digits, + truth_final_carry=truth_final_carry, + ) + a_text = digits_to_string(problem.a_digits[:stage], final_carry=0, radix=config.radix) + b_text = digits_to_string(problem.b_digits[:stage], final_carry=0, radix=config.radix) + pred_text = digits_to_string(pred_digits, final_carry=pred_final_carry, radix=config.radix) + truth_text = digits_to_string(truth_digits, final_carry=truth_final_carry, radix=config.radix) + print( + f"[addition sample] stage={stage} idx={example_index} " + f"a={a_text} b={b_text} pred={pred_text} true={truth_text} " + f"pred_digits={pred_digits} pred_carry={pred_final_carry} " + f"true_digits={truth_digits} true_carry={truth_final_carry} exact={int(exact)}", + flush=True, + ) + + +def run_experiment(config: ExperimentConfig) -> dict[str, Any]: + directories = ensure_output_dirs(config) + save_config(config, directories["root"]) + seed_everything(config.seed) + device = config.device + model = build_model(config, device=device) + optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) + digit_loss_fn = nn.CrossEntropyLoss() + final_carry_loss_fn = nn.CrossEntropyLoss() + suite = build_evaluation_suite(config) + rng = __import__("random").Random(config.seed + 12345) + history: list[dict[str, Any]] = [] + best_validation = -1.0 + best_checkpoint_path = directories["checkpoints"] / "best.pt" + last_checkpoint_path = directories["checkpoints"] / "last.pt" + stage = config.initial_stage if config.uses_curriculum else config.train_max_digits + stage_steps = 0 + global_step = 0 + stop_reason = "train_steps_exhausted" + wandb_run = _maybe_init_wandb(config, directories["root"]) + started_at = time.time() + param_counts = describe_model(config) + print( + f"[addition train] model={config.model} seed={config.seed} device={device} " + f"params={param_counts['total_params']} stage={stage}", + flush=True, + ) + _print_model_debug_format(model=model, config=config, stage=stage, rng=rng, device=device) + + while global_step < config.train_steps: + model.train() + batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device) + optimizer.zero_grad(set_to_none=True) + outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False) + digit_loss = _masked_digit_loss( + logits=outputs.digit_logits, + targets=batch.target_digits, + mask=batch.target_digit_mask, + loss_fn=digit_loss_fn, + ) + final_carry_loss = final_carry_loss_fn(outputs.final_carry_logits, batch.target_final_carry) + loss = digit_loss + final_carry_loss + loss.backward() + if config.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm) + optimizer.step() + + global_step += 1 + stage_steps += 1 + + if global_step % max(1, config.validation_interval // 2) == 0: + train_message = ( + f"[addition train] step={global_step} stage={stage} " + f"loss={loss.item():.4f} digit_loss={digit_loss.item():.4f} " + f"final_carry_loss={final_carry_loss.item():.4f}" + ) + print(train_message, flush=True) + + should_validate = ( + global_step % config.validation_interval == 0 + or global_step == config.train_steps + or (config.uses_curriculum and stage_steps == config.max_steps_per_stage) + ) + if not should_validate: + continue + + validation = _evaluate_current_stage(model=model, config=config, suite=suite, stage=stage, device=device) + history_entry = { + "global_step": global_step, + "stage": stage, + "stage_steps": stage_steps, + "loss": float(loss.item()), + "digit_loss": float(digit_loss.item()), + "final_carry_loss": float(final_carry_loss.item()), + "validation_digit_accuracy": validation["digit_accuracy"], + "validation_final_carry_accuracy": validation["final_carry_accuracy"], + "validation_exact_match": validation["exact_match"], + "latent_steps": config.latent_steps_for_stage(stage), + } + history.append(history_entry) + print( + f"[addition val] step={global_step} stage={stage} " + f"digit_acc={validation['digit_accuracy']:.4f} final_carry_acc={validation['final_carry_accuracy']:.4f} " + f"exact={validation['exact_match']:.4f}", + flush=True, + ) + _print_validation_samples( + model=model, + config=config, + problems=suite.validation_uniform[stage], + stage=stage, + device=device, + ) + if wandb_run is not None: + payload = { + "train/loss": float(loss.item()), + "train/digit_loss": float(digit_loss.item()), + "train/final_carry_loss": float(final_carry_loss.item()), + "train/stage": stage, + "train/latent_steps": config.latent_steps_for_stage(stage), + "validation/digit_accuracy": validation["digit_accuracy"], + "validation/final_carry_accuracy": validation["final_carry_accuracy"], + "validation/exact_match": validation["exact_match"], + "step": global_step, + } + wandb_run.log(payload) + + if validation["exact_match"] >= best_validation: + best_validation = validation["exact_match"] + _save_checkpoint( + best_checkpoint_path, + model, + optimizer, + metadata={ + "global_step": global_step, + "stage": stage, + "best_validation_exact_match": best_validation, + }, + ) + + reached_threshold = validation["exact_match"] >= config.stage_accuracy_threshold + reached_cap = stage_steps >= config.max_steps_per_stage + + if config.uses_curriculum: + if stage < config.train_max_digits and reached_threshold: + _save_checkpoint( + _stage_checkpoint_path(directories["stage_checkpoints"], stage), + model, + optimizer, + metadata={ + "global_step": global_step, + "stage": stage, + "validation_exact_match": validation["exact_match"], + "validation_digit_accuracy": validation["digit_accuracy"], + "validation_final_carry_accuracy": validation["final_carry_accuracy"], + }, + ) + print( + f"[addition curriculum] advance {stage} -> {stage + 1} " + f"(exact_match={validation['exact_match']:.4f})", + flush=True, + ) + stage += 1 + stage_steps = 0 + continue + if reached_cap and not reached_threshold: + print( + f"[addition curriculum] hold stage={stage} after {stage_steps} steps " + f"(exact_match={validation['exact_match']:.4f} < threshold={config.stage_accuracy_threshold:.2f})", + flush=True, + ) + if stage == config.train_max_digits and reached_threshold: + stop_reason = "final_stage_threshold" + break + + _save_checkpoint( + last_checkpoint_path, + model, + optimizer, + metadata={ + "global_step": global_step, + "stage": stage, + "stop_reason": stop_reason, + }, + ) + + best_payload = torch.load(best_checkpoint_path, map_location=device) + model.load_state_dict(best_payload["model_state"]) + final_results = evaluate_suite(model=model, config=config, suite=suite, device=device) + flat_final_metrics = flatten_nested_metrics("", final_results) + summary = { + "config": config.__dict__, + "param_counts": param_counts, + "best_validation_exact_match": best_validation, + "global_step": global_step, + "final_stage": stage, + "stop_reason": stop_reason, + "elapsed_seconds": time.time() - started_at, + "history": history, + "final_results": final_results, + "flat_final_metrics": flat_final_metrics, + } + _save_json(directories["artifacts"] / "summary.json", summary) + with (directories["artifacts"] / "history.jsonl").open("w", encoding="utf-8") as handle: + for entry in history: + handle.write(json.dumps(entry, sort_keys=True) + "\n") + plot_single_run_results(summary, directories["plots"]) + + if wandb_run is not None: + wandb_run.log(flat_final_metrics | {"step": global_step}) + wandb_run.summary.update( + { + "best_validation_exact_match": best_validation, + "final_stage": stage, + "stop_reason": stop_reason, + } + ) + wandb_run.finish() + + return summary + + +def main() -> None: + config = parse_config("Train the addition carry experiment.") + run_experiment(config) + + +if __name__ == "__main__": + main() diff --git a/aligned_cell_policy/shared_cell_policy.py b/aligned_cell_policy/shared_cell_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..a7881d5cfd0266048a70c1fabd66fa7eec131c79 --- /dev/null +++ b/aligned_cell_policy/shared_cell_policy.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import os +import re +import sys +from dataclasses import dataclass +from typing import Any, Dict, List + +import numpy as np + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from format_utils_icon import parse_n_value_prediction + + +@dataclass(frozen=True) +class CellExample: + grid: np.ndarray + target_cell: tuple[int, int] + target_value: int + turn_idx: int + total_turns: int + + +_TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)") + + +def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray: + triples = _TUPLE_PROMPT_RE.findall(str(prompt_text)) + if len(triples) < 81: + raise ValueError("Could not recover 81 (row,col,value) tuples from prompt.") + grid = np.zeros((9, 9), dtype=int) + for rr, cc, vv in triples[:81]: + grid[int(rr) - 1, int(cc) - 1] = int(vv) + return grid + + +def build_cell_examples_from_row(row: Dict[str, Any]) -> List[CellExample]: + prompt = str(row["prompt"]) + grid = parse_grid_from_tuple_prompt(prompt) + metadata = dict(row.get("metadata", {})) + empty_locs = metadata.get("empty_locs_1based") + target_triples = metadata.get("target_triples_1based") + + if not empty_locs or not target_triples: + completion = str(row.get("completion", "")) + parsed, _ = parse_n_value_prediction(completion, int(metadata.get("empties", 0) or 0)) + if parsed is None: + raise ValueError("Row is missing metadata and completion could not be parsed.") + empty_locs = [(r + 1, c + 1) for r, c in np.argwhere(grid == 0).tolist()] + target_triples = [(int(r), int(c), int(v)) for (r, c), v in zip(empty_locs, parsed)] + + total_turns = len(target_triples) + out: List[CellExample] = [] + for idx, triple in enumerate(target_triples, start=1): + rr, cc, value = int(triple[0]) - 1, int(triple[1]) - 1, int(triple[2]) + out.append( + CellExample( + grid=np.asarray(grid, dtype=int).copy(), + target_cell=(rr, cc), + target_value=value, + turn_idx=idx, + total_turns=total_turns, + ) + ) + return out diff --git a/analysis/eval_saved_hard9x9_checkpoints.py b/analysis/eval_saved_hard9x9_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8bea67bdf554892ab5c75ece147df681e2d491 --- /dev/null +++ b/analysis/eval_saved_hard9x9_checkpoints.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import json +import os +from types import SimpleNamespace +import sys + +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +if ROOT := "/home/ubuntu/curriculum_cot": + if ROOT not in sys.path: + sys.path.insert(0, ROOT) + +from multi_output_cell_policy import grpo_multi_output_train as baseline_grpo +from multi_output_cell_policy import sft_multi_output_train as baseline_sft +from latent_multi_output_cell_policy import grpo_residual_projector_latent_train as latent_grpo +from latent_multi_output_cell_policy import residual_projector_warmstart_sft_latent_multi_output_train as latent_sft + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +CACHE_DIR = os.path.join(ROOT, ".hf_cache") +DATA_PATH = os.path.join(ROOT, "data", "sudoku_t3_30empty_value_qwen_text.jsonl") +EVAL_ROWS = 20 +TOTAL_EMPTIES_HINT = 30 + + +def make_tokenizer() -> AutoTokenizer: + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + return tokenizer + + +def make_device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def make_baseline_sft_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + cache_dir=CACHE_DIR, + torch_dtype=baseline_sft.pick_dtype() if torch.cuda.is_available() else torch.float32, + low_cpu_mem_usage=True, + ) + model = PeftModel.from_pretrained(base, checkpoint_dir, is_trainable=False) + model.to(device) + model.eval() + return model + + +def make_baseline_grpo_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + cache_dir=CACHE_DIR, + torch_dtype=baseline_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32, + low_cpu_mem_usage=True, + ) + model = baseline_grpo.load_trainable_adapter(base, checkpoint_dir) + model.to(device) + model.eval() + return model + + +def make_latent_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: + base = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + cache_dir=CACHE_DIR, + torch_dtype=latent_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32, + low_cpu_mem_usage=True, + ) + model = latent_grpo.load_trainable_adapter(base, checkpoint_dir) + projector_hidden = latent_grpo.infer_projector_hidden_from_state(checkpoint_dir) or latent_grpo.PROJECTOR_HIDDEN + latent_grpo.attach_residual_projector_modules( + model, + hidden_size=int(latent_grpo.unwrap_backbone(model).config.hidden_size), + projector_hidden=projector_hidden, + ) + latent_grpo.maybe_load_projector_state(model, checkpoint_dir) + model.to(device) + model.eval() + return model + + +def common_reward_args() -> dict: + return { + "reward_good_value": 1.0, + "penalty_bad_value": 1.75, + "penalty_malformed": 4.0, + "penalty_empty": 0.5, + "penalty_singleton": 1.5, + } + + +def eval_baseline_sft(checkpoint_dir: str, stage_i: int) -> dict: + device = make_device() + tokenizer = make_tokenizer() + model = make_baseline_sft_model(checkpoint_dir, device) + rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) + args = SimpleNamespace( + stage_i=int(stage_i), + total_empties_hint=TOTAL_EMPTIES_HINT, + max_completion_length=24, + debug_print_limit=0, + ) + metrics = baseline_sft.run_eval(args, rows, model, tokenizer, device) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return metrics + + +def eval_baseline_grpo(checkpoint_dir: str, stage_i: int) -> dict: + device = make_device() + tokenizer = make_tokenizer() + model = make_baseline_grpo_model(checkpoint_dir, device) + rows = baseline_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) + args = SimpleNamespace( + stage_i=int(stage_i), + total_empties_hint=TOTAL_EMPTIES_HINT, + max_completion_length=24, + debug_print_limit=0, + **common_reward_args(), + ) + metrics = baseline_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return metrics + + +def eval_latent_sft(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict: + device = make_device() + tokenizer = make_tokenizer() + model = make_latent_model(checkpoint_dir, device) + rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) + args = SimpleNamespace( + stage_i=int(stage_i), + num_cot_tokens=int(num_cot_tokens), + total_empties_hint=TOTAL_EMPTIES_HINT, + max_completion_length=32, + debug_print_limit=0, + **common_reward_args(), + ) + metrics = latent_sft.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i)) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return metrics + + +def eval_latent_grpo(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict: + device = make_device() + tokenizer = make_tokenizer() + model = make_latent_model(checkpoint_dir, device) + rows = latent_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) + args = SimpleNamespace( + stage_i=int(stage_i), + num_cot_tokens=int(num_cot_tokens), + total_empties_hint=TOTAL_EMPTIES_HINT, + max_completion_length=32, + debug_print_limit=0, + **common_reward_args(), + ) + metrics = latent_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i)) + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return metrics + + +def main() -> None: + # Explicit step dirs (not run roots) so metrics match the agreed endpoints. + checkpoints = [ + { + "label": "baseline_stage1_sft", + "stage_i": 1, + "kind": "baseline_sft", + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/20260404_023600_baseline30_clean/baseline_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty/checkpoint-step-01000", + ), + }, + { + "label": "baseline_stage1_grpo", + "stage_i": 1, + "kind": "baseline_grpo", + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i1_20260404_fixed_baseline_grpo_i1/checkpoint-5350", + ), + }, + { + "label": "baseline_stage2_sft", + "stage_i": 2, + "kind": "baseline_sft", + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft/i2_20260404_stage2_baseline_sft_from_grpo5350/checkpoint-step-13100", + ), + }, + { + "label": "baseline_stage2_grpo", + "stage_i": 2, + "kind": "baseline_grpo", + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i2_20260405_stage2_baseline_grpo_from_sft13100/checkpoint-4325", + ), + }, + { + "label": "latent_stage1_sft", + "stage_i": 1, + "kind": "latent_sft", + "num_cot_tokens": 1, + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/20260404_013500_latent30_frombaseline/latent_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty_residual_projector/checkpoint-step-00200", + ), + }, + { + "label": "latent_stage1_grpo", + "stage_i": 1, + "kind": "latent_grpo", + "num_cot_tokens": 1, + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740", + ), + }, + { + "label": "latent_stage2_sft", + "stage_i": 2, + "kind": "latent_sft", + "num_cot_tokens": 2, + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft/i2_cot2_20260404_stage2_latent_sft_from_grpo2740/checkpoint-step-00700", + ), + }, + { + "label": "latent_stage2_grpo", + "stage_i": 2, + "kind": "latent_grpo", + "num_cot_tokens": 2, + "checkpoint_dir": os.path.join( + ROOT, + "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i2_cot2_20260405_stage2_latent_grpo_from_sft00700/checkpoint-1620", + ), + }, + ] + + results: dict[str, dict] = {} + for item in checkpoints: + label = item["label"] + print(f"[eval] starting {label}", flush=True) + if item["kind"] == "baseline_sft": + metrics = eval_baseline_sft(item["checkpoint_dir"], item["stage_i"]) + elif item["kind"] == "baseline_grpo": + metrics = eval_baseline_grpo(item["checkpoint_dir"], item["stage_i"]) + elif item["kind"] == "latent_sft": + metrics = eval_latent_sft(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"]) + else: + metrics = eval_latent_grpo(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"]) + results[label] = metrics + print(json.dumps({"label": label, "metrics": metrics}, sort_keys=True), flush=True) + + print("[eval] complete", flush=True) + print(json.dumps(results, sort_keys=True, indent=2), flush=True) + + +if __name__ == "__main__": + main() diff --git a/checkpoint_utils.py b/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..399d5ce4a3555ef487557a0f4977597b65e42cf3 --- /dev/null +++ b/checkpoint_utils.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import os +import shutil +from typing import Any, Callable + +import torch +from peft import get_peft_model_state_dict +from safetensors.torch import save_file as save_safetensors_file + +FINAL_CHECKPOINT_DIRNAME = "final_checkpoint" +_WEIGHT_FILENAMES = ( + "adapter_model.safetensors", + "adapter_model.bin", + "model.safetensors", + "pytorch_model.bin", +) + + +def ensure_final_checkpoint_dir(output_dir: str) -> str: + repo_root = os.path.dirname(os.path.abspath(__file__)) + output_dir_abs = os.path.abspath(output_dir) + try: + rel_output_dir = os.path.relpath(output_dir_abs, repo_root) + except Exception: + rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep)) + rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")] + if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME: + rel_parts = rel_parts[1:] + if rel_parts and rel_parts[0] == "checkpoints": + rel_parts = rel_parts[1:] + if not rel_parts: + rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"] + final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts) + os.makedirs(final_dir, exist_ok=True) + return final_dir + + +def final_checkpoint_root(*parts: str) -> str: + repo_root = os.path.dirname(os.path.abspath(__file__)) + root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts) + os.makedirs(root, exist_ok=True) + return root + + +def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str: + raw = str(path or "").strip() + if not raw: + return final_checkpoint_root(*default_parts) + abs_path = os.path.abspath(raw) + repo_root = os.path.dirname(os.path.abspath(__file__)) + rel_path = os.path.relpath(abs_path, repo_root) + rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")] + if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]: + return abs_path + if rel_parts[:1] == ["checkpoints"]: + rel_parts = rel_parts[1:] + return final_checkpoint_root(*rel_parts) + return abs_path + + +def _has_saved_weights(target_dir: str) -> bool: + return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES) + + +def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None: + if _has_saved_weights(target_dir): + return + state = get_peft_model_state_dict(model) + cpu_state = { + key: value.detach().cpu().contiguous() + for key, value in state.items() + if torch.is_tensor(value) + } + if cpu_state: + save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors")) + + +def save_model_artifacts( + model: Any, + tokenizer: Any, + target_dir: str, + *, + extra_save_fn: Callable[[Any, str], None] | None = None, +) -> str: + os.makedirs(target_dir, exist_ok=True) + model.save_pretrained(target_dir) + if tokenizer is not None: + tokenizer.save_pretrained(target_dir) + _fallback_save_adapter_weights(model, target_dir) + if extra_save_fn is not None: + extra_save_fn(model, target_dir) + return target_dir + + +def _replace_dir_contents(src_dir: str, dst_dir: str) -> None: + os.makedirs(dst_dir, exist_ok=True) + src_dir_abs = os.path.abspath(src_dir) + for name in os.listdir(dst_dir): + path = os.path.join(dst_dir, name) + if os.path.abspath(path) == src_dir_abs: + continue + if os.path.isdir(path) and not os.path.islink(path): + shutil.rmtree(path) + else: + os.unlink(path) + for name in os.listdir(src_dir): + src_path = os.path.join(src_dir, name) + dst_path = os.path.join(dst_dir, name) + if os.path.isdir(src_path) and not os.path.islink(src_path): + shutil.copytree(src_path, dst_path) + else: + shutil.copy2(src_path, dst_path) + + +def save_checkpoint_and_update_final( + model: Any, + tokenizer: Any, + output_dir: str, + checkpoint_name: str, + *, + extra_save_fn: Callable[[Any, str], None] | None = None, +) -> str: + checkpoint_dir = os.path.join(output_dir, checkpoint_name) + save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn) + _replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir)) + return checkpoint_dir diff --git a/format_utils_icon.py b/format_utils_icon.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc5c99d4dcf79cb577f774db76d1e18615b5478 --- /dev/null +++ b/format_utils_icon.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import re +from typing import List, Tuple + +import numpy as np + + +_INT_RE = re.compile(r"-?\d+") + + +def grid_to_text(grid_9x9: np.ndarray) -> str: + grid = np.asarray(grid_9x9, dtype=int).reshape(9, 9) + return "\n".join(" ".join(str(int(value)) for value in row) for row in grid.tolist()) + + +def parse_n_value_prediction(text: str, n: int) -> Tuple[List[int] | None, bool]: + raw = str(text or "").strip() + if not raw: + return None, False + + try: + parsed = json.loads(raw) + if isinstance(parsed, dict) and isinstance(parsed.get("values"), list): + values = [int(v) for v in parsed["values"]] + if len(values) == int(n): + return values, True + if isinstance(parsed, list): + values = [int(v) for v in parsed] + if len(values) == int(n): + return values, True + except Exception: + pass + + values = [int(match.group(0)) for match in _INT_RE.finditer(raw)] + if len(values) == int(n): + return values, True + return None, False diff --git a/formatting_icon.py b/formatting_icon.py new file mode 100644 index 0000000000000000000000000000000000000000..e6aa5947dca971a8465df038dfb57805ec087988 --- /dev/null +++ b/formatting_icon.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import numpy as np + + +def is_consistent_pair(grid, *, cell: int, value: int, t: int = 3, n: int = 9) -> bool: + g = np.asarray(grid, dtype=int).reshape(int(n), int(n)) + cell = int(cell) + value = int(value) + if value < 1 or value > int(n): + return False + rr, cc = divmod(cell, int(n)) + current = int(g[rr, cc]) + if current != 0 and current != value: + return False + + row = g[rr, :] + for idx, existing in enumerate(row): + if idx != cc and int(existing) == value: + return False + + col = g[:, cc] + for idx, existing in enumerate(col): + if idx != rr and int(existing) == value: + return False + + box_r = (rr // int(t)) * int(t) + box_c = (cc // int(t)) * int(t) + for r in range(box_r, box_r + int(t)): + for c in range(box_c, box_c + int(t)): + if (r != rr or c != cc) and int(g[r, c]) == value: + return False + + return True diff --git a/hard_9x9_10empty/launch_baseline_stage3_pipeline.sh b/hard_9x9_10empty/launch_baseline_stage3_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..03d22c891a8e9bc534ff2db77f2514db3040361e --- /dev/null +++ b/hard_9x9_10empty/launch_baseline_stage3_pipeline.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py" + +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_10empty_value_qwen_text_longrun.jsonl}" +NUM_PUZZLES="${NUM_PUZZLES:-5000}" +DATASET_SEED="${DATASET_SEED:-0}" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_qwen05b/baseline}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_10empty_3stage_hard9x9}" + +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" +WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}" +WAIT_SECONDS="${WAIT_SECONDS:-60}" + +if [[ ! -f "${TRAIN_JSONL}" ]]; then + mkdir -p "$(dirname "${TRAIN_JSONL}")" + printf 'Building 10-empty dataset: %s\n' "${TRAIN_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${TRAIN_JSONL}" \ + --num_puzzles "${NUM_PUZZLES}" \ + --empties 10 \ + --seed "${DATASET_SEED}" +fi + +if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then + while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do + printf 'Existing training detected; waiting %ss before launching 10-empty baseline pipeline...\n' "${WAIT_SECONDS}" + sleep "${WAIT_SECONDS}" + done +fi + +mkdir -p "${CHECKPOINT_ROOT}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${ROOT}/.hf_cache" + --model_name "Qwen/Qwen2.5-0.5B-Instruct" + --checkpoint_root "${CHECKPOINT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage 1 + --max_stage 3 + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint 10 + --limit_train_rows 5000 + --sft_num_epochs 3.0 + --grpo_num_train_epochs 1.5 + --sft_gradient_accumulation_steps 8 + --grpo_per_device_train_batch_size 8 + --grpo_gradient_accumulation_steps 2 + --grpo_num_generations 4 + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_eval_steps 100 + --sft_save_steps 100 + --grpo_eval_steps 50 + --grpo_save_steps 50 + --sft_eval_rows 100 + --grpo_eval_rows 100 + --sft_stage_max_steps "1:2000,2:2000,3:2000" + --grpo_stage_max_steps "1:1200,2:1200,3:1200" + --sft_eval_solve_rate_stop 0.8 + --sft_min_steps_before_stop 100 + --grpo_eval_solve_rate_stop 0.8 + --grpo_min_steps_before_stop 50 + --grpo_reward_good_value 1.25 + --grpo_penalty_bad_value 1.0 + --grpo_penalty_malformed 4.0 + --grpo_penalty_empty 0.5 + --grpo_penalty_singleton 1.0 + --phase_max_wall_clock_seconds 36000 + --wandb_mode "${WANDB_MODE}" + --use_wandb +) + +if [[ -n "${WANDB_ENTITY}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching 10-empty baseline stage-3 pipeline\n' +printf 'Dataset: %s\n' "${TRAIN_JSONL}" +printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" + +exec "${cmd[@]}" diff --git a/hard_9x9_15empty/launch_baseline_pipeline.sh b/hard_9x9_15empty/launch_baseline_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..42eb74fd4dd86690bf7a1523e35b5b791743a289 --- /dev/null +++ b/hard_9x9_15empty/launch_baseline_pipeline.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +PIPELINE_LAUNCHER="${ROOT}/large_baseline_extension/launch_nonlocation_pipeline.sh" + +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text.jsonl}" +NUM_PUZZLES="${NUM_PUZZLES:-20000}" +DATASET_SEED="${DATASET_SEED:-0}" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_15empty_4stage_hard9x9}" + +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" +WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}" +WAIT_SECONDS="${WAIT_SECONDS:-60}" + +if [[ ! -f "${TRAIN_JSONL}" ]]; then + mkdir -p "$(dirname "${TRAIN_JSONL}")" + printf 'Building 15-empty dataset: %s\n' "${TRAIN_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${TRAIN_JSONL}" \ + --num_puzzles "${NUM_PUZZLES}" \ + --empties 15 \ + --seed "${DATASET_SEED}" +fi + +if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then + while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do + printf 'Existing training detected; waiting %ss before launching 15-empty baseline...\n' "${WAIT_SECONDS}" + sleep "${WAIT_SECONDS}" + done +fi + +mkdir -p "${CHECKPOINT_ROOT}" + +export TRAIN_JSONL +export TOTAL_EMPTIES_HINT=15 +export GPU_IDS +export NUM_PROCESSES +export MIN_STAGE +export MAX_STAGE +export RUN_TAG +export CHECKPOINT_ROOT +export OUTPUT_ROOT +export WANDB_MODE +export WANDB_ENTITY + +printf 'Launching 15-empty hard 9x9 baseline pipeline\n' +printf 'Dataset: %s\n' "${TRAIN_JSONL}" +printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" + +exec "${PIPELINE_LAUNCHER}" diff --git a/hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh b/hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..2fae68f74653d940ac2f89748d0119f01ae7b22d --- /dev/null +++ b/hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" + +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_train.jsonl}" +EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_eval.jsonl}" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-2000}" +TRAIN_SEED="${TRAIN_SEED:-0}" +EVAL_SEED="${EVAL_SEED:-1}" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline_stage1_multivalue}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_15empty_size2only}" + +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +if [[ ! -f "${TRAIN_JSONL}" ]]; then + mkdir -p "$(dirname "${TRAIN_JSONL}")" + printf 'Building 15-empty train dataset: %s\n' "${TRAIN_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${TRAIN_JSONL}" \ + --num_puzzles "${TRAIN_PUZZLES}" \ + --empties 15 \ + --seed "${TRAIN_SEED}" +fi + +if [[ ! -f "${EVAL_JSONL}" ]]; then + mkdir -p "$(dirname "${EVAL_JSONL}")" + printf 'Building 15-empty eval dataset: %s\n' "${EVAL_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${EVAL_JSONL}" \ + --num_puzzles "${EVAL_PUZZLES}" \ + --empties 15 \ + --seed "${EVAL_SEED}" +fi + +mkdir -p "${CHECKPOINT_ROOT}" + +cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" + --model_name "Qwen/Qwen2.5-0.5B-Instruct" + --train_jsonl "${TRAIN_JSONL}" + --eval_jsonl "${EVAL_JSONL}" + --output_dir "${OUTPUT_DIR}" + --cache_dir "${ROOT}/.hf_cache" + --seed 0 + --gpu_id 0 + --stage_i 1 + --total_empties_hint 15 + --per_device_train_batch_size 16 + --gradient_accumulation_steps 2 + --num_epochs 4.0 + --learning_rate 2e-4 + --enable_gradient_checkpointing + --logging_steps 10 + --eval_steps 50 + --save_steps 50 + --eval_rows "${EVAL_PUZZLES}" + --max_completion_length 24 + --limit_train_rows "${TRAIN_PUZZLES}" + --lora_r 32 + --lora_alpha 64 + --lora_dropout 0.05 + --multi_value_oversample_factor 1 + --train_target_size_min 2 + --train_target_size_max 2 + --eval_target_size_min 2 + --eval_target_size_max 2 + --eval_value_precision_stop 0.95 + --eval_value_recall_stop 0.95 + --min_steps_before_stop 100 + --max_wall_clock_seconds 7200 + --max_steps 600 + --use_wandb + --wandb_project "sudoku-multi-output-sft" + --wandb_run_name "baseline_stage01_sft_i1_15empty_size2only" + --wandb_mode "${WANDB_MODE}" +) + +if [[ -n "${WANDB_ENTITY}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +printf 'Launching 15-empty stage-1 size-2-only SFT baseline\n' +printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}" +printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" + +exec "${cmd[@]}" diff --git a/hard_9x9_7empty/launch_stage1_sft.sh b/hard_9x9_7empty/launch_stage1_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5000bef62654367a6448f0f124cd55eb1728a3c --- /dev/null +++ b/hard_9x9_7empty/launch_stage1_sft.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" + +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_train.jsonl}" +EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_eval.jsonl}" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-400}" +EVAL_PUZZLES="${EVAL_PUZZLES:-200}" +TRAIN_SEED="${TRAIN_SEED:-0}" +EVAL_SEED="${EVAL_SEED:-1}" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_qwen05b/baseline_stage1}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_7empty}" + +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +if [[ ! -f "${TRAIN_JSONL}" ]]; then + mkdir -p "$(dirname "${TRAIN_JSONL}")" + printf 'Building 7-empty train dataset: %s\n' "${TRAIN_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${TRAIN_JSONL}" \ + --num_puzzles "${TRAIN_PUZZLES}" \ + --empties 7 \ + --seed "${TRAIN_SEED}" +fi + +if [[ ! -f "${EVAL_JSONL}" ]]; then + mkdir -p "$(dirname "${EVAL_JSONL}")" + printf 'Building 7-empty eval dataset: %s\n' "${EVAL_JSONL}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" \ + --output "${EVAL_JSONL}" \ + --num_puzzles "${EVAL_PUZZLES}" \ + --empties 7 \ + --seed "${EVAL_SEED}" +fi + +mkdir -p "${CHECKPOINT_ROOT}" + +cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" + --model_name "Qwen/Qwen2.5-0.5B-Instruct" + --train_jsonl "${TRAIN_JSONL}" + --eval_jsonl "${EVAL_JSONL}" + --output_dir "${OUTPUT_DIR}" + --cache_dir "${ROOT}/.hf_cache" + --seed 0 + --gpu_id 0 + --stage_i 1 + --total_empties_hint 7 + --num_epochs 3.0 + --learning_rate 2e-4 + --gradient_accumulation_steps 8 + --enable_gradient_checkpointing + --logging_steps 10 + --eval_steps 25 + --save_steps 25 + --eval_rows "${EVAL_PUZZLES}" + --max_completion_length 24 + --limit_train_rows "${TRAIN_PUZZLES}" + --lora_r 32 + --lora_alpha 64 + --lora_dropout 0.05 + --multi_value_oversample_factor 16 + --eval_exact_set_match_stop 0.999 + --eval_value_precision_stop 0.999 + --eval_value_recall_stop 0.999 + --min_steps_before_stop 50 + --max_wall_clock_seconds 1800 + --max_steps 250 + --use_wandb + --wandb_project "sudoku-multi-output-sft" + --wandb_run_name "baseline_stage01_sft_i1_7empty" + --wandb_mode "${WANDB_MODE}" +) + +if [[ -n "${WANDB_ENTITY}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +printf 'Launching 7-empty stage-1 SFT baseline\n' +printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}" +printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" + +exec "${cmd[@]}" diff --git a/hard_9x9_curriculum/build_stage3_hard_dataset.py b/hard_9x9_curriculum/build_stage3_hard_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..618379f40a24826f1cf16c8eff0a7e65fa0b3bc1 --- /dev/null +++ b/hard_9x9_curriculum/build_stage3_hard_dataset.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +import argparse +import json +import random +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Sequence, Tuple + +import numpy as np + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from formatting_icon import is_consistent_pair +from multi_output_cell_policy.shared_multi_output_policy import stage_i_consistent_values + + +GRID_SIZE = 9 +BOX_SIZE = 3 +ALL_VALUES = tuple(range(1, 10)) +DEFAULT_BASE_NAME = "sudoku_t3_30empty_stage3hard" + + +@dataclass(frozen=True) +class DifficultyProfile: + stage1_solved: bool + stage2_solved: bool + stage3_solved: bool + stage1_steps: int + stage2_steps: int + stage3_steps: int + + +@dataclass(frozen=True) +class SeedMask: + mask_cells: tuple[int, ...] + profile: DifficultyProfile + + +def parse_args() -> argparse.Namespace: + root = PARENT_DIR + default_train = root / "data" / f"{DEFAULT_BASE_NAME}_value_qwen_text.jsonl" + default_eval = root / "data" / f"{DEFAULT_BASE_NAME}_eval_value_qwen_text.jsonl" + default_manifest = root / "data" / f"{DEFAULT_BASE_NAME}_manifest.json" + p = argparse.ArgumentParser() + p.add_argument("--train_output", type=str, default=str(default_train)) + p.add_argument("--eval_output", type=str, default=str(default_eval)) + p.add_argument("--manifest_output", type=str, default=str(default_manifest)) + p.add_argument("--num_train_puzzles", type=int, default=4000) + p.add_argument("--num_eval_puzzles", type=int, default=200) + p.add_argument("--empties", type=int, default=30) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--max_attempts", type=int, default=200000) + p.add_argument("--progress_every", type=int, default=250) + p.add_argument("--num_seed_masks", type=int, default=8) + return p.parse_args() + + +def permute_groups(rng: random.Random, values: Sequence[int], group_size: int) -> List[int]: + groups = [list(values[idx : idx + group_size]) for idx in range(0, len(values), group_size)] + rng.shuffle(groups) + out: List[int] = [] + for group in groups: + rng.shuffle(group) + out.extend(group) + return out + + +def base_solved_grid() -> np.ndarray: + return np.asarray( + [[((rr * BOX_SIZE + rr // BOX_SIZE + cc) % GRID_SIZE) + 1 for cc in range(GRID_SIZE)] for rr in range(GRID_SIZE)], + dtype=int, + ) + + +def row_major_empty_locs(grid: np.ndarray) -> List[Tuple[int, int]]: + return [(int(r), int(c)) for r, c in np.argwhere(np.asarray(grid, dtype=int) == 0).tolist()] + + +def make_prompt(grid: np.ndarray) -> str: + tuples = [f"({r + 1},{c + 1},{int(grid[r, c])})" for r in range(GRID_SIZE) for c in range(GRID_SIZE)] + return ( + "9x9 Sudoku board encoded as (row,col,value) tuples in row-major order.\n" + "Value 0 means the cell is empty.\n" + + " ".join(tuples) + ) + + +def legal_values(grid: np.ndarray, row: int, col: int) -> List[int]: + cell = int(row) * GRID_SIZE + int(col) + return [int(value) for value in ALL_VALUES if is_consistent_pair(grid, cell=cell, value=int(value), t=3, n=9)] + + +def count_solutions(grid: np.ndarray, *, limit: int = 2) -> int: + board = np.asarray(grid, dtype=int).copy() + solutions = 0 + + def backtrack() -> None: + nonlocal solutions + if solutions >= int(limit): + return + best_cell: Tuple[int, int] | None = None + best_values: List[int] | None = None + for rr, cc in row_major_empty_locs(board): + values = legal_values(board, rr, cc) + if not values: + return + if best_values is None or len(values) < len(best_values): + best_cell = (rr, cc) + best_values = values + if len(best_values) == 1: + break + if best_cell is None: + solutions += 1 + return + rr, cc = best_cell + for value in best_values or []: + board[rr, cc] = int(value) + backtrack() + board[rr, cc] = 0 + if solutions >= int(limit): + return + + backtrack() + return int(solutions) + + +def propagate_stage(grid: np.ndarray, *, stage_i: int) -> Tuple[np.ndarray | None, int]: + board = np.asarray(grid, dtype=int).copy() + num_assignments = 0 + while True: + chosen: Tuple[int, int, int] | None = None + for rr, cc in row_major_empty_locs(board): + values = stage_i_consistent_values(board, target_cell=(rr, cc), stage_i=int(stage_i)) + if not values: + return None, num_assignments + if len(values) == 1: + chosen = (rr, cc, int(values[0])) + break + if chosen is None: + return board, num_assignments + rr, cc, value = chosen + board[rr, cc] = int(value) + num_assignments += 1 + + +def build_difficulty_profile(puzzle: np.ndarray, solved: np.ndarray) -> DifficultyProfile | None: + stage1_board, stage1_steps = propagate_stage(puzzle, stage_i=1) + if stage1_board is None: + return None + stage2_board, stage2_steps = propagate_stage(puzzle, stage_i=2) + if stage2_board is None: + return None + stage3_board, stage3_steps = propagate_stage(puzzle, stage_i=3) + if stage3_board is None: + return None + return DifficultyProfile( + stage1_solved=bool(np.array_equal(stage1_board, solved)), + stage2_solved=bool(np.array_equal(stage2_board, solved)), + stage3_solved=bool(np.array_equal(stage3_board, solved)), + stage1_steps=int(stage1_steps), + stage2_steps=int(stage2_steps), + stage3_steps=int(stage3_steps), + ) + + +def qualifies(profile: DifficultyProfile) -> bool: + return (not profile.stage1_solved) and (not profile.stage2_solved) and profile.stage3_solved + + +def build_puzzle_from_mask(solved: np.ndarray, mask_cells: Sequence[int]) -> np.ndarray: + puzzle = np.asarray(solved, dtype=int).copy() + for cell in mask_cells: + rr, cc = divmod(int(cell), GRID_SIZE) + puzzle[rr, cc] = 0 + return puzzle + + +def sample_mask_cells(*, empties: int, rng: random.Random) -> tuple[int, ...]: + cells = list(range(GRID_SIZE * GRID_SIZE)) + rng.shuffle(cells) + return tuple(sorted(int(cell) for cell in cells[: int(empties)])) + + +def greedy_find_seed_mask( + *, + empties: int, + max_attempts: int, + rng: random.Random, + progress_every: int, +) -> Tuple[SeedMask | None, Dict[str, int]]: + solved = base_solved_grid() + attempts = 0 + restarts = 0 + while attempts < int(max_attempts): + restarts += 1 + mask: List[int] = [] + remaining = list(range(GRID_SIZE * GRID_SIZE)) + rng.shuffle(remaining) + current_profile: DifficultyProfile | None = None + + while len(mask) < int(empties) and attempts < int(max_attempts): + best_cell: int | None = None + best_profile: DifficultyProfile | None = None + best_score: Tuple[int, int, int] | None = None + candidate_cells = list(remaining[: min(len(remaining), 12)]) + if not candidate_cells: + break + + for cell in candidate_cells: + attempts += 1 + trial_mask = tuple(sorted(mask + [int(cell)])) + puzzle = build_puzzle_from_mask(solved, trial_mask) + profile = build_difficulty_profile(puzzle, solved) + if profile is None or not profile.stage3_solved: + continue + score = ( + int(not profile.stage2_solved), + int(not profile.stage1_solved), + int(profile.stage3_steps - profile.stage2_steps), + ) + if best_score is None or score > best_score: + best_cell = int(cell) + best_profile = profile + best_score = score + + if attempts == 1 or attempts % max(1, int(progress_every)) == 0: + print( + f"[search hard 9x9 masks] attempts={attempts} restarts={restarts} current_empties={len(mask)}", + flush=True, + ) + + if best_cell is None or best_profile is None: + break + + mask.append(int(best_cell)) + mask.sort() + remaining.remove(int(best_cell)) + current_profile = best_profile + + if len(mask) != int(empties) or current_profile is None: + continue + + final_mask = tuple(sorted(int(cell) for cell in mask)) + final_puzzle = build_puzzle_from_mask(solved, final_mask) + final_profile = build_difficulty_profile(final_puzzle, solved) + if final_profile is None or not qualifies(final_profile): + continue + if count_solutions(final_puzzle, limit=2) != 1: + continue + return SeedMask(mask_cells=final_mask, profile=final_profile), { + "attempts": int(attempts), + "restarts": int(restarts), + } + + return None, {"attempts": int(attempts), "restarts": int(restarts)} + + +def random_symmetry( + rng: random.Random, *, solved: np.ndarray, mask_cells: Sequence[int] +) -> Tuple[np.ndarray, tuple[int, ...]]: + digits = list(ALL_VALUES) + rng.shuffle(digits) + digit_map = {src: dst for src, dst in zip(ALL_VALUES, digits, strict=True)} + transformed = np.vectorize(lambda value: digit_map[int(value)], otypes=[int])(np.asarray(solved, dtype=int).copy()) + + row_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + col_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + inverse_row = {old: new for new, old in enumerate(row_order)} + inverse_col = {old: new for new, old in enumerate(col_order)} + + transformed = transformed[row_order, :] + transformed = transformed[:, col_order] + + transformed_cells: List[int] = [] + for cell in mask_cells: + rr, cc = divmod(int(cell), GRID_SIZE) + new_r = int(inverse_row[int(rr)]) + new_c = int(inverse_col[int(cc)]) + transformed_cells.append(new_r * GRID_SIZE + new_c) + + if rng.random() < 0.5: + transformed = transformed.T + transformed_cells = [int(cc) * GRID_SIZE + int(rr) for rr, cc in (divmod(cell, GRID_SIZE) for cell in transformed_cells)] + + return np.asarray(transformed, dtype=int), tuple(sorted(int(cell) for cell in transformed_cells)) + + +def make_example(solved: np.ndarray, mask_cells: Sequence[int], *, empties: int, profile: DifficultyProfile) -> Dict[str, object]: + puzzle = build_puzzle_from_mask(solved, mask_cells) + empty_locs_1based = [(rr + 1, cc + 1) for rr, cc in row_major_empty_locs(puzzle)] + target_triples_1based = [(rr + 1, cc + 1, int(solved[rr, cc])) for rr, cc in row_major_empty_locs(puzzle)] + completion_values = [int(value) for _, _, value in target_triples_1based] + return { + "prompt": make_prompt(puzzle), + "completion": json.dumps(completion_values, separators=(",", ":")), + "metadata": { + "grid_size": GRID_SIZE, + "box_size": BOX_SIZE, + "empties": int(empties), + "empty_locs_1based": empty_locs_1based, + "target_triples_1based": target_triples_1based, + "required_consistency_stage": 3, + "difficulty_profile": asdict(profile), + }, + } + + +def search_seed_masks( + *, + num_seed_masks: int, + empties: int, + max_attempts: int, + seed: int, + progress_every: int, +) -> Tuple[List[SeedMask], Dict[str, int]]: + rng = random.Random(int(seed)) + seeds: List[SeedMask] = [] + seen = set() + total_attempts = 0 + total_restarts = 0 + + while len(seeds) < int(num_seed_masks) and total_attempts < int(max_attempts): + mask_seed, stats = greedy_find_seed_mask( + empties=int(empties), + max_attempts=max(1, int(max_attempts) - int(total_attempts)), + rng=rng, + progress_every=int(progress_every), + ) + total_attempts += int(stats.get("attempts", 0)) + total_restarts += int(stats.get("restarts", 0)) + if mask_seed is None: + break + if mask_seed.mask_cells in seen: + continue + seen.add(mask_seed.mask_cells) + seeds.append(mask_seed) + print( + f"[search hard 9x9 masks] attempts={total_attempts} accepted={len(seeds)}/{num_seed_masks}", + flush=True, + ) + + stats = { + "attempts": int(total_attempts), + "restarts": int(total_restarts), + "accepted_seed_masks": int(len(seeds)), + } + return seeds, stats + + +def generate_examples( + *, + num_examples: int, + empties: int, + seed_masks: Sequence[SeedMask], + seed: int, +) -> List[Dict[str, object]]: + if not seed_masks: + raise ValueError("seed_masks must not be empty") + rng = random.Random(int(seed) + 1) + solved = base_solved_grid() + rows: List[Dict[str, object]] = [] + for idx in range(int(num_examples)): + seed_mask = seed_masks[idx % len(seed_masks)] + transformed_solved, transformed_mask = random_symmetry( + rng, solved=solved, mask_cells=seed_mask.mask_cells + ) + rows.append( + make_example( + transformed_solved, + transformed_mask, + empties=int(empties), + profile=seed_mask.profile, + ) + ) + return rows + + +def write_jsonl(path: Path, rows: Iterable[Dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, separators=(",", ":")) + "\n") + + +def main() -> None: + args = parse_args() + total_needed = int(args.num_train_puzzles) + int(args.num_eval_puzzles) + num_seed_masks = min(max(1, int(args.num_seed_masks)), total_needed) + seed_masks, search_stats = search_seed_masks( + num_seed_masks=num_seed_masks, + empties=int(args.empties), + max_attempts=int(args.max_attempts), + seed=int(args.seed), + progress_every=int(args.progress_every), + ) + if len(seed_masks) < num_seed_masks: + raise RuntimeError( + f"Only found {len(seed_masks)} qualifying seed masks out of requested {num_seed_masks}. " + f"Try increasing --max_attempts or reducing --num_seed_masks." + ) + rows = generate_examples( + num_examples=total_needed, + empties=int(args.empties), + seed_masks=seed_masks, + seed=int(args.seed), + ) + + eval_rows = rows[: int(args.num_eval_puzzles)] + train_rows = rows[int(args.num_eval_puzzles) :] + + train_output = Path(args.train_output).resolve() + eval_output = Path(args.eval_output).resolve() + manifest_output = Path(args.manifest_output).resolve() + + write_jsonl(train_output, train_rows) + write_jsonl(eval_output, eval_rows) + manifest_output.parent.mkdir(parents=True, exist_ok=True) + manifest_output.write_text( + json.dumps( + { + "train_output": str(train_output), + "eval_output": str(eval_output), + "num_train_puzzles": int(len(train_rows)), + "num_eval_puzzles": int(len(eval_rows)), + "empties": int(args.empties), + "seed": int(args.seed), + "required_consistency_stage": 3, + "num_seed_masks": int(num_seed_masks), + "search_stats": search_stats, + }, + indent=2, + sort_keys=True, + ) + + "\n", + encoding="utf-8", + ) + print(f"Wrote {len(train_rows)} train puzzles to {train_output}") + print(f"Wrote {len(eval_rows)} eval puzzles to {eval_output}") + print(f"Wrote manifest to {manifest_output}") + + +if __name__ == "__main__": + main() diff --git a/hard_9x9_stage1_consistency_queue/README.md b/hard_9x9_stage1_consistency_queue/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f4605ce3e1c223a10e7c9e78566870288faebd6e --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/README.md @@ -0,0 +1,117 @@ +# Stage-1 Latent SFT Mode Sweep + +This folder contains launchers for the 9x9 Sudoku curriculum experiments. The +stage-1 latent sweep launcher is: + +```bash +bash hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh +``` + +The goal of this sweep is to compare the four latent implementations under the +same stage-1 SFT setup and measure which one gives the fastest useful +convergence. The main comparison should include training loss, held-out value +precision/recall, completion quality, wall-clock time, and GPU efficiency. In +particular, compare both loss vs. optimizer step and loss vs. elapsed time, +because some methods do more transformer forward passes per step. + +## Four Latent Modes + +### `residual` + +The residual mode performs a dynamic latent hidden rollout, then projects the +difference between the latent hidden state and the base hidden state back into +the model hidden space. This projected delta is added to the base next-token +hidden state before computing logits. It is expressive, but it is slower because +the latent rollout requires repeated transformer passes. + +### `fixed_slots` + +The fixed-slots mode learns a bank of trainable latent slot embeddings plus a +separate final readout slot. For each prediction, the model runs once on: + +```text +[prompt tokens, slot_1, ..., slot_k, final_slot] +``` + +The next token is predicted from the hidden state at `final_slot`. This is a +parallel latent method: all latent slots are inserted at once, so it avoids the +recursive pass used by recurrent methods. + +### `recurrent_hidden` + +The recurrent-hidden mode generates latent tokens dynamically from the current +example. It appends a hidden latent token, reruns the transformer, takes the new +last hidden state as the next latent token, and repeats for `num_cot_tokens`. +This is the closest to iterative hidden reasoning, but it is usually the +slowest because the latent steps are serial. + +### `latent_seeds` + +The latent-seeds mode learns a bank of trainable seed embeddings. For each +prediction, the model runs once on: + +```text +[prompt tokens, seed_1, ..., seed_k] +``` + +The next token is predicted from the hidden state at the last seed position. +Like fixed slots, this is parallel and avoids recursive transformer passes. The +main difference from `fixed_slots` is that there is no separate final readout +slot; the last seed position acts as the readout. + +## Experimental Strategy + +Run all four modes in parallel on stage 1 with the same dataset, LoRA settings, +number of latent tokens, stopping rule, and evaluation set. The default launcher +splits an 8-GPU node into four two-GPU jobs: + +```text +residual -> GPUs 0,1 +fixed_slots -> GPUs 2,3 +recurrent_hidden -> GPUs 4,5 +latent_seeds -> GPUs 6,7 +``` + +Use the results to decide which one or two methods should be promoted to deeper +curriculum stages. The expected practical tradeoff is that `fixed_slots` and +`latent_seeds` should be much faster per wall-clock time, while `residual` and +`recurrent_hidden` test more iterative, example-dependent latent computation. + +## Warm Baseline Stages 1-3 Pipeline + +The full warm-baseline launcher is: + +```bash +STAGE1_BASELINE_ADAPTER_DIR=/path/to/warmed/stage1/baseline/checkpoint \ + bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh +``` + +It runs all four latent modes in parallel, two GPUs per mode: + +```text +residual -> GPUs 0,1 +fixed_slots -> GPUs 2,3 +recurrent_hidden -> GPUs 4,5 +latent_seeds -> GPUs 6,7 +``` + +For each mode, the intended sequence is: + +```text +stage1 latent SFT + -> stage1 latent GRPO + -> stage2 baseline warm-up SFT + -> stage2 latent SFT + -> stage2 latent GRPO + -> stage3 baseline warm-up SFT + -> stage3 latent SFT + -> stage3 latent GRPO +``` + +The run is capped by fixed step budgets by default (`1000` SFT steps and `500` +GRPO steps per phase) and can stop early when the configured solve-rate target +is reached. The current 1.5B run uses the warmed Stage-1 baseline adapter from +`hard_9x9_20empty_baseline_1p5b_warmup`. + +See `warm_baseline_all_latent_modes_stages123_results.md` for the current +solve-rate snapshot from the ongoing full-pipeline run. diff --git a/hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh b/hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh new file mode 100644 index 0000000000000000000000000000000000000000..c915afdc1e0421d45a0a1b1149576adf1485bdcd --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh @@ -0,0 +1,158 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +DATA_PATH="${DATA_PATH:-${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl}" +GPU_ID="${GPU_ID:-0}" +NUM_COT="${NUM_COT:-3}" +MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}" +LIMIT_ROWS="${LIMIT_ROWS:-1}" +TRAIN_STEPS="${TRAIN_STEPS:-60}" +LR="${LR:-1e-1}" +LORA_R="${LORA_R:-32}" +LORA_ALPHA="${LORA_ALPHA:-64}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" + +export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}" +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${GPU_ID}}" + +exec "${PYTHON_BIN}" - <<'PY' +import os + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import ( + attach_fixed_latent_slot_modules, + fixed_slot_next_token_logits_from_ids, + load_jsonl_rows, + load_trainable_adapter, + pick_dtype, + sample_fixed_slot_completion, + unwrap_backbone, +) +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.shared_multi_output_policy import build_supervised_completion + + +def env_int(name: str, default: int) -> int: + return int(os.environ.get(name, str(default))) + + +def env_float(name: str, default: float) -> float: + return float(os.environ.get(name, str(default))) + + +model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct") +data_path = os.environ.get("DATA_PATH", "data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl") +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +num_cot = env_int("NUM_COT", 5) +max_latent_slots = env_int("MAX_LATENT_SLOTS", 8) +limit_rows = env_int("LIMIT_ROWS", 1) +train_steps = env_int("TRAIN_STEPS", 60) +lr = env_float("LR", 1e-1) +lora_r = env_int("LORA_R", 32) +lora_alpha = env_int("LORA_ALPHA", 64) +lora_dropout = env_float("LORA_DROPOUT", 0.05) + +rows = load_jsonl_rows(data_path, limit_rows=limit_rows) +ex = build_cell_examples_from_row(rows[0])[0] + +tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) +if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + +prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=1, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=20, +) +target_text = build_supervised_completion(ex, stage_i=1) + (tokenizer.eos_token or "") +print("target_text", target_text) + +base = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, +) +model = load_trainable_adapter(base, "", lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) +attach_fixed_latent_slot_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + max_latent_slots=max_latent_slots, +) +if hasattr(model, "config"): + model.config.use_cache = False +backbone = unwrap_backbone(model) +if hasattr(backbone, "config"): + backbone.config.use_cache = False +model.to(device) + +for p in model.parameters(): + p.requires_grad = False +model.fixed_latent_slots.requires_grad_(True) +model.fixed_final_slot_embed.requires_grad_(True) +optimizer = torch.optim.AdamW([model.fixed_latent_slots, model.fixed_final_slot_embed], lr=lr) + +prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device) +completion_ids = tokenizer(target_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) + + +@torch.no_grad() +def sample_now(tag: str) -> None: + model.eval() + attn = torch.ones_like(prompt_ids, device=device) + logits = fixed_slot_next_token_logits_from_ids(model, prompt_ids, attn, num_cot) + probs = torch.softmax(logits[0].float(), dim=-1) + top_probs, top_ids = torch.topk(probs, k=5) + out_ids = sample_fixed_slot_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=num_cot, + max_new_tokens=12, + do_sample=False, + ) + top_next = [(tokenizer.decode([int(i)]), round(float(p), 4)) for i, p in zip(top_ids.tolist(), top_probs.tolist())] + print(tag, tokenizer.decode(out_ids[0], skip_special_tokens=True), "top_next", top_next) + + +sample_now("before:") + +for step in range(1, train_steps + 1): + model.train() + cur_ids = prompt_ids + cur_mask = torch.ones_like(prompt_ids, device=device) + losses = [] + for idx in range(int(completion_ids.shape[1])): + logits = fixed_slot_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot) + target = completion_ids[:, idx] + losses.append(F.cross_entropy(logits.float(), target, reduction="mean")) + cur_ids = torch.cat([cur_ids, completion_ids[:, idx : idx + 1]], dim=1) + cur_mask = torch.cat( + [ + cur_mask, + torch.ones((cur_mask.shape[0], 1), dtype=cur_mask.dtype, device=cur_mask.device), + ], + dim=1, + ) + loss = torch.stack(losses).mean() + optimizer.zero_grad(set_to_none=True) + loss.backward() + optimizer.step() + if step == 1 or step % 10 == 0 or step == train_steps: + print(f"step={step} loss={float(loss.item()):.6f}") + sample_now(f"after_step_{step}:") +PY diff --git a/hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..d71a67ea5e80e184bc7d943ba65fc7a2224acac5 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Full 10-empty baseline pipeline, matching the successful 7-empty procedure: +# 1) Stage-1 SFT to value precision/recall >= 0.98 +# 2) Stage-1 GRPO +# 3) Stage-2 SFT +# 4) Stage-2 GRPO +# 5) Stage-3 SFT +# 6) Stage-3 GRPO +# +# This is a wrapper around: +# - launch_10empty_sft_stage1_98p.sh +# - launch_10empty_post_s1sft_stages123_value98.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_full_stages123_value98}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" + +SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_10empty_sft_stage1_98p.sh" +POST_S1_SCRIPT="${SCRIPT_DIR}/launch_10empty_post_s1sft_stages123_value98.sh" +S1_DIR="${OUTPUT_ROOT}/10empty/stage01_sft_i1_10empty_sft98" + +latest_checkpoint_in_dir() { + local d="$1" + shopt -s nullglob + local checkpoints=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#checkpoints[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +printf '=== 10-empty full baseline pipeline (stage1 SFT -> stages123) ===\n' +printf 'run_tag=%s\n' "${RUN_TAG}" +printf 'output_root=%s\n' "${OUTPUT_ROOT}" + +OUTPUT_DIR="${S1_DIR}" \ +RUN_TAG="${RUN_TAG}" \ +CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \ +"${SFT_STAGE1_SCRIPT}" + +STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")" +if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2 + exit 1 +fi + +printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" + +STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \ +RUN_TAG="${RUN_TAG}" \ +CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \ +OUTPUT_ROOT="${OUTPUT_ROOT}" \ +"${POST_S1_SCRIPT}" diff --git a/hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..f22621541961c8d736c2cb84978ee3b7c96551a1 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh @@ -0,0 +1,365 @@ +#!/usr/bin/env bash +# Run AFTER stage-1 SFT finishes (10-empty). Order: +# 1) Stage-1 GRPO (init = your stage-1 SFT adapter) +# 2) Stage-2 SFT (init = stage-1 GRPO adapter) +# 3) Stage-2 GRPO (init = stage-2 SFT adapter) +# 4) Stage-3 SFT (init = stage-2 GRPO adapter) +# 5) Stage-3 GRPO (init = stage-3 SFT adapter) +# +# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall +# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use +# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap +# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling). +# +# Required (full pipeline from stage-1 SFT): +# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX +# +# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT): +# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_10empty +# (OUTPUT_ROOT defaults to dirname of that dir.) +# +# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO): +# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_10empty +# +# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only): +# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_10empty +# +# Optional: +# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=... +# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0 +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=10 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_stages123_value98}" +START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}" +START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}" +RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}" + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}" +elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}" +elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}" +else + if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +fi + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +mkdir -p "${OUTPUT_ROOT}" + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +resolve_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) +fi + +run_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + mkdir -p "${out_dir}" + printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + mkdir -p "${out_dir}" + printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta 0.0 \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-grpo" \ + --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + S2_DIR="${START_AT_STAGE2_GRPO_DIR}" + CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" + if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 + fi + printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2 + G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" + run_grpo 2 "${CKPT_S2}" "${G2_DIR}" + A2="$(resolve_grpo_adapter "${G2_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 + fi + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started at stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2 + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started after stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" +else + printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" +fi +printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \ + "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}" + +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")" +else + run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}" + A1="$(resolve_grpo_adapter "${G1_DIR}")" +fi +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2 + exit 1 +fi +printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty" +run_sft 2 "${A1}" "${S2_DIR}" "5e-5" +CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" +if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 +fi +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" +run_grpo 2 "${CKPT_S2}" "${G2_DIR}" +A2="$(resolve_grpo_adapter "${G2_DIR}")" +if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" +run_sft 3 "${A2}" "${S3_DIR}" "5e-5" +CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" +if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 +fi +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" +run_grpo 3 "${CKPT_S3}" "${G3_DIR}" +A3="$(resolve_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh b/hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdbc17f76c2284eaaf2fceec95ba7c1ecbdd9399 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash +# Stage-1 SFT only for 10-empty: train until eval value_precision AND value_recall +# both reach 0.98 (or max_steps / optional wall clock). Use the resulting +# checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for +# launch_10empty_post_s1sft_stages123_value98.sh. +# +# Fresh LoRA on base model: +# ./launch_10empty_sft_stage1_98p.sh +# +# Continue from a prior SFT checkpoint: +# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_10empty_sft_stage1_98p.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=10 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +SFT_TARGET="${SFT_TARGET:-0.98}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MAX_STEPS="${MAX_STEPS:-30000}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_sft98_stage1}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +if [[ ! -f "${train_jsonl}" ]]; then + mkdir -p "$(dirname "${train_jsonl}")" + printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0 +fi +if [[ ! -f "${eval_jsonl}" ]]; then + mkdir -p "$(dirname "${eval_jsonl}")" + printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1 +fi + +mkdir -p "${OUTPUT_DIR}" + +INIT_FLAGS=() +if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then + INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}") + printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}" +fi + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) + printf 'NOTE: USE_GC=1 — slower, less VRAM.\n' +fi + +if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}" +else + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}" +fi +printf 'Output: %s\n' "${OUTPUT_DIR}" + +exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${OUTPUT_DIR}" \ + --cache_dir "${ROOT}/.hf_cache" \ + "${INIT_FLAGS[@]}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i 1 \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs 64.0 \ + --learning_rate 2e-4 \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 100 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${SFT_TARGET}" \ + --eval_value_recall_stop "${SFT_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh new file mode 100644 index 0000000000000000000000000000000000000000..a66042515a0f2da8b80af91cc3546d1760ce2d81 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh @@ -0,0 +1,125 @@ +#!/usr/bin/env bash +# Stage-1 fixed-slot latent SFT for 20-empty: train until eval value_precision AND +# value_recall both reach 0.98. This uses prompt + z1 + final_slot during stage 1, +# while still updating LoRA weights so the transformer can learn how to use z1. +# +# Fresh run: +# ./launch_20empty_fixed_slot_sft_stage1_98p.sh +# +# Warm-start from a prior checkpoint: +# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_fixed_slot_sft_stage1_98p.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6}" +NUM_PROCESSES="${NUM_PROCESSES:-7}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=20 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +SFT_TARGET="${SFT_TARGET:-0.98}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MAX_STEPS="${MAX_STEPS:-30000}" + +LORA_R="${LORA_R:-32}" +LORA_ALPHA="${LORA_ALPHA:-64}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" +MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-3}" + +PER_DEVICE_TRAIN_BATCH_SIZE="${PER_DEVICE_TRAIN_BATCH_SIZE:-4}" +GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-8}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_fixed_slot_sft98_stage1}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_fixed_slot_sft98_i1_${EMPTIES}empty}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +if [[ ! -f "${train_jsonl}" ]]; then + mkdir -p "$(dirname "${train_jsonl}")" + printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0 +fi +if [[ ! -f "${eval_jsonl}" ]]; then + mkdir -p "$(dirname "${eval_jsonl}")" + printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1 +fi + +mkdir -p "${OUTPUT_DIR}" + +INIT_FLAGS=() +if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then + INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}") + printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}" +fi + +GC_FLAGS=() +if [[ "${USE_GC:-1}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) + printf 'NOTE: USE_GC=1 - slower, less VRAM.\n' +fi + +if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then + printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}" +else + printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}" +fi +printf 'Output: %s\n' "${OUTPUT_DIR}" +printf 'LoRA: r=%s alpha=%s dropout=%s | latent_mode=fixed_slots | active_z=1 | max_latent_slots=%s\n' "${LORA_R}" "${LORA_ALPHA}" "${LORA_DROPOUT}" "${MAX_LATENT_SLOTS}" +printf 'DDP: visible_gpus=%s nproc=%s | batch/device=%s grad_accum=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" "${PER_DEVICE_TRAIN_BATCH_SIZE}" "${GRADIENT_ACCUMULATION_STEPS}" + +exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${OUTPUT_DIR}" \ + --cache_dir "${ROOT}/.hf_cache" \ + "${INIT_FLAGS[@]}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i 1 \ + --num_cot_tokens 1 \ + --latent_mode fixed_slots \ + --max_latent_slots "${MAX_LATENT_SLOTS}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE}" \ + --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS}" \ + --num_epochs 64.0 \ + --learning_rate 2e-4 \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 100 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --eval_value_precision_stop "${SFT_TARGET}" \ + --eval_value_recall_stop "${SFT_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-fixed-slot-sft" \ + --wandb_run_name "${WANDB_RUN_NAME:-stage01_fixed_slot_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..213f674a393b191ccbeeab66806e2adb14ae7f45 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +# Full 20-empty baseline pipeline, matching the successful 10-empty procedure: +# 1) Stage-1 SFT to value precision/recall >= 0.98 +# 2) Stage-1 GRPO +# 3) Stage-2 SFT +# 4) Stage-2 GRPO +# 5) Stage-3 SFT +# 6) Stage-3 GRPO +# +# This is a wrapper around: +# - launch_20empty_sft_stage1_98p.sh +# - launch_20empty_post_s1sft_stages123_value98.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_full_stages123_value98}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" + +SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_20empty_sft_stage1_98p.sh" +POST_S1_SCRIPT="${SCRIPT_DIR}/launch_20empty_post_s1sft_stages123_value98.sh" +S1_DIR="${OUTPUT_ROOT}/20empty/stage01_sft_i1_20empty_sft98" + +latest_checkpoint_in_dir() { + local d="$1" + shopt -s nullglob + local checkpoints=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#checkpoints[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +printf '=== 20-empty full baseline pipeline (stage1 SFT -> stages123) ===\n' +printf 'run_tag=%s\n' "${RUN_TAG}" +printf 'output_root=%s\n' "${OUTPUT_ROOT}" + +OUTPUT_DIR="${S1_DIR}" \ +RUN_TAG="${RUN_TAG}" \ +CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \ +"${SFT_STAGE1_SCRIPT}" + +STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")" +if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2 + exit 1 +fi + +printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" + +STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \ +RUN_TAG="${RUN_TAG}" \ +CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \ +OUTPUT_ROOT="${OUTPUT_ROOT}" \ +"${POST_S1_SCRIPT}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..996b5a6eda6b67f14adf96dd73f8760ae115546a --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh @@ -0,0 +1,341 @@ +#!/usr/bin/env bash +# Latent recurrent-hidden (Coconut-style) pipeline for 20-empty Sudoku. +# +# Per-stage latent token count grows with curriculum: +# stage 1 -> num_cot_tokens = 1 +# stage 2 -> num_cot_tokens = 2 +# stage 3 -> num_cot_tokens = 3 +# +# Pipeline: +# Stage 1 SFT (cot=1, fresh LoRA + random latent state) +# -> Stage 1 GRPO (cot=1) +# -> Stage 2 SFT (cot=2) +# -> Stage 2 GRPO (cot=2) +# -> Stage 3 SFT (cot=3) +# -> Stage 3 GRPO (cot=3) +# +# Mirrors the hyperparameters of the successful 20-empty recurrent-hidden stage-1 +# run (bs=8 per-device, gradient accumulation 2, gradient checkpointing ON). +# +# Optional overrides: +# STAGE1_INIT_ADAPTER_DIR=/path/to/adapter +# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX +# VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +LATENT_MODE="recurrent_hidden" +EMPTIES=20 +TAG_SUFFIX="latent_recurrent" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +# Per-phase early-stop bars. Default behavior preserved: both phases use +# VALUE_TARGET unless explicitly overridden. Recommended: SFT_VALUE_TARGET=0.95 +# (let SFT do bulk learning quickly) and GRPO_VALUE_TARGET=0.98 (let GRPO push +# the last few percent of value precision/recall). +SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}" +GRPO_VALUE_TARGET="${GRPO_VALUE_TARGET:-${VALUE_TARGET}}" +MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +# -1 resolves in code to hidden_size, and alpha=-1 resolves to 2 * resolved rank. +LORA_R="${LORA_R:--1}" +LORA_ALPHA="${LORA_ALPHA:--1}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" +STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}" +SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}" +SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}" +GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-8}" +GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_recurrent_stages123_value98}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}" +STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}" +# When set, skip both Stage-1 SFT and Stage-1 GRPO and use this adapter +# directly as the init for Stage-2 SFT. Useful for resuming after a Stage-1 +# GRPO post-training eval hangs but the LoRA adapter is already on disk. +STAGE1_GRPO_ADAPTER_DIR="${STAGE1_GRPO_ADAPTER_DIR:-}" +STAGE2_SFT_ADAPTER_DIR="${STAGE2_SFT_ADAPTER_DIR:-}" +STAGE2_GRPO_ADAPTER_DIR="${STAGE2_GRPO_ADAPTER_DIR:-}" +# When set, skip Stage-3 SFT and use this adapter directly as the init for +# Stage-3 GRPO. Useful when SFT plateaus mid-training and we want GRPO to push +# the last few percentage points without burning more SFT compute. +STAGE3_SFT_ADAPTER_DIR="${STAGE3_SFT_ADAPTER_DIR:-}" +# KL anchor for GRPO. Setting > 0 keeps the policy close to the SFT reference +# and prevents singleton/mode collapse seen in Stage-2 GRPO. 0.0 = no KL. +GRPO_BETA="${GRPO_BETA:-0.0}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +mkdir -p "${OUTPUT_ROOT}" + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train or eval jsonl.\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +resolve_latent_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +run_latent_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + local cot="$5" + local ms1=0 ms2=1 + if [[ "${stage}" == "1" ]]; then + ms1=1 + ms2=0 + fi + mkdir -p "${out_dir}" + printf '\n=== Latent(recurrent) stage %s SFT -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${SFT_VALUE_TARGET}" "${cot}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --latent_mode "${LATENT_MODE}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio "${ms1}" \ + --mixed_stage2_ratio "${ms2}" \ + --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --weight_decay 0.0 \ + --enable_gradient_checkpointing \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --eval_value_precision_stop "${SFT_VALUE_TARGET}" \ + --eval_value_recall_stop "${SFT_VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-sft-recurrent" \ + --wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_cot${cot}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_latent_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local cot="$4" + mkdir -p "${out_dir}" + printf '\n=== Latent(recurrent) stage %s GRPO -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${GRPO_VALUE_TARGET}" "${cot}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --latent_mode "${LATENT_MODE}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio 0 \ + --mixed_stage2_ratio 1 \ + --per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta "${GRPO_BETA}" \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${GRPO_VALUE_TARGET}" \ + --eval_value_recall_stop "${GRPO_VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-grpo-recurrent" \ + --wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_cot${cot}_val${GRPO_VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +printf 'Latent mode: %s (cot grows 1->2->3 per stage)\n' "${LATENT_MODE}" +printf 'Value gate: SFT prec+recall >= %s ; GRPO prec+recall >= %s (min_steps=%s) ; GRPO_BETA=%s\n' "${SFT_VALUE_TARGET}" "${GRPO_VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}" "${GRPO_BETA}" +printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-}" + +S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}" +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}" +if [[ -n "${STAGE1_GRPO_ADAPTER_DIR}" ]]; then + A1="${STAGE1_GRPO_ADAPTER_DIR}" + printf 'Using existing stage-1 GRPO adapter (skipping stage-1 SFT + GRPO): %s\n' "${A1}" >&2 +elif [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}" + printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2 + run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1 + A1="$(resolve_latent_grpo_adapter "${G1_DIR}")" +else + run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1 + G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")" + if [[ -z "${G1_SFT_CKPT}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2 + exit 1 + fi + run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1 + A1="$(resolve_latent_grpo_adapter "${G1_DIR}")" +fi +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2 + exit 1 +fi +printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}" +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}" +if [[ -n "${STAGE2_GRPO_ADAPTER_DIR}" ]]; then + A2="${STAGE2_GRPO_ADAPTER_DIR}" + printf 'Using existing stage-2 GRPO adapter (skipping stage-2 SFT + GRPO): %s\n' "${A2}" >&2 +elif [[ -n "${STAGE2_SFT_ADAPTER_DIR}" ]]; then + CKPT_S2="${STAGE2_SFT_ADAPTER_DIR}" + printf 'Using existing stage-2 SFT checkpoint as GRPO init (skipping stage-2 SFT train): %s\n' "${CKPT_S2}" >&2 + run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2 + A2="$(resolve_latent_grpo_adapter "${G2_DIR}")" +else + run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2 + CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" + if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 + fi + run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2 + A2="$(resolve_latent_grpo_adapter "${G2_DIR}")" +fi + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}" +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}" +if [[ -n "${STAGE3_SFT_ADAPTER_DIR}" ]]; then + CKPT_S3="${STAGE3_SFT_ADAPTER_DIR}" + printf 'Using existing stage-3 SFT checkpoint as GRPO init (skipping stage-3 SFT train): %s\n' "${CKPT_S3}" >&2 +else + run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3 + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi +fi +run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3 +A3="$(resolve_latent_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll latent(recurrent) phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6eefaef58277781a339250fc8c9768b7edfeda9 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh @@ -0,0 +1,279 @@ +#!/usr/bin/env bash +# Latent residual-projector pipeline for 20-empty Sudoku. +# Default behavior: +# - Stage 1 SFT starts from fresh LoRA + random latent residual modules +# - Then stage 1 GRPO -> stage 2 SFT -> stage 2 GRPO -> stage 3 SFT -> stage 3 GRPO +# Optional: +# STAGE1_INIT_ADAPTER_DIR=/path/to/adapter +# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX +# VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=20 +TAG_SUFFIX="latent_residual" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +LORA_R="${LORA_R:-32}" +LORA_ALPHA="${LORA_ALPHA:-64}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" +STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_residual_stages123_value98}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}" +STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +mkdir -p "${OUTPUT_ROOT}" + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train or eval jsonl.\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +resolve_latent_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + [[ -f "${c}/latent_cot_state.pt" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +run_latent_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + local cot="$5" + local ms1=0 ms2=1 + if [[ "${stage}" == "1" ]]; then + ms1=1 + ms2=0 + fi + mkdir -p "${out_dir}" + printf '\n=== Latent stage %s SFT -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio "${ms1}" \ + --mixed_stage2_ratio "${ms2}" \ + --gradient_accumulation_steps 2 \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --weight_decay 0.0 \ + --enable_gradient_checkpointing \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-sft-residual-projector" \ + --wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_latent_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local cot="$4" + mkdir -p "${out_dir}" + printf '\n=== Latent stage %s GRPO -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio 0 \ + --mixed_stage2_ratio 1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta 0.0 \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \ + --wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}" +printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-}" + +S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}" +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}" +if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}" + printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2 +else + run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1 + G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")" + if [[ -z "${G1_SFT_CKPT}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2 + exit 1 + fi +fi +run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1 +A1="$(resolve_latent_grpo_adapter "${G1_DIR}")" +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2 + exit 1 +fi +printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2 +CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" +if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 +fi +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2 +A2="$(resolve_latent_grpo_adapter "${G2_DIR}")" +if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3 +CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" +if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 +fi +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3 +A3="$(resolve_latent_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll latent residual phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e8e6bbd3c16ad9cc1ba28c5efd81d025e0da6f7 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh @@ -0,0 +1,368 @@ +#!/usr/bin/env bash +# Run AFTER stage-1 SFT finishes (20-empty). Order: +# 1) Stage-1 GRPO (init = your stage-1 SFT adapter) +# 2) Stage-2 SFT (init = stage-1 GRPO adapter) +# 3) Stage-2 GRPO (init = stage-2 SFT adapter) +# 4) Stage-3 SFT (init = stage-2 GRPO adapter) +# 5) Stage-3 GRPO (init = stage-3 SFT adapter) +# +# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall +# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use +# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap +# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling). +# +# Required (full pipeline from stage-1 SFT): +# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX +# +# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT): +# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_20empty +# (OUTPUT_ROOT defaults to dirname of that dir.) +# +# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO): +# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_20empty +# +# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only): +# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_20empty +# +# Optional: +# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=... +# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0 +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=20 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_stages123_value98}" +START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}" +START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}" +RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}" + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}" +elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}" +elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}" +else + if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +fi + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +mkdir -p "${OUTPUT_ROOT}" + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +resolve_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) +fi + +run_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + mkdir -p "${out_dir}" + printf '\n=== Stage %s SFT -> stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + mkdir -p "${out_dir}" + printf '\n=== Stage %s GRPO -> stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta 0.0 \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-grpo" \ + --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 SFT dir %s -> stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + S2_DIR="${START_AT_STAGE2_GRPO_DIR}" + CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" + if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 + fi + printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2 + G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" + run_grpo 2 "${CKPT_S2}" "${G2_DIR}" + A2="$(resolve_grpo_adapter "${G2_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 + fi + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started at stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 GRPO dir %s -> stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2 + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started after stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" +else + printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" +fi +printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \ + "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}" + +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")" +else + run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}" + A1="$(resolve_grpo_adapter "${G1_DIR}")" +fi +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2 + exit 1 +fi +printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty" +run_sft 2 "${A1}" "${S2_DIR}" "5e-5" +CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" +if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 +fi +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" +run_grpo 2 "${CKPT_S2}" "${G2_DIR}" +A2="$(resolve_grpo_adapter "${G2_DIR}")" +if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" +run_sft 3 "${A2}" "${S3_DIR}" "5e-5" +CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" +if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 +fi +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" +run_grpo 3 "${CKPT_S3}" "${G3_DIR}" +A3="$(resolve_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5ae996286e8221617e4f7ee35e440d8e30500ec --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash +# Stage-1 SFT only for 20-empty: train until eval value_precision AND value_recall +# both reach 0.98 (or max_steps / optional wall clock). Use the resulting +# checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for +# launch_20empty_post_s1sft_stages123_value98.sh. +# +# Fresh LoRA on base model: +# ./launch_20empty_sft_stage1_98p.sh +# +# Continue from a prior SFT checkpoint: +# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_sft_stage1_98p.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=20 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +SFT_TARGET="${SFT_TARGET:-0.98}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MAX_STEPS="${MAX_STEPS:-30000}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_sft98_stage1}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +if [[ ! -f "${train_jsonl}" ]]; then + mkdir -p "$(dirname "${train_jsonl}")" + printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0 +fi +if [[ ! -f "${eval_jsonl}" ]]; then + mkdir -p "$(dirname "${eval_jsonl}")" + printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1 +fi + +mkdir -p "${OUTPUT_DIR}" + +INIT_FLAGS=() +if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then + INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}") + printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}" +fi + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) + printf 'NOTE: USE_GC=1 - slower, less VRAM.\n' +fi + +if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}" +else + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}" +fi +printf 'Output: %s\n' "${OUTPUT_DIR}" + +exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${OUTPUT_DIR}" \ + --cache_dir "${ROOT}/.hf_cache" \ + "${INIT_FLAGS[@]}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i 1 \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs 64.0 \ + --learning_rate 2e-4 \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 100 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${SFT_TARGET}" \ + --eval_value_recall_stop "${SFT_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh new file mode 100644 index 0000000000000000000000000000000000000000..7d8ce726928e9bdbc898ceda1782cfde8df14d80 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh @@ -0,0 +1,187 @@ +#!/usr/bin/env bash +# Stage-1 SFT sweep over all latent modes for 20-empty Sudoku. +# +# Runs four independent SFT jobs in parallel: +# residual, fixed_slots, recurrent_hidden, latent_seeds +# +# Default GPU split on an 8-GPU node: +# residual -> CUDA_VISIBLE_DEVICES=0,1 +# fixed_slots -> CUDA_VISIBLE_DEVICES=2,3 +# recurrent_hidden -> CUDA_VISIBLE_DEVICES=4,5 +# latent_seeds -> CUDA_VISIBLE_DEVICES=6,7 +# +# Useful overrides: +# RUN_TAG=... CHECKPOINT_ROOT=... +# GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1 +# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 SFT_VALUE_TARGET=0.98 +# STAGE1_INIT_ADAPTER_DIR=/path/to/init_adapter +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py" + +MODES=("residual" "fixed_slots" "recurrent_hidden" "latent_seeds") +MODE_TAGS=("latent_residual" "latent_fixed_slots" "latent_recurrent_hidden" "latent_seeds") + +# Space-separated list of CUDA_VISIBLE_DEVICES groups, one per latent mode. +# Example for one GPU per method: GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1 +GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}" +read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}" + +NPROC_PER_JOB="${NPROC_PER_JOB:-2}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES="${EMPTIES:-20}" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}" +MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" + +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +# Match the recurrent 20-empty launcher defaults: -1 resolves inside the +# trainer to hidden_size, and alpha=-1 resolves to 2 * resolved rank. +LORA_R="${LORA_R:--1}" +LORA_ALPHA="${LORA_ALPHA:--1}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" +STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}" +SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}" +SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}" +NUM_COT_TOKENS="${NUM_COT_TOKENS:-1}" +MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}" +MAX_LATENT_SEEDS="${MAX_LATENT_SEEDS:-8}" +STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_stage1_sft_all_latent_modes}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}" + +if [[ ${#GPU_GROUPS[@]} -ne ${#MODES[@]} ]]; then + printf 'ERROR: expected %d GPU groups, got %d.\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2 + printf 'Example: GPU_GROUPS_SPEC="0,1 2,3 4,5 6,7"\n' >&2 + exit 1 +fi + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train or eval jsonl.\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then + printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2 + exit 1 +fi + +mkdir -p "${OUTPUT_ROOT}" + +run_stage1_sft_for_mode() { + local mode="$1" + local tag="$2" + local gpu_group="$3" + local out_dir="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${tag}" + local log_dir="${OUTPUT_ROOT}/logs" + local log_file="${log_dir}/stage01_sft_${mode}.log" + + mkdir -p "${out_dir}" "${log_dir}" + printf '\n=== launching stage-1 SFT: mode=%s gpus=%s out=%s ===\n' "${mode}" "${gpu_group}" "${out_dir}" >&2 + + ( + export CUDA_VISIBLE_DEVICES="${gpu_group}" + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i 1 \ + --num_cot_tokens "${NUM_COT_TOKENS}" \ + --latent_mode "${mode}" \ + --max_latent_slots "${MAX_LATENT_SLOTS}" \ + --max_latent_seeds "${MAX_LATENT_SEEDS}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio 1 \ + --mixed_stage2_ratio 0 \ + --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${STAGE1_SFT_LR}" \ + --weight_decay 0.0 \ + --enable_gradient_checkpointing \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --eval_value_precision_stop "${SFT_VALUE_TARGET}" \ + --eval_value_recall_stop "${SFT_VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-stage1-sft-all-modes" \ + --wandb_run_name "latent20_stage1_sft_${mode}_cot${NUM_COT_TOKENS}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" + ) >"${log_file}" 2>&1 & + + printf '%s\n' "$!" +} + +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-}" +printf 'Modes: %s\n' "${MODES[*]}" +printf 'GPU groups: %s\n' "${GPU_GROUPS[*]}" +printf 'Processes per job: %s\n' "${NPROC_PER_JOB}" + +pids=() +names=() +for i in "${!MODES[@]}"; do + pid="$(run_stage1_sft_for_mode "${MODES[$i]}" "${MODE_TAGS[$i]}" "${GPU_GROUPS[$i]}")" + pids+=("${pid}") + names+=("${MODES[$i]}") +done + +failed=0 +for i in "${!pids[@]}"; do + if wait "${pids[$i]}"; then + printf 'DONE: %s\n' "${names[$i]}" + else + printf 'FAILED: %s (pid=%s). See logs under %s/logs\n' "${names[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2 + failed=1 + fi +done + +if [[ "${failed}" -ne 0 ]]; then + exit 1 +fi + +printf '\nAll stage-1 latent SFT jobs finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" diff --git a/hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh b/hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh new file mode 100644 index 0000000000000000000000000000000000000000..de2687530d65e03a3983393a1318fd88bad658eb --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh @@ -0,0 +1,394 @@ +#!/usr/bin/env bash +# Full 20-empty latent comparison with baseline warm-up before latent stages. +# +# Required: +# STAGE1_BASELINE_ADAPTER_DIR=/path/to/baseline/stage1/checkpoint-step-XXXXX +# +# Default mode split on 8 GPUs: +# residual -> GPUs 0,1 +# fixed_slots -> GPUs 2,3 +# recurrent_hidden -> GPUs 4,5 +# latent_seeds -> GPUs 6,7 +# +# Per mode: +# stage1 latent SFT -> stage1 latent GRPO +# stage2 baseline SFT warm-up -> stage2 latent SFT -> stage2 latent GRPO +# stage3 baseline SFT warm-up -> stage3 latent SFT -> stage3 latent GRPO +# +# Optional resume adapters, intended for single-mode resumes: +# STAGE1_LATENT_SFT_ADAPTER_DIR=/path/to/stage01_latent_sft_or_checkpoint +# STAGE1_LATENT_GRPO_ADAPTER_DIR=/path/to/stage01_latent_grpo +# STAGE2_BASELINE_WARM_ADAPTER_DIR=/path/to/stage02_baseline_warm_sft +# STAGE2_LATENT_SFT_INIT_ADAPTER_DIR=/path/to/stage02_latent_sft_checkpoint_to_continue_training +# STAGE2_LATENT_SFT_ADAPTER_DIR=/path/to/stage02_latent_sft_or_checkpoint +# STAGE2_LATENT_GRPO_ADAPTER_DIR=/path/to/stage02_latent_grpo +# STAGE3_BASELINE_WARM_ADAPTER_DIR=/path/to/stage03_baseline_warm_sft +# STAGE3_LATENT_SFT_ADAPTER_DIR=/path/to/stage03_latent_sft_or_checkpoint +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +BASELINE_SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" +LATENT_SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py" +LATENT_GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py" + +EMPTIES="${EMPTIES:-20}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +MODES_SPEC="${MODES_SPEC:-residual fixed_slots recurrent_hidden latent_seeds}" +GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}" +NPROC_PER_JOB="${NPROC_PER_JOB:-2}" + +STAGE1_BASELINE_ADAPTER_DIR="${STAGE1_BASELINE_ADAPTER_DIR:-}" +STAGE1_LATENT_SFT_ADAPTER_DIR="${STAGE1_LATENT_SFT_ADAPTER_DIR:-}" +STAGE1_LATENT_GRPO_ADAPTER_DIR="${STAGE1_LATENT_GRPO_ADAPTER_DIR:-}" +STAGE2_BASELINE_WARM_ADAPTER_DIR="${STAGE2_BASELINE_WARM_ADAPTER_DIR:-}" +STAGE2_LATENT_SFT_INIT_ADAPTER_DIR="${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR:-}" +STAGE2_LATENT_SFT_ADAPTER_DIR="${STAGE2_LATENT_SFT_ADAPTER_DIR:-}" +STAGE2_LATENT_GRPO_ADAPTER_DIR="${STAGE2_LATENT_GRPO_ADAPTER_DIR:-}" +STAGE3_BASELINE_WARM_ADAPTER_DIR="${STAGE3_BASELINE_WARM_ADAPTER_DIR:-}" +STAGE3_LATENT_SFT_ADAPTER_DIR="${STAGE3_LATENT_SFT_ADAPTER_DIR:-}" +if [[ -z "${STAGE1_LATENT_SFT_ADAPTER_DIR}" && -z "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then + if [[ -z "${STAGE1_BASELINE_ADAPTER_DIR}" ]] || [[ ! -d "${STAGE1_BASELINE_ADAPTER_DIR}" ]]; then + printf 'ERROR: Set STAGE1_BASELINE_ADAPTER_DIR to a finished baseline SFT checkpoint directory, or provide a stage-1 latent resume adapter.\n' >&2 + exit 1 + fi +fi + +SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}" +SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}" +BASELINE_PER_DEVICE_BS="${BASELINE_PER_DEVICE_BS:-16}" +BASELINE_GRAD_ACCUM="${BASELINE_GRAD_ACCUM:-2}" +GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-4}" +GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}" + +BASELINE_WARM_MAX_STEPS="${BASELINE_WARM_MAX_STEPS:-1000}" +LATENT_SFT_MAX_STEPS="${LATENT_SFT_MAX_STEPS:-1000}" +LATENT_GRPO_MAX_STEPS="${LATENT_GRPO_MAX_STEPS:-500}" +LATENT_SFT_EVAL_STEPS="${LATENT_SFT_EVAL_STEPS:-250}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-64}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-50}" + +SOLVE_TARGET="${SOLVE_TARGET:-0.95}" +VALUE_TARGET="${VALUE_TARGET:-0}" +MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}" +GRPO_BETA="${GRPO_BETA:-0.0}" + +LORA_R="${LORA_R:-32}" +LORA_ALPHA="${LORA_ALPHA:-64}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_warm_baseline_all_latent_modes_stages123}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}" + +read -r -a MODES <<< "${MODES_SPEC}" +read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}" +if [[ ${#MODES[@]} -ne ${#GPU_GROUPS[@]} ]]; then + printf 'ERROR: expected one GPU group per mode. modes=%d gpu_groups=%d\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2 + exit 1 +fi + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train or eval jsonl.\n %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +mkdir -p "${OUTPUT_ROOT}/logs" + +mode_tag() { + case "$1" in + residual) printf 'latent_residual' ;; + fixed_slots) printf 'latent_fixed_slots' ;; + recurrent_hidden) printf 'latent_recurrent_hidden' ;; + latent_seeds) printf 'latent_seeds' ;; + *) printf 'latent_%s' "$1" ;; + esac +} + +latest_checkpoint_or_dir() { + local d="$1" + shopt -s nullglob + local checkpoints=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#checkpoints[@]} > 0 )); then + printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1 + return 0 + fi + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + printf '' + return 1 +} + +run_baseline_sft() { + local stage="$1" init_adapter="$2" out_dir="$3" lr="$4" run_name="$5" + mkdir -p "${out_dir}" + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${BASELINE_SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size "${BASELINE_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${BASELINE_GRAD_ACCUM}" \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --max_grad_norm 1.0 \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop "${SOLVE_TARGET}" \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds 0 \ + --max_steps "${BASELINE_WARM_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-baseline-stage-warmups" \ + --wandb_run_name "${run_name}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_latent_sft() { + local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" lr="$6" run_name="$7" + local ms1=0 ms2=1 + if [[ "${stage}" == "1" ]]; then + ms1=1 + ms2=0 + fi + mkdir -p "${out_dir}" + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --latent_mode "${mode}" \ + --max_latent_slots 8 \ + --max_latent_seeds 8 \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio "${ms1}" \ + --mixed_stage2_ratio "${ms2}" \ + --per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --weight_decay 0.0 \ + --enable_gradient_checkpointing \ + --logging_steps 20 \ + --eval_steps "${LATENT_SFT_EVAL_STEPS}" \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop "${SOLVE_TARGET}" \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds 0 \ + --max_steps "${LATENT_SFT_MAX_STEPS}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-stage-sft-warm-baseline" \ + --wandb_run_name "${run_name}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_latent_grpo() { + local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" run_name="$6" + mkdir -p "${out_dir}" + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_GRPO_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --latent_mode "${mode}" \ + --max_latent_seeds 8 \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio 0 \ + --mixed_stage2_ratio 1 \ + --per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \ + --gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta "${GRPO_BETA}" \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop "${SOLVE_TARGET}" \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds 0 \ + --max_steps "${LATENT_GRPO_MAX_STEPS}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-stage-grpo-warm-baseline" \ + --wandb_run_name "${run_name}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_mode_pipeline() { + local mode="$1" gpu_group="$2" tag + tag="$(mode_tag "${mode}")" + local mode_root="${OUTPUT_ROOT}/${tag}" + local log="${OUTPUT_ROOT}/logs/${tag}.log" + mkdir -p "${mode_root}" + export CUDA_VISIBLE_DEVICES="${gpu_group}" + printf 'Mode %s on GPUs %s\n' "${mode}" "${gpu_group}" + + local s1_lat="${mode_root}/stage01_latent_sft_i1_${EMPTIES}empty_${tag}" + local g1="${mode_root}/stage01_latent_grpo_i1_${EMPTIES}empty_${tag}" + local a_s1_lat a_g1 + if [[ -n "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then + a_g1="$(latest_checkpoint_or_dir "${STAGE1_LATENT_GRPO_ADAPTER_DIR}")" + printf 'Using existing stage-1 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g1}" | tee -a "${log}" + else + if [[ -n "${STAGE1_LATENT_SFT_ADAPTER_DIR}" ]]; then + a_s1_lat="$(latest_checkpoint_or_dir "${STAGE1_LATENT_SFT_ADAPTER_DIR}")" + printf 'Using existing stage-1 latent SFT adapter for %s: %s\n' "${mode}" "${a_s1_lat}" | tee -a "${log}" + else + run_latent_sft "${mode}" 1 1 "${STAGE1_BASELINE_ADAPTER_DIR}" "${s1_lat}" "2e-4" "warmfull_${mode}_st1_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_s1_lat="$(latest_checkpoint_or_dir "${s1_lat}")" + fi + run_latent_grpo "${mode}" 1 1 "${a_s1_lat}" "${g1}" "warmfull_${mode}_st1_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_g1="$(latest_checkpoint_or_dir "${g1}")" + fi + + local b2="${mode_root}/stage02_baseline_warm_sft_i2_${EMPTIES}empty_${tag}" + local s2_lat="${mode_root}/stage02_latent_sft_i2_${EMPTIES}empty_${tag}" + local g2="${mode_root}/stage02_latent_grpo_i2_${EMPTIES}empty_${tag}" + local a_b2 a_s2_lat a_g2 + if [[ -n "${STAGE2_LATENT_GRPO_ADAPTER_DIR}" ]]; then + a_g2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_GRPO_ADAPTER_DIR}")" + printf 'Using existing stage-2 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g2}" | tee -a "${log}" + else + if [[ -n "${STAGE2_LATENT_SFT_ADAPTER_DIR}" ]]; then + a_s2_lat="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_ADAPTER_DIR}")" + printf 'Using existing stage-2 latent SFT adapter for %s: %s\n' "${mode}" "${a_s2_lat}" | tee -a "${log}" + else + if [[ -n "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}" ]]; then + a_b2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}")" + printf 'Continuing stage-2 latent SFT for %s from adapter: %s\n' "${mode}" "${a_b2}" | tee -a "${log}" + elif [[ -n "${STAGE2_BASELINE_WARM_ADAPTER_DIR}" ]]; then + a_b2="$(latest_checkpoint_or_dir "${STAGE2_BASELINE_WARM_ADAPTER_DIR}")" + printf 'Using existing stage-2 baseline warm adapter for %s: %s\n' "${mode}" "${a_b2}" | tee -a "${log}" + else + run_baseline_sft 2 "${a_g1}" "${b2}" "5e-5" "warmfull_${mode}_st2_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_b2="$(latest_checkpoint_or_dir "${b2}")" + fi + run_latent_sft "${mode}" 2 2 "${a_b2}" "${s2_lat}" "5e-5" "warmfull_${mode}_st2_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_s2_lat="$(latest_checkpoint_or_dir "${s2_lat}")" + fi + run_latent_grpo "${mode}" 2 2 "${a_s2_lat}" "${g2}" "warmfull_${mode}_st2_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_g2="$(latest_checkpoint_or_dir "${g2}")" + fi + + local b3="${mode_root}/stage03_baseline_warm_sft_i3_${EMPTIES}empty_${tag}" + local s3_lat="${mode_root}/stage03_latent_sft_i3_${EMPTIES}empty_${tag}" + local g3="${mode_root}/stage03_latent_grpo_i3_${EMPTIES}empty_${tag}" + local a_b3 a_s3_lat + if [[ -n "${STAGE3_LATENT_SFT_ADAPTER_DIR}" ]]; then + a_s3_lat="$(latest_checkpoint_or_dir "${STAGE3_LATENT_SFT_ADAPTER_DIR}")" + printf 'Using existing stage-3 latent SFT adapter for %s: %s\n' "${mode}" "${a_s3_lat}" | tee -a "${log}" + else + if [[ -n "${STAGE3_BASELINE_WARM_ADAPTER_DIR}" ]]; then + a_b3="$(latest_checkpoint_or_dir "${STAGE3_BASELINE_WARM_ADAPTER_DIR}")" + printf 'Using existing stage-3 baseline warm adapter for %s: %s\n' "${mode}" "${a_b3}" | tee -a "${log}" + else + run_baseline_sft 3 "${a_g2}" "${b3}" "5e-5" "warmfull_${mode}_st3_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_b3="$(latest_checkpoint_or_dir "${b3}")" + fi + run_latent_sft "${mode}" 3 3 "${a_b3}" "${s3_lat}" "5e-5" "warmfull_${mode}_st3_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}" + a_s3_lat="$(latest_checkpoint_or_dir "${s3_lat}")" + fi + run_latent_grpo "${mode}" 3 3 "${a_s3_lat}" "${g3}" "warmfull_${mode}_st3_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}" + + printf 'Mode %s finished. Output: %s\n' "${mode}" "${mode_root}" | tee -a "${log}" +} + +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stage-1 baseline adapter: %s\n' "${STAGE1_BASELINE_ADAPTER_DIR}" +printf 'Solve target: %s (value target: %s)\n' "${SOLVE_TARGET}" "${VALUE_TARGET}" + +pids=() +for i in "${!MODES[@]}"; do + ( + run_mode_pipeline "${MODES[$i]}" "${GPU_GROUPS[$i]}" + ) >"${OUTPUT_ROOT}/logs/$(mode_tag "${MODES[$i]}").supervisor.log" 2>&1 & + pids+=("$!") + printf 'Launched mode=%s pid=%s gpus=%s\n' "${MODES[$i]}" "${pids[-1]}" "${GPU_GROUPS[$i]}" +done + +failed=0 +for i in "${!pids[@]}"; do + if wait "${pids[$i]}"; then + printf 'DONE: %s\n' "${MODES[$i]}" + else + printf 'FAILED: %s (pid=%s). See %s/logs\n' "${MODES[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2 + failed=1 + fi +done + +exit "${failed}" diff --git a/hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..bdd4a8cc7ab1fd457912e7c8f888802effb9e723 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh @@ -0,0 +1,419 @@ +#!/usr/bin/env bash +# Latent residual projector pipeline (7-empty), aligned with the text +# launch_7empty_post_s1sft_stages123_value98.sh order and value gate: +# 1) Stage-1 SFT (default: init = STAGE1_INIT_ADAPTER_DIR or fresh LoRA + random residual) +# 2) Stage-1 GRPO (init = stage-1 SFT checkpoint-step-* dir, or STAGE1_SFT_ADAPTER_DIR if set) +# 3) Stage-2 SFT (init = stage-1 GRPO adapter) +# 4) Stage-2 GRPO +# 5) Stage-3 SFT +# 6) Stage-3 GRPO +# +# Legacy GRPO-first (skip training stage-1 SFT): STAGE1_GRPO_FIRST=1 +# +# Latent structure (implemented in latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py): +# - attach_residual_projector_modules(): adds trainable special_thought_embed, latent_mix_logit, +# and MLP latent_projector_in/out (hidden→4096→hidden) on the Peft-wrapped model. +# - build_latent_hidden() / residual_next_token_logits_from_ids(): append num_cot_tokens "latent" +# virtual tokens, run backbone, take (latent_hidden - base_hidden), project through the MLP, +# mix with base hidden (sigmoid(latent_mix_logit)), then lm_head logits (with optional fallback). +# - sample_latent_completion() / GRPO use this path for generation; SFT uses the same via +# residual_projector_warmstart_sft_latent_multi_output_train.py (latent_residual_completion_ce_loss). +# - latent_cot_state.pt saves/loads the projector + special_thought_embed + mix logit. +# +# Each phase stops when eval value_precision AND value_recall are both >= VALUE_TARGET +# (default 0.98), after MIN_STEPS_BEFORE_STOP optimizer steps (SFT) / GRPO steps (GRPO). +# Eval rows come from eval_jsonl (same held-out file as the text pipeline). +# +# Stage-1 SFT init (when not using STAGE1_SFT_ADAPTER_DIR or STAGE1_GRPO_FIRST): +# Default: omit STAGE1_INIT_ADAPTER_DIR → fresh LoRA + random residual (same as trainers --init_adapter_dir ""). +# Optional: STAGE1_INIT_ADAPTER_DIR=/path/to/adapter +# +# Skip running stage-1 SFT (you already have a finished SFT checkpoint-step-*): +# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft_.../checkpoint-step-XXXX +# → first trained phase is stage-1 GRPO with that init. +# +# Resume: +# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty_latent_residual +# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty_latent_residual +# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty_latent_residual +# +# Optional env: VALUE_TARGET, TRAIN_PUZZLES, EVAL_PUZZLES, RUN_TAG, CHECKPOINT_ROOT, GPU_IDS, +# WANDB_MODE, WANDB_ENTITY, SFT_NUM_EPOCHS, GRPO_NUM_TRAIN_EPOCHS, SFT_MAX_STEPS, GRPO_MAX_STEPS, +# STAGE1_SFT_LR (default 2e-4), STAGE1_GRPO_FIRST, STAGE1_SFT_ADAPTER_DIR +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=7 +TAG_SUFFIX="latent_residual" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}" +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +# Fresh-LoRA defaults (match text 7-empty SFT scale); override if you use a different init checkpoint. +LORA_R="${LORA_R:-32}" +LORA_ALPHA="${LORA_ALPHA:-64}" +LORA_DROPOUT="${LORA_DROPOUT:-0.05}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_latent_residual_stages123_value98}" +START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}" +START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}" +RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}" +STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}" +STAGE1_GRPO_FIRST="${STAGE1_GRPO_FIRST:-0}" +STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}" + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}" +elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}" +elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}" +else + if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then + printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2 + exit 1 + fi + if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: STAGE1_SFT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +fi + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +mkdir -p "${OUTPUT_ROOT}" + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +resolve_latent_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +run_latent_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + local cot="$5" + # Stage-1 SFT must weight stage-1 rows only (mixed 1/0). Stages 2–3 use stage-i curriculum (mixed 0/1). + local ms1=0 ms2=1 + if [[ "${stage}" == "1" ]]; then + ms1=1 + ms2=0 + fi + mkdir -p "${out_dir}" + printf '\n=== Latent stage %s SFT (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio "${ms1}" \ + --mixed_stage2_ratio "${ms2}" \ + --gradient_accumulation_steps 2 \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --weight_decay 0.0 \ + --enable_gradient_checkpointing \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --eval_jsonl "${eval_jsonl}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-sft-residual-projector" \ + --wandb_run_name "latent7_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_latent_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local cot="$4" + mkdir -p "${out_dir}" + printf '\n=== Latent stage %s GRPO (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2 + printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "${MODEL_NAME}" \ + --train_jsonl "${train_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --num_cot_tokens "${cot}" \ + --total_empties_hint "${EMPTIES}" \ + --mixed_stage1_ratio 0 \ + --mixed_stage2_ratio 1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --eval_jsonl "${eval_jsonl}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta 0.0 \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --lora_r "${LORA_R}" \ + --lora_alpha "${LORA_ALPHA}" \ + --lora_dropout "${LORA_DROPOUT}" \ + --use_wandb \ + --wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \ + --wandb_run_name "latent7_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train or eval jsonl.\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 latent SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + S2_DIR="${START_AT_STAGE2_GRPO_DIR}" + CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" + if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 + fi + G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}" + run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2 + A2="$(resolve_latent_grpo_adapter "${G2_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 + fi + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}" + run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3 + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}" + run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3 + A3="$(resolve_latent_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll latent phases finished (started at stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final latent GRPO adapter: %s\n' "${A3}" + exit 0 +fi + +if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 latent GRPO dir %s → stage-3 SFT + GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + A2="$(resolve_latent_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}" + run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3 + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}" + run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3 + A3="$(resolve_latent_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll latent phases finished (started after stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final latent GRPO adapter: %s\n' "${A3}" + exit 0 +fi + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}" + +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}" +S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}" +STAGE1_INIT="${STAGE1_INIT_ADAPTER_DIR:-}" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + A1="$(resolve_latent_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")" +elif [[ "${STAGE1_GRPO_FIRST}" == "1" ]]; then + # Legacy: stage-1 GRPO first (fresh LoRA + random residual unless STAGE1_INIT_ADAPTER_DIR set). + run_latent_grpo 1 "${STAGE1_INIT}" "${G1_DIR}" 1 + A1="$(resolve_latent_grpo_adapter "${G1_DIR}")" +else + # Default: stage-1 SFT → stage-1 GRPO (matches text post-s1sft pipeline). + if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}" + printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2 + else + run_latent_sft 1 "${STAGE1_INIT}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1 + G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")" + if [[ -z "${G1_SFT_CKPT}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2 + exit 1 + fi + fi + run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1 + A1="$(resolve_latent_grpo_adapter "${G1_DIR}")" +fi +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2 + exit 1 +fi +printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2 +CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" +if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 +fi +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2 +A2="$(resolve_latent_grpo_adapter "${G2_DIR}")" +if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3 +CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" +if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 +fi +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}" +run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3 +A3="$(resolve_latent_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll latent residual phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh b/hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh new file mode 100644 index 0000000000000000000000000000000000000000..3733e3b8447a8da86b71ee78d44c7e747a98f38b --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh @@ -0,0 +1,372 @@ +#!/usr/bin/env bash +# Run AFTER stage-1 SFT finishes (7-empty). Order: +# 1) Stage-1 GRPO (init = your stage-1 SFT adapter) +# 2) Stage-2 SFT (init = stage-1 GRPO adapter) +# 3) Stage-2 GRPO (init = stage-2 SFT adapter) +# 4) Stage-3 SFT (init = stage-2 GRPO adapter) +# 5) Stage-3 GRPO (init = stage-3 SFT adapter) +# +# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall +# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use +# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap +# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling). +# +# Required (full pipeline from stage-1 SFT): +# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX +# +# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT): +# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty +# (OUTPUT_ROOT defaults to dirname of that dir.) +# +# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO): +# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty +# +# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only): +# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty +# +# Optional: +# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=... +# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0 +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" +GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES=7 +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +VALUE_TARGET="${VALUE_TARGET:-0.98}" +# Large caps so training is gated by VALUE_TARGET, not an arbitrary step limit. +SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}" +GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}" +SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}" +GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_stages123_value98}" +START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}" +START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}" +RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}" + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}" +elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}" +elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2 + exit 1 + fi + # Keep stage02+ next to stage01 (same parent dir). Override with OUTPUT_ROOT=... if needed. + OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}" +else + if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then + printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2 + exit 1 + fi + OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}" +fi + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +mkdir -p "${OUTPUT_ROOT}" + +latest_sft_step_ckpt() { + local d="$1" + shopt -s nullglob + local cks=("${d}"/checkpoint-step-*) + shopt -u nullglob + if (( ${#cks[@]} == 0 )); then + printf '' + return 1 + fi + # Avoid set -o pipefail treating a pipeline edge case as failure (was breaking the caller). + set +o pipefail + printf '%s\n' "${cks[@]}" | sort -V | tail -n 1 + set -o pipefail +} + +# GRPO: prefer final root adapter; else latest TRL checkpoint dir with adapter weights. +resolve_grpo_adapter() { + local d="$1" + if [[ -f "${d}/adapter_model.safetensors" ]]; then + printf '%s\n' "${d}" + return 0 + fi + local best="" step=-1 + shopt -s nullglob + local c + for c in "${d}"/checkpoint-*; do + [[ -d "${c}" ]] || continue + [[ -f "${c}/adapter_model.safetensors" ]] || continue + local n + n="${c##*checkpoint-}" + if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then + step=$((10#${n})) + best="${c}" + fi + done + shopt -u nullglob + if [[ -n "${best}" ]]; then + printf '%s\n' "${best}" + return 0 + fi + printf '' + return 1 +} + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) +fi + +run_sft() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + local lr="$4" + mkdir -p "${out_dir}" + printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs "${SFT_NUM_EPOCHS}" \ + --learning_rate "${lr}" \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 200 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_exact_set_match_stop 0 \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${SFT_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +run_grpo() { + local stage="$1" + local init_adapter="$2" + local out_dir="$3" + mkdir -p "${out_dir}" + printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2 + printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2 + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${out_dir}" \ + --cache_dir "${ROOT}/.hf_cache" \ + --init_adapter_dir "${init_adapter}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i "${stage}" \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \ + --learning_rate 1e-6 \ + --logging_steps 20 \ + --save_steps 200 \ + --eval_steps 500 \ + --eval_rows "${EVAL_PUZZLES}" \ + --num_generations 4 \ + --max_prompt_length 1024 \ + --max_completion_length 24 \ + --beta 0.0 \ + --enable_gradient_checkpointing \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --reward_good_value 1.25 \ + --penalty_bad_value 1.0 \ + --penalty_malformed 4.0 \ + --penalty_empty 0.5 \ + --penalty_singleton 1.5 \ + --eval_value_precision_stop "${VALUE_TARGET}" \ + --eval_value_recall_stop "${VALUE_TARGET}" \ + --eval_solve_rate_stop 0 \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${GRPO_MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-grpo" \ + --wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" +} + +if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then + printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2 + printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2 + exit 1 +fi + +if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + S2_DIR="${START_AT_STAGE2_GRPO_DIR}" + CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" + if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 + fi + printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2 + G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" + run_grpo 2 "${CKPT_S2}" "${G2_DIR}" + A2="$(resolve_grpo_adapter "${G2_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 + fi + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started at stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then + printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" + A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")" + if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2 + exit 1 + fi + printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2 + S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" + run_sft 3 "${A2}" "${S3_DIR}" "5e-5" + CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" + if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 + fi + G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" + run_grpo 3 "${CKPT_S3}" "${G3_DIR}" + A3="$(resolve_grpo_adapter "${G3_DIR}")" + if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 + fi + printf '\nAll phases finished (started after stage-2 GRPO).\n' + printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" + printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" + exit 0 +fi + +printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" +else + printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" +fi +printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \ + "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}" + +# --- Stage 1 GRPO (skip if resuming) --- +G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty" +if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then + A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")" +else + run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}" + A1="$(resolve_grpo_adapter "${G1_DIR}")" +fi +if [[ -z "${A1}" ]]; then + printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2 + exit 1 +fi +printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}" + +# --- Stage 2 SFT + GRPO --- +S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty" +run_sft 2 "${A1}" "${S2_DIR}" "5e-5" +CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")" +if [[ -z "${CKPT_S2}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2 + exit 1 +fi +G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty" +run_grpo 2 "${CKPT_S2}" "${G2_DIR}" +A2="$(resolve_grpo_adapter "${G2_DIR}")" +if [[ -z "${A2}" ]]; then + printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2 + exit 1 +fi + +# --- Stage 3 SFT + GRPO --- +S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty" +run_sft 3 "${A2}" "${S3_DIR}" "5e-5" +CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")" +if [[ -z "${CKPT_S3}" ]]; then + printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2 + exit 1 +fi +G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty" +run_grpo 3 "${CKPT_S3}" "${G3_DIR}" +A3="$(resolve_grpo_adapter "${G3_DIR}")" +if [[ -z "${A3}" ]]; then + printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2 + exit 1 +fi + +printf '\nAll phases finished.\n' +printf 'Outputs under: %s\n' "${OUTPUT_ROOT}" +printf 'Final GRPO adapter (stage 3): %s\n' "${A3}" diff --git a/hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh b/hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh new file mode 100644 index 0000000000000000000000000000000000000000..bdb6a9cda53ff6b97eacdf2ae119163a8d379ba0 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh @@ -0,0 +1,113 @@ +#!/usr/bin/env bash +# Stage-1 SFT only (no GRPO): train until eval value_precision AND value_recall both >= SFT_TARGET +# (default 0.95), or max_steps / optional wall clock. Use this before tuning GRPO. +# +# Fresh LoRA on base model: +# ./launch_sft_stage1_95p.sh +# EMPTIES=10 ./launch_sft_stage1_95p.sh +# +# Continue from a prior SFT checkpoint (recommended after a 90% run plateaus): +# INIT_ADAPTER_DIR=/path/to/checkpoint-step-01200 ./launch_sft_stage1_95p.sh +# +# Optional W&B display name: +# WANDB_RUN_NAME=my_run_name SFT_TARGET=0.99 MAX_STEPS=30000 INIT_ADAPTER_DIR=... ./launch_sft_stage1_95p.sh +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py" +SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" + +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +WANDB_MODE="${WANDB_MODE:-online}" +WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}" + +EMPTIES="${EMPTIES:-7}" +TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}" +# Held-out eval: 100 rows lines up with the baseline queue and stabilizes 95% gate vs tiny eval. +EVAL_PUZZLES="${EVAL_PUZZLES:-100}" +SFT_TARGET="${SFT_TARGET:-0.95}" +PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}" +MAX_STEPS="${MAX_STEPS:-12000}" + +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_sft95_stage1}" +OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft95}" + +train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl" +eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl" + +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" + +if [[ ! -f "${train_jsonl}" ]]; then + mkdir -p "$(dirname "${train_jsonl}")" + printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles 10000 --empties "${EMPTIES}" --seed 0 +fi +if [[ ! -f "${eval_jsonl}" ]]; then + mkdir -p "$(dirname "${eval_jsonl}")" + printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}" + "${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1 +fi + +mkdir -p "${OUTPUT_DIR}" + +INIT_FLAGS=() +if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then + INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}") + printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}" +fi + +GC_FLAGS=() +if [[ "${USE_GC:-0}" == "1" ]]; then + GC_FLAGS+=(--enable_gradient_checkpointing) + printf 'NOTE: USE_GC=1 — slower, less VRAM.\n' +fi + +if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}" +else + printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}" +fi +printf 'Output: %s\n' "${OUTPUT_DIR}" + +exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \ + --model_name "Qwen/Qwen2.5-0.5B-Instruct" \ + --train_jsonl "${train_jsonl}" \ + --eval_jsonl "${eval_jsonl}" \ + --output_dir "${OUTPUT_DIR}" \ + --cache_dir "${ROOT}/.hf_cache" \ + "${INIT_FLAGS[@]}" \ + --seed 0 \ + --gpu_id 0 \ + --stage_i 1 \ + --total_empties_hint "${EMPTIES}" \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 2 \ + --num_epochs 24.0 \ + --learning_rate 2e-4 \ + --max_grad_norm 1.0 \ + "${GC_FLAGS[@]}" \ + --logging_steps 20 \ + --eval_steps 250 \ + --save_steps 100 \ + --eval_rows "${EVAL_PUZZLES}" \ + --max_completion_length 24 \ + --limit_train_rows "${TRAIN_PUZZLES}" \ + --lora_r 32 \ + --lora_alpha 64 \ + --lora_dropout 0.05 \ + --eval_value_precision_stop "${SFT_TARGET}" \ + --eval_value_recall_stop "${SFT_TARGET}" \ + --min_steps_before_stop 50 \ + --max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \ + --max_steps "${MAX_STEPS}" \ + --use_wandb \ + --wandb_project "sudoku-multi-output-sft" \ + --wandb_run_name "${WANDB_RUN_NAME:-stage01_sft95_i1_${EMPTIES}empty_${RUN_TAG}}" \ + --wandb_mode "${WANDB_MODE}" \ + --wandb_entity "${WANDB_ENTITY}" diff --git a/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md b/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md new file mode 100644 index 0000000000000000000000000000000000000000..4e7c5a2f45492fc353b8cbdf755f4606756f80f2 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md @@ -0,0 +1,68 @@ +# Recurrent Hidden Stage 2 Resume Summary + +This note summarizes the May 16, 2026 stage-2 recurrent-hidden SFT recovery and monitoring changes. + +## What Changed + +- Added explicit eval lifecycle logging in `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`. +- W&B now receives `eval/in_progress`, `eval/rows`, and `eval/duration_seconds`. +- Local logs now print `[latent sft eval start ...]` and `[latent sft eval end ...]` markers. + +## Why + +The previous resumed stage-2 run reached step 2000 but appeared silent during validation. The validation metrics only logged after the whole eval completed, and the old `eval_rows=100` setting made a single validation take roughly 35 minutes. The run then crashed before producing the step-2000 eval metrics or checkpoint. + +## Probe Result + +A one-GPU eval probe from `checkpoint-step-01800` measured validation cost: + +- Eval rows: 20 puzzles +- Eval duration: 427.3 seconds, about 7.1 minutes +- Exact set match: 0.9225 +- Value precision: 0.945 +- Value recall: 0.934 +- Solve rate: 0.15 +- W&B run: `xudqbjqh` + +## Active Resume Run + +The main run was restarted from: + +`final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_from200_20260515_205857/latent_recurrent_hidden/stage02_latent_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01800` + +Run settings: + +- Stage: 2 +- Latent mode: recurrent_hidden +- GPUs: 8 +- Eval rows: 20 +- Eval interval: every 100 steps +- Checkpoint interval: every 100 steps +- Max steps: 5000 +- Early stop: disabled for solve rate; precision and recall target set to 0.9999 +- W&B run: `h3lxi62v` + +At the first eval: + +- Step: 100 +- Eval duration: 427.3 seconds +- Exact set match: 0.935 +- Value precision: 0.95875 +- Value recall: 0.94875 +- Solve rate: 0.25 + +## Checkpoint Sync + +The run output is periodically synced to Hugging Face every 10 minutes: + +`Avra98/sudoku-latent-recurrent-hidden-20empty-stages/resume_runs/recurrent_hidden_resume_stage2sft_from1800_eval20_long_20260516_090446` + +Confirmed uploaded checkpoint: + +- `checkpoint-step-00100/adapter_model.safetensors` +- `checkpoint-step-00100/adapter_config.json` +- `checkpoint-step-00100/tokenizer.json` +- `checkpoint-step-00100/tokenizer_config.json` +- `checkpoint-step-00100/chat_template.jinja` +- `checkpoint-step-00100/README.md` + diff --git a/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md b/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md new file mode 100644 index 0000000000000000000000000000000000000000..c3a75c3e9f905f0953e05e5f49ebf7495a659d63 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md @@ -0,0 +1,83 @@ +# Recurrent-Hidden 20-Empty Stage-2 SFT Resume + +This note records the recovered recurrent-hidden run restarted on May 15, 2026. + +## Source Checkpoints + +Recovered adapters were downloaded from: + +```text +https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages +``` + +Local snapshot path: + +```text +/home/ubuntu/curriculum_cot/final_checkpoint/hf_sudoku_latent_recurrent_hidden_20empty_stages +``` + +Available recovered folders: + +```text +stage01_latent_sft_i1_20empty_latent_recurrent_hidden +stage01_latent_grpo_i1_20empty_latent_recurrent_hidden +stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden +stage02_latent_sft_i2_20empty_latent_recurrent_hidden +``` + +The uploaded stage-2 latent SFT checkpoint did not include `trainer_state.json` +or solve-rate metadata, so the restart intentionally resumes from the stage-2 +baseline warm-up adapter and reruns stage-2 latent SFT instead of jumping to +stage-2 GRPO. + +## Active Resume Run + +Output root: + +```text +/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858 +``` + +W&B run: + +```text +https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1vyq1a1n +``` + +Launch settings: + +```text +MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct +MODES_SPEC=recurrent_hidden +GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7 +NPROC_PER_JOB=8 +STAGE1_LATENT_GRPO_ADAPTER_DIR=/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden +STAGE2_BASELINE_WARM_ADAPTER_DIR=/stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01000 +LATENT_SFT_MAX_STEPS=5000 +LATENT_GRPO_MAX_STEPS=500 +SOLVE_TARGET=0.95 +VALUE_TARGET=0 +MIN_STEPS_BEFORE_STOP=50 +WANDB_MODE=online +WANDB_ENTITY=training-dynamics +``` + +## Backup Plan + +Code changes are pushed to GitHub branch: + +```text +llm-policy-icon-code +``` + +Checkpoint backups should be pushed periodically to the same Hugging Face repo +using: + +```bash +HF_TOKEN=hf_xxx \ +RUN_OUTPUT_DIR=/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858 \ +bash hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh +``` + +The sync script uploads checkpoint folders, adapter files, tokenizer files, and +logs while ignoring W&B runtime directories and prepared-data caches. diff --git a/hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh b/hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh new file mode 100644 index 0000000000000000000000000000000000000000..5324cc8e5cce65e60cc370c4c9a2c64e7e290801 --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# Periodically upload the active recurrent-hidden resume output to Hugging Face. +# +# Required: +# RUN_OUTPUT_DIR=/path/to/recurrent_hidden_resume_stage2sft_... +# +# Optional: +# HF_TOKEN=hf_... # otherwise uses `hf auth login` / cached login +# HF_REPO_ID=Avra98/sudoku-latent-recurrent-hidden-20empty-stages +# HF_REPO_PREFIX=resume_runs/ +# SYNC_INTERVAL_SECONDS=900 + +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +HF_REPO_ID="${HF_REPO_ID:-Avra98/sudoku-latent-recurrent-hidden-20empty-stages}" +RUN_OUTPUT_DIR="${RUN_OUTPUT_DIR:-}" +SYNC_INTERVAL_SECONDS="${SYNC_INTERVAL_SECONDS:-900}" + +if [[ -z "${RUN_OUTPUT_DIR}" ]] || [[ ! -d "${RUN_OUTPUT_DIR}" ]]; then + printf 'ERROR: Set RUN_OUTPUT_DIR to an existing run output directory.\n' >&2 + exit 1 +fi + +if [[ ! -x "${PYTHON_BIN}" ]]; then + printf 'ERROR: Python not found at %s\n' "${PYTHON_BIN}" >&2 + exit 1 +fi + +RUN_NAME="$(basename "${RUN_OUTPUT_DIR}")" +HF_REPO_PREFIX="${HF_REPO_PREFIX:-resume_runs/${RUN_NAME}}" +export HF_REPO_ID RUN_OUTPUT_DIR HF_REPO_PREFIX + +upload_once() { + "${PYTHON_BIN}" - <<'PY' +import os +from pathlib import Path +from huggingface_hub import HfApi, get_token + +repo_id = os.environ["HF_REPO_ID"] +folder = Path(os.environ["RUN_OUTPUT_DIR"]).resolve() +path_in_repo = os.environ["HF_REPO_PREFIX"].strip("/") + +token = os.environ.get("HF_TOKEN") or get_token() +if not token: + raise SystemExit("No Hugging Face token found. Run `hf auth login` or set HF_TOKEN.") + +api = HfApi(token=token) +api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=str(folder), + path_in_repo=path_in_repo, + commit_message=f"Sync recurrent-hidden resume checkpoints: {folder.name}", + allow_patterns=[ + "logs/**", + "**/checkpoint*/**", + "**/adapter_config.json", + "**/adapter_model.safetensors", + "**/tokenizer.json", + "**/tokenizer_config.json", + "**/chat_template.jinja", + "**/README.md", + "**/training_args.bin", + ], + ignore_patterns=[ + "**/wandb_runtime/**", + "**/.wandb/**", + "**/wandb/**", + "**/optimizer.pt", + "**/scheduler.pt", + "**/rng_state_*.pth", + ], +) +print(f"Uploaded {folder} to {repo_id}/{path_in_repo}") +PY +} + +while true; do + date -Is + upload_once + sleep "${SYNC_INTERVAL_SECONDS}" +done diff --git a/hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md b/hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md new file mode 100644 index 0000000000000000000000000000000000000000..e72f4bfad28fb410fd6bcb86922714e56bea16da --- /dev/null +++ b/hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md @@ -0,0 +1,65 @@ +# Warm Baseline All-Latent Stages 1-3 Results + +Run tag: `warmbaseline_alllatent_stages123_20260512_1620` + +Base model: `Qwen/Qwen2.5-1.5B-Instruct` + +Stage-1 warm baseline adapter: + +```text +/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_baseline_1p5b_warmup/baseline_1p5b_warmup_bs32_eval100_20260512_203845/20empty/stage01_sft_i1_20empty_1p5b_warmup/checkpoint-step-01000 +``` + +This file records the solve-rate snapshot from the ongoing full pipeline. Later +stages should be updated when all modes finish. + +## Current Phase Snapshot + +| Mode | Current phase at snapshot | +| --- | --- | +| `residual` | Stage-2 latent SFT | +| `fixed_slots` | Stage-2 latent SFT | +| `recurrent_hidden` | Stage-2 baseline warm-up SFT | +| `latent_seeds` | Stage-3 baseline warm-up SFT | + +## Latest Solve Rates By Phase + +| Mode | Stage 1 latent SFT | Stage 1 latent GRPO | Stage 2 baseline warm-up | Stage 2 latent SFT | Stage 2 latent GRPO | Stage 3 baseline warm-up | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| `residual` | 0.470 latest / 0.610 best | 0.620 | 0.110 latest / 0.150 best | in progress | not reached | not reached | +| `fixed_slots` | 0.770 latest / 0.770 best | 0.870 | 0.140 latest / 0.140 best | 0.100 latest / 0.100 best | not reached | not reached | +| `recurrent_hidden` | 0.860 latest / 0.860 best | 0.950 | 0.110 latest / 0.110 best | not reached | not reached | not reached | +| `latent_seeds` | 0.740 latest / 0.740 best | 0.860 | 0.090 latest / 0.100 best | 0.120 latest / 0.120 best | 0.090 | started, no eval yet | + +## Stage 1 Solve Trajectories + +| Mode | Latent SFT solve rates | Post-GRPO solve rate | +| --- | --- | ---: | +| `residual` | 0.320 -> 0.610 -> 0.520 -> 0.470 | 0.620 | +| `fixed_slots` | 0.650 -> 0.200 -> 0.660 -> 0.770 | 0.870 | +| `recurrent_hidden` | 0.400 -> 0.600 -> 0.800 -> 0.860 | 0.950 | +| `latent_seeds` | 0.290 -> 0.500 -> 0.640 -> 0.740 | 0.860 | + +## Stage 2 Solve Trajectories So Far + +| Mode | Baseline warm-up solve rates | Latent SFT solve rates | Post-GRPO solve rate | +| --- | --- | --- | ---: | +| `residual` | 0.050 -> 0.150 -> 0.110 -> 0.110 | in progress | not reached | +| `fixed_slots` | 0.090 -> 0.120 -> 0.080 -> 0.140 | 0.080 -> 0.100 | not reached | +| `recurrent_hidden` | 0.060 -> 0.090 -> 0.100 -> 0.110 | not reached | not reached | +| `latent_seeds` | 0.090 -> 0.100 -> 0.080 -> 0.090 | 0.080 -> 0.090 -> 0.110 -> 0.120 | 0.090 | + +## W&B Links + +Stage 1 latent SFT: + +- `residual`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/sp4seb59 +- `fixed_slots`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/d62aiu1g +- `recurrent_hidden`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/cv3nr7ie +- `latent_seeds`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1f818jfg + +Additional stage runs are logged under: + +- SFT project: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline +- GRPO project: https://wandb.ai/training-dynamics/sudoku-latent-stage-grpo-warm-baseline +- Baseline warm-up project: https://wandb.ai/training-dynamics/sudoku-baseline-stage-warmups diff --git a/large_baseline_extension/README.md b/large_baseline_extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae20090e90caa051237c698d13d19955b1e837cb --- /dev/null +++ b/large_baseline_extension/README.md @@ -0,0 +1,35 @@ +# Large Baseline Extension Launchers + +This folder contains launch scripts for the non-location baseline multi-output runs. + +- `launch_nonlocation_pipeline.sh` +- `launch_nonlocation_sft.sh` +- `launch_nonlocation_grpo.sh` + +The main entry point for a full staged resume run is `launch_nonlocation_pipeline.sh`. + +Useful environment variables: + +- `MIN_STAGE` +- `MAX_STAGE` +- `NUM_PROCESSES` +- `GPU_IDS` +- `BOOTSTRAP_ADAPTER_DIR` +- `OUTPUT_ROOT` +- `RUN_TAG` +- `LIMIT_TRAIN_ROWS` +- `WANDB_MODE` +- `WANDB_ENTITY` + +Example: + +```bash +MIN_STAGE=3 \ +MAX_STAGE=5 \ +NUM_PROCESSES=8 \ +GPU_IDS=0,1,2,3,4,5,6,7 \ +BOOTSTRAP_ADAPTER_DIR=/path/to/stage02_grpo \ +WANDB_MODE=online \ +WANDB_ENTITY=training-dynamics \ +bash launch_nonlocation_pipeline.sh +``` diff --git a/large_baseline_extension/launch_nonlocation_grpo.sh b/large_baseline_extension/launch_nonlocation_grpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..38656596a2aa7b9f5873b0eec7b272cbbd631eaa --- /dev/null +++ b/large_baseline_extension/launch_nonlocation_grpo.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +TRAINER="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_ID="${GPU_ID:-0}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +STAGE_I="${STAGE_I:-2}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo}" +OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}" +WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-grpo}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_grpo_i${STAGE_I}_${RUN_TAG}}" +WANDB_GROUP="${WANDB_GROUP:-large_baseline_extension_noloc_grpo_i${STAGE_I}}" + +case "${STAGE_I}" in + 2) default_bs=4; default_gas=2 ;; + 3) default_bs=3; default_gas=2 ;; + 4) default_bs=2; default_gas=4 ;; + 5) default_bs=2; default_gas=4 ;; + *) default_bs=2; default_gas=4 ;; +esac + +mkdir -p "${OUTPUT_DIR}" +export CUDA_DEVICE_ORDER=PCI_BUS_ID +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +else + export CUDA_VISIBLE_DEVICES="${GPU_ID}" +fi +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}" + ) +else + cmd=( + "${PYTHON_BIN}" -u "${TRAINER}" + ) +fi + +cmd+=( + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --output_dir "${OUTPUT_DIR}" + --init_adapter_dir "${INIT_ADAPTER_DIR:?INIT_ADAPTER_DIR must be set}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" + --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" + --num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}" + --learning_rate "${LEARNING_RATE:-1e-6}" + --logging_steps "${LOGGING_STEPS:-5}" + --save_steps "${SAVE_STEPS:-25}" + --eval_steps "${EVAL_STEPS:-25}" + --eval_rows "${EVAL_ROWS:-20}" + --num_generations "${NUM_GENERATIONS:-2}" + --max_prompt_length "${MAX_PROMPT_LENGTH:-1024}" + --max_completion_length "${MAX_COMPLETION_LENGTH:-24}" + --beta "${BETA:-0.0}" + --enable_gradient_checkpointing + --wandb_project "${WANDB_PROJECT}" + --wandb_run_name "${WANDB_RUN_NAME}" + --wandb_group "${WANDB_GROUP}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${MAX_STEPS:-}" ]]; then + cmd+=(--max_steps "${MAX_STEPS}") +fi + +if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then + cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}") +fi + +printf 'Launching hard 9x9 baseline GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}" +printf 'stage_i=%s batch=%s grad_accum=%s processes=%s\n' \ + "${STAGE_I}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/large_baseline_extension/launch_nonlocation_pipeline.sh b/large_baseline_extension/launch_nonlocation_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7859250500080d8e5f6531870495978a23a3c7b --- /dev/null +++ b/large_baseline_extension/launch_nonlocation_pipeline.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_30empty_4stage_hard9x9}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${CHECKPOINT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --grpo_eval_solve_rate_stop "${GRPO_EVAL_SOLVE_RATE_STOP:-0.8}" + --grpo_min_steps_before_stop "${GRPO_MIN_STEPS_BEFORE_STOP:-100}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching hard 9x9 baseline pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/large_baseline_extension/launch_nonlocation_sft.sh b/large_baseline_extension/launch_nonlocation_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..05fdf82feef383ebe946bae42915c8533cb5ab12 --- /dev/null +++ b/large_baseline_extension/launch_nonlocation_sft.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +TRAINER="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_ID="${GPU_ID:-0}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +STAGE_I="${STAGE_I:-2}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft}" +OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}" +WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-sft}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_sft_i${STAGE_I}_${RUN_TAG}}" + +mkdir -p "${OUTPUT_DIR}" +export CUDA_DEVICE_ORDER=PCI_BUS_ID +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +else + export CUDA_VISIBLE_DEVICES="${GPU_ID}" +fi +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}" + ) +else + cmd=( + "${PYTHON_BIN}" -u "${TRAINER}" + ) +fi + +cmd+=( + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --output_dir "${OUTPUT_DIR}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-4}" + --num_epochs "${NUM_EPOCHS:-1.0}" + --learning_rate "${LEARNING_RATE:-2e-4}" + --weight_decay "${WEIGHT_DECAY:-0.0}" + --enable_gradient_checkpointing + --logging_steps "${LOGGING_STEPS:-10}" + --save_steps "${SAVE_STEPS:-100}" + --eval_steps "${EVAL_STEPS:-100}" + --eval_rows "${EVAL_ROWS:-20}" + --max_completion_length "${MAX_COMPLETION_LENGTH:-24}" + --wandb_project "${WANDB_PROJECT}" + --wandb_run_name "${WANDB_RUN_NAME}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--init_adapter_dir "${INIT_ADAPTER_DIR}") +fi + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${MAX_STEPS:-}" ]]; then + cmd+=(--max_steps "${MAX_STEPS}") +fi + +printf 'Launching hard 9x9 baseline SFT on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'Stage=%s processes=%s\n' "${STAGE_I}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/large_latent_extension/README.md b/large_latent_extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c36cf37096313b1c1659e44a83d2a6cb4230fc4f --- /dev/null +++ b/large_latent_extension/README.md @@ -0,0 +1,32 @@ +# Large Latent Extension Launchers + +This folder contains the launch scripts for the non-location latent CoT runs. + +- `launch_nonlocation_sft.sh` +- `launch_nonlocation_grpo.sh` + +These are the scripts used for the distributed multi-GPU non-location curriculum. + +Useful environment variables: + +- `NUM_COT_TOKENS` +- `STAGE_I` +- `NUM_PROCESSES` +- `GPU_IDS` +- `INIT_ADAPTER_DIR` +- `OUTPUT_DIR` +- `LIMIT_TRAIN_ROWS` +- `WANDB_MODE` +- `WANDB_ENTITY` + +Example: + +```bash +NUM_COT_TOKENS=3 \ +STAGE_I=3 \ +NUM_PROCESSES=8 \ +GPU_IDS=0,1,2,3,4,5,6,7 \ +WANDB_MODE=online \ +WANDB_ENTITY=training-dynamics \ +bash launch_nonlocation_sft.sh +``` diff --git a/large_latent_extension/launch_nonlocation_grpo.sh b/large_latent_extension/launch_nonlocation_grpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb60519ce0c6ea3adeedfa6c56dfc52392e18fe6 --- /dev/null +++ b/large_latent_extension/launch_nonlocation_grpo.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +TRAINER="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_ID="${GPU_ID:-0}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +NUM_COT_TOKENS="${NUM_COT_TOKENS:?NUM_COT_TOKENS must be set}" +STAGE_I="${STAGE_I:-2}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo}" +OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}" +INIT_ADAPTER_DIR="${INIT_ADAPTER_DIR:-}" +WANDB_PROJECT="${WANDB_PROJECT:-sudoku-latent-multi-output-grpo-residual-projector}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_latent_noloc_grpo_i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}" +WANDB_GROUP="${WANDB_GROUP:-large_latent_extension_noloc_grpo_i${STAGE_I}}" + +case "${NUM_COT_TOKENS}" in + 2) default_bs=4; default_gas=2 ;; + 4) default_bs=2; default_gas=4 ;; + 5) default_bs=2; default_gas=4 ;; + *) default_bs=2; default_gas=4 ;; +esac + +mkdir -p "${OUTPUT_DIR}" +export CUDA_DEVICE_ORDER=PCI_BUS_ID +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +else + export CUDA_VISIBLE_DEVICES="${GPU_ID}" +fi +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}" + ) +else + cmd=( + "${PYTHON_BIN}" -u "${TRAINER}" + ) +fi + +cmd+=( + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --output_dir "${OUTPUT_DIR}" + --init_adapter_dir "${INIT_ADAPTER_DIR}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --num_cot_tokens "${NUM_COT_TOKENS}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" + --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" + --num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}" + --learning_rate "${LEARNING_RATE:-7e-7}" + --logging_steps "${LOGGING_STEPS:-5}" + --save_steps "${SAVE_STEPS:-10}" + --eval_steps "${EVAL_STEPS:-25}" + --eval_rows "${EVAL_ROWS:-20}" + --num_generations "${NUM_GENERATIONS:-2}" + --max_prompt_length "${MAX_PROMPT_LENGTH:-1024}" + --max_completion_length "${MAX_COMPLETION_LENGTH:-32}" + --beta "${BETA:-0.01}" + --enable_gradient_checkpointing + --wandb_project "${WANDB_PROJECT}" + --wandb_run_name "${WANDB_RUN_NAME}" + --wandb_group "${WANDB_GROUP}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then + cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}") +fi + +printf 'Launching hard 9x9 latent GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}" +printf 'num_cot_tokens=%s batch=%s grad_accum=%s stage_i=%s num_processes=%s\n' \ + "${NUM_COT_TOKENS}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${STAGE_I}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/large_latent_extension/launch_nonlocation_pipeline.sh b/large_latent_extension/launch_nonlocation_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..373f4b51ac35fd9646a117307342debf42ffa171 --- /dev/null +++ b/large_latent_extension/launch_nonlocation_pipeline.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${ROOT}/run_latent_residual_projector_pipeline.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/latent_pipeline_30empty_4stage_hard9x9}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${CHECKPOINT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching hard 9x9 latent pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/large_latent_extension/launch_nonlocation_sft.sh b/large_latent_extension/launch_nonlocation_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..3487c94f2bf58c0075136834861cab173a375b4d --- /dev/null +++ b/large_latent_extension/launch_nonlocation_sft.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +TRAINER="${ROOT}/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_ID="${GPU_ID:-0}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +NUM_COT_TOKENS="${NUM_COT_TOKENS:?NUM_COT_TOKENS must be set}" +STAGE_I="${STAGE_I:-2}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft}" +OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}" +INIT_ADAPTER_DIR="${INIT_ADAPTER_DIR:-}" +WANDB_PROJECT="${WANDB_PROJECT:-sudoku-latent-multi-output-sft-residual-projector}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_latent_noloc_sft_i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}" +WANDB_GROUP="${WANDB_GROUP:-large_latent_extension_noloc_sft_i${STAGE_I}}" + +mkdir -p "${OUTPUT_DIR}" +export CUDA_DEVICE_ORDER=PCI_BUS_ID +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +else + export CUDA_VISIBLE_DEVICES="${GPU_ID}" +fi +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +if [[ "${NUM_PROCESSES}" -gt 1 ]]; then + cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}" + ) +else + cmd=( + "${PYTHON_BIN}" -u "${TRAINER}" + ) +fi + +cmd+=( + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --output_dir "${OUTPUT_DIR}" + --init_adapter_dir "${INIT_ADAPTER_DIR}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --num_cot_tokens "${NUM_COT_TOKENS}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-8}" + --num_epochs "${NUM_EPOCHS:-1.0}" + --learning_rate "${LEARNING_RATE:-1e-6}" + --weight_decay "${WEIGHT_DECAY:-0.0}" + --enable_gradient_checkpointing + --logging_steps "${LOGGING_STEPS:-10}" + --save_steps "${SAVE_STEPS:-100}" + --eval_steps "${EVAL_STEPS:-100}" + --eval_rows "${EVAL_ROWS:-20}" + --max_completion_length "${MAX_COMPLETION_LENGTH:-32}" + --wandb_project "${WANDB_PROJECT}" + --wandb_run_name "${WANDB_RUN_NAME}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +printf 'Launching hard 9x9 latent SFT on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}" +printf 'Output dir: %s\n' "${OUTPUT_DIR}" +printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}" +printf 'num_cot_tokens=%s stage_i=%s num_processes=%s\n' "${NUM_COT_TOKENS}" "${STAGE_I}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/latent_multi_output_cell_policy/__init__.py b/latent_multi_output_cell_policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe049dd27343195a95167ea280c34dfb64dbc8a --- /dev/null +++ b/latent_multi_output_cell_policy/__init__.py @@ -0,0 +1 @@ +"""Recovered package wrapper; original source is currently unavailable.""" diff --git a/latent_multi_output_cell_policy/grpo_multimode_latent_train.py b/latent_multi_output_cell_policy/grpo_multimode_latent_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8062c9411026ccc5280b1630db2b53187796ac --- /dev/null +++ b/latent_multi_output_cell_policy/grpo_multimode_latent_train.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +"""GRPO entrypoint with latent-mode dispatch aligned to latent SFT. + +This script intentionally reuses ``grpo_residual_projector_latent_train`` after +the underlying module was updated to route GRPO forward/generation through the +selected ``--latent_mode``. Keeping this thin entrypoint makes the experiment +name explicit while preserving the existing CLI and checkpoint format. + +Supported modes match ``sft_latent_multi_output_train.py``: + +* ``residual`` +* ``fixed_slots`` +* ``recurrent_hidden`` +* ``latent_seeds`` +""" + +import os +import sys + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import main + + +if __name__ == "__main__": + main() diff --git a/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py b/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py new file mode 100644 index 0000000000000000000000000000000000000000..bd317e5e4ba759c6fdd2dbe0bbe6a42e2ca63777 --- /dev/null +++ b/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py @@ -0,0 +1,2058 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import sys +import time +from types import MethodType +from dataclasses import dataclass +from typing import Any, Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset +from peft import LoraConfig, PeftModel, get_peft_model +from safetensors.torch import load_file as load_safetensors_file +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed +from transformers.modeling_outputs import CausalLMOutput + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_model_artifacts +from mixed_curriculum_cot.runtime_mixed_curriculum import build_two_stage_mixed_rows, training_stage_i_for_row +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.rewards import score_prediction_text +from multi_output_cell_policy.shared_multi_output_policy import make_solved_grid_from_row + + +try: + import wandb +except Exception: + wandb = None + + +PROJECTOR_HIDDEN = 4096 + + +@dataclass +class Args: + model_name: str + train_jsonl: str + train_jsonl_stage1: str + train_jsonl_stage2: str + eval_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + num_cot_tokens: int + latent_mode: str + max_latent_seeds: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_train_epochs: float + learning_rate: float + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + num_generations: int + max_prompt_length: int + max_completion_length: int + beta: float + enable_gradient_checkpointing: bool + lora_r: int + lora_alpha: int + lora_dropout: float + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + wandb_group: str + wandb_run_id: str + debug_print_limit: int + limit_train_rows: int + mixed_stage1_ratio: float + mixed_stage2_ratio: float + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + max_wall_clock_seconds: int + max_steps: int + resume_from_checkpoint: str + eval_value_precision_stop: float + eval_value_recall_stop: float + eval_solve_rate_stop: float + min_steps_before_stop: int + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available(): + try: + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + if int(getattr(props, "major", 0)) >= 8: + return torch.bfloat16 + except Exception: + pass + return torch.float16 + + +def ensure_trl_fsdp_compat() -> None: + try: + import torch.distributed.fsdp as fsdp + + if not hasattr(fsdp, "FSDPModule") and hasattr(fsdp, "FullyShardedDataParallel"): + fsdp.FSDPModule = fsdp.FullyShardedDataParallel + except Exception: + pass + + +def resolve_lora_hparams(base_model: torch.nn.Module, *, lora_r: int, lora_alpha: int) -> tuple[int, int]: + """Resolve sentinel LoRA hyperparameters. + + ``lora_r <= 0`` means "full-rank for hidden-width projections", implemented + as ``config.hidden_size``. This is not full model fine-tuning, but it makes + the LoRA update full-rank for the main hidden-width target matrices. + + ``lora_alpha <= 0`` tracks the usual local convention ``alpha = 2 * r``. + """ + resolved_r = int(lora_r) + if resolved_r <= 0: + config = getattr(base_model, "config", None) + hidden_size = int(getattr(config, "hidden_size", 0) or getattr(config, "n_embd", 0) or 0) + if hidden_size <= 0: + raise ValueError("Cannot resolve full-rank LoRA: model config has no hidden_size/n_embd.") + resolved_r = hidden_size + resolved_alpha = int(lora_alpha) + if resolved_alpha <= 0: + resolved_alpha = 2 * resolved_r + return resolved_r, resolved_alpha + + +def load_trainable_adapter( + base_model: torch.nn.Module, + adapter_dir: str, + *, + lora_r: int = 128, + lora_alpha: int = 256, + lora_dropout: float = 0.05, +) -> torch.nn.Module: + if not str(adapter_dir).strip(): + resolved_r, resolved_alpha = resolve_lora_hparams( + base_model, + lora_r=int(lora_r), + lora_alpha=int(lora_alpha), + ) + lora = LoraConfig( + r=resolved_r, + lora_alpha=resolved_alpha, + lora_dropout=float(lora_dropout), + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + ) + return get_peft_model(base_model, lora) + try: + return PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=True) + except Exception: + config_path = os.path.join(adapter_dir, "adapter_config.json") + model_path = os.path.join(adapter_dir, "adapter_model.safetensors") + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + lora = LoraConfig( + r=int(cfg["r"]), + lora_alpha=int(cfg["lora_alpha"]), + lora_dropout=float(cfg["lora_dropout"]), + bias=str(cfg.get("bias", "none")), + task_type=str(cfg.get("task_type", "CAUSAL_LM")), + target_modules=list(cfg["target_modules"]), + ) + model = get_peft_model(base_model, lora) + state = load_safetensors_file(model_path) + remapped: Dict[str, torch.Tensor] = {} + for key, value in state.items(): + new_key = key.replace(".lora_A.weight", ".lora_A.default.weight") + new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") + remapped[new_key] = value + model.load_state_dict(remapped, strict=False) + return model + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def build_grpo_dataset( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> Dataset: + records: List[Dict[str, Any]] = [] + total_rows = len(rows) + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + row_stage_i = training_stage_i_for_row(row, stage_i) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=row_stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + records.append( + { + "prompt": prompt, + "grid_json": json.dumps(ex.grid.tolist()), + "solved_json": json.dumps(solved.tolist()), + "target_row": int(ex.target_cell[0]), + "target_col": int(ex.target_cell[1]), + "stage_i": int(row_stage_i), + } + ) + if progress_callback is not None and ( + row_idx == total_rows or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx=row_idx, total_rows=total_rows, record_count=len(records)) + return Dataset.from_list(records) + + +def _prepared_data_dir() -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "latent_multi_output_cell_policy") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_grpo_cache_path(args: Args) -> str: + payload = { + "kind": "grpo", + "train_jsonl": os.path.abspath(args.train_jsonl), + "train_jsonl_stage1": os.path.abspath(args.train_jsonl_stage1 or args.train_jsonl), + "train_jsonl_stage2": os.path.abspath(args.train_jsonl_stage2 or args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "mixed_stage1_ratio": float(args.mixed_stage1_ratio), + "mixed_stage2_ratio": float(args.mixed_stage2_ratio), + "model_name": str(args.model_name), + } + digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:20] + return os.path.join(_prepared_data_dir(), f"grpo_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=True) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_seconds: int = 6 * 60 * 60) -> None: + start = time.time() + while not os.path.exists(path): + if (time.time() - start) > float(timeout_seconds): + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_grpo_records( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_grpo_cache_path(args) + if os.path.exists(cache_path): + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] loading prepared cache: {cache_path}", flush=True) + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + dataset = build_grpo_dataset( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + records = [dataset[int(i)] for i in range(len(dataset))] + _write_jsonl(cache_path, records) + elif world_size > 1: + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +def load_training_rows(args: Args) -> List[Dict[str, Any]]: + stage1_path = str(args.train_jsonl_stage1 or "").strip() + stage2_path = str(args.train_jsonl_stage2 or "").strip() + use_mixed = bool(stage1_path or stage2_path) + if not use_mixed: + return load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + + stage1_rows = load_jsonl_rows(stage1_path or args.train_jsonl, limit_rows=0) + stage2_rows = load_jsonl_rows(stage2_path or args.train_jsonl, limit_rows=0) + return build_two_stage_mixed_rows( + stage1_rows, + stage2_rows, + stage1_ratio=float(args.mixed_stage1_ratio), + stage2_ratio=float(args.mixed_stage2_ratio), + seed=int(args.seed), + target_stage=int(args.stage_i), + total_rows=int(args.limit_train_rows), + ) + + +def make_reward_func(args: Args): + def reward_func(completions, grid_json, solved_json, target_row, target_col, stage_i, **kwargs): + rewards: List[float] = [] + for completion, grid_s, solved_s, rr, cc, stage_val in zip( + completions, grid_json, solved_json, target_row, target_col, stage_i + ): + info = score_prediction_text( + text=str(completion), + grid=torch.tensor(json.loads(grid_s), dtype=torch.long).numpy(), + solved=torch.tensor(json.loads(solved_s), dtype=torch.long).numpy(), + target_cell=(int(rr), int(cc)), + stage_i=int(stage_val), + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + rewards.append(float(info["reward"])) + return rewards + + return reward_func + + +def unwrap_backbone(model: nn.Module) -> nn.Module: + if isinstance(model, PeftModel): + return model.get_base_model() + return model + + +def unwrap_training_model(model: Any) -> Any: + current = model + while hasattr(current, "module"): + current = current.module + return current + + +def get_input_embeddings_module(model: nn.Module) -> nn.Module: + return unwrap_backbone(model).get_input_embeddings() + + +def get_output_embeddings_module(model: nn.Module) -> nn.Module: + base = unwrap_backbone(model) + return base.get_output_embeddings() or base.lm_head + + +def get_last_hidden_state(model_output: Any) -> torch.Tensor: + hidden = getattr(model_output, "last_hidden_state", None) + if hidden is not None: + return hidden + return model_output.hidden_states[-1] + + +def run_backbone_from_embeds( + backbone: nn.Module, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + *, + output_attentions: bool = False, +): + base = unwrap_backbone(backbone) + inner = getattr(base, "model", base) + return inner( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_attentions=bool(output_attentions), + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + + +def extend_attention_mask(mask: torch.Tensor, extra_tokens: int) -> torch.Tensor: + extra = torch.ones(mask.shape[0], int(extra_tokens), dtype=mask.dtype, device=mask.device) + return torch.cat([mask, extra], dim=1) + + +def _maybe_print_attention_density(model: nn.Module, attentions: Any, *, label: str) -> None: + debug_count = int(getattr(model, "_attention_density_debug_count", 0)) + debug_limit = int(getattr(model, "_attention_density_debug_limit", 0)) + if debug_limit <= 0 or debug_count >= debug_limit or attentions is None: + return + + threshold_mult = float(getattr(model, "_attention_density_threshold_mult", 1.0)) + summaries: list[str] = [] + for layer_idx, layer_attn in enumerate(attentions): + if layer_attn is None or layer_attn.ndim != 4 or int(layer_attn.shape[0]) <= 0: + continue + probs = layer_attn[0].float().clamp_min(1e-8) + seq_len = int(probs.shape[-1]) + uniform = 1.0 / max(1, seq_len) + dense_frac = float((probs > (threshold_mult * uniform)).float().mean().item()) + entropy = -(probs * probs.log()).sum(dim=-1) + eff_support = float((entropy.exp() / max(1, seq_len)).mean().item()) + final_row = probs[:, -1, :] + final_row_entropy = -(final_row * final_row.log()).sum(dim=-1) + final_row_eff = float((final_row_entropy.exp() / max(1, seq_len)).mean().item()) + final_row_max = float(final_row.max(dim=-1).values.mean().item()) + summaries.append( + f"L{layer_idx}:dense>{threshold_mult:.1f}u={dense_frac:.3f},eff={eff_support:.3f}," + f"final_eff={final_row_eff:.3f},final_max={final_row_max:.3f}" + ) + + if summaries: + print(f"[attention density] {label} " + " | ".join(summaries), flush=True) + model._attention_density_debug_count = debug_count + 1 + + +def _maybe_print_fixed_slot_debug( + model: nn.Module, + *, + input_ids: torch.Tensor, + latent_hiddens: torch.Tensor, + final_hidden: torch.Tensor, + logits: torch.Tensor, +) -> None: + debug_count = int(getattr(model, "_fixed_slot_debug_count", 0)) + debug_limit = int(getattr(model, "_fixed_slot_debug_limit", 0)) + if debug_limit <= 0 or debug_count >= debug_limit or int(logits.shape[0]) <= 0: + return + + tokenizer = getattr(model, "_latent_debug_tokenizer", None) + topk = 1 + probs = torch.softmax(logits[0].float(), dim=-1) + top_probs, top_ids = torch.topk(probs, k=min(topk, int(probs.shape[-1])), dim=-1) + + def _fmt(ids: torch.Tensor, probs: torch.Tensor) -> str: + parts = [] + for tok_id, prob in zip(ids.tolist(), probs.tolist(), strict=True): + piece = repr(tokenizer.decode([int(tok_id)])) if tokenizer is not None else f"id={int(tok_id)}" + parts.append(f"{piece}:{float(prob):.4f}") + return " | ".join(parts) + + print( + "[fixed slot debug] " + f"token_step={debug_count} " + f"next_top1={_fmt(top_ids, top_probs)}", + flush=True, + ) + model._fixed_slot_debug_count = debug_count + 1 + + +def _maybe_print_fixed_slot_decode_debug(model: nn.Module, tokenizer: Any, next_id: torch.Tensor, step_idx: int) -> None: + debug_count = int(getattr(model, "_fixed_slot_decode_debug_count", 0)) + debug_limit = int(getattr(model, "_fixed_slot_decode_debug_limit", 0)) + if debug_limit <= 0 or debug_count >= debug_limit or int(next_id.shape[0]) <= 0: + return + piece = repr(tokenizer.decode([int(next_id[0, 0].item())])) if tokenizer is not None else f"id={int(next_id[0, 0].item())}" + print(f"[fixed slot decode] step={step_idx} next_token={piece}", flush=True) + model._fixed_slot_decode_debug_count = debug_count + 1 + + +def attach_residual_projector_modules(model: nn.Module, hidden_size: int, projector_hidden: int = PROJECTOR_HIDDEN) -> None: + if hasattr(model, "latent_projector_in") and hasattr(model, "latent_projector_out") and hasattr( + model, "special_thought_embed" + ): + return + projector_hidden = int(projector_hidden) + model.special_thought_embed = nn.Parameter(torch.randn(hidden_size) * 0.02) + model.latent_mix_logit = nn.Parameter(torch.tensor(-8.0)) + model.latent_projector_in = nn.Linear(hidden_size, projector_hidden, bias=True) + model.latent_projector_out = nn.Linear(projector_hidden, hidden_size, bias=True) + nn.init.normal_(model.special_thought_embed, std=0.02) + nn.init.xavier_uniform_(model.latent_projector_in.weight) + nn.init.zeros_(model.latent_projector_in.bias) + nn.init.xavier_uniform_(model.latent_projector_out.weight) + nn.init.zeros_(model.latent_projector_out.bias) + + +def attach_fixed_latent_slot_modules(model: nn.Module, hidden_size: int, max_latent_slots: int = 8) -> None: + if hasattr(model, "fixed_latent_slots") and hasattr(model, "fixed_final_slot_embed"): + return + max_latent_slots = max(1, int(max_latent_slots)) + model.fixed_latent_slots = nn.Parameter(torch.randn(max_latent_slots, hidden_size) * 0.02) + model.fixed_final_slot_embed = nn.Parameter(torch.randn(hidden_size) * 0.02) + nn.init.normal_(model.fixed_latent_slots, std=0.02) + nn.init.normal_(model.fixed_final_slot_embed, std=0.02) + + +def maybe_load_fixed_slot_state(model: nn.Module, path_or_dir: str) -> bool: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "fixed_slot_latent_state.pt") + if not os.path.exists(state_path): + return False + state = torch.load(state_path, map_location="cpu") + with torch.no_grad(): + if "fixed_latent_slots" in state: + model.fixed_latent_slots.copy_(state["fixed_latent_slots"].to(model.fixed_latent_slots)) + if "fixed_final_slot_embed" in state: + model.fixed_final_slot_embed.copy_(state["fixed_final_slot_embed"].to(model.fixed_final_slot_embed)) + return True + + +def maybe_load_projector_state(model: nn.Module, path_or_dir: str) -> bool: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_cot_state.pt") + if not os.path.exists(state_path): + return False + state = torch.load(state_path, map_location="cpu") + with torch.no_grad(): + for name in [ + "special_thought_embed", + "latent_mix_logit", + "latent_projector_in_weight", + "latent_projector_in_bias", + "latent_projector_out_weight", + "latent_projector_out_bias", + ]: + if name not in state: + continue + if name == "special_thought_embed": + model.special_thought_embed.copy_(state[name].to(model.special_thought_embed)) + elif name == "latent_mix_logit": + model.latent_mix_logit.copy_(state[name].to(model.latent_mix_logit)) + elif name == "latent_projector_in_weight": + model.latent_projector_in.weight.copy_(state[name].to(model.latent_projector_in.weight)) + elif name == "latent_projector_in_bias": + model.latent_projector_in.bias.copy_(state[name].to(model.latent_projector_in.bias)) + elif name == "latent_projector_out_weight": + model.latent_projector_out.weight.copy_(state[name].to(model.latent_projector_out.weight)) + elif name == "latent_projector_out_bias": + model.latent_projector_out.bias.copy_(state[name].to(model.latent_projector_out.bias)) + return True + + +def infer_projector_hidden_from_state(path_or_dir: str) -> int | None: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_cot_state.pt") + if not os.path.exists(state_path): + return None + state = torch.load(state_path, map_location="cpu") + weight = state.get("latent_projector_in_weight") + if isinstance(weight, torch.Tensor) and weight.ndim == 2: + return int(weight.shape[0]) + return None + + +def infer_fixed_slot_count_from_state(path_or_dir: str) -> int | None: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "fixed_slot_latent_state.pt") + if not os.path.exists(state_path): + return None + state = torch.load(state_path, map_location="cpu") + slots = state.get("fixed_latent_slots") + if isinstance(slots, torch.Tensor) and slots.ndim == 2: + return int(slots.shape[0]) + return None + + +def save_latent_projector_state(model: nn.Module, output_dir: str) -> None: + state = { + "special_thought_embed": model.special_thought_embed.detach().cpu(), + "latent_mix_logit": model.latent_mix_logit.detach().cpu(), + "latent_projector_in_weight": model.latent_projector_in.weight.detach().cpu(), + "latent_projector_in_bias": model.latent_projector_in.bias.detach().cpu(), + "latent_projector_out_weight": model.latent_projector_out.weight.detach().cpu(), + "latent_projector_out_bias": model.latent_projector_out.bias.detach().cpu(), + } + torch.save(state, os.path.join(output_dir, "latent_cot_state.pt")) + + +def save_fixed_slot_latent_state(model: nn.Module, output_dir: str) -> None: + state = { + "fixed_latent_slots": model.fixed_latent_slots.detach().cpu(), + "fixed_final_slot_embed": model.fixed_final_slot_embed.detach().cpu(), + } + torch.save(state, os.path.join(output_dir, "fixed_slot_latent_state.pt")) + + +def attach_latent_seed_modules(model: nn.Module, hidden_size: int, max_latent_seeds: int = 8) -> None: + """Option-2 architecture: a bank of trainable latent seed vectors m_1..m_k + that are appended to the prompt embeddings. Only LoRA + these seeds are + trained. Seeds persist across examples and (after save/load) across + curriculum stages, giving the model an explicit "latent memory" that can + carry information across complexity stages. + + Shape: ``model.latent_seed_embeds`` is ``nn.Parameter[max_latent_seeds, d]``. + """ + if hasattr(model, "latent_seed_embeds"): + return + max_latent_seeds = max(1, int(max_latent_seeds)) + model.latent_seed_embeds = nn.Parameter(torch.randn(max_latent_seeds, hidden_size) * 0.02) + nn.init.normal_(model.latent_seed_embeds, std=0.02) + + +def maybe_load_latent_seed_state(model: nn.Module, path_or_dir: str) -> bool: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_seed_state.pt") + if not os.path.exists(state_path): + return False + state = torch.load(state_path, map_location="cpu") + if "latent_seed_embeds" not in state: + return False + with torch.no_grad(): + saved = state["latent_seed_embeds"] + target = model.latent_seed_embeds + shared = min(int(saved.shape[0]), int(target.shape[0])) + if shared > 0: + target[:shared].copy_(saved[:shared].to(target)) + return True + + +def infer_latent_seed_count_from_state(path_or_dir: str) -> int | None: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_seed_state.pt") + if not os.path.exists(state_path): + return None + state = torch.load(state_path, map_location="cpu") + seeds = state.get("latent_seed_embeds") + if isinstance(seeds, torch.Tensor) and seeds.ndim == 2: + return int(seeds.shape[0]) + return None + + +def save_latent_seed_state(model: nn.Module, output_dir: str) -> None: + state = {"latent_seed_embeds": model.latent_seed_embeds.detach().cpu()} + torch.save(state, os.path.join(output_dir, "latent_seed_state.pt")) + + +def build_fixed_slot_latent_hidden( + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_cot_tokens: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + backbone = unwrap_backbone(model) + inner_backbone = getattr(backbone, "model", backbone) + input_embeds = get_input_embeddings_module(model)(input_ids) + base_out = inner_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + base_hidden = get_last_hidden_state(base_out)[:, -1, :] + + slot_bank = model.fixed_latent_slots.to(device=input_embeds.device, dtype=input_embeds.dtype) + slot_count = min(max(0, int(num_cot_tokens)), int(slot_bank.shape[0])) + latent_slots = slot_bank[:slot_count].unsqueeze(0).expand(input_embeds.shape[0], slot_count, slot_bank.shape[-1]) + final_slot = model.fixed_final_slot_embed.to(device=input_embeds.device, dtype=input_embeds.dtype).view(1, 1, -1) + final_slot = final_slot.expand(input_embeds.shape[0], 1, final_slot.shape[-1]) + + if slot_count > 0: + full_embeds = torch.cat([input_embeds, latent_slots, final_slot], dim=1) + else: + full_embeds = torch.cat([input_embeds, final_slot], dim=1) + full_mask = extend_attention_mask(attention_mask, slot_count + 1) + capture_attn = bool(getattr(model, "_attention_density_debug_limit", 0) > 0) + full_out = run_backbone_from_embeds(backbone, full_embeds, full_mask, output_attentions=capture_attn) + _maybe_print_attention_density(model, getattr(full_out, "attentions", None), label="fixed_slots") + full_hidden = get_last_hidden_state(full_out) + latent_hiddens = full_hidden[:, input_embeds.shape[1] : input_embeds.shape[1] + slot_count, :] + final_hidden = full_hidden[:, -1, :] + return base_hidden, latent_hiddens, final_hidden + + +def fixed_slot_next_token_logits_from_ids( + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_cot_tokens: int, +) -> torch.Tensor: + _, latent_hiddens, final_hidden = build_fixed_slot_latent_hidden(model, input_ids, attention_mask, num_cot_tokens) + output_embeddings = get_output_embeddings_module(model) + output_dtype = getattr(getattr(output_embeddings, "weight", None), "dtype", final_hidden.dtype) + logits = output_embeddings(final_hidden.to(dtype=output_dtype)) + logits = _sanitize_logits(logits, output_dtype=output_dtype) + _maybe_print_fixed_slot_debug( + model, + input_ids=input_ids, + latent_hiddens=latent_hiddens, + final_hidden=final_hidden, + logits=logits, + ) + return logits + + +def project_hidden(model: nn.Module, hidden: torch.Tensor) -> torch.Tensor: + input_dtype = hidden.dtype + hidden = hidden.to(torch.float32) + hidden = F.linear( + hidden, + model.latent_projector_in.weight.to(dtype=torch.float32), + None if model.latent_projector_in.bias is None else model.latent_projector_in.bias.to(dtype=torch.float32), + ) + hidden = F.gelu(hidden) + hidden = F.linear( + hidden, + model.latent_projector_out.weight.to(dtype=torch.float32), + None if model.latent_projector_out.bias is None else model.latent_projector_out.bias.to(dtype=torch.float32), + ) + hidden = torch.nan_to_num(hidden, nan=0.0, posinf=50.0, neginf=-50.0) + hidden = hidden.clamp(min=-50.0, max=50.0) + return hidden.to(input_dtype) + + +def _sanitize_logits(logits: torch.Tensor, *, output_dtype: torch.dtype) -> torch.Tensor: + logits = torch.nan_to_num(logits.float(), nan=0.0, posinf=50.0, neginf=-50.0) + logits = logits.clamp(min=-50.0, max=50.0) + return logits.to(dtype=output_dtype) + + +def _debug_print_vocab_comparison( + model: nn.Module, + base_logits: torch.Tensor, + latent_logits: torch.Tensor, + *, + fallback_mask: torch.Tensor, +) -> None: + debug_count = int(getattr(model, "_latent_vocab_debug_count", 0)) + debug_limit = int(getattr(model, "_latent_vocab_debug_limit", 10000)) + if debug_count >= debug_limit or int(base_logits.shape[0]) <= 0: + return + + tokenizer = getattr(model, "_latent_debug_tokenizer", None) + topk = max(1, int(getattr(model, "_latent_vocab_debug_topk", 1))) + base_probs = torch.softmax(base_logits[0].float(), dim=-1) + latent_probs = torch.softmax(latent_logits[0].float(), dim=-1) + base_top_probs, base_top_ids = torch.topk(base_probs, k=min(topk, int(base_probs.shape[-1])), dim=-1) + latent_top_probs, latent_top_ids = torch.topk(latent_probs, k=min(topk, int(latent_probs.shape[-1])), dim=-1) + gate = float(torch.sigmoid(model.latent_mix_logit.float()).item()) if hasattr(model, "latent_mix_logit") else float("nan") + + def _fmt(ids: torch.Tensor, probs: torch.Tensor) -> str: + parts = [] + for tok_id, prob in zip(ids.tolist(), probs.tolist(), strict=True): + piece = repr(tokenizer.decode([int(tok_id)])) if tokenizer is not None else f"id={int(tok_id)}" + parts.append(f"{piece}:{float(prob):.4f}") + return " | ".join(parts) + + print( + "[latent vocab debug] " + f"token_step={debug_count} " + f"gate={gate:.6f} " + f"fallback_row0={bool(fallback_mask[0].item())} " + f"base_next={_fmt(base_top_ids, base_top_probs)} " + f"latent_next={_fmt(latent_top_ids, latent_top_probs)}", + flush=True, + ) + model._latent_vocab_debug_count = debug_count + 1 + + +def _should_fallback_to_base(model: nn.Module, latent_logits: torch.Tensor) -> torch.Tensor: + # Safety checker for latent decoding: if the latent-steered logits become too + # sharp / low-entropy / numerically suspicious, ignore them for that row and + # fall back to plain base-model logits instead. + scores = torch.nan_to_num(latent_logits.float(), nan=0.0, posinf=50.0, neginf=-50.0) + probs = torch.softmax(scores, dim=-1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8) + max_prob = probs.max(dim=-1).values + entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1) + if int(scores.shape[-1]) > 1: + top2 = torch.topk(scores, k=2, dim=-1).values + margin = top2[..., 0] - top2[..., 1] + else: + margin = torch.full_like(max_prob, float("inf")) + fallback = (~torch.isfinite(scores)).any(dim=-1) + fallback |= max_prob > float(getattr(model, "_latent_fallback_max_prob", 0.995)) + fallback |= entropy < float(getattr(model, "_latent_fallback_entropy_min", 0.02)) + fallback |= margin > float(getattr(model, "_latent_fallback_margin_max", 25.0)) + return fallback + + +def build_latent_hidden(model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int): + backbone = unwrap_backbone(model) + inner_backbone = getattr(backbone, "model", backbone) + input_embeds = get_input_embeddings_module(model)(input_ids) + base_out = inner_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + base_hidden = get_last_hidden_state(base_out)[:, -1, :] + if num_cot_tokens <= 0: + return base_hidden, base_hidden + + cur_embeds = input_embeds + cur_mask = attention_mask + latent_token = None + special = model.special_thought_embed.to(device=input_embeds.device, dtype=input_embeds.dtype).view(1, 1, -1) + for step_idx in range(int(num_cot_tokens)): + # This is the actual latent-thought rollout: append one hidden "thought" + # token, rerun the backbone, then feed the resulting last hidden state back + # in as the next latent token embedding. + next_embed = special.expand(cur_embeds.shape[0], 1, -1) if latent_token is None else latent_token + # print( + # f"[latent rollout shapes] step={step_idx} " + # f"next_embed={tuple(next_embed.shape)} cur_embeds_before={tuple(cur_embeds.shape)} " + # f"cur_mask_before={tuple(cur_mask.shape)}", + # flush=True, + # ) + cur_embeds = torch.cat([cur_embeds, next_embed], dim=1) + cur_mask = extend_attention_mask(cur_mask, 1) + out = run_backbone_from_embeds(backbone, cur_embeds, cur_mask) + latent_token = get_last_hidden_state(out)[:, -1:, :] + # print( + # f"[latent rollout shapes] step={step_idx} " + # f"cur_embeds_after={tuple(cur_embeds.shape)} cur_mask_after={tuple(cur_mask.shape)} " + # f"latent_token={tuple(latent_token.shape)}", + # flush=True, + # ) + latent_hidden = latent_token[:, 0, :] + # print( + # f"[latent rollout shapes] final base_hidden={tuple(base_hidden.shape)} latent_hidden={tuple(latent_hidden.shape)}", + # flush=True, + # ) + return base_hidden, latent_hidden + + +def build_recurrent_hidden_latent_hidden( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> tuple[torch.Tensor, torch.Tensor]: + backbone = unwrap_backbone(model) + inner_backbone = getattr(backbone, "model", backbone) + input_embeds = get_input_embeddings_module(model)(input_ids) + base_out = inner_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + base_hidden = get_last_hidden_state(base_out)[:, -1, :] + if num_cot_tokens <= 0: + return base_hidden, base_hidden + + cur_embeds = input_embeds + cur_mask = attention_mask + latent_token = base_hidden.to(device=input_embeds.device, dtype=input_embeds.dtype).unsqueeze(1) + for _ in range(int(num_cot_tokens)): + cur_embeds = torch.cat([cur_embeds, latent_token], dim=1) + cur_mask = extend_attention_mask(cur_mask, 1) + out = run_backbone_from_embeds(backbone, cur_embeds, cur_mask) + latent_token = get_last_hidden_state(out)[:, -1:, :] + latent_hidden = latent_token[:, 0, :] + return base_hidden, latent_hidden + + +def recurrent_hidden_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + _, latent_hidden = build_recurrent_hidden_latent_hidden(model, input_ids, attention_mask, num_cot_tokens) + output_embeddings = get_output_embeddings_module(model) + output_dtype = getattr(getattr(output_embeddings, "weight", None), "dtype", latent_hidden.dtype) + logits = output_embeddings(latent_hidden.to(dtype=output_dtype)) + return _sanitize_logits(logits, output_dtype=output_dtype) + + +def build_latent_seed_hidden( + model: nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_cot_tokens: int, +) -> torch.Tensor: + """Option-2 forward: append k trainable seed vectors m_1..m_k to the prompt + embeddings and run the backbone once. The next-token hidden state is the + last position of the resulting hidden states. + + Architecture (one backbone forward per next-token prediction): + inputs = [ E[x_1], ..., E[x_T], m_1, m_2, ..., m_k ] # [B, T+k, d] + H = f_theta(inputs) # [B, T+k, d] + z_k = H[:, T+k-1, :] # [B, d] + + Only the k seed parameters ``model.latent_seed_embeds[:k]`` and the LoRA + deltas get gradients. When ``num_cot_tokens == 0`` this reduces to the + standard last-hidden prediction (no seeds appended). + """ + backbone = unwrap_backbone(model) + inner_backbone = getattr(backbone, "model", backbone) + input_embeds = get_input_embeddings_module(model)(input_ids) + k = max(0, int(num_cot_tokens)) + if k <= 0: + base_out = inner_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + return get_last_hidden_state(base_out)[:, -1, :] + seed_bank = model.latent_seed_embeds.to(device=input_embeds.device, dtype=input_embeds.dtype) + k = min(k, int(seed_bank.shape[0])) + seed_tokens = seed_bank[:k].unsqueeze(0).expand(input_embeds.shape[0], k, seed_bank.shape[-1]) + full_embeds = torch.cat([input_embeds, seed_tokens], dim=1) + full_mask = extend_attention_mask(attention_mask, k) + out = run_backbone_from_embeds(backbone, full_embeds, full_mask) + return get_last_hidden_state(out)[:, -1, :] + + +def latent_seed_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + latent_hidden = build_latent_seed_hidden(model, input_ids, attention_mask, num_cot_tokens) + output_embeddings = get_output_embeddings_module(model) + output_dtype = getattr(getattr(output_embeddings, "weight", None), "dtype", latent_hidden.dtype) + logits = output_embeddings(latent_hidden.to(dtype=output_dtype)) + return _sanitize_logits(logits, output_dtype=output_dtype) + + +def residual_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + base_hidden, latent_hidden = build_latent_hidden(model, input_ids, attention_mask, num_cot_tokens) + # The latent path does not emit tokens directly; it produces a hidden-state + # delta that is projected back onto the base next-token hidden state. + projected_delta = project_hidden(model, latent_hidden - base_hidden).float() + mix = torch.sigmoid(model.latent_mix_logit.float()).to(projected_delta.device) + projected_delta = projected_delta * float(getattr(model, "_latent_delta_scale", 1.0)) * mix + base_hidden_fp32 = base_hidden.float() + base_norm = base_hidden_fp32.norm(dim=-1, keepdim=True).clamp_min(1e-6) + delta_norm = projected_delta.norm(dim=-1, keepdim=True).clamp_min(1e-6) + max_ratio = max(0.0, float(getattr(model, "_latent_delta_max_ratio", 0.5))) + max_delta_norm = base_norm * max_ratio + projected_delta = projected_delta * torch.clamp(max_delta_norm / delta_norm, max=1.0) + final_hidden = torch.nan_to_num(base_hidden_fp32 + projected_delta, nan=0.0, posinf=50.0, neginf=-50.0) + output_embeddings = get_output_embeddings_module(model) + output_dtype = getattr(getattr(output_embeddings, "weight", None), "dtype", final_hidden.dtype) + base_logits = _sanitize_logits(output_embeddings(base_hidden_fp32.to(dtype=output_dtype)), output_dtype=output_dtype) + latent_logits = _sanitize_logits(output_embeddings(final_hidden.to(dtype=output_dtype)), output_dtype=output_dtype) + fallback_mask = _should_fallback_to_base(model, latent_logits) + if bool(fallback_mask.any()): + warn_count = int(getattr(model, "_latent_fallback_warn_count", 0)) + if warn_count < 5: + print(f"[latent grpo] falling back to base logits for {int(fallback_mask.sum().item())} rows") + model._latent_fallback_warn_count = warn_count + 1 + fallback_hidden = base_hidden_fp32[fallback_mask].to(dtype=output_dtype) + fallback_logits = _sanitize_logits(output_embeddings(fallback_hidden), output_dtype=output_dtype) + latent_logits = latent_logits.clone() + latent_logits[fallback_mask] = fallback_logits + _debug_print_vocab_comparison(model, base_logits, latent_logits, fallback_mask=fallback_mask) + return latent_logits + + +def _apply_repetition_penalty(logits: torch.Tensor, tokens: torch.Tensor, penalty: float) -> torch.Tensor: + if penalty == 1.0 or tokens.numel() == 0: + return logits + adjusted = logits.clone() + unique_tokens = torch.unique(tokens, sorted=False) + seen_logits = adjusted.index_select(dim=-1, index=unique_tokens) + seen_logits = torch.where(seen_logits < 0, seen_logits * penalty, seen_logits / penalty) + adjusted.index_copy_(dim=-1, index=unique_tokens, source=seen_logits) + return adjusted + + +def _sample_from_latent_logits( + logits: torch.Tensor, + *, + do_sample: bool, + temperature: float, + top_p: float, + top_k: int, +) -> torch.Tensor: + if not do_sample: + return torch.argmax(logits, dim=-1, keepdim=True) + + temperature = max(float(temperature), 1e-5) + scores = logits / temperature + + if int(top_k) > 0 and int(top_k) < scores.shape[-1]: + topk_values, _ = torch.topk(scores, k=int(top_k), dim=-1) + cutoff = topk_values[:, -1:].expand_as(scores) + scores = torch.where(scores < cutoff, torch.full_like(scores, float("-inf")), scores) + + if 0.0 < float(top_p) < 1.0: + sorted_scores, sorted_indices = torch.sort(scores, dim=-1, descending=True) + sorted_probs = torch.softmax(sorted_scores, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + nucleus_mask = cumulative_probs > float(top_p) + nucleus_mask[:, 0] = False + sorted_scores = sorted_scores.masked_fill(nucleus_mask, float("-inf")) + scores = torch.full_like(scores, float("-inf")) + scores.scatter_(dim=-1, index=sorted_indices, src=sorted_scores) + + probs = torch.softmax(scores, dim=-1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8) + return torch.multinomial(probs, num_samples=1) + + +@torch.no_grad() +def sample_latent_completion( + model: nn.Module, + tokenizer: Any, + prompt_ids: torch.Tensor, + attention_mask: torch.Tensor, + *, + num_cot_tokens: int, + max_new_tokens: int, + do_sample: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + repetition_penalty: float = 1.0, +) -> torch.Tensor: + generated = prompt_ids + mask = attention_mask + eos = tokenizer.eos_token_id + for _ in range(max(1, int(max_new_tokens))): + logits = residual_next_token_logits_from_ids(model, generated, mask, num_cot_tokens) + logits = _apply_repetition_penalty(logits, generated, float(repetition_penalty)) + next_id = _sample_from_latent_logits( + logits.float(), + do_sample=bool(do_sample), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + ) + generated = torch.cat([generated, next_id], dim=1) + mask = extend_attention_mask(mask, 1) + if eos is not None and bool((next_id == int(eos)).all()): + break + return generated[:, prompt_ids.shape[1] :] + + +@torch.no_grad() +def sample_recurrent_hidden_completion( + model: nn.Module, + tokenizer: Any, + prompt_ids: torch.Tensor, + attention_mask: torch.Tensor, + *, + num_cot_tokens: int, + max_new_tokens: int, + do_sample: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + repetition_penalty: float = 1.0, +) -> torch.Tensor: + generated = prompt_ids + mask = attention_mask + eos = tokenizer.eos_token_id + for _ in range(max(1, int(max_new_tokens))): + logits = recurrent_hidden_next_token_logits_from_ids(model, generated, mask, num_cot_tokens) + logits = _apply_repetition_penalty(logits, generated, float(repetition_penalty)) + next_id = _sample_from_latent_logits( + logits.float(), + do_sample=bool(do_sample), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + ) + generated = torch.cat([generated, next_id], dim=1) + mask = extend_attention_mask(mask, 1) + if eos is not None and bool((next_id == int(eos)).all()): + break + return generated[:, prompt_ids.shape[1] :] + + +@torch.no_grad() +def sample_latent_seed_completion( + model: nn.Module, + tokenizer: Any, + prompt_ids: torch.Tensor, + attention_mask: torch.Tensor, + *, + num_cot_tokens: int, + max_new_tokens: int, + do_sample: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + repetition_penalty: float = 1.0, +) -> torch.Tensor: + """Greedy/temperature sampler for the Option-2 latent-seed architecture. + + At each output step the trainable seeds m_1..m_k are appended to the + current real-token prefix and a single backbone forward pass produces the + next-token logits. The seeds are not carried across output tokens; they are + re-appended each step (they are constant parameters, so this is equivalent + to caching them). + """ + generated = prompt_ids + mask = attention_mask + eos = tokenizer.eos_token_id + for _ in range(max(1, int(max_new_tokens))): + logits = latent_seed_next_token_logits_from_ids(model, generated, mask, num_cot_tokens) + logits = _apply_repetition_penalty(logits, generated, float(repetition_penalty)) + next_id = _sample_from_latent_logits( + logits.float(), + do_sample=bool(do_sample), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + ) + generated = torch.cat([generated, next_id], dim=1) + mask = extend_attention_mask(mask, 1) + if eos is not None and bool((next_id == int(eos)).all()): + break + return generated[:, prompt_ids.shape[1] :] + + +@torch.no_grad() +def sample_fixed_slot_completion( + model: nn.Module, + tokenizer: Any, + prompt_ids: torch.Tensor, + attention_mask: torch.Tensor, + *, + num_cot_tokens: int, + max_new_tokens: int, + do_sample: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + repetition_penalty: float = 1.0, +) -> torch.Tensor: + generated = prompt_ids + mask = attention_mask + eos = tokenizer.eos_token_id + for _ in range(max(1, int(max_new_tokens))): + logits = fixed_slot_next_token_logits_from_ids(model, generated, mask, num_cot_tokens) + logits = _apply_repetition_penalty(logits, generated, float(repetition_penalty)) + next_id = _sample_from_latent_logits( + logits, + do_sample=bool(do_sample), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + ) + _maybe_print_fixed_slot_decode_debug( + model, + tokenizer, + next_id, + step_idx=int(generated.shape[1] - prompt_ids.shape[1]), + ) + generated = torch.cat([generated, next_id], dim=1) + mask = extend_attention_mask(mask, 1) + if eos is not None and bool((next_id == int(eos)).all()): + break + return generated[:, prompt_ids.shape[1] :] + + +def install_latent_grpo_model_interface( + model: nn.Module, + tokenizer: Any, + *, + num_cot_tokens: int, + latent_mode: str = "residual", + latent_delta_scale: float = 1.0, + latent_delta_max_ratio: float = 0.5, +) -> nn.Module: + if getattr(model, "_latent_grpo_interface_installed", False): + model._latent_grpo_num_cot_tokens = int(num_cot_tokens) + model._latent_grpo_latent_mode = str(latent_mode).strip().lower() + model._latent_grpo_tokenizer = tokenizer + model._latent_delta_scale = float(latent_delta_scale) + model._latent_delta_max_ratio = float(latent_delta_max_ratio) + return model + + model._latent_grpo_interface_installed = True + model._latent_grpo_num_cot_tokens = int(num_cot_tokens) + model._latent_grpo_latent_mode = str(latent_mode).strip().lower() + model._latent_grpo_tokenizer = tokenizer + model._latent_delta_scale = float(latent_delta_scale) + model._latent_delta_max_ratio = float(latent_delta_max_ratio) + model._latent_original_forward = model.forward + model._latent_original_generate = model.generate + + def next_token_logits_for_mode(self, prefix_ids: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor: + mode = str(getattr(self, "_latent_grpo_latent_mode", "residual")).strip().lower() + k = int(getattr(self, "_latent_grpo_num_cot_tokens", 0)) + if mode == "fixed_slots": + return fixed_slot_next_token_logits_from_ids(self, prefix_ids, prefix_mask, k) + if mode == "recurrent_hidden": + return recurrent_hidden_next_token_logits_from_ids(self, prefix_ids, prefix_mask, k) + if mode == "latent_seeds": + return latent_seed_next_token_logits_from_ids(self, prefix_ids, prefix_mask, k) + return residual_next_token_logits_from_ids(self, prefix_ids, prefix_mask, k) + + def latent_forward( + self, + input_ids=None, + attention_mask=None, + logits_to_keep=None, + use_cache=None, + **kwargs, + ): + if input_ids is None or attention_mask is None or logits_to_keep is None: + return self._latent_original_forward( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + **kwargs, + ) + + seq_len = int(input_ids.shape[1]) + keep = max(1, min(int(logits_to_keep), seq_len)) + start = max(1, seq_len - keep) + logits = [] + for prefix_len in range(start, seq_len + 1): + prefix_ids = input_ids[:, :prefix_len] + prefix_mask = attention_mask[:, :prefix_len] + step_logits = next_token_logits_for_mode(self, prefix_ids, prefix_mask) + logits.append(step_logits.unsqueeze(1)) + return CausalLMOutput(logits=torch.cat(logits, dim=1)) + + @torch.no_grad() + def latent_generate(self, input_ids=None, attention_mask=None, generation_config=None, **kwargs): + if input_ids is None or attention_mask is None: + return self._latent_original_generate( + input_ids=input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + **kwargs, + ) + + tokenizer_local = self._latent_grpo_tokenizer + max_new_tokens = int( + getattr(generation_config, "max_new_tokens", None) or kwargs.get("max_new_tokens") or 16 + ) + do_sample = bool(getattr(generation_config, "do_sample", True)) + # New TRL/transformers can leave temperature/top_p/top_k/repetition_penalty as None in + # GenerationConfig. Fall back to sane defaults instead of crashing in int()/float(). + _temperature_val = getattr(generation_config, "temperature", 1.0) + temperature = float(_temperature_val) if _temperature_val is not None else 1.0 + _top_p_val = getattr(generation_config, "top_p", 1.0) + top_p = float(_top_p_val) if _top_p_val is not None else 1.0 + _top_k_val = getattr(generation_config, "top_k", 0) + top_k = int(_top_k_val) if _top_k_val is not None else 0 + _rep_val = getattr(generation_config, "repetition_penalty", 1.0) + repetition_penalty = float(_rep_val) if _rep_val is not None else 1.0 + pad_token_id = getattr(generation_config, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = getattr(tokenizer_local, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = getattr(tokenizer_local, "eos_token_id", 0) + + rows = [] + for row_ids, row_mask in zip(input_ids, attention_mask, strict=True): + row_prompt = row_ids.unsqueeze(0) + row_attn = row_mask.unsqueeze(0) + mode = str(getattr(self, "_latent_grpo_latent_mode", "residual")).strip().lower() + sample_kwargs = { + "num_cot_tokens": int(self._latent_grpo_num_cot_tokens), + "max_new_tokens": max_new_tokens, + "do_sample": do_sample, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "repetition_penalty": repetition_penalty, + } + if mode == "fixed_slots": + completion = sample_fixed_slot_completion(self, tokenizer_local, row_prompt, row_attn, **sample_kwargs) + elif mode == "recurrent_hidden": + completion = sample_recurrent_hidden_completion(self, tokenizer_local, row_prompt, row_attn, **sample_kwargs) + elif mode == "latent_seeds": + completion = sample_latent_seed_completion(self, tokenizer_local, row_prompt, row_attn, **sample_kwargs) + else: + completion = sample_latent_completion(self, tokenizer_local, row_prompt, row_attn, **sample_kwargs) + rows.append(torch.cat([row_prompt, completion], dim=1).squeeze(0)) + + max_len = max(int(row.shape[0]) for row in rows) + padded = [] + for row in rows: + if int(row.shape[0]) < max_len: + pad = torch.full((max_len - int(row.shape[0]),), int(pad_token_id), device=row.device, dtype=row.dtype) + row = torch.cat([row, pad], dim=0) + padded.append(row) + return torch.stack(padded, dim=0) + + model.forward = MethodType(latent_forward, model) + model.generate = MethodType(latent_generate, model) + return model + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: torch.nn.Module, + tokenizer: Any, + device: torch.device, + eval_stage_i: int | None = None, + log_prefix: str = "latent grpo eval", +) -> Dict[str, float]: + model.eval() + stage_i = int(eval_stage_i if eval_stage_i is not None else args.stage_i) + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + prompt_ids = enc["input_ids"].to(device) + attn = enc["attention_mask"].to(device) + latent_mode_eval = str(getattr(args, "latent_mode", "residual")).strip().lower() + if latent_mode_eval == "recurrent_hidden": + completion_ids = sample_recurrent_hidden_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + elif latent_mode_eval == "fixed_slots": + completion_ids = sample_fixed_slot_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + elif latent_mode_eval == "latent_seeds": + completion_ids = sample_latent_seed_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + else: + completion_ids = sample_latent_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + pred_text = tokenizer.decode(completion_ids[0], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[latent grpo eval debug] target=({rr+1},{cc+1}) output={pred_text!r}") + print(f"[latent grpo eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}") + printed += 1 + solve_ok += int(row_all_exact) + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + "eval_cells": float(total_cells), + } + print( + f"[{log_prefix}] parse={out['parse_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f} " + f"avg_set_size={out['avg_predicted_set_size']:.3f} " + f"good={out['avg_num_i_consistent_values']:.3f} " + f"bad={out['avg_num_non_i_consistent_values']:.3f}" + ) + return out + + +def run_dual_eval( + *, + args: Args, + eval_rows_stage1: List[Dict[str, Any]], + eval_rows_stage2: List[Dict[str, Any]], + model: torch.nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + metrics_stage1 = run_eval( + args=args, + rows=eval_rows_stage1, + model=model, + tokenizer=tokenizer, + device=device, + eval_stage_i=1, + log_prefix="latent grpo eval stage1", + ) + metrics_stage2 = run_eval( + args=args, + rows=eval_rows_stage2, + model=model, + tokenizer=tokenizer, + device=device, + eval_stage_i=max(1, int(args.stage_i)), + log_prefix=f"latent grpo eval stage{int(args.stage_i)}", + ) + out = {f"stage1/{k}": float(v) for k, v in metrics_stage1.items()} + out.update({f"stage{int(args.stage_i)}/{k}": float(v) for k, v in metrics_stage2.items()}) + return out + + +class ResidualProjectorEvalCallback(TrainerCallback): + def __init__( + self, + args: Args, + eval_rows_stage1: List[Dict[str, Any]], + eval_rows_stage2: List[Dict[str, Any]], + tokenizer: Any, + device: torch.device, + wb_run: Any, + is_main_process: bool, + ): + self.args = args + self.eval_rows_stage1 = eval_rows_stage1 + self.eval_rows_stage2 = eval_rows_stage2 + self.tokenizer = tokenizer + self.device = device + self.wb_run = wb_run + self.is_main_process = is_main_process + self.last_logged_step = -1 + + def on_step_end(self, args, state, control, **kwargs): + step = int(state.global_step) + eval_every = int(self.args.eval_steps) + if step <= 0 or step % eval_every != 0: + return control + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_dist = world_size > 1 and torch.distributed.is_available() and torch.distributed.is_initialized() + stop_tensor = torch.zeros(1, dtype=torch.int32, device=self.device) + + if self.is_main_process: + if step != self.last_logged_step: + model = kwargs.get("model") + if model is not None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + metrics = run_eval( + args=self.args, + rows=self.eval_rows_stage2, + model=unwrap_training_model(model), + tokenizer=self.tokenizer, + device=self.device, + eval_stage_i=max(1, int(self.args.stage_i)), + log_prefix=f"latent grpo callback eval stage{int(self.args.stage_i)}", + ) + if self.eval_rows_stage1: + stage1_metrics = run_eval( + args=self.args, + rows=self.eval_rows_stage1, + model=unwrap_training_model(model), + tokenizer=self.tokenizer, + device=self.device, + eval_stage_i=1, + log_prefix="latent grpo callback eval stage1", + ) + metrics = {f"stage1/{k}": float(v) for k, v in stage1_metrics.items()} | { + f"stage{int(self.args.stage_i)}/{k}": float(v) for k, v in metrics.items() + } + else: + metrics = {f"stage{int(self.args.stage_i)}/{k}": float(v) for k, v in metrics.items()} + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.last_logged_step = step + si = int(self.args.stage_i) + pfx = f"stage{si}/" + print( + f"[latent grpo custom eval step {step}] " + f"stage1_exact={metrics.get('stage1/exact_set_match_rate', float('nan')):.3f} " + f"stage{si}_exact={metrics[f'{pfx}exact_set_match_rate']:.3f} " + f"stage{si}_prec={metrics[f'{pfx}value_precision']:.3f} " + f"stage{si}_rec={metrics[f'{pfx}value_recall']:.3f} " + f"stage{si}_solve={metrics[f'{pfx}solve_rate']:.3f}", + flush=True, + ) + if self.args.use_wandb and self.wb_run is not None: + payload = {f"custom_eval/{k}": float(v) for k, v in metrics.items()} + payload["custom_eval/global_step"] = float(step) + wandb.log(payload) + + if step >= int(self.args.min_steps_before_stop): + vp = float(metrics[f"{pfx}value_precision"]) + vr = float(metrics[f"{pfx}value_recall"]) + sr = float(metrics[f"{pfx}solve_rate"]) + if ( + float(self.args.eval_value_precision_stop) > 0.0 + and float(self.args.eval_value_recall_stop) > 0.0 + and vp >= float(self.args.eval_value_precision_stop) + and vr >= float(self.args.eval_value_recall_stop) + ): + print( + f"[latent grpo custom eval step {step}] stopping early: " + f"value_precision={vp:.3f} >= {float(self.args.eval_value_precision_stop):.3f} " + f"and value_recall={vr:.3f} >= {float(self.args.eval_value_recall_stop):.3f}", + flush=True, + ) + stop_tensor[0] = 1 + if ( + int(stop_tensor.item()) == 0 + and float(self.args.eval_solve_rate_stop) > 0.0 + and sr >= float(self.args.eval_solve_rate_stop) + ): + print( + f"[latent grpo custom eval step {step}] stopping early: " + f"solve_rate={sr:.3f} >= {float(self.args.eval_solve_rate_stop):.3f}", + flush=True, + ) + stop_tensor[0] = 1 + + if use_dist: + torch.distributed.broadcast(stop_tensor, src=0) + + if int(stop_tensor.item()) != 0: + control.should_training_stop = True + return control + + +class SaveLatentStateCallback(TrainerCallback): + def __init__(self, is_main_process: bool, extra_save_fn: Any | None = save_latent_projector_state): + self.is_main_process = is_main_process + self.extra_save_fn = extra_save_fn + + def on_save(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is None: + return control + step_dir = os.path.join(args.output_dir, f"checkpoint-{int(state.global_step)}") + if os.path.isdir(step_dir) and self.extra_save_fn is not None: + self.extra_save_fn(unwrap_training_model(model), step_dir) + return control + + +class FinalCheckpointCallback(TrainerCallback): + def __init__( + self, + output_dir: str, + tokenizer: Any, + is_main_process: bool, + extra_save_fn: Any | None = save_latent_projector_state, + ): + self.output_dir = output_dir + self.tokenizer = tokenizer + self.is_main_process = is_main_process + self.extra_save_fn = extra_save_fn + + def _save(self, model: Any) -> None: + save_model_artifacts( + unwrap_training_model(model), + self.tokenizer, + ensure_final_checkpoint_dir(self.output_dir), + extra_save_fn=self.extra_save_fn, + ) + + def on_save(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + def on_train_end(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + +class WallClockStopCallback(TrainerCallback): + def __init__(self, max_wall_clock_seconds: int): + self.max_wall_clock_seconds = int(max_wall_clock_seconds) + self.start_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if self.max_wall_clock_seconds > 0 and (time.time() - self.start_time) >= float(self.max_wall_clock_seconds): + control.should_training_stop = True + return control + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument( + "--train_jsonl", + type=str, + default="/egr/research-slim/ghoshavr/curriculum-CoT/sudoku/llm_policy_icon/data/sudoku_t3_20empty_value_qwen_text.jsonl", + ) + p.add_argument("--train_jsonl_stage1", type=str, default="") + p.add_argument("--train_jsonl_stage2", type=str, default="") + p.add_argument( + "--eval_jsonl", + type=str, + default="", + help="If set, first eval_rows lines are used for both stage1/stage2 eval (held-out). Else slice train files.", + ) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/egr/research-slim/ghoshavr/.hf_cache") + p.add_argument( + "--init_adapter_dir", + type=str, + default="", + help="Peft adapter checkpoint dir, or empty string for fresh LoRA on the base model (random init).", + ) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--num_cot_tokens", type=int, default=1) + p.add_argument( + "--latent_mode", + type=str, + default="residual", + choices=["residual", "fixed_slots", "recurrent_hidden", "latent_seeds"], + help="Which latent-COT architecture to use for GRPO forward, generation, and eval sampling.", + ) + p.add_argument( + "--max_latent_seeds", + type=int, + default=8, + help="For --latent_mode latent_seeds: size of the trainable seed bank (num_cot_tokens per stage must be <=).", + ) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=4) + p.add_argument("--gradient_accumulation_steps", type=int, default=4) + p.add_argument("--num_train_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--logging_steps", type=int, default=5) + p.add_argument("--save_steps", type=int, default=10) + p.add_argument("--eval_steps", type=int, default=25) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--num_generations", type=int, default=2) + p.add_argument("--max_prompt_length", type=int, default=1024) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--beta", type=float, default=0.0) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument( + "--lora_r", + type=int, + default=192, + help="LoRA rank. Use -1 to resolve to model hidden_size, i.e. full-rank adapters for hidden-width projections.", + ) + p.add_argument( + "--lora_alpha", + type=int, + default=384, + help="LoRA alpha. Use -1 to resolve to 2 * resolved_lora_r.", + ) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-latent-multi-output-grpo-residual-projector") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--wandb_group", type=str, default="") + p.add_argument("--wandb_run_id", type=str, default="") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--mixed_stage1_ratio", type=float, default=0.0) + p.add_argument("--mixed_stage2_ratio", type=float, default=1.0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--resume_from_checkpoint", type=str, default="") + p.add_argument( + "--eval_value_precision_stop", + type=float, + default=0.0, + help="If >0 and --eval_value_recall_stop>0, stop when both reached on current stage_i eval (with min_steps_before_stop).", + ) + p.add_argument("--eval_value_recall_stop", type=float, default=0.0) + p.add_argument( + "--eval_solve_rate_stop", + type=float, + default=0.0, + help="If >0, stop when stage_i solve_rate reaches this threshold (after min_steps_before_stop).", + ) + p.add_argument("--min_steps_before_stop", type=int, default=0) + return Args(**vars(p.parse_args())) + + +def main() -> None: + args = parse_args() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_main_process = rank == 0 + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + if preset_visible_devices: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}") + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank if world_size > 1 else max(0, int(args.gpu_id))) + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + print(f"Using Hugging Face cache dir: {cache_dir}") + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + wb_run = wandb.init( + project=args.wandb_project, + entity=args.wandb_entity or None, + name=args.wandb_run_name or None, + mode=args.wandb_mode, + group=args.wandb_group or None, + id=args.wandb_run_id or None, + ) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/records_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_training_rows(args) + eval_src = str(getattr(args, "eval_jsonl", "") or "").strip() + if eval_src: + _eval_slice = load_jsonl_rows(eval_src, limit_rows=0)[: max(1, int(args.eval_rows))] + eval_rows_stage1 = _eval_slice + eval_rows_stage2 = _eval_slice + else: + eval_rows_stage1 = load_jsonl_rows(args.train_jsonl_stage1 or args.train_jsonl, limit_rows=0)[ + : max(1, int(args.eval_rows)) + ] + eval_rows_stage2 = load_jsonl_rows(args.train_jsonl_stage2 or args.train_jsonl, limit_rows=0)[ + : max(1, int(args.eval_rows)) + ] + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter( + base, + args.init_adapter_dir, + lora_r=int(args.lora_r), + lora_alpha=int(args.lora_alpha), + lora_dropout=float(args.lora_dropout), + ) + init_ad = str(args.init_adapter_dir).strip() + hidden_size = int(unwrap_backbone(model).config.hidden_size) + latent_mode = str(args.latent_mode).strip().lower() + extra_save_fn: Any | None + if init_ad: + print(f"Loaded init adapter: {init_ad}", flush=True) + else: + print( + "init_adapter_dir empty: fresh LoRA on base (weights random); matches --lora_r/--lora_alpha/--lora_dropout.", + flush=True, + ) + + if latent_mode == "fixed_slots": + max_latent_slots = max(1, int(args.num_cot_tokens)) + if init_ad: + max_latent_slots = infer_fixed_slot_count_from_state(init_ad) or max_latent_slots + attach_fixed_latent_slot_modules(model, hidden_size=hidden_size, max_latent_slots=max_latent_slots) + if init_ad and maybe_load_fixed_slot_state(model, init_ad): + print(f"Loaded fixed_slot_latent_state.pt from: {init_ad}", flush=True) + else: + print(f"Fixed latent slots active (max_latent_slots={max_latent_slots}).", flush=True) + extra_save_fn = save_fixed_slot_latent_state + elif latent_mode == "latent_seeds": + max_latent_seeds = max(1, int(args.max_latent_seeds), int(args.num_cot_tokens)) + if init_ad: + max_latent_seeds = infer_latent_seed_count_from_state(init_ad) or max_latent_seeds + attach_latent_seed_modules(model, hidden_size=hidden_size, max_latent_seeds=max_latent_seeds) + if init_ad and maybe_load_latent_seed_state(model, init_ad): + print(f"Loaded latent_seed_state.pt from: {init_ad}", flush=True) + else: + print(f"Latent seed bank active (max_latent_seeds={max_latent_seeds}).", flush=True) + extra_save_fn = save_latent_seed_state + elif latent_mode == "recurrent_hidden": + print( + f"Recurrent hidden latent rollout active (num_cot_tokens={int(args.num_cot_tokens)}); " + "no extra latent parameters to save.", + flush=True, + ) + extra_save_fn = None + else: + projector_hidden = infer_projector_hidden_from_state(init_ad) or PROJECTOR_HIDDEN if init_ad else PROJECTOR_HIDDEN + attach_residual_projector_modules(model, hidden_size=hidden_size, projector_hidden=projector_hidden) + if init_ad and maybe_load_projector_state(model, init_ad): + print(f"Loaded latent_cot_state.pt from: {init_ad}", flush=True) + else: + print("Residual projector + special_thought_embed active.", flush=True) + extra_save_fn = save_latent_projector_state + model._latent_debug_tokenizer = tokenizer + install_latent_grpo_model_interface( + model, + tokenizer, + num_cot_tokens=int(args.num_cot_tokens), + latent_mode=latent_mode, + ) + if world_size <= 1: + model.to(device) + model.train() + + def on_prep_progress(*, row_idx: int, total_rows: int, record_count: int) -> None: + if not is_main_process: + return + print( + f"[dataset build][grpo stage {args.stage_i}] rows={row_idx}/{total_rows} records={record_count}", + flush=True, + ) + if wb_run is not None: + wandb.log( + { + "prep/rows_done": float(row_idx), + "prep/rows_total": float(total_rows), + "prep/records_built": float(record_count), + } + ) + + train_records = load_or_build_grpo_records( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + train_dataset = Dataset.from_list(train_records) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_grpo_cache_path(args))), + "prep/records_final": float(len(train_records)), + } + ) + reward_func = make_reward_func(args) + + ensure_trl_fsdp_compat() + from trl import GRPOConfig, GRPOTrainer + + if int(args.limit_train_rows) > 0 and int(args.max_steps) <= 0: + args.max_steps = 1 + config_kwargs = { + "output_dir": args.output_dir, + "per_device_train_batch_size": args.per_device_train_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + "num_train_epochs": args.num_train_epochs, + "learning_rate": args.learning_rate, + "logging_steps": args.logging_steps, + "save_steps": args.save_steps, + "eval_strategy": "no", + "do_eval": False, + "max_completion_length": args.max_completion_length, + "num_generations": args.num_generations, + "beta": args.beta, + "gradient_checkpointing": bool(args.enable_gradient_checkpointing), + "bf16": (pick_dtype() == torch.bfloat16), + "report_to": (["wandb"] if args.use_wandb and is_main_process else []), + "remove_unused_columns": False, + # Custom eval runs only on rank 0 (latent decode is heavy), while + # ranks 1..N-1 wait inside torch.distributed.broadcast for the + # stop signal. With the NCCL watchdog default of 1800s, anything + # slower than 30 min triggers a fatal collective timeout. We + # extend it generously so deeper-stage evals (k>=2 latent steps) + # cannot deadlock the run. + "ddp_timeout": 7200, + } + if int(args.max_steps) > 0: + config_kwargs["max_steps"] = int(args.max_steps) + config = GRPOConfig(**config_kwargs) + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[reward_func], + args=config, + train_dataset=train_dataset, + ) + trainer.add_callback( + ResidualProjectorEvalCallback( + args, + eval_rows_stage1, + eval_rows_stage2, + tokenizer, + device, + wb_run, + is_main_process, + ) + ) + trainer.add_callback(SaveLatentStateCallback(is_main_process, extra_save_fn=extra_save_fn)) + trainer.add_callback(FinalCheckpointCallback(args.output_dir, tokenizer, is_main_process, extra_save_fn=extra_save_fn)) + trainer.add_callback(WallClockStopCallback(args.max_wall_clock_seconds)) + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint or None) + + if hasattr(trainer, "accelerator"): + trainer.accelerator.wait_for_everyone() + final_model = trainer.accelerator.unwrap_model(trainer.model) if hasattr(trainer, "accelerator") else trainer.model + final_model = unwrap_training_model(final_model) + if is_main_process: + eval_metrics = run_dual_eval( + args=args, + eval_rows_stage1=eval_rows_stage1, + eval_rows_stage2=eval_rows_stage2, + model=final_model, + tokenizer=tokenizer, + device=device, + ) + si = int(args.stage_i) + print( + f"[latent grpo final eval] " + f"stage1_exact={eval_metrics.get('stage1/exact_set_match_rate', float('nan')):.3f} " + f"stage{si}_exact={eval_metrics[f'stage{si}/exact_set_match_rate']:.3f} " + f"stage{si}_prec={eval_metrics[f'stage{si}/value_precision']:.3f} " + f"stage{si}_rec={eval_metrics[f'stage{si}/value_recall']:.3f} " + f"stage{si}_solve={eval_metrics[f'stage{si}/solve_rate']:.3f}" + ) + trainer.save_model(args.output_dir) + if extra_save_fn is not None: + extra_save_fn(final_model, args.output_dir) + save_model_artifacts( + final_model, + tokenizer, + ensure_final_checkpoint_dir(args.output_dir), + extra_save_fn=extra_save_fn, + ) + if wb_run is not None: + wandb.log({f"final_eval/{k}": float(v) for k, v in eval_metrics.items()}) + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py b/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6c932b253fb342534939bc2b925b478b53fbc4c1 --- /dev/null +++ b/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py @@ -0,0 +1,931 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import os +import sys +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from peft import PeftModel +from torch.optim import AdamW +from torch.utils.data import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final +from mixed_curriculum_cot.runtime_mixed_curriculum import training_stage_i_for_row +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.rewards import score_prediction_text +from multi_output_cell_policy.shared_multi_output_policy import build_supervised_completion, make_solved_grid_from_row +from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import ( + PROJECTOR_HIDDEN, + attach_residual_projector_modules, + build_latent_hidden, + configure_hf_cache, + extend_attention_mask, + get_output_embeddings_module, + infer_projector_hidden_from_state, + load_jsonl_rows, + load_trainable_adapter, + maybe_load_projector_state, + pick_dtype, + project_hidden, + residual_next_token_logits_from_ids as shared_residual_next_token_logits_from_ids, + sample_latent_completion, + save_latent_projector_state, + unwrap_backbone, +) + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + train_jsonl_stage1: str + train_jsonl_stage2: str + eval_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + lora_r: int + lora_alpha: int + lora_dropout: float + seed: int + gpu_id: int + stage_i: int + num_cot_tokens: int + total_empties_hint: int + gradient_accumulation_steps: int + num_epochs: float + learning_rate: float + weight_decay: float + enable_gradient_checkpointing: bool + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + max_completion_length: int + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + debug_print_limit: int + limit_train_rows: int + mixed_stage1_ratio: float + mixed_stage2_ratio: float + eval_exact_set_match_stop: float + eval_value_precision_stop: float + eval_value_recall_stop: float + eval_solve_rate_stop: float + min_steps_before_stop: int + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + max_wall_clock_seconds: int + max_steps: int + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def build_training_examples( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +): + examples = [] + total_rows = len(rows) + eos_text = getattr(tokenizer, "eos_token", None) or "" + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + row_stage_i = training_stage_i_for_row(row, stage_i) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=row_stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + examples.append( + { + "prompt_text": prompt, + "completion_text": build_supervised_completion(ex, stage_i=row_stage_i) + eos_text, + "grid": ex.grid, + "solved": solved, + "target_cell": ex.target_cell, + "stage_i": int(row_stage_i), + } + ) + if progress_callback is not None and ( + row_idx == total_rows or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx=row_idx, total_rows=total_rows, example_count=len(examples)) + return examples + + +def _prepared_data_dir() -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "latent_multi_output_cell_policy") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_sft_cache_path( + *, + train_jsonl_path: str, + stage_i: int, + total_empties_hint: int, + limit_train_rows: int, + model_name: str, + dataset_tag: str, +) -> str: + payload = { + "completion_format_version": 2, + "kind": "sft", + "dataset_tag": str(dataset_tag), + "train_jsonl": os.path.abspath(train_jsonl_path), + "stage_i": int(stage_i), + "total_empties_hint": int(total_empties_hint), + "limit_train_rows": int(limit_train_rows), + "model_name": str(model_name), + } + digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:20] + return os.path.join(_prepared_data_dir(), f"sft_stage{int(stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=True) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_seconds: int = 6 * 60 * 60) -> None: + start = time.time() + while not os.path.exists(path): + if (time.time() - start) > float(timeout_seconds): + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def normalize_loss_weights(stage1_ratio: float, stage2_ratio: float) -> tuple[float, float]: + weight1 = max(0.0, float(stage1_ratio)) + weight2 = max(0.0, float(stage2_ratio)) + weight_sum = weight1 + weight2 + if weight_sum <= 0.0: + raise ValueError("At least one mixed curriculum ratio must be positive.") + return (weight1 / weight_sum, weight2 / weight_sum) + + +def load_or_build_sft_examples( + *, + cache_train_jsonl_path: str, + cache_dataset_tag: str, + rows: List[Dict[str, Any]], + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + limit_train_rows: int, + model_name: str, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_sft_cache_path( + train_jsonl_path=cache_train_jsonl_path, + stage_i=stage_i, + total_empties_hint=total_empties_hint, + limit_train_rows=limit_train_rows, + model_name=model_name, + dataset_tag=cache_dataset_tag, + ) + if os.path.exists(cache_path): + if rank == 0: + print(f"[dataset build][{cache_dataset_tag}] loading prepared cache: {cache_path}", flush=True) + return _read_jsonl(cache_path) + + if rank == 0: + print(f"[dataset build][{cache_dataset_tag}] building prepared cache: {cache_path}", flush=True) + built = build_training_examples( + rows, + tokenizer=tokenizer, + stage_i=stage_i, + total_empties_hint=total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + serializable = [ + { + "prompt_text": ex["prompt_text"], + "completion_text": ex["completion_text"], + } + for ex in built + ] + _write_jsonl(cache_path, serializable) + elif world_size > 1: + _wait_for_cache(cache_path) + + if world_size > 1 and dist.is_initialized(): + dist.barrier() + return _read_jsonl(cache_path) + + +def load_weighted_training_row_groups(args: Args) -> tuple[float, List[Dict[str, Any]], float, List[Dict[str, Any]]]: + stage1_weight, stage2_weight = normalize_loss_weights(args.mixed_stage1_ratio, args.mixed_stage2_ratio) + stage1_rows: List[Dict[str, Any]] = [] + stage2_rows: List[Dict[str, Any]] = [] + if stage1_weight > 0.0: + stage1_rows = load_jsonl_rows(args.train_jsonl_stage1 or args.train_jsonl, limit_rows=args.limit_train_rows) + if stage2_weight > 0.0: + stage2_rows = load_jsonl_rows(args.train_jsonl_stage2 or args.train_jsonl, limit_rows=args.limit_train_rows) + return stage1_weight, stage1_rows, stage2_weight, stage2_rows + + +def residual_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + # Keep SFT teacher-forced CE aligned with the same latent logits path used by + # eval / rollout decoding, including mix-gating, clipping, and fallback logic. + return shared_residual_next_token_logits_from_ids(model, input_ids, attention_mask, num_cot_tokens) + + +def latent_residual_completion_ce_loss( + model: nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, + *, + num_cot_tokens: int, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( + device + ) + completion_ids = all_ids[:, int(prompt_ids.shape[1]) :] + if int(completion_ids.shape[1]) <= 0: + return torch.zeros((), device=device, dtype=torch.float32, requires_grad=True) + + cur_ids = prompt_ids + cur_mask = torch.ones_like(prompt_ids, device=device) + token_losses: List[torch.Tensor] = [] + for idx in range(int(completion_ids.shape[1])): + logits = residual_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot_tokens) + target = completion_ids[:, idx] + token_losses.append(F.cross_entropy(logits.float(), target, reduction="mean")) + cur_ids = torch.cat([cur_ids, completion_ids[:, idx : idx + 1]], dim=1) + cur_mask = extend_attention_mask(cur_mask, 1) + return torch.stack(token_losses, dim=0).mean() + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: nn.Module, + tokenizer: Any, + device: torch.device, + eval_stage_i: int | None = None, + log_prefix: str = "latent sft eval", +) -> Dict[str, float]: + model.eval() + stage_i = int(eval_stage_i if eval_stage_i is not None else args.stage_i) + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + prompt_ids = enc["input_ids"].to(device) + attn = enc["attention_mask"].to(device) + completion_ids = sample_latent_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + pred_text = tokenizer.decode(completion_ids[0], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[latent sft eval debug] target=({rr+1},{cc+1}) output={pred_text!r}") + print(f"[latent sft eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}") + printed += 1 + solve_ok += int(row_all_exact) + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + } + print( + f"[{log_prefix}] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}" + ) + model.train() + return out + + +def run_dual_eval( + *, + args: Args, + eval_rows_stage1: List[Dict[str, Any]], + eval_rows_stage2: List[Dict[str, Any]], + model: nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + metrics_stage1 = run_eval( + args=args, + rows=eval_rows_stage1, + model=model, + tokenizer=tokenizer, + device=device, + eval_stage_i=1, + log_prefix="latent sft eval stage1", + ) + metrics_stage2 = run_eval( + args=args, + rows=eval_rows_stage2, + model=model, + tokenizer=tokenizer, + device=device, + eval_stage_i=max(1, int(args.stage_i)), + log_prefix=f"latent sft eval stage{int(args.stage_i)}", + ) + out = {f"stage1/{k}": float(v) for k, v in metrics_stage1.items()} + out.update({f"stage{int(args.stage_i)}/{k}": float(v) for k, v in metrics_stage2.items()}) + return out + + +def save_checkpoint(model: nn.Module, tokenizer: Any, output_dir: str, step: int) -> None: + save_checkpoint_and_update_final( + model, + tokenizer, + output_dir, + f"checkpoint-step-{step:05d}", + extra_save_fn=save_latent_projector_state, + ) + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument( + "--train_jsonl", + type=str, + default="/egr/research-slim/ghoshavr/curriculum-CoT/sudoku/llm_policy_icon/data/sudoku_t3_20empty_value_qwen_text.jsonl", + ) + p.add_argument("--train_jsonl_stage1", type=str, default="") + p.add_argument("--train_jsonl_stage2", type=str, default="") + p.add_argument( + "--eval_jsonl", + type=str, + default="", + help="If set, first eval_rows lines used for both stage1/stage2 eval. Else slice train files.", + ) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/egr/research-slim/ghoshavr/.hf_cache") + p.add_argument( + "--init_adapter_dir", + type=str, + default="", + help="Peft checkpoint dir, or empty for fresh LoRA on base (random).", + ) + p.add_argument("--lora_r", type=int, default=32) + p.add_argument("--lora_alpha", type=int, default=64) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=2) + p.add_argument("--num_cot_tokens", type=int, default=2) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--gradient_accumulation_steps", type=int, default=8) + p.add_argument("--num_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--weight_decay", type=float, default=0.0) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--logging_steps", type=int, default=10) + p.add_argument("--save_steps", type=int, default=100) + p.add_argument("--eval_steps", type=int, default=100) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-latent-multi-output-sft-residual-projector") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--mixed_stage1_ratio", type=float, default=0.0) + p.add_argument("--mixed_stage2_ratio", type=float, default=1.0) + p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) + p.add_argument( + "--eval_value_precision_stop", + type=float, + default=0.0, + help="With eval_value_recall_stop>0, stop when both hold on stage_i eval (after min_steps_before_stop).", + ) + p.add_argument("--eval_value_recall_stop", type=float, default=0.0) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + return Args(**vars(p.parse_args())) + + +def main() -> None: + args = parse_args() + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + if torch.cuda.is_available(): + if is_distributed: + torch.cuda.set_device(local_rank) + else: + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + if not preset_visible_devices and int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + if is_distributed and not dist.is_initialized(): + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + timeout=timedelta(hours=2), + ) + is_main_process = rank == 0 + + set_seed(args.seed + rank) + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + if is_distributed and dist.is_initialized(): + dist.barrier() + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) + + stage1_weight, stage1_rows, stage2_weight, stage2_rows = load_weighted_training_row_groups(args) + eval_src = str(getattr(args, "eval_jsonl", "") or "").strip() + if eval_src: + _eval_slice = load_jsonl_rows(eval_src, limit_rows=0)[: max(1, int(args.eval_rows))] + eval_rows_stage1 = _eval_slice + eval_rows_stage2 = _eval_slice + else: + eval_rows_stage1 = load_jsonl_rows(args.train_jsonl_stage1 or args.train_jsonl, limit_rows=0)[ + : max(1, int(args.eval_rows)) + ] + eval_rows_stage2 = load_jsonl_rows(args.train_jsonl_stage2 or args.train_jsonl, limit_rows=0)[ + : max(1, int(args.eval_rows)) + ] + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + if torch.cuda.is_available(): + device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") + else: + device = torch.device("cpu") + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter( + base, + args.init_adapter_dir, + lora_r=int(args.lora_r), + lora_alpha=int(args.lora_alpha), + lora_dropout=float(args.lora_dropout), + ) + init_ad = str(args.init_adapter_dir).strip() + if init_ad: + projector_hidden = infer_projector_hidden_from_state(init_ad) or PROJECTOR_HIDDEN + else: + projector_hidden = PROJECTOR_HIDDEN + attach_residual_projector_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + projector_hidden=projector_hidden, + ) + if init_ad: + maybe_load_projector_state(model, init_ad) + if is_main_process: + if init_ad: + print(f"Init adapter: {init_ad}", flush=True) + else: + print( + "init_adapter_dir empty: fresh LoRA (random) + residual projector random init " + f"(lora_r={args.lora_r} lora_alpha={args.lora_alpha}).", + flush=True, + ) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + backbone = unwrap_backbone(model) + if hasattr(backbone, "config"): + backbone.config.use_cache = False + model.to(device) + model.train() + + def on_prep_progress(*, dataset_tag: str, row_idx: int, total_rows: int, example_count: int) -> None: + if not is_main_process: + return + print( + f"[dataset build][{dataset_tag}] rows={row_idx}/{total_rows} examples={example_count}", + flush=True, + ) + if wb_run is not None: + wandb.log( + { + f"prep/{dataset_tag}_rows_done": float(row_idx), + f"prep/{dataset_tag}_rows_total": float(total_rows), + f"prep/{dataset_tag}_examples_built": float(example_count), + } + ) + + stage1_examples = load_or_build_sft_examples( + cache_train_jsonl_path=args.train_jsonl_stage1 or args.train_jsonl, + cache_dataset_tag="sft_stage1_weighted", + rows=stage1_rows, + tokenizer=tokenizer, + stage_i=1, + total_empties_hint=args.total_empties_hint, + limit_train_rows=args.limit_train_rows, + model_name=args.model_name, + rank=rank, + world_size=world_size, + progress_callback=lambda **kwargs: on_prep_progress(dataset_tag="sft_stage1_weighted", **kwargs), + ) + stage2_examples = load_or_build_sft_examples( + cache_train_jsonl_path=args.train_jsonl_stage2 or args.train_jsonl, + cache_dataset_tag=f"sft_stage{int(args.stage_i)}_weighted", + rows=stage2_rows, + tokenizer=tokenizer, + stage_i=int(args.stage_i), + total_empties_hint=args.total_empties_hint, + limit_train_rows=args.limit_train_rows, + model_name=args.model_name, + rank=rank, + world_size=world_size, + progress_callback=lambda **kwargs: on_prep_progress(dataset_tag=f"sft_stage{int(args.stage_i)}_weighted", **kwargs), + ) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/stage1_weight": float(stage1_weight), + "prep/stage2_weight": float(stage2_weight), + "prep/stage1_examples_final": float(len(stage1_examples)), + "prep/stage2_examples_final": float(len(stage2_examples)), + } + ) + optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) + examples_per_epoch = max(len(stage1_examples) if stage1_weight > 0.0 else 0, len(stage2_examples) if stage2_weight > 0.0 else 0, 1) + total_steps = max(1, math.ceil(examples_per_epoch * args.num_epochs / max(1, args.gradient_accumulation_steps))) + if int(args.max_steps) > 0: + total_steps = min(total_steps, int(args.max_steps)) + step = 0 + start_time = time.time() + + def average_scalar(value: float) -> float: + if not is_distributed or not dist.is_initialized(): + return float(value) + tensor = torch.tensor(float(value), device=device, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float((tensor / float(world_size)).item()) + + def all_reduce_gradients() -> None: + if not is_distributed or not dist.is_initialized(): + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(float(world_size)) + + def sync_stop(local_stop: bool) -> bool: + if not is_distributed or not dist.is_initialized(): + return bool(local_stop) + tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item()) > 0) + + def build_epoch_order(examples: List[Dict[str, Any]], *, seed_offset: int, epoch_idx: int) -> List[int]: + if not examples: + return [] + if is_distributed: + sampler = DistributedSampler( + examples, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=args.seed + seed_offset, + drop_last=False, + ) + sampler.set_epoch(epoch_idx) + return list(iter(sampler)) + generator = torch.Generator() + generator.manual_seed(args.seed + seed_offset + epoch_idx) + return torch.randperm(len(examples), generator=generator).tolist() + + def cycle_order(order: List[int], target_len: int) -> List[int]: + if not order or target_len <= 0: + return [] + if len(order) >= target_len: + return order[:target_len] + out: List[int] = [] + while len(out) < target_len: + out.extend(order) + return out[:target_len] + + for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): + stage1_order = build_epoch_order(stage1_examples, seed_offset=1009, epoch_idx=epoch_idx) + stage2_order = build_epoch_order(stage2_examples, seed_offset=2003, epoch_idx=epoch_idx) + epoch_micro_steps = max( + len(stage1_order) if stage1_weight > 0.0 else 0, + len(stage2_order) if stage2_weight > 0.0 else 0, + 1, + ) + stage1_order = cycle_order(stage1_order, epoch_micro_steps) + stage2_order = cycle_order(stage2_order, epoch_micro_steps) + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + for micro_idx in range(epoch_micro_steps): + total_loss = torch.zeros((), device=device, dtype=torch.float32) + stage1_loss_value = float("nan") + stage2_loss_value = float("nan") + if stage1_weight > 0.0: + ex_stage1 = stage1_examples[stage1_order[micro_idx]] + stage1_loss = latent_residual_completion_ce_loss( + model, + tokenizer, + ex_stage1["prompt_text"], + ex_stage1["completion_text"], + device, + num_cot_tokens=args.num_cot_tokens, + ) + total_loss = total_loss + (stage1_loss * stage1_weight) + stage1_loss_value = float(stage1_loss.detach().item()) + if stage2_weight > 0.0: + ex_stage2 = stage2_examples[stage2_order[micro_idx]] + stage2_loss = latent_residual_completion_ce_loss( + model, + tokenizer, + ex_stage2["prompt_text"], + ex_stage2["completion_text"], + device, + num_cot_tokens=args.num_cot_tokens, + ) + total_loss = total_loss + (stage2_loss * stage2_weight) + stage2_loss_value = float(stage2_loss.detach().item()) + scaled_loss = total_loss / max(1, int(args.gradient_accumulation_steps)) + scaled_loss.backward() + accum_count += 1 + if accum_count >= int(args.gradient_accumulation_steps): + all_reduce_gradients() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + step += 1 + if step % int(args.logging_steps) == 0: + loss_value = average_scalar(float(total_loss.detach().item())) + stage1_loss_log = average_scalar(stage1_loss_value) if stage1_weight > 0.0 else 0.0 + stage2_loss_log = average_scalar(stage2_loss_value) if stage2_weight > 0.0 else 0.0 + if is_main_process: + print( + f"[latent sft train step {step:05d}] loss={loss_value:.4f} " + f"stage1_loss={stage1_loss_log:.4f} stage2_loss={stage2_loss_log:.4f} " + f"stage1_w={stage1_weight:.2f} stage2_w={stage2_weight:.2f}", + flush=True, + ) + if wb_run is not None: + wandb.log( + { + "train/loss": loss_value, + "train/stage1_loss": stage1_loss_log, + "train/stage2_loss": stage2_loss_log, + "train/stage1_weight": float(stage1_weight), + "train/stage2_weight": float(stage2_weight), + "step": step, + } + ) + if step % int(args.eval_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + should_stop_eval = False + if is_main_process: + metrics = run_dual_eval( + args=args, + eval_rows_stage1=eval_rows_stage1, + eval_rows_stage2=eval_rows_stage2, + model=model, + tokenizer=tokenizer, + device=device, + ) + if wb_run is not None: + wandb.log({f"eval/{k}": float(v) for k, v in metrics.items()} | {"step": step}) + si = int(args.stage_i) + pfx = f"stage{si}/" + if ( + args.eval_exact_set_match_stop > 0.0 + and float(metrics[f"{pfx}exact_set_match_rate"]) + >= args.eval_exact_set_match_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + if ( + not should_stop_eval + and step >= int(args.min_steps_before_stop) + and float(args.eval_value_precision_stop) > 0.0 + and float(args.eval_value_recall_stop) > 0.0 + and float(metrics[f"{pfx}value_precision"]) + >= float(args.eval_value_precision_stop) + and float(metrics[f"{pfx}value_recall"]) + >= float(args.eval_value_recall_stop) + ): + print( + f"[latent sft eval] stopping early: value_precision=" + f"{float(metrics[f'{pfx}value_precision']):.3f} value_recall=" + f"{float(metrics[f'{pfx}value_recall']):.3f}", + flush=True, + ) + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + if ( + not should_stop_eval + and step >= int(args.min_steps_before_stop) + and float(args.eval_solve_rate_stop) > 0.0 + and float(metrics[f"{pfx}solve_rate"]) + >= float(args.eval_solve_rate_stop) + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + should_stop_eval = sync_stop(should_stop_eval) + if is_distributed and dist.is_initialized(): + dist.barrier() + if should_stop_eval: + if is_main_process and wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.destroy_process_group() + return + if step % int(args.save_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and ( + time.time() - start_time >= float(args.max_wall_clock_seconds) + ) + if sync_stop(reached_limit or exceeded_wall): + break + if accum_count > 0: + all_reduce_gradients() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + step += 1 + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and (time.time() - start_time >= float(args.max_wall_clock_seconds)) + if sync_stop(reached_limit or exceeded_wall): + break + + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + if is_main_process and wb_run is not None: + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/latent_multi_output_cell_policy/sft_latent_multi_output_train.py b/latent_multi_output_cell_policy/sft_latent_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..926feccf13a35b6bb5831dfb7721345f715231f0 --- /dev/null +++ b/latent_multi_output_cell_policy/sft_latent_multi_output_train.py @@ -0,0 +1,1055 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import os +import sys +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import AdamW +from torch.utils.data import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.rewards import score_prediction_text +from multi_output_cell_policy.shared_multi_output_policy import ( + build_supervised_completion, + make_solved_grid_from_row, + stage_i_consistent_values, +) +from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import ( + PROJECTOR_HIDDEN, + attach_fixed_latent_slot_modules, + attach_latent_seed_modules, + attach_residual_projector_modules, + extend_attention_mask, + fixed_slot_next_token_logits_from_ids, + infer_latent_seed_count_from_state, + infer_projector_hidden_from_state, + infer_fixed_slot_count_from_state, + latent_seed_next_token_logits_from_ids, + load_trainable_adapter, + maybe_load_fixed_slot_state, + maybe_load_latent_seed_state, + maybe_load_projector_state, + recurrent_hidden_next_token_logits_from_ids, + residual_next_token_logits_from_ids as shared_residual_next_token_logits_from_ids, + sample_recurrent_hidden_completion, + sample_fixed_slot_completion, + sample_latent_completion, + sample_latent_seed_completion, + save_fixed_slot_latent_state, + save_latent_projector_state, + save_latent_seed_state, + unwrap_backbone, +) + + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + train_jsonl_stage1: str + train_jsonl_stage2: str + eval_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + num_cot_tokens: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_epochs: float + learning_rate: float + weight_decay: float + max_grad_norm: float + enable_gradient_checkpointing: bool + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + max_completion_length: int + lora_r: int + lora_alpha: int + lora_dropout: float + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + debug_print_limit: int + limit_train_rows: int + eval_exact_set_match_stop: float + eval_value_precision_stop: float + eval_value_recall_stop: float + eval_solve_rate_stop: float + min_steps_before_stop: int + max_wall_clock_seconds: int + max_steps: int + mixed_stage1_ratio: float + mixed_stage2_ratio: float + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + multi_value_oversample_factor: int + train_target_size_min: int + train_target_size_max: int + eval_target_size_min: int + eval_target_size_max: int + latent_mode: str + max_latent_slots: int + max_latent_seeds: int + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def target_size_allowed(target_size: int, min_size: int, max_size: int) -> bool: + if int(min_size) > 0 and int(target_size) < int(min_size): + return False + if int(max_size) > 0 and int(target_size) > int(max_size): + return False + return True + + +def build_training_examples( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + examples: List[Dict[str, Any]] = [] + eos_text = getattr(tokenizer, "eos_token", None) or "" + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) + if not target_size_allowed( + len(target_values), + getattr(tokenizer, "_train_target_size_min", 0), + getattr(tokenizer, "_train_target_size_max", 0), + ): + continue + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + target_text = build_supervised_completion(ex, stage_i=stage_i) + if eos_text: + target_text = target_text + eos_text + repeat_count = max(1, int(getattr(tokenizer, "_multi_value_oversample_factor", 1))) if len(target_values) > 1 else 1 + for _ in range(repeat_count): + examples.append( + { + "prompt_text": prompt, + "completion_text": target_text, + "target_values": list(target_values), + "grid": ex.grid, + "solved": solved, + "target_cell": ex.target_cell, + } + ) + if progress_callback is not None and ( + row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx, len(rows), len(examples)) + return examples + + +def _prepared_data_dir(args: Args) -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_sft_cache_path(args: Args) -> str: + payload = json.dumps( + { + "completion_format_version": 2, + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + "multi_value_oversample_factor": int(args.multi_value_oversample_factor), + "train_target_size_min": int(args.train_target_size_min), + "train_target_size_max": int(args.train_target_size_max), + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(payload).hexdigest()[:20] + return os.path.join(_prepared_data_dir(args), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _to_jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_to_jsonable(v) for v in value] + if hasattr(value, "tolist"): + return _to_jsonable(value.tolist()) + return value + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(_to_jsonable(row), separators=(",", ":")) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: + start = time.time() + while not os.path.exists(path): + if time.time() - start > timeout_s: + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_sft_examples( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_sft_cache_path(args) + if os.path.exists(cache_path): + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + examples = build_training_examples( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + _write_jsonl(cache_path, examples) + return examples + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +def residual_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + # Keep SFT on the exact same latent rollout/logit path used by decoding and GRPO, + # so training is supervising the same latent mechanism we later sample from. + return shared_residual_next_token_logits_from_ids(model, input_ids, attention_mask, num_cot_tokens) + + +def fixed_slot_next_token_logits_for_sft( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + return fixed_slot_next_token_logits_from_ids(model, input_ids, attention_mask, num_cot_tokens) + + +def recurrent_hidden_next_token_logits_for_sft( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + return recurrent_hidden_next_token_logits_from_ids(model, input_ids, attention_mask, num_cot_tokens) + + +def latent_seed_next_token_logits_for_sft( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + return latent_seed_next_token_logits_from_ids(model, input_ids, attention_mask, num_cot_tokens) + + +def _tokenize_prompt_completion_pair(tokenizer: Any, prompt_text: str, completion_text: str) -> tuple[List[int], List[int]]: + prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids[0] + all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids[0] + completion_ids = all_ids[int(prompt_ids.shape[0]) :] + return prompt_ids.tolist(), completion_ids.tolist() + + +def prepare_tokenized_train_examples(examples: List[Dict[str, Any]], tokenizer: Any) -> List[Dict[str, Any]]: + for ex in examples: + if "prompt_ids" in ex and "completion_ids" in ex: + continue + prompt_ids, completion_ids = _tokenize_prompt_completion_pair( + tokenizer, + str(ex["prompt_text"]), + str(ex["completion_text"]), + ) + ex["prompt_ids"] = prompt_ids + ex["completion_ids"] = completion_ids + return examples + + +def latent_residual_completion_ce_loss( + model: nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, + *, + num_cot_tokens: int, + latent_mode: str, +) -> torch.Tensor: + prompt_ids, completion_ids = _tokenize_prompt_completion_pair(tokenizer, prompt_text, completion_text) + return latent_batched_completion_ce_loss( + model, + [{"prompt_ids": prompt_ids, "completion_ids": completion_ids}], + device, + num_cot_tokens=num_cot_tokens, + latent_mode=latent_mode, + pad_token_id=int(tokenizer.pad_token_id), + ) + + +def latent_batched_completion_ce_loss( + model: nn.Module, + batch_examples: List[Dict[str, Any]], + device: torch.device, + *, + num_cot_tokens: int, + latent_mode: str, + pad_token_id: int, +) -> torch.Tensor: + if not batch_examples: + return torch.zeros((), device=device, dtype=torch.float32, requires_grad=True) + prompt_token_lists = [list(ex.get("prompt_ids", [])) for ex in batch_examples] + completion_token_lists = [list(ex.get("completion_ids", [])) for ex in batch_examples] + if not any(completion_token_lists): + return torch.zeros((), device=device, dtype=torch.float32, requires_grad=True) + + batch_size = len(batch_examples) + max_prompt_len = max(len(ids) for ids in prompt_token_lists) + max_completion_len = max(len(ids) for ids in completion_token_lists) + total_seq_len = max_prompt_len + max_completion_len + + full_ids = torch.full((batch_size, total_seq_len), int(pad_token_id), dtype=torch.long, device=device) + full_mask = torch.zeros((batch_size, total_seq_len), dtype=torch.long, device=device) + completion_tokens = torch.full((batch_size, max_completion_len), int(pad_token_id), dtype=torch.long, device=device) + completion_mask = torch.zeros((batch_size, max_completion_len), dtype=torch.bool, device=device) + + for row_idx, (prompt_ids, completion_ids) in enumerate(zip(prompt_token_lists, completion_token_lists, strict=True)): + prompt_len = len(prompt_ids) + prompt_start = max_prompt_len - prompt_len + if prompt_len > 0: + full_ids[row_idx, prompt_start:max_prompt_len] = torch.tensor(prompt_ids, dtype=torch.long, device=device) + full_mask[row_idx, prompt_start:max_prompt_len] = 1 + if completion_ids: + comp_len = len(completion_ids) + completion_tokens[row_idx, :comp_len] = torch.tensor(completion_ids, dtype=torch.long, device=device) + completion_mask[row_idx, :comp_len] = True + + total_loss = torch.zeros((), device=device, dtype=torch.float32) + total_tokens = torch.zeros((), device=device, dtype=torch.float32) + mode = str(latent_mode).strip().lower() + for idx in range(max_completion_len): + cur_len = max_prompt_len + idx + cur_ids = full_ids[:, :cur_len] + cur_mask = full_mask[:, :cur_len] + if mode == "fixed_slots": + logits = fixed_slot_next_token_logits_for_sft(model, cur_ids, cur_mask, num_cot_tokens) + elif mode == "recurrent_hidden": + logits = recurrent_hidden_next_token_logits_for_sft(model, cur_ids, cur_mask, num_cot_tokens) + elif mode == "latent_seeds": + logits = latent_seed_next_token_logits_for_sft(model, cur_ids, cur_mask, num_cot_tokens) + else: + logits = residual_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot_tokens) + active = completion_mask[:, idx] + if torch.any(active): + targets = completion_tokens[:, idx] + token_losses = F.cross_entropy(logits.float(), targets, reduction="none") + total_loss = total_loss + token_losses[active].sum() + total_tokens = total_tokens + active.to(dtype=torch.float32).sum() + full_ids[:, cur_len] = completion_tokens[:, idx] + full_mask[:, cur_len] = active.to(dtype=full_mask.dtype) + if float(total_tokens.detach().item()) <= 0.0: + return torch.zeros((), device=device, dtype=torch.float32, requires_grad=True) + return total_loss / total_tokens + + +@torch.no_grad() +def run_eval(args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device): + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + solve_rows = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + row_has_eval_cell = False + row_debug_lines: List[str] = [] + for ex in build_cell_examples_from_row(row): + target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i) + if not target_size_allowed(len(target_values), int(args.eval_target_size_min), int(args.eval_target_size_max)): + continue + row_has_eval_cell = True + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + prompt_ids = enc["input_ids"].to(device) + attn = enc["attention_mask"].to(device) + sft_eval_mode = str(args.latent_mode).strip().lower() + if sft_eval_mode == "fixed_slots": + completion_ids = sample_fixed_slot_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + ) + elif sft_eval_mode == "recurrent_hidden": + completion_ids = sample_recurrent_hidden_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + ) + elif sft_eval_mode == "latent_seeds": + completion_ids = sample_latent_seed_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + ) + else: + completion_ids = sample_latent_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + ) + pred_text = tokenizer.decode(completion_ids[0], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + row_debug_lines.append( + f"[latent sft eval debug] true_values={info['target_values']} " + f"predicted_values={info['predicted_values']} output={pred_text!r}" + ) + if row_has_eval_cell: + if printed < int(args.debug_print_limit) and row_debug_lines: + print("[latent sft eval debug] puzzle_outputs_begin", flush=True) + for line in row_debug_lines: + print(line, flush=True) + print("[latent sft eval debug] puzzle_outputs_end", flush=True) + printed += 1 + solve_ok += int(row_all_exact) + solve_rows += 1 + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, solve_rows)), + } + print( + f"[latent sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}", + flush=True, + ) + model.train() + return out + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument("--train_jsonl", type=str, required=True) + p.add_argument("--train_jsonl_stage1", type=str, default="") + p.add_argument("--train_jsonl_stage2", type=str, default="") + p.add_argument("--eval_jsonl", type=str, default="") + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, default="") + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--num_cot_tokens", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=1) + p.add_argument("--gradient_accumulation_steps", type=int, default=8) + p.add_argument("--num_epochs", type=float, default=1.0) + p.add_argument("--learning_rate", type=float, default=2e-4) + p.add_argument("--weight_decay", type=float, default=0.0) + p.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Clip global grad norm before each optimizer step (0 disables).", + ) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--logging_steps", type=int, default=10) + p.add_argument("--save_steps", type=int, default=100) + p.add_argument("--eval_steps", type=int, default=100) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument( + "--lora_r", + type=int, + default=16, + help="LoRA rank. Use -1 to resolve to model hidden_size, i.e. full-rank adapters for hidden-width projections.", + ) + p.add_argument( + "--lora_alpha", + type=int, + default=32, + help="LoRA alpha. Use -1 to resolve to 2 * resolved_lora_r.", + ) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-latent-multi-output-sft") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) + p.add_argument("--eval_value_precision_stop", type=float, default=0.0) + p.add_argument("--eval_value_recall_stop", type=float, default=0.0) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--mixed_stage1_ratio", type=float, default=0.0) + p.add_argument("--mixed_stage2_ratio", type=float, default=1.0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--multi_value_oversample_factor", type=int, default=1) + p.add_argument("--train_target_size_min", type=int, default=0) + p.add_argument("--train_target_size_max", type=int, default=0) + p.add_argument("--eval_target_size_min", type=int, default=0) + p.add_argument("--eval_target_size_max", type=int, default=0) + p.add_argument( + "--latent_mode", + type=str, + default="residual", + choices=["residual", "fixed_slots", "recurrent_hidden", "latent_seeds"], + ) + p.add_argument("--max_latent_slots", type=int, default=8) + p.add_argument( + "--max_latent_seeds", + type=int, + default=8, + help="For --latent_mode latent_seeds: bank size of trainable seed vectors m_1..m_{max}. " + "Each stage uses --num_cot_tokens of them; the bank persists across stages when loaded from init_adapter_dir.", + ) + return Args(**vars(p.parse_args())) + + +def save_checkpoint( + model: torch.nn.Module, tokenizer: Any, output_dir: str, step: int, extra_save_fn: Any | None = None +) -> None: + save_checkpoint_and_update_final( + model, + tokenizer, + output_dir, + f"checkpoint-step-{step:05d}", + extra_save_fn=extra_save_fn, + ) + + +def main() -> None: + args = parse_args() + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + is_main_process = rank == 0 + + if preset_visible_devices: + if is_main_process: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + if is_distributed: + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", timeout=timedelta(hours=2)) + + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) + + train_source = args.train_jsonl + if int(args.stage_i) <= 1 and str(args.train_jsonl_stage1).strip(): + train_source = args.train_jsonl_stage1 + elif int(args.stage_i) > 1 and str(args.train_jsonl_stage2).strip(): + train_source = args.train_jsonl_stage2 + rows = load_jsonl_rows(train_source, limit_rows=args.limit_train_rows) + eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else train_source + eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows))) + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + tokenizer._multi_value_oversample_factor = max(1, int(args.multi_value_oversample_factor)) + tokenizer._train_target_size_min = max(0, int(args.train_target_size_min)) + tokenizer._train_target_size_max = max(0, int(args.train_target_size_max)) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") + else: + device = torch.device("cpu") + + latent_mode = str(args.latent_mode).strip().lower() + base = AutoModelForCausalLM.from_pretrained( + args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True + ) + model = load_trainable_adapter( + base, + args.init_adapter_dir, + lora_r=int(args.lora_r), + lora_alpha=int(args.lora_alpha), + lora_dropout=float(args.lora_dropout), + ) + init_ad = str(args.init_adapter_dir).strip() + if latent_mode == "fixed_slots": + max_latent_slots = max(1, int(args.max_latent_slots)) + if init_ad: + max_latent_slots = infer_fixed_slot_count_from_state(init_ad) or max_latent_slots + attach_fixed_latent_slot_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + max_latent_slots=max_latent_slots, + ) + if init_ad: + maybe_load_fixed_slot_state(model, init_ad) + extra_save_fn = save_fixed_slot_latent_state + elif latent_mode == "latent_seeds": + max_latent_seeds = max(1, int(args.max_latent_seeds)) + if init_ad: + max_latent_seeds = infer_latent_seed_count_from_state(init_ad) or max_latent_seeds + attach_latent_seed_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + max_latent_seeds=max_latent_seeds, + ) + if init_ad: + maybe_load_latent_seed_state(model, init_ad) + extra_save_fn = save_latent_seed_state + elif latent_mode == "recurrent_hidden": + extra_save_fn = None + else: + if init_ad: + projector_hidden = infer_projector_hidden_from_state(init_ad) or PROJECTOR_HIDDEN + else: + projector_hidden = PROJECTOR_HIDDEN + attach_residual_projector_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + projector_hidden=projector_hidden, + ) + if init_ad: + maybe_load_projector_state(model, init_ad) + extra_save_fn = save_latent_projector_state + if is_main_process: + if init_ad: + print(f"Init adapter: {init_ad}", flush=True) + elif latent_mode == "fixed_slots": + print( + "init_adapter_dir empty: fresh LoRA (random) + fixed latent slot random init " + f"(lora_r={args.lora_r} lora_alpha={args.lora_alpha} max_latent_slots={args.max_latent_slots}).", + flush=True, + ) + elif latent_mode == "recurrent_hidden": + print( + "init_adapter_dir empty: fresh LoRA (random) + recurrent hidden latent rollout " + f"(lora_r={args.lora_r} lora_alpha={args.lora_alpha} num_cot_tokens={args.num_cot_tokens}).", + flush=True, + ) + elif latent_mode == "latent_seeds": + print( + "init_adapter_dir empty: fresh LoRA (random) + trainable latent seed bank " + f"(lora_r={args.lora_r} lora_alpha={args.lora_alpha} " + f"max_latent_seeds={args.max_latent_seeds} num_cot_tokens={args.num_cot_tokens}).", + flush=True, + ) + else: + print( + "init_adapter_dir empty: fresh LoRA (random) + residual projector random init " + f"(lora_r={args.lora_r} lora_alpha={args.lora_alpha}).", + flush=True, + ) + model._latent_debug_tokenizer = tokenizer + model._fixed_slot_debug_limit = int(os.environ.get("FIXED_SLOT_DEBUG_LIMIT", "0")) + model._fixed_slot_debug_count = 0 + model._fixed_slot_decode_debug_limit = int(os.environ.get("FIXED_SLOT_DECODE_DEBUG_LIMIT", "0")) + model._fixed_slot_decode_debug_count = 0 + model._latent_vocab_debug_topk = int(os.environ.get("LATENT_VOCAB_DEBUG_TOPK", "1")) + model._attention_density_debug_limit = int(os.environ.get("ATTN_DENSITY_DEBUG_LIMIT", "0")) + model._attention_density_debug_count = 0 + model._attention_density_threshold_mult = float(os.environ.get("ATTN_DENSITY_THRESHOLD_MULT", "1.0")) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + backbone = unwrap_backbone(model) + if hasattr(backbone, "config"): + backbone.config.use_cache = False + model.to(device) + model.train() + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + + def on_prep_progress(rows_done: int, total_rows: int, examples_built: int) -> None: + if is_main_process: + print( + f"[dataset build][sft stage {args.stage_i}] rows={rows_done}/{total_rows} examples={examples_built}", + flush=True, + ) + if wb_run is not None: + wandb.log({"prep/rows_done": float(rows_done), "prep/examples_built": float(examples_built)}) + + train_examples = load_or_build_sft_examples( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + train_examples = prepare_tokenized_train_examples(train_examples, tokenizer) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))), + "prep/examples_final": float(len(train_examples)), + } + ) + + optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) + denom = max(1, int(args.gradient_accumulation_steps)) * max(1, int(args.per_device_train_batch_size)) + total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / denom)) + if int(args.max_steps) > 0: + total_steps = min(total_steps, int(args.max_steps)) + step = 0 + start_time = time.time() + + def average_scalar(value: float) -> float: + if not is_distributed or not dist.is_initialized(): + return float(value) + tensor = torch.tensor(float(value), device=device, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float((tensor / float(world_size)).item()) + + def all_reduce_gradients() -> None: + if not is_distributed or not dist.is_initialized(): + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(float(world_size)) + + def sync_stop(local_stop: bool) -> bool: + if not is_distributed or not dist.is_initialized(): + return bool(local_stop) + tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item()) > 0) + + for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): + if is_distributed: + sampler = DistributedSampler( + train_examples, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=args.seed, + drop_last=False, + ) + sampler.set_epoch(epoch_idx) + order = list(iter(sampler)) + else: + generator = torch.Generator() + generator.manual_seed(args.seed + epoch_idx) + order = torch.randperm(len(train_examples), generator=generator).tolist() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + accum_ce_sum = 0.0 + microbatch_size = max(1, int(args.per_device_train_batch_size)) + for batch_start in range(0, len(order), microbatch_size): + batch_indices = order[batch_start : batch_start + microbatch_size] + batch_examples = [train_examples[ex_idx] for ex_idx in batch_indices] + ce_full = latent_batched_completion_ce_loss( + model, + batch_examples, + device, + num_cot_tokens=args.num_cot_tokens, + latent_mode=latent_mode, + pad_token_id=int(tokenizer.pad_token_id), + ) + loss = ce_full / max(1, int(args.gradient_accumulation_steps)) + loss.backward() + accum_ce_sum += float(ce_full.detach().item()) + accum_count += 1 + if accum_count >= int(args.gradient_accumulation_steps): + all_reduce_gradients() + if float(args.max_grad_norm) > 0.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.max_grad_norm)) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + step += 1 + mean_ce = accum_ce_sum / max(1, int(args.gradient_accumulation_steps)) + accum_ce_sum = 0.0 + if step % int(args.logging_steps) == 0: + loss_value = average_scalar(mean_ce) + if is_main_process: + print(f"[latent sft train step {step:05d}] loss={loss_value:.4f}", flush=True) + if wb_run is not None: + wandb.log({"train/loss": loss_value, "step": step}) + if step % int(args.eval_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + should_stop_eval = False + if is_main_process: + eval_started_at = time.time() + print( + f"[latent sft eval start step {step:05d}] rows={len(eval_rows)}", + flush=True, + ) + if wb_run is not None: + wandb.log({"eval/in_progress": 1.0, "eval/rows": float(len(eval_rows)), "step": step}) + ev = run_eval(args, eval_rows, model, tokenizer, device) + eval_duration = time.time() - eval_started_at + print( + f"[latent sft eval end step {step:05d}] duration_seconds={eval_duration:.1f}", + flush=True, + ) + if wb_run is not None: + wandb.log( + {f"eval/{k}": float(v) for k, v in ev.items()} + | { + "eval/duration_seconds": float(eval_duration), + "eval/in_progress": 0.0, + "step": step, + } + ) + if ( + args.eval_exact_set_match_stop > 0.0 + and float(ev["exact_set_match_rate"]) >= args.eval_exact_set_match_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step, extra_save_fn=extra_save_fn) + should_stop_eval = True + if ( + not should_stop_eval + and step >= int(args.min_steps_before_stop) + and args.eval_value_precision_stop > 0.0 + and args.eval_value_recall_stop > 0.0 + and float(ev["value_precision"]) >= args.eval_value_precision_stop + and float(ev["value_recall"]) >= args.eval_value_recall_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step, extra_save_fn=extra_save_fn) + should_stop_eval = True + if ( + not should_stop_eval + and args.eval_solve_rate_stop > 0.0 + and step >= int(args.min_steps_before_stop) + and float(ev["solve_rate"]) >= args.eval_solve_rate_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step, extra_save_fn=extra_save_fn) + should_stop_eval = True + should_stop_eval = sync_stop(should_stop_eval) + if is_distributed and dist.is_initialized(): + dist.barrier() + if should_stop_eval: + if is_main_process and wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.destroy_process_group() + return + if step % int(args.save_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step, extra_save_fn=extra_save_fn) + if is_distributed and dist.is_initialized(): + dist.barrier() + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and ( + time.time() - start_time >= float(args.max_wall_clock_seconds) + ) + should_stop = sync_stop(reached_limit or exceeded_wall) + if should_stop: + break + if sync_stop(step >= total_steps): + break + + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step, extra_save_fn=extra_save_fn) + if wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/mixed_curriculum_cot/__init__.py b/mixed_curriculum_cot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/mixed_curriculum_cot/__init__.py @@ -0,0 +1 @@ + diff --git a/mixed_curriculum_cot/build_mixed_curriculum_dataset.py b/mixed_curriculum_cot/build_mixed_curriculum_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2b1b1d0fb6f920332d323336206b4a45b318867f --- /dev/null +++ b/mixed_curriculum_cot/build_mixed_curriculum_dataset.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import argparse +import json +import random +from pathlib import Path +from typing import Dict, Iterable, List, Sequence + + +def parse_args() -> argparse.Namespace: + root = Path(__file__).resolve().parent.parent + default_base = root / "data" / "sudoku_t3_30empty_value_qwen_text.jsonl" + default_output = root / "data" / "mixed_curriculum_cot" / "30empty" + p = argparse.ArgumentParser() + p.add_argument("--base_jsonl", type=str, default=str(default_base)) + p.add_argument("--output_dir", type=str, default=str(default_output)) + p.add_argument("--max_stage", type=int, default=4) + p.add_argument("--rows_per_stage", type=int, default=0) + p.add_argument("--current_stage_fraction", type=float, default=0.8) + p.add_argument("--seed", type=int, default=0) + return p.parse_args() + + +def load_jsonl(path: Path) -> List[Dict[str, object]]: + rows: List[Dict[str, object]] = [] + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def write_jsonl(path: Path, rows: Iterable[Dict[str, object]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, separators=(",", ":")) + "\n") + + +def stage_mix(stage: int, current_fraction: float) -> List[tuple[int, float]]: + stage = int(stage) + if stage <= 1: + return [(1, 1.0)] + current_fraction = max(0.0, min(1.0, float(current_fraction))) + prev_fraction = 1.0 - current_fraction + return [(stage, current_fraction), (stage - 1, prev_fraction)] + + +def sample_rows( + base_rows: Sequence[Dict[str, object]], + *, + count: int, + rng: random.Random, + offset: int, +) -> List[Dict[str, object]]: + total = len(base_rows) + if total == 0: + return [] + if count <= total: + indices = list(range(total)) + rng.shuffle(indices) + chosen = indices[:count] + return [dict(base_rows[idx]) for idx in chosen] + rows: List[Dict[str, object]] = [] + for ii in range(int(count)): + idx = (int(offset) + ii) % total + rows.append(dict(base_rows[idx])) + rng.shuffle(rows) + return rows + + +def build_stage_rows( + base_rows: Sequence[Dict[str, object]], + *, + target_stage: int, + rows_per_stage: int, + current_stage_fraction: float, + rng: random.Random, +) -> List[Dict[str, object]]: + mixed_rows: List[Dict[str, object]] = [] + mix = stage_mix(target_stage, current_stage_fraction) + assigned = 0 + for item_idx, (source_stage, fraction) in enumerate(mix): + if item_idx == len(mix) - 1: + count = int(rows_per_stage - assigned) + else: + count = int(round(float(rows_per_stage) * float(fraction))) + assigned += count + sampled = sample_rows( + base_rows, + count=count, + rng=rng, + offset=target_stage * 100003 + source_stage * 1009 + item_idx * 17, + ) + for row in sampled: + metadata = dict(row.get("metadata", {})) + metadata["mixed_curriculum_target_stage"] = int(target_stage) + metadata["mixed_curriculum_source_stage"] = int(source_stage) + metadata["mixed_curriculum_fraction"] = float(fraction) + row["metadata"] = metadata + mixed_rows.append(row) + rng.shuffle(mixed_rows) + return mixed_rows + + +def main() -> None: + args = parse_args() + base_path = Path(args.base_jsonl).resolve() + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + base_rows = load_jsonl(base_path) + if not base_rows: + raise RuntimeError(f"No rows found in {base_path}") + + rows_per_stage = int(args.rows_per_stage) if int(args.rows_per_stage) > 0 else len(base_rows) + rng = random.Random(int(args.seed)) + manifest: Dict[str, object] = { + "base_jsonl": str(base_path), + "output_dir": str(output_dir), + "max_stage": int(args.max_stage), + "rows_per_stage": int(rows_per_stage), + "current_stage_fraction": float(args.current_stage_fraction), + "stages": {}, + } + + for stage in range(1, int(args.max_stage) + 1): + stage_rows = build_stage_rows( + base_rows, + target_stage=stage, + rows_per_stage=rows_per_stage, + current_stage_fraction=float(args.current_stage_fraction), + rng=rng, + ) + stage_file = output_dir / f"stage{stage:02d}_mixed.jsonl" + write_jsonl(stage_file, stage_rows) + manifest["stages"][f"stage{stage:02d}"] = { + "path": str(stage_file), + "mix": [ + {"source_stage": int(source_stage), "fraction": float(fraction)} + for source_stage, fraction in stage_mix(stage, float(args.current_stage_fraction)) + ], + "rows": int(len(stage_rows)), + } + print( + f"Wrote stage {stage} mixed dataset to {stage_file} " + f"with mix={manifest['stages'][f'stage{stage:02d}']['mix']}", + flush=True, + ) + + manifest_path = output_dir / "mixed_curriculum_manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + print(f"Wrote manifest to {manifest_path}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/mixed_curriculum_cot/launch_latent_stage2_mixed_curriculum_cot_70_30.sh b/mixed_curriculum_cot/launch_latent_stage2_mixed_curriculum_cot_70_30.sh new file mode 100644 index 0000000000000000000000000000000000000000..4941cde1ac889dacb28762be4444a20cedf2ea56 --- /dev/null +++ b/mixed_curriculum_cot/launch_latent_stage2_mixed_curriculum_cot_70_30.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +SFT_TRAINER="${ROOT}/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py" +GRPO_TRAINER="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py" + +BASE_JSONL="${BASE_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}" +TRAIN_JSONL_STAGE1="${TRAIN_JSONL_STAGE1:-${BASE_JSONL}}" +TRAIN_JSONL_STAGE2="${TRAIN_JSONL_STAGE2:-${BASE_JSONL}}" +TRAIN_JSONL="${TRAIN_JSONL:-${TRAIN_JSONL_STAGE2}}" + +STAGE_I="${STAGE_I:-2}" +NUM_COT_TOKENS="${NUM_COT_TOKENS:-2}" +MIX_STAGE1_RATIO="${MIX_STAGE1_RATIO:-30}" +MIX_STAGE2_RATIO="${MIX_STAGE2_RATIO:-70}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" + +STAGE1_LATENT_GRPO_DIR="${STAGE1_LATENT_GRPO_DIR:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" + +OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/mixed_curriculum_cot_70_30/latent}" +SFT_OUTPUT_DIR="${SFT_OUTPUT_DIR:-${OUTPUT_ROOT}/stage02_sft_i2_cot2_${RUN_TAG}}" +GRPO_OUTPUT_DIR="${GRPO_OUTPUT_DIR:-${OUTPUT_ROOT}/stage02_grpo_i2_cot2_${RUN_TAG}}" + +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export CUDA_VISIBLE_DEVICES="${GPU_IDS}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +mkdir -p "${SFT_OUTPUT_DIR}" "${GRPO_OUTPUT_DIR}" + +common_wandb_args=() +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + common_wandb_args+=(--use_wandb) +fi +if [[ -n "${WANDB_ENTITY:-}" ]]; then + common_wandb_args+=(--wandb_entity "${WANDB_ENTITY}") +fi + +sft_cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_TRAINER}" + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --train_jsonl_stage1 "${TRAIN_JSONL_STAGE1}" + --train_jsonl_stage2 "${TRAIN_JSONL_STAGE2}" + --mixed_stage1_ratio "${MIX_STAGE1_RATIO}" + --mixed_stage2_ratio "${MIX_STAGE2_RATIO}" + --output_dir "${SFT_OUTPUT_DIR}" + --init_adapter_dir "${STAGE1_LATENT_GRPO_DIR}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --num_cot_tokens "${NUM_COT_TOKENS}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --learning_rate "${SFT_LEARNING_RATE:-1e-6}" + --weight_decay "${SFT_WEIGHT_DECAY:-0.0}" + --enable_gradient_checkpointing + --logging_steps "${SFT_LOGGING_STEPS:-10}" + --save_steps "${SFT_SAVE_STEPS:-100}" + --eval_steps "${SFT_EVAL_STEPS:-100}" + --eval_rows "${SFT_EVAL_ROWS:-20}" + --max_completion_length "${SFT_MAX_COMPLETION_LENGTH:-32}" + --max_wall_clock_seconds "${SFT_MAX_WALL_CLOCK_SECONDS:-0}" + --wandb_project "${SFT_WANDB_PROJECT:-sudoku-latent-multi-output-sft-residual-projector}" + --wandb_run_name "${SFT_WANDB_RUN_NAME:-latent_stage02_sft_mixed_curriculum_cot_70_30_${RUN_TAG}}" + --wandb_mode "${WANDB_MODE:-offline}" +) +sft_cmd+=("${common_wandb_args[@]}") + +grpo_cmd=( + "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_TRAINER}" + --model_name "${MODEL_NAME}" + --train_jsonl "${TRAIN_JSONL}" + --train_jsonl_stage1 "${TRAIN_JSONL_STAGE1}" + --train_jsonl_stage2 "${TRAIN_JSONL_STAGE2}" + --mixed_stage1_ratio "${MIX_STAGE1_RATIO}" + --mixed_stage2_ratio "${MIX_STAGE2_RATIO}" + --output_dir "${GRPO_OUTPUT_DIR}" + --init_adapter_dir "${SFT_OUTPUT_DIR}" + --cache_dir "${CACHE_DIR}" + --gpu_id 0 + --stage_i "${STAGE_I}" + --num_cot_tokens "${NUM_COT_TOKENS}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-30}" + --per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-4}" + --gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-2}" + --num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}" + --learning_rate "${GRPO_LEARNING_RATE:-7e-7}" + --logging_steps "${GRPO_LOGGING_STEPS:-5}" + --save_steps "${GRPO_SAVE_STEPS:-25}" + --eval_steps "${GRPO_EVAL_STEPS:-25}" + --eval_rows "${GRPO_EVAL_ROWS:-20}" + --num_generations "${GRPO_NUM_GENERATIONS:-2}" + --max_prompt_length "${GRPO_MAX_PROMPT_LENGTH:-1024}" + --max_completion_length "${GRPO_MAX_COMPLETION_LENGTH:-32}" + --beta "${GRPO_BETA:-0.01}" + --enable_gradient_checkpointing + --max_wall_clock_seconds "${GRPO_MAX_WALL_CLOCK_SECONDS:-0}" + --wandb_project "${GRPO_WANDB_PROJECT:-sudoku-latent-multi-output-grpo-residual-projector}" + --wandb_run_name "${GRPO_WANDB_RUN_NAME:-latent_stage02_grpo_mixed_curriculum_cot_70_30_${RUN_TAG}}" + --wandb_group "${GRPO_WANDB_GROUP:-mixed_curriculum_cot_70_30}" + --wandb_mode "${WANDB_MODE:-offline}" +) +grpo_cmd+=("${common_wandb_args[@]}") + +echo "Launching latent mixed curriculum_cot 70-30 stage-2 SFT" +"${sft_cmd[@]}" + +echo "Launching latent mixed curriculum_cot 70-30 stage-2 GRPO" +"${grpo_cmd[@]}" + diff --git a/mixed_curriculum_cot/runtime_mixed_curriculum.py b/mixed_curriculum_cot/runtime_mixed_curriculum.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6ea589edd587c94248f57e54ce3cc5a8353fce --- /dev/null +++ b/mixed_curriculum_cot/runtime_mixed_curriculum.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import random +from typing import Any, Dict, List, Sequence + + +def sample_rows( + rows: Sequence[Dict[str, Any]], + *, + count: int, + rng: random.Random, + offset: int, +) -> List[Dict[str, Any]]: + total = len(rows) + if total == 0 or int(count) <= 0: + return [] + if int(count) <= total: + indices = list(range(total)) + rng.shuffle(indices) + chosen = indices[: int(count)] + return [dict(rows[idx]) for idx in chosen] + out: List[Dict[str, Any]] = [] + for item_idx in range(int(count)): + out.append(dict(rows[(int(offset) + item_idx) % total])) + rng.shuffle(out) + return out + + +def annotate_mixed_curriculum_row( + row: Dict[str, Any], + *, + source_stage: int, + target_stage: int, + fraction: float, +) -> Dict[str, Any]: + out = dict(row) + metadata = dict(out.get("metadata", {})) + metadata["mixed_curriculum_target_stage"] = int(target_stage) + metadata["mixed_curriculum_source_stage"] = int(source_stage) + metadata["mixed_curriculum_fraction"] = float(fraction) + metadata["mixed_curriculum_runtime"] = True + out["metadata"] = metadata + return out + + +def build_two_stage_mixed_rows( + stage1_rows: Sequence[Dict[str, Any]], + stage2_rows: Sequence[Dict[str, Any]], + *, + stage1_ratio: float, + stage2_ratio: float, + seed: int, + target_stage: int, + total_rows: int = 0, +) -> List[Dict[str, Any]]: + weight1 = max(0.0, float(stage1_ratio)) + weight2 = max(0.0, float(stage2_ratio)) + weight_sum = weight1 + weight2 + if weight_sum <= 0.0: + raise ValueError("At least one mixed curriculum ratio must be positive.") + + if int(total_rows) > 0: + target_count = int(total_rows) + else: + target_count = max(len(stage1_rows), len(stage2_rows)) + target_count = max(1, int(target_count)) + + count1 = int(round(target_count * (weight1 / weight_sum))) + count1 = min(max(count1, 0), target_count) + count2 = int(target_count - count1) + rng = random.Random(int(seed)) + + mixed: List[Dict[str, Any]] = [] + for row in sample_rows(stage1_rows, count=count1, rng=rng, offset=1009): + mixed.append( + annotate_mixed_curriculum_row( + row, + source_stage=1, + target_stage=int(target_stage), + fraction=(weight1 / weight_sum), + ) + ) + for row in sample_rows(stage2_rows, count=count2, rng=rng, offset=2003): + mixed.append( + annotate_mixed_curriculum_row( + row, + source_stage=2, + target_stage=int(target_stage), + fraction=(weight2 / weight_sum), + ) + ) + rng.shuffle(mixed) + return mixed + + +def training_stage_i_for_row(row: Dict[str, Any], default_stage_i: int) -> int: + metadata = dict(row.get("metadata", {})) + source_stage = metadata.get("mixed_curriculum_source_stage") + if source_stage is not None: + return max(1, int(source_stage)) + row_stage = row.get("stage_i") + if row_stage is not None: + return max(1, int(row_stage)) + return max(1, int(default_stage_i)) diff --git a/multi_output_cell_policy/__init__.py b/multi_output_cell_policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe049dd27343195a95167ea280c34dfb64dbc8a --- /dev/null +++ b/multi_output_cell_policy/__init__.py @@ -0,0 +1 @@ +"""Recovered package wrapper; original source is currently unavailable.""" diff --git a/multi_output_cell_policy/grpo_multi_output_train.py b/multi_output_cell_policy/grpo_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..8392639bac9d7bee5cf1959a4b1f8a4be573235b --- /dev/null +++ b/multi_output_cell_policy/grpo_multi_output_train.py @@ -0,0 +1,782 @@ +from __future__ import annotations + +import argparse +import hashlib +import inspect +import json +import os +import sys +import time +from dataclasses import dataclass +from typing import Any, Dict, List + +import torch +from datasets import Dataset +from peft import LoraConfig, PeftModel, get_peft_model +from safetensors.torch import load_file as load_safetensors_file +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_model_artifacts +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.rewards import score_prediction_text +from multi_output_cell_policy.shared_multi_output_policy import make_solved_grid_from_row + + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + eval_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_train_epochs: float + learning_rate: float + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + num_generations: int + max_prompt_length: int + max_completion_length: int + beta: float + lora_r: int + lora_alpha: int + lora_dropout: float + enable_gradient_checkpointing: bool + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + wandb_group: str + wandb_run_id: str + debug_print_limit: int + limit_train_rows: int + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + penalty_missing: float + exact_match_bonus: float + cardinality_mismatch_penalty: float + eval_value_precision_stop: float + eval_value_recall_stop: float + eval_solve_rate_stop: float + min_steps_before_stop: int + max_wall_clock_seconds: int + max_steps: int + resume_from_checkpoint: str + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available(): + try: + if torch.cuda.is_bf16_supported(): + return torch.bfloat16 + except Exception: + pass + return torch.float16 + + +def ensure_trl_fsdp_compat() -> None: + try: + import torch.distributed.fsdp as fsdp + + if not hasattr(fsdp, "FSDPModule") and hasattr(fsdp, "FullyShardedDataParallel"): + fsdp.FSDPModule = fsdp.FullyShardedDataParallel + except Exception: + pass + + +def load_trainable_adapter(base_model: torch.nn.Module, adapter_dir: str) -> torch.nn.Module: + try: + return PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=True) + except Exception: + config_path = os.path.join(adapter_dir, "adapter_config.json") + model_path = os.path.join(adapter_dir, "adapter_model.safetensors") + if not (os.path.exists(config_path) and os.path.exists(model_path)): + raise + + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + lora = LoraConfig( + r=int(cfg["r"]), + lora_alpha=int(cfg["lora_alpha"]), + lora_dropout=float(cfg["lora_dropout"]), + bias=str(cfg.get("bias", "none")), + task_type=str(cfg.get("task_type", "CAUSAL_LM")), + target_modules=list(cfg["target_modules"]), + ) + model = get_peft_model(base_model, lora) + state = load_safetensors_file(model_path) + remapped: Dict[str, torch.Tensor] = {} + for key, value in state.items(): + new_key = key.replace(".lora_A.weight", ".lora_A.default.weight") + new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") + remapped[new_key] = value + model.load_state_dict(remapped, strict=False) + return model + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def build_grpo_records( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + records: List[Dict[str, Any]] = [] + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + records.append( + { + "prompt": prompt, + "grid_json": json.dumps(ex.grid.tolist(), separators=(",", ":")), + "solved_json": json.dumps(solved.tolist(), separators=(",", ":")), + "target_row": int(ex.target_cell[0]), + "target_col": int(ex.target_cell[1]), + "stage_i": int(stage_i), + } + ) + if progress_callback is not None and ( + row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx, len(rows), len(records)) + return records + + +def _prepared_data_dir(args: Args) -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_grpo_cache_path(args: Args) -> str: + payload = json.dumps( + { + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(payload).hexdigest()[:20] + return os.path.join(_prepared_data_dir(args), f"grpo_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, separators=(",", ":")) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: + start = time.time() + while not os.path.exists(path): + if time.time() - start > timeout_s: + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_grpo_records( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_grpo_cache_path(args) + if os.path.exists(cache_path): + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + records = build_grpo_records( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + _write_jsonl(cache_path, records) + return records + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +def make_reward_func(args: Args): + def reward_func(completions, grid_json, solved_json, target_row, target_col, stage_i, **kwargs): + rewards: List[float] = [] + for completion, grid_s, solved_s, rr, cc, stage_val in zip( + completions, grid_json, solved_json, target_row, target_col, stage_i + ): + info = score_prediction_text( + text=str(completion), + grid=torch.tensor(json.loads(grid_s), dtype=torch.long).numpy(), + solved=torch.tensor(json.loads(solved_s), dtype=torch.long).numpy(), + target_cell=(int(rr), int(cc)), + stage_i=int(stage_val), + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + penalty_missing=args.penalty_missing, + exact_match_bonus=args.exact_match_bonus, + cardinality_mismatch_penalty=args.cardinality_mismatch_penalty, + ) + rewards.append(float(info["reward"])) + return rewards + + return reward_func + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: torch.nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + enc = {k: v.to(device) for k, v in enc.items()} + out = model.generate( + **enc, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + penalty_missing=args.penalty_missing, + exact_match_bonus=args.exact_match_bonus, + cardinality_mismatch_penalty=args.cardinality_mismatch_penalty, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[baseline grpo eval debug] target=({rr+1},{cc+1}) output={pred_text!r}", flush=True) + print( + f"[baseline grpo eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}", + flush=True, + ) + printed += 1 + solve_ok += int(row_all_exact) + + return { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + "eval_cells": float(total_cells), + } + + +def unwrap_training_model(model: Any) -> Any: + current = model + while hasattr(current, "module"): + current = current.module + return current + + +class CustomEvalCallback(TrainerCallback): + def __init__( + self, + args: Args, + eval_rows: List[Dict[str, Any]], + tokenizer: Any, + device: torch.device, + wb_run: Any, + is_main_process: bool, + ): + self.args = args + self.eval_rows = eval_rows + self.tokenizer = tokenizer + self.device = device + self.wb_run = wb_run + self.is_main_process = is_main_process + self.last_logged_step = -1 + + def on_step_end(self, args, state, control, **kwargs): + step = int(state.global_step) + eval_every = int(self.args.eval_steps) + if step <= 0 or step % eval_every != 0: + return control + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + use_dist = world_size > 1 and torch.distributed.is_available() and torch.distributed.is_initialized() + stop_tensor = torch.zeros(1, dtype=torch.int32, device=self.device) + + if self.is_main_process: + if step != self.last_logged_step: + model = kwargs.get("model") + if model is not None: + metrics = run_eval( + args=self.args, + rows=self.eval_rows, + model=unwrap_training_model(model), + tokenizer=self.tokenizer, + device=self.device, + ) + self.last_logged_step = step + print( + f"[baseline grpo custom eval step {step}] parse={metrics['parse_rate']:.3f} " + f"solve={metrics['solve_rate']:.3f} " + f"avg_set_size={metrics['avg_predicted_set_size']:.3f} " + f"good={metrics['avg_num_i_consistent_values']:.3f} " + f"bad={metrics['avg_num_non_i_consistent_values']:.3f}", + flush=True, + ) + if self.args.use_wandb and self.wb_run is not None: + payload = {f"custom_eval/{k}": float(v) for k, v in metrics.items()} + payload["custom_eval/global_step"] = float(step) + wandb.log(payload) + if ( + float(self.args.eval_value_precision_stop) > 0.0 + and float(self.args.eval_value_recall_stop) > 0.0 + and step >= int(self.args.min_steps_before_stop) + and float(metrics["value_precision"]) >= float(self.args.eval_value_precision_stop) + and float(metrics["value_recall"]) >= float(self.args.eval_value_recall_stop) + ): + print( + f"[baseline grpo custom eval step {step}] stopping early: " + f"value_precision={metrics['value_precision']:.3f} >= {float(self.args.eval_value_precision_stop):.3f} " + f"and value_recall={metrics['value_recall']:.3f} >= {float(self.args.eval_value_recall_stop):.3f}", + flush=True, + ) + stop_tensor[0] = 1 + if ( + int(stop_tensor.item()) == 0 + and float(self.args.eval_solve_rate_stop) > 0.0 + and step >= int(self.args.min_steps_before_stop) + and float(metrics["solve_rate"]) >= float(self.args.eval_solve_rate_stop) + ): + print( + f"[baseline grpo custom eval step {step}] stopping early: " + f"solve_rate={metrics['solve_rate']:.3f} >= {float(self.args.eval_solve_rate_stop):.3f}", + flush=True, + ) + stop_tensor[0] = 1 + + if use_dist: + torch.distributed.broadcast(stop_tensor, src=0) + + if int(stop_tensor.item()) != 0: + control.should_training_stop = True + return control + + +class FinalCheckpointCallback(TrainerCallback): + def __init__(self, output_dir: str, tokenizer: Any, is_main_process: bool): + self.output_dir = output_dir + self.tokenizer = tokenizer + self.is_main_process = is_main_process + + def _save(self, model: Any) -> None: + if self.is_main_process: + save_model_artifacts(unwrap_training_model(model), self.tokenizer, ensure_final_checkpoint_dir(self.output_dir)) + + def on_save(self, args, state, control, **kwargs): + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + def on_train_end(self, args, state, control, **kwargs): + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + +class WallClockStopCallback(TrainerCallback): + def __init__(self, max_wall_clock_seconds: int): + self.max_wall_clock_seconds = int(max_wall_clock_seconds) + self.start_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if self.max_wall_clock_seconds > 0 and (time.time() - self.start_time) >= float(self.max_wall_clock_seconds): + control.should_training_stop = True + return control + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument("--train_jsonl", type=str, required=True) + p.add_argument("--eval_jsonl", type=str, default="") + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, required=True) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=2) + p.add_argument("--gradient_accumulation_steps", type=int, default=4) + p.add_argument("--num_train_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--logging_steps", type=int, default=5) + p.add_argument("--save_steps", type=int, default=25) + p.add_argument("--eval_steps", type=int, default=25) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--num_generations", type=int, default=2) + p.add_argument("--max_prompt_length", type=int, default=1024) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--beta", type=float, default=0.0) + p.add_argument("--lora_r", type=int, default=8) + p.add_argument("--lora_alpha", type=int, default=16) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-grpo") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--wandb_group", type=str, default="") + p.add_argument("--wandb_run_id", type=str, default="") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument( + "--penalty_missing", + type=float, + default=0.0, + help="Per-missing-value penalty: reward -= penalty_missing * |target_set \\ predicted_set|. " + "Defaults to 0 (legacy); set ~0.75 at stage>=2 to push recall up.", + ) + p.add_argument( + "--exact_match_bonus", + type=float, + default=0.0, + help="Bonus added only when set(predicted_values) == set(target_values) and prediction is non-empty. " + "Defaults to 0; set ~2.0 to strictly dominate partial supersets.", + ) + p.add_argument( + "--cardinality_mismatch_penalty", + type=float, + default=0.0, + help="Penalty when len(predicted_values) < len(target_values) for multi-value targets " + "(stage-agnostic). Defaults to 0; set ~1.0 at stage>=2 to deter under-prediction.", + ) + p.add_argument("--eval_value_precision_stop", type=float, default=0.0) + p.add_argument("--eval_value_recall_stop", type=float, default=0.0) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--resume_from_checkpoint", type=str, default="") + ns = p.parse_args() + return Args(**vars(ns)) + + +def main() -> None: + args = parse_args() + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_main_process = rank == 0 + + if preset_visible_devices: + if is_main_process: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + if is_main_process: + print(f"Using Hugging Face cache dir: {cache_dir}", flush=True) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + "group": args.wandb_group or None, + "id": args.wandb_run_id or None, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/records_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else args.train_jsonl + eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows))) + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if is_main_process: + print(f"Using device: {device}", flush=True) + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter(base, args.init_adapter_dir) + if is_main_process: + print(f"Loaded init adapter: {args.init_adapter_dir}", flush=True) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + if world_size <= 1: + model.to(device) + model.train() + + def on_prep_progress(rows_done: int, total_rows: int, records_built: int) -> None: + if is_main_process: + print( + f"[dataset build][grpo stage {args.stage_i}] rows={rows_done}/{total_rows} records={records_built}", + flush=True, + ) + if wb_run is not None: + wandb.log({"prep/rows_done": float(rows_done), "prep/records_built": float(records_built)}) + + train_records = load_or_build_grpo_records( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_grpo_cache_path(args))), + "prep/records_final": float(len(train_records)), + } + ) + + train_dataset = Dataset.from_list(train_records) + reward_func = make_reward_func(args) + + if int(args.limit_train_rows) > 0 and int(args.max_steps) <= 0: + args.max_steps = 1 + + ensure_trl_fsdp_compat() + from trl import GRPOConfig, GRPOTrainer + + config_kwargs = { + "output_dir": args.output_dir, + "per_device_train_batch_size": args.per_device_train_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + "num_train_epochs": args.num_train_epochs, + "learning_rate": args.learning_rate, + "logging_steps": args.logging_steps, + "save_steps": args.save_steps, + "eval_strategy": "steps", + "eval_steps": args.eval_steps, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + "num_generations": args.num_generations, + "beta": args.beta, + "bf16": (pick_dtype() == torch.bfloat16), + "report_to": ["wandb"] if args.use_wandb and is_main_process else [], + "remove_unused_columns": False, + } + if int(args.max_steps) > 0: + config_kwargs["max_steps"] = int(args.max_steps) + grpo_config_params = inspect.signature(GRPOConfig.__init__).parameters + unsupported_keys = sorted(key for key in config_kwargs if key not in grpo_config_params) + for key in unsupported_keys: + config_kwargs.pop(key, None) + if is_main_process and unsupported_keys: + print(f"Skipping unsupported GRPOConfig args: {', '.join(unsupported_keys)}", flush=True) + config = GRPOConfig(**config_kwargs) + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[reward_func], + args=config, + train_dataset=train_dataset, + eval_dataset=train_dataset.select(range(min(len(train_dataset), max(1, int(args.eval_rows))))), + ) + trainer.add_callback(CustomEvalCallback(args, eval_rows, tokenizer, device, wb_run, is_main_process)) + trainer.add_callback(FinalCheckpointCallback(args.output_dir, tokenizer, is_main_process)) + trainer.add_callback(WallClockStopCallback(args.max_wall_clock_seconds)) + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint or None) + + final_model = unwrap_training_model(trainer.model) + if is_main_process: + eval_metrics = run_eval(args=args, rows=eval_rows, model=final_model, tokenizer=tokenizer, device=device) + print( + f"[baseline grpo final eval] parse={eval_metrics['parse_rate']:.3f} " + f"canonical={eval_metrics['strict_canonical_rate']:.3f} " + f"exact={eval_metrics['exact_set_match_rate']:.3f} precision={eval_metrics['value_precision']:.3f} " + f"recall={eval_metrics['value_recall']:.3f} solve={eval_metrics['solve_rate']:.3f}", + flush=True, + ) + if wb_run is not None: + wandb.log({f"custom_eval/{k}": float(v) for k, v in eval_metrics.items()}) + trainer.save_model(args.output_dir) + save_model_artifacts(final_model, tokenizer, ensure_final_checkpoint_dir(args.output_dir)) + if wb_run is not None: + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/multi_output_cell_policy/prompt_builder.py b/multi_output_cell_policy/prompt_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..475b4609c55b9fa74cc81ebb5a1f58026f910e87 --- /dev/null +++ b/multi_output_cell_policy/prompt_builder.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import os +import sys +from typing import Any, Optional + +import numpy as np + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from format_utils_icon import grid_to_text + + +def make_multi_output_system_prompt(*, stage_i: int, total_empties_hint: int = 10) -> str: + i = max(1, int(stage_i)) + return ( + "You are a Sudoku value policy.\n" + f"This setup uses puzzles with about {int(total_empties_hint)} empty cells.\n" + "You will be given one target empty cell.\n" + 'Return ONLY one JSON object of the form {"values":[...]}.\n' + 'The JSON object must contain exactly one key named "values".\n' + 'The "values" field must be a JSON array of unique integers in [1,9].\n' + "You may return as many candidate values as you want, including one, several, or many values.\n" + "Choose the number of returned values yourself based on which values seem i-consistent.\n" + "The order of the values does not matter.\n" + "Do not output any explanation, markdown, punctuation outside JSON, or extra text.\n" + f"Current stage objective: i={i} consistency.\n" + ) + + +def build_multi_output_cell_prompt( + grid_9x9: np.ndarray, + *, + target_cell: tuple[int, int], + stage_i: int, + tokenizer: Any, + turn_idx: int, + total_turns: int, + prev_output_flag: Optional[str] = None, + total_empties_hint: int = 10, +) -> str: + g = np.asarray(grid_9x9, dtype=int).reshape(9, 9) + empties = int(np.sum(g == 0)) + rr, cc = int(target_cell[0]), int(target_cell[1]) + system_msg = make_multi_output_system_prompt( + stage_i=stage_i, total_empties_hint=total_empties_hint + ).strip() + empty_locs = [(int(r) + 1, int(c) + 1) for r, c in np.argwhere(g == 0).tolist()] + empty_locs_text = ", ".join(f"({r},{c})" for r, c in empty_locs) + user_msg = ( + "Sudoku grid (0 means empty):\n" + + grid_to_text(g) + + "\n" + + f"Empty cells in row-major order ({empties} total): {empty_locs_text}\n\n" + + f"Target cell to fill now: ({rr + 1},{cc + 1}).\n" + + f"Turn: {int(turn_idx)}/{int(total_turns)}.\n" + + 'Return only JSON with candidate values for this target cell: {"values":[...]}' + ) + if prev_output_flag is not None: + user_msg += f"\nPrevious output_flag (context only): {prev_output_flag}" + + chat_template = getattr(tokenizer, "chat_template", None) + if chat_template: + messages = [ + {"role": "system", "content": system_msg}, + {"role": "user", "content": user_msg}, + ] + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + return system_msg + "\n\n" + user_msg + "\n" diff --git a/multi_output_cell_policy/rewards.py b/multi_output_cell_policy/rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b74f63511785a595eda438355efbb220c65dba --- /dev/null +++ b/multi_output_cell_policy/rewards.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import os +import sys +from typing import Dict, List + +import numpy as np + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from multi_output_cell_policy.shared_multi_output_policy import ( + compute_set_precision_recall, + parse_values_json, + stage_i_consistent_values, +) + + +def triangular_number(n: int) -> float: + nn = max(0, int(n)) + return float(nn * (nn + 1) // 2) + + +def score_prediction_text( + *, + text: str, + grid: np.ndarray, + solved: np.ndarray, + target_cell: tuple[int, int], + stage_i: int, + reward_good_value: float, + penalty_bad_value: float, + penalty_malformed: float, + penalty_empty: float, + penalty_singleton: float, + penalty_missing: float = 0.0, + exact_match_bonus: float = 0.0, + cardinality_mismatch_penalty: float = 0.0, +) -> Dict[str, float | List[int] | str]: + parsed = parse_values_json(text) + target_values = stage_i_consistent_values(grid, target_cell=target_cell, stage_i=stage_i) + solved_value = int(np.asarray(solved, dtype=int).reshape(9, 9)[int(target_cell[0]), int(target_cell[1])]) + # Legacy gating preserved: at stage>=2 the original singleton penalty is off by default. + # Under-prediction pressure at stage>=2 is supplied by the new cardinality_mismatch_penalty + # below (if > 0). At stage 1, both penalties may stack. + singleton_penalty = 0.0 if int(stage_i) >= 2 else float(penalty_singleton) + + if not parsed.parse_ok: + return { + "reward": -float(penalty_malformed), + "parse_ok": 0.0, + "strict_canonical": 0.0, + "num_predicted_values": 0.0, + "num_i_consistent_values": 0.0, + "num_non_i_consistent_values": 0.0, + "num_missing_values": float(len(target_values)), + "includes_ground_truth": 0.0, + "value_precision": 0.0, + "value_recall": 0.0, + "exact_set_match": 0.0, + "predicted_values": [], + "target_values": [int(v) for v in target_values], + "format_error": "parse_failed", + } + + predicted_values = [int(v) for v in parsed.values] + target_set = set(int(v) for v in target_values) + num_good = sum(1 for v in predicted_values if v in target_set) + num_bad = sum(1 for v in predicted_values if v not in target_set) + num_missing = max(0, len(target_set) - num_good) + is_exact = bool(predicted_values) and (set(predicted_values) == target_set) + + # Base reward: encourage larger all-good sets while making extra wrong values expensive. + reward = triangular_number(num_good) * float(reward_good_value) - float(num_bad) * float( + penalty_bad_value + ) + # Directly penalize missing target values so recall is part of the optimization signal. + if num_missing > 0: + reward -= float(num_missing) * float(penalty_missing) + # Bonus only when the predicted set exactly matches the target (and is non-empty), + # so the optimum strictly dominates partial supersets. + if is_exact: + reward += float(exact_match_bonus) + if not predicted_values: + reward -= float(penalty_empty) + if len(predicted_values) == 1 and len(target_values) > 1: + reward -= singleton_penalty + # Stage-agnostic cardinality-mismatch pressure for multi-value targets. + # Fires whenever the prediction has strictly fewer values than the target set + # (the dominant failure mode for stage>=2 multi-value cells). + if len(predicted_values) < len(target_values) and len(target_values) > 1: + reward -= float(cardinality_mismatch_penalty) + + precision, recall = compute_set_precision_recall(predicted_values, target_values) + return { + "reward": float(reward), + "parse_ok": 1.0, + "strict_canonical": 1.0 if parsed.strict_canonical else 0.0, + "num_predicted_values": float(len(predicted_values)), + "num_i_consistent_values": float(num_good), + "num_non_i_consistent_values": float(num_bad), + "num_missing_values": float(num_missing), + "includes_ground_truth": 1.0 if solved_value in predicted_values else 0.0, + "value_precision": float(precision), + "value_recall": float(recall), + "exact_set_match": 1.0 if is_exact else 0.0, + "predicted_values": predicted_values, + "target_values": [int(v) for v in target_values], + "format_error": "", + } diff --git a/multi_output_cell_policy/run_baseline_multi_output_pipeline_10empty_5stage.py b/multi_output_cell_policy/run_baseline_multi_output_pipeline_10empty_5stage.py new file mode 100644 index 0000000000000000000000000000000000000000..849cc55ae956b2cd8e6390d1706b23d9daf0b704 --- /dev/null +++ b/multi_output_cell_policy/run_baseline_multi_output_pipeline_10empty_5stage.py @@ -0,0 +1,22 @@ +"""Recovered wrapper around the surviving compiled pipeline module.""" + +from __future__ import annotations + +import os +import sys + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from _sourceless_loader import load_pyc_into_globals + + +_MODULE = load_pyc_into_globals( + __file__, "run_baseline_multi_output_pipeline_10empty_5stage.cpython-311.pyc", globals() +) + + +if __name__ == "__main__" and hasattr(_MODULE, "main"): + _MODULE.main() diff --git a/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py b/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py new file mode 100644 index 0000000000000000000000000000000000000000..954a5e67135aab02efa780ffed186f2034506c55 --- /dev/null +++ b/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py @@ -0,0 +1,624 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("multi_output_cell_policy")) +DEFAULT_CACHE_DIR = Path("/home/ubuntu/curriculum-CoT/.hf_cache") +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" + +SFT_SCRIPT = CURRENT_DIR / "sft_multi_output_train.py" +GRPO_SCRIPT = CURRENT_DIR / "grpo_multi_output_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=3) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default="baseline_multi_output_pipeline_resume") + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=2e-4) + p.add_argument("--sft_learning_rate_later", type=float, default=5e-5) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--sft_eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--sft_min_steps_before_stop", type=int, default=0) + p.add_argument("--grpo_num_train_epochs", type=float, default=0.5) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=2) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--grpo_reward_good_value", type=float, default=1.0) + p.add_argument("--grpo_penalty_bad_value", type=float, default=1.75) + p.add_argument("--grpo_penalty_malformed", type=float, default=4.0) + p.add_argument("--grpo_penalty_empty", type=float, default=0.5) + p.add_argument("--grpo_penalty_singleton", type=float, default=1.5) + p.add_argument("--grpo_eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--grpo_min_steps_before_stop", type=int, default=0) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + if root_adapter.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + if not adapter.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + checkpoint_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + best: Optional[Artifact] = None + for stage_dir in checkpoint_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path(normalize_to_final_checkpoint_root(args.output_root, "multi_output_cell_policy")).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_resume" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku_t3_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: Optional[str], + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--eval_solve_rate_stop", + str(float(args.sft_eval_solve_rate_stop)), + "--min_steps_before_stop", + str(int(args.sft_min_steps_before_stop)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + "sudoku-multi-output-sft", + "--wandb_run_name", + f"baseline_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + if init_adapter_dir: + command.extend(["--init_adapter_dir", str(init_adapter_dir)]) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--reward_good_value", + str(float(args.grpo_reward_good_value)), + "--penalty_bad_value", + str(float(args.grpo_penalty_bad_value)), + "--penalty_malformed", + str(float(args.grpo_penalty_malformed)), + "--penalty_empty", + str(float(args.grpo_penalty_empty)), + "--penalty_singleton", + str(float(args.grpo_penalty_singleton)), + "--eval_solve_rate_stop", + str(float(args.grpo_eval_solve_rate_stop)), + "--min_steps_before_stop", + str(int(args.grpo_min_steps_before_stop)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + "sudoku-multi-output-grpo", + "--wandb_run_name", + f"baseline_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def main() -> None: + args = parse_args() + checkpoint_root = Path(normalize_to_final_checkpoint_root(args.checkpoint_root, "multi_output_cell_policy")).resolve() + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + checkpoint_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + checkpoint_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + elif stage == 1: + init_adapter_dir = str(args.stage1_init_adapter_dir).strip() or None + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch baseline stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/multi_output_cell_policy/sft_multi_output_train.py b/multi_output_cell_policy/sft_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..aea7fd9e2eab38c983089dc6de2a81ea731af693 --- /dev/null +++ b/multi_output_cell_policy/sft_multi_output_train.py @@ -0,0 +1,706 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import os +import random +import sys +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +from peft import LoraConfig, PeftModel, get_peft_model +from torch.optim import AdamW +from torch.utils.data import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final +from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt +from multi_output_cell_policy.rewards import score_prediction_text +from multi_output_cell_policy.shared_multi_output_policy import ( + batched_completion_ce_loss, + build_supervised_completion, + completion_ce_loss, + make_solved_grid_from_row, + stage_i_consistent_values, +) + + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + eval_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_epochs: float + learning_rate: float + weight_decay: float + max_grad_norm: float + enable_gradient_checkpointing: bool + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + max_completion_length: int + lora_r: int + lora_alpha: int + lora_dropout: float + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + debug_print_limit: int + limit_train_rows: int + eval_exact_set_match_stop: float + eval_value_precision_stop: float + eval_value_recall_stop: float + eval_solve_rate_stop: float + min_steps_before_stop: int + max_wall_clock_seconds: int + max_steps: int + multi_value_oversample_factor: int + train_target_size_min: int + train_target_size_max: int + eval_target_size_min: int + eval_target_size_max: int + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def target_size_allowed(target_size: int, min_size: int, max_size: int) -> bool: + if int(min_size) > 0 and int(target_size) < int(min_size): + return False + if int(max_size) > 0 and int(target_size) > int(max_size): + return False + return True + + +def build_training_examples( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + examples: List[Dict[str, Any]] = [] + eos_text = getattr(tokenizer, "eos_token", None) or "" + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) + if not target_size_allowed( + len(target_values), + getattr(tokenizer, "_train_target_size_min", 0), + getattr(tokenizer, "_train_target_size_max", 0), + ): + continue + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + target_text = build_supervised_completion(ex, stage_i=stage_i) + if eos_text: + target_text = target_text + eos_text + repeat_count = max(1, int(getattr(tokenizer, "_multi_value_oversample_factor", 1))) if len(target_values) > 1 else 1 + for _ in range(repeat_count): + examples.append( + { + "prompt_text": prompt, + "completion_text": target_text, + "target_values": list(target_values), + "grid": ex.grid, + "solved": solved, + "target_cell": ex.target_cell, + } + ) + if progress_callback is not None and ( + row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx, len(rows), len(examples)) + return examples + + +def _prepared_data_dir(args: Args) -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_sft_cache_path(args: Args) -> str: + payload = json.dumps( + { + "completion_format_version": 2, + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + "multi_value_oversample_factor": int(args.multi_value_oversample_factor), + "train_target_size_min": int(args.train_target_size_min), + "train_target_size_max": int(args.train_target_size_max), + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(payload).hexdigest()[:20] + return os.path.join(_prepared_data_dir(args), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _to_jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_to_jsonable(v) for v in value] + if hasattr(value, "tolist"): + return _to_jsonable(value.tolist()) + return value + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(_to_jsonable(row), separators=(",", ":")) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: + start = time.time() + while not os.path.exists(path): + if time.time() - start > timeout_s: + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_sft_examples( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_sft_cache_path(args) + if os.path.exists(cache_path): + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + examples = build_training_examples( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + _write_jsonl(cache_path, examples) + return examples + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +@torch.no_grad() +def run_eval(args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device): + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + solve_rows = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + row_has_eval_cell = False + row_debug_lines: List[str] = [] + for ex in build_cell_examples_from_row(row): + target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i) + if not target_size_allowed(len(target_values), int(args.eval_target_size_min), int(args.eval_target_size_max)): + continue + row_has_eval_cell = True + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + enc = {k: v.to(device) for k, v in enc.items()} + out = model.generate( + **enc, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=1.0, + penalty_bad_value=1.75, + penalty_malformed=4.0, + penalty_empty=0.5, + penalty_singleton=1.5, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + row_debug_lines.append( + f"[baseline sft eval debug] true_values={info['target_values']} " + f"predicted_values={info['predicted_values']} output={pred_text!r}" + ) + if row_has_eval_cell: + if printed < int(args.debug_print_limit) and row_debug_lines: + print("[baseline sft eval debug] puzzle_outputs_begin", flush=True) + for line in row_debug_lines: + print(line, flush=True) + print("[baseline sft eval debug] puzzle_outputs_end", flush=True) + printed += 1 + solve_ok += int(row_all_exact) + solve_rows += 1 + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, solve_rows)), + } + print( + f"[baseline sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}", + flush=True, + ) + model.train() + return out + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument("--train_jsonl", type=str, required=True) + p.add_argument("--eval_jsonl", type=str, default="") + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, default="") + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=1) + p.add_argument("--gradient_accumulation_steps", type=int, default=8) + p.add_argument("--num_epochs", type=float, default=1.0) + p.add_argument("--learning_rate", type=float, default=2e-4) + p.add_argument("--weight_decay", type=float, default=0.0) + p.add_argument( + "--max_grad_norm", + type=float, + default=1.0, + help="Clip global grad norm before each optimizer step (0 disables).", + ) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--logging_steps", type=int, default=10) + p.add_argument("--save_steps", type=int, default=100) + p.add_argument("--eval_steps", type=int, default=100) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--lora_r", type=int, default=16) + p.add_argument("--lora_alpha", type=int, default=32) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-sft") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) + p.add_argument("--eval_value_precision_stop", type=float, default=0.0) + p.add_argument("--eval_value_recall_stop", type=float, default=0.0) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--multi_value_oversample_factor", type=int, default=1) + p.add_argument("--train_target_size_min", type=int, default=0) + p.add_argument("--train_target_size_max", type=int, default=0) + p.add_argument("--eval_target_size_min", type=int, default=0) + p.add_argument("--eval_target_size_max", type=int, default=0) + return Args(**vars(p.parse_args())) + + +def save_checkpoint(model: torch.nn.Module, tokenizer: Any, output_dir: str, step: int) -> None: + save_checkpoint_and_update_final(model, tokenizer, output_dir, f"checkpoint-step-{step:05d}") + + +def main() -> None: + args = parse_args() + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + is_main_process = rank == 0 + + if preset_visible_devices: + if is_main_process: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + if is_distributed: + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", timeout=timedelta(hours=2)) + + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else args.train_jsonl + eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows))) + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + tokenizer._multi_value_oversample_factor = max(1, int(args.multi_value_oversample_factor)) + tokenizer._train_target_size_min = max(0, int(args.train_target_size_min)) + tokenizer._train_target_size_max = max(0, int(args.train_target_size_max)) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") + else: + device = torch.device("cpu") + + model = AutoModelForCausalLM.from_pretrained( + args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True + ) + if str(args.init_adapter_dir).strip(): + model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True) + else: + lora = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + ) + model = get_peft_model(model, lora) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + model.to(device) + model.train() + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + + def on_prep_progress(rows_done: int, total_rows: int, examples_built: int) -> None: + if is_main_process: + print( + f"[dataset build][sft stage {args.stage_i}] rows={rows_done}/{total_rows} examples={examples_built}", + flush=True, + ) + if wb_run is not None: + wandb.log({"prep/rows_done": float(rows_done), "prep/examples_built": float(examples_built)}) + + train_examples = load_or_build_sft_examples( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))), + "prep/examples_final": float(len(train_examples)), + } + ) + + optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) + denom = max(1, int(args.gradient_accumulation_steps)) * max(1, int(args.per_device_train_batch_size)) + total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / denom)) + if int(args.max_steps) > 0: + total_steps = min(total_steps, int(args.max_steps)) + step = 0 + start_time = time.time() + + def average_scalar(value: float) -> float: + if not is_distributed or not dist.is_initialized(): + return float(value) + tensor = torch.tensor(float(value), device=device, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float((tensor / float(world_size)).item()) + + def all_reduce_gradients() -> None: + if not is_distributed or not dist.is_initialized(): + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(float(world_size)) + + def sync_stop(local_stop: bool) -> bool: + if not is_distributed or not dist.is_initialized(): + return bool(local_stop) + tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item()) > 0) + + for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): + if is_distributed: + sampler = DistributedSampler( + train_examples, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=args.seed, + drop_last=False, + ) + sampler.set_epoch(epoch_idx) + order = list(iter(sampler)) + else: + generator = torch.Generator() + generator.manual_seed(args.seed + epoch_idx) + order = torch.randperm(len(train_examples), generator=generator).tolist() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + accum_ce_sum = 0.0 + microbatch_size = max(1, int(args.per_device_train_batch_size)) + for batch_start in range(0, len(order), microbatch_size): + batch_indices = order[batch_start : batch_start + microbatch_size] + batch_examples = [train_examples[ex_idx] for ex_idx in batch_indices] + ce_full = batched_completion_ce_loss( + model, + tokenizer, + [str(ex["prompt_text"]) for ex in batch_examples], + [str(ex["completion_text"]) for ex in batch_examples], + device, + ) + loss = ce_full / max(1, int(args.gradient_accumulation_steps)) + loss.backward() + accum_ce_sum += float(ce_full.detach().item()) + accum_count += 1 + if accum_count >= int(args.gradient_accumulation_steps): + all_reduce_gradients() + if float(args.max_grad_norm) > 0.0: + torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.max_grad_norm)) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + step += 1 + mean_ce = accum_ce_sum / max(1, int(args.gradient_accumulation_steps)) + accum_ce_sum = 0.0 + if step % int(args.logging_steps) == 0: + loss_value = average_scalar(mean_ce) + if is_main_process: + print(f"[baseline sft train step {step:05d}] loss={loss_value:.4f}", flush=True) + if wb_run is not None: + wandb.log({"train/loss": loss_value, "step": step}) + if step % int(args.eval_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + should_stop_eval = False + if is_main_process: + ev = run_eval(args, eval_rows, model, tokenizer, device) + if wb_run is not None: + wandb.log({f"eval/{k}": float(v) for k, v in ev.items()} | {"step": step}) + if ( + args.eval_exact_set_match_stop > 0.0 + and float(ev["exact_set_match_rate"]) >= args.eval_exact_set_match_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + if ( + not should_stop_eval + and step >= int(args.min_steps_before_stop) + and args.eval_value_precision_stop > 0.0 + and args.eval_value_recall_stop > 0.0 + and float(ev["value_precision"]) >= args.eval_value_precision_stop + and float(ev["value_recall"]) >= args.eval_value_recall_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + if ( + not should_stop_eval + and args.eval_solve_rate_stop > 0.0 + and step >= int(args.min_steps_before_stop) + and float(ev["solve_rate"]) >= args.eval_solve_rate_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + should_stop_eval = sync_stop(should_stop_eval) + if is_distributed and dist.is_initialized(): + dist.barrier() + if should_stop_eval: + if is_main_process and wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.destroy_process_group() + return + if step % int(args.save_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and ( + time.time() - start_time >= float(args.max_wall_clock_seconds) + ) + should_stop = sync_stop(reached_limit or exceeded_wall) + if should_stop: + break + if sync_stop(step >= total_steps): + break + + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/multi_output_cell_policy/shared_multi_output_policy.py b/multi_output_cell_policy/shared_multi_output_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..caef1de5158d5f992b32eceef1d221cb37cdf526 --- /dev/null +++ b/multi_output_cell_policy/shared_multi_output_policy.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import itertools +import json +import math +import os +import random +import re +import sys +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Iterable, List, Optional, Sequence + +import numpy as np +import torch + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from aligned_cell_policy.shared_cell_policy import CellExample +from formatting_icon import is_consistent_pair + + +def all_remaining_empties_have_legal_value(grid: np.ndarray) -> bool: + g = np.asarray(grid, dtype=int).reshape(9, 9) + for r in range(9): + for c in range(9): + if int(g[r, c]) != 0: + continue + cell = r * 9 + c + has_legal = any(is_consistent_pair(g, cell=cell, value=v, t=3, n=9) for v in range(1, 10)) + if not has_legal: + return False + return True + + +@dataclass(frozen=True) +class ParsedValues: + values: tuple[int, ...] + parse_ok: bool + strict_canonical: bool + + +def all_digit_values() -> List[int]: + return list(range(1, 10)) + + +def make_solved_grid_from_row(row: Dict[str, Any]) -> np.ndarray: + grid = parse_grid_from_tuple_prompt(str(row["prompt"])) + solved = np.asarray(grid, dtype=int).copy() + triples = row.get("metadata", {}).get("target_triples_1based", []) + for rr, cc, value in triples: + solved[int(rr) - 1, int(cc) - 1] = int(value) + return solved + + +def _grid_state_key(grid: np.ndarray) -> tuple[int, ...]: + return tuple(int(v) for v in np.asarray(grid, dtype=int).reshape(-1)) + + +def _legal_values_for_cell(state: tuple[int, ...], cell: int) -> tuple[int, ...]: + rr, cc = divmod(int(cell), 9) + if int(state[cell]) != 0: + return tuple() + g = np.asarray(state, dtype=int).reshape(9, 9) + return tuple( + int(value) + for value in all_digit_values() + if is_consistent_pair(g, cell=int(cell), value=int(value), t=3, n=9) + ) + + +@lru_cache(maxsize=200000) +def _stage_i_consistent_values_for_grid(state: tuple[int, ...], stage_i: int) -> tuple[tuple[int, ...], ...]: + stage_i = max(1, int(stage_i)) + out: List[tuple[int, ...]] = [tuple() for _ in range(81)] + + for cell in range(81): + legal_values = _legal_values_for_cell(state, cell) + if not legal_values: + continue + if stage_i <= 1: + out[cell] = legal_values + continue + + consistent_values: List[int] = [] + for value in legal_values: + child = list(state) + child[cell] = int(value) + child_state = tuple(child) + child_stage_values = _stage_i_consistent_values_for_grid(child_state, stage_i - 1) + if all(int(child_state[idx]) != 0 or len(child_stage_values[idx]) > 0 for idx in range(81)): + consistent_values.append(int(value)) + out[cell] = tuple(consistent_values) + + return tuple(out) + + +def stage_i_consistent_values( + grid: np.ndarray, + *, + target_cell: tuple[int, int], + stage_i: int, +) -> List[int]: + g = np.asarray(grid, dtype=int).reshape(9, 9) + rr, cc = int(target_cell[0]), int(target_cell[1]) + if int(g[rr, cc]) != 0: + return [] + cell = rr * 9 + cc + stage_values = _stage_i_consistent_values_for_grid(_grid_state_key(g), int(stage_i)) + return [int(value) for value in stage_values[cell]] + + +def canonicalize_values(values: Iterable[int]) -> List[int]: + out: List[int] = [] + seen = set() + for value in values: + if isinstance(value, bool): + raise ValueError("Boolean values are not allowed.") + vv = int(value) + if vv < 1 or vv > 9: + raise ValueError(f"Value must be in [1,9], got {vv}.") + if vv not in seen: + seen.add(vv) + out.append(vv) + return out + + +def values_json_text(values: Iterable[int]) -> str: + return json.dumps({"values": canonicalize_values(values)}, separators=(",", ":")) + + +def parse_values_json(text: str) -> ParsedValues: + raw = str(text).strip() + if not raw: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + try: + obj = json.loads(raw) + except Exception: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if not isinstance(obj, dict): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if set(obj.keys()) != {"values"}: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + values_obj = obj.get("values") + if not isinstance(values_obj, list): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + try: + values = canonicalize_values(values_obj) + except Exception: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if len(values) != len(values_obj): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + canonical = values_json_text(values) + return ParsedValues(values=tuple(values), parse_ok=True, strict_canonical=(canonical == raw)) + + +def compute_set_precision_recall(pred_values: Sequence[int], target_values: Sequence[int]) -> tuple[float, float]: + pred = set(int(v) for v in pred_values) + target = set(int(v) for v in target_values) + precision = 0.0 if not pred else float(len(pred & target) / max(1, len(pred))) + recall = 1.0 if not target else float(len(pred & target) / max(1, len(target))) + return precision, recall + + +def completion_ce_loss( + model: torch.nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( + device + ) + labels = all_ids.clone() + labels[:, : int(prompt_ids.shape[1])] = -100 + out = model(input_ids=all_ids, labels=labels) + return out.loss + + +def batched_completion_ce_loss( + model: torch.nn.Module, + tokenizer: Any, + prompt_texts: Sequence[str], + completion_texts: Sequence[str], + device: torch.device, +) -> torch.Tensor: + if len(prompt_texts) != len(completion_texts): + raise ValueError("prompt_texts and completion_texts must have the same length") + if not prompt_texts: + raise ValueError("batched_completion_ce_loss requires at least one example") + + full_texts = [str(p) + str(c) for p, c in zip(prompt_texts, completion_texts, strict=True)] + batch = tokenizer(full_texts, return_tensors="pt", add_special_tokens=False, padding=True) + prompt_batch = tokenizer(list(prompt_texts), return_tensors="pt", add_special_tokens=False, padding=True) + + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + prompt_attention = prompt_batch["attention_mask"] + prompt_lengths = prompt_attention.sum(dim=1).tolist() + + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + for row_idx, prompt_len in enumerate(prompt_lengths): + labels[row_idx, : int(prompt_len)] = -100 + + out = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + return out.loss + + +def completion_logprob( + model: torch.nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( + device + ) + labels = all_ids.clone() + labels[:, : int(prompt_ids.shape[1])] = -100 + out = model(input_ids=all_ids, labels=labels) + num_completion_tokens = int((labels != -100).sum().item()) + return -out.loss * max(1, num_completion_tokens) + + +def enumerate_value_permutations( + values: Sequence[int], + *, + max_permutations: int, + rng: Optional[random.Random] = None, +) -> List[tuple[int, ...]]: + uniq = tuple(canonicalize_values(values)) + if len(uniq) <= 1: + return [uniq] + total = math.factorial(len(uniq)) + if total <= max(1, int(max_permutations)): + return [tuple(p) for p in itertools.permutations(uniq)] + + rr = rng or random.Random(0) + perms = set() + base = list(uniq) + max_needed = max(1, int(max_permutations)) + while len(perms) < max_needed: + shuffled = list(base) + rr.shuffle(shuffled) + perms.add(tuple(shuffled)) + return list(perms) + + +def permutation_invariant_json_ce_loss( + model: torch.nn.Module, + tokenizer: Any, + prompt_text: str, + values: Sequence[int], + device: torch.device, + *, + max_permutations: int, + rng: Optional[random.Random] = None, +) -> torch.Tensor: + permutations = enumerate_value_permutations(values, max_permutations=max_permutations, rng=rng) + logps = [ + completion_logprob(model, tokenizer, prompt_text, values_json_text(perm), device) for perm in permutations + ] + stacked = torch.stack(logps, dim=0) + return -(torch.logsumexp(stacked, dim=0) - math.log(float(len(permutations)))) + + +def build_supervised_completion( + ex: CellExample, + *, + stage_i: int, + rng: Optional[random.Random] = None, + randomize_order: bool = False, +) -> str: + values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) + if randomize_order and len(values) > 1: + shuffled = list(values) + (rng or random).shuffle(shuffled) + values = shuffled + return values_json_text(values) + + +def summarize_values(values: Iterable[int]) -> str: + return "[" + ", ".join(str(int(v)) for v in values) + "]" + + +_TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)") + + +def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray: + triples = _TUPLE_PROMPT_RE.findall(str(prompt_text)) + if len(triples) < 81: + raise ValueError("Could not recover 81 (row,col,value) tuples from prompt.") + grid = np.zeros((9, 9), dtype=int) + for rr, cc, vv in triples[:81]: + r = int(rr) - 1 + c = int(cc) - 1 + grid[r, c] = int(vv) + return grid diff --git a/run_latent_residual_projector_pipeline.py b/run_latent_residual_projector_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5f7b9782f87c57f301666884fc37eb8c52cc1950 --- /dev/null +++ b/run_latent_residual_projector_pipeline.py @@ -0,0 +1,604 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("latent_multi_output_cell_policy")) +DEFAULT_CACHE_DIR = Path("/home/ubuntu/curriculum-CoT/.hf_cache") +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" + +SFT_SCRIPT = CURRENT_DIR / "latent_multi_output_cell_policy" / "residual_projector_warmstart_sft_latent_multi_output_train.py" +GRPO_SCRIPT = CURRENT_DIR / "latent_multi_output_cell_policy" / "grpo_residual_projector_latent_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=3) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default="latent_residual_projector_pipeline") + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=1e-6) + p.add_argument("--sft_learning_rate_later", type=float, default=1e-6) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--grpo_num_train_epochs", type=float, default=1.0) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=4) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + if root_adapter.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + latent_state = child / "latent_cot_state.pt" + if not adapter.exists() or not latent_state.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + search_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + if not search_root.exists(): + return None + best: Optional[Artifact] = None + for stage_dir in search_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path(normalize_to_final_checkpoint_root(args.output_root, "latent_multi_output_cell_policy")).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"latent_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku_t3_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty_residual_projector" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend( + [ + "--use_wandb", + "--wandb_entity", + args.wandb_entity, + "--wandb_project", + "sudoku-latent-multi-output-sft-residual-projector", + "--wandb_run_name", + f"latent_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend( + [ + "--use_wandb", + "--wandb_entity", + args.wandb_entity, + "--wandb_project", + "sudoku-latent-multi-output-grpo-residual-projector", + "--wandb_run_name", + f"latent_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def main() -> None: + args = parse_args() + checkpoint_root = Path(normalize_to_final_checkpoint_root(args.checkpoint_root, "latent_multi_output_cell_policy")).resolve() + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + checkpoint_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + checkpoint_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + elif stage == 1: + init_adapter_dir = str(args.stage1_init_adapter_dir).strip() + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch latent stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching latent SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing latent SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching latent GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/simple_9x9_curriculum/build_dataset.py b/simple_9x9_curriculum/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7500a3c56efa663f5596cdf844e4493dea84b1b5 --- /dev/null +++ b/simple_9x9_curriculum/build_dataset.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import argparse +import json +import random +from pathlib import Path +from typing import Iterable, List, Sequence, Tuple + +import numpy as np + + +GRID_SIZE = 9 +BOX_SIZE = 3 +ALL_VALUES = tuple(range(1, 10)) + + +def parse_args() -> argparse.Namespace: + root = Path(__file__).resolve().parent.parent + default_output = root / "data" / "sudoku_t3_30empty_value_qwen_text.jsonl" + p = argparse.ArgumentParser() + p.add_argument("--output", type=str, default=str(default_output)) + p.add_argument("--num_puzzles", type=int, default=20000) + p.add_argument("--empties", type=int, default=30) + p.add_argument("--seed", type=int, default=0) + return p.parse_args() + + +def permute_groups(rng: random.Random, values: Sequence[int], group_size: int) -> List[int]: + groups = [list(values[idx : idx + group_size]) for idx in range(0, len(values), group_size)] + rng.shuffle(groups) + out: List[int] = [] + for group in groups: + rng.shuffle(group) + out.extend(group) + return out + + +def base_grid() -> np.ndarray: + return np.asarray( + [[((rr * BOX_SIZE + rr // BOX_SIZE + cc) % GRID_SIZE) + 1 for cc in range(GRID_SIZE)] for rr in range(GRID_SIZE)], + dtype=int, + ) + + +def random_solved_grid(rng: random.Random) -> np.ndarray: + grid = base_grid().copy() + + digits = list(ALL_VALUES) + rng.shuffle(digits) + digit_map = {src: dst for src, dst in zip(ALL_VALUES, digits, strict=True)} + grid = np.vectorize(lambda value: digit_map[int(value)], otypes=[int])(grid) + + row_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + col_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + grid = grid[row_order, :] + grid = grid[:, col_order] + + if rng.random() < 0.5: + grid = grid.T + return np.asarray(grid, dtype=int) + + +def row_major_empty_locs(grid: np.ndarray) -> List[Tuple[int, int]]: + return [(int(r), int(c)) for r, c in np.argwhere(np.asarray(grid, dtype=int) == 0).tolist()] + + +def make_prompt(grid: np.ndarray) -> str: + tuples = [f"({r + 1},{c + 1},{int(grid[r, c])})" for r in range(GRID_SIZE) for c in range(GRID_SIZE)] + return ( + "9x9 Sudoku board encoded as (row,col,value) tuples in row-major order.\n" + "Value 0 means the cell is empty.\n" + + " ".join(tuples) + ) + + +def make_example(solved: np.ndarray, *, empties: int, rng: random.Random) -> dict: + if empties <= 0 or empties >= GRID_SIZE * GRID_SIZE: + raise ValueError(f"empties must be between 1 and {GRID_SIZE * GRID_SIZE - 1}") + + cells = list(range(GRID_SIZE * GRID_SIZE)) + rng.shuffle(cells) + masked_cells = sorted(cells[:empties]) + + puzzle = np.asarray(solved, dtype=int).copy() + for cell in masked_cells: + rr, cc = divmod(int(cell), GRID_SIZE) + puzzle[rr, cc] = 0 + + empty_locs_1based = [(rr + 1, cc + 1) for rr, cc in row_major_empty_locs(puzzle)] + target_triples_1based = [(rr + 1, cc + 1, int(solved[rr, cc])) for rr, cc in row_major_empty_locs(puzzle)] + completion_values = [int(value) for _, _, value in target_triples_1based] + + return { + "prompt": make_prompt(puzzle), + "completion": json.dumps(completion_values, separators=(",", ":")), + "metadata": { + "grid_size": GRID_SIZE, + "box_size": BOX_SIZE, + "empties": int(empties), + "empty_locs_1based": empty_locs_1based, + "target_triples_1based": target_triples_1based, + }, + } + + +def generate_examples(num_puzzles: int, *, empties: int, seed: int) -> Iterable[dict]: + rng = random.Random(int(seed)) + for _ in range(int(num_puzzles)): + solved = random_solved_grid(rng) + yield make_example(solved, empties=int(empties), rng=rng) + + +def main() -> None: + args = parse_args() + output_path = Path(args.output).resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w", encoding="utf-8") as f: + for row in generate_examples(args.num_puzzles, empties=args.empties, seed=args.seed): + f.write(json.dumps(row, separators=(",", ":")) + "\n") + print(f"Wrote {int(args.num_puzzles)} puzzles to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/small_model_20empty/launch_small_baseline_pipeline.sh b/small_model_20empty/launch_small_baseline_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f7bfe256de6e5cbf1bc9c7a5caf6916d649f70a --- /dev/null +++ b/small_model_20empty/launch_small_baseline_pipeline.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="/home/ubuntu/curriculum-CoT" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${ROOT}/sudoku/llm_policy_icon/small_model_20empty/run_small_baseline_pipeline.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/sudoku/llm_policy_icon/data/sudoku_t3_20empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/sudoku/llm_policy_icon/final_checkpoint/small_model_20empty/baseline}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_20empty_4stage_small}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${CHECKPOINT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-20}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--use_wandb --wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching small baseline pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/small_model_20empty/launch_small_latent_pipeline.sh b/small_model_20empty/launch_small_latent_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..145d365c76133d962e412e88eaedf7f719cb4bc7 --- /dev/null +++ b/small_model_20empty/launch_small_latent_pipeline.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="/home/ubuntu/curriculum-CoT" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${ROOT}/sudoku/llm_policy_icon/small_model_20empty/run_small_latent_pipeline.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/sudoku/llm_policy_icon/data/sudoku_t3_20empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +BASELINE_CHECKPOINT_ROOT="${BASELINE_CHECKPOINT_ROOT:-${ROOT}/sudoku/llm_policy_icon/final_checkpoint/small_model_20empty/baseline}" +LATENT_CHECKPOINT_ROOT="${LATENT_CHECKPOINT_ROOT:-${ROOT}/sudoku/llm_policy_icon/final_checkpoint/small_model_20empty/latent}" +BASELINE_OUTPUT_ROOT="${BASELINE_OUTPUT_ROOT:-${BASELINE_CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_20empty_4stage_small}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${LATENT_CHECKPOINT_ROOT}/${RUN_TAG}/latent_pipeline_20empty_4stage_small}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${LATENT_CHECKPOINT_ROOT}" + --baseline_output_root "${BASELINE_OUTPUT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-20}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-1.0}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-4}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--use_wandb --wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching small latent pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Baseline root: %s\n' "${BASELINE_OUTPUT_ROOT}" +printf 'Latent output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/small_model_20empty/run_small_baseline_pipeline.py b/small_model_20empty/run_small_baseline_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c469702fec4a2bbd752e934b5f4d8241136d10e7 --- /dev/null +++ b/small_model_20empty/run_small_baseline_pipeline.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("small_model_20empty", "baseline")) +DEFAULT_CACHE_DIR = Path("/home/ubuntu/curriculum-CoT/.hf_cache") +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_WANDB_GROUP = "small_model_20empty_baseline_pipeline" +DEFAULT_SFT_PROJECT = "sudoku-small-20empty-baseline-sft" +DEFAULT_GRPO_PROJECT = "sudoku-small-20empty-baseline-grpo" + +SFT_SCRIPT = PARENT_DIR / "multi_output_cell_policy" / "sft_multi_output_train.py" +GRPO_SCRIPT = PARENT_DIR / "multi_output_cell_policy" / "grpo_multi_output_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=4) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default=DEFAULT_WANDB_GROUP) + p.add_argument("--wandb_sft_project", type=str, default=DEFAULT_SFT_PROJECT) + p.add_argument("--wandb_grpo_project", type=str, default=DEFAULT_GRPO_PROJECT) + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=2e-4) + p.add_argument("--sft_learning_rate_later", type=float, default=5e-5) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--grpo_num_train_epochs", type=float, default=0.5) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=2) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + if root_adapter.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + if not adapter.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + checkpoint_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + best: Optional[Artifact] = None + for stage_dir in checkpoint_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path( + normalize_to_final_checkpoint_root(args.output_root, "small_model_20empty", "baseline") + ).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_small" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku_t3_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: Optional[str], + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_sft_project, + "--wandb_run_name", + f"small_baseline_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + if init_adapter_dir: + command.extend(["--init_adapter_dir", str(init_adapter_dir)]) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_grpo_project, + "--wandb_run_name", + f"small_baseline_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def main() -> None: + args = parse_args() + checkpoint_root = Path( + normalize_to_final_checkpoint_root(args.checkpoint_root, "small_model_20empty", "baseline") + ).resolve() + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "model_name": str(args.model_name), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + elif stage == 1: + init_adapter_dir = str(args.stage1_init_adapter_dir).strip() or None + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch baseline stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/small_model_20empty/run_small_latent_pipeline.py b/small_model_20empty/run_small_latent_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3481380ffe50c73b17b263f0114b5c629af7bc --- /dev/null +++ b/small_model_20empty/run_small_latent_pipeline.py @@ -0,0 +1,650 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("small_model_20empty", "latent")) +DEFAULT_BASELINE_ROOT = Path(final_checkpoint_root("small_model_20empty", "baseline")) +DEFAULT_CACHE_DIR = Path("/home/ubuntu/curriculum-CoT/.hf_cache") +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_WANDB_GROUP = "small_model_20empty_latent_pipeline" +DEFAULT_SFT_PROJECT = "sudoku-small-20empty-latent-sft-residual-projector" +DEFAULT_GRPO_PROJECT = "sudoku-small-20empty-latent-grpo-residual-projector" + +SFT_SCRIPT = PARENT_DIR / "latent_multi_output_cell_policy" / "residual_projector_warmstart_sft_latent_multi_output_train.py" +GRPO_SCRIPT = PARENT_DIR / "latent_multi_output_cell_policy" / "grpo_residual_projector_latent_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--baseline_output_root", type=str, default="") + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=4) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default=DEFAULT_WANDB_GROUP) + p.add_argument("--wandb_sft_project", type=str, default=DEFAULT_SFT_PROJECT) + p.add_argument("--wandb_grpo_project", type=str, default=DEFAULT_GRPO_PROJECT) + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=1e-6) + p.add_argument("--sft_learning_rate_later", type=float, default=1e-6) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--grpo_num_train_epochs", type=float, default=1.0) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=4) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + root_latent_state = stage_dir / "latent_cot_state.pt" + if root_adapter.exists() and root_latent_state.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + latent_state = child / "latent_cot_state.pt" + if not adapter.exists() or not latent_state.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + search_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + if not search_root.exists(): + return None + best: Optional[Artifact] = None + for stage_dir in search_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path( + normalize_to_final_checkpoint_root(args.output_root, "small_model_20empty", "latent") + ).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"latent_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_small" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku_t3_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def resolve_baseline_root(args: argparse.Namespace) -> Path: + if str(args.baseline_output_root).strip(): + return Path( + normalize_to_final_checkpoint_root(args.baseline_output_root, "small_model_20empty", "baseline") + ).resolve() + run_tag = str(args.run_tag).strip() + if run_tag: + return DEFAULT_BASELINE_ROOT / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_small" + return DEFAULT_BASELINE_ROOT + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty_residual_projector" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_sft_project, + "--wandb_run_name", + f"small_latent_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_grpo_project, + "--wandb_run_name", + f"small_latent_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def resolve_stage1_init_adapter(args: argparse.Namespace, baseline_root: Path) -> str: + if str(args.bootstrap_adapter_dir).strip(): + return str(args.bootstrap_adapter_dir).strip() + if str(args.stage1_init_adapter_dir).strip(): + return str(args.stage1_init_adapter_dir).strip() + baseline_stage1 = discover_latest_artifact( + baseline_root, + stage=1, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if baseline_stage1 is None: + raise RuntimeError( + "Latent stage 1 needs a warmstart adapter. Pass --bootstrap_adapter_dir / --stage1_init_adapter_dir " + "or provide --baseline_output_root with a completed baseline stage 1 SFT checkpoint." + ) + return baseline_stage1.path + + +def main() -> None: + args = parse_args() + checkpoint_root = Path( + normalize_to_final_checkpoint_root(args.checkpoint_root, "small_model_20empty", "latent") + ).resolve() + baseline_root = resolve_baseline_root(args) + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "baseline_output_root": str(baseline_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "model_name": str(args.model_name), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing latent GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == 1: + init_adapter_dir = resolve_stage1_init_adapter(args, baseline_root) + elif stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch latent stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching latent SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing latent SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching latent GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/__init__.py b/sudoku4x4_11empty/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sudoku4x4_11empty/build_dataset.py b/sudoku4x4_11empty/build_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..42db7b3d274e865f4a97e30bfa225aeaa140f344 --- /dev/null +++ b/sudoku4x4_11empty/build_dataset.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import argparse +import json +import random +from pathlib import Path +from typing import Iterable, List, Sequence, Tuple + +import numpy as np + + +GRID_SIZE = 4 +BOX_SIZE = 2 +ALL_VALUES = (1, 2, 3, 4) + +BASE_GRID = np.array( + [ + [1, 2, 3, 4], + [3, 4, 1, 2], + [2, 1, 4, 3], + [4, 3, 2, 1], + ], + dtype=int, +) + + +def parse_args() -> argparse.Namespace: + root = Path(__file__).resolve().parent.parent + default_output = root / "data" / "sudoku4x4_11empty_value_qwen_text.jsonl" + p = argparse.ArgumentParser() + p.add_argument("--output", type=str, default=str(default_output)) + p.add_argument("--num_puzzles", type=int, default=20000) + p.add_argument("--empties", type=int, default=11) + p.add_argument("--seed", type=int, default=0) + return p.parse_args() + + +def permute_groups(rng: random.Random, values: Sequence[int], group_size: int) -> List[int]: + groups = [list(values[idx : idx + group_size]) for idx in range(0, len(values), group_size)] + rng.shuffle(groups) + out: List[int] = [] + for group in groups: + rng.shuffle(group) + out.extend(group) + return out + + +def random_solved_grid(rng: random.Random) -> np.ndarray: + grid = np.asarray(BASE_GRID, dtype=int).copy() + + digits = list(ALL_VALUES) + rng.shuffle(digits) + digit_map = {src: dst for src, dst in zip(ALL_VALUES, digits, strict=True)} + grid = np.vectorize(lambda value: digit_map[int(value)], otypes=[int])(grid) + + row_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + col_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE) + grid = grid[row_order, :] + grid = grid[:, col_order] + + if rng.random() < 0.5: + grid = grid.T + return np.asarray(grid, dtype=int) + + +def row_major_empty_locs(grid: np.ndarray) -> List[Tuple[int, int]]: + return [(int(r), int(c)) for r, c in np.argwhere(np.asarray(grid, dtype=int) == 0).tolist()] + + +def make_prompt(grid: np.ndarray) -> str: + tuples = [ + f"({r + 1},{c + 1},{int(grid[r, c])})" + for r in range(GRID_SIZE) + for c in range(GRID_SIZE) + ] + return ( + "4x4 Sudoku board encoded as (row,col,value) tuples in row-major order.\n" + "Value 0 means the cell is empty.\n" + + " ".join(tuples) + ) + + +def make_example(solved: np.ndarray, *, empties: int, rng: random.Random) -> dict: + if empties <= 0 or empties >= GRID_SIZE * GRID_SIZE: + raise ValueError(f"empties must be between 1 and {GRID_SIZE * GRID_SIZE - 1}") + + cells = list(range(GRID_SIZE * GRID_SIZE)) + rng.shuffle(cells) + masked_cells = sorted(cells[:empties]) + + puzzle = np.asarray(solved, dtype=int).copy() + for cell in masked_cells: + rr, cc = divmod(int(cell), GRID_SIZE) + puzzle[rr, cc] = 0 + + empty_locs_1based = [(rr + 1, cc + 1) for rr, cc in row_major_empty_locs(puzzle)] + target_triples_1based = [ + (rr + 1, cc + 1, int(solved[rr, cc])) + for rr, cc in row_major_empty_locs(puzzle) + ] + completion_values = [int(value) for _, _, value in target_triples_1based] + + return { + "prompt": make_prompt(puzzle), + "completion": json.dumps(completion_values, separators=(",", ":")), + "metadata": { + "grid_size": GRID_SIZE, + "box_size": BOX_SIZE, + "empties": int(empties), + "empty_locs_1based": empty_locs_1based, + "target_triples_1based": target_triples_1based, + }, + } + + +def generate_examples(num_puzzles: int, *, empties: int, seed: int) -> Iterable[dict]: + rng = random.Random(int(seed)) + for _ in range(int(num_puzzles)): + solved = random_solved_grid(rng) + yield make_example(solved, empties=int(empties), rng=rng) + + +def main() -> None: + args = parse_args() + output_path = Path(args.output).resolve() + output_path.parent.mkdir(parents=True, exist_ok=True) + + with output_path.open("w", encoding="utf-8") as f: + for row in generate_examples(args.num_puzzles, empties=args.empties, seed=args.seed): + f.write(json.dumps(row, separators=(",", ":")) + "\n") + + print(f"Wrote {int(args.num_puzzles)} puzzles to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/grpo_multi_output_train.py b/sudoku4x4_11empty/grpo_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..86c40d8470dc3813d442eb3e4e9385ec0c8aa261 --- /dev/null +++ b/sudoku4x4_11empty/grpo_multi_output_train.py @@ -0,0 +1,728 @@ +from __future__ import annotations + +import argparse +import hashlib +import inspect +import json +import os +import sys +import time +from dataclasses import dataclass +from typing import Any, Dict, List + +import torch +from datasets import Dataset +from peft import LoraConfig, PeftModel, get_peft_model +from safetensors.torch import load_file as load_safetensors_file +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from sudoku4x4_11empty.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_model_artifacts +from sudoku4x4_11empty.prompt_builder import build_multi_output_cell_prompt +from sudoku4x4_11empty.rewards import score_prediction_text +from sudoku4x4_11empty.shared_multi_output_policy import make_solved_grid_from_row + + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_train_epochs: float + learning_rate: float + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + num_generations: int + max_prompt_length: int + max_completion_length: int + beta: float + lora_r: int + lora_alpha: int + lora_dropout: float + enable_gradient_checkpointing: bool + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + wandb_group: str + wandb_run_id: str + debug_print_limit: int + limit_train_rows: int + limit_train_examples: int + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + eval_solve_rate_stop: float + min_steps_before_stop: int + max_wall_clock_seconds: int + max_steps: int + resume_from_checkpoint: str + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available(): + try: + if torch.cuda.is_bf16_supported(): + return torch.bfloat16 + except Exception: + pass + return torch.float16 + + +def ensure_trl_fsdp_compat() -> None: + try: + import torch.distributed.fsdp as fsdp + + if not hasattr(fsdp, "FSDPModule") and hasattr(fsdp, "FullyShardedDataParallel"): + fsdp.FSDPModule = fsdp.FullyShardedDataParallel + except Exception: + pass + + +def load_trainable_adapter(base_model: torch.nn.Module, adapter_dir: str) -> torch.nn.Module: + try: + return PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=True) + except Exception: + config_path = os.path.join(adapter_dir, "adapter_config.json") + model_path = os.path.join(adapter_dir, "adapter_model.safetensors") + if not (os.path.exists(config_path) and os.path.exists(model_path)): + raise + + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + lora = LoraConfig( + r=int(cfg["r"]), + lora_alpha=int(cfg["lora_alpha"]), + lora_dropout=float(cfg["lora_dropout"]), + bias=str(cfg.get("bias", "none")), + task_type=str(cfg.get("task_type", "CAUSAL_LM")), + target_modules=list(cfg["target_modules"]), + ) + model = get_peft_model(base_model, lora) + state = load_safetensors_file(model_path) + remapped: Dict[str, torch.Tensor] = {} + for key, value in state.items(): + new_key = key.replace(".lora_A.weight", ".lora_A.default.weight") + new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") + remapped[new_key] = value + model.load_state_dict(remapped, strict=False) + return model + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def build_grpo_records( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + max_records: int = 0, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + records: List[Dict[str, Any]] = [] + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + records.append( + { + "prompt": prompt, + "grid_json": json.dumps(ex.grid.tolist(), separators=(",", ":")), + "solved_json": json.dumps(solved.tolist(), separators=(",", ":")), + "target_row": int(ex.target_cell[0]), + "target_col": int(ex.target_cell[1]), + "stage_i": int(stage_i), + } + ) + if int(max_records) > 0 and len(records) >= int(max_records): + break + if progress_callback is not None and ( + row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx, len(rows), len(records)) + if int(max_records) > 0 and len(records) >= int(max_records): + break + return records + + +def _prepared_data_dir(args: Args) -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "sudoku4x4_11empty") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_grpo_cache_path(args: Args) -> str: + payload = json.dumps( + { + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "limit_train_examples": int(args.limit_train_examples), + "model_name": str(args.model_name), + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(payload).hexdigest()[:20] + return os.path.join(_prepared_data_dir(args), f"grpo_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, separators=(",", ":")) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: + start = time.time() + while not os.path.exists(path): + if time.time() - start > timeout_s: + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_grpo_records( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_grpo_cache_path(args) + if os.path.exists(cache_path): + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + records = build_grpo_records( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + max_records=int(args.limit_train_examples), + progress_every_rows=10, + progress_callback=progress_callback, + ) + _write_jsonl(cache_path, records) + return records + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +def make_reward_func(args: Args): + def reward_func(completions, grid_json, solved_json, target_row, target_col, stage_i, **kwargs): + rewards: List[float] = [] + for completion, grid_s, solved_s, rr, cc, stage_val in zip( + completions, grid_json, solved_json, target_row, target_col, stage_i + ): + info = score_prediction_text( + text=str(completion), + grid=torch.tensor(json.loads(grid_s), dtype=torch.long).numpy(), + solved=torch.tensor(json.loads(solved_s), dtype=torch.long).numpy(), + target_cell=(int(rr), int(cc)), + stage_i=int(stage_val), + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + rewards.append(float(info["reward"])) + return rewards + + return reward_func + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: torch.nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + enc = {k: v.to(device) for k, v in enc.items()} + out = model.generate( + **enc, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[baseline grpo eval debug] target=({rr+1},{cc+1}) output={pred_text!r}", flush=True) + print( + f"[baseline grpo eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}", + flush=True, + ) + printed += 1 + solve_ok += int(row_all_exact) + + return { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + "eval_cells": float(total_cells), + } + + +def unwrap_training_model(model: Any) -> Any: + current = model + while hasattr(current, "module"): + current = current.module + return current + + +class CustomEvalCallback(TrainerCallback): + def __init__( + self, + args: Args, + eval_rows: List[Dict[str, Any]], + tokenizer: Any, + device: torch.device, + wb_run: Any, + is_main_process: bool, + ): + self.args = args + self.eval_rows = eval_rows + self.tokenizer = tokenizer + self.device = device + self.wb_run = wb_run + self.is_main_process = is_main_process + self.last_logged_step = -1 + + def on_step_end(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + step = int(state.global_step) + if step <= 0 or step == self.last_logged_step or step % int(self.args.eval_steps) != 0: + return control + model = kwargs.get("model") + if model is None: + return control + metrics = run_eval( + args=self.args, + rows=self.eval_rows, + model=unwrap_training_model(model), + tokenizer=self.tokenizer, + device=self.device, + ) + self.last_logged_step = step + print( + f"[baseline grpo custom eval step {step}] parse={metrics['parse_rate']:.3f} " + f"avg_set_size={metrics['avg_predicted_set_size']:.3f} " + f"good={metrics['avg_num_i_consistent_values']:.3f} " + f"bad={metrics['avg_num_non_i_consistent_values']:.3f}", + flush=True, + ) + if self.args.use_wandb and self.wb_run is not None: + payload = {f"custom_eval/{k}": float(v) for k, v in metrics.items()} + payload["custom_eval/global_step"] = float(step) + wandb.log(payload) + if ( + int(step) >= int(self.args.min_steps_before_stop) + and float(self.args.eval_solve_rate_stop) > 0.0 + and float(metrics["solve_rate"]) >= float(self.args.eval_solve_rate_stop) + ): + print( + f"[baseline grpo custom eval step {step}] early stop: " + f"solve_rate={metrics['solve_rate']:.3f} >= {float(self.args.eval_solve_rate_stop):.3f}", + flush=True, + ) + control.should_training_stop = True + return control + + +class FinalCheckpointCallback(TrainerCallback): + def __init__(self, output_dir: str, tokenizer: Any, is_main_process: bool): + self.output_dir = output_dir + self.tokenizer = tokenizer + self.is_main_process = is_main_process + + def _save(self, model: Any) -> None: + if self.is_main_process: + save_model_artifacts(unwrap_training_model(model), self.tokenizer, ensure_final_checkpoint_dir(self.output_dir)) + + def on_save(self, args, state, control, **kwargs): + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + def on_train_end(self, args, state, control, **kwargs): + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + +class WallClockStopCallback(TrainerCallback): + def __init__(self, max_wall_clock_seconds: int): + self.max_wall_clock_seconds = int(max_wall_clock_seconds) + self.start_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if self.max_wall_clock_seconds > 0 and (time.time() - self.start_time) >= float(self.max_wall_clock_seconds): + control.should_training_stop = True + return control + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument("--train_jsonl", type=str, required=True) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum_cot/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, required=True) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=2) + p.add_argument("--gradient_accumulation_steps", type=int, default=4) + p.add_argument("--num_train_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--logging_steps", type=int, default=5) + p.add_argument("--save_steps", type=int, default=25) + p.add_argument("--eval_steps", type=int, default=25) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--num_generations", type=int, default=2) + p.add_argument("--max_prompt_length", type=int, default=1024) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--beta", type=float, default=0.0) + p.add_argument("--lora_r", type=int, default=8) + p.add_argument("--lora_alpha", type=int, default=16) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-grpo") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--wandb_group", type=str, default="") + p.add_argument("--wandb_run_id", type=str, default="") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--limit_train_examples", type=int, default=0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--resume_from_checkpoint", type=str, default="") + ns = p.parse_args() + return Args(**vars(ns)) + + +def main() -> None: + args = parse_args() + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_main_process = rank == 0 + + if preset_visible_devices: + if is_main_process: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + if is_main_process: + print(f"Using Hugging Face cache dir: {cache_dir}", flush=True) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + "group": args.wandb_group or None, + "id": args.wandb_run_id or None, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/records_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_rows = rows[: max(1, int(args.eval_rows))] + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + if is_main_process: + print(f"Using device: {device}", flush=True) + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter(base, args.init_adapter_dir) + if is_main_process: + print(f"Loaded init adapter: {args.init_adapter_dir}", flush=True) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + if world_size <= 1: + model.to(device) + model.train() + + def on_prep_progress(rows_done: int, total_rows: int, records_built: int) -> None: + if is_main_process: + print( + f"[dataset build][grpo stage {args.stage_i}] rows={rows_done}/{total_rows} records={records_built}", + flush=True, + ) + if wb_run is not None: + wandb.log({"prep/rows_done": float(rows_done), "prep/records_built": float(records_built)}) + + train_records = load_or_build_grpo_records( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + if is_main_process and int(args.limit_train_examples) > 0: + print(f"Limiting GRPO train records to {len(train_records)} examples", flush=True) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_grpo_cache_path(args))), + "prep/records_final": float(len(train_records)), + } + ) + + train_dataset = Dataset.from_list(train_records) + reward_func = make_reward_func(args) + + if int(args.limit_train_rows) > 0 and int(args.max_steps) <= 0: + args.max_steps = 1 + + ensure_trl_fsdp_compat() + from trl import GRPOConfig, GRPOTrainer + + config_kwargs = { + "output_dir": args.output_dir, + "per_device_train_batch_size": args.per_device_train_batch_size, + "gradient_accumulation_steps": args.gradient_accumulation_steps, + "num_train_epochs": args.num_train_epochs, + "learning_rate": args.learning_rate, + "logging_steps": args.logging_steps, + "save_steps": args.save_steps, + "eval_strategy": "steps", + "eval_steps": args.eval_steps, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + "num_generations": args.num_generations, + "beta": args.beta, + "bf16": (pick_dtype() == torch.bfloat16), + "report_to": ["wandb"] if args.use_wandb and is_main_process else [], + "remove_unused_columns": False, + "max_steps": int(args.max_steps), + } + grpo_config_params = inspect.signature(GRPOConfig.__init__).parameters + unsupported_keys = sorted(key for key in config_kwargs if key not in grpo_config_params) + for key in unsupported_keys: + config_kwargs.pop(key, None) + if is_main_process and unsupported_keys: + print(f"Skipping unsupported GRPOConfig args: {', '.join(unsupported_keys)}", flush=True) + config = GRPOConfig(**config_kwargs) + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[reward_func], + args=config, + train_dataset=train_dataset, + eval_dataset=train_dataset.select(range(min(len(train_dataset), max(1, int(args.eval_rows))))), + ) + trainer.add_callback(CustomEvalCallback(args, eval_rows, tokenizer, device, wb_run, is_main_process)) + trainer.add_callback(FinalCheckpointCallback(args.output_dir, tokenizer, is_main_process)) + trainer.add_callback(WallClockStopCallback(args.max_wall_clock_seconds)) + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint or None) + + final_model = unwrap_training_model(trainer.model) + if is_main_process: + eval_metrics = run_eval(args=args, rows=eval_rows, model=final_model, tokenizer=tokenizer, device=device) + print( + f"[baseline grpo final eval] parse={eval_metrics['parse_rate']:.3f} " + f"canonical={eval_metrics['strict_canonical_rate']:.3f} " + f"exact={eval_metrics['exact_set_match_rate']:.3f} precision={eval_metrics['value_precision']:.3f} " + f"recall={eval_metrics['value_recall']:.3f} solve={eval_metrics['solve_rate']:.3f}", + flush=True, + ) + if wb_run is not None: + wandb.log({f"custom_eval/{k}": float(v) for k, v in eval_metrics.items()}) + trainer.save_model(args.output_dir) + save_model_artifacts(final_model, tokenizer, ensure_final_checkpoint_dir(args.output_dir)) + if wb_run is not None: + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/grpo_residual_projector_latent_train.py b/sudoku4x4_11empty/grpo_residual_projector_latent_train.py new file mode 100644 index 0000000000000000000000000000000000000000..0225b344463e89907567076fa46162b69ffd5f6a --- /dev/null +++ b/sudoku4x4_11empty/grpo_residual_projector_latent_train.py @@ -0,0 +1,1186 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import sys +import time +from types import MethodType +from dataclasses import dataclass +from typing import Any, Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset +from peft import LoraConfig, PeftModel, get_peft_model +from safetensors.torch import load_file as load_safetensors_file +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed +from transformers.modeling_outputs import CausalLMOutput + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from sudoku4x4_11empty.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_model_artifacts +from sudoku4x4_11empty.prompt_builder import build_multi_output_cell_prompt +from sudoku4x4_11empty.rewards import score_prediction_text +from sudoku4x4_11empty.shared_multi_output_policy import make_solved_grid_from_row + + +try: + import wandb +except Exception: + wandb = None + + +PROJECTOR_HIDDEN = 4096 + + +@dataclass +class Args: + model_name: str + train_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + num_cot_tokens: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_train_epochs: float + learning_rate: float + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + num_generations: int + max_prompt_length: int + max_completion_length: int + beta: float + enable_gradient_checkpointing: bool + lora_r: int + lora_alpha: int + lora_dropout: float + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + wandb_group: str + wandb_run_id: str + debug_print_limit: int + limit_train_rows: int + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + eval_solve_rate_stop: float + min_steps_before_stop: int + max_wall_clock_seconds: int + max_steps: int + resume_from_checkpoint: str + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available(): + try: + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + if int(getattr(props, "major", 0)) >= 8: + return torch.bfloat16 + except Exception: + pass + return torch.float16 + + +def ensure_trl_fsdp_compat() -> None: + try: + import torch.distributed.fsdp as fsdp + + if not hasattr(fsdp, "FSDPModule") and hasattr(fsdp, "FullyShardedDataParallel"): + fsdp.FSDPModule = fsdp.FullyShardedDataParallel + except Exception: + pass + + +def load_trainable_adapter( + base_model: torch.nn.Module, + adapter_dir: str, + *, + lora_r: int = 128, + lora_alpha: int = 256, + lora_dropout: float = 0.05, +) -> torch.nn.Module: + if not str(adapter_dir).strip(): + lora = LoraConfig( + r=int(lora_r), + lora_alpha=int(lora_alpha), + lora_dropout=float(lora_dropout), + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + ) + return get_peft_model(base_model, lora) + try: + return PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=True) + except Exception: + config_path = os.path.join(adapter_dir, "adapter_config.json") + model_path = os.path.join(adapter_dir, "adapter_model.safetensors") + with open(config_path, "r", encoding="utf-8") as f: + cfg = json.load(f) + lora = LoraConfig( + r=int(cfg["r"]), + lora_alpha=int(cfg["lora_alpha"]), + lora_dropout=float(cfg["lora_dropout"]), + bias=str(cfg.get("bias", "none")), + task_type=str(cfg.get("task_type", "CAUSAL_LM")), + target_modules=list(cfg["target_modules"]), + ) + model = get_peft_model(base_model, lora) + state = load_safetensors_file(model_path) + remapped: Dict[str, torch.Tensor] = {} + for key, value in state.items(): + new_key = key.replace(".lora_A.weight", ".lora_A.default.weight") + new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") + remapped[new_key] = value + model.load_state_dict(remapped, strict=False) + return model + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def build_grpo_dataset( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> Dataset: + records: List[Dict[str, Any]] = [] + total_rows = len(rows) + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + records.append( + { + "prompt": prompt, + "grid_json": json.dumps(ex.grid.tolist()), + "solved_json": json.dumps(solved.tolist()), + "target_row": int(ex.target_cell[0]), + "target_col": int(ex.target_cell[1]), + "stage_i": int(stage_i), + } + ) + if progress_callback is not None and ( + row_idx == total_rows or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx=row_idx, total_rows=total_rows, record_count=len(records)) + return Dataset.from_list(records) + + +def _prepared_data_dir() -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "sudoku4x4_11empty_latent") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_grpo_cache_path(args: Args) -> str: + payload = { + "kind": "grpo", + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + } + digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:20] + return os.path.join(_prepared_data_dir(), f"grpo_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=True) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_seconds: int = 6 * 60 * 60) -> None: + start = time.time() + while not os.path.exists(path): + if (time.time() - start) > float(timeout_seconds): + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_grpo_records( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_grpo_cache_path(args) + if os.path.exists(cache_path): + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] loading prepared cache: {cache_path}", flush=True) + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][grpo stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + dataset = build_grpo_dataset( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + records = [dataset[int(i)] for i in range(len(dataset))] + _write_jsonl(cache_path, records) + elif world_size > 1: + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +def make_reward_func(args: Args): + def reward_func(completions, grid_json, solved_json, target_row, target_col, stage_i, **kwargs): + rewards: List[float] = [] + for completion, grid_s, solved_s, rr, cc, stage_val in zip( + completions, grid_json, solved_json, target_row, target_col, stage_i + ): + info = score_prediction_text( + text=str(completion), + grid=torch.tensor(json.loads(grid_s), dtype=torch.long).numpy(), + solved=torch.tensor(json.loads(solved_s), dtype=torch.long).numpy(), + target_cell=(int(rr), int(cc)), + stage_i=int(stage_val), + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + rewards.append(float(info["reward"])) + return rewards + + return reward_func + + +def unwrap_backbone(model: nn.Module) -> nn.Module: + if isinstance(model, PeftModel): + return model.get_base_model() + return model + + +def unwrap_training_model(model: Any) -> Any: + current = model + while hasattr(current, "module"): + current = current.module + return current + + +def get_input_embeddings_module(model: nn.Module) -> nn.Module: + return unwrap_backbone(model).get_input_embeddings() + + +def get_output_embeddings_module(model: nn.Module) -> nn.Module: + base = unwrap_backbone(model) + return base.get_output_embeddings() or base.lm_head + + +def get_last_hidden_state(model_output: Any) -> torch.Tensor: + hidden = getattr(model_output, "last_hidden_state", None) + if hidden is not None: + return hidden + return model_output.hidden_states[-1] + + +def run_backbone_from_embeds(backbone: nn.Module, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor): + base = unwrap_backbone(backbone) + inner = getattr(base, "model", base) + return inner( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + + +def extend_attention_mask(mask: torch.Tensor, extra_tokens: int) -> torch.Tensor: + extra = torch.ones(mask.shape[0], int(extra_tokens), dtype=mask.dtype, device=mask.device) + return torch.cat([mask, extra], dim=1) + + +def attach_residual_projector_modules(model: nn.Module, hidden_size: int, projector_hidden: int = PROJECTOR_HIDDEN) -> None: + if hasattr(model, "latent_projector_in") and hasattr(model, "latent_projector_out") and hasattr( + model, "special_thought_embed" + ): + return + projector_hidden = int(projector_hidden) + model.special_thought_embed = nn.Parameter(torch.randn(hidden_size) * 0.02) + model.latent_mix_logit = nn.Parameter(torch.tensor(-8.0)) + model.latent_projector_in = nn.Linear(hidden_size, projector_hidden, bias=True) + model.latent_projector_out = nn.Linear(projector_hidden, hidden_size, bias=True) + nn.init.normal_(model.special_thought_embed, std=0.02) + nn.init.xavier_uniform_(model.latent_projector_in.weight) + nn.init.zeros_(model.latent_projector_in.bias) + nn.init.xavier_uniform_(model.latent_projector_out.weight) + nn.init.zeros_(model.latent_projector_out.bias) + + +def maybe_load_projector_state(model: nn.Module, path_or_dir: str) -> bool: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_cot_state.pt") + if not os.path.exists(state_path): + return False + state = torch.load(state_path, map_location="cpu") + with torch.no_grad(): + for name in [ + "special_thought_embed", + "latent_mix_logit", + "latent_projector_in_weight", + "latent_projector_in_bias", + "latent_projector_out_weight", + "latent_projector_out_bias", + ]: + if name not in state: + continue + if name == "special_thought_embed": + model.special_thought_embed.copy_(state[name].to(model.special_thought_embed)) + elif name == "latent_mix_logit": + model.latent_mix_logit.copy_(state[name].to(model.latent_mix_logit)) + elif name == "latent_projector_in_weight": + model.latent_projector_in.weight.copy_(state[name].to(model.latent_projector_in.weight)) + elif name == "latent_projector_in_bias": + model.latent_projector_in.bias.copy_(state[name].to(model.latent_projector_in.bias)) + elif name == "latent_projector_out_weight": + model.latent_projector_out.weight.copy_(state[name].to(model.latent_projector_out.weight)) + elif name == "latent_projector_out_bias": + model.latent_projector_out.bias.copy_(state[name].to(model.latent_projector_out.bias)) + return True + + +def infer_projector_hidden_from_state(path_or_dir: str) -> int | None: + state_path = str(path_or_dir) + if os.path.isdir(state_path): + state_path = os.path.join(state_path, "latent_cot_state.pt") + if not os.path.exists(state_path): + return None + state = torch.load(state_path, map_location="cpu") + weight = state.get("latent_projector_in_weight") + if isinstance(weight, torch.Tensor) and weight.ndim == 2: + return int(weight.shape[0]) + return None + + +def save_latent_projector_state(model: nn.Module, output_dir: str) -> None: + state = { + "special_thought_embed": model.special_thought_embed.detach().cpu(), + "latent_mix_logit": model.latent_mix_logit.detach().cpu(), + "latent_projector_in_weight": model.latent_projector_in.weight.detach().cpu(), + "latent_projector_in_bias": model.latent_projector_in.bias.detach().cpu(), + "latent_projector_out_weight": model.latent_projector_out.weight.detach().cpu(), + "latent_projector_out_bias": model.latent_projector_out.bias.detach().cpu(), + } + torch.save(state, os.path.join(output_dir, "latent_cot_state.pt")) + + +def project_hidden(model: nn.Module, hidden: torch.Tensor) -> torch.Tensor: + input_dtype = hidden.dtype + hidden = hidden.to(torch.float32) + hidden = F.linear( + hidden, + model.latent_projector_in.weight.to(dtype=torch.float32), + None if model.latent_projector_in.bias is None else model.latent_projector_in.bias.to(dtype=torch.float32), + ) + hidden = F.gelu(hidden) + hidden = F.linear( + hidden, + model.latent_projector_out.weight.to(dtype=torch.float32), + None if model.latent_projector_out.bias is None else model.latent_projector_out.bias.to(dtype=torch.float32), + ) + hidden = torch.nan_to_num(hidden, nan=0.0, posinf=50.0, neginf=-50.0) + hidden = hidden.clamp(min=-50.0, max=50.0) + return hidden.to(input_dtype) + + +def _sanitize_logits(logits: torch.Tensor, *, output_dtype: torch.dtype) -> torch.Tensor: + logits = torch.nan_to_num(logits.float(), nan=0.0, posinf=50.0, neginf=-50.0) + logits = logits.clamp(min=-50.0, max=50.0) + return logits.to(dtype=output_dtype) + + +def _should_fallback_to_base(model: nn.Module, latent_logits: torch.Tensor) -> torch.Tensor: + scores = torch.nan_to_num(latent_logits.float(), nan=0.0, posinf=50.0, neginf=-50.0) + probs = torch.softmax(scores, dim=-1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8) + max_prob = probs.max(dim=-1).values + entropy = -(probs * probs.clamp_min(1e-8).log()).sum(dim=-1) + if int(scores.shape[-1]) > 1: + top2 = torch.topk(scores, k=2, dim=-1).values + margin = top2[..., 0] - top2[..., 1] + else: + margin = torch.full_like(max_prob, float("inf")) + fallback = (~torch.isfinite(scores)).any(dim=-1) + fallback |= max_prob > float(getattr(model, "_latent_fallback_max_prob", 0.995)) + fallback |= entropy < float(getattr(model, "_latent_fallback_entropy_min", 0.02)) + fallback |= margin > float(getattr(model, "_latent_fallback_margin_max", 25.0)) + return fallback + + +def build_latent_hidden(model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int): + backbone = unwrap_backbone(model) + inner_backbone = getattr(backbone, "model", backbone) + input_embeds = get_input_embeddings_module(model)(input_ids) + base_out = inner_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=False, + return_dict=True, + use_cache=False, + ) + base_hidden = get_last_hidden_state(base_out)[:, -1, :] + if num_cot_tokens <= 0: + return base_hidden, base_hidden + + cur_embeds = input_embeds + cur_mask = attention_mask + latent_token = None + special = model.special_thought_embed.to(device=input_embeds.device, dtype=input_embeds.dtype).view(1, 1, -1) + for _ in range(int(num_cot_tokens)): + next_embed = special.expand(cur_embeds.shape[0], 1, -1) if latent_token is None else latent_token + cur_embeds = torch.cat([cur_embeds, next_embed], dim=1) + cur_mask = extend_attention_mask(cur_mask, 1) + out = run_backbone_from_embeds(backbone, cur_embeds, cur_mask) + latent_token = get_last_hidden_state(out)[:, -1:, :] + latent_hidden = latent_token[:, 0, :] + return base_hidden, latent_hidden + + +def residual_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + base_hidden, latent_hidden = build_latent_hidden(model, input_ids, attention_mask, num_cot_tokens) + projected_delta = project_hidden(model, latent_hidden - base_hidden).float() + mix = torch.sigmoid(model.latent_mix_logit.float()).to(projected_delta.device) + projected_delta = projected_delta * float(getattr(model, "_latent_delta_scale", 1.0)) * mix + base_hidden_fp32 = base_hidden.float() + base_norm = base_hidden_fp32.norm(dim=-1, keepdim=True).clamp_min(1e-6) + delta_norm = projected_delta.norm(dim=-1, keepdim=True).clamp_min(1e-6) + max_ratio = max(0.0, float(getattr(model, "_latent_delta_max_ratio", 0.5))) + max_delta_norm = base_norm * max_ratio + projected_delta = projected_delta * torch.clamp(max_delta_norm / delta_norm, max=1.0) + final_hidden = torch.nan_to_num(base_hidden_fp32 + projected_delta, nan=0.0, posinf=50.0, neginf=-50.0) + output_embeddings = get_output_embeddings_module(model) + output_dtype = getattr(getattr(output_embeddings, "weight", None), "dtype", final_hidden.dtype) + latent_logits = _sanitize_logits(output_embeddings(final_hidden.to(dtype=output_dtype)), output_dtype=output_dtype) + fallback_mask = _should_fallback_to_base(model, latent_logits) + if bool(fallback_mask.any()): + warn_count = int(getattr(model, "_latent_fallback_warn_count", 0)) + if warn_count < 5: + print(f"[latent grpo] falling back to base logits for {int(fallback_mask.sum().item())} rows") + model._latent_fallback_warn_count = warn_count + 1 + fallback_hidden = base_hidden_fp32[fallback_mask].to(dtype=output_dtype) + fallback_logits = _sanitize_logits(output_embeddings(fallback_hidden), output_dtype=output_dtype) + latent_logits = latent_logits.clone() + latent_logits[fallback_mask] = fallback_logits + return latent_logits + + +def _apply_repetition_penalty(logits: torch.Tensor, tokens: torch.Tensor, penalty: float) -> torch.Tensor: + if penalty == 1.0 or tokens.numel() == 0: + return logits + adjusted = logits.clone() + unique_tokens = torch.unique(tokens, sorted=False) + seen_logits = adjusted.index_select(dim=-1, index=unique_tokens) + seen_logits = torch.where(seen_logits < 0, seen_logits * penalty, seen_logits / penalty) + adjusted.index_copy_(dim=-1, index=unique_tokens, source=seen_logits) + return adjusted + + +def _sample_from_latent_logits( + logits: torch.Tensor, + *, + do_sample: bool, + temperature: float, + top_p: float, + top_k: int, +) -> torch.Tensor: + if not do_sample: + return torch.argmax(logits, dim=-1, keepdim=True) + + temperature = max(float(temperature), 1e-5) + scores = logits / temperature + + if int(top_k) > 0 and int(top_k) < scores.shape[-1]: + topk_values, _ = torch.topk(scores, k=int(top_k), dim=-1) + cutoff = topk_values[:, -1:].expand_as(scores) + scores = torch.where(scores < cutoff, torch.full_like(scores, float("-inf")), scores) + + if 0.0 < float(top_p) < 1.0: + sorted_scores, sorted_indices = torch.sort(scores, dim=-1, descending=True) + sorted_probs = torch.softmax(sorted_scores, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + nucleus_mask = cumulative_probs > float(top_p) + nucleus_mask[:, 0] = False + sorted_scores = sorted_scores.masked_fill(nucleus_mask, float("-inf")) + scores = torch.full_like(scores, float("-inf")) + scores.scatter_(dim=-1, index=sorted_indices, src=sorted_scores) + + probs = torch.softmax(scores, dim=-1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-8) + return torch.multinomial(probs, num_samples=1) + + +@torch.no_grad() +def sample_latent_completion( + model: nn.Module, + tokenizer: Any, + prompt_ids: torch.Tensor, + attention_mask: torch.Tensor, + *, + num_cot_tokens: int, + max_new_tokens: int, + do_sample: bool = False, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + repetition_penalty: float = 1.0, +) -> torch.Tensor: + generated = prompt_ids + mask = attention_mask + eos = tokenizer.eos_token_id + for _ in range(max(1, int(max_new_tokens))): + logits = residual_next_token_logits_from_ids(model, generated, mask, num_cot_tokens) + logits = _apply_repetition_penalty(logits, generated, float(repetition_penalty)) + next_id = _sample_from_latent_logits( + logits.float(), + do_sample=bool(do_sample), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + ) + generated = torch.cat([generated, next_id], dim=1) + mask = extend_attention_mask(mask, 1) + if eos is not None and bool((next_id == int(eos)).all()): + break + return generated[:, prompt_ids.shape[1] :] + + +def install_latent_grpo_model_interface( + model: nn.Module, + tokenizer: Any, + *, + num_cot_tokens: int, + latent_delta_scale: float = 1.0, + latent_delta_max_ratio: float = 0.5, +) -> nn.Module: + if getattr(model, "_latent_grpo_interface_installed", False): + model._latent_grpo_num_cot_tokens = int(num_cot_tokens) + model._latent_grpo_tokenizer = tokenizer + model._latent_delta_scale = float(latent_delta_scale) + model._latent_delta_max_ratio = float(latent_delta_max_ratio) + return model + + model._latent_grpo_interface_installed = True + model._latent_grpo_num_cot_tokens = int(num_cot_tokens) + model._latent_grpo_tokenizer = tokenizer + model._latent_delta_scale = float(latent_delta_scale) + model._latent_delta_max_ratio = float(latent_delta_max_ratio) + model._latent_original_forward = model.forward + model._latent_original_generate = model.generate + + def latent_forward( + self, + input_ids=None, + attention_mask=None, + logits_to_keep=None, + use_cache=None, + **kwargs, + ): + if input_ids is None or attention_mask is None or logits_to_keep is None: + return self._latent_original_forward( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=use_cache, + **kwargs, + ) + + seq_len = int(input_ids.shape[1]) + keep = max(1, min(int(logits_to_keep), seq_len)) + start = max(1, seq_len - keep) + logits = [] + for prefix_len in range(start, seq_len + 1): + prefix_ids = input_ids[:, :prefix_len] + prefix_mask = attention_mask[:, :prefix_len] + step_logits = residual_next_token_logits_from_ids( + self, + prefix_ids, + prefix_mask, + int(self._latent_grpo_num_cot_tokens), + ) + logits.append(step_logits.unsqueeze(1)) + return CausalLMOutput(logits=torch.cat(logits, dim=1)) + + @torch.no_grad() + def latent_generate(self, input_ids=None, attention_mask=None, generation_config=None, **kwargs): + if input_ids is None or attention_mask is None: + return self._latent_original_generate( + input_ids=input_ids, + attention_mask=attention_mask, + generation_config=generation_config, + **kwargs, + ) + + tokenizer_local = self._latent_grpo_tokenizer + max_new_tokens = int( + getattr(generation_config, "max_new_tokens", None) or kwargs.get("max_new_tokens") or 16 + ) + do_sample = bool(getattr(generation_config, "do_sample", True)) + temperature = float(getattr(generation_config, "temperature", 1.0)) + top_p = float(getattr(generation_config, "top_p", 1.0)) + top_k = int(getattr(generation_config, "top_k", 0)) + repetition_penalty = float(getattr(generation_config, "repetition_penalty", 1.0)) + pad_token_id = getattr(generation_config, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = getattr(tokenizer_local, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = getattr(tokenizer_local, "eos_token_id", 0) + + rows = [] + for row_ids, row_mask in zip(input_ids, attention_mask, strict=True): + row_prompt = row_ids.unsqueeze(0) + row_attn = row_mask.unsqueeze(0) + completion = sample_latent_completion( + self, + tokenizer_local, + row_prompt, + row_attn, + num_cot_tokens=int(self._latent_grpo_num_cot_tokens), + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repetition_penalty=repetition_penalty, + ) + rows.append(torch.cat([row_prompt, completion], dim=1).squeeze(0)) + + max_len = max(int(row.shape[0]) for row in rows) + padded = [] + for row in rows: + if int(row.shape[0]) < max_len: + pad = torch.full((max_len - int(row.shape[0]),), int(pad_token_id), device=row.device, dtype=row.dtype) + row = torch.cat([row, pad], dim=0) + padded.append(row) + return torch.stack(padded, dim=0) + + model.forward = MethodType(latent_forward, model) + model.generate = MethodType(latent_generate, model) + return model + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: torch.nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + prompt_ids = enc["input_ids"].to(device) + attn = enc["attention_mask"].to(device) + completion_ids = sample_latent_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + pred_text = tokenizer.decode(completion_ids[0], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[latent grpo eval debug] target=({rr+1},{cc+1}) output={pred_text!r}") + print(f"[latent grpo eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}") + printed += 1 + solve_ok += int(row_all_exact) + return { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + "eval_cells": float(total_cells), + } + + +class ResidualProjectorEvalCallback(TrainerCallback): + def __init__( + self, + args: Args, + eval_rows: List[Dict[str, Any]], + tokenizer: Any, + device: torch.device, + wb_run: Any, + is_main_process: bool, + ): + self.args = args + self.eval_rows = eval_rows + self.tokenizer = tokenizer + self.device = device + self.wb_run = wb_run + self.is_main_process = is_main_process + self.last_logged_step = -1 + + def on_step_end(self, args, state, control, **kwargs): + step = int(state.global_step) + if ( + not self.is_main_process + or step <= 0 + or step == self.last_logged_step + or step % int(self.args.eval_steps) != 0 + ): + return control + model = kwargs.get("model") + if model is None: + return control + if torch.cuda.is_available(): + torch.cuda.empty_cache() + metrics = run_eval( + args=self.args, + rows=self.eval_rows, + model=unwrap_training_model(model), + tokenizer=self.tokenizer, + device=self.device, + ) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.last_logged_step = step + print( + f"[latent grpo custom eval step {step}] parse={metrics['parse_rate']:.3f} " + f"avg_set_size={metrics['avg_predicted_set_size']:.3f} " + f"good={metrics['avg_num_i_consistent_values']:.3f} " + f"bad={metrics['avg_num_non_i_consistent_values']:.3f}" + ) + if self.args.use_wandb and self.wb_run is not None: + payload = {f"custom_eval/{k}": float(v) for k, v in metrics.items()} + payload["custom_eval/global_step"] = float(step) + wandb.log(payload) + if ( + int(step) >= int(self.args.min_steps_before_stop) + and float(self.args.eval_solve_rate_stop) > 0.0 + and float(metrics["solve_rate"]) >= float(self.args.eval_solve_rate_stop) + ): + print( + f"[latent grpo custom eval step {step}] early stop: " + f"solve_rate={metrics['solve_rate']:.3f} >= {float(self.args.eval_solve_rate_stop):.3f}" + ) + control.should_training_stop = True + return control + + +class SaveLatentStateCallback(TrainerCallback): + def __init__(self, is_main_process: bool): + self.is_main_process = is_main_process + + def on_save(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is None: + return control + step_dir = os.path.join(args.output_dir, f"checkpoint-{int(state.global_step)}") + if os.path.isdir(step_dir): + save_latent_projector_state(unwrap_training_model(model), step_dir) + return control + + +class FinalCheckpointCallback(TrainerCallback): + def __init__(self, output_dir: str, tokenizer: Any, is_main_process: bool): + self.output_dir = output_dir + self.tokenizer = tokenizer + self.is_main_process = is_main_process + + def _save(self, model: Any) -> None: + save_model_artifacts( + unwrap_training_model(model), + self.tokenizer, + ensure_final_checkpoint_dir(self.output_dir), + extra_save_fn=save_latent_projector_state, + ) + + def on_save(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + def on_train_end(self, args, state, control, **kwargs): + if not self.is_main_process: + return control + model = kwargs.get("model") + if model is not None: + self._save(model) + return control + + +class WallClockStopCallback(TrainerCallback): + def __init__(self, max_wall_clock_seconds: int): + self.max_wall_clock_seconds = int(max_wall_clock_seconds) + self.start_time = time.time() + + def on_step_end(self, args, state, control, **kwargs): + if self.max_wall_clock_seconds > 0 and (time.time() - self.start_time) >= float(self.max_wall_clock_seconds): + control.should_training_stop = True + return control + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument( + "--train_jsonl", + type=str, + default="/egr/research-slim/ghoshavr/curriculum-CoT/sudoku/llm_policy_icon/data/sudoku4x4_11empty_value_qwen_text.jsonl", + ) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum_cot/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, required=True) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--num_cot_tokens", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=4) + p.add_argument("--gradient_accumulation_steps", type=int, default=4) + p.add_argument("--num_train_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--logging_steps", type=int, default=5) + p.add_argument("--save_steps", type=int, default=10) + p.add_argument("--eval_steps", type=int, default=25) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--num_generations", type=int, default=2) + p.add_argument("--max_prompt_length", type=int, default=1024) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--beta", type=float, default=0.0) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--lora_r", type=int, default=192) + p.add_argument("--lora_alpha", type=int, default=384) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-latent-multi-output-grpo-residual-projector") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--wandb_group", type=str, default="") + p.add_argument("--wandb_run_id", type=str, default="") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) + p.add_argument("--min_steps_before_stop", type=int, default=0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + p.add_argument("--resume_from_checkpoint", type=str, default="") + return Args(**vars(p.parse_args())) + + +def main() -> None: + args = parse_args() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_main_process = rank == 0 + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + if preset_visible_devices: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}") + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank if world_size > 1 else max(0, int(args.gpu_id))) + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + print(f"Using Hugging Face cache dir: {cache_dir}") + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + wb_run = wandb.init( + project=args.wandb_project, + entity=args.wandb_entity or None, + name=args.wandb_run_name or None, + mode=args.wandb_mode, + group=args.wandb_group or None, + id=args.wandb_run_id or None, + ) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/records_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_rows = rows[: max(1, int(args.eval_rows))] + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter(base, args.init_adapter_dir) + print(f"Loaded init adapter: {args.init_adapter_dir}") + projector_hidden = infer_projector_hidden_from_state(args.init_adapter_dir) or PROJECTOR_HIDDEN + attach_residual_projector_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + projector_hidden=projector_hidden, + ) + maybe_load_projector_state(model, args.init_adapter_dir) + print(f"Loaded latent/projector state from: {args.init_adapter_dir}") + if world_size <= 1: + model.to(device) + model.train() + + def on_prep_progress(*, row_idx: int, total_rows: int, record_count: int) -> None: + if not is_main_process: + return + print( + f"[dataset build][grpo stage {args.stage_i}] rows={row_idx}/{total_rows} records={record_count}", + flush=True, + ) + if wb_run is not None: + wandb.log( + { + "prep/rows_done": float(row_idx), + "prep/rows_total": float(total_rows), + "prep/records_built": float(record_count), + } + ) + + train_records = load_or_build_grpo_records( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + train_dataset = Dataset.from_list(train_records) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_grpo_cache_path(args))), + "prep/records_final": float(len(train_records)), + } + ) + reward_func = make_reward_func(args) + + ensure_trl_fsdp_compat() + from trl import GRPOConfig, GRPOTrainer + + if int(args.limit_train_rows) > 0 and int(args.max_steps) <= 0: + args.max_steps = 1 + config = GRPOConfig( + output_dir=args.output_dir, + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + num_train_epochs=args.num_train_epochs, + learning_rate=args.learning_rate, + logging_steps=args.logging_steps, + save_steps=args.save_steps, + eval_strategy="no", + do_eval=False, + max_completion_length=args.max_completion_length, + num_generations=args.num_generations, + beta=args.beta, + gradient_checkpointing=bool(args.enable_gradient_checkpointing), + bf16=(pick_dtype() == torch.bfloat16), + report_to=(["wandb"] if args.use_wandb and is_main_process else []), + remove_unused_columns=False, + max_steps=int(args.max_steps), + ) + + trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[reward_func], + args=config, + train_dataset=train_dataset, + ) + trainer.add_callback(ResidualProjectorEvalCallback(args, eval_rows, tokenizer, device, wb_run, is_main_process)) + trainer.add_callback(SaveLatentStateCallback(is_main_process)) + trainer.add_callback(FinalCheckpointCallback(args.output_dir, tokenizer, is_main_process)) + trainer.add_callback(WallClockStopCallback(args.max_wall_clock_seconds)) + trainer.train(resume_from_checkpoint=args.resume_from_checkpoint or None) + + if hasattr(trainer, "accelerator"): + trainer.accelerator.wait_for_everyone() + final_model = trainer.accelerator.unwrap_model(trainer.model) if hasattr(trainer, "accelerator") else trainer.model + final_model = unwrap_training_model(final_model) + if is_main_process: + eval_metrics = run_eval(args=args, rows=eval_rows, model=final_model, tokenizer=tokenizer, device=device) + print( + f"[latent grpo final eval] parse={eval_metrics['parse_rate']:.3f} " + f"canonical={eval_metrics['strict_canonical_rate']:.3f} " + f"exact={eval_metrics['exact_set_match_rate']:.3f} " + f"precision={eval_metrics['value_precision']:.3f} " + f"recall={eval_metrics['value_recall']:.3f} solve={eval_metrics['solve_rate']:.3f}" + ) + trainer.save_model(args.output_dir) + save_latent_projector_state(final_model, args.output_dir) + save_model_artifacts( + final_model, + tokenizer, + ensure_final_checkpoint_dir(args.output_dir), + extra_save_fn=save_latent_projector_state, + ) + if wb_run is not None: + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/launch_4x4_11empty_baseline_pipeline.sh b/sudoku4x4_11empty/launch_4x4_11empty_baseline_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..4eb2ad25d67e4ff6bfa9433e8c16b9751d83805b --- /dev/null +++ b/sudoku4x4_11empty/launch_4x4_11empty_baseline_pipeline.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${SCRIPT_DIR}/run_4x4_11empty_baseline_pipeline.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku4x4_11empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/sudoku4x4_11empty/baseline}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_11empty_4stage_4x4}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${CHECKPOINT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-11}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --grpo_eval_solve_rate_stop "${GRPO_EVAL_SOLVE_RATE_STOP:-0.65}" + --grpo_min_steps_before_stop "${GRPO_MIN_STEPS_BEFORE_STOP:-3000}" + --grpo_limit_train_examples "${GRPO_LIMIT_TRAIN_EXAMPLES:-0}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching 4x4 baseline pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/sudoku4x4_11empty/launch_4x4_11empty_latent_pipeline.sh b/sudoku4x4_11empty/launch_4x4_11empty_latent_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f9d7cabd1f5c953b77bd8213ace4a422148fcd3 --- /dev/null +++ b/sudoku4x4_11empty/launch_4x4_11empty_latent_pipeline.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}" +PIPELINE="${SCRIPT_DIR}/run_4x4_11empty_latent_pipeline.py" +TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku4x4_11empty_value_qwen_text.jsonl}" +CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}" +MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}" +GPU_IDS="${GPU_IDS:-0}" +NUM_PROCESSES="${NUM_PROCESSES:-1}" +MIN_STAGE="${MIN_STAGE:-1}" +MAX_STAGE="${MAX_STAGE:-4}" +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}" +BASELINE_CHECKPOINT_ROOT="${BASELINE_CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/sudoku4x4_11empty/baseline}" +LATENT_CHECKPOINT_ROOT="${LATENT_CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/sudoku4x4_11empty/latent}" +BASELINE_OUTPUT_ROOT="${BASELINE_OUTPUT_ROOT:-${BASELINE_CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_11empty_4stage_4x4}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${LATENT_CHECKPOINT_ROOT}/${RUN_TAG}/latent_pipeline_11empty_4stage_4x4}" + +cmd=( + "${PYTHON_BIN}" "${PIPELINE}" + --python_executable "${PYTHON_BIN}" + --train_jsonl "${TRAIN_JSONL}" + --cache_dir "${CACHE_DIR}" + --model_name "${MODEL_NAME}" + --checkpoint_root "${LATENT_CHECKPOINT_ROOT}" + --baseline_output_root "${BASELINE_OUTPUT_ROOT}" + --output_root "${OUTPUT_ROOT}" + --run_tag "${RUN_TAG}" + --min_stage "${MIN_STAGE}" + --max_stage "${MAX_STAGE}" + --distributed_gpu_ids "${GPU_IDS}" + --sft_num_processes "${NUM_PROCESSES}" + --grpo_num_processes "${NUM_PROCESSES}" + --total_empties_hint "${TOTAL_EMPTIES_HINT:-11}" + --sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}" + --grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-1.0}" + --sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}" + --grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-4}" + --grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}" + --grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}" + --grpo_eval_solve_rate_stop "${GRPO_EVAL_SOLVE_RATE_STOP:-0.65}" + --grpo_min_steps_before_stop "${GRPO_MIN_STEPS_BEFORE_STOP:-3000}" + --sft_enable_gradient_checkpointing + --grpo_enable_gradient_checkpointing + --sft_save_steps "${SFT_SAVE_STEPS:-100}" + --sft_eval_steps "${SFT_EVAL_STEPS:-100}" + --grpo_save_steps "${GRPO_SAVE_STEPS:-25}" + --grpo_eval_steps "${GRPO_EVAL_STEPS:-25}" + --phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}" + --wandb_mode "${WANDB_MODE:-offline}" +) + +if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then + cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}") +fi + +if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then + cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}") +fi + +if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then + cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}") +fi + +if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}") +fi + +if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then + cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}") +fi + +if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then + cmd+=(--use_wandb) +fi + +if [[ -n "${WANDB_ENTITY:-}" ]]; then + cmd+=(--wandb_entity "${WANDB_ENTITY}") +fi + +printf 'Launching 4x4 latent pipeline on GPUs %s\n' "${GPU_IDS}" +printf 'Baseline root: %s\n' "${BASELINE_OUTPUT_ROOT}" +printf 'Latent output root: %s\n' "${OUTPUT_ROOT}" +printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}" + +"${cmd[@]}" diff --git a/sudoku4x4_11empty/prompt_builder.py b/sudoku4x4_11empty/prompt_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aad26407e429eaa6c824b6123353e94e256cb5 --- /dev/null +++ b/sudoku4x4_11empty/prompt_builder.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import Any, Optional + +import numpy as np + + +def grid_to_text(grid_4x4: np.ndarray) -> str: + grid = np.asarray(grid_4x4, dtype=int).reshape(4, 4) + return "\n".join(" ".join(str(int(value)) for value in row) for row in grid.tolist()) + + +def make_multi_output_system_prompt(*, stage_i: int, total_empties_hint: int = 11) -> str: + i = max(1, int(stage_i)) + return ( + "You are a 4x4 Sudoku value policy.\n" + f"This setup uses 4x4 Sudoku puzzles with about {int(total_empties_hint)} empty cells.\n" + "Each row, each column, and each 2x2 box must contain the numbers 1 through 4 exactly once.\n" + "You will be given one target empty cell.\n" + 'Return ONLY one JSON object of the form {"values":[...]}.\n' + 'The JSON object must contain exactly one key named "values".\n' + 'The "values" field must be a JSON array of unique integers in [1,4].\n' + "You may return as many candidate values as you want, including one, several, or many values.\n" + "Choose the number of returned values yourself based on which values seem i-consistent.\n" + "The order of the values does not matter.\n" + "Do not output any explanation, markdown, punctuation outside JSON, or extra text.\n" + f"Current stage objective: i={i} consistency.\n" + ) + + +def build_multi_output_cell_prompt( + grid_4x4: np.ndarray, + *, + target_cell: tuple[int, int], + stage_i: int, + tokenizer: Any, + turn_idx: int, + total_turns: int, + prev_output_flag: Optional[str] = None, + total_empties_hint: int = 11, +) -> str: + g = np.asarray(grid_4x4, dtype=int).reshape(4, 4) + empties = int(np.sum(g == 0)) + rr, cc = int(target_cell[0]), int(target_cell[1]) + system_msg = make_multi_output_system_prompt( + stage_i=stage_i, total_empties_hint=total_empties_hint + ).strip() + empty_locs = [(int(r) + 1, int(c) + 1) for r, c in np.argwhere(g == 0).tolist()] + empty_locs_text = ", ".join(f"({r},{c})" for r, c in empty_locs) + user_msg = ( + "4x4 Sudoku grid (0 means empty):\n" + + grid_to_text(g) + + "\n" + + f"Empty cells in row-major order ({empties} total): {empty_locs_text}\n\n" + + f"Target cell to fill now: ({rr + 1},{cc + 1}).\n" + + f"Turn: {int(turn_idx)}/{int(total_turns)}.\n" + + 'Return only JSON with candidate values for this target cell: {"values":[...]}' + ) + if prev_output_flag is not None: + user_msg += f"\nPrevious output_flag (context only): {prev_output_flag}" + + chat_template = getattr(tokenizer, "chat_template", None) + if chat_template: + messages = [ + {"role": "system", "content": system_msg}, + {"role": "user", "content": user_msg}, + ] + return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + return system_msg + "\n\n" + user_msg + "\n" diff --git a/sudoku4x4_11empty/residual_projector_warmstart_sft_latent_multi_output_train.py b/sudoku4x4_11empty/residual_projector_warmstart_sft_latent_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4de45f32bd68cad011be5688b60a15dd92aab8 --- /dev/null +++ b/sudoku4x4_11empty/residual_projector_warmstart_sft_latent_multi_output_train.py @@ -0,0 +1,664 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import os +import sys +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from peft import PeftModel +from torch.optim import AdamW +from torch.utils.data import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from sudoku4x4_11empty.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final +from sudoku4x4_11empty.prompt_builder import build_multi_output_cell_prompt +from sudoku4x4_11empty.rewards import score_prediction_text +from sudoku4x4_11empty.shared_multi_output_policy import build_supervised_completion, make_solved_grid_from_row +from sudoku4x4_11empty.grpo_residual_projector_latent_train import ( + PROJECTOR_HIDDEN, + attach_residual_projector_modules, + build_latent_hidden, + configure_hf_cache, + extend_attention_mask, + get_output_embeddings_module, + infer_projector_hidden_from_state, + load_jsonl_rows, + load_trainable_adapter, + maybe_load_projector_state, + pick_dtype, + project_hidden, + sample_latent_completion, + save_latent_projector_state, + unwrap_backbone, +) + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + num_cot_tokens: int + total_empties_hint: int + gradient_accumulation_steps: int + num_epochs: float + learning_rate: float + weight_decay: float + enable_gradient_checkpointing: bool + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + max_completion_length: int + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + debug_print_limit: int + limit_train_rows: int + eval_exact_set_match_stop: float + reward_good_value: float + penalty_bad_value: float + penalty_malformed: float + penalty_empty: float + penalty_singleton: float + max_wall_clock_seconds: int + max_steps: int + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def build_training_examples( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +): + examples = [] + total_rows = len(rows) + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + examples.append( + { + "prompt_text": prompt, + "completion_text": build_supervised_completion(ex, stage_i=stage_i), + "grid": ex.grid, + "solved": solved, + "target_cell": ex.target_cell, + } + ) + if progress_callback is not None and ( + row_idx == total_rows or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx=row_idx, total_rows=total_rows, example_count=len(examples)) + return examples + + +def _prepared_data_dir() -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "sudoku4x4_11empty_latent") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_sft_cache_path(args: Args) -> str: + payload = { + "kind": "sft", + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + } + digest = hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:20] + return os.path.join(_prepared_data_dir(), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(row, ensure_ascii=True) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_seconds: int = 6 * 60 * 60) -> None: + start = time.time() + while not os.path.exists(path): + if (time.time() - start) > float(timeout_seconds): + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_sft_examples( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_sft_cache_path(args) + if os.path.exists(cache_path): + if rank == 0: + print(f"[dataset build][sft stage {args.stage_i}] loading prepared cache: {cache_path}", flush=True) + return _read_jsonl(cache_path) + + if rank == 0: + print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + built = build_training_examples( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + serializable = [ + { + "prompt_text": ex["prompt_text"], + "completion_text": ex["completion_text"], + } + for ex in built + ] + _write_jsonl(cache_path, serializable) + elif world_size > 1: + _wait_for_cache(cache_path) + + if world_size > 1 and dist.is_initialized(): + dist.barrier() + return _read_jsonl(cache_path) + + +def residual_next_token_logits_from_ids( + model: nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, num_cot_tokens: int +) -> torch.Tensor: + base_hidden, latent_hidden = build_latent_hidden(model, input_ids, attention_mask, num_cot_tokens) + final_hidden = base_hidden + project_hidden(model, latent_hidden - base_hidden) + return get_output_embeddings_module(model)(final_hidden) + + +def latent_residual_completion_ce_loss( + model: nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, + *, + num_cot_tokens: int, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors="pt", add_special_tokens=False).input_ids.to( + device + ) + completion_ids = all_ids[:, int(prompt_ids.shape[1]) :] + if int(completion_ids.shape[1]) <= 0: + return torch.zeros((), device=device, dtype=torch.float32, requires_grad=True) + + cur_ids = prompt_ids + cur_mask = torch.ones_like(prompt_ids, device=device) + token_losses: List[torch.Tensor] = [] + for idx in range(int(completion_ids.shape[1])): + logits = residual_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot_tokens) + target = completion_ids[:, idx] + token_losses.append(F.cross_entropy(logits.float(), target, reduction="mean")) + cur_ids = torch.cat([cur_ids, completion_ids[:, idx : idx + 1]], dim=1) + cur_mask = extend_attention_mask(cur_mask, 1) + return torch.stack(token_losses, dim=0).mean() + + +@torch.no_grad() +def run_eval( + *, + args: Args, + rows: List[Dict[str, Any]], + model: nn.Module, + tokenizer: Any, + device: torch.device, +) -> Dict[str, float]: + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + prompt_ids = enc["input_ids"].to(device) + attn = enc["attention_mask"].to(device) + completion_ids = sample_latent_completion( + model, + tokenizer, + prompt_ids, + attn, + num_cot_tokens=args.num_cot_tokens, + max_new_tokens=args.max_completion_length, + do_sample=False, + ) + pred_text = tokenizer.decode(completion_ids[0], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=args.reward_good_value, + penalty_bad_value=args.penalty_bad_value, + penalty_malformed=args.penalty_malformed, + penalty_empty=args.penalty_empty, + penalty_singleton=args.penalty_singleton, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[latent sft eval debug] target=({rr+1},{cc+1}) output={pred_text!r}") + print(f"[latent sft eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}") + printed += 1 + solve_ok += int(row_all_exact) + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + } + print( + f"[latent sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}" + ) + model.train() + return out + + +def save_checkpoint(model: nn.Module, tokenizer: Any, output_dir: str, step: int) -> None: + save_checkpoint_and_update_final( + model, + tokenizer, + output_dir, + f"checkpoint-step-{step:05d}", + extra_save_fn=save_latent_projector_state, + ) + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument( + "--train_jsonl", + type=str, + default="/egr/research-slim/ghoshavr/curriculum-CoT/sudoku/llm_policy_icon/data/sudoku4x4_11empty_value_qwen_text.jsonl", + ) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum_cot/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, required=True) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=2) + p.add_argument("--num_cot_tokens", type=int, default=2) + p.add_argument("--total_empties_hint", type=int, default=20) + p.add_argument("--gradient_accumulation_steps", type=int, default=8) + p.add_argument("--num_epochs", type=float, default=0.5) + p.add_argument("--learning_rate", type=float, default=1e-6) + p.add_argument("--weight_decay", type=float, default=0.0) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--logging_steps", type=int, default=10) + p.add_argument("--save_steps", type=int, default=100) + p.add_argument("--eval_steps", type=int, default=100) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-latent-multi-output-sft-residual-projector") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) + p.add_argument("--reward_good_value", type=float, default=1.0) + p.add_argument("--penalty_bad_value", type=float, default=1.75) + p.add_argument("--penalty_malformed", type=float, default=4.0) + p.add_argument("--penalty_empty", type=float, default=0.5) + p.add_argument("--penalty_singleton", type=float, default=1.5) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + return Args(**vars(p.parse_args())) + + +def main() -> None: + args = parse_args() + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + if torch.cuda.is_available(): + if is_distributed: + torch.cuda.set_device(local_rank) + else: + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + if not preset_visible_devices and int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + if is_distributed and not dist.is_initialized(): + dist.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + timeout=timedelta(hours=2), + ) + is_main_process = rank == 0 + + set_seed(args.seed + rank) + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + if is_distributed and dist.is_initialized(): + dist.barrier() + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_rows = rows[: max(1, int(args.eval_rows))] + + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + if torch.cuda.is_available(): + device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") + else: + device = torch.device("cpu") + + base = AutoModelForCausalLM.from_pretrained( + args.model_name, + cache_dir=cache_dir, + torch_dtype=pick_dtype(), + low_cpu_mem_usage=True, + ) + model = load_trainable_adapter(base, args.init_adapter_dir) + projector_hidden = infer_projector_hidden_from_state(args.init_adapter_dir) or PROJECTOR_HIDDEN + attach_residual_projector_modules( + model, + hidden_size=int(unwrap_backbone(model).config.hidden_size), + projector_hidden=projector_hidden, + ) + maybe_load_projector_state(model, args.init_adapter_dir) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + backbone = unwrap_backbone(model) + if hasattr(backbone, "config"): + backbone.config.use_cache = False + model.to(device) + model.train() + + def on_prep_progress(*, row_idx: int, total_rows: int, example_count: int) -> None: + if not is_main_process: + return + print( + f"[dataset build][sft stage {args.stage_i}] rows={row_idx}/{total_rows} examples={example_count}", + flush=True, + ) + if wb_run is not None: + wandb.log( + { + "prep/rows_done": float(row_idx), + "prep/rows_total": float(total_rows), + "prep/examples_built": float(example_count), + } + ) + + train_examples = load_or_build_sft_examples( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))), + "prep/examples_final": float(len(train_examples)), + } + ) + optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) + total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / max(1, args.gradient_accumulation_steps))) + if int(args.max_steps) > 0: + total_steps = min(total_steps, int(args.max_steps)) + step = 0 + start_time = time.time() + + def average_scalar(value: float) -> float: + if not is_distributed or not dist.is_initialized(): + return float(value) + tensor = torch.tensor(float(value), device=device, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float((tensor / float(world_size)).item()) + + def all_reduce_gradients() -> None: + if not is_distributed or not dist.is_initialized(): + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(float(world_size)) + + def sync_stop(local_stop: bool) -> bool: + if not is_distributed or not dist.is_initialized(): + return bool(local_stop) + tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item()) > 0) + + for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): + if is_distributed: + sampler = DistributedSampler( + train_examples, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=args.seed, + drop_last=False, + ) + sampler.set_epoch(epoch_idx) + order = list(iter(sampler)) + else: + generator = torch.Generator() + generator.manual_seed(args.seed + epoch_idx) + order = torch.randperm(len(train_examples), generator=generator).tolist() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + for idx, ex_idx in enumerate(order, start=1): + ex = train_examples[ex_idx] + loss = latent_residual_completion_ce_loss( + model, + tokenizer, + ex["prompt_text"], + ex["completion_text"], + device, + num_cot_tokens=args.num_cot_tokens, + ) / max(1, int(args.gradient_accumulation_steps)) + loss.backward() + accum_count += 1 + if accum_count >= int(args.gradient_accumulation_steps): + all_reduce_gradients() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + step += 1 + if step % int(args.logging_steps) == 0: + loss_value = average_scalar(float(loss.item()) * args.gradient_accumulation_steps) + if is_main_process: + print(f"[latent sft train step {step:05d}] loss={loss_value:.4f}", flush=True) + if wb_run is not None: + wandb.log({"train/loss": loss_value, "step": step}) + if step % int(args.eval_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + should_stop_eval = False + if is_main_process: + metrics = run_eval(args=args, rows=eval_rows, model=model, tokenizer=tokenizer, device=device) + if wb_run is not None: + wandb.log({f"eval/{k}": float(v) for k, v in metrics.items()} | {"step": step}) + if ( + args.eval_exact_set_match_stop > 0.0 + and float(metrics["exact_set_match_rate"]) >= args.eval_exact_set_match_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + should_stop_eval = sync_stop(should_stop_eval) + if is_distributed and dist.is_initialized(): + dist.barrier() + if should_stop_eval: + if is_main_process and wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.destroy_process_group() + return + if step % int(args.save_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and ( + time.time() - start_time >= float(args.max_wall_clock_seconds) + ) + if sync_stop(reached_limit or exceeded_wall): + break + if accum_count > 0: + all_reduce_gradients() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + step += 1 + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and (time.time() - start_time >= float(args.max_wall_clock_seconds)) + if sync_stop(reached_limit or exceeded_wall): + break + + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + if is_main_process and wb_run is not None: + wb_run.finish() + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/rewards.py b/sudoku4x4_11empty/rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5c7f340cf5768d6329ccb9b1a593856191b8eb --- /dev/null +++ b/sudoku4x4_11empty/rewards.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Dict, List + +import numpy as np + +from sudoku4x4_11empty.shared_multi_output_policy import ( + compute_set_precision_recall, + parse_values_json, + stage_i_consistent_values, +) + + +def triangular_number(n: int) -> float: + nn = max(0, int(n)) + return float(nn * (nn + 1) // 2) + + +def score_prediction_text( + *, + text: str, + grid: np.ndarray, + solved: np.ndarray, + target_cell: tuple[int, int], + stage_i: int, + reward_good_value: float, + penalty_bad_value: float, + penalty_malformed: float, + penalty_empty: float, + penalty_singleton: float, +) -> Dict[str, float | List[int] | str]: + parsed = parse_values_json(text) + target_values = stage_i_consistent_values(grid, target_cell=target_cell, stage_i=stage_i) + solved_value = int(np.asarray(solved, dtype=int).reshape(4, 4)[int(target_cell[0]), int(target_cell[1])]) + singleton_penalty = 0.0 if int(stage_i) >= 2 else float(penalty_singleton) + + if not parsed.parse_ok: + return { + 'reward': -float(penalty_malformed), + 'parse_ok': 0.0, + 'strict_canonical': 0.0, + 'num_predicted_values': 0.0, + 'num_i_consistent_values': 0.0, + 'num_non_i_consistent_values': 0.0, + 'includes_ground_truth': 0.0, + 'value_precision': 0.0, + 'value_recall': 0.0, + 'exact_set_match': 0.0, + 'predicted_values': [], + 'target_values': [int(v) for v in target_values], + 'format_error': 'parse_failed', + } + + predicted_values = [int(v) for v in parsed.values] + target_set = set(int(v) for v in target_values) + num_good = sum(1 for v in predicted_values if v in target_set) + num_bad = sum(1 for v in predicted_values if v not in target_set) + + reward = triangular_number(num_good) * float(reward_good_value) - float(num_bad) * float(penalty_bad_value) + if not predicted_values: + reward -= float(penalty_empty) + if len(predicted_values) == 1 and len(target_values) > 1: + reward -= singleton_penalty + + precision, recall = compute_set_precision_recall(predicted_values, target_values) + return { + 'reward': float(reward), + 'parse_ok': 1.0, + 'strict_canonical': 1.0 if parsed.strict_canonical else 0.0, + 'num_predicted_values': float(len(predicted_values)), + 'num_i_consistent_values': float(num_good), + 'num_non_i_consistent_values': float(num_bad), + 'includes_ground_truth': 1.0 if solved_value in predicted_values else 0.0, + 'value_precision': float(precision), + 'value_recall': float(recall), + 'exact_set_match': 1.0 if set(predicted_values) == target_set else 0.0, + 'predicted_values': predicted_values, + 'target_values': [int(v) for v in target_values], + 'format_error': '', + } diff --git a/sudoku4x4_11empty/run_4x4_11empty_baseline_pipeline.py b/sudoku4x4_11empty/run_4x4_11empty_baseline_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d264432fd6544fb1aab1107bd7010384f2d015ba --- /dev/null +++ b/sudoku4x4_11empty/run_4x4_11empty_baseline_pipeline.py @@ -0,0 +1,614 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("sudoku4x4_11empty", "baseline")) +DEFAULT_CACHE_DIR = PARENT_DIR / ".hf_cache" +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_WANDB_GROUP = "sudoku4x4_11empty_baseline_pipeline" +DEFAULT_SFT_PROJECT = "sudoku4x4-11empty-baseline-sft" +DEFAULT_GRPO_PROJECT = "sudoku4x4-11empty-baseline-grpo" + +SFT_SCRIPT = CURRENT_DIR / "sft_multi_output_train.py" +GRPO_SCRIPT = CURRENT_DIR / "grpo_multi_output_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=11) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=4) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default=DEFAULT_WANDB_GROUP) + p.add_argument("--wandb_sft_project", type=str, default=DEFAULT_SFT_PROJECT) + p.add_argument("--wandb_grpo_project", type=str, default=DEFAULT_GRPO_PROJECT) + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=2e-4) + p.add_argument("--sft_learning_rate_later", type=float, default=5e-5) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--grpo_num_train_epochs", type=float, default=0.5) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=2) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--grpo_eval_solve_rate_stop", type=float, default=0.65) + p.add_argument("--grpo_min_steps_before_stop", type=int, default=3000) + p.add_argument("--grpo_limit_train_examples", type=int, default=0) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + if root_adapter.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + if not adapter.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + checkpoint_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + best: Optional[Artifact] = None + for stage_dir in checkpoint_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path( + normalize_to_final_checkpoint_root(args.output_root, "sudoku4x4_11empty", "baseline") + ).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_4x4" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku4x4_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: Optional[str], + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_sft_project, + "--wandb_run_name", + f"sudoku4x4_baseline_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + if init_adapter_dir: + command.extend(["--init_adapter_dir", str(init_adapter_dir)]) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--eval_solve_rate_stop", + str(float(args.grpo_eval_solve_rate_stop)), + "--min_steps_before_stop", + str(int(args.grpo_min_steps_before_stop)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_grpo_project, + "--wandb_run_name", + f"sudoku4x4_baseline_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def main() -> None: + args = parse_args() + checkpoint_root = Path( + normalize_to_final_checkpoint_root(args.checkpoint_root, "sudoku4x4_11empty", "baseline") + ).resolve() + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "model_name": str(args.model_name), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + checkpoint_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + checkpoint_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + elif stage == 1: + init_adapter_dir = str(args.stage1_init_adapter_dir).strip() or None + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch baseline stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/run_4x4_11empty_latent_pipeline.py b/sudoku4x4_11empty/run_4x4_11empty_latent_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5a754a90d88427a0ebbfb65346bdb7fc6be87b --- /dev/null +++ b/sudoku4x4_11empty/run_4x4_11empty_latent_pipeline.py @@ -0,0 +1,656 @@ +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +CURRENT_DIR = Path(__file__).resolve().parent +PARENT_DIR = CURRENT_DIR.parent +if str(PARENT_DIR) not in sys.path: + sys.path.insert(0, str(PARENT_DIR)) + +from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root + + +DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("sudoku4x4_11empty", "latent")) +DEFAULT_BASELINE_ROOT = Path(final_checkpoint_root("sudoku4x4_11empty", "baseline")) +DEFAULT_CACHE_DIR = PARENT_DIR / ".hf_cache" +DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +DEFAULT_WANDB_GROUP = "sudoku4x4_11empty_latent_pipeline" +DEFAULT_SFT_PROJECT = "sudoku4x4-11empty-latent-sft-residual-projector" +DEFAULT_GRPO_PROJECT = "sudoku4x4-11empty-latent-grpo-residual-projector" + +SFT_SCRIPT = CURRENT_DIR / "residual_projector_warmstart_sft_latent_multi_output_train.py" +GRPO_SCRIPT = CURRENT_DIR / "grpo_residual_projector_latent_train.py" +STAGE_COMPLETE_MARKER = "_stage_complete.json" + + +@dataclass +class Artifact: + path: str + stage: int + phase: str + step: int + mtime: float + source_dir: str + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--python_executable", type=str, default=sys.executable) + p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT)) + p.add_argument("--baseline_output_root", type=str, default="") + p.add_argument("--output_root", type=str, default="") + p.add_argument("--run_tag", type=str, default="") + p.add_argument("--train_jsonl", type=str, default="") + p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR)) + p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--total_empties_hint", type=int, default=11) + p.add_argument("--min_stage", type=int, default=1) + p.add_argument("--max_stage", type=int, default=4) + p.add_argument("--sft_gpu_id", type=int, default=0) + p.add_argument("--grpo_gpu_id", type=int, default=1) + p.add_argument("--stage1_init_adapter_dir", type=str, default="") + p.add_argument("--bootstrap_adapter_dir", type=str, default="") + p.add_argument("--distributed_gpu_ids", type=str, default="") + p.add_argument("--sft_num_processes", type=int, default=1) + p.add_argument("--grpo_num_processes", type=int, default=1) + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_group", type=str, default=DEFAULT_WANDB_GROUP) + p.add_argument("--wandb_sft_project", type=str, default=DEFAULT_SFT_PROJECT) + p.add_argument("--wandb_grpo_project", type=str, default=DEFAULT_GRPO_PROJECT) + p.add_argument("--sft_num_epochs", type=float, default=1.0) + p.add_argument("--sft_learning_rate_stage1", type=float, default=1e-6) + p.add_argument("--sft_learning_rate_later", type=float, default=1e-6) + p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8) + p.add_argument("--sft_enable_gradient_checkpointing", action="store_true") + p.add_argument("--sft_logging_steps", type=int, default=10) + p.add_argument("--sft_eval_steps", type=int, default=100) + p.add_argument("--sft_save_steps", type=int, default=100) + p.add_argument("--sft_eval_rows", type=int, default=20) + p.add_argument("--sft_max_completion_length", type=int, default=24) + p.add_argument("--grpo_num_train_epochs", type=float, default=1.0) + p.add_argument("--grpo_learning_rate", type=float, default=1e-6) + p.add_argument("--grpo_per_device_train_batch_size", type=int, default=4) + p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4) + p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true") + p.add_argument("--grpo_logging_steps", type=int, default=5) + p.add_argument("--grpo_eval_steps", type=int, default=25) + p.add_argument("--grpo_save_steps", type=int, default=25) + p.add_argument("--grpo_eval_rows", type=int, default=20) + p.add_argument("--grpo_num_generations", type=int, default=2) + p.add_argument("--grpo_max_prompt_length", type=int, default=1024) + p.add_argument("--grpo_max_completion_length", type=int, default=24) + p.add_argument("--grpo_beta", type=float, default=0.0) + p.add_argument("--grpo_eval_solve_rate_stop", type=float, default=0.65) + p.add_argument("--grpo_min_steps_before_stop", type=int, default=3000) + p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--sft_stage_max_steps", type=str, default="") + p.add_argument("--grpo_stage_max_steps", type=str, default="") + p.add_argument("--dry_run", action="store_true") + return p.parse_args() + + +def stage_dir_pattern(stage: int, phase: str, empties: int) -> str: + return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*" + + +def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]: + match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name) + return int(match.group(1)) if match else None + + +def stage_complete_path(stage_dir: Path) -> Path: + return stage_dir / STAGE_COMPLETE_MARKER + + +def is_stage_complete(stage_dir: Path) -> bool: + return stage_complete_path(stage_dir).is_file() + + +def output_root_has_stage_artifacts(path: Path) -> bool: + if not path.exists(): + return False + if (path / "pipeline_state.json").exists(): + return True + return any(path.glob("stage[0-9][0-9]_*")) + + +def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-step-") + if step is None: + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="sft", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]: + best: Optional[Artifact] = None + root_adapter = stage_dir / "adapter_model.safetensors" + root_latent_state = stage_dir / "latent_cot_state.pt" + if root_adapter.exists() and root_latent_state.exists(): + best = Artifact( + path=str(stage_dir), + stage=-1, + phase="grpo", + step=10**9, + mtime=stage_dir.stat().st_mtime, + source_dir=str(stage_dir), + ) + for child in stage_dir.iterdir(): + if not child.is_dir(): + continue + step = extract_numeric_suffix(child.name, "checkpoint-") + if step is None: + continue + adapter = child / "adapter_model.safetensors" + latent_state = child / "latent_cot_state.pt" + if not adapter.exists() or not latent_state.exists(): + continue + artifact = Artifact( + path=str(child), + stage=-1, + phase="grpo", + step=step, + mtime=child.stat().st_mtime, + source_dir=str(stage_dir), + ) + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def discover_latest_artifact( + search_root: Path, + *, + stage: int, + phase: str, + empties: int, + require_complete: bool = True, +) -> Optional[Artifact]: + if not search_root.exists(): + return None + best: Optional[Artifact] = None + for stage_dir in search_root.rglob(stage_dir_pattern(stage, phase, empties)): + if not stage_dir.is_dir(): + continue + if require_complete and not is_stage_complete(stage_dir): + continue + artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir) + if artifact is None: + continue + artifact.stage = stage + artifact.phase = phase + if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step): + best = artifact + return best + + +def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path: + if str(args.output_root).strip(): + requested_root = Path( + normalize_to_final_checkpoint_root(args.output_root, "sudoku4x4_11empty", "latent") + ).resolve() + if output_root_has_stage_artifacts(requested_root): + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return requested_root / run_tag + return requested_root + run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S") + return checkpoint_root / run_tag / f"latent_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_4x4" + + +def default_train_jsonl(args: argparse.Namespace) -> Path: + if str(args.train_jsonl).strip(): + return Path(args.train_jsonl).resolve() + return (PARENT_DIR / "data" / f"sudoku4x4_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve() + + +def resolve_baseline_root(args: argparse.Namespace) -> Path: + if str(args.baseline_output_root).strip(): + return Path( + normalize_to_final_checkpoint_root(args.baseline_output_root, "sudoku4x4_11empty", "baseline") + ).resolve() + run_tag = str(args.run_tag).strip() + if run_tag: + return DEFAULT_BASELINE_ROOT / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_4x4" + return DEFAULT_BASELINE_ROOT + + +def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path: + return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty_residual_projector" + + +def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None: + print("") + print("Running command:") + print(" ".join(subprocess.list2cmdline([part]) for part in command)) + if dry_run: + print("Dry run enabled; command not executed.") + return + subprocess.run(command, env=env, check=True) + + +def parse_stage_int_map(raw: str) -> Dict[int, int]: + mapping: Dict[int, int] = {} + text = str(raw or "").strip() + if not text: + return mapping + for part in text.split(","): + item = part.strip() + if not item: + continue + stage_text, value_text = item.split(":", 1) + mapping[int(stage_text.strip())] = int(value_text.strip()) + return mapping + + +def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int: + return int(mapping.get(int(stage), 0)) + + +def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]: + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()] + if int(num_processes) > 1: + if requested: + env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)]) + else: + env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id)) + env["WANDB__SERVICE_WAIT"] = "300" + env["WANDB_MODE"] = str(wandb_mode) + return env + + +def build_sft_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.sft_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(SFT_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(SFT_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.sft_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--num_epochs", + str(float(args.sft_num_epochs)), + "--learning_rate", + str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)), + "--gradient_accumulation_steps", + str(int(args.sft_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "", + "--logging_steps", + str(int(args.sft_logging_steps)), + "--save_steps", + str(int(args.sft_save_steps)), + "--eval_steps", + str(int(args.sft_eval_steps)), + "--eval_rows", + str(int(args.sft_eval_rows)), + "--max_completion_length", + str(int(args.sft_max_completion_length)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_sft_project, + "--wandb_run_name", + f"sudoku4x4_latent_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def build_grpo_command( + args: argparse.Namespace, + *, + train_jsonl: Path, + output_dir: Path, + stage: int, + init_adapter_dir: str, + stage_max_steps: int, +) -> List[str]: + num_processes = max(1, int(args.grpo_num_processes)) + if num_processes > 1: + command = [ + args.python_executable, + "-m", + "torch.distributed.run", + "--standalone", + "--nproc_per_node", + str(num_processes), + str(GRPO_SCRIPT), + ] + else: + command = [args.python_executable, "-u", str(GRPO_SCRIPT)] + command.extend( + [ + "--model_name", + args.model_name, + "--train_jsonl", + str(train_jsonl), + "--output_dir", + str(output_dir), + "--cache_dir", + args.cache_dir, + "--init_adapter_dir", + str(init_adapter_dir), + "--seed", + str(int(args.seed)), + "--gpu_id", + str(0 if num_processes > 1 else int(args.grpo_gpu_id)), + "--stage_i", + str(int(stage)), + "--num_cot_tokens", + str(int(stage)), + "--total_empties_hint", + str(int(args.total_empties_hint)), + "--per_device_train_batch_size", + str(int(args.grpo_per_device_train_batch_size)), + "--gradient_accumulation_steps", + str(int(args.grpo_gradient_accumulation_steps)), + "--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "", + "--num_train_epochs", + str(float(args.grpo_num_train_epochs)), + "--learning_rate", + str(float(args.grpo_learning_rate)), + "--logging_steps", + str(int(args.grpo_logging_steps)), + "--save_steps", + str(int(args.grpo_save_steps)), + "--eval_steps", + str(int(args.grpo_eval_steps)), + "--eval_rows", + str(int(args.grpo_eval_rows)), + "--num_generations", + str(int(args.grpo_num_generations)), + "--max_prompt_length", + str(int(args.grpo_max_prompt_length)), + "--max_completion_length", + str(int(args.grpo_max_completion_length)), + "--beta", + str(float(args.grpo_beta)), + "--eval_solve_rate_stop", + str(float(args.grpo_eval_solve_rate_stop)), + "--min_steps_before_stop", + str(int(args.grpo_min_steps_before_stop)), + "--max_wall_clock_seconds", + str(int(args.phase_max_wall_clock_seconds)), + "--wandb_group", + args.wandb_group, + ] + ) + command = [part for part in command if part != ""] + if int(args.limit_train_rows) > 0: + command.extend(["--limit_train_rows", str(int(args.limit_train_rows))]) + if int(stage_max_steps) > 0: + command.extend(["--max_steps", str(int(stage_max_steps))]) + if args.use_wandb: + command.extend(["--use_wandb"]) + if str(args.wandb_entity).strip(): + command.extend(["--wandb_entity", args.wandb_entity]) + command.extend( + [ + "--wandb_project", + args.wandb_grpo_project, + "--wandb_run_name", + f"sudoku4x4_latent_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty_residual_projector", + "--wandb_mode", + args.wandb_mode, + ] + ) + return command + + +def write_state(path: Path, payload: Dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None: + write_state( + stage_complete_path(stage_dir), + { + "completed_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "artifact": asdict(artifact), + }, + ) + + +def resolve_stage1_init_adapter(args: argparse.Namespace, baseline_root: Path) -> str: + if str(args.bootstrap_adapter_dir).strip(): + return str(args.bootstrap_adapter_dir).strip() + if str(args.stage1_init_adapter_dir).strip(): + return str(args.stage1_init_adapter_dir).strip() + baseline_stage1 = discover_latest_artifact( + baseline_root, + stage=1, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if baseline_stage1 is None: + raise RuntimeError( + "Latent stage 1 needs a warmstart adapter. Pass --bootstrap_adapter_dir / --stage1_init_adapter_dir " + "or provide --baseline_output_root with a completed baseline stage 1 SFT checkpoint." + ) + return baseline_stage1.path + + +def main() -> None: + args = parse_args() + checkpoint_root = Path( + normalize_to_final_checkpoint_root(args.checkpoint_root, "sudoku4x4_11empty", "latent") + ).resolve() + baseline_root = resolve_baseline_root(args) + output_root = choose_output_root(args, checkpoint_root) + train_jsonl = default_train_jsonl(args) + state_path = output_root / "pipeline_state.json" + sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps) + grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps) + + output_root.mkdir(parents=True, exist_ok=True) + if not train_jsonl.exists(): + raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}") + + state: Dict[str, Any] = { + "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"), + "train_jsonl": str(train_jsonl), + "checkpoint_root": str(checkpoint_root), + "baseline_output_root": str(baseline_root), + "output_root": str(output_root), + "min_stage": int(args.min_stage), + "max_stage": int(args.max_stage), + "total_empties_hint": int(args.total_empties_hint), + "model_name": str(args.model_name), + "stages": [], + } + + previous_grpo: Optional[Artifact] = None + for stage in range(int(args.min_stage), int(args.max_stage) + 1): + stage_record: Dict[str, Any] = {"stage": stage} + existing_sft = discover_latest_artifact( + output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint) + ) + existing_grpo = discover_latest_artifact( + output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint) + ) + + if existing_grpo is not None: + previous_grpo = existing_grpo + stage_record["status"] = "using_existing_grpo" + stage_record["grpo_artifact"] = asdict(existing_grpo) + if existing_sft is not None: + stage_record["sft_artifact"] = asdict(existing_sft) + state["stages"].append(stage_record) + write_state(state_path, state) + print(f"Stage {stage}: using existing latent GRPO artifact {existing_grpo.path}") + continue + + if existing_sft is None: + sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)) + if stage == 1: + init_adapter_dir = resolve_stage1_init_adapter(args, baseline_root) + elif stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip(): + init_adapter_dir = str(args.bootstrap_adapter_dir).strip() + else: + if previous_grpo is None: + raise RuntimeError(f"Missing previous GRPO artifact needed to launch latent stage {stage} SFT.") + init_adapter_dir = previous_grpo.path + print(f"Stage {stage}: launching latent SFT into {sft_output_dir}") + run_command( + build_sft_command( + args, + train_jsonl=train_jsonl, + output_dir=sft_output_dir, + stage=stage, + init_adapter_dir=init_adapter_dir, + stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.sft_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.sft_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_sft = discover_latest_artifact( + output_root, + stage=stage, + phase="sft", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_sft is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent SFT checkpoint for stage {stage} after running SFT.") + if existing_sft is not None: + mark_stage_complete(Path(existing_sft.source_dir), existing_sft) + stage_record["sft_artifact"] = asdict(existing_sft) + else: + stage_record["sft_artifact"] = asdict(existing_sft) + print(f"Stage {stage}: using existing latent SFT artifact {existing_sft.path}") + + if existing_sft is None: + stage_record["status"] = "dry_run_pending_grpo" + state["stages"].append(stage_record) + write_state(state_path, state) + break + + grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)) + print(f"Stage {stage}: launching latent GRPO into {grpo_output_dir}") + run_command( + build_grpo_command( + args, + train_jsonl=train_jsonl, + output_dir=grpo_output_dir, + stage=stage, + init_adapter_dir=existing_sft.path, + stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage), + ), + env=make_env( + gpu_id=int(args.grpo_gpu_id), + wandb_mode=args.wandb_mode, + gpu_ids=args.distributed_gpu_ids, + num_processes=int(args.grpo_num_processes), + ), + dry_run=bool(args.dry_run), + ) + existing_grpo = discover_latest_artifact( + output_root, + stage=stage, + phase="grpo", + empties=int(args.total_empties_hint), + require_complete=False, + ) + if existing_grpo is None and not args.dry_run: + raise RuntimeError(f"Could not locate latent GRPO artifact for stage {stage} after running GRPO.") + if existing_grpo is not None: + mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo) + previous_grpo = existing_grpo + stage_record["grpo_artifact"] = asdict(existing_grpo) + stage_record["status"] = "launched" + state["stages"].append(stage_record) + write_state(state_path, state) + + state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S") + write_state(state_path, state) + print("") + print(f"Pipeline state written to: {state_path}") + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/sft_multi_output_train.py b/sudoku4x4_11empty/sft_multi_output_train.py new file mode 100644 index 0000000000000000000000000000000000000000..00e08b87e7c9461721196d6d7139dd2cd835cbae --- /dev/null +++ b/sudoku4x4_11empty/sft_multi_output_train.py @@ -0,0 +1,610 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import math +import os +import random +import sys +import time +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, List + +import torch +import torch.distributed as dist +from peft import LoraConfig, PeftModel, get_peft_model +from torch.optim import AdamW +from torch.utils.data import DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(CURRENT_DIR) +if PARENT_DIR not in sys.path: + sys.path.insert(0, PARENT_DIR) + +from sudoku4x4_11empty.shared_cell_policy import build_cell_examples_from_row +from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final +from sudoku4x4_11empty.prompt_builder import build_multi_output_cell_prompt +from sudoku4x4_11empty.rewards import score_prediction_text +from sudoku4x4_11empty.shared_multi_output_policy import ( + build_supervised_completion, + completion_ce_loss, + make_solved_grid_from_row, +) + + +try: + import wandb +except Exception: + wandb = None + + +@dataclass +class Args: + model_name: str + train_jsonl: str + output_dir: str + cache_dir: str + init_adapter_dir: str + seed: int + gpu_id: int + stage_i: int + total_empties_hint: int + per_device_train_batch_size: int + gradient_accumulation_steps: int + num_epochs: float + learning_rate: float + weight_decay: float + enable_gradient_checkpointing: bool + logging_steps: int + save_steps: int + eval_steps: int + eval_rows: int + max_completion_length: int + lora_r: int + lora_alpha: int + lora_dropout: float + use_wandb: bool + wandb_entity: str + wandb_project: str + wandb_run_name: str + wandb_mode: str + debug_print_limit: int + limit_train_rows: int + eval_exact_set_match_stop: float + max_wall_clock_seconds: int + max_steps: int + + +def configure_hf_cache(cache_dir: str) -> str: + cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) + hub_dir = os.path.join(cache_dir, "hub") + transformers_dir = os.path.join(cache_dir, "transformers") + os.makedirs(hub_dir, exist_ok=True) + os.makedirs(transformers_dir, exist_ok=True) + os.environ["HF_HOME"] = cache_dir + os.environ["HF_HUB_CACHE"] = hub_dir + os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir + os.environ["TRANSFORMERS_CACHE"] = transformers_dir + os.environ.setdefault("HF_HUB_DISABLE_XET", "1") + return cache_dir + + +def configure_wandb_dirs(output_dir: str) -> None: + wandb_dir = os.path.join(output_dir, "wandb_runtime") + os.makedirs(wandb_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", wandb_dir) + os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) + os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) + + +def pick_dtype() -> torch.dtype: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + return torch.float16 + + +def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + if limit_rows > 0 and len(rows) >= limit_rows: + break + return rows + + +def build_training_examples( + rows: List[Dict[str, Any]], + *, + tokenizer: Any, + stage_i: int, + total_empties_hint: int, + progress_every_rows: int = 10, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + examples: List[Dict[str, Any]] = [] + eos_text = getattr(tokenizer, "eos_token", None) or "" + for row_idx, row in enumerate(rows, start=1): + solved = make_solved_grid_from_row(row) + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=total_empties_hint, + ) + target_text = build_supervised_completion(ex, stage_i=stage_i) + if eos_text: + target_text = target_text + eos_text + examples.append( + { + "prompt_text": prompt, + "completion_text": target_text, + "grid": ex.grid, + "solved": solved, + "target_cell": ex.target_cell, + } + ) + if progress_callback is not None and ( + row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 + ): + progress_callback(row_idx, len(rows), len(examples)) + return examples + + +def _prepared_data_dir(args: Args) -> str: + path = os.path.join(PARENT_DIR, "_prepared_data", "sudoku4x4_11empty") + os.makedirs(path, exist_ok=True) + return path + + +def _prepared_sft_cache_path(args: Args) -> str: + payload = json.dumps( + { + "completion_format_version": 2, + "train_jsonl": os.path.abspath(args.train_jsonl), + "stage_i": int(args.stage_i), + "total_empties_hint": int(args.total_empties_hint), + "limit_train_rows": int(args.limit_train_rows), + "model_name": str(args.model_name), + }, + sort_keys=True, + ).encode("utf-8") + digest = hashlib.sha1(payload).hexdigest()[:20] + return os.path.join(_prepared_data_dir(args), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl") + + +def _to_jsonable(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_jsonable(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_to_jsonable(v) for v in value] + if hasattr(value, "tolist"): + return _to_jsonable(value.tolist()) + return value + + +def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: + tmp_path = f"{path}.tmp" + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + f.write(json.dumps(_to_jsonable(row), separators=(",", ":")) + "\n") + os.replace(tmp_path, path) + + +def _read_jsonl(path: str) -> List[Dict[str, Any]]: + rows: List[Dict[str, Any]] = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: + start = time.time() + while not os.path.exists(path): + if time.time() - start > timeout_s: + raise TimeoutError(f"Timed out waiting for prepared cache: {path}") + time.sleep(2.0) + + +def load_or_build_sft_examples( + args: Args, + *, + rows: List[Dict[str, Any]], + tokenizer: Any, + rank: int, + world_size: int, + progress_callback: Any = None, +) -> List[Dict[str, Any]]: + cache_path = _prepared_sft_cache_path(args) + if os.path.exists(cache_path): + return _read_jsonl(cache_path) + if rank == 0: + print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) + examples = build_training_examples( + rows, + tokenizer=tokenizer, + stage_i=args.stage_i, + total_empties_hint=args.total_empties_hint, + progress_every_rows=10, + progress_callback=progress_callback, + ) + _write_jsonl(cache_path, examples) + return examples + _wait_for_cache(cache_path) + return _read_jsonl(cache_path) + + +@torch.no_grad() +def run_eval(args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device): + model.eval() + total_cells = 0 + parse_ok = 0.0 + canonical_ok = 0.0 + exact_set_match = 0.0 + includes_gt = 0.0 + precision_sum = 0.0 + recall_sum = 0.0 + predicted_size_sum = 0.0 + good_count_sum = 0.0 + bad_count_sum = 0.0 + solve_ok = 0 + printed = 0 + for row in rows: + solved = make_solved_grid_from_row(row) + row_all_exact = True + for ex in build_cell_examples_from_row(row): + prompt = build_multi_output_cell_prompt( + ex.grid, + target_cell=ex.target_cell, + stage_i=args.stage_i, + tokenizer=tokenizer, + turn_idx=ex.turn_idx, + total_turns=ex.total_turns, + prev_output_flag=None, + total_empties_hint=args.total_empties_hint, + ) + enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) + enc = {k: v.to(device) for k, v in enc.items()} + out = model.generate( + **enc, + max_new_tokens=max(1, int(args.max_completion_length)), + do_sample=False, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() + info = score_prediction_text( + text=pred_text, + grid=ex.grid, + solved=solved, + target_cell=ex.target_cell, + stage_i=args.stage_i, + reward_good_value=1.0, + penalty_bad_value=1.75, + penalty_malformed=4.0, + penalty_empty=0.5, + penalty_singleton=1.5, + ) + total_cells += 1 + parse_ok += float(info["parse_ok"]) + canonical_ok += float(info["strict_canonical"]) + exact_set_match += float(info["exact_set_match"]) + includes_gt += float(info["includes_ground_truth"]) + precision_sum += float(info["value_precision"]) + recall_sum += float(info["value_recall"]) + predicted_size_sum += float(info["num_predicted_values"]) + good_count_sum += float(info["num_i_consistent_values"]) + bad_count_sum += float(info["num_non_i_consistent_values"]) + if float(info["exact_set_match"]) < 0.5: + row_all_exact = False + if printed < int(args.debug_print_limit): + rr, cc = ex.target_cell + print(f"[baseline sft eval debug] target=({rr+1},{cc+1}) output={pred_text!r}", flush=True) + print( + f"[baseline sft eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}", + flush=True, + ) + printed += 1 + solve_ok += int(row_all_exact) + out = { + "parse_rate": float(parse_ok / max(1, total_cells)), + "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), + "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), + "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), + "value_precision": float(precision_sum / max(1, total_cells)), + "value_recall": float(recall_sum / max(1, total_cells)), + "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), + "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), + "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), + "solve_rate": float(solve_ok / max(1, len(rows))), + } + print( + f"[baseline sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " + f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " + f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}", + flush=True, + ) + model.train() + return out + + +def parse_args() -> Args: + p = argparse.ArgumentParser() + p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") + p.add_argument("--train_jsonl", type=str, required=True) + p.add_argument("--output_dir", type=str, required=True) + p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum_cot/.hf_cache") + p.add_argument("--init_adapter_dir", type=str, default="") + p.add_argument("--seed", type=int, default=0) + p.add_argument("--gpu_id", type=int, default=0) + p.add_argument("--stage_i", type=int, default=1) + p.add_argument("--total_empties_hint", type=int, default=10) + p.add_argument("--per_device_train_batch_size", type=int, default=1) + p.add_argument("--gradient_accumulation_steps", type=int, default=8) + p.add_argument("--num_epochs", type=float, default=1.0) + p.add_argument("--learning_rate", type=float, default=2e-4) + p.add_argument("--weight_decay", type=float, default=0.0) + p.add_argument("--enable_gradient_checkpointing", action="store_true") + p.add_argument("--logging_steps", type=int, default=10) + p.add_argument("--save_steps", type=int, default=100) + p.add_argument("--eval_steps", type=int, default=100) + p.add_argument("--eval_rows", type=int, default=20) + p.add_argument("--max_completion_length", type=int, default=24) + p.add_argument("--lora_r", type=int, default=16) + p.add_argument("--lora_alpha", type=int, default=32) + p.add_argument("--lora_dropout", type=float, default=0.05) + p.add_argument("--use_wandb", action="store_true") + p.add_argument("--wandb_entity", type=str, default="") + p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-sft") + p.add_argument("--wandb_run_name", type=str, default="") + p.add_argument("--wandb_mode", type=str, default="online") + p.add_argument("--debug_print_limit", type=int, default=3) + p.add_argument("--limit_train_rows", type=int, default=0) + p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) + p.add_argument("--max_wall_clock_seconds", type=int, default=0) + p.add_argument("--max_steps", type=int, default=0) + return Args(**vars(p.parse_args())) + + +def save_checkpoint(model: torch.nn.Module, tokenizer: Any, output_dir: str, step: int) -> None: + save_checkpoint_and_update_final(model, tokenizer, output_dir, f"checkpoint-step-{step:05d}") + + +def main() -> None: + args = parse_args() + preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() + rank = int(os.environ.get("RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + is_distributed = world_size > 1 + is_main_process = rank == 0 + + if preset_visible_devices: + if is_main_process: + print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) + elif int(args.gpu_id) >= 0: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) + + if is_distributed: + if torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + dist.init_process_group(backend="nccl", timeout=timedelta(hours=2)) + + set_seed(args.seed + rank) + os.makedirs(args.output_dir, exist_ok=True) + ensure_final_checkpoint_dir(args.output_dir) + cache_dir = configure_hf_cache(args.cache_dir) + configure_wandb_dirs(args.output_dir) + + wb_run = None + if is_main_process and args.use_wandb and wandb is not None: + init_kwargs = { + "project": args.wandb_project, + "name": args.wandb_run_name or None, + "mode": args.wandb_mode, + } + if str(args.wandb_entity).strip(): + init_kwargs["entity"] = args.wandb_entity + wb_run = wandb.init(**init_kwargs) + print(f"W&B run id: {wb_run.id}", flush=True) + print(f"W&B run URL: {wb_run.url}", flush=True) + wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) + + rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) + eval_rows = rows[: max(1, int(args.eval_rows))] + tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" + if torch.cuda.is_available(): + device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") + else: + device = torch.device("cpu") + + model = AutoModelForCausalLM.from_pretrained( + args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True + ) + if str(args.init_adapter_dir).strip(): + model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True) + else: + lora = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + ) + model = get_peft_model(model, lora) + if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if hasattr(model, "config"): + model.config.use_cache = False + model.to(device) + model.train() + + def on_prep_progress(rows_done: int, total_rows: int, examples_built: int) -> None: + if is_main_process: + print( + f"[dataset build][sft stage {args.stage_i}] rows={rows_done}/{total_rows} examples={examples_built}", + flush=True, + ) + if wb_run is not None: + wandb.log({"prep/rows_done": float(rows_done), "prep/examples_built": float(examples_built)}) + + train_examples = load_or_build_sft_examples( + args, + rows=rows, + tokenizer=tokenizer, + rank=rank, + world_size=world_size, + progress_callback=on_prep_progress, + ) + if is_main_process and wb_run is not None: + wandb.log( + { + "prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))), + "prep/examples_final": float(len(train_examples)), + } + ) + + optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) + total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / max(1, args.gradient_accumulation_steps))) + if int(args.max_steps) > 0: + total_steps = min(total_steps, int(args.max_steps)) + step = 0 + start_time = time.time() + + def average_scalar(value: float) -> float: + if not is_distributed or not dist.is_initialized(): + return float(value) + tensor = torch.tensor(float(value), device=device, dtype=torch.float32) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return float((tensor / float(world_size)).item()) + + def all_reduce_gradients() -> None: + if not is_distributed or not dist.is_initialized(): + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) + param.grad.div_(float(world_size)) + + def sync_stop(local_stop: bool) -> bool: + if not is_distributed or not dist.is_initialized(): + return bool(local_stop) + tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return bool(int(tensor.item()) > 0) + + for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): + if is_distributed: + sampler = DistributedSampler( + train_examples, + num_replicas=world_size, + rank=rank, + shuffle=True, + seed=args.seed, + drop_last=False, + ) + sampler.set_epoch(epoch_idx) + order = list(iter(sampler)) + else: + generator = torch.Generator() + generator.manual_seed(args.seed + epoch_idx) + order = torch.randperm(len(train_examples), generator=generator).tolist() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + for ex_idx in order: + ex = train_examples[ex_idx] + loss = completion_ce_loss( + model, + tokenizer, + ex["prompt_text"], + ex["completion_text"], + device, + ) / max(1, int(args.gradient_accumulation_steps)) + loss.backward() + accum_count += 1 + if accum_count >= int(args.gradient_accumulation_steps): + all_reduce_gradients() + optimizer.step() + optimizer.zero_grad(set_to_none=True) + accum_count = 0 + step += 1 + if step % int(args.logging_steps) == 0: + loss_value = average_scalar(float(loss.item()) * args.gradient_accumulation_steps) + if is_main_process: + print(f"[baseline sft train step {step:05d}] loss={loss_value:.4f}", flush=True) + if wb_run is not None: + wandb.log({"train/loss": loss_value, "step": step}) + if step % int(args.eval_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + should_stop_eval = False + if is_main_process: + ev = run_eval(args, eval_rows, model, tokenizer, device) + if wb_run is not None: + wandb.log({f"eval/{k}": float(v) for k, v in ev.items()} | {"step": step}) + if ( + args.eval_exact_set_match_stop > 0.0 + and float(ev["exact_set_match_rate"]) >= args.eval_exact_set_match_stop + ): + save_checkpoint(model, tokenizer, args.output_dir, step) + should_stop_eval = True + should_stop_eval = sync_stop(should_stop_eval) + if is_distributed and dist.is_initialized(): + dist.barrier() + if should_stop_eval: + if is_main_process and wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.destroy_process_group() + return + if step % int(args.save_steps) == 0: + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if is_distributed and dist.is_initialized(): + dist.barrier() + reached_limit = step >= total_steps + exceeded_wall = bool(args.max_wall_clock_seconds) and ( + time.time() - start_time >= float(args.max_wall_clock_seconds) + ) + should_stop = sync_stop(reached_limit or exceeded_wall) + if should_stop: + break + if sync_stop(step >= total_steps): + break + + if is_distributed and dist.is_initialized(): + dist.barrier() + if is_main_process: + save_checkpoint(model, tokenizer, args.output_dir, step) + if wb_run is not None: + wb_run.finish() + if is_distributed and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/sudoku4x4_11empty/shared_cell_policy.py b/sudoku4x4_11empty/shared_cell_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ab19747609e0a4e4afe7d71e64f862ada085c29d --- /dev/null +++ b/sudoku4x4_11empty/shared_cell_policy.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import numpy as np + + +_INT_RE = re.compile(r"-?\d+") +_TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)") + + +@dataclass(frozen=True) +class CellExample: + grid: np.ndarray + target_cell: tuple[int, int] + target_value: int + turn_idx: int + total_turns: int + + +def parse_n_value_prediction(text: str, n: int) -> Tuple[List[int] | None, bool]: + raw = str(text or '').strip() + if not raw: + return None, False + + try: + parsed = json.loads(raw) + if isinstance(parsed, dict) and isinstance(parsed.get('values'), list): + values = [int(v) for v in parsed['values']] + if len(values) == int(n): + return values, True + if isinstance(parsed, list): + values = [int(v) for v in parsed] + if len(values) == int(n): + return values, True + except Exception: + pass + + values = [int(match.group(0)) for match in _INT_RE.finditer(raw)] + if len(values) == int(n): + return values, True + return None, False + + +def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray: + triples = _TUPLE_PROMPT_RE.findall(str(prompt_text)) + if len(triples) < 16: + raise ValueError('Could not recover 16 (row,col,value) tuples from prompt.') + grid = np.zeros((4, 4), dtype=int) + for rr, cc, vv in triples[:16]: + grid[int(rr) - 1, int(cc) - 1] = int(vv) + return grid + + +def build_cell_examples_from_row(row: Dict[str, Any]) -> List[CellExample]: + prompt = str(row['prompt']) + grid = parse_grid_from_tuple_prompt(prompt) + metadata = dict(row.get('metadata', {})) + empty_locs = metadata.get('empty_locs_1based') + target_triples = metadata.get('target_triples_1based') + + if not empty_locs or not target_triples: + completion = str(row.get('completion', '')) + parsed, _ = parse_n_value_prediction(completion, int(metadata.get('empties', 0) or 0)) + if parsed is None: + raise ValueError('Row is missing metadata and completion could not be parsed.') + empty_locs = [(r + 1, c + 1) for r, c in np.argwhere(grid == 0).tolist()] + target_triples = [(int(r), int(c), int(v)) for (r, c), v in zip(empty_locs, parsed)] + + total_turns = len(target_triples) + out: List[CellExample] = [] + for idx, triple in enumerate(target_triples, start=1): + rr, cc, value = int(triple[0]) - 1, int(triple[1]) - 1, int(triple[2]) + out.append( + CellExample( + grid=np.asarray(grid, dtype=int).copy(), + target_cell=(rr, cc), + target_value=value, + turn_idx=idx, + total_turns=total_turns, + ) + ) + return out diff --git a/sudoku4x4_11empty/shared_multi_output_policy.py b/sudoku4x4_11empty/shared_multi_output_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..76c950a532f3cf83e2b7500061d5869b9349e0ee --- /dev/null +++ b/sudoku4x4_11empty/shared_multi_output_policy.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import itertools +import json +import math +import random +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, Iterable, List, Optional, Sequence + +import numpy as np +import torch + +from sudoku4x4_11empty.shared_cell_policy import CellExample, parse_grid_from_tuple_prompt +from formatting_icon import is_consistent_pair + + +GRID_SIZE = 4 +BOX_SIZE = 2 +ALL_VALUES = (1, 2, 3, 4) +NUM_CELLS = GRID_SIZE * GRID_SIZE + + +def all_remaining_empties_have_legal_value(grid: np.ndarray) -> bool: + g = np.asarray(grid, dtype=int).reshape(GRID_SIZE, GRID_SIZE) + for r in range(GRID_SIZE): + for c in range(GRID_SIZE): + if int(g[r, c]) != 0: + continue + cell = r * GRID_SIZE + c + has_legal = any(is_consistent_pair(g, cell=cell, value=v, t=BOX_SIZE, n=GRID_SIZE) for v in ALL_VALUES) + if not has_legal: + return False + return True + + +@dataclass(frozen=True) +class ParsedValues: + values: tuple[int, ...] + parse_ok: bool + strict_canonical: bool + + +def all_digit_values() -> List[int]: + return list(ALL_VALUES) + + +def make_solved_grid_from_row(row: Dict[str, Any]) -> np.ndarray: + grid = parse_grid_from_tuple_prompt(str(row['prompt'])) + solved = np.asarray(grid, dtype=int).copy() + triples = row.get('metadata', {}).get('target_triples_1based', []) + for rr, cc, value in triples: + solved[int(rr) - 1, int(cc) - 1] = int(value) + return solved + + +def _grid_state_key(grid: np.ndarray) -> tuple[int, ...]: + return tuple(int(v) for v in np.asarray(grid, dtype=int).reshape(-1)) + + +def _legal_values_for_cell(state: tuple[int, ...], cell: int) -> tuple[int, ...]: + rr, cc = divmod(int(cell), GRID_SIZE) + if int(state[cell]) != 0: + return tuple() + g = np.asarray(state, dtype=int).reshape(GRID_SIZE, GRID_SIZE) + return tuple( + int(value) + for value in all_digit_values() + if is_consistent_pair(g, cell=int(cell), value=int(value), t=BOX_SIZE, n=GRID_SIZE) + ) + + +@lru_cache(maxsize=200000) +def _stage_i_consistent_values_for_grid(state: tuple[int, ...], stage_i: int) -> tuple[tuple[int, ...], ...]: + stage_i = max(1, int(stage_i)) + out: List[tuple[int, ...]] = [tuple() for _ in range(NUM_CELLS)] + + for cell in range(NUM_CELLS): + legal_values = _legal_values_for_cell(state, cell) + if not legal_values: + continue + if stage_i <= 1: + out[cell] = legal_values + continue + + consistent_values: List[int] = [] + for value in legal_values: + child = list(state) + child[cell] = int(value) + child_state = tuple(child) + child_stage_values = _stage_i_consistent_values_for_grid(child_state, stage_i - 1) + if all(int(child_state[idx]) != 0 or len(child_stage_values[idx]) > 0 for idx in range(NUM_CELLS)): + consistent_values.append(int(value)) + out[cell] = tuple(consistent_values) + + return tuple(out) + + +def stage_i_consistent_values( + grid: np.ndarray, + *, + target_cell: tuple[int, int], + stage_i: int, +) -> List[int]: + g = np.asarray(grid, dtype=int).reshape(GRID_SIZE, GRID_SIZE) + rr, cc = int(target_cell[0]), int(target_cell[1]) + if int(g[rr, cc]) != 0: + return [] + cell = rr * GRID_SIZE + cc + stage_values = _stage_i_consistent_values_for_grid(_grid_state_key(g), int(stage_i)) + return [int(value) for value in stage_values[cell]] + + +def canonicalize_values(values: Iterable[int]) -> List[int]: + out: List[int] = [] + seen = set() + for value in values: + if isinstance(value, bool): + raise ValueError('Boolean values are not allowed.') + vv = int(value) + if vv < 1 or vv > GRID_SIZE: + raise ValueError(f'Value must be in [1,{GRID_SIZE}], got {vv}.') + if vv not in seen: + seen.add(vv) + out.append(vv) + return out + + +def values_json_text(values: Iterable[int]) -> str: + return json.dumps({'values': canonicalize_values(values)}, separators=(',', ':')) + + +def parse_values_json(text: str) -> ParsedValues: + raw = str(text).strip() + if not raw: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + try: + obj = json.loads(raw) + except Exception: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if not isinstance(obj, dict): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if set(obj.keys()) != {'values'}: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + values_obj = obj.get('values') + if not isinstance(values_obj, list): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + try: + values = canonicalize_values(values_obj) + except Exception: + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + if len(values) != len(values_obj): + return ParsedValues(values=tuple(), parse_ok=False, strict_canonical=False) + canonical = values_json_text(values) + return ParsedValues(values=tuple(values), parse_ok=True, strict_canonical=(canonical == raw)) + + +def compute_set_precision_recall(pred_values: Sequence[int], target_values: Sequence[int]) -> tuple[float, float]: + pred = set(int(v) for v in pred_values) + target = set(int(v) for v in target_values) + precision = 0.0 if not pred else float(len(pred & target) / max(1, len(pred))) + recall = 1.0 if not target else float(len(pred & target) / max(1, len(target))) + return precision, recall + + +def completion_ce_loss( + model: torch.nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors='pt', add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors='pt', add_special_tokens=False).input_ids.to(device) + labels = all_ids.clone() + labels[:, : int(prompt_ids.shape[1])] = -100 + out = model(input_ids=all_ids, labels=labels) + return out.loss + + +def completion_logprob( + model: torch.nn.Module, + tokenizer: Any, + prompt_text: str, + completion_text: str, + device: torch.device, +) -> torch.Tensor: + prompt_ids = tokenizer(prompt_text, return_tensors='pt', add_special_tokens=False).input_ids.to(device) + all_ids = tokenizer(prompt_text + completion_text, return_tensors='pt', add_special_tokens=False).input_ids.to(device) + labels = all_ids.clone() + labels[:, : int(prompt_ids.shape[1])] = -100 + out = model(input_ids=all_ids, labels=labels) + num_completion_tokens = int((labels != -100).sum().item()) + return -out.loss * max(1, num_completion_tokens) + + +def enumerate_value_permutations( + values: Sequence[int], + *, + max_permutations: int, + rng: Optional[random.Random] = None, +) -> List[tuple[int, ...]]: + uniq = tuple(canonicalize_values(values)) + if len(uniq) <= 1: + return [uniq] + total = math.factorial(len(uniq)) + if total <= max(1, int(max_permutations)): + return [tuple(p) for p in itertools.permutations(uniq)] + + rr = rng or random.Random(0) + perms = set() + base = list(uniq) + max_needed = max(1, int(max_permutations)) + while len(perms) < max_needed: + rr.shuffle(base) + perms.add(tuple(base)) + return list(perms) + + +def build_supervised_completion(ex: CellExample, *, stage_i: int) -> str: + values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) + return values_json_text(values)