File size: 5,507 Bytes
48ecd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env bash
# =============================================================================
# launch_3b_sft.sh โ€” 8-GPU FP8 SFT launcher for 3B Korean LLM
#
# Usage:
#   bash scripts/launch_3b_sft.sh
#   bash scripts/launch_3b_sft.sh --max_steps 200    # quick test
#   bash scripts/launch_3b_sft.sh --resume checkpoints/korean_3b_sft_v1/checkpoint-0002000
#
# Base model : checkpoints/korean_3b_fp8_run1/checkpoint-XXXXXX  (๊ธฐ๋ณธ๊ฐ’)
#              --base_checkpoint ์ธ์ž๋กœ ๋ฎ์–ด์“ธ ์ˆ˜ ์žˆ์Œ
# SFT data   : data/sft_combined/train_filtered.jsonl
#              (๋จผ์ € scripts/prepare_sft_combined.sh โ†’ data/filter_sft_v2.py ์‹คํ–‰)
#
# Effective batch: 2 (local) ร— 8 GPU ร— 4 (grad_accum) = 64 samples/step
# =============================================================================
set -euo pipefail

# ---- Configurable defaults --------------------------------------------------
RUN_NAME="${RUN_NAME:-korean_3b_sft_v1}"
CONFIG="${CONFIG:-configs/korean_3b_sft.yaml}"
BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}"
SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}"
VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}"
CKPT_DIR="checkpoints/${RUN_NAME}"
LOG_FILE="${CKPT_DIR}/train.log"
NPROC=8
MASTER_PORT="${MASTER_PORT:-29503}"

MAX_STEPS=33000
BATCH_SIZE=2
GRAD_ACCUM=4
LR="1.0e-5"
WARMUP_STEPS=500
SEED=42

EXTRA_ARGS="$@"

# ---- B200 / NVSwitch NCCL tuning (same as pretrain) -------------------------
export NCCL_IB_DISABLE=1
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export NCCL_BUFFSIZE=67108864
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4

# 3B ๋ชจ๋ธ VRAM ์ ˆ์•ฝ โ€” ๋™์  ๋ฉ”๋ชจ๋ฆฌ ์„ธ๊ทธ๋จผํŠธ ํ™•์žฅ ํ—ˆ์šฉ
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

cd "$(dirname "$0")/.."

# ---- Pre-flight checks ------------------------------------------------------
if [[ ! -d "${BASE_CHECKPOINT}" ]]; then
    echo "=================================================================="
    echo "  ERROR: Base checkpoint ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
    echo "  ๊ฒฝ๋กœ: ${BASE_CHECKPOINT}"
    echo ""
    echo "  --base_checkpoint ์ธ์ž๋กœ ์‹ค์ œ ๊ฒฝ๋กœ๋ฅผ ์ง€์ •ํ•˜๊ฑฐ๋‚˜"
    echo "  BASE_CHECKPOINT ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜์„ธ์š”."
    echo "  ์˜ˆ: bash scripts/launch_3b_sft.sh --base_checkpoint checkpoints/korean_3b_fp8_run1/checkpoint-0057000"
    echo "=================================================================="
    exit 1
fi

if [[ ! -f "${SFT_DATA}" ]]; then
    echo "=================================================================="
    echo "  ERROR: SFT ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: ${SFT_DATA}"
    echo ""
    echo "  ๋ฐ์ดํ„ฐ ์ค€๋น„ ์ˆœ์„œ:"
    echo "    1. bash scripts/prepare_sft_combined.sh"
    echo "    2. python data/filter_sft_v2.py \\"
    echo "           --input  data/sft_combined/train.jsonl \\"
    echo "           --output data/sft_combined/train_filtered.jsonl"
    echo "=================================================================="
    exit 1
fi

# val ํŒŒ์ผ ์—†์œผ๋ฉด ์›๋ณธ val.jsonl ๋กœ ํด๋ฐฑ
if [[ ! -f "${VAL_DATA}" ]]; then
    VAL_FALLBACK="data/sft_combined/val.jsonl"
    if [[ -f "${VAL_FALLBACK}" ]]; then
        VAL_DATA="${VAL_FALLBACK}"
        echo "[INFO] val_filtered ์—†์Œ, ํด๋ฐฑ: ${VAL_DATA}"
    else
        echo "ERROR: VAL_DATA ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: ${VAL_DATA}"
        exit 1
    fi
fi

mkdir -p "${CKPT_DIR}"

echo "=================================================================="
echo "  3B SFT Fine-Tuning"
echo "  Run name        : ${RUN_NAME}"
echo "  Config          : ${CONFIG}"
echo "  Base checkpoint : ${BASE_CHECKPOINT}"
echo "  SFT data        : ${SFT_DATA}"
echo "  Val data        : ${VAL_DATA}"
echo "  CKPT dir        : ${CKPT_DIR}"
echo "  Log file        : ${LOG_FILE}"
echo "  Max steps       : ${MAX_STEPS}"
echo "  Batch size      : ${BATCH_SIZE} (local) ร— ${NPROC} GPU ร— ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch"
echo "  Learning rate   : ${LR}"
echo "  Warmup          : ${WARMUP_STEPS} steps"
echo "  Master port     : ${MASTER_PORT}"
echo "  ALLOC_CONF      : ${PYTORCH_CUDA_ALLOC_CONF}"
echo "  Started         : $(date)"
echo "=================================================================="

export PYTHONWARNINGS="ignore::UserWarning:torch.library"

torchrun \
    --nproc_per_node=${NPROC} \
    --master_port=${MASTER_PORT} \
    train/sft.py \
    --config "${CONFIG}" \
    --base_checkpoint "${BASE_CHECKPOINT}" \
    --sft_data "${SFT_DATA}" \
    --val_data "${VAL_DATA}" \
    --checkpoint_dir "${CKPT_DIR}" \
    --log_file "${LOG_FILE}" \
    --max_steps ${MAX_STEPS} \
    --batch_size ${BATCH_SIZE} \
    --grad_accum ${GRAD_ACCUM} \
    --lr ${LR} \
    --warmup_steps ${WARMUP_STEPS} \
    --seed ${SEED} \
    --use_fp8 \
    ${EXTRA_ARGS} \
    2>&1 | grep -v "UserWarning" \
         | grep -v "Warning only once" \
         | grep -v "Overriding a previously" \
         | grep -v "dispatch key:" \
         | grep -v "previous kernel:" \
         | grep -v "new kernel:" \
         | grep -v "operator: flash_attn" \
         | grep -v "registered at /usr/local" \
         | grep -v "self.m.impl" \
         | tee -a "${LOG_FILE}"

echo "=================================================================="
echo "  3B SFT Done : $(date)"
echo "=================================================================="