| #!/bin/bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| set -e |
| set -u |
| set -o pipefail |
|
|
| PROJECT_ROOT="$(cd "$(dirname "$0")" && pwd)" |
| cd "$PROJECT_ROOT" |
|
|
| export DATA_ROOT="${DATA_ROOT:-$PROJECT_ROOT/data}" |
| export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" |
| export PYTHONPATH="$PROJECT_ROOT:${PYTHONPATH:-}" |
| export TOKENIZERS_PARALLELISM=false |
|
|
| N_TRAIN="${N_TRAIN:-150}" |
| N_MATH_TEST="${N_MATH_TEST:-50}" |
| N_AIME="${N_AIME:-30}" |
| N_GPQA="${N_GPQA:-20}" |
| N_SWEEP="${N_SWEEP:-30}" |
|
|
| |
| JOINT_FLAG="" |
| if [[ "${JOINT:-0}" == "1" ]]; then |
| JOINT_FLAG="--joint" |
| fi |
|
|
| mkdir -p "$DATA_ROOT/logs" |
| RUNALL_LOG="$DATA_ROOT/logs/runall.log" |
|
|
| echo "=========================================================" | tee -a "$RUNALL_LOG" |
| echo "Student Simulation v2 - $(date)" | tee -a "$RUNALL_LOG" |
| echo "PROJECT_ROOT: $PROJECT_ROOT" | tee -a "$RUNALL_LOG" |
| echo "DATA_ROOT: $DATA_ROOT" | tee -a "$RUNALL_LOG" |
| echo "CUDA: $CUDA_VISIBLE_DEVICES" | tee -a "$RUNALL_LOG" |
| echo "N_TRAIN: $N_TRAIN" | tee -a "$RUNALL_LOG" |
| echo "N_SWEEP: $N_SWEEP" | tee -a "$RUNALL_LOG" |
| echo "JOINT: ${JOINT:-0}" | tee -a "$RUNALL_LOG" |
| echo "=========================================================" | tee -a "$RUNALL_LOG" |
|
|
| python -m configs.paths 2>&1 | tee -a "$RUNALL_LOG" |
|
|
| |
| STAGES="${STAGES:-0,1,2,3,4,5,6,7,8,8b,9,10,12,13}" |
|
|
| run_stage() { |
| local stage_num="$1" |
| local stage_name="$2" |
| shift 2 |
| if [[ ",$STAGES," != *",$stage_num,"* ]]; then |
| echo "[skip] Stage $stage_num: $stage_name" | tee -a "$RUNALL_LOG" |
| return 0 |
| fi |
| echo "" | tee -a "$RUNALL_LOG" |
| echo "==================== Stage $stage_num: $stage_name ====================" | tee -a "$RUNALL_LOG" |
| local t_start |
| t_start=$(date +%s) |
| "$@" 2>&1 | tee -a "$RUNALL_LOG" |
| local t_end |
| t_end=$(date +%s) |
| echo "Stage $stage_num took $((t_end - t_start))s" | tee -a "$RUNALL_LOG" |
| } |
|
|
| |
| if [[ -z "${SKIP_DOWNLOAD:-}" ]]; then |
| run_stage 1 "Download model" \ |
| python scripts/01_download_model.py |
| fi |
|
|
| run_stage 2 "Generate CoTs (MATH training + test sets)" \ |
| python scripts/02_generate_cots.py \ |
| --n_train "$N_TRAIN" --n_math_test "$N_MATH_TEST" \ |
| --n_aime "$N_AIME" --n_gpqa "$N_GPQA" --resume |
|
|
| run_stage 3 "Label CoTs (decision points)" \ |
| python scripts/03_label_cots.py --resume |
|
|
| run_stage 4 "Capture routing dumps (Stage 1)" \ |
| python scripts/04_capture_routing.py --resume |
|
|
| run_stage 5 "Select top experts" \ |
| python scripts/05_select_top_experts.py --resume |
|
|
| run_stage 6 "Interaction analysis (Stage 2)" \ |
| python scripts/06_interaction_analysis.py |
|
|
| run_stage 7 "Capture decision-point residuals" \ |
| python scripts/07_capture_residuals.py --resume |
|
|
| run_stage 8 "Compute direction vectors (v1_raw + v_pca_subspace)" \ |
| python scripts/08_compute_directions.py --resume |
|
|
| run_stage 8b "Attention diagnostic (informational)" \ |
| python scripts/08b_attention_diagnostic.py --resume |
|
|
| run_stage 9 "Steering sweep (α ∈ [0,1], with text saving)" \ |
| python scripts/09_steering_sweep.py \ |
| --n_test "$N_SWEEP" --resume \ |
| --save_texts $JOINT_FLAG |
|
|
| run_stage 10 "Sanity-check inference on 2 default problems" \ |
| python scripts/10_infer.py \ |
| --auto_problems \ |
| --dim planning --version v_pca_subspace \ |
| --alphas 1.0 0.5 0.0 \ |
| --save_to "$DATA_ROOT/results/infer_sanity_planning.json" |
|
|
| run_stage 10b "Sanity-check inference (monitoring)" \ |
| python scripts/10_infer.py \ |
| --auto_problems \ |
| --dim monitoring --version v_pca_subspace \ |
| --alphas 1.0 0.5 0.0 \ |
| --save_to "$DATA_ROOT/results/infer_sanity_monitoring.json" |
|
|
| run_stage 12 "Downstream accuracy eval" \ |
| python scripts/12_downstream_eval.py --resume |
|
|
| run_stage 13 "Final analysis + report" \ |
| python scripts/13_analyze_and_report.py |
|
|
| echo "" | tee -a "$RUNALL_LOG" |
| echo "=========================================================" | tee -a "$RUNALL_LOG" |
| echo "Pipeline complete - $(date)" | tee -a "$RUNALL_LOG" |
| echo "Final report: $DATA_ROOT/results/final_report.md" | tee -a "$RUNALL_LOG" |
| echo "Sweep curves: $DATA_ROOT/results/sweep_curves.png" | tee -a "$RUNALL_LOG" |
| echo "Attn diagnostic: $DATA_ROOT/results/attention_diagnostic.png" | tee -a "$RUNALL_LOG" |
| echo "Sanity inference: $DATA_ROOT/results/infer_sanity_*.json" | tee -a "$RUNALL_LOG" |
| echo "=========================================================" | tee -a "$RUNALL_LOG" |
|
|