frankenstallm / source /configs /3b_pretrain.yaml
pathcosmos's picture
Upload folder using huggingface_hub (#16)
09ea133
raw
history blame
2.41 kB
# Korean LLM 3B parameters โ€” FP8 (B200 TransformerEngine MXFP8)
#
# [์•„ํ‚คํ…์ฒ˜ ๊ทผ๊ฑฐ โ€” 2026-02-27]
# - ์ €์Šคํ‹ฐ์Šค๋ฆฌ๊ทธ ์ œ์•ˆ ๊ธฐ๋ฐ˜: d_model=2560, 32L, 32H, 8KV
# - ํŒŒ๋ผ๋ฏธํ„ฐ: ~2.39B ("3B๊ธ‰" โ€” Llama-3.2-3B ๋Œ€๋น„ ๊ฒฝ๋Ÿ‰, ํ•œ๊ตญ์–ด 64K vocab ํšจ์œจ)
# - d_ffn=6912: 2.7ร—d_model, 16๋ฐฐ์ˆ˜ FP8 ์ •๋ ฌ
# - GQA 4:1 (32H:8KV) โ€” ์ถ”๋ก  ํšจ์œจ + KV cache ์ ˆ์•ฝ
# - head_dim=80 (2560/32) โ€” Flash Attention ํšจ์œจ์ 
#
# [๋ฐ์ดํ„ฐ/ํ•™์Šต ์„ค๊ณ„]
# - ๋ฐ์ดํ„ฐ: korean_train.bin 8.91B tokens
# - Chinchilla ์ตœ์ : 2.4B ร— 20 = 48B tokens
# - ์‹ค์ œ ๋ชฉํ‘œ: 60B tokens (6.7 ์—ํฌํฌ) โ€” ํ•œ๊ตญ์–ด ๋‹จ์ผ ์–ธ์–ด ํŠน์„ฑ์ƒ ์ถ”๊ฐ€ ํ•™์Šต ์œ ๋ฆฌ
# - max_steps 57000 = 60B tokens / 1,048,576 tok/step
#
# [GPU ๋ฉ”๋ชจ๋ฆฌ ์˜ˆ์ธก โ€” 8ร— B200 183GB]
# - ๋ชจ๋ธ FP8: 2.4 GB
# - Optimizer (bf16 master + fp32 mom/var): 23.9 GB
# - Gradient (bf16): 4.8 GB
# - Activation (per GPU, bs=8): ~27 GB
# - ํ•ฉ๊ณ„: ~58 GB/GPU (31.7% ํ™œ์šฉ) โ†’ ์—ฌ์œ  ์ถฉ๋ถ„
#
# ์‹คํ–‰: bash scripts/launch_korean_3b.sh
# ํ…Œ์ŠคํŠธ: RUN_NAME=korean_3b_test bash scripts/launch_korean_3b.sh --max_steps 50
model:
vocab_size: 64000
d_model: 2560
n_layers: 32
n_heads: 32
n_kv_heads: 8 # GQA 4:1 (K/V ํŒŒ๋ผ๋ฏธํ„ฐ 75% ์ ˆ๊ฐ)
d_ffn: 6912 # 2.7ร—d_model, 16๋ฐฐ์ˆ˜ (FP8 alignment)
max_seq_len: 4096
rope_theta: 500000.0
dropout: 0.0
bias: false
use_flash_attn: true
use_fp8: true # TransformerEngine MXFP8BlockScaling (B200 ๋„ค์ดํ‹ฐ๋ธŒ)
train:
# 57k steps ร— 1,048,576 tok/step = 59.8B tokens โ‰ˆ 6.7 ์—ํฌํฌ
max_steps: 57000
batch_size: 4 # per GPU: 4 ร— 4096 = 16,384 ํ† ํฐ | VRAM ~130 GB (183GB์˜ 71%)
grad_accum_steps: 8 # eff_batch: 4 ร— 8GPU ร— 8 ร— 4096 = 1,048,576 tok/step
lr: 1.5e-4 # 3B ๊ทœ๋ชจ: GPT-3 scaling ๊ธฐ์ค€ 1B(2e-4) โ†’ 3B(1.5e-4)
weight_decay: 0.1
warmup_steps: 2000 # 57k steps์˜ 3.5% โ€” ์•ˆ์ •์  warmup
max_grad_norm: 1.0
log_interval: 10
save_interval: 1000 # 57k steps ๊ธฐ์ค€ ~57 ์ฒดํฌํฌ์ธํŠธ
eval_interval: 500 # val loss ๋ชจ๋‹ˆํ„ฐ๋ง
use_amp: false # fp8_autocast๊ฐ€ ๋Œ€์ฒด
compile_model: false # TE 2.10 + DDP graph break ์œ„ํ—˜
fp8_amax_history_len: 16
fp8_amax_compute_algo: "max"
fp8_format: "MXFP8" # B200 Blackwell ๋„ค์ดํ‹ฐ๋ธŒ ๋ธ”๋ก ์Šค์ผ€์ผ๋ง
tokenizer:
vocab_size: 64000
type: sentencepiece_unigram