| | #!/usr/bin/env bash |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | set -euo pipefail |
| |
|
| | |
| | RUN_NAME="${RUN_NAME:-korean_3b_sft_v1}" |
| | CONFIG="${CONFIG:-configs/korean_3b_sft.yaml}" |
| | BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}" |
| | SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}" |
| | VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}" |
| | CKPT_DIR="checkpoints/${RUN_NAME}" |
| | LOG_FILE="${CKPT_DIR}/train.log" |
| | NPROC=8 |
| | MASTER_PORT="${MASTER_PORT:-29503}" |
| |
|
| | MAX_STEPS=33000 |
| | BATCH_SIZE=2 |
| | GRAD_ACCUM=4 |
| | LR="1.0e-5" |
| | WARMUP_STEPS=500 |
| | SEED=42 |
| |
|
| | EXTRA_ARGS="$@" |
| |
|
| | |
| | export NCCL_IB_DISABLE=1 |
| | export NCCL_ALGO=Ring |
| | export NCCL_PROTO=Simple |
| | export NCCL_MIN_NCHANNELS=16 |
| | export NCCL_MAX_NCHANNELS=16 |
| | export NCCL_BUFFSIZE=67108864 |
| | export OMP_NUM_THREADS=4 |
| | export MKL_NUM_THREADS=4 |
| |
|
| | |
| | export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |
| |
|
| | cd "$(dirname "$0")/.." |
| |
|
| | |
| | if [[ ! -d "${BASE_CHECKPOINT}" ]]; then |
| | echo "==================================================================" |
| | echo " ERROR: Base checkpoint ๋๋ ํ ๋ฆฌ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
| | echo " ๊ฒฝ๋ก: ${BASE_CHECKPOINT}" |
| | echo "" |
| | echo " --base_checkpoint ์ธ์๋ก ์ค์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๊ฑฐ๋" |
| | echo " BASE_CHECKPOINT ํ๊ฒฝ๋ณ์๋ฅผ ์ค์ ํ์ธ์." |
| | echo " ์: bash scripts/launch_3b_sft.sh --base_checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000" |
| | echo "==================================================================" |
| | exit 1 |
| | fi |
| |
|
| | if [[ ! -f "${SFT_DATA}" ]]; then |
| | echo "==================================================================" |
| | echo " ERROR: SFT ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: ${SFT_DATA}" |
| | echo "" |
| | echo " ๋ฐ์ดํฐ ์ค๋น ์์:" |
| | echo " 1. bash scripts/prepare_sft_combined.sh" |
| | echo " 2. python data/filter_sft_v2.py \\" |
| | echo " --input data/sft_combined/train.jsonl \\" |
| | echo " --output data/sft_combined/train_filtered.jsonl" |
| | echo "==================================================================" |
| | exit 1 |
| | fi |
| |
|
| | |
| | if [[ ! -f "${VAL_DATA}" ]]; then |
| | VAL_FALLBACK="data/sft_combined/val.jsonl" |
| | if [[ -f "${VAL_FALLBACK}" ]]; then |
| | VAL_DATA="${VAL_FALLBACK}" |
| | echo "[INFO] val_filtered ์์, ํด๋ฐฑ: ${VAL_DATA}" |
| | else |
| | echo "ERROR: VAL_DATA ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: ${VAL_DATA}" |
| | exit 1 |
| | fi |
| | fi |
| |
|
| | mkdir -p "${CKPT_DIR}" |
| |
|
| | echo "==================================================================" |
| | echo " 3B SFT Fine-Tuning" |
| | echo " Run name : ${RUN_NAME}" |
| | echo " Config : ${CONFIG}" |
| | echo " Base checkpoint : ${BASE_CHECKPOINT}" |
| | echo " SFT data : ${SFT_DATA}" |
| | echo " Val data : ${VAL_DATA}" |
| | echo " CKPT dir : ${CKPT_DIR}" |
| | echo " Log file : ${LOG_FILE}" |
| | echo " Max steps : ${MAX_STEPS}" |
| | echo " Batch size : ${BATCH_SIZE} (local) ร ${NPROC} GPU ร ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch" |
| | echo " Learning rate : ${LR}" |
| | echo " Warmup : ${WARMUP_STEPS} steps" |
| | echo " Master port : ${MASTER_PORT}" |
| | echo " ALLOC_CONF : ${PYTORCH_CUDA_ALLOC_CONF}" |
| | echo " Started : $(date)" |
| | echo "==================================================================" |
| |
|
| | export PYTHONWARNINGS="ignore::UserWarning:torch.library" |
| |
|
| | torchrun \ |
| | --nproc_per_node=${NPROC} \ |
| | --master_port=${MASTER_PORT} \ |
| | train/sft.py \ |
| | --config "${CONFIG}" \ |
| | --base_checkpoint "${BASE_CHECKPOINT}" \ |
| | --sft_data "${SFT_DATA}" \ |
| | --val_data "${VAL_DATA}" \ |
| | --checkpoint_dir "${CKPT_DIR}" \ |
| | --log_file "${LOG_FILE}" \ |
| | --max_steps ${MAX_STEPS} \ |
| | --batch_size ${BATCH_SIZE} \ |
| | --grad_accum ${GRAD_ACCUM} \ |
| | --lr ${LR} \ |
| | --warmup_steps ${WARMUP_STEPS} \ |
| | --seed ${SEED} \ |
| | --use_fp8 \ |
| | ${EXTRA_ARGS} \ |
| | 2>&1 | grep -v "UserWarning" \ |
| | | grep -v "Warning only once" \ |
| | | grep -v "Overriding a previously" \ |
| | | grep -v "dispatch key:" \ |
| | | grep -v "previous kernel:" \ |
| | | grep -v "new kernel:" \ |
| | | grep -v "operator: flash_attn" \ |
| | | grep -v "registered at /usr/local" \ |
| | | grep -v "self.m.impl" \ |
| | | tee -a "${LOG_FILE}" |
| |
|
| | echo "==================================================================" |
| | echo " 3B SFT Done : $(date)" |
| | echo "==================================================================" |
| |
|