#!/usr/bin/env bash # Waits for longer-SFT to finish, then runs slot-decorrelation experiment from # that ckpt. Runs both ablations + capacity_diagnostic on the result, then # pushes everything to HF. set -uo pipefail REPO="LauraGG/blt-reasoner-pilot1" OUT="/home/ubuntu/work/blt_decorr_exp" CFG="/home/ubuntu/experiments/blt_reasoner/configs/exp7b_decorr.json" WAIT_FOR="/home/ubuntu/work/blt_longer_sft/final/ablation_n200_K16.json" RESUME_FROM="/home/ubuntu/work/blt_longer_sft/final" LOG="/home/ubuntu/work/queue_decorr.log" log() { echo "[$(date +%T)] $*" | tee -a "$LOG"; } mkdir -p "$OUT" cd /home/ubuntu export TOKENIZERS_PARALLELISM=false TRANSFORMERS_NO_ADVISORY_WARNINGS=1 HF_HUB_DISABLE_PROGRESS_BARS=1 export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True log "queue_decorr.sh starting; waiting for longer-SFT to finish at $WAIT_FOR" DEADLINE=$(( $(date +%s) + 6*3600 )) while [ ! -f "$WAIT_FOR" ]; do if [ "$(date +%s)" -gt "$DEADLINE" ]; then log "deadline reached; using GRPO ckpt as resume baseline" RESUME_FROM="/home/ubuntu/work/blt_grpo_opt13/final" break fi sleep 60 done log "proceeding from RESUME_FROM=$RESUME_FROM; 30s GPU drain" sleep 30 log "===========================================" log "DECORR experiment (lambda_decorr=0.5) from $RESUME_FROM" log "===========================================" python3 -u -m experiments.blt_reasoner.train --config "$CFG" \ --resume_from "$RESUME_FROM" \ > "$OUT/train.log" 2>&1 log "train exit=$?" log "Capacity diagnostic on decorr-final" python3 -u -m experiments.blt_reasoner.scripts.capacity_diagnostic \ --ckpt "$OUT/final" --config "$CFG" --n 100 --K 16 --max_new_tokens 128 \ --out "$OUT/final/capacity_diagnostic.json" \ > "$OUT/capacity.log" 2>&1 log "capacity diag exit=$?" log "TF ablation" python3 -u -m experiments.blt_reasoner.scripts.ablate_teacher_forced \ --ckpt "$OUT/final" --config "$CFG" --n 200 --K 16 \ --out "$OUT/final/ablation_teacher_forced.json" \ > "$OUT/tf_eval.log" 2>&1 log "TF ablate exit=$?" log "AR ablation" python3 -u -m experiments.blt_reasoner.eval \ --ckpt "$OUT/final" --config "$CFG" --n 200 --K 16 \ --max_new_tokens 192 --temperature 0.0 \ --out "$OUT/final/ablation_n200_K16.json" \ > "$OUT/ar_eval.log" 2>&1 log "AR ablate exit=$?" log "pushing decorr_exp/ to HF" python3 - <