File size: 5,674 Bytes
48ecd01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | #!/usr/bin/env bash
# =============================================================================
# launch_3b_sft_v2.sh — 8-GPU FP8 SFT v2 launcher for 3B Korean LLM
#
# SFT v2 improvements over v1:
# - LR: 1e-5 → 5e-5 (5x, resolve underfitting)
# - Effective batch: 64 → 256 (4x)
# - Data mixing: 70% SFT + 30% pretrain (forgetting prevention)
# - Weight decay: 0.01 → 0.05
# - Warmup: 500 → 2000 steps
# - Max steps: 33000 → 15000
#
# Usage:
# bash scripts/launch_3b_sft_v2.sh
# bash scripts/launch_3b_sft_v2.sh --max_steps 200 # quick test
# bash scripts/launch_3b_sft_v2.sh --resume checkpoints/korean_3b_sft_v2/checkpoint-0002000
#
# Effective batch: 4 (local) x 8 GPU x 8 (grad_accum) = 256 samples/step
# =============================================================================
set -euo pipefail
# ---- Configurable defaults --------------------------------------------------
RUN_NAME="${RUN_NAME:-korean_3b_sft_v2}"
CONFIG="${CONFIG:-configs/korean_3b_sft_v2.yaml}"
BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}"
SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}"
VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}"
PRETRAIN_DATA="${PRETRAIN_DATA:-data/3b_train.bin}"
CKPT_DIR="checkpoints/${RUN_NAME}"
LOG_FILE="${CKPT_DIR}/train.log"
NPROC=8
MASTER_PORT="${MASTER_PORT:-29504}"
MAX_STEPS=15000
BATCH_SIZE=4
GRAD_ACCUM=8
LR="5.0e-5"
WARMUP_STEPS=2000
WEIGHT_DECAY=0.05
PRETRAIN_MIX_RATIO=0.3
SEED=42
EXTRA_ARGS="$@"
# ---- B200 / NVSwitch NCCL tuning (same as pretrain) -------------------------
export NCCL_IB_DISABLE=1
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export NCCL_BUFFSIZE=67108864
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4
# 3B + bs=4 VRAM allocation
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
cd "$(dirname "$0")/.."
# ---- Pre-flight checks ------------------------------------------------------
if [[ ! -d "${BASE_CHECKPOINT}" ]]; then
echo "=================================================================="
echo " ERROR: Base checkpoint not found: ${BASE_CHECKPOINT}"
echo " Set BASE_CHECKPOINT env var or use --base_checkpoint CLI arg."
echo "=================================================================="
exit 1
fi
if [[ ! -f "${SFT_DATA}" ]]; then
echo "=================================================================="
echo " ERROR: SFT data not found: ${SFT_DATA}"
echo " Run: bash scripts/prepare_sft_combined.sh"
echo "=================================================================="
exit 1
fi
if [[ ! -f "${PRETRAIN_DATA}" ]]; then
echo "=================================================================="
echo " ERROR: Pretrain data not found: ${PRETRAIN_DATA}"
echo " Set PRETRAIN_DATA env var to the correct path."
echo "=================================================================="
exit 1
fi
# val fallback
if [[ ! -f "${VAL_DATA}" ]]; then
VAL_FALLBACK="data/sft_combined/val.jsonl"
if [[ -f "${VAL_FALLBACK}" ]]; then
VAL_DATA="${VAL_FALLBACK}"
echo "[INFO] val_filtered not found, fallback: ${VAL_DATA}"
else
echo "ERROR: VAL_DATA not found: ${VAL_DATA}"
exit 1
fi
fi
mkdir -p "${CKPT_DIR}"
echo "=================================================================="
echo " 3B SFT v2 Fine-Tuning"
echo " Run name : ${RUN_NAME}"
echo " Config : ${CONFIG}"
echo " Base checkpoint : ${BASE_CHECKPOINT}"
echo " SFT data : ${SFT_DATA}"
echo " Pretrain data : ${PRETRAIN_DATA}"
echo " Val data : ${VAL_DATA}"
echo " CKPT dir : ${CKPT_DIR}"
echo " Log file : ${LOG_FILE}"
echo " Max steps : ${MAX_STEPS}"
echo " Batch size : ${BATCH_SIZE} (local) x ${NPROC} GPU x ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch"
echo " Learning rate : ${LR}"
echo " Weight decay : ${WEIGHT_DECAY}"
echo " Warmup : ${WARMUP_STEPS} steps"
echo " Data mixing : $((100 - ${PRETRAIN_MIX_RATIO%.*}0))% SFT + ${PRETRAIN_MIX_RATIO}00% pretrain"
echo " Master port : ${MASTER_PORT}"
echo " ALLOC_CONF : ${PYTORCH_CUDA_ALLOC_CONF}"
echo " Started : $(date)"
echo "=================================================================="
export PYTHONWARNINGS="ignore::UserWarning:torch.library"
torchrun \
--nproc_per_node=${NPROC} \
--master_port=${MASTER_PORT} \
train/sft.py \
--config "${CONFIG}" \
--base_checkpoint "${BASE_CHECKPOINT}" \
--sft_data "${SFT_DATA}" \
--val_data "${VAL_DATA}" \
--pretrain_data "${PRETRAIN_DATA}" \
--pretrain_mix_ratio ${PRETRAIN_MIX_RATIO} \
--checkpoint_dir "${CKPT_DIR}" \
--log_file "${LOG_FILE}" \
--max_steps ${MAX_STEPS} \
--batch_size ${BATCH_SIZE} \
--grad_accum ${GRAD_ACCUM} \
--lr ${LR} \
--weight_decay ${WEIGHT_DECAY} \
--warmup_steps ${WARMUP_STEPS} \
--seed ${SEED} \
--use_fp8 \
${EXTRA_ARGS} \
2>&1 | grep -v "UserWarning" \
| grep -v "Warning only once" \
| grep -v "Overriding a previously" \
| grep -v "dispatch key:" \
| grep -v "previous kernel:" \
| grep -v "new kernel:" \
| grep -v "operator: flash_attn" \
| grep -v "registered at /usr/local" \
| grep -v "self.m.impl" \
| tee -a "${LOG_FILE}"
echo "=================================================================="
echo " 3B SFT v2 Done : $(date)"
echo "=================================================================="
|