CausalGrok / code /scripts /run_harder_spurious.sh
nileshsarkar-ai's picture
Upload code/scripts
42c0d23 verified
#!/usr/bin/env bash
# scripts/run_harder_spurious.sh
#
# The "harder spurious" variant. Forces the network into a true
# memorization-then-generalization regime by stacking three knobs:
#
# n_train = 100 (lung anatomy too noisy to learn from 100 samples)
# patch_size = 10 (10Γ—10 = 100 of 784 pixels = 12.7% of image)
# spurious_rho = 0.95 (patch is correct 95% of the time β†’ val ceiling at 0.95)
# weight_decay = 5e-3 (carryover from clean-grokking β€” Goldilocks-zone fix)
# condition = grokking (Grokfast, init Ξ±=4, 3000 epochs)
#
# What we expect to see:
# ep 1: tr 0.55 vl 0.50 sc β‰ˆ 1.0 (random)
# ep 30: tr 0.98 vl 0.93 sc > 1.5 (learned the patch β€” corner reliance)
# ep 500: tr 1.00 vl 0.94 sc β‰ˆ 1.7 (plateau at rho ceiling, memorization)
# ep 900: tr 1.00 vl 0.94 sc β‰ˆ 1.4 (sc declining β€” leading indicator)
# ep 1200: tr 1.00 vl 0.96 sc < 1.0 (TRANSITION β€” model abandons patch)
# ep 1500: tr 1.00 vl 0.98 sc β‰ˆ 0.7 (grokked, anatomy-driven)
#
# The first epoch where `sc` crosses below 1.0 BEFORE val_acc jumps is
# the predictive progress measure that makes the paper.
set -euo pipefail
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "${ROOT}"
GPU="${GPU:-0}"
N_TRAIN="${N_TRAIN:-100}"
SEED="${SEED:-42}"
RHO="${RHO:-0.95}"
PATCH_SIZE="${PATCH_SIZE:-10}"
WD="${WD:-5e-3}"
export WANDB_MODE="${WANDB_MODE:-offline}"
echo "Harder spurious variant"
echo " n_train : ${N_TRAIN}"
echo " rho : ${RHO}"
echo " patch_size : ${PATCH_SIZE}"
echo " weight_decay: ${WD}"
echo " GPU : ${GPU}"
# RHO 0.95 β†’ "spurious095_ps10" (slug encodes both rho and patch size since
# both differ from the rho=0.8 / ps=4 default β€” the dir name should be
# enough to identify the variant).
RHO_SLUG="spurious$(echo "${RHO}" | tr -d '.')_ps${PATCH_SIZE}"
EXTRA_ARGS="--spurious_rho ${RHO} --spurious_patch_size ${PATCH_SIZE} --weight_decay ${WD}" \
RUN_TAG="${RHO_SLUG}" \
bash scripts/launch.sh grokking "${N_TRAIN}" "${SEED}" "${GPU}"
echo
echo "Harder-spurious run detached. Watch with:"
echo " bash scripts/list_runs.sh"
echo " tail -f experiments/runs/<latest>/logs/train.log"
echo
echo "When DONE, the leading-indicator check is:"
cat <<'EOF'
awk '/^ ep / { for (i=1;i<=NF;i++) if ($i=="sc") { gsub("x","",$(i+1)); if ($(i+1)+0 < 1.0) { print "FIRST sc<1.0:", $0; exit } } }' \
experiments/runs/<latest>/logs/train.log
EOF