| # SFT ํ์ดํผํ๋ผ๋ฏธํฐ ๋ถ์ & ๋ค์ ํ๋ ์ต์
์กฐ์ฌ |
|
|
| > ์์ฑ์ผ: 2026-02-26 |
| > ๋ชจ๋ธ: korean_1b_sft (1.19B params, base: korean_1b_fp8_run1/checkpoint-0034000) |
| > ํ์ต: 5000 steps, 39๋ถ, 8ร B200 |
| |
| --- |
| |
| ## 1. Loss Curve ๋ถ์ |
| |
| ### 1-1. ๊ธฐ๋ณธ ํต๊ณ |
| |
| | ๊ตฌ๊ฐ | Steps | n | Loss Mean | Loss Stdev | Loss Min | Loss Max | GNorm Mean | |
| |------|-------|---|-----------|------------|----------|----------|------------| |
| | Warmup | 10โ150 | 15 | 2.3100 | 0.1144 | 2.1129 | 2.5229 | 1.414 | |
| | Post-warmup ์ ์ฒด | 160โ5000 | 485 | 1.9984 | 0.0942 | 1.7305 | 2.3413 | 1.133 | |
| | Q1 (์ด๊ธฐ) | 160โ1360 | 121 | 2.0698 | 0.0860 | 1.8850 | 2.3413 | 1.138 | |
| | Q2 (์ค๋ฐ1) | 1370โ2570 | 121 | 1.9915 | 0.0801 | 1.7960 | 2.2088 | 1.131 | |
| | Q3 (์ค๋ฐ2) | 2580โ3780 | 121 | 1.9583 | 0.0870 | 1.7384 | 2.1293 | 1.119 | |
| | Q4 (ํ๋ฐ) | 3790โ5000 | 122 | **1.9739** | 0.0835 | 1.7305 | 2.1635 | 1.142 | |
| |
| ### 1-2. 500-step ์ด๋ ํ๊ท Loss (ยฑ50 step ์๋์ฐ) |
| |
| | Step | Loss(avg) | GNorm(avg) | ํด์ | |
| |------|-----------|------------|------| |
| | ~500 | 2.0658 | 1.098 | ์ด๊ธฐ ํ๊ฐ ๋จ๊ณ | |
| | ~1000 | 2.0281 | 1.121 | ๋น ๋ฅธ ํ๊ฐ ์ง์ | |
| | ~1500 | 1.9663 | 1.092 | โ
์ต์ด <2.0 ์ง์
| |
| | ~2000 | 1.9802 | 1.158 | ์ํญ ๋ฐ๋ฑ (์ ์) | |
| | ~2500 | 1.9882 | 1.140 | ์์ ํ ๊ตฌ๊ฐ ์์ | |
| | ~3000 | 1.9628 | 1.083 | ์ต์ ์ ๊ทผ๋ฐฉ | |
| | ~3500 | 1.9668 | 1.151 | ์๋ ด ์ ํธ | |
| | ~4000 | 1.9679 | 1.161 | ๊ณ ์ ์ง์
| |
| | ~4500 | 1.9555 | 1.142 | ๋ฏธ์ธ ํ๊ฐ ์ง์ | |
| | ~5000 | 1.9718 | 1.195 | **์ต์ข
: 1.9677** | |
| |
| ### 1-3. ํด์ |
| |
| **Warmup ๊ตฌ๊ฐ (step 10โ150):** |
| - LR์ด 1.33e-6 โ 2e-5๋ก ์ ํ ์ฆ๊ฐํ๋ ๋์ loss๊ฐ 2.11โ2.52 ๋ฒ์์์ ๋ถ๊ท์นํจ |
| - Warmup ์งํ step 160์์ loss spike (2.34, 3.6ฯ) ๋ฐ์ โ warmup ์ข
๋ฃ ์งํ full LR ์ถฉ๊ฒฉ. ์ ์์ ์ด๊ณ ํํ ํจํด |
| - Warmup 150 steps๋ ์ด 5000 steps์ 3% โ ์ ์ |
| |
| **์ ์ ํ์ต ๊ตฌ๊ฐ (step 160โ5000):** |
| - Loss๊ฐ Q1โQ3 ๊ตฌ๊ฐ์์ 2.07โ1.96์ผ๋ก ์ง์ ํ๊ฐ (์ด 0.11 ๊ฐ์) |
| - Q3โQ4๋ 1.958โ1.974์ผ๋ก **์คํ๋ ค ์ํญ ์์น** โ cosine LR์ด ์ถฉ๋ถํ ๋ฎ์์ง๋ฉด์ ํ์ต ์๋ ์ ํ, ์๋ ด ์งํ |
| - ํ์คํธ์ฐจ 0.094๋ ์์ ์ (SFT ๊ธฐ์ค 0.05โ0.15 ์ ์ ๋ฒ์) |
| |
| **Outlier ๋ถ์:** |
| - Mean+2ฯ = 2.187 ์ด๊ณผ: 10๊ฐ / 485 = **2.1%** โ ์ ์ ์์ค |
| - ๋ชจ๋ ์ด๊ธฐ(step 160โ800)์ ์ง์ค + step 2190 1๊ฐ โ ๋ฐ์ดํฐ ๋ค์์ฑ์ ์ํ ์ ์ ๋ณ๋ |
| - gnorm spike์ ๋๋ฐํ์ง ์์ gradient ํญ๋ฐ ์์ |
| |
| **GNorm ํจํด:** |
| - ์ ์ฒด ํ๊ท 1.13, max_grad_norm=1.0์ผ๋ก ์ค์ ๋์ด ์์ผ๋ ๋ก๊ทธ๊ฐ์ 0.89โ1.53 |
| - ๋ก๊ทธ๋๋ gnorm์ clip **์ด์ ** ๊ฐ์ผ๋ก ์ถ์ ; ์ค์ 1.0 ์ด๊ณผ ์ clip ๋ฐ์ |
| - Warmup ๊ตฌ๊ฐ(ํ๊ท 1.41)์ด ์ดํ(ํ๊ท 1.13)๋ณด๋ค ๋์ โ ์ ์ ํจํด |
| - ํ์ต ์ ๋ฐ์ ๊ฑธ์ณ ๊ฐ์ ์ถ์ธ (gnorm ์์ ํ = ํ์ต์ด ์๋ ด ์ค) |
| |
| **ํต์ฌ ๊ฒฐ๋ก :** ํ์ต์ ๊ฑด๊ฐํ๊ฒ ์งํ๋จ. Step ~3000 ์ดํ ์๋ ด ์ ํธ๊ฐ ์์ผ๋ loss๋ ์ฌ์ ํ ๋ฏธ์ธ ํ๊ฐ ์ค. 5000 steps ์ข
๋ฃ ์์ ์ด ์ ์ ํ stopping point์๊ฑฐ๋ ์ถ๊ฐ ํ์ต ์ฌ์ง ์์. |
| |
| --- |
| |
| ## 2. ํ์ดํผํ๋ผ๋ฏธํฐ ์ํฅ ๋ถ์ |
| |
| ### 2-1. Learning Rate: **2e-5** โ โ
์ ์ (์
๊ณ ํ์ค ๋ฒ์) |
| |
| | ๋ชจ๋ธ/ํ๋ ์์ํฌ | LR | ๊ท๋ชจ | |
| |---|---|---| |
| | Meta Alpaca (Llama 7B) | 2e-5 | 7B | |
| | WizardLM (Vicuna 13B) | 2e-5 | 13B | |
| | OpenHermes (Mistral 7B) | 2e-5 | 7B | |
| | LIMA (65B) | 1e-5 | 65B | |
| | TinyLlama SFT (1.1B) | 2e-5 | 1.1B | |
| | **ํ์ฌ ์ค์ ** | **2e-5** | **1.2B** | |
| |
| - 1B ๊ท๋ชจ์์ 2e-5๋ ์
๊ณ ํ์ค๊ฐ๊ณผ ์ ํํ ์ผ์น |
| - pretrain LR(2e-4)์ 1/10์ผ๋ก ์ค์ โ catastrophic forgetting ๋ฐฉ์ง ์์น ์ถฉ์กฑ |
| - ๋จ, ์ถ๊ฐ epoch ์์๋ 1e-5๋ก ๋ฎ์ถ๋ ๊ฒ์ด ์์ |
| |
| **๊ฐ์ ๋ฐฉํฅ:** ํ์ฌ ์ค์ ์ ์ง. 2์ฐจ ํ์ต ์ 1e-5 ์ถ์ฒ. |
| |
| ### 2-2. Cosine Decay ์ค์ผ์ค โ โ
์ ์ (๋จ, ์ต์ข
LR ์ฝ๊ฐ ๋์) |
| |
| - ์ต์ข
LR: 2.00e-6 (peak์ 10%) |
| - ํ์ค cosine schedule: min_lr = 0.1 ร peak_lr |
| - 5000 steps์ ๋ง๋ ์ค์ : warmup 150 + cosine decay 4850 steps |
| - step 5000์์ LR์ด 2e-6์ผ๋ก ์์ฐ ์๋ ด โ ํ์ต์ด ๋ง๋ฌด๋ฆฌ๋ ๋๋ |
| |
| **๊ฐ์ ๋ฐฉํฅ:** min_lr์ 0 ๋๋ 1e-7๋ก ๋ฎ์ถ๋ฉด ๋ง์ง๋ง ๊ตฌ๊ฐ ๋ ์์ ์ ์๋ ด ๊ฐ๋ฅ. ํ์ฌ ์ค์ ๋ ๋ฌด๋ฐฉ. |
|
|
| ### 2-3. Effective Batch Size: **64 sequences** (=262K tokens/step) โ โ
์ ์ |
|
|
| - 64 seqs ร ํ๊ท ~500 tokens (dynamic padding) โ 32,000 tokens/step ์ค์ ์ฒ๋ฆฌ๋ |
| - max_seq_len=4096 ๊ธฐ์ค ์ด๋ก ๊ฐ์ 262,144 tok/step์ด์ง๋ง ๋์ ํจ๋ฉ์ผ๋ก ์ค์ ๋ ๋ฎ์ |
| - SFT ๋ฐฐ์น ํฌ๊ธฐ ์ฐธ๊ณ : Alpaca=128 seqs, WizardLM=64 seqs, LIMA=64 seqs |
| - **64๋ ์
๊ณ ํ์ค๊ฐ๊ณผ ์ ํ ์ผ์น** |
|
|
| **๊ฐ์ ๋ฐฉํฅ:** ํ์ฌ ์ค์ ์ ์ง. ๋ฐฐ์น๊ฐ ๋๋ฌด ํฌ๋ฉด generalization ์ ํ ๊ฐ๋ฅ์ฑ ์์. |
|
|
| ### 2-4. Epochs: **~2 epoch** โ โ ๏ธ ๋ถ์กฑ ๊ฐ๋ฅ์ฑ (์์ ์ ํจ) |
|
|
| - 5000 steps ร 64 seqs = 320,000 ์์ ์ฒ๋ฆฌ / 159,000 ์ํ = **์ฝ 2.0 epoch** |
| - SFT ์
๊ณ ๊ธฐ์ค: |
| - LIMA: 15 epoch (์๋ ๋ฐ์ดํฐ 1K๊ฐ) |
| - Alpaca, WizardLM: **3 epoch** |
| - OpenHermes, Hermes: 3โ5 epoch |
| - ๋๊ท๋ชจ ๋ฐ์ดํฐ(>100K): 1โ3 epoch |
|
|
| - 2 epoch๋ **๊ณผ์ํ์ต ๊ฐ๋ฅ์ฑ** ์์ (ํนํ ๋ฎ์ ๋น๋ ๋ฐ์ดํฐ ํจํด ํ์ต ๋ถ์กฑ) |
| - Q4 loss(1.974)๊ฐ Q3(1.958)๋ณด๋ค ์ด์ง ๋์์ง ๊ฒ์ cosine LR ๊ฐ์ ํจ๊ณผ + ์์ง ์๋ ด ์ ์ผ ๊ฐ๋ฅ์ฑ ๊ณต์กด |
| - Val loss๊ฐ ์์ด ๊ณผ์ ํฉ ์ฌ๋ถ ํ์ธ ๋ถ๊ฐ (โ
eval_interval=100์ผ๋ก ์ค์ ์ ๋์ด ์์์ผ๋ ๊ฒฐ๊ณผ ์์) |
| |
| **๊ฐ์ ๋ฐฉํฅ:** 3โ4 epoch (7500โ10000 steps) ์ถ๊ฐ ์คํ ๊ถ์ฅ. ๋จ val split ํ์ ํ๋ณด ํ ์งํ. |
| |
| ### 2-5. NEFTune alpha=10 โ โ
์ด ๋ฐ์ดํฐ์
ํฌ๊ธฐ์ ์ ํฉ |
| |
| - ์๋
ผ๋ฌธ(Jain et al., 2023) ๊ถ์ฅ๊ฐ: ์๊ท๋ชจ(<10K) โ 5, ์ค๊ท๋ชจ(10Kโ500K) โ 10, ๋๊ท๋ชจ(>500K) โ 15 |
| - 159K ์ํ โ **alpha=10 ์ ํฉ** |
| - Noise magnitude = alpha / sqrt(seq_len ร d_model) = 10 / sqrt(500 ร 2048) โ 0.0099 |
| - ์ค์ embedding ๊ฐ ๋๋น ์ ์ ํ noise ๋น์จ |
| - Loss curve ์์ ์ฑ(stdev 0.094)์ผ๋ก ๋ณผ ๋ NEFTune์ด ํ์ต์ ๋ถ์์ ํ๊ฒ ๋ง๋ค์ง ์์์ |
| |
| **๊ฐ์ ๋ฐฉํฅ:** ํ์ฌ ์ค์ ์ ์ง. ๋ฐ์ดํฐ ์ฆ๊ฐ(500K+) ์ alpha=15๋ก ์ํฅ ๊ณ ๋ ค. |
| |
| ### 2-6. max_seq_len: **4096** โ โ
์ ์ (๋จ, ํ์ฉ๋ ํ์ธ ํ์) |
| |
| - ์ค์ : max_seq_len=4096, dynamic padding ์ ์ฉ |
| - ํ๊ตญ์ด instruction ๋ฐ์ดํฐ ํ๊ท ๊ธธ์ด: 200โ1000 tokens (kullm/KoAlpaca ๊ธฐ์ค) |
| - Dynamic padding ๋๋ถ์ ์งง์ ์ํ์ค๋ค์ ์ค์ ๋ก 4096์ ์ฑ์ฐ์ง ์์ โ compute ํจ์จ์ |
| - rope_theta=500000 (Llama-3 ์คํ์ผ) โ 4096 ์ด์ ์ธ์ฝ๋ ์ง์ |
|
|
| **์ ์ฌ ๋ฌธ์ :** |
| - ๋ฐ์ดํฐ์
์ 4096 ์ด๊ณผ ๋ํ๊ฐ ์๋ค๋ฉด truncation ๋ฐ์ โ ๊ธด multi-turn ๋ํ ์์ค |
| - ํ์ฌ ๋ฐ์ดํฐ์
(kullm, KoAlpaca, LIMA ๋ฑ)์ ๋๋ถ๋ถ 2048 ์ดํ์ด๋ฏ๋ก ์ค์ง์ ์ํฅ ์ ์ |
|
|
| **๊ฐ์ ๋ฐฉํฅ:** ํ์ฌ ์ค์ ์ ์ง. ์ฅ๋ฌธ ๋ํ ๋ฐ์ดํฐ ์ถ๊ฐ ์ 8192 ๊ณ ๋ ค. |
|
|
| --- |
|
|
| ## 3. ๋ค์ ํ๋ ์ต์
ํ๋ณด๊ตฐ |
|
|
| ### A. ์ถ๊ฐ SFT Epoch (5000 โ 10000 steps, epoch 4) |
|
|
| **Pros:** |
| - ํ์ฌ loss๊ฐ ์ฌ์ ํ ํ๊ฐ ์ถ์ธ โ ์ถ๊ฐ ํ์ต ์ฌ์ง ์์ |
| - epoch 3โ4๋ SFT ์
๊ณ ํ์ค (Alpaca, WizardLM ๊ธฐ์ค) |
| - ๊ธฐ์กด ์ฒดํฌํฌ์ธํธ์์ resume ๊ฐ๋ฅ, 39๋ถ ์ถ๊ฐ๋ฉด ์ถฉ๋ถ (B200 ์๋ ๊ธฐ์ค) |
| - ๊ตฌํ ๊ฐ๋ฅ: `--resume checkpoints/korean_1b_sft/checkpoint-5000 --max_steps 10000` |
|
|
| **Cons:** |
| - Val loss ์์ด ์งํ ์ ๊ณผ์ ํฉ ๊ฐ์ง ๋ถ๊ฐ |
| - cosine schedule์ด ์ด๋ฏธ step 5000 ๊ธฐ์ค์ผ๋ก ์ค๊ณ๋์ด ์์ โ resume ์ LR ์ค์ผ์ค ์ฌ์ค์ ํ์ |
| - epoch 4 ์ดํ ๊ณผ์ ํฉ ์ํ (ํนํ ๋ฐ๋ณต ํจํด memorization) |
|
|
| **์ถ์ฒ:** โ
**์กฐ๊ฑด๋ถ ์ถ์ฒ** โ val split 5โ10% ํ๋ณด ํ, LR=1e-5๋ก ์ cosine schedule ์ค์ ํ์ฌ ์ถ๊ฐ ํ์ต. Resume๋ณด๋ค fresh start ๊ถ์ฅ. |
|
|
| **๊ตฌ์ฒด์ ์ค์ :** |
| ```yaml |
| max_steps: 5000 # ์ถ๊ฐ 5000 steps (epoch 3-4) |
| lr: 1.0e-5 # ์ด์ ์ ์ ๋ฐ |
| warmup_steps: 50 # ์งง์ warmup |
| ``` |
|
|
| --- |
|
|
| ### B. LR ํ๋: 2e-5 vs 1e-5 vs 5e-6 |
|
|
| | LR | ์ฅ์ | ๋จ์ | ์ถ์ฒ | |
| |----|------|------|------| |
| | 5e-6 | ๋งค์ฐ ์์ , ๊ณผ์ ํฉ ๋ฐฉ์ง | 5000 steps์์ ๊ฐ์ ํญ ์ ์ ์ ์์ | โ ๋๋ฌด ๋ณด์์ | |
| | **1e-5** | **๊ท ํ์กํ ์ ํ, 2์ฐจ ํ์ต ํ์ค** | ํ์ฌ ๋๋น ํ์ต ์๋ ์ ๋ฐ | โ
**์ถ์ฒ** | |
| | 2e-5 (ํ์ฌ) | 1์ฐจ ํ์ต์์ ์ข์ ๊ฒฐ๊ณผ | ์ถ๊ฐ epoch์์ ๊ณผ์ ํฉ ์ํ | โ ๏ธ ์ถ๊ฐ ํ์ต์ ๋ถ๋ฆฌ | |
|
|
| **๊ฒฐ๋ก :** 2์ฐจ ํ์ต ์ **lr=1e-5** ์ฌ์ฉ. ํ์ฌ lr=2e-5๋ 1์ฐจ ํ์ต์ ์ต์ . |
|
|
| --- |
|
|
| ### C. ORPO (Odds Ratio Preference Optimization) |
|
|
| **๊ฐ์:** SFT + preference alignment์ ๋จ์ผ ๋จ๊ณ์์ ๋์ ์ํ. Reference model ๋ถํ์. |
|
|
| **Pros:** |
| - Reference model ์์ด ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ (DPO ๋๋น VRAM ์ฝ 40% ์ ์ฝ) |
| - SFT์ preference๋ฅผ ๋์์ ์ต์ ํ โ ๋ชจ๋ธ ํ์ง ์ ํ ์์ด alignment ๊ฐ๋ฅ |
| - 1-stage ํ์ดํ๋ผ์ธ โ ์ด์ ๋จ์ํ |
| - `trl` ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์ฝ๊ฒ ๊ตฌํ ๊ฐ๋ฅ |
|
|
| **Cons:** |
| - Chosen/rejected ์ ๋ฐ์ดํฐ ํ์ (ํ์ฌ ์์) |
| - ํ๊ตญ์ด preference ๋ฐ์ดํฐ ์ ํ์ง๊ฐ ์ ํ์ |
|
|
| **ํ๊ตญ์ด Preference ๋ฐ์ดํฐ ํํฉ (HuggingFace ๊ธฐ์ค):** |
| | ๋ฐ์ดํฐ์
| ์ํ ์ | ํน์ง | |
| |---------|---------|------| |
| | `maywell/ko_Ultrafeedback` | ~60K | UltraFeedback ํ๊ตญ์ด ๋ฒ์ญ | |
| | `ChuGyouk/korean-ultrafeedback-armorm` | ~60K | ArmoRM ์ค์ฝ์ด ํฌํจ | |
| | `HAERAE-HUB/K2-Align` | ~10K | ํ๊ตญ์ด RLHF alignment | |
| | `heegyu/KORANI-v1` | ~20K | Korean RANI (human feedback) | |
| | `trl-lib/ultrafeedback_binarized` | ~60K | ์์ด (๋ฒ์ญ ํ์) | |
|
|
| **์ถ์ฒ:** โ
**์ถ์ฒ** โ `maywell/ko_Ultrafeedback` ๋๋ `ChuGyouk/korean-ultrafeedback-armorm` ํ๋ณด ํ TRL `ORPOTrainer`๋ก ๊ตฌํ. SFT ํ ORPO ์ ์ฉ ๋๋ from scratch ORPO ๋ชจ๋ ๊ฐ๋ฅ. |
|
|
| **๊ตฌํ ์์:** |
| ```python |
| from trl import ORPOConfig, ORPOTrainer |
| config = ORPOConfig(learning_rate=5e-7, num_train_epochs=1, ...) |
| trainer = ORPOTrainer(model, config, train_dataset=preference_data) |
| ``` |
|
|
| --- |
|
|
| ### D. DPO (Direct Preference Optimization) |
|
|
| **๊ฐ์:** SFT ์๋ฃ ๋ชจ๋ธ ์์ preference alignment์ ์ถ๊ฐ ํ์ต. Reference model(=SFT ๋ชจ๋ธ frozen) ํ์. |
|
|
| **vs ORPO:** |
| | | DPO | ORPO | |
| |--|-----|------| |
| | Reference model | ํ์ (VRAM +40%) | ๋ถํ์ | |
| | SFT ๋จ๊ณ | ๋ณ๋ ํ์ | ํตํฉ ๊ฐ๋ฅ | |
| | ์์ ์ฑ | ๊ฒ์ฆ๋ ๋ฐฉ๋ฒ | ์๋์ ์ผ๋ก ์ ๊ท | |
| | ๋ฐ์ดํฐ | chosen/rejected | chosen/rejected | |
| | ๊ตฌํ ๋ณต์ก๋ | ์ค๊ฐ | ๋ฎ์ | |
|
|
| **Pros:** |
| - ๊ฐ์ฅ ๋๋ฆฌ ๊ฒ์ฆ๋ preference optimization ๋ฐฉ๋ฒ |
| - `trl` ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์์ ์ง์ |
| - Llama, Mistral ๊ธฐ๋ฐ ๋ชจ๋ ์ฃผ์ ๋ชจ๋ธ์ ์ ์ฉ๋จ |
|
|
| **Cons:** |
| - SFT ๋ชจ๋ธ์ reference๋ก ๋๊ณ ์ถ๊ฐ ํ์ต โ ๋ฉ๋ชจ๋ฆฌ 2๋ฐฐ (1.2B ร 2 = ~16GB, B200 192GB์์ ๋ฌด๋ฆฌ ์์) |
| - 2๋จ๊ณ ํ์ต ํ์ดํ๋ผ์ธ ๋ณต์ก์ฑ |
|
|
| **์ถ์ฒ:** โ
**์ถ์ฒ** โ ORPO๋ณด๋ค ๊ฒ์ฆ๋ ๋ฐฉ๋ฒ. B200 ร 8์์ ๋ฉ๋ชจ๋ฆฌ ์ด์ ์์. ORPO์ A/B ํ
์คํธ ๊ฐ์น ์์. |
|
|
| --- |
|
|
| ### E. LoRA/QLoRA |
|
|
| **๋งฅ๋ฝ:** ์ด๋ฏธ full fine-tuning ์๋ฃ. LoRA์ ์ญํ ์? |
|
|
| **Pros:** |
| - ๋น ๋ฅธ ํ์ดํผํ๋ผ๋ฏธํฐ ์คํ (LR, epoch, alpha ์กฐํฉ): full FT ๋๋น 3-5x ๋น ๋ฆ |
| - ์ฌ๋ฌ adaptation ๋์ ๊ด๋ฆฌ (domain-specific LoRA weights) |
| - DPO/ORPO ๋จ๊ณ์์ adapter๋ง ํ์ต ๊ฐ๋ฅ |
| - VRAM ์ฌ์ฉ ์ ์ฝ โ batch size ์ฆ๊ฐ ๊ฐ๋ฅ |
|
|
| **Cons:** |
| - ์ด๋ฏธ full FT๋ ๋ชจ๋ธ์ด ์์ผ๋ฏ๋ก LoRA ์ฑ๋ฅ ์ํ โค full FT |
| - 1B ๋ชจ๋ธ์ ์ด๋ฏธ ์์์ QLoRA์ 4-bit quantization ์ด์ ์ด ํฌ์ง ์์ |
| - Fine-tuning quality๋ full FT๊ฐ ํญ์ ์ฐ์ธ |
|
|
| **์ถ์ฒ:** โ ๏ธ **์กฐ๊ฑด๋ถ ์ถ์ฒ** โ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์(lr ๊ทธ๋ฆฌ๋์์น, epoch sweep)์ LoRA ํ์ฉ. ์ต์ข
๋ชจ๋ธ์ full FT. |
|
|
| **์ค์ฉ์ ์ฌ์ฉ๋ฒ:** |
| ```python |
| # ๋น ๋ฅธ ์คํ: LoRA rank=64๋ก LR ๊ทธ๋ฆฌ๋์์น |
| # rank=64, alpha=128, dropout=0.05 |
| # ์ฝ 5-10๋ถ / ์คํ (B200 ๊ธฐ์ค) |
| ``` |
|
|
| --- |
|
|
| ### F. ๋ฐ์ดํฐ ํ์ง ๊ฐ์ |
|
|
| **ํ์ฌ ๋ฐ์ดํฐ ๊ตฌ์ฑ:** |
| - kullm: ๋๊ท๋ชจ ํ๊ตญ์ด instruction (ํ์ง ํผ์ฌ) |
| - KoAlpaca: Alpaca ํ๊ตญ์ด ๋ฒ์ญ (๋ฒ์ญ ํ์ง ์ด์) |
| - safe_conv: ์์ ๋ํ ๋ฐ์ดํฐ |
| - LIMA: ๊ณ ํ์ง ์์ด instruction (1000๊ฐ) |
| - evol_instruct: GPT-4 ์์ฑ (๊ณ ํ์ง) |
| - kovast: ํ๊ตญ์ด ๋ํ |
|
|
| **๊ฐ์ ๋ฐฉํฅ:** |
|
|
| 1. **Deduplication (MinHash LSH):** |
| - instruction text์ ๋ํด locality-sensitive hashing |
| - ์์ ์ค๋ณต ์ ๊ฑฐ์จ: 5โ15% (159K โ 135โ150K ์ ๋) |
| - ํ์ง ํฅ์ ํจ๊ณผ: ์ค๋ณต ํจํด memorization ๋ฐฉ์ง |
|
|
| 2. **Quality Filtering:** |
| - Perplexity ๊ธฐ๋ฐ ํํฐ: ๋๋ฌด ๋ฎ๊ฑฐ๋ ๋๋ฌด ๋์ perplexity ์ ๊ฑฐ |
| - ์ธ์ด ํ์ธ: ํ๊ตญ์ด ๋น์จ ์ฒดํฌ (`langdetect`) |
| - ๊ธธ์ด ํํฐ: ๋๋ฌด ์งง์ ์๋ต(<50 tokens) ์ ๊ฑฐ |
| - ๋ฐ๋ณต ํจํด ์ ๊ฑฐ: `n-gram repetition score` ๊ธฐ๋ฐ |
|
|
| 3. **Domain Mixing ์กฐ์ :** |
| - LIMA-style: ์๋์ ๊ณ ํ์ง ๋ฐ์ดํฐ๊ฐ ๋๋์ ์ ํ์ง๋ณด๋ค ํจ๊ณผ์ |
| - evol_instruct ๋น์จ โ (GPT-4 ์์ฑ์ด๋ฏ๋ก ๊ณ ํ์ง) |
| - ๋จ์ ๋ฒ์ญ ๋ฐ์ดํฐ(KoAlpaca) ๋น์จ โ |
| |
| **์ถ์ฒ:** โ
**๊ฐ๋ ฅ ์ถ์ฒ** โ ๋ฐ์ดํฐ ํ์ง์ด epoch ์๋ณด๋ค ์ค์. 1์ฃผ์ผ ํฌ์๋ก ์ค์ง์ ์ฑ๋ฅ ํฅ์ ๊ธฐ๋. |
| |
| --- |
| |
| ### G. ๋ ๋ง์ SFT ๋ฐ์ดํฐ (159K โ 500K+) |
| |
| **HuggingFace ์ถ๊ฐ ๊ฐ๋ฅ ๋ฐ์ดํฐ์
:** |
| |
| | ๋ฐ์ดํฐ์
| ์ํ ์ | ์ธ์ด | ํ์ง | ๋น๊ณ | |
| |---------|---------|------|------|------| |
| | `HAERAE-HUB/qarv-instruct-100k` | 100K | ํ๊ตญ์ด | ์ค์ | ํ๊ตญ์ด instruction 100K | |
| | `nayohan/llama3-instruct-ko-dataset` | 58K | ํ๊ตญ์ด | ์ | Llama-3 instruction ํ๊ตญ์ด | |
| | `hPark/orca-ko` | 200K+ | ํ๊ตญ์ด | ์ | Orca ์คํ์ผ ํ๊ตญ์ด | |
| | `maywell/synatra-orca` | 300K+ | ํ๊ตญ์ด | ์ | ํฉ์ฑ ๋ฐ์ดํฐ, ๊ณ ํ์ง | |
| | `FreedomIntelligence/evol-instruct-korean` | 70K | ํ๊ตญ์ด | ์ | GPT-4 ์์ฑ ํ๊ตญ์ด | |
| | `Bingsu/ko_alpaca_data` | 52K | ํ๊ตญ์ด | ์ค | Alpaca ํ๊ตญ์ด (๋ฒ์ญ) | |
| | `HAERAE-HUB/KoInstruct` | 50K+ | ํ๊ตญ์ด | ์ค์ | ํ๊ตญ์ด instruction | |
| | `Open-Orca/OpenOrca` | 1M+ | ์์ด | ์ต์ | ๊ณ ํ์ง ์์ด (ํ๊ตญ์ด ๋ชจ๋ธ์ ํผํฉ ๊ฐ๋ฅ) | |
| |
| **500K ๋ฌ์ฑ ๊ฒฝ๋ก:** |
| 1. ํ์ฌ 159K |
| 2. `hPark/orca-ko` + `maywell/synatra-orca` ์ถ๊ฐ: +200K = 359K |
| 3. `HAERAE-HUB/qarv-instruct-100k` + `nayohan/llama3-instruct-ko-dataset`: +158K = 517K |
| 4. ํ์ง ํํฐ ํ ์ ์ง ๋น์จ ~80% โ **์ฝ 400K ์ ๋ฐ์ดํฐ** |
| |
| **Pros:** |
| - ๋ ๋ง์ ๋๋ฉ์ธ ์ปค๋ฒ๋ฆฌ์ง |
| - ๋๋ฌธ ํจํด ํ์ต ๊ธฐํ ์ฆ๊ฐ |
| - Generalization ํฅ์ |
| |
| **Cons:** |
| - ๋ฐ์ดํฐ ํ์ง ๊ฒ์ฆ ํ์ (๋ฌด๋ถ๋ณ ์ถ๊ฐ๋ ์ญํจ๊ณผ) |
| - ํ์ต ์๊ฐ ์ฆ๊ฐ (๊ฐ์ epoch ๊ธฐ์ค 3๋ฐฐ โ 2์๊ฐ+) |
| - ๊ณ ํ์ง ์๋ vs ์ ํ์ง ๋ค๋ ํธ๋ ์ด๋์คํ |
| |
| **์ถ์ฒ:** โ
**์ถ์ฒ (ํ์ง ํํฐ ์ ์ )** โ `hPark/orca-ko`๋ `maywell/synatra-orca` ๊ฐ์ ๊ณ ํ์ง ํฉ์ฑ ๋ฐ์ดํฐ ์ฐ์ ์ถ๊ฐ. ๋จ์ ๋ฒ์ญ ๋ฐ์ดํฐ ๋น์จ ์ฃผ์. |
| |
| --- |
| |
| ## 4. ์ฆ์ ์คํ ๊ฐ๋ฅํ ์คํ Top 3 |
| |
| ### ๐ฅ 1์์: **ํ์ฌ ๋ชจ๋ธ ์ข
ํฉ ํ๊ฐ (eval ์คํ)** |
| |
| **์ด์ :** |
| - Loss 1.9677์ด ์ค์ ๋ก ์ข์ ๋ชจ๋ธ์ธ์ง ์ ์ ์์ |
| - ์ถ๊ฐ ํ์ต ๋ฐฉํฅ ๊ฒฐ์ ์ baseline ํ์ |
| - ์ด๋ฏธ `eval/comprehensive_eval.py` ์กด์ฌ |
|
|
| **์ฆ์ ์คํ:** |
| ```bash |
| cd /PROJECT/0325120031_A/ghong/taketimes/llm-bang |
| |
| # Perplexity ํ๊ฐ |
| python eval/perplexity.py \ |
| --checkpoint checkpoints/korean_1b_sft/checkpoint-5000 \ |
| --data data/sft/val.jsonl # val split ํ์ |
| |
| # ์์ฑ ํ์ง ๋น ๋ฅธ ์ฒดํฌ |
| python eval/generate.py \ |
| --checkpoint checkpoints/korean_1b_sft/checkpoint-5000 \ |
| --prompts "์๋
ํ์ธ์, ์ ๋ AI ๋ชจ๋ธ์
๋๋ค. ์ค๋ ๋ ์จ์ ๋ํด ์ค๋ช
ํด์ฃผ์ธ์." |
| ``` |
|
|
| **์์ ์๊ฐ:** 10โ30๋ถ |
|
|
| --- |
|
|
| ### ๐ฅ 2์์: **lr=1e-5๋ก ์ถ๊ฐ SFT (epoch 3โ4๊น์ง)** |
|
|
| **์ด์ :** |
| - Loss curve๊ฐ ์์ง ์๋ ดํ์ง ์์๊ณ epoch 2๋ ์
๊ณ ํ์ค๋ณด๋ค ๋ถ์กฑ |
| - ๊ตฌํ ๋น์ฉ ์ต์ (๊ธฐ์กด ์ฝ๋ ์ฌ์ฌ์ฉ) |
| - B200 ร 8์์ ์ฝ 40โ60๋ถ ์ถ๊ฐ (39๋ถ/5000steps ๊ธฐ์ค) |
|
|
| **๊ตฌ์ฒด์ ์ค์ :** |
| ```bash |
| # ์ run์ผ๋ก checkpoint-5000์์ ์์ |
| RUN_NAME=korean_1b_sft_v2 \ |
| BASE_CHECKPOINT=checkpoints/korean_1b_sft/checkpoint-5000 \ |
| LR=1.0e-5 \ |
| MAX_STEPS=5000 \ # epoch 3-4 |
| WARMUP_STEPS=50 \ # ์งง์ warmup |
| bash scripts/launch_sft.sh |
| ``` |
|
|
| **์ฃผ์:** val split ์์ผ๋ฉด step 3000โ5000์์ val loss ์ฒดํฌํ๋ฉฐ early stop ๊ธฐ์ค ์๋ ์ค์ ํ์. |
|
|
| **์์ ๊ฒฐ๊ณผ:** loss 1.90โ1.93 (ํ์ฌ 1.97 ๋๋น ์ฝ 2โ3% ๊ฐ์ ), ์์ฑ ํ์ง ์ฒด๊ฐ ํฅ์ ๊ธฐ๋. |
|
|
| --- |
|
|
| ### ๐ฅ 3์์: **๋ฐ์ดํฐ ํ์ง ๊ฐ์ + ์ถ๊ฐ ๋ฐ์ดํฐ ์์ง** |
|
|
| **์ด์ :** |
| - ๋ฐ์ดํฐ ํ์ง์ด ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋๋ณด๋ค ์ฅ๊ธฐ์ ์ผ๋ก ์ค์ |
| - ํ์ฌ ๋ฐ์ดํฐ์ ์ค๋ณต/์ ํ์ง ํฌํจ ๊ฐ๋ฅ์ฑ ์์ |
| - ORPO/DPO ํ์ดํ๋ผ์ธ ์ค๋น๋ฅผ ์ํด preference ๋ฐ์ดํฐ๋ ๋์์ ์์ง |
|
|
| **์ฆ์ ์คํ ๊ฐ๋ฅํ ์์
:** |
|
|
| ```python |
| # 1. Deduplication (MinHash) |
| pip install datasketch |
| # instruction text ๊ธฐ์ค MinHash dedup, threshold=0.8 |
| |
| # 2. ์ถ๊ฐ ๋ฐ์ดํฐ ๋ค์ด๋ก๋ |
| from datasets import load_dataset |
| ds = load_dataset("hPark/orca-ko") # ~200K ๊ณ ํ์ง ํ๊ตญ์ด |
| ds2 = load_dataset("maywell/synatra-orca") # ~300K ํฉ์ฑ |
| |
| # 3. ํ๊ตญ์ด Preference ๋ฐ์ดํฐ ์์ง (ORPO/DPO ์ค๋น) |
| pref = load_dataset("maywell/ko_Ultrafeedback") # ~60K preference ์ |
| ``` |
|
|
| **์์ ์๊ฐ:** ๋ฐ์ดํฐ ์ค๋น 2โ4์๊ฐ, ์ฌํ์ต์ ์ถ๊ฐ ์ค์ ํ ์งํ. |
|
|
| --- |
|
|
| ## 5. ์ข
ํฉ ํ๊ฐ ์์ฝ |
|
|
| ### ํ์ฌ ์ค์ ํ๊ฐ |
|
|
| | ํญ๋ชฉ | ์ค์ ๊ฐ | ํ๊ฐ | ๋น๊ณ | |
| |------|--------|------|------| |
| | Learning Rate | 2e-5 | โ
์ ์ | ์
๊ณ ํ์ค ์ ์ค์ | |
| | Cosine Decay | 5000 steps | โ
์ ์ | min_lr ~10% | |
| | Warmup | 150 steps (3%) | โ
์ ์ | 3-5% ๊ถ์ฅ ๋ฒ์ | |
| | Effective Batch | 64 seqs | โ
์ ์ | ์
๊ณ ํ์ค | |
| | Epochs | ~2 | โ ๏ธ ๋ถ์กฑ ๊ฐ๋ฅ | 3 epoch ํ์ค | |
| | NEFTune alpha | 10 | โ
์ ์ | 159K ๋ฐ์ดํฐ์ ๋ง์ | |
| | max_seq_len | 4096 | โ
์ ์ | ๋์ ํจ๋ฉ์ผ๋ก ํจ์จ์ | |
| | Weight Decay | 0.01 | โ
์ ์ | pretrain(0.1)์ 1/10 | |
| |
| ### ์ต์
๋ณ ์ถ์ฒ ์ฐ์ ์์ |
| |
| | ์ต์
| ์ถ์ฒ | ์ด์ | |
| |------|------|------| |
| | A. ์ถ๊ฐ SFT (epoch 4) | โ
๋์ | epoch ๋ถ์กฑ, ์ฆ์ ์คํ ๊ฐ๋ฅ | |
| | B. LR 1e-5๋ก ์ฌํ์ต | โ
๋์ | ์ถ๊ฐ ํ์ต ์ ํ์ | |
| | C. ORPO | โ
์ค๊ฐ | ๋ฐ์ดํฐ ์ค๋น ํ์ | |
| | D. DPO | โ
์ค๊ฐ | ORPO ๋์, ๋ ๊ฒ์ฆ๋จ | |
| | E. LoRA | โ ๏ธ ๋ฎ์ | ํ์ดํผํ๋ผ๋ฏธํฐ ํ์์๋ง ์ ์ฉ | |
| | F. ๋ฐ์ดํฐ ํ์ง ๊ฐ์ | โ
๋์ | ์ฅ๊ธฐ ํฌ์ ๋๋น ํจ๊ณผ ํผ | |
| | G. ๋ฐ์ดํฐ ์ถ๊ฐ (500K) | โ
์ค๊ฐ | ๊ณ ํ์ง ์์ค ์ ์ | |
| |
| ### ํ์ต ๊ณก์ ์ดํ |
| |
| ํ์ฌ SFT๋ **๊ฑด๊ฐํ๊ฒ ์๋ฃ**๋จ: |
| - Gradient norm ์์ , spike ์์ |
| - Loss ๋จ์กฐ ๊ฐ์ (๋ฏธ์์ ๋ณ๋์ ์ ์) |
| - Outlier 2.1%๋ ์ ์ ๋ฒ์ |
| - ์๋ ด ์ ํธ๊ฐ step 3000+ ์ดํ ๋ํ๋์ง๋ง ์์ง plateau๋ ์๋ |
| |
| **๊ฐ์ฅ ์ฐ๋ ค๋๋ ์ :** Validation loss ์์ โ ๊ณผ์ ํฉ ์ฌ๋ถ ๋ถ๋ช
ํ. **์ฆ์ val split ํ๋ณด ํ์.** |
| |
| --- |
| |
| *๋ถ์ ์๋ฃ. ๋ค์ ์คํ ์ ์ด ํ์ผ์ ๊ธฐ๋ฐ์ผ๋ก ์คํ ๋ฐฉํฅ ๊ฒฐ์ ๊ถ์ฅ.* |
| |