| # Korean LLM 3B parameters โ FP8 (B200 TransformerEngine MXFP8) | |
| # | |
| # [์ํคํ ์ฒ ๊ทผ๊ฑฐ โ 2026-02-27] | |
| # - ์ ์คํฐ์ค๋ฆฌ๊ทธ ์ ์ ๊ธฐ๋ฐ: d_model=2560, 32L, 32H, 8KV | |
| # - ํ๋ผ๋ฏธํฐ: ~2.39B ("3B๊ธ" โ Llama-3.2-3B ๋๋น ๊ฒฝ๋, ํ๊ตญ์ด 64K vocab ํจ์จ) | |
| # - d_ffn=6912: 2.7รd_model, 16๋ฐฐ์ FP8 ์ ๋ ฌ | |
| # - GQA 4:1 (32H:8KV) โ ์ถ๋ก ํจ์จ + KV cache ์ ์ฝ | |
| # - head_dim=80 (2560/32) โ Flash Attention ํจ์จ์ | |
| # | |
| # [๋ฐ์ดํฐ/ํ์ต ์ค๊ณ] | |
| # - ๋ฐ์ดํฐ: korean_train.bin 8.91B tokens | |
| # - Chinchilla ์ต์ : 2.4B ร 20 = 48B tokens | |
| # - ์ค์ ๋ชฉํ: 60B tokens (6.7 ์ํฌํฌ) โ ํ๊ตญ์ด ๋จ์ผ ์ธ์ด ํน์ฑ์ ์ถ๊ฐ ํ์ต ์ ๋ฆฌ | |
| # - max_steps 57000 = 60B tokens / 1,048,576 tok/step | |
| # | |
| # [GPU ๋ฉ๋ชจ๋ฆฌ ์์ธก โ 8ร B200 183GB] | |
| # - ๋ชจ๋ธ FP8: 2.4 GB | |
| # - Optimizer (bf16 master + fp32 mom/var): 23.9 GB | |
| # - Gradient (bf16): 4.8 GB | |
| # - Activation (per GPU, bs=8): ~27 GB | |
| # - ํฉ๊ณ: ~58 GB/GPU (31.7% ํ์ฉ) โ ์ฌ์ ์ถฉ๋ถ | |
| # | |
| # ์คํ: bash scripts/launch_korean_3b.sh | |
| # ํ ์คํธ: RUN_NAME=korean_3b_test bash scripts/launch_korean_3b.sh --max_steps 50 | |
| model: | |
| vocab_size: 64000 | |
| d_model: 2560 | |
| n_layers: 32 | |
| n_heads: 32 | |
| n_kv_heads: 8 # GQA 4:1 (K/V ํ๋ผ๋ฏธํฐ 75% ์ ๊ฐ) | |
| d_ffn: 6912 # 2.7รd_model, 16๋ฐฐ์ (FP8 alignment) | |
| max_seq_len: 4096 | |
| rope_theta: 500000.0 | |
| dropout: 0.0 | |
| bias: false | |
| use_flash_attn: true | |
| use_fp8: true # TransformerEngine MXFP8BlockScaling (B200 ๋ค์ดํฐ๋ธ) | |
| train: | |
| # 57k steps ร 1,048,576 tok/step = 59.8B tokens โ 6.7 ์ํฌํฌ | |
| max_steps: 57000 | |
| batch_size: 4 # per GPU: 4 ร 4096 = 16,384 ํ ํฐ | VRAM ~130 GB (183GB์ 71%) | |
| grad_accum_steps: 8 # eff_batch: 4 ร 8GPU ร 8 ร 4096 = 1,048,576 tok/step | |
| lr: 1.5e-4 # 3B ๊ท๋ชจ: GPT-3 scaling ๊ธฐ์ค 1B(2e-4) โ 3B(1.5e-4) | |
| weight_decay: 0.1 | |
| warmup_steps: 2000 # 57k steps์ 3.5% โ ์์ ์ warmup | |
| max_grad_norm: 1.0 | |
| log_interval: 10 | |
| save_interval: 1000 # 57k steps ๊ธฐ์ค ~57 ์ฒดํฌํฌ์ธํธ | |
| eval_interval: 500 # val loss ๋ชจ๋ํฐ๋ง | |
| use_amp: false # fp8_autocast๊ฐ ๋์ฒด | |
| compile_model: false # TE 2.10 + DDP graph break ์ํ | |
| fp8_amax_history_len: 16 | |
| fp8_amax_compute_algo: "max" | |
| fp8_format: "MXFP8" # B200 Blackwell ๋ค์ดํฐ๋ธ ๋ธ๋ก ์ค์ผ์ผ๋ง | |
| tokenizer: | |
| vocab_size: 64000 | |
| type: sentencepiece_unigram | |