frankenstallm / source /eval /sft_redesign_report.md
pathcosmos's picture
Upload folder using huggingface_hub (#29)
5b1ff4d
|
raw
history blame
12.6 kB

SFT ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ ๊ทผ๋ณธ ์žฌ์„ค๊ณ„ ๋ณด๊ณ ์„œ

์ž‘์„ฑ์ผ: 2026-02-26
๋Œ€์ƒ ํ”„๋กœ์ ํŠธ: /PROJECT/0325120031_A/ghong/taketimes/llm-bang/
ํ˜„์žฌ ์ƒํƒœ: ๋ฐ˜๋ณต ํ‡ดํ™” 57%, EOS ์ƒ์„ฑ ๋ถˆ์•ˆ์ •


1. ํ˜„์žฌ ๊ตฌํ˜„์˜ ๊ทผ๋ณธ์  ๋ฌธ์ œ์ 

๐Ÿ”ด Critical #1: Dynamic Padding์ด ์ž‘๋™ํ•˜์ง€ ์•Š์Œ (๊ฐ€์žฅ ํฐ ์„ฑ๋Šฅ ๋‚ญ๋น„)

ํŒŒ์ผ: data/sft_dataset.py L139-146, train/sft.py L198-230

SFTDataset.__init__์—์„œ ๋ชจ๋“  ์ƒ˜ํ”Œ์„ max_seq_len(4096)์œผ๋กœ ๋ฏธ๋ฆฌ ํŒจ๋”ฉํ•œ๋‹ค:

# sft_dataset.py L139-141
input_ids = torch.full(
    (max_seq_len,), fill_value=pad_token_id, dtype=torch.long
)

๊ทธ๋Ÿฐ๋ฐ dynamic_collate_fn์€ ๋ฐฐ์น˜ ๋‚ด ์ตœ๋Œ€ ๊ธธ์ด์— ๋งž์ถฐ ํŒจ๋”ฉํ•˜๋„๋ก ์„ค๊ณ„๋์ง€๋งŒ, ์ด๋ฏธ ๋ชจ๋“  ํ…์„œ๊ฐ€ max_seq_len ๊ธธ์ด์ด๋ฏ€๋กœ raw_max = max_seq_len ํ•ญ์ƒ ๊ณ ์ •. Dynamic padding์ด ์‚ฌ์‹ค์ƒ ๋ฌดํšจํ™”๋œ ์ƒํƒœ.

์˜ํ–ฅ: ํ‰๊ท  ์‹œํ€€์Šค ๊ธธ์ด๊ฐ€ ~500 ํ† ํฐ์ด๋ผ๋ฉด, ๋งค ์Šคํ…๋งˆ๋‹ค ~3600๊ฐœ์˜ ํŒจ๋”ฉ ํ† ํฐ์„ ๋ถˆํ•„์š”ํ•˜๊ฒŒ ์ฒ˜๋ฆฌ. ํ•™์Šต ์†๋„ ~3-8x ์ €ํ•˜, GPU FLOPs ๋‚ญ๋น„.

์ˆ˜์ •: __getitem__์—์„œ ํŒจ๋”ฉํ•˜์ง€ ๋ง๊ณ , ์‹ค์ œ ๊ธธ์ด์˜ ํ…์„œ๋ฅผ ๋ฐ˜ํ™˜. dynamic_collate_fn์ด ๋ฐฐ์น˜๋ณ„๋กœ ํŒจ๋”ฉํ•˜๋„๋ก.

๐Ÿ”ด Critical #2: ํŠธ๋ ์ผ€์ด์…˜ ์‹œ EOS ํ† ํฐ ์†์‹ค

ํŒŒ์ผ: data/sft_dataset.py L130-134

# truncation์ด ๋ฐœ์ƒํ•˜๋ฉด
response_ids = response_ids[:allowed_response]

response_ids๋Š” "{output}</s>"๋ฅผ ์ธ์ฝ”๋”ฉํ•œ ๊ฒƒ์ธ๋ฐ, ์ž˜๋ฆฌ๋ฉด ๋งˆ์ง€๋ง‰์˜ </s> ํ† ํฐ์ด ์ œ๊ฑฐ๋œ๋‹ค. ํ•ด๋‹น ์ƒ˜ํ”Œ์€ EOS๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์—†๋‹ค.

์˜ํ–ฅ: truncated_count๊ฐœ์˜ ์ƒ˜ํ”Œ์ด EOS ์—†์ด ํ•™์Šต๋จ. ๊ธด ์‘๋‹ต โ†’ ์ž˜๋ฆผ โ†’ EOS ๋ฏธํ•™์Šต โ†’ ์ƒ์„ฑ ์‹œ ๋์—†์ด ๋ฐ˜๋ณต ์ƒ์„ฑ์˜ ์ง์ ‘์  ์›์ธ.

์ˆ˜์ •:

response_ids = response_ids[:allowed_response - 1] + [self.eos_token_id]

๐ŸŸก Important #3: Label Shift ๋กœ์ง ๋ฏธ๋ฌ˜ํ•œ ๋ฒ„๊ทธ ๊ฐ€๋Šฅ์„ฑ

ํŒŒ์ผ: data/sft_dataset.py L148-157, train/trainer.py L263-274

ํ˜„์žฌ ๊ตฌํ˜„:

  • labels[resp_start-1 : resp_start-1+len(response_ids)] = response_ids
  • _compute_loss์—์„œ ๋ณ„๋„์˜ shift ์—†์ด cross_entropy(logits[i], labels[i]) ์ˆ˜ํ–‰

์ด ํŒจํ„ด์€ "logits[i]๊ฐ€ position i+1์„ ์˜ˆ์ธกํ•œ๋‹ค"๋Š” ๊ฐ€์ • ํ•˜์— ์˜ฌ๋ฐ”๋ฅด์ง€๋งŒ, ๋งˆ์ง€๋ง‰ response ํ† ํฐ ์ดํ›„ position์— ๋Œ€ํ•œ ์˜ˆ์ธก์ด ํ•™์Šต๋˜์ง€ ์•Š๋Š”๋‹ค. ์ฆ‰, EOS๋ฅผ ์ถœ๋ ฅํ•œ ํ›„ ๋‹ค์Œ ํ† ํฐ์„ ํŒจ๋”ฉ(๋˜๋Š” ์•„๋ฌด๊ฒƒ๋„ ์•„๋‹Œ ๊ฒƒ)์œผ๋กœ ์˜ˆ์ธกํ•˜๋„๋ก ํ•™์Šตํ•˜์ง€ ์•Š์•„์„œ, EOS ์ดํ›„์—๋„ ๊ณ„์† ํ† ํฐ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ์„ ์ˆ˜ ์žˆ๋‹ค.

๊ถŒ์žฅ: response_ids์˜ ๋งˆ์ง€๋ง‰ ํ† ํฐ(EOS) ๋‹ค์Œ position์˜ label๋„ EOS๋กœ ์„ค์ •ํ•˜์—ฌ "EOS ์ดํ›„์—๋„ EOS"๋ฅผ ํ•™์Šต์‹œํ‚ด.

๐ŸŸก Important #4: Validation Split ์—†์Œ

ํŒŒ์ผ: scripts/launch_sft.sh โ€” --val_data ์ธ์ž ๋ฏธ์ „๋‹ฌ

Trainer์— validation ๋ฃจํ”„๊ฐ€ ๊ตฌํ˜„๋ผ ์žˆ๊ณ  (trainer.py L200-220), best checkpoint ์ €์žฅ๋„ ๋˜์ง€๋งŒ, ์‹ค์ œ๋กœ val_data๋ฅผ ์ „๋‹ฌํ•˜์ง€ ์•Š์•„ ๊ณผ์ ํ•ฉ ๋ชจ๋‹ˆํ„ฐ๋ง์ด ๋ถˆ๊ฐ€๋Šฅ.

๐ŸŸก Important #5: ํ•™์Šต ์—ํฌํฌ ๋ถ€์กฑ

  • Effective batch = 4 ร— 2 ร— 8 = 64 seqs/step
  • 5000 steps ร— 64 = 320,000 sample-steps
  • 159k ์ƒ˜ํ”Œ โ†’ ~2 epochs

SFT๋Š” ๋ณดํ†ต 3-5 epochs๊ฐ€ ์ ์ ˆ. 2 epochs๋Š” ๋ฐ์ดํ„ฐ๋ฅผ ์ถฉ๋ถ„ํžˆ ํ•™์Šตํ•˜์ง€ ๋ชปํ•จ.

๐ŸŸข Minor #6: Warmup ๋น„์œจ

150/5000 = 3%. SFT์—์„œ๋Š” ์ ์ ˆํ•œ ๋ฒ”์œ„ (3-10%).

๐ŸŸข Minor #7: NEFTune ์‚ฌ์šฉ ์ค‘

noise_alpha=10.0์œผ๋กœ NEFTune ์ ์šฉ ์ค‘. ์ด๊ฑด ์ข‹์€ ์„ค์ •. ๋‹ค๋งŒ ๋ฐ˜๋ณต ํ‡ดํ™”๊ฐ€ ์‹ฌํ•œ ๊ฒฝ์šฐ ํšจ๊ณผ๊ฐ€ ์ œํ•œ์ .


2. ์—…๊ณ„ ์ตœ๊ณ  ์ˆ˜์ค€ SFT ๊ตฌํ˜„๊ณผ์˜ ๋น„๊ต

LLaMA-Factory

  • completion_only_loss: prompt ํ† ํฐ masking + response๋งŒ ํ•™์Šต (ํ˜„์žฌ ๊ตฌํ˜„๊ณผ ๋™์ผ)
  • EOS ๋ณด์žฅ: ํŠธ๋ ์ผ€์ด์…˜ ์‹œ ๋ฐ˜๋“œ์‹œ EOS ํ† ํฐ ์œ ์ง€
  • Packing: ์—ฌ๋Ÿฌ ์งง์€ ์ƒ˜ํ”Œ์„ ํ•˜๋‚˜์˜ ์‹œํ€€์Šค์— ํŒจํ‚นํ•˜์—ฌ GPU ํ™œ์šฉ๋ฅ  ๊ทน๋Œ€ํ™”
  • ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง: ๋„ˆ๋ฌด ์งง๊ฑฐ๋‚˜ ํ’ˆ์งˆ ๋‚ฎ์€ ์ƒ˜ํ”Œ ์ž๋™ ์ œ๊ฑฐ

TRL SFTTrainer

  • packing=True: ์—ฌ๋Ÿฌ ์ƒ˜ํ”Œ์„ ํ•˜๋‚˜์˜ ์‹œํ€€์Šค์— ํŒจํ‚น (ํŒจ๋”ฉ ๋‚ญ๋น„ ์ œ๊ฑฐ)
  • DataCollatorForCompletionOnlyLM: response ํ† ํฐ๋งŒ ํ•™์Šตํ•˜๋Š” collator
  • max_seq_length์— ๋งž์ถฐ ๋™์  ํŒจ๋”ฉ (๋ฐฐ์น˜๋ณ„)

Axolotl

  • Sample packing: ๊ธด ์‹œํ€€์Šค ๋‚ด์— ์—ฌ๋Ÿฌ ๋Œ€ํ™”๋ฅผ ํŒจํ‚น
  • Sequence length warmup: ์งง์€ ์‹œํ€€์Šค๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ด์„œ ์ ์ง„์ ์œผ๋กœ ๊ธธ์ด ์ฆ๊ฐ€
  • Repetition penalty during training: ํ•™์Šต ์ค‘ ๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ ์ ์šฉ ์˜ต์…˜

OpenInstruct (Allen AI)

  • ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ์šฐ์„ : ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ•์กฐ
  • ๊ธธ์ด ๊ธฐ๋ฐ˜ ์ •๊ทœํ™”: ๊ธด ์‘๋‹ต์— ๋Œ€ํ•œ loss normalization
  • ๋‹ค๋‹จ๊ณ„ ํ•™์Šต: ์ผ๋ฐ˜ SFT โ†’ ๋„๋ฉ”์ธ SFT โ†’ DPO/RLHF

๋ฐ˜๋ณต ํ‡ดํ™” ๋ฐฉ์ง€๋ฅผ ์œ„ํ•œ ์—…๊ณ„ ๊ณตํ†ต ํŒจํ„ด

  1. EOS ํ† ํฐ ํ™•์‹คํ•œ ํ•™์Šต: ๋ชจ๋“  ์ƒ˜ํ”Œ์—์„œ EOS๊ฐ€ label์— ํฌํ•จ๋˜๋„๋ก ๋ณด์žฅ
  2. Repetition penalty loss: ๋™์ผ n-gram ๋ฐ˜๋ณต ์‹œ ์ถ”๊ฐ€ ํŽ˜๋„ํ‹ฐ
  3. Length normalization: ๊ธด ์‘๋‹ต์˜ loss๊ฐ€ ๊ณผ๋„ํ•˜๊ฒŒ ์ปค์ง€์ง€ ์•Š๋„๋ก ์ •๊ทœํ™”
  4. ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ํ•„ํ„ฐ๋ง: ๋ฐ˜๋ณต์ด ์žˆ๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ ์ž์ฒด๋ฅผ ์ œ๊ฑฐ
  5. KL divergence regularization: base model๊ณผ์˜ KL divergence ์ œ์•ฝ

3. Curriculum Learning ๋ฐฉ๋ฒ•๋ก 

IFD (Instruction Following Difficulty) Score

# ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด:
# 1. base model๋กœ response์˜ perplexity ๊ณ„์‚ฐ
# 2. instruction ์—†์ด response๋งŒ์˜ perplexity ๊ณ„์‚ฐ
# 3. IFD = conditional_ppl / unconditional_ppl
# IFD๊ฐ€ ๋‚ฎ์„์ˆ˜๋ก "์‰ฌ์šด" ์ƒ˜ํ”Œ

def compute_ifd(model, tokenizer, instruction, response):
    cond_ppl = compute_perplexity(model, instruction + response)
    uncond_ppl = compute_perplexity(model, response)
    return cond_ppl / uncond_ppl

KL-Divergence ๊ธฐ๋ฐ˜ ๋ฐ์ดํ„ฐ ์„ ํƒ

  • Base model ๋Œ€๋น„ ์‘๋‹ต ๋ถ„ํฌ๊ฐ€ ํฌ๊ฒŒ ๋‹ค๋ฅธ ์ƒ˜ํ”Œ = ์–ด๋ ค์šด ์ƒ˜ํ”Œ
  • ์ฒ˜์Œ์—๋Š” KL์ด ์ž‘์€ (์‰ฌ์šด) ์ƒ˜ํ”Œ, ๋‚˜์ค‘์— ํฐ ์ƒ˜ํ”Œ

์‹ค์šฉ์  ์ ‘๊ทผ: ๊ธธ์ด ๊ธฐ๋ฐ˜ Curriculum

  • Phase 1 (epoch 1): ์‘๋‹ต ๊ธธ์ด < 256 ํ† ํฐ ์ƒ˜ํ”Œ๋งŒ
  • Phase 2 (epoch 2): ์‘๋‹ต ๊ธธ์ด < 1024 ํ† ํฐ
  • Phase 3 (epoch 3+): ์ „์ฒด ๋ฐ์ดํ„ฐ

4. ์ฆ‰์‹œ ์ ์šฉ ๊ฐ€๋Šฅํ•œ ์ˆ˜์ • Top 5

Fix #1: Dynamic Padding ์‹ค์ œ ์ž‘๋™ํ•˜๋„๋ก ์ˆ˜์ • (์„ฑ๋Šฅ 3-8x ๊ฐœ์„ )

# data/sft_dataset.py โ€” __getitem__ ์ˆ˜์ •
# ๊ธฐ์กด: ๊ณ ์ • max_seq_len ํŒจ๋”ฉ๋œ ํ…์„œ ๋ฐ˜ํ™˜
# ์ˆ˜์ •: ์‹ค์ œ ๊ธธ์ด์˜ ํ…์„œ๋งŒ ๋ฐ˜ํ™˜

class SFTDataset(Dataset):
    def __init__(self, ...):
        # ...๊ธฐ์กด ์ฝ”๋“œ...
        # samples ์ €์žฅ ์‹œ ํŒจ๋”ฉ ์ œ๊ฑฐ
        self.samples: list[tuple[torch.Tensor, torch.Tensor]] = []
        
        for prompt_text, response_text in raw_samples:
            # ...๊ธฐ์กด ์ธ์ฝ”๋”ฉ ์ฝ”๋“œ...
            
            seq_len = len(full_ids)
            # ํŒจ๋”ฉ ์—†์ด ์‹ค์ œ ๊ธธ์ด๋งŒ ์ €์žฅ
            input_ids = torch.tensor(full_ids, dtype=torch.long)
            
            labels = torch.full((seq_len,), fill_value=-1, dtype=torch.long)
            resp_start = len(prompt_ids)
            resp_label_start = max(0, resp_start - 1)
            resp_label_end = resp_label_start + len(response_ids)
            labels[resp_label_start:resp_label_end] = torch.tensor(
                response_ids, dtype=torch.long
            )
            
            self.samples.append((input_ids, labels))

Fix #2: ํŠธ๋ ์ผ€์ด์…˜ ์‹œ EOS ๋ณด์กด

# data/sft_dataset.py L130-134 ์ˆ˜์ •
# ๊ธฐ์กด:
#   response_ids = response_ids[:allowed_response]
# ์ˆ˜์ •:
if len(full_ids) > max_seq_len:
    truncated_count += 1
    allowed_response = max_seq_len - len(prompt_ids)
    if allowed_response <= 1:
        skipped_too_long += 1
        continue
    # EOS๋ฅผ ๋ฐ˜๋“œ์‹œ ๋งˆ์ง€๋ง‰์— ์œ ์ง€
    response_ids = response_ids[:allowed_response - 1] + [self.eos_token_id]
    full_ids = prompt_ids + response_ids

Fix #3: Validation Split ์ถ”๊ฐ€

# ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ ์Šคํฌ๋ฆฝํŠธ
python -c "
import json, random
random.seed(42)
with open('data/sft/train.jsonl') as f:
    lines = f.readlines()
random.shuffle(lines)
split = int(len(lines) * 0.9)
with open('data/sft/train_split.jsonl', 'w') as f:
    f.writelines(lines[:split])
with open('data/sft/val_split.jsonl', 'w') as f:
    f.writelines(lines[split:])
print(f'Train: {split}, Val: {len(lines)-split}')
"

# launch_sft.sh์— ์ถ”๊ฐ€:
# --val_data data/sft/val_split.jsonl

Fix #4: Max Steps ์ฆ๊ฐ€ (3-5 epochs)

# launch_sft.sh ์ˆ˜์ •
# 159k samples / 64 effective_batch โ‰ˆ 2484 steps/epoch
# 4 epochs โ†’ ~10,000 steps
MAX_STEPS=10000
WARMUP_STEPS=300  # 3% of 10000

Fix #5: ํ•™์Šต ๋ฐ์ดํ„ฐ์—์„œ ๋ฐ˜๋ณต ํŒจํ„ด ํ•„ํ„ฐ๋ง

# data/filter_repetitive.py
import json, re

def has_repetition(text, n=3, threshold=0.3):
    """n-gram ๋ฐ˜๋ณต ๋น„์œจ์ด threshold ์ด์ƒ์ด๋ฉด True"""
    words = text.split()
    if len(words) < n * 2:
        return False
    ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
    unique_ratio = len(set(ngrams)) / len(ngrams)
    return unique_ratio < (1 - threshold)

filtered = 0
with open('data/sft/train.jsonl') as fin, \
     open('data/sft/train_clean.jsonl', 'w') as fout:
    for line in fin:
        obj = json.loads(line)
        text = obj.get('output', '') or ''
        if 'conversations' in obj:
            text = ' '.join(t['content'] for t in obj['conversations'] 
                          if t['role'] == 'assistant')
        if not has_repetition(text):
            fout.write(line)
        else:
            filtered += 1

print(f"Filtered {filtered} repetitive samples")

5. ๊ทผ๋ณธ์  ์žฌํ•™์Šต ์‹œ๋‚˜๋ฆฌ์˜ค

Phase 1: ๋ฐ์ดํ„ฐ ์ค€๋น„ (1-2์‹œ๊ฐ„)

  1. ๋ฐ˜๋ณต ํŒจํ„ด์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง
  2. Train/Val split (90/10)
  3. ๋ฐ์ดํ„ฐ ํ†ต๊ณ„ ํ™•์ธ (๊ธธ์ด ๋ถ„ํฌ, ์นดํ…Œ๊ณ ๋ฆฌ ๋ถ„ํฌ)

Phase 2: ์ฝ”๋“œ ์ˆ˜์ • (1-2์‹œ๊ฐ„)

  1. โœ… sft_dataset.py: ํŒจ๋”ฉ ์ œ๊ฑฐ, EOS ๋ณด์กด, ๊ฐ€๋ณ€ ๊ธธ์ด ๋ฐ˜ํ™˜
  2. โœ… sft.py์˜ dynamic_collate_fn: ์ด๋ฏธ ๊ตฌํ˜„๋จ (dataset ์ˆ˜์ •๋งŒ ํ•˜๋ฉด ์ž‘๋™)
  3. โœ… launch_sft.sh: val_data ์ถ”๊ฐ€, max_steps 10000

Phase 3: ํ•™์Šต (8-12์‹œ๊ฐ„, 8ร—B200 ๊ธฐ์ค€)

# ์ˆ˜์ •๋œ launch_sft.sh
MAX_STEPS=10000
BATCH_SIZE=4
GRAD_ACCUM=2
LR="2.0e-5"
WARMUP_STEPS=300

torchrun --nproc_per_node=8 train/sft.py \
    --base_checkpoint checkpoints/korean_1b_fp8_run1/checkpoint-0034000 \
    --sft_data data/sft/train_clean.jsonl \
    --val_data data/sft/val_split.jsonl \
    --max_steps 10000 \
    --batch_size 4 \
    --grad_accum 2 \
    --lr 2.0e-5 \
    --warmup_steps 300 \
    --use_fp8

Phase 4: ํ‰๊ฐ€ (1์‹œ๊ฐ„)

  1. Val loss ๋ชจ๋‹ˆํ„ฐ๋ง (TensorBoard)
  2. ๋ฐ˜๋ณต ํ‡ดํ™”์œจ ์ธก์ • (๊ธฐ์กด eval ์Šคํฌ๋ฆฝํŠธ)
  3. Best checkpoint ์„ ํƒ (val_loss ๊ธฐ์ค€)

์˜ˆ์ƒ ์†Œ์š”์‹œ๊ฐ„

  • Dynamic padding ์ˆ˜์ • ํ›„ ํ•™์Šต ์†๋„: 3-5x ํ–ฅ์ƒ (4096โ†’~600 avg ๊ธฐ์ค€)
  • 10000 steps: ๊ธฐ์กด ๋Œ€๋น„ ๋น„์Šทํ•˜๊ฑฐ๋‚˜ ๋” ๋น ๋ฅผ ์ˆ˜ ์žˆ์Œ
  • ์ด 12-16์‹œ๊ฐ„ (๋ฐ์ดํ„ฐ ์ค€๋น„ ~ ํ‰๊ฐ€ ์™„๋ฃŒ)

6. ์˜ˆ์ƒ ๊ฐœ์„  ํšจ๊ณผ

์ˆ˜์ • ํ•ญ๋ชฉ ํ˜„์žฌ ์˜ˆ์ƒ ๊ฐœ์„  ๊ทผ๊ฑฐ
Dynamic padding 4096 ๊ณ ์ • ํ•™์Šต์†๋„ 3-8xโ†‘ ํ‰๊ท  ~600 ํ† ํฐ ์‹œ ํŒจ๋”ฉ 85% ๊ฐ์†Œ
EOS ๋ณด์กด ํŠธ๋ ์ผ€์ด์…˜ ์‹œ ์†์‹ค ๋ฐ˜๋ณตํ‡ดํ™” 57%โ†’~20% EOS ํ•™์Šต ๋ˆ„๋ฝ์ด ๋ฐ˜๋ณต์˜ ์ง์ ‘ ์›์ธ
๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง ๋ฐ˜๋ณต ๋ฐ์ดํ„ฐ ํฌํ•จ ๋ฐ˜๋ณตํ‡ดํ™” โ†’ ~10% ๋ฐ˜๋ณต ํ•™์Šต๋ฐ์ดํ„ฐ ์ œ๊ฑฐ
Val split ์—†์Œ ๊ณผ์ ํ•ฉ ์กฐ๊ธฐ ๊ฐ์ง€ best checkpoint ์„ ํƒ ๊ฐ€๋Šฅ
Epoch ์ฆ๊ฐ€ ~2 epochs ์ˆ˜๋ ด ๊ฐœ์„  3-5 epochs์ด SFT ํ‘œ์ค€
์ข…ํ•ฉ ๋ฐ˜๋ณต 57% ๋ฐ˜๋ณต <10% ์œ„ ์ˆ˜์ • ๋ชจ๋‘ ์ ์šฉ ์‹œ

7. ๊ถŒ์žฅ ์•„ํ‚คํ…์ฒ˜ (์žฅ๊ธฐ)

ํ˜„์žฌ ์ง์ ‘ ๊ตฌํ˜„ํ•œ SFT ํŒŒ์ดํ”„๋ผ์ธ์„ TRL SFTTrainer ๋˜๋Š” LLaMA-Factory๋กœ ๊ต์ฒด ๊ถŒ์žฅ:

  1. Packing ์ง€์›: ํŒจ๋”ฉ ๋‚ญ๋น„ ์™„์ „ ์ œ๊ฑฐ
  2. ๊ฒ€์ฆ๋œ EOS ์ฒ˜๋ฆฌ: ์ˆ˜๋ฐฑ ๊ฐœ ํ”„๋กœ์ ํŠธ์—์„œ ๊ฒ€์ฆ
  3. DPO/RLHF ํŒŒ์ดํ”„๋ผ์ธ ์—ฐ๊ณ„: SFT ํ›„ alignment ํ•™์Šต์œผ๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ดํ–‰
  4. Sample packing + Flash Attention: ์ตœ์ ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ/์†๋„

๋‹จ, ํ˜„์žฌ ์ปค์Šคํ…€ ๋ชจ๋ธ(LLM ํด๋ž˜์Šค)๊ณผ์˜ ํ˜ธํ™˜์„ฑ ํ™•์ธ ํ•„์š”. HuggingFace ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜ ํ›„ ์‚ฌ์šฉํ•˜๋ฉด ๋ฐ”๋กœ ์ ์šฉ ๊ฐ€๋Šฅ.


์š”์•ฝ: ์ฆ‰์‹œ ํ•ด์•ผ ํ•  ์ผ (์šฐ์„ ์ˆœ์œ„)

  1. ๐Ÿ”ด sft_dataset.py: ํŒจ๋”ฉ ์ œ๊ฑฐ + EOS ๋ณด์กด (30๋ถ„)
  2. ๐Ÿ”ด ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง: ๋ฐ˜๋ณต ํŒจํ„ด ์žˆ๋Š” ์ƒ˜ํ”Œ ์ œ๊ฑฐ (30๋ถ„)
  3. ๐ŸŸก Val split ์ƒ์„ฑ + launch_sft.sh ์ˆ˜์ • (15๋ถ„)
  4. ๐ŸŸก Max steps 10000์œผ๋กœ ์ฆ๊ฐ€ (์„ค์ •๋งŒ ๋ณ€๊ฒฝ)
  5. ๐ŸŸข ์žฌํ•™์Šต ์‹œ์ž‘ (์ˆ˜์ • ์™„๋ฃŒ ํ›„)