File size: 5,674 Bytes
48ecd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env bash
# =============================================================================
# launch_3b_sft_v2.sh — 8-GPU FP8 SFT v2 launcher for 3B Korean LLM
#
# SFT v2 improvements over v1:
#   - LR: 1e-5 → 5e-5 (5x, resolve underfitting)
#   - Effective batch: 64 → 256 (4x)
#   - Data mixing: 70% SFT + 30% pretrain (forgetting prevention)
#   - Weight decay: 0.01 → 0.05
#   - Warmup: 500 → 2000 steps
#   - Max steps: 33000 → 15000
#
# Usage:
#   bash scripts/launch_3b_sft_v2.sh
#   bash scripts/launch_3b_sft_v2.sh --max_steps 200    # quick test
#   bash scripts/launch_3b_sft_v2.sh --resume checkpoints/korean_3b_sft_v2/checkpoint-0002000
#
# Effective batch: 4 (local) x 8 GPU x 8 (grad_accum) = 256 samples/step
# =============================================================================
set -euo pipefail

# ---- Configurable defaults --------------------------------------------------
RUN_NAME="${RUN_NAME:-korean_3b_sft_v2}"
CONFIG="${CONFIG:-configs/korean_3b_sft_v2.yaml}"
BASE_CHECKPOINT="${BASE_CHECKPOINT:-checkpoints/korean_3b_fp8_run1/checkpoint-0057000}"
SFT_DATA="${SFT_DATA:-data/sft_combined/train_filtered.jsonl}"
VAL_DATA="${VAL_DATA:-data/sft_combined/val_filtered.jsonl}"
PRETRAIN_DATA="${PRETRAIN_DATA:-data/3b_train.bin}"
CKPT_DIR="checkpoints/${RUN_NAME}"
LOG_FILE="${CKPT_DIR}/train.log"
NPROC=8
MASTER_PORT="${MASTER_PORT:-29504}"

MAX_STEPS=15000
BATCH_SIZE=4
GRAD_ACCUM=8
LR="5.0e-5"
WARMUP_STEPS=2000
WEIGHT_DECAY=0.05
PRETRAIN_MIX_RATIO=0.3
SEED=42

EXTRA_ARGS="$@"

# ---- B200 / NVSwitch NCCL tuning (same as pretrain) -------------------------
export NCCL_IB_DISABLE=1
export NCCL_ALGO=Ring
export NCCL_PROTO=Simple
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=16
export NCCL_BUFFSIZE=67108864
export OMP_NUM_THREADS=4
export MKL_NUM_THREADS=4

# 3B + bs=4 VRAM allocation
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

cd "$(dirname "$0")/.."

# ---- Pre-flight checks ------------------------------------------------------
if [[ ! -d "${BASE_CHECKPOINT}" ]]; then
    echo "=================================================================="
    echo "  ERROR: Base checkpoint not found: ${BASE_CHECKPOINT}"
    echo "  Set BASE_CHECKPOINT env var or use --base_checkpoint CLI arg."
    echo "=================================================================="
    exit 1
fi

if [[ ! -f "${SFT_DATA}" ]]; then
    echo "=================================================================="
    echo "  ERROR: SFT data not found: ${SFT_DATA}"
    echo "  Run: bash scripts/prepare_sft_combined.sh"
    echo "=================================================================="
    exit 1
fi

if [[ ! -f "${PRETRAIN_DATA}" ]]; then
    echo "=================================================================="
    echo "  ERROR: Pretrain data not found: ${PRETRAIN_DATA}"
    echo "  Set PRETRAIN_DATA env var to the correct path."
    echo "=================================================================="
    exit 1
fi

# val fallback
if [[ ! -f "${VAL_DATA}" ]]; then
    VAL_FALLBACK="data/sft_combined/val.jsonl"
    if [[ -f "${VAL_FALLBACK}" ]]; then
        VAL_DATA="${VAL_FALLBACK}"
        echo "[INFO] val_filtered not found, fallback: ${VAL_DATA}"
    else
        echo "ERROR: VAL_DATA not found: ${VAL_DATA}"
        exit 1
    fi
fi

mkdir -p "${CKPT_DIR}"

echo "=================================================================="
echo "  3B SFT v2 Fine-Tuning"
echo "  Run name        : ${RUN_NAME}"
echo "  Config          : ${CONFIG}"
echo "  Base checkpoint : ${BASE_CHECKPOINT}"
echo "  SFT data        : ${SFT_DATA}"
echo "  Pretrain data   : ${PRETRAIN_DATA}"
echo "  Val data        : ${VAL_DATA}"
echo "  CKPT dir        : ${CKPT_DIR}"
echo "  Log file        : ${LOG_FILE}"
echo "  Max steps       : ${MAX_STEPS}"
echo "  Batch size      : ${BATCH_SIZE} (local) x ${NPROC} GPU x ${GRAD_ACCUM} grad_accum = $((BATCH_SIZE * NPROC * GRAD_ACCUM)) eff_batch"
echo "  Learning rate   : ${LR}"
echo "  Weight decay    : ${WEIGHT_DECAY}"
echo "  Warmup          : ${WARMUP_STEPS} steps"
echo "  Data mixing     : $((100 - ${PRETRAIN_MIX_RATIO%.*}0))% SFT + ${PRETRAIN_MIX_RATIO}00% pretrain"
echo "  Master port     : ${MASTER_PORT}"
echo "  ALLOC_CONF      : ${PYTORCH_CUDA_ALLOC_CONF}"
echo "  Started         : $(date)"
echo "=================================================================="

export PYTHONWARNINGS="ignore::UserWarning:torch.library"

torchrun \
    --nproc_per_node=${NPROC} \
    --master_port=${MASTER_PORT} \
    train/sft.py \
    --config "${CONFIG}" \
    --base_checkpoint "${BASE_CHECKPOINT}" \
    --sft_data "${SFT_DATA}" \
    --val_data "${VAL_DATA}" \
    --pretrain_data "${PRETRAIN_DATA}" \
    --pretrain_mix_ratio ${PRETRAIN_MIX_RATIO} \
    --checkpoint_dir "${CKPT_DIR}" \
    --log_file "${LOG_FILE}" \
    --max_steps ${MAX_STEPS} \
    --batch_size ${BATCH_SIZE} \
    --grad_accum ${GRAD_ACCUM} \
    --lr ${LR} \
    --weight_decay ${WEIGHT_DECAY} \
    --warmup_steps ${WARMUP_STEPS} \
    --seed ${SEED} \
    --use_fp8 \
    ${EXTRA_ARGS} \
    2>&1 | grep -v "UserWarning" \
         | grep -v "Warning only once" \
         | grep -v "Overriding a previously" \
         | grep -v "dispatch key:" \
         | grep -v "previous kernel:" \
         | grep -v "new kernel:" \
         | grep -v "operator: flash_attn" \
         | grep -v "registered at /usr/local" \
         | grep -v "self.m.impl" \
         | tee -a "${LOG_FILE}"

echo "=================================================================="
echo "  3B SFT v2 Done : $(date)"
echo "=================================================================="