File size: 5,507 Bytes
48ecd01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | #!/usr/bin/env bash
# =============================================================================
# launch_3b_sft.sh โ 8-GPU FP8 SFT launcher for 3B Korean LLM
#
# Usage:
# bash scripts/launch_3b_sft.sh
# bash scripts/launch_3b_sft.sh --max_steps 200 # quick test
# bash scripts/launch_3b_sft.sh --resume checkpoints/korean_3b_sft_v1/checkpoint-0002000
#
# Base model : checkpoints/korean_3b_fp8_run1/checkpoint-XXXXXX (๊ธฐ๋ณธ๊ฐ)
# --base_checkpoint ์ธ์๋ก ๋ฎ์ด์ธ ์ ์์
# SFT data : data/sft_combined/train_filtered.jsonl
# (๋จผ์ scripts/prepare_sft_combined.sh โ data/filter_sft_v2.py ์คํ)
#
# Effective batch: 2 (local) ร 8 GPU ร 4 (grad_accum) = 64 samples/step
# =============================================================================
set -euo pipefail
# ---- Configurable defaults --------------------------------------------------
RUN_NAME="${RUN_NAME:-korean_3b_sft_v1}"
CONFIG="${CONFIG:-configs/korean_3b_sft.yaml}"
BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}"
SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}"
VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}"
CKPT_DIR="checkpoints/${RUN_NAME}"
LOG_FILE="${CKPT_DIR}/train.log"
NPROC=8
MASTER_PORT="${MASTER_PORT:-29503}"
MAX_STEPS=33000
BATCH_SIZE=2
GRAD_ACCUM=4
LR="1.0e-5"
WARMUP_STEPS=500
SEED=42
EXTRA_ARGS="$@"
# ---- B200 / NVSwitch NCCL tuning (same as pretrain) -------------------------
export NCCL_IB_DISABLE=1
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export NCCL_BUFFSIZE=67108864
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4
# 3B ๋ชจ๋ธ VRAM ์ ์ฝ โ ๋์ ๋ฉ๋ชจ๋ฆฌ ์ธ๊ทธ๋จผํธ ํ์ฅ ํ์ฉ
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
cd "$(dirname "$0")/.."
# ---- Pre-flight checks ------------------------------------------------------
if [[ ! -d "${BASE_CHECKPOINT}" ]]; then
echo "=================================================================="
echo " ERROR: Base checkpoint ๋๋ ํ ๋ฆฌ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
echo " ๊ฒฝ๋ก: ${BASE_CHECKPOINT}"
echo ""
echo " --base_checkpoint ์ธ์๋ก ์ค์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๊ฑฐ๋"
echo " BASE_CHECKPOINT ํ๊ฒฝ๋ณ์๋ฅผ ์ค์ ํ์ธ์."
echo " ์: bash scripts/launch_3b_sft.sh --base_checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000"
echo "=================================================================="
exit 1
fi
if [[ ! -f "${SFT_DATA}" ]]; then
echo "=================================================================="
echo " ERROR: SFT ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: ${SFT_DATA}"
echo ""
echo " ๋ฐ์ดํฐ ์ค๋น ์์:"
echo " 1. bash scripts/prepare_sft_combined.sh"
echo " 2. python data/filter_sft_v2.py \\"
echo " --input data/sft_combined/train.jsonl \\"
echo " --output data/sft_combined/train_filtered.jsonl"
echo "=================================================================="
exit 1
fi
# val ํ์ผ ์์ผ๋ฉด ์๋ณธ val.jsonl ๋ก ํด๋ฐฑ
if [[ ! -f "${VAL_DATA}" ]]; then
VAL_FALLBACK="data/sft_combined/val.jsonl"
if [[ -f "${VAL_FALLBACK}" ]]; then
VAL_DATA="${VAL_FALLBACK}"
echo "[INFO] val_filtered ์์, ํด๋ฐฑ: ${VAL_DATA}"
else
echo "ERROR: VAL_DATA ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: ${VAL_DATA}"
exit 1
fi
fi
mkdir -p "${CKPT_DIR}"
echo "=================================================================="
echo " 3B SFT Fine-Tuning"
echo " Run name : ${RUN_NAME}"
echo " Config : ${CONFIG}"
echo " Base checkpoint : ${BASE_CHECKPOINT}"
echo " SFT data : ${SFT_DATA}"
echo " Val data : ${VAL_DATA}"
echo " CKPT dir : ${CKPT_DIR}"
echo " Log file : ${LOG_FILE}"
echo " Max steps : ${MAX_STEPS}"
echo " Batch size : ${BATCH_SIZE} (local) ร ${NPROC} GPU ร ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch"
echo " Learning rate : ${LR}"
echo " Warmup : ${WARMUP_STEPS} steps"
echo " Master port : ${MASTER_PORT}"
echo " ALLOC_CONF : ${PYTORCH_CUDA_ALLOC_CONF}"
echo " Started : $(date)"
echo "=================================================================="
export PYTHONWARNINGS="ignore::UserWarning:torch.library"
torchrun \
--nproc_per_node=${NPROC} \
--master_port=${MASTER_PORT} \
train/sft.py \
--config "${CONFIG}" \
--base_checkpoint "${BASE_CHECKPOINT}" \
--sft_data "${SFT_DATA}" \
--val_data "${VAL_DATA}" \
--checkpoint_dir "${CKPT_DIR}" \
--log_file "${LOG_FILE}" \
--max_steps ${MAX_STEPS} \
--batch_size ${BATCH_SIZE} \
--grad_accum ${GRAD_ACCUM} \
--lr ${LR} \
--warmup_steps ${WARMUP_STEPS} \
--seed ${SEED} \
--use_fp8 \
${EXTRA_ARGS} \
2>&1 | grep -v "UserWarning" \
| grep -v "Warning only once" \
| grep -v "Overriding a previously" \
| grep -v "dispatch key:" \
| grep -v "previous kernel:" \
| grep -v "new kernel:" \
| grep -v "operator: flash_attn" \
| grep -v "registered at /usr/local" \
| grep -v "self.m.impl" \
| tee -a "${LOG_FILE}"
echo "=================================================================="
echo " 3B SFT Done : $(date)"
echo "=================================================================="
|