hypernet-sp-distill / narrowdeep_wrapper.sh
baya1116's picture
Upload narrowdeep_wrapper.sh with huggingface_hub
2cebb87 verified
#!/bin/bash
# (B) NARROW+DEEP output-space KL: the only config with a proven positive signal
# (overfit_e2e reached KL~0.02 on focused scope). Tests whether restricting the
# hypernet's amortization scope to the N longest answers + FULL BPTT (tbptt=8 covers
# all chunks of a <=512-tok answer) lets it fit BELOW the 0.16 broad-training floor,
# and whether that fit generalizes to the held-out diag eval samples.
# - warm start from GOLD hn_step7750 (NOT the degraded fullbptt 8050)
# - pure KL (carry/distill aux losses OFF -> symmetry-safe)
# - --fresh_opt healthy lr ; --no_skip ignore ckpt curriculum position
# - OOM -> drop batch by 1 ; crash -> retry ; auto-resume from latest narrowdeep ckpt
set -u
export HF_HOME=/workspace/.hf
cd /workspace
LOG=/workspace/narrowdeep.log
CKPT_DIR=/workspace/hypernet_qwen_narrowdeep
GOLD=/workspace/hypernet_qwen/hn_step7750.pt
BATCH="${BATCH:-3}"
LIMIT="${LIMIT:-48}"
KCH="${KCH:-8}"
FAILS=0
mkdir -p "$CKPT_DIR"
while true; do
CK=$(ls -t "$CKPT_DIR"/hn_step*.pt 2>/dev/null | head -1)
if [ -z "${CK:-}" ]; then RESUME="$GOLD"; START=7750; else RESUME="$CK"; START=$(basename "$CK" | sed -E 's/hn_step([0-9]+)\.pt/\1/'); fi
echo "[$(date '+%F %T')] narrowdeep resume=$RESUME start=$START batch=$BATCH limit=$LIMIT K=$KCH" >> "$LOG"
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python3 train_qwen_distill.py \
--base_model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
--resume "$RESUME" --start_step "$START" --output_dir "$CKPT_DIR" \
--batch "$BATCH" --chunk_size 64 --raw_window 32 --tbptt_chunks "$KCH" --kl_temperature 1.0 \
--max_ans_len 512 --filter_cjk --filter_toolcall --dataset local:/workspace/dolphin_subset.jsonl \
--lr 1e-4 --epochs 120 --limit_samples "$LIMIT" --fresh_opt --no_skip >> "$LOG" 2>&1
EXIT=$?
if tail -200 "$LOG" | grep -qiE "out of memory|outofmemory"; then
NB=$((BATCH - 1)); [ "$NB" -lt 1 ] && { echo "[$(date '+%F %T')] batch<1 giveup" >> "$LOG"; exit 1; }
BATCH=$NB; echo "[$(date '+%F %T')] OOM -> batch $BATCH" >> "$LOG"; sleep 8
elif [ "$EXIT" -eq 0 ]; then echo "[$(date '+%F %T')] clean exit" >> "$LOG"; exit 0
else
FAILS=$((FAILS + 1)); [ "$FAILS" -ge 6 ] && { echo "[$(date '+%F %T')] 6 fails giveup" >> "$LOG"; exit 1; }
echo "[$(date '+%F %T')] crash ($FAILS/6) retry 20s" >> "$LOG"; sleep 20
fi
done