frankenstallm / source /scripts /launch_3b_sft.sh
pathcosmos's picture
Upload folder using huggingface_hub (#17)
48ecd01
#!/usr/bin/env bash
# =============================================================================
# launch_3b_sft.sh โ€” 8-GPU FP8 SFT launcher for 3B Korean LLM
#
# Usage:
# bash scripts/launch_3b_sft.sh
# bash scripts/launch_3b_sft.sh --max_steps 200 # quick test
# bash scripts/launch_3b_sft.sh --resume checkpoints/korean_3b_sft_v1/checkpoint-0002000
#
# Base model : checkpoints/korean_3b_fp8_run1/checkpoint-XXXXXX (๊ธฐ๋ณธ๊ฐ’)
# --base_checkpoint ์ธ์ž๋กœ ๋ฎ์–ด์“ธ ์ˆ˜ ์žˆ์Œ
# SFT data : data/sft_combined/train_filtered.jsonl
# (๋จผ์ € scripts/prepare_sft_combined.sh โ†’ data/filter_sft_v2.py ์‹คํ–‰)
#
# Effective batch: 2 (local) ร— 8 GPU ร— 4 (grad_accum) = 64 samples/step
# =============================================================================
set -euo pipefail
# ---- Configurable defaults --------------------------------------------------
RUN_NAME="${RUN_NAME:-korean_3b_sft_v1}"
CONFIG="${CONFIG:-configs/korean_3b_sft.yaml}"
BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}"
SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}"
VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}"
CKPT_DIR="checkpoints/${RUN_NAME}"
LOG_FILE="${CKPT_DIR}/train.log"
NPROC=8
MASTER_PORT="${MASTER_PORT:-29503}"
MAX_STEPS=33000
BATCH_SIZE=2
GRAD_ACCUM=4
LR="1.0e-5"
WARMUP_STEPS=500
SEED=42
EXTRA_ARGS="$@"
# ---- B200 / NVSwitch NCCL tuning (same as pretrain) -------------------------
export NCCL_IB_DISABLE=1
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export NCCL_BUFFSIZE=67108864
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4
# 3B ๋ชจ๋ธ VRAM ์ ˆ์•ฝ โ€” ๋™์  ๋ฉ”๋ชจ๋ฆฌ ์„ธ๊ทธ๋จผํŠธ ํ™•์žฅ ํ—ˆ์šฉ
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
cd "$(dirname "$0")/.."
# ---- Pre-flight checks ------------------------------------------------------
if [[ ! -d "${BASE_CHECKPOINT}" ]]; then
echo "=================================================================="
echo " ERROR: Base checkpoint ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
echo " ๊ฒฝ๋กœ: ${BASE_CHECKPOINT}"
echo ""
echo " --base_checkpoint ์ธ์ž๋กœ ์‹ค์ œ ๊ฒฝ๋กœ๋ฅผ ์ง€์ •ํ•˜๊ฑฐ๋‚˜"
echo " BASE_CHECKPOINT ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜์„ธ์š”."
echo " ์˜ˆ: bash scripts/launch_3b_sft.sh --base_checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000"
echo "=================================================================="
exit 1
fi
if [[ ! -f "${SFT_DATA}" ]]; then
echo "=================================================================="
echo " ERROR: SFT ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: ${SFT_DATA}"
echo ""
echo " ๋ฐ์ดํ„ฐ ์ค€๋น„ ์ˆœ์„œ:"
echo " 1. bash scripts/prepare_sft_combined.sh"
echo " 2. python data/filter_sft_v2.py \\"
echo " --input data/sft_combined/train.jsonl \\"
echo " --output data/sft_combined/train_filtered.jsonl"
echo "=================================================================="
exit 1
fi
# val ํŒŒ์ผ ์—†์œผ๋ฉด ์›๋ณธ val.jsonl ๋กœ ํด๋ฐฑ
if [[ ! -f "${VAL_DATA}" ]]; then
VAL_FALLBACK="data/sft_combined/val.jsonl"
if [[ -f "${VAL_FALLBACK}" ]]; then
VAL_DATA="${VAL_FALLBACK}"
echo "[INFO] val_filtered ์—†์Œ, ํด๋ฐฑ: ${VAL_DATA}"
else
echo "ERROR: VAL_DATA ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: ${VAL_DATA}"
exit 1
fi
fi
mkdir -p "${CKPT_DIR}"
echo "=================================================================="
echo " 3B SFT Fine-Tuning"
echo " Run name : ${RUN_NAME}"
echo " Config : ${CONFIG}"
echo " Base checkpoint : ${BASE_CHECKPOINT}"
echo " SFT data : ${SFT_DATA}"
echo " Val data : ${VAL_DATA}"
echo " CKPT dir : ${CKPT_DIR}"
echo " Log file : ${LOG_FILE}"
echo " Max steps : ${MAX_STEPS}"
echo " Batch size : ${BATCH_SIZE} (local) ร— ${NPROC} GPU ร— ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch"
echo " Learning rate : ${LR}"
echo " Warmup : ${WARMUP_STEPS} steps"
echo " Master port : ${MASTER_PORT}"
echo " ALLOC_CONF : ${PYTORCH_CUDA_ALLOC_CONF}"
echo " Started : $(date)"
echo "=================================================================="
export PYTHONWARNINGS="ignore::UserWarning:torch.library"
torchrun \
--nproc_per_node=${NPROC} \
--master_port=${MASTER_PORT} \
train/sft.py \
--config "${CONFIG}" \
--base_checkpoint "${BASE_CHECKPOINT}" \
--sft_data "${SFT_DATA}" \
--val_data "${VAL_DATA}" \
--checkpoint_dir "${CKPT_DIR}" \
--log_file "${LOG_FILE}" \
--max_steps ${MAX_STEPS} \
--batch_size ${BATCH_SIZE} \
--grad_accum ${GRAD_ACCUM} \
--lr ${LR} \
--warmup_steps ${WARMUP_STEPS} \
--seed ${SEED} \
--use_fp8 \
${EXTRA_ARGS} \
2>&1 | grep -v "UserWarning" \
| grep -v "Warning only once" \
| grep -v "Overriding a previously" \
| grep -v "dispatch key:" \
| grep -v "previous kernel:" \
| grep -v "new kernel:" \
| grep -v "operator: flash_attn" \
| grep -v "registered at /usr/local" \
| grep -v "self.m.impl" \
| tee -a "${LOG_FILE}"
echo "=================================================================="
echo " 3B SFT Done : $(date)"
echo "=================================================================="