Initial code dump (rebuttal-ready snapshot)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +70 -0
- .gitignore +26 -0
- REBUTTAL_REPORT.md +589 -0
- REBUTTAL_REPORT.tex +711 -0
- addition/README.md +114 -0
- addition/__init__.py +1 -0
- addition/config.py +294 -0
- addition/data.py +390 -0
- addition/eval.py +326 -0
- addition/model.py +190 -0
- addition/plots.py +135 -0
- addition/run_comparison.py +122 -0
- addition/train.py +369 -0
- aligned_cell_policy/shared_cell_policy.py +69 -0
- analysis/eval_saved_hard9x9_checkpoints.py +273 -0
- checkpoint_utils.py +127 -0
- format_utils_icon.py +39 -0
- formatting_icon.py +34 -0
- hard_9x9_10empty/launch_baseline_stage3_pipeline.sh +102 -0
- hard_9x9_15empty/launch_baseline_pipeline.sh +65 -0
- hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh +103 -0
- hard_9x9_7empty/launch_stage1_sft.sh +99 -0
- hard_9x9_curriculum/build_stage3_hard_dataset.py +448 -0
- hard_9x9_stage1_consistency_queue/README.md +117 -0
- hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh +158 -0
- hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh +62 -0
- hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh +365 -0
- hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh +112 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh +125 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh +62 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh +341 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh +279 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh +368 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh +112 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh +187 -0
- hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh +394 -0
- hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh +419 -0
- hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh +372 -0
- hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh +113 -0
- hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md +68 -0
- hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md +83 -0
- hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh +84 -0
- hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md +65 -0
- large_baseline_extension/README.md +35 -0
- large_baseline_extension/launch_nonlocation_grpo.sh +103 -0
- large_baseline_extension/launch_nonlocation_pipeline.sh +80 -0
- large_baseline_extension/launch_nonlocation_sft.sh +87 -0
- large_latent_extension/README.md +32 -0
- large_latent_extension/launch_nonlocation_grpo.sh +101 -0
- large_latent_extension/launch_nonlocation_pipeline.sh +82 -0
.env.example
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy to .env for local runs. Do not put real secrets in this example file.
|
| 2 |
+
|
| 3 |
+
# Hugging Face token, if needed for private/gated repos or higher rate limits.
|
| 4 |
+
# HF_TOKEN=hf_xxx
|
| 5 |
+
|
| 6 |
+
# Weights & Biases configuration.
|
| 7 |
+
WANDB_MODE=online
|
| 8 |
+
WANDB_ENTITY=training-dynamics
|
| 9 |
+
WANDB_PROJECT=
|
| 10 |
+
WANDB_RUN_NAME=
|
| 11 |
+
WANDB__SERVICE_WAIT=300
|
| 12 |
+
|
| 13 |
+
# GPU/distributed launch defaults.
|
| 14 |
+
GPU_IDS=0,1,2,3,4,5,6,7
|
| 15 |
+
NUM_PROCESSES=8
|
| 16 |
+
NPROC_PER_JOB=2
|
| 17 |
+
MASTER_PORT=29501
|
| 18 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 19 |
+
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 20 |
+
|
| 21 |
+
# Model/cache/output defaults used by launch scripts.
|
| 22 |
+
MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
|
| 23 |
+
CACHE_DIR=.hf_cache
|
| 24 |
+
RUN_TAG=
|
| 25 |
+
CHECKPOINT_ROOT=
|
| 26 |
+
OUTPUT_ROOT=
|
| 27 |
+
|
| 28 |
+
# Warm-baseline all-latent stage pipeline defaults.
|
| 29 |
+
EMPTIES=20
|
| 30 |
+
MODES_SPEC=recurrent_hidden
|
| 31 |
+
GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7
|
| 32 |
+
TRAIN_PUZZLES=10000
|
| 33 |
+
EVAL_PUZZLES=100
|
| 34 |
+
SOLVE_TARGET=0.95
|
| 35 |
+
VALUE_TARGET=0
|
| 36 |
+
MIN_STEPS_BEFORE_STOP=50
|
| 37 |
+
BASELINE_WARM_MAX_STEPS=1000
|
| 38 |
+
LATENT_SFT_MAX_STEPS=1000
|
| 39 |
+
LATENT_GRPO_MAX_STEPS=500
|
| 40 |
+
SFT_NUM_EPOCHS=64
|
| 41 |
+
GRPO_NUM_TRAIN_EPOCHS=50
|
| 42 |
+
|
| 43 |
+
# SFT/GRPO batch and LoRA defaults.
|
| 44 |
+
SFT_PER_DEVICE_BS=8
|
| 45 |
+
SFT_GRAD_ACCUM=2
|
| 46 |
+
BASELINE_PER_DEVICE_BS=16
|
| 47 |
+
BASELINE_GRAD_ACCUM=2
|
| 48 |
+
GRPO_PER_DEVICE_BS=4
|
| 49 |
+
GRPO_GRAD_ACCUM=2
|
| 50 |
+
LORA_R=32
|
| 51 |
+
LORA_ALPHA=64
|
| 52 |
+
LORA_DROPOUT=0.05
|
| 53 |
+
GRPO_BETA=0.0
|
| 54 |
+
|
| 55 |
+
# Optional resume adapters.
|
| 56 |
+
STAGE1_BASELINE_ADAPTER_DIR=
|
| 57 |
+
STAGE1_LATENT_SFT_ADAPTER_DIR=
|
| 58 |
+
STAGE1_LATENT_GRPO_ADAPTER_DIR=
|
| 59 |
+
STAGE2_BASELINE_WARM_ADAPTER_DIR=
|
| 60 |
+
STAGE2_LATENT_SFT_ADAPTER_DIR=
|
| 61 |
+
STAGE2_LATENT_GRPO_ADAPTER_DIR=
|
| 62 |
+
STAGE3_BASELINE_WARM_ADAPTER_DIR=
|
| 63 |
+
STAGE3_LATENT_SFT_ADAPTER_DIR=
|
| 64 |
+
|
| 65 |
+
# Optional debug knobs.
|
| 66 |
+
FIXED_SLOT_DEBUG_LIMIT=0
|
| 67 |
+
FIXED_SLOT_DECODE_DEBUG_LIMIT=0
|
| 68 |
+
LATENT_VOCAB_DEBUG_TOPK=1
|
| 69 |
+
ATTN_DENSITY_DEBUG_LIMIT=0
|
| 70 |
+
ATTN_DENSITY_THRESHOLD_MULT=1.0
|
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
|
| 4 |
+
.env
|
| 5 |
+
.env.*
|
| 6 |
+
!.env.example
|
| 7 |
+
|
| 8 |
+
_prepared_data/
|
| 9 |
+
data/
|
| 10 |
+
checkpoints/
|
| 11 |
+
final_checkpoint/
|
| 12 |
+
location_learner/
|
| 13 |
+
|
| 14 |
+
*.jsonl
|
| 15 |
+
|
| 16 |
+
.wandb/
|
| 17 |
+
wandb/
|
| 18 |
+
**/.wandb/
|
| 19 |
+
**/wandb/
|
| 20 |
+
**/wandb_runtime/
|
| 21 |
+
|
| 22 |
+
.venv/
|
| 23 |
+
.hf_cache/
|
| 24 |
+
addition_runs/
|
| 25 |
+
tmp_latent_debug/
|
| 26 |
+
logs/
|
REBUTTAL_REPORT.md
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Curriculum CoT for 9x9 Sudoku — Rebuttal/Paper-Section Material
|
| 2 |
+
|
| 3 |
+
_Last updated: 2026-05-24_
|
| 4 |
+
|
| 5 |
+
This document is a comprehensive, paper-ready reference of (a) the data pipeline,
|
| 6 |
+
(b) the instruction-tuning prompt format, (c) the curriculum and reward design,
|
| 7 |
+
(d) the latent thought-token architecture, (e) the multi-stage SFT-then-GRPO
|
| 8 |
+
training recipe, and (f) the headline numerical results — so a rebuttal section
|
| 9 |
+
can be assembled directly from this document.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 1. Task
|
| 14 |
+
|
| 15 |
+
We use the model as a **per-cell value policy** for 9×9 Sudoku. For a fixed
|
| 16 |
+
target empty cell, the model emits a JSON set of candidate digits that are
|
| 17 |
+
"i-consistent" with the current grid (definition in §4). We evaluate two
|
| 18 |
+
metrics:
|
| 19 |
+
|
| 20 |
+
- **per-cell exact set match** (`exact_set_match`) — predicted set equals the
|
| 21 |
+
ground-truth i-consistent set;
|
| 22 |
+
- **whole-puzzle solve rate** (`solve`) — every empty cell on a 20-empty puzzle
|
| 23 |
+
produces an exact set match.
|
| 24 |
+
|
| 25 |
+
Because solve = ∏ exact_set_match across the ~20 empty cells of a puzzle,
|
| 26 |
+
the two metrics are non-linearly coupled:
|
| 27 |
+
|
| 28 |
+
$$ \text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}} $$
|
| 29 |
+
|
| 30 |
+
so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ — every percentage
|
| 31 |
+
point of per-cell exact maps to a much larger swing in solve.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 2. Data pipeline
|
| 36 |
+
|
| 37 |
+
### 2.1 Puzzle generation
|
| 38 |
+
|
| 39 |
+
Generated by `simple_9x9_curriculum/build_dataset.py`:
|
| 40 |
+
|
| 41 |
+
- Start from a base Latin-square grid; randomly relabel digits, permute
|
| 42 |
+
rows and columns within bands, and transpose.
|
| 43 |
+
- Sample `empties=20` cell positions uniformly at random and erase them.
|
| 44 |
+
- Save 10 000 train + 1 000 eval puzzles (seed 0, seed 1).
|
| 45 |
+
- Output JSONL files
|
| 46 |
+
`data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`.
|
| 47 |
+
|
| 48 |
+
A single record contains:
|
| 49 |
+
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"prompt": "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>",
|
| 53 |
+
"completion": "[7,3,8,2,6,9,4,5,...]",
|
| 54 |
+
"metadata": {
|
| 55 |
+
"grid_size": 9, "box_size": 3, "empties": 20,
|
| 56 |
+
"empty_locs_1based": [[1,4],[1,9],...],
|
| 57 |
+
"target_triples_1based": [[1,4,7],[1,9,3],...]
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
The 20 `target_triples` give the **solved** value at each of the 20 empty
|
| 63 |
+
positions, so per-cell training targets are always available. At training
|
| 64 |
+
time we expand each puzzle into 20 (puzzle, target_cell) examples.
|
| 65 |
+
|
| 66 |
+
### 2.2 Cell-policy framing
|
| 67 |
+
|
| 68 |
+
The model is never asked to solve a whole puzzle in one shot. Each example
|
| 69 |
+
is one (current_grid, target_cell) pair, and the supervised target is the
|
| 70 |
+
set of digits that are "i-consistent" with the current grid (see §4). This
|
| 71 |
+
turns Sudoku into a **classification-into-a-set** problem and lets us share
|
| 72 |
+
parameters across cells, stages, and puzzle sizes.
|
| 73 |
+
|
| 74 |
+
### 2.3 Multi-value oversampling (data-side trick)
|
| 75 |
+
|
| 76 |
+
Implemented in `multi_output_cell_policy/sft_multi_output_train.py` via
|
| 77 |
+
`tokenizer._multi_value_oversample_factor` and the CLI flags
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
--multi_value_oversample_factor INT (default 1)
|
| 81 |
+
--train_target_size_min INT (default 0)
|
| 82 |
+
--train_target_size_max INT (default 0)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Inside the dataset builder, examples whose target set has more than one
|
| 86 |
+
digit are repeated `multi_value_oversample_factor` times in the training
|
| 87 |
+
mix. This biases gradient steps toward exactly the cells the model gets
|
| 88 |
+
wrong (multi-value cells). Empirically, this is the single biggest data-side
|
| 89 |
+
lever — see §10.
|
| 90 |
+
|
| 91 |
+
### 2.4 Where the bottleneck lives
|
| 92 |
+
|
| 93 |
+
For 20-empty puzzles in stage 3, only ~25 % of empty cells have a
|
| 94 |
+
multi-value target set (the rest collapse to one i-consistent value). Yet
|
| 95 |
+
those multi-value cells are responsible for the entire solve-rate gap:
|
| 96 |
+
they are the cells where the model under-predicts (returns a singleton
|
| 97 |
+
when the target is a 2- or 3-element set), and a single failed cell kills
|
| 98 |
+
the whole-puzzle solve. The reward shaping in §6 and the oversample in 2.3
|
| 99 |
+
both attack this single failure mode.
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## 3. Instruction format
|
| 104 |
+
|
| 105 |
+
### 3.1 System prompt
|
| 106 |
+
|
| 107 |
+
(verbatim from `multi_output_cell_policy/prompt_builder.py`)
|
| 108 |
+
|
| 109 |
+
```text
|
| 110 |
+
You are a Sudoku value policy.
|
| 111 |
+
This setup uses puzzles with about 20 empty cells.
|
| 112 |
+
You will be given one target empty cell.
|
| 113 |
+
Return ONLY one JSON object of the form {"values":[...]}.
|
| 114 |
+
The JSON object must contain exactly one key named "values".
|
| 115 |
+
The "values" field must be a JSON array of unique integers in [1,9].
|
| 116 |
+
You may return as many candidate values as you want, including one, several,
|
| 117 |
+
or many values.
|
| 118 |
+
Choose the number of returned values yourself based on which values seem
|
| 119 |
+
i-consistent.
|
| 120 |
+
The order of the values does not matter.
|
| 121 |
+
Do not output any explanation, markdown, punctuation outside JSON, or extra text.
|
| 122 |
+
Current stage objective: i={i} consistency.
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### 3.2 User message
|
| 126 |
+
|
| 127 |
+
```text
|
| 128 |
+
Sudoku grid (0 means empty):
|
| 129 |
+
<grid_to_text(grid)>
|
| 130 |
+
Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), …
|
| 131 |
+
Target cell to fill now: (R,C).
|
| 132 |
+
Turn: t/T.
|
| 133 |
+
Return only JSON with candidate values for this target cell: {"values":[...]}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
We use the Qwen2.5-Instruct chat template (`tokenizer.apply_chat_template`,
|
| 137 |
+
`add_generation_prompt=True`) to wrap system + user into the actual prompt
|
| 138 |
+
ids. `max_prompt_length = 768`.
|
| 139 |
+
|
| 140 |
+
### 3.3 Output format
|
| 141 |
+
|
| 142 |
+
```json
|
| 143 |
+
{"values":[3,7]}
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
Strictly canonical JSON (single key `values`, sorted unique digit list,
|
| 147 |
+
no whitespace). Outputs are scored by `parse_values_json`
|
| 148 |
+
(`shared_multi_output_policy.py`); any deviation collapses the whole
|
| 149 |
+
prediction to `parse_ok=0` and a hard-coded malformed penalty.
|
| 150 |
+
|
| 151 |
+
`max_completion_length = 24` tokens — enough to emit any 9-digit set.
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
## 4. Curriculum: stage-i consistency
|
| 156 |
+
|
| 157 |
+
The curriculum lives in `_stage_i_consistent_values_for_grid`:
|
| 158 |
+
|
| 159 |
+
- **Stage 1 — i=1 (legal moves).** A value v is i=1 consistent at cell c
|
| 160 |
+
iff placing v at c violates no Sudoku constraint (row, column, 3×3 box).
|
| 161 |
+
This is just "legal candidates".
|
| 162 |
+
|
| 163 |
+
- **Stage 2 — i=2.** v is i=2 consistent at c iff (a) it is i=1 consistent
|
| 164 |
+
AND (b) after placing v, every other empty cell in the grid still has at
|
| 165 |
+
least one i=1-consistent value (i.e. placing v does not immediately make
|
| 166 |
+
the puzzle unsolvable by 1-step propagation).
|
| 167 |
+
|
| 168 |
+
- **Stage 3 — i=3.** Same recursion one more level deep: v is i=3 consistent
|
| 169 |
+
iff after placing v, every other empty cell still has at least one i=2
|
| 170 |
+
consistent value.
|
| 171 |
+
|
| 172 |
+
This is bounded look-ahead constraint propagation. Stage-3 sets are tighter
|
| 173 |
+
than stage-2 sets which are tighter than stage-1 sets. The curriculum
|
| 174 |
+
goal at deployment time is stage-3.
|
| 175 |
+
|
| 176 |
+
In data, we use the same source records and just change `--stage_i`; the
|
| 177 |
+
target set is regenerated on the fly by `stage_i_consistent_values`.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## 5. Latent thought-token architecture
|
| 182 |
+
|
| 183 |
+
Base model: **Qwen/Qwen2.5-1.5B-Instruct** + LoRA (r=32, α=64, dropout=0.05)
|
| 184 |
+
on `q,k,v,o,gate,up,down`. The latent variant adds **k thought-token slots**
|
| 185 |
+
between the prompt and the next-token logits.
|
| 186 |
+
|
| 187 |
+
Four modes are implemented (`latent_multi_output_cell_policy/`); the winning
|
| 188 |
+
mode for the final number is **`recurrent_hidden`**:
|
| 189 |
+
|
| 190 |
+
> `build_recurrent_hidden_latent_hidden(model, ids, mask, k)`
|
| 191 |
+
>
|
| 192 |
+
> 1. Run the backbone once on the prompt. Keep `base_hidden = h[:,-1,:]`.
|
| 193 |
+
> 2. Set `latent_token = base_hidden`.
|
| 194 |
+
> 3. Repeat k times: append `latent_token` (as an embedding) to the running
|
| 195 |
+
> sequence, run the backbone again on the extended sequence, and replace
|
| 196 |
+
> `latent_token` with the new last hidden state.
|
| 197 |
+
> 4. After k recursions, `latent_hidden` is fed through the LM head to
|
| 198 |
+
> produce the next-token distribution.
|
| 199 |
+
|
| 200 |
+
In equations, with E the input embedding lookup, f_θ the LoRA-decorated
|
| 201 |
+
backbone, U the LM head:
|
| 202 |
+
|
| 203 |
+
$$ z_0 = f_\theta(E([x_1,\dots,x_T]))_{T} $$
|
| 204 |
+
$$ z_{j+1} = f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1,\dots,z_j]\bigr)_{T+j+1},\ j=0,\dots,k-1 $$
|
| 205 |
+
$$ p(\cdot \mid x_{1:T}) = \mathrm{softmax}(U z_k) $$
|
| 206 |
+
|
| 207 |
+
The model can therefore "iterate" k extra forward passes on the same prompt
|
| 208 |
+
before committing to a token, with the k extra hidden states carrying
|
| 209 |
+
intermediate computation. Setting k=0 recovers the vanilla baseline.
|
| 210 |
+
|
| 211 |
+
The other three latent modes are alternatives that we ablated:
|
| 212 |
+
`fixed_slots` (concatenate k trainable seed embeddings — Option-2),
|
| 213 |
+
`latent_seeds` (similar to fixed_slots), and `residual` (project k extra
|
| 214 |
+
hidden states back onto the base hidden state via a learned residual). All
|
| 215 |
+
modes share the SFT and GRPO trainers; only the next-token logit function
|
| 216 |
+
changes.
|
| 217 |
+
|
| 218 |
+
For the curriculum, we grow k stage by stage:
|
| 219 |
+
|
| 220 |
+
| stage | num_cot_tokens | comment |
|
| 221 |
+
| --- | ---: | --- |
|
| 222 |
+
| 1 | 1 | one extra recursion as soon as the model has the surface form |
|
| 223 |
+
| 2 | 2 | two — needed for 1-step propagation reasoning |
|
| 224 |
+
| 3 | 3 | three — needed for 2-step propagation reasoning |
|
| 225 |
+
|
| 226 |
+
---
|
| 227 |
+
|
| 228 |
+
## 6. The reward function
|
| 229 |
+
|
| 230 |
+
Defined in `multi_output_cell_policy/rewards.py`.
|
| 231 |
+
|
| 232 |
+
Given target set T, predicted set P (after JSON parse), let
|
| 233 |
+
|
| 234 |
+
- `num_good = |P ∩ T|`
|
| 235 |
+
- `num_bad = |P \ T|`
|
| 236 |
+
- `num_missing = max(0, |T| − num_good)`
|
| 237 |
+
- `is_exact = (P ≠ ∅) ∧ (P = T)`
|
| 238 |
+
- `tri(n) = n(n+1)/2` (rewards larger correct sets superlinearly)
|
| 239 |
+
|
| 240 |
+
Then
|
| 241 |
+
|
| 242 |
+
$$
|
| 243 |
+
r = \mathrm{tri}(\mathrm{num\_good}) \cdot R_g \;-\; \mathrm{num\_bad} \cdot P_b \\
|
| 244 |
+
- \mathbb{1}[P=\varnothing]\, P_e \;-\; \mathbb{1}[|P|=1, |T|>1, i<2]\, P_s \\
|
| 245 |
+
- \mathrm{num\_missing}\cdot P_m \;+\; \mathbb{1}[\text{is\_exact}]\, B_x \\
|
| 246 |
+
- \mathbb{1}[|P|<|T|, |T|>1]\, P_c
|
| 247 |
+
$$
|
| 248 |
+
|
| 249 |
+
with parameters (this is the recipe that produced the 0.58/0.68 latent solve):
|
| 250 |
+
|
| 251 |
+
| symbol | flag | value | role |
|
| 252 |
+
|---|---|---:|---|
|
| 253 |
+
| $R_g$ | `--reward_good_value` | 1.25 | per-correct-value reward (with triangular shape) |
|
| 254 |
+
| $P_b$ | `--penalty_bad_value` | 1.0 | per-extra-wrong-value penalty |
|
| 255 |
+
| $P_{\!\text{mal}}$ | `--penalty_malformed` | 4.0 | flat penalty if JSON parse fails |
|
| 256 |
+
| $P_e$ | `--penalty_empty` | 0.5 | flat penalty if predicted set is empty |
|
| 257 |
+
| $P_s$ | `--penalty_singleton` | 1.5 | only at stage<2: punishes singleton on multi-value targets |
|
| 258 |
+
| $P_m$ | `--penalty_missing` | **0.75** | per-missing-value (recall pressure) — **NEW** |
|
| 259 |
+
| $B_x$ | `--exact_match_bonus` | **2.0** | only when P = T — **NEW** |
|
| 260 |
+
| $P_c$ | `--cardinality_mismatch_penalty` | **1.0** | when |P| < |T| and |T|>1 — **NEW** |
|
| 261 |
+
|
| 262 |
+
Parse failures short-circuit to `r = -P_mal` and zero per-cell metrics.
|
| 263 |
+
|
| 264 |
+
### 6.1 Why those three new terms exist (the breakthrough)
|
| 265 |
+
|
| 266 |
+
Diagnosis: at the v3/v4 plateau, eval reported
|
| 267 |
+
|
| 268 |
+
```
|
| 269 |
+
exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
across all checkpoints. Per-cell exact and precision/recall were all near
|
| 273 |
+
0.95 but the model **always predicted a single digit** (`avg_set_size=1.000`).
|
| 274 |
+
On a multi-value target $T=\{8,9\}$, predicting $\{8\}$ keeps precision=1.0,
|
| 275 |
+
recall=0.5 and yet `exact_set_match=0`. Solve = exact_set_match^N is
|
| 276 |
+
catastrophic in $N$ (=20), so even a small fraction of multi-value cells
|
| 277 |
+
killed it.
|
| 278 |
+
|
| 279 |
+
Without any of the new terms the optimum of $r$ on a multi-value cell is
|
| 280 |
+
trivially "predict the singleton you are most confident about" — there is
|
| 281 |
+
no upside to enumerate the second value. The three new terms close exactly
|
| 282 |
+
that hole:
|
| 283 |
+
|
| 284 |
+
- $P_m$ (`penalty_missing`) directly penalises recall;
|
| 285 |
+
- $B_x$ (`exact_match_bonus`) makes $P=T$ strictly dominate any singleton;
|
| 286 |
+
- $P_c$ (`cardinality_mismatch_penalty`) is a flat hammer whenever $|P|<|T|$.
|
| 287 |
+
|
| 288 |
+
After these terms were added, GRPO on the latent variant moved solve from
|
| 289 |
+
~0.30 to ~0.58 (100-puzzle eval) over ~200 steps. The same fix is what we
|
| 290 |
+
ported back into the baseline pipeline this evening (see §10).
|
| 291 |
+
|
| 292 |
+
---
|
| 293 |
+
|
| 294 |
+
## 7. Multi-stage warm-baseline pipeline (the recipe that worked)
|
| 295 |
+
|
| 296 |
+
Master script:
|
| 297 |
+
`hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`.
|
| 298 |
+
|
| 299 |
+
For each curriculum stage we run **three sub-phases in order**:
|
| 300 |
+
|
| 301 |
+
```
|
| 302 |
+
[stage i]
|
| 303 |
+
(1) baseline warm SFT (no latent tokens, k=0, vanilla LM)
|
| 304 |
+
(2) latent SFT (k = i, latent mode = recurrent_hidden)
|
| 305 |
+
(3) latent GRPO (k = i)
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
**The warm baseline phase (1) is the trick that makes the curriculum work.**
|
| 309 |
+
At every stage transition the data distribution changes (i increases →
|
| 310 |
+
target sets shrink) and a new latent slot appears. Doing a vanilla SFT on
|
| 311 |
+
the new distribution first lets the LM relearn the surface form on familiar
|
| 312 |
+
parameters; THEN the latent SFT adds the extra thought slot on top of an
|
| 313 |
+
already-good policy. When we tried to add a new latent slot directly on
|
| 314 |
+
top of the previous stage's GRPO checkpoint, training loss did NOT
|
| 315 |
+
decrease.
|
| 316 |
+
|
| 317 |
+
Concrete LR schedule used for the champion run:
|
| 318 |
+
|
| 319 |
+
| phase | init from | LR | k |
|
| 320 |
+
|---|---|---:|---:|
|
| 321 |
+
| S1 baseline SFT | base Qwen | 2e-4 | 0 |
|
| 322 |
+
| S1 latent SFT | S1 baseline | 2e-4 | 1 |
|
| 323 |
+
| S1 latent GRPO | S1 latent SFT | 1e-6 | 1 |
|
| 324 |
+
| S2 baseline warm SFT | S1 GRPO | 5e-5 | 0 |
|
| 325 |
+
| S2 latent SFT | S2 baseline | 5e-5 | 2 |
|
| 326 |
+
| S2 latent GRPO | S2 latent SFT | 1e-6 | 2 |
|
| 327 |
+
| S3 baseline warm SFT | S2 GRPO | 5e-5 | 0 |
|
| 328 |
+
| S3 latent SFT | S3 baseline | 5e-5 → 1e-5 (champion) | 3 |
|
| 329 |
+
| S3 latent GRPO | S3 latent SFT | 5e-6 (β=0) | 3 |
|
| 330 |
+
|
| 331 |
+
Other shared knobs:
|
| 332 |
+
|
| 333 |
+
```
|
| 334 |
+
LoRA: r=32 α=64 dropout=0.05 on q,k,v,o,gate,up,down
|
| 335 |
+
SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128
|
| 336 |
+
GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64
|
| 337 |
+
num_generations=4 beta=0.0 max_prompt_length=1024
|
| 338 |
+
max_completion_length=24
|
| 339 |
+
multi_value_oversample_factor=5, exact_match_bonus=2.0,
|
| 340 |
+
penalty_missing=0.75, cardinality_mismatch_penalty=1.0
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
---
|
| 344 |
+
|
| 345 |
+
## 8. GRPO settings that mattered
|
| 346 |
+
|
| 347 |
+
- **β = 0.** The KL anchor was harmful in every sweep where we tried β>0.
|
| 348 |
+
`s3_grpo_kl04` (β=0.04) peaked at solve=0.625 (40p) at step 100 and
|
| 349 |
+
regressed to 0.525 by step 500.
|
| 350 |
+
|
| 351 |
+
- **num_generations = 4.** With num_generations=2 we routinely saw
|
| 352 |
+
`reward_std = 0` (all sampled completions identical → no gradient).
|
| 353 |
+
Bumping to 4 fixed it.
|
| 354 |
+
|
| 355 |
+
- **Low LR.** `lr=5e-6` was the steadiest. `lr=1e-5` peaked at step 200
|
| 356 |
+
(solve 0.65) then collapsed back to 0.54 — classic mode collapse.
|
| 357 |
+
|
| 358 |
+
- **Effective bs ≥ 64.** TRL's GRPOConfig requires
|
| 359 |
+
`eff_bs * grad_accum % num_generations == 0`; with 8 GPUs we hit this
|
| 360 |
+
trivially, but we caution single-GPU rerunners to set
|
| 361 |
+
`per_device_bs=4 grad_accum=2 num_generations=4`.
|
| 362 |
+
|
| 363 |
+
- **`enable_input_require_grads()` on the wrapped backbone.** Required for
|
| 364 |
+
TRL 0.15.x + PEFT LoRA + gradient checkpointing — otherwise the loss
|
| 365 |
+
tensor produced by GRPOTrainer has `requires_grad=False` and `.backward()`
|
| 366 |
+
raises. Also `unwrapped.config.use_cache = False`.
|
| 367 |
+
|
| 368 |
+
---
|
| 369 |
+
|
| 370 |
+
## 9. Final hyperparameters table — champion latent run
|
| 371 |
+
|
| 372 |
+
| group | hyperparameter | value |
|
| 373 |
+
|---|---|---|
|
| 374 |
+
| Backbone | model | Qwen/Qwen2.5-1.5B-Instruct |
|
| 375 |
+
| Backbone | dtype | bf16 |
|
| 376 |
+
| Backbone | LoRA target modules | q,k,v,o,gate,up,down |
|
| 377 |
+
| Backbone | LoRA r / α / dropout | 32 / 64 / 0.05 |
|
| 378 |
+
| Latent | mode | recurrent_hidden |
|
| 379 |
+
| Latent | num_cot_tokens (S1/S2/S3) | 1 / 2 / 3 |
|
| 380 |
+
| Latent | max_latent_slots / seeds | 8 / 8 |
|
| 381 |
+
| Data | total empties | 20 |
|
| 382 |
+
| Data | train rows / eval rows | 10 000 / 100 |
|
| 383 |
+
| Data | multi_value_oversample_factor | 5 |
|
| 384 |
+
| Data | mixed_stage1_ratio (S1) | 1 |
|
| 385 |
+
| Data | mixed_stage2_ratio (S≥2) | 1 |
|
| 386 |
+
| SFT | per_device_bs / grad_accum | 8 / 2 |
|
| 387 |
+
| SFT | num_epochs (cap) | 64 |
|
| 388 |
+
| SFT | LR (S1 latent) | 2e-4 |
|
| 389 |
+
| SFT | LR (S2/S3 baseline warm + latent) | 5e-5 |
|
| 390 |
+
| SFT | LR (S3 latent champion `s3b_lr1e5_o5`) | 1e-5 |
|
| 391 |
+
| SFT | weight_decay | 0.0 |
|
| 392 |
+
| SFT | gradient checkpointing | on |
|
| 393 |
+
| GRPO | per_device_bs / grad_accum | 4 / 2 |
|
| 394 |
+
| GRPO | num_generations | 4 |
|
| 395 |
+
| GRPO | LR | 5e-6 (S3); 1e-6 (S1, S2) |
|
| 396 |
+
| GRPO | β (KL) | 0.0 |
|
| 397 |
+
| GRPO | max_prompt_length | 1024 |
|
| 398 |
+
| GRPO | max_completion_length | 24 |
|
| 399 |
+
| Reward | reward_good_value | 1.25 |
|
| 400 |
+
| Reward | penalty_bad_value | 1.0 |
|
| 401 |
+
| Reward | penalty_malformed | 4.0 |
|
| 402 |
+
| Reward | penalty_empty | 0.5 |
|
| 403 |
+
| Reward | penalty_singleton | 1.5 |
|
| 404 |
+
| Reward | penalty_missing | 0.75 |
|
| 405 |
+
| Reward | exact_match_bonus | 2.0 |
|
| 406 |
+
| Reward | cardinality_mismatch_penalty | 1.0 |
|
| 407 |
+
| Eval | early-stop on prec/recall | 0.98 |
|
| 408 |
+
|
| 409 |
+
---
|
| 410 |
+
|
| 411 |
+
## 10. Headline results
|
| 412 |
+
|
| 413 |
+
### 10.1 Latent (with thought tokens, recurrent_hidden)
|
| 414 |
+
|
| 415 |
+
| eval | model / phase | step | exact | prec | recall | **solve** |
|
| 416 |
+
|---|---|---:|---:|---:|---:|---:|
|
| 417 |
+
| **100p (auth.)** | `s3_grpo_baseline` (S3 GRPO, β=0, lr=5e-6) | 200 | 0.9665 | 0.9673 | 0.9680 | **0.580 (58/100)** |
|
| 418 |
+
| 40p | `s3_grpo_sharp_rwd` (exact_b=4, card_pen=3) | 300 | — | — | — | **0.675 (27/40)** |
|
| 419 |
+
| 40p | `s3_grpo_lr1e5` | 200 | 0.978 | 0.978 | 0.979 | 0.650 |
|
| 420 |
+
| 40p | `s3b_lr1e5_o5` (S3 SFT champion) | 2400 | 0.974 | 0.974 | 0.975 | 0.600 |
|
| 421 |
+
|
| 422 |
+
### 10.2 Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)
|
| 423 |
+
|
| 424 |
+
| sweep | best variant | best step | exact | **solve (100p)** |
|
| 425 |
+
|---|---|---:|---:|---:|
|
| 426 |
+
| v3 (single-GPU LR=2e-5, no oversample, no new reward terms) | `baseline_3stage_20260522` | — | 0.730 | **0.000** |
|
| 427 |
+
| v4 (LR sweep, multi-GPU, original reward) | `pipe_v_sft_extend` (S3 SFT extended) | 4000 | 0.948 | **0.400** |
|
| 428 |
+
| **v6 (this evening, ports latent reward + oversample)** | `v6_i_sft_v_oversample10` (oversample=10) | running | 0.952+ | **0.440 (best so far)** |
|
| 429 |
+
|
| 430 |
+
The v6 sweep is still running — `v6_e/f/i` are in S3 SFT continuation,
|
| 431 |
+
GRPO follow-on phases queued. The v6_i variant has hit **solve=0.44** at
|
| 432 |
+
SFT eval (new baseline best, +0.04 over v4) and is still climbing.
|
| 433 |
+
|
| 434 |
+
### 10.3 Stage-by-stage trajectory (latent, 40-puzzle eval)
|
| 435 |
+
|
| 436 |
+
```
|
| 437 |
+
S1 SFT : exact ≈ 0.85, solve ≈ 0.20
|
| 438 |
+
S1 GRPO : exact ≈ 0.90, solve ≈ 0.20
|
| 439 |
+
S2 SFT (no oversample) : exact ≈ 0.94, solve ≈ 0.20-0.25 <- the wall
|
| 440 |
+
S2 SFT + multi_value_oversample=5 : exact ≈ 0.96, solve ≈ 0.30-0.35
|
| 441 |
+
S2 GRPO + new reward terms : exact ≈ 0.96, solve ≈ 0.35-0.40
|
| 442 |
+
S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion
|
| 443 |
+
S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion
|
| 444 |
+
S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak
|
| 445 |
+
```
|
| 446 |
+
|
| 447 |
+
### 10.4 Latent vs baseline gap (head-to-head, same 100p eval, same prompts)
|
| 448 |
+
|
| 449 |
+
| model | exact | prec | recall | **solve** | solved/100 |
|
| 450 |
+
|---|---:|---:|---:|---:|---:|
|
| 451 |
+
| Latent recurrent_hidden, S3 GRPO | 0.9665 | 0.9673 | 0.9680 | **0.580** | 58 |
|
| 452 |
+
| Vanilla baseline, v6_i (best at time of writing) | 0.952 | 0.952 | 0.952 | **0.440** | 44 |
|
| 453 |
+
|
| 454 |
+
Gap on 100-puzzle solve: ≈ **+0.14 absolute / +32 % relative** for latent
|
| 455 |
+
over the strongest baseline we have.
|
| 456 |
+
|
| 457 |
+
---
|
| 458 |
+
|
| 459 |
+
## 11. Why the latent works (interpretation hypotheses)
|
| 460 |
+
|
| 461 |
+
These are the working hypotheses the experiments are consistent with;
|
| 462 |
+
none is fully proven and ablations are still WIP.
|
| 463 |
+
|
| 464 |
+
1. **Constraint-propagation depth.** Stage-3 i-consistency is essentially
|
| 465 |
+
2-ply lookahead. With $k=3$ recurrent hidden tokens the model gets
|
| 466 |
+
exactly three extra forward passes between prompt and output — one for
|
| 467 |
+
the legality check, one for 1-step propagation, one for the second
|
| 468 |
+
step of propagation. Empirically the gap to the no-thought-token
|
| 469 |
+
baseline appears at stages where multi-step propagation matters
|
| 470 |
+
(stage 2 onward; stage 1 numbers are essentially identical).
|
| 471 |
+
|
| 472 |
+
2. **Multi-value cells require enumeration, which a singleton softmax can't
|
| 473 |
+
do in one forward pass.** A vanilla LM at 1.5B parameters predicts
|
| 474 |
+
essentially deterministically once temperature is low; for a target set
|
| 475 |
+
{8, 9} the LM picks one of the two and stops. The latent model can use
|
| 476 |
+
one of the recurrent hidden steps to "consider" each option without
|
| 477 |
+
committing yet, which is exactly the failure mode in the data
|
| 478 |
+
(`avg_set_size = 1.000` for the baseline, `≈ 1.05` for the latent S3
|
| 479 |
+
model on the same eval).
|
| 480 |
+
|
| 481 |
+
3. **Stable curriculum capacity growth.** Adding a new latent slot at every
|
| 482 |
+
stage gives the model a "fresh slate" of representational capacity at the
|
| 483 |
+
exact transition where the task gets harder. The warm-baseline SFT
|
| 484 |
+
between stages prevents the new slot from corrupting the previously
|
| 485 |
+
learned policy. Without warm baseline, training loss did not decrease
|
| 486 |
+
at all (we observed this directly when we tried to skip the warm
|
| 487 |
+
baseline yesterday).
|
| 488 |
+
|
| 489 |
+
4. **GRPO without latent slots is starved of variance.** With max_completion
|
| 490 |
+
length 24 and the model essentially deterministic, GRPO's 4 sampled
|
| 491 |
+
completions per prompt collapse to a single answer — `reward_std = 0`,
|
| 492 |
+
no gradient. With latent recurrence + the new exact_match_bonus reward,
|
| 493 |
+
the model occasionally samples a 2-element set, gets a much higher
|
| 494 |
+
reward, and that prompt gets a real gradient signal.
|
| 495 |
+
|
| 496 |
+
---
|
| 497 |
+
|
| 498 |
+
## 12. Reproducibility
|
| 499 |
+
|
| 500 |
+
Code repository: `https://github.com/Avra98/curriculum_cot`
|
| 501 |
+
Latent checkpoints: `https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages`
|
| 502 |
+
Baseline checkpoints: `https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep`
|
| 503 |
+
|
| 504 |
+
Key scripts:
|
| 505 |
+
|
| 506 |
+
- Master orchestrator (latent, 9-phase warm-baseline pipeline):
|
| 507 |
+
`hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh`
|
| 508 |
+
- Vanilla baseline pipeline:
|
| 509 |
+
`_runs/baseline_1p5b_pipeline_v4.sh` (with v6 launchers
|
| 510 |
+
`_runs/launch_baseline_push_v6.sh`)
|
| 511 |
+
- SFT trainer (vanilla): `multi_output_cell_policy/sft_multi_output_train.py`
|
| 512 |
+
- GRPO trainer (vanilla): `multi_output_cell_policy/grpo_multi_output_train.py`
|
| 513 |
+
- SFT trainer (latent): `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`
|
| 514 |
+
- GRPO trainer (latent): `latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py`
|
| 515 |
+
- Reward function: `multi_output_cell_policy/rewards.py`
|
| 516 |
+
- Prompt builder: `multi_output_cell_policy/prompt_builder.py`
|
| 517 |
+
- Stage-i consistency: `multi_output_cell_policy/shared_multi_output_policy.py`
|
| 518 |
+
- 100-puzzle evaluator: `analysis/eval_stage2_checkpoint.py`
|
| 519 |
+
|
| 520 |
+
To reproduce the latent champion (1.5B, 9-phase, ~16 GPU·h on 8×H100 80GB):
|
| 521 |
+
|
| 522 |
+
```bash
|
| 523 |
+
export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter
|
| 524 |
+
bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
|
| 525 |
+
```
|
| 526 |
+
|
| 527 |
+
To reproduce the v6 baseline push (single-GPU per variant, ~6 GPU·h):
|
| 528 |
+
|
| 529 |
+
```bash
|
| 530 |
+
bash _runs/launch_baseline_push_v6.sh
|
| 531 |
+
```
|
| 532 |
+
|
| 533 |
+
---
|
| 534 |
+
|
| 535 |
+
## Appendix A. The reward fix as a one-line patch
|
| 536 |
+
|
| 537 |
+
The single most consequential code change in this whole project, as a
|
| 538 |
+
self-contained patch on `multi_output_cell_policy/rewards.py`:
|
| 539 |
+
|
| 540 |
+
```python
|
| 541 |
+
# new args (default 0 preserves legacy behaviour)
|
| 542 |
+
penalty_missing: float = 0.0
|
| 543 |
+
exact_match_bonus: float = 0.0
|
| 544 |
+
cardinality_mismatch_penalty: float = 0.0
|
| 545 |
+
|
| 546 |
+
num_missing = max(0, len(target_set) - num_good)
|
| 547 |
+
is_exact = bool(predicted_values) and (set(predicted_values) == target_set)
|
| 548 |
+
|
| 549 |
+
# ... base reward (triangular_number(num_good)*reward_good_value - num_bad*penalty_bad_value)
|
| 550 |
+
|
| 551 |
+
if num_missing > 0:
|
| 552 |
+
reward -= num_missing * penalty_missing
|
| 553 |
+
if is_exact:
|
| 554 |
+
reward += exact_match_bonus
|
| 555 |
+
if len(predicted_values) < len(target_values) and len(target_values) > 1:
|
| 556 |
+
reward -= cardinality_mismatch_penalty
|
| 557 |
+
```
|
| 558 |
+
|
| 559 |
+
Defaults are zero so old runs are unaffected; the recipe sets
|
| 560 |
+
`(P_m, B_x, P_c) = (0.75, 2.0, 1.0)` for the vanilla recipe and
|
| 561 |
+
`(1.0, 4.0, 3.0)` for the "sharp_rwd" variant.
|
| 562 |
+
|
| 563 |
+
## Appendix B. The warm-baseline trick as a sequence diagram
|
| 564 |
+
|
| 565 |
+
```
|
| 566 |
+
Stage 1 Stage 2 Stage 3
|
| 567 |
+
───────── ───────── ─────────
|
| 568 |
+
[base Qwen] ↓ ↓
|
| 569 |
+
↓ ↓ ↓
|
| 570 |
+
S1 baseline SFT → S2 baseline SFT → S3 baseline SFT
|
| 571 |
+
(no latent, k=0) (no latent, k=0) (no latent, k=0)
|
| 572 |
+
↓ ↓ ↓
|
| 573 |
+
S1 latent SFT → S2 latent SFT → S3 latent SFT
|
| 574 |
+
(k=1) (k=2) (k=3)
|
| 575 |
+
↓ ↓ ↓
|
| 576 |
+
S1 latent GRPO → S2 latent GRPO → S3 latent GRPO
|
| 577 |
+
(k=1, β=0, lr 1e-6) (k=2, β=0) (k=3, β=0, lr 5e-6)
|
| 578 |
+
↓ ↓ ↓
|
| 579 |
+
[final policy]
|
| 580 |
+
```
|
| 581 |
+
|
| 582 |
+
Every arrow is `init_adapter_dir = <previous output>`. Each row is a
|
| 583 |
+
"slot in the curriculum"; the column adds reasoning capacity (k+=1) and
|
| 584 |
+
moves to a harder target distribution (i+=1). The diagonal across the
|
| 585 |
+
diagram is the actual training trajectory.
|
| 586 |
+
|
| 587 |
+
---
|
| 588 |
+
|
| 589 |
+
_End of report._
|
REBUTTAL_REPORT.tex
ADDED
|
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
\documentclass[11pt]{article}
|
| 2 |
+
\usepackage[a4paper,margin=1in]{geometry}
|
| 3 |
+
\usepackage[utf8]{inputenc}
|
| 4 |
+
\usepackage[T1]{fontenc}
|
| 5 |
+
\usepackage{lmodern}
|
| 6 |
+
\usepackage{microtype}
|
| 7 |
+
\usepackage{amsmath,amssymb}
|
| 8 |
+
\usepackage{booktabs}
|
| 9 |
+
\usepackage{array}
|
| 10 |
+
\usepackage{longtable}
|
| 11 |
+
\usepackage{tabularx}
|
| 12 |
+
\usepackage{xcolor}
|
| 13 |
+
\usepackage{listings}
|
| 14 |
+
\usepackage{hyperref}
|
| 15 |
+
\usepackage{enumitem}
|
| 16 |
+
\usepackage{parskip}
|
| 17 |
+
|
| 18 |
+
\hypersetup{colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue}
|
| 19 |
+
|
| 20 |
+
\lstset{
|
| 21 |
+
basicstyle=\ttfamily\small,
|
| 22 |
+
breaklines=true,
|
| 23 |
+
columns=fullflexible,
|
| 24 |
+
keepspaces=true,
|
| 25 |
+
frame=single,
|
| 26 |
+
framerule=0.4pt,
|
| 27 |
+
xleftmargin=0.5em,
|
| 28 |
+
xrightmargin=0.5em,
|
| 29 |
+
showstringspaces=false,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
\newcommand{\code}[1]{\texttt{#1}}
|
| 33 |
+
\newcommand{\indic}{\mathbf{1}}
|
| 34 |
+
|
| 35 |
+
\title{Curriculum CoT for $9{\times}9$ Sudoku\\[2pt]
|
| 36 |
+
\large Rebuttal / Paper-Section Material}
|
| 37 |
+
\author{}
|
| 38 |
+
\date{Last updated: 2026--05--24}
|
| 39 |
+
|
| 40 |
+
\begin{document}
|
| 41 |
+
\maketitle
|
| 42 |
+
|
| 43 |
+
\noindent
|
| 44 |
+
This document is a comprehensive, paper-ready reference of (a) the data
|
| 45 |
+
pipeline, (b) the instruction-tuning prompt format, (c) the curriculum
|
| 46 |
+
and reward design, (d) the latent thought-token architecture, (e) the
|
| 47 |
+
multi-stage SFT-then-GRPO training recipe, and (f) the headline numerical
|
| 48 |
+
results --- so a rebuttal section can be assembled directly from this
|
| 49 |
+
document.
|
| 50 |
+
|
| 51 |
+
\bigskip
|
| 52 |
+
\hrule
|
| 53 |
+
\bigskip
|
| 54 |
+
|
| 55 |
+
\section{Task}
|
| 56 |
+
|
| 57 |
+
We use the model as a \textbf{per-cell value policy} for $9\times 9$
|
| 58 |
+
Sudoku. For a fixed target empty cell, the model emits a JSON set of
|
| 59 |
+
candidate digits that are ``i-consistent'' with the current grid
|
| 60 |
+
(definition in \S 4). We evaluate two metrics:
|
| 61 |
+
|
| 62 |
+
\begin{itemize}[leftmargin=*]
|
| 63 |
+
\item \textbf{per-cell exact set match} (\code{exact\_set\_match}) ---
|
| 64 |
+
predicted set equals the ground-truth i-consistent set;
|
| 65 |
+
\item \textbf{whole-puzzle solve rate} (\code{solve}) --- every empty
|
| 66 |
+
cell on a 20-empty puzzle produces an exact set match.
|
| 67 |
+
\end{itemize}
|
| 68 |
+
|
| 69 |
+
Because $\text{solve} = \prod \text{exact\_set\_match}$ across the $\sim 20$
|
| 70 |
+
empty cells of a puzzle, the two metrics are non-linearly coupled:
|
| 71 |
+
\[
|
| 72 |
+
\text{solve} \approx \text{exact\_set\_match}^{N_{\text{empty}}}
|
| 73 |
+
\]
|
| 74 |
+
so $0.95^{20} \approx 0.358$ and $0.97^{20} \approx 0.544$ --- every
|
| 75 |
+
percentage point of per-cell exact maps to a much larger swing in solve.
|
| 76 |
+
|
| 77 |
+
\section{Data pipeline}
|
| 78 |
+
|
| 79 |
+
\subsection{Puzzle generation}
|
| 80 |
+
|
| 81 |
+
Generated by \code{simple\_9x9\_curriculum/build\_dataset.py}:
|
| 82 |
+
|
| 83 |
+
\begin{itemize}[leftmargin=*]
|
| 84 |
+
\item Start from a base Latin-square grid; randomly relabel digits,
|
| 85 |
+
permute rows and columns within bands, and transpose.
|
| 86 |
+
\item Sample \code{empties=20} cell positions uniformly at random and
|
| 87 |
+
erase them.
|
| 88 |
+
\item Save 10\,000 train + 1\,000 eval puzzles (seed 0, seed 1).
|
| 89 |
+
\item Output JSONL files \code{data/sudoku\_t3\_20empty\_value\_qwen\_text\_stage1\_\{train,eval\}.jsonl}.
|
| 90 |
+
\end{itemize}
|
| 91 |
+
|
| 92 |
+
A single record contains:
|
| 93 |
+
|
| 94 |
+
\begin{lstlisting}
|
| 95 |
+
{
|
| 96 |
+
"prompt": "<full Qwen chat-templated prompt for one (puzzle, target_cell) pair>",
|
| 97 |
+
"completion": "[7,3,8,2,6,9,4,5,...]",
|
| 98 |
+
"metadata": {
|
| 99 |
+
"grid_size": 9, "box_size": 3, "empties": 20,
|
| 100 |
+
"empty_locs_1based": [[1,4],[1,9],...],
|
| 101 |
+
"target_triples_1based": [[1,4,7],[1,9,3],...]
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
\end{lstlisting}
|
| 105 |
+
|
| 106 |
+
The 20 \code{target\_triples} give the \textbf{solved} value at each of
|
| 107 |
+
the 20 empty positions, so per-cell training targets are always
|
| 108 |
+
available. At training time we expand each puzzle into 20 (puzzle,
|
| 109 |
+
target\_cell) examples.
|
| 110 |
+
|
| 111 |
+
\subsection{Cell-policy framing}
|
| 112 |
+
|
| 113 |
+
The model is never asked to solve a whole puzzle in one shot. Each
|
| 114 |
+
example is one (current\_grid, target\_cell) pair, and the supervised
|
| 115 |
+
target is the set of digits that are ``i-consistent'' with the current
|
| 116 |
+
grid (see \S 4). This turns Sudoku into a
|
| 117 |
+
\textbf{classification-into-a-set} problem and lets us share parameters
|
| 118 |
+
across cells, stages, and puzzle sizes.
|
| 119 |
+
|
| 120 |
+
\subsection{Multi-value oversampling (data-side trick)}
|
| 121 |
+
|
| 122 |
+
Implemented in \code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py}
|
| 123 |
+
via \code{tokenizer.\_multi\_value\_oversample\_factor} and the CLI flags
|
| 124 |
+
|
| 125 |
+
\begin{lstlisting}
|
| 126 |
+
--multi_value_oversample_factor INT (default 1)
|
| 127 |
+
--train_target_size_min INT (default 0)
|
| 128 |
+
--train_target_size_max INT (default 0)
|
| 129 |
+
\end{lstlisting}
|
| 130 |
+
|
| 131 |
+
Inside the dataset builder, examples whose target set has more than one
|
| 132 |
+
digit are repeated \code{multi\_value\_oversample\_factor} times in the
|
| 133 |
+
training mix. This biases gradient steps toward exactly the cells the
|
| 134 |
+
model gets wrong (multi-value cells). Empirically, this is the single
|
| 135 |
+
biggest data-side lever --- see \S 10.
|
| 136 |
+
|
| 137 |
+
\subsection{Where the bottleneck lives}
|
| 138 |
+
|
| 139 |
+
For 20-empty puzzles in stage 3, only $\sim 25\%$ of empty cells have a
|
| 140 |
+
multi-value target set (the rest collapse to one i-consistent value).
|
| 141 |
+
Yet those multi-value cells are responsible for the entire solve-rate
|
| 142 |
+
gap: they are the cells where the model under-predicts (returns a
|
| 143 |
+
singleton when the target is a 2- or 3-element set), and a single
|
| 144 |
+
failed cell kills the whole-puzzle solve. The reward shaping in \S 6
|
| 145 |
+
and the oversample in 2.3 both attack this single failure mode.
|
| 146 |
+
|
| 147 |
+
\section{Instruction format}
|
| 148 |
+
|
| 149 |
+
\subsection{System prompt}
|
| 150 |
+
|
| 151 |
+
(verbatim from \code{multi\_output\_cell\_policy/prompt\_builder.py})
|
| 152 |
+
|
| 153 |
+
\begin{lstlisting}
|
| 154 |
+
You are a Sudoku value policy.
|
| 155 |
+
This setup uses puzzles with about 20 empty cells.
|
| 156 |
+
You will be given one target empty cell.
|
| 157 |
+
Return ONLY one JSON object of the form {"values":[...]}.
|
| 158 |
+
The JSON object must contain exactly one key named "values".
|
| 159 |
+
The "values" field must be a JSON array of unique integers in [1,9].
|
| 160 |
+
You may return as many candidate values as you want, including one,
|
| 161 |
+
several, or many values.
|
| 162 |
+
Choose the number of returned values yourself based on which values seem
|
| 163 |
+
i-consistent.
|
| 164 |
+
The order of the values does not matter.
|
| 165 |
+
Do not output any explanation, markdown, punctuation outside JSON, or
|
| 166 |
+
extra text.
|
| 167 |
+
Current stage objective: i={i} consistency.
|
| 168 |
+
\end{lstlisting}
|
| 169 |
+
|
| 170 |
+
\subsection{User message}
|
| 171 |
+
|
| 172 |
+
\begin{lstlisting}
|
| 173 |
+
Sudoku grid (0 means empty):
|
| 174 |
+
<grid_to_text(grid)>
|
| 175 |
+
Empty cells in row-major order (20 total): (1,4), (1,9), (2,8), ...
|
| 176 |
+
Target cell to fill now: (R,C).
|
| 177 |
+
Turn: t/T.
|
| 178 |
+
Return only JSON with candidate values for this target cell: {"values":[...]}
|
| 179 |
+
\end{lstlisting}
|
| 180 |
+
|
| 181 |
+
We use the Qwen2.5-Instruct chat template
|
| 182 |
+
(\code{tokenizer.apply\_chat\_template}, \code{add\_generation\_prompt=True})
|
| 183 |
+
to wrap system + user into the actual prompt ids.
|
| 184 |
+
\code{max\_prompt\_length = 768}.
|
| 185 |
+
|
| 186 |
+
\subsection{Output format}
|
| 187 |
+
|
| 188 |
+
\begin{lstlisting}
|
| 189 |
+
{"values":[3,7]}
|
| 190 |
+
\end{lstlisting}
|
| 191 |
+
|
| 192 |
+
Strictly canonical JSON (single key \code{values}, sorted unique digit
|
| 193 |
+
list, no whitespace). Outputs are scored by \code{parse\_values\_json}
|
| 194 |
+
(\code{shared\_multi\_output\_policy.py}); any deviation collapses the
|
| 195 |
+
whole prediction to \code{parse\_ok=0} and a hard-coded malformed
|
| 196 |
+
penalty.
|
| 197 |
+
|
| 198 |
+
\code{max\_completion\_length = 24} tokens --- enough to emit any
|
| 199 |
+
9-digit set.
|
| 200 |
+
|
| 201 |
+
\section{Curriculum: stage-i consistency}
|
| 202 |
+
|
| 203 |
+
The curriculum lives in \code{\_stage\_i\_consistent\_values\_for\_grid}:
|
| 204 |
+
|
| 205 |
+
\begin{itemize}[leftmargin=*]
|
| 206 |
+
\item \textbf{Stage 1 --- $i=1$ (legal moves).} A value $v$ is $i=1$
|
| 207 |
+
consistent at cell $c$ iff placing $v$ at $c$ violates no Sudoku
|
| 208 |
+
constraint (row, column, $3\times 3$ box). This is just ``legal
|
| 209 |
+
candidates''.
|
| 210 |
+
|
| 211 |
+
\item \textbf{Stage 2 --- $i=2$.} $v$ is $i=2$ consistent at $c$ iff
|
| 212 |
+
(a) it is $i=1$ consistent AND (b) after placing $v$, every other
|
| 213 |
+
empty cell in the grid still has at least one $i=1$-consistent value
|
| 214 |
+
(i.e.\ placing $v$ does not immediately make the puzzle unsolvable
|
| 215 |
+
by 1-step propagation).
|
| 216 |
+
|
| 217 |
+
\item \textbf{Stage 3 --- $i=3$.} Same recursion one more level deep:
|
| 218 |
+
$v$ is $i=3$ consistent iff after placing $v$, every other empty cell
|
| 219 |
+
still has at least one $i=2$ consistent value.
|
| 220 |
+
\end{itemize}
|
| 221 |
+
|
| 222 |
+
This is bounded look-ahead constraint propagation. Stage-3 sets are
|
| 223 |
+
tighter than stage-2 sets which are tighter than stage-1 sets. The
|
| 224 |
+
curriculum goal at deployment time is stage-3.
|
| 225 |
+
|
| 226 |
+
In data, we use the same source records and just change \code{--stage\_i};
|
| 227 |
+
the target set is regenerated on the fly by
|
| 228 |
+
\code{stage\_i\_consistent\_values}.
|
| 229 |
+
|
| 230 |
+
\section{Latent thought-token architecture}
|
| 231 |
+
|
| 232 |
+
Base model: \textbf{Qwen/Qwen2.5-1.5B-Instruct} + LoRA
|
| 233 |
+
($r=32$, $\alpha=64$, dropout $=0.05$) on
|
| 234 |
+
\code{q,k,v,o,gate,up,down}. The latent variant adds \textbf{$k$
|
| 235 |
+
thought-token slots} between the prompt and the next-token logits.
|
| 236 |
+
|
| 237 |
+
Four modes are implemented (\code{latent\_multi\_output\_cell\_policy/});
|
| 238 |
+
the winning mode for the final number is \textbf{\code{recurrent\_hidden}}:
|
| 239 |
+
|
| 240 |
+
\begin{quote}
|
| 241 |
+
\code{build\_recurrent\_hidden\_latent\_hidden(model, ids, mask, k)}
|
| 242 |
+
\begin{enumerate}[leftmargin=*,nosep]
|
| 243 |
+
\item Run the backbone once on the prompt. Keep
|
| 244 |
+
\code{base\_hidden = h[:,-1,:]}.
|
| 245 |
+
\item Set \code{latent\_token = base\_hidden}.
|
| 246 |
+
\item Repeat $k$ times: append \code{latent\_token} (as an embedding)
|
| 247 |
+
to the running sequence, run the backbone again on the extended
|
| 248 |
+
sequence, and replace \code{latent\_token} with the new last hidden
|
| 249 |
+
state.
|
| 250 |
+
\item After $k$ recursions, \code{latent\_hidden} is fed through the LM
|
| 251 |
+
head to produce the next-token distribution.
|
| 252 |
+
\end{enumerate}
|
| 253 |
+
\end{quote}
|
| 254 |
+
|
| 255 |
+
In equations, with $E$ the input embedding lookup, $f_\theta$ the
|
| 256 |
+
LoRA-decorated backbone, $U$ the LM head:
|
| 257 |
+
\begin{align*}
|
| 258 |
+
z_0 &= f_\theta\bigl(E([x_1,\dots,x_T])\bigr)_T \\
|
| 259 |
+
z_{j+1} &= f_\theta\bigl([E(x_1),\dots,E(x_T), z_0, z_1, \dots, z_j]\bigr)_{T+j+1},\quad j=0,\dots,k-1 \\
|
| 260 |
+
p(\cdot \mid x_{1:T}) &= \mathrm{softmax}(U z_k)
|
| 261 |
+
\end{align*}
|
| 262 |
+
|
| 263 |
+
The model can therefore ``iterate'' $k$ extra forward passes on the
|
| 264 |
+
same prompt before committing to a token, with the $k$ extra hidden
|
| 265 |
+
states carrying intermediate computation. Setting $k=0$ recovers the
|
| 266 |
+
vanilla baseline.
|
| 267 |
+
|
| 268 |
+
The other three latent modes are alternatives that we ablated:
|
| 269 |
+
\code{fixed\_slots} (concatenate $k$ trainable seed embeddings ---
|
| 270 |
+
Option-2), \code{latent\_seeds} (similar to \code{fixed\_slots}), and
|
| 271 |
+
\code{residual} (project $k$ extra hidden states back onto the base
|
| 272 |
+
hidden state via a learned residual). All modes share the SFT and GRPO
|
| 273 |
+
trainers; only the next-token logit function changes.
|
| 274 |
+
|
| 275 |
+
For the curriculum, we grow $k$ stage by stage:
|
| 276 |
+
|
| 277 |
+
\begin{center}
|
| 278 |
+
\begin{tabular}{ccl}
|
| 279 |
+
\toprule
|
| 280 |
+
\textbf{stage} & \textbf{num\_cot\_tokens} & \textbf{comment} \\
|
| 281 |
+
\midrule
|
| 282 |
+
1 & 1 & one extra recursion as soon as the model has the surface form \\
|
| 283 |
+
2 & 2 & two --- needed for 1-step propagation reasoning \\
|
| 284 |
+
3 & 3 & three --- needed for 2-step propagation reasoning \\
|
| 285 |
+
\bottomrule
|
| 286 |
+
\end{tabular}
|
| 287 |
+
\end{center}
|
| 288 |
+
|
| 289 |
+
\section{The reward function}
|
| 290 |
+
|
| 291 |
+
Defined in \code{multi\_output\_cell\_policy/rewards.py}.
|
| 292 |
+
|
| 293 |
+
Given target set $T$, predicted set $P$ (after JSON parse), let
|
| 294 |
+
\begin{itemize}[leftmargin=*,nosep]
|
| 295 |
+
\item \code{num\_good} $= |P \cap T|$
|
| 296 |
+
\item \code{num\_bad} $= |P \setminus T|$
|
| 297 |
+
\item \code{num\_missing} $= \max(0, |T| - \text{num\_good})$
|
| 298 |
+
\item \code{is\_exact} $= (P \neq \varnothing) \land (P = T)$
|
| 299 |
+
\item $\mathrm{tri}(n) = n(n+1)/2$ (rewards larger correct sets superlinearly)
|
| 300 |
+
\end{itemize}
|
| 301 |
+
|
| 302 |
+
Then
|
| 303 |
+
\begin{align*}
|
| 304 |
+
r &= \mathrm{tri}(\text{num\_good}) \cdot R_g \;-\; \text{num\_bad} \cdot P_b \\
|
| 305 |
+
&\quad - \indic[P=\varnothing]\, P_e \;-\; \indic[|P|=1, |T|>1, i<2]\, P_s \\
|
| 306 |
+
&\quad - \text{num\_missing}\cdot P_m \;+\; \indic[\text{is\_exact}]\, B_x \\
|
| 307 |
+
&\quad - \indic[|P|<|T|, |T|>1]\, P_c
|
| 308 |
+
\end{align*}
|
| 309 |
+
|
| 310 |
+
with parameters (this is the recipe that produced the 0.58/0.68 latent
|
| 311 |
+
solve):
|
| 312 |
+
|
| 313 |
+
\begin{center}
|
| 314 |
+
\begin{tabular}{cllr}
|
| 315 |
+
\toprule
|
| 316 |
+
\textbf{symbol} & \textbf{flag} & \textbf{role} & \textbf{value} \\
|
| 317 |
+
\midrule
|
| 318 |
+
$R_g$ & \code{--reward\_good\_value} & per-correct-value reward (triangular shape) & 1.25 \\
|
| 319 |
+
$P_b$ & \code{--penalty\_bad\_value} & per-extra-wrong-value penalty & 1.0 \\
|
| 320 |
+
$P_{\text{mal}}$ & \code{--penalty\_malformed} & flat penalty if JSON parse fails & 4.0 \\
|
| 321 |
+
$P_e$ & \code{--penalty\_empty} & flat penalty if predicted set is empty & 0.5 \\
|
| 322 |
+
$P_s$ & \code{--penalty\_singleton} & only at stage$<$2: punishes singleton on multi-value targets & 1.5 \\
|
| 323 |
+
$P_m$ & \code{--penalty\_missing} & per-missing-value (recall pressure) --- \textbf{NEW} & \textbf{0.75} \\
|
| 324 |
+
$B_x$ & \code{--exact\_match\_bonus} & only when $P = T$ --- \textbf{NEW} & \textbf{2.0} \\
|
| 325 |
+
$P_c$ & \code{--cardinality\_mismatch\_penalty} & when $|P| < |T|$ and $|T|>1$ --- \textbf{NEW} & \textbf{1.0} \\
|
| 326 |
+
\bottomrule
|
| 327 |
+
\end{tabular}
|
| 328 |
+
\end{center}
|
| 329 |
+
|
| 330 |
+
Parse failures short-circuit to $r = -P_{\text{mal}}$ and zero per-cell
|
| 331 |
+
metrics.
|
| 332 |
+
|
| 333 |
+
\subsection{Why those three new terms exist (the breakthrough)}
|
| 334 |
+
|
| 335 |
+
Diagnosis: at the v3/v4 plateau, eval reported
|
| 336 |
+
|
| 337 |
+
\begin{lstlisting}
|
| 338 |
+
exact=0.95 precision=0.95 recall=0.95 solve=0.30 avg_set_size=1.000
|
| 339 |
+
\end{lstlisting}
|
| 340 |
+
|
| 341 |
+
across all checkpoints. Per-cell exact and precision/recall were all
|
| 342 |
+
near 0.95 but the model \textbf{always predicted a single digit}
|
| 343 |
+
(\code{avg\_set\_size=1.000}). On a multi-value target $T=\{8,9\}$,
|
| 344 |
+
predicting $\{8\}$ keeps precision $=1.0$, recall $=0.5$ and yet
|
| 345 |
+
\code{exact\_set\_match}$=0$. Solve $= \text{exact\_set\_match}^N$ is
|
| 346 |
+
catastrophic in $N$ ($=20$), so even a small fraction of multi-value
|
| 347 |
+
cells killed it.
|
| 348 |
+
|
| 349 |
+
Without any of the new terms the optimum of $r$ on a multi-value cell
|
| 350 |
+
is trivially ``predict the singleton you are most confident about'' ---
|
| 351 |
+
there is no upside to enumerate the second value. The three new terms
|
| 352 |
+
close exactly that hole:
|
| 353 |
+
|
| 354 |
+
\begin{itemize}[leftmargin=*,nosep]
|
| 355 |
+
\item $P_m$ (\code{penalty\_missing}) directly penalises recall;
|
| 356 |
+
\item $B_x$ (\code{exact\_match\_bonus}) makes $P=T$ strictly dominate any singleton;
|
| 357 |
+
\item $P_c$ (\code{cardinality\_mismatch\_penalty}) is a flat hammer whenever $|P|<|T|$.
|
| 358 |
+
\end{itemize}
|
| 359 |
+
|
| 360 |
+
After these terms were added, GRPO on the latent variant moved solve
|
| 361 |
+
from $\sim 0.30$ to $\sim 0.58$ (100-puzzle eval) over $\sim 200$
|
| 362 |
+
steps. The same fix is what we ported back into the baseline pipeline
|
| 363 |
+
this evening (see \S 10).
|
| 364 |
+
|
| 365 |
+
\section{Multi-stage warm-baseline pipeline (the recipe that worked)}
|
| 366 |
+
|
| 367 |
+
Master script:
|
| 368 |
+
\code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}.
|
| 369 |
+
|
| 370 |
+
For each curriculum stage we run \textbf{three sub-phases in order}:
|
| 371 |
+
|
| 372 |
+
\begin{lstlisting}
|
| 373 |
+
[stage i]
|
| 374 |
+
(1) baseline warm SFT (no latent tokens, k=0, vanilla LM)
|
| 375 |
+
(2) latent SFT (k = i, latent mode = recurrent_hidden)
|
| 376 |
+
(3) latent GRPO (k = i)
|
| 377 |
+
\end{lstlisting}
|
| 378 |
+
|
| 379 |
+
\textbf{The warm baseline phase (1) is the trick that makes the
|
| 380 |
+
curriculum work.} At every stage transition the data distribution
|
| 381 |
+
changes ($i$ increases $\Rightarrow$ target sets shrink) and a new
|
| 382 |
+
latent slot appears. Doing a vanilla SFT on the new distribution first
|
| 383 |
+
lets the LM relearn the surface form on familiar parameters; THEN the
|
| 384 |
+
latent SFT adds the extra thought slot on top of an already-good policy.
|
| 385 |
+
When we tried to add a new latent slot directly on top of the previous
|
| 386 |
+
stage's GRPO checkpoint, training loss did NOT decrease.
|
| 387 |
+
|
| 388 |
+
Concrete LR schedule used for the champion run:
|
| 389 |
+
|
| 390 |
+
\begin{center}
|
| 391 |
+
\begin{tabular}{lllc}
|
| 392 |
+
\toprule
|
| 393 |
+
\textbf{phase} & \textbf{init from} & \textbf{LR} & \textbf{k} \\
|
| 394 |
+
\midrule
|
| 395 |
+
S1 baseline SFT & base Qwen & 2e-4 & 0 \\
|
| 396 |
+
S1 latent SFT & S1 baseline & 2e-4 & 1 \\
|
| 397 |
+
S1 latent GRPO & S1 latent SFT & 1e-6 & 1 \\
|
| 398 |
+
S2 baseline warm SFT & S1 GRPO & 5e-5 & 0 \\
|
| 399 |
+
S2 latent SFT & S2 baseline & 5e-5 & 2 \\
|
| 400 |
+
S2 latent GRPO & S2 latent SFT & 1e-6 & 2 \\
|
| 401 |
+
S3 baseline warm SFT & S2 GRPO & 5e-5 & 0 \\
|
| 402 |
+
S3 latent SFT & S3 baseline & 5e-5 $\rightarrow$ 1e-5 (champion) & 3 \\
|
| 403 |
+
S3 latent GRPO & S3 latent SFT & 5e-6 ($\beta=0$) & 3 \\
|
| 404 |
+
\bottomrule
|
| 405 |
+
\end{tabular}
|
| 406 |
+
\end{center}
|
| 407 |
+
|
| 408 |
+
Other shared knobs:
|
| 409 |
+
|
| 410 |
+
\begin{lstlisting}
|
| 411 |
+
LoRA: r=32 a=64 dropout=0.05 on q,k,v,o,gate,up,down
|
| 412 |
+
SFT: per_device_bs=8 grad_accum=2 nproc=8 -> eff_bs=128
|
| 413 |
+
GRPO: per_device_bs=4 grad_accum=2 nproc=8 -> eff_bs=64
|
| 414 |
+
num_generations=4 beta=0.0 max_prompt_length=1024
|
| 415 |
+
max_completion_length=24
|
| 416 |
+
multi_value_oversample_factor=5, exact_match_bonus=2.0,
|
| 417 |
+
penalty_missing=0.75, cardinality_mismatch_penalty=1.0
|
| 418 |
+
\end{lstlisting}
|
| 419 |
+
|
| 420 |
+
\section{GRPO settings that mattered}
|
| 421 |
+
|
| 422 |
+
\begin{itemize}[leftmargin=*]
|
| 423 |
+
\item \textbf{$\beta = 0$.} The KL anchor was harmful in every sweep
|
| 424 |
+
where we tried $\beta>0$. \code{s3\_grpo\_kl04} ($\beta=0.04$) peaked
|
| 425 |
+
at solve $=0.625$ (40p) at step 100 and regressed to $0.525$ by step
|
| 426 |
+
500.
|
| 427 |
+
|
| 428 |
+
\item \textbf{\code{num\_generations} $= 4$.} With \code{num\_generations}$=2$
|
| 429 |
+
we routinely saw \code{reward\_std}$=0$ (all sampled completions
|
| 430 |
+
identical $\Rightarrow$ no gradient). Bumping to 4 fixed it.
|
| 431 |
+
|
| 432 |
+
\item \textbf{Low LR.} \code{lr=5e-6} was the steadiest. \code{lr=1e-5}
|
| 433 |
+
peaked at step 200 (solve $0.65$) then collapsed back to $0.54$ ---
|
| 434 |
+
classic mode collapse.
|
| 435 |
+
|
| 436 |
+
\item \textbf{Effective bs $\geq 64$.} TRL's GRPOConfig requires
|
| 437 |
+
\code{eff\_bs * grad\_accum \% num\_generations == 0}; with 8 GPUs we
|
| 438 |
+
hit this trivially, but we caution single-GPU rerunners to set
|
| 439 |
+
\code{per\_device\_bs=4 grad\_accum=2 num\_generations=4}.
|
| 440 |
+
|
| 441 |
+
\item \textbf{\code{enable\_input\_require\_grads()} on the wrapped backbone.}
|
| 442 |
+
Required for TRL 0.15.x + PEFT LoRA + gradient checkpointing ---
|
| 443 |
+
otherwise the loss tensor produced by GRPOTrainer has
|
| 444 |
+
\code{requires\_grad=False} and \code{.backward()} raises. Also
|
| 445 |
+
\code{unwrapped.config.use\_cache = False}.
|
| 446 |
+
\end{itemize}
|
| 447 |
+
|
| 448 |
+
\section{Final hyperparameters table --- champion latent run}
|
| 449 |
+
|
| 450 |
+
\begin{center}
|
| 451 |
+
\begin{longtable}{lll}
|
| 452 |
+
\toprule
|
| 453 |
+
\textbf{group} & \textbf{hyperparameter} & \textbf{value} \\
|
| 454 |
+
\midrule
|
| 455 |
+
\endfirsthead
|
| 456 |
+
\toprule
|
| 457 |
+
\textbf{group} & \textbf{hyperparameter} & \textbf{value} \\
|
| 458 |
+
\midrule
|
| 459 |
+
\endhead
|
| 460 |
+
Backbone & model & Qwen/Qwen2.5-1.5B-Instruct \\
|
| 461 |
+
Backbone & dtype & bf16 \\
|
| 462 |
+
Backbone & LoRA target modules & q,k,v,o,gate,up,down \\
|
| 463 |
+
Backbone & LoRA $r$ / $\alpha$ / dropout & 32 / 64 / 0.05 \\
|
| 464 |
+
Latent & mode & \code{recurrent\_hidden} \\
|
| 465 |
+
Latent & \code{num\_cot\_tokens} (S1/S2/S3) & 1 / 2 / 3 \\
|
| 466 |
+
Latent & \code{max\_latent\_slots} / seeds & 8 / 8 \\
|
| 467 |
+
Data & total empties & 20 \\
|
| 468 |
+
Data & train rows / eval rows & 10\,000 / 100 \\
|
| 469 |
+
Data & \code{multi\_value\_oversample\_factor} & 5 \\
|
| 470 |
+
Data & \code{mixed\_stage1\_ratio} (S1) & 1 \\
|
| 471 |
+
Data & \code{mixed\_stage2\_ratio} (S$\geq 2$) & 1 \\
|
| 472 |
+
SFT & per\_device\_bs / grad\_accum & 8 / 2 \\
|
| 473 |
+
SFT & \code{num\_epochs} (cap) & 64 \\
|
| 474 |
+
SFT & LR (S1 latent) & 2e-4 \\
|
| 475 |
+
SFT & LR (S2/S3 baseline warm + latent) & 5e-5 \\
|
| 476 |
+
SFT & LR (S3 latent champion \code{s3b\_lr1e5\_o5}) & 1e-5 \\
|
| 477 |
+
SFT & weight\_decay & 0.0 \\
|
| 478 |
+
SFT & gradient checkpointing & on \\
|
| 479 |
+
GRPO & per\_device\_bs / grad\_accum & 4 / 2 \\
|
| 480 |
+
GRPO & \code{num\_generations} & 4 \\
|
| 481 |
+
GRPO & LR & 5e-6 (S3); 1e-6 (S1, S2) \\
|
| 482 |
+
GRPO & $\beta$ (KL) & 0.0 \\
|
| 483 |
+
GRPO & \code{max\_prompt\_length} & 1024 \\
|
| 484 |
+
GRPO & \code{max\_completion\_length} & 24 \\
|
| 485 |
+
Reward & \code{reward\_good\_value} & 1.25 \\
|
| 486 |
+
Reward & \code{penalty\_bad\_value} & 1.0 \\
|
| 487 |
+
Reward & \code{penalty\_malformed} & 4.0 \\
|
| 488 |
+
Reward & \code{penalty\_empty} & 0.5 \\
|
| 489 |
+
Reward & \code{penalty\_singleton} & 1.5 \\
|
| 490 |
+
Reward & \code{penalty\_missing} & 0.75 \\
|
| 491 |
+
Reward & \code{exact\_match\_bonus} & 2.0 \\
|
| 492 |
+
Reward & \code{cardinality\_mismatch\_penalty} & 1.0 \\
|
| 493 |
+
Eval & early-stop on prec/recall & 0.98 \\
|
| 494 |
+
\bottomrule
|
| 495 |
+
\end{longtable}
|
| 496 |
+
\end{center}
|
| 497 |
+
|
| 498 |
+
\section{Headline results}
|
| 499 |
+
|
| 500 |
+
\subsection{Latent (with thought tokens, \code{recurrent\_hidden})}
|
| 501 |
+
|
| 502 |
+
\begin{center}
|
| 503 |
+
\begin{tabular}{llrrrrr}
|
| 504 |
+
\toprule
|
| 505 |
+
\textbf{eval} & \textbf{model / phase} & \textbf{step} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} \\
|
| 506 |
+
\midrule
|
| 507 |
+
\textbf{100p (auth.)} & \code{s3\_grpo\_baseline} (S3 GRPO, $\beta=0$, lr=5e-6) & 200 & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580 (58/100)} \\
|
| 508 |
+
40p & \code{s3\_grpo\_sharp\_rwd} ($B_x{=}4$, $P_c{=}3$) & 300 & --- & --- & --- & \textbf{0.675 (27/40)} \\
|
| 509 |
+
40p & \code{s3\_grpo\_lr1e5} & 200 & 0.978 & 0.978 & 0.979 & 0.650 \\
|
| 510 |
+
40p & \code{s3b\_lr1e5\_o5} (S3 SFT champion) & 2400 & 0.974 & 0.974 & 0.975 & 0.600 \\
|
| 511 |
+
\bottomrule
|
| 512 |
+
\end{tabular}
|
| 513 |
+
\end{center}
|
| 514 |
+
|
| 515 |
+
\subsection{Vanilla baseline (no thought tokens, same Qwen2.5-1.5B + LoRA)}
|
| 516 |
+
|
| 517 |
+
\begin{center}
|
| 518 |
+
\begin{tabular}{llrrr}
|
| 519 |
+
\toprule
|
| 520 |
+
\textbf{sweep} & \textbf{best variant} & \textbf{best step} & \textbf{exact} & \textbf{solve (100p)} \\
|
| 521 |
+
\midrule
|
| 522 |
+
v3 (single-GPU, no oversample, no new reward) & \code{baseline\_3stage\_20260522} & --- & 0.730 & \textbf{0.000} \\
|
| 523 |
+
v4 (LR sweep, multi-GPU, original reward) & \code{pipe\_v\_sft\_extend} (S3 SFT extended) & 4000 & 0.948 & \textbf{0.400} \\
|
| 524 |
+
\textbf{v6 (this evening; ports latent reward + oversample)} & \code{v6\_i\_sft\_v\_oversample10} & running & 0.952$+$ & \textbf{0.440 (best so far)} \\
|
| 525 |
+
\bottomrule
|
| 526 |
+
\end{tabular}
|
| 527 |
+
\end{center}
|
| 528 |
+
|
| 529 |
+
The v6 sweep is still running --- \code{v6\_e/f/i} are in S3 SFT
|
| 530 |
+
continuation, GRPO follow-on phases queued. The \code{v6\_i} variant
|
| 531 |
+
has hit \textbf{solve $=0.44$} at SFT eval (new baseline best,
|
| 532 |
+
$+0.04$ over v4) and is still climbing.
|
| 533 |
+
|
| 534 |
+
\subsection{Stage-by-stage trajectory (latent, 40-puzzle eval)}
|
| 535 |
+
|
| 536 |
+
\begin{lstlisting}
|
| 537 |
+
S1 SFT : exact ~ 0.85, solve ~ 0.20
|
| 538 |
+
S1 GRPO : exact ~ 0.90, solve ~ 0.20
|
| 539 |
+
S2 SFT (no oversample) : exact ~ 0.94, solve ~ 0.20-0.25 <- the wall
|
| 540 |
+
S2 SFT + multi_value_oversample=5 : exact ~ 0.96, solve ~ 0.30-0.35
|
| 541 |
+
S2 GRPO + new reward terms : exact ~ 0.96, solve ~ 0.35-0.40
|
| 542 |
+
S3 SFT (s3b_lr1e5_o5 step 2400) : exact 0.974, solve 0.600 <- SFT champion
|
| 543 |
+
S3 GRPO (s3_grpo_baseline step 200,100p): exact 0.967, solve 0.580 <- 100p champion
|
| 544 |
+
S3 GRPO (s3_grpo_sharp_rwd step 300,40p): solve 0.675 <- 40p peak
|
| 545 |
+
\end{lstlisting}
|
| 546 |
+
|
| 547 |
+
\subsection{Latent vs baseline gap (head-to-head, same 100p eval, same prompts)}
|
| 548 |
+
|
| 549 |
+
\begin{center}
|
| 550 |
+
\begin{tabular}{lrrrrr}
|
| 551 |
+
\toprule
|
| 552 |
+
\textbf{model} & \textbf{exact} & \textbf{prec} & \textbf{recall} & \textbf{solve} & \textbf{solved/100} \\
|
| 553 |
+
\midrule
|
| 554 |
+
Latent \code{recurrent\_hidden}, S3 GRPO & 0.9665 & 0.9673 & 0.9680 & \textbf{0.580} & 58 \\
|
| 555 |
+
Vanilla baseline, \code{v6\_i} (best at time of writing) & 0.952 & 0.952 & 0.952 & \textbf{0.440} & 44 \\
|
| 556 |
+
\bottomrule
|
| 557 |
+
\end{tabular}
|
| 558 |
+
\end{center}
|
| 559 |
+
|
| 560 |
+
Gap on 100-puzzle solve: $\approx$ \textbf{$+0.14$ absolute / $+32\%$
|
| 561 |
+
relative} for latent over the strongest baseline we have.
|
| 562 |
+
|
| 563 |
+
\section{Why the latent works (interpretation hypotheses)}
|
| 564 |
+
|
| 565 |
+
These are the working hypotheses the experiments are consistent with;
|
| 566 |
+
none is fully proven and ablations are still WIP.
|
| 567 |
+
|
| 568 |
+
\begin{enumerate}[leftmargin=*]
|
| 569 |
+
\item \textbf{Constraint-propagation depth.} Stage-3 i-consistency is
|
| 570 |
+
essentially 2-ply lookahead. With $k=3$ recurrent hidden tokens the
|
| 571 |
+
model gets exactly three extra forward passes between prompt and
|
| 572 |
+
output --- one for the legality check, one for 1-step propagation,
|
| 573 |
+
one for the second step of propagation. Empirically the gap to the
|
| 574 |
+
no-thought-token baseline appears at stages where multi-step
|
| 575 |
+
propagation matters (stage 2 onward; stage 1 numbers are essentially
|
| 576 |
+
identical).
|
| 577 |
+
|
| 578 |
+
\item \textbf{Multi-value cells require enumeration, which a singleton
|
| 579 |
+
softmax can't do in one forward pass.} A vanilla LM at 1.5B
|
| 580 |
+
parameters predicts essentially deterministically once temperature is
|
| 581 |
+
low; for a target set $\{8, 9\}$ the LM picks one of the two and
|
| 582 |
+
stops. The latent model can use one of the recurrent hidden steps to
|
| 583 |
+
``consider'' each option without committing yet, which is exactly
|
| 584 |
+
the failure mode in the data (\code{avg\_set\_size} $= 1.000$ for the
|
| 585 |
+
baseline, $\approx 1.05$ for the latent S3 model on the same eval).
|
| 586 |
+
|
| 587 |
+
\item \textbf{Stable curriculum capacity growth.} Adding a new latent
|
| 588 |
+
slot at every stage gives the model a ``fresh slate'' of
|
| 589 |
+
representational capacity at the exact transition where the task
|
| 590 |
+
gets harder. The warm-baseline SFT between stages prevents the new
|
| 591 |
+
slot from corrupting the previously learned policy. Without warm
|
| 592 |
+
baseline, training loss did not decrease at all (we observed this
|
| 593 |
+
directly when we tried to skip the warm baseline).
|
| 594 |
+
|
| 595 |
+
\item \textbf{GRPO without latent slots is starved of variance.} With
|
| 596 |
+
\code{max\_completion\_length} 24 and the model essentially
|
| 597 |
+
deterministic, GRPO's 4 sampled completions per prompt collapse to a
|
| 598 |
+
single answer --- \code{reward\_std}$=0$, no gradient. With latent
|
| 599 |
+
recurrence + the new \code{exact\_match\_bonus} reward, the model
|
| 600 |
+
occasionally samples a 2-element set, gets a much higher reward, and
|
| 601 |
+
that prompt gets a real gradient signal.
|
| 602 |
+
\end{enumerate}
|
| 603 |
+
|
| 604 |
+
\section{Reproducibility}
|
| 605 |
+
|
| 606 |
+
\noindent
|
| 607 |
+
Code repository: \url{https://github.com/Avra98/curriculum_cot} \\
|
| 608 |
+
Latent checkpoints: \url{https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages} \\
|
| 609 |
+
Baseline checkpoints: \url{https://huggingface.co/Avra98/sudoku-9x9-20empty-baseline-1p5b-sweep}
|
| 610 |
+
|
| 611 |
+
Key scripts:
|
| 612 |
+
|
| 613 |
+
\begin{itemize}[leftmargin=*,nosep]
|
| 614 |
+
\item Master orchestrator (latent, 9-phase warm-baseline pipeline):
|
| 615 |
+
\code{hard\_9x9\_stage1\_consistency\_queue/launch\_20empty\_warm\_baseline\_all\_latent\_modes\_stages123.sh}
|
| 616 |
+
\item Vanilla baseline pipeline:
|
| 617 |
+
\code{\_runs/baseline\_1p5b\_pipeline\_v4.sh} (with v6 launchers
|
| 618 |
+
\code{\_runs/launch\_baseline\_push\_v6.sh})
|
| 619 |
+
\item SFT trainer (vanilla):
|
| 620 |
+
\code{multi\_output\_cell\_policy/sft\_multi\_output\_train.py}
|
| 621 |
+
\item GRPO trainer (vanilla):
|
| 622 |
+
\code{multi\_output\_cell\_policy/grpo\_multi\_output\_train.py}
|
| 623 |
+
\item SFT trainer (latent):
|
| 624 |
+
\code{latent\_multi\_output\_cell\_policy/sft\_latent\_multi\_output\_train.py}
|
| 625 |
+
\item GRPO trainer (latent):
|
| 626 |
+
\code{latent\_multi\_output\_cell\_policy/grpo\_residual\_projector\_latent\_train.py}
|
| 627 |
+
\item Reward function: \code{multi\_output\_cell\_policy/rewards.py}
|
| 628 |
+
\item Prompt builder: \code{multi\_output\_cell\_policy/prompt\_builder.py}
|
| 629 |
+
\item Stage-i consistency:
|
| 630 |
+
\code{multi\_output\_cell\_policy/shared\_multi\_output\_policy.py}
|
| 631 |
+
\item 100-puzzle evaluator: \code{analysis/eval\_stage2\_checkpoint.py}
|
| 632 |
+
\end{itemize}
|
| 633 |
+
|
| 634 |
+
To reproduce the latent champion (1.5B, 9-phase, $\sim 16$ GPU$\cdot$h
|
| 635 |
+
on $8\times$H100 80GB):
|
| 636 |
+
|
| 637 |
+
\begin{lstlisting}
|
| 638 |
+
export STAGE1_BASELINE_ADAPTER_DIR=/path/to/stage1_baseline_seed_adapter
|
| 639 |
+
bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
|
| 640 |
+
\end{lstlisting}
|
| 641 |
+
|
| 642 |
+
To reproduce the v6 baseline push (single-GPU per variant, $\sim 6$
|
| 643 |
+
GPU$\cdot$h):
|
| 644 |
+
|
| 645 |
+
\begin{lstlisting}
|
| 646 |
+
bash _runs/launch_baseline_push_v6.sh
|
| 647 |
+
\end{lstlisting}
|
| 648 |
+
|
| 649 |
+
\appendix
|
| 650 |
+
|
| 651 |
+
\section{The reward fix as a one-line patch}
|
| 652 |
+
|
| 653 |
+
The single most consequential code change in this whole project, as a
|
| 654 |
+
self-contained patch on \code{multi\_output\_cell\_policy/rewards.py}:
|
| 655 |
+
|
| 656 |
+
\begin{lstlisting}[language=Python]
|
| 657 |
+
# new args (default 0 preserves legacy behaviour)
|
| 658 |
+
penalty_missing: float = 0.0
|
| 659 |
+
exact_match_bonus: float = 0.0
|
| 660 |
+
cardinality_mismatch_penalty: float = 0.0
|
| 661 |
+
|
| 662 |
+
num_missing = max(0, len(target_set) - num_good)
|
| 663 |
+
is_exact = bool(predicted_values) and (set(predicted_values) == target_set)
|
| 664 |
+
|
| 665 |
+
# ... base reward (triangular_number(num_good)*reward_good_value
|
| 666 |
+
# - num_bad*penalty_bad_value)
|
| 667 |
+
|
| 668 |
+
if num_missing > 0:
|
| 669 |
+
reward -= num_missing * penalty_missing
|
| 670 |
+
if is_exact:
|
| 671 |
+
reward += exact_match_bonus
|
| 672 |
+
if len(predicted_values) < len(target_values) and len(target_values) > 1:
|
| 673 |
+
reward -= cardinality_mismatch_penalty
|
| 674 |
+
\end{lstlisting}
|
| 675 |
+
|
| 676 |
+
Defaults are zero so old runs are unaffected; the recipe sets
|
| 677 |
+
$(P_m, B_x, P_c) = (0.75, 2.0, 1.0)$ for the vanilla recipe and
|
| 678 |
+
$(1.0, 4.0, 3.0)$ for the ``sharp\_rwd'' variant.
|
| 679 |
+
|
| 680 |
+
\section{The warm-baseline trick as a sequence diagram}
|
| 681 |
+
|
| 682 |
+
\begin{lstlisting}
|
| 683 |
+
Stage 1 Stage 2 Stage 3
|
| 684 |
+
--------- --------- ---------
|
| 685 |
+
[base Qwen] | |
|
| 686 |
+
| | |
|
| 687 |
+
v v v
|
| 688 |
+
S1 baseline SFT -> S2 baseline SFT -> S3 baseline SFT
|
| 689 |
+
(no latent, k=0) (no latent, k=0) (no latent, k=0)
|
| 690 |
+
| | |
|
| 691 |
+
v v v
|
| 692 |
+
S1 latent SFT -> S2 latent SFT -> S3 latent SFT
|
| 693 |
+
(k=1) (k=2) (k=3)
|
| 694 |
+
| | |
|
| 695 |
+
v v v
|
| 696 |
+
S1 latent GRPO -> S2 latent GRPO -> S3 latent GRPO
|
| 697 |
+
(k=1, b=0, lr 1e-6) (k=2, b=0) (k=3, b=0, lr 5e-6)
|
| 698 |
+
| | |
|
| 699 |
+
[final policy]
|
| 700 |
+
\end{lstlisting}
|
| 701 |
+
|
| 702 |
+
Every arrow is \code{init\_adapter\_dir = <previous output>}. Each row
|
| 703 |
+
is a ``slot in the curriculum''; the column adds reasoning capacity
|
| 704 |
+
($k\mathrel{+}=1$) and moves to a harder target distribution
|
| 705 |
+
($i\mathrel{+}=1$). The diagonal across the diagram is the actual
|
| 706 |
+
training trajectory.
|
| 707 |
+
|
| 708 |
+
\bigskip
|
| 709 |
+
\noindent\emph{End of report.}
|
| 710 |
+
|
| 711 |
+
\end{document}
|
addition/README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Addition Carry Experiment
|
| 2 |
+
|
| 3 |
+
This folder contains a standalone PyTorch experiment for algorithmic addition with carry on a one-layer decoder-only Transformer.
|
| 4 |
+
|
| 5 |
+
The comparison includes exactly three methods:
|
| 6 |
+
|
| 7 |
+
- `nocurr_nocot`: no curriculum, no latent chain-of-thought
|
| 8 |
+
- `curr_nocot`: digit-length curriculum, no latent chain-of-thought
|
| 9 |
+
- `curr_cot`: same one-layer backbone plus recurrent latent scratchpad tokens
|
| 10 |
+
|
| 11 |
+
## Task
|
| 12 |
+
|
| 13 |
+
Each example adds two reversed digit sequences in a configurable radix. Stage `k` means only the first `k` least-significant positions vary and the rest are zero. Every method now trains on the full example in one forward pass:
|
| 14 |
+
|
| 15 |
+
- predict all `k` active sum digits
|
| 16 |
+
- predict the final carry bit as an additional output slot
|
| 17 |
+
- compute masked loss over the active digits plus the final carry
|
| 18 |
+
|
| 19 |
+
This means the baseline and both curriculum variants learn whole-example addition rather than a single queried digit at a time. Internal carry targets are still kept for diagnostics and linear probing, but not as an auxiliary training loss.
|
| 20 |
+
|
| 21 |
+
The latent method reuses the same one-layer Transformer recurrently. After an initial pass over the inputs and output slots, the model appends continuous latent scratchpad tokens before the output slots and reruns the same layer, giving later curriculum stages more internal workspace for carry-like computation.
|
| 22 |
+
|
| 23 |
+
## Files
|
| 24 |
+
|
| 25 |
+
- `config.py`: experiment config and CLI handling
|
| 26 |
+
- `data.py`: synthetic data generation, curriculum stages, carry-heavy subsets
|
| 27 |
+
- `model.py`: one-layer decoder-only Transformer and latent recurrence
|
| 28 |
+
- `train.py`: single-run training entrypoint
|
| 29 |
+
- `eval.py`: evaluation and diagnostics
|
| 30 |
+
- `plots.py`: local plotting
|
| 31 |
+
- `run_comparison.py`: multi-seed comparison across all three methods
|
| 32 |
+
|
| 33 |
+
## Outputs
|
| 34 |
+
|
| 35 |
+
Each run writes:
|
| 36 |
+
|
| 37 |
+
- `config.json`
|
| 38 |
+
- `artifacts/history.jsonl`
|
| 39 |
+
- `artifacts/summary.json`
|
| 40 |
+
- `checkpoints/best.pt`
|
| 41 |
+
- `checkpoints/last.pt`
|
| 42 |
+
- local plots under `plots/`
|
| 43 |
+
|
| 44 |
+
If W&B is enabled, the same run also logs metrics there.
|
| 45 |
+
|
| 46 |
+
## Run A Single Method
|
| 47 |
+
|
| 48 |
+
Default settings:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python addition/train.py --model nocurr_nocot --use_wandb
|
| 52 |
+
python addition/train.py --model curr_nocot --use_wandb
|
| 53 |
+
python addition/train.py --model curr_cot --use_wandb
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
The default backbone now uses a single attention head. To run a harder hexadecimal setting:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
python addition/train.py --model curr_cot --radix 16 --use_wandb --output_dir addition_runs/hex_curr_cot
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Run offline or local-only:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
python addition/train.py --model curr_cot --wandb_mode offline
|
| 66 |
+
python addition/train.py --model curr_cot --no_wandb
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Smoke Test
|
| 70 |
+
|
| 71 |
+
Use the smoke preset to verify the whole pipeline quickly:
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
python addition/train.py --model curr_cot --preset smoke --no_wandb --output_dir addition_runs/smoke_curr_cot
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Run The Full Comparison
|
| 78 |
+
|
| 79 |
+
This runs all three methods across multiple seeds and saves aggregate plots and JSON:
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
python addition/run_comparison.py --preset default --use_wandb --comparison_output_dir addition_runs/comparison_default
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Small fast comparison:
|
| 86 |
+
|
| 87 |
+
```bash
|
| 88 |
+
python addition/run_comparison.py --preset smoke --no_wandb --comparison_output_dir addition_runs/comparison_smoke
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
## Main Metrics
|
| 92 |
+
|
| 93 |
+
The experiment reports:
|
| 94 |
+
|
| 95 |
+
- digit accuracy by output position
|
| 96 |
+
- final-carry accuracy
|
| 97 |
+
- exact whole-sum accuracy by active length
|
| 98 |
+
- average digit accuracy by length
|
| 99 |
+
- in-distribution results up to `train_max_digits`
|
| 100 |
+
- OOD results on longer lengths
|
| 101 |
+
- separate uniform and carry-heavy evaluations
|
| 102 |
+
|
| 103 |
+
## Diagnostics
|
| 104 |
+
|
| 105 |
+
The evaluation also includes:
|
| 106 |
+
|
| 107 |
+
- a linear probe on output-slot hidden states for carry prediction
|
| 108 |
+
- attention summaries showing how strongly the final carry readout attends to operand digits, previous output slots, and latent tokens
|
| 109 |
+
|
| 110 |
+
## Notes
|
| 111 |
+
|
| 112 |
+
- The first version is intentionally small enough to iterate locally.
|
| 113 |
+
- The backbone depth stays fixed at one layer in all methods.
|
| 114 |
+
- The latent method gets more recurrent compute, not more layers.
|
addition/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Standalone addition-with-carry experiment package."""
|
addition/config.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import dataclasses
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Iterable
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
VALID_MODELS = ("nocurr_nocot", "curr_nocot", "curr_cot")
|
| 14 |
+
VALID_PRESETS = ("default", "smoke")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class ExperimentConfig:
|
| 19 |
+
model: str = "nocurr_nocot"
|
| 20 |
+
output_dir: str = "addition_runs/default"
|
| 21 |
+
seed: int = 0
|
| 22 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
preset: str = "default"
|
| 24 |
+
run_name: str = ""
|
| 25 |
+
notes: str = ""
|
| 26 |
+
use_wandb: bool = True
|
| 27 |
+
wandb_project: str = "addition-carry"
|
| 28 |
+
wandb_entity: str = ""
|
| 29 |
+
wandb_mode: str = "online"
|
| 30 |
+
radix: int = 10
|
| 31 |
+
train_max_digits: int = 12
|
| 32 |
+
eval_max_digits: int = 20
|
| 33 |
+
ood_lengths: tuple[int, ...] = (14, 16, 20)
|
| 34 |
+
train_batch_size: int = 256
|
| 35 |
+
eval_batch_size: int = 512
|
| 36 |
+
learning_rate: float = 3e-4
|
| 37 |
+
weight_decay: float = 1e-2
|
| 38 |
+
grad_clip_norm: float = 1.0
|
| 39 |
+
carry_loss_weight: float = 0.0
|
| 40 |
+
train_steps: int = 3600
|
| 41 |
+
max_steps_per_stage: int = 300
|
| 42 |
+
validation_interval: int = 100
|
| 43 |
+
stage_accuracy_threshold: float = 0.99
|
| 44 |
+
initial_stage: int = 1
|
| 45 |
+
eval_examples_per_length: int = 256
|
| 46 |
+
carry_heavy_examples_per_length: int = 256
|
| 47 |
+
train_carry_heavy_prob: float = 0.15
|
| 48 |
+
d_model: int = 512
|
| 49 |
+
n_heads: int = 1
|
| 50 |
+
ff_dim: int = 2048
|
| 51 |
+
dropout: float = 0.0
|
| 52 |
+
max_latent_steps: int = 12
|
| 53 |
+
attention_probe_examples: int = 256
|
| 54 |
+
linear_probe_epochs: int = 150
|
| 55 |
+
linear_probe_lr: float = 1e-2
|
| 56 |
+
comparison_num_seeds: int = 5
|
| 57 |
+
|
| 58 |
+
def __post_init__(self) -> None:
|
| 59 |
+
if self.model not in VALID_MODELS:
|
| 60 |
+
raise ValueError(f"Unsupported model: {self.model}")
|
| 61 |
+
if self.preset not in VALID_PRESETS:
|
| 62 |
+
raise ValueError(f"Unsupported preset: {self.preset}")
|
| 63 |
+
if self.train_max_digits > self.eval_max_digits:
|
| 64 |
+
raise ValueError("train_max_digits must be <= eval_max_digits")
|
| 65 |
+
if self.max_latent_steps < 0:
|
| 66 |
+
raise ValueError("max_latent_steps must be non-negative")
|
| 67 |
+
if self.radix < 2 or self.radix > 16:
|
| 68 |
+
raise ValueError("radix must be between 2 and 16")
|
| 69 |
+
if self.initial_stage < 1 or self.initial_stage > self.train_max_digits:
|
| 70 |
+
raise ValueError("initial_stage must be between 1 and train_max_digits")
|
| 71 |
+
self.ood_lengths = tuple(int(v) for v in self.ood_lengths if int(v) > self.train_max_digits)
|
| 72 |
+
if not self.ood_lengths:
|
| 73 |
+
self.ood_lengths = (self.eval_max_digits,)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def uses_curriculum(self) -> bool:
|
| 77 |
+
return self.model in {"curr_nocot", "curr_cot"}
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def uses_latent_cot(self) -> bool:
|
| 81 |
+
return self.model == "curr_cot"
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def discrete_vocab_size(self) -> int:
|
| 85 |
+
return self.radix + 2
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def digit_vocab_size(self) -> int:
|
| 89 |
+
return self.radix
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def input_sequence_length(self) -> int:
|
| 93 |
+
return self.input_sequence_length_for_digits(self.eval_max_digits)
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def output_sequence_length(self) -> int:
|
| 97 |
+
return self.output_sequence_length_for_digits(self.eval_max_digits)
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def base_sequence_length(self) -> int:
|
| 101 |
+
return self.base_sequence_length_for_digits(self.eval_max_digits)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def max_sequence_length(self) -> int:
|
| 105 |
+
return self.base_sequence_length + self.max_latent_steps
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def effective_run_name(self) -> str:
|
| 109 |
+
if self.run_name:
|
| 110 |
+
return self.run_name
|
| 111 |
+
return f"{self.model}_base{self.radix}_seed{self.seed}"
|
| 112 |
+
|
| 113 |
+
def input_sequence_length_for_digits(self, active_digits: int) -> int:
|
| 114 |
+
return (int(active_digits) * 2) + 2
|
| 115 |
+
|
| 116 |
+
def output_sequence_length_for_digits(self, active_digits: int) -> int:
|
| 117 |
+
return int(active_digits) + 1
|
| 118 |
+
|
| 119 |
+
def base_sequence_length_for_digits(self, active_digits: int) -> int:
|
| 120 |
+
return self.input_sequence_length_for_digits(active_digits) + self.output_sequence_length_for_digits(active_digits)
|
| 121 |
+
|
| 122 |
+
def latent_steps_for_stage(self, stage: int) -> int:
|
| 123 |
+
if not self.uses_latent_cot:
|
| 124 |
+
return 0
|
| 125 |
+
return max(0, min(int(stage), int(self.max_latent_steps)))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def default_output_root() -> Path:
|
| 129 |
+
return Path("addition_runs")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def apply_preset(config: ExperimentConfig) -> ExperimentConfig:
|
| 133 |
+
config = dataclasses.replace(config)
|
| 134 |
+
if config.preset == "smoke":
|
| 135 |
+
config.output_dir = config.output_dir or str(default_output_root() / "smoke")
|
| 136 |
+
config.train_batch_size = 64
|
| 137 |
+
config.eval_batch_size = 128
|
| 138 |
+
config.d_model = 128
|
| 139 |
+
config.n_heads = 1
|
| 140 |
+
config.ff_dim = 512
|
| 141 |
+
config.train_steps = 180
|
| 142 |
+
config.max_steps_per_stage = 40
|
| 143 |
+
config.validation_interval = 20
|
| 144 |
+
config.eval_examples_per_length = 64
|
| 145 |
+
config.carry_heavy_examples_per_length = 64
|
| 146 |
+
config.attention_probe_examples = 64
|
| 147 |
+
config.linear_probe_epochs = 60
|
| 148 |
+
config.comparison_num_seeds = 2
|
| 149 |
+
return config
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def config_to_dict(config: ExperimentConfig) -> dict:
|
| 153 |
+
data = dataclasses.asdict(config)
|
| 154 |
+
data["ood_lengths"] = list(config.ood_lengths)
|
| 155 |
+
data["uses_curriculum"] = config.uses_curriculum
|
| 156 |
+
data["uses_latent_cot"] = config.uses_latent_cot
|
| 157 |
+
data["discrete_vocab_size"] = config.discrete_vocab_size
|
| 158 |
+
data["input_sequence_length"] = config.input_sequence_length
|
| 159 |
+
data["output_sequence_length"] = config.output_sequence_length
|
| 160 |
+
data["base_sequence_length"] = config.base_sequence_length
|
| 161 |
+
data["max_sequence_length"] = config.max_sequence_length
|
| 162 |
+
data["effective_run_name"] = config.effective_run_name
|
| 163 |
+
return data
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def save_config(config: ExperimentConfig, output_dir: Path) -> None:
|
| 167 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 168 |
+
with (output_dir / "config.json").open("w", encoding="utf-8") as handle:
|
| 169 |
+
json.dump(config_to_dict(config), handle, indent=2, sort_keys=True)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def add_config_arguments(parser: argparse.ArgumentParser) -> None:
|
| 173 |
+
parser.add_argument("--model", choices=VALID_MODELS, default="nocurr_nocot")
|
| 174 |
+
parser.add_argument("--output_dir", type=str, default="")
|
| 175 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 176 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 177 |
+
parser.add_argument("--preset", choices=VALID_PRESETS, default="default")
|
| 178 |
+
parser.add_argument("--run_name", type=str, default="")
|
| 179 |
+
parser.add_argument("--notes", type=str, default="")
|
| 180 |
+
parser.add_argument("--use_wandb", action="store_true")
|
| 181 |
+
parser.add_argument("--no_wandb", action="store_true")
|
| 182 |
+
parser.add_argument("--wandb_project", type=str, default="addition-carry")
|
| 183 |
+
parser.add_argument("--wandb_entity", type=str, default="")
|
| 184 |
+
parser.add_argument("--wandb_mode", type=str, default="online", choices=("online", "offline", "disabled"))
|
| 185 |
+
parser.add_argument("--radix", type=int, default=10)
|
| 186 |
+
parser.add_argument("--train_max_digits", type=int, default=12)
|
| 187 |
+
parser.add_argument("--eval_max_digits", type=int, default=20)
|
| 188 |
+
parser.add_argument("--ood_lengths", type=int, nargs="*", default=[14, 16, 20])
|
| 189 |
+
parser.add_argument("--train_batch_size", type=int, default=256)
|
| 190 |
+
parser.add_argument("--eval_batch_size", type=int, default=512)
|
| 191 |
+
parser.add_argument("--learning_rate", type=float, default=3e-4)
|
| 192 |
+
parser.add_argument("--weight_decay", type=float, default=1e-2)
|
| 193 |
+
parser.add_argument("--grad_clip_norm", type=float, default=1.0)
|
| 194 |
+
parser.add_argument("--carry_loss_weight", type=float, default=0.0)
|
| 195 |
+
parser.add_argument("--train_steps", type=int, default=3600)
|
| 196 |
+
parser.add_argument("--max_steps_per_stage", type=int, default=300)
|
| 197 |
+
parser.add_argument("--validation_interval", type=int, default=100)
|
| 198 |
+
parser.add_argument("--stage_accuracy_threshold", type=float, default=0.99)
|
| 199 |
+
parser.add_argument("--initial_stage", type=int, default=1)
|
| 200 |
+
parser.add_argument("--eval_examples_per_length", type=int, default=256)
|
| 201 |
+
parser.add_argument("--carry_heavy_examples_per_length", type=int, default=256)
|
| 202 |
+
parser.add_argument("--train_carry_heavy_prob", type=float, default=0.15)
|
| 203 |
+
parser.add_argument("--d_model", type=int, default=512)
|
| 204 |
+
parser.add_argument("--n_heads", type=int, default=1)
|
| 205 |
+
parser.add_argument("--ff_dim", type=int, default=2048)
|
| 206 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 207 |
+
parser.add_argument("--max_latent_steps", type=int, default=12)
|
| 208 |
+
parser.add_argument("--attention_probe_examples", type=int, default=256)
|
| 209 |
+
parser.add_argument("--linear_probe_epochs", type=int, default=150)
|
| 210 |
+
parser.add_argument("--linear_probe_lr", type=float, default=1e-2)
|
| 211 |
+
parser.add_argument("--comparison_num_seeds", type=int, default=5)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def build_config_from_args(args: argparse.Namespace) -> ExperimentConfig:
|
| 215 |
+
use_wandb = bool(args.use_wandb or not args.no_wandb)
|
| 216 |
+
if args.wandb_mode == "disabled":
|
| 217 |
+
use_wandb = False
|
| 218 |
+
output_dir = args.output_dir or str(default_output_root() / f"{args.model}_base{args.radix}_seed{args.seed}")
|
| 219 |
+
config = ExperimentConfig(
|
| 220 |
+
model=args.model,
|
| 221 |
+
output_dir=output_dir,
|
| 222 |
+
seed=args.seed,
|
| 223 |
+
device=args.device,
|
| 224 |
+
preset=args.preset,
|
| 225 |
+
run_name=args.run_name,
|
| 226 |
+
notes=args.notes,
|
| 227 |
+
use_wandb=use_wandb,
|
| 228 |
+
wandb_project=args.wandb_project,
|
| 229 |
+
wandb_entity=args.wandb_entity,
|
| 230 |
+
wandb_mode=args.wandb_mode,
|
| 231 |
+
radix=args.radix,
|
| 232 |
+
train_max_digits=args.train_max_digits,
|
| 233 |
+
eval_max_digits=args.eval_max_digits,
|
| 234 |
+
ood_lengths=tuple(args.ood_lengths),
|
| 235 |
+
train_batch_size=args.train_batch_size,
|
| 236 |
+
eval_batch_size=args.eval_batch_size,
|
| 237 |
+
learning_rate=args.learning_rate,
|
| 238 |
+
weight_decay=args.weight_decay,
|
| 239 |
+
grad_clip_norm=args.grad_clip_norm,
|
| 240 |
+
carry_loss_weight=args.carry_loss_weight,
|
| 241 |
+
train_steps=args.train_steps,
|
| 242 |
+
max_steps_per_stage=args.max_steps_per_stage,
|
| 243 |
+
validation_interval=args.validation_interval,
|
| 244 |
+
stage_accuracy_threshold=args.stage_accuracy_threshold,
|
| 245 |
+
initial_stage=args.initial_stage,
|
| 246 |
+
eval_examples_per_length=args.eval_examples_per_length,
|
| 247 |
+
carry_heavy_examples_per_length=args.carry_heavy_examples_per_length,
|
| 248 |
+
train_carry_heavy_prob=args.train_carry_heavy_prob,
|
| 249 |
+
d_model=args.d_model,
|
| 250 |
+
n_heads=args.n_heads,
|
| 251 |
+
ff_dim=args.ff_dim,
|
| 252 |
+
dropout=args.dropout,
|
| 253 |
+
max_latent_steps=args.max_latent_steps,
|
| 254 |
+
attention_probe_examples=args.attention_probe_examples,
|
| 255 |
+
linear_probe_epochs=args.linear_probe_epochs,
|
| 256 |
+
linear_probe_lr=args.linear_probe_lr,
|
| 257 |
+
comparison_num_seeds=args.comparison_num_seeds,
|
| 258 |
+
)
|
| 259 |
+
return apply_preset(config)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def build_arg_parser(description: str) -> argparse.ArgumentParser:
|
| 263 |
+
parser = argparse.ArgumentParser(description=description)
|
| 264 |
+
add_config_arguments(parser)
|
| 265 |
+
return parser
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def parse_config(description: str) -> ExperimentConfig:
|
| 269 |
+
parser = build_arg_parser(description)
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
return build_config_from_args(args)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def ensure_output_dirs(config: ExperimentConfig) -> dict[str, Path]:
|
| 275 |
+
root = Path(config.output_dir)
|
| 276 |
+
directories = {
|
| 277 |
+
"root": root,
|
| 278 |
+
"checkpoints": root / "checkpoints",
|
| 279 |
+
"stage_checkpoints": root / "checkpoints" / "stages",
|
| 280 |
+
"plots": root / "plots",
|
| 281 |
+
"artifacts": root / "artifacts",
|
| 282 |
+
}
|
| 283 |
+
for directory in directories.values():
|
| 284 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 285 |
+
return directories
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def flatten_metric_dict(prefix: str, metrics: dict[str, float | int | str]) -> dict[str, float | int | str]:
|
| 289 |
+
return {f"{prefix}{key}": value for key, value in metrics.items()}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def iter_stage_lengths(config: ExperimentConfig) -> Iterable[int]:
|
| 293 |
+
for stage in range(1, config.train_max_digits + 1):
|
| 294 |
+
yield stage
|
addition/data.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from addition.config import ExperimentConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DIGIT_OFFSET = 0
|
| 15 |
+
DEFAULT_SYMBOLS = "0123456789ABCDEF"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AdditionProblem:
|
| 20 |
+
a_digits: list[int]
|
| 21 |
+
b_digits: list[int]
|
| 22 |
+
sum_digits: list[int]
|
| 23 |
+
carry_out: list[int]
|
| 24 |
+
active_digits: int
|
| 25 |
+
is_carry_heavy: bool
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Batch:
|
| 30 |
+
input_ids: torch.Tensor
|
| 31 |
+
target_digits: torch.Tensor
|
| 32 |
+
target_digit_mask: torch.Tensor
|
| 33 |
+
target_carry: torch.Tensor
|
| 34 |
+
target_final_carry: torch.Tensor
|
| 35 |
+
active_digits: torch.Tensor
|
| 36 |
+
is_carry_heavy: torch.Tensor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class EvaluationSuite:
|
| 41 |
+
validation_uniform: dict[int, list[AdditionProblem]]
|
| 42 |
+
test_uniform: dict[int, list[AdditionProblem]]
|
| 43 |
+
test_carry_heavy: dict[int, list[AdditionProblem]]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def a_token_id(radix: int) -> int:
|
| 47 |
+
return radix
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def b_token_id(radix: int) -> int:
|
| 51 |
+
return radix + 1
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def seed_everything(seed: int) -> None:
|
| 55 |
+
random.seed(seed)
|
| 56 |
+
torch.manual_seed(seed)
|
| 57 |
+
if torch.cuda.is_available():
|
| 58 |
+
torch.cuda.manual_seed_all(seed)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def compute_sum_and_carry(a_digits: list[int], b_digits: list[int], radix: int) -> tuple[list[int], list[int]]:
|
| 62 |
+
sum_digits: list[int] = []
|
| 63 |
+
carry_out: list[int] = []
|
| 64 |
+
carry = 0
|
| 65 |
+
for a_digit, b_digit in zip(a_digits, b_digits):
|
| 66 |
+
total = int(a_digit) + int(b_digit) + carry
|
| 67 |
+
sum_digits.append(total % radix)
|
| 68 |
+
carry = total // radix
|
| 69 |
+
carry_out.append(carry)
|
| 70 |
+
return sum_digits, carry_out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def sample_uniform_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
|
| 74 |
+
a_digits = [0] * max_digits
|
| 75 |
+
b_digits = [0] * max_digits
|
| 76 |
+
for index in range(active_digits):
|
| 77 |
+
a_digits[index] = rng.randint(0, radix - 1)
|
| 78 |
+
b_digits[index] = rng.randint(0, radix - 1)
|
| 79 |
+
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
|
| 80 |
+
return AdditionProblem(
|
| 81 |
+
a_digits=a_digits,
|
| 82 |
+
b_digits=b_digits,
|
| 83 |
+
sum_digits=sum_digits,
|
| 84 |
+
carry_out=carry_out,
|
| 85 |
+
active_digits=active_digits,
|
| 86 |
+
is_carry_heavy=False,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sample_carry_heavy_problem(max_digits: int, active_digits: int, radix: int, rng: random.Random) -> AdditionProblem:
|
| 91 |
+
a_digits = [0] * max_digits
|
| 92 |
+
b_digits = [0] * max_digits
|
| 93 |
+
carry = 0
|
| 94 |
+
for index in range(active_digits):
|
| 95 |
+
high_floor = max(0, radix // 2)
|
| 96 |
+
a_digit = rng.randint(high_floor, radix - 1)
|
| 97 |
+
if carry == 0:
|
| 98 |
+
min_b = max(0, radix - a_digit)
|
| 99 |
+
else:
|
| 100 |
+
min_b = max(0, (radix - 1) - a_digit)
|
| 101 |
+
b_digit = rng.randint(min_b, radix - 1)
|
| 102 |
+
a_digits[index] = a_digit
|
| 103 |
+
b_digits[index] = b_digit
|
| 104 |
+
total = a_digit + b_digit + carry
|
| 105 |
+
carry = total // radix
|
| 106 |
+
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=radix)
|
| 107 |
+
return AdditionProblem(
|
| 108 |
+
a_digits=a_digits,
|
| 109 |
+
b_digits=b_digits,
|
| 110 |
+
sum_digits=sum_digits,
|
| 111 |
+
carry_out=carry_out,
|
| 112 |
+
active_digits=active_digits,
|
| 113 |
+
is_carry_heavy=True,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sample_problem(
|
| 118 |
+
max_digits: int,
|
| 119 |
+
active_digits: int,
|
| 120 |
+
radix: int,
|
| 121 |
+
rng: random.Random,
|
| 122 |
+
carry_heavy: bool = False,
|
| 123 |
+
) -> AdditionProblem:
|
| 124 |
+
if carry_heavy:
|
| 125 |
+
return sample_carry_heavy_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
|
| 126 |
+
return sample_uniform_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def encode_problem_tokens(problem: AdditionProblem, radix: int) -> list[int]:
|
| 130 |
+
return (
|
| 131 |
+
[a_token_id(radix)]
|
| 132 |
+
+ [DIGIT_OFFSET + digit for digit in problem.a_digits[: problem.active_digits]]
|
| 133 |
+
+ [b_token_id(radix)]
|
| 134 |
+
+ [DIGIT_OFFSET + digit for digit in problem.b_digits[: problem.active_digits]]
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def build_batch(
|
| 139 |
+
problems: list[AdditionProblem],
|
| 140 |
+
radix: int,
|
| 141 |
+
device: str,
|
| 142 |
+
) -> Batch:
|
| 143 |
+
active_digits = problems[0].active_digits if problems else 0
|
| 144 |
+
input_ids = torch.tensor(
|
| 145 |
+
[
|
| 146 |
+
encode_problem_tokens(problem=problem, radix=radix)
|
| 147 |
+
for problem in problems
|
| 148 |
+
],
|
| 149 |
+
dtype=torch.long,
|
| 150 |
+
device=device,
|
| 151 |
+
)
|
| 152 |
+
target_digits = torch.tensor(
|
| 153 |
+
[problem.sum_digits[:active_digits] for problem in problems],
|
| 154 |
+
dtype=torch.long,
|
| 155 |
+
device=device,
|
| 156 |
+
)
|
| 157 |
+
target_digit_mask = torch.tensor(
|
| 158 |
+
[[1] * active_digits for _ in problems],
|
| 159 |
+
dtype=torch.bool,
|
| 160 |
+
device=device,
|
| 161 |
+
)
|
| 162 |
+
target_carry = torch.tensor(
|
| 163 |
+
[problem.carry_out[:active_digits] for problem in problems],
|
| 164 |
+
dtype=torch.long,
|
| 165 |
+
device=device,
|
| 166 |
+
)
|
| 167 |
+
target_final_carry = torch.tensor(
|
| 168 |
+
[problem.carry_out[problem.active_digits - 1] for problem in problems],
|
| 169 |
+
dtype=torch.long,
|
| 170 |
+
device=device,
|
| 171 |
+
)
|
| 172 |
+
return Batch(
|
| 173 |
+
input_ids=input_ids,
|
| 174 |
+
target_digits=target_digits,
|
| 175 |
+
target_digit_mask=target_digit_mask,
|
| 176 |
+
target_carry=target_carry,
|
| 177 |
+
target_final_carry=target_final_carry,
|
| 178 |
+
active_digits=torch.tensor([problem.active_digits for problem in problems], dtype=torch.long, device=device),
|
| 179 |
+
is_carry_heavy=torch.tensor([int(problem.is_carry_heavy) for problem in problems], dtype=torch.bool, device=device),
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def sample_training_batch(
|
| 184 |
+
config: ExperimentConfig,
|
| 185 |
+
stage: int,
|
| 186 |
+
rng: random.Random,
|
| 187 |
+
device: str,
|
| 188 |
+
) -> Batch:
|
| 189 |
+
problems: list[AdditionProblem] = []
|
| 190 |
+
for _ in range(config.train_batch_size):
|
| 191 |
+
carry_heavy = rng.random() < config.train_carry_heavy_prob
|
| 192 |
+
problem = sample_problem(
|
| 193 |
+
max_digits=stage,
|
| 194 |
+
active_digits=stage,
|
| 195 |
+
radix=config.radix,
|
| 196 |
+
rng=rng,
|
| 197 |
+
carry_heavy=carry_heavy,
|
| 198 |
+
)
|
| 199 |
+
problems.append(problem)
|
| 200 |
+
return build_batch(
|
| 201 |
+
problems=problems,
|
| 202 |
+
radix=config.radix,
|
| 203 |
+
device=device,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def build_problem_set(
|
| 208 |
+
*,
|
| 209 |
+
max_digits: int,
|
| 210 |
+
active_digits: int,
|
| 211 |
+
radix: int,
|
| 212 |
+
count: int,
|
| 213 |
+
seed: int,
|
| 214 |
+
carry_heavy: bool,
|
| 215 |
+
) -> list[AdditionProblem]:
|
| 216 |
+
rng = random.Random(seed)
|
| 217 |
+
return [
|
| 218 |
+
sample_problem(max_digits=max_digits, active_digits=active_digits, radix=radix, rng=rng, carry_heavy=carry_heavy)
|
| 219 |
+
for _ in range(count)
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def build_evaluation_suite(config: ExperimentConfig) -> EvaluationSuite:
|
| 224 |
+
validation_uniform: dict[int, list[AdditionProblem]] = {}
|
| 225 |
+
test_uniform: dict[int, list[AdditionProblem]] = {}
|
| 226 |
+
test_carry_heavy: dict[int, list[AdditionProblem]] = {}
|
| 227 |
+
all_lengths = sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
|
| 228 |
+
for length in all_lengths:
|
| 229 |
+
validation_uniform[length] = build_problem_set(
|
| 230 |
+
max_digits=length,
|
| 231 |
+
active_digits=length,
|
| 232 |
+
radix=config.radix,
|
| 233 |
+
count=config.eval_examples_per_length,
|
| 234 |
+
seed=10_000 + length,
|
| 235 |
+
carry_heavy=False,
|
| 236 |
+
)
|
| 237 |
+
test_uniform[length] = build_problem_set(
|
| 238 |
+
max_digits=length,
|
| 239 |
+
active_digits=length,
|
| 240 |
+
radix=config.radix,
|
| 241 |
+
count=config.eval_examples_per_length,
|
| 242 |
+
seed=20_000 + length,
|
| 243 |
+
carry_heavy=False,
|
| 244 |
+
)
|
| 245 |
+
test_carry_heavy[length] = build_problem_set(
|
| 246 |
+
max_digits=length,
|
| 247 |
+
active_digits=length,
|
| 248 |
+
radix=config.radix,
|
| 249 |
+
count=config.carry_heavy_examples_per_length,
|
| 250 |
+
seed=30_000 + length,
|
| 251 |
+
carry_heavy=True,
|
| 252 |
+
)
|
| 253 |
+
return EvaluationSuite(
|
| 254 |
+
validation_uniform=validation_uniform,
|
| 255 |
+
test_uniform=test_uniform,
|
| 256 |
+
test_carry_heavy=test_carry_heavy,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def digits_to_string(digits: Iterable[int], final_carry: int, radix: int) -> str:
|
| 261 |
+
digits = list(digits)
|
| 262 |
+
significant_digits = list(digits)
|
| 263 |
+
if final_carry:
|
| 264 |
+
significant_digits.append(final_carry)
|
| 265 |
+
while len(significant_digits) > 1 and significant_digits[-1] == 0:
|
| 266 |
+
significant_digits.pop()
|
| 267 |
+
symbols = DEFAULT_SYMBOLS[:radix]
|
| 268 |
+
return "".join(symbols[digit] for digit in reversed(significant_digits))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def value_from_digits(digits: Iterable[int], final_carry: int, radix: int) -> int:
|
| 272 |
+
value = 0
|
| 273 |
+
place = 1
|
| 274 |
+
for digit in digits:
|
| 275 |
+
value += int(digit) * place
|
| 276 |
+
place *= radix
|
| 277 |
+
if final_carry:
|
| 278 |
+
value += int(final_carry) * place
|
| 279 |
+
return value
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def exact_sum_matches(
|
| 283 |
+
predicted_digits: list[int],
|
| 284 |
+
predicted_final_carry: int,
|
| 285 |
+
truth_digits: list[int],
|
| 286 |
+
truth_final_carry: int,
|
| 287 |
+
) -> bool:
|
| 288 |
+
return predicted_digits == truth_digits and int(predicted_final_carry) == int(truth_final_carry)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def summarize_problem(problem: AdditionProblem, radix: int) -> dict[str, int | str]:
|
| 292 |
+
final_carry = problem.carry_out[problem.active_digits - 1]
|
| 293 |
+
return {
|
| 294 |
+
"a": digits_to_string(problem.a_digits[: problem.active_digits], final_carry=0, radix=radix),
|
| 295 |
+
"b": digits_to_string(problem.b_digits[: problem.active_digits], final_carry=0, radix=radix),
|
| 296 |
+
"sum": digits_to_string(problem.sum_digits[: problem.active_digits], final_carry=final_carry, radix=radix),
|
| 297 |
+
"radix": radix,
|
| 298 |
+
"active_digits": problem.active_digits,
|
| 299 |
+
"carry_heavy": int(problem.is_carry_heavy),
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def count_carry_chain(problem: AdditionProblem) -> int:
|
| 304 |
+
longest = 0
|
| 305 |
+
current = 0
|
| 306 |
+
for index in range(problem.active_digits):
|
| 307 |
+
if problem.carry_out[index]:
|
| 308 |
+
current += 1
|
| 309 |
+
longest = max(longest, current)
|
| 310 |
+
else:
|
| 311 |
+
current = 0
|
| 312 |
+
return longest
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def carry_density(problem: AdditionProblem) -> float:
|
| 316 |
+
if problem.active_digits <= 0:
|
| 317 |
+
return 0.0
|
| 318 |
+
return float(sum(problem.carry_out[: problem.active_digits])) / float(problem.active_digits)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def curriculum_stage_lengths(config: ExperimentConfig) -> list[int]:
|
| 322 |
+
if config.uses_curriculum:
|
| 323 |
+
return list(range(1, config.train_max_digits + 1))
|
| 324 |
+
return [config.train_max_digits]
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def infer_eval_lengths(config: ExperimentConfig) -> list[int]:
|
| 328 |
+
return sorted(set(range(1, config.train_max_digits + 1)).union(config.ood_lengths))
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def estimate_train_tokens_per_step(config: ExperimentConfig, stage: int) -> int:
|
| 332 |
+
latent_steps = config.latent_steps_for_stage(stage)
|
| 333 |
+
return config.train_batch_size * (config.base_sequence_length_for_digits(stage) + latent_steps)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def stage_fraction(stage: int, max_stage: int) -> float:
|
| 337 |
+
if max_stage <= 1:
|
| 338 |
+
return 1.0
|
| 339 |
+
return float(stage - 1) / float(max_stage - 1)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def maybe_trim_examples(problems: list[AdditionProblem], limit: int) -> list[AdditionProblem]:
|
| 343 |
+
if limit <= 0 or len(problems) <= limit:
|
| 344 |
+
return list(problems)
|
| 345 |
+
return list(problems[:limit])
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def stage_display_name(stage: int) -> str:
|
| 349 |
+
suffix = "th"
|
| 350 |
+
if stage % 10 == 1 and stage % 100 != 11:
|
| 351 |
+
suffix = "st"
|
| 352 |
+
elif stage % 10 == 2 and stage % 100 != 12:
|
| 353 |
+
suffix = "nd"
|
| 354 |
+
elif stage % 10 == 3 and stage % 100 != 13:
|
| 355 |
+
suffix = "rd"
|
| 356 |
+
return f"{stage}{suffix}-digit"
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def ideal_carry_chain_examples(config: ExperimentConfig, active_digits: int) -> list[AdditionProblem]:
|
| 360 |
+
examples: list[AdditionProblem] = []
|
| 361 |
+
for base_digit in (max(0, config.radix - 2), config.radix - 1):
|
| 362 |
+
a_digits = [base_digit] * active_digits
|
| 363 |
+
b_digits = [1] * active_digits
|
| 364 |
+
sum_digits, carry_out = compute_sum_and_carry(a_digits, b_digits, radix=config.radix)
|
| 365 |
+
examples.append(
|
| 366 |
+
AdditionProblem(
|
| 367 |
+
a_digits=a_digits,
|
| 368 |
+
b_digits=b_digits,
|
| 369 |
+
sum_digits=sum_digits,
|
| 370 |
+
carry_out=carry_out,
|
| 371 |
+
active_digits=active_digits,
|
| 372 |
+
is_carry_heavy=True,
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
return examples
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def expected_sum_length(problem: AdditionProblem) -> int:
|
| 379 |
+
final_carry = problem.carry_out[problem.active_digits - 1]
|
| 380 |
+
return problem.active_digits + int(final_carry > 0)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def average_query_count(config: ExperimentConfig) -> float:
|
| 384 |
+
lengths = curriculum_stage_lengths(config)
|
| 385 |
+
return sum(lengths) / float(len(lengths))
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def token_budget(config: ExperimentConfig) -> int:
|
| 389 |
+
avg_stage = int(math.ceil(average_query_count(config)))
|
| 390 |
+
return config.base_sequence_length_for_digits(avg_stage) + config.latent_steps_for_stage(avg_stage)
|
addition/eval.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Iterable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from addition.config import ExperimentConfig
|
| 11 |
+
from addition.data import (
|
| 12 |
+
AdditionProblem,
|
| 13 |
+
EvaluationSuite,
|
| 14 |
+
build_batch,
|
| 15 |
+
carry_density,
|
| 16 |
+
count_carry_chain,
|
| 17 |
+
exact_sum_matches,
|
| 18 |
+
maybe_trim_examples,
|
| 19 |
+
)
|
| 20 |
+
from addition.model import AdditionTransformer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class LengthMetrics:
|
| 25 |
+
digit_accuracy: float
|
| 26 |
+
final_carry_accuracy: float
|
| 27 |
+
exact_match: float
|
| 28 |
+
avg_carry_chain: float
|
| 29 |
+
avg_carry_density: float
|
| 30 |
+
example_count: int
|
| 31 |
+
per_position_digit_accuracy: list[float]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _chunked(sequence: list[AdditionProblem], chunk_size: int) -> Iterable[list[AdditionProblem]]:
|
| 35 |
+
for start in range(0, len(sequence), chunk_size):
|
| 36 |
+
yield sequence[start : start + chunk_size]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def evaluate_problem_set(
|
| 41 |
+
model: AdditionTransformer,
|
| 42 |
+
config: ExperimentConfig,
|
| 43 |
+
problems: list[AdditionProblem],
|
| 44 |
+
active_digits: int,
|
| 45 |
+
*,
|
| 46 |
+
device: str,
|
| 47 |
+
return_attention: bool = False,
|
| 48 |
+
) -> tuple[LengthMetrics, dict[str, float] | None]:
|
| 49 |
+
model.eval()
|
| 50 |
+
latent_steps = config.latent_steps_for_stage(active_digits)
|
| 51 |
+
num_examples = len(problems)
|
| 52 |
+
if num_examples == 0:
|
| 53 |
+
empty = LengthMetrics(
|
| 54 |
+
digit_accuracy=0.0,
|
| 55 |
+
final_carry_accuracy=0.0,
|
| 56 |
+
exact_match=0.0,
|
| 57 |
+
avg_carry_chain=0.0,
|
| 58 |
+
avg_carry_density=0.0,
|
| 59 |
+
example_count=0,
|
| 60 |
+
per_position_digit_accuracy=[0.0] * active_digits,
|
| 61 |
+
)
|
| 62 |
+
return empty, None
|
| 63 |
+
|
| 64 |
+
predicted_digits = torch.zeros(num_examples, active_digits, dtype=torch.long)
|
| 65 |
+
predicted_final_carry = torch.zeros(num_examples, dtype=torch.long)
|
| 66 |
+
truth_digits = torch.tensor([[problem.sum_digits[position] for position in range(active_digits)] for problem in problems], dtype=torch.long)
|
| 67 |
+
truth_final_carry = torch.tensor([problem.carry_out[active_digits - 1] for problem in problems], dtype=torch.long)
|
| 68 |
+
attention_stats: dict[str, float] | None = None
|
| 69 |
+
|
| 70 |
+
offset = 0
|
| 71 |
+
for problem_chunk in _chunked(problems, config.eval_batch_size):
|
| 72 |
+
batch = build_batch(
|
| 73 |
+
problems=problem_chunk,
|
| 74 |
+
radix=config.radix,
|
| 75 |
+
device=device,
|
| 76 |
+
)
|
| 77 |
+
outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=return_attention)
|
| 78 |
+
chunk_size = len(problem_chunk)
|
| 79 |
+
predicted_digits[offset : offset + chunk_size] = outputs.digit_logits.argmax(dim=-1)[:, :active_digits].cpu()
|
| 80 |
+
predicted_final_carry[offset : offset + chunk_size] = outputs.final_carry_logits.argmax(dim=-1).cpu()
|
| 81 |
+
if return_attention and attention_stats is None:
|
| 82 |
+
attention_stats = summarize_attention(
|
| 83 |
+
attention_weights=outputs.attention_weights,
|
| 84 |
+
active_digits=active_digits,
|
| 85 |
+
input_sequence_length=batch.input_ids.shape[1],
|
| 86 |
+
output_sequence_length=outputs.output_hidden.shape[1],
|
| 87 |
+
)
|
| 88 |
+
offset += chunk_size
|
| 89 |
+
|
| 90 |
+
exact_matches = []
|
| 91 |
+
for example_index, problem in enumerate(problems):
|
| 92 |
+
exact_matches.append(
|
| 93 |
+
exact_sum_matches(
|
| 94 |
+
predicted_digits=predicted_digits[example_index].tolist(),
|
| 95 |
+
predicted_final_carry=int(predicted_final_carry[example_index].item()),
|
| 96 |
+
truth_digits=problem.sum_digits[:active_digits],
|
| 97 |
+
truth_final_carry=problem.carry_out[active_digits - 1],
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
per_position_digit = (predicted_digits == truth_digits).float().mean(dim=0).tolist()
|
| 102 |
+
metrics = LengthMetrics(
|
| 103 |
+
digit_accuracy=float((predicted_digits == truth_digits).float().mean().item()),
|
| 104 |
+
final_carry_accuracy=float((predicted_final_carry == truth_final_carry).float().mean().item()),
|
| 105 |
+
exact_match=float(torch.tensor(exact_matches, dtype=torch.float32).mean().item()),
|
| 106 |
+
avg_carry_chain=float(sum(count_carry_chain(problem) for problem in problems) / len(problems)),
|
| 107 |
+
avg_carry_density=float(sum(carry_density(problem) for problem in problems) / len(problems)),
|
| 108 |
+
example_count=len(problems),
|
| 109 |
+
per_position_digit_accuracy=[float(value) for value in per_position_digit],
|
| 110 |
+
)
|
| 111 |
+
return metrics, attention_stats
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def summarize_attention(
|
| 115 |
+
attention_weights: torch.Tensor | None,
|
| 116 |
+
*,
|
| 117 |
+
active_digits: int,
|
| 118 |
+
input_sequence_length: int,
|
| 119 |
+
output_sequence_length: int,
|
| 120 |
+
) -> dict[str, float]:
|
| 121 |
+
if attention_weights is None:
|
| 122 |
+
return {}
|
| 123 |
+
# Shape: [batch, heads, target_len, source_len]
|
| 124 |
+
final_attention = attention_weights[:, :, -1, :]
|
| 125 |
+
attention_mean = final_attention.mean(dim=(0, 1))
|
| 126 |
+
active_last_a_index = active_digits
|
| 127 |
+
active_last_b_index = input_sequence_length // 2 + active_digits
|
| 128 |
+
latent_slice = attention_mean[input_sequence_length : -output_sequence_length]
|
| 129 |
+
output_slice = attention_mean[-output_sequence_length:-1]
|
| 130 |
+
entropy = -torch.sum(attention_mean * torch.log(attention_mean.clamp_min(1e-9))).item()
|
| 131 |
+
summary = {
|
| 132 |
+
"lsd_a_attention": float(attention_mean[1].item()),
|
| 133 |
+
"msd_a_attention": float(attention_mean[active_last_a_index].item()),
|
| 134 |
+
"lsd_b_attention": float(attention_mean[(input_sequence_length // 2) + 1].item()),
|
| 135 |
+
"msd_b_attention": float(attention_mean[active_last_b_index].item()),
|
| 136 |
+
"attention_entropy": float(entropy),
|
| 137 |
+
"all_latent_attention": float(latent_slice.sum().item()) if latent_slice.numel() else 0.0,
|
| 138 |
+
"previous_output_attention": float(output_slice.sum().item()) if output_slice.numel() else 0.0,
|
| 139 |
+
}
|
| 140 |
+
return summary
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def evaluate_length_dict(
|
| 145 |
+
model: AdditionTransformer,
|
| 146 |
+
config: ExperimentConfig,
|
| 147 |
+
problems_by_length: dict[int, list[AdditionProblem]],
|
| 148 |
+
*,
|
| 149 |
+
device: str,
|
| 150 |
+
attention_length: int | None = None,
|
| 151 |
+
) -> dict[str, dict]:
|
| 152 |
+
structured: dict[str, dict] = {}
|
| 153 |
+
for length, problems in sorted(problems_by_length.items()):
|
| 154 |
+
length_metrics, attention = evaluate_problem_set(
|
| 155 |
+
model=model,
|
| 156 |
+
config=config,
|
| 157 |
+
problems=problems,
|
| 158 |
+
active_digits=length,
|
| 159 |
+
device=device,
|
| 160 |
+
return_attention=attention_length is not None and attention_length == length,
|
| 161 |
+
)
|
| 162 |
+
structured[str(length)] = {
|
| 163 |
+
"digit_accuracy": length_metrics.digit_accuracy,
|
| 164 |
+
"final_carry_accuracy": length_metrics.final_carry_accuracy,
|
| 165 |
+
"exact_match": length_metrics.exact_match,
|
| 166 |
+
"avg_carry_chain": length_metrics.avg_carry_chain,
|
| 167 |
+
"avg_carry_density": length_metrics.avg_carry_density,
|
| 168 |
+
"example_count": length_metrics.example_count,
|
| 169 |
+
"per_position_digit_accuracy": length_metrics.per_position_digit_accuracy,
|
| 170 |
+
}
|
| 171 |
+
if attention is not None:
|
| 172 |
+
structured[str(length)]["attention_summary"] = attention
|
| 173 |
+
return structured
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def collect_hidden_dataset(
|
| 177 |
+
model: AdditionTransformer,
|
| 178 |
+
config: ExperimentConfig,
|
| 179 |
+
problems: list[AdditionProblem],
|
| 180 |
+
*,
|
| 181 |
+
active_digits: int,
|
| 182 |
+
device: str,
|
| 183 |
+
limit_examples: int,
|
| 184 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 185 |
+
model.eval()
|
| 186 |
+
latent_steps = config.latent_steps_for_stage(active_digits)
|
| 187 |
+
selected = maybe_trim_examples(problems, limit_examples)
|
| 188 |
+
hidden_states: list[torch.Tensor] = []
|
| 189 |
+
carry_targets: list[torch.Tensor] = []
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
for problem_chunk in _chunked(selected, config.eval_batch_size):
|
| 192 |
+
batch = build_batch(
|
| 193 |
+
problems=problem_chunk,
|
| 194 |
+
radix=config.radix,
|
| 195 |
+
device=device,
|
| 196 |
+
)
|
| 197 |
+
outputs = model(batch.input_ids, latent_steps=latent_steps, return_attention=False)
|
| 198 |
+
slot_hidden = outputs.output_hidden[:, :active_digits, :]
|
| 199 |
+
slot_mask = batch.target_digit_mask
|
| 200 |
+
hidden_states.append(slot_hidden[slot_mask].detach().cpu())
|
| 201 |
+
carry_targets.append(batch.target_carry[slot_mask].detach().cpu())
|
| 202 |
+
return torch.cat(hidden_states, dim=0), torch.cat(carry_targets, dim=0)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def fit_linear_probe(
|
| 206 |
+
hidden_states: torch.Tensor,
|
| 207 |
+
carry_targets: torch.Tensor,
|
| 208 |
+
*,
|
| 209 |
+
epochs: int,
|
| 210 |
+
learning_rate: float,
|
| 211 |
+
) -> dict[str, float]:
|
| 212 |
+
if hidden_states.numel() == 0:
|
| 213 |
+
return {"probe_accuracy": 0.0}
|
| 214 |
+
indices = torch.randperm(hidden_states.shape[0])
|
| 215 |
+
hidden_states = hidden_states[indices]
|
| 216 |
+
carry_targets = carry_targets[indices]
|
| 217 |
+
split_index = max(1, int(0.8 * hidden_states.shape[0]))
|
| 218 |
+
train_hidden = hidden_states[:split_index]
|
| 219 |
+
train_targets = carry_targets[:split_index]
|
| 220 |
+
test_hidden = hidden_states[split_index:]
|
| 221 |
+
test_targets = carry_targets[split_index:]
|
| 222 |
+
if test_hidden.numel() == 0:
|
| 223 |
+
test_hidden = train_hidden
|
| 224 |
+
test_targets = train_targets
|
| 225 |
+
|
| 226 |
+
probe = nn.Linear(hidden_states.shape[-1], 2)
|
| 227 |
+
optimizer = torch.optim.AdamW(probe.parameters(), lr=learning_rate)
|
| 228 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 229 |
+
for _ in range(epochs):
|
| 230 |
+
logits = probe(train_hidden)
|
| 231 |
+
loss = loss_fn(logits, train_targets)
|
| 232 |
+
optimizer.zero_grad()
|
| 233 |
+
loss.backward()
|
| 234 |
+
optimizer.step()
|
| 235 |
+
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
predictions = probe(test_hidden).argmax(dim=-1)
|
| 238 |
+
accuracy = float((predictions == test_targets).float().mean().item())
|
| 239 |
+
return {"probe_accuracy": accuracy}
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def evaluate_suite(
|
| 243 |
+
model: AdditionTransformer,
|
| 244 |
+
config: ExperimentConfig,
|
| 245 |
+
suite: EvaluationSuite,
|
| 246 |
+
*,
|
| 247 |
+
device: str,
|
| 248 |
+
) -> dict[str, dict]:
|
| 249 |
+
id_lengths = list(range(1, config.train_max_digits + 1))
|
| 250 |
+
ood_lengths = list(config.ood_lengths)
|
| 251 |
+
max_attention_length = max(ood_lengths) if ood_lengths else config.train_max_digits
|
| 252 |
+
|
| 253 |
+
validation = evaluate_length_dict(
|
| 254 |
+
model=model,
|
| 255 |
+
config=config,
|
| 256 |
+
problems_by_length={length: suite.validation_uniform[length] for length in id_lengths},
|
| 257 |
+
device=device,
|
| 258 |
+
)
|
| 259 |
+
uniform_all = evaluate_length_dict(
|
| 260 |
+
model=model,
|
| 261 |
+
config=config,
|
| 262 |
+
problems_by_length={length: suite.test_uniform[length] for length in sorted(set(id_lengths + ood_lengths))},
|
| 263 |
+
device=device,
|
| 264 |
+
attention_length=max_attention_length,
|
| 265 |
+
)
|
| 266 |
+
carry_heavy_all = evaluate_length_dict(
|
| 267 |
+
model=model,
|
| 268 |
+
config=config,
|
| 269 |
+
problems_by_length={length: suite.test_carry_heavy[length] for length in sorted(set(id_lengths + ood_lengths))},
|
| 270 |
+
device=device,
|
| 271 |
+
attention_length=max_attention_length,
|
| 272 |
+
)
|
| 273 |
+
probe_hidden, probe_targets = collect_hidden_dataset(
|
| 274 |
+
model=model,
|
| 275 |
+
config=config,
|
| 276 |
+
problems=suite.test_carry_heavy[max_attention_length],
|
| 277 |
+
active_digits=max_attention_length,
|
| 278 |
+
device=device,
|
| 279 |
+
limit_examples=config.attention_probe_examples,
|
| 280 |
+
)
|
| 281 |
+
diagnostics = fit_linear_probe(
|
| 282 |
+
hidden_states=probe_hidden,
|
| 283 |
+
carry_targets=probe_targets,
|
| 284 |
+
epochs=config.linear_probe_epochs,
|
| 285 |
+
learning_rate=config.linear_probe_lr,
|
| 286 |
+
)
|
| 287 |
+
diagnostics["attention_uniform"] = uniform_all[str(max_attention_length)].get("attention_summary", {})
|
| 288 |
+
diagnostics["attention_carry_heavy"] = carry_heavy_all[str(max_attention_length)].get("attention_summary", {})
|
| 289 |
+
return {
|
| 290 |
+
"validation_uniform": validation,
|
| 291 |
+
"test_uniform": uniform_all,
|
| 292 |
+
"test_carry_heavy": carry_heavy_all,
|
| 293 |
+
"diagnostics": diagnostics,
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def stage_validation_metric(results: dict[str, dict], stage: int) -> float:
|
| 298 |
+
stage_metrics = results["validation_uniform"][str(stage)]
|
| 299 |
+
return float(stage_metrics["digit_accuracy"])
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def flatten_nested_metrics(prefix: str, nested: dict[str, dict]) -> dict[str, float]:
|
| 303 |
+
flat: dict[str, float] = {}
|
| 304 |
+
for split_name, split_metrics in nested.items():
|
| 305 |
+
if split_name == "diagnostics":
|
| 306 |
+
for key, value in split_metrics.items():
|
| 307 |
+
if isinstance(value, dict):
|
| 308 |
+
for inner_key, inner_value in value.items():
|
| 309 |
+
flat[f"{prefix}{split_name}/{key}/{inner_key}"] = float(inner_value)
|
| 310 |
+
else:
|
| 311 |
+
flat[f"{prefix}{split_name}/{key}"] = float(value)
|
| 312 |
+
continue
|
| 313 |
+
for length, length_metrics in split_metrics.items():
|
| 314 |
+
if not isinstance(length_metrics, dict):
|
| 315 |
+
continue
|
| 316 |
+
for metric_name, metric_value in length_metrics.items():
|
| 317 |
+
if isinstance(metric_value, list):
|
| 318 |
+
if metric_value:
|
| 319 |
+
flat[f"{prefix}{split_name}/length_{length}/{metric_name}_mean"] = float(sum(metric_value) / len(metric_value))
|
| 320 |
+
continue
|
| 321 |
+
if isinstance(metric_value, dict):
|
| 322 |
+
for inner_key, inner_value in metric_value.items():
|
| 323 |
+
flat[f"{prefix}{split_name}/length_{length}/{metric_name}/{inner_key}"] = float(inner_value)
|
| 324 |
+
continue
|
| 325 |
+
flat[f"{prefix}{split_name}/length_{length}/{metric_name}"] = float(metric_value)
|
| 326 |
+
return flat
|
addition/model.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from addition.config import ExperimentConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class ModelOutput:
|
| 13 |
+
digit_logits: torch.Tensor
|
| 14 |
+
final_carry_logits: torch.Tensor
|
| 15 |
+
output_hidden: torch.Tensor
|
| 16 |
+
latent_history: list[torch.Tensor]
|
| 17 |
+
attention_weights: torch.Tensor | None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerBlock(nn.Module):
|
| 21 |
+
def __init__(self, d_model: int, n_heads: int, ff_dim: int, dropout: float) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.ln_1 = nn.LayerNorm(d_model)
|
| 24 |
+
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
self.ln_2 = nn.LayerNorm(d_model)
|
| 27 |
+
self.mlp = nn.Sequential(
|
| 28 |
+
nn.Linear(d_model, ff_dim),
|
| 29 |
+
nn.GELU(),
|
| 30 |
+
nn.Linear(ff_dim, d_model),
|
| 31 |
+
nn.Dropout(dropout),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, hidden_states: torch.Tensor, need_weights: bool = False) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 35 |
+
seq_len = hidden_states.shape[1]
|
| 36 |
+
causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool).triu(1)
|
| 37 |
+
normed = self.ln_1(hidden_states)
|
| 38 |
+
attn_output, attn_weights = self.attn(
|
| 39 |
+
normed,
|
| 40 |
+
normed,
|
| 41 |
+
normed,
|
| 42 |
+
need_weights=need_weights,
|
| 43 |
+
average_attn_weights=False,
|
| 44 |
+
attn_mask=causal_mask,
|
| 45 |
+
)
|
| 46 |
+
hidden_states = hidden_states + self.dropout(attn_output)
|
| 47 |
+
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
|
| 48 |
+
return hidden_states, attn_weights if need_weights else None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AdditionTransformer(nn.Module):
|
| 52 |
+
def __init__(self, config: ExperimentConfig) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.config = config
|
| 55 |
+
self.token_embedding = nn.Embedding(config.discrete_vocab_size, config.d_model)
|
| 56 |
+
self.position_embedding = nn.Embedding(config.max_sequence_length, config.d_model)
|
| 57 |
+
self.latent_type_embedding = nn.Parameter(torch.zeros(config.d_model))
|
| 58 |
+
self.output_slot_embeddings = nn.Parameter(torch.zeros(config.output_sequence_length, config.d_model))
|
| 59 |
+
self.block = TransformerBlock(
|
| 60 |
+
d_model=config.d_model,
|
| 61 |
+
n_heads=config.n_heads,
|
| 62 |
+
ff_dim=config.ff_dim,
|
| 63 |
+
dropout=config.dropout,
|
| 64 |
+
)
|
| 65 |
+
self.final_ln = nn.LayerNorm(config.d_model)
|
| 66 |
+
self.digit_head = nn.Linear(config.d_model, config.digit_vocab_size)
|
| 67 |
+
self.final_carry_head = nn.Linear(config.d_model, 2)
|
| 68 |
+
self.reset_parameters()
|
| 69 |
+
|
| 70 |
+
def reset_parameters(self) -> None:
|
| 71 |
+
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
|
| 72 |
+
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
| 73 |
+
nn.init.normal_(self.latent_type_embedding, mean=0.0, std=0.02)
|
| 74 |
+
nn.init.normal_(self.output_slot_embeddings, mean=0.0, std=0.02)
|
| 75 |
+
nn.init.xavier_uniform_(self.digit_head.weight)
|
| 76 |
+
nn.init.zeros_(self.digit_head.bias)
|
| 77 |
+
nn.init.xavier_uniform_(self.final_carry_head.weight)
|
| 78 |
+
nn.init.zeros_(self.final_carry_head.bias)
|
| 79 |
+
|
| 80 |
+
def embed_discrete_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
seq_len = input_ids.shape[1]
|
| 82 |
+
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
|
| 83 |
+
return self.token_embedding(input_ids) + self.position_embedding(positions)
|
| 84 |
+
|
| 85 |
+
def embed_output_slots(
|
| 86 |
+
self,
|
| 87 |
+
batch_size: int,
|
| 88 |
+
output_length: int,
|
| 89 |
+
latent_count: int,
|
| 90 |
+
input_length: int,
|
| 91 |
+
device: torch.device,
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
positions = torch.arange(output_length, device=device) + input_length + latent_count
|
| 94 |
+
positioned = self.output_slot_embeddings[:output_length] + self.position_embedding(positions)
|
| 95 |
+
return positioned.unsqueeze(0).expand(batch_size, -1, -1)
|
| 96 |
+
|
| 97 |
+
def _run_block(
|
| 98 |
+
self,
|
| 99 |
+
embeddings: torch.Tensor,
|
| 100 |
+
*,
|
| 101 |
+
need_attention: bool = False,
|
| 102 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 103 |
+
hidden_states, attention_weights = self.block(embeddings, need_weights=need_attention)
|
| 104 |
+
hidden_states = self.final_ln(hidden_states)
|
| 105 |
+
return hidden_states, attention_weights
|
| 106 |
+
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
input_ids: torch.Tensor,
|
| 110 |
+
*,
|
| 111 |
+
latent_steps: int = 0,
|
| 112 |
+
return_attention: bool = False,
|
| 113 |
+
) -> ModelOutput:
|
| 114 |
+
base_embeddings = self.embed_discrete_tokens(input_ids)
|
| 115 |
+
latent_history: list[torch.Tensor] = []
|
| 116 |
+
attention_weights: torch.Tensor | None = None
|
| 117 |
+
batch_size = input_ids.shape[0]
|
| 118 |
+
input_length = input_ids.shape[1]
|
| 119 |
+
active_digits = max(1, (input_length - 2) // 2)
|
| 120 |
+
output_length = active_digits + 1
|
| 121 |
+
output_embeddings = self.embed_output_slots(
|
| 122 |
+
batch_size=batch_size,
|
| 123 |
+
output_length=output_length,
|
| 124 |
+
latent_count=0,
|
| 125 |
+
input_length=input_length,
|
| 126 |
+
device=input_ids.device,
|
| 127 |
+
)
|
| 128 |
+
hidden_states, attention_weights = self._run_block(
|
| 129 |
+
torch.cat([base_embeddings, output_embeddings], dim=1),
|
| 130 |
+
need_attention=return_attention,
|
| 131 |
+
)
|
| 132 |
+
output_hidden = hidden_states[:, -output_length:, :]
|
| 133 |
+
summary_hidden = output_hidden[:, -1, :]
|
| 134 |
+
latent_history.append(summary_hidden)
|
| 135 |
+
|
| 136 |
+
latent_embeddings: list[torch.Tensor] = []
|
| 137 |
+
for step_index in range(int(latent_steps)):
|
| 138 |
+
latent_token = summary_hidden.unsqueeze(1) + self.latent_type_embedding.view(1, 1, -1)
|
| 139 |
+
latent_position_index = input_length + step_index
|
| 140 |
+
latent_token = latent_token + self.position_embedding.weight[latent_position_index].view(1, 1, -1)
|
| 141 |
+
latent_embeddings.append(latent_token)
|
| 142 |
+
output_embeddings = self.embed_output_slots(
|
| 143 |
+
batch_size=batch_size,
|
| 144 |
+
output_length=output_length,
|
| 145 |
+
latent_count=len(latent_embeddings),
|
| 146 |
+
input_length=input_length,
|
| 147 |
+
device=input_ids.device,
|
| 148 |
+
)
|
| 149 |
+
hidden_states, attention_weights = self._run_block(
|
| 150 |
+
torch.cat([base_embeddings] + latent_embeddings + [output_embeddings], dim=1),
|
| 151 |
+
need_attention=return_attention,
|
| 152 |
+
)
|
| 153 |
+
latent_index = input_length + step_index
|
| 154 |
+
summary_hidden = hidden_states[:, latent_index, :]
|
| 155 |
+
output_hidden = hidden_states[:, -output_length:, :]
|
| 156 |
+
latent_history.append(summary_hidden)
|
| 157 |
+
|
| 158 |
+
digit_logits = self.digit_head(output_hidden[:, :active_digits, :])
|
| 159 |
+
final_carry_logits = self.final_carry_head(output_hidden[:, -1, :])
|
| 160 |
+
return ModelOutput(
|
| 161 |
+
digit_logits=digit_logits,
|
| 162 |
+
final_carry_logits=final_carry_logits,
|
| 163 |
+
output_hidden=output_hidden,
|
| 164 |
+
latent_history=latent_history,
|
| 165 |
+
attention_weights=attention_weights,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def parameter_count(self) -> int:
|
| 169 |
+
return sum(parameter.numel() for parameter in self.parameters())
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def build_model(config: ExperimentConfig, device: str | None = None) -> AdditionTransformer:
|
| 173 |
+
model = AdditionTransformer(config)
|
| 174 |
+
if device is not None:
|
| 175 |
+
model = model.to(device)
|
| 176 |
+
return model
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@torch.no_grad()
|
| 180 |
+
def describe_model(config: ExperimentConfig) -> dict[str, int]:
|
| 181 |
+
model = build_model(config)
|
| 182 |
+
total_params = model.parameter_count()
|
| 183 |
+
head_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "head" in name)
|
| 184 |
+
embedding_params = sum(parameter.numel() for name, parameter in model.named_parameters() if "embedding" in name)
|
| 185 |
+
return {
|
| 186 |
+
"total_params": int(total_params),
|
| 187 |
+
"embedding_params": int(embedding_params),
|
| 188 |
+
"head_params": int(head_params),
|
| 189 |
+
"backbone_params": int(total_params - head_params),
|
| 190 |
+
}
|
addition/plots.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _load_pyplot():
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
return plt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def plot_training_history(history: list[dict[str, Any]], output_dir: Path) -> list[Path]:
|
| 15 |
+
if not history:
|
| 16 |
+
return []
|
| 17 |
+
plt = _load_pyplot()
|
| 18 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
steps = [entry["global_step"] for entry in history]
|
| 20 |
+
digit_acc = [entry["validation_digit_accuracy"] for entry in history]
|
| 21 |
+
carry_acc = [entry["validation_final_carry_accuracy"] for entry in history]
|
| 22 |
+
exact_match = [entry["validation_exact_match"] for entry in history]
|
| 23 |
+
stages = [entry["stage"] for entry in history]
|
| 24 |
+
|
| 25 |
+
saved_paths: list[Path] = []
|
| 26 |
+
|
| 27 |
+
plt.figure(figsize=(8, 4.5))
|
| 28 |
+
plt.plot(steps, digit_acc, label="Val digit acc")
|
| 29 |
+
plt.plot(steps, carry_acc, label="Val final carry acc")
|
| 30 |
+
plt.plot(steps, exact_match, label="Val exact match")
|
| 31 |
+
plt.xlabel("Global step")
|
| 32 |
+
plt.ylabel("Accuracy")
|
| 33 |
+
plt.ylim(0.0, 1.01)
|
| 34 |
+
plt.legend()
|
| 35 |
+
plt.tight_layout()
|
| 36 |
+
metrics_path = output_dir / "training_curves.png"
|
| 37 |
+
plt.savefig(metrics_path, dpi=160)
|
| 38 |
+
plt.close()
|
| 39 |
+
saved_paths.append(metrics_path)
|
| 40 |
+
|
| 41 |
+
plt.figure(figsize=(8, 4.5))
|
| 42 |
+
plt.step(steps, stages, where="post")
|
| 43 |
+
plt.xlabel("Global step")
|
| 44 |
+
plt.ylabel("Curriculum stage")
|
| 45 |
+
plt.tight_layout()
|
| 46 |
+
stage_path = output_dir / "stage_progression.png"
|
| 47 |
+
plt.savefig(stage_path, dpi=160)
|
| 48 |
+
plt.close()
|
| 49 |
+
saved_paths.append(stage_path)
|
| 50 |
+
|
| 51 |
+
return saved_paths
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _collect_length_metric(aggregate: dict[str, Any], method: str, split: str, metric: str) -> tuple[list[int], list[float], list[float]]:
|
| 55 |
+
lengths = sorted(int(length) for length in aggregate[method][split].keys())
|
| 56 |
+
means = [aggregate[method][split][str(length)][metric]["mean"] for length in lengths]
|
| 57 |
+
stds = [aggregate[method][split][str(length)][metric]["std"] for length in lengths]
|
| 58 |
+
return lengths, means, stds
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def plot_method_comparison(aggregate: dict[str, Any], output_dir: Path) -> list[Path]:
|
| 62 |
+
plt = _load_pyplot()
|
| 63 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
saved_paths: list[Path] = []
|
| 65 |
+
methods = list(aggregate.keys())
|
| 66 |
+
splits = [
|
| 67 |
+
("test_uniform", "uniform_exact_match.png", "Uniform exact-match by length"),
|
| 68 |
+
("test_carry_heavy", "carry_heavy_exact_match.png", "Carry-heavy exact-match by length"),
|
| 69 |
+
]
|
| 70 |
+
for split, filename, title in splits:
|
| 71 |
+
plt.figure(figsize=(8, 4.5))
|
| 72 |
+
for method in methods:
|
| 73 |
+
lengths, means, stds = _collect_length_metric(aggregate, method, split, "exact_match")
|
| 74 |
+
plt.plot(lengths, means, marker="o", label=method)
|
| 75 |
+
lower = [max(0.0, mean - std) for mean, std in zip(means, stds)]
|
| 76 |
+
upper = [min(1.0, mean + std) for mean, std in zip(means, stds)]
|
| 77 |
+
plt.fill_between(lengths, lower, upper, alpha=0.15)
|
| 78 |
+
plt.xlabel("Active digits")
|
| 79 |
+
plt.ylabel("Exact-match accuracy")
|
| 80 |
+
plt.title(title)
|
| 81 |
+
plt.ylim(0.0, 1.01)
|
| 82 |
+
plt.legend()
|
| 83 |
+
plt.tight_layout()
|
| 84 |
+
path = output_dir / filename
|
| 85 |
+
plt.savefig(path, dpi=160)
|
| 86 |
+
plt.close()
|
| 87 |
+
saved_paths.append(path)
|
| 88 |
+
|
| 89 |
+
plt.figure(figsize=(8, 4.5))
|
| 90 |
+
for method in methods:
|
| 91 |
+
stages = sorted(int(stage) for stage in aggregate[method]["stage_progression"].keys())
|
| 92 |
+
means = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["mean"] for stage in stages]
|
| 93 |
+
stds = [aggregate[method]["stage_progression"][str(stage)]["validation_digit_accuracy"]["std"] for stage in stages]
|
| 94 |
+
plt.plot(stages, means, marker="o", label=method)
|
| 95 |
+
plt.fill_between(
|
| 96 |
+
stages,
|
| 97 |
+
[max(0.0, mean - std) for mean, std in zip(means, stds)],
|
| 98 |
+
[min(1.0, mean + std) for mean, std in zip(means, stds)],
|
| 99 |
+
alpha=0.15,
|
| 100 |
+
)
|
| 101 |
+
plt.xlabel("Curriculum stage")
|
| 102 |
+
plt.ylabel("Best validation digit accuracy")
|
| 103 |
+
plt.ylim(0.0, 1.01)
|
| 104 |
+
plt.title("Validation digit accuracy vs stage")
|
| 105 |
+
plt.legend()
|
| 106 |
+
plt.tight_layout()
|
| 107 |
+
stage_curve_path = output_dir / "validation_digit_accuracy_by_stage.png"
|
| 108 |
+
plt.savefig(stage_curve_path, dpi=160)
|
| 109 |
+
plt.close()
|
| 110 |
+
saved_paths.append(stage_curve_path)
|
| 111 |
+
return saved_paths
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def plot_single_run_results(summary: dict[str, Any], output_dir: Path) -> list[Path]:
|
| 115 |
+
plt = _load_pyplot()
|
| 116 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 117 |
+
saved_paths = plot_training_history(summary.get("history", []), output_dir)
|
| 118 |
+
uniform = summary["final_results"]["test_uniform"]
|
| 119 |
+
carry_heavy = summary["final_results"]["test_carry_heavy"]
|
| 120 |
+
lengths = sorted(int(length) for length in uniform.keys())
|
| 121 |
+
uniform_exact = [uniform[str(length)]["exact_match"] for length in lengths]
|
| 122 |
+
carry_exact = [carry_heavy[str(length)]["exact_match"] for length in lengths]
|
| 123 |
+
plt.figure(figsize=(8, 4.5))
|
| 124 |
+
plt.plot(lengths, uniform_exact, marker="o", label="Uniform")
|
| 125 |
+
plt.plot(lengths, carry_exact, marker="o", label="Carry-heavy")
|
| 126 |
+
plt.xlabel("Active digits")
|
| 127 |
+
plt.ylabel("Exact-match accuracy")
|
| 128 |
+
plt.ylim(0.0, 1.01)
|
| 129 |
+
plt.legend()
|
| 130 |
+
plt.tight_layout()
|
| 131 |
+
final_curve_path = output_dir / "final_exact_match_by_length.png"
|
| 132 |
+
plt.savefig(final_curve_path, dpi=160)
|
| 133 |
+
plt.close()
|
| 134 |
+
saved_paths.append(final_curve_path)
|
| 135 |
+
return saved_paths
|
addition/run_comparison.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from statistics import mean, pstdev
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from addition.config import VALID_MODELS, add_config_arguments, apply_preset, build_config_from_args
|
| 11 |
+
from addition.plots import plot_method_comparison
|
| 12 |
+
from addition.train import run_experiment
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _mean_std(values: list[float]) -> dict[str, float]:
|
| 16 |
+
if not values:
|
| 17 |
+
return {"mean": 0.0, "std": 0.0}
|
| 18 |
+
if len(values) == 1:
|
| 19 |
+
return {"mean": float(values[0]), "std": 0.0}
|
| 20 |
+
return {"mean": float(mean(values)), "std": float(pstdev(values))}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _aggregate_split_metrics(run_summaries: list[dict[str, Any]], split_name: str) -> dict[str, Any]:
|
| 24 |
+
lengths = sorted(run_summaries[0]["final_results"][split_name].keys(), key=int)
|
| 25 |
+
metric_names = ["digit_accuracy", "final_carry_accuracy", "exact_match", "avg_carry_chain", "avg_carry_density"]
|
| 26 |
+
aggregated: dict[str, Any] = {}
|
| 27 |
+
for length in lengths:
|
| 28 |
+
aggregated[length] = {}
|
| 29 |
+
for metric_name in metric_names:
|
| 30 |
+
values = [float(summary["final_results"][split_name][length][metric_name]) for summary in run_summaries]
|
| 31 |
+
aggregated[length][metric_name] = _mean_std(values)
|
| 32 |
+
return aggregated
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _aggregate_stage_progression(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
| 36 |
+
max_stage = max(int(entry["stage"]) for summary in run_summaries for entry in summary["history"])
|
| 37 |
+
aggregated: dict[str, Any] = {}
|
| 38 |
+
for stage in range(1, max_stage + 1):
|
| 39 |
+
stage_values = []
|
| 40 |
+
stage_exact = []
|
| 41 |
+
for summary in run_summaries:
|
| 42 |
+
stage_entries = [entry for entry in summary["history"] if int(entry["stage"]) == stage]
|
| 43 |
+
if not stage_entries:
|
| 44 |
+
continue
|
| 45 |
+
stage_values.append(max(float(entry["validation_digit_accuracy"]) for entry in stage_entries))
|
| 46 |
+
stage_exact.append(max(float(entry["validation_exact_match"]) for entry in stage_entries))
|
| 47 |
+
aggregated[str(stage)] = {
|
| 48 |
+
"validation_digit_accuracy": _mean_std(stage_values),
|
| 49 |
+
"validation_exact_match": _mean_std(stage_exact),
|
| 50 |
+
}
|
| 51 |
+
return aggregated
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _aggregate_diagnostics(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
| 55 |
+
diagnostics = [summary["final_results"]["diagnostics"] for summary in run_summaries]
|
| 56 |
+
output: dict[str, Any] = {
|
| 57 |
+
"probe_accuracy": _mean_std([float(diag["probe_accuracy"]) for diag in diagnostics]),
|
| 58 |
+
}
|
| 59 |
+
for attention_key in ("attention_uniform", "attention_carry_heavy"):
|
| 60 |
+
attention_values = [diag.get(attention_key, {}) for diag in diagnostics]
|
| 61 |
+
metric_names = sorted({metric for diag in attention_values for metric in diag.keys()})
|
| 62 |
+
output[attention_key] = {
|
| 63 |
+
metric_name: _mean_std([float(diag.get(metric_name, 0.0)) for diag in attention_values]) for metric_name in metric_names
|
| 64 |
+
}
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def aggregate_runs(results_by_method: dict[str, list[dict[str, Any]]]) -> dict[str, Any]:
|
| 69 |
+
aggregate: dict[str, Any] = {}
|
| 70 |
+
for method, run_summaries in results_by_method.items():
|
| 71 |
+
aggregate[method] = {
|
| 72 |
+
"test_uniform": _aggregate_split_metrics(run_summaries, "test_uniform"),
|
| 73 |
+
"test_carry_heavy": _aggregate_split_metrics(run_summaries, "test_carry_heavy"),
|
| 74 |
+
"stage_progression": _aggregate_stage_progression(run_summaries),
|
| 75 |
+
"diagnostics": _aggregate_diagnostics(run_summaries),
|
| 76 |
+
}
|
| 77 |
+
return aggregate
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 81 |
+
parser = argparse.ArgumentParser(description="Run the full addition comparison across methods and seeds.")
|
| 82 |
+
add_config_arguments(parser)
|
| 83 |
+
parser.add_argument("--methods", nargs="*", default=list(VALID_MODELS), choices=VALID_MODELS)
|
| 84 |
+
parser.add_argument("--seeds", nargs="*", type=int, default=None)
|
| 85 |
+
parser.add_argument("--comparison_output_dir", type=str, default="")
|
| 86 |
+
return parser
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main() -> None:
|
| 90 |
+
parser = build_parser()
|
| 91 |
+
args = parser.parse_args()
|
| 92 |
+
base_config = apply_preset(build_config_from_args(args))
|
| 93 |
+
seeds = args.seeds or list(range(base_config.comparison_num_seeds))
|
| 94 |
+
comparison_root = Path(args.comparison_output_dir or f"addition_runs/comparison_{base_config.preset}")
|
| 95 |
+
comparison_root.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
results_by_method: dict[str, list[dict[str, Any]]] = {}
|
| 98 |
+
for method in args.methods:
|
| 99 |
+
results_by_method[method] = []
|
| 100 |
+
for seed in seeds:
|
| 101 |
+
args.model = method
|
| 102 |
+
args.seed = seed
|
| 103 |
+
args.output_dir = str(comparison_root / f"{method}_seed{seed}")
|
| 104 |
+
config = apply_preset(build_config_from_args(args))
|
| 105 |
+
config.output_dir = str(comparison_root / f"{method}_seed{seed}")
|
| 106 |
+
print(f"[addition comparison] running method={method} seed={seed}", flush=True)
|
| 107 |
+
summary = run_experiment(config)
|
| 108 |
+
results_by_method[method].append(summary)
|
| 109 |
+
|
| 110 |
+
aggregate = aggregate_runs(results_by_method)
|
| 111 |
+
aggregate_payload = {
|
| 112 |
+
"methods": args.methods,
|
| 113 |
+
"seeds": seeds,
|
| 114 |
+
"aggregate": aggregate,
|
| 115 |
+
}
|
| 116 |
+
with (comparison_root / "aggregate_results.json").open("w", encoding="utf-8") as handle:
|
| 117 |
+
json.dump(aggregate_payload, handle, indent=2, sort_keys=True)
|
| 118 |
+
plot_method_comparison(aggregate, comparison_root / "plots")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
addition/train.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from addition.config import ExperimentConfig, ensure_output_dirs, parse_config, save_config
|
| 12 |
+
from addition.data import build_batch, build_evaluation_suite, digits_to_string, exact_sum_matches, sample_training_batch, seed_everything
|
| 13 |
+
from addition.eval import evaluate_problem_set, evaluate_suite, flatten_nested_metrics
|
| 14 |
+
from addition.model import build_model, describe_model
|
| 15 |
+
from addition.plots import plot_single_run_results
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _maybe_init_wandb(config: ExperimentConfig, output_dir: Path):
|
| 19 |
+
if not config.use_wandb or config.wandb_mode == "disabled":
|
| 20 |
+
return None
|
| 21 |
+
try:
|
| 22 |
+
import wandb
|
| 23 |
+
except ImportError:
|
| 24 |
+
print("wandb is not installed; continuing with local logging only.")
|
| 25 |
+
return None
|
| 26 |
+
run = wandb.init(
|
| 27 |
+
project=config.wandb_project,
|
| 28 |
+
entity=config.wandb_entity or None,
|
| 29 |
+
name=config.effective_run_name,
|
| 30 |
+
mode=config.wandb_mode,
|
| 31 |
+
config={"experiment": config.__dict__},
|
| 32 |
+
dir=str(output_dir),
|
| 33 |
+
reinit=True,
|
| 34 |
+
)
|
| 35 |
+
return run
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _save_json(path: Path, payload: dict[str, Any]) -> None:
|
| 39 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 40 |
+
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _save_checkpoint(path: Path, model: nn.Module, optimizer: torch.optim.Optimizer, metadata: dict[str, Any]) -> None:
|
| 44 |
+
torch.save(
|
| 45 |
+
{
|
| 46 |
+
"model_state": model.state_dict(),
|
| 47 |
+
"optimizer_state": optimizer.state_dict(),
|
| 48 |
+
"metadata": metadata,
|
| 49 |
+
},
|
| 50 |
+
path,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _stage_checkpoint_path(stage_directory: Path, stage: int) -> Path:
|
| 55 |
+
return stage_directory / f"stage_{stage:02d}_passed.pt"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _evaluate_current_stage(
|
| 59 |
+
model: nn.Module,
|
| 60 |
+
config: ExperimentConfig,
|
| 61 |
+
suite,
|
| 62 |
+
stage: int,
|
| 63 |
+
device: str,
|
| 64 |
+
) -> dict[str, float]:
|
| 65 |
+
stage_metrics, _ = evaluate_problem_set(
|
| 66 |
+
model=model,
|
| 67 |
+
config=config,
|
| 68 |
+
problems=suite.validation_uniform[stage],
|
| 69 |
+
active_digits=stage,
|
| 70 |
+
device=device,
|
| 71 |
+
return_attention=False,
|
| 72 |
+
)
|
| 73 |
+
return {
|
| 74 |
+
"digit_accuracy": stage_metrics.digit_accuracy,
|
| 75 |
+
"final_carry_accuracy": stage_metrics.final_carry_accuracy,
|
| 76 |
+
"exact_match": stage_metrics.exact_match,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _masked_digit_loss(
|
| 81 |
+
logits: torch.Tensor,
|
| 82 |
+
targets: torch.Tensor,
|
| 83 |
+
mask: torch.Tensor,
|
| 84 |
+
loss_fn: nn.Module,
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
masked_logits = logits[mask]
|
| 87 |
+
masked_targets = targets[mask]
|
| 88 |
+
if masked_logits.numel() == 0:
|
| 89 |
+
return logits.new_zeros(())
|
| 90 |
+
return loss_fn(masked_logits, masked_targets)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def _print_model_debug_format(
|
| 95 |
+
model: nn.Module,
|
| 96 |
+
config: ExperimentConfig,
|
| 97 |
+
*,
|
| 98 |
+
stage: int,
|
| 99 |
+
rng,
|
| 100 |
+
device: str,
|
| 101 |
+
) -> None:
|
| 102 |
+
debug_batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device)
|
| 103 |
+
outputs = model(debug_batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
|
| 104 |
+
print("[addition debug] model_architecture", flush=True)
|
| 105 |
+
print(model, flush=True)
|
| 106 |
+
print(
|
| 107 |
+
"[addition debug] batch_format "
|
| 108 |
+
f"stage={stage} input_shape={tuple(debug_batch.input_ids.shape)} "
|
| 109 |
+
f"target_digits_shape={tuple(debug_batch.target_digits.shape)} "
|
| 110 |
+
f"target_mask_shape={tuple(debug_batch.target_digit_mask.shape)} "
|
| 111 |
+
f"target_final_carry_shape={tuple(debug_batch.target_final_carry.shape)} "
|
| 112 |
+
f"digit_logits_shape={tuple(outputs.digit_logits.shape)} "
|
| 113 |
+
f"final_carry_logits_shape={tuple(outputs.final_carry_logits.shape)} "
|
| 114 |
+
f"output_hidden_shape={tuple(outputs.output_hidden.shape)}",
|
| 115 |
+
flush=True,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def _print_validation_samples(
|
| 121 |
+
model: nn.Module,
|
| 122 |
+
config: ExperimentConfig,
|
| 123 |
+
problems,
|
| 124 |
+
*,
|
| 125 |
+
stage: int,
|
| 126 |
+
device: str,
|
| 127 |
+
limit: int = 3,
|
| 128 |
+
) -> None:
|
| 129 |
+
sample_problems = list(problems[:limit])
|
| 130 |
+
if not sample_problems:
|
| 131 |
+
return
|
| 132 |
+
batch = build_batch(problems=sample_problems, radix=config.radix, device=device)
|
| 133 |
+
outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
|
| 134 |
+
predicted_digits = outputs.digit_logits.argmax(dim=-1).cpu().tolist()
|
| 135 |
+
predicted_final_carry = outputs.final_carry_logits.argmax(dim=-1).cpu().tolist()
|
| 136 |
+
|
| 137 |
+
for example_index, problem in enumerate(sample_problems):
|
| 138 |
+
truth_digits = problem.sum_digits[:stage]
|
| 139 |
+
truth_final_carry = problem.carry_out[stage - 1]
|
| 140 |
+
pred_digits = predicted_digits[example_index][:stage]
|
| 141 |
+
pred_final_carry = int(predicted_final_carry[example_index])
|
| 142 |
+
exact = exact_sum_matches(
|
| 143 |
+
predicted_digits=pred_digits,
|
| 144 |
+
predicted_final_carry=pred_final_carry,
|
| 145 |
+
truth_digits=truth_digits,
|
| 146 |
+
truth_final_carry=truth_final_carry,
|
| 147 |
+
)
|
| 148 |
+
a_text = digits_to_string(problem.a_digits[:stage], final_carry=0, radix=config.radix)
|
| 149 |
+
b_text = digits_to_string(problem.b_digits[:stage], final_carry=0, radix=config.radix)
|
| 150 |
+
pred_text = digits_to_string(pred_digits, final_carry=pred_final_carry, radix=config.radix)
|
| 151 |
+
truth_text = digits_to_string(truth_digits, final_carry=truth_final_carry, radix=config.radix)
|
| 152 |
+
print(
|
| 153 |
+
f"[addition sample] stage={stage} idx={example_index} "
|
| 154 |
+
f"a={a_text} b={b_text} pred={pred_text} true={truth_text} "
|
| 155 |
+
f"pred_digits={pred_digits} pred_carry={pred_final_carry} "
|
| 156 |
+
f"true_digits={truth_digits} true_carry={truth_final_carry} exact={int(exact)}",
|
| 157 |
+
flush=True,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def run_experiment(config: ExperimentConfig) -> dict[str, Any]:
|
| 162 |
+
directories = ensure_output_dirs(config)
|
| 163 |
+
save_config(config, directories["root"])
|
| 164 |
+
seed_everything(config.seed)
|
| 165 |
+
device = config.device
|
| 166 |
+
model = build_model(config, device=device)
|
| 167 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
|
| 168 |
+
digit_loss_fn = nn.CrossEntropyLoss()
|
| 169 |
+
final_carry_loss_fn = nn.CrossEntropyLoss()
|
| 170 |
+
suite = build_evaluation_suite(config)
|
| 171 |
+
rng = __import__("random").Random(config.seed + 12345)
|
| 172 |
+
history: list[dict[str, Any]] = []
|
| 173 |
+
best_validation = -1.0
|
| 174 |
+
best_checkpoint_path = directories["checkpoints"] / "best.pt"
|
| 175 |
+
last_checkpoint_path = directories["checkpoints"] / "last.pt"
|
| 176 |
+
stage = config.initial_stage if config.uses_curriculum else config.train_max_digits
|
| 177 |
+
stage_steps = 0
|
| 178 |
+
global_step = 0
|
| 179 |
+
stop_reason = "train_steps_exhausted"
|
| 180 |
+
wandb_run = _maybe_init_wandb(config, directories["root"])
|
| 181 |
+
started_at = time.time()
|
| 182 |
+
param_counts = describe_model(config)
|
| 183 |
+
print(
|
| 184 |
+
f"[addition train] model={config.model} seed={config.seed} device={device} "
|
| 185 |
+
f"params={param_counts['total_params']} stage={stage}",
|
| 186 |
+
flush=True,
|
| 187 |
+
)
|
| 188 |
+
_print_model_debug_format(model=model, config=config, stage=stage, rng=rng, device=device)
|
| 189 |
+
|
| 190 |
+
while global_step < config.train_steps:
|
| 191 |
+
model.train()
|
| 192 |
+
batch = sample_training_batch(config=config, stage=stage, rng=rng, device=device)
|
| 193 |
+
optimizer.zero_grad(set_to_none=True)
|
| 194 |
+
outputs = model(batch.input_ids, latent_steps=config.latent_steps_for_stage(stage), return_attention=False)
|
| 195 |
+
digit_loss = _masked_digit_loss(
|
| 196 |
+
logits=outputs.digit_logits,
|
| 197 |
+
targets=batch.target_digits,
|
| 198 |
+
mask=batch.target_digit_mask,
|
| 199 |
+
loss_fn=digit_loss_fn,
|
| 200 |
+
)
|
| 201 |
+
final_carry_loss = final_carry_loss_fn(outputs.final_carry_logits, batch.target_final_carry)
|
| 202 |
+
loss = digit_loss + final_carry_loss
|
| 203 |
+
loss.backward()
|
| 204 |
+
if config.grad_clip_norm > 0:
|
| 205 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
|
| 206 |
+
optimizer.step()
|
| 207 |
+
|
| 208 |
+
global_step += 1
|
| 209 |
+
stage_steps += 1
|
| 210 |
+
|
| 211 |
+
if global_step % max(1, config.validation_interval // 2) == 0:
|
| 212 |
+
train_message = (
|
| 213 |
+
f"[addition train] step={global_step} stage={stage} "
|
| 214 |
+
f"loss={loss.item():.4f} digit_loss={digit_loss.item():.4f} "
|
| 215 |
+
f"final_carry_loss={final_carry_loss.item():.4f}"
|
| 216 |
+
)
|
| 217 |
+
print(train_message, flush=True)
|
| 218 |
+
|
| 219 |
+
should_validate = (
|
| 220 |
+
global_step % config.validation_interval == 0
|
| 221 |
+
or global_step == config.train_steps
|
| 222 |
+
or (config.uses_curriculum and stage_steps == config.max_steps_per_stage)
|
| 223 |
+
)
|
| 224 |
+
if not should_validate:
|
| 225 |
+
continue
|
| 226 |
+
|
| 227 |
+
validation = _evaluate_current_stage(model=model, config=config, suite=suite, stage=stage, device=device)
|
| 228 |
+
history_entry = {
|
| 229 |
+
"global_step": global_step,
|
| 230 |
+
"stage": stage,
|
| 231 |
+
"stage_steps": stage_steps,
|
| 232 |
+
"loss": float(loss.item()),
|
| 233 |
+
"digit_loss": float(digit_loss.item()),
|
| 234 |
+
"final_carry_loss": float(final_carry_loss.item()),
|
| 235 |
+
"validation_digit_accuracy": validation["digit_accuracy"],
|
| 236 |
+
"validation_final_carry_accuracy": validation["final_carry_accuracy"],
|
| 237 |
+
"validation_exact_match": validation["exact_match"],
|
| 238 |
+
"latent_steps": config.latent_steps_for_stage(stage),
|
| 239 |
+
}
|
| 240 |
+
history.append(history_entry)
|
| 241 |
+
print(
|
| 242 |
+
f"[addition val] step={global_step} stage={stage} "
|
| 243 |
+
f"digit_acc={validation['digit_accuracy']:.4f} final_carry_acc={validation['final_carry_accuracy']:.4f} "
|
| 244 |
+
f"exact={validation['exact_match']:.4f}",
|
| 245 |
+
flush=True,
|
| 246 |
+
)
|
| 247 |
+
_print_validation_samples(
|
| 248 |
+
model=model,
|
| 249 |
+
config=config,
|
| 250 |
+
problems=suite.validation_uniform[stage],
|
| 251 |
+
stage=stage,
|
| 252 |
+
device=device,
|
| 253 |
+
)
|
| 254 |
+
if wandb_run is not None:
|
| 255 |
+
payload = {
|
| 256 |
+
"train/loss": float(loss.item()),
|
| 257 |
+
"train/digit_loss": float(digit_loss.item()),
|
| 258 |
+
"train/final_carry_loss": float(final_carry_loss.item()),
|
| 259 |
+
"train/stage": stage,
|
| 260 |
+
"train/latent_steps": config.latent_steps_for_stage(stage),
|
| 261 |
+
"validation/digit_accuracy": validation["digit_accuracy"],
|
| 262 |
+
"validation/final_carry_accuracy": validation["final_carry_accuracy"],
|
| 263 |
+
"validation/exact_match": validation["exact_match"],
|
| 264 |
+
"step": global_step,
|
| 265 |
+
}
|
| 266 |
+
wandb_run.log(payload)
|
| 267 |
+
|
| 268 |
+
if validation["exact_match"] >= best_validation:
|
| 269 |
+
best_validation = validation["exact_match"]
|
| 270 |
+
_save_checkpoint(
|
| 271 |
+
best_checkpoint_path,
|
| 272 |
+
model,
|
| 273 |
+
optimizer,
|
| 274 |
+
metadata={
|
| 275 |
+
"global_step": global_step,
|
| 276 |
+
"stage": stage,
|
| 277 |
+
"best_validation_exact_match": best_validation,
|
| 278 |
+
},
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
reached_threshold = validation["exact_match"] >= config.stage_accuracy_threshold
|
| 282 |
+
reached_cap = stage_steps >= config.max_steps_per_stage
|
| 283 |
+
|
| 284 |
+
if config.uses_curriculum:
|
| 285 |
+
if stage < config.train_max_digits and reached_threshold:
|
| 286 |
+
_save_checkpoint(
|
| 287 |
+
_stage_checkpoint_path(directories["stage_checkpoints"], stage),
|
| 288 |
+
model,
|
| 289 |
+
optimizer,
|
| 290 |
+
metadata={
|
| 291 |
+
"global_step": global_step,
|
| 292 |
+
"stage": stage,
|
| 293 |
+
"validation_exact_match": validation["exact_match"],
|
| 294 |
+
"validation_digit_accuracy": validation["digit_accuracy"],
|
| 295 |
+
"validation_final_carry_accuracy": validation["final_carry_accuracy"],
|
| 296 |
+
},
|
| 297 |
+
)
|
| 298 |
+
print(
|
| 299 |
+
f"[addition curriculum] advance {stage} -> {stage + 1} "
|
| 300 |
+
f"(exact_match={validation['exact_match']:.4f})",
|
| 301 |
+
flush=True,
|
| 302 |
+
)
|
| 303 |
+
stage += 1
|
| 304 |
+
stage_steps = 0
|
| 305 |
+
continue
|
| 306 |
+
if reached_cap and not reached_threshold:
|
| 307 |
+
print(
|
| 308 |
+
f"[addition curriculum] hold stage={stage} after {stage_steps} steps "
|
| 309 |
+
f"(exact_match={validation['exact_match']:.4f} < threshold={config.stage_accuracy_threshold:.2f})",
|
| 310 |
+
flush=True,
|
| 311 |
+
)
|
| 312 |
+
if stage == config.train_max_digits and reached_threshold:
|
| 313 |
+
stop_reason = "final_stage_threshold"
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
_save_checkpoint(
|
| 317 |
+
last_checkpoint_path,
|
| 318 |
+
model,
|
| 319 |
+
optimizer,
|
| 320 |
+
metadata={
|
| 321 |
+
"global_step": global_step,
|
| 322 |
+
"stage": stage,
|
| 323 |
+
"stop_reason": stop_reason,
|
| 324 |
+
},
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
best_payload = torch.load(best_checkpoint_path, map_location=device)
|
| 328 |
+
model.load_state_dict(best_payload["model_state"])
|
| 329 |
+
final_results = evaluate_suite(model=model, config=config, suite=suite, device=device)
|
| 330 |
+
flat_final_metrics = flatten_nested_metrics("", final_results)
|
| 331 |
+
summary = {
|
| 332 |
+
"config": config.__dict__,
|
| 333 |
+
"param_counts": param_counts,
|
| 334 |
+
"best_validation_exact_match": best_validation,
|
| 335 |
+
"global_step": global_step,
|
| 336 |
+
"final_stage": stage,
|
| 337 |
+
"stop_reason": stop_reason,
|
| 338 |
+
"elapsed_seconds": time.time() - started_at,
|
| 339 |
+
"history": history,
|
| 340 |
+
"final_results": final_results,
|
| 341 |
+
"flat_final_metrics": flat_final_metrics,
|
| 342 |
+
}
|
| 343 |
+
_save_json(directories["artifacts"] / "summary.json", summary)
|
| 344 |
+
with (directories["artifacts"] / "history.jsonl").open("w", encoding="utf-8") as handle:
|
| 345 |
+
for entry in history:
|
| 346 |
+
handle.write(json.dumps(entry, sort_keys=True) + "\n")
|
| 347 |
+
plot_single_run_results(summary, directories["plots"])
|
| 348 |
+
|
| 349 |
+
if wandb_run is not None:
|
| 350 |
+
wandb_run.log(flat_final_metrics | {"step": global_step})
|
| 351 |
+
wandb_run.summary.update(
|
| 352 |
+
{
|
| 353 |
+
"best_validation_exact_match": best_validation,
|
| 354 |
+
"final_stage": stage,
|
| 355 |
+
"stop_reason": stop_reason,
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
wandb_run.finish()
|
| 359 |
+
|
| 360 |
+
return summary
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def main() -> None:
|
| 364 |
+
config = parse_config("Train the addition carry experiment.")
|
| 365 |
+
run_experiment(config)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if __name__ == "__main__":
|
| 369 |
+
main()
|
aligned_cell_policy/shared_cell_policy.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Dict, List
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
+
PARENT_DIR = os.path.dirname(CURRENT_DIR)
|
| 13 |
+
if PARENT_DIR not in sys.path:
|
| 14 |
+
sys.path.insert(0, PARENT_DIR)
|
| 15 |
+
|
| 16 |
+
from format_utils_icon import parse_n_value_prediction
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class CellExample:
|
| 21 |
+
grid: np.ndarray
|
| 22 |
+
target_cell: tuple[int, int]
|
| 23 |
+
target_value: int
|
| 24 |
+
turn_idx: int
|
| 25 |
+
total_turns: int
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_TUPLE_PROMPT_RE = re.compile(r"\((\d+),(\d+),(\d+)\)")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_grid_from_tuple_prompt(prompt_text: str) -> np.ndarray:
|
| 32 |
+
triples = _TUPLE_PROMPT_RE.findall(str(prompt_text))
|
| 33 |
+
if len(triples) < 81:
|
| 34 |
+
raise ValueError("Could not recover 81 (row,col,value) tuples from prompt.")
|
| 35 |
+
grid = np.zeros((9, 9), dtype=int)
|
| 36 |
+
for rr, cc, vv in triples[:81]:
|
| 37 |
+
grid[int(rr) - 1, int(cc) - 1] = int(vv)
|
| 38 |
+
return grid
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def build_cell_examples_from_row(row: Dict[str, Any]) -> List[CellExample]:
|
| 42 |
+
prompt = str(row["prompt"])
|
| 43 |
+
grid = parse_grid_from_tuple_prompt(prompt)
|
| 44 |
+
metadata = dict(row.get("metadata", {}))
|
| 45 |
+
empty_locs = metadata.get("empty_locs_1based")
|
| 46 |
+
target_triples = metadata.get("target_triples_1based")
|
| 47 |
+
|
| 48 |
+
if not empty_locs or not target_triples:
|
| 49 |
+
completion = str(row.get("completion", ""))
|
| 50 |
+
parsed, _ = parse_n_value_prediction(completion, int(metadata.get("empties", 0) or 0))
|
| 51 |
+
if parsed is None:
|
| 52 |
+
raise ValueError("Row is missing metadata and completion could not be parsed.")
|
| 53 |
+
empty_locs = [(r + 1, c + 1) for r, c in np.argwhere(grid == 0).tolist()]
|
| 54 |
+
target_triples = [(int(r), int(c), int(v)) for (r, c), v in zip(empty_locs, parsed)]
|
| 55 |
+
|
| 56 |
+
total_turns = len(target_triples)
|
| 57 |
+
out: List[CellExample] = []
|
| 58 |
+
for idx, triple in enumerate(target_triples, start=1):
|
| 59 |
+
rr, cc, value = int(triple[0]) - 1, int(triple[1]) - 1, int(triple[2])
|
| 60 |
+
out.append(
|
| 61 |
+
CellExample(
|
| 62 |
+
grid=np.asarray(grid, dtype=int).copy(),
|
| 63 |
+
target_cell=(rr, cc),
|
| 64 |
+
target_value=value,
|
| 65 |
+
turn_idx=idx,
|
| 66 |
+
total_turns=total_turns,
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
return out
|
analysis/eval_saved_hard9x9_checkpoints.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from peft import PeftModel
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
|
| 12 |
+
if ROOT := "/home/ubuntu/curriculum_cot":
|
| 13 |
+
if ROOT not in sys.path:
|
| 14 |
+
sys.path.insert(0, ROOT)
|
| 15 |
+
|
| 16 |
+
from multi_output_cell_policy import grpo_multi_output_train as baseline_grpo
|
| 17 |
+
from multi_output_cell_policy import sft_multi_output_train as baseline_sft
|
| 18 |
+
from latent_multi_output_cell_policy import grpo_residual_projector_latent_train as latent_grpo
|
| 19 |
+
from latent_multi_output_cell_policy import residual_projector_warmstart_sft_latent_multi_output_train as latent_sft
|
| 20 |
+
|
| 21 |
+
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 22 |
+
CACHE_DIR = os.path.join(ROOT, ".hf_cache")
|
| 23 |
+
DATA_PATH = os.path.join(ROOT, "data", "sudoku_t3_30empty_value_qwen_text.jsonl")
|
| 24 |
+
EVAL_ROWS = 20
|
| 25 |
+
TOTAL_EMPTIES_HINT = 30
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def make_tokenizer() -> AutoTokenizer:
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, use_fast=True)
|
| 30 |
+
if tokenizer.pad_token_id is None:
|
| 31 |
+
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
|
| 32 |
+
return tokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def make_device() -> torch.device:
|
| 36 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_baseline_sft_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
|
| 40 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
MODEL_NAME,
|
| 42 |
+
cache_dir=CACHE_DIR,
|
| 43 |
+
torch_dtype=baseline_sft.pick_dtype() if torch.cuda.is_available() else torch.float32,
|
| 44 |
+
low_cpu_mem_usage=True,
|
| 45 |
+
)
|
| 46 |
+
model = PeftModel.from_pretrained(base, checkpoint_dir, is_trainable=False)
|
| 47 |
+
model.to(device)
|
| 48 |
+
model.eval()
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def make_baseline_grpo_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
|
| 53 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 54 |
+
MODEL_NAME,
|
| 55 |
+
cache_dir=CACHE_DIR,
|
| 56 |
+
torch_dtype=baseline_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
|
| 57 |
+
low_cpu_mem_usage=True,
|
| 58 |
+
)
|
| 59 |
+
model = baseline_grpo.load_trainable_adapter(base, checkpoint_dir)
|
| 60 |
+
model.to(device)
|
| 61 |
+
model.eval()
|
| 62 |
+
return model
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def make_latent_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
|
| 66 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 67 |
+
MODEL_NAME,
|
| 68 |
+
cache_dir=CACHE_DIR,
|
| 69 |
+
torch_dtype=latent_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
|
| 70 |
+
low_cpu_mem_usage=True,
|
| 71 |
+
)
|
| 72 |
+
model = latent_grpo.load_trainable_adapter(base, checkpoint_dir)
|
| 73 |
+
projector_hidden = latent_grpo.infer_projector_hidden_from_state(checkpoint_dir) or latent_grpo.PROJECTOR_HIDDEN
|
| 74 |
+
latent_grpo.attach_residual_projector_modules(
|
| 75 |
+
model,
|
| 76 |
+
hidden_size=int(latent_grpo.unwrap_backbone(model).config.hidden_size),
|
| 77 |
+
projector_hidden=projector_hidden,
|
| 78 |
+
)
|
| 79 |
+
latent_grpo.maybe_load_projector_state(model, checkpoint_dir)
|
| 80 |
+
model.to(device)
|
| 81 |
+
model.eval()
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def common_reward_args() -> dict:
|
| 86 |
+
return {
|
| 87 |
+
"reward_good_value": 1.0,
|
| 88 |
+
"penalty_bad_value": 1.75,
|
| 89 |
+
"penalty_malformed": 4.0,
|
| 90 |
+
"penalty_empty": 0.5,
|
| 91 |
+
"penalty_singleton": 1.5,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def eval_baseline_sft(checkpoint_dir: str, stage_i: int) -> dict:
|
| 96 |
+
device = make_device()
|
| 97 |
+
tokenizer = make_tokenizer()
|
| 98 |
+
model = make_baseline_sft_model(checkpoint_dir, device)
|
| 99 |
+
rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
|
| 100 |
+
args = SimpleNamespace(
|
| 101 |
+
stage_i=int(stage_i),
|
| 102 |
+
total_empties_hint=TOTAL_EMPTIES_HINT,
|
| 103 |
+
max_completion_length=24,
|
| 104 |
+
debug_print_limit=0,
|
| 105 |
+
)
|
| 106 |
+
metrics = baseline_sft.run_eval(args, rows, model, tokenizer, device)
|
| 107 |
+
del model
|
| 108 |
+
if torch.cuda.is_available():
|
| 109 |
+
torch.cuda.empty_cache()
|
| 110 |
+
return metrics
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def eval_baseline_grpo(checkpoint_dir: str, stage_i: int) -> dict:
|
| 114 |
+
device = make_device()
|
| 115 |
+
tokenizer = make_tokenizer()
|
| 116 |
+
model = make_baseline_grpo_model(checkpoint_dir, device)
|
| 117 |
+
rows = baseline_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
|
| 118 |
+
args = SimpleNamespace(
|
| 119 |
+
stage_i=int(stage_i),
|
| 120 |
+
total_empties_hint=TOTAL_EMPTIES_HINT,
|
| 121 |
+
max_completion_length=24,
|
| 122 |
+
debug_print_limit=0,
|
| 123 |
+
**common_reward_args(),
|
| 124 |
+
)
|
| 125 |
+
metrics = baseline_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device)
|
| 126 |
+
del model
|
| 127 |
+
if torch.cuda.is_available():
|
| 128 |
+
torch.cuda.empty_cache()
|
| 129 |
+
return metrics
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def eval_latent_sft(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
|
| 133 |
+
device = make_device()
|
| 134 |
+
tokenizer = make_tokenizer()
|
| 135 |
+
model = make_latent_model(checkpoint_dir, device)
|
| 136 |
+
rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
|
| 137 |
+
args = SimpleNamespace(
|
| 138 |
+
stage_i=int(stage_i),
|
| 139 |
+
num_cot_tokens=int(num_cot_tokens),
|
| 140 |
+
total_empties_hint=TOTAL_EMPTIES_HINT,
|
| 141 |
+
max_completion_length=32,
|
| 142 |
+
debug_print_limit=0,
|
| 143 |
+
**common_reward_args(),
|
| 144 |
+
)
|
| 145 |
+
metrics = latent_sft.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
|
| 146 |
+
del model
|
| 147 |
+
if torch.cuda.is_available():
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
return metrics
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def eval_latent_grpo(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
|
| 153 |
+
device = make_device()
|
| 154 |
+
tokenizer = make_tokenizer()
|
| 155 |
+
model = make_latent_model(checkpoint_dir, device)
|
| 156 |
+
rows = latent_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
|
| 157 |
+
args = SimpleNamespace(
|
| 158 |
+
stage_i=int(stage_i),
|
| 159 |
+
num_cot_tokens=int(num_cot_tokens),
|
| 160 |
+
total_empties_hint=TOTAL_EMPTIES_HINT,
|
| 161 |
+
max_completion_length=32,
|
| 162 |
+
debug_print_limit=0,
|
| 163 |
+
**common_reward_args(),
|
| 164 |
+
)
|
| 165 |
+
metrics = latent_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
|
| 166 |
+
del model
|
| 167 |
+
if torch.cuda.is_available():
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
return metrics
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main() -> None:
|
| 173 |
+
# Explicit step dirs (not run roots) so metrics match the agreed endpoints.
|
| 174 |
+
checkpoints = [
|
| 175 |
+
{
|
| 176 |
+
"label": "baseline_stage1_sft",
|
| 177 |
+
"stage_i": 1,
|
| 178 |
+
"kind": "baseline_sft",
|
| 179 |
+
"checkpoint_dir": os.path.join(
|
| 180 |
+
ROOT,
|
| 181 |
+
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/20260404_023600_baseline30_clean/baseline_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty/checkpoint-step-01000",
|
| 182 |
+
),
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"label": "baseline_stage1_grpo",
|
| 186 |
+
"stage_i": 1,
|
| 187 |
+
"kind": "baseline_grpo",
|
| 188 |
+
"checkpoint_dir": os.path.join(
|
| 189 |
+
ROOT,
|
| 190 |
+
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i1_20260404_fixed_baseline_grpo_i1/checkpoint-5350",
|
| 191 |
+
),
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"label": "baseline_stage2_sft",
|
| 195 |
+
"stage_i": 2,
|
| 196 |
+
"kind": "baseline_sft",
|
| 197 |
+
"checkpoint_dir": os.path.join(
|
| 198 |
+
ROOT,
|
| 199 |
+
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft/i2_20260404_stage2_baseline_sft_from_grpo5350/checkpoint-step-13100",
|
| 200 |
+
),
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"label": "baseline_stage2_grpo",
|
| 204 |
+
"stage_i": 2,
|
| 205 |
+
"kind": "baseline_grpo",
|
| 206 |
+
"checkpoint_dir": os.path.join(
|
| 207 |
+
ROOT,
|
| 208 |
+
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i2_20260405_stage2_baseline_grpo_from_sft13100/checkpoint-4325",
|
| 209 |
+
),
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"label": "latent_stage1_sft",
|
| 213 |
+
"stage_i": 1,
|
| 214 |
+
"kind": "latent_sft",
|
| 215 |
+
"num_cot_tokens": 1,
|
| 216 |
+
"checkpoint_dir": os.path.join(
|
| 217 |
+
ROOT,
|
| 218 |
+
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/20260404_013500_latent30_frombaseline/latent_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty_residual_projector/checkpoint-step-00200",
|
| 219 |
+
),
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"label": "latent_stage1_grpo",
|
| 223 |
+
"stage_i": 1,
|
| 224 |
+
"kind": "latent_grpo",
|
| 225 |
+
"num_cot_tokens": 1,
|
| 226 |
+
"checkpoint_dir": os.path.join(
|
| 227 |
+
ROOT,
|
| 228 |
+
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740",
|
| 229 |
+
),
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"label": "latent_stage2_sft",
|
| 233 |
+
"stage_i": 2,
|
| 234 |
+
"kind": "latent_sft",
|
| 235 |
+
"num_cot_tokens": 2,
|
| 236 |
+
"checkpoint_dir": os.path.join(
|
| 237 |
+
ROOT,
|
| 238 |
+
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft/i2_cot2_20260404_stage2_latent_sft_from_grpo2740/checkpoint-step-00700",
|
| 239 |
+
),
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"label": "latent_stage2_grpo",
|
| 243 |
+
"stage_i": 2,
|
| 244 |
+
"kind": "latent_grpo",
|
| 245 |
+
"num_cot_tokens": 2,
|
| 246 |
+
"checkpoint_dir": os.path.join(
|
| 247 |
+
ROOT,
|
| 248 |
+
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i2_cot2_20260405_stage2_latent_grpo_from_sft00700/checkpoint-1620",
|
| 249 |
+
),
|
| 250 |
+
},
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
results: dict[str, dict] = {}
|
| 254 |
+
for item in checkpoints:
|
| 255 |
+
label = item["label"]
|
| 256 |
+
print(f"[eval] starting {label}", flush=True)
|
| 257 |
+
if item["kind"] == "baseline_sft":
|
| 258 |
+
metrics = eval_baseline_sft(item["checkpoint_dir"], item["stage_i"])
|
| 259 |
+
elif item["kind"] == "baseline_grpo":
|
| 260 |
+
metrics = eval_baseline_grpo(item["checkpoint_dir"], item["stage_i"])
|
| 261 |
+
elif item["kind"] == "latent_sft":
|
| 262 |
+
metrics = eval_latent_sft(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
|
| 263 |
+
else:
|
| 264 |
+
metrics = eval_latent_grpo(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
|
| 265 |
+
results[label] = metrics
|
| 266 |
+
print(json.dumps({"label": label, "metrics": metrics}, sort_keys=True), flush=True)
|
| 267 |
+
|
| 268 |
+
print("[eval] complete", flush=True)
|
| 269 |
+
print(json.dumps(results, sort_keys=True, indent=2), flush=True)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
main()
|
checkpoint_utils.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from peft import get_peft_model_state_dict
|
| 9 |
+
from safetensors.torch import save_file as save_safetensors_file
|
| 10 |
+
|
| 11 |
+
FINAL_CHECKPOINT_DIRNAME = "final_checkpoint"
|
| 12 |
+
_WEIGHT_FILENAMES = (
|
| 13 |
+
"adapter_model.safetensors",
|
| 14 |
+
"adapter_model.bin",
|
| 15 |
+
"model.safetensors",
|
| 16 |
+
"pytorch_model.bin",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def ensure_final_checkpoint_dir(output_dir: str) -> str:
|
| 21 |
+
repo_root = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
output_dir_abs = os.path.abspath(output_dir)
|
| 23 |
+
try:
|
| 24 |
+
rel_output_dir = os.path.relpath(output_dir_abs, repo_root)
|
| 25 |
+
except Exception:
|
| 26 |
+
rel_output_dir = os.path.basename(output_dir_abs.rstrip(os.sep))
|
| 27 |
+
rel_parts = [part for part in rel_output_dir.split(os.sep) if part not in ("", ".")]
|
| 28 |
+
if rel_parts and rel_parts[0] == FINAL_CHECKPOINT_DIRNAME:
|
| 29 |
+
rel_parts = rel_parts[1:]
|
| 30 |
+
if rel_parts and rel_parts[0] == "checkpoints":
|
| 31 |
+
rel_parts = rel_parts[1:]
|
| 32 |
+
if not rel_parts:
|
| 33 |
+
rel_parts = [os.path.basename(output_dir_abs.rstrip(os.sep)) or "run"]
|
| 34 |
+
final_dir = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *rel_parts)
|
| 35 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 36 |
+
return final_dir
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def final_checkpoint_root(*parts: str) -> str:
|
| 40 |
+
repo_root = os.path.dirname(os.path.abspath(__file__))
|
| 41 |
+
root = os.path.join(repo_root, FINAL_CHECKPOINT_DIRNAME, *parts)
|
| 42 |
+
os.makedirs(root, exist_ok=True)
|
| 43 |
+
return root
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def normalize_to_final_checkpoint_root(path: str, *default_parts: str) -> str:
|
| 47 |
+
raw = str(path or "").strip()
|
| 48 |
+
if not raw:
|
| 49 |
+
return final_checkpoint_root(*default_parts)
|
| 50 |
+
abs_path = os.path.abspath(raw)
|
| 51 |
+
repo_root = os.path.dirname(os.path.abspath(__file__))
|
| 52 |
+
rel_path = os.path.relpath(abs_path, repo_root)
|
| 53 |
+
rel_parts = [part for part in rel_path.split(os.sep) if part not in ("", ".")]
|
| 54 |
+
if rel_parts[:1] == [FINAL_CHECKPOINT_DIRNAME]:
|
| 55 |
+
return abs_path
|
| 56 |
+
if rel_parts[:1] == ["checkpoints"]:
|
| 57 |
+
rel_parts = rel_parts[1:]
|
| 58 |
+
return final_checkpoint_root(*rel_parts)
|
| 59 |
+
return abs_path
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _has_saved_weights(target_dir: str) -> bool:
|
| 63 |
+
return any(os.path.exists(os.path.join(target_dir, name)) for name in _WEIGHT_FILENAMES)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _fallback_save_adapter_weights(model: Any, target_dir: str) -> None:
|
| 67 |
+
if _has_saved_weights(target_dir):
|
| 68 |
+
return
|
| 69 |
+
state = get_peft_model_state_dict(model)
|
| 70 |
+
cpu_state = {
|
| 71 |
+
key: value.detach().cpu().contiguous()
|
| 72 |
+
for key, value in state.items()
|
| 73 |
+
if torch.is_tensor(value)
|
| 74 |
+
}
|
| 75 |
+
if cpu_state:
|
| 76 |
+
save_safetensors_file(cpu_state, os.path.join(target_dir, "adapter_model.safetensors"))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def save_model_artifacts(
|
| 80 |
+
model: Any,
|
| 81 |
+
tokenizer: Any,
|
| 82 |
+
target_dir: str,
|
| 83 |
+
*,
|
| 84 |
+
extra_save_fn: Callable[[Any, str], None] | None = None,
|
| 85 |
+
) -> str:
|
| 86 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 87 |
+
model.save_pretrained(target_dir)
|
| 88 |
+
if tokenizer is not None:
|
| 89 |
+
tokenizer.save_pretrained(target_dir)
|
| 90 |
+
_fallback_save_adapter_weights(model, target_dir)
|
| 91 |
+
if extra_save_fn is not None:
|
| 92 |
+
extra_save_fn(model, target_dir)
|
| 93 |
+
return target_dir
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _replace_dir_contents(src_dir: str, dst_dir: str) -> None:
|
| 97 |
+
os.makedirs(dst_dir, exist_ok=True)
|
| 98 |
+
src_dir_abs = os.path.abspath(src_dir)
|
| 99 |
+
for name in os.listdir(dst_dir):
|
| 100 |
+
path = os.path.join(dst_dir, name)
|
| 101 |
+
if os.path.abspath(path) == src_dir_abs:
|
| 102 |
+
continue
|
| 103 |
+
if os.path.isdir(path) and not os.path.islink(path):
|
| 104 |
+
shutil.rmtree(path)
|
| 105 |
+
else:
|
| 106 |
+
os.unlink(path)
|
| 107 |
+
for name in os.listdir(src_dir):
|
| 108 |
+
src_path = os.path.join(src_dir, name)
|
| 109 |
+
dst_path = os.path.join(dst_dir, name)
|
| 110 |
+
if os.path.isdir(src_path) and not os.path.islink(src_path):
|
| 111 |
+
shutil.copytree(src_path, dst_path)
|
| 112 |
+
else:
|
| 113 |
+
shutil.copy2(src_path, dst_path)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def save_checkpoint_and_update_final(
|
| 117 |
+
model: Any,
|
| 118 |
+
tokenizer: Any,
|
| 119 |
+
output_dir: str,
|
| 120 |
+
checkpoint_name: str,
|
| 121 |
+
*,
|
| 122 |
+
extra_save_fn: Callable[[Any, str], None] | None = None,
|
| 123 |
+
) -> str:
|
| 124 |
+
checkpoint_dir = os.path.join(output_dir, checkpoint_name)
|
| 125 |
+
save_model_artifacts(model, tokenizer, checkpoint_dir, extra_save_fn=extra_save_fn)
|
| 126 |
+
_replace_dir_contents(checkpoint_dir, ensure_final_checkpoint_dir(output_dir))
|
| 127 |
+
return checkpoint_dir
|
format_utils_icon.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_INT_RE = re.compile(r"-?\d+")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def grid_to_text(grid_9x9: np.ndarray) -> str:
|
| 14 |
+
grid = np.asarray(grid_9x9, dtype=int).reshape(9, 9)
|
| 15 |
+
return "\n".join(" ".join(str(int(value)) for value in row) for row in grid.tolist())
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def parse_n_value_prediction(text: str, n: int) -> Tuple[List[int] | None, bool]:
|
| 19 |
+
raw = str(text or "").strip()
|
| 20 |
+
if not raw:
|
| 21 |
+
return None, False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
parsed = json.loads(raw)
|
| 25 |
+
if isinstance(parsed, dict) and isinstance(parsed.get("values"), list):
|
| 26 |
+
values = [int(v) for v in parsed["values"]]
|
| 27 |
+
if len(values) == int(n):
|
| 28 |
+
return values, True
|
| 29 |
+
if isinstance(parsed, list):
|
| 30 |
+
values = [int(v) for v in parsed]
|
| 31 |
+
if len(values) == int(n):
|
| 32 |
+
return values, True
|
| 33 |
+
except Exception:
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
values = [int(match.group(0)) for match in _INT_RE.finditer(raw)]
|
| 37 |
+
if len(values) == int(n):
|
| 38 |
+
return values, True
|
| 39 |
+
return None, False
|
formatting_icon.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_consistent_pair(grid, *, cell: int, value: int, t: int = 3, n: int = 9) -> bool:
|
| 7 |
+
g = np.asarray(grid, dtype=int).reshape(int(n), int(n))
|
| 8 |
+
cell = int(cell)
|
| 9 |
+
value = int(value)
|
| 10 |
+
if value < 1 or value > int(n):
|
| 11 |
+
return False
|
| 12 |
+
rr, cc = divmod(cell, int(n))
|
| 13 |
+
current = int(g[rr, cc])
|
| 14 |
+
if current != 0 and current != value:
|
| 15 |
+
return False
|
| 16 |
+
|
| 17 |
+
row = g[rr, :]
|
| 18 |
+
for idx, existing in enumerate(row):
|
| 19 |
+
if idx != cc and int(existing) == value:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
col = g[:, cc]
|
| 23 |
+
for idx, existing in enumerate(col):
|
| 24 |
+
if idx != rr and int(existing) == value:
|
| 25 |
+
return False
|
| 26 |
+
|
| 27 |
+
box_r = (rr // int(t)) * int(t)
|
| 28 |
+
box_c = (cc // int(t)) * int(t)
|
| 29 |
+
for r in range(box_r, box_r + int(t)):
|
| 30 |
+
for c in range(box_c, box_c + int(t)):
|
| 31 |
+
if (r != rr or c != cc) and int(g[r, c]) == value:
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
return True
|
hard_9x9_10empty/launch_baseline_stage3_pipeline.sh
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 8 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 9 |
+
PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py"
|
| 10 |
+
|
| 11 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_10empty_value_qwen_text_longrun.jsonl}"
|
| 12 |
+
NUM_PUZZLES="${NUM_PUZZLES:-5000}"
|
| 13 |
+
DATASET_SEED="${DATASET_SEED:-0}"
|
| 14 |
+
|
| 15 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 16 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 17 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 18 |
+
|
| 19 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_qwen05b/baseline}"
|
| 20 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_10empty_3stage_hard9x9}"
|
| 21 |
+
|
| 22 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 23 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 24 |
+
WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}"
|
| 25 |
+
WAIT_SECONDS="${WAIT_SECONDS:-60}"
|
| 26 |
+
|
| 27 |
+
if [[ ! -f "${TRAIN_JSONL}" ]]; then
|
| 28 |
+
mkdir -p "$(dirname "${TRAIN_JSONL}")"
|
| 29 |
+
printf 'Building 10-empty dataset: %s\n' "${TRAIN_JSONL}"
|
| 30 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 31 |
+
--output "${TRAIN_JSONL}" \
|
| 32 |
+
--num_puzzles "${NUM_PUZZLES}" \
|
| 33 |
+
--empties 10 \
|
| 34 |
+
--seed "${DATASET_SEED}"
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then
|
| 38 |
+
while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do
|
| 39 |
+
printf 'Existing training detected; waiting %ss before launching 10-empty baseline pipeline...\n' "${WAIT_SECONDS}"
|
| 40 |
+
sleep "${WAIT_SECONDS}"
|
| 41 |
+
done
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
mkdir -p "${CHECKPOINT_ROOT}"
|
| 45 |
+
|
| 46 |
+
cmd=(
|
| 47 |
+
"${PYTHON_BIN}" "${PIPELINE}"
|
| 48 |
+
--python_executable "${PYTHON_BIN}"
|
| 49 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 50 |
+
--cache_dir "${ROOT}/.hf_cache"
|
| 51 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct"
|
| 52 |
+
--checkpoint_root "${CHECKPOINT_ROOT}"
|
| 53 |
+
--output_root "${OUTPUT_ROOT}"
|
| 54 |
+
--run_tag "${RUN_TAG}"
|
| 55 |
+
--min_stage 1
|
| 56 |
+
--max_stage 3
|
| 57 |
+
--distributed_gpu_ids "${GPU_IDS}"
|
| 58 |
+
--sft_num_processes "${NUM_PROCESSES}"
|
| 59 |
+
--grpo_num_processes "${NUM_PROCESSES}"
|
| 60 |
+
--total_empties_hint 10
|
| 61 |
+
--limit_train_rows 5000
|
| 62 |
+
--sft_num_epochs 3.0
|
| 63 |
+
--grpo_num_train_epochs 1.5
|
| 64 |
+
--sft_gradient_accumulation_steps 8
|
| 65 |
+
--grpo_per_device_train_batch_size 8
|
| 66 |
+
--grpo_gradient_accumulation_steps 2
|
| 67 |
+
--grpo_num_generations 4
|
| 68 |
+
--sft_enable_gradient_checkpointing
|
| 69 |
+
--grpo_enable_gradient_checkpointing
|
| 70 |
+
--sft_eval_steps 100
|
| 71 |
+
--sft_save_steps 100
|
| 72 |
+
--grpo_eval_steps 50
|
| 73 |
+
--grpo_save_steps 50
|
| 74 |
+
--sft_eval_rows 100
|
| 75 |
+
--grpo_eval_rows 100
|
| 76 |
+
--sft_stage_max_steps "1:2000,2:2000,3:2000"
|
| 77 |
+
--grpo_stage_max_steps "1:1200,2:1200,3:1200"
|
| 78 |
+
--sft_eval_solve_rate_stop 0.8
|
| 79 |
+
--sft_min_steps_before_stop 100
|
| 80 |
+
--grpo_eval_solve_rate_stop 0.8
|
| 81 |
+
--grpo_min_steps_before_stop 50
|
| 82 |
+
--grpo_reward_good_value 1.25
|
| 83 |
+
--grpo_penalty_bad_value 1.0
|
| 84 |
+
--grpo_penalty_malformed 4.0
|
| 85 |
+
--grpo_penalty_empty 0.5
|
| 86 |
+
--grpo_penalty_singleton 1.0
|
| 87 |
+
--phase_max_wall_clock_seconds 36000
|
| 88 |
+
--wandb_mode "${WANDB_MODE}"
|
| 89 |
+
--use_wandb
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if [[ -n "${WANDB_ENTITY}" ]]; then
|
| 93 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 94 |
+
fi
|
| 95 |
+
|
| 96 |
+
printf 'Launching 10-empty baseline stage-3 pipeline\n'
|
| 97 |
+
printf 'Dataset: %s\n' "${TRAIN_JSONL}"
|
| 98 |
+
printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}"
|
| 99 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 100 |
+
printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
|
| 101 |
+
|
| 102 |
+
exec "${cmd[@]}"
|
hard_9x9_15empty/launch_baseline_pipeline.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 8 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 9 |
+
PIPELINE_LAUNCHER="${ROOT}/large_baseline_extension/launch_nonlocation_pipeline.sh"
|
| 10 |
+
|
| 11 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text.jsonl}"
|
| 12 |
+
NUM_PUZZLES="${NUM_PUZZLES:-20000}"
|
| 13 |
+
DATASET_SEED="${DATASET_SEED:-0}"
|
| 14 |
+
|
| 15 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 16 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 17 |
+
MIN_STAGE="${MIN_STAGE:-1}"
|
| 18 |
+
MAX_STAGE="${MAX_STAGE:-4}"
|
| 19 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 20 |
+
|
| 21 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline}"
|
| 22 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_15empty_4stage_hard9x9}"
|
| 23 |
+
|
| 24 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 25 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 26 |
+
WAIT_FOR_EXISTING_TRAINING="${WAIT_FOR_EXISTING_TRAINING:-1}"
|
| 27 |
+
WAIT_SECONDS="${WAIT_SECONDS:-60}"
|
| 28 |
+
|
| 29 |
+
if [[ ! -f "${TRAIN_JSONL}" ]]; then
|
| 30 |
+
mkdir -p "$(dirname "${TRAIN_JSONL}")"
|
| 31 |
+
printf 'Building 15-empty dataset: %s\n' "${TRAIN_JSONL}"
|
| 32 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 33 |
+
--output "${TRAIN_JSONL}" \
|
| 34 |
+
--num_puzzles "${NUM_PUZZLES}" \
|
| 35 |
+
--empties 15 \
|
| 36 |
+
--seed "${DATASET_SEED}"
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
if [[ "${WAIT_FOR_EXISTING_TRAINING}" == "1" ]]; then
|
| 40 |
+
while pgrep -f "/home/ubuntu/curriculum_cot/.venv/bin/python.*(run_baseline_multi_output_pipeline_resume.py|run_latent_residual_projector_pipeline.py|sft_multi_output_train.py|grpo_multi_output_train.py|residual_projector_warmstart_sft_latent_multi_output_train.py|grpo_residual_projector_latent_train.py)" >/dev/null; do
|
| 41 |
+
printf 'Existing training detected; waiting %ss before launching 15-empty baseline...\n' "${WAIT_SECONDS}"
|
| 42 |
+
sleep "${WAIT_SECONDS}"
|
| 43 |
+
done
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
mkdir -p "${CHECKPOINT_ROOT}"
|
| 47 |
+
|
| 48 |
+
export TRAIN_JSONL
|
| 49 |
+
export TOTAL_EMPTIES_HINT=15
|
| 50 |
+
export GPU_IDS
|
| 51 |
+
export NUM_PROCESSES
|
| 52 |
+
export MIN_STAGE
|
| 53 |
+
export MAX_STAGE
|
| 54 |
+
export RUN_TAG
|
| 55 |
+
export CHECKPOINT_ROOT
|
| 56 |
+
export OUTPUT_ROOT
|
| 57 |
+
export WANDB_MODE
|
| 58 |
+
export WANDB_ENTITY
|
| 59 |
+
|
| 60 |
+
printf 'Launching 15-empty hard 9x9 baseline pipeline\n'
|
| 61 |
+
printf 'Dataset: %s\n' "${TRAIN_JSONL}"
|
| 62 |
+
printf 'Checkpoint root: %s\n' "${CHECKPOINT_ROOT}"
|
| 63 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 64 |
+
|
| 65 |
+
exec "${PIPELINE_LAUNCHER}"
|
hard_9x9_15empty_multivalue_stage1/launch_stage1_size2_sft.sh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 8 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 9 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 10 |
+
|
| 11 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_train.jsonl}"
|
| 12 |
+
EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_15empty_value_qwen_text_stage1_eval.jsonl}"
|
| 13 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 14 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-2000}"
|
| 15 |
+
TRAIN_SEED="${TRAIN_SEED:-0}"
|
| 16 |
+
EVAL_SEED="${EVAL_SEED:-1}"
|
| 17 |
+
|
| 18 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 19 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 20 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 21 |
+
|
| 22 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_15empty_qwen05b/baseline_stage1_multivalue}"
|
| 23 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_15empty_size2only}"
|
| 24 |
+
|
| 25 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 26 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 27 |
+
|
| 28 |
+
if [[ ! -f "${TRAIN_JSONL}" ]]; then
|
| 29 |
+
mkdir -p "$(dirname "${TRAIN_JSONL}")"
|
| 30 |
+
printf 'Building 15-empty train dataset: %s\n' "${TRAIN_JSONL}"
|
| 31 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 32 |
+
--output "${TRAIN_JSONL}" \
|
| 33 |
+
--num_puzzles "${TRAIN_PUZZLES}" \
|
| 34 |
+
--empties 15 \
|
| 35 |
+
--seed "${TRAIN_SEED}"
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
if [[ ! -f "${EVAL_JSONL}" ]]; then
|
| 39 |
+
mkdir -p "$(dirname "${EVAL_JSONL}")"
|
| 40 |
+
printf 'Building 15-empty eval dataset: %s\n' "${EVAL_JSONL}"
|
| 41 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 42 |
+
--output "${EVAL_JSONL}" \
|
| 43 |
+
--num_puzzles "${EVAL_PUZZLES}" \
|
| 44 |
+
--empties 15 \
|
| 45 |
+
--seed "${EVAL_SEED}"
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
mkdir -p "${CHECKPOINT_ROOT}"
|
| 49 |
+
|
| 50 |
+
cmd=(
|
| 51 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}"
|
| 52 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct"
|
| 53 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 54 |
+
--eval_jsonl "${EVAL_JSONL}"
|
| 55 |
+
--output_dir "${OUTPUT_DIR}"
|
| 56 |
+
--cache_dir "${ROOT}/.hf_cache"
|
| 57 |
+
--seed 0
|
| 58 |
+
--gpu_id 0
|
| 59 |
+
--stage_i 1
|
| 60 |
+
--total_empties_hint 15
|
| 61 |
+
--per_device_train_batch_size 16
|
| 62 |
+
--gradient_accumulation_steps 2
|
| 63 |
+
--num_epochs 4.0
|
| 64 |
+
--learning_rate 2e-4
|
| 65 |
+
--enable_gradient_checkpointing
|
| 66 |
+
--logging_steps 10
|
| 67 |
+
--eval_steps 50
|
| 68 |
+
--save_steps 50
|
| 69 |
+
--eval_rows "${EVAL_PUZZLES}"
|
| 70 |
+
--max_completion_length 24
|
| 71 |
+
--limit_train_rows "${TRAIN_PUZZLES}"
|
| 72 |
+
--lora_r 32
|
| 73 |
+
--lora_alpha 64
|
| 74 |
+
--lora_dropout 0.05
|
| 75 |
+
--multi_value_oversample_factor 1
|
| 76 |
+
--train_target_size_min 2
|
| 77 |
+
--train_target_size_max 2
|
| 78 |
+
--eval_target_size_min 2
|
| 79 |
+
--eval_target_size_max 2
|
| 80 |
+
--eval_value_precision_stop 0.95
|
| 81 |
+
--eval_value_recall_stop 0.95
|
| 82 |
+
--min_steps_before_stop 100
|
| 83 |
+
--max_wall_clock_seconds 7200
|
| 84 |
+
--max_steps 600
|
| 85 |
+
--use_wandb
|
| 86 |
+
--wandb_project "sudoku-multi-output-sft"
|
| 87 |
+
--wandb_run_name "baseline_stage01_sft_i1_15empty_size2only"
|
| 88 |
+
--wandb_mode "${WANDB_MODE}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if [[ -n "${WANDB_ENTITY}" ]]; then
|
| 92 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 96 |
+
|
| 97 |
+
printf 'Launching 15-empty stage-1 size-2-only SFT baseline\n'
|
| 98 |
+
printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}"
|
| 99 |
+
printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}"
|
| 100 |
+
printf 'Output dir: %s\n' "${OUTPUT_DIR}"
|
| 101 |
+
printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
|
| 102 |
+
|
| 103 |
+
exec "${cmd[@]}"
|
hard_9x9_7empty/launch_stage1_sft.sh
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 8 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 9 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 10 |
+
|
| 11 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_train.jsonl}"
|
| 12 |
+
EVAL_JSONL="${EVAL_JSONL:-${ROOT}/data/sudoku_t3_7empty_value_qwen_text_eval.jsonl}"
|
| 13 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-400}"
|
| 14 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-200}"
|
| 15 |
+
TRAIN_SEED="${TRAIN_SEED:-0}"
|
| 16 |
+
EVAL_SEED="${EVAL_SEED:-1}"
|
| 17 |
+
|
| 18 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 19 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 20 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 21 |
+
|
| 22 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_qwen05b/baseline_stage1}"
|
| 23 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/stage01_sft_i1_7empty}"
|
| 24 |
+
|
| 25 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 26 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 27 |
+
|
| 28 |
+
if [[ ! -f "${TRAIN_JSONL}" ]]; then
|
| 29 |
+
mkdir -p "$(dirname "${TRAIN_JSONL}")"
|
| 30 |
+
printf 'Building 7-empty train dataset: %s\n' "${TRAIN_JSONL}"
|
| 31 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 32 |
+
--output "${TRAIN_JSONL}" \
|
| 33 |
+
--num_puzzles "${TRAIN_PUZZLES}" \
|
| 34 |
+
--empties 7 \
|
| 35 |
+
--seed "${TRAIN_SEED}"
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
if [[ ! -f "${EVAL_JSONL}" ]]; then
|
| 39 |
+
mkdir -p "$(dirname "${EVAL_JSONL}")"
|
| 40 |
+
printf 'Building 7-empty eval dataset: %s\n' "${EVAL_JSONL}"
|
| 41 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" \
|
| 42 |
+
--output "${EVAL_JSONL}" \
|
| 43 |
+
--num_puzzles "${EVAL_PUZZLES}" \
|
| 44 |
+
--empties 7 \
|
| 45 |
+
--seed "${EVAL_SEED}"
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
mkdir -p "${CHECKPOINT_ROOT}"
|
| 49 |
+
|
| 50 |
+
cmd=(
|
| 51 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}"
|
| 52 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct"
|
| 53 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 54 |
+
--eval_jsonl "${EVAL_JSONL}"
|
| 55 |
+
--output_dir "${OUTPUT_DIR}"
|
| 56 |
+
--cache_dir "${ROOT}/.hf_cache"
|
| 57 |
+
--seed 0
|
| 58 |
+
--gpu_id 0
|
| 59 |
+
--stage_i 1
|
| 60 |
+
--total_empties_hint 7
|
| 61 |
+
--num_epochs 3.0
|
| 62 |
+
--learning_rate 2e-4
|
| 63 |
+
--gradient_accumulation_steps 8
|
| 64 |
+
--enable_gradient_checkpointing
|
| 65 |
+
--logging_steps 10
|
| 66 |
+
--eval_steps 25
|
| 67 |
+
--save_steps 25
|
| 68 |
+
--eval_rows "${EVAL_PUZZLES}"
|
| 69 |
+
--max_completion_length 24
|
| 70 |
+
--limit_train_rows "${TRAIN_PUZZLES}"
|
| 71 |
+
--lora_r 32
|
| 72 |
+
--lora_alpha 64
|
| 73 |
+
--lora_dropout 0.05
|
| 74 |
+
--multi_value_oversample_factor 16
|
| 75 |
+
--eval_exact_set_match_stop 0.999
|
| 76 |
+
--eval_value_precision_stop 0.999
|
| 77 |
+
--eval_value_recall_stop 0.999
|
| 78 |
+
--min_steps_before_stop 50
|
| 79 |
+
--max_wall_clock_seconds 1800
|
| 80 |
+
--max_steps 250
|
| 81 |
+
--use_wandb
|
| 82 |
+
--wandb_project "sudoku-multi-output-sft"
|
| 83 |
+
--wandb_run_name "baseline_stage01_sft_i1_7empty"
|
| 84 |
+
--wandb_mode "${WANDB_MODE}"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if [[ -n "${WANDB_ENTITY}" ]]; then
|
| 88 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 92 |
+
|
| 93 |
+
printf 'Launching 7-empty stage-1 SFT baseline\n'
|
| 94 |
+
printf 'Train dataset: %s (%s puzzles)\n' "${TRAIN_JSONL}" "${TRAIN_PUZZLES}"
|
| 95 |
+
printf 'Eval dataset: %s (%s puzzles)\n' "${EVAL_JSONL}" "${EVAL_PUZZLES}"
|
| 96 |
+
printf 'Output dir: %s\n' "${OUTPUT_DIR}"
|
| 97 |
+
printf 'GPUs: %s processes=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}"
|
| 98 |
+
|
| 99 |
+
exec "${cmd[@]}"
|
hard_9x9_curriculum/build_stage3_hard_dataset.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
from dataclasses import asdict, dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, Iterable, List, Sequence, Tuple
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
CURRENT_DIR = Path(__file__).resolve().parent
|
| 14 |
+
PARENT_DIR = CURRENT_DIR.parent
|
| 15 |
+
if str(PARENT_DIR) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(PARENT_DIR))
|
| 17 |
+
|
| 18 |
+
from formatting_icon import is_consistent_pair
|
| 19 |
+
from multi_output_cell_policy.shared_multi_output_policy import stage_i_consistent_values
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
GRID_SIZE = 9
|
| 23 |
+
BOX_SIZE = 3
|
| 24 |
+
ALL_VALUES = tuple(range(1, 10))
|
| 25 |
+
DEFAULT_BASE_NAME = "sudoku_t3_30empty_stage3hard"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass(frozen=True)
|
| 29 |
+
class DifficultyProfile:
|
| 30 |
+
stage1_solved: bool
|
| 31 |
+
stage2_solved: bool
|
| 32 |
+
stage3_solved: bool
|
| 33 |
+
stage1_steps: int
|
| 34 |
+
stage2_steps: int
|
| 35 |
+
stage3_steps: int
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass(frozen=True)
|
| 39 |
+
class SeedMask:
|
| 40 |
+
mask_cells: tuple[int, ...]
|
| 41 |
+
profile: DifficultyProfile
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def parse_args() -> argparse.Namespace:
|
| 45 |
+
root = PARENT_DIR
|
| 46 |
+
default_train = root / "data" / f"{DEFAULT_BASE_NAME}_value_qwen_text.jsonl"
|
| 47 |
+
default_eval = root / "data" / f"{DEFAULT_BASE_NAME}_eval_value_qwen_text.jsonl"
|
| 48 |
+
default_manifest = root / "data" / f"{DEFAULT_BASE_NAME}_manifest.json"
|
| 49 |
+
p = argparse.ArgumentParser()
|
| 50 |
+
p.add_argument("--train_output", type=str, default=str(default_train))
|
| 51 |
+
p.add_argument("--eval_output", type=str, default=str(default_eval))
|
| 52 |
+
p.add_argument("--manifest_output", type=str, default=str(default_manifest))
|
| 53 |
+
p.add_argument("--num_train_puzzles", type=int, default=4000)
|
| 54 |
+
p.add_argument("--num_eval_puzzles", type=int, default=200)
|
| 55 |
+
p.add_argument("--empties", type=int, default=30)
|
| 56 |
+
p.add_argument("--seed", type=int, default=0)
|
| 57 |
+
p.add_argument("--max_attempts", type=int, default=200000)
|
| 58 |
+
p.add_argument("--progress_every", type=int, default=250)
|
| 59 |
+
p.add_argument("--num_seed_masks", type=int, default=8)
|
| 60 |
+
return p.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def permute_groups(rng: random.Random, values: Sequence[int], group_size: int) -> List[int]:
|
| 64 |
+
groups = [list(values[idx : idx + group_size]) for idx in range(0, len(values), group_size)]
|
| 65 |
+
rng.shuffle(groups)
|
| 66 |
+
out: List[int] = []
|
| 67 |
+
for group in groups:
|
| 68 |
+
rng.shuffle(group)
|
| 69 |
+
out.extend(group)
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def base_solved_grid() -> np.ndarray:
|
| 74 |
+
return np.asarray(
|
| 75 |
+
[[((rr * BOX_SIZE + rr // BOX_SIZE + cc) % GRID_SIZE) + 1 for cc in range(GRID_SIZE)] for rr in range(GRID_SIZE)],
|
| 76 |
+
dtype=int,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def row_major_empty_locs(grid: np.ndarray) -> List[Tuple[int, int]]:
|
| 81 |
+
return [(int(r), int(c)) for r, c in np.argwhere(np.asarray(grid, dtype=int) == 0).tolist()]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_prompt(grid: np.ndarray) -> str:
|
| 85 |
+
tuples = [f"({r + 1},{c + 1},{int(grid[r, c])})" for r in range(GRID_SIZE) for c in range(GRID_SIZE)]
|
| 86 |
+
return (
|
| 87 |
+
"9x9 Sudoku board encoded as (row,col,value) tuples in row-major order.\n"
|
| 88 |
+
"Value 0 means the cell is empty.\n"
|
| 89 |
+
+ " ".join(tuples)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def legal_values(grid: np.ndarray, row: int, col: int) -> List[int]:
|
| 94 |
+
cell = int(row) * GRID_SIZE + int(col)
|
| 95 |
+
return [int(value) for value in ALL_VALUES if is_consistent_pair(grid, cell=cell, value=int(value), t=3, n=9)]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def count_solutions(grid: np.ndarray, *, limit: int = 2) -> int:
|
| 99 |
+
board = np.asarray(grid, dtype=int).copy()
|
| 100 |
+
solutions = 0
|
| 101 |
+
|
| 102 |
+
def backtrack() -> None:
|
| 103 |
+
nonlocal solutions
|
| 104 |
+
if solutions >= int(limit):
|
| 105 |
+
return
|
| 106 |
+
best_cell: Tuple[int, int] | None = None
|
| 107 |
+
best_values: List[int] | None = None
|
| 108 |
+
for rr, cc in row_major_empty_locs(board):
|
| 109 |
+
values = legal_values(board, rr, cc)
|
| 110 |
+
if not values:
|
| 111 |
+
return
|
| 112 |
+
if best_values is None or len(values) < len(best_values):
|
| 113 |
+
best_cell = (rr, cc)
|
| 114 |
+
best_values = values
|
| 115 |
+
if len(best_values) == 1:
|
| 116 |
+
break
|
| 117 |
+
if best_cell is None:
|
| 118 |
+
solutions += 1
|
| 119 |
+
return
|
| 120 |
+
rr, cc = best_cell
|
| 121 |
+
for value in best_values or []:
|
| 122 |
+
board[rr, cc] = int(value)
|
| 123 |
+
backtrack()
|
| 124 |
+
board[rr, cc] = 0
|
| 125 |
+
if solutions >= int(limit):
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
backtrack()
|
| 129 |
+
return int(solutions)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def propagate_stage(grid: np.ndarray, *, stage_i: int) -> Tuple[np.ndarray | None, int]:
|
| 133 |
+
board = np.asarray(grid, dtype=int).copy()
|
| 134 |
+
num_assignments = 0
|
| 135 |
+
while True:
|
| 136 |
+
chosen: Tuple[int, int, int] | None = None
|
| 137 |
+
for rr, cc in row_major_empty_locs(board):
|
| 138 |
+
values = stage_i_consistent_values(board, target_cell=(rr, cc), stage_i=int(stage_i))
|
| 139 |
+
if not values:
|
| 140 |
+
return None, num_assignments
|
| 141 |
+
if len(values) == 1:
|
| 142 |
+
chosen = (rr, cc, int(values[0]))
|
| 143 |
+
break
|
| 144 |
+
if chosen is None:
|
| 145 |
+
return board, num_assignments
|
| 146 |
+
rr, cc, value = chosen
|
| 147 |
+
board[rr, cc] = int(value)
|
| 148 |
+
num_assignments += 1
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def build_difficulty_profile(puzzle: np.ndarray, solved: np.ndarray) -> DifficultyProfile | None:
|
| 152 |
+
stage1_board, stage1_steps = propagate_stage(puzzle, stage_i=1)
|
| 153 |
+
if stage1_board is None:
|
| 154 |
+
return None
|
| 155 |
+
stage2_board, stage2_steps = propagate_stage(puzzle, stage_i=2)
|
| 156 |
+
if stage2_board is None:
|
| 157 |
+
return None
|
| 158 |
+
stage3_board, stage3_steps = propagate_stage(puzzle, stage_i=3)
|
| 159 |
+
if stage3_board is None:
|
| 160 |
+
return None
|
| 161 |
+
return DifficultyProfile(
|
| 162 |
+
stage1_solved=bool(np.array_equal(stage1_board, solved)),
|
| 163 |
+
stage2_solved=bool(np.array_equal(stage2_board, solved)),
|
| 164 |
+
stage3_solved=bool(np.array_equal(stage3_board, solved)),
|
| 165 |
+
stage1_steps=int(stage1_steps),
|
| 166 |
+
stage2_steps=int(stage2_steps),
|
| 167 |
+
stage3_steps=int(stage3_steps),
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def qualifies(profile: DifficultyProfile) -> bool:
|
| 172 |
+
return (not profile.stage1_solved) and (not profile.stage2_solved) and profile.stage3_solved
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def build_puzzle_from_mask(solved: np.ndarray, mask_cells: Sequence[int]) -> np.ndarray:
|
| 176 |
+
puzzle = np.asarray(solved, dtype=int).copy()
|
| 177 |
+
for cell in mask_cells:
|
| 178 |
+
rr, cc = divmod(int(cell), GRID_SIZE)
|
| 179 |
+
puzzle[rr, cc] = 0
|
| 180 |
+
return puzzle
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def sample_mask_cells(*, empties: int, rng: random.Random) -> tuple[int, ...]:
|
| 184 |
+
cells = list(range(GRID_SIZE * GRID_SIZE))
|
| 185 |
+
rng.shuffle(cells)
|
| 186 |
+
return tuple(sorted(int(cell) for cell in cells[: int(empties)]))
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def greedy_find_seed_mask(
|
| 190 |
+
*,
|
| 191 |
+
empties: int,
|
| 192 |
+
max_attempts: int,
|
| 193 |
+
rng: random.Random,
|
| 194 |
+
progress_every: int,
|
| 195 |
+
) -> Tuple[SeedMask | None, Dict[str, int]]:
|
| 196 |
+
solved = base_solved_grid()
|
| 197 |
+
attempts = 0
|
| 198 |
+
restarts = 0
|
| 199 |
+
while attempts < int(max_attempts):
|
| 200 |
+
restarts += 1
|
| 201 |
+
mask: List[int] = []
|
| 202 |
+
remaining = list(range(GRID_SIZE * GRID_SIZE))
|
| 203 |
+
rng.shuffle(remaining)
|
| 204 |
+
current_profile: DifficultyProfile | None = None
|
| 205 |
+
|
| 206 |
+
while len(mask) < int(empties) and attempts < int(max_attempts):
|
| 207 |
+
best_cell: int | None = None
|
| 208 |
+
best_profile: DifficultyProfile | None = None
|
| 209 |
+
best_score: Tuple[int, int, int] | None = None
|
| 210 |
+
candidate_cells = list(remaining[: min(len(remaining), 12)])
|
| 211 |
+
if not candidate_cells:
|
| 212 |
+
break
|
| 213 |
+
|
| 214 |
+
for cell in candidate_cells:
|
| 215 |
+
attempts += 1
|
| 216 |
+
trial_mask = tuple(sorted(mask + [int(cell)]))
|
| 217 |
+
puzzle = build_puzzle_from_mask(solved, trial_mask)
|
| 218 |
+
profile = build_difficulty_profile(puzzle, solved)
|
| 219 |
+
if profile is None or not profile.stage3_solved:
|
| 220 |
+
continue
|
| 221 |
+
score = (
|
| 222 |
+
int(not profile.stage2_solved),
|
| 223 |
+
int(not profile.stage1_solved),
|
| 224 |
+
int(profile.stage3_steps - profile.stage2_steps),
|
| 225 |
+
)
|
| 226 |
+
if best_score is None or score > best_score:
|
| 227 |
+
best_cell = int(cell)
|
| 228 |
+
best_profile = profile
|
| 229 |
+
best_score = score
|
| 230 |
+
|
| 231 |
+
if attempts == 1 or attempts % max(1, int(progress_every)) == 0:
|
| 232 |
+
print(
|
| 233 |
+
f"[search hard 9x9 masks] attempts={attempts} restarts={restarts} current_empties={len(mask)}",
|
| 234 |
+
flush=True,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
if best_cell is None or best_profile is None:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
mask.append(int(best_cell))
|
| 241 |
+
mask.sort()
|
| 242 |
+
remaining.remove(int(best_cell))
|
| 243 |
+
current_profile = best_profile
|
| 244 |
+
|
| 245 |
+
if len(mask) != int(empties) or current_profile is None:
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
final_mask = tuple(sorted(int(cell) for cell in mask))
|
| 249 |
+
final_puzzle = build_puzzle_from_mask(solved, final_mask)
|
| 250 |
+
final_profile = build_difficulty_profile(final_puzzle, solved)
|
| 251 |
+
if final_profile is None or not qualifies(final_profile):
|
| 252 |
+
continue
|
| 253 |
+
if count_solutions(final_puzzle, limit=2) != 1:
|
| 254 |
+
continue
|
| 255 |
+
return SeedMask(mask_cells=final_mask, profile=final_profile), {
|
| 256 |
+
"attempts": int(attempts),
|
| 257 |
+
"restarts": int(restarts),
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
return None, {"attempts": int(attempts), "restarts": int(restarts)}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def random_symmetry(
|
| 264 |
+
rng: random.Random, *, solved: np.ndarray, mask_cells: Sequence[int]
|
| 265 |
+
) -> Tuple[np.ndarray, tuple[int, ...]]:
|
| 266 |
+
digits = list(ALL_VALUES)
|
| 267 |
+
rng.shuffle(digits)
|
| 268 |
+
digit_map = {src: dst for src, dst in zip(ALL_VALUES, digits, strict=True)}
|
| 269 |
+
transformed = np.vectorize(lambda value: digit_map[int(value)], otypes=[int])(np.asarray(solved, dtype=int).copy())
|
| 270 |
+
|
| 271 |
+
row_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE)
|
| 272 |
+
col_order = permute_groups(rng, list(range(GRID_SIZE)), BOX_SIZE)
|
| 273 |
+
inverse_row = {old: new for new, old in enumerate(row_order)}
|
| 274 |
+
inverse_col = {old: new for new, old in enumerate(col_order)}
|
| 275 |
+
|
| 276 |
+
transformed = transformed[row_order, :]
|
| 277 |
+
transformed = transformed[:, col_order]
|
| 278 |
+
|
| 279 |
+
transformed_cells: List[int] = []
|
| 280 |
+
for cell in mask_cells:
|
| 281 |
+
rr, cc = divmod(int(cell), GRID_SIZE)
|
| 282 |
+
new_r = int(inverse_row[int(rr)])
|
| 283 |
+
new_c = int(inverse_col[int(cc)])
|
| 284 |
+
transformed_cells.append(new_r * GRID_SIZE + new_c)
|
| 285 |
+
|
| 286 |
+
if rng.random() < 0.5:
|
| 287 |
+
transformed = transformed.T
|
| 288 |
+
transformed_cells = [int(cc) * GRID_SIZE + int(rr) for rr, cc in (divmod(cell, GRID_SIZE) for cell in transformed_cells)]
|
| 289 |
+
|
| 290 |
+
return np.asarray(transformed, dtype=int), tuple(sorted(int(cell) for cell in transformed_cells))
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def make_example(solved: np.ndarray, mask_cells: Sequence[int], *, empties: int, profile: DifficultyProfile) -> Dict[str, object]:
|
| 294 |
+
puzzle = build_puzzle_from_mask(solved, mask_cells)
|
| 295 |
+
empty_locs_1based = [(rr + 1, cc + 1) for rr, cc in row_major_empty_locs(puzzle)]
|
| 296 |
+
target_triples_1based = [(rr + 1, cc + 1, int(solved[rr, cc])) for rr, cc in row_major_empty_locs(puzzle)]
|
| 297 |
+
completion_values = [int(value) for _, _, value in target_triples_1based]
|
| 298 |
+
return {
|
| 299 |
+
"prompt": make_prompt(puzzle),
|
| 300 |
+
"completion": json.dumps(completion_values, separators=(",", ":")),
|
| 301 |
+
"metadata": {
|
| 302 |
+
"grid_size": GRID_SIZE,
|
| 303 |
+
"box_size": BOX_SIZE,
|
| 304 |
+
"empties": int(empties),
|
| 305 |
+
"empty_locs_1based": empty_locs_1based,
|
| 306 |
+
"target_triples_1based": target_triples_1based,
|
| 307 |
+
"required_consistency_stage": 3,
|
| 308 |
+
"difficulty_profile": asdict(profile),
|
| 309 |
+
},
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def search_seed_masks(
|
| 314 |
+
*,
|
| 315 |
+
num_seed_masks: int,
|
| 316 |
+
empties: int,
|
| 317 |
+
max_attempts: int,
|
| 318 |
+
seed: int,
|
| 319 |
+
progress_every: int,
|
| 320 |
+
) -> Tuple[List[SeedMask], Dict[str, int]]:
|
| 321 |
+
rng = random.Random(int(seed))
|
| 322 |
+
seeds: List[SeedMask] = []
|
| 323 |
+
seen = set()
|
| 324 |
+
total_attempts = 0
|
| 325 |
+
total_restarts = 0
|
| 326 |
+
|
| 327 |
+
while len(seeds) < int(num_seed_masks) and total_attempts < int(max_attempts):
|
| 328 |
+
mask_seed, stats = greedy_find_seed_mask(
|
| 329 |
+
empties=int(empties),
|
| 330 |
+
max_attempts=max(1, int(max_attempts) - int(total_attempts)),
|
| 331 |
+
rng=rng,
|
| 332 |
+
progress_every=int(progress_every),
|
| 333 |
+
)
|
| 334 |
+
total_attempts += int(stats.get("attempts", 0))
|
| 335 |
+
total_restarts += int(stats.get("restarts", 0))
|
| 336 |
+
if mask_seed is None:
|
| 337 |
+
break
|
| 338 |
+
if mask_seed.mask_cells in seen:
|
| 339 |
+
continue
|
| 340 |
+
seen.add(mask_seed.mask_cells)
|
| 341 |
+
seeds.append(mask_seed)
|
| 342 |
+
print(
|
| 343 |
+
f"[search hard 9x9 masks] attempts={total_attempts} accepted={len(seeds)}/{num_seed_masks}",
|
| 344 |
+
flush=True,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
stats = {
|
| 348 |
+
"attempts": int(total_attempts),
|
| 349 |
+
"restarts": int(total_restarts),
|
| 350 |
+
"accepted_seed_masks": int(len(seeds)),
|
| 351 |
+
}
|
| 352 |
+
return seeds, stats
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def generate_examples(
|
| 356 |
+
*,
|
| 357 |
+
num_examples: int,
|
| 358 |
+
empties: int,
|
| 359 |
+
seed_masks: Sequence[SeedMask],
|
| 360 |
+
seed: int,
|
| 361 |
+
) -> List[Dict[str, object]]:
|
| 362 |
+
if not seed_masks:
|
| 363 |
+
raise ValueError("seed_masks must not be empty")
|
| 364 |
+
rng = random.Random(int(seed) + 1)
|
| 365 |
+
solved = base_solved_grid()
|
| 366 |
+
rows: List[Dict[str, object]] = []
|
| 367 |
+
for idx in range(int(num_examples)):
|
| 368 |
+
seed_mask = seed_masks[idx % len(seed_masks)]
|
| 369 |
+
transformed_solved, transformed_mask = random_symmetry(
|
| 370 |
+
rng, solved=solved, mask_cells=seed_mask.mask_cells
|
| 371 |
+
)
|
| 372 |
+
rows.append(
|
| 373 |
+
make_example(
|
| 374 |
+
transformed_solved,
|
| 375 |
+
transformed_mask,
|
| 376 |
+
empties=int(empties),
|
| 377 |
+
profile=seed_mask.profile,
|
| 378 |
+
)
|
| 379 |
+
)
|
| 380 |
+
return rows
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def write_jsonl(path: Path, rows: Iterable[Dict[str, object]]) -> None:
|
| 384 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 385 |
+
with path.open("w", encoding="utf-8") as f:
|
| 386 |
+
for row in rows:
|
| 387 |
+
f.write(json.dumps(row, separators=(",", ":")) + "\n")
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def main() -> None:
|
| 391 |
+
args = parse_args()
|
| 392 |
+
total_needed = int(args.num_train_puzzles) + int(args.num_eval_puzzles)
|
| 393 |
+
num_seed_masks = min(max(1, int(args.num_seed_masks)), total_needed)
|
| 394 |
+
seed_masks, search_stats = search_seed_masks(
|
| 395 |
+
num_seed_masks=num_seed_masks,
|
| 396 |
+
empties=int(args.empties),
|
| 397 |
+
max_attempts=int(args.max_attempts),
|
| 398 |
+
seed=int(args.seed),
|
| 399 |
+
progress_every=int(args.progress_every),
|
| 400 |
+
)
|
| 401 |
+
if len(seed_masks) < num_seed_masks:
|
| 402 |
+
raise RuntimeError(
|
| 403 |
+
f"Only found {len(seed_masks)} qualifying seed masks out of requested {num_seed_masks}. "
|
| 404 |
+
f"Try increasing --max_attempts or reducing --num_seed_masks."
|
| 405 |
+
)
|
| 406 |
+
rows = generate_examples(
|
| 407 |
+
num_examples=total_needed,
|
| 408 |
+
empties=int(args.empties),
|
| 409 |
+
seed_masks=seed_masks,
|
| 410 |
+
seed=int(args.seed),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
eval_rows = rows[: int(args.num_eval_puzzles)]
|
| 414 |
+
train_rows = rows[int(args.num_eval_puzzles) :]
|
| 415 |
+
|
| 416 |
+
train_output = Path(args.train_output).resolve()
|
| 417 |
+
eval_output = Path(args.eval_output).resolve()
|
| 418 |
+
manifest_output = Path(args.manifest_output).resolve()
|
| 419 |
+
|
| 420 |
+
write_jsonl(train_output, train_rows)
|
| 421 |
+
write_jsonl(eval_output, eval_rows)
|
| 422 |
+
manifest_output.parent.mkdir(parents=True, exist_ok=True)
|
| 423 |
+
manifest_output.write_text(
|
| 424 |
+
json.dumps(
|
| 425 |
+
{
|
| 426 |
+
"train_output": str(train_output),
|
| 427 |
+
"eval_output": str(eval_output),
|
| 428 |
+
"num_train_puzzles": int(len(train_rows)),
|
| 429 |
+
"num_eval_puzzles": int(len(eval_rows)),
|
| 430 |
+
"empties": int(args.empties),
|
| 431 |
+
"seed": int(args.seed),
|
| 432 |
+
"required_consistency_stage": 3,
|
| 433 |
+
"num_seed_masks": int(num_seed_masks),
|
| 434 |
+
"search_stats": search_stats,
|
| 435 |
+
},
|
| 436 |
+
indent=2,
|
| 437 |
+
sort_keys=True,
|
| 438 |
+
)
|
| 439 |
+
+ "\n",
|
| 440 |
+
encoding="utf-8",
|
| 441 |
+
)
|
| 442 |
+
print(f"Wrote {len(train_rows)} train puzzles to {train_output}")
|
| 443 |
+
print(f"Wrote {len(eval_rows)} eval puzzles to {eval_output}")
|
| 444 |
+
print(f"Wrote manifest to {manifest_output}")
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
if __name__ == "__main__":
|
| 448 |
+
main()
|
hard_9x9_stage1_consistency_queue/README.md
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Stage-1 Latent SFT Mode Sweep
|
| 2 |
+
|
| 3 |
+
This folder contains launchers for the 9x9 Sudoku curriculum experiments. The
|
| 4 |
+
stage-1 latent sweep launcher is:
|
| 5 |
+
|
| 6 |
+
```bash
|
| 7 |
+
bash hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
The goal of this sweep is to compare the four latent implementations under the
|
| 11 |
+
same stage-1 SFT setup and measure which one gives the fastest useful
|
| 12 |
+
convergence. The main comparison should include training loss, held-out value
|
| 13 |
+
precision/recall, completion quality, wall-clock time, and GPU efficiency. In
|
| 14 |
+
particular, compare both loss vs. optimizer step and loss vs. elapsed time,
|
| 15 |
+
because some methods do more transformer forward passes per step.
|
| 16 |
+
|
| 17 |
+
## Four Latent Modes
|
| 18 |
+
|
| 19 |
+
### `residual`
|
| 20 |
+
|
| 21 |
+
The residual mode performs a dynamic latent hidden rollout, then projects the
|
| 22 |
+
difference between the latent hidden state and the base hidden state back into
|
| 23 |
+
the model hidden space. This projected delta is added to the base next-token
|
| 24 |
+
hidden state before computing logits. It is expressive, but it is slower because
|
| 25 |
+
the latent rollout requires repeated transformer passes.
|
| 26 |
+
|
| 27 |
+
### `fixed_slots`
|
| 28 |
+
|
| 29 |
+
The fixed-slots mode learns a bank of trainable latent slot embeddings plus a
|
| 30 |
+
separate final readout slot. For each prediction, the model runs once on:
|
| 31 |
+
|
| 32 |
+
```text
|
| 33 |
+
[prompt tokens, slot_1, ..., slot_k, final_slot]
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
The next token is predicted from the hidden state at `final_slot`. This is a
|
| 37 |
+
parallel latent method: all latent slots are inserted at once, so it avoids the
|
| 38 |
+
recursive pass used by recurrent methods.
|
| 39 |
+
|
| 40 |
+
### `recurrent_hidden`
|
| 41 |
+
|
| 42 |
+
The recurrent-hidden mode generates latent tokens dynamically from the current
|
| 43 |
+
example. It appends a hidden latent token, reruns the transformer, takes the new
|
| 44 |
+
last hidden state as the next latent token, and repeats for `num_cot_tokens`.
|
| 45 |
+
This is the closest to iterative hidden reasoning, but it is usually the
|
| 46 |
+
slowest because the latent steps are serial.
|
| 47 |
+
|
| 48 |
+
### `latent_seeds`
|
| 49 |
+
|
| 50 |
+
The latent-seeds mode learns a bank of trainable seed embeddings. For each
|
| 51 |
+
prediction, the model runs once on:
|
| 52 |
+
|
| 53 |
+
```text
|
| 54 |
+
[prompt tokens, seed_1, ..., seed_k]
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
The next token is predicted from the hidden state at the last seed position.
|
| 58 |
+
Like fixed slots, this is parallel and avoids recursive transformer passes. The
|
| 59 |
+
main difference from `fixed_slots` is that there is no separate final readout
|
| 60 |
+
slot; the last seed position acts as the readout.
|
| 61 |
+
|
| 62 |
+
## Experimental Strategy
|
| 63 |
+
|
| 64 |
+
Run all four modes in parallel on stage 1 with the same dataset, LoRA settings,
|
| 65 |
+
number of latent tokens, stopping rule, and evaluation set. The default launcher
|
| 66 |
+
splits an 8-GPU node into four two-GPU jobs:
|
| 67 |
+
|
| 68 |
+
```text
|
| 69 |
+
residual -> GPUs 0,1
|
| 70 |
+
fixed_slots -> GPUs 2,3
|
| 71 |
+
recurrent_hidden -> GPUs 4,5
|
| 72 |
+
latent_seeds -> GPUs 6,7
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
Use the results to decide which one or two methods should be promoted to deeper
|
| 76 |
+
curriculum stages. The expected practical tradeoff is that `fixed_slots` and
|
| 77 |
+
`latent_seeds` should be much faster per wall-clock time, while `residual` and
|
| 78 |
+
`recurrent_hidden` test more iterative, example-dependent latent computation.
|
| 79 |
+
|
| 80 |
+
## Warm Baseline Stages 1-3 Pipeline
|
| 81 |
+
|
| 82 |
+
The full warm-baseline launcher is:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
STAGE1_BASELINE_ADAPTER_DIR=/path/to/warmed/stage1/baseline/checkpoint \
|
| 86 |
+
bash hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
It runs all four latent modes in parallel, two GPUs per mode:
|
| 90 |
+
|
| 91 |
+
```text
|
| 92 |
+
residual -> GPUs 0,1
|
| 93 |
+
fixed_slots -> GPUs 2,3
|
| 94 |
+
recurrent_hidden -> GPUs 4,5
|
| 95 |
+
latent_seeds -> GPUs 6,7
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
For each mode, the intended sequence is:
|
| 99 |
+
|
| 100 |
+
```text
|
| 101 |
+
stage1 latent SFT
|
| 102 |
+
-> stage1 latent GRPO
|
| 103 |
+
-> stage2 baseline warm-up SFT
|
| 104 |
+
-> stage2 latent SFT
|
| 105 |
+
-> stage2 latent GRPO
|
| 106 |
+
-> stage3 baseline warm-up SFT
|
| 107 |
+
-> stage3 latent SFT
|
| 108 |
+
-> stage3 latent GRPO
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
The run is capped by fixed step budgets by default (`1000` SFT steps and `500`
|
| 112 |
+
GRPO steps per phase) and can stop early when the configured solve-rate target
|
| 113 |
+
is reached. The current 1.5B run uses the warmed Stage-1 baseline adapter from
|
| 114 |
+
`hard_9x9_20empty_baseline_1p5b_warmup`.
|
| 115 |
+
|
| 116 |
+
See `warm_baseline_all_latent_modes_stages123_results.md` for the current
|
| 117 |
+
solve-rate snapshot from the ongoing full-pipeline run.
|
hard_9x9_stage1_consistency_queue/debug_fixed_slot_latent_one_example.sh
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
|
| 7 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 8 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 9 |
+
DATA_PATH="${DATA_PATH:-${ROOT}/data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl}"
|
| 10 |
+
GPU_ID="${GPU_ID:-0}"
|
| 11 |
+
NUM_COT="${NUM_COT:-3}"
|
| 12 |
+
MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}"
|
| 13 |
+
LIMIT_ROWS="${LIMIT_ROWS:-1}"
|
| 14 |
+
TRAIN_STEPS="${TRAIN_STEPS:-60}"
|
| 15 |
+
LR="${LR:-1e-1}"
|
| 16 |
+
LORA_R="${LORA_R:-32}"
|
| 17 |
+
LORA_ALPHA="${LORA_ALPHA:-64}"
|
| 18 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 19 |
+
|
| 20 |
+
export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
|
| 21 |
+
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-${GPU_ID}}"
|
| 22 |
+
|
| 23 |
+
exec "${PYTHON_BIN}" - <<'PY'
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 29 |
+
|
| 30 |
+
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row
|
| 31 |
+
from latent_multi_output_cell_policy.grpo_residual_projector_latent_train import (
|
| 32 |
+
attach_fixed_latent_slot_modules,
|
| 33 |
+
fixed_slot_next_token_logits_from_ids,
|
| 34 |
+
load_jsonl_rows,
|
| 35 |
+
load_trainable_adapter,
|
| 36 |
+
pick_dtype,
|
| 37 |
+
sample_fixed_slot_completion,
|
| 38 |
+
unwrap_backbone,
|
| 39 |
+
)
|
| 40 |
+
from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt
|
| 41 |
+
from multi_output_cell_policy.shared_multi_output_policy import build_supervised_completion
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def env_int(name: str, default: int) -> int:
|
| 45 |
+
return int(os.environ.get(name, str(default)))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def env_float(name: str, default: float) -> float:
|
| 49 |
+
return float(os.environ.get(name, str(default)))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 53 |
+
data_path = os.environ.get("DATA_PATH", "data/sudoku_t3_20empty_value_qwen_text_stage1_train.jsonl")
|
| 54 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 55 |
+
num_cot = env_int("NUM_COT", 5)
|
| 56 |
+
max_latent_slots = env_int("MAX_LATENT_SLOTS", 8)
|
| 57 |
+
limit_rows = env_int("LIMIT_ROWS", 1)
|
| 58 |
+
train_steps = env_int("TRAIN_STEPS", 60)
|
| 59 |
+
lr = env_float("LR", 1e-1)
|
| 60 |
+
lora_r = env_int("LORA_R", 32)
|
| 61 |
+
lora_alpha = env_int("LORA_ALPHA", 64)
|
| 62 |
+
lora_dropout = env_float("LORA_DROPOUT", 0.05)
|
| 63 |
+
|
| 64 |
+
rows = load_jsonl_rows(data_path, limit_rows=limit_rows)
|
| 65 |
+
ex = build_cell_examples_from_row(rows[0])[0]
|
| 66 |
+
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
| 68 |
+
if tokenizer.pad_token_id is None:
|
| 69 |
+
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
|
| 70 |
+
|
| 71 |
+
prompt = build_multi_output_cell_prompt(
|
| 72 |
+
ex.grid,
|
| 73 |
+
target_cell=ex.target_cell,
|
| 74 |
+
stage_i=1,
|
| 75 |
+
tokenizer=tokenizer,
|
| 76 |
+
turn_idx=ex.turn_idx,
|
| 77 |
+
total_turns=ex.total_turns,
|
| 78 |
+
prev_output_flag=None,
|
| 79 |
+
total_empties_hint=20,
|
| 80 |
+
)
|
| 81 |
+
target_text = build_supervised_completion(ex, stage_i=1) + (tokenizer.eos_token or "")
|
| 82 |
+
print("target_text", target_text)
|
| 83 |
+
|
| 84 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 85 |
+
model_name,
|
| 86 |
+
torch_dtype=pick_dtype(),
|
| 87 |
+
low_cpu_mem_usage=True,
|
| 88 |
+
)
|
| 89 |
+
model = load_trainable_adapter(base, "", lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
|
| 90 |
+
attach_fixed_latent_slot_modules(
|
| 91 |
+
model,
|
| 92 |
+
hidden_size=int(unwrap_backbone(model).config.hidden_size),
|
| 93 |
+
max_latent_slots=max_latent_slots,
|
| 94 |
+
)
|
| 95 |
+
if hasattr(model, "config"):
|
| 96 |
+
model.config.use_cache = False
|
| 97 |
+
backbone = unwrap_backbone(model)
|
| 98 |
+
if hasattr(backbone, "config"):
|
| 99 |
+
backbone.config.use_cache = False
|
| 100 |
+
model.to(device)
|
| 101 |
+
|
| 102 |
+
for p in model.parameters():
|
| 103 |
+
p.requires_grad = False
|
| 104 |
+
model.fixed_latent_slots.requires_grad_(True)
|
| 105 |
+
model.fixed_final_slot_embed.requires_grad_(True)
|
| 106 |
+
optimizer = torch.optim.AdamW([model.fixed_latent_slots, model.fixed_final_slot_embed], lr=lr)
|
| 107 |
+
|
| 108 |
+
prompt_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
|
| 109 |
+
completion_ids = tokenizer(target_text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@torch.no_grad()
|
| 113 |
+
def sample_now(tag: str) -> None:
|
| 114 |
+
model.eval()
|
| 115 |
+
attn = torch.ones_like(prompt_ids, device=device)
|
| 116 |
+
logits = fixed_slot_next_token_logits_from_ids(model, prompt_ids, attn, num_cot)
|
| 117 |
+
probs = torch.softmax(logits[0].float(), dim=-1)
|
| 118 |
+
top_probs, top_ids = torch.topk(probs, k=5)
|
| 119 |
+
out_ids = sample_fixed_slot_completion(
|
| 120 |
+
model,
|
| 121 |
+
tokenizer,
|
| 122 |
+
prompt_ids,
|
| 123 |
+
attn,
|
| 124 |
+
num_cot_tokens=num_cot,
|
| 125 |
+
max_new_tokens=12,
|
| 126 |
+
do_sample=False,
|
| 127 |
+
)
|
| 128 |
+
top_next = [(tokenizer.decode([int(i)]), round(float(p), 4)) for i, p in zip(top_ids.tolist(), top_probs.tolist())]
|
| 129 |
+
print(tag, tokenizer.decode(out_ids[0], skip_special_tokens=True), "top_next", top_next)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
sample_now("before:")
|
| 133 |
+
|
| 134 |
+
for step in range(1, train_steps + 1):
|
| 135 |
+
model.train()
|
| 136 |
+
cur_ids = prompt_ids
|
| 137 |
+
cur_mask = torch.ones_like(prompt_ids, device=device)
|
| 138 |
+
losses = []
|
| 139 |
+
for idx in range(int(completion_ids.shape[1])):
|
| 140 |
+
logits = fixed_slot_next_token_logits_from_ids(model, cur_ids, cur_mask, num_cot)
|
| 141 |
+
target = completion_ids[:, idx]
|
| 142 |
+
losses.append(F.cross_entropy(logits.float(), target, reduction="mean"))
|
| 143 |
+
cur_ids = torch.cat([cur_ids, completion_ids[:, idx : idx + 1]], dim=1)
|
| 144 |
+
cur_mask = torch.cat(
|
| 145 |
+
[
|
| 146 |
+
cur_mask,
|
| 147 |
+
torch.ones((cur_mask.shape[0], 1), dtype=cur_mask.dtype, device=cur_mask.device),
|
| 148 |
+
],
|
| 149 |
+
dim=1,
|
| 150 |
+
)
|
| 151 |
+
loss = torch.stack(losses).mean()
|
| 152 |
+
optimizer.zero_grad(set_to_none=True)
|
| 153 |
+
loss.backward()
|
| 154 |
+
optimizer.step()
|
| 155 |
+
if step == 1 or step % 10 == 0 or step == train_steps:
|
| 156 |
+
print(f"step={step} loss={float(loss.item()):.6f}")
|
| 157 |
+
sample_now(f"after_step_{step}:")
|
| 158 |
+
PY
|
hard_9x9_stage1_consistency_queue/launch_10empty_full_pipeline_stages123_value98.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Full 10-empty baseline pipeline, matching the successful 7-empty procedure:
|
| 3 |
+
# 1) Stage-1 SFT to value precision/recall >= 0.98
|
| 4 |
+
# 2) Stage-1 GRPO
|
| 5 |
+
# 3) Stage-2 SFT
|
| 6 |
+
# 4) Stage-2 GRPO
|
| 7 |
+
# 5) Stage-3 SFT
|
| 8 |
+
# 6) Stage-3 GRPO
|
| 9 |
+
#
|
| 10 |
+
# This is a wrapper around:
|
| 11 |
+
# - launch_10empty_sft_stage1_98p.sh
|
| 12 |
+
# - launch_10empty_post_s1sft_stages123_value98.sh
|
| 13 |
+
#
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
|
| 16 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 17 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 18 |
+
|
| 19 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 20 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_full_stages123_value98}"
|
| 21 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 22 |
+
|
| 23 |
+
SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_10empty_sft_stage1_98p.sh"
|
| 24 |
+
POST_S1_SCRIPT="${SCRIPT_DIR}/launch_10empty_post_s1sft_stages123_value98.sh"
|
| 25 |
+
S1_DIR="${OUTPUT_ROOT}/10empty/stage01_sft_i1_10empty_sft98"
|
| 26 |
+
|
| 27 |
+
latest_checkpoint_in_dir() {
|
| 28 |
+
local d="$1"
|
| 29 |
+
shopt -s nullglob
|
| 30 |
+
local checkpoints=("${d}"/checkpoint-step-*)
|
| 31 |
+
shopt -u nullglob
|
| 32 |
+
if (( ${#checkpoints[@]} == 0 )); then
|
| 33 |
+
printf ''
|
| 34 |
+
return 1
|
| 35 |
+
fi
|
| 36 |
+
set +o pipefail
|
| 37 |
+
printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
|
| 38 |
+
set -o pipefail
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
printf '=== 10-empty full baseline pipeline (stage1 SFT -> stages123) ===\n'
|
| 42 |
+
printf 'run_tag=%s\n' "${RUN_TAG}"
|
| 43 |
+
printf 'output_root=%s\n' "${OUTPUT_ROOT}"
|
| 44 |
+
|
| 45 |
+
OUTPUT_DIR="${S1_DIR}" \
|
| 46 |
+
RUN_TAG="${RUN_TAG}" \
|
| 47 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
|
| 48 |
+
"${SFT_STAGE1_SCRIPT}"
|
| 49 |
+
|
| 50 |
+
STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")"
|
| 51 |
+
if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 52 |
+
printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2
|
| 53 |
+
exit 1
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
|
| 57 |
+
|
| 58 |
+
STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \
|
| 59 |
+
RUN_TAG="${RUN_TAG}" \
|
| 60 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
|
| 61 |
+
OUTPUT_ROOT="${OUTPUT_ROOT}" \
|
| 62 |
+
"${POST_S1_SCRIPT}"
|
hard_9x9_stage1_consistency_queue/launch_10empty_post_s1sft_stages123_value98.sh
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run AFTER stage-1 SFT finishes (10-empty). Order:
|
| 3 |
+
# 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
|
| 4 |
+
# 2) Stage-2 SFT (init = stage-1 GRPO adapter)
|
| 5 |
+
# 3) Stage-2 GRPO (init = stage-2 SFT adapter)
|
| 6 |
+
# 4) Stage-3 SFT (init = stage-2 GRPO adapter)
|
| 7 |
+
# 5) Stage-3 GRPO (init = stage-3 SFT adapter)
|
| 8 |
+
#
|
| 9 |
+
# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
|
| 10 |
+
# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
|
| 11 |
+
# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
|
| 12 |
+
# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
|
| 13 |
+
#
|
| 14 |
+
# Required (full pipeline from stage-1 SFT):
|
| 15 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
|
| 16 |
+
#
|
| 17 |
+
# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
|
| 18 |
+
# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_10empty
|
| 19 |
+
# (OUTPUT_ROOT defaults to dirname of that dir.)
|
| 20 |
+
#
|
| 21 |
+
# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
|
| 22 |
+
# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_10empty
|
| 23 |
+
#
|
| 24 |
+
# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
|
| 25 |
+
# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_10empty
|
| 26 |
+
#
|
| 27 |
+
# Optional:
|
| 28 |
+
# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
|
| 29 |
+
# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
|
| 30 |
+
#
|
| 31 |
+
set -euo pipefail
|
| 32 |
+
|
| 33 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 34 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 35 |
+
|
| 36 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 37 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 38 |
+
GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 39 |
+
|
| 40 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 41 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 42 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 43 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 44 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 45 |
+
|
| 46 |
+
EMPTIES=10
|
| 47 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 48 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 49 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 50 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 51 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 52 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 53 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 54 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 55 |
+
|
| 56 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_stages123_value98}"
|
| 57 |
+
START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
|
| 58 |
+
START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
|
| 59 |
+
RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
|
| 60 |
+
|
| 61 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 62 |
+
if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 63 |
+
printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 64 |
+
exit 1
|
| 65 |
+
fi
|
| 66 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
|
| 67 |
+
elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 68 |
+
if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 69 |
+
printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
|
| 73 |
+
elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 74 |
+
if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 75 |
+
printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
|
| 76 |
+
exit 1
|
| 77 |
+
fi
|
| 78 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
|
| 79 |
+
else
|
| 80 |
+
if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 81 |
+
printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 88 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 89 |
+
|
| 90 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 91 |
+
|
| 92 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 93 |
+
|
| 94 |
+
latest_sft_step_ckpt() {
|
| 95 |
+
local d="$1"
|
| 96 |
+
shopt -s nullglob
|
| 97 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 98 |
+
shopt -u nullglob
|
| 99 |
+
if (( ${#cks[@]} == 0 )); then
|
| 100 |
+
printf ''
|
| 101 |
+
return 1
|
| 102 |
+
fi
|
| 103 |
+
set +o pipefail
|
| 104 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 105 |
+
set -o pipefail
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
resolve_grpo_adapter() {
|
| 109 |
+
local d="$1"
|
| 110 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 111 |
+
printf '%s\n' "${d}"
|
| 112 |
+
return 0
|
| 113 |
+
fi
|
| 114 |
+
local best="" step=-1
|
| 115 |
+
shopt -s nullglob
|
| 116 |
+
local c
|
| 117 |
+
for c in "${d}"/checkpoint-*; do
|
| 118 |
+
[[ -d "${c}" ]] || continue
|
| 119 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 120 |
+
local n
|
| 121 |
+
n="${c##*checkpoint-}"
|
| 122 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 123 |
+
step=$((10#${n}))
|
| 124 |
+
best="${c}"
|
| 125 |
+
fi
|
| 126 |
+
done
|
| 127 |
+
shopt -u nullglob
|
| 128 |
+
if [[ -n "${best}" ]]; then
|
| 129 |
+
printf '%s\n' "${best}"
|
| 130 |
+
return 0
|
| 131 |
+
fi
|
| 132 |
+
printf ''
|
| 133 |
+
return 1
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
GC_FLAGS=()
|
| 137 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 138 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
run_sft() {
|
| 142 |
+
local stage="$1"
|
| 143 |
+
local init_adapter="$2"
|
| 144 |
+
local out_dir="$3"
|
| 145 |
+
local lr="$4"
|
| 146 |
+
mkdir -p "${out_dir}"
|
| 147 |
+
printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
|
| 148 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 149 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 150 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 151 |
+
--train_jsonl "${train_jsonl}" \
|
| 152 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 153 |
+
--output_dir "${out_dir}" \
|
| 154 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 155 |
+
--init_adapter_dir "${init_adapter}" \
|
| 156 |
+
--seed 0 \
|
| 157 |
+
--gpu_id 0 \
|
| 158 |
+
--stage_i "${stage}" \
|
| 159 |
+
--total_empties_hint "${EMPTIES}" \
|
| 160 |
+
--per_device_train_batch_size 16 \
|
| 161 |
+
--gradient_accumulation_steps 2 \
|
| 162 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 163 |
+
--learning_rate "${lr}" \
|
| 164 |
+
--max_grad_norm 1.0 \
|
| 165 |
+
"${GC_FLAGS[@]}" \
|
| 166 |
+
--logging_steps 20 \
|
| 167 |
+
--eval_steps 250 \
|
| 168 |
+
--save_steps 200 \
|
| 169 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 170 |
+
--max_completion_length 24 \
|
| 171 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 172 |
+
--lora_r 32 \
|
| 173 |
+
--lora_alpha 64 \
|
| 174 |
+
--lora_dropout 0.05 \
|
| 175 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 176 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 177 |
+
--eval_exact_set_match_stop 0 \
|
| 178 |
+
--eval_solve_rate_stop 0 \
|
| 179 |
+
--min_steps_before_stop 50 \
|
| 180 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 181 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 182 |
+
--use_wandb \
|
| 183 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 184 |
+
--wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 185 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 186 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
run_grpo() {
|
| 190 |
+
local stage="$1"
|
| 191 |
+
local init_adapter="$2"
|
| 192 |
+
local out_dir="$3"
|
| 193 |
+
mkdir -p "${out_dir}"
|
| 194 |
+
printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
|
| 195 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 196 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 197 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 198 |
+
--train_jsonl "${train_jsonl}" \
|
| 199 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 200 |
+
--output_dir "${out_dir}" \
|
| 201 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 202 |
+
--init_adapter_dir "${init_adapter}" \
|
| 203 |
+
--seed 0 \
|
| 204 |
+
--gpu_id 0 \
|
| 205 |
+
--stage_i "${stage}" \
|
| 206 |
+
--total_empties_hint "${EMPTIES}" \
|
| 207 |
+
--per_device_train_batch_size 8 \
|
| 208 |
+
--gradient_accumulation_steps 2 \
|
| 209 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 210 |
+
--learning_rate 1e-6 \
|
| 211 |
+
--logging_steps 20 \
|
| 212 |
+
--save_steps 200 \
|
| 213 |
+
--eval_steps 500 \
|
| 214 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 215 |
+
--num_generations 4 \
|
| 216 |
+
--max_prompt_length 1024 \
|
| 217 |
+
--max_completion_length 24 \
|
| 218 |
+
--beta 0.0 \
|
| 219 |
+
--enable_gradient_checkpointing \
|
| 220 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 221 |
+
--reward_good_value 1.25 \
|
| 222 |
+
--penalty_bad_value 1.0 \
|
| 223 |
+
--penalty_malformed 4.0 \
|
| 224 |
+
--penalty_empty 0.5 \
|
| 225 |
+
--penalty_singleton 1.5 \
|
| 226 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 227 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 228 |
+
--eval_solve_rate_stop 0 \
|
| 229 |
+
--min_steps_before_stop 50 \
|
| 230 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 231 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 232 |
+
--use_wandb \
|
| 233 |
+
--wandb_project "sudoku-multi-output-grpo" \
|
| 234 |
+
--wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 235 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 236 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 240 |
+
printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
|
| 241 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 242 |
+
exit 1
|
| 243 |
+
fi
|
| 244 |
+
|
| 245 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 246 |
+
printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 247 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 248 |
+
S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
|
| 249 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 250 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 251 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 252 |
+
exit 1
|
| 253 |
+
fi
|
| 254 |
+
printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
|
| 255 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 256 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 257 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 258 |
+
if [[ -z "${A2}" ]]; then
|
| 259 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 260 |
+
exit 1
|
| 261 |
+
fi
|
| 262 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 263 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 264 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 265 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 266 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 267 |
+
exit 1
|
| 268 |
+
fi
|
| 269 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 270 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 271 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 272 |
+
if [[ -z "${A3}" ]]; then
|
| 273 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 274 |
+
exit 1
|
| 275 |
+
fi
|
| 276 |
+
printf '\nAll phases finished (started at stage-2 GRPO).\n'
|
| 277 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 278 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 279 |
+
exit 0
|
| 280 |
+
fi
|
| 281 |
+
|
| 282 |
+
if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 283 |
+
printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 284 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 285 |
+
A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
|
| 286 |
+
if [[ -z "${A2}" ]]; then
|
| 287 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 288 |
+
exit 1
|
| 289 |
+
fi
|
| 290 |
+
printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
|
| 291 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 292 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 293 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 294 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 295 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 296 |
+
exit 1
|
| 297 |
+
fi
|
| 298 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 299 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 300 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 301 |
+
if [[ -z "${A3}" ]]; then
|
| 302 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 303 |
+
exit 1
|
| 304 |
+
fi
|
| 305 |
+
printf '\nAll phases finished (started after stage-2 GRPO).\n'
|
| 306 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 307 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 308 |
+
exit 0
|
| 309 |
+
fi
|
| 310 |
+
|
| 311 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 312 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 313 |
+
printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
|
| 314 |
+
else
|
| 315 |
+
printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
|
| 316 |
+
fi
|
| 317 |
+
printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
|
| 318 |
+
"${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 319 |
+
|
| 320 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
|
| 321 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 322 |
+
A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
|
| 323 |
+
else
|
| 324 |
+
run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
|
| 325 |
+
A1="$(resolve_grpo_adapter "${G1_DIR}")"
|
| 326 |
+
fi
|
| 327 |
+
if [[ -z "${A1}" ]]; then
|
| 328 |
+
printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
|
| 329 |
+
exit 1
|
| 330 |
+
fi
|
| 331 |
+
printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 332 |
+
|
| 333 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
|
| 334 |
+
run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
|
| 335 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 336 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 337 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 338 |
+
exit 1
|
| 339 |
+
fi
|
| 340 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 341 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 342 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 343 |
+
if [[ -z "${A2}" ]]; then
|
| 344 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 345 |
+
exit 1
|
| 346 |
+
fi
|
| 347 |
+
|
| 348 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 349 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 350 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 351 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 352 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 353 |
+
exit 1
|
| 354 |
+
fi
|
| 355 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 356 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 357 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 358 |
+
if [[ -z "${A3}" ]]; then
|
| 359 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 360 |
+
exit 1
|
| 361 |
+
fi
|
| 362 |
+
|
| 363 |
+
printf '\nAll phases finished.\n'
|
| 364 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 365 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_10empty_sft_stage1_98p.sh
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Stage-1 SFT only for 10-empty: train until eval value_precision AND value_recall
|
| 3 |
+
# both reach 0.98 (or max_steps / optional wall clock). Use the resulting
|
| 4 |
+
# checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for
|
| 5 |
+
# launch_10empty_post_s1sft_stages123_value98.sh.
|
| 6 |
+
#
|
| 7 |
+
# Fresh LoRA on base model:
|
| 8 |
+
# ./launch_10empty_sft_stage1_98p.sh
|
| 9 |
+
#
|
| 10 |
+
# Continue from a prior SFT checkpoint:
|
| 11 |
+
# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_10empty_sft_stage1_98p.sh
|
| 12 |
+
#
|
| 13 |
+
set -euo pipefail
|
| 14 |
+
|
| 15 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 16 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 17 |
+
|
| 18 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 19 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 20 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 21 |
+
|
| 22 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 23 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 24 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 25 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 26 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 27 |
+
|
| 28 |
+
EMPTIES=10
|
| 29 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 30 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 31 |
+
SFT_TARGET="${SFT_TARGET:-0.98}"
|
| 32 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 33 |
+
MAX_STEPS="${MAX_STEPS:-30000}"
|
| 34 |
+
|
| 35 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_10empty_sft98_stage1}"
|
| 36 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}"
|
| 37 |
+
|
| 38 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 39 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 40 |
+
|
| 41 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 42 |
+
|
| 43 |
+
if [[ ! -f "${train_jsonl}" ]]; then
|
| 44 |
+
mkdir -p "$(dirname "${train_jsonl}")"
|
| 45 |
+
printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
|
| 46 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
|
| 47 |
+
fi
|
| 48 |
+
if [[ ! -f "${eval_jsonl}" ]]; then
|
| 49 |
+
mkdir -p "$(dirname "${eval_jsonl}")"
|
| 50 |
+
printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
|
| 51 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 55 |
+
|
| 56 |
+
INIT_FLAGS=()
|
| 57 |
+
if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
|
| 58 |
+
INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
|
| 59 |
+
printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
GC_FLAGS=()
|
| 63 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 64 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 65 |
+
printf 'NOTE: USE_GC=1 — slower, less VRAM.\n'
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
|
| 69 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 70 |
+
else
|
| 71 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
|
| 72 |
+
fi
|
| 73 |
+
printf 'Output: %s\n' "${OUTPUT_DIR}"
|
| 74 |
+
|
| 75 |
+
exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 76 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 77 |
+
--train_jsonl "${train_jsonl}" \
|
| 78 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 79 |
+
--output_dir "${OUTPUT_DIR}" \
|
| 80 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 81 |
+
"${INIT_FLAGS[@]}" \
|
| 82 |
+
--seed 0 \
|
| 83 |
+
--gpu_id 0 \
|
| 84 |
+
--stage_i 1 \
|
| 85 |
+
--total_empties_hint "${EMPTIES}" \
|
| 86 |
+
--per_device_train_batch_size 16 \
|
| 87 |
+
--gradient_accumulation_steps 2 \
|
| 88 |
+
--num_epochs 64.0 \
|
| 89 |
+
--learning_rate 2e-4 \
|
| 90 |
+
--max_grad_norm 1.0 \
|
| 91 |
+
"${GC_FLAGS[@]}" \
|
| 92 |
+
--logging_steps 20 \
|
| 93 |
+
--eval_steps 250 \
|
| 94 |
+
--save_steps 100 \
|
| 95 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 96 |
+
--max_completion_length 24 \
|
| 97 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 98 |
+
--lora_r 32 \
|
| 99 |
+
--lora_alpha 64 \
|
| 100 |
+
--lora_dropout 0.05 \
|
| 101 |
+
--eval_value_precision_stop "${SFT_TARGET}" \
|
| 102 |
+
--eval_value_recall_stop "${SFT_TARGET}" \
|
| 103 |
+
--eval_exact_set_match_stop 0 \
|
| 104 |
+
--eval_solve_rate_stop 0 \
|
| 105 |
+
--min_steps_before_stop 50 \
|
| 106 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 107 |
+
--max_steps "${MAX_STEPS}" \
|
| 108 |
+
--use_wandb \
|
| 109 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 110 |
+
--wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
|
| 111 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 112 |
+
--wandb_entity "${WANDB_ENTITY}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_fixed_slot_sft_stage1_98p.sh
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Stage-1 fixed-slot latent SFT for 20-empty: train until eval value_precision AND
|
| 3 |
+
# value_recall both reach 0.98. This uses prompt + z1 + final_slot during stage 1,
|
| 4 |
+
# while still updating LoRA weights so the transformer can learn how to use z1.
|
| 5 |
+
#
|
| 6 |
+
# Fresh run:
|
| 7 |
+
# ./launch_20empty_fixed_slot_sft_stage1_98p.sh
|
| 8 |
+
#
|
| 9 |
+
# Warm-start from a prior checkpoint:
|
| 10 |
+
# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_fixed_slot_sft_stage1_98p.sh
|
| 11 |
+
#
|
| 12 |
+
set -euo pipefail
|
| 13 |
+
|
| 14 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 15 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 16 |
+
|
| 17 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 18 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 19 |
+
SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 20 |
+
|
| 21 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6}"
|
| 22 |
+
NUM_PROCESSES="${NUM_PROCESSES:-7}"
|
| 23 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 24 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 25 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 26 |
+
|
| 27 |
+
EMPTIES=20
|
| 28 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 29 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 30 |
+
SFT_TARGET="${SFT_TARGET:-0.98}"
|
| 31 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 32 |
+
MAX_STEPS="${MAX_STEPS:-30000}"
|
| 33 |
+
|
| 34 |
+
LORA_R="${LORA_R:-32}"
|
| 35 |
+
LORA_ALPHA="${LORA_ALPHA:-64}"
|
| 36 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 37 |
+
MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-3}"
|
| 38 |
+
|
| 39 |
+
PER_DEVICE_TRAIN_BATCH_SIZE="${PER_DEVICE_TRAIN_BATCH_SIZE:-4}"
|
| 40 |
+
GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-8}"
|
| 41 |
+
|
| 42 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_fixed_slot_sft98_stage1}"
|
| 43 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_fixed_slot_sft98_i1_${EMPTIES}empty}"
|
| 44 |
+
|
| 45 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 46 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 47 |
+
|
| 48 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 49 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 50 |
+
|
| 51 |
+
if [[ ! -f "${train_jsonl}" ]]; then
|
| 52 |
+
mkdir -p "$(dirname "${train_jsonl}")"
|
| 53 |
+
printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
|
| 54 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
|
| 55 |
+
fi
|
| 56 |
+
if [[ ! -f "${eval_jsonl}" ]]; then
|
| 57 |
+
mkdir -p "$(dirname "${eval_jsonl}")"
|
| 58 |
+
printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
|
| 59 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 63 |
+
|
| 64 |
+
INIT_FLAGS=()
|
| 65 |
+
if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
|
| 66 |
+
INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
|
| 67 |
+
printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
GC_FLAGS=()
|
| 71 |
+
if [[ "${USE_GC:-1}" == "1" ]]; then
|
| 72 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 73 |
+
printf 'NOTE: USE_GC=1 - slower, less VRAM.\n'
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
|
| 77 |
+
printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 78 |
+
else
|
| 79 |
+
printf '\n=== Stage1 fixed-slot SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
|
| 80 |
+
fi
|
| 81 |
+
printf 'Output: %s\n' "${OUTPUT_DIR}"
|
| 82 |
+
printf 'LoRA: r=%s alpha=%s dropout=%s | latent_mode=fixed_slots | active_z=1 | max_latent_slots=%s\n' "${LORA_R}" "${LORA_ALPHA}" "${LORA_DROPOUT}" "${MAX_LATENT_SLOTS}"
|
| 83 |
+
printf 'DDP: visible_gpus=%s nproc=%s | batch/device=%s grad_accum=%s\n' "${GPU_IDS}" "${NUM_PROCESSES}" "${PER_DEVICE_TRAIN_BATCH_SIZE}" "${GRADIENT_ACCUMULATION_STEPS}"
|
| 84 |
+
|
| 85 |
+
exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 86 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 87 |
+
--train_jsonl "${train_jsonl}" \
|
| 88 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 89 |
+
--output_dir "${OUTPUT_DIR}" \
|
| 90 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 91 |
+
"${INIT_FLAGS[@]}" \
|
| 92 |
+
--seed 0 \
|
| 93 |
+
--gpu_id 0 \
|
| 94 |
+
--stage_i 1 \
|
| 95 |
+
--num_cot_tokens 1 \
|
| 96 |
+
--latent_mode fixed_slots \
|
| 97 |
+
--max_latent_slots "${MAX_LATENT_SLOTS}" \
|
| 98 |
+
--total_empties_hint "${EMPTIES}" \
|
| 99 |
+
--per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE}" \
|
| 100 |
+
--gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS}" \
|
| 101 |
+
--num_epochs 64.0 \
|
| 102 |
+
--learning_rate 2e-4 \
|
| 103 |
+
--max_grad_norm 1.0 \
|
| 104 |
+
"${GC_FLAGS[@]}" \
|
| 105 |
+
--logging_steps 20 \
|
| 106 |
+
--eval_steps 250 \
|
| 107 |
+
--save_steps 100 \
|
| 108 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 109 |
+
--max_completion_length 24 \
|
| 110 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 111 |
+
--lora_r "${LORA_R}" \
|
| 112 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 113 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 114 |
+
--eval_value_precision_stop "${SFT_TARGET}" \
|
| 115 |
+
--eval_value_recall_stop "${SFT_TARGET}" \
|
| 116 |
+
--eval_exact_set_match_stop 0 \
|
| 117 |
+
--eval_solve_rate_stop 0 \
|
| 118 |
+
--min_steps_before_stop 50 \
|
| 119 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 120 |
+
--max_steps "${MAX_STEPS}" \
|
| 121 |
+
--use_wandb \
|
| 122 |
+
--wandb_project "sudoku-fixed-slot-sft" \
|
| 123 |
+
--wandb_run_name "${WANDB_RUN_NAME:-stage01_fixed_slot_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
|
| 124 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 125 |
+
--wandb_entity "${WANDB_ENTITY}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_full_pipeline_stages123_value98.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Full 20-empty baseline pipeline, matching the successful 10-empty procedure:
|
| 3 |
+
# 1) Stage-1 SFT to value precision/recall >= 0.98
|
| 4 |
+
# 2) Stage-1 GRPO
|
| 5 |
+
# 3) Stage-2 SFT
|
| 6 |
+
# 4) Stage-2 GRPO
|
| 7 |
+
# 5) Stage-3 SFT
|
| 8 |
+
# 6) Stage-3 GRPO
|
| 9 |
+
#
|
| 10 |
+
# This is a wrapper around:
|
| 11 |
+
# - launch_20empty_sft_stage1_98p.sh
|
| 12 |
+
# - launch_20empty_post_s1sft_stages123_value98.sh
|
| 13 |
+
#
|
| 14 |
+
set -euo pipefail
|
| 15 |
+
|
| 16 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 17 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 18 |
+
|
| 19 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 20 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_full_stages123_value98}"
|
| 21 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 22 |
+
|
| 23 |
+
SFT_STAGE1_SCRIPT="${SCRIPT_DIR}/launch_20empty_sft_stage1_98p.sh"
|
| 24 |
+
POST_S1_SCRIPT="${SCRIPT_DIR}/launch_20empty_post_s1sft_stages123_value98.sh"
|
| 25 |
+
S1_DIR="${OUTPUT_ROOT}/20empty/stage01_sft_i1_20empty_sft98"
|
| 26 |
+
|
| 27 |
+
latest_checkpoint_in_dir() {
|
| 28 |
+
local d="$1"
|
| 29 |
+
shopt -s nullglob
|
| 30 |
+
local checkpoints=("${d}"/checkpoint-step-*)
|
| 31 |
+
shopt -u nullglob
|
| 32 |
+
if (( ${#checkpoints[@]} == 0 )); then
|
| 33 |
+
printf ''
|
| 34 |
+
return 1
|
| 35 |
+
fi
|
| 36 |
+
set +o pipefail
|
| 37 |
+
printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
|
| 38 |
+
set -o pipefail
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
printf '=== 20-empty full baseline pipeline (stage1 SFT -> stages123) ===\n'
|
| 42 |
+
printf 'run_tag=%s\n' "${RUN_TAG}"
|
| 43 |
+
printf 'output_root=%s\n' "${OUTPUT_ROOT}"
|
| 44 |
+
|
| 45 |
+
OUTPUT_DIR="${S1_DIR}" \
|
| 46 |
+
RUN_TAG="${RUN_TAG}" \
|
| 47 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
|
| 48 |
+
"${SFT_STAGE1_SCRIPT}"
|
| 49 |
+
|
| 50 |
+
STAGE1_SFT_ADAPTER_DIR="$(latest_checkpoint_in_dir "${S1_DIR}")"
|
| 51 |
+
if [[ -z "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 52 |
+
printf 'ERROR: No checkpoint-step-* found under %s\n' "${S1_DIR}" >&2
|
| 53 |
+
exit 1
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
printf '\nStage-1 SFT complete. Using checkpoint: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
|
| 57 |
+
|
| 58 |
+
STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR}" \
|
| 59 |
+
RUN_TAG="${RUN_TAG}" \
|
| 60 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT}" \
|
| 61 |
+
OUTPUT_ROOT="${OUTPUT_ROOT}" \
|
| 62 |
+
"${POST_S1_SCRIPT}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_latent_recurrent_stages123_value98.sh
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Latent recurrent-hidden (Coconut-style) pipeline for 20-empty Sudoku.
|
| 3 |
+
#
|
| 4 |
+
# Per-stage latent token count grows with curriculum:
|
| 5 |
+
# stage 1 -> num_cot_tokens = 1
|
| 6 |
+
# stage 2 -> num_cot_tokens = 2
|
| 7 |
+
# stage 3 -> num_cot_tokens = 3
|
| 8 |
+
#
|
| 9 |
+
# Pipeline:
|
| 10 |
+
# Stage 1 SFT (cot=1, fresh LoRA + random latent state)
|
| 11 |
+
# -> Stage 1 GRPO (cot=1)
|
| 12 |
+
# -> Stage 2 SFT (cot=2)
|
| 13 |
+
# -> Stage 2 GRPO (cot=2)
|
| 14 |
+
# -> Stage 3 SFT (cot=3)
|
| 15 |
+
# -> Stage 3 GRPO (cot=3)
|
| 16 |
+
#
|
| 17 |
+
# Mirrors the hyperparameters of the successful 20-empty recurrent-hidden stage-1
|
| 18 |
+
# run (bs=8 per-device, gradient accumulation 2, gradient checkpointing ON).
|
| 19 |
+
#
|
| 20 |
+
# Optional overrides:
|
| 21 |
+
# STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
|
| 22 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX
|
| 23 |
+
# VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=...
|
| 24 |
+
set -euo pipefail
|
| 25 |
+
|
| 26 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 27 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 28 |
+
|
| 29 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 30 |
+
SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 31 |
+
GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py"
|
| 32 |
+
|
| 33 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 34 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 35 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 36 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 37 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 38 |
+
|
| 39 |
+
LATENT_MODE="recurrent_hidden"
|
| 40 |
+
EMPTIES=20
|
| 41 |
+
TAG_SUFFIX="latent_recurrent"
|
| 42 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 43 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 44 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 45 |
+
# Per-phase early-stop bars. Default behavior preserved: both phases use
|
| 46 |
+
# VALUE_TARGET unless explicitly overridden. Recommended: SFT_VALUE_TARGET=0.95
|
| 47 |
+
# (let SFT do bulk learning quickly) and GRPO_VALUE_TARGET=0.98 (let GRPO push
|
| 48 |
+
# the last few percent of value precision/recall).
|
| 49 |
+
SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}"
|
| 50 |
+
GRPO_VALUE_TARGET="${GRPO_VALUE_TARGET:-${VALUE_TARGET}}"
|
| 51 |
+
MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
|
| 52 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 53 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 54 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 55 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 56 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 57 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 58 |
+
# -1 resolves in code to hidden_size, and alpha=-1 resolves to 2 * resolved rank.
|
| 59 |
+
LORA_R="${LORA_R:--1}"
|
| 60 |
+
LORA_ALPHA="${LORA_ALPHA:--1}"
|
| 61 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 62 |
+
STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
|
| 63 |
+
SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
|
| 64 |
+
SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
|
| 65 |
+
GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-8}"
|
| 66 |
+
GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}"
|
| 67 |
+
|
| 68 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_recurrent_stages123_value98}"
|
| 69 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 70 |
+
STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
|
| 71 |
+
STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
|
| 72 |
+
# When set, skip both Stage-1 SFT and Stage-1 GRPO and use this adapter
|
| 73 |
+
# directly as the init for Stage-2 SFT. Useful for resuming after a Stage-1
|
| 74 |
+
# GRPO post-training eval hangs but the LoRA adapter is already on disk.
|
| 75 |
+
STAGE1_GRPO_ADAPTER_DIR="${STAGE1_GRPO_ADAPTER_DIR:-}"
|
| 76 |
+
STAGE2_SFT_ADAPTER_DIR="${STAGE2_SFT_ADAPTER_DIR:-}"
|
| 77 |
+
STAGE2_GRPO_ADAPTER_DIR="${STAGE2_GRPO_ADAPTER_DIR:-}"
|
| 78 |
+
# When set, skip Stage-3 SFT and use this adapter directly as the init for
|
| 79 |
+
# Stage-3 GRPO. Useful when SFT plateaus mid-training and we want GRPO to push
|
| 80 |
+
# the last few percentage points without burning more SFT compute.
|
| 81 |
+
STAGE3_SFT_ADAPTER_DIR="${STAGE3_SFT_ADAPTER_DIR:-}"
|
| 82 |
+
# KL anchor for GRPO. Setting > 0 keeps the policy close to the SFT reference
|
| 83 |
+
# and prevents singleton/mode collapse seen in Stage-2 GRPO. 0.0 = no KL.
|
| 84 |
+
GRPO_BETA="${GRPO_BETA:-0.0}"
|
| 85 |
+
|
| 86 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 87 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 88 |
+
|
| 89 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 90 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 91 |
+
|
| 92 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 93 |
+
|
| 94 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 95 |
+
printf 'ERROR: Missing train or eval jsonl.\n' >&2
|
| 96 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 97 |
+
exit 1
|
| 98 |
+
fi
|
| 99 |
+
|
| 100 |
+
latest_sft_step_ckpt() {
|
| 101 |
+
local d="$1"
|
| 102 |
+
shopt -s nullglob
|
| 103 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 104 |
+
shopt -u nullglob
|
| 105 |
+
if (( ${#cks[@]} == 0 )); then
|
| 106 |
+
printf ''
|
| 107 |
+
return 1
|
| 108 |
+
fi
|
| 109 |
+
set +o pipefail
|
| 110 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 111 |
+
set -o pipefail
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
resolve_latent_grpo_adapter() {
|
| 115 |
+
local d="$1"
|
| 116 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 117 |
+
printf '%s\n' "${d}"
|
| 118 |
+
return 0
|
| 119 |
+
fi
|
| 120 |
+
local best="" step=-1
|
| 121 |
+
shopt -s nullglob
|
| 122 |
+
local c
|
| 123 |
+
for c in "${d}"/checkpoint-*; do
|
| 124 |
+
[[ -d "${c}" ]] || continue
|
| 125 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 126 |
+
local n
|
| 127 |
+
n="${c##*checkpoint-}"
|
| 128 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 129 |
+
step=$((10#${n}))
|
| 130 |
+
best="${c}"
|
| 131 |
+
fi
|
| 132 |
+
done
|
| 133 |
+
shopt -u nullglob
|
| 134 |
+
if [[ -n "${best}" ]]; then
|
| 135 |
+
printf '%s\n' "${best}"
|
| 136 |
+
return 0
|
| 137 |
+
fi
|
| 138 |
+
printf ''
|
| 139 |
+
return 1
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
run_latent_sft() {
|
| 143 |
+
local stage="$1"
|
| 144 |
+
local init_adapter="$2"
|
| 145 |
+
local out_dir="$3"
|
| 146 |
+
local lr="$4"
|
| 147 |
+
local cot="$5"
|
| 148 |
+
local ms1=0 ms2=1
|
| 149 |
+
if [[ "${stage}" == "1" ]]; then
|
| 150 |
+
ms1=1
|
| 151 |
+
ms2=0
|
| 152 |
+
fi
|
| 153 |
+
mkdir -p "${out_dir}"
|
| 154 |
+
printf '\n=== Latent(recurrent) stage %s SFT -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${SFT_VALUE_TARGET}" "${cot}" >&2
|
| 155 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
|
| 156 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 157 |
+
--model_name "${MODEL_NAME}" \
|
| 158 |
+
--train_jsonl "${train_jsonl}" \
|
| 159 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 160 |
+
--output_dir "${out_dir}" \
|
| 161 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 162 |
+
--init_adapter_dir "${init_adapter}" \
|
| 163 |
+
--seed 0 \
|
| 164 |
+
--gpu_id 0 \
|
| 165 |
+
--stage_i "${stage}" \
|
| 166 |
+
--num_cot_tokens "${cot}" \
|
| 167 |
+
--latent_mode "${LATENT_MODE}" \
|
| 168 |
+
--total_empties_hint "${EMPTIES}" \
|
| 169 |
+
--mixed_stage1_ratio "${ms1}" \
|
| 170 |
+
--mixed_stage2_ratio "${ms2}" \
|
| 171 |
+
--per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
|
| 172 |
+
--gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
|
| 173 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 174 |
+
--learning_rate "${lr}" \
|
| 175 |
+
--weight_decay 0.0 \
|
| 176 |
+
--enable_gradient_checkpointing \
|
| 177 |
+
--logging_steps 20 \
|
| 178 |
+
--eval_steps 250 \
|
| 179 |
+
--save_steps 200 \
|
| 180 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 181 |
+
--max_completion_length 24 \
|
| 182 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 183 |
+
--eval_value_precision_stop "${SFT_VALUE_TARGET}" \
|
| 184 |
+
--eval_value_recall_stop "${SFT_VALUE_TARGET}" \
|
| 185 |
+
--eval_exact_set_match_stop 0 \
|
| 186 |
+
--eval_solve_rate_stop 0 \
|
| 187 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 188 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 189 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 190 |
+
--reward_good_value 1.25 \
|
| 191 |
+
--penalty_bad_value 1.0 \
|
| 192 |
+
--penalty_malformed 4.0 \
|
| 193 |
+
--penalty_empty 0.5 \
|
| 194 |
+
--penalty_singleton 1.5 \
|
| 195 |
+
--lora_r "${LORA_R}" \
|
| 196 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 197 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 198 |
+
--use_wandb \
|
| 199 |
+
--wandb_project "sudoku-latent-multi-output-sft-recurrent" \
|
| 200 |
+
--wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_cot${cot}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \
|
| 201 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 202 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
run_latent_grpo() {
|
| 206 |
+
local stage="$1"
|
| 207 |
+
local init_adapter="$2"
|
| 208 |
+
local out_dir="$3"
|
| 209 |
+
local cot="$4"
|
| 210 |
+
mkdir -p "${out_dir}"
|
| 211 |
+
printf '\n=== Latent(recurrent) stage %s GRPO -> stop value prec+recall >= %s (cot=%s) ===\n' "${stage}" "${GRPO_VALUE_TARGET}" "${cot}" >&2
|
| 212 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
|
| 213 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 214 |
+
--model_name "${MODEL_NAME}" \
|
| 215 |
+
--train_jsonl "${train_jsonl}" \
|
| 216 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 217 |
+
--output_dir "${out_dir}" \
|
| 218 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 219 |
+
--init_adapter_dir "${init_adapter}" \
|
| 220 |
+
--seed 0 \
|
| 221 |
+
--gpu_id 0 \
|
| 222 |
+
--stage_i "${stage}" \
|
| 223 |
+
--num_cot_tokens "${cot}" \
|
| 224 |
+
--latent_mode "${LATENT_MODE}" \
|
| 225 |
+
--total_empties_hint "${EMPTIES}" \
|
| 226 |
+
--mixed_stage1_ratio 0 \
|
| 227 |
+
--mixed_stage2_ratio 1 \
|
| 228 |
+
--per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \
|
| 229 |
+
--gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \
|
| 230 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 231 |
+
--learning_rate 1e-6 \
|
| 232 |
+
--logging_steps 20 \
|
| 233 |
+
--save_steps 200 \
|
| 234 |
+
--eval_steps 500 \
|
| 235 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 236 |
+
--num_generations 4 \
|
| 237 |
+
--max_prompt_length 1024 \
|
| 238 |
+
--max_completion_length 24 \
|
| 239 |
+
--beta "${GRPO_BETA}" \
|
| 240 |
+
--enable_gradient_checkpointing \
|
| 241 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 242 |
+
--reward_good_value 1.25 \
|
| 243 |
+
--penalty_bad_value 1.0 \
|
| 244 |
+
--penalty_malformed 4.0 \
|
| 245 |
+
--penalty_empty 0.5 \
|
| 246 |
+
--penalty_singleton 1.5 \
|
| 247 |
+
--eval_value_precision_stop "${GRPO_VALUE_TARGET}" \
|
| 248 |
+
--eval_value_recall_stop "${GRPO_VALUE_TARGET}" \
|
| 249 |
+
--eval_solve_rate_stop 0 \
|
| 250 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 251 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 252 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 253 |
+
--lora_r "${LORA_R}" \
|
| 254 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 255 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 256 |
+
--use_wandb \
|
| 257 |
+
--wandb_project "sudoku-latent-multi-output-grpo-recurrent" \
|
| 258 |
+
--wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_cot${cot}_val${GRPO_VALUE_TARGET}_${RUN_TAG}" \
|
| 259 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 260 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 264 |
+
printf 'Latent mode: %s (cot grows 1->2->3 per stage)\n' "${LATENT_MODE}"
|
| 265 |
+
printf 'Value gate: SFT prec+recall >= %s ; GRPO prec+recall >= %s (min_steps=%s) ; GRPO_BETA=%s\n' "${SFT_VALUE_TARGET}" "${GRPO_VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}" "${GRPO_BETA}"
|
| 266 |
+
printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
|
| 267 |
+
|
| 268 |
+
S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 269 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 270 |
+
if [[ -n "${STAGE1_GRPO_ADAPTER_DIR}" ]]; then
|
| 271 |
+
A1="${STAGE1_GRPO_ADAPTER_DIR}"
|
| 272 |
+
printf 'Using existing stage-1 GRPO adapter (skipping stage-1 SFT + GRPO): %s\n' "${A1}" >&2
|
| 273 |
+
elif [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 274 |
+
G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
|
| 275 |
+
printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
|
| 276 |
+
run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
|
| 277 |
+
A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
|
| 278 |
+
else
|
| 279 |
+
run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
|
| 280 |
+
G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
|
| 281 |
+
if [[ -z "${G1_SFT_CKPT}" ]]; then
|
| 282 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
|
| 283 |
+
exit 1
|
| 284 |
+
fi
|
| 285 |
+
run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
|
| 286 |
+
A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
|
| 287 |
+
fi
|
| 288 |
+
if [[ -z "${A1}" ]]; then
|
| 289 |
+
printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
|
| 290 |
+
exit 1
|
| 291 |
+
fi
|
| 292 |
+
printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 293 |
+
|
| 294 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 295 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 296 |
+
if [[ -n "${STAGE2_GRPO_ADAPTER_DIR}" ]]; then
|
| 297 |
+
A2="${STAGE2_GRPO_ADAPTER_DIR}"
|
| 298 |
+
printf 'Using existing stage-2 GRPO adapter (skipping stage-2 SFT + GRPO): %s\n' "${A2}" >&2
|
| 299 |
+
elif [[ -n "${STAGE2_SFT_ADAPTER_DIR}" ]]; then
|
| 300 |
+
CKPT_S2="${STAGE2_SFT_ADAPTER_DIR}"
|
| 301 |
+
printf 'Using existing stage-2 SFT checkpoint as GRPO init (skipping stage-2 SFT train): %s\n' "${CKPT_S2}" >&2
|
| 302 |
+
run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
|
| 303 |
+
A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
|
| 304 |
+
else
|
| 305 |
+
run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
|
| 306 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 307 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 308 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 309 |
+
exit 1
|
| 310 |
+
fi
|
| 311 |
+
run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
|
| 312 |
+
A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
|
| 313 |
+
fi
|
| 314 |
+
if [[ -z "${A2}" ]]; then
|
| 315 |
+
printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 316 |
+
exit 1
|
| 317 |
+
fi
|
| 318 |
+
|
| 319 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 320 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 321 |
+
if [[ -n "${STAGE3_SFT_ADAPTER_DIR}" ]]; then
|
| 322 |
+
CKPT_S3="${STAGE3_SFT_ADAPTER_DIR}"
|
| 323 |
+
printf 'Using existing stage-3 SFT checkpoint as GRPO init (skipping stage-3 SFT train): %s\n' "${CKPT_S3}" >&2
|
| 324 |
+
else
|
| 325 |
+
run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
|
| 326 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 327 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 328 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 329 |
+
exit 1
|
| 330 |
+
fi
|
| 331 |
+
fi
|
| 332 |
+
run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
|
| 333 |
+
A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
|
| 334 |
+
if [[ -z "${A3}" ]]; then
|
| 335 |
+
printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 336 |
+
exit 1
|
| 337 |
+
fi
|
| 338 |
+
|
| 339 |
+
printf '\nAll latent(recurrent) phases finished.\n'
|
| 340 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 341 |
+
printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_latent_residual_stages123_value98.sh
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Latent residual-projector pipeline for 20-empty Sudoku.
|
| 3 |
+
# Default behavior:
|
| 4 |
+
# - Stage 1 SFT starts from fresh LoRA + random latent residual modules
|
| 5 |
+
# - Then stage 1 GRPO -> stage 2 SFT -> stage 2 GRPO -> stage 3 SFT -> stage 3 GRPO
|
| 6 |
+
# Optional:
|
| 7 |
+
# STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
|
| 8 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft/checkpoint-step-XXXX
|
| 9 |
+
# VALUE_TARGET=0.98 TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=...
|
| 10 |
+
set -euo pipefail
|
| 11 |
+
|
| 12 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 13 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 14 |
+
|
| 15 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 16 |
+
SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 17 |
+
GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
|
| 18 |
+
|
| 19 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 20 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 21 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 22 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 23 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 24 |
+
|
| 25 |
+
EMPTIES=20
|
| 26 |
+
TAG_SUFFIX="latent_residual"
|
| 27 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 28 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 29 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 30 |
+
MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
|
| 31 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 32 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 33 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 34 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 35 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 36 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 37 |
+
LORA_R="${LORA_R:-32}"
|
| 38 |
+
LORA_ALPHA="${LORA_ALPHA:-64}"
|
| 39 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 40 |
+
STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
|
| 41 |
+
|
| 42 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_latent_residual_stages123_value98}"
|
| 43 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 44 |
+
STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
|
| 45 |
+
STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
|
| 46 |
+
|
| 47 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 48 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 49 |
+
|
| 50 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 51 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 52 |
+
|
| 53 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 54 |
+
|
| 55 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 56 |
+
printf 'ERROR: Missing train or eval jsonl.\n' >&2
|
| 57 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
latest_sft_step_ckpt() {
|
| 62 |
+
local d="$1"
|
| 63 |
+
shopt -s nullglob
|
| 64 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 65 |
+
shopt -u nullglob
|
| 66 |
+
if (( ${#cks[@]} == 0 )); then
|
| 67 |
+
printf ''
|
| 68 |
+
return 1
|
| 69 |
+
fi
|
| 70 |
+
set +o pipefail
|
| 71 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 72 |
+
set -o pipefail
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
resolve_latent_grpo_adapter() {
|
| 76 |
+
local d="$1"
|
| 77 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 78 |
+
printf '%s\n' "${d}"
|
| 79 |
+
return 0
|
| 80 |
+
fi
|
| 81 |
+
local best="" step=-1
|
| 82 |
+
shopt -s nullglob
|
| 83 |
+
local c
|
| 84 |
+
for c in "${d}"/checkpoint-*; do
|
| 85 |
+
[[ -d "${c}" ]] || continue
|
| 86 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 87 |
+
[[ -f "${c}/latent_cot_state.pt" ]] || continue
|
| 88 |
+
local n
|
| 89 |
+
n="${c##*checkpoint-}"
|
| 90 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 91 |
+
step=$((10#${n}))
|
| 92 |
+
best="${c}"
|
| 93 |
+
fi
|
| 94 |
+
done
|
| 95 |
+
shopt -u nullglob
|
| 96 |
+
if [[ -n "${best}" ]]; then
|
| 97 |
+
printf '%s\n' "${best}"
|
| 98 |
+
return 0
|
| 99 |
+
fi
|
| 100 |
+
printf ''
|
| 101 |
+
return 1
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
run_latent_sft() {
|
| 105 |
+
local stage="$1"
|
| 106 |
+
local init_adapter="$2"
|
| 107 |
+
local out_dir="$3"
|
| 108 |
+
local lr="$4"
|
| 109 |
+
local cot="$5"
|
| 110 |
+
local ms1=0 ms2=1
|
| 111 |
+
if [[ "${stage}" == "1" ]]; then
|
| 112 |
+
ms1=1
|
| 113 |
+
ms2=0
|
| 114 |
+
fi
|
| 115 |
+
mkdir -p "${out_dir}"
|
| 116 |
+
printf '\n=== Latent stage %s SFT -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
|
| 117 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
|
| 118 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 119 |
+
--model_name "${MODEL_NAME}" \
|
| 120 |
+
--train_jsonl "${train_jsonl}" \
|
| 121 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 122 |
+
--output_dir "${out_dir}" \
|
| 123 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 124 |
+
--init_adapter_dir "${init_adapter}" \
|
| 125 |
+
--seed 0 \
|
| 126 |
+
--gpu_id 0 \
|
| 127 |
+
--stage_i "${stage}" \
|
| 128 |
+
--num_cot_tokens "${cot}" \
|
| 129 |
+
--total_empties_hint "${EMPTIES}" \
|
| 130 |
+
--mixed_stage1_ratio "${ms1}" \
|
| 131 |
+
--mixed_stage2_ratio "${ms2}" \
|
| 132 |
+
--gradient_accumulation_steps 2 \
|
| 133 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 134 |
+
--learning_rate "${lr}" \
|
| 135 |
+
--weight_decay 0.0 \
|
| 136 |
+
--enable_gradient_checkpointing \
|
| 137 |
+
--logging_steps 20 \
|
| 138 |
+
--eval_steps 250 \
|
| 139 |
+
--save_steps 200 \
|
| 140 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 141 |
+
--max_completion_length 24 \
|
| 142 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 143 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 144 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 145 |
+
--eval_exact_set_match_stop 0 \
|
| 146 |
+
--eval_solve_rate_stop 0 \
|
| 147 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 148 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 149 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 150 |
+
--reward_good_value 1.25 \
|
| 151 |
+
--penalty_bad_value 1.0 \
|
| 152 |
+
--penalty_malformed 4.0 \
|
| 153 |
+
--penalty_empty 0.5 \
|
| 154 |
+
--penalty_singleton 1.5 \
|
| 155 |
+
--lora_r "${LORA_R}" \
|
| 156 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 157 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 158 |
+
--use_wandb \
|
| 159 |
+
--wandb_project "sudoku-latent-multi-output-sft-residual-projector" \
|
| 160 |
+
--wandb_run_name "latent20_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 161 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 162 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
run_latent_grpo() {
|
| 166 |
+
local stage="$1"
|
| 167 |
+
local init_adapter="$2"
|
| 168 |
+
local out_dir="$3"
|
| 169 |
+
local cot="$4"
|
| 170 |
+
mkdir -p "${out_dir}"
|
| 171 |
+
printf '\n=== Latent stage %s GRPO -> stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
|
| 172 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
|
| 173 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 174 |
+
--model_name "${MODEL_NAME}" \
|
| 175 |
+
--train_jsonl "${train_jsonl}" \
|
| 176 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 177 |
+
--output_dir "${out_dir}" \
|
| 178 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 179 |
+
--init_adapter_dir "${init_adapter}" \
|
| 180 |
+
--seed 0 \
|
| 181 |
+
--gpu_id 0 \
|
| 182 |
+
--stage_i "${stage}" \
|
| 183 |
+
--num_cot_tokens "${cot}" \
|
| 184 |
+
--total_empties_hint "${EMPTIES}" \
|
| 185 |
+
--mixed_stage1_ratio 0 \
|
| 186 |
+
--mixed_stage2_ratio 1 \
|
| 187 |
+
--per_device_train_batch_size 8 \
|
| 188 |
+
--gradient_accumulation_steps 2 \
|
| 189 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 190 |
+
--learning_rate 1e-6 \
|
| 191 |
+
--logging_steps 20 \
|
| 192 |
+
--save_steps 200 \
|
| 193 |
+
--eval_steps 500 \
|
| 194 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 195 |
+
--num_generations 4 \
|
| 196 |
+
--max_prompt_length 1024 \
|
| 197 |
+
--max_completion_length 24 \
|
| 198 |
+
--beta 0.0 \
|
| 199 |
+
--enable_gradient_checkpointing \
|
| 200 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 201 |
+
--reward_good_value 1.25 \
|
| 202 |
+
--penalty_bad_value 1.0 \
|
| 203 |
+
--penalty_malformed 4.0 \
|
| 204 |
+
--penalty_empty 0.5 \
|
| 205 |
+
--penalty_singleton 1.5 \
|
| 206 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 207 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 208 |
+
--eval_solve_rate_stop 0 \
|
| 209 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 210 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 211 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 212 |
+
--lora_r "${LORA_R}" \
|
| 213 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 214 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 215 |
+
--use_wandb \
|
| 216 |
+
--wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \
|
| 217 |
+
--wandb_run_name "latent20_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 218 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 219 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 223 |
+
printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}"
|
| 224 |
+
printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
|
| 225 |
+
|
| 226 |
+
S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 227 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 228 |
+
if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 229 |
+
G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
|
| 230 |
+
printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
|
| 231 |
+
else
|
| 232 |
+
run_latent_sft 1 "${STAGE1_INIT_ADAPTER_DIR}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
|
| 233 |
+
G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
|
| 234 |
+
if [[ -z "${G1_SFT_CKPT}" ]]; then
|
| 235 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
|
| 236 |
+
exit 1
|
| 237 |
+
fi
|
| 238 |
+
fi
|
| 239 |
+
run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
|
| 240 |
+
A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
|
| 241 |
+
if [[ -z "${A1}" ]]; then
|
| 242 |
+
printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
|
| 243 |
+
exit 1
|
| 244 |
+
fi
|
| 245 |
+
printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 246 |
+
|
| 247 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 248 |
+
run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
|
| 249 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 250 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 251 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 252 |
+
exit 1
|
| 253 |
+
fi
|
| 254 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 255 |
+
run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
|
| 256 |
+
A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
|
| 257 |
+
if [[ -z "${A2}" ]]; then
|
| 258 |
+
printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 259 |
+
exit 1
|
| 260 |
+
fi
|
| 261 |
+
|
| 262 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 263 |
+
run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
|
| 264 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 265 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 266 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 267 |
+
exit 1
|
| 268 |
+
fi
|
| 269 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 270 |
+
run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
|
| 271 |
+
A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
|
| 272 |
+
if [[ -z "${A3}" ]]; then
|
| 273 |
+
printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 274 |
+
exit 1
|
| 275 |
+
fi
|
| 276 |
+
|
| 277 |
+
printf '\nAll latent residual phases finished.\n'
|
| 278 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 279 |
+
printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_post_s1sft_stages123_value98.sh
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run AFTER stage-1 SFT finishes (20-empty). Order:
|
| 3 |
+
# 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
|
| 4 |
+
# 2) Stage-2 SFT (init = stage-1 GRPO adapter)
|
| 5 |
+
# 3) Stage-2 GRPO (init = stage-2 SFT adapter)
|
| 6 |
+
# 4) Stage-3 SFT (init = stage-2 GRPO adapter)
|
| 7 |
+
# 5) Stage-3 GRPO (init = stage-3 SFT adapter)
|
| 8 |
+
#
|
| 9 |
+
# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
|
| 10 |
+
# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
|
| 11 |
+
# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
|
| 12 |
+
# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
|
| 13 |
+
#
|
| 14 |
+
# Required (full pipeline from stage-1 SFT):
|
| 15 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
|
| 16 |
+
#
|
| 17 |
+
# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
|
| 18 |
+
# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_20empty
|
| 19 |
+
# (OUTPUT_ROOT defaults to dirname of that dir.)
|
| 20 |
+
#
|
| 21 |
+
# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
|
| 22 |
+
# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_20empty
|
| 23 |
+
#
|
| 24 |
+
# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
|
| 25 |
+
# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_20empty
|
| 26 |
+
#
|
| 27 |
+
# Optional:
|
| 28 |
+
# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
|
| 29 |
+
# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
|
| 30 |
+
#
|
| 31 |
+
set -euo pipefail
|
| 32 |
+
|
| 33 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 34 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 35 |
+
|
| 36 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 37 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 38 |
+
GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 39 |
+
|
| 40 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 41 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 42 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 43 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 44 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 45 |
+
|
| 46 |
+
EMPTIES=20
|
| 47 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 48 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 49 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 50 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 51 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 52 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 53 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 54 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 55 |
+
|
| 56 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_stages123_value98}"
|
| 57 |
+
START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
|
| 58 |
+
START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
|
| 59 |
+
RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
|
| 60 |
+
|
| 61 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 62 |
+
if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 63 |
+
printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 64 |
+
exit 1
|
| 65 |
+
fi
|
| 66 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
|
| 67 |
+
elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 68 |
+
if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 69 |
+
printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
|
| 73 |
+
elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 74 |
+
if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 75 |
+
printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
|
| 76 |
+
exit 1
|
| 77 |
+
fi
|
| 78 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
|
| 79 |
+
else
|
| 80 |
+
if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 81 |
+
printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 88 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 89 |
+
|
| 90 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 91 |
+
|
| 92 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 93 |
+
|
| 94 |
+
latest_sft_step_ckpt() {
|
| 95 |
+
local d="$1"
|
| 96 |
+
shopt -s nullglob
|
| 97 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 98 |
+
shopt -u nullglob
|
| 99 |
+
if (( ${#cks[@]} == 0 )); then
|
| 100 |
+
printf ''
|
| 101 |
+
return 1
|
| 102 |
+
fi
|
| 103 |
+
set +o pipefail
|
| 104 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 105 |
+
set -o pipefail
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
resolve_grpo_adapter() {
|
| 109 |
+
local d="$1"
|
| 110 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 111 |
+
printf '%s\n' "${d}"
|
| 112 |
+
return 0
|
| 113 |
+
fi
|
| 114 |
+
local best="" step=-1
|
| 115 |
+
shopt -s nullglob
|
| 116 |
+
local c
|
| 117 |
+
for c in "${d}"/checkpoint-*; do
|
| 118 |
+
[[ -d "${c}" ]] || continue
|
| 119 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 120 |
+
local n
|
| 121 |
+
n="${c##*checkpoint-}"
|
| 122 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 123 |
+
step=$((10#${n}))
|
| 124 |
+
best="${c}"
|
| 125 |
+
fi
|
| 126 |
+
done
|
| 127 |
+
shopt -u nullglob
|
| 128 |
+
if [[ -n "${best}" ]]; then
|
| 129 |
+
printf '%s\n' "${best}"
|
| 130 |
+
return 0
|
| 131 |
+
fi
|
| 132 |
+
printf ''
|
| 133 |
+
return 1
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
GC_FLAGS=()
|
| 137 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 138 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
run_sft() {
|
| 142 |
+
local stage="$1"
|
| 143 |
+
local init_adapter="$2"
|
| 144 |
+
local out_dir="$3"
|
| 145 |
+
local lr="$4"
|
| 146 |
+
mkdir -p "${out_dir}"
|
| 147 |
+
printf '\n=== Stage %s SFT -> stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
|
| 148 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 149 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 150 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 151 |
+
--train_jsonl "${train_jsonl}" \
|
| 152 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 153 |
+
--output_dir "${out_dir}" \
|
| 154 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 155 |
+
--init_adapter_dir "${init_adapter}" \
|
| 156 |
+
--seed 0 \
|
| 157 |
+
--gpu_id 0 \
|
| 158 |
+
--stage_i "${stage}" \
|
| 159 |
+
--total_empties_hint "${EMPTIES}" \
|
| 160 |
+
--per_device_train_batch_size 16 \
|
| 161 |
+
--gradient_accumulation_steps 2 \
|
| 162 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 163 |
+
--learning_rate "${lr}" \
|
| 164 |
+
--max_grad_norm 1.0 \
|
| 165 |
+
"${GC_FLAGS[@]}" \
|
| 166 |
+
--logging_steps 20 \
|
| 167 |
+
--eval_steps 250 \
|
| 168 |
+
--save_steps 200 \
|
| 169 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 170 |
+
--max_completion_length 24 \
|
| 171 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 172 |
+
--lora_r 32 \
|
| 173 |
+
--lora_alpha 64 \
|
| 174 |
+
--lora_dropout 0.05 \
|
| 175 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 176 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 177 |
+
--eval_exact_set_match_stop 0 \
|
| 178 |
+
--eval_solve_rate_stop 0 \
|
| 179 |
+
--min_steps_before_stop 50 \
|
| 180 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 181 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 182 |
+
--use_wandb \
|
| 183 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 184 |
+
--wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 185 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 186 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
run_grpo() {
|
| 190 |
+
local stage="$1"
|
| 191 |
+
local init_adapter="$2"
|
| 192 |
+
local out_dir="$3"
|
| 193 |
+
mkdir -p "${out_dir}"
|
| 194 |
+
printf '\n=== Stage %s GRPO -> stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
|
| 195 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 196 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 197 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 198 |
+
--train_jsonl "${train_jsonl}" \
|
| 199 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 200 |
+
--output_dir "${out_dir}" \
|
| 201 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 202 |
+
--init_adapter_dir "${init_adapter}" \
|
| 203 |
+
--seed 0 \
|
| 204 |
+
--gpu_id 0 \
|
| 205 |
+
--stage_i "${stage}" \
|
| 206 |
+
--total_empties_hint "${EMPTIES}" \
|
| 207 |
+
--per_device_train_batch_size 8 \
|
| 208 |
+
--gradient_accumulation_steps 2 \
|
| 209 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 210 |
+
--learning_rate 1e-6 \
|
| 211 |
+
--logging_steps 20 \
|
| 212 |
+
--save_steps 200 \
|
| 213 |
+
--eval_steps 500 \
|
| 214 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 215 |
+
--num_generations 4 \
|
| 216 |
+
--max_prompt_length 1024 \
|
| 217 |
+
--max_completion_length 24 \
|
| 218 |
+
--beta 0.0 \
|
| 219 |
+
--enable_gradient_checkpointing \
|
| 220 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 221 |
+
--lora_r 32 \
|
| 222 |
+
--lora_alpha 64 \
|
| 223 |
+
--lora_dropout 0.05 \
|
| 224 |
+
--reward_good_value 1.25 \
|
| 225 |
+
--penalty_bad_value 1.0 \
|
| 226 |
+
--penalty_malformed 4.0 \
|
| 227 |
+
--penalty_empty 0.5 \
|
| 228 |
+
--penalty_singleton 1.5 \
|
| 229 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 230 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 231 |
+
--eval_solve_rate_stop 0 \
|
| 232 |
+
--min_steps_before_stop 50 \
|
| 233 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 234 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 235 |
+
--use_wandb \
|
| 236 |
+
--wandb_project "sudoku-multi-output-grpo" \
|
| 237 |
+
--wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 238 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 239 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 243 |
+
printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
|
| 244 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 245 |
+
exit 1
|
| 246 |
+
fi
|
| 247 |
+
|
| 248 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 249 |
+
printf 'Fast-forward: stage-2 SFT dir %s -> stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 250 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 251 |
+
S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
|
| 252 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 253 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 254 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 255 |
+
exit 1
|
| 256 |
+
fi
|
| 257 |
+
printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
|
| 258 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 259 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 260 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 261 |
+
if [[ -z "${A2}" ]]; then
|
| 262 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 263 |
+
exit 1
|
| 264 |
+
fi
|
| 265 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 266 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 267 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 268 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 269 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 270 |
+
exit 1
|
| 271 |
+
fi
|
| 272 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 273 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 274 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 275 |
+
if [[ -z "${A3}" ]]; then
|
| 276 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 277 |
+
exit 1
|
| 278 |
+
fi
|
| 279 |
+
printf '\nAll phases finished (started at stage-2 GRPO).\n'
|
| 280 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 281 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 282 |
+
exit 0
|
| 283 |
+
fi
|
| 284 |
+
|
| 285 |
+
if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 286 |
+
printf 'Fast-forward: stage-2 GRPO dir %s -> stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 287 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 288 |
+
A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
|
| 289 |
+
if [[ -z "${A2}" ]]; then
|
| 290 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 291 |
+
exit 1
|
| 292 |
+
fi
|
| 293 |
+
printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
|
| 294 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 295 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 296 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 297 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 298 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 299 |
+
exit 1
|
| 300 |
+
fi
|
| 301 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 302 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 303 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 304 |
+
if [[ -z "${A3}" ]]; then
|
| 305 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 306 |
+
exit 1
|
| 307 |
+
fi
|
| 308 |
+
printf '\nAll phases finished (started after stage-2 GRPO).\n'
|
| 309 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 310 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 311 |
+
exit 0
|
| 312 |
+
fi
|
| 313 |
+
|
| 314 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 315 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 316 |
+
printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
|
| 317 |
+
else
|
| 318 |
+
printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
|
| 319 |
+
fi
|
| 320 |
+
printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
|
| 321 |
+
"${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 322 |
+
|
| 323 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
|
| 324 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 325 |
+
A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
|
| 326 |
+
else
|
| 327 |
+
run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
|
| 328 |
+
A1="$(resolve_grpo_adapter "${G1_DIR}")"
|
| 329 |
+
fi
|
| 330 |
+
if [[ -z "${A1}" ]]; then
|
| 331 |
+
printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
|
| 332 |
+
exit 1
|
| 333 |
+
fi
|
| 334 |
+
printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 335 |
+
|
| 336 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
|
| 337 |
+
run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
|
| 338 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 339 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 340 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 341 |
+
exit 1
|
| 342 |
+
fi
|
| 343 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 344 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 345 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 346 |
+
if [[ -z "${A2}" ]]; then
|
| 347 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 348 |
+
exit 1
|
| 349 |
+
fi
|
| 350 |
+
|
| 351 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 352 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 353 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 354 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 355 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 356 |
+
exit 1
|
| 357 |
+
fi
|
| 358 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 359 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 360 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 361 |
+
if [[ -z "${A3}" ]]; then
|
| 362 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 363 |
+
exit 1
|
| 364 |
+
fi
|
| 365 |
+
|
| 366 |
+
printf '\nAll phases finished.\n'
|
| 367 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 368 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_sft_stage1_98p.sh
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Stage-1 SFT only for 20-empty: train until eval value_precision AND value_recall
|
| 3 |
+
# both reach 0.98 (or max_steps / optional wall clock). Use the resulting
|
| 4 |
+
# checkpoint-step-* directory as STAGE1_SFT_ADAPTER_DIR for
|
| 5 |
+
# launch_20empty_post_s1sft_stages123_value98.sh.
|
| 6 |
+
#
|
| 7 |
+
# Fresh LoRA on base model:
|
| 8 |
+
# ./launch_20empty_sft_stage1_98p.sh
|
| 9 |
+
#
|
| 10 |
+
# Continue from a prior SFT checkpoint:
|
| 11 |
+
# INIT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX ./launch_20empty_sft_stage1_98p.sh
|
| 12 |
+
#
|
| 13 |
+
set -euo pipefail
|
| 14 |
+
|
| 15 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 16 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 17 |
+
|
| 18 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 19 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 20 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 21 |
+
|
| 22 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 23 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 24 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 25 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 26 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 27 |
+
|
| 28 |
+
EMPTIES=20
|
| 29 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 30 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 31 |
+
SFT_TARGET="${SFT_TARGET:-0.98}"
|
| 32 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 33 |
+
MAX_STEPS="${MAX_STEPS:-30000}"
|
| 34 |
+
|
| 35 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_20empty_sft98_stage1}"
|
| 36 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft98}"
|
| 37 |
+
|
| 38 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 39 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 40 |
+
|
| 41 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 42 |
+
|
| 43 |
+
if [[ ! -f "${train_jsonl}" ]]; then
|
| 44 |
+
mkdir -p "$(dirname "${train_jsonl}")"
|
| 45 |
+
printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
|
| 46 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles "${TRAIN_PUZZLES}" --empties "${EMPTIES}" --seed 0
|
| 47 |
+
fi
|
| 48 |
+
if [[ ! -f "${eval_jsonl}" ]]; then
|
| 49 |
+
mkdir -p "$(dirname "${eval_jsonl}")"
|
| 50 |
+
printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
|
| 51 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 55 |
+
|
| 56 |
+
INIT_FLAGS=()
|
| 57 |
+
if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
|
| 58 |
+
INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
|
| 59 |
+
printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
GC_FLAGS=()
|
| 63 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 64 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 65 |
+
printf 'NOTE: USE_GC=1 - slower, less VRAM.\n'
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
|
| 69 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 70 |
+
else
|
| 71 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
|
| 72 |
+
fi
|
| 73 |
+
printf 'Output: %s\n' "${OUTPUT_DIR}"
|
| 74 |
+
|
| 75 |
+
exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 76 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 77 |
+
--train_jsonl "${train_jsonl}" \
|
| 78 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 79 |
+
--output_dir "${OUTPUT_DIR}" \
|
| 80 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 81 |
+
"${INIT_FLAGS[@]}" \
|
| 82 |
+
--seed 0 \
|
| 83 |
+
--gpu_id 0 \
|
| 84 |
+
--stage_i 1 \
|
| 85 |
+
--total_empties_hint "${EMPTIES}" \
|
| 86 |
+
--per_device_train_batch_size 16 \
|
| 87 |
+
--gradient_accumulation_steps 2 \
|
| 88 |
+
--num_epochs 64.0 \
|
| 89 |
+
--learning_rate 2e-4 \
|
| 90 |
+
--max_grad_norm 1.0 \
|
| 91 |
+
"${GC_FLAGS[@]}" \
|
| 92 |
+
--logging_steps 20 \
|
| 93 |
+
--eval_steps 250 \
|
| 94 |
+
--save_steps 100 \
|
| 95 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 96 |
+
--max_completion_length 24 \
|
| 97 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 98 |
+
--lora_r 32 \
|
| 99 |
+
--lora_alpha 64 \
|
| 100 |
+
--lora_dropout 0.05 \
|
| 101 |
+
--eval_value_precision_stop "${SFT_TARGET}" \
|
| 102 |
+
--eval_value_recall_stop "${SFT_TARGET}" \
|
| 103 |
+
--eval_exact_set_match_stop 0 \
|
| 104 |
+
--eval_solve_rate_stop 0 \
|
| 105 |
+
--min_steps_before_stop 50 \
|
| 106 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 107 |
+
--max_steps "${MAX_STEPS}" \
|
| 108 |
+
--use_wandb \
|
| 109 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 110 |
+
--wandb_run_name "${WANDB_RUN_NAME:-stage01_sft98_i1_${EMPTIES}empty_${RUN_TAG}}" \
|
| 111 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 112 |
+
--wandb_entity "${WANDB_ENTITY}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_stage1_sft_all_latent_modes_parallel.sh
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Stage-1 SFT sweep over all latent modes for 20-empty Sudoku.
|
| 3 |
+
#
|
| 4 |
+
# Runs four independent SFT jobs in parallel:
|
| 5 |
+
# residual, fixed_slots, recurrent_hidden, latent_seeds
|
| 6 |
+
#
|
| 7 |
+
# Default GPU split on an 8-GPU node:
|
| 8 |
+
# residual -> CUDA_VISIBLE_DEVICES=0,1
|
| 9 |
+
# fixed_slots -> CUDA_VISIBLE_DEVICES=2,3
|
| 10 |
+
# recurrent_hidden -> CUDA_VISIBLE_DEVICES=4,5
|
| 11 |
+
# latent_seeds -> CUDA_VISIBLE_DEVICES=6,7
|
| 12 |
+
#
|
| 13 |
+
# Useful overrides:
|
| 14 |
+
# RUN_TAG=... CHECKPOINT_ROOT=...
|
| 15 |
+
# GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1
|
| 16 |
+
# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 SFT_VALUE_TARGET=0.98
|
| 17 |
+
# STAGE1_INIT_ADAPTER_DIR=/path/to/init_adapter
|
| 18 |
+
set -euo pipefail
|
| 19 |
+
|
| 20 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 21 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 22 |
+
|
| 23 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 24 |
+
SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 25 |
+
|
| 26 |
+
MODES=("residual" "fixed_slots" "recurrent_hidden" "latent_seeds")
|
| 27 |
+
MODE_TAGS=("latent_residual" "latent_fixed_slots" "latent_recurrent_hidden" "latent_seeds")
|
| 28 |
+
|
| 29 |
+
# Space-separated list of CUDA_VISIBLE_DEVICES groups, one per latent mode.
|
| 30 |
+
# Example for one GPU per method: GPU_GROUPS_SPEC="0 1 2 3" NPROC_PER_JOB=1
|
| 31 |
+
GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}"
|
| 32 |
+
read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}"
|
| 33 |
+
|
| 34 |
+
NPROC_PER_JOB="${NPROC_PER_JOB:-2}"
|
| 35 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 36 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 37 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 38 |
+
|
| 39 |
+
EMPTIES="${EMPTIES:-20}"
|
| 40 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 41 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 42 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 43 |
+
SFT_VALUE_TARGET="${SFT_VALUE_TARGET:-${VALUE_TARGET}}"
|
| 44 |
+
MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
|
| 45 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 46 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 47 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 48 |
+
|
| 49 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 50 |
+
# Match the recurrent 20-empty launcher defaults: -1 resolves inside the
|
| 51 |
+
# trainer to hidden_size, and alpha=-1 resolves to 2 * resolved rank.
|
| 52 |
+
LORA_R="${LORA_R:--1}"
|
| 53 |
+
LORA_ALPHA="${LORA_ALPHA:--1}"
|
| 54 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 55 |
+
STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
|
| 56 |
+
SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
|
| 57 |
+
SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
|
| 58 |
+
NUM_COT_TOKENS="${NUM_COT_TOKENS:-1}"
|
| 59 |
+
MAX_LATENT_SLOTS="${MAX_LATENT_SLOTS:-8}"
|
| 60 |
+
MAX_LATENT_SEEDS="${MAX_LATENT_SEEDS:-8}"
|
| 61 |
+
STAGE1_INIT_ADAPTER_DIR="${STAGE1_INIT_ADAPTER_DIR:-}"
|
| 62 |
+
|
| 63 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_stage1_sft_all_latent_modes}"
|
| 64 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 65 |
+
|
| 66 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 67 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 68 |
+
|
| 69 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 70 |
+
export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
|
| 71 |
+
|
| 72 |
+
if [[ ${#GPU_GROUPS[@]} -ne ${#MODES[@]} ]]; then
|
| 73 |
+
printf 'ERROR: expected %d GPU groups, got %d.\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2
|
| 74 |
+
printf 'Example: GPU_GROUPS_SPEC="0,1 2,3 4,5 6,7"\n' >&2
|
| 75 |
+
exit 1
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 79 |
+
printf 'ERROR: Missing train or eval jsonl.\n' >&2
|
| 80 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 81 |
+
exit 1
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
if [[ -n "${STAGE1_INIT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then
|
| 85 |
+
printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 90 |
+
|
| 91 |
+
run_stage1_sft_for_mode() {
|
| 92 |
+
local mode="$1"
|
| 93 |
+
local tag="$2"
|
| 94 |
+
local gpu_group="$3"
|
| 95 |
+
local out_dir="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${tag}"
|
| 96 |
+
local log_dir="${OUTPUT_ROOT}/logs"
|
| 97 |
+
local log_file="${log_dir}/stage01_sft_${mode}.log"
|
| 98 |
+
|
| 99 |
+
mkdir -p "${out_dir}" "${log_dir}"
|
| 100 |
+
printf '\n=== launching stage-1 SFT: mode=%s gpus=%s out=%s ===\n' "${mode}" "${gpu_group}" "${out_dir}" >&2
|
| 101 |
+
|
| 102 |
+
(
|
| 103 |
+
export CUDA_VISIBLE_DEVICES="${gpu_group}"
|
| 104 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${SFT_SCRIPT}" \
|
| 105 |
+
--model_name "${MODEL_NAME}" \
|
| 106 |
+
--train_jsonl "${train_jsonl}" \
|
| 107 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 108 |
+
--output_dir "${out_dir}" \
|
| 109 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 110 |
+
--init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}" \
|
| 111 |
+
--seed 0 \
|
| 112 |
+
--gpu_id 0 \
|
| 113 |
+
--stage_i 1 \
|
| 114 |
+
--num_cot_tokens "${NUM_COT_TOKENS}" \
|
| 115 |
+
--latent_mode "${mode}" \
|
| 116 |
+
--max_latent_slots "${MAX_LATENT_SLOTS}" \
|
| 117 |
+
--max_latent_seeds "${MAX_LATENT_SEEDS}" \
|
| 118 |
+
--total_empties_hint "${EMPTIES}" \
|
| 119 |
+
--mixed_stage1_ratio 1 \
|
| 120 |
+
--mixed_stage2_ratio 0 \
|
| 121 |
+
--per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
|
| 122 |
+
--gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
|
| 123 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 124 |
+
--learning_rate "${STAGE1_SFT_LR}" \
|
| 125 |
+
--weight_decay 0.0 \
|
| 126 |
+
--enable_gradient_checkpointing \
|
| 127 |
+
--logging_steps 20 \
|
| 128 |
+
--eval_steps 250 \
|
| 129 |
+
--save_steps 200 \
|
| 130 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 131 |
+
--max_completion_length 24 \
|
| 132 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 133 |
+
--eval_value_precision_stop "${SFT_VALUE_TARGET}" \
|
| 134 |
+
--eval_value_recall_stop "${SFT_VALUE_TARGET}" \
|
| 135 |
+
--eval_exact_set_match_stop 0 \
|
| 136 |
+
--eval_solve_rate_stop 0 \
|
| 137 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 138 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 139 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 140 |
+
--reward_good_value 1.25 \
|
| 141 |
+
--penalty_bad_value 1.0 \
|
| 142 |
+
--penalty_malformed 4.0 \
|
| 143 |
+
--penalty_empty 0.5 \
|
| 144 |
+
--penalty_singleton 1.5 \
|
| 145 |
+
--lora_r "${LORA_R}" \
|
| 146 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 147 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 148 |
+
--use_wandb \
|
| 149 |
+
--wandb_project "sudoku-latent-stage1-sft-all-modes" \
|
| 150 |
+
--wandb_run_name "latent20_stage1_sft_${mode}_cot${NUM_COT_TOKENS}_val${SFT_VALUE_TARGET}_${RUN_TAG}" \
|
| 151 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 152 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 153 |
+
) >"${log_file}" 2>&1 &
|
| 154 |
+
|
| 155 |
+
printf '%s\n' "$!"
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 159 |
+
printf 'Stage-1 init adapter: %s\n' "${STAGE1_INIT_ADAPTER_DIR:-<fresh-lora-random-latent>}"
|
| 160 |
+
printf 'Modes: %s\n' "${MODES[*]}"
|
| 161 |
+
printf 'GPU groups: %s\n' "${GPU_GROUPS[*]}"
|
| 162 |
+
printf 'Processes per job: %s\n' "${NPROC_PER_JOB}"
|
| 163 |
+
|
| 164 |
+
pids=()
|
| 165 |
+
names=()
|
| 166 |
+
for i in "${!MODES[@]}"; do
|
| 167 |
+
pid="$(run_stage1_sft_for_mode "${MODES[$i]}" "${MODE_TAGS[$i]}" "${GPU_GROUPS[$i]}")"
|
| 168 |
+
pids+=("${pid}")
|
| 169 |
+
names+=("${MODES[$i]}")
|
| 170 |
+
done
|
| 171 |
+
|
| 172 |
+
failed=0
|
| 173 |
+
for i in "${!pids[@]}"; do
|
| 174 |
+
if wait "${pids[$i]}"; then
|
| 175 |
+
printf 'DONE: %s\n' "${names[$i]}"
|
| 176 |
+
else
|
| 177 |
+
printf 'FAILED: %s (pid=%s). See logs under %s/logs\n' "${names[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2
|
| 178 |
+
failed=1
|
| 179 |
+
fi
|
| 180 |
+
done
|
| 181 |
+
|
| 182 |
+
if [[ "${failed}" -ne 0 ]]; then
|
| 183 |
+
exit 1
|
| 184 |
+
fi
|
| 185 |
+
|
| 186 |
+
printf '\nAll stage-1 latent SFT jobs finished.\n'
|
| 187 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
hard_9x9_stage1_consistency_queue/launch_20empty_warm_baseline_all_latent_modes_stages123.sh
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Full 20-empty latent comparison with baseline warm-up before latent stages.
|
| 3 |
+
#
|
| 4 |
+
# Required:
|
| 5 |
+
# STAGE1_BASELINE_ADAPTER_DIR=/path/to/baseline/stage1/checkpoint-step-XXXXX
|
| 6 |
+
#
|
| 7 |
+
# Default mode split on 8 GPUs:
|
| 8 |
+
# residual -> GPUs 0,1
|
| 9 |
+
# fixed_slots -> GPUs 2,3
|
| 10 |
+
# recurrent_hidden -> GPUs 4,5
|
| 11 |
+
# latent_seeds -> GPUs 6,7
|
| 12 |
+
#
|
| 13 |
+
# Per mode:
|
| 14 |
+
# stage1 latent SFT -> stage1 latent GRPO
|
| 15 |
+
# stage2 baseline SFT warm-up -> stage2 latent SFT -> stage2 latent GRPO
|
| 16 |
+
# stage3 baseline SFT warm-up -> stage3 latent SFT -> stage3 latent GRPO
|
| 17 |
+
#
|
| 18 |
+
# Optional resume adapters, intended for single-mode resumes:
|
| 19 |
+
# STAGE1_LATENT_SFT_ADAPTER_DIR=/path/to/stage01_latent_sft_or_checkpoint
|
| 20 |
+
# STAGE1_LATENT_GRPO_ADAPTER_DIR=/path/to/stage01_latent_grpo
|
| 21 |
+
# STAGE2_BASELINE_WARM_ADAPTER_DIR=/path/to/stage02_baseline_warm_sft
|
| 22 |
+
# STAGE2_LATENT_SFT_INIT_ADAPTER_DIR=/path/to/stage02_latent_sft_checkpoint_to_continue_training
|
| 23 |
+
# STAGE2_LATENT_SFT_ADAPTER_DIR=/path/to/stage02_latent_sft_or_checkpoint
|
| 24 |
+
# STAGE2_LATENT_GRPO_ADAPTER_DIR=/path/to/stage02_latent_grpo
|
| 25 |
+
# STAGE3_BASELINE_WARM_ADAPTER_DIR=/path/to/stage03_baseline_warm_sft
|
| 26 |
+
# STAGE3_LATENT_SFT_ADAPTER_DIR=/path/to/stage03_latent_sft_or_checkpoint
|
| 27 |
+
set -euo pipefail
|
| 28 |
+
|
| 29 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 30 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 31 |
+
|
| 32 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 33 |
+
BASELINE_SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 34 |
+
LATENT_SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/sft_latent_multi_output_train.py"
|
| 35 |
+
LATENT_GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_multimode_latent_train.py"
|
| 36 |
+
|
| 37 |
+
EMPTIES="${EMPTIES:-20}"
|
| 38 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-1.5B-Instruct}"
|
| 39 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 40 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 41 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 42 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 43 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 44 |
+
|
| 45 |
+
MODES_SPEC="${MODES_SPEC:-residual fixed_slots recurrent_hidden latent_seeds}"
|
| 46 |
+
GPU_GROUPS_SPEC="${GPU_GROUPS_SPEC:-0,1 2,3 4,5 6,7}"
|
| 47 |
+
NPROC_PER_JOB="${NPROC_PER_JOB:-2}"
|
| 48 |
+
|
| 49 |
+
STAGE1_BASELINE_ADAPTER_DIR="${STAGE1_BASELINE_ADAPTER_DIR:-}"
|
| 50 |
+
STAGE1_LATENT_SFT_ADAPTER_DIR="${STAGE1_LATENT_SFT_ADAPTER_DIR:-}"
|
| 51 |
+
STAGE1_LATENT_GRPO_ADAPTER_DIR="${STAGE1_LATENT_GRPO_ADAPTER_DIR:-}"
|
| 52 |
+
STAGE2_BASELINE_WARM_ADAPTER_DIR="${STAGE2_BASELINE_WARM_ADAPTER_DIR:-}"
|
| 53 |
+
STAGE2_LATENT_SFT_INIT_ADAPTER_DIR="${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR:-}"
|
| 54 |
+
STAGE2_LATENT_SFT_ADAPTER_DIR="${STAGE2_LATENT_SFT_ADAPTER_DIR:-}"
|
| 55 |
+
STAGE2_LATENT_GRPO_ADAPTER_DIR="${STAGE2_LATENT_GRPO_ADAPTER_DIR:-}"
|
| 56 |
+
STAGE3_BASELINE_WARM_ADAPTER_DIR="${STAGE3_BASELINE_WARM_ADAPTER_DIR:-}"
|
| 57 |
+
STAGE3_LATENT_SFT_ADAPTER_DIR="${STAGE3_LATENT_SFT_ADAPTER_DIR:-}"
|
| 58 |
+
if [[ -z "${STAGE1_LATENT_SFT_ADAPTER_DIR}" && -z "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then
|
| 59 |
+
if [[ -z "${STAGE1_BASELINE_ADAPTER_DIR}" ]] || [[ ! -d "${STAGE1_BASELINE_ADAPTER_DIR}" ]]; then
|
| 60 |
+
printf 'ERROR: Set STAGE1_BASELINE_ADAPTER_DIR to a finished baseline SFT checkpoint directory, or provide a stage-1 latent resume adapter.\n' >&2
|
| 61 |
+
exit 1
|
| 62 |
+
fi
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
SFT_PER_DEVICE_BS="${SFT_PER_DEVICE_BS:-8}"
|
| 66 |
+
SFT_GRAD_ACCUM="${SFT_GRAD_ACCUM:-2}"
|
| 67 |
+
BASELINE_PER_DEVICE_BS="${BASELINE_PER_DEVICE_BS:-16}"
|
| 68 |
+
BASELINE_GRAD_ACCUM="${BASELINE_GRAD_ACCUM:-2}"
|
| 69 |
+
GRPO_PER_DEVICE_BS="${GRPO_PER_DEVICE_BS:-4}"
|
| 70 |
+
GRPO_GRAD_ACCUM="${GRPO_GRAD_ACCUM:-2}"
|
| 71 |
+
|
| 72 |
+
BASELINE_WARM_MAX_STEPS="${BASELINE_WARM_MAX_STEPS:-1000}"
|
| 73 |
+
LATENT_SFT_MAX_STEPS="${LATENT_SFT_MAX_STEPS:-1000}"
|
| 74 |
+
LATENT_GRPO_MAX_STEPS="${LATENT_GRPO_MAX_STEPS:-500}"
|
| 75 |
+
LATENT_SFT_EVAL_STEPS="${LATENT_SFT_EVAL_STEPS:-250}"
|
| 76 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-64}"
|
| 77 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-50}"
|
| 78 |
+
|
| 79 |
+
SOLVE_TARGET="${SOLVE_TARGET:-0.95}"
|
| 80 |
+
VALUE_TARGET="${VALUE_TARGET:-0}"
|
| 81 |
+
MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
|
| 82 |
+
GRPO_BETA="${GRPO_BETA:-0.0}"
|
| 83 |
+
|
| 84 |
+
LORA_R="${LORA_R:-32}"
|
| 85 |
+
LORA_ALPHA="${LORA_ALPHA:-64}"
|
| 86 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 87 |
+
|
| 88 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_${EMPTIES}empty_warm_baseline_all_latent_modes_stages123}"
|
| 89 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 90 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 91 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 92 |
+
|
| 93 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 94 |
+
export CUDA_DEVICE_ORDER="${CUDA_DEVICE_ORDER:-PCI_BUS_ID}"
|
| 95 |
+
|
| 96 |
+
read -r -a MODES <<< "${MODES_SPEC}"
|
| 97 |
+
read -r -a GPU_GROUPS <<< "${GPU_GROUPS_SPEC}"
|
| 98 |
+
if [[ ${#MODES[@]} -ne ${#GPU_GROUPS[@]} ]]; then
|
| 99 |
+
printf 'ERROR: expected one GPU group per mode. modes=%d gpu_groups=%d\n' "${#MODES[@]}" "${#GPU_GROUPS[@]}" >&2
|
| 100 |
+
exit 1
|
| 101 |
+
fi
|
| 102 |
+
|
| 103 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 104 |
+
printf 'ERROR: Missing train or eval jsonl.\n %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 105 |
+
exit 1
|
| 106 |
+
fi
|
| 107 |
+
|
| 108 |
+
mkdir -p "${OUTPUT_ROOT}/logs"
|
| 109 |
+
|
| 110 |
+
mode_tag() {
|
| 111 |
+
case "$1" in
|
| 112 |
+
residual) printf 'latent_residual' ;;
|
| 113 |
+
fixed_slots) printf 'latent_fixed_slots' ;;
|
| 114 |
+
recurrent_hidden) printf 'latent_recurrent_hidden' ;;
|
| 115 |
+
latent_seeds) printf 'latent_seeds' ;;
|
| 116 |
+
*) printf 'latent_%s' "$1" ;;
|
| 117 |
+
esac
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
latest_checkpoint_or_dir() {
|
| 121 |
+
local d="$1"
|
| 122 |
+
shopt -s nullglob
|
| 123 |
+
local checkpoints=("${d}"/checkpoint-step-*)
|
| 124 |
+
shopt -u nullglob
|
| 125 |
+
if (( ${#checkpoints[@]} > 0 )); then
|
| 126 |
+
printf '%s\n' "${checkpoints[@]}" | sort -V | tail -n 1
|
| 127 |
+
return 0
|
| 128 |
+
fi
|
| 129 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 130 |
+
printf '%s\n' "${d}"
|
| 131 |
+
return 0
|
| 132 |
+
fi
|
| 133 |
+
printf ''
|
| 134 |
+
return 1
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
run_baseline_sft() {
|
| 138 |
+
local stage="$1" init_adapter="$2" out_dir="$3" lr="$4" run_name="$5"
|
| 139 |
+
mkdir -p "${out_dir}"
|
| 140 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${BASELINE_SFT_SCRIPT}" \
|
| 141 |
+
--model_name "${MODEL_NAME}" \
|
| 142 |
+
--train_jsonl "${train_jsonl}" \
|
| 143 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 144 |
+
--output_dir "${out_dir}" \
|
| 145 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 146 |
+
--init_adapter_dir "${init_adapter}" \
|
| 147 |
+
--seed 0 \
|
| 148 |
+
--gpu_id 0 \
|
| 149 |
+
--stage_i "${stage}" \
|
| 150 |
+
--total_empties_hint "${EMPTIES}" \
|
| 151 |
+
--per_device_train_batch_size "${BASELINE_PER_DEVICE_BS}" \
|
| 152 |
+
--gradient_accumulation_steps "${BASELINE_GRAD_ACCUM}" \
|
| 153 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 154 |
+
--learning_rate "${lr}" \
|
| 155 |
+
--max_grad_norm 1.0 \
|
| 156 |
+
--logging_steps 20 \
|
| 157 |
+
--eval_steps 250 \
|
| 158 |
+
--save_steps 200 \
|
| 159 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 160 |
+
--max_completion_length 24 \
|
| 161 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 162 |
+
--lora_r "${LORA_R}" \
|
| 163 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 164 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 165 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 166 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 167 |
+
--eval_exact_set_match_stop 0 \
|
| 168 |
+
--eval_solve_rate_stop "${SOLVE_TARGET}" \
|
| 169 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 170 |
+
--max_wall_clock_seconds 0 \
|
| 171 |
+
--max_steps "${BASELINE_WARM_MAX_STEPS}" \
|
| 172 |
+
--use_wandb \
|
| 173 |
+
--wandb_project "sudoku-baseline-stage-warmups" \
|
| 174 |
+
--wandb_run_name "${run_name}" \
|
| 175 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 176 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
run_latent_sft() {
|
| 180 |
+
local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" lr="$6" run_name="$7"
|
| 181 |
+
local ms1=0 ms2=1
|
| 182 |
+
if [[ "${stage}" == "1" ]]; then
|
| 183 |
+
ms1=1
|
| 184 |
+
ms2=0
|
| 185 |
+
fi
|
| 186 |
+
mkdir -p "${out_dir}"
|
| 187 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_SFT_SCRIPT}" \
|
| 188 |
+
--model_name "${MODEL_NAME}" \
|
| 189 |
+
--train_jsonl "${train_jsonl}" \
|
| 190 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 191 |
+
--output_dir "${out_dir}" \
|
| 192 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 193 |
+
--init_adapter_dir "${init_adapter}" \
|
| 194 |
+
--seed 0 \
|
| 195 |
+
--gpu_id 0 \
|
| 196 |
+
--stage_i "${stage}" \
|
| 197 |
+
--num_cot_tokens "${cot}" \
|
| 198 |
+
--latent_mode "${mode}" \
|
| 199 |
+
--max_latent_slots 8 \
|
| 200 |
+
--max_latent_seeds 8 \
|
| 201 |
+
--total_empties_hint "${EMPTIES}" \
|
| 202 |
+
--mixed_stage1_ratio "${ms1}" \
|
| 203 |
+
--mixed_stage2_ratio "${ms2}" \
|
| 204 |
+
--per_device_train_batch_size "${SFT_PER_DEVICE_BS}" \
|
| 205 |
+
--gradient_accumulation_steps "${SFT_GRAD_ACCUM}" \
|
| 206 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 207 |
+
--learning_rate "${lr}" \
|
| 208 |
+
--weight_decay 0.0 \
|
| 209 |
+
--enable_gradient_checkpointing \
|
| 210 |
+
--logging_steps 20 \
|
| 211 |
+
--eval_steps "${LATENT_SFT_EVAL_STEPS}" \
|
| 212 |
+
--save_steps 200 \
|
| 213 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 214 |
+
--max_completion_length 24 \
|
| 215 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 216 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 217 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 218 |
+
--eval_exact_set_match_stop 0 \
|
| 219 |
+
--eval_solve_rate_stop "${SOLVE_TARGET}" \
|
| 220 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 221 |
+
--max_wall_clock_seconds 0 \
|
| 222 |
+
--max_steps "${LATENT_SFT_MAX_STEPS}" \
|
| 223 |
+
--reward_good_value 1.25 \
|
| 224 |
+
--penalty_bad_value 1.0 \
|
| 225 |
+
--penalty_malformed 4.0 \
|
| 226 |
+
--penalty_empty 0.5 \
|
| 227 |
+
--penalty_singleton 1.5 \
|
| 228 |
+
--lora_r "${LORA_R}" \
|
| 229 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 230 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 231 |
+
--use_wandb \
|
| 232 |
+
--wandb_project "sudoku-latent-stage-sft-warm-baseline" \
|
| 233 |
+
--wandb_run_name "${run_name}" \
|
| 234 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 235 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
run_latent_grpo() {
|
| 239 |
+
local mode="$1" stage="$2" cot="$3" init_adapter="$4" out_dir="$5" run_name="$6"
|
| 240 |
+
mkdir -p "${out_dir}"
|
| 241 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NPROC_PER_JOB}" "${LATENT_GRPO_SCRIPT}" \
|
| 242 |
+
--model_name "${MODEL_NAME}" \
|
| 243 |
+
--train_jsonl "${train_jsonl}" \
|
| 244 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 245 |
+
--output_dir "${out_dir}" \
|
| 246 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 247 |
+
--init_adapter_dir "${init_adapter}" \
|
| 248 |
+
--seed 0 \
|
| 249 |
+
--gpu_id 0 \
|
| 250 |
+
--stage_i "${stage}" \
|
| 251 |
+
--num_cot_tokens "${cot}" \
|
| 252 |
+
--latent_mode "${mode}" \
|
| 253 |
+
--max_latent_seeds 8 \
|
| 254 |
+
--total_empties_hint "${EMPTIES}" \
|
| 255 |
+
--mixed_stage1_ratio 0 \
|
| 256 |
+
--mixed_stage2_ratio 1 \
|
| 257 |
+
--per_device_train_batch_size "${GRPO_PER_DEVICE_BS}" \
|
| 258 |
+
--gradient_accumulation_steps "${GRPO_GRAD_ACCUM}" \
|
| 259 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 260 |
+
--learning_rate 1e-6 \
|
| 261 |
+
--logging_steps 20 \
|
| 262 |
+
--save_steps 200 \
|
| 263 |
+
--eval_steps 500 \
|
| 264 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 265 |
+
--num_generations 4 \
|
| 266 |
+
--max_prompt_length 1024 \
|
| 267 |
+
--max_completion_length 24 \
|
| 268 |
+
--beta "${GRPO_BETA}" \
|
| 269 |
+
--enable_gradient_checkpointing \
|
| 270 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 271 |
+
--reward_good_value 1.25 \
|
| 272 |
+
--penalty_bad_value 1.0 \
|
| 273 |
+
--penalty_malformed 4.0 \
|
| 274 |
+
--penalty_empty 0.5 \
|
| 275 |
+
--penalty_singleton 1.5 \
|
| 276 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 277 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 278 |
+
--eval_solve_rate_stop "${SOLVE_TARGET}" \
|
| 279 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 280 |
+
--max_wall_clock_seconds 0 \
|
| 281 |
+
--max_steps "${LATENT_GRPO_MAX_STEPS}" \
|
| 282 |
+
--lora_r "${LORA_R}" \
|
| 283 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 284 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 285 |
+
--use_wandb \
|
| 286 |
+
--wandb_project "sudoku-latent-stage-grpo-warm-baseline" \
|
| 287 |
+
--wandb_run_name "${run_name}" \
|
| 288 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 289 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
run_mode_pipeline() {
|
| 293 |
+
local mode="$1" gpu_group="$2" tag
|
| 294 |
+
tag="$(mode_tag "${mode}")"
|
| 295 |
+
local mode_root="${OUTPUT_ROOT}/${tag}"
|
| 296 |
+
local log="${OUTPUT_ROOT}/logs/${tag}.log"
|
| 297 |
+
mkdir -p "${mode_root}"
|
| 298 |
+
export CUDA_VISIBLE_DEVICES="${gpu_group}"
|
| 299 |
+
printf 'Mode %s on GPUs %s\n' "${mode}" "${gpu_group}"
|
| 300 |
+
|
| 301 |
+
local s1_lat="${mode_root}/stage01_latent_sft_i1_${EMPTIES}empty_${tag}"
|
| 302 |
+
local g1="${mode_root}/stage01_latent_grpo_i1_${EMPTIES}empty_${tag}"
|
| 303 |
+
local a_s1_lat a_g1
|
| 304 |
+
if [[ -n "${STAGE1_LATENT_GRPO_ADAPTER_DIR}" ]]; then
|
| 305 |
+
a_g1="$(latest_checkpoint_or_dir "${STAGE1_LATENT_GRPO_ADAPTER_DIR}")"
|
| 306 |
+
printf 'Using existing stage-1 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g1}" | tee -a "${log}"
|
| 307 |
+
else
|
| 308 |
+
if [[ -n "${STAGE1_LATENT_SFT_ADAPTER_DIR}" ]]; then
|
| 309 |
+
a_s1_lat="$(latest_checkpoint_or_dir "${STAGE1_LATENT_SFT_ADAPTER_DIR}")"
|
| 310 |
+
printf 'Using existing stage-1 latent SFT adapter for %s: %s\n' "${mode}" "${a_s1_lat}" | tee -a "${log}"
|
| 311 |
+
else
|
| 312 |
+
run_latent_sft "${mode}" 1 1 "${STAGE1_BASELINE_ADAPTER_DIR}" "${s1_lat}" "2e-4" "warmfull_${mode}_st1_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 313 |
+
a_s1_lat="$(latest_checkpoint_or_dir "${s1_lat}")"
|
| 314 |
+
fi
|
| 315 |
+
run_latent_grpo "${mode}" 1 1 "${a_s1_lat}" "${g1}" "warmfull_${mode}_st1_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 316 |
+
a_g1="$(latest_checkpoint_or_dir "${g1}")"
|
| 317 |
+
fi
|
| 318 |
+
|
| 319 |
+
local b2="${mode_root}/stage02_baseline_warm_sft_i2_${EMPTIES}empty_${tag}"
|
| 320 |
+
local s2_lat="${mode_root}/stage02_latent_sft_i2_${EMPTIES}empty_${tag}"
|
| 321 |
+
local g2="${mode_root}/stage02_latent_grpo_i2_${EMPTIES}empty_${tag}"
|
| 322 |
+
local a_b2 a_s2_lat a_g2
|
| 323 |
+
if [[ -n "${STAGE2_LATENT_GRPO_ADAPTER_DIR}" ]]; then
|
| 324 |
+
a_g2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_GRPO_ADAPTER_DIR}")"
|
| 325 |
+
printf 'Using existing stage-2 latent GRPO adapter for %s: %s\n' "${mode}" "${a_g2}" | tee -a "${log}"
|
| 326 |
+
else
|
| 327 |
+
if [[ -n "${STAGE2_LATENT_SFT_ADAPTER_DIR}" ]]; then
|
| 328 |
+
a_s2_lat="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_ADAPTER_DIR}")"
|
| 329 |
+
printf 'Using existing stage-2 latent SFT adapter for %s: %s\n' "${mode}" "${a_s2_lat}" | tee -a "${log}"
|
| 330 |
+
else
|
| 331 |
+
if [[ -n "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}" ]]; then
|
| 332 |
+
a_b2="$(latest_checkpoint_or_dir "${STAGE2_LATENT_SFT_INIT_ADAPTER_DIR}")"
|
| 333 |
+
printf 'Continuing stage-2 latent SFT for %s from adapter: %s\n' "${mode}" "${a_b2}" | tee -a "${log}"
|
| 334 |
+
elif [[ -n "${STAGE2_BASELINE_WARM_ADAPTER_DIR}" ]]; then
|
| 335 |
+
a_b2="$(latest_checkpoint_or_dir "${STAGE2_BASELINE_WARM_ADAPTER_DIR}")"
|
| 336 |
+
printf 'Using existing stage-2 baseline warm adapter for %s: %s\n' "${mode}" "${a_b2}" | tee -a "${log}"
|
| 337 |
+
else
|
| 338 |
+
run_baseline_sft 2 "${a_g1}" "${b2}" "5e-5" "warmfull_${mode}_st2_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 339 |
+
a_b2="$(latest_checkpoint_or_dir "${b2}")"
|
| 340 |
+
fi
|
| 341 |
+
run_latent_sft "${mode}" 2 2 "${a_b2}" "${s2_lat}" "5e-5" "warmfull_${mode}_st2_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 342 |
+
a_s2_lat="$(latest_checkpoint_or_dir "${s2_lat}")"
|
| 343 |
+
fi
|
| 344 |
+
run_latent_grpo "${mode}" 2 2 "${a_s2_lat}" "${g2}" "warmfull_${mode}_st2_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 345 |
+
a_g2="$(latest_checkpoint_or_dir "${g2}")"
|
| 346 |
+
fi
|
| 347 |
+
|
| 348 |
+
local b3="${mode_root}/stage03_baseline_warm_sft_i3_${EMPTIES}empty_${tag}"
|
| 349 |
+
local s3_lat="${mode_root}/stage03_latent_sft_i3_${EMPTIES}empty_${tag}"
|
| 350 |
+
local g3="${mode_root}/stage03_latent_grpo_i3_${EMPTIES}empty_${tag}"
|
| 351 |
+
local a_b3 a_s3_lat
|
| 352 |
+
if [[ -n "${STAGE3_LATENT_SFT_ADAPTER_DIR}" ]]; then
|
| 353 |
+
a_s3_lat="$(latest_checkpoint_or_dir "${STAGE3_LATENT_SFT_ADAPTER_DIR}")"
|
| 354 |
+
printf 'Using existing stage-3 latent SFT adapter for %s: %s\n' "${mode}" "${a_s3_lat}" | tee -a "${log}"
|
| 355 |
+
else
|
| 356 |
+
if [[ -n "${STAGE3_BASELINE_WARM_ADAPTER_DIR}" ]]; then
|
| 357 |
+
a_b3="$(latest_checkpoint_or_dir "${STAGE3_BASELINE_WARM_ADAPTER_DIR}")"
|
| 358 |
+
printf 'Using existing stage-3 baseline warm adapter for %s: %s\n' "${mode}" "${a_b3}" | tee -a "${log}"
|
| 359 |
+
else
|
| 360 |
+
run_baseline_sft 3 "${a_g2}" "${b3}" "5e-5" "warmfull_${mode}_st3_baseline_warm_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 361 |
+
a_b3="$(latest_checkpoint_or_dir "${b3}")"
|
| 362 |
+
fi
|
| 363 |
+
run_latent_sft "${mode}" 3 3 "${a_b3}" "${s3_lat}" "5e-5" "warmfull_${mode}_st3_latent_sft_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 364 |
+
a_s3_lat="$(latest_checkpoint_or_dir "${s3_lat}")"
|
| 365 |
+
fi
|
| 366 |
+
run_latent_grpo "${mode}" 3 3 "${a_s3_lat}" "${g3}" "warmfull_${mode}_st3_latent_grpo_${RUN_TAG}" 2>&1 | tee -a "${log}"
|
| 367 |
+
|
| 368 |
+
printf 'Mode %s finished. Output: %s\n' "${mode}" "${mode_root}" | tee -a "${log}"
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 372 |
+
printf 'Stage-1 baseline adapter: %s\n' "${STAGE1_BASELINE_ADAPTER_DIR}"
|
| 373 |
+
printf 'Solve target: %s (value target: %s)\n' "${SOLVE_TARGET}" "${VALUE_TARGET}"
|
| 374 |
+
|
| 375 |
+
pids=()
|
| 376 |
+
for i in "${!MODES[@]}"; do
|
| 377 |
+
(
|
| 378 |
+
run_mode_pipeline "${MODES[$i]}" "${GPU_GROUPS[$i]}"
|
| 379 |
+
) >"${OUTPUT_ROOT}/logs/$(mode_tag "${MODES[$i]}").supervisor.log" 2>&1 &
|
| 380 |
+
pids+=("$!")
|
| 381 |
+
printf 'Launched mode=%s pid=%s gpus=%s\n' "${MODES[$i]}" "${pids[-1]}" "${GPU_GROUPS[$i]}"
|
| 382 |
+
done
|
| 383 |
+
|
| 384 |
+
failed=0
|
| 385 |
+
for i in "${!pids[@]}"; do
|
| 386 |
+
if wait "${pids[$i]}"; then
|
| 387 |
+
printf 'DONE: %s\n' "${MODES[$i]}"
|
| 388 |
+
else
|
| 389 |
+
printf 'FAILED: %s (pid=%s). See %s/logs\n' "${MODES[$i]}" "${pids[$i]}" "${OUTPUT_ROOT}" >&2
|
| 390 |
+
failed=1
|
| 391 |
+
fi
|
| 392 |
+
done
|
| 393 |
+
|
| 394 |
+
exit "${failed}"
|
hard_9x9_stage1_consistency_queue/launch_7empty_latent_residual_stages123_value98.sh
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Latent residual projector pipeline (7-empty), aligned with the text
|
| 3 |
+
# launch_7empty_post_s1sft_stages123_value98.sh order and value gate:
|
| 4 |
+
# 1) Stage-1 SFT (default: init = STAGE1_INIT_ADAPTER_DIR or fresh LoRA + random residual)
|
| 5 |
+
# 2) Stage-1 GRPO (init = stage-1 SFT checkpoint-step-* dir, or STAGE1_SFT_ADAPTER_DIR if set)
|
| 6 |
+
# 3) Stage-2 SFT (init = stage-1 GRPO adapter)
|
| 7 |
+
# 4) Stage-2 GRPO
|
| 8 |
+
# 5) Stage-3 SFT
|
| 9 |
+
# 6) Stage-3 GRPO
|
| 10 |
+
#
|
| 11 |
+
# Legacy GRPO-first (skip training stage-1 SFT): STAGE1_GRPO_FIRST=1
|
| 12 |
+
#
|
| 13 |
+
# Latent structure (implemented in latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py):
|
| 14 |
+
# - attach_residual_projector_modules(): adds trainable special_thought_embed, latent_mix_logit,
|
| 15 |
+
# and MLP latent_projector_in/out (hidden→4096→hidden) on the Peft-wrapped model.
|
| 16 |
+
# - build_latent_hidden() / residual_next_token_logits_from_ids(): append num_cot_tokens "latent"
|
| 17 |
+
# virtual tokens, run backbone, take (latent_hidden - base_hidden), project through the MLP,
|
| 18 |
+
# mix with base hidden (sigmoid(latent_mix_logit)), then lm_head logits (with optional fallback).
|
| 19 |
+
# - sample_latent_completion() / GRPO use this path for generation; SFT uses the same via
|
| 20 |
+
# residual_projector_warmstart_sft_latent_multi_output_train.py (latent_residual_completion_ce_loss).
|
| 21 |
+
# - latent_cot_state.pt saves/loads the projector + special_thought_embed + mix logit.
|
| 22 |
+
#
|
| 23 |
+
# Each phase stops when eval value_precision AND value_recall are both >= VALUE_TARGET
|
| 24 |
+
# (default 0.98), after MIN_STEPS_BEFORE_STOP optimizer steps (SFT) / GRPO steps (GRPO).
|
| 25 |
+
# Eval rows come from eval_jsonl (same held-out file as the text pipeline).
|
| 26 |
+
#
|
| 27 |
+
# Stage-1 SFT init (when not using STAGE1_SFT_ADAPTER_DIR or STAGE1_GRPO_FIRST):
|
| 28 |
+
# Default: omit STAGE1_INIT_ADAPTER_DIR → fresh LoRA + random residual (same as trainers --init_adapter_dir "").
|
| 29 |
+
# Optional: STAGE1_INIT_ADAPTER_DIR=/path/to/adapter
|
| 30 |
+
#
|
| 31 |
+
# Skip running stage-1 SFT (you already have a finished SFT checkpoint-step-*):
|
| 32 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/stage01_sft_.../checkpoint-step-XXXX
|
| 33 |
+
# → first trained phase is stage-1 GRPO with that init.
|
| 34 |
+
#
|
| 35 |
+
# Resume:
|
| 36 |
+
# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty_latent_residual
|
| 37 |
+
# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty_latent_residual
|
| 38 |
+
# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty_latent_residual
|
| 39 |
+
#
|
| 40 |
+
# Optional env: VALUE_TARGET, TRAIN_PUZZLES, EVAL_PUZZLES, RUN_TAG, CHECKPOINT_ROOT, GPU_IDS,
|
| 41 |
+
# WANDB_MODE, WANDB_ENTITY, SFT_NUM_EPOCHS, GRPO_NUM_TRAIN_EPOCHS, SFT_MAX_STEPS, GRPO_MAX_STEPS,
|
| 42 |
+
# STAGE1_SFT_LR (default 2e-4), STAGE1_GRPO_FIRST, STAGE1_SFT_ADAPTER_DIR
|
| 43 |
+
#
|
| 44 |
+
set -euo pipefail
|
| 45 |
+
|
| 46 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 47 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 48 |
+
|
| 49 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 50 |
+
SFT_SCRIPT="${ROOT}/latent_multi_output_cell_policy/residual_projector_warmstart_sft_latent_multi_output_train.py"
|
| 51 |
+
GRPO_SCRIPT="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
|
| 52 |
+
|
| 53 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 54 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 55 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 56 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 57 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 58 |
+
|
| 59 |
+
EMPTIES=7
|
| 60 |
+
TAG_SUFFIX="latent_residual"
|
| 61 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 62 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 63 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 64 |
+
MIN_STEPS_BEFORE_STOP="${MIN_STEPS_BEFORE_STOP:-50}"
|
| 65 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 66 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 67 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 68 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 69 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 70 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 71 |
+
# Fresh-LoRA defaults (match text 7-empty SFT scale); override if you use a different init checkpoint.
|
| 72 |
+
LORA_R="${LORA_R:-32}"
|
| 73 |
+
LORA_ALPHA="${LORA_ALPHA:-64}"
|
| 74 |
+
LORA_DROPOUT="${LORA_DROPOUT:-0.05}"
|
| 75 |
+
|
| 76 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_latent_residual_stages123_value98}"
|
| 77 |
+
START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
|
| 78 |
+
START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
|
| 79 |
+
RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
|
| 80 |
+
STAGE1_SFT_ADAPTER_DIR="${STAGE1_SFT_ADAPTER_DIR:-}"
|
| 81 |
+
STAGE1_GRPO_FIRST="${STAGE1_GRPO_FIRST:-0}"
|
| 82 |
+
STAGE1_SFT_LR="${STAGE1_SFT_LR:-2e-4}"
|
| 83 |
+
|
| 84 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 85 |
+
if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 86 |
+
printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 87 |
+
exit 1
|
| 88 |
+
fi
|
| 89 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
|
| 90 |
+
elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 91 |
+
if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 92 |
+
printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 93 |
+
exit 1
|
| 94 |
+
fi
|
| 95 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
|
| 96 |
+
elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 97 |
+
if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 98 |
+
printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
|
| 99 |
+
exit 1
|
| 100 |
+
fi
|
| 101 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
|
| 102 |
+
else
|
| 103 |
+
if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]] && [[ ! -d "${STAGE1_INIT_ADAPTER_DIR}" ]]; then
|
| 104 |
+
printf 'ERROR: STAGE1_INIT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_INIT_ADAPTER_DIR}" >&2
|
| 105 |
+
exit 1
|
| 106 |
+
fi
|
| 107 |
+
if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]] && [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 108 |
+
printf 'ERROR: STAGE1_SFT_ADAPTER_DIR is not a directory: %s\n' "${STAGE1_SFT_ADAPTER_DIR}" >&2
|
| 109 |
+
exit 1
|
| 110 |
+
fi
|
| 111 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 112 |
+
fi
|
| 113 |
+
|
| 114 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 115 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 116 |
+
|
| 117 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 118 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 119 |
+
|
| 120 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 121 |
+
|
| 122 |
+
latest_sft_step_ckpt() {
|
| 123 |
+
local d="$1"
|
| 124 |
+
shopt -s nullglob
|
| 125 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 126 |
+
shopt -u nullglob
|
| 127 |
+
if (( ${#cks[@]} == 0 )); then
|
| 128 |
+
printf ''
|
| 129 |
+
return 1
|
| 130 |
+
fi
|
| 131 |
+
set +o pipefail
|
| 132 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 133 |
+
set -o pipefail
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
resolve_latent_grpo_adapter() {
|
| 137 |
+
local d="$1"
|
| 138 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 139 |
+
printf '%s\n' "${d}"
|
| 140 |
+
return 0
|
| 141 |
+
fi
|
| 142 |
+
local best="" step=-1
|
| 143 |
+
shopt -s nullglob
|
| 144 |
+
local c
|
| 145 |
+
for c in "${d}"/checkpoint-*; do
|
| 146 |
+
[[ -d "${c}" ]] || continue
|
| 147 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 148 |
+
local n
|
| 149 |
+
n="${c##*checkpoint-}"
|
| 150 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 151 |
+
step=$((10#${n}))
|
| 152 |
+
best="${c}"
|
| 153 |
+
fi
|
| 154 |
+
done
|
| 155 |
+
shopt -u nullglob
|
| 156 |
+
if [[ -n "${best}" ]]; then
|
| 157 |
+
printf '%s\n' "${best}"
|
| 158 |
+
return 0
|
| 159 |
+
fi
|
| 160 |
+
printf ''
|
| 161 |
+
return 1
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
run_latent_sft() {
|
| 165 |
+
local stage="$1"
|
| 166 |
+
local init_adapter="$2"
|
| 167 |
+
local out_dir="$3"
|
| 168 |
+
local lr="$4"
|
| 169 |
+
local cot="$5"
|
| 170 |
+
# Stage-1 SFT must weight stage-1 rows only (mixed 1/0). Stages 2–3 use stage-i curriculum (mixed 0/1).
|
| 171 |
+
local ms1=0 ms2=1
|
| 172 |
+
if [[ "${stage}" == "1" ]]; then
|
| 173 |
+
ms1=1
|
| 174 |
+
ms2=0
|
| 175 |
+
fi
|
| 176 |
+
mkdir -p "${out_dir}"
|
| 177 |
+
printf '\n=== Latent stage %s SFT (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
|
| 178 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s mixed_s1/s2=%s/%s\n' "${init_adapter}" "${out_dir}" "${cot}" "${ms1}" "${ms2}" >&2
|
| 179 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 180 |
+
--model_name "${MODEL_NAME}" \
|
| 181 |
+
--train_jsonl "${train_jsonl}" \
|
| 182 |
+
--output_dir "${out_dir}" \
|
| 183 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 184 |
+
--init_adapter_dir "${init_adapter}" \
|
| 185 |
+
--seed 0 \
|
| 186 |
+
--gpu_id 0 \
|
| 187 |
+
--stage_i "${stage}" \
|
| 188 |
+
--num_cot_tokens "${cot}" \
|
| 189 |
+
--total_empties_hint "${EMPTIES}" \
|
| 190 |
+
--mixed_stage1_ratio "${ms1}" \
|
| 191 |
+
--mixed_stage2_ratio "${ms2}" \
|
| 192 |
+
--gradient_accumulation_steps 2 \
|
| 193 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 194 |
+
--learning_rate "${lr}" \
|
| 195 |
+
--weight_decay 0.0 \
|
| 196 |
+
--enable_gradient_checkpointing \
|
| 197 |
+
--logging_steps 20 \
|
| 198 |
+
--eval_steps 250 \
|
| 199 |
+
--save_steps 200 \
|
| 200 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 201 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 202 |
+
--max_completion_length 24 \
|
| 203 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 204 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 205 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 206 |
+
--eval_exact_set_match_stop 0 \
|
| 207 |
+
--eval_solve_rate_stop 0 \
|
| 208 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 209 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 210 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 211 |
+
--reward_good_value 1.25 \
|
| 212 |
+
--penalty_bad_value 1.0 \
|
| 213 |
+
--penalty_malformed 4.0 \
|
| 214 |
+
--penalty_empty 0.5 \
|
| 215 |
+
--penalty_singleton 1.5 \
|
| 216 |
+
--lora_r "${LORA_R}" \
|
| 217 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 218 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 219 |
+
--use_wandb \
|
| 220 |
+
--wandb_project "sudoku-latent-multi-output-sft-residual-projector" \
|
| 221 |
+
--wandb_run_name "latent7_st${stage}_sft_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 222 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 223 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
run_latent_grpo() {
|
| 227 |
+
local stage="$1"
|
| 228 |
+
local init_adapter="$2"
|
| 229 |
+
local out_dir="$3"
|
| 230 |
+
local cot="$4"
|
| 231 |
+
mkdir -p "${out_dir}"
|
| 232 |
+
printf '\n=== Latent stage %s GRPO (residual) → stop value prec+recall >= %s ===\n' "${stage}" "${VALUE_TARGET}" >&2
|
| 233 |
+
printf 'init=%s\nout=%s num_cot_tokens=%s\n' "${init_adapter}" "${out_dir}" "${cot}" >&2
|
| 234 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 235 |
+
--model_name "${MODEL_NAME}" \
|
| 236 |
+
--train_jsonl "${train_jsonl}" \
|
| 237 |
+
--output_dir "${out_dir}" \
|
| 238 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 239 |
+
--init_adapter_dir "${init_adapter}" \
|
| 240 |
+
--seed 0 \
|
| 241 |
+
--gpu_id 0 \
|
| 242 |
+
--stage_i "${stage}" \
|
| 243 |
+
--num_cot_tokens "${cot}" \
|
| 244 |
+
--total_empties_hint "${EMPTIES}" \
|
| 245 |
+
--mixed_stage1_ratio 0 \
|
| 246 |
+
--mixed_stage2_ratio 1 \
|
| 247 |
+
--per_device_train_batch_size 8 \
|
| 248 |
+
--gradient_accumulation_steps 2 \
|
| 249 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 250 |
+
--learning_rate 1e-6 \
|
| 251 |
+
--logging_steps 20 \
|
| 252 |
+
--save_steps 200 \
|
| 253 |
+
--eval_steps 500 \
|
| 254 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 255 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 256 |
+
--num_generations 4 \
|
| 257 |
+
--max_prompt_length 1024 \
|
| 258 |
+
--max_completion_length 24 \
|
| 259 |
+
--beta 0.0 \
|
| 260 |
+
--enable_gradient_checkpointing \
|
| 261 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 262 |
+
--reward_good_value 1.25 \
|
| 263 |
+
--penalty_bad_value 1.0 \
|
| 264 |
+
--penalty_malformed 4.0 \
|
| 265 |
+
--penalty_empty 0.5 \
|
| 266 |
+
--penalty_singleton 1.5 \
|
| 267 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 268 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 269 |
+
--eval_solve_rate_stop 0 \
|
| 270 |
+
--min_steps_before_stop "${MIN_STEPS_BEFORE_STOP}" \
|
| 271 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 272 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 273 |
+
--lora_r "${LORA_R}" \
|
| 274 |
+
--lora_alpha "${LORA_ALPHA}" \
|
| 275 |
+
--lora_dropout "${LORA_DROPOUT}" \
|
| 276 |
+
--use_wandb \
|
| 277 |
+
--wandb_project "sudoku-latent-multi-output-grpo-residual-projector" \
|
| 278 |
+
--wandb_run_name "latent7_st${stage}_grpo_i${stage}_${TAG_SUFFIX}_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 279 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 280 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 284 |
+
printf 'ERROR: Missing train or eval jsonl.\n' >&2
|
| 285 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 286 |
+
exit 1
|
| 287 |
+
fi
|
| 288 |
+
|
| 289 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 290 |
+
printf 'Fast-forward: stage-2 latent SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 291 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 292 |
+
S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
|
| 293 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 294 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 295 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 296 |
+
exit 1
|
| 297 |
+
fi
|
| 298 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 299 |
+
run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
|
| 300 |
+
A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
|
| 301 |
+
if [[ -z "${A2}" ]]; then
|
| 302 |
+
printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 303 |
+
exit 1
|
| 304 |
+
fi
|
| 305 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 306 |
+
run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
|
| 307 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 308 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 309 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 310 |
+
exit 1
|
| 311 |
+
fi
|
| 312 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 313 |
+
run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
|
| 314 |
+
A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
|
| 315 |
+
if [[ -z "${A3}" ]]; then
|
| 316 |
+
printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 317 |
+
exit 1
|
| 318 |
+
fi
|
| 319 |
+
printf '\nAll latent phases finished (started at stage-2 GRPO).\n'
|
| 320 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 321 |
+
printf 'Final latent GRPO adapter: %s\n' "${A3}"
|
| 322 |
+
exit 0
|
| 323 |
+
fi
|
| 324 |
+
|
| 325 |
+
if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 326 |
+
printf 'Fast-forward: stage-2 latent GRPO dir %s → stage-3 SFT + GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 327 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 328 |
+
A2="$(resolve_latent_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
|
| 329 |
+
if [[ -z "${A2}" ]]; then
|
| 330 |
+
printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 331 |
+
exit 1
|
| 332 |
+
fi
|
| 333 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 334 |
+
run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
|
| 335 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 336 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 337 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 338 |
+
exit 1
|
| 339 |
+
fi
|
| 340 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 341 |
+
run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
|
| 342 |
+
A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
|
| 343 |
+
if [[ -z "${A3}" ]]; then
|
| 344 |
+
printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 345 |
+
exit 1
|
| 346 |
+
fi
|
| 347 |
+
printf '\nAll latent phases finished (started after stage-2 GRPO).\n'
|
| 348 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 349 |
+
printf 'Final latent GRPO adapter: %s\n' "${A3}"
|
| 350 |
+
exit 0
|
| 351 |
+
fi
|
| 352 |
+
|
| 353 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 354 |
+
printf 'Value gate: precision AND recall >= %s (min_steps=%s)\n' "${VALUE_TARGET}" "${MIN_STEPS_BEFORE_STOP}"
|
| 355 |
+
|
| 356 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 357 |
+
S1_SFT_DIR="${OUTPUT_ROOT}/stage01_sft_i1_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 358 |
+
STAGE1_INIT="${STAGE1_INIT_ADAPTER_DIR:-}"
|
| 359 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 360 |
+
A1="$(resolve_latent_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
|
| 361 |
+
elif [[ "${STAGE1_GRPO_FIRST}" == "1" ]]; then
|
| 362 |
+
# Legacy: stage-1 GRPO first (fresh LoRA + random residual unless STAGE1_INIT_ADAPTER_DIR set).
|
| 363 |
+
run_latent_grpo 1 "${STAGE1_INIT}" "${G1_DIR}" 1
|
| 364 |
+
A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
|
| 365 |
+
else
|
| 366 |
+
# Default: stage-1 SFT → stage-1 GRPO (matches text post-s1sft pipeline).
|
| 367 |
+
if [[ -n "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 368 |
+
G1_SFT_CKPT="${STAGE1_SFT_ADAPTER_DIR}"
|
| 369 |
+
printf 'Using existing stage-1 SFT checkpoint as GRPO init (skipping stage-1 SFT train): %s\n' "${G1_SFT_CKPT}" >&2
|
| 370 |
+
else
|
| 371 |
+
run_latent_sft 1 "${STAGE1_INIT}" "${S1_SFT_DIR}" "${STAGE1_SFT_LR}" 1
|
| 372 |
+
G1_SFT_CKPT="$(latest_sft_step_ckpt "${S1_SFT_DIR}")"
|
| 373 |
+
if [[ -z "${G1_SFT_CKPT}" ]]; then
|
| 374 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S1_SFT_DIR}" >&2
|
| 375 |
+
exit 1
|
| 376 |
+
fi
|
| 377 |
+
fi
|
| 378 |
+
run_latent_grpo 1 "${G1_SFT_CKPT}" "${G1_DIR}" 1
|
| 379 |
+
A1="$(resolve_latent_grpo_adapter "${G1_DIR}")"
|
| 380 |
+
fi
|
| 381 |
+
if [[ -z "${A1}" ]]; then
|
| 382 |
+
printf 'ERROR: Could not resolve stage-1 latent GRPO adapter.\n' >&2
|
| 383 |
+
exit 1
|
| 384 |
+
fi
|
| 385 |
+
printf 'Stage-1 latent GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 386 |
+
|
| 387 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 388 |
+
run_latent_sft 2 "${A1}" "${S2_DIR}" "5e-5" 2
|
| 389 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 390 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 391 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 392 |
+
exit 1
|
| 393 |
+
fi
|
| 394 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 395 |
+
run_latent_grpo 2 "${CKPT_S2}" "${G2_DIR}" 2
|
| 396 |
+
A2="$(resolve_latent_grpo_adapter "${G2_DIR}")"
|
| 397 |
+
if [[ -z "${A2}" ]]; then
|
| 398 |
+
printf 'ERROR: Could not resolve stage-2 latent GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 399 |
+
exit 1
|
| 400 |
+
fi
|
| 401 |
+
|
| 402 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 403 |
+
run_latent_sft 3 "${A2}" "${S3_DIR}" "5e-5" 3
|
| 404 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 405 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 406 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 407 |
+
exit 1
|
| 408 |
+
fi
|
| 409 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty_${TAG_SUFFIX}"
|
| 410 |
+
run_latent_grpo 3 "${CKPT_S3}" "${G3_DIR}" 3
|
| 411 |
+
A3="$(resolve_latent_grpo_adapter "${G3_DIR}")"
|
| 412 |
+
if [[ -z "${A3}" ]]; then
|
| 413 |
+
printf 'ERROR: Could not resolve stage-3 latent GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 414 |
+
exit 1
|
| 415 |
+
fi
|
| 416 |
+
|
| 417 |
+
printf '\nAll latent residual phases finished.\n'
|
| 418 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 419 |
+
printf 'Final latent GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_7empty_post_s1sft_stages123_value98.sh
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run AFTER stage-1 SFT finishes (7-empty). Order:
|
| 3 |
+
# 1) Stage-1 GRPO (init = your stage-1 SFT adapter)
|
| 4 |
+
# 2) Stage-2 SFT (init = stage-1 GRPO adapter)
|
| 5 |
+
# 3) Stage-2 GRPO (init = stage-2 SFT adapter)
|
| 6 |
+
# 4) Stage-3 SFT (init = stage-2 GRPO adapter)
|
| 7 |
+
# 5) Stage-3 GRPO (init = stage-3 SFT adapter)
|
| 8 |
+
#
|
| 9 |
+
# Each SFT/GRPO phase stops early only when BOTH eval value_precision AND value_recall
|
| 10 |
+
# are >= VALUE_TARGET (default 0.98). Other metric gates are disabled (0). Defaults use
|
| 11 |
+
# very large max_steps / epochs so in practice you exit on the 0.98 gate, not a low cap
|
| 12 |
+
# (override SFT_MAX_STEPS / GRPO_MAX_STEPS if you want a hard ceiling).
|
| 13 |
+
#
|
| 14 |
+
# Required (full pipeline from stage-1 SFT):
|
| 15 |
+
# STAGE1_SFT_ADAPTER_DIR=/path/to/checkpoint-step-XXXXX
|
| 16 |
+
#
|
| 17 |
+
# Resume after stage-1 GRPO already ran (skip GRPO i=1, start at stage-2 SFT):
|
| 18 |
+
# RESUME_FROM_STAGE1_GRPO_DIR=/path/to/stage01_grpo_i1_7empty
|
| 19 |
+
# (OUTPUT_ROOT defaults to dirname of that dir.)
|
| 20 |
+
#
|
| 21 |
+
# Resume after stage-2 SFT already ran (skip through stage-2 SFT, start at stage-2 GRPO):
|
| 22 |
+
# START_AT_STAGE2_GRPO_DIR=/path/to/stage02_sft_i2_7empty
|
| 23 |
+
#
|
| 24 |
+
# Resume after stage-2 GRPO finished (stage-3 SFT + stage-3 GRPO only):
|
| 25 |
+
# START_AFTER_STAGE2_GRPO_DIR=/path/to/stage02_grpo_i2_7empty
|
| 26 |
+
#
|
| 27 |
+
# Optional:
|
| 28 |
+
# VALUE_TARGET=0.98 SFT_MAX_STEPS=... GRPO_MAX_STEPS=... SFT_NUM_EPOCHS=... GRPO_NUM_TRAIN_EPOCHS=...
|
| 29 |
+
# TRAIN_PUZZLES=10000 EVAL_PUZZLES=100 RUN_TAG=... CHECKPOINT_ROOT=... USE_GC=1 PHASE_WALL_CLOCK_SECONDS=0
|
| 30 |
+
#
|
| 31 |
+
set -euo pipefail
|
| 32 |
+
|
| 33 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 34 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 35 |
+
|
| 36 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 37 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 38 |
+
GRPO_SCRIPT="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 39 |
+
|
| 40 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 41 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 42 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 43 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 44 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 45 |
+
|
| 46 |
+
EMPTIES=7
|
| 47 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 48 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 49 |
+
VALUE_TARGET="${VALUE_TARGET:-0.98}"
|
| 50 |
+
# Large caps so training is gated by VALUE_TARGET, not an arbitrary step limit.
|
| 51 |
+
SFT_MAX_STEPS="${SFT_MAX_STEPS:-10000000}"
|
| 52 |
+
GRPO_MAX_STEPS="${GRPO_MAX_STEPS:-10000000}"
|
| 53 |
+
SFT_NUM_EPOCHS="${SFT_NUM_EPOCHS:-512}"
|
| 54 |
+
GRPO_NUM_TRAIN_EPOCHS="${GRPO_NUM_TRAIN_EPOCHS:-200}"
|
| 55 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 56 |
+
|
| 57 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_7empty_stages123_value98}"
|
| 58 |
+
START_AT_STAGE2_GRPO_DIR="${START_AT_STAGE2_GRPO_DIR:-}"
|
| 59 |
+
START_AFTER_STAGE2_GRPO_DIR="${START_AFTER_STAGE2_GRPO_DIR:-}"
|
| 60 |
+
RESUME_FROM_STAGE1_GRPO_DIR="${RESUME_FROM_STAGE1_GRPO_DIR:-}"
|
| 61 |
+
|
| 62 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 63 |
+
if [[ ! -d "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 64 |
+
printf 'ERROR: START_AT_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 65 |
+
exit 1
|
| 66 |
+
fi
|
| 67 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AT_STAGE2_GRPO_DIR}")}"
|
| 68 |
+
elif [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 69 |
+
if [[ ! -d "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 70 |
+
printf 'ERROR: START_AFTER_STAGE2_GRPO_DIR is not a directory: %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 71 |
+
exit 1
|
| 72 |
+
fi
|
| 73 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${START_AFTER_STAGE2_GRPO_DIR}")}"
|
| 74 |
+
elif [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 75 |
+
if [[ ! -d "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 76 |
+
printf 'ERROR: RESUME_FROM_STAGE1_GRPO_DIR is not a directory: %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}" >&2
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
# Keep stage02+ next to stage01 (same parent dir). Override with OUTPUT_ROOT=... if needed.
|
| 80 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-$(dirname "${RESUME_FROM_STAGE1_GRPO_DIR}")}"
|
| 81 |
+
else
|
| 82 |
+
if [[ -z "${STAGE1_SFT_ADAPTER_DIR:-}" ]] || [[ ! -d "${STAGE1_SFT_ADAPTER_DIR}" ]]; then
|
| 83 |
+
printf 'ERROR: Set STAGE1_SFT_ADAPTER_DIR to a finished stage-1 SFT checkpoint directory, or RESUME_FROM_STAGE1_GRPO_DIR, START_AT_STAGE2_GRPO_DIR, or START_AFTER_STAGE2_GRPO_DIR.\n' >&2
|
| 84 |
+
exit 1
|
| 85 |
+
fi
|
| 86 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}}"
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 90 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 91 |
+
|
| 92 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 93 |
+
|
| 94 |
+
mkdir -p "${OUTPUT_ROOT}"
|
| 95 |
+
|
| 96 |
+
latest_sft_step_ckpt() {
|
| 97 |
+
local d="$1"
|
| 98 |
+
shopt -s nullglob
|
| 99 |
+
local cks=("${d}"/checkpoint-step-*)
|
| 100 |
+
shopt -u nullglob
|
| 101 |
+
if (( ${#cks[@]} == 0 )); then
|
| 102 |
+
printf ''
|
| 103 |
+
return 1
|
| 104 |
+
fi
|
| 105 |
+
# Avoid set -o pipefail treating a pipeline edge case as failure (was breaking the caller).
|
| 106 |
+
set +o pipefail
|
| 107 |
+
printf '%s\n' "${cks[@]}" | sort -V | tail -n 1
|
| 108 |
+
set -o pipefail
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# GRPO: prefer final root adapter; else latest TRL checkpoint dir with adapter weights.
|
| 112 |
+
resolve_grpo_adapter() {
|
| 113 |
+
local d="$1"
|
| 114 |
+
if [[ -f "${d}/adapter_model.safetensors" ]]; then
|
| 115 |
+
printf '%s\n' "${d}"
|
| 116 |
+
return 0
|
| 117 |
+
fi
|
| 118 |
+
local best="" step=-1
|
| 119 |
+
shopt -s nullglob
|
| 120 |
+
local c
|
| 121 |
+
for c in "${d}"/checkpoint-*; do
|
| 122 |
+
[[ -d "${c}" ]] || continue
|
| 123 |
+
[[ -f "${c}/adapter_model.safetensors" ]] || continue
|
| 124 |
+
local n
|
| 125 |
+
n="${c##*checkpoint-}"
|
| 126 |
+
if [[ "${n}" =~ ^[0-9]+$ ]] && (( 10#${n} >= step )); then
|
| 127 |
+
step=$((10#${n}))
|
| 128 |
+
best="${c}"
|
| 129 |
+
fi
|
| 130 |
+
done
|
| 131 |
+
shopt -u nullglob
|
| 132 |
+
if [[ -n "${best}" ]]; then
|
| 133 |
+
printf '%s\n' "${best}"
|
| 134 |
+
return 0
|
| 135 |
+
fi
|
| 136 |
+
printf ''
|
| 137 |
+
return 1
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
GC_FLAGS=()
|
| 141 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 142 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
run_sft() {
|
| 146 |
+
local stage="$1"
|
| 147 |
+
local init_adapter="$2"
|
| 148 |
+
local out_dir="$3"
|
| 149 |
+
local lr="$4"
|
| 150 |
+
mkdir -p "${out_dir}"
|
| 151 |
+
printf '\n=== Stage %s SFT → stop when value prec+recall >= %s (max_steps=%s epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" >&2
|
| 152 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 153 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 154 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 155 |
+
--train_jsonl "${train_jsonl}" \
|
| 156 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 157 |
+
--output_dir "${out_dir}" \
|
| 158 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 159 |
+
--init_adapter_dir "${init_adapter}" \
|
| 160 |
+
--seed 0 \
|
| 161 |
+
--gpu_id 0 \
|
| 162 |
+
--stage_i "${stage}" \
|
| 163 |
+
--total_empties_hint "${EMPTIES}" \
|
| 164 |
+
--per_device_train_batch_size 16 \
|
| 165 |
+
--gradient_accumulation_steps 2 \
|
| 166 |
+
--num_epochs "${SFT_NUM_EPOCHS}" \
|
| 167 |
+
--learning_rate "${lr}" \
|
| 168 |
+
--max_grad_norm 1.0 \
|
| 169 |
+
"${GC_FLAGS[@]}" \
|
| 170 |
+
--logging_steps 20 \
|
| 171 |
+
--eval_steps 250 \
|
| 172 |
+
--save_steps 200 \
|
| 173 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 174 |
+
--max_completion_length 24 \
|
| 175 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 176 |
+
--lora_r 32 \
|
| 177 |
+
--lora_alpha 64 \
|
| 178 |
+
--lora_dropout 0.05 \
|
| 179 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 180 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 181 |
+
--eval_exact_set_match_stop 0 \
|
| 182 |
+
--eval_solve_rate_stop 0 \
|
| 183 |
+
--min_steps_before_stop 50 \
|
| 184 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 185 |
+
--max_steps "${SFT_MAX_STEPS}" \
|
| 186 |
+
--use_wandb \
|
| 187 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 188 |
+
--wandb_run_name "postS1_st${stage}_sft_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 189 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 190 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
run_grpo() {
|
| 194 |
+
local stage="$1"
|
| 195 |
+
local init_adapter="$2"
|
| 196 |
+
local out_dir="$3"
|
| 197 |
+
mkdir -p "${out_dir}"
|
| 198 |
+
printf '\n=== Stage %s GRPO → stop when value prec+recall >= %s (max_steps=%s num_train_epochs=%s) ===\n' "${stage}" "${VALUE_TARGET}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" >&2
|
| 199 |
+
printf 'init=%s\nout=%s\n' "${init_adapter}" "${out_dir}" >&2
|
| 200 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${GRPO_SCRIPT}" \
|
| 201 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 202 |
+
--train_jsonl "${train_jsonl}" \
|
| 203 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 204 |
+
--output_dir "${out_dir}" \
|
| 205 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 206 |
+
--init_adapter_dir "${init_adapter}" \
|
| 207 |
+
--seed 0 \
|
| 208 |
+
--gpu_id 0 \
|
| 209 |
+
--stage_i "${stage}" \
|
| 210 |
+
--total_empties_hint "${EMPTIES}" \
|
| 211 |
+
--per_device_train_batch_size 8 \
|
| 212 |
+
--gradient_accumulation_steps 2 \
|
| 213 |
+
--num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS}" \
|
| 214 |
+
--learning_rate 1e-6 \
|
| 215 |
+
--logging_steps 20 \
|
| 216 |
+
--save_steps 200 \
|
| 217 |
+
--eval_steps 500 \
|
| 218 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 219 |
+
--num_generations 4 \
|
| 220 |
+
--max_prompt_length 1024 \
|
| 221 |
+
--max_completion_length 24 \
|
| 222 |
+
--beta 0.0 \
|
| 223 |
+
--enable_gradient_checkpointing \
|
| 224 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 225 |
+
--reward_good_value 1.25 \
|
| 226 |
+
--penalty_bad_value 1.0 \
|
| 227 |
+
--penalty_malformed 4.0 \
|
| 228 |
+
--penalty_empty 0.5 \
|
| 229 |
+
--penalty_singleton 1.5 \
|
| 230 |
+
--eval_value_precision_stop "${VALUE_TARGET}" \
|
| 231 |
+
--eval_value_recall_stop "${VALUE_TARGET}" \
|
| 232 |
+
--eval_solve_rate_stop 0 \
|
| 233 |
+
--min_steps_before_stop 50 \
|
| 234 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 235 |
+
--max_steps "${GRPO_MAX_STEPS}" \
|
| 236 |
+
--use_wandb \
|
| 237 |
+
--wandb_project "sudoku-multi-output-grpo" \
|
| 238 |
+
--wandb_run_name "postS1_st${stage}_grpo_i${stage}_${EMPTIES}empty_val${VALUE_TARGET}_${RUN_TAG}" \
|
| 239 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 240 |
+
--wandb_entity "${WANDB_ENTITY}"
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
if [[ ! -f "${train_jsonl}" ]] || [[ ! -f "${eval_jsonl}" ]]; then
|
| 244 |
+
printf 'ERROR: Missing train/eval jsonl. Build stage-1 datasets first (see launch_sft_stage1_95p.sh / build_dataset.py).\n' >&2
|
| 245 |
+
printf ' %s\n %s\n' "${train_jsonl}" "${eval_jsonl}" >&2
|
| 246 |
+
exit 1
|
| 247 |
+
fi
|
| 248 |
+
|
| 249 |
+
if [[ -n "${START_AT_STAGE2_GRPO_DIR}" ]]; then
|
| 250 |
+
printf 'Fast-forward: stage-2 SFT dir %s → stage-2 GRPO, then stage 3.\n' "${START_AT_STAGE2_GRPO_DIR}" >&2
|
| 251 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 252 |
+
S2_DIR="${START_AT_STAGE2_GRPO_DIR}"
|
| 253 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 254 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 255 |
+
printf 'ERROR: No checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 256 |
+
exit 1
|
| 257 |
+
fi
|
| 258 |
+
printf 'Using SFT checkpoint: %s\n' "${CKPT_S2}" >&2
|
| 259 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 260 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 261 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 262 |
+
if [[ -z "${A2}" ]]; then
|
| 263 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 264 |
+
exit 1
|
| 265 |
+
fi
|
| 266 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 267 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 268 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 269 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 270 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 271 |
+
exit 1
|
| 272 |
+
fi
|
| 273 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 274 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 275 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 276 |
+
if [[ -z "${A3}" ]]; then
|
| 277 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 278 |
+
exit 1
|
| 279 |
+
fi
|
| 280 |
+
printf '\nAll phases finished (started at stage-2 GRPO).\n'
|
| 281 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 282 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 283 |
+
exit 0
|
| 284 |
+
fi
|
| 285 |
+
|
| 286 |
+
if [[ -n "${START_AFTER_STAGE2_GRPO_DIR}" ]]; then
|
| 287 |
+
printf 'Fast-forward: stage-2 GRPO dir %s → stage-3 SFT + stage-3 GRPO.\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 288 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 289 |
+
A2="$(resolve_grpo_adapter "${START_AFTER_STAGE2_GRPO_DIR}")"
|
| 290 |
+
if [[ -z "${A2}" ]]; then
|
| 291 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${START_AFTER_STAGE2_GRPO_DIR}" >&2
|
| 292 |
+
exit 1
|
| 293 |
+
fi
|
| 294 |
+
printf 'Using stage-2 GRPO adapter: %s\n' "${A2}" >&2
|
| 295 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 296 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 297 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 298 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 299 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 300 |
+
exit 1
|
| 301 |
+
fi
|
| 302 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 303 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 304 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 305 |
+
if [[ -z "${A3}" ]]; then
|
| 306 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 307 |
+
exit 1
|
| 308 |
+
fi
|
| 309 |
+
printf '\nAll phases finished (started after stage-2 GRPO).\n'
|
| 310 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 311 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
| 312 |
+
exit 0
|
| 313 |
+
fi
|
| 314 |
+
|
| 315 |
+
printf 'Pipeline root: %s\n' "${OUTPUT_ROOT}"
|
| 316 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 317 |
+
printf 'Resume: using existing stage-1 GRPO dir %s\n' "${RESUME_FROM_STAGE1_GRPO_DIR}"
|
| 318 |
+
else
|
| 319 |
+
printf 'Stage-1 SFT adapter: %s\n' "${STAGE1_SFT_ADAPTER_DIR}"
|
| 320 |
+
fi
|
| 321 |
+
printf 'Value gate: precision AND recall >= %s | SFT max_steps=%s epochs=%s | GRPO max_steps=%s train_epochs=%s | wall=%s\n' \
|
| 322 |
+
"${VALUE_TARGET}" "${SFT_MAX_STEPS}" "${SFT_NUM_EPOCHS}" "${GRPO_MAX_STEPS}" "${GRPO_NUM_TRAIN_EPOCHS}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 323 |
+
|
| 324 |
+
# --- Stage 1 GRPO (skip if resuming) ---
|
| 325 |
+
G1_DIR="${OUTPUT_ROOT}/stage01_grpo_i1_${EMPTIES}empty"
|
| 326 |
+
if [[ -n "${RESUME_FROM_STAGE1_GRPO_DIR}" ]]; then
|
| 327 |
+
A1="$(resolve_grpo_adapter "${RESUME_FROM_STAGE1_GRPO_DIR}")"
|
| 328 |
+
else
|
| 329 |
+
run_grpo 1 "${STAGE1_SFT_ADAPTER_DIR}" "${G1_DIR}"
|
| 330 |
+
A1="$(resolve_grpo_adapter "${G1_DIR}")"
|
| 331 |
+
fi
|
| 332 |
+
if [[ -z "${A1}" ]]; then
|
| 333 |
+
printf 'ERROR: Could not resolve stage-1 GRPO adapter (resume dir or %s)\n' "${G1_DIR}" >&2
|
| 334 |
+
exit 1
|
| 335 |
+
fi
|
| 336 |
+
printf 'Stage-1 GRPO adapter for stage-2 SFT init: %s\n' "${A1}"
|
| 337 |
+
|
| 338 |
+
# --- Stage 2 SFT + GRPO ---
|
| 339 |
+
S2_DIR="${OUTPUT_ROOT}/stage02_sft_i2_${EMPTIES}empty"
|
| 340 |
+
run_sft 2 "${A1}" "${S2_DIR}" "5e-5"
|
| 341 |
+
CKPT_S2="$(latest_sft_step_ckpt "${S2_DIR}")"
|
| 342 |
+
if [[ -z "${CKPT_S2}" ]]; then
|
| 343 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S2_DIR}" >&2
|
| 344 |
+
exit 1
|
| 345 |
+
fi
|
| 346 |
+
G2_DIR="${OUTPUT_ROOT}/stage02_grpo_i2_${EMPTIES}empty"
|
| 347 |
+
run_grpo 2 "${CKPT_S2}" "${G2_DIR}"
|
| 348 |
+
A2="$(resolve_grpo_adapter "${G2_DIR}")"
|
| 349 |
+
if [[ -z "${A2}" ]]; then
|
| 350 |
+
printf 'ERROR: Could not resolve stage-2 GRPO adapter under %s\n' "${G2_DIR}" >&2
|
| 351 |
+
exit 1
|
| 352 |
+
fi
|
| 353 |
+
|
| 354 |
+
# --- Stage 3 SFT + GRPO ---
|
| 355 |
+
S3_DIR="${OUTPUT_ROOT}/stage03_sft_i3_${EMPTIES}empty"
|
| 356 |
+
run_sft 3 "${A2}" "${S3_DIR}" "5e-5"
|
| 357 |
+
CKPT_S3="$(latest_sft_step_ckpt "${S3_DIR}")"
|
| 358 |
+
if [[ -z "${CKPT_S3}" ]]; then
|
| 359 |
+
printf 'ERROR: No SFT checkpoint-step-* under %s\n' "${S3_DIR}" >&2
|
| 360 |
+
exit 1
|
| 361 |
+
fi
|
| 362 |
+
G3_DIR="${OUTPUT_ROOT}/stage03_grpo_i3_${EMPTIES}empty"
|
| 363 |
+
run_grpo 3 "${CKPT_S3}" "${G3_DIR}"
|
| 364 |
+
A3="$(resolve_grpo_adapter "${G3_DIR}")"
|
| 365 |
+
if [[ -z "${A3}" ]]; then
|
| 366 |
+
printf 'ERROR: Could not resolve stage-3 GRPO adapter under %s\n' "${G3_DIR}" >&2
|
| 367 |
+
exit 1
|
| 368 |
+
fi
|
| 369 |
+
|
| 370 |
+
printf '\nAll phases finished.\n'
|
| 371 |
+
printf 'Outputs under: %s\n' "${OUTPUT_ROOT}"
|
| 372 |
+
printf 'Final GRPO adapter (stage 3): %s\n' "${A3}"
|
hard_9x9_stage1_consistency_queue/launch_sft_stage1_95p.sh
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Stage-1 SFT only (no GRPO): train until eval value_precision AND value_recall both >= SFT_TARGET
|
| 3 |
+
# (default 0.95), or max_steps / optional wall clock. Use this before tuning GRPO.
|
| 4 |
+
#
|
| 5 |
+
# Fresh LoRA on base model:
|
| 6 |
+
# ./launch_sft_stage1_95p.sh
|
| 7 |
+
# EMPTIES=10 ./launch_sft_stage1_95p.sh
|
| 8 |
+
#
|
| 9 |
+
# Continue from a prior SFT checkpoint (recommended after a 90% run plateaus):
|
| 10 |
+
# INIT_ADAPTER_DIR=/path/to/checkpoint-step-01200 ./launch_sft_stage1_95p.sh
|
| 11 |
+
#
|
| 12 |
+
# Optional W&B display name:
|
| 13 |
+
# WANDB_RUN_NAME=my_run_name SFT_TARGET=0.99 MAX_STEPS=30000 INIT_ADAPTER_DIR=... ./launch_sft_stage1_95p.sh
|
| 14 |
+
#
|
| 15 |
+
set -euo pipefail
|
| 16 |
+
|
| 17 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 18 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 19 |
+
|
| 20 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 21 |
+
DATASET_BUILDER="${ROOT}/simple_9x9_curriculum/build_dataset.py"
|
| 22 |
+
SFT_SCRIPT="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 23 |
+
|
| 24 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 25 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 26 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 27 |
+
WANDB_MODE="${WANDB_MODE:-online}"
|
| 28 |
+
WANDB_ENTITY="${WANDB_ENTITY:-training-dynamics}"
|
| 29 |
+
|
| 30 |
+
EMPTIES="${EMPTIES:-7}"
|
| 31 |
+
TRAIN_PUZZLES="${TRAIN_PUZZLES:-10000}"
|
| 32 |
+
# Held-out eval: 100 rows lines up with the baseline queue and stabilizes 95% gate vs tiny eval.
|
| 33 |
+
EVAL_PUZZLES="${EVAL_PUZZLES:-100}"
|
| 34 |
+
SFT_TARGET="${SFT_TARGET:-0.95}"
|
| 35 |
+
PHASE_WALL_CLOCK_SECONDS="${PHASE_WALL_CLOCK_SECONDS:-0}"
|
| 36 |
+
MAX_STEPS="${MAX_STEPS:-12000}"
|
| 37 |
+
|
| 38 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/hard_9x9_sft95_stage1}"
|
| 39 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${CHECKPOINT_ROOT}/${RUN_TAG}/${EMPTIES}empty/stage01_sft_i1_${EMPTIES}empty_sft95}"
|
| 40 |
+
|
| 41 |
+
train_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_train.jsonl"
|
| 42 |
+
eval_jsonl="${ROOT}/data/sudoku_t3_${EMPTIES}empty_value_qwen_text_stage1_eval.jsonl"
|
| 43 |
+
|
| 44 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 45 |
+
|
| 46 |
+
if [[ ! -f "${train_jsonl}" ]]; then
|
| 47 |
+
mkdir -p "$(dirname "${train_jsonl}")"
|
| 48 |
+
printf 'Building %s-empty train dataset: %s\n' "${EMPTIES}" "${train_jsonl}"
|
| 49 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${train_jsonl}" --num_puzzles 10000 --empties "${EMPTIES}" --seed 0
|
| 50 |
+
fi
|
| 51 |
+
if [[ ! -f "${eval_jsonl}" ]]; then
|
| 52 |
+
mkdir -p "$(dirname "${eval_jsonl}")"
|
| 53 |
+
printf 'Building %s-empty eval dataset: %s\n' "${EMPTIES}" "${eval_jsonl}"
|
| 54 |
+
"${PYTHON_BIN}" "${DATASET_BUILDER}" --output "${eval_jsonl}" --num_puzzles "${EVAL_PUZZLES}" --empties "${EMPTIES}" --seed 1
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 58 |
+
|
| 59 |
+
INIT_FLAGS=()
|
| 60 |
+
if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
|
| 61 |
+
INIT_FLAGS+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
|
| 62 |
+
printf 'Warm-start from adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
GC_FLAGS=()
|
| 66 |
+
if [[ "${USE_GC:-0}" == "1" ]]; then
|
| 67 |
+
GC_FLAGS+=(--enable_gradient_checkpointing)
|
| 68 |
+
printf 'NOTE: USE_GC=1 — slower, less VRAM.\n'
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
if [[ "${PHASE_WALL_CLOCK_SECONDS}" -gt 0 ]]; then
|
| 72 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, wall %ss) ===\n' "${EMPTIES}" "${SFT_TARGET}" "${PHASE_WALL_CLOCK_SECONDS}"
|
| 73 |
+
else
|
| 74 |
+
printf '\n=== Stage1 SFT %s-empty (prec+recall >= %s, no wall cap) ===\n' "${EMPTIES}" "${SFT_TARGET}"
|
| 75 |
+
fi
|
| 76 |
+
printf 'Output: %s\n' "${OUTPUT_DIR}"
|
| 77 |
+
|
| 78 |
+
exec "${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${SFT_SCRIPT}" \
|
| 79 |
+
--model_name "Qwen/Qwen2.5-0.5B-Instruct" \
|
| 80 |
+
--train_jsonl "${train_jsonl}" \
|
| 81 |
+
--eval_jsonl "${eval_jsonl}" \
|
| 82 |
+
--output_dir "${OUTPUT_DIR}" \
|
| 83 |
+
--cache_dir "${ROOT}/.hf_cache" \
|
| 84 |
+
"${INIT_FLAGS[@]}" \
|
| 85 |
+
--seed 0 \
|
| 86 |
+
--gpu_id 0 \
|
| 87 |
+
--stage_i 1 \
|
| 88 |
+
--total_empties_hint "${EMPTIES}" \
|
| 89 |
+
--per_device_train_batch_size 16 \
|
| 90 |
+
--gradient_accumulation_steps 2 \
|
| 91 |
+
--num_epochs 24.0 \
|
| 92 |
+
--learning_rate 2e-4 \
|
| 93 |
+
--max_grad_norm 1.0 \
|
| 94 |
+
"${GC_FLAGS[@]}" \
|
| 95 |
+
--logging_steps 20 \
|
| 96 |
+
--eval_steps 250 \
|
| 97 |
+
--save_steps 100 \
|
| 98 |
+
--eval_rows "${EVAL_PUZZLES}" \
|
| 99 |
+
--max_completion_length 24 \
|
| 100 |
+
--limit_train_rows "${TRAIN_PUZZLES}" \
|
| 101 |
+
--lora_r 32 \
|
| 102 |
+
--lora_alpha 64 \
|
| 103 |
+
--lora_dropout 0.05 \
|
| 104 |
+
--eval_value_precision_stop "${SFT_TARGET}" \
|
| 105 |
+
--eval_value_recall_stop "${SFT_TARGET}" \
|
| 106 |
+
--min_steps_before_stop 50 \
|
| 107 |
+
--max_wall_clock_seconds "${PHASE_WALL_CLOCK_SECONDS}" \
|
| 108 |
+
--max_steps "${MAX_STEPS}" \
|
| 109 |
+
--use_wandb \
|
| 110 |
+
--wandb_project "sudoku-multi-output-sft" \
|
| 111 |
+
--wandb_run_name "${WANDB_RUN_NAME:-stage01_sft95_i1_${EMPTIES}empty_${RUN_TAG}}" \
|
| 112 |
+
--wandb_mode "${WANDB_MODE}" \
|
| 113 |
+
--wandb_entity "${WANDB_ENTITY}"
|
hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2_resume_summary_20260516.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Recurrent Hidden Stage 2 Resume Summary
|
| 2 |
+
|
| 3 |
+
This note summarizes the May 16, 2026 stage-2 recurrent-hidden SFT recovery and monitoring changes.
|
| 4 |
+
|
| 5 |
+
## What Changed
|
| 6 |
+
|
| 7 |
+
- Added explicit eval lifecycle logging in `latent_multi_output_cell_policy/sft_latent_multi_output_train.py`.
|
| 8 |
+
- W&B now receives `eval/in_progress`, `eval/rows`, and `eval/duration_seconds`.
|
| 9 |
+
- Local logs now print `[latent sft eval start ...]` and `[latent sft eval end ...]` markers.
|
| 10 |
+
|
| 11 |
+
## Why
|
| 12 |
+
|
| 13 |
+
The previous resumed stage-2 run reached step 2000 but appeared silent during validation. The validation metrics only logged after the whole eval completed, and the old `eval_rows=100` setting made a single validation take roughly 35 minutes. The run then crashed before producing the step-2000 eval metrics or checkpoint.
|
| 14 |
+
|
| 15 |
+
## Probe Result
|
| 16 |
+
|
| 17 |
+
A one-GPU eval probe from `checkpoint-step-01800` measured validation cost:
|
| 18 |
+
|
| 19 |
+
- Eval rows: 20 puzzles
|
| 20 |
+
- Eval duration: 427.3 seconds, about 7.1 minutes
|
| 21 |
+
- Exact set match: 0.9225
|
| 22 |
+
- Value precision: 0.945
|
| 23 |
+
- Value recall: 0.934
|
| 24 |
+
- Solve rate: 0.15
|
| 25 |
+
- W&B run: `xudqbjqh`
|
| 26 |
+
|
| 27 |
+
## Active Resume Run
|
| 28 |
+
|
| 29 |
+
The main run was restarted from:
|
| 30 |
+
|
| 31 |
+
`final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_from200_20260515_205857/latent_recurrent_hidden/stage02_latent_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01800`
|
| 32 |
+
|
| 33 |
+
Run settings:
|
| 34 |
+
|
| 35 |
+
- Stage: 2
|
| 36 |
+
- Latent mode: recurrent_hidden
|
| 37 |
+
- GPUs: 8
|
| 38 |
+
- Eval rows: 20
|
| 39 |
+
- Eval interval: every 100 steps
|
| 40 |
+
- Checkpoint interval: every 100 steps
|
| 41 |
+
- Max steps: 5000
|
| 42 |
+
- Early stop: disabled for solve rate; precision and recall target set to 0.9999
|
| 43 |
+
- W&B run: `h3lxi62v`
|
| 44 |
+
|
| 45 |
+
At the first eval:
|
| 46 |
+
|
| 47 |
+
- Step: 100
|
| 48 |
+
- Eval duration: 427.3 seconds
|
| 49 |
+
- Exact set match: 0.935
|
| 50 |
+
- Value precision: 0.95875
|
| 51 |
+
- Value recall: 0.94875
|
| 52 |
+
- Solve rate: 0.25
|
| 53 |
+
|
| 54 |
+
## Checkpoint Sync
|
| 55 |
+
|
| 56 |
+
The run output is periodically synced to Hugging Face every 10 minutes:
|
| 57 |
+
|
| 58 |
+
`Avra98/sudoku-latent-recurrent-hidden-20empty-stages/resume_runs/recurrent_hidden_resume_stage2sft_from1800_eval20_long_20260516_090446`
|
| 59 |
+
|
| 60 |
+
Confirmed uploaded checkpoint:
|
| 61 |
+
|
| 62 |
+
- `checkpoint-step-00100/adapter_model.safetensors`
|
| 63 |
+
- `checkpoint-step-00100/adapter_config.json`
|
| 64 |
+
- `checkpoint-step-00100/tokenizer.json`
|
| 65 |
+
- `checkpoint-step-00100/tokenizer_config.json`
|
| 66 |
+
- `checkpoint-step-00100/chat_template.jinja`
|
| 67 |
+
- `checkpoint-step-00100/README.md`
|
| 68 |
+
|
hard_9x9_stage1_consistency_queue/recurrent_hidden_stage2sft_resume.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Recurrent-Hidden 20-Empty Stage-2 SFT Resume
|
| 2 |
+
|
| 3 |
+
This note records the recovered recurrent-hidden run restarted on May 15, 2026.
|
| 4 |
+
|
| 5 |
+
## Source Checkpoints
|
| 6 |
+
|
| 7 |
+
Recovered adapters were downloaded from:
|
| 8 |
+
|
| 9 |
+
```text
|
| 10 |
+
https://huggingface.co/Avra98/sudoku-latent-recurrent-hidden-20empty-stages
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
Local snapshot path:
|
| 14 |
+
|
| 15 |
+
```text
|
| 16 |
+
/home/ubuntu/curriculum_cot/final_checkpoint/hf_sudoku_latent_recurrent_hidden_20empty_stages
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
Available recovered folders:
|
| 20 |
+
|
| 21 |
+
```text
|
| 22 |
+
stage01_latent_sft_i1_20empty_latent_recurrent_hidden
|
| 23 |
+
stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
|
| 24 |
+
stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden
|
| 25 |
+
stage02_latent_sft_i2_20empty_latent_recurrent_hidden
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
The uploaded stage-2 latent SFT checkpoint did not include `trainer_state.json`
|
| 29 |
+
or solve-rate metadata, so the restart intentionally resumes from the stage-2
|
| 30 |
+
baseline warm-up adapter and reruns stage-2 latent SFT instead of jumping to
|
| 31 |
+
stage-2 GRPO.
|
| 32 |
+
|
| 33 |
+
## Active Resume Run
|
| 34 |
+
|
| 35 |
+
Output root:
|
| 36 |
+
|
| 37 |
+
```text
|
| 38 |
+
/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
W&B run:
|
| 42 |
+
|
| 43 |
+
```text
|
| 44 |
+
https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1vyq1a1n
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Launch settings:
|
| 48 |
+
|
| 49 |
+
```text
|
| 50 |
+
MODEL_NAME=Qwen/Qwen2.5-1.5B-Instruct
|
| 51 |
+
MODES_SPEC=recurrent_hidden
|
| 52 |
+
GPU_GROUPS_SPEC=0,1,2,3,4,5,6,7
|
| 53 |
+
NPROC_PER_JOB=8
|
| 54 |
+
STAGE1_LATENT_GRPO_ADAPTER_DIR=<HF snapshot>/stage01_latent_grpo_i1_20empty_latent_recurrent_hidden
|
| 55 |
+
STAGE2_BASELINE_WARM_ADAPTER_DIR=<HF snapshot>/stage02_baseline_warm_sft_i2_20empty_latent_recurrent_hidden/checkpoint-step-01000
|
| 56 |
+
LATENT_SFT_MAX_STEPS=5000
|
| 57 |
+
LATENT_GRPO_MAX_STEPS=500
|
| 58 |
+
SOLVE_TARGET=0.95
|
| 59 |
+
VALUE_TARGET=0
|
| 60 |
+
MIN_STEPS_BEFORE_STOP=50
|
| 61 |
+
WANDB_MODE=online
|
| 62 |
+
WANDB_ENTITY=training-dynamics
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Backup Plan
|
| 66 |
+
|
| 67 |
+
Code changes are pushed to GitHub branch:
|
| 68 |
+
|
| 69 |
+
```text
|
| 70 |
+
llm-policy-icon-code
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Checkpoint backups should be pushed periodically to the same Hugging Face repo
|
| 74 |
+
using:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
HF_TOKEN=hf_xxx \
|
| 78 |
+
RUN_OUTPUT_DIR=/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_warm_baseline_all_latent_modes_stages123/recurrent_hidden_resume_stage2sft_20260515_184858 \
|
| 79 |
+
bash hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
The sync script uploads checkpoint folders, adapter files, tokenizer files, and
|
| 83 |
+
logs while ignoring W&B runtime directories and prepared-data caches.
|
hard_9x9_stage1_consistency_queue/sync_recurrent_hidden_checkpoints_to_hf.sh
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Periodically upload the active recurrent-hidden resume output to Hugging Face.
|
| 3 |
+
#
|
| 4 |
+
# Required:
|
| 5 |
+
# RUN_OUTPUT_DIR=/path/to/recurrent_hidden_resume_stage2sft_...
|
| 6 |
+
#
|
| 7 |
+
# Optional:
|
| 8 |
+
# HF_TOKEN=hf_... # otherwise uses `hf auth login` / cached login
|
| 9 |
+
# HF_REPO_ID=Avra98/sudoku-latent-recurrent-hidden-20empty-stages
|
| 10 |
+
# HF_REPO_PREFIX=resume_runs/<run_name>
|
| 11 |
+
# SYNC_INTERVAL_SECONDS=900
|
| 12 |
+
|
| 13 |
+
set -euo pipefail
|
| 14 |
+
|
| 15 |
+
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 16 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 17 |
+
HF_REPO_ID="${HF_REPO_ID:-Avra98/sudoku-latent-recurrent-hidden-20empty-stages}"
|
| 18 |
+
RUN_OUTPUT_DIR="${RUN_OUTPUT_DIR:-}"
|
| 19 |
+
SYNC_INTERVAL_SECONDS="${SYNC_INTERVAL_SECONDS:-900}"
|
| 20 |
+
|
| 21 |
+
if [[ -z "${RUN_OUTPUT_DIR}" ]] || [[ ! -d "${RUN_OUTPUT_DIR}" ]]; then
|
| 22 |
+
printf 'ERROR: Set RUN_OUTPUT_DIR to an existing run output directory.\n' >&2
|
| 23 |
+
exit 1
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
if [[ ! -x "${PYTHON_BIN}" ]]; then
|
| 27 |
+
printf 'ERROR: Python not found at %s\n' "${PYTHON_BIN}" >&2
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
RUN_NAME="$(basename "${RUN_OUTPUT_DIR}")"
|
| 32 |
+
HF_REPO_PREFIX="${HF_REPO_PREFIX:-resume_runs/${RUN_NAME}}"
|
| 33 |
+
export HF_REPO_ID RUN_OUTPUT_DIR HF_REPO_PREFIX
|
| 34 |
+
|
| 35 |
+
upload_once() {
|
| 36 |
+
"${PYTHON_BIN}" - <<'PY'
|
| 37 |
+
import os
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from huggingface_hub import HfApi, get_token
|
| 40 |
+
|
| 41 |
+
repo_id = os.environ["HF_REPO_ID"]
|
| 42 |
+
folder = Path(os.environ["RUN_OUTPUT_DIR"]).resolve()
|
| 43 |
+
path_in_repo = os.environ["HF_REPO_PREFIX"].strip("/")
|
| 44 |
+
|
| 45 |
+
token = os.environ.get("HF_TOKEN") or get_token()
|
| 46 |
+
if not token:
|
| 47 |
+
raise SystemExit("No Hugging Face token found. Run `hf auth login` or set HF_TOKEN.")
|
| 48 |
+
|
| 49 |
+
api = HfApi(token=token)
|
| 50 |
+
api.upload_folder(
|
| 51 |
+
repo_id=repo_id,
|
| 52 |
+
repo_type="model",
|
| 53 |
+
folder_path=str(folder),
|
| 54 |
+
path_in_repo=path_in_repo,
|
| 55 |
+
commit_message=f"Sync recurrent-hidden resume checkpoints: {folder.name}",
|
| 56 |
+
allow_patterns=[
|
| 57 |
+
"logs/**",
|
| 58 |
+
"**/checkpoint*/**",
|
| 59 |
+
"**/adapter_config.json",
|
| 60 |
+
"**/adapter_model.safetensors",
|
| 61 |
+
"**/tokenizer.json",
|
| 62 |
+
"**/tokenizer_config.json",
|
| 63 |
+
"**/chat_template.jinja",
|
| 64 |
+
"**/README.md",
|
| 65 |
+
"**/training_args.bin",
|
| 66 |
+
],
|
| 67 |
+
ignore_patterns=[
|
| 68 |
+
"**/wandb_runtime/**",
|
| 69 |
+
"**/.wandb/**",
|
| 70 |
+
"**/wandb/**",
|
| 71 |
+
"**/optimizer.pt",
|
| 72 |
+
"**/scheduler.pt",
|
| 73 |
+
"**/rng_state_*.pth",
|
| 74 |
+
],
|
| 75 |
+
)
|
| 76 |
+
print(f"Uploaded {folder} to {repo_id}/{path_in_repo}")
|
| 77 |
+
PY
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
while true; do
|
| 81 |
+
date -Is
|
| 82 |
+
upload_once
|
| 83 |
+
sleep "${SYNC_INTERVAL_SECONDS}"
|
| 84 |
+
done
|
hard_9x9_stage1_consistency_queue/warm_baseline_all_latent_modes_stages123_results.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Warm Baseline All-Latent Stages 1-3 Results
|
| 2 |
+
|
| 3 |
+
Run tag: `warmbaseline_alllatent_stages123_20260512_1620`
|
| 4 |
+
|
| 5 |
+
Base model: `Qwen/Qwen2.5-1.5B-Instruct`
|
| 6 |
+
|
| 7 |
+
Stage-1 warm baseline adapter:
|
| 8 |
+
|
| 9 |
+
```text
|
| 10 |
+
/home/ubuntu/curriculum_cot/final_checkpoint/hard_9x9_20empty_baseline_1p5b_warmup/baseline_1p5b_warmup_bs32_eval100_20260512_203845/20empty/stage01_sft_i1_20empty_1p5b_warmup/checkpoint-step-01000
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
This file records the solve-rate snapshot from the ongoing full pipeline. Later
|
| 14 |
+
stages should be updated when all modes finish.
|
| 15 |
+
|
| 16 |
+
## Current Phase Snapshot
|
| 17 |
+
|
| 18 |
+
| Mode | Current phase at snapshot |
|
| 19 |
+
| --- | --- |
|
| 20 |
+
| `residual` | Stage-2 latent SFT |
|
| 21 |
+
| `fixed_slots` | Stage-2 latent SFT |
|
| 22 |
+
| `recurrent_hidden` | Stage-2 baseline warm-up SFT |
|
| 23 |
+
| `latent_seeds` | Stage-3 baseline warm-up SFT |
|
| 24 |
+
|
| 25 |
+
## Latest Solve Rates By Phase
|
| 26 |
+
|
| 27 |
+
| Mode | Stage 1 latent SFT | Stage 1 latent GRPO | Stage 2 baseline warm-up | Stage 2 latent SFT | Stage 2 latent GRPO | Stage 3 baseline warm-up |
|
| 28 |
+
| --- | ---: | ---: | ---: | ---: | ---: | ---: |
|
| 29 |
+
| `residual` | 0.470 latest / 0.610 best | 0.620 | 0.110 latest / 0.150 best | in progress | not reached | not reached |
|
| 30 |
+
| `fixed_slots` | 0.770 latest / 0.770 best | 0.870 | 0.140 latest / 0.140 best | 0.100 latest / 0.100 best | not reached | not reached |
|
| 31 |
+
| `recurrent_hidden` | 0.860 latest / 0.860 best | 0.950 | 0.110 latest / 0.110 best | not reached | not reached | not reached |
|
| 32 |
+
| `latent_seeds` | 0.740 latest / 0.740 best | 0.860 | 0.090 latest / 0.100 best | 0.120 latest / 0.120 best | 0.090 | started, no eval yet |
|
| 33 |
+
|
| 34 |
+
## Stage 1 Solve Trajectories
|
| 35 |
+
|
| 36 |
+
| Mode | Latent SFT solve rates | Post-GRPO solve rate |
|
| 37 |
+
| --- | --- | ---: |
|
| 38 |
+
| `residual` | 0.320 -> 0.610 -> 0.520 -> 0.470 | 0.620 |
|
| 39 |
+
| `fixed_slots` | 0.650 -> 0.200 -> 0.660 -> 0.770 | 0.870 |
|
| 40 |
+
| `recurrent_hidden` | 0.400 -> 0.600 -> 0.800 -> 0.860 | 0.950 |
|
| 41 |
+
| `latent_seeds` | 0.290 -> 0.500 -> 0.640 -> 0.740 | 0.860 |
|
| 42 |
+
|
| 43 |
+
## Stage 2 Solve Trajectories So Far
|
| 44 |
+
|
| 45 |
+
| Mode | Baseline warm-up solve rates | Latent SFT solve rates | Post-GRPO solve rate |
|
| 46 |
+
| --- | --- | --- | ---: |
|
| 47 |
+
| `residual` | 0.050 -> 0.150 -> 0.110 -> 0.110 | in progress | not reached |
|
| 48 |
+
| `fixed_slots` | 0.090 -> 0.120 -> 0.080 -> 0.140 | 0.080 -> 0.100 | not reached |
|
| 49 |
+
| `recurrent_hidden` | 0.060 -> 0.090 -> 0.100 -> 0.110 | not reached | not reached |
|
| 50 |
+
| `latent_seeds` | 0.090 -> 0.100 -> 0.080 -> 0.090 | 0.080 -> 0.090 -> 0.110 -> 0.120 | 0.090 |
|
| 51 |
+
|
| 52 |
+
## W&B Links
|
| 53 |
+
|
| 54 |
+
Stage 1 latent SFT:
|
| 55 |
+
|
| 56 |
+
- `residual`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/sp4seb59
|
| 57 |
+
- `fixed_slots`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/d62aiu1g
|
| 58 |
+
- `recurrent_hidden`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/cv3nr7ie
|
| 59 |
+
- `latent_seeds`: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline/runs/1f818jfg
|
| 60 |
+
|
| 61 |
+
Additional stage runs are logged under:
|
| 62 |
+
|
| 63 |
+
- SFT project: https://wandb.ai/training-dynamics/sudoku-latent-stage-sft-warm-baseline
|
| 64 |
+
- GRPO project: https://wandb.ai/training-dynamics/sudoku-latent-stage-grpo-warm-baseline
|
| 65 |
+
- Baseline warm-up project: https://wandb.ai/training-dynamics/sudoku-baseline-stage-warmups
|
large_baseline_extension/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Large Baseline Extension Launchers
|
| 2 |
+
|
| 3 |
+
This folder contains launch scripts for the non-location baseline multi-output runs.
|
| 4 |
+
|
| 5 |
+
- `launch_nonlocation_pipeline.sh`
|
| 6 |
+
- `launch_nonlocation_sft.sh`
|
| 7 |
+
- `launch_nonlocation_grpo.sh`
|
| 8 |
+
|
| 9 |
+
The main entry point for a full staged resume run is `launch_nonlocation_pipeline.sh`.
|
| 10 |
+
|
| 11 |
+
Useful environment variables:
|
| 12 |
+
|
| 13 |
+
- `MIN_STAGE`
|
| 14 |
+
- `MAX_STAGE`
|
| 15 |
+
- `NUM_PROCESSES`
|
| 16 |
+
- `GPU_IDS`
|
| 17 |
+
- `BOOTSTRAP_ADAPTER_DIR`
|
| 18 |
+
- `OUTPUT_ROOT`
|
| 19 |
+
- `RUN_TAG`
|
| 20 |
+
- `LIMIT_TRAIN_ROWS`
|
| 21 |
+
- `WANDB_MODE`
|
| 22 |
+
- `WANDB_ENTITY`
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
|
| 26 |
+
```bash
|
| 27 |
+
MIN_STAGE=3 \
|
| 28 |
+
MAX_STAGE=5 \
|
| 29 |
+
NUM_PROCESSES=8 \
|
| 30 |
+
GPU_IDS=0,1,2,3,4,5,6,7 \
|
| 31 |
+
BOOTSTRAP_ADAPTER_DIR=/path/to/stage02_grpo \
|
| 32 |
+
WANDB_MODE=online \
|
| 33 |
+
WANDB_ENTITY=training-dynamics \
|
| 34 |
+
bash launch_nonlocation_pipeline.sh
|
| 35 |
+
```
|
large_baseline_extension/launch_nonlocation_grpo.sh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 7 |
+
TRAINER="${ROOT}/multi_output_cell_policy/grpo_multi_output_train.py"
|
| 8 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
|
| 9 |
+
CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
|
| 10 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 11 |
+
GPU_ID="${GPU_ID:-0}"
|
| 12 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 13 |
+
NUM_PROCESSES="${NUM_PROCESSES:-1}"
|
| 14 |
+
STAGE_I="${STAGE_I:-2}"
|
| 15 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 16 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo}"
|
| 17 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}"
|
| 18 |
+
WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-grpo}"
|
| 19 |
+
WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_grpo_i${STAGE_I}_${RUN_TAG}}"
|
| 20 |
+
WANDB_GROUP="${WANDB_GROUP:-large_baseline_extension_noloc_grpo_i${STAGE_I}}"
|
| 21 |
+
|
| 22 |
+
case "${STAGE_I}" in
|
| 23 |
+
2) default_bs=4; default_gas=2 ;;
|
| 24 |
+
3) default_bs=3; default_gas=2 ;;
|
| 25 |
+
4) default_bs=2; default_gas=4 ;;
|
| 26 |
+
5) default_bs=2; default_gas=4 ;;
|
| 27 |
+
*) default_bs=2; default_gas=4 ;;
|
| 28 |
+
esac
|
| 29 |
+
|
| 30 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 31 |
+
export CUDA_DEVICE_ORDER=PCI_BUS_ID
|
| 32 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 33 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 34 |
+
else
|
| 35 |
+
export CUDA_VISIBLE_DEVICES="${GPU_ID}"
|
| 36 |
+
fi
|
| 37 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 38 |
+
|
| 39 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 40 |
+
cmd=(
|
| 41 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
|
| 42 |
+
)
|
| 43 |
+
else
|
| 44 |
+
cmd=(
|
| 45 |
+
"${PYTHON_BIN}" -u "${TRAINER}"
|
| 46 |
+
)
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
cmd+=(
|
| 50 |
+
--model_name "${MODEL_NAME}"
|
| 51 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 52 |
+
--output_dir "${OUTPUT_DIR}"
|
| 53 |
+
--init_adapter_dir "${INIT_ADAPTER_DIR:?INIT_ADAPTER_DIR must be set}"
|
| 54 |
+
--cache_dir "${CACHE_DIR}"
|
| 55 |
+
--gpu_id 0
|
| 56 |
+
--stage_i "${STAGE_I}"
|
| 57 |
+
--total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
|
| 58 |
+
--per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}"
|
| 59 |
+
--gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}"
|
| 60 |
+
--num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}"
|
| 61 |
+
--learning_rate "${LEARNING_RATE:-1e-6}"
|
| 62 |
+
--logging_steps "${LOGGING_STEPS:-5}"
|
| 63 |
+
--save_steps "${SAVE_STEPS:-25}"
|
| 64 |
+
--eval_steps "${EVAL_STEPS:-25}"
|
| 65 |
+
--eval_rows "${EVAL_ROWS:-20}"
|
| 66 |
+
--num_generations "${NUM_GENERATIONS:-2}"
|
| 67 |
+
--max_prompt_length "${MAX_PROMPT_LENGTH:-1024}"
|
| 68 |
+
--max_completion_length "${MAX_COMPLETION_LENGTH:-24}"
|
| 69 |
+
--beta "${BETA:-0.0}"
|
| 70 |
+
--enable_gradient_checkpointing
|
| 71 |
+
--wandb_project "${WANDB_PROJECT}"
|
| 72 |
+
--wandb_run_name "${WANDB_RUN_NAME}"
|
| 73 |
+
--wandb_group "${WANDB_GROUP}"
|
| 74 |
+
--wandb_mode "${WANDB_MODE:-offline}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
|
| 78 |
+
cmd+=(--use_wandb)
|
| 79 |
+
fi
|
| 80 |
+
|
| 81 |
+
if [[ -n "${WANDB_ENTITY:-}" ]]; then
|
| 82 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
|
| 86 |
+
cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
|
| 87 |
+
fi
|
| 88 |
+
|
| 89 |
+
if [[ -n "${MAX_STEPS:-}" ]]; then
|
| 90 |
+
cmd+=(--max_steps "${MAX_STEPS}")
|
| 91 |
+
fi
|
| 92 |
+
|
| 93 |
+
if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then
|
| 94 |
+
cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}")
|
| 95 |
+
fi
|
| 96 |
+
|
| 97 |
+
printf 'Launching hard 9x9 baseline GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
|
| 98 |
+
printf 'Output dir: %s\n' "${OUTPUT_DIR}"
|
| 99 |
+
printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 100 |
+
printf 'stage_i=%s batch=%s grad_accum=%s processes=%s\n' \
|
| 101 |
+
"${STAGE_I}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${NUM_PROCESSES}"
|
| 102 |
+
|
| 103 |
+
"${cmd[@]}"
|
large_baseline_extension/launch_nonlocation_pipeline.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 7 |
+
PIPELINE="${ROOT}/multi_output_cell_policy/run_baseline_multi_output_pipeline_resume.py"
|
| 8 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
|
| 9 |
+
CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
|
| 10 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 11 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 12 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 13 |
+
MIN_STAGE="${MIN_STAGE:-1}"
|
| 14 |
+
MAX_STAGE="${MAX_STAGE:-4}"
|
| 15 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 16 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline}"
|
| 17 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/baseline_pipeline_30empty_4stage_hard9x9}"
|
| 18 |
+
|
| 19 |
+
cmd=(
|
| 20 |
+
"${PYTHON_BIN}" "${PIPELINE}"
|
| 21 |
+
--python_executable "${PYTHON_BIN}"
|
| 22 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 23 |
+
--cache_dir "${CACHE_DIR}"
|
| 24 |
+
--model_name "${MODEL_NAME}"
|
| 25 |
+
--checkpoint_root "${CHECKPOINT_ROOT}"
|
| 26 |
+
--output_root "${OUTPUT_ROOT}"
|
| 27 |
+
--run_tag "${RUN_TAG}"
|
| 28 |
+
--min_stage "${MIN_STAGE}"
|
| 29 |
+
--max_stage "${MAX_STAGE}"
|
| 30 |
+
--distributed_gpu_ids "${GPU_IDS}"
|
| 31 |
+
--sft_num_processes "${NUM_PROCESSES}"
|
| 32 |
+
--grpo_num_processes "${NUM_PROCESSES}"
|
| 33 |
+
--total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
|
| 34 |
+
--sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}"
|
| 35 |
+
--grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}"
|
| 36 |
+
--sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}"
|
| 37 |
+
--grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}"
|
| 38 |
+
--grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}"
|
| 39 |
+
--grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}"
|
| 40 |
+
--grpo_eval_solve_rate_stop "${GRPO_EVAL_SOLVE_RATE_STOP:-0.8}"
|
| 41 |
+
--grpo_min_steps_before_stop "${GRPO_MIN_STEPS_BEFORE_STOP:-100}"
|
| 42 |
+
--sft_enable_gradient_checkpointing
|
| 43 |
+
--grpo_enable_gradient_checkpointing
|
| 44 |
+
--sft_save_steps "${SFT_SAVE_STEPS:-100}"
|
| 45 |
+
--sft_eval_steps "${SFT_EVAL_STEPS:-100}"
|
| 46 |
+
--grpo_save_steps "${GRPO_SAVE_STEPS:-25}"
|
| 47 |
+
--grpo_eval_steps "${GRPO_EVAL_STEPS:-25}"
|
| 48 |
+
--phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}"
|
| 49 |
+
--wandb_mode "${WANDB_MODE:-offline}"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then
|
| 53 |
+
cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}")
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
|
| 57 |
+
cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then
|
| 61 |
+
cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}")
|
| 62 |
+
fi
|
| 63 |
+
|
| 64 |
+
if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then
|
| 65 |
+
cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}")
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
|
| 69 |
+
cmd+=(--use_wandb)
|
| 70 |
+
fi
|
| 71 |
+
|
| 72 |
+
if [[ -n "${WANDB_ENTITY:-}" ]]; then
|
| 73 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
printf 'Launching hard 9x9 baseline pipeline on GPUs %s\n' "${GPU_IDS}"
|
| 77 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 78 |
+
printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}"
|
| 79 |
+
|
| 80 |
+
"${cmd[@]}"
|
large_baseline_extension/launch_nonlocation_sft.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 7 |
+
TRAINER="${ROOT}/multi_output_cell_policy/sft_multi_output_train.py"
|
| 8 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
|
| 9 |
+
CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
|
| 10 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 11 |
+
GPU_ID="${GPU_ID:-0}"
|
| 12 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 13 |
+
NUM_PROCESSES="${NUM_PROCESSES:-1}"
|
| 14 |
+
STAGE_I="${STAGE_I:-2}"
|
| 15 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 16 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft}"
|
| 17 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_${RUN_TAG}}"
|
| 18 |
+
WANDB_PROJECT="${WANDB_PROJECT:-sudoku-multi-output-sft}"
|
| 19 |
+
WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_baseline_noloc_sft_i${STAGE_I}_${RUN_TAG}}"
|
| 20 |
+
|
| 21 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 22 |
+
export CUDA_DEVICE_ORDER=PCI_BUS_ID
|
| 23 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 24 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 25 |
+
else
|
| 26 |
+
export CUDA_VISIBLE_DEVICES="${GPU_ID}"
|
| 27 |
+
fi
|
| 28 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 29 |
+
|
| 30 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 31 |
+
cmd=(
|
| 32 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
|
| 33 |
+
)
|
| 34 |
+
else
|
| 35 |
+
cmd=(
|
| 36 |
+
"${PYTHON_BIN}" -u "${TRAINER}"
|
| 37 |
+
)
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
cmd+=(
|
| 41 |
+
--model_name "${MODEL_NAME}"
|
| 42 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 43 |
+
--output_dir "${OUTPUT_DIR}"
|
| 44 |
+
--cache_dir "${CACHE_DIR}"
|
| 45 |
+
--gpu_id 0
|
| 46 |
+
--stage_i "${STAGE_I}"
|
| 47 |
+
--total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
|
| 48 |
+
--gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-4}"
|
| 49 |
+
--num_epochs "${NUM_EPOCHS:-1.0}"
|
| 50 |
+
--learning_rate "${LEARNING_RATE:-2e-4}"
|
| 51 |
+
--weight_decay "${WEIGHT_DECAY:-0.0}"
|
| 52 |
+
--enable_gradient_checkpointing
|
| 53 |
+
--logging_steps "${LOGGING_STEPS:-10}"
|
| 54 |
+
--save_steps "${SAVE_STEPS:-100}"
|
| 55 |
+
--eval_steps "${EVAL_STEPS:-100}"
|
| 56 |
+
--eval_rows "${EVAL_ROWS:-20}"
|
| 57 |
+
--max_completion_length "${MAX_COMPLETION_LENGTH:-24}"
|
| 58 |
+
--wandb_project "${WANDB_PROJECT}"
|
| 59 |
+
--wandb_run_name "${WANDB_RUN_NAME}"
|
| 60 |
+
--wandb_mode "${WANDB_MODE:-offline}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if [[ -n "${INIT_ADAPTER_DIR:-}" ]]; then
|
| 64 |
+
cmd+=(--init_adapter_dir "${INIT_ADAPTER_DIR}")
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
|
| 68 |
+
cmd+=(--use_wandb)
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
if [[ -n "${WANDB_ENTITY:-}" ]]; then
|
| 72 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 73 |
+
fi
|
| 74 |
+
|
| 75 |
+
if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
|
| 76 |
+
cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
if [[ -n "${MAX_STEPS:-}" ]]; then
|
| 80 |
+
cmd+=(--max_steps "${MAX_STEPS}")
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
printf 'Launching hard 9x9 baseline SFT on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
|
| 84 |
+
printf 'Output dir: %s\n' "${OUTPUT_DIR}"
|
| 85 |
+
printf 'Stage=%s processes=%s\n' "${STAGE_I}" "${NUM_PROCESSES}"
|
| 86 |
+
|
| 87 |
+
"${cmd[@]}"
|
large_latent_extension/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Large Latent Extension Launchers
|
| 2 |
+
|
| 3 |
+
This folder contains the launch scripts for the non-location latent CoT runs.
|
| 4 |
+
|
| 5 |
+
- `launch_nonlocation_sft.sh`
|
| 6 |
+
- `launch_nonlocation_grpo.sh`
|
| 7 |
+
|
| 8 |
+
These are the scripts used for the distributed multi-GPU non-location curriculum.
|
| 9 |
+
|
| 10 |
+
Useful environment variables:
|
| 11 |
+
|
| 12 |
+
- `NUM_COT_TOKENS`
|
| 13 |
+
- `STAGE_I`
|
| 14 |
+
- `NUM_PROCESSES`
|
| 15 |
+
- `GPU_IDS`
|
| 16 |
+
- `INIT_ADAPTER_DIR`
|
| 17 |
+
- `OUTPUT_DIR`
|
| 18 |
+
- `LIMIT_TRAIN_ROWS`
|
| 19 |
+
- `WANDB_MODE`
|
| 20 |
+
- `WANDB_ENTITY`
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
NUM_COT_TOKENS=3 \
|
| 26 |
+
STAGE_I=3 \
|
| 27 |
+
NUM_PROCESSES=8 \
|
| 28 |
+
GPU_IDS=0,1,2,3,4,5,6,7 \
|
| 29 |
+
WANDB_MODE=online \
|
| 30 |
+
WANDB_ENTITY=training-dynamics \
|
| 31 |
+
bash launch_nonlocation_sft.sh
|
| 32 |
+
```
|
large_latent_extension/launch_nonlocation_grpo.sh
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 7 |
+
TRAINER="${ROOT}/latent_multi_output_cell_policy/grpo_residual_projector_latent_train.py"
|
| 8 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
|
| 9 |
+
CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
|
| 10 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 11 |
+
GPU_ID="${GPU_ID:-0}"
|
| 12 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 13 |
+
NUM_PROCESSES="${NUM_PROCESSES:-1}"
|
| 14 |
+
NUM_COT_TOKENS="${NUM_COT_TOKENS:?NUM_COT_TOKENS must be set}"
|
| 15 |
+
STAGE_I="${STAGE_I:-2}"
|
| 16 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 17 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo}"
|
| 18 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${OUTPUT_ROOT}/i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}"
|
| 19 |
+
INIT_ADAPTER_DIR="${INIT_ADAPTER_DIR:-}"
|
| 20 |
+
WANDB_PROJECT="${WANDB_PROJECT:-sudoku-latent-multi-output-grpo-residual-projector}"
|
| 21 |
+
WANDB_RUN_NAME="${WANDB_RUN_NAME:-large_latent_noloc_grpo_i${STAGE_I}_cot${NUM_COT_TOKENS}_${RUN_TAG}}"
|
| 22 |
+
WANDB_GROUP="${WANDB_GROUP:-large_latent_extension_noloc_grpo_i${STAGE_I}}"
|
| 23 |
+
|
| 24 |
+
case "${NUM_COT_TOKENS}" in
|
| 25 |
+
2) default_bs=4; default_gas=2 ;;
|
| 26 |
+
4) default_bs=2; default_gas=4 ;;
|
| 27 |
+
5) default_bs=2; default_gas=4 ;;
|
| 28 |
+
*) default_bs=2; default_gas=4 ;;
|
| 29 |
+
esac
|
| 30 |
+
|
| 31 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 32 |
+
export CUDA_DEVICE_ORDER=PCI_BUS_ID
|
| 33 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 34 |
+
export CUDA_VISIBLE_DEVICES="${GPU_IDS}"
|
| 35 |
+
else
|
| 36 |
+
export CUDA_VISIBLE_DEVICES="${GPU_ID}"
|
| 37 |
+
fi
|
| 38 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 39 |
+
|
| 40 |
+
if [[ "${NUM_PROCESSES}" -gt 1 ]]; then
|
| 41 |
+
cmd=(
|
| 42 |
+
"${PYTHON_BIN}" -m torch.distributed.run --standalone --nproc_per_node "${NUM_PROCESSES}" "${TRAINER}"
|
| 43 |
+
)
|
| 44 |
+
else
|
| 45 |
+
cmd=(
|
| 46 |
+
"${PYTHON_BIN}" -u "${TRAINER}"
|
| 47 |
+
)
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
cmd+=(
|
| 51 |
+
--model_name "${MODEL_NAME}"
|
| 52 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 53 |
+
--output_dir "${OUTPUT_DIR}"
|
| 54 |
+
--init_adapter_dir "${INIT_ADAPTER_DIR}"
|
| 55 |
+
--cache_dir "${CACHE_DIR}"
|
| 56 |
+
--gpu_id 0
|
| 57 |
+
--stage_i "${STAGE_I}"
|
| 58 |
+
--num_cot_tokens "${NUM_COT_TOKENS}"
|
| 59 |
+
--total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
|
| 60 |
+
--per_device_train_batch_size "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}"
|
| 61 |
+
--gradient_accumulation_steps "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}"
|
| 62 |
+
--num_train_epochs "${NUM_TRAIN_EPOCHS:-0.5}"
|
| 63 |
+
--learning_rate "${LEARNING_RATE:-7e-7}"
|
| 64 |
+
--logging_steps "${LOGGING_STEPS:-5}"
|
| 65 |
+
--save_steps "${SAVE_STEPS:-10}"
|
| 66 |
+
--eval_steps "${EVAL_STEPS:-25}"
|
| 67 |
+
--eval_rows "${EVAL_ROWS:-20}"
|
| 68 |
+
--num_generations "${NUM_GENERATIONS:-2}"
|
| 69 |
+
--max_prompt_length "${MAX_PROMPT_LENGTH:-1024}"
|
| 70 |
+
--max_completion_length "${MAX_COMPLETION_LENGTH:-32}"
|
| 71 |
+
--beta "${BETA:-0.01}"
|
| 72 |
+
--enable_gradient_checkpointing
|
| 73 |
+
--wandb_project "${WANDB_PROJECT}"
|
| 74 |
+
--wandb_run_name "${WANDB_RUN_NAME}"
|
| 75 |
+
--wandb_group "${WANDB_GROUP}"
|
| 76 |
+
--wandb_mode "${WANDB_MODE:-offline}"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
|
| 80 |
+
cmd+=(--use_wandb)
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
if [[ -n "${WANDB_ENTITY:-}" ]]; then
|
| 84 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
|
| 88 |
+
cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
|
| 89 |
+
fi
|
| 90 |
+
|
| 91 |
+
if [[ -n "${RESUME_FROM_CHECKPOINT:-}" ]]; then
|
| 92 |
+
cmd+=(--resume_from_checkpoint "${RESUME_FROM_CHECKPOINT}")
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
printf 'Launching hard 9x9 latent GRPO on GPUs %s\n' "${CUDA_VISIBLE_DEVICES}"
|
| 96 |
+
printf 'Output dir: %s\n' "${OUTPUT_DIR}"
|
| 97 |
+
printf 'Init adapter: %s\n' "${INIT_ADAPTER_DIR}"
|
| 98 |
+
printf 'num_cot_tokens=%s batch=%s grad_accum=%s stage_i=%s num_processes=%s\n' \
|
| 99 |
+
"${NUM_COT_TOKENS}" "${PER_DEVICE_TRAIN_BATCH_SIZE:-${default_bs}}" "${GRADIENT_ACCUMULATION_STEPS:-${default_gas}}" "${STAGE_I}" "${NUM_PROCESSES}"
|
| 100 |
+
|
| 101 |
+
"${cmd[@]}"
|
large_latent_extension/launch_nonlocation_pipeline.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
| 6 |
+
PYTHON_BIN="${PYTHON_BIN:-${ROOT}/.venv/bin/python}"
|
| 7 |
+
PIPELINE="${ROOT}/run_latent_residual_projector_pipeline.py"
|
| 8 |
+
TRAIN_JSONL="${TRAIN_JSONL:-${ROOT}/data/sudoku_t3_30empty_value_qwen_text.jsonl}"
|
| 9 |
+
CACHE_DIR="${CACHE_DIR:-${ROOT}/.hf_cache}"
|
| 10 |
+
MODEL_NAME="${MODEL_NAME:-Qwen/Qwen2.5-0.5B-Instruct}"
|
| 11 |
+
GPU_IDS="${GPU_IDS:-0,1,2,3,4,5,6,7}"
|
| 12 |
+
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
| 13 |
+
MIN_STAGE="${MIN_STAGE:-1}"
|
| 14 |
+
MAX_STAGE="${MAX_STAGE:-4}"
|
| 15 |
+
RUN_TAG="${RUN_TAG:-$(date +%Y%m%d_%H%M%S)}"
|
| 16 |
+
CHECKPOINT_ROOT="${CHECKPOINT_ROOT:-${ROOT}/final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent}"
|
| 17 |
+
OUTPUT_ROOT="${OUTPUT_ROOT:-${CHECKPOINT_ROOT}/${RUN_TAG}/latent_pipeline_30empty_4stage_hard9x9}"
|
| 18 |
+
|
| 19 |
+
cmd=(
|
| 20 |
+
"${PYTHON_BIN}" "${PIPELINE}"
|
| 21 |
+
--python_executable "${PYTHON_BIN}"
|
| 22 |
+
--train_jsonl "${TRAIN_JSONL}"
|
| 23 |
+
--cache_dir "${CACHE_DIR}"
|
| 24 |
+
--model_name "${MODEL_NAME}"
|
| 25 |
+
--checkpoint_root "${CHECKPOINT_ROOT}"
|
| 26 |
+
--output_root "${OUTPUT_ROOT}"
|
| 27 |
+
--run_tag "${RUN_TAG}"
|
| 28 |
+
--min_stage "${MIN_STAGE}"
|
| 29 |
+
--max_stage "${MAX_STAGE}"
|
| 30 |
+
--distributed_gpu_ids "${GPU_IDS}"
|
| 31 |
+
--sft_num_processes "${NUM_PROCESSES}"
|
| 32 |
+
--grpo_num_processes "${NUM_PROCESSES}"
|
| 33 |
+
--total_empties_hint "${TOTAL_EMPTIES_HINT:-30}"
|
| 34 |
+
--sft_num_epochs "${SFT_NUM_EPOCHS:-1.0}"
|
| 35 |
+
--grpo_num_train_epochs "${GRPO_NUM_TRAIN_EPOCHS:-0.5}"
|
| 36 |
+
--sft_gradient_accumulation_steps "${SFT_GRADIENT_ACCUMULATION_STEPS:-8}"
|
| 37 |
+
--grpo_per_device_train_batch_size "${GRPO_PER_DEVICE_TRAIN_BATCH_SIZE:-2}"
|
| 38 |
+
--grpo_gradient_accumulation_steps "${GRPO_GRADIENT_ACCUMULATION_STEPS:-4}"
|
| 39 |
+
--grpo_num_generations "${GRPO_NUM_GENERATIONS:-2}"
|
| 40 |
+
--sft_enable_gradient_checkpointing
|
| 41 |
+
--grpo_enable_gradient_checkpointing
|
| 42 |
+
--sft_save_steps "${SFT_SAVE_STEPS:-100}"
|
| 43 |
+
--sft_eval_steps "${SFT_EVAL_STEPS:-100}"
|
| 44 |
+
--grpo_save_steps "${GRPO_SAVE_STEPS:-25}"
|
| 45 |
+
--grpo_eval_steps "${GRPO_EVAL_STEPS:-25}"
|
| 46 |
+
--phase_max_wall_clock_seconds "${PHASE_MAX_WALL_CLOCK_SECONDS:-21600}"
|
| 47 |
+
--wandb_mode "${WANDB_MODE:-offline}"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
if [[ -n "${BOOTSTRAP_ADAPTER_DIR:-}" ]]; then
|
| 51 |
+
cmd+=(--bootstrap_adapter_dir "${BOOTSTRAP_ADAPTER_DIR}")
|
| 52 |
+
fi
|
| 53 |
+
|
| 54 |
+
if [[ -n "${STAGE1_INIT_ADAPTER_DIR:-}" ]]; then
|
| 55 |
+
cmd+=(--stage1_init_adapter_dir "${STAGE1_INIT_ADAPTER_DIR}")
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
if [[ -n "${LIMIT_TRAIN_ROWS:-}" ]]; then
|
| 59 |
+
cmd+=(--limit_train_rows "${LIMIT_TRAIN_ROWS}")
|
| 60 |
+
fi
|
| 61 |
+
|
| 62 |
+
if [[ -n "${SFT_STAGE_MAX_STEPS:-}" ]]; then
|
| 63 |
+
cmd+=(--sft_stage_max_steps "${SFT_STAGE_MAX_STEPS}")
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
if [[ -n "${GRPO_STAGE_MAX_STEPS:-}" ]]; then
|
| 67 |
+
cmd+=(--grpo_stage_max_steps "${GRPO_STAGE_MAX_STEPS}")
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
if [[ "${WANDB_MODE:-offline}" != "offline" ]]; then
|
| 71 |
+
cmd+=(--use_wandb)
|
| 72 |
+
fi
|
| 73 |
+
|
| 74 |
+
if [[ -n "${WANDB_ENTITY:-}" ]]; then
|
| 75 |
+
cmd+=(--wandb_entity "${WANDB_ENTITY}")
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
printf 'Launching hard 9x9 latent pipeline on GPUs %s\n' "${GPU_IDS}"
|
| 79 |
+
printf 'Output root: %s\n' "${OUTPUT_ROOT}"
|
| 80 |
+
printf 'Stages: %s -> %s, processes=%s\n' "${MIN_STAGE}" "${MAX_STAGE}" "${NUM_PROCESSES}"
|
| 81 |
+
|
| 82 |
+
"${cmd[@]}"
|