SFT ํ์ต ํ์ดํ๋ผ์ธ ๊ทผ๋ณธ ์ฌ์ค๊ณ ๋ณด๊ณ ์
์์ฑ์ผ: 2026-02-26
๋์ ํ๋ก์ ํธ: /PROJECT/0325120031_A/ghong/taketimes/llm-bang/
ํ์ฌ ์ํ: ๋ฐ๋ณต ํดํ 57%, EOS ์์ฑ ๋ถ์์
1. ํ์ฌ ๊ตฌํ์ ๊ทผ๋ณธ์ ๋ฌธ์ ์
๐ด Critical #1: Dynamic Padding์ด ์๋ํ์ง ์์ (๊ฐ์ฅ ํฐ ์ฑ๋ฅ ๋ญ๋น)
ํ์ผ: data/sft_dataset.py L139-146, train/sft.py L198-230
SFTDataset.__init__์์ ๋ชจ๋ ์ํ์ max_seq_len(4096)์ผ๋ก ๋ฏธ๋ฆฌ ํจ๋ฉํ๋ค:
# sft_dataset.py L139-141
input_ids = torch.full(
(max_seq_len,), fill_value=pad_token_id, dtype=torch.long
)
๊ทธ๋ฐ๋ฐ dynamic_collate_fn์ ๋ฐฐ์น ๋ด ์ต๋ ๊ธธ์ด์ ๋ง์ถฐ ํจ๋ฉํ๋๋ก ์ค๊ณ๋์ง๋ง, ์ด๋ฏธ ๋ชจ๋ ํ
์๊ฐ max_seq_len ๊ธธ์ด์ด๋ฏ๋ก raw_max = max_seq_len ํญ์ ๊ณ ์ . Dynamic padding์ด ์ฌ์ค์ ๋ฌดํจํ๋ ์ํ.
์ํฅ: ํ๊ท ์ํ์ค ๊ธธ์ด๊ฐ ~500 ํ ํฐ์ด๋ผ๋ฉด, ๋งค ์คํ ๋ง๋ค ~3600๊ฐ์ ํจ๋ฉ ํ ํฐ์ ๋ถํ์ํ๊ฒ ์ฒ๋ฆฌ. ํ์ต ์๋ ~3-8x ์ ํ, GPU FLOPs ๋ญ๋น.
์์ : __getitem__์์ ํจ๋ฉํ์ง ๋ง๊ณ , ์ค์ ๊ธธ์ด์ ํ
์๋ฅผ ๋ฐํ. dynamic_collate_fn์ด ๋ฐฐ์น๋ณ๋ก ํจ๋ฉํ๋๋ก.
๐ด Critical #2: ํธ๋ ์ผ์ด์ ์ EOS ํ ํฐ ์์ค
ํ์ผ: data/sft_dataset.py L130-134
# truncation์ด ๋ฐ์ํ๋ฉด
response_ids = response_ids[:allowed_response]
response_ids๋ "{output}</s>"๋ฅผ ์ธ์ฝ๋ฉํ ๊ฒ์ธ๋ฐ, ์๋ฆฌ๋ฉด ๋ง์ง๋ง์ </s> ํ ํฐ์ด ์ ๊ฑฐ๋๋ค. ํด๋น ์ํ์ EOS๋ฅผ ํ์ตํ ์ ์๋ค.
์ํฅ: truncated_count๊ฐ์ ์ํ์ด EOS ์์ด ํ์ต๋จ. ๊ธด ์๋ต โ ์๋ฆผ โ EOS ๋ฏธํ์ต โ ์์ฑ ์ ๋์์ด ๋ฐ๋ณต ์์ฑ์ ์ง์ ์ ์์ธ.
์์ :
response_ids = response_ids[:allowed_response - 1] + [self.eos_token_id]
๐ก Important #3: Label Shift ๋ก์ง ๋ฏธ๋ฌํ ๋ฒ๊ทธ ๊ฐ๋ฅ์ฑ
ํ์ผ: data/sft_dataset.py L148-157, train/trainer.py L263-274
ํ์ฌ ๊ตฌํ:
labels[resp_start-1 : resp_start-1+len(response_ids)] = response_ids_compute_loss์์ ๋ณ๋์ shift ์์ดcross_entropy(logits[i], labels[i])์ํ
์ด ํจํด์ "logits[i]๊ฐ position i+1์ ์์ธกํ๋ค"๋ ๊ฐ์ ํ์ ์ฌ๋ฐ๋ฅด์ง๋ง, ๋ง์ง๋ง response ํ ํฐ ์ดํ position์ ๋ํ ์์ธก์ด ํ์ต๋์ง ์๋๋ค. ์ฆ, EOS๋ฅผ ์ถ๋ ฅํ ํ ๋ค์ ํ ํฐ์ ํจ๋ฉ(๋๋ ์๋ฌด๊ฒ๋ ์๋ ๊ฒ)์ผ๋ก ์์ธกํ๋๋ก ํ์ตํ์ง ์์์, EOS ์ดํ์๋ ๊ณ์ ํ ํฐ์ ์์ฑํ๋ ๊ฒฝํฅ์ด ์์ ์ ์๋ค.
๊ถ์ฅ: response_ids์ ๋ง์ง๋ง ํ ํฐ(EOS) ๋ค์ position์ label๋ EOS๋ก ์ค์ ํ์ฌ "EOS ์ดํ์๋ EOS"๋ฅผ ํ์ต์ํด.
๐ก Important #4: Validation Split ์์
ํ์ผ: scripts/launch_sft.sh โ --val_data ์ธ์ ๋ฏธ์ ๋ฌ
Trainer์ validation ๋ฃจํ๊ฐ ๊ตฌํ๋ผ ์๊ณ (trainer.py L200-220), best checkpoint ์ ์ฅ๋ ๋์ง๋ง, ์ค์ ๋ก val_data๋ฅผ ์ ๋ฌํ์ง ์์ ๊ณผ์ ํฉ ๋ชจ๋ํฐ๋ง์ด ๋ถ๊ฐ๋ฅ.
๐ก Important #5: ํ์ต ์ํฌํฌ ๋ถ์กฑ
- Effective batch = 4 ร 2 ร 8 = 64 seqs/step
- 5000 steps ร 64 = 320,000 sample-steps
- 159k ์ํ โ ~2 epochs
SFT๋ ๋ณดํต 3-5 epochs๊ฐ ์ ์ . 2 epochs๋ ๋ฐ์ดํฐ๋ฅผ ์ถฉ๋ถํ ํ์ตํ์ง ๋ชปํจ.
๐ข Minor #6: Warmup ๋น์จ
150/5000 = 3%. SFT์์๋ ์ ์ ํ ๋ฒ์ (3-10%).
๐ข Minor #7: NEFTune ์ฌ์ฉ ์ค
noise_alpha=10.0์ผ๋ก NEFTune ์ ์ฉ ์ค. ์ด๊ฑด ์ข์ ์ค์ . ๋ค๋ง ๋ฐ๋ณต ํดํ๊ฐ ์ฌํ ๊ฒฝ์ฐ ํจ๊ณผ๊ฐ ์ ํ์ .
2. ์ ๊ณ ์ต๊ณ ์์ค SFT ๊ตฌํ๊ณผ์ ๋น๊ต
LLaMA-Factory
- completion_only_loss: prompt ํ ํฐ masking + response๋ง ํ์ต (ํ์ฌ ๊ตฌํ๊ณผ ๋์ผ)
- EOS ๋ณด์ฅ: ํธ๋ ์ผ์ด์ ์ ๋ฐ๋์ EOS ํ ํฐ ์ ์ง
- Packing: ์ฌ๋ฌ ์งง์ ์ํ์ ํ๋์ ์ํ์ค์ ํจํนํ์ฌ GPU ํ์ฉ๋ฅ ๊ทน๋ํ
- ๋ฐ์ดํฐ ํํฐ๋ง: ๋๋ฌด ์งง๊ฑฐ๋ ํ์ง ๋ฎ์ ์ํ ์๋ ์ ๊ฑฐ
TRL SFTTrainer
packing=True: ์ฌ๋ฌ ์ํ์ ํ๋์ ์ํ์ค์ ํจํน (ํจ๋ฉ ๋ญ๋น ์ ๊ฑฐ)DataCollatorForCompletionOnlyLM: response ํ ํฐ๋ง ํ์ตํ๋ collator- max_seq_length์ ๋ง์ถฐ ๋์ ํจ๋ฉ (๋ฐฐ์น๋ณ)
Axolotl
- Sample packing: ๊ธด ์ํ์ค ๋ด์ ์ฌ๋ฌ ๋ํ๋ฅผ ํจํน
- Sequence length warmup: ์งง์ ์ํ์ค๋ถํฐ ์์ํด์ ์ ์ง์ ์ผ๋ก ๊ธธ์ด ์ฆ๊ฐ
- Repetition penalty during training: ํ์ต ์ค ๋ฐ๋ณต ํจ๋ํฐ ์ ์ฉ ์ต์
OpenInstruct (Allen AI)
- ๋ฐ์ดํฐ ํ์ง ์ฐ์ : ๋ฐ์ดํฐ ํํฐ๋ง ํ์ดํ๋ผ์ธ ๊ฐ์กฐ
- ๊ธธ์ด ๊ธฐ๋ฐ ์ ๊ทํ: ๊ธด ์๋ต์ ๋ํ loss normalization
- ๋ค๋จ๊ณ ํ์ต: ์ผ๋ฐ SFT โ ๋๋ฉ์ธ SFT โ DPO/RLHF
๋ฐ๋ณต ํดํ ๋ฐฉ์ง๋ฅผ ์ํ ์ ๊ณ ๊ณตํต ํจํด
- EOS ํ ํฐ ํ์คํ ํ์ต: ๋ชจ๋ ์ํ์์ EOS๊ฐ label์ ํฌํจ๋๋๋ก ๋ณด์ฅ
- Repetition penalty loss: ๋์ผ n-gram ๋ฐ๋ณต ์ ์ถ๊ฐ ํ๋ํฐ
- Length normalization: ๊ธด ์๋ต์ loss๊ฐ ๊ณผ๋ํ๊ฒ ์ปค์ง์ง ์๋๋ก ์ ๊ทํ
- ๋ฐ์ดํฐ ํ์ง ํํฐ๋ง: ๋ฐ๋ณต์ด ์๋ ํ์ต ๋ฐ์ดํฐ ์์ฒด๋ฅผ ์ ๊ฑฐ
- KL divergence regularization: base model๊ณผ์ KL divergence ์ ์ฝ
3. Curriculum Learning ๋ฐฉ๋ฒ๋ก
IFD (Instruction Following Difficulty) Score
# ๊ฐ ์ํ์ ๋ํด:
# 1. base model๋ก response์ perplexity ๊ณ์ฐ
# 2. instruction ์์ด response๋ง์ perplexity ๊ณ์ฐ
# 3. IFD = conditional_ppl / unconditional_ppl
# IFD๊ฐ ๋ฎ์์๋ก "์ฌ์ด" ์ํ
def compute_ifd(model, tokenizer, instruction, response):
cond_ppl = compute_perplexity(model, instruction + response)
uncond_ppl = compute_perplexity(model, response)
return cond_ppl / uncond_ppl
KL-Divergence ๊ธฐ๋ฐ ๋ฐ์ดํฐ ์ ํ
- Base model ๋๋น ์๋ต ๋ถํฌ๊ฐ ํฌ๊ฒ ๋ค๋ฅธ ์ํ = ์ด๋ ค์ด ์ํ
- ์ฒ์์๋ KL์ด ์์ (์ฌ์ด) ์ํ, ๋์ค์ ํฐ ์ํ
์ค์ฉ์ ์ ๊ทผ: ๊ธธ์ด ๊ธฐ๋ฐ Curriculum
- Phase 1 (epoch 1): ์๋ต ๊ธธ์ด < 256 ํ ํฐ ์ํ๋ง
- Phase 2 (epoch 2): ์๋ต ๊ธธ์ด < 1024 ํ ํฐ
- Phase 3 (epoch 3+): ์ ์ฒด ๋ฐ์ดํฐ
4. ์ฆ์ ์ ์ฉ ๊ฐ๋ฅํ ์์ Top 5
Fix #1: Dynamic Padding ์ค์ ์๋ํ๋๋ก ์์ (์ฑ๋ฅ 3-8x ๊ฐ์ )
# data/sft_dataset.py โ __getitem__ ์์
# ๊ธฐ์กด: ๊ณ ์ max_seq_len ํจ๋ฉ๋ ํ
์ ๋ฐํ
# ์์ : ์ค์ ๊ธธ์ด์ ํ
์๋ง ๋ฐํ
class SFTDataset(Dataset):
def __init__(self, ...):
# ...๊ธฐ์กด ์ฝ๋...
# samples ์ ์ฅ ์ ํจ๋ฉ ์ ๊ฑฐ
self.samples: list[tuple[torch.Tensor, torch.Tensor]] = []
for prompt_text, response_text in raw_samples:
# ...๊ธฐ์กด ์ธ์ฝ๋ฉ ์ฝ๋...
seq_len = len(full_ids)
# ํจ๋ฉ ์์ด ์ค์ ๊ธธ์ด๋ง ์ ์ฅ
input_ids = torch.tensor(full_ids, dtype=torch.long)
labels = torch.full((seq_len,), fill_value=-1, dtype=torch.long)
resp_start = len(prompt_ids)
resp_label_start = max(0, resp_start - 1)
resp_label_end = resp_label_start + len(response_ids)
labels[resp_label_start:resp_label_end] = torch.tensor(
response_ids, dtype=torch.long
)
self.samples.append((input_ids, labels))
Fix #2: ํธ๋ ์ผ์ด์ ์ EOS ๋ณด์กด
# data/sft_dataset.py L130-134 ์์
# ๊ธฐ์กด:
# response_ids = response_ids[:allowed_response]
# ์์ :
if len(full_ids) > max_seq_len:
truncated_count += 1
allowed_response = max_seq_len - len(prompt_ids)
if allowed_response <= 1:
skipped_too_long += 1
continue
# EOS๋ฅผ ๋ฐ๋์ ๋ง์ง๋ง์ ์ ์ง
response_ids = response_ids[:allowed_response - 1] + [self.eos_token_id]
full_ids = prompt_ids + response_ids
Fix #3: Validation Split ์ถ๊ฐ
# ๋ฐ์ดํฐ ๋ถ๋ฆฌ ์คํฌ๋ฆฝํธ
python -c "
import json, random
random.seed(42)
with open('data/sft/train.jsonl') as f:
lines = f.readlines()
random.shuffle(lines)
split = int(len(lines) * 0.9)
with open('data/sft/train_split.jsonl', 'w') as f:
f.writelines(lines[:split])
with open('data/sft/val_split.jsonl', 'w') as f:
f.writelines(lines[split:])
print(f'Train: {split}, Val: {len(lines)-split}')
"
# launch_sft.sh์ ์ถ๊ฐ:
# --val_data data/sft/val_split.jsonl
Fix #4: Max Steps ์ฆ๊ฐ (3-5 epochs)
# launch_sft.sh ์์
# 159k samples / 64 effective_batch โ 2484 steps/epoch
# 4 epochs โ ~10,000 steps
MAX_STEPS=10000
WARMUP_STEPS=300 # 3% of 10000
Fix #5: ํ์ต ๋ฐ์ดํฐ์์ ๋ฐ๋ณต ํจํด ํํฐ๋ง
# data/filter_repetitive.py
import json, re
def has_repetition(text, n=3, threshold=0.3):
"""n-gram ๋ฐ๋ณต ๋น์จ์ด threshold ์ด์์ด๋ฉด True"""
words = text.split()
if len(words) < n * 2:
return False
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
unique_ratio = len(set(ngrams)) / len(ngrams)
return unique_ratio < (1 - threshold)
filtered = 0
with open('data/sft/train.jsonl') as fin, \
open('data/sft/train_clean.jsonl', 'w') as fout:
for line in fin:
obj = json.loads(line)
text = obj.get('output', '') or ''
if 'conversations' in obj:
text = ' '.join(t['content'] for t in obj['conversations']
if t['role'] == 'assistant')
if not has_repetition(text):
fout.write(line)
else:
filtered += 1
print(f"Filtered {filtered} repetitive samples")
5. ๊ทผ๋ณธ์ ์ฌํ์ต ์๋๋ฆฌ์ค
Phase 1: ๋ฐ์ดํฐ ์ค๋น (1-2์๊ฐ)
- ๋ฐ๋ณต ํจํด์ด ์๋ ๋ฐ์ดํฐ ํํฐ๋ง
- Train/Val split (90/10)
- ๋ฐ์ดํฐ ํต๊ณ ํ์ธ (๊ธธ์ด ๋ถํฌ, ์นดํ ๊ณ ๋ฆฌ ๋ถํฌ)
Phase 2: ์ฝ๋ ์์ (1-2์๊ฐ)
- โ
sft_dataset.py: ํจ๋ฉ ์ ๊ฑฐ, EOS ๋ณด์กด, ๊ฐ๋ณ ๊ธธ์ด ๋ฐํ - โ
sft.py์dynamic_collate_fn: ์ด๋ฏธ ๊ตฌํ๋จ (dataset ์์ ๋ง ํ๋ฉด ์๋) - โ
launch_sft.sh: val_data ์ถ๊ฐ, max_steps 10000
Phase 3: ํ์ต (8-12์๊ฐ, 8รB200 ๊ธฐ์ค)
# ์์ ๋ launch_sft.sh
MAX_STEPS=10000
BATCH_SIZE=4
GRAD_ACCUM=2
LR="2.0e-5"
WARMUP_STEPS=300
torchrun --nproc_per_node=8 train/sft.py \
--base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \
--sft_data data/sft/train_clean.jsonl \
--val_data data/sft/val_split.jsonl \
--max_steps 10000 \
--batch_size 4 \
--grad_accum 2 \
--lr 2.0e-5 \
--warmup_steps 300 \
--use_fp8
Phase 4: ํ๊ฐ (1์๊ฐ)
- Val loss ๋ชจ๋ํฐ๋ง (TensorBoard)
- ๋ฐ๋ณต ํดํ์จ ์ธก์ (๊ธฐ์กด eval ์คํฌ๋ฆฝํธ)
- Best checkpoint ์ ํ (val_loss ๊ธฐ์ค)
์์ ์์์๊ฐ
- Dynamic padding ์์ ํ ํ์ต ์๋: 3-5x ํฅ์ (4096โ~600 avg ๊ธฐ์ค)
- 10000 steps: ๊ธฐ์กด ๋๋น ๋น์ทํ๊ฑฐ๋ ๋ ๋น ๋ฅผ ์ ์์
- ์ด 12-16์๊ฐ (๋ฐ์ดํฐ ์ค๋น ~ ํ๊ฐ ์๋ฃ)
6. ์์ ๊ฐ์ ํจ๊ณผ
| ์์ ํญ๋ชฉ | ํ์ฌ | ์์ ๊ฐ์ | ๊ทผ๊ฑฐ |
|---|---|---|---|
| Dynamic padding | 4096 ๊ณ ์ | ํ์ต์๋ 3-8xโ | ํ๊ท ~600 ํ ํฐ ์ ํจ๋ฉ 85% ๊ฐ์ |
| EOS ๋ณด์กด | ํธ๋ ์ผ์ด์ ์ ์์ค | ๋ฐ๋ณตํดํ 57%โ~20% | EOS ํ์ต ๋๋ฝ์ด ๋ฐ๋ณต์ ์ง์ ์์ธ |
| ๋ฐ์ดํฐ ํํฐ๋ง | ๋ฐ๋ณต ๋ฐ์ดํฐ ํฌํจ | ๋ฐ๋ณตํดํ โ ~10% | ๋ฐ๋ณต ํ์ต๋ฐ์ดํฐ ์ ๊ฑฐ |
| Val split | ์์ | ๊ณผ์ ํฉ ์กฐ๊ธฐ ๊ฐ์ง | best checkpoint ์ ํ ๊ฐ๋ฅ |
| Epoch ์ฆ๊ฐ | ~2 epochs | ์๋ ด ๊ฐ์ | 3-5 epochs์ด SFT ํ์ค |
| ์ข ํฉ | ๋ฐ๋ณต 57% | ๋ฐ๋ณต <10% | ์ ์์ ๋ชจ๋ ์ ์ฉ ์ |
7. ๊ถ์ฅ ์ํคํ ์ฒ (์ฅ๊ธฐ)
ํ์ฌ ์ง์ ๊ตฌํํ SFT ํ์ดํ๋ผ์ธ์ TRL SFTTrainer ๋๋ LLaMA-Factory๋ก ๊ต์ฒด ๊ถ์ฅ:
- Packing ์ง์: ํจ๋ฉ ๋ญ๋น ์์ ์ ๊ฑฐ
- ๊ฒ์ฆ๋ EOS ์ฒ๋ฆฌ: ์๋ฐฑ ๊ฐ ํ๋ก์ ํธ์์ ๊ฒ์ฆ
- DPO/RLHF ํ์ดํ๋ผ์ธ ์ฐ๊ณ: SFT ํ alignment ํ์ต์ผ๋ก ์์ฐ์ค๋ฝ๊ฒ ์ดํ
- Sample packing + Flash Attention: ์ต์ ํ๋ ๋ฉ๋ชจ๋ฆฌ/์๋
๋จ, ํ์ฌ ์ปค์คํ
๋ชจ๋ธ(LLM ํด๋์ค)๊ณผ์ ํธํ์ฑ ํ์ธ ํ์. HuggingFace ํฌ๋งท์ผ๋ก ๋ณํ ํ ์ฌ์ฉํ๋ฉด ๋ฐ๋ก ์ ์ฉ ๊ฐ๋ฅ.
์์ฝ: ์ฆ์ ํด์ผ ํ ์ผ (์ฐ์ ์์)
- ๐ด sft_dataset.py: ํจ๋ฉ ์ ๊ฑฐ + EOS ๋ณด์กด (30๋ถ)
- ๐ด ๋ฐ์ดํฐ ํํฐ๋ง: ๋ฐ๋ณต ํจํด ์๋ ์ํ ์ ๊ฑฐ (30๋ถ)
- ๐ก Val split ์์ฑ + launch_sft.sh ์์ (15๋ถ)
- ๐ก Max steps 10000์ผ๋ก ์ฆ๊ฐ (์ค์ ๋ง ๋ณ๊ฒฝ)
- ๐ข ์ฌํ์ต ์์ (์์ ์๋ฃ ํ)