Initial commit: LLM-1B-Lab project setup
Browse filesLLaMA-style 1.1B parameter Decoder-Only Transformer for educational purposes.
Includes modularized llm_lab package, notebooks, and configuration files.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- .gitignore +47 -0
- CLAUDE.md +131 -0
- LLM_Foundation_Model.code-workspace +8 -0
- _archive/llm-1b-data-pipeline.py +906 -0
- _archive/llm-1b-evaluation.py +1455 -0
- _archive/llm-1b-model.py +791 -0
- _archive/llm-1b-trainer.py +1108 -0
- llm_lab/__init__.py +30 -0
- llm_lab/config/__init__.py +7 -0
- llm_lab/config/data_config.py +41 -0
- llm_lab/config/eval_config.py +20 -0
- llm_lab/config/model_config.py +53 -0
- llm_lab/config/train_config.py +114 -0
- llm_lab/data/__init__.py +11 -0
- llm_lab/data/dataset.py +218 -0
- llm_lab/data/diagnostics.py +153 -0
- llm_lab/data/pipeline.py +156 -0
- llm_lab/data/tokenizer.py +196 -0
- llm_lab/evaluation/__init__.py +21 -0
- llm_lab/evaluation/attention_viz.py +176 -0
- llm_lab/evaluation/checklist.py +99 -0
- llm_lab/evaluation/dynamics.py +242 -0
- llm_lab/evaluation/full_evaluator.py +222 -0
- llm_lab/evaluation/generation.py +200 -0
- llm_lab/evaluation/perplexity.py +172 -0
- llm_lab/evaluation/runner.py +56 -0
- llm_lab/evaluation/scaling.py +153 -0
- llm_lab/model/__init__.py +14 -0
- llm_lab/model/attention.py +134 -0
- llm_lab/model/feedforward.py +48 -0
- llm_lab/model/llm_model.py +200 -0
- llm_lab/model/norm.py +40 -0
- llm_lab/model/rope.py +103 -0
- llm_lab/model/transformer_block.py +65 -0
- llm_lab/model/utils.py +85 -0
- llm_lab/training/__init__.py +12 -0
- llm_lab/training/checkpoint.py +159 -0
- llm_lab/training/metrics.py +112 -0
- llm_lab/training/optimizer.py +54 -0
- llm_lab/training/runner.py +68 -0
- llm_lab/training/scheduler.py +68 -0
- llm_lab/training/trainer.py +351 -0
- llm_lab/utils/__init__.py +5 -0
- llm_lab/utils/device.py +94 -0
- llm_lab/utils/seed.py +9 -0
- notebooks/01_data_pipeline.ipynb +169 -0
- notebooks/02_model.ipynb +212 -0
- notebooks/03_training.ipynb +211 -0
- notebooks/04_evaluation.ipynb +188 -0
- requirements.txt +8 -0
.gitignore
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.egg-info/
|
| 6 |
+
*.egg
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Virtual environments
|
| 12 |
+
venv/
|
| 13 |
+
.venv/
|
| 14 |
+
env/
|
| 15 |
+
|
| 16 |
+
# IDE
|
| 17 |
+
.vscode/
|
| 18 |
+
.idea/
|
| 19 |
+
*.swp
|
| 20 |
+
*.swo
|
| 21 |
+
*~
|
| 22 |
+
|
| 23 |
+
# Jupyter Notebook
|
| 24 |
+
.ipynb_checkpoints/
|
| 25 |
+
|
| 26 |
+
# OS
|
| 27 |
+
.DS_Store
|
| 28 |
+
Thumbs.db
|
| 29 |
+
|
| 30 |
+
# ML / Training artifacts
|
| 31 |
+
*.pt
|
| 32 |
+
*.pth
|
| 33 |
+
*.bin
|
| 34 |
+
*.ckpt
|
| 35 |
+
checkpoints/
|
| 36 |
+
wandb/
|
| 37 |
+
runs/
|
| 38 |
+
|
| 39 |
+
# Data
|
| 40 |
+
*.log
|
| 41 |
+
*.csv
|
| 42 |
+
*.tsv
|
| 43 |
+
data/
|
| 44 |
+
|
| 45 |
+
# Secrets
|
| 46 |
+
.env
|
| 47 |
+
*.key
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM-1B-Lab
|
| 2 |
+
|
| 3 |
+
1.1B parameter LLaMA-style Decoder-Only Transformer ๊ต์ก์ฉ ๊ตฌํ.
|
| 4 |
+
๋ฅ๋ฌ๋ ์ด๋ณด์๊ฐ ์ฒ์๋ถํฐ ๋๊น์ง LLM์ ํ์ตํ๊ณ ํ๊ฐํ๋ ๊ณผ์ ์ ๊ฒฝํํ ์ ์๋๋ก ์ค๊ณ๋จ.
|
| 5 |
+
|
| 6 |
+
## ํ๋ก์ ํธ ๊ตฌ์กฐ
|
| 7 |
+
|
| 8 |
+
```
|
| 9 |
+
LLM_Foundation_Model/
|
| 10 |
+
โโโ CLAUDE.md
|
| 11 |
+
โโโ requirements.txt
|
| 12 |
+
โโโ llm_lab/ # Python ํจํค์ง (ํต์ฌ ์ฝ๋)
|
| 13 |
+
โ โโโ __init__.py
|
| 14 |
+
โ โโโ config/ # ์ค์ ๋ฐ์ดํฐํด๋์ค
|
| 15 |
+
โ โ โโโ model_config.py # ModelConfig (debug_10m / small_100m / base_1b ํ๋ฆฌ์
)
|
| 16 |
+
โ โ โโโ data_config.py # DataConfig (๋ฐ์ดํฐ์
, ํ ํฌ๋์ด์ , ๋ฐฐ์น ์ค์ )
|
| 17 |
+
โ โ โโโ train_config.py # TrainConfig (LR, ์ค์ผ์ค๋ฌ, ์ฒดํฌํฌ์ธํธ, wandb)
|
| 18 |
+
โ โ โโโ eval_config.py # EvalConfig (ํ๊ฐ ํ๋ผ๋ฏธํฐ)
|
| 19 |
+
โ โโโ model/ # ๋ชจ๋ธ ์ํคํ
์ฒ
|
| 20 |
+
โ โ โโโ norm.py # RMSNorm
|
| 21 |
+
โ โ โโโ rope.py # RotaryPositionalEmbedding (RoPE)
|
| 22 |
+
โ โ โโโ attention.py # GroupedQueryAttention (GQA)
|
| 23 |
+
โ โ โโโ feedforward.py # SwiGLUFeedForward
|
| 24 |
+
โ โ โโโ transformer_block.py # TransformerBlock (Pre-LN)
|
| 25 |
+
โ โ โโโ llm_model.py # LLMModel (์ ์ฒด ๋ชจ๋ธ + generate)
|
| 26 |
+
โ โ โโโ utils.py # count_parameters_detailed, estimate_memory_gb
|
| 27 |
+
โ โโโ data/ # ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ
|
| 28 |
+
โ โ โโโ tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
|
| 29 |
+
โ โ โโโ dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
|
| 30 |
+
โ โ โโโ pipeline.py # create_train_dataloader, setup_data_pipeline
|
| 31 |
+
โ โ โโโ diagnostics.py # DataPipelineDiagnostics
|
| 32 |
+
โ โโโ training/ # ํ์ต ๋ฃจํ
|
| 33 |
+
โ โ โโโ scheduler.py # CosineWarmupScheduler
|
| 34 |
+
โ โ โโโ checkpoint.py # CheckpointManager (Google Drive ์ง์)
|
| 35 |
+
โ โ โโโ metrics.py # MetricsTracker (wandb ์ฐ๋)
|
| 36 |
+
โ โ โโโ optimizer.py # create_optimizer (weight decay ๋ถ๋ฆฌ)
|
| 37 |
+
โ โ โโโ trainer.py # Trainer (gradient accumulation, mixed precision)
|
| 38 |
+
โ โ โโโ runner.py # start_training (ํ ์ค ์คํ ํฌํผ)
|
| 39 |
+
โ โโโ evaluation/ # ํ๊ฐ & ๋ถ์
|
| 40 |
+
โ โ โโโ perplexity.py # PerplexityEvaluator (์์น๋ณ Loss ํฌํจ)
|
| 41 |
+
โ โ โโโ generation.py # GenerationEvaluator (๋ค์ํ ํ๋กฌํํธ)
|
| 42 |
+
โ โ โโโ scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
|
| 43 |
+
โ โ โโโ dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad ์๊ฐํ)
|
| 44 |
+
โ โ โโโ attention_viz.py # AttentionVisualizer (ํค๋๋ณ heatmap)
|
| 45 |
+
โ โ โโโ full_evaluator.py # FullEvaluator (์ข
ํฉ ํ๊ฐ + ๋ฆฌํฌํธ)
|
| 46 |
+
โ โ โโโ checklist.py # InsightChecklist (ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ)
|
| 47 |
+
โ โ โโโ runner.py # run_evaluation (ํ ์ค ์คํ ํฌํผ)
|
| 48 |
+
โ โโโ utils/ # ๊ณตํต ์ ํธ๋ฆฌํฐ
|
| 49 |
+
โ โโโ device.py # auto_configure, get_device, detect_gpu_info
|
| 50 |
+
โ โโโ seed.py # set_seed
|
| 51 |
+
โโโ notebooks/ # Jupyter ๋
ธํธ๋ถ (์ค์ + ์คํ)
|
| 52 |
+
โ โโโ 01_data_pipeline.ipynb
|
| 53 |
+
โ โโโ 02_model.ipynb
|
| 54 |
+
โ โโโ 03_training.ipynb
|
| 55 |
+
โ โโโ 04_evaluation.ipynb
|
| 56 |
+
โโโ _archive/ # ์๋ณธ ๋จ์ผํ์ผ ๋ฐฑ์
|
| 57 |
+
โโโ llm-1b-model.py
|
| 58 |
+
โโโ llm-1b-data-pipeline.py
|
| 59 |
+
โโโ llm-1b-trainer.py
|
| 60 |
+
โโโ llm-1b-evaluation.py
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## ๊ธฐ์ ์คํ
|
| 64 |
+
|
| 65 |
+
- **๋ชจ๋ธ**: LLaMA-style Decoder-Only Transformer (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying)
|
| 66 |
+
- **ํ์ต**: Gradient Accumulation, Mixed Precision (bf16/fp16), Cosine LR + Warmup, Activation Checkpointing
|
| 67 |
+
- **๋ฐ์ดํฐ**: HuggingFace Streaming (FineWeb-Edu), BPE ํ ํฌ๋์ด์ , ์ํ์ค ํจํน
|
| 68 |
+
- **์ฒดํฌํฌ์ธํธ**: Google Drive ์๋ ์ ์ฅ/๋ณต์ (Colab Pro+ ํ๊ฒฝ)
|
| 69 |
+
- **ํ๊ฐ**: Perplexity, ํ
์คํธ ์์ฑ, Scaling Law, Attention ์๊ฐํ
|
| 70 |
+
- **ํ๊ฒ ํ๊ฒฝ**: Google Colab Pro+ (A100 40GB)
|
| 71 |
+
|
| 72 |
+
## ์์กด์ฑ ๊ทธ๋ํ (์ํ ์์)
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
config (์์กด์ฑ ์์)
|
| 76 |
+
โ
|
| 77 |
+
utils โ config
|
| 78 |
+
โ
|
| 79 |
+
model โ config
|
| 80 |
+
โ
|
| 81 |
+
data โ config
|
| 82 |
+
โ
|
| 83 |
+
training โ config, utils
|
| 84 |
+
โ
|
| 85 |
+
evaluation โ config
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## ๋ชจ๋ธ ํ๋ฆฌ์
|
| 89 |
+
|
| 90 |
+
| ํ๋ฆฌ์
| ํ๋ผ๋ฏธํฐ | dim | layers | heads | kv_heads | ์ฉ๋ |
|
| 91 |
+
|--------|---------|-----|--------|-------|----------|------|
|
| 92 |
+
| `debug_10m` | ~10M | 256 | 6 | 8 | 4 | ๋น ๋ฅธ ๊ฒ์ฆ/๋๋ฒ๊ทธ |
|
| 93 |
+
| `small_100m` | ~100M | 768 | 12 | 12 | 4 | ์ค๊ฐ ์คํ |
|
| 94 |
+
| `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 | ๋ณธ๊ฒฉ ํ์ต |
|
| 95 |
+
|
| 96 |
+
## Quick Start
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
from llm_lab.config import ModelConfig, DataConfig, TrainConfig
|
| 100 |
+
from llm_lab.model import LLMModel
|
| 101 |
+
from llm_lab.data import setup_data_pipeline
|
| 102 |
+
from llm_lab.training import start_training
|
| 103 |
+
from llm_lab.evaluation import run_evaluation
|
| 104 |
+
|
| 105 |
+
# 1. ๋ชจ๋ธ
|
| 106 |
+
model = LLMModel(ModelConfig.base_1b())
|
| 107 |
+
|
| 108 |
+
# 2. ๋ฐ์ดํฐ
|
| 109 |
+
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 110 |
+
|
| 111 |
+
# 3. ํ์ต
|
| 112 |
+
trainer = start_training(model, train_dl, val_dl)
|
| 113 |
+
|
| 114 |
+
# 4. ํ๊ฐ
|
| 115 |
+
report = run_evaluation(model, tok, val_dl,
|
| 116 |
+
metrics_history=trainer.metrics.history)
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## ์ฝ๋ ์ปจ๋ฒค์
|
| 120 |
+
|
| 121 |
+
- **์ธ์ด**: ์ฝ๋๋ ์์ด, ์ฃผ์/๋
์คํธ๋ง์ ํ๊ตญ์ด (๊ต์ก์ ์ค๋ช
ํฌํจ)
|
| 122 |
+
- **ํ์
ํํธ**: ๋ชจ๋ ํจ์์ typing ์ด๋
ธํ
์ด์
์ฌ์ฉ
|
| 123 |
+
- **import ์์**: stdlib โ torch โ llm_lab (์ ๋ ๊ฒฝ๋ก) โ ๋ก์ปฌ (์๋ ๊ฒฝ๋ก)
|
| 124 |
+
- **๋ฐ์ดํฐํด๋์ค**: ๋ชจ๋ ์ค์ ์ `@dataclass`๋ก ์ ์, ๊ธฐ๋ณธ๊ฐ ํฌํจ
|
| 125 |
+
- **์๋ฌ ์ฒ๋ฆฌ**: ์ธ๋ถ ์์กด์ฑ(matplotlib, wandb ๋ฑ)์ `try/except ImportError`๋ก ์ ํ์ ์ฌ์ฉ
|
| 126 |
+
|
| 127 |
+
## ์ฃผ์์ฌํญ
|
| 128 |
+
|
| 129 |
+
- `torch`๋ ๋ก์ปฌ ํ๊ฒฝ์ ์ค์น๋์ด ์์ง ์์ ์ ์์ (Colab Pro+์์ ์คํ ์ ์ )
|
| 130 |
+
- `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
|
| 131 |
+
- ์๋ณธ 4๊ฐ ํ์ผ(`_archive/`)๊ณผ ๋ชจ๋ํ๋ `llm_lab/` ํจํค์ง์ ๋ก์ง์ ๋์ผ (import ๊ฒฝ๋ก๋ง ๋ณ๊ฒฝ)
|
LLM_Foundation_Model.code-workspace
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"folders": [
|
| 3 |
+
{
|
| 4 |
+
"path": "."
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"settings": {}
|
| 8 |
+
}
|
_archive/llm-1b-data-pipeline.py
ADDED
|
@@ -0,0 +1,906 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-1B-Lab: ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ
|
| 3 |
+
==============================
|
| 4 |
+
ํ ํฌ๋์ด์ ์ค๋น โ ๋ฐ์ดํฐ ์คํธ๋ฆฌ๋ฐ โ ์ํ์ค ํจํน โ ๋ฐฐ์น ๊ตฌ์ฑ
|
| 5 |
+
|
| 6 |
+
์ ์ฒด ํ๋ฆ:
|
| 7 |
+
FineWeb-Edu (HuggingFace)
|
| 8 |
+
โ Streaming์ผ๋ก ๋ก๋ (๋์คํฌ ์ ์ฅ ์์)
|
| 9 |
+
โ ํ ํฌ๋์ด์ง (BPE, vocab=32K)
|
| 10 |
+
โ ์ํ์ค ํจํน (์ฌ๋ฌ ๋ฌธ์๋ฅผ max_seq_len์ผ๋ก ์ฐ๊ฒฐ)
|
| 11 |
+
โ ๋ฐฐ์น ๊ตฌ์ฑ (input_ids, targets)
|
| 12 |
+
โ GPU ์ ์ก
|
| 13 |
+
|
| 14 |
+
์ค์น ํ์ ํจํค์ง:
|
| 15 |
+
pip install datasets tokenizers sentencepiece wandb
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
import json
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import Optional, Iterator, List, Dict, Any
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 27 |
+
|
| 28 |
+
# ============================================================================
|
| 29 |
+
# 1. ๋ฐ์ดํฐ ์ค์
|
| 30 |
+
# ============================================================================
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class DataConfig:
|
| 34 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ .
|
| 35 |
+
|
| 36 |
+
Colab Pro+ ํ๊ฒฝ ์ ์ฝ์ ๊ณ ๋ คํ ๊ธฐ๋ณธ๊ฐ:
|
| 37 |
+
- Streaming ๋ชจ๋๋ก ๋์คํฌ ์ฌ์ฉ ์ต์ํ
|
| 38 |
+
- ์ํ์ค ํจํน์ผ๋ก ํจ๋ฉ ์์ด GPU ํ์ฉ๋ฅ ๊ทน๋ํ
|
| 39 |
+
- ์ ์ฒ๋ฆฌ๋ฅผ on-the-fly๋ก ์ํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 40 |
+
"""
|
| 41 |
+
# โโ ๋ฐ์ดํฐ์
โโ
|
| 42 |
+
dataset_name: str = "HuggingFaceFW/fineweb-edu"
|
| 43 |
+
dataset_subset: str = "sample-10BT" # 10B ํ ํฐ ์ํ
|
| 44 |
+
dataset_split: str = "train"
|
| 45 |
+
text_column: str = "text" # ํ
์คํธ๊ฐ ๋ด๊ธด ์ปฌ๋ผ๋ช
|
| 46 |
+
|
| 47 |
+
# โโ ํ ํฌ๋์ด์ โโ
|
| 48 |
+
tokenizer_type: str = "sentencepiece" # "sentencepiece" ๋๋ "hf"
|
| 49 |
+
# ์ฌ์ ํ์ต๋ ํ ํฌ๋์ด์ ๊ฒฝ๋ก (์์ผ๋ฉด ์๋ก ํ์ต)
|
| 50 |
+
tokenizer_path: Optional[str] = None
|
| 51 |
+
vocab_size: int = 32_000
|
| 52 |
+
|
| 53 |
+
# โโ ์ํ์ค โโ
|
| 54 |
+
max_seq_len: int = 2048
|
| 55 |
+
# ๋ฌธ์ ๊ตฌ๋ถ ํ ํฐ ์ฌ์ฉ ์ฌ๋ถ (ํจํน ์ ๋ฌธ์ ๊ฒฝ๊ณ ํ์)
|
| 56 |
+
use_eos_separator: bool = True
|
| 57 |
+
|
| 58 |
+
# โโ ๋ฐฐ์น โโ
|
| 59 |
+
batch_size: int = 4 # micro batch (GPU๋น)
|
| 60 |
+
num_workers: int = 2 # DataLoader ์์ปค ์
|
| 61 |
+
prefetch_factor: int = 4 # ๋ฏธ๋ฆฌ ์ค๋นํ ๋ฐฐ์น ์
|
| 62 |
+
|
| 63 |
+
# โโ ํ ํฌ๋์ด์ ํ์ต ์ค์ (์๋ก ํ์ต ์) โโ
|
| 64 |
+
tokenizer_train_samples: int = 50_000 # ํ์ต์ ์ฌ์ฉํ ๋ฌธ์ ์
|
| 65 |
+
tokenizer_save_dir: str = "./tokenizer"
|
| 66 |
+
|
| 67 |
+
# โโ ๊ฒ์ฆ ๋ฐ์ดํฐ โโ
|
| 68 |
+
val_ratio: float = 0.001 # ์ ์ฒด์ 0.1%๋ฅผ ๊ฒ์ฆ์ฉ์ผ๋ก
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ============================================================================
|
| 72 |
+
# 2. ํ ํฌ๋์ด์ ๋ํผ
|
| 73 |
+
# ============================================================================
|
| 74 |
+
|
| 75 |
+
class Tokenizer:
|
| 76 |
+
"""ํ ํฌ๋์ด์ ํตํฉ ๋ํผ.
|
| 77 |
+
|
| 78 |
+
์ธ ๊ฐ์ง ๋ฐฉ๋ฒ ์ง์:
|
| 79 |
+
1) ๊ธฐ์กด SentencePiece ๋ชจ๋ธ ๋ก๋
|
| 80 |
+
2) HuggingFace tokenizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์๋ก ํ์ต
|
| 81 |
+
3) ์ฌ์ ํ์ต๋ HF ํ ํฌ๋์ด์ ๋ก๋ (์: LLaMA tokenizer)
|
| 82 |
+
|
| 83 |
+
์ ์ง์ ๊ตฌํํ์ง ์๋๊ฐ?
|
| 84 |
+
- BPE ํ ํฌ๋์ด์ ํ์ต์ ๋๊ท๋ชจ ํ
์คํธ ํต๊ณ ์ฒ๋ฆฌ์ด๋ฉฐ,
|
| 85 |
+
๋ชจ๋ธ ์ํคํ
์ฒ ์ดํด์ ์ง์ ์ ๊ด๋ จ์ด ์ ์ต๋๋ค.
|
| 86 |
+
- ๋ค๋ง ํ ํฌ๋์ด์ ์ ๋์ ์๋ฆฌ(BPE ๋ณํฉ ๊ท์น)๋ ์ดํดํด์ผ ํฉ๋๋ค.
|
| 87 |
+
|
| 88 |
+
BPE(Byte Pair Encoding) ํต์ฌ ์๋ฆฌ:
|
| 89 |
+
1) ํ
์คํธ๋ฅผ ๋ฐ์ดํธ/๋ฌธ์ ๋จ์๋ก ๋ถ๋ฆฌ
|
| 90 |
+
2) ๊ฐ์ฅ ๋น๋ฒํ ์ธ์ ์์ ๋ฐ๋ณต์ ์ผ๋ก ๋ณํฉ
|
| 91 |
+
3) vocab_size์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต
|
| 92 |
+
โ ์์ฃผ ๋ฑ์ฅํ๋ ๋จ์ด๋ ํ๋์ ํ ํฐ, ํฌ๊ท ๋จ์ด๋ ์ฌ๋ฌ ํ ํฐ์ผ๋ก ๋ถ๋ฆฌ
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, config: DataConfig):
|
| 96 |
+
self.config = config
|
| 97 |
+
self._tokenizer = None
|
| 98 |
+
self.vocab_size = config.vocab_size
|
| 99 |
+
|
| 100 |
+
# ํน์ ํ ํฐ ID (์ด๊ธฐํ ํ ์ค์ ๋จ)
|
| 101 |
+
self.bos_id: int = 1 # Beginning of Sequence
|
| 102 |
+
self.eos_id: int = 2 # End of Sequence
|
| 103 |
+
self.pad_id: int = 0 # Padding
|
| 104 |
+
|
| 105 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 106 |
+
# ๋ฐฉ๋ฒ 1: SentencePiece ๋ชจ๋ธ ๋ก๋
|
| 107 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 108 |
+
|
| 109 |
+
def load_sentencepiece(self, model_path: str):
|
| 110 |
+
"""๊ธฐ์กด SentencePiece ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค."""
|
| 111 |
+
import sentencepiece as spm
|
| 112 |
+
|
| 113 |
+
self._tokenizer = spm.SentencePieceProcessor()
|
| 114 |
+
self._tokenizer.Load(model_path)
|
| 115 |
+
|
| 116 |
+
self.vocab_size = self._tokenizer.GetPieceSize()
|
| 117 |
+
self.bos_id = self._tokenizer.bos_id()
|
| 118 |
+
self.eos_id = self._tokenizer.eos_id()
|
| 119 |
+
self.pad_id = self._tokenizer.pad_id()
|
| 120 |
+
self._encode_fn = self._tokenizer.Encode
|
| 121 |
+
self._decode_fn = self._tokenizer.Decode
|
| 122 |
+
|
| 123 |
+
print(f"[Tokenizer] SentencePiece ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 124 |
+
|
| 125 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 126 |
+
# ๋ฐฉ๋ฒ 2: HuggingFace tokenizers๋ก BPE ํ์ต
|
| 127 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 128 |
+
|
| 129 |
+
def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
|
| 130 |
+
"""BPE ํ ํฌ๋์ด์ ๋ฅผ ์ฒ์๋ถํฐ ํ์ตํฉ๋๋ค.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
text_iterator: ํ์ต ํ
์คํธ๋ฅผ yieldํ๋ ์ดํฐ๋ ์ดํฐ
|
| 134 |
+
save_dir: ์ ์ฅ ๊ฒฝ๋ก
|
| 135 |
+
|
| 136 |
+
ํ์ต ํฌ์ธํธ:
|
| 137 |
+
- vocab_size๊ฐ ํด์๋ก: ์์ฃผ ์ฐ๋ ํํ์ด 1ํ ํฐ โ ์ํ์ค ์งง์์ง
|
| 138 |
+
- vocab_size๊ฐ ์์์๋ก: Embedding ํ๋ผ๋ฏธํฐ ์ ์ฝ, ํ์ง๋ง ์ํ์ค ๊ธธ์ด์ง
|
| 139 |
+
- 32K๋ ์์ด ๊ธฐ์ค ์ข์ ๊ท ํ์
|
| 140 |
+
"""
|
| 141 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 142 |
+
from tokenizers.models import BPE
|
| 143 |
+
from tokenizers.trainers import BpeTrainer
|
| 144 |
+
from tokenizers.pre_tokenizers import ByteLevel
|
| 145 |
+
from tokenizers.processors import TemplateProcessing
|
| 146 |
+
|
| 147 |
+
print("[Tokenizer] BPE ํ ํฌ๋์ด์ ํ์ต ์์...")
|
| 148 |
+
|
| 149 |
+
# BPE ๋ชจ๋ธ ์์ฑ
|
| 150 |
+
tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
|
| 151 |
+
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
| 152 |
+
|
| 153 |
+
# ํน์ ํ ํฐ ์ ์
|
| 154 |
+
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
|
| 155 |
+
|
| 156 |
+
# ํธ๋ ์ด๋ ์ค์
|
| 157 |
+
trainer = BpeTrainer(
|
| 158 |
+
vocab_size=self.config.vocab_size,
|
| 159 |
+
special_tokens=special_tokens,
|
| 160 |
+
min_frequency=2, # ์ต์ 2๋ฒ ๋ฑ์ฅํ ์๋ง ๋ณํฉ
|
| 161 |
+
show_progress=True,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# ํ์ต ์คํ
|
| 165 |
+
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
| 166 |
+
|
| 167 |
+
# ํ์ฒ๋ฆฌ: BOS/EOS ์๋ ์ถ๊ฐ
|
| 168 |
+
tokenizer.post_processor = TemplateProcessing(
|
| 169 |
+
single="<s> $A </s>",
|
| 170 |
+
special_tokens=[("<s>", 1), ("</s>", 2)],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self._tokenizer = tokenizer
|
| 174 |
+
self.vocab_size = tokenizer.get_vocab_size()
|
| 175 |
+
self.pad_id = 0
|
| 176 |
+
self.bos_id = 1
|
| 177 |
+
self.eos_id = 2
|
| 178 |
+
|
| 179 |
+
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 180 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 181 |
+
|
| 182 |
+
# ์ ์ฅ
|
| 183 |
+
save_dir = save_dir or self.config.tokenizer_save_dir
|
| 184 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 185 |
+
tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
|
| 186 |
+
# ๋ฉํ ์ ๋ณด ์ ์ฅ
|
| 187 |
+
meta = {
|
| 188 |
+
"vocab_size": self.vocab_size,
|
| 189 |
+
"bos_id": self.bos_id,
|
| 190 |
+
"eos_id": self.eos_id,
|
| 191 |
+
"pad_id": self.pad_id,
|
| 192 |
+
}
|
| 193 |
+
with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
|
| 194 |
+
json.dump(meta, f, indent=2)
|
| 195 |
+
|
| 196 |
+
print(f"[Tokenizer] ํ์ต ์๋ฃ: vocab_size={self.vocab_size}")
|
| 197 |
+
print(f"[Tokenizer] ์ ์ฅ ์์น: {save_dir}")
|
| 198 |
+
|
| 199 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 200 |
+
# ๋ฐฉ๋ฒ 3: ์ฌ์ ํ์ต๋ HF ํ ํฌ๋์ด์ ๋ก๋
|
| 201 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 202 |
+
|
| 203 |
+
def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
|
| 204 |
+
"""HuggingFace์์ ์ฌ์ ํ์ต๋ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 205 |
+
|
| 206 |
+
๊ฐ์ฅ ๊ฐํธํ ๋ฐฉ๋ฒ. LLaMA ํ ํฌ๋์ด์ ๋ 32K vocab, BPE ๊ธฐ๋ฐ.
|
| 207 |
+
์ฃผ์: meta-llama ๋ชจ๋ธ์ HF ์น์ธ์ด ํ์ํ ์ ์์.
|
| 208 |
+
๋์: mistralai/Mistral-7B-v0.1 (์น์ธ ๋ถํ์)
|
| 209 |
+
"""
|
| 210 |
+
from transformers import AutoTokenizer
|
| 211 |
+
|
| 212 |
+
print(f"[Tokenizer] HF ํ ํฌ๋์ด์ ๋ก๋: {name_or_path}")
|
| 213 |
+
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
| 214 |
+
|
| 215 |
+
self._tokenizer = tokenizer
|
| 216 |
+
self.vocab_size = tokenizer.vocab_size
|
| 217 |
+
self.bos_id = tokenizer.bos_token_id or 1
|
| 218 |
+
self.eos_id = tokenizer.eos_token_id or 2
|
| 219 |
+
self.pad_id = tokenizer.pad_token_id or 0
|
| 220 |
+
|
| 221 |
+
self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
| 222 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 223 |
+
|
| 224 |
+
print(f"[Tokenizer] ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 225 |
+
|
| 226 |
+
def load_trained_hf(self, path: str):
|
| 227 |
+
"""train_bpe()๋ก ํ์ตํ ํ ํฌ๋์ด์ ๋ฅผ ๋ค์ ๋ก๋ํฉ๋๋ค."""
|
| 228 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 229 |
+
|
| 230 |
+
tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
|
| 231 |
+
with open(os.path.join(path, "tokenizer_meta.json"), "r") as f:
|
| 232 |
+
meta = json.load(f)
|
| 233 |
+
|
| 234 |
+
self._tokenizer = tokenizer
|
| 235 |
+
self.vocab_size = meta["vocab_size"]
|
| 236 |
+
self.bos_id = meta["bos_id"]
|
| 237 |
+
self.eos_id = meta["eos_id"]
|
| 238 |
+
self.pad_id = meta["pad_id"]
|
| 239 |
+
|
| 240 |
+
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 241 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 242 |
+
|
| 243 |
+
print(f"[Tokenizer] ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 244 |
+
|
| 245 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 246 |
+
# ๊ณตํต ์ธํฐํ์ด์ค
|
| 247 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 248 |
+
|
| 249 |
+
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
| 250 |
+
"""ํ
์คํธ โ ํ ํฐ ID ๋ฆฌ์คํธ."""
|
| 251 |
+
ids = self._encode_fn(text)
|
| 252 |
+
if add_special_tokens:
|
| 253 |
+
ids = [self.bos_id] + ids + [self.eos_id]
|
| 254 |
+
return ids
|
| 255 |
+
|
| 256 |
+
def decode(self, ids: List[int]) -> str:
|
| 257 |
+
"""ํ ํฐ ID ๋ฆฌ์คํธ โ ํ
์คํธ."""
|
| 258 |
+
return self._decode_fn(ids)
|
| 259 |
+
|
| 260 |
+
def __len__(self) -> int:
|
| 261 |
+
return self.vocab_size
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ============================================================================
|
| 265 |
+
# 3. ์ํ์ค ํจํน ์คํธ๋ฆฌ๋ฐ ๋ฐ์ดํฐ์
|
| 266 |
+
# ============================================================================
|
| 267 |
+
|
| 268 |
+
class PackedStreamingDataset(IterableDataset):
|
| 269 |
+
"""Streaming + ์ํ์ค ํจํน ๋ฐ์ดํฐ์
.
|
| 270 |
+
|
| 271 |
+
์ ์ํ์ค ํจํน์ธ๊ฐ?
|
| 272 |
+
- ์ผ๋ฐ์ ๋ฐฉ๋ฒ: ๊ฐ ๋ฌธ์๋ฅผ max_seq_len์ผ๋ก ์๋ผ ํจ๋ฉ โ GPU ๋ญ๋น
|
| 273 |
+
- ์ํ์ค ํจํน: ์ฌ๋ฌ ๋ฌธ์๋ฅผ ์ด์ด๋ถ์ฌ max_seq_len์ ๊ฝ ์ฑ์ โ 100% ํ์ฉ
|
| 274 |
+
|
| 275 |
+
๋์ ๋ฐฉ์:
|
| 276 |
+
๋ฌธ์1 (300 ํ ํฐ) + ๋ฌธ์2 (1500 ํ ํฐ) + ๋ฌธ์3 (248 ํ ํฐ) = 2048 ํ ํฐ
|
| 277 |
+
โ [๋ฌธ์1][EOS][๋ฌธ์2][EOS][๋ฌธ์3][EOS][...ํจ๋ฉ ์์ด ๋ฑ ๋ง์ถค]
|
| 278 |
+
|
| 279 |
+
์ Streaming์ธ๊ฐ?
|
| 280 |
+
- FineWeb-Edu 10B ์ํ: ์์ถ ์ํ์์๋ ์์ญ GB
|
| 281 |
+
- Colab ๋์คํฌ ํ๊ณ (~200GB)์์ ์ ์ฒด ๋ค์ด๋ก๋ ๋ถ๊ฐ
|
| 282 |
+
- Streaming: ํ์ํ ๋งํผ๋ง ๋คํธ์ํฌ์์ ์ฝ์ด์ด
|
| 283 |
+
|
| 284 |
+
ํ์ต ์ ์ฃผ์์ฌํญ:
|
| 285 |
+
- ์ํ์ค ๋ด ๋ฌธ์ ๊ฒฝ๊ณ์ EOS ํ ํฐ ์ฝ์
์ผ๋ก ๋ชจ๋ธ์ด ๋ฌธ์ ๋์ ์ธ์
|
| 286 |
+
- Cross-Attention ๋ง์คํฌ ์์ด๋ EOS๊ฐ ์์ฐ์ค๋ฌ์ด ๊ฒฝ๊ณ ์ญํ
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
tokenizer: Tokenizer,
|
| 292 |
+
config: DataConfig,
|
| 293 |
+
split: str = "train",
|
| 294 |
+
seed: int = 42,
|
| 295 |
+
):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.tokenizer = tokenizer
|
| 298 |
+
self.config = config
|
| 299 |
+
self.split = split
|
| 300 |
+
self.seed = seed
|
| 301 |
+
self.max_seq_len = config.max_seq_len
|
| 302 |
+
|
| 303 |
+
def _load_dataset(self):
|
| 304 |
+
"""HuggingFace ๋ฐ์ดํฐ์
์ ์คํธ๋ฆฌ๋ฐ ๋ชจ๋๋ก ๋ก๋ํฉ๋๋ค."""
|
| 305 |
+
from datasets import load_dataset
|
| 306 |
+
|
| 307 |
+
ds = load_dataset(
|
| 308 |
+
self.config.dataset_name,
|
| 309 |
+
name=self.config.dataset_subset,
|
| 310 |
+
split=self.config.dataset_split,
|
| 311 |
+
streaming=True, # ํต์ฌ: ์คํธ๋ฆฌ๋ฐ ๋ชจ๋
|
| 312 |
+
trust_remote_code=True,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# ์
ํ (์คํธ๋ฆฌ๋ฐ์์๋ ๋ฒํผ ๊ธฐ๋ฐ ๊ทผ์ฌ ์
ํ)
|
| 316 |
+
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 317 |
+
|
| 318 |
+
return ds
|
| 319 |
+
|
| 320 |
+
def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
|
| 321 |
+
"""๋ฌธ์๋ฅผ ํ ํฌ๋์ด์ฆํ๊ณ ์ํ์ค ํจํนํฉ๋๋ค.
|
| 322 |
+
|
| 323 |
+
Yields:
|
| 324 |
+
{"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
|
| 325 |
+
|
| 326 |
+
targets = input_ids๋ฅผ ํ ์นธ shift:
|
| 327 |
+
input_ids: [A, B, C, D, E]
|
| 328 |
+
targets: [B, C, D, E, F]
|
| 329 |
+
โ ๋ชจ๋ธ์ A๋ฅผ ๋ณด๊ณ B๋ฅผ ์์ธก, B๋ฅผ ๋ณด๊ณ C๋ฅผ ์์ธก, ...
|
| 330 |
+
"""
|
| 331 |
+
buffer: List[int] = [] # ํ ํฐ ๋ฒํผ
|
| 332 |
+
|
| 333 |
+
for example in dataset:
|
| 334 |
+
text = example[self.config.text_column]
|
| 335 |
+
if not text or not text.strip():
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
# ํ ํฌ๋์ด์ฆ (ํน์ ํ ํฐ ์์ด)
|
| 339 |
+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 340 |
+
|
| 341 |
+
if not token_ids:
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
# EOS ํ ํฐ ์ถ๊ฐ (๋ฌธ์ ๊ฒฝ๊ณ ํ์)
|
| 345 |
+
if self.config.use_eos_separator:
|
| 346 |
+
token_ids.append(self.tokenizer.eos_id)
|
| 347 |
+
|
| 348 |
+
# ๋ฒํผ์ ์ถ๊ฐ
|
| 349 |
+
buffer.extend(token_ids)
|
| 350 |
+
|
| 351 |
+
# ๋ฒํผ๊ฐ ์ถฉ๋ถํ ์ฐจ๋ฉด ์ํ์ค ์์ฑ
|
| 352 |
+
# +1์ targets ์์ฑ์ ์ํด (input + ๋ค์ ํ ํฐ)
|
| 353 |
+
while len(buffer) >= self.max_seq_len + 1:
|
| 354 |
+
# max_seq_len + 1 ๋งํผ ๊บผ๋
|
| 355 |
+
chunk = buffer[: self.max_seq_len + 1]
|
| 356 |
+
buffer = buffer[self.max_seq_len + 1 :]
|
| 357 |
+
|
| 358 |
+
# input_ids: ์ฒ์ ~ ๋์์ ๋ ๋ฒ์งธ
|
| 359 |
+
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 360 |
+
# targets: ๋ ๋ฒ์งธ ~ ๋ (ํ ์นธ shift)
|
| 361 |
+
targets = torch.tensor(chunk[1:], dtype=torch.long)
|
| 362 |
+
|
| 363 |
+
yield {"input_ids": input_ids, "targets": targets}
|
| 364 |
+
|
| 365 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 366 |
+
"""DataLoader๊ฐ ํธ์ถํ๋ ์ดํฐ๋ ์ดํฐ.
|
| 367 |
+
|
| 368 |
+
๋ฉํฐ ์์ปค ์ง์:
|
| 369 |
+
- ๊ฐ ์์ปค๊ฐ ์๋ก ๋ค๋ฅธ ์๋๋ก ์
ํ๋ ์คํธ๋ฆผ์ ์ฒ๋ฆฌ
|
| 370 |
+
- ์์ปค ๊ฐ ๋ฐ์ดํฐ ์ค๋ณต์ ์ต์ํ
|
| 371 |
+
"""
|
| 372 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 373 |
+
|
| 374 |
+
if worker_info is not None:
|
| 375 |
+
# ๋ฉํฐ ์์ปค: ๊ฐ ์์ปค์ ๋ค๋ฅธ ์๋
|
| 376 |
+
worker_seed = self.seed + worker_info.id
|
| 377 |
+
else:
|
| 378 |
+
worker_seed = self.seed
|
| 379 |
+
|
| 380 |
+
# ์์ปค๋ณ ์๋๋ก ๋ฐ์ดํฐ์
๋ก๋
|
| 381 |
+
self.seed = worker_seed
|
| 382 |
+
dataset = self._load_dataset()
|
| 383 |
+
|
| 384 |
+
return self._tokenize_and_pack(dataset)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# ============================================================================
|
| 388 |
+
# 4. ๊ฒ์ฆ์ฉ ๋ฐ์ดํฐ์
(๊ณ ์ ํฌ๊ธฐ)
|
| 389 |
+
# ============================================================================
|
| 390 |
+
|
| 391 |
+
class ValidationDataset:
|
| 392 |
+
"""๊ฒ์ฆ์ฉ ๋ฐ์ดํฐ์
.
|
| 393 |
+
|
| 394 |
+
Streaming ๋ฐ์ดํฐ์
์์ ์ผ์ ๋์ ๋ฏธ๋ฆฌ ๊ฐ์ ธ์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํฉ๋๋ค.
|
| 395 |
+
๋งค ์ํญ ๋์ผํ ๋ฐ์ดํฐ๋ก ํ๊ฐํด์ผ ๋น๊ต๊ฐ ์๋ฏธ ์๊ธฐ ๋๋ฌธ์
๋๋ค.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(
|
| 399 |
+
self,
|
| 400 |
+
tokenizer: Tokenizer,
|
| 401 |
+
config: DataConfig,
|
| 402 |
+
num_samples: int = 100,
|
| 403 |
+
seed: int = 9999,
|
| 404 |
+
):
|
| 405 |
+
self.tokenizer = tokenizer
|
| 406 |
+
self.config = config
|
| 407 |
+
self.num_samples = num_samples
|
| 408 |
+
self.samples: List[Dict[str, torch.Tensor]] = []
|
| 409 |
+
|
| 410 |
+
self._prepare(seed)
|
| 411 |
+
|
| 412 |
+
def _prepare(self, seed: int):
|
| 413 |
+
"""๋ฐ์ดํฐ์
์์ ๊ฒ์ฆ ์ํ์ ๋ฏธ๋ฆฌ ์ถ์ถํฉ๋๋ค."""
|
| 414 |
+
from datasets import load_dataset
|
| 415 |
+
|
| 416 |
+
print(f"[Validation] {self.num_samples}๊ฐ ๊ฒ์ฆ ์ํ ์ค๋น ์ค...")
|
| 417 |
+
|
| 418 |
+
ds = load_dataset(
|
| 419 |
+
self.config.dataset_name,
|
| 420 |
+
name=self.config.dataset_subset,
|
| 421 |
+
split=self.config.dataset_split,
|
| 422 |
+
streaming=True,
|
| 423 |
+
trust_remote_code=True,
|
| 424 |
+
)
|
| 425 |
+
# ํ์ต ๋ฐ์ดํฐ์ ๊ฒน์น์ง ์๋๋ก ๋ค๋ฅธ ์๋, ์๋ถ๋ถ ๊ฑด๋๋ฐ๊ธฐ
|
| 426 |
+
ds = ds.shuffle(seed=seed, buffer_size=5_000)
|
| 427 |
+
|
| 428 |
+
buffer: List[int] = []
|
| 429 |
+
count = 0
|
| 430 |
+
|
| 431 |
+
for example in ds:
|
| 432 |
+
if count >= self.num_samples:
|
| 433 |
+
break
|
| 434 |
+
|
| 435 |
+
text = example[self.config.text_column]
|
| 436 |
+
if not text or not text.strip():
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 440 |
+
if not token_ids:
|
| 441 |
+
continue
|
| 442 |
+
|
| 443 |
+
token_ids.append(self.tokenizer.eos_id)
|
| 444 |
+
buffer.extend(token_ids)
|
| 445 |
+
|
| 446 |
+
while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples:
|
| 447 |
+
chunk = buffer[: self.config.max_seq_len + 1]
|
| 448 |
+
buffer = buffer[self.config.max_seq_len + 1 :]
|
| 449 |
+
|
| 450 |
+
self.samples.append({
|
| 451 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 452 |
+
"targets": torch.tensor(chunk[1:], dtype=torch.long),
|
| 453 |
+
})
|
| 454 |
+
count += 1
|
| 455 |
+
|
| 456 |
+
print(f"[Validation] {len(self.samples)}๊ฐ ์ํ ์ค๋น ์๋ฃ")
|
| 457 |
+
|
| 458 |
+
def get_dataloader(self, batch_size: int) -> DataLoader:
|
| 459 |
+
"""๊ฒ์ฆ DataLoader๋ฅผ ๋ฐํํฉ๋๋ค."""
|
| 460 |
+
return DataLoader(
|
| 461 |
+
self.samples,
|
| 462 |
+
batch_size=batch_size,
|
| 463 |
+
shuffle=False,
|
| 464 |
+
num_workers=0,
|
| 465 |
+
collate_fn=_collate_fn,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
# ============================================================================
|
| 470 |
+
# 5. DataLoader ์์ฑ ์ ํธ๋ฆฌํฐ
|
| 471 |
+
# ============================================================================
|
| 472 |
+
|
| 473 |
+
def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 474 |
+
"""๋ฐฐ์น ๋ด ์ํ๋ค์ ํ๋์ ํ
์๋ก ํฉ์นฉ๋๋ค.
|
| 475 |
+
|
| 476 |
+
์ํ์ค ํจํน ๋๋ถ์ ๋ชจ๋ ์ํ์ด ๋์ผํ ๊ธธ์ด(max_seq_len)์ด๋ฏ๋ก
|
| 477 |
+
์ถ๊ฐ ํจ๋ฉ์ด ํ์ ์์ต๋๋ค.
|
| 478 |
+
"""
|
| 479 |
+
return {
|
| 480 |
+
"input_ids": torch.stack([s["input_ids"] for s in batch]),
|
| 481 |
+
"targets": torch.stack([s["targets"] for s in batch]),
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def create_train_dataloader(
|
| 486 |
+
tokenizer: Tokenizer,
|
| 487 |
+
config: DataConfig,
|
| 488 |
+
seed: int = 42,
|
| 489 |
+
) -> DataLoader:
|
| 490 |
+
"""ํ์ต์ฉ DataLoader๋ฅผ ์์ฑํฉ๋๋ค.
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
๋ฌดํํ ๋ฐ๋ณต๋๋ ์คํธ๋ฆฌ๋ฐ DataLoader
|
| 494 |
+
|
| 495 |
+
์ฌ์ฉ๋ฒ:
|
| 496 |
+
dataloader = create_train_dataloader(tokenizer, config)
|
| 497 |
+
for step, batch in enumerate(dataloader):
|
| 498 |
+
input_ids = batch["input_ids"].to(device) # (B, seq_len)
|
| 499 |
+
targets = batch["targets"].to(device) # (B, seq_len)
|
| 500 |
+
logits, loss = model(input_ids, targets)
|
| 501 |
+
...
|
| 502 |
+
"""
|
| 503 |
+
dataset = PackedStreamingDataset(
|
| 504 |
+
tokenizer=tokenizer,
|
| 505 |
+
config=config,
|
| 506 |
+
split="train",
|
| 507 |
+
seed=seed,
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
dataloader = DataLoader(
|
| 511 |
+
dataset,
|
| 512 |
+
batch_size=config.batch_size,
|
| 513 |
+
num_workers=config.num_workers,
|
| 514 |
+
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
|
| 515 |
+
pin_memory=True, # GPU ์ ์ก ์๋ ํฅ์
|
| 516 |
+
collate_fn=_collate_fn,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
return dataloader
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# ============================================================================
|
| 523 |
+
# 6. ํ ํฌ๋์ด์ ํ์ต ํฌํผ
|
| 524 |
+
# ============================================================================
|
| 525 |
+
|
| 526 |
+
def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
|
| 527 |
+
"""๋ฐ์ดํฐ์
์์ BPE ํ ํฌ๋์ด์ ๋ฅผ ํ์ตํฉ๋๋ค.
|
| 528 |
+
|
| 529 |
+
์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ค ์ฌ์ฉํ ํ์ ์์ด, 50K ๋ฌธ์๋ฉด ์ถฉ๋ถํฉ๋๋ค.
|
| 530 |
+
ํ ํฌ๋์ด์ vocab์ ์ ์ฒด ๋ฐ์ดํฐ์ ํต๊ณ๋ฅผ ๋ฐ์ํ๋ฉด ๋๋ฏ๋ก.
|
| 531 |
+
"""
|
| 532 |
+
from datasets import load_dataset
|
| 533 |
+
|
| 534 |
+
print(f"[Train Tokenizer] {config.dataset_name}์์ ํ ํฌ๋์ด์ ํ์ต")
|
| 535 |
+
print(f"[Train Tokenizer] ํ์ต ๋ฌธ์ ์: {config.tokenizer_train_samples:,}")
|
| 536 |
+
|
| 537 |
+
# ํ
์คํธ ์ดํฐ๋ ์ดํฐ ์์ฑ
|
| 538 |
+
ds = load_dataset(
|
| 539 |
+
config.dataset_name,
|
| 540 |
+
name=config.dataset_subset,
|
| 541 |
+
split=config.dataset_split,
|
| 542 |
+
streaming=True,
|
| 543 |
+
trust_remote_code=True,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def text_iterator():
|
| 547 |
+
count = 0
|
| 548 |
+
for example in ds:
|
| 549 |
+
if count >= config.tokenizer_train_samples:
|
| 550 |
+
break
|
| 551 |
+
text = example[config.text_column]
|
| 552 |
+
if text and text.strip():
|
| 553 |
+
yield text
|
| 554 |
+
count += 1
|
| 555 |
+
if count % 10_000 == 0:
|
| 556 |
+
print(f" ... {count:,} ๋ฌธ์ ์ฒ๋ฆฌ")
|
| 557 |
+
|
| 558 |
+
# ํ ํฌ๋์ด์ ํ์ต
|
| 559 |
+
tokenizer = Tokenizer(config)
|
| 560 |
+
tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
|
| 561 |
+
|
| 562 |
+
return tokenizer
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# ============================================================================
|
| 566 |
+
# 7. ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ํต๊ณ/์ง๋จ ๋๊ตฌ
|
| 567 |
+
# ============================================================================
|
| 568 |
+
|
| 569 |
+
class DataPipelineDiagnostics:
|
| 570 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ์ฑ๋ฅ๊ณผ ํ์ง์ ์ง๋จํฉ๋๋ค.
|
| 571 |
+
|
| 572 |
+
ํ์ต ์ ๋ฐ๋์ ํ์ธํด์ผ ํ ํญ๋ชฉ:
|
| 573 |
+
1) ํ ํฌ๋์ด์ ํ์ง: ํ๊ท ํ ํฐ/๋ฌธ์, ์ ์ ์๋ ํ ํฐ ๋น์จ
|
| 574 |
+
2) ํจํน ํจ์จ: ์ค์ ํ ํฐ ๋น์จ vs ํจ๋ฉ ๋น์จ
|
| 575 |
+
3) ์ฒ๋ฆฌ ์๋: tokens/sec (๋ฐ์ดํฐ ๋ก๋ฉ ๋ณ๋ชฉ ํ์ธ)
|
| 576 |
+
4) ๋ฐฐ์น ํํ: shape, dtype ์ ํ์ฑ
|
| 577 |
+
"""
|
| 578 |
+
|
| 579 |
+
@staticmethod
|
| 580 |
+
def check_tokenizer_quality(
|
| 581 |
+
tokenizer: Tokenizer,
|
| 582 |
+
config: DataConfig,
|
| 583 |
+
num_samples: int = 1000,
|
| 584 |
+
):
|
| 585 |
+
"""ํ ํฌ๋์ด์ ํ์ง์ ์ง๋จํฉ๋๋ค."""
|
| 586 |
+
from datasets import load_dataset
|
| 587 |
+
|
| 588 |
+
print("\n" + "=" * 60)
|
| 589 |
+
print("๐ ํ ํฌ๋์ด์ ํ์ง ์ง๋จ")
|
| 590 |
+
print("=" * 60)
|
| 591 |
+
|
| 592 |
+
ds = load_dataset(
|
| 593 |
+
config.dataset_name,
|
| 594 |
+
name=config.dataset_subset,
|
| 595 |
+
split=config.dataset_split,
|
| 596 |
+
streaming=True,
|
| 597 |
+
trust_remote_code=True,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
token_counts = []
|
| 601 |
+
char_counts = []
|
| 602 |
+
sample_count = 0
|
| 603 |
+
|
| 604 |
+
for example in ds:
|
| 605 |
+
if sample_count >= num_samples:
|
| 606 |
+
break
|
| 607 |
+
text = example[config.text_column]
|
| 608 |
+
if not text or not text.strip():
|
| 609 |
+
continue
|
| 610 |
+
|
| 611 |
+
tokens = tokenizer.encode(text)
|
| 612 |
+
token_counts.append(len(tokens))
|
| 613 |
+
char_counts.append(len(text))
|
| 614 |
+
sample_count += 1
|
| 615 |
+
|
| 616 |
+
avg_tokens = sum(token_counts) / len(token_counts)
|
| 617 |
+
avg_chars = sum(char_counts) / len(char_counts)
|
| 618 |
+
compression_ratio = avg_chars / avg_tokens # ๋ฌธ์/ํ ํฐ ๋น์จ
|
| 619 |
+
|
| 620 |
+
print(f" ๋ถ์ ๋ฌธ์ ์: {len(token_counts):,}")
|
| 621 |
+
print(f" ํ๊ท ํ ํฐ/๋ฌธ์: {avg_tokens:.1f}")
|
| 622 |
+
print(f" ํ๊ท ๋ฌธ์/๋ฌธ์: {avg_chars:.1f}")
|
| 623 |
+
print(f" ์์ถ ๋น์จ (๋ฌธ์/ํ ํฐ): {compression_ratio:.2f}")
|
| 624 |
+
print(f" โ ์์ด ๊ธฐ์ค 3.5~4.5๊ฐ ์ ์")
|
| 625 |
+
print(f" ์ต์ ํ ํฐ: {min(token_counts)}, ์ต๋: {max(token_counts)}")
|
| 626 |
+
|
| 627 |
+
# ๋์ฝ๋ ์๋ณต ํ
์คํธ
|
| 628 |
+
test_text = "The quick brown fox jumps over the lazy dog."
|
| 629 |
+
encoded = tokenizer.encode(test_text)
|
| 630 |
+
decoded = tokenizer.decode(encoded)
|
| 631 |
+
roundtrip_ok = test_text.strip() in decoded.strip()
|
| 632 |
+
print(f"\n ์๋ณต ํ
์คํธ: {'โ
ํต๊ณผ' if roundtrip_ok else 'โ ์คํจ'}")
|
| 633 |
+
print(f" ์๋ณธ: {test_text}")
|
| 634 |
+
print(f" ์ธ์ฝ๋ฉ: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
|
| 635 |
+
print(f" ๋์ฝ๋ฉ: {decoded}")
|
| 636 |
+
|
| 637 |
+
@staticmethod
|
| 638 |
+
def benchmark_throughput(
|
| 639 |
+
dataloader: DataLoader,
|
| 640 |
+
num_batches: int = 50,
|
| 641 |
+
seq_len: int = 2048,
|
| 642 |
+
):
|
| 643 |
+
"""๋ฐ์ดํฐ ๋ก๋ฉ ์ฒ๋ฆฌ๋์ ์ธก์ ํฉ๋๋ค.
|
| 644 |
+
|
| 645 |
+
GPU ํ์ต ์๋์ ๋ณ๋ชฉ์ด ๋ฐ์ดํฐ ๋ก๋ฉ์ธ์ง ํ์ธํ๋ ํต์ฌ ์ง๋จ.
|
| 646 |
+
๋ชฉํ: ๋ฐ์ดํฐ ๋ก๋ฉ์ด GPU ์ฐ์ฐ๋ณด๋ค ๋นจ๋ผ์ผ ํจ (data loading โ bottleneck).
|
| 647 |
+
"""
|
| 648 |
+
print("\n" + "=" * 60)
|
| 649 |
+
print("โก ๋ฐ์ดํฐ ๋ก๋ฉ ์ฒ๋ฆฌ๋ ๋ฒค์น๋งํฌ")
|
| 650 |
+
print("=" * 60)
|
| 651 |
+
|
| 652 |
+
total_tokens = 0
|
| 653 |
+
start_time = time.time()
|
| 654 |
+
|
| 655 |
+
for i, batch in enumerate(dataloader):
|
| 656 |
+
if i >= num_batches:
|
| 657 |
+
break
|
| 658 |
+
batch_tokens = batch["input_ids"].numel()
|
| 659 |
+
total_tokens += batch_tokens
|
| 660 |
+
|
| 661 |
+
if (i + 1) % 10 == 0:
|
| 662 |
+
elapsed = time.time() - start_time
|
| 663 |
+
tps = total_tokens / elapsed
|
| 664 |
+
print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
|
| 665 |
+
|
| 666 |
+
elapsed = time.time() - start_time
|
| 667 |
+
tps = total_tokens / elapsed
|
| 668 |
+
|
| 669 |
+
print(f"\n ์ด ๋ฐฐ์น ์: {num_batches}")
|
| 670 |
+
print(f" ์ด ํ ํฐ ์: {total_tokens:,}")
|
| 671 |
+
print(f" ์์ ์๊ฐ: {elapsed:.2f}์ด")
|
| 672 |
+
print(f" ํ๊ท ์ฒ๋ฆฌ๋: {tps:,.0f} tokens/sec")
|
| 673 |
+
print(f"\n ๐ก A100 ํ์ต ์ฒ๋ฆฌ๋ ~50-80K tokens/sec ๊ธฐ์ค:")
|
| 674 |
+
if tps > 80_000:
|
| 675 |
+
print(f" โ
๋ฐ์ดํฐ ๋ก๋ฉ์ด ๋ณ๋ชฉ์ด ์๋๋๋ค")
|
| 676 |
+
elif tps > 30_000:
|
| 677 |
+
print(f" โ ๏ธ ๊ฒฝ๊ณ์ - num_workers ์ฆ๊ฐ๋ฅผ ๊ณ ๋ คํ์ธ์")
|
| 678 |
+
else:
|
| 679 |
+
print(f" โ ๋ฐ์ดํฐ ๋ก๋ฉ์ด ๋ณ๋ชฉ! num_workers/prefetch ์กฐ์ ํ์")
|
| 680 |
+
|
| 681 |
+
@staticmethod
|
| 682 |
+
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
|
| 683 |
+
"""๋ฐฐ์น ํ๋๋ฅผ ์์ธ ๊ฒ์ฌํฉ๋๋ค."""
|
| 684 |
+
print("\n" + "=" * 60)
|
| 685 |
+
print("๐ ๋ฐฐ์น ์์ธ ๊ฒ์ฌ")
|
| 686 |
+
print("=" * 60)
|
| 687 |
+
|
| 688 |
+
input_ids = batch["input_ids"]
|
| 689 |
+
targets = batch["targets"]
|
| 690 |
+
|
| 691 |
+
print(f" input_ids shape: {input_ids.shape}")
|
| 692 |
+
print(f" targets shape: {targets.shape}")
|
| 693 |
+
print(f" dtype: {input_ids.dtype}")
|
| 694 |
+
print(f" ๊ฐ ๋ฒ์: [{input_ids.min().item()}, {input_ids.max().item()}]")
|
| 695 |
+
|
| 696 |
+
# Shift ๊ด๊ณ ํ์ธ: targets[i] == input_ids[i+1]
|
| 697 |
+
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
|
| 698 |
+
print(f" Shift ์ ํฉ์ฑ: {shift_correct*100:.1f}% (100%์ฌ์ผ ์ ์)")
|
| 699 |
+
|
| 700 |
+
# EOS ํ ํฐ ๋ถํฌ (๋ฌธ์ ๊ฒฝ๊ณ)
|
| 701 |
+
eos_count = (input_ids == tokenizer.eos_id).sum().item()
|
| 702 |
+
total_tokens = input_ids.numel()
|
| 703 |
+
print(f" EOS ํ ํฐ ์: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
|
| 704 |
+
|
| 705 |
+
# ์ฒซ ๋ฒ์งธ ์ํ ๋์ฝ๋ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
|
| 706 |
+
first_sample = input_ids[0][:100].tolist()
|
| 707 |
+
decoded_preview = tokenizer.decode(first_sample)
|
| 708 |
+
print(f"\n ์ฒซ ์ํ ๋์ฝ๋ฉ (์ฒ์ 100 ํ ํฐ):")
|
| 709 |
+
print(f" {decoded_preview[:300]}...")
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
# ============================================================================
|
| 713 |
+
# 8. ์ ์ฒด ํ์ดํ๋ผ์ธ ํตํฉ (Quick Start)
|
| 714 |
+
# ============================================================================
|
| 715 |
+
|
| 716 |
+
def setup_data_pipeline(
|
| 717 |
+
tokenizer_mode: str = "train_new",
|
| 718 |
+
tokenizer_path: Optional[str] = None,
|
| 719 |
+
config: Optional[DataConfig] = None,
|
| 720 |
+
) -> tuple:
|
| 721 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ํ ๋ฒ์ ์ค์ ํฉ๋๋ค.
|
| 722 |
+
|
| 723 |
+
Args:
|
| 724 |
+
tokenizer_mode:
|
| 725 |
+
"train_new" - BPE ํ ํฌ๋์ด์ ์๋ก ํ์ต
|
| 726 |
+
"load_trained" - ์ด์ ์ ํ์ตํ ํ ํฌ๋์ด์ ๋ก๋
|
| 727 |
+
"pretrained" - HuggingFace ์ฌ์ ํ์ต ํ ํฌ๋์ด์ ์ฌ์ฉ
|
| 728 |
+
tokenizer_path:
|
| 729 |
+
"train_new" โ ์ ์ฅ ๊ฒฝ๋ก (๊ธฐ๋ณธ: ./tokenizer)
|
| 730 |
+
"load_trained" โ ์ ์ฅ๋ ํ ํฌ๋์ด์ ๊ฒฝ๋ก
|
| 731 |
+
"pretrained" โ HF ๋ชจ๋ธ๋ช
(๊ธฐ๋ณธ: mistralai/Mistral-7B-v0.1)
|
| 732 |
+
|
| 733 |
+
Returns:
|
| 734 |
+
(tokenizer, train_dataloader, val_dataloader)
|
| 735 |
+
|
| 736 |
+
์ฌ์ฉ ์์ (Colab):
|
| 737 |
+
# ๋ฐฉ๋ฒ 1: ํ ํฌ๋์ด์ ์๋ก ํ์ต
|
| 738 |
+
tok, train_dl, val_dl = setup_data_pipeline("train_new")
|
| 739 |
+
|
| 740 |
+
# ๋ฐฉ๋ฒ 2: ๊ธฐ์กด ํ ํฌ๋์ด์ ๋ก๋
|
| 741 |
+
tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
|
| 742 |
+
|
| 743 |
+
# ๋ฐฉ๋ฒ 3: ์ฌ์ ํ์ต ํ ํฌ๋์ด์ (๊ฐ์ฅ ๊ฐํธ)
|
| 744 |
+
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 745 |
+
"""
|
| 746 |
+
config = config or DataConfig()
|
| 747 |
+
|
| 748 |
+
print("=" * 60)
|
| 749 |
+
print("๐ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ ")
|
| 750 |
+
print("=" * 60)
|
| 751 |
+
|
| 752 |
+
# โโ Step 1: ํ ํฌ๋์ด์ โโ
|
| 753 |
+
tokenizer = Tokenizer(config)
|
| 754 |
+
|
| 755 |
+
if tokenizer_mode == "train_new":
|
| 756 |
+
tokenizer = train_tokenizer_from_dataset(config)
|
| 757 |
+
elif tokenizer_mode == "load_trained":
|
| 758 |
+
path = tokenizer_path or config.tokenizer_save_dir
|
| 759 |
+
tokenizer.load_trained_hf(path)
|
| 760 |
+
elif tokenizer_mode == "pretrained":
|
| 761 |
+
name = tokenizer_path or "mistralai/Mistral-7B-v0.1"
|
| 762 |
+
tokenizer.load_pretrained_hf(name)
|
| 763 |
+
else:
|
| 764 |
+
raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
|
| 765 |
+
|
| 766 |
+
# โโ Step 2: ํ์ต DataLoader โโ
|
| 767 |
+
print("\n[DataLoader] ํ์ต DataLoader ์์ฑ...")
|
| 768 |
+
train_dataloader = create_train_dataloader(tokenizer, config)
|
| 769 |
+
|
| 770 |
+
# โโ Step 3: ๊ฒ์ฆ DataLoader โโ
|
| 771 |
+
print("\n[DataLoader] ๊ฒ์ฆ DataLoader ์์ฑ...")
|
| 772 |
+
val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
|
| 773 |
+
val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
|
| 774 |
+
|
| 775 |
+
print("\n" + "=" * 60)
|
| 776 |
+
print("โ
๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ ์๋ฃ!")
|
| 777 |
+
print(f" ํ ํฌ๋์ด์ vocab: {tokenizer.vocab_size:,}")
|
| 778 |
+
print(f" ์ํ์ค ๊ธธ์ด: {config.max_seq_len}")
|
| 779 |
+
print(f" ๋ฐฐ์น ํฌ๊ธฐ: {config.batch_size}")
|
| 780 |
+
print(f" ํ ํฐ/๋ฐฐ์น: {config.batch_size * config.max_seq_len:,}")
|
| 781 |
+
print("=" * 60)
|
| 782 |
+
|
| 783 |
+
return tokenizer, train_dataloader, val_dataloader
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
# ============================================================================
|
| 787 |
+
# 9. ๊ฒ์ฆ ์คํฌ๋ฆฝํธ
|
| 788 |
+
# ============================================================================
|
| 789 |
+
|
| 790 |
+
if __name__ == "__main__":
|
| 791 |
+
"""
|
| 792 |
+
๋ก์ปฌ/Colab์์ ์คํํ์ฌ ํ์ดํ๋ผ์ธ์ ๊ฒ์ฆํฉ๋๋ค.
|
| 793 |
+
|
| 794 |
+
์คํ ๋ฐฉ๋ฒ:
|
| 795 |
+
python data_pipeline.py
|
| 796 |
+
|
| 797 |
+
๋๋ Colab์์:
|
| 798 |
+
!pip install datasets tokenizers sentencepiece
|
| 799 |
+
%run data_pipeline.py
|
| 800 |
+
"""
|
| 801 |
+
print("=" * 70)
|
| 802 |
+
print("LLM-1B-Lab: ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ๊ฒ์ฆ")
|
| 803 |
+
print("=" * 70)
|
| 804 |
+
|
| 805 |
+
# โโ ๊ฐ๋จํ ๊ฒ์ฆ: ๋๋ฏธ ํ ํฌ๋์ด์ ๋ก ํ์ดํ๋ผ์ธ ํ
์คํธ โโ
|
| 806 |
+
print("\n[ํ
์คํธ 1] ๋๋ฏธ ํ ํฌ๋์ด์ ๋ก ํ์ดํ๋ผ์ธ ๊ตฌ์กฐ ๊ฒ์ฆ")
|
| 807 |
+
|
| 808 |
+
# ๋๋ฏธ ํ ํฌ๋์ด์ (์ค์ ๋ฐ์ดํฐ์
์์ด ํ
์คํธ)
|
| 809 |
+
class DummyTokenizer:
|
| 810 |
+
"""ํ
์คํธ์ฉ ๊ฐ๋จํ ๋ฌธ์ ๋จ์ ํ ํฌ๋์ด์ ."""
|
| 811 |
+
def __init__(self, vocab_size=256):
|
| 812 |
+
self.vocab_size = vocab_size
|
| 813 |
+
self.eos_id = 2
|
| 814 |
+
self.bos_id = 1
|
| 815 |
+
self.pad_id = 0
|
| 816 |
+
|
| 817 |
+
def encode(self, text, add_special_tokens=False):
|
| 818 |
+
# ๊ฐ ๋ฌธ์๋ฅผ ASCII ๊ฐ์ผ๋ก ๋ณํ (๊ฐ๋จํ ํ
์คํธ์ฉ)
|
| 819 |
+
ids = [min(ord(c), self.vocab_size - 1) for c in text]
|
| 820 |
+
if add_special_tokens:
|
| 821 |
+
ids = [self.bos_id] + ids + [self.eos_id]
|
| 822 |
+
return ids
|
| 823 |
+
|
| 824 |
+
def decode(self, ids):
|
| 825 |
+
return "".join(chr(min(i, 127)) for i in ids if i > 2)
|
| 826 |
+
|
| 827 |
+
def __len__(self):
|
| 828 |
+
return self.vocab_size
|
| 829 |
+
|
| 830 |
+
config = DataConfig(max_seq_len=64, batch_size=2) # ์์ ์ค์
|
| 831 |
+
dummy_tok = DummyTokenizer()
|
| 832 |
+
|
| 833 |
+
# ๋๋ฏธ ๋ฐ์ดํฐ๋ก ํจํน ํ
์คํธ
|
| 834 |
+
print("\n[ํ
์คํธ 2] ์ํ์ค ํจํน ๋ก์ง ๊ฒ์ฆ")
|
| 835 |
+
|
| 836 |
+
buffer = []
|
| 837 |
+
test_docs = [
|
| 838 |
+
"Hello world! This is document one. " * 5,
|
| 839 |
+
"Second document here with different content. " * 8,
|
| 840 |
+
"Third doc. " * 20,
|
| 841 |
+
"A " * 200,
|
| 842 |
+
]
|
| 843 |
+
|
| 844 |
+
for doc in test_docs:
|
| 845 |
+
tokens = dummy_tok.encode(doc)
|
| 846 |
+
tokens.append(dummy_tok.eos_id)
|
| 847 |
+
buffer.extend(tokens)
|
| 848 |
+
|
| 849 |
+
seq_len = config.max_seq_len
|
| 850 |
+
packed_count = 0
|
| 851 |
+
while len(buffer) >= seq_len + 1:
|
| 852 |
+
chunk = buffer[: seq_len + 1]
|
| 853 |
+
buffer = buffer[seq_len + 1 :]
|
| 854 |
+
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 855 |
+
targets = torch.tensor(chunk[1:], dtype=torch.long)
|
| 856 |
+
|
| 857 |
+
# Shift ๊ด๊ณ ํ์ธ
|
| 858 |
+
assert (input_ids[1:] == targets[:-1]).all(), "Shift ๊ด๊ณ ์ค๋ฅ!"
|
| 859 |
+
packed_count += 1
|
| 860 |
+
|
| 861 |
+
print(f" ๋ฌธ์ ์: {len(test_docs)}")
|
| 862 |
+
print(f" ์ด ํ ํฐ ์: {sum(len(dummy_tok.encode(d)) + 1 for d in test_docs)}")
|
| 863 |
+
print(f" ํจํน๋ ์ํ์ค ์: {packed_count}")
|
| 864 |
+
print(f" ์ํ์ค ๊ธธ์ด: {seq_len}")
|
| 865 |
+
print(f" ๋จ์ ๋ฒํผ: {len(buffer)} ํ ํฐ")
|
| 866 |
+
print(f" โ
Shift ๊ด๊ณ ๊ฒ์ฆ ํต๊ณผ")
|
| 867 |
+
|
| 868 |
+
# ๋ฐฐ์น ๊ตฌ์ฑ ํ
์คํธ
|
| 869 |
+
print("\n[ํ
์คํธ 3] ๋ฐฐ์น ๊ตฌ์ฑ ๊ฒ์ฆ")
|
| 870 |
+
|
| 871 |
+
samples = []
|
| 872 |
+
buffer2 = []
|
| 873 |
+
for doc in test_docs * 10: # ์ถฉ๋ถํ ๋ฐ์ดํฐ ์์ฑ
|
| 874 |
+
tokens = dummy_tok.encode(doc)
|
| 875 |
+
tokens.append(dummy_tok.eos_id)
|
| 876 |
+
buffer2.extend(tokens)
|
| 877 |
+
|
| 878 |
+
while len(buffer2) >= seq_len + 1 and len(samples) < 10:
|
| 879 |
+
chunk = buffer2[: seq_len + 1]
|
| 880 |
+
buffer2 = buffer2[seq_len + 1 :]
|
| 881 |
+
samples.append({
|
| 882 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 883 |
+
"targets": torch.tensor(chunk[1:], dtype=torch.long),
|
| 884 |
+
})
|
| 885 |
+
|
| 886 |
+
batch = _collate_fn(samples[:config.batch_size])
|
| 887 |
+
print(f" input_ids shape: {batch['input_ids'].shape}")
|
| 888 |
+
print(f" targets shape: {batch['targets'].shape}")
|
| 889 |
+
print(f" dtype: {batch['input_ids'].dtype}")
|
| 890 |
+
|
| 891 |
+
expected_shape = (config.batch_size, seq_len)
|
| 892 |
+
assert batch["input_ids"].shape == expected_shape, f"Shape ๋ถ์ผ์น: {batch['input_ids'].shape} != {expected_shape}"
|
| 893 |
+
print(f" โ
๋ฐฐ์น shape ๊ฒ์ฆ ํต๊ณผ: {expected_shape}")
|
| 894 |
+
|
| 895 |
+
# EOS ํ ํฐ ์กด์ฌ ํ์ธ
|
| 896 |
+
eos_found = (batch["input_ids"] == dummy_tok.eos_id).any().item()
|
| 897 |
+
print(f" โ
EOS ํ ํฐ ์กด์ฌ: {eos_found}")
|
| 898 |
+
|
| 899 |
+
print("\n" + "=" * 70)
|
| 900 |
+
print("โ
๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ๊ตฌ์กฐ ๊ฒ์ฆ ์๋ฃ!")
|
| 901 |
+
print()
|
| 902 |
+
print("๋ค์ ๋จ๊ณ: ์ค์ ๋ฐ์ดํฐ์
์ผ๋ก ํ
์คํธ")
|
| 903 |
+
print(" tokenizer, train_dl, val_dl = setup_data_pipeline('pretrained')")
|
| 904 |
+
print(" DataPipelineDiagnostics.check_tokenizer_quality(tokenizer, DataConfig())")
|
| 905 |
+
print(" DataPipelineDiagnostics.benchmark_throughput(train_dl)")
|
| 906 |
+
print("=" * 70)
|
_archive/llm-1b-evaluation.py
ADDED
|
@@ -0,0 +1,1455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-1B-Lab: ํ๊ฐ ๋ชจ๋ (Evaluation)
|
| 3 |
+
=====================================
|
| 4 |
+
ํ์ต๋ ๋ชจ๋ธ์ ํ์ง์ ๋ค๊ฐ๋๋ก ํ๊ฐํ๊ณ ,
|
| 5 |
+
ํ์ต ๊ณผ์ ์์ ์ป์ ํต์ฐฐ์ ๋ถ์ํฉ๋๋ค.
|
| 6 |
+
|
| 7 |
+
ํ๊ฐ ์์ญ:
|
| 8 |
+
1. Perplexity ์ธก์ โ ์ธ์ด ๋ชจ๋ธ์ ํ์ค ์ ๋ ์งํ
|
| 9 |
+
2. ํ
์คํธ ์์ฑ ํ์ง โ ์ ์ฑ์ ํ๊ฐ (๋ค์ํ ํ๋กฌํํธ)
|
| 10 |
+
3. Scaling Law ๋ถ์ โ 10M โ 100M โ 1B ๋น๊ต
|
| 11 |
+
4. ํ์ต ์ญํ ๋ถ์ โ Loss ๊ณก์ , LR, Gradient ํจํด
|
| 12 |
+
5. Attention ์๊ฐํ โ ๋ชจ๋ธ์ด "์ด๋๋ฅผ ๋ณด๋์ง" ๋ถ์
|
| 13 |
+
6. ์ข
ํฉ ๋ฆฌํฌํธ ์์ฑ โ ํ์ต ์ธ์ฌ์ดํธ ์ ๋ฆฌ
|
| 14 |
+
|
| 15 |
+
์ค์น ํ์:
|
| 16 |
+
pip install matplotlib numpy
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import time
|
| 21 |
+
import json
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from torch.utils.data import DataLoader
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
import matplotlib
|
| 33 |
+
matplotlib.use("Agg") # Colab/์๋ฒ ํธํ
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
+
import matplotlib.ticker as ticker
|
| 36 |
+
HAS_MATPLOTLIB = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
HAS_MATPLOTLIB = False
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
import numpy as np
|
| 42 |
+
HAS_NUMPY = True
|
| 43 |
+
except ImportError:
|
| 44 |
+
HAS_NUMPY = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ============================================================================
|
| 48 |
+
# 1. ํ๊ฐ ์ค์
|
| 49 |
+
# ============================================================================
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class EvalConfig:
|
| 53 |
+
"""ํ๊ฐ ํ๋ผ๋ฏธํฐ."""
|
| 54 |
+
# โโ Perplexity โโ
|
| 55 |
+
eval_batch_size: int = 4
|
| 56 |
+
max_eval_batches: int = 100 # ์ต๋ ํ๊ฐ ๋ฐฐ์น ์
|
| 57 |
+
|
| 58 |
+
# โโ ์์ฑ โโ
|
| 59 |
+
max_new_tokens: int = 200
|
| 60 |
+
temperature: float = 0.8
|
| 61 |
+
top_k: int = 50
|
| 62 |
+
top_p: float = 0.9
|
| 63 |
+
num_samples: int = 3 # ํ๋กฌํํธ๋น ์์ฑ ํ์
|
| 64 |
+
|
| 65 |
+
# โโ ์ถ๋ ฅ โโ
|
| 66 |
+
save_dir: str = "./eval_results"
|
| 67 |
+
plot_dpi: int = 150
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ============================================================================
|
| 71 |
+
# 2. Perplexity ํ๊ฐ๊ธฐ
|
| 72 |
+
# ============================================================================
|
| 73 |
+
|
| 74 |
+
class PerplexityEvaluator:
|
| 75 |
+
"""Perplexity(PPL)๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 76 |
+
|
| 77 |
+
Perplexity๋?
|
| 78 |
+
PPL = exp(average cross-entropy loss)
|
| 79 |
+
|
| 80 |
+
์ง๊ด์ ์๋ฏธ:
|
| 81 |
+
- PPL = 1: ์๋ฒฝํ ์์ธก (๋ถ๊ฐ๋ฅ)
|
| 82 |
+
- PPL = 10: ๋งค๋ฒ 10๊ฐ ํ๋ณด ์ค ๊ณ ๋ฅด๋ ์์ค
|
| 83 |
+
- PPL = 100: 100๊ฐ ํ๋ณด ์ค ๊ณ ๋ฅด๋ ์์ค (๋ฌด์์์ ๊ฐ๊น์)
|
| 84 |
+
- PPL = 32000: vocab ์ ์ฒด์์ ๋๋ค ์ ํ (์ด๊ธฐ ๋๋ค ๋ชจ๋ธ)
|
| 85 |
+
|
| 86 |
+
์ข์ 1B ๋ชจ๋ธ ๊ธฐ์ค (์์ด ์น ํ
์คํธ):
|
| 87 |
+
- 5B ํ ํฐ ํ์ต: PPL ~30-40
|
| 88 |
+
- 10B ํ ํฐ ํ์ต: PPL ~20-30
|
| 89 |
+
- 20B ํ ํฐ ํ์ต: PPL ~15-25
|
| 90 |
+
|
| 91 |
+
์ธก์ ๋ฐฉ๋ฒ:
|
| 92 |
+
- ๊ฒ์ฆ ๋ฐ์ดํฐ์
์ ๋ชจ๋ ํ ํฐ์ ๋ํด cross-entropy ๊ณ์ฐ
|
| 93 |
+
- ํ ํฐ ๋จ์ ํ๊ท ํ exp() ์ ์ฉ
|
| 94 |
+
- ํจ๋ฉ ํ ํฐ์ ์ ์ธ (ignore_index=-100)
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, config: EvalConfig):
|
| 98 |
+
self.config = config
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def evaluate(
|
| 102 |
+
self,
|
| 103 |
+
model: nn.Module,
|
| 104 |
+
dataloader: DataLoader,
|
| 105 |
+
device: torch.device,
|
| 106 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 107 |
+
desc: str = "Evaluation",
|
| 108 |
+
) -> Dict[str, float]:
|
| 109 |
+
"""Perplexity๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
{
|
| 113 |
+
"loss": ํ๊ท cross-entropy loss,
|
| 114 |
+
"perplexity": exp(loss),
|
| 115 |
+
"num_tokens": ํ๊ฐ์ ์ฌ์ฉ๋ ์ด ํ ํฐ ์,
|
| 116 |
+
"num_batches": ํ๊ฐ์ ์ฌ์ฉ๋ ๋ฐฐ์น ์,
|
| 117 |
+
}
|
| 118 |
+
"""
|
| 119 |
+
model.eval()
|
| 120 |
+
|
| 121 |
+
total_loss = 0.0
|
| 122 |
+
total_tokens = 0
|
| 123 |
+
num_batches = 0
|
| 124 |
+
|
| 125 |
+
print(f"\n๐ {desc}")
|
| 126 |
+
start_time = time.time()
|
| 127 |
+
|
| 128 |
+
for i, batch in enumerate(dataloader):
|
| 129 |
+
if i >= self.config.max_eval_batches:
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
input_ids = batch["input_ids"].to(device)
|
| 133 |
+
targets = batch["targets"].to(device)
|
| 134 |
+
|
| 135 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 136 |
+
logits, _ = model(input_ids)
|
| 137 |
+
|
| 138 |
+
# ํ ํฐ๋ณ cross-entropy (reduction='none')
|
| 139 |
+
# logits: (B, S, V) โ (B*S, V)
|
| 140 |
+
# targets: (B, S) โ (B*S,)
|
| 141 |
+
loss_per_token = F.cross_entropy(
|
| 142 |
+
logits.view(-1, logits.size(-1)),
|
| 143 |
+
targets.view(-1),
|
| 144 |
+
ignore_index=-100,
|
| 145 |
+
reduction="none",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# -100์ด ์๋ ์ ํจ ํ ํฐ๋ง ์นด์ดํธ
|
| 149 |
+
valid_mask = (targets.view(-1) != -100)
|
| 150 |
+
valid_tokens = valid_mask.sum().item()
|
| 151 |
+
|
| 152 |
+
total_loss += loss_per_token[valid_mask].sum().item()
|
| 153 |
+
total_tokens += valid_tokens
|
| 154 |
+
num_batches += 1
|
| 155 |
+
|
| 156 |
+
if (i + 1) % 20 == 0:
|
| 157 |
+
running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
|
| 158 |
+
print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
|
| 159 |
+
|
| 160 |
+
elapsed = time.time() - start_time
|
| 161 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 162 |
+
perplexity = math.exp(min(avg_loss, 100)) # overflow ๋ฐฉ์ง
|
| 163 |
+
|
| 164 |
+
results = {
|
| 165 |
+
"loss": round(avg_loss, 4),
|
| 166 |
+
"perplexity": round(perplexity, 2),
|
| 167 |
+
"num_tokens": total_tokens,
|
| 168 |
+
"num_batches": num_batches,
|
| 169 |
+
"eval_time_sec": round(elapsed, 1),
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
print(f" โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 173 |
+
print(f" Loss: {results['loss']:.4f}")
|
| 174 |
+
print(f" Perplexity: {results['perplexity']:.2f}")
|
| 175 |
+
print(f" ํ๊ฐ ํ ํฐ: {total_tokens:,}")
|
| 176 |
+
print(f" ์์ ์๊ฐ: {elapsed:.1f}์ด")
|
| 177 |
+
|
| 178 |
+
return results
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def evaluate_per_position(
|
| 182 |
+
self,
|
| 183 |
+
model: nn.Module,
|
| 184 |
+
dataloader: DataLoader,
|
| 185 |
+
device: torch.device,
|
| 186 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 187 |
+
max_batches: int = 50,
|
| 188 |
+
) -> List[float]:
|
| 189 |
+
"""์ํ์ค ๋ด ์์น๋ณ Loss๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 190 |
+
|
| 191 |
+
ํ์ต ํฌ์ธํธ:
|
| 192 |
+
- ์์น 0~10: Loss๊ฐ ๋์ (๋ฌธ๋งฅ์ด ๋ถ์กฑ)
|
| 193 |
+
- ์์น 100+: Loss๊ฐ ์์ ์ ์ผ๋ก ๋ฎ์์ง (๋ฌธ๋งฅ ํ์ฉ)
|
| 194 |
+
- ์ด ํจํด์ด Transformer์ in-context learning ๋ฅ๋ ฅ์ ๋ณด์ฌ์ค
|
| 195 |
+
"""
|
| 196 |
+
model.eval()
|
| 197 |
+
seq_len = None
|
| 198 |
+
position_loss_sum = None
|
| 199 |
+
position_count = None
|
| 200 |
+
|
| 201 |
+
for i, batch in enumerate(dataloader):
|
| 202 |
+
if i >= max_batches:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
input_ids = batch["input_ids"].to(device)
|
| 206 |
+
targets = batch["targets"].to(device)
|
| 207 |
+
B, S = targets.shape
|
| 208 |
+
|
| 209 |
+
if seq_len is None:
|
| 210 |
+
seq_len = S
|
| 211 |
+
position_loss_sum = torch.zeros(S, device=device)
|
| 212 |
+
position_count = torch.zeros(S, device=device)
|
| 213 |
+
|
| 214 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 215 |
+
logits, _ = model(input_ids)
|
| 216 |
+
|
| 217 |
+
# (B, S) ํํ์ ํ ํฐ๋ณ loss
|
| 218 |
+
loss_per_token = F.cross_entropy(
|
| 219 |
+
logits.view(-1, logits.size(-1)),
|
| 220 |
+
targets.view(-1),
|
| 221 |
+
ignore_index=-100,
|
| 222 |
+
reduction="none",
|
| 223 |
+
).view(B, S)
|
| 224 |
+
|
| 225 |
+
valid_mask = (targets != -100).float()
|
| 226 |
+
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
|
| 227 |
+
position_count += valid_mask.sum(dim=0)
|
| 228 |
+
|
| 229 |
+
# ์์น๋ณ ํ๊ท loss
|
| 230 |
+
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
|
| 231 |
+
return position_avg_loss
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# ============================================================================
|
| 235 |
+
# 3. ํ
์คํธ ์์ฑ ํ๊ฐ
|
| 236 |
+
# ============================================================================
|
| 237 |
+
|
| 238 |
+
class GenerationEvaluator:
|
| 239 |
+
"""๋ค์ํ ํ๋กฌํํธ๋ก ํ
์คํธ๋ฅผ ์์ฑํ์ฌ ํ์ง์ ํ๊ฐํฉ๋๋ค.
|
| 240 |
+
|
| 241 |
+
ํ๊ฐ ๊ด์ :
|
| 242 |
+
1) ๋ฌธ๋ฒ์ ์ ํ์ฑ: ์์ด ๋ฌธ๋ฒ์ ๋ง๋ ๋ฌธ์ฅ์ ์์ฑํ๋๊ฐ?
|
| 243 |
+
2) ์ผ๊ด์ฑ: ๋ฌธ๋งฅ์ ์ ์งํ๋ฉฐ ์ด์ด๊ฐ๋๊ฐ?
|
| 244 |
+
3) ๋ค์์ฑ: ๊ฐ์ ํ๋กฌํํธ์ ๋ค๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ๋๊ฐ?
|
| 245 |
+
4) ๋ฐ๋ณต ํํผ: ๊ฐ์ ๊ตฌ์ ์ ๋ฐ๋ณตํ์ง ์๋๊ฐ?
|
| 246 |
+
5) ์ง์ ํํ: ํ์ต ๋ฐ์ดํฐ์ ์ง์์ด ๋ฐ์๋๋๊ฐ?
|
| 247 |
+
|
| 248 |
+
1B ๋ชจ๋ธ์ ํ์ค์ ๊ธฐ๋์น:
|
| 249 |
+
- ๋ฌธ๋ฒ์ ์ผ๋ก ์ฌ๋ฐ๋ฅธ ์์ด ๋ฌธ์ฅ ์์ฑ โ
|
| 250 |
+
- ์งง์ ๋ฌธ๋จ ๋ด ์ผ๊ด์ฑ ์ ์ง โ
|
| 251 |
+
- ๋ณต์กํ ์ถ๋ก ์ด๋ ๊ธด ๋
ผ๋ฆฌ ์ ๊ฐ โ (๋ ํฐ ๋ชจ๋ธ ํ์)
|
| 252 |
+
- ์ฌ์ค์ ์ ํ์ฑ์ ๋ณด์ฅ ์ ๋จ โ ๏ธ
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
# ๋ค์ํ ๋๋ฉ์ธ์ ํ
์คํธ ํ๋กฌํํธ
|
| 256 |
+
DEFAULT_PROMPTS = [
|
| 257 |
+
# โโ ์ผ๋ฐ ์ง์ โโ
|
| 258 |
+
"The theory of relativity states that",
|
| 259 |
+
"In the history of computer science,",
|
| 260 |
+
"The human brain is remarkable because",
|
| 261 |
+
|
| 262 |
+
# โโ ์ค๋ช
/๊ต์ก โโ
|
| 263 |
+
"To understand machine learning, one must first",
|
| 264 |
+
"The water cycle begins when",
|
| 265 |
+
"Photosynthesis is the process by which",
|
| 266 |
+
|
| 267 |
+
# โโ ์์ฌ/์คํ ๋ฆฌ โโ
|
| 268 |
+
"Once upon a time, in a small village near the mountains,",
|
| 269 |
+
"The detective looked at the evidence and realized that",
|
| 270 |
+
|
| 271 |
+
# โโ ์ฝ๋/๊ธฐ์ โโ
|
| 272 |
+
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
|
| 273 |
+
"The most important data structures in programming are",
|
| 274 |
+
|
| 275 |
+
# โโ ์งง์ ์์ฑ โโ
|
| 276 |
+
"The capital of France is",
|
| 277 |
+
"Water boils at a temperature of",
|
| 278 |
+
|
| 279 |
+
# โโ ๊ธด ๋ฌธ๋งฅ โโ
|
| 280 |
+
("Artificial intelligence has transformed many industries. "
|
| 281 |
+
"In healthcare, AI is used for diagnosis and drug discovery. "
|
| 282 |
+
"In finance, it powers algorithmic trading and fraud detection. "
|
| 283 |
+
"Looking ahead, the most promising application of AI is"),
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
def __init__(self, config: EvalConfig):
|
| 287 |
+
self.config = config
|
| 288 |
+
|
| 289 |
+
@torch.no_grad()
|
| 290 |
+
def generate_samples(
|
| 291 |
+
self,
|
| 292 |
+
model: nn.Module,
|
| 293 |
+
tokenizer: Any,
|
| 294 |
+
device: torch.device,
|
| 295 |
+
prompts: Optional[List[str]] = None,
|
| 296 |
+
verbose: bool = True,
|
| 297 |
+
) -> List[Dict[str, Any]]:
|
| 298 |
+
"""ํ๋กฌํํธ๋ณ๋ก ํ
์คํธ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
|
| 302 |
+
"""
|
| 303 |
+
model.eval()
|
| 304 |
+
prompts = prompts or self.DEFAULT_PROMPTS
|
| 305 |
+
results = []
|
| 306 |
+
|
| 307 |
+
if verbose:
|
| 308 |
+
print("\n" + "=" * 70)
|
| 309 |
+
print("๐ ํ
์คํธ ์์ฑ ํ๊ฐ")
|
| 310 |
+
print("=" * 70)
|
| 311 |
+
|
| 312 |
+
for idx, prompt in enumerate(prompts):
|
| 313 |
+
prompt_results = {
|
| 314 |
+
"prompt": prompt,
|
| 315 |
+
"generations": [],
|
| 316 |
+
"metrics": {},
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
if verbose:
|
| 320 |
+
print(f"\n{'โ'*60}")
|
| 321 |
+
print(f"ํ๋กฌํํธ [{idx+1}/{len(prompts)}]:")
|
| 322 |
+
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
|
| 323 |
+
print(f"{'โ'*60}")
|
| 324 |
+
|
| 325 |
+
# ํ๋กฌํํธ ์ธ์ฝ๋ฉ
|
| 326 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 327 |
+
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 328 |
+
|
| 329 |
+
all_texts = []
|
| 330 |
+
for sample_idx in range(self.config.num_samples):
|
| 331 |
+
# ์์ฑ
|
| 332 |
+
generated_ids = model.generate(
|
| 333 |
+
input_tensor,
|
| 334 |
+
max_new_tokens=self.config.max_new_tokens,
|
| 335 |
+
temperature=self.config.temperature,
|
| 336 |
+
top_k=self.config.top_k,
|
| 337 |
+
top_p=self.config.top_p,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
# ๋์ฝ๋ฉ (ํ๋กฌํํธ ์ดํ ๋ถ๋ถ๋ง)
|
| 341 |
+
new_ids = generated_ids[0][len(prompt_ids):].tolist()
|
| 342 |
+
generated_text = tokenizer.decode(new_ids)
|
| 343 |
+
all_texts.append(generated_text)
|
| 344 |
+
|
| 345 |
+
prompt_results["generations"].append(generated_text)
|
| 346 |
+
|
| 347 |
+
if verbose:
|
| 348 |
+
print(f"\n โ๏ธ ์์ฑ #{sample_idx+1}:")
|
| 349 |
+
# ๊น๋ํ ์ถ๋ ฅ (์ค๋ฐ๊ฟ ํฌํจ)
|
| 350 |
+
display_text = generated_text[:500]
|
| 351 |
+
for line in display_text.split("\n"):
|
| 352 |
+
print(f" {line}")
|
| 353 |
+
if len(generated_text) > 500:
|
| 354 |
+
print(f" ... (์ด {len(generated_text)} ๋ฌธ์)")
|
| 355 |
+
|
| 356 |
+
# ์์ฑ ํ์ง ๋ฉํธ๋ฆญ
|
| 357 |
+
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
|
| 358 |
+
|
| 359 |
+
if verbose and prompt_results["metrics"]:
|
| 360 |
+
m = prompt_results["metrics"]
|
| 361 |
+
print(f"\n ๐ ๋ฉํธ๋ฆญ: "
|
| 362 |
+
f"ํ๊ท ๊ธธ์ด={m['avg_length']:.0f}์, "
|
| 363 |
+
f"๋ฐ๋ณต๋ฅ ={m['repetition_rate']:.1%}, "
|
| 364 |
+
f"์ดํ ๋ค์์ฑ={m['lexical_diversity']:.2f}")
|
| 365 |
+
|
| 366 |
+
results.append(prompt_results)
|
| 367 |
+
|
| 368 |
+
return results
|
| 369 |
+
|
| 370 |
+
@staticmethod
|
| 371 |
+
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
|
| 372 |
+
"""์์ฑ ํ
์คํธ์ ํ์ง ๋ฉํธ๋ฆญ์ ๊ณ์ฐํฉ๋๋ค.
|
| 373 |
+
|
| 374 |
+
๋ฉํธ๋ฆญ:
|
| 375 |
+
- avg_length: ํ๊ท ์์ฑ ๊ธธ์ด (๋ฌธ์)
|
| 376 |
+
- avg_word_count: ํ๊ท ๋จ์ด ์
|
| 377 |
+
- repetition_rate: n-gram ๋ฐ๋ณต๋ฅ (๋ฎ์์๋ก ์ข์)
|
| 378 |
+
- lexical_diversity: ๊ณ ์ ๋จ์ด ๋น์จ (๋์์๋ก ๋ค์)
|
| 379 |
+
- sample_diversity: ์ํ ๊ฐ ๋ค์์ฑ (๋ค๋ฅธ ์์ฑ๋ผ๋ฆฌ ์ผ๋ง๋ ๋ค๋ฅธ๊ฐ)
|
| 380 |
+
"""
|
| 381 |
+
if not texts:
|
| 382 |
+
return {}
|
| 383 |
+
|
| 384 |
+
# ๊ธธ์ด
|
| 385 |
+
lengths = [len(t) for t in texts]
|
| 386 |
+
word_counts = [len(t.split()) for t in texts]
|
| 387 |
+
|
| 388 |
+
# ๋ฐ๋ณต๋ฅ (4-gram ๊ธฐ์ค)
|
| 389 |
+
rep_rates = []
|
| 390 |
+
for text in texts:
|
| 391 |
+
words = text.lower().split()
|
| 392 |
+
if len(words) < 4:
|
| 393 |
+
rep_rates.append(0.0)
|
| 394 |
+
continue
|
| 395 |
+
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
|
| 396 |
+
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
|
| 397 |
+
rep_rates.append(1.0 - unique_ratio) # ๋ฐ๋ณต๋ฅ = 1 - ๊ณ ์ ๋น์จ
|
| 398 |
+
|
| 399 |
+
# ์ดํ ๋ค์์ฑ (Type-Token Ratio)
|
| 400 |
+
diversities = []
|
| 401 |
+
for text in texts:
|
| 402 |
+
words = text.lower().split()
|
| 403 |
+
if words:
|
| 404 |
+
diversities.append(len(set(words)) / len(words))
|
| 405 |
+
else:
|
| 406 |
+
diversities.append(0.0)
|
| 407 |
+
|
| 408 |
+
# ์ํ ๊ฐ ๋ค์์ฑ (์์นด๋ ์ ์ฌ๋์ ์ญ)
|
| 409 |
+
sample_div = 0.0
|
| 410 |
+
if len(texts) > 1:
|
| 411 |
+
word_sets = [set(t.lower().split()) for t in texts]
|
| 412 |
+
similarities = []
|
| 413 |
+
for i in range(len(word_sets)):
|
| 414 |
+
for j in range(i+1, len(word_sets)):
|
| 415 |
+
inter = len(word_sets[i] & word_sets[j])
|
| 416 |
+
union = len(word_sets[i] | word_sets[j])
|
| 417 |
+
if union > 0:
|
| 418 |
+
similarities.append(inter / union)
|
| 419 |
+
sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
|
| 420 |
+
|
| 421 |
+
return {
|
| 422 |
+
"avg_length": sum(lengths) / len(lengths),
|
| 423 |
+
"avg_word_count": sum(word_counts) / len(word_counts),
|
| 424 |
+
"repetition_rate": sum(rep_rates) / len(rep_rates),
|
| 425 |
+
"lexical_diversity": sum(diversities) / len(diversities),
|
| 426 |
+
"sample_diversity": round(sample_div, 3),
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# ============================================================================
|
| 431 |
+
# 4. Scaling Law ๋ถ์๊ธฐ
|
| 432 |
+
# ============================================================================
|
| 433 |
+
|
| 434 |
+
class ScalingAnalyzer:
|
| 435 |
+
"""10M โ 100M โ 1B ๋ชจ๋ธ์ Scaling Law๋ฅผ ๋ถ์ํฉ๋๋ค.
|
| 436 |
+
|
| 437 |
+
Chinchilla Scaling Law (2022):
|
| 438 |
+
- ์ต์ ํ์ต: ํ ํฐ ์ โ 20 ร ํ๋ผ๋ฏธํฐ ์
|
| 439 |
+
- Loss โ N^(-ฮฑ) ร D^(-ฮฒ) (N=ํ๋ผ๋ฏธํฐ, D=๋ฐ์ดํฐ)
|
| 440 |
+
- ฮฑ โ 0.076, ฮฒ โ 0.095 (๋
ผ๋ฌธ ๊ธฐ์ค)
|
| 441 |
+
|
| 442 |
+
์ด ๋ถ์์ ๋ชฉ์ :
|
| 443 |
+
- ์ฐ๋ฆฌ ๋ชจ๋ธ์ด Scaling Law๋ฅผ ๋ฐ๋ฅด๋์ง ํ์ธ
|
| 444 |
+
- ๋ ํฐ ๋ชจ๋ธ/๋ ๋ง์ ๋ฐ์ดํฐ์ ํจ๊ณผ๋ฅผ ์์ธก
|
| 445 |
+
- ์ปดํจํ
์์ ๋ฐฐ๋ถ์ ์ต์ ์ ์ดํด
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 449 |
+
self.save_dir = Path(save_dir)
|
| 450 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 451 |
+
|
| 452 |
+
def analyze(
|
| 453 |
+
self,
|
| 454 |
+
model_results: List[Dict[str, Any]],
|
| 455 |
+
) -> Dict[str, Any]:
|
| 456 |
+
"""์ฌ๋ฌ ๋ชจ๋ธ ํฌ๊ธฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ต ๋ถ์ํฉ๋๋ค.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
model_results: [
|
| 460 |
+
{"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
|
| 461 |
+
{"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
|
| 462 |
+
{"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1},
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
๋ถ์ ๊ฒฐ๊ณผ ๋์
๋๋ฆฌ
|
| 467 |
+
"""
|
| 468 |
+
if len(model_results) < 2:
|
| 469 |
+
print("โ ๏ธ Scaling ๋ถ์์๋ ์ต์ 2๊ฐ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ํ์ํฉ๋๋ค.")
|
| 470 |
+
return {}
|
| 471 |
+
|
| 472 |
+
print("\n" + "=" * 70)
|
| 473 |
+
print("๐ Scaling Law ๋ถ์")
|
| 474 |
+
print("=" * 70)
|
| 475 |
+
|
| 476 |
+
# โโ ๊ฒฐ๊ณผ ํ
์ด๋ธ โโ
|
| 477 |
+
print(f"\n {'๋ชจ๋ธ':<8} {'ํ๋ผ๋ฏธํฐ':>12} {'ํ ํฐ':>10} {'Loss':>8} {'PPL':>8}")
|
| 478 |
+
print(f" {'โ'*52}")
|
| 479 |
+
for r in model_results:
|
| 480 |
+
params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
|
| 481 |
+
tokens_str = f"{r['tokens']/1e9:.1f}B"
|
| 482 |
+
print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
|
| 483 |
+
|
| 484 |
+
# โโ Scaling ํจ์จ ๊ณ์ฐ โโ
|
| 485 |
+
analysis = {"models": model_results, "scaling_efficiency": []}
|
| 486 |
+
|
| 487 |
+
for i in range(1, len(model_results)):
|
| 488 |
+
prev = model_results[i-1]
|
| 489 |
+
curr = model_results[i]
|
| 490 |
+
|
| 491 |
+
param_ratio = curr["params"] / prev["params"]
|
| 492 |
+
loss_reduction = prev["loss"] - curr["loss"]
|
| 493 |
+
ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"]
|
| 494 |
+
|
| 495 |
+
efficiency = {
|
| 496 |
+
"from": prev["name"],
|
| 497 |
+
"to": curr["name"],
|
| 498 |
+
"param_multiplier": round(param_ratio, 1),
|
| 499 |
+
"loss_reduction": round(loss_reduction, 4),
|
| 500 |
+
"ppl_reduction_pct": round(ppl_reduction * 100, 1),
|
| 501 |
+
}
|
| 502 |
+
analysis["scaling_efficiency"].append(efficiency)
|
| 503 |
+
|
| 504 |
+
print(f"\n {prev['name']} โ {curr['name']}:")
|
| 505 |
+
print(f" ํ๋ผ๋ฏธํฐ ร{param_ratio:.1f}")
|
| 506 |
+
print(f" Loss ๊ฐ์: {loss_reduction:.4f}")
|
| 507 |
+
print(f" PPL ๊ฐ์: {ppl_reduction*100:.1f}%")
|
| 508 |
+
|
| 509 |
+
# โโ Chinchilla ์ต์ ์ฑ ์ฒดํฌ โโ
|
| 510 |
+
print(f"\n Chinchilla ์ต์ ์ฑ ์ฒดํฌ (ํ ํฐ โ 20 ร ํ๋ผ๋ฏธํฐ):")
|
| 511 |
+
for r in model_results:
|
| 512 |
+
optimal_tokens = r["params"] * 20
|
| 513 |
+
actual_ratio = r["tokens"] / r["params"]
|
| 514 |
+
status = "โ
์ต์ ๋ฒ์" if 15 <= actual_ratio <= 25 else "โ ๏ธ ๋ฒ์ ๋ฐ"
|
| 515 |
+
print(f" {r['name']}: ํ ํฐ/ํ๋ผ๋ฏธํฐ = {actual_ratio:.1f}x "
|
| 516 |
+
f"(์ต์ : 20x) {status}")
|
| 517 |
+
|
| 518 |
+
analysis["chinchilla_ratios"] = [
|
| 519 |
+
{"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
|
| 520 |
+
for r in model_results
|
| 521 |
+
]
|
| 522 |
+
|
| 523 |
+
return analysis
|
| 524 |
+
|
| 525 |
+
def plot_scaling_curves(
|
| 526 |
+
self,
|
| 527 |
+
model_results: List[Dict[str, Any]],
|
| 528 |
+
save_path: Optional[str] = None,
|
| 529 |
+
):
|
| 530 |
+
"""Scaling ๊ณก์ ์ ์๊ฐํํฉ๋๋ค."""
|
| 531 |
+
if not HAS_MATPLOTLIB or not HAS_NUMPY:
|
| 532 |
+
print("โ ๏ธ matplotlib/numpy๊ฐ ํ์ํฉ๋๋ค: pip install matplotlib numpy")
|
| 533 |
+
return
|
| 534 |
+
|
| 535 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 536 |
+
|
| 537 |
+
params = [r["params"] for r in model_results]
|
| 538 |
+
losses = [r["loss"] for r in model_results]
|
| 539 |
+
ppls = [r["ppl"] for r in model_results]
|
| 540 |
+
names = [r["name"] for r in model_results]
|
| 541 |
+
|
| 542 |
+
# โโ Loss vs Parameters (log-log) โโ
|
| 543 |
+
ax = axes[0]
|
| 544 |
+
ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10)
|
| 545 |
+
for p, l, n in zip(params, losses, names):
|
| 546 |
+
ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9)
|
| 547 |
+
ax.set_xlabel("Parameters", fontsize=12)
|
| 548 |
+
ax.set_ylabel("Validation Loss", fontsize=12)
|
| 549 |
+
ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold")
|
| 550 |
+
ax.grid(True, alpha=0.3)
|
| 551 |
+
|
| 552 |
+
# โโ PPL vs Parameters (log-log) โโ
|
| 553 |
+
ax = axes[1]
|
| 554 |
+
ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10)
|
| 555 |
+
for p, pp, n in zip(params, ppls, names):
|
| 556 |
+
ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9)
|
| 557 |
+
ax.set_xlabel("Parameters", fontsize=12)
|
| 558 |
+
ax.set_ylabel("Perplexity", fontsize=12)
|
| 559 |
+
ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold")
|
| 560 |
+
ax.grid(True, alpha=0.3)
|
| 561 |
+
|
| 562 |
+
plt.tight_layout()
|
| 563 |
+
|
| 564 |
+
save_path = save_path or str(self.save_dir / "scaling_curves.png")
|
| 565 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 566 |
+
print(f"\n ๐ Scaling ๊ณก์ ์ ์ฅ: {save_path}")
|
| 567 |
+
plt.close(fig)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# ============================================================================
|
| 571 |
+
# 5. ํ์ต ์ญํ ๋ถ์๊ธฐ
|
| 572 |
+
# ============================================================================
|
| 573 |
+
|
| 574 |
+
class TrainingDynamicsAnalyzer:
|
| 575 |
+
"""ํ์ต ๊ณผ์ ์ ๋ฉํธ๋ฆญ์ ๋ถ์ํ๊ณ ์๊ฐํํฉ๋๋ค.
|
| 576 |
+
|
| 577 |
+
๋ถ์ ํญ๋ชฉ:
|
| 578 |
+
- Loss ๊ณก์ : ์๋ ด ํจํด, ์คํ์ดํฌ ๊ฐ์ง
|
| 579 |
+
- LR ์ค์ผ์ค: Warmup + Cosine decay ํ์ธ
|
| 580 |
+
- Gradient Norm: ํ์ต ์์ ์ฑ, ํญ๋ฐ/์๋ฉธ ๊ฐ์ง
|
| 581 |
+
- ์ฒ๋ฆฌ๋: tokens/sec ์์ ์ฑ, ๋ณ๋ชฉ ๊ฐ์ง
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 585 |
+
self.save_dir = Path(save_dir)
|
| 586 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 587 |
+
|
| 588 |
+
def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
|
| 589 |
+
"""ํ์ต ๋ฉํธ๋ฆญ์ ๋ถ์ํฉ๋๋ค.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
metrics_history: Trainer.metrics.history ๋์
๋๋ฆฌ
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
๋ถ์ ๊ฒฐ๊ณผ
|
| 596 |
+
"""
|
| 597 |
+
print("\n" + "=" * 70)
|
| 598 |
+
print("๐ฌ ํ์ต ์ญํ ๋ถ์")
|
| 599 |
+
print("=" * 70)
|
| 600 |
+
|
| 601 |
+
analysis = {}
|
| 602 |
+
|
| 603 |
+
# โโ Loss ๋ถ์ โโ
|
| 604 |
+
if metrics_history.get("train_loss"):
|
| 605 |
+
losses = metrics_history["train_loss"]
|
| 606 |
+
analysis["loss"] = {
|
| 607 |
+
"initial": round(losses[0], 4),
|
| 608 |
+
"final": round(losses[-1], 4),
|
| 609 |
+
"minimum": round(min(losses), 4),
|
| 610 |
+
"total_reduction": round(losses[0] - losses[-1], 4),
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
# ์คํ์ดํฌ ๊ฐ์ง (์ด์ ๊ฐ ๋๋น 50% ์ด์ ๊ธ์ฆ)
|
| 614 |
+
spikes = []
|
| 615 |
+
for i in range(1, len(losses)):
|
| 616 |
+
if losses[i] > losses[i-1] * 1.5:
|
| 617 |
+
step = metrics_history["step"][i] if "step" in metrics_history else i
|
| 618 |
+
spikes.append({"step": step, "loss": round(losses[i], 4)})
|
| 619 |
+
|
| 620 |
+
analysis["loss"]["spikes"] = spikes
|
| 621 |
+
|
| 622 |
+
print(f"\n ๐ Loss ๋ถ์:")
|
| 623 |
+
print(f" ์ด๊ธฐ: {analysis['loss']['initial']:.4f}")
|
| 624 |
+
print(f" ์ต์ข
: {analysis['loss']['final']:.4f}")
|
| 625 |
+
print(f" ์ต์: {analysis['loss']['minimum']:.4f}")
|
| 626 |
+
print(f" ๊ฐ์: {analysis['loss']['total_reduction']:.4f}")
|
| 627 |
+
print(f" ์คํ์ดํฌ: {len(spikes)}ํ")
|
| 628 |
+
if spikes:
|
| 629 |
+
for s in spikes[:5]:
|
| 630 |
+
print(f" Step {s['step']}: Loss = {s['loss']}")
|
| 631 |
+
|
| 632 |
+
# โโ Gradient Norm ๋ถ์ โโ
|
| 633 |
+
if metrics_history.get("grad_norm"):
|
| 634 |
+
gnorms = metrics_history["grad_norm"]
|
| 635 |
+
analysis["grad_norm"] = {
|
| 636 |
+
"mean": round(sum(gnorms) / len(gnorms), 4),
|
| 637 |
+
"max": round(max(gnorms), 4),
|
| 638 |
+
"min": round(min(gnorms), 4),
|
| 639 |
+
"clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
print(f"\n ๐ Gradient Norm ๋ถ์:")
|
| 643 |
+
print(f" ํ๊ท : {analysis['grad_norm']['mean']:.4f}")
|
| 644 |
+
print(f" ์ต๋: {analysis['grad_norm']['max']:.4f}")
|
| 645 |
+
print(f" ํด๋ฆฌํ ๋น์จ: {analysis['grad_norm']['clipped_pct']:.1f}%")
|
| 646 |
+
if analysis["grad_norm"]["clipped_pct"] > 30:
|
| 647 |
+
print(f" โ ๏ธ ํด๋ฆฌํ์ด ์ฆ์ โ LR ํํฅ ๋๋ warmup ์ฐ์ฅ ๊ณ ๋ ค")
|
| 648 |
+
|
| 649 |
+
# โโ ์ฒ๋ฆฌ๋ ๋ถ์ โโ
|
| 650 |
+
if metrics_history.get("tokens_per_sec"):
|
| 651 |
+
tps = metrics_history["tokens_per_sec"]
|
| 652 |
+
tps_valid = [t for t in tps if t > 0]
|
| 653 |
+
if tps_valid:
|
| 654 |
+
analysis["throughput"] = {
|
| 655 |
+
"mean": round(sum(tps_valid) / len(tps_valid)),
|
| 656 |
+
"std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5),
|
| 657 |
+
"min": round(min(tps_valid)),
|
| 658 |
+
"max": round(max(tps_valid)),
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
print(f"\n โก ์ฒ๋ฆฌ๋ ๋ถ์:")
|
| 662 |
+
print(f" ํ๊ท : {analysis['throughput']['mean']:,} tokens/sec")
|
| 663 |
+
print(f" ํ์คํธ์ฐจ: {analysis['throughput']['std']:,}")
|
| 664 |
+
print(f" ๋ฒ์: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
|
| 665 |
+
|
| 666 |
+
return analysis
|
| 667 |
+
|
| 668 |
+
def plot_training_curves(
|
| 669 |
+
self,
|
| 670 |
+
metrics_history: Dict[str, list],
|
| 671 |
+
save_path: Optional[str] = None,
|
| 672 |
+
):
|
| 673 |
+
"""ํ์ต ๊ณก์ ์ 4-panel ์ฐจํธ๋ก ์๊ฐํํฉ๋๋ค."""
|
| 674 |
+
if not HAS_MATPLOTLIB:
|
| 675 |
+
print("โ ๏ธ matplotlib๊ฐ ํ์ํฉ๋๋ค: pip install matplotlib")
|
| 676 |
+
return
|
| 677 |
+
|
| 678 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
| 679 |
+
fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold")
|
| 680 |
+
|
| 681 |
+
steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", [])))))
|
| 682 |
+
|
| 683 |
+
# โโ (1) Loss โโ
|
| 684 |
+
ax = axes[0, 0]
|
| 685 |
+
if metrics_history.get("train_loss"):
|
| 686 |
+
ax.plot(steps[:len(metrics_history["train_loss"])],
|
| 687 |
+
metrics_history["train_loss"],
|
| 688 |
+
color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
|
| 689 |
+
|
| 690 |
+
# ์ด๋ ํ๊ท (์ค๋ฌด๋ฉ)
|
| 691 |
+
if len(metrics_history["train_loss"]) > 20:
|
| 692 |
+
window = min(50, len(metrics_history["train_loss"]) // 5)
|
| 693 |
+
smoothed = self._moving_average(metrics_history["train_loss"], window)
|
| 694 |
+
ax.plot(steps[window-1:len(smoothed)+window-1],
|
| 695 |
+
smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})")
|
| 696 |
+
|
| 697 |
+
if metrics_history.get("val_loss"):
|
| 698 |
+
val_steps = [steps[i] for i in range(0, len(steps),
|
| 699 |
+
max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])]
|
| 700 |
+
ax.plot(val_steps, metrics_history["val_loss"],
|
| 701 |
+
"o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss")
|
| 702 |
+
|
| 703 |
+
ax.set_xlabel("Step")
|
| 704 |
+
ax.set_ylabel("Loss")
|
| 705 |
+
ax.set_title("Training & Validation Loss")
|
| 706 |
+
ax.legend()
|
| 707 |
+
ax.grid(True, alpha=0.3)
|
| 708 |
+
|
| 709 |
+
# โโ (2) Learning Rate โโ
|
| 710 |
+
ax = axes[0, 1]
|
| 711 |
+
if metrics_history.get("learning_rate"):
|
| 712 |
+
ax.plot(steps[:len(metrics_history["learning_rate"])],
|
| 713 |
+
metrics_history["learning_rate"],
|
| 714 |
+
color="#059669", linewidth=2)
|
| 715 |
+
ax.set_xlabel("Step")
|
| 716 |
+
ax.set_ylabel("Learning Rate")
|
| 717 |
+
ax.set_title("Learning Rate Schedule")
|
| 718 |
+
ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0))
|
| 719 |
+
ax.grid(True, alpha=0.3)
|
| 720 |
+
|
| 721 |
+
# โโ (3) Gradient Norm โโ
|
| 722 |
+
ax = axes[1, 0]
|
| 723 |
+
if metrics_history.get("grad_norm"):
|
| 724 |
+
ax.plot(steps[:len(metrics_history["grad_norm"])],
|
| 725 |
+
metrics_history["grad_norm"],
|
| 726 |
+
color="#d97706", alpha=0.6, linewidth=0.8)
|
| 727 |
+
ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold")
|
| 728 |
+
ax.legend()
|
| 729 |
+
ax.set_xlabel("Step")
|
| 730 |
+
ax.set_ylabel("Gradient Norm")
|
| 731 |
+
ax.set_title("Gradient Norm (clipped at 1.0)")
|
| 732 |
+
ax.grid(True, alpha=0.3)
|
| 733 |
+
|
| 734 |
+
# โโ (4) Throughput โโ
|
| 735 |
+
ax = axes[1, 1]
|
| 736 |
+
if metrics_history.get("tokens_per_sec"):
|
| 737 |
+
tps = metrics_history["tokens_per_sec"]
|
| 738 |
+
ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8)
|
| 739 |
+
if tps:
|
| 740 |
+
avg_tps = sum(tps) / len(tps)
|
| 741 |
+
ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5,
|
| 742 |
+
label=f"Avg: {avg_tps:,.0f}")
|
| 743 |
+
ax.legend()
|
| 744 |
+
ax.set_xlabel("Step")
|
| 745 |
+
ax.set_ylabel("Tokens/sec")
|
| 746 |
+
ax.set_title("Training Throughput")
|
| 747 |
+
ax.grid(True, alpha=0.3)
|
| 748 |
+
|
| 749 |
+
plt.tight_layout()
|
| 750 |
+
|
| 751 |
+
save_path = save_path or str(self.save_dir / "training_curves.png")
|
| 752 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 753 |
+
print(f"\n ๐ ํ์ต ๊ณก์ ์ ์ฅ: {save_path}")
|
| 754 |
+
plt.close(fig)
|
| 755 |
+
|
| 756 |
+
def plot_position_loss(
|
| 757 |
+
self,
|
| 758 |
+
position_losses: List[float],
|
| 759 |
+
save_path: Optional[str] = None,
|
| 760 |
+
):
|
| 761 |
+
"""์์น๋ณ Loss ๋ถํฌ๋ฅผ ์๊ฐํํฉ๋๋ค."""
|
| 762 |
+
if not HAS_MATPLOTLIB:
|
| 763 |
+
return
|
| 764 |
+
|
| 765 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 766 |
+
|
| 767 |
+
positions = list(range(len(position_losses)))
|
| 768 |
+
ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5)
|
| 769 |
+
ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb")
|
| 770 |
+
|
| 771 |
+
ax.set_xlabel("Position in Sequence", fontsize=12)
|
| 772 |
+
ax.set_ylabel("Cross-Entropy Loss", fontsize=12)
|
| 773 |
+
ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
|
| 774 |
+
ax.grid(True, alpha=0.3)
|
| 775 |
+
|
| 776 |
+
# ์ฃผ์ ๊ตฌ๊ฐ ํ์
|
| 777 |
+
if len(position_losses) > 100:
|
| 778 |
+
early_avg = sum(position_losses[:50]) / 50
|
| 779 |
+
late_avg = sum(position_losses[-200:]) / 200
|
| 780 |
+
ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4,
|
| 781 |
+
label=f"Early avg (0-50): {early_avg:.2f}")
|
| 782 |
+
ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4,
|
| 783 |
+
label=f"Late avg (-200): {late_avg:.2f}")
|
| 784 |
+
ax.legend()
|
| 785 |
+
|
| 786 |
+
plt.tight_layout()
|
| 787 |
+
|
| 788 |
+
save_path = save_path or str(self.save_dir / "position_loss.png")
|
| 789 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 790 |
+
print(f" ๐ ์์น๋ณ Loss ์ ์ฅ: {save_path}")
|
| 791 |
+
plt.close(fig)
|
| 792 |
+
|
| 793 |
+
@staticmethod
|
| 794 |
+
def _moving_average(data: list, window: int) -> list:
|
| 795 |
+
"""์ด๋ ํ๊ท ๊ณ์ฐ."""
|
| 796 |
+
result = []
|
| 797 |
+
for i in range(window - 1, len(data)):
|
| 798 |
+
avg = sum(data[i - window + 1 : i + 1]) / window
|
| 799 |
+
result.append(avg)
|
| 800 |
+
return result
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
# ============================================================================
|
| 804 |
+
# 6. Attention ์๊ฐํ
|
| 805 |
+
# ============================================================================
|
| 806 |
+
|
| 807 |
+
class AttentionVisualizer:
|
| 808 |
+
"""Attention ํจํด์ ์๊ฐํํฉ๋๋ค.
|
| 809 |
+
|
| 810 |
+
ํ์ต ํฌ์ธํธ:
|
| 811 |
+
- Causal Mask: ํ์ผ๊ฐ ํจํด (๋ฏธ๋ ํ ํฐ์ ๋ณผ ์ ์์)
|
| 812 |
+
- ํค๋๋ณ ์ญํ ๋ถํ: ์ผ๋ถ๋ ๋ก์ปฌ(์ธ์ ), ์ผ๋ถ๋ ๊ธ๋ก๋ฒ(๋จผ ํ ํฐ) ์ฃผ๋ชฉ
|
| 813 |
+
- ๊ตฌ๋ฌธ๋ก ์ ํจํด: ๋์ฌโ์ฃผ์ด, ๋๋ช
์ฌโ์ ํ์ฌ ๋ฑ์ ๋์ attention
|
| 814 |
+
|
| 815 |
+
์ฃผ์: 1B ๋ชจ๋ธ์ ์ ์ฒด attention์ ์ ์ฅํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ!
|
| 816 |
+
โ ํน์ ๋ ์ด์ด/ํค๋๋ง ์ ํ์ ์ผ๋ก ์๊ฐํํฉ๋๋ค.
|
| 817 |
+
"""
|
| 818 |
+
|
| 819 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 820 |
+
self.save_dir = Path(save_dir)
|
| 821 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 822 |
+
|
| 823 |
+
@torch.no_grad()
|
| 824 |
+
def extract_attention(
|
| 825 |
+
self,
|
| 826 |
+
model: nn.Module,
|
| 827 |
+
input_ids: torch.Tensor,
|
| 828 |
+
layer_idx: int = 0,
|
| 829 |
+
device: torch.device = torch.device("cpu"),
|
| 830 |
+
) -> torch.Tensor:
|
| 831 |
+
"""ํน์ ๋ ์ด์ด์ attention weight๋ฅผ ์ถ์ถํฉ๋๋ค.
|
| 832 |
+
|
| 833 |
+
๋ชจ๋ธ์ attention ๋ชจ๋์ ์ผ์์ ์ผ๋ก ์์ ํ์ฌ
|
| 834 |
+
attention weight๋ฅผ ์บก์ฒํฉ๋๋ค.
|
| 835 |
+
|
| 836 |
+
Returns:
|
| 837 |
+
attention_weights: (num_heads, seq_len, seq_len)
|
| 838 |
+
"""
|
| 839 |
+
model.eval()
|
| 840 |
+
captured_attn = {}
|
| 841 |
+
|
| 842 |
+
# Hook์ผ๋ก attention weight ์บก์ฒ
|
| 843 |
+
target_layer = model.layers[layer_idx].attention
|
| 844 |
+
|
| 845 |
+
# scaled_dot_product_attention์ ์๋ ๊ตฌํ์ผ๋ก ๋์ฒด
|
| 846 |
+
original_forward = target_layer.forward
|
| 847 |
+
|
| 848 |
+
def hooked_forward(x, mask=None, position_offset=0):
|
| 849 |
+
B, S, _ = x.shape
|
| 850 |
+
hd = target_layer.head_dim
|
| 851 |
+
|
| 852 |
+
q = target_layer.q_proj(x).view(B, S, target_layer.num_heads, hd).transpose(1, 2)
|
| 853 |
+
k = target_layer.k_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
|
| 854 |
+
v = target_layer.v_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
|
| 855 |
+
|
| 856 |
+
q, k = target_layer.rope(q, k, position_offset)
|
| 857 |
+
|
| 858 |
+
if target_layer.num_kv_groups > 1:
|
| 859 |
+
k = target_layer._repeat_kv(k)
|
| 860 |
+
v = target_layer._repeat_kv(v)
|
| 861 |
+
|
| 862 |
+
# ์๋ attention ๊ณ์ฐ (weight ์ถ์ถ์ฉ)
|
| 863 |
+
scale = 1.0 / math.sqrt(hd)
|
| 864 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 865 |
+
|
| 866 |
+
# Causal mask
|
| 867 |
+
causal = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
|
| 868 |
+
scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 869 |
+
|
| 870 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 871 |
+
captured_attn["weights"] = attn_weights[0].cpu() # ์ฒซ ๋ฐฐ์น๋ง
|
| 872 |
+
|
| 873 |
+
out = torch.matmul(attn_weights, v)
|
| 874 |
+
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 875 |
+
return target_layer.o_proj(out)
|
| 876 |
+
|
| 877 |
+
# Hook ์ ์ฉ
|
| 878 |
+
target_layer.forward = hooked_forward
|
| 879 |
+
|
| 880 |
+
try:
|
| 881 |
+
model(input_ids.to(device))
|
| 882 |
+
finally:
|
| 883 |
+
target_layer.forward = original_forward
|
| 884 |
+
|
| 885 |
+
return captured_attn.get("weights") # (num_heads, S, S)
|
| 886 |
+
|
| 887 |
+
def plot_attention_heatmap(
|
| 888 |
+
self,
|
| 889 |
+
attn_weights: torch.Tensor,
|
| 890 |
+
tokens: List[str],
|
| 891 |
+
head_idx: int = 0,
|
| 892 |
+
save_path: Optional[str] = None,
|
| 893 |
+
title: str = "Attention Weights",
|
| 894 |
+
):
|
| 895 |
+
"""Attention heatmap์ ๊ทธ๋ฆฝ๋๋ค."""
|
| 896 |
+
if not HAS_MATPLOTLIB:
|
| 897 |
+
print("โ ๏ธ matplotlib๊ฐ ํ์ํฉ๋๋ค")
|
| 898 |
+
return
|
| 899 |
+
|
| 900 |
+
weights = attn_weights[head_idx].numpy()
|
| 901 |
+
max_len = min(len(tokens), 50) # ์ต๋ 50 ํ ํฐ๋ง ํ์
|
| 902 |
+
weights = weights[:max_len, :max_len]
|
| 903 |
+
display_tokens = tokens[:max_len]
|
| 904 |
+
|
| 905 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 906 |
+
im = ax.imshow(weights, cmap="Blues", aspect="auto")
|
| 907 |
+
|
| 908 |
+
ax.set_xticks(range(max_len))
|
| 909 |
+
ax.set_yticks(range(max_len))
|
| 910 |
+
ax.set_xticklabels(display_tokens, rotation=90, fontsize=7)
|
| 911 |
+
ax.set_yticklabels(display_tokens, fontsize=7)
|
| 912 |
+
|
| 913 |
+
ax.set_xlabel("Key (attended to)", fontsize=11)
|
| 914 |
+
ax.set_ylabel("Query (attending from)", fontsize=11)
|
| 915 |
+
ax.set_title(f"{title} โ Head {head_idx}", fontsize=13, fontweight="bold")
|
| 916 |
+
|
| 917 |
+
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 918 |
+
plt.tight_layout()
|
| 919 |
+
|
| 920 |
+
save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
|
| 921 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 922 |
+
print(f" ๐ Attention ์๊ฐํ ์ ์ฅ: {save_path}")
|
| 923 |
+
plt.close(fig)
|
| 924 |
+
|
| 925 |
+
def plot_multi_head_summary(
|
| 926 |
+
self,
|
| 927 |
+
attn_weights: torch.Tensor,
|
| 928 |
+
num_heads_to_show: int = 8,
|
| 929 |
+
save_path: Optional[str] = None,
|
| 930 |
+
):
|
| 931 |
+
"""์ฌ๋ฌ ํค๋์ attention ํจํด์ ์์ฝ ๋น๊ตํฉ๋๋ค."""
|
| 932 |
+
if not HAS_MATPLOTLIB:
|
| 933 |
+
return
|
| 934 |
+
|
| 935 |
+
n_heads = min(attn_weights.shape[0], num_heads_to_show)
|
| 936 |
+
cols = 4
|
| 937 |
+
rows = math.ceil(n_heads / cols)
|
| 938 |
+
|
| 939 |
+
fig, axes = plt.subplots(rows, cols, figsize=(16, 4 * rows))
|
| 940 |
+
if rows == 1:
|
| 941 |
+
axes = axes.reshape(1, -1)
|
| 942 |
+
|
| 943 |
+
for idx in range(n_heads):
|
| 944 |
+
r, c = idx // cols, idx % cols
|
| 945 |
+
ax = axes[r, c]
|
| 946 |
+
w = attn_weights[idx].numpy()
|
| 947 |
+
ax.imshow(w, cmap="Blues", aspect="auto")
|
| 948 |
+
ax.set_title(f"Head {idx}", fontsize=10)
|
| 949 |
+
ax.set_xticks([])
|
| 950 |
+
ax.set_yticks([])
|
| 951 |
+
|
| 952 |
+
# ๋น subplot ์จ๊ธฐ๊ธฐ
|
| 953 |
+
for idx in range(n_heads, rows * cols):
|
| 954 |
+
r, c = idx // cols, idx % cols
|
| 955 |
+
axes[r, c].axis("off")
|
| 956 |
+
|
| 957 |
+
fig.suptitle("Attention Patterns by Head", fontsize=14, fontweight="bold")
|
| 958 |
+
plt.tight_layout()
|
| 959 |
+
|
| 960 |
+
save_path = save_path or str(self.save_dir / "attention_multi_head.png")
|
| 961 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 962 |
+
print(f" ๐ ๋ฉํฐ ํค๋ ์์ฝ ์ ์ฅ: {save_path}")
|
| 963 |
+
plt.close(fig)
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
# ============================================================================
|
| 967 |
+
# 7. ์ข
ํฉ ํ๊ฐ ์คํ๊ธฐ
|
| 968 |
+
# ============================================================================
|
| 969 |
+
|
| 970 |
+
class FullEvaluator:
|
| 971 |
+
"""๋ชจ๋ ํ๊ฐ๋ฅผ ํ ๋ฒ์ ์คํํ๊ณ ๋ฆฌํฌํธ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 972 |
+
|
| 973 |
+
์ฌ์ฉ๋ฒ:
|
| 974 |
+
```python
|
| 975 |
+
evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
|
| 976 |
+
report = evaluator.run_full_evaluation()
|
| 977 |
+
```
|
| 978 |
+
"""
|
| 979 |
+
|
| 980 |
+
def __init__(
|
| 981 |
+
self,
|
| 982 |
+
model: nn.Module,
|
| 983 |
+
tokenizer: Any,
|
| 984 |
+
val_dataloader: DataLoader,
|
| 985 |
+
device: torch.device,
|
| 986 |
+
config: Optional[EvalConfig] = None,
|
| 987 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 988 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 989 |
+
):
|
| 990 |
+
self.model = model
|
| 991 |
+
self.tokenizer = tokenizer
|
| 992 |
+
self.val_dataloader = val_dataloader
|
| 993 |
+
self.device = device
|
| 994 |
+
self.config = config or EvalConfig()
|
| 995 |
+
self.dtype = dtype
|
| 996 |
+
self.metrics_history = metrics_history
|
| 997 |
+
|
| 998 |
+
self.save_dir = Path(self.config.save_dir)
|
| 999 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 1000 |
+
|
| 1001 |
+
def run_full_evaluation(self) -> Dict[str, Any]:
|
| 1002 |
+
"""์ ์ฒด ํ๊ฐ๋ฅผ ์คํํฉ๋๋ค."""
|
| 1003 |
+
report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
|
| 1004 |
+
|
| 1005 |
+
print("\n" + "=" * 70)
|
| 1006 |
+
print("๐ ์ข
ํฉ ํ๊ฐ ์์")
|
| 1007 |
+
print("=" * 70)
|
| 1008 |
+
|
| 1009 |
+
# โโ 1. Perplexity โโ
|
| 1010 |
+
print("\n" + "โ" * 40)
|
| 1011 |
+
print("Phase 1/4: Perplexity ์ธก์ ")
|
| 1012 |
+
print("โ" * 40)
|
| 1013 |
+
ppl_evaluator = PerplexityEvaluator(self.config)
|
| 1014 |
+
report["perplexity"] = ppl_evaluator.evaluate(
|
| 1015 |
+
self.model, self.val_dataloader, self.device, self.dtype
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# ์์น๋ณ Loss
|
| 1019 |
+
print("\n ์์น๋ณ Loss ์ธก์ ์ค...")
|
| 1020 |
+
position_losses = ppl_evaluator.evaluate_per_position(
|
| 1021 |
+
self.model, self.val_dataloader, self.device, self.dtype
|
| 1022 |
+
)
|
| 1023 |
+
report["position_losses"] = {
|
| 1024 |
+
"early_avg": round(sum(position_losses[:50]) / max(len(position_losses[:50]), 1), 4),
|
| 1025 |
+
"late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
# ์์น๋ณ Loss ์๊ฐํ
|
| 1029 |
+
dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
|
| 1030 |
+
dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
|
| 1031 |
+
|
| 1032 |
+
# โโ 2. ํ
์คํธ ์์ฑ โโ
|
| 1033 |
+
print("\n" + "โ" * 40)
|
| 1034 |
+
print("Phase 2/4: ํ
์คํธ ์์ฑ")
|
| 1035 |
+
print("โ" * 40)
|
| 1036 |
+
gen_evaluator = GenerationEvaluator(self.config)
|
| 1037 |
+
gen_results = gen_evaluator.generate_samples(
|
| 1038 |
+
self.model, self.tokenizer, self.device
|
| 1039 |
+
)
|
| 1040 |
+
report["generation"] = {
|
| 1041 |
+
"num_prompts": len(gen_results),
|
| 1042 |
+
"avg_metrics": self._average_gen_metrics(gen_results),
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
# โโ 3. ํ์ต ์ญํ ๋ถ์ โโ
|
| 1046 |
+
if self.metrics_history:
|
| 1047 |
+
print("\n" + "โ" * 40)
|
| 1048 |
+
print("Phase 3/4: ํ์ต ์ญํ ๋ถ์")
|
| 1049 |
+
print("โ" * 40)
|
| 1050 |
+
report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
|
| 1051 |
+
dynamics.plot_training_curves(self.metrics_history,
|
| 1052 |
+
str(self.save_dir / "training_curves.png"))
|
| 1053 |
+
else:
|
| 1054 |
+
print("\n Phase 3/4: ๊ฑด๋๋ (metrics_history ์์)")
|
| 1055 |
+
|
| 1056 |
+
# โโ 4. Attention ์๊ฐํ (์ํ) โโ
|
| 1057 |
+
print("\n" + "โ" * 40)
|
| 1058 |
+
print("Phase 4/4: Attention ์๊ฐํ")
|
| 1059 |
+
print("โ" * 40)
|
| 1060 |
+
try:
|
| 1061 |
+
self._visualize_attention_sample()
|
| 1062 |
+
except Exception as e:
|
| 1063 |
+
print(f" โ ๏ธ Attention ์๊ฐํ ์คํจ: {e}")
|
| 1064 |
+
|
| 1065 |
+
# โโ ๋ฆฌํฌํธ ์ ์ฅ โโ
|
| 1066 |
+
report_path = self.save_dir / "eval_report.json"
|
| 1067 |
+
with open(report_path, "w") as f:
|
| 1068 |
+
json.dump(report, f, indent=2, default=str)
|
| 1069 |
+
print(f"\n๐ ๋ฆฌํฌํธ ์ ์ฅ: {report_path}")
|
| 1070 |
+
|
| 1071 |
+
# โโ ์์ฝ ์ถ๋ ฅ โโ
|
| 1072 |
+
self._print_summary(report)
|
| 1073 |
+
|
| 1074 |
+
return report
|
| 1075 |
+
|
| 1076 |
+
def _visualize_attention_sample(self):
|
| 1077 |
+
"""์ํ ํ
์คํธ๋ก attention์ ์๊ฐํํฉ๋๋ค."""
|
| 1078 |
+
viz = AttentionVisualizer(str(self.save_dir))
|
| 1079 |
+
|
| 1080 |
+
sample_text = "The cat sat on the mat and looked at the bird."
|
| 1081 |
+
token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
|
| 1082 |
+
input_tensor = torch.tensor([token_ids], dtype=torch.long)
|
| 1083 |
+
|
| 1084 |
+
# ํ ํฐ ๋ฌธ์์ด (์๊ฐํ ๋ผ๋ฒจ์ฉ)
|
| 1085 |
+
tokens_str = []
|
| 1086 |
+
for tid in token_ids:
|
| 1087 |
+
decoded = self.tokenizer.decode([tid])
|
| 1088 |
+
tokens_str.append(decoded.replace("\n", "\\n"))
|
| 1089 |
+
|
| 1090 |
+
# Layer 0 attention ์ถ์ถ
|
| 1091 |
+
attn_weights = viz.extract_attention(
|
| 1092 |
+
self.model, input_tensor, layer_idx=0, device=self.device
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
if attn_weights is not None:
|
| 1096 |
+
viz.plot_attention_heatmap(
|
| 1097 |
+
attn_weights, tokens_str, head_idx=0,
|
| 1098 |
+
title="Layer 0 Attention"
|
| 1099 |
+
)
|
| 1100 |
+
viz.plot_multi_head_summary(attn_weights)
|
| 1101 |
+
|
| 1102 |
+
@staticmethod
|
| 1103 |
+
def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
|
| 1104 |
+
"""๋ชจ๋ ํ๋กฌํํธ์ ์์ฑ ๋ฉํธ๋ฆญ ํ๊ท ."""
|
| 1105 |
+
if not gen_results:
|
| 1106 |
+
return {}
|
| 1107 |
+
|
| 1108 |
+
all_metrics = [r["metrics"] for r in gen_results if r.get("metrics")]
|
| 1109 |
+
if not all_metrics:
|
| 1110 |
+
return {}
|
| 1111 |
+
|
| 1112 |
+
keys = all_metrics[0].keys()
|
| 1113 |
+
return {
|
| 1114 |
+
k: round(sum(m.get(k, 0) for m in all_metrics) / len(all_metrics), 3)
|
| 1115 |
+
for k in keys
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
def _print_summary(self, report: Dict[str, Any]):
|
| 1119 |
+
"""์ต์ข
์์ฝ์ ์ถ๋ ฅํฉ๋๋ค."""
|
| 1120 |
+
print("\n" + "=" * 70)
|
| 1121 |
+
print("๐ ํ๊ฐ ์์ฝ ๋ฆฌํฌํธ")
|
| 1122 |
+
print("=" * 70)
|
| 1123 |
+
|
| 1124 |
+
# Perplexity
|
| 1125 |
+
if "perplexity" in report:
|
| 1126 |
+
ppl = report["perplexity"]
|
| 1127 |
+
print(f"\n ๐ฏ Perplexity:")
|
| 1128 |
+
print(f" Loss: {ppl['loss']:.4f}")
|
| 1129 |
+
print(f" PPL: {ppl['perplexity']:.2f}")
|
| 1130 |
+
|
| 1131 |
+
# ๋ฑ๊ธ ํ์
|
| 1132 |
+
ppl_val = ppl["perplexity"]
|
| 1133 |
+
if ppl_val < 20:
|
| 1134 |
+
grade = "๐ ์ฐ์ (Strong)"
|
| 1135 |
+
elif ppl_val < 35:
|
| 1136 |
+
grade = "โ
์ํธ (Good)"
|
| 1137 |
+
elif ppl_val < 60:
|
| 1138 |
+
grade = "โ ๏ธ ๋ณดํต (Fair)"
|
| 1139 |
+
else:
|
| 1140 |
+
grade = "โ ๋ฏธํก (ํ์ต ์ถ๊ฐ ํ์)"
|
| 1141 |
+
print(f" ๋ฑ๊ธ: {grade}")
|
| 1142 |
+
|
| 1143 |
+
# ์์น๋ณ Loss
|
| 1144 |
+
if "position_losses" in report:
|
| 1145 |
+
pl = report["position_losses"]
|
| 1146 |
+
print(f"\n ๐ ์์น๋ณ Loss:")
|
| 1147 |
+
print(f" ์ด๋ฐ (0-50): {pl['early_avg']:.4f}")
|
| 1148 |
+
print(f" ํ๋ฐ (-200): {pl['late_avg']:.4f}")
|
| 1149 |
+
print(f" ์ปจํ
์คํธ ํจ๊ณผ: {pl['early_avg'] - pl['late_avg']:.4f} ๊ฐ์")
|
| 1150 |
+
|
| 1151 |
+
# ์์ฑ ํ์ง
|
| 1152 |
+
if "generation" in report and report["generation"].get("avg_metrics"):
|
| 1153 |
+
gm = report["generation"]["avg_metrics"]
|
| 1154 |
+
print(f"\n โ๏ธ ์์ฑ ํ์ง:")
|
| 1155 |
+
print(f" ํ๊ท ๊ธธ์ด: {gm.get('avg_length', 0):.0f} ์")
|
| 1156 |
+
print(f" ๋ฐ๋ณต๋ฅ : {gm.get('repetition_rate', 0):.1%}")
|
| 1157 |
+
print(f" ์ดํ ๋ค์์ฑ: {gm.get('lexical_diversity', 0):.3f}")
|
| 1158 |
+
|
| 1159 |
+
# ํ์ต ์ญํ
|
| 1160 |
+
if "training_dynamics" in report:
|
| 1161 |
+
td = report["training_dynamics"]
|
| 1162 |
+
if "loss" in td:
|
| 1163 |
+
print(f"\n ๐ ํ์ต ์ญํ:")
|
| 1164 |
+
print(f" Loss ๊ฐ์: {td['loss']['initial']:.4f} โ {td['loss']['final']:.4f}")
|
| 1165 |
+
print(f" ์คํ์ดํฌ: {len(td['loss']['spikes'])}ํ")
|
| 1166 |
+
|
| 1167 |
+
# ์์ฑ๋ ํ์ผ
|
| 1168 |
+
print(f"\n ๐ ๊ฒฐ๊ณผ ํ์ผ:")
|
| 1169 |
+
for f in sorted(self.save_dir.glob("*")):
|
| 1170 |
+
size = f.stat().st_size / 1024
|
| 1171 |
+
print(f" {f.name} ({size:.1f} KB)")
|
| 1172 |
+
|
| 1173 |
+
print("\n" + "=" * 70)
|
| 1174 |
+
|
| 1175 |
+
|
| 1176 |
+
# ============================================================================
|
| 1177 |
+
# 8. ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ ๊ฒ์ฆ๊ธฐ
|
| 1178 |
+
# ============================================================================
|
| 1179 |
+
|
| 1180 |
+
class InsightChecklist:
|
| 1181 |
+
"""PRD์ ์ ์๋ ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ๋ฅผ ์๋/์๋์ผ๋ก ๊ฒ์ฆํฉ๋๋ค.
|
| 1182 |
+
|
| 1183 |
+
์๋ ๊ฒ์ฆ ๊ฐ๋ฅ ํญ๋ชฉ์ ๋ฉํธ๋ฆญ ๊ธฐ๋ฐ์ผ๋ก ํ์ ํ๊ณ ,
|
| 1184 |
+
์๋ ํญ๋ชฉ์ ์ง๋ฌธ์ผ๋ก ์ ์ํฉ๋๋ค.
|
| 1185 |
+
"""
|
| 1186 |
+
|
| 1187 |
+
@staticmethod
|
| 1188 |
+
def run_checklist(
|
| 1189 |
+
report: Dict[str, Any],
|
| 1190 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 1191 |
+
):
|
| 1192 |
+
"""์ฒดํฌ๋ฆฌ์คํธ๋ฅผ ์คํํฉ๋๋ค."""
|
| 1193 |
+
print("\n" + "=" * 70)
|
| 1194 |
+
print("โ
ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ")
|
| 1195 |
+
print("=" * 70)
|
| 1196 |
+
|
| 1197 |
+
checks = {
|
| 1198 |
+
"passed": [],
|
| 1199 |
+
"failed": [],
|
| 1200 |
+
"manual": [],
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
# โโ ์๋ ๊ฒ์ฆ โโ
|
| 1204 |
+
|
| 1205 |
+
# 1. Loss ์๋ ด
|
| 1206 |
+
if report.get("perplexity", {}).get("loss", 99) < 4.0:
|
| 1207 |
+
checks["passed"].append("๋ชจ๋ธ Loss๊ฐ 4.0 ์ดํ๋ก ์๋ ด")
|
| 1208 |
+
else:
|
| 1209 |
+
checks["failed"].append("๋ชจ๋ธ Loss๊ฐ 4.0 ์ดํ๋ก ๋ฏธ์๋ ด")
|
| 1210 |
+
|
| 1211 |
+
# 2. Loss ์คํ์ดํฌ
|
| 1212 |
+
spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
|
| 1213 |
+
if len(spikes) < 5:
|
| 1214 |
+
checks["passed"].append(f"Loss ์คํ์ดํฌ {len(spikes)}ํ (< 5ํ)")
|
| 1215 |
+
else:
|
| 1216 |
+
checks["failed"].append(f"Loss ์คํ์ดํฌ {len(spikes)}ํ (โฅ 5ํ, ์์ ์ฑ ๊ฐ์ ํ์)")
|
| 1217 |
+
|
| 1218 |
+
# 3. ์์น๋ณ Loss ํจํด
|
| 1219 |
+
if report.get("position_losses"):
|
| 1220 |
+
early = report["position_losses"]["early_avg"]
|
| 1221 |
+
late = report["position_losses"]["late_avg"]
|
| 1222 |
+
if early > late:
|
| 1223 |
+
checks["passed"].append("์์น๋ณ Loss ๊ฐ์ ํจํด ํ์ธ (์ปจํ
์คํธ ํ์ฉ)")
|
| 1224 |
+
else:
|
| 1225 |
+
checks["failed"].append("์์น๋ณ Loss ํจํด ์ด์ (์ปจํ
์คํธ ๋ฏธํ์ฉ?)")
|
| 1226 |
+
|
| 1227 |
+
# 4. ์์ฑ ๋ฐ๋ณต๋ฅ
|
| 1228 |
+
rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
|
| 1229 |
+
if rep < 0.3:
|
| 1230 |
+
checks["passed"].append(f"์์ฑ ๋ฐ๋ณต๋ฅ {rep:.1%} (< 30%)")
|
| 1231 |
+
else:
|
| 1232 |
+
checks["failed"].append(f"์์ฑ ๋ฐ๋ณต๋ฅ {rep:.1%} (โฅ 30%, temperature/top_p ์กฐ์ )")
|
| 1233 |
+
|
| 1234 |
+
# 5. Gradient ํด๋ฆฌํ ๋น์จ
|
| 1235 |
+
if metrics_history and metrics_history.get("grad_norm"):
|
| 1236 |
+
gnorms = metrics_history["grad_norm"]
|
| 1237 |
+
clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
|
| 1238 |
+
if clip_rate < 0.3:
|
| 1239 |
+
checks["passed"].append(f"Gradient ํด๋ฆฌํ ๋น์จ {clip_rate:.1%} (๊ฑด๊ฐ)")
|
| 1240 |
+
else:
|
| 1241 |
+
checks["failed"].append(f"Gradient ํด๋ฆฌํ ๋น์จ {clip_rate:.1%} (๋๋ฌด ์ฆ์)")
|
| 1242 |
+
|
| 1243 |
+
# โโ ์๋ ํ์ธ ํญ๋ชฉ โโ
|
| 1244 |
+
manual_items = [
|
| 1245 |
+
"Self-Attention์์ Q, K, V ๊ฐ๊ฐ์ ์ญํ ์ ์ค๋ช
ํ ์ ์๋๊ฐ?",
|
| 1246 |
+
"RoPE๊ฐ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉํ๋ ์ํ์ ์๋ฆฌ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 1247 |
+
"GQA๊ฐ MHA ๋๋น ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ๋ ๋ฉ์ปค๋์ฆ์ ์ค๋ช
ํ ์ ์๋๊ฐ?",
|
| 1248 |
+
"SwiGLU์ ๊ฒ์ดํ
๋ฉ์ปค๋์ฆ์ด ReLU FFN๊ณผ ์ด๋ป๊ฒ ๋ค๋ฅธ์ง ์ดํดํ๋๊ฐ?",
|
| 1249 |
+
"Learning Rate Warmup์ด ์ ํ์ํ์ง ์ฒด๊ฐํ๋๊ฐ?",
|
| 1250 |
+
"Gradient Accumulation์ด ํฐ ๋ฐฐ์น๋ฅผ ์๋ฎฌ๋ ์ด์
ํ๋ ์๋ฆฌ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 1251 |
+
"Mixed Precision(bf16)์ ๋ฉ๋ชจ๋ฆฌ-์๋ ํจ๊ณผ๋ฅผ ์ธก์ ํ๋๊ฐ?",
|
| 1252 |
+
"Activation Checkpointing์ ๋ฉ๋ชจ๋ฆฌ-์ฐ์ฐ ํธ๋ ์ด๋์คํ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 1253 |
+
]
|
| 1254 |
+
checks["manual"] = manual_items
|
| 1255 |
+
|
| 1256 |
+
# โโ ์ถ๋ ฅ โโ
|
| 1257 |
+
total_auto = len(checks["passed"]) + len(checks["failed"])
|
| 1258 |
+
passed_auto = len(checks["passed"])
|
| 1259 |
+
|
| 1260 |
+
print(f"\n ์๋ ๊ฒ์ฆ: {passed_auto}/{total_auto} ํต๊ณผ")
|
| 1261 |
+
for item in checks["passed"]:
|
| 1262 |
+
print(f" โ
{item}")
|
| 1263 |
+
for item in checks["failed"]:
|
| 1264 |
+
print(f" โ {item}")
|
| 1265 |
+
|
| 1266 |
+
print(f"\n ์๋ ํ์ธ ({len(manual_items)} ํญ๋ชฉ):")
|
| 1267 |
+
for i, item in enumerate(manual_items, 1):
|
| 1268 |
+
print(f" {i}. [ ] {item}")
|
| 1269 |
+
|
| 1270 |
+
print(f"\n ์ด ์งํ๋ฅ : {passed_auto}/{total_auto + len(manual_items)} "
|
| 1271 |
+
f"(์๋ ํญ๋ชฉ ํฌํจ ์)")
|
| 1272 |
+
|
| 1273 |
+
return checks
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
# ============================================================================
|
| 1277 |
+
# 9. Quick Start
|
| 1278 |
+
# ============================================================================
|
| 1279 |
+
|
| 1280 |
+
def run_evaluation(
|
| 1281 |
+
model: nn.Module,
|
| 1282 |
+
tokenizer: Any,
|
| 1283 |
+
val_dataloader: DataLoader,
|
| 1284 |
+
device: torch.device = None,
|
| 1285 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 1286 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 1287 |
+
config: Optional[EvalConfig] = None,
|
| 1288 |
+
) -> Dict[str, Any]:
|
| 1289 |
+
"""ํ๊ฐ๋ฅผ ํ ๋ฒ์ ์คํํฉ๋๋ค.
|
| 1290 |
+
|
| 1291 |
+
์ฌ์ฉ๋ฒ (Colab):
|
| 1292 |
+
```python
|
| 1293 |
+
from evaluation import run_evaluation
|
| 1294 |
+
|
| 1295 |
+
# ํ์ต ์๋ฃ ํ
|
| 1296 |
+
report = run_evaluation(
|
| 1297 |
+
model=trainer.model,
|
| 1298 |
+
tokenizer=tokenizer,
|
| 1299 |
+
val_dataloader=val_dl,
|
| 1300 |
+
metrics_history=trainer.metrics.history,
|
| 1301 |
+
)
|
| 1302 |
+
```
|
| 1303 |
+
"""
|
| 1304 |
+
if device is None:
|
| 1305 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1306 |
+
|
| 1307 |
+
evaluator = FullEvaluator(
|
| 1308 |
+
model=model,
|
| 1309 |
+
tokenizer=tokenizer,
|
| 1310 |
+
val_dataloader=val_dataloader,
|
| 1311 |
+
device=device,
|
| 1312 |
+
config=config,
|
| 1313 |
+
dtype=dtype,
|
| 1314 |
+
metrics_history=metrics_history,
|
| 1315 |
+
)
|
| 1316 |
+
|
| 1317 |
+
report = evaluator.run_full_evaluation()
|
| 1318 |
+
|
| 1319 |
+
# ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ
|
| 1320 |
+
InsightChecklist.run_checklist(report, metrics_history)
|
| 1321 |
+
|
| 1322 |
+
return report
|
| 1323 |
+
|
| 1324 |
+
|
| 1325 |
+
# ============================================================================
|
| 1326 |
+
# 10. ๊ฒ์ฆ ์คํฌ๋ฆฝํธ
|
| 1327 |
+
# ============================================================================
|
| 1328 |
+
|
| 1329 |
+
if __name__ == "__main__":
|
| 1330 |
+
print("=" * 70)
|
| 1331 |
+
print("LLM-1B-Lab: ํ๊ฐ ๋ชจ๋ ๊ฒ์ฆ")
|
| 1332 |
+
print("=" * 70)
|
| 1333 |
+
|
| 1334 |
+
# โโ ๋๋ฏธ ๋ชจ๋ธ๋ก ๊ตฌ์กฐ ๊ฒ์ฆ โโ
|
| 1335 |
+
class TinyModel(nn.Module):
|
| 1336 |
+
def __init__(self, vocab=100, dim=64):
|
| 1337 |
+
super().__init__()
|
| 1338 |
+
self.emb = nn.Embedding(vocab, dim)
|
| 1339 |
+
self.linear = nn.Linear(dim, vocab)
|
| 1340 |
+
self.linear.weight = self.emb.weight
|
| 1341 |
+
self.layers = nn.ModuleList() # attention ์๊ฐํ ํธํ์ฉ
|
| 1342 |
+
|
| 1343 |
+
def forward(self, input_ids, targets=None):
|
| 1344 |
+
h = self.emb(input_ids)
|
| 1345 |
+
logits = self.linear(h)
|
| 1346 |
+
loss = None
|
| 1347 |
+
if targets is not None:
|
| 1348 |
+
loss = F.cross_entropy(logits.view(-1, 100), targets.view(-1))
|
| 1349 |
+
return logits, loss
|
| 1350 |
+
|
| 1351 |
+
def generate(self, input_ids, max_new_tokens=20, temperature=1.0, top_k=50, top_p=0.9):
|
| 1352 |
+
generated = input_ids
|
| 1353 |
+
for _ in range(max_new_tokens):
|
| 1354 |
+
logits, _ = self(generated[:, -64:])
|
| 1355 |
+
next_logits = logits[:, -1, :] / temperature
|
| 1356 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 1357 |
+
nxt = torch.multinomial(probs, 1)
|
| 1358 |
+
generated = torch.cat([generated, nxt], dim=1)
|
| 1359 |
+
return generated
|
| 1360 |
+
|
| 1361 |
+
model = TinyModel()
|
| 1362 |
+
device = torch.device("cpu")
|
| 1363 |
+
|
| 1364 |
+
# ๋๋ฏธ ํ ํฌ๋์ด์
|
| 1365 |
+
class DummyTok:
|
| 1366 |
+
eos_id = 2
|
| 1367 |
+
vocab_size = 100
|
| 1368 |
+
def encode(self, t, add_special_tokens=False):
|
| 1369 |
+
return [min(ord(c), 99) for c in t]
|
| 1370 |
+
def decode(self, ids):
|
| 1371 |
+
return "".join(chr(max(min(i, 122), 32)) for i in ids if i > 2)
|
| 1372 |
+
|
| 1373 |
+
tok = DummyTok()
|
| 1374 |
+
|
| 1375 |
+
# ๋๋ฏธ ๋ฐ์ดํฐ
|
| 1376 |
+
val_data = []
|
| 1377 |
+
for _ in range(30):
|
| 1378 |
+
ids = torch.randint(3, 100, (65,))
|
| 1379 |
+
val_data.append({"input_ids": ids[:64], "targets": ids[1:65]})
|
| 1380 |
+
|
| 1381 |
+
def collate(batch):
|
| 1382 |
+
return {
|
| 1383 |
+
"input_ids": torch.stack([b["input_ids"] for b in batch]),
|
| 1384 |
+
"targets": torch.stack([b["targets"] for b in batch]),
|
| 1385 |
+
}
|
| 1386 |
+
|
| 1387 |
+
val_dl = DataLoader(val_data, batch_size=4, collate_fn=collate)
|
| 1388 |
+
|
| 1389 |
+
# โโ 1. Perplexity ํ
์คํธ โโ
|
| 1390 |
+
print("\n[ํ
์คํธ 1] Perplexity ์ธก์ ")
|
| 1391 |
+
ppl_eval = PerplexityEvaluator(EvalConfig(max_eval_batches=5))
|
| 1392 |
+
result = ppl_eval.evaluate(model, val_dl, device, torch.float32, desc="Test Eval")
|
| 1393 |
+
print(f" โ Loss={result['loss']:.4f}, PPL={result['perplexity']:.2f}")
|
| 1394 |
+
expected_ppl = math.exp(math.log(100)) # vocab=100 โ ์ด๊ธฐ PPL โ 100
|
| 1395 |
+
print(f" โ ์์ ์ด๊ธฐ PPL โ {expected_ppl:.0f} (vocab=100 ๋๋ค)")
|
| 1396 |
+
|
| 1397 |
+
# โโ 2. ์์ฑ ํ
์คํธ โโ
|
| 1398 |
+
print("\n[ํ
์คํธ 2] ํ
์คํธ ์์ฑ")
|
| 1399 |
+
gen_eval = GenerationEvaluator(EvalConfig(max_new_tokens=30, num_samples=1))
|
| 1400 |
+
gen_results = gen_eval.generate_samples(
|
| 1401 |
+
model, tok, device, prompts=["Hello world"], verbose=True
|
| 1402 |
+
)
|
| 1403 |
+
|
| 1404 |
+
# โโ 3. Scaling ๋ถ์ ํ
์คํธ โโ
|
| 1405 |
+
print("\n[ํ
์คํธ 3] Scaling Law ๋ถ์")
|
| 1406 |
+
analyzer = ScalingAnalyzer("./test_eval")
|
| 1407 |
+
dummy_scaling = [
|
| 1408 |
+
{"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
|
| 1409 |
+
{"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
|
| 1410 |
+
{"name": "1B", "params": 1.1e9, "tokens": 10e9, "loss": 3.0, "ppl": 20.1},
|
| 1411 |
+
]
|
| 1412 |
+
scaling_result = analyzer.analyze(dummy_scaling)
|
| 1413 |
+
|
| 1414 |
+
# โโ 4. ํ์ต ์ญํ ๋ถ์ ํ
์คํธ โโ
|
| 1415 |
+
print("\n[ํ
์คํธ 4] ํ์ต ์ญํ ๋ถ์")
|
| 1416 |
+
import random
|
| 1417 |
+
random.seed(42)
|
| 1418 |
+
|
| 1419 |
+
dummy_history = {
|
| 1420 |
+
"step": list(range(0, 1000, 10)),
|
| 1421 |
+
"train_loss": [10.0 * (0.995 ** i) + random.gauss(0, 0.1) for i in range(100)],
|
| 1422 |
+
"learning_rate": [min(3e-4 * i / 20, 3e-4) * (0.5 + 0.5 * math.cos(math.pi * max(0, i-20)/80))
|
| 1423 |
+
for i in range(100)],
|
| 1424 |
+
"grad_norm": [min(random.gauss(0.5, 0.3), 1.0) for _ in range(100)],
|
| 1425 |
+
"tokens_per_sec": [50000 + random.gauss(0, 3000) for _ in range(100)],
|
| 1426 |
+
"val_loss": [8.0, 6.0, 4.5, 3.8, 3.5],
|
| 1427 |
+
"val_ppl": [2981, 403, 90, 44, 33],
|
| 1428 |
+
}
|
| 1429 |
+
|
| 1430 |
+
dynamics = TrainingDynamicsAnalyzer("./test_eval")
|
| 1431 |
+
dynamics.analyze_metrics(dummy_history)
|
| 1432 |
+
|
| 1433 |
+
# โโ 5. ์ฒดํฌ๋ฆฌ์คํธ ํ
์คํธ โโ
|
| 1434 |
+
print("\n[ํ
์คํธ 5] ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ")
|
| 1435 |
+
dummy_report = {
|
| 1436 |
+
"perplexity": {"loss": 3.5, "perplexity": 33.1},
|
| 1437 |
+
"position_losses": {"early_avg": 4.5, "late_avg": 3.2},
|
| 1438 |
+
"generation": {"avg_metrics": {"repetition_rate": 0.15}},
|
| 1439 |
+
"training_dynamics": {"loss": {"initial": 10.0, "final": 3.5, "spikes": []}},
|
| 1440 |
+
}
|
| 1441 |
+
InsightChecklist.run_checklist(dummy_report, dummy_history)
|
| 1442 |
+
|
| 1443 |
+
# ์ ๋ฆฌ
|
| 1444 |
+
import shutil
|
| 1445 |
+
if os.path.exists("./test_eval"):
|
| 1446 |
+
shutil.rmtree("./test_eval")
|
| 1447 |
+
|
| 1448 |
+
print("\n" + "=" * 70)
|
| 1449 |
+
print("โ
ํ๊ฐ ๋ชจ๋ ๊ฒ์ฆ ์๋ฃ!")
|
| 1450 |
+
print()
|
| 1451 |
+
print("์ค์ ์ฌ์ฉ๋ฒ:")
|
| 1452 |
+
print(" from evaluation import run_evaluation")
|
| 1453 |
+
print(" report = run_evaluation(model, tokenizer, val_dl,")
|
| 1454 |
+
print(" metrics_history=trainer.metrics.history)")
|
| 1455 |
+
print("=" * 70)
|
_archive/llm-1b-model.py
ADDED
|
@@ -0,0 +1,791 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
|
| 3 |
+
================================================================
|
| 4 |
+
๋ฅ๋ฌ๋ ์ด๋ณด์๋ฅผ ์ํ ํ์ต์ฉ ๊ตฌํ.
|
| 5 |
+
๊ฐ ์ปดํฌ๋ํธ์ ์์ธ ์ฃผ์์ ๋ฌ์ "์ ์ด๋ ๊ฒ ํ๋์ง"๋ฅผ ์ค๋ช
ํฉ๋๋ค.
|
| 6 |
+
|
| 7 |
+
์ํคํ
์ฒ ์์ฝ:
|
| 8 |
+
- Decoder-Only Transformer (Causal LM)
|
| 9 |
+
- RMSNorm (Pre-Normalization)
|
| 10 |
+
- Rotary Positional Embedding (RoPE)
|
| 11 |
+
- Grouped Query Attention (GQA)
|
| 12 |
+
- SwiGLU Feed-Forward Network
|
| 13 |
+
- Weight Tying (Embedding โ Output Head)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ============================================================================
|
| 26 |
+
# 1. ๋ชจ๋ธ ์ค์ (Config)
|
| 27 |
+
# ============================================================================
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ModelConfig:
|
| 31 |
+
"""๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ๋์ ๋ฐ์ดํฐํด๋์ค๋ก ๊ด๋ฆฌํฉ๋๋ค.
|
| 32 |
+
|
| 33 |
+
๊ท๋ชจ๋ณ ํ๋ฆฌ์
:
|
| 34 |
+
- debug: ~10M (ํ์ดํ๋ผ์ธ ๊ฒ์ฆ์ฉ)
|
| 35 |
+
- small: ~100M (์ค๊ฐ ๊ฒ์ฆ์ฉ)
|
| 36 |
+
- base: ~1.1B (์ต์ข
๋ชฉํ)
|
| 37 |
+
"""
|
| 38 |
+
vocab_size: int = 32_000
|
| 39 |
+
hidden_dim: int = 2048 # d_model: ๋ชจ๋ธ์ ๊ธฐ๋ณธ ์ฐจ์
|
| 40 |
+
num_layers: int = 22 # Transformer ๋ธ๋ก ์
|
| 41 |
+
num_heads: int = 16 # Query ํค๋ ์
|
| 42 |
+
num_kv_heads: int = 4 # Key/Value ํค๋ ์ (GQA)
|
| 43 |
+
intermediate_dim: int = 5632 # FFN ์ค๊ฐ ์ฐจ์ (โ 2.75 ร hidden_dim)
|
| 44 |
+
max_seq_len: int = 2048 # ์ต๋ ์ํ์ค ๊ธธ์ด
|
| 45 |
+
dropout: float = 0.0 # Pretraining์์๋ ๋ณดํต 0 ์ฌ์ฉ
|
| 46 |
+
rope_theta: float = 10000.0 # RoPE ์ฃผํ์ ๋ฒ ์ด์ค
|
| 47 |
+
norm_eps: float = 1e-6 # RMSNorm epsilon
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def head_dim(self) -> int:
|
| 51 |
+
"""๊ฐ ์ดํ
์
ํค๋์ ์ฐจ์."""
|
| 52 |
+
return self.hidden_dim // self.num_heads
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def num_kv_groups(self) -> int:
|
| 56 |
+
"""GQA์์ ํ๋์ KV ํค๋๊ฐ ๋ด๋นํ๋ Q ํค๋ ์."""
|
| 57 |
+
return self.num_heads // self.num_kv_heads
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def debug_10m(cls) -> "ModelConfig":
|
| 61 |
+
"""~10M ํ๋ผ๋ฏธํฐ - ๋น ๋ฅธ ๋๋ฒ๊น
์ฉ."""
|
| 62 |
+
return cls(
|
| 63 |
+
hidden_dim=256, num_layers=6, num_heads=8,
|
| 64 |
+
num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def small_100m(cls) -> "ModelConfig":
|
| 69 |
+
"""~100M ํ๋ผ๋ฏธํฐ - ์ค๊ฐ ๊ฒ์ฆ์ฉ."""
|
| 70 |
+
return cls(
|
| 71 |
+
hidden_dim=768, num_layers=12, num_heads=12,
|
| 72 |
+
num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def base_1b(cls) -> "ModelConfig":
|
| 77 |
+
"""~1.1B ํ๋ผ๋ฏธํฐ - ์ต์ข
ํ์ต ๋ชฉํ."""
|
| 78 |
+
return cls() # ๊ธฐ๋ณธ๊ฐ์ด 1B ์ค์
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ============================================================================
|
| 82 |
+
# 2. RMSNorm (Root Mean Square Layer Normalization)
|
| 83 |
+
# ============================================================================
|
| 84 |
+
|
| 85 |
+
class RMSNorm(nn.Module):
|
| 86 |
+
"""RMSNorm: LayerNorm์ ๊ฒฝ๋ํ ๋ฒ์ .
|
| 87 |
+
|
| 88 |
+
์ผ๋ฐ LayerNorm๊ณผ์ ์ฐจ์ด:
|
| 89 |
+
- ํ๊ท (mean)์ ๋นผ์ง ์์ โ ์ฐ์ฐ ์ ์ฝ
|
| 90 |
+
- ๋ถ์ฐ ๋์ RMS(Root Mean Square)๋ก ์ ๊ทํ
|
| 91 |
+
- bias ํ๋ผ๋ฏธํฐ ์์
|
| 92 |
+
|
| 93 |
+
์์:
|
| 94 |
+
RMSNorm(x) = (x / RMS(x)) * ฮณ
|
| 95 |
+
RMS(x) = sqrt(mean(xยฒ) + ฮต)
|
| 96 |
+
|
| 97 |
+
์ ์ ๊ทํ๊ฐ ํ์ํ๊ฐ?
|
| 98 |
+
โ ๋ ์ด์ด๋ฅผ ๊น๊ฒ ์์ผ๋ฉด ํ์ฑํ ๊ฐ์ ์ค์ผ์ผ์ด ํญ๋ฐํ๊ฑฐ๋ ์๋ฉธํฉ๋๋ค.
|
| 99 |
+
โ ์ ๊ทํ๋ก ๊ฐ ๋ ์ด์ด์ ์
๋ ฅ์ ์์ ์ ์ธ ๋ฒ์๋ก ์ ์งํฉ๋๋ค.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.eps = eps
|
| 105 |
+
# ฮณ (gamma): ํ์ต ๊ฐ๋ฅํ ์ค์ผ์ผ ํ๋ผ๋ฏธํฐ, 1๋ก ์ด๊ธฐํ
|
| 106 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 107 |
+
|
| 108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
# 1) ์
๋ ฅ์ float32๋ก ๋ณํ (์์น ์์ ์ฑ)
|
| 110 |
+
# bf16/fp16 ์ํ์์ ์ ๊ณฑํฉ์ ๊ตฌํ๋ฉด ์ค๋ฒํ๋ก์ฐ ์ํ
|
| 111 |
+
x_float = x.float()
|
| 112 |
+
|
| 113 |
+
# 2) RMS ๊ณ์ฐ: sqrt(mean(xยฒ) + ฮต)
|
| 114 |
+
rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 115 |
+
# rsqrt = 1/sqrt(x) โ ๋๋์
๋์ ๊ณฑ์
์ผ๋ก ๋์ฒด (๋ ๋น ๋ฆ)
|
| 116 |
+
|
| 117 |
+
# 3) ์ ๊ทํ ํ ์๋ dtype์ผ๋ก ๋ณต์, ์ค์ผ์ผ ์ ์ฉ
|
| 118 |
+
return (x_float * rms).to(x.dtype) * self.weight
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ============================================================================
|
| 122 |
+
# 3. Rotary Positional Embedding (RoPE)
|
| 123 |
+
# ============================================================================
|
| 124 |
+
|
| 125 |
+
class RotaryPositionalEmbedding(nn.Module):
|
| 126 |
+
"""RoPE: ํ์ ํ๋ ฌ์ ์ด์ฉํ ์๋ ์์น ์ธ์ฝ๋ฉ.
|
| 127 |
+
|
| 128 |
+
ํต์ฌ ์์ด๋์ด:
|
| 129 |
+
- ๊ฐ ์ฐจ์ ์(2i, 2i+1)์ 2D ํ๋ฉด์ ์ขํ๋ก ๋ณด๊ณ ,
|
| 130 |
+
์์น(position)์ ๋น๋กํ ๊ฐ๋๋งํผ ํ์ ์ํต๋๋ค.
|
| 131 |
+
- ๋ ํ ํฐ์ ์ดํ
์
์ค์ฝ์ด(QยทK)๋ ์๋ ๊ฑฐ๋ฆฌ์๋ง ์์กดํ๊ฒ ๋ฉ๋๋ค.
|
| 132 |
+
|
| 133 |
+
์ RoPE์ธ๊ฐ?
|
| 134 |
+
- ์ ๏ฟฝ๏ฟฝ๏ฟฝ ์์น ์๋ฒ ๋ฉ: ๊ฐ ์์น์ ๊ณ ์ ๋ฒกํฐ๋ฅผ ๋ํจ โ ๊ธธ์ด ์ผ๋ฐํ ์ด๋ ค์
|
| 135 |
+
- ์๋ ์์น ์๋ฒ ๋ฉ: ๊ตฌํ ๋ณต์ก, ์ถ๊ฐ ํ๋ผ๋ฏธํฐ ํ์
|
| 136 |
+
- RoPE: ํ๋ผ๋ฏธํฐ ์์ด, ์์ฐ์ค๋ฝ๊ฒ ์๋ ์์น ์ ๋ณด ์ธ์ฝ๋ฉ
|
| 137 |
+
|
| 138 |
+
์์:
|
| 139 |
+
ฮธ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
|
| 140 |
+
RoPE(x, pos) = x๋ฅผ ๊ฐ ์ฐจ์ ์์์ pos ร ฮธ_i ๋งํผ ํ์
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.dim = dim
|
| 146 |
+
self.max_seq_len = max_seq_len
|
| 147 |
+
self.theta = theta
|
| 148 |
+
|
| 149 |
+
# ์ฃผํ์ ๋ฒกํฐ ๋ฏธ๋ฆฌ ๊ณ์ฐ (ํ์ต ๋ถํ์ โ buffer๋ก ๋ฑ๋ก)
|
| 150 |
+
# freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
|
| 151 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 152 |
+
self.register_buffer("freqs", freqs, persistent=False)
|
| 153 |
+
|
| 154 |
+
# (max_seq_len, dim/2) ํฌ๊ธฐ์ cos/sin ํ
์ด๋ธ ๋ฏธ๋ฆฌ ๊ณ์ฐ
|
| 155 |
+
self._build_cache(max_seq_len)
|
| 156 |
+
|
| 157 |
+
def _build_cache(self, seq_len: int):
|
| 158 |
+
"""cos/sin ๊ฐ์ ๋ฏธ๋ฆฌ ๊ณ์ฐํ์ฌ ์บ์ฑํฉ๋๋ค."""
|
| 159 |
+
t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
|
| 160 |
+
# outer product: (seq_len,) ร (dim/2,) โ (seq_len, dim/2)
|
| 161 |
+
angles = torch.outer(t, self.freqs)
|
| 162 |
+
self.register_buffer("cos_cached", angles.cos(), persistent=False)
|
| 163 |
+
self.register_buffer("sin_cached", angles.sin(), persistent=False)
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
|
| 167 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 168 |
+
"""Q, K์ ํ์ ๋ณํ์ ์ ์ฉํฉ๋๋ค.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
q: (batch, num_heads, seq_len, head_dim)
|
| 172 |
+
k: (batch, num_kv_heads, seq_len, head_dim)
|
| 173 |
+
position_offset: ์ํ์ค ์์ ์์น ์คํ์
(์ถ๋ก ์ KV ์บ์ ์ฌ์ฉ ์)
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
ํ์ ๋ณํ์ด ์ ์ฉ๋ (q_rotated, k_rotated)
|
| 177 |
+
"""
|
| 178 |
+
seq_len = q.shape[2]
|
| 179 |
+
|
| 180 |
+
# ํ์ ์ ์บ์ ํ์ฅ
|
| 181 |
+
if position_offset + seq_len > self.cos_cached.shape[0]:
|
| 182 |
+
self._build_cache(position_offset + seq_len)
|
| 183 |
+
|
| 184 |
+
# ํ์ฌ ์์น์ ํด๋นํ๋ cos/sin ์ฌ๋ผ์ด์ค
|
| 185 |
+
cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
|
| 186 |
+
sin = self.sin_cached[position_offset : position_offset + seq_len]
|
| 187 |
+
|
| 188 |
+
q_rotated = self._apply_rotation(q, cos, sin)
|
| 189 |
+
k_rotated = self._apply_rotation(k, cos, sin)
|
| 190 |
+
return q_rotated, k_rotated
|
| 191 |
+
|
| 192 |
+
@staticmethod
|
| 193 |
+
def _apply_rotation(
|
| 194 |
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
"""ํ์ ๋ณํ ์ ์ฉ.
|
| 197 |
+
|
| 198 |
+
2D ํ์ ํ๋ ฌ:
|
| 199 |
+
[cos ฮธ, -sin ฮธ] [x1] [x1ยทcos ฮธ - x2ยทsin ฮธ]
|
| 200 |
+
[sin ฮธ, cos ฮธ] [x2] = [x1ยทsin ฮธ + x2ยทcos ฮธ]
|
| 201 |
+
|
| 202 |
+
์ด๋ฅผ ๋ฒกํฐ ์ฐ์ฐ์ผ๋ก ํจ์จ์ ์ผ๋ก ๊ตฌํํฉ๋๋ค.
|
| 203 |
+
"""
|
| 204 |
+
# x: (batch, heads, seq_len, head_dim)
|
| 205 |
+
# ์ง์/ํ์ ์ธ๋ฑ์ค๋ฅผ ๋ถ๋ฆฌ: (x0, x1, x2, x3, ...) โ (x0, x2, ...), (x1, x3, ...)
|
| 206 |
+
x_even = x[..., 0::2] # ์ง์ ์ธ๋ฑ์ค
|
| 207 |
+
x_odd = x[..., 1::2] # ํ์ ์ธ๋ฑ์ค
|
| 208 |
+
|
| 209 |
+
# ๋ธ๋ก๋์บ์คํ
์ ์ํด ์ฐจ์ ๋ง์ถค: (seq_len, dim/2) โ (1, 1, seq_len, dim/2)
|
| 210 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 211 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 212 |
+
|
| 213 |
+
# ํ์ ์ ์ฉ
|
| 214 |
+
rotated_even = x_even * cos - x_odd * sin
|
| 215 |
+
rotated_odd = x_even * sin + x_odd * cos
|
| 216 |
+
|
| 217 |
+
# ๋ค์ ์ธํฐ๋ฆฌ๋น: (even0, odd0, even1, odd1, ...)
|
| 218 |
+
out = torch.stack([rotated_even, rotated_odd], dim=-1)
|
| 219 |
+
return out.flatten(-2) # ๋ง์ง๋ง ๋ ์ฐจ์์ ํฉ์ณ ์๋ shape ๋ณต์
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ============================================================================
|
| 223 |
+
# 4. Grouped Query Attention (GQA)
|
| 224 |
+
# ============================================================================
|
| 225 |
+
|
| 226 |
+
class GroupedQueryAttention(nn.Module):
|
| 227 |
+
"""GQA: Multi-Head Attention์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ๋ณํ.
|
| 228 |
+
|
| 229 |
+
MHA vs GQA vs MQA:
|
| 230 |
+
- MHA (Multi-Head Attention): Q, K, V ๋ชจ๋ num_heads๊ฐ โ ๋ฉ๋ชจ๋ฆฌ ํผ
|
| 231 |
+
- MQA (Multi-Query Attention): K, V๋ 1๊ฐ ํค๋ ๊ณต์ โ ํ์ง ์ ํ ์ฐ๋ ค
|
| 232 |
+
- GQA (Grouped Query Attention): K, V๋ฅผ num_kv_heads๊ฐ๋ก ๊ทธ๋ฃนํ
|
| 233 |
+
โ MHA์ MQA์ ์ค๊ฐ, ์ข์ ํ์ง-ํจ์จ ๊ท ํ
|
| 234 |
+
|
| 235 |
+
์์ (num_heads=16, num_kv_heads=4):
|
| 236 |
+
Q ํค๋: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
|
| 237 |
+
K/V ๊ทธ๋ฃน: [ 0 , 1 , 2 , 3 ]
|
| 238 |
+
โ Q ํค๋ 4๊ฐ๊ฐ K/V ํค๋ 1๊ฐ๋ฅผ ๊ณต์
|
| 239 |
+
|
| 240 |
+
Attention ์์:
|
| 241 |
+
Attention(Q, K, V) = softmax(QยทK^T / โd_k) ยท V
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, config: ModelConfig):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.config = config
|
| 247 |
+
self.head_dim = config.head_dim
|
| 248 |
+
self.num_heads = config.num_heads
|
| 249 |
+
self.num_kv_heads = config.num_kv_heads
|
| 250 |
+
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
|
| 251 |
+
|
| 252 |
+
# Q/K/V ํ๋ก์ ์
|
| 253 |
+
# Q: hidden_dim โ num_heads ร head_dim
|
| 254 |
+
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
|
| 255 |
+
# K, V: hidden_dim โ num_kv_heads ร head_dim (Q๋ณด๋ค ์์!)
|
| 256 |
+
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 257 |
+
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 258 |
+
|
| 259 |
+
# ์ถ๋ ฅ ํ๋ก์ ์
: ๋ชจ๋ ํค๋์ ์ถ๋ ฅ์ ๋ค์ hidden_dim์ผ๋ก
|
| 260 |
+
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
| 261 |
+
|
| 262 |
+
# RoPE
|
| 263 |
+
self.rope = RotaryPositionalEmbedding(
|
| 264 |
+
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Attention dropout (pretraining์์๋ ๋ณดํต 0)
|
| 268 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 269 |
+
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
x: torch.Tensor,
|
| 273 |
+
mask: Optional[torch.Tensor] = None,
|
| 274 |
+
position_offset: int = 0,
|
| 275 |
+
) -> torch.Tensor:
|
| 276 |
+
"""
|
| 277 |
+
Args:
|
| 278 |
+
x: (batch_size, seq_len, hidden_dim)
|
| 279 |
+
mask: (seq_len, seq_len) causal mask
|
| 280 |
+
position_offset: ์์น ์คํ์
(์ถ๋ก ์ ์ฌ์ฉ)
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
(batch_size, seq_len, hidden_dim)
|
| 284 |
+
"""
|
| 285 |
+
B, S, _ = x.shape
|
| 286 |
+
|
| 287 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 288 |
+
# Step 1: Q, K, V ํ๋ก์ ์
|
| 289 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 290 |
+
q = self.q_proj(x) # (B, S, num_heads ร head_dim)
|
| 291 |
+
k = self.k_proj(x) # (B, S, num_kv_heads ร head_dim)
|
| 292 |
+
v = self.v_proj(x) # (B, S, num_kv_heads ร head_dim)
|
| 293 |
+
|
| 294 |
+
# ๋ฉํฐํค๋ ํํ๋ก reshape
|
| 295 |
+
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
| 296 |
+
# โ (B, num_heads, S, head_dim)
|
| 297 |
+
k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 298 |
+
# โ (B, num_kv_heads, S, head_dim)
|
| 299 |
+
v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 300 |
+
|
| 301 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 302 |
+
# Step 2: RoPE ์ ์ฉ (Q, K์๋ง! V์๋ ์ ์ฉํ์ง ์์)
|
| 303 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 304 |
+
# ์์น ์ ๋ณด๋ "์ด๋๋ฅผ ๋ณผ์ง"(QยทK)์๋ง ์ํฅ์ ์ค์ผ ํ๊ณ ,
|
| 305 |
+
# "๋ฌด์์ ๊ฐ์ ธ์ฌ์ง"(V)์๋ ์ํฅ์ ์ฃผ๋ฉด ์ ๋ฉ๋๋ค.
|
| 306 |
+
q, k = self.rope(q, k, position_offset)
|
| 307 |
+
|
| 308 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 309 |
+
# Step 3: GQA - KV ํค๋ ํ์ฅ (repeat)
|
| 310 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 311 |
+
# num_kv_heads=4 โ num_heads=16: ๊ฐ KV๋ฅผ 4๋ฒ ๋ฐ๋ณต
|
| 312 |
+
if self.num_kv_groups > 1:
|
| 313 |
+
k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
|
| 314 |
+
v = self._repeat_kv(v)
|
| 315 |
+
|
| 316 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 317 |
+
# Step 4: Scaled Dot-Product Attention
|
| 318 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 319 |
+
# PyTorch >= 2.0์ ์ต์ ํ๋ ๊ตฌํ ์ฌ์ฉ (Flash Attention ์๋ ์ ์ฉ)
|
| 320 |
+
attn_out = F.scaled_dot_product_attention(
|
| 321 |
+
q, k, v,
|
| 322 |
+
attn_mask=mask,
|
| 323 |
+
dropout_p=self.config.dropout if self.training else 0.0,
|
| 324 |
+
is_causal=(mask is None), # mask๊ฐ ์์ผ๋ฉด ์๋ causal masking
|
| 325 |
+
)
|
| 326 |
+
# โ (B, num_heads, S, head_dim)
|
| 327 |
+
|
| 328 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 329 |
+
# Step 5: ํค๋ ํฉ์น๊ธฐ + ์ถ๋ ฅ ํ๋ก์ ์
|
| 330 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 331 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 332 |
+
# โ (B, S, num_heads ร head_dim)
|
| 333 |
+
|
| 334 |
+
return self.o_proj(attn_out) # โ (B, S, hidden_dim)
|
| 335 |
+
|
| 336 |
+
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 337 |
+
"""KV ํค๋๋ฅผ Q ํค๋ ์์ ๋ง๊ฒ ๋ฐ๋ณตํฉ๋๋ค.
|
| 338 |
+
|
| 339 |
+
(B, num_kv_heads, S, head_dim) โ (B, num_heads, S, head_dim)
|
| 340 |
+
|
| 341 |
+
์: num_kv_heads=4, num_kv_groups=4
|
| 342 |
+
[kv0, kv1, kv2, kv3] โ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 343 |
+
"""
|
| 344 |
+
B, H_kv, S, D = x.shape
|
| 345 |
+
x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
|
| 346 |
+
x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
|
| 347 |
+
return x.reshape(B, self.num_heads, S, D)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ============================================================================
|
| 351 |
+
# 5. SwiGLU Feed-Forward Network
|
| 352 |
+
# ============================================================================
|
| 353 |
+
|
| 354 |
+
class SwiGLUFeedForward(nn.Module):
|
| 355 |
+
"""SwiGLU: Gated Linear Unit with Swish ํ์ฑํ ํจ์.
|
| 356 |
+
|
| 357 |
+
๊ธฐ์กด FFN:
|
| 358 |
+
FFN(x) = ReLU(xยทW1 + b1)ยทW2 + b2
|
| 359 |
+
โ ๋จ์ํ ๋น์ ํ ๋ณํ
|
| 360 |
+
|
| 361 |
+
SwiGLU FFN:
|
| 362 |
+
SwiGLU(x) = (Swish(xยทW_gate) โ (xยทW_up)) ยท W_down
|
| 363 |
+
โ ๊ฒ์ดํ
๋ฉ์ปค๋์ฆ์ผ๋ก ์ ๋ณด ํ๋ฆ์ ์ ์ด
|
| 364 |
+
|
| 365 |
+
์ SwiGLU๊ฐ ๋ ์ข์๊ฐ?
|
| 366 |
+
- Swish(x) = x ยท sigmoid(x): ๋ถ๋๋ฌ์ด ํ์ฑํ, ์์ ์์ญ ์ผ๋ถ ํ์ฉ
|
| 367 |
+
- Gate ๋ฒกํฐ๊ฐ "์ด๋ค ์ ๋ณด๋ฅผ ํต๊ณผ์ํฌ์ง" ํ์ต
|
| 368 |
+
- PaLM, LLaMA ๋ฑ์์ ReLU FFN ๋๋น ์ผ๊ด๋ ์ฑ๋ฅ ํฅ์ ๋ณด๊ณ
|
| 369 |
+
|
| 370 |
+
์ฐธ๊ณ : W_gate์ W_up ๋ ๊ฐ์ up-projection์ด ์์ด์
|
| 371 |
+
ํ๋ผ๋ฏธํฐ ์๊ฐ ๊ธฐ์กด FFN ๋๋น 1.5๋ฐฐ์ด์ง๋ง, intermediate_dim์
|
| 372 |
+
์กฐ์ ํ์ฌ ์ด ํ๋ผ๋ฏธํฐ ์๋ฅผ ๋ง์ถฅ๋๋ค.
|
| 373 |
+
"""
|
| 374 |
+
|
| 375 |
+
def __init__(self, config: ModelConfig):
|
| 376 |
+
super().__init__()
|
| 377 |
+
# ๊ฒ์ดํธ ํ๋ก์ ์
: hidden_dim โ intermediate_dim
|
| 378 |
+
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 379 |
+
# ์
ํ๋ก์ ์
: hidden_dim โ intermediate_dim
|
| 380 |
+
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 381 |
+
# ๋ค์ด ํ๋ก์ ์
: intermediate_dim โ hidden_dim
|
| 382 |
+
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
| 383 |
+
|
| 384 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
# SwiGLU(x) = (Swish(gate(x)) โ up(x)) ยท down
|
| 386 |
+
#
|
| 387 |
+
# 1) gate: ์ด๋ค ์ ๋ณด๋ฅผ ํต๊ณผ์ํฌ์ง ๊ฒฐ์ (Swish ํ์ฑํ)
|
| 388 |
+
gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
|
| 389 |
+
# 2) up: ์ ๋ณด๋ฅผ ๊ณ ์ฐจ์์ผ๋ก ์ฌ์
|
| 390 |
+
up = self.up_proj(x)
|
| 391 |
+
# 3) element-wise ๊ณฑ (๊ฒ์ดํ
) โ ๋ค์ ์๋ ์ฐจ์์ผ๋ก
|
| 392 |
+
return self.down_proj(gate * up)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ============================================================================
|
| 396 |
+
# 6. Transformer Block (ํ๋์ ๋ ์ด์ด)
|
| 397 |
+
# ============================================================================
|
| 398 |
+
|
| 399 |
+
class TransformerBlock(nn.Module):
|
| 400 |
+
"""ํ๋์ Transformer ๋์ฝ๋ ๋ธ๋ก.
|
| 401 |
+
|
| 402 |
+
๊ตฌ์กฐ (Pre-Norm ๋ฐฉ์):
|
| 403 |
+
x โ RMSNorm โ Attention โ + (residual) โ RMSNorm โ FFN โ + (residual) โ out
|
| 404 |
+
|
| 405 |
+
Pre-Norm vs Post-Norm:
|
| 406 |
+
- Post-Norm (์๋ Transformer): LayerNorm์ด residual ์ดํ
|
| 407 |
+
โ ๊น์ ๋ชจ๋ธ์์ ํ์ต ๋ถ์์
|
| 408 |
+
- Pre-Norm (GPT-2 ์ดํ ํ์ค): LayerNorm์ด sublayer ์ด์
|
| 409 |
+
โ gradient ํ๋ฆ์ด ์ํ, ํ์ต์ด ์์ ์
|
| 410 |
+
|
| 411 |
+
Residual Connection์ ์ญํ :
|
| 412 |
+
- ์
๋ ฅ์ ์ถ๋ ฅ์ ๋ํจ โ gradient๊ฐ ๋ ์ด์ด๋ฅผ ๊ฑด๋๋ธ ์ ์๋ "๊ณ ์๋๋ก"
|
| 413 |
+
- 22๊ฐ ๋ ์ด์ด๋ฅผ ์์๋ ํ์ต์ด ๊ฐ๋ฅํ ํต์ฌ ์ด์
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
def __init__(self, config: ModelConfig, layer_idx: int):
|
| 417 |
+
super().__init__()
|
| 418 |
+
self.layer_idx = layer_idx
|
| 419 |
+
|
| 420 |
+
# Pre-Norm: Attention ์ ์ ๊ทํ
|
| 421 |
+
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 422 |
+
# Self-Attention
|
| 423 |
+
self.attention = GroupedQueryAttention(config)
|
| 424 |
+
|
| 425 |
+
# Pre-Norm: FFN ์ ์ ๊ทํ
|
| 426 |
+
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 427 |
+
# Feed-Forward Network
|
| 428 |
+
self.feed_forward = SwiGLUFeedForward(config)
|
| 429 |
+
|
| 430 |
+
def forward(
|
| 431 |
+
self,
|
| 432 |
+
x: torch.Tensor,
|
| 433 |
+
mask: Optional[torch.Tensor] = None,
|
| 434 |
+
position_offset: int = 0,
|
| 435 |
+
) -> torch.Tensor:
|
| 436 |
+
"""
|
| 437 |
+
Args:
|
| 438 |
+
x: (batch_size, seq_len, hidden_dim)
|
| 439 |
+
Returns:
|
| 440 |
+
(batch_size, seq_len, hidden_dim)
|
| 441 |
+
"""
|
| 442 |
+
# โโ Attention sublayer with residual โโ
|
| 443 |
+
# h = x + Attention(RMSNorm(x))
|
| 444 |
+
h = x + self.attention(self.attn_norm(x), mask, position_offset)
|
| 445 |
+
|
| 446 |
+
# โโ FFN sublayer with residual โโ
|
| 447 |
+
# out = h + FFN(RMSNorm(h))
|
| 448 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 449 |
+
|
| 450 |
+
return out
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# ============================================================================
|
| 454 |
+
# 7. Full Transformer Model (LLaMA-style)
|
| 455 |
+
# ============================================================================
|
| 456 |
+
|
| 457 |
+
class LLMModel(nn.Module):
|
| 458 |
+
"""1B ํ๋ผ๋ฏธํฐ LLaMA-style Decoder-Only Transformer.
|
| 459 |
+
|
| 460 |
+
์ ์ฒด ๊ตฌ์กฐ:
|
| 461 |
+
Input Token IDs
|
| 462 |
+
โ Token Embedding
|
| 463 |
+
โ [TransformerBlock] ร num_layers (+ Activation Checkpointing)
|
| 464 |
+
โ RMSNorm (์ต์ข
)
|
| 465 |
+
โ Linear Head (โ vocab logits)
|
| 466 |
+
|
| 467 |
+
Weight Tying:
|
| 468 |
+
- ์
๋ ฅ Embedding๊ณผ ์ถ๋ ฅ Linear Head์ ๊ฐ์ค์น๋ฅผ ๊ณต์
|
| 469 |
+
- ํ๋ผ๋ฏธํฐ ์ ์ ์ฝ (~65M) + ์ฑ๋ฅ ์ ์ง/ํฅ์
|
| 470 |
+
- ์ง๊ด: "๋จ์ด์ ์๋ฏธ ํํ"๊ณผ "๋จ์ด ์์ธก"์ด ๊ฐ์ ๊ณต๊ฐ๏ฟฝ๏ฟฝ ์ฌ์ฉ
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
def __init__(self, config: ModelConfig):
|
| 474 |
+
super().__init__()
|
| 475 |
+
self.config = config
|
| 476 |
+
|
| 477 |
+
# โโ Token Embedding โโ
|
| 478 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
|
| 479 |
+
|
| 480 |
+
# โโ Transformer Blocks โโ
|
| 481 |
+
self.layers = nn.ModuleList([
|
| 482 |
+
TransformerBlock(config, layer_idx=i)
|
| 483 |
+
for i in range(config.num_layers)
|
| 484 |
+
])
|
| 485 |
+
|
| 486 |
+
# โโ ์ต์ข
์ ๊ทํ โโ
|
| 487 |
+
self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 488 |
+
|
| 489 |
+
# โโ ์ถ๋ ฅ ํค๋ (Weight Tying) โโ
|
| 490 |
+
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 491 |
+
# Weight Tying: lm_head์ ๊ฐ์ค์น = token_embedding์ ๊ฐ์ค์น
|
| 492 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 493 |
+
|
| 494 |
+
# ๊ฐ์ค์น ์ด๊ธฐํ
|
| 495 |
+
self._init_weights()
|
| 496 |
+
|
| 497 |
+
def _init_weights(self):
|
| 498 |
+
"""๊ฐ์ค์น ์ด๊ธฐํ ์ ๋ต.
|
| 499 |
+
|
| 500 |
+
์ ์ด๊ธฐํ๊ฐ ์ค์ํ๊ฐ?
|
| 501 |
+
- ๋๋ฌด ํฌ๋ฉด: ํ์ฑํ ํญ๋ฐ โ NaN
|
| 502 |
+
- ๋๋ฌด ์์ผ๋ฉด: gradient ์๋ฉธ โ ํ์ต ์ ์ฒด
|
| 503 |
+
- ์ ์ ํ ์ด๊ธฐํ: ๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ ๋ถ์ฐ์ ์ผ์ ํ๊ฒ ์ ์ง
|
| 504 |
+
|
| 505 |
+
GPT-2 ์คํ์ผ ์ด๊ธฐํ:
|
| 506 |
+
- ์ผ๋ฐ Linear: N(0, 0.02)
|
| 507 |
+
- Residual projection: N(0, 0.02 / โ(2 ร num_layers))
|
| 508 |
+
โ ๋ ์ด์ด๊ฐ ๊น์ด์ง์๋ก residual ๊ธฐ์ฌ๋ฅผ ์ค์ฌ ์์ ํ
|
| 509 |
+
"""
|
| 510 |
+
std = 0.02
|
| 511 |
+
residual_std = std / math.sqrt(2 * self.config.num_layers)
|
| 512 |
+
|
| 513 |
+
for module in self.modules():
|
| 514 |
+
if isinstance(module, nn.Linear):
|
| 515 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 516 |
+
if module.bias is not None:
|
| 517 |
+
nn.init.zeros_(module.bias)
|
| 518 |
+
elif isinstance(module, nn.Embedding):
|
| 519 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 520 |
+
|
| 521 |
+
# Residual projection ๋ ์ด์ด์ ์ถ์๋ ์ด๊ธฐํ ์ ์ฉ
|
| 522 |
+
for layer in self.layers:
|
| 523 |
+
nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
|
| 524 |
+
nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
|
| 525 |
+
|
| 526 |
+
def forward(
|
| 527 |
+
self,
|
| 528 |
+
input_ids: torch.Tensor,
|
| 529 |
+
targets: Optional[torch.Tensor] = None,
|
| 530 |
+
position_offset: int = 0,
|
| 531 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 532 |
+
"""
|
| 533 |
+
Args:
|
| 534 |
+
input_ids: (batch_size, seq_len) - ํ ํฐ ID
|
| 535 |
+
targets: (batch_size, seq_len) - ์ ๋ต ํ ํฐ ID (ํ์ต ์)
|
| 536 |
+
position_offset: ์์น ์คํ์
(์ถ๋ก ์)
|
| 537 |
+
|
| 538 |
+
Returns:
|
| 539 |
+
logits: (batch_size, seq_len, vocab_size)
|
| 540 |
+
loss: ์ค์นผ๋ผ (targets ์ ๊ณต ์) ๋๋ None
|
| 541 |
+
"""
|
| 542 |
+
B, S = input_ids.shape
|
| 543 |
+
|
| 544 |
+
# โโ Step 1: Token Embedding โโ
|
| 545 |
+
# ๊ฐ ํ ํฐ ID๋ฅผ hidden_dim ์ฐจ์์ ๋ฒกํฐ๋ก ๋ณํ
|
| 546 |
+
h = self.token_embedding(input_ids) # (B, S, hidden_dim)
|
| 547 |
+
|
| 548 |
+
# โโ Step 2: Transformer Blocks โโ
|
| 549 |
+
# Activation Checkpointing: ํ์ต ์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 550 |
+
# (์ค๊ฐ ํ์ฑํ๋ฅผ ์ ์ฅํ์ง ์๊ณ , backward ์ ์ฌ๊ณ์ฐ)
|
| 551 |
+
for layer in self.layers:
|
| 552 |
+
if self.training and torch.is_grad_enabled():
|
| 553 |
+
# Activation Checkpointing ์ ์ฉ
|
| 554 |
+
h = torch.utils.checkpoint.checkpoint(
|
| 555 |
+
layer, h, None, position_offset,
|
| 556 |
+
use_reentrant=False, # PyTorch >= 2.0 ๊ถ์ฅ
|
| 557 |
+
)
|
| 558 |
+
else:
|
| 559 |
+
h = layer(h, mask=None, position_offset=position_offset)
|
| 560 |
+
|
| 561 |
+
# โโ Step 3: ์ต์ข
์ ๊ทํ โโ
|
| 562 |
+
h = self.final_norm(h)
|
| 563 |
+
|
| 564 |
+
# โโ Step 4: ์ถ๋ ฅ ๋ก์ง ๊ณ์ฐ โโ
|
| 565 |
+
logits = self.lm_head(h) # (B, S, vocab_size)
|
| 566 |
+
|
| 567 |
+
# โโ Step 5: Loss ๊ณ์ฐ (ํ์ต ์) โโ
|
| 568 |
+
loss = None
|
| 569 |
+
if targets is not None:
|
| 570 |
+
# Cross-Entropy Loss: ๋ค์ ํ ํฐ ์์ธก
|
| 571 |
+
# logits: (B, S, V) โ (B*S, V)
|
| 572 |
+
# targets: (B, S) โ (B*S,)
|
| 573 |
+
loss = F.cross_entropy(
|
| 574 |
+
logits.view(-1, self.config.vocab_size),
|
| 575 |
+
targets.view(-1),
|
| 576 |
+
ignore_index=-100, # ํจ๋ฉ ํ ํฐ ๋ฌด์
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
return logits, loss
|
| 580 |
+
|
| 581 |
+
def count_parameters(self, trainable_only: bool = True) -> int:
|
| 582 |
+
"""๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ."""
|
| 583 |
+
if trainable_only:
|
| 584 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 585 |
+
return sum(p.numel() for p in self.parameters())
|
| 586 |
+
|
| 587 |
+
@torch.no_grad()
|
| 588 |
+
def generate(
|
| 589 |
+
self,
|
| 590 |
+
input_ids: torch.Tensor,
|
| 591 |
+
max_new_tokens: int = 100,
|
| 592 |
+
temperature: float = 1.0,
|
| 593 |
+
top_k: int = 50,
|
| 594 |
+
top_p: float = 0.9,
|
| 595 |
+
) -> torch.Tensor:
|
| 596 |
+
"""ํ
์คํธ ์์ฑ (์ถ๋ก ).
|
| 597 |
+
|
| 598 |
+
Autoregressive ์์ฑ: ํ ํ ํฐ์ฉ ์์ธกํ์ฌ ์ด์ด๋ถ์ด๊ธฐ.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
input_ids: (1, prompt_len) - ์ด๊ธฐ ํ๋กฌํํธ
|
| 602 |
+
max_new_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์
|
| 603 |
+
temperature: ํ๋ฅ ๋ถํฌ ๋ ์นด๋ก์ ์กฐ์ (๋ฎ์์๋ก ๋ณด์์ )
|
| 604 |
+
top_k: ํ๋ฅ ์์ k๊ฐ๋ง ๊ณ ๋ ค
|
| 605 |
+
top_p: ๋์ ํ๋ฅ p๊น์ง๋ง ๊ณ ๋ ค (nucleus sampling)
|
| 606 |
+
"""
|
| 607 |
+
self.eval()
|
| 608 |
+
generated = input_ids
|
| 609 |
+
|
| 610 |
+
for _ in range(max_new_tokens):
|
| 611 |
+
# ํ์ฌ ์ํ์ค๊ฐ max_seq_len์ ์ด๊ณผํ๋ฉด ์๋ผ๋ด๊ธฐ
|
| 612 |
+
ctx = generated[:, -self.config.max_seq_len:]
|
| 613 |
+
|
| 614 |
+
# Forward pass
|
| 615 |
+
logits, _ = self(ctx)
|
| 616 |
+
# ๋ง์ง๋ง ํ ํฐ์ logits๋ง ์ฌ์ฉ (๋ค์ ํ ํฐ ์์ธก)
|
| 617 |
+
next_logits = logits[:, -1, :] / temperature
|
| 618 |
+
|
| 619 |
+
# โโ Top-K ํํฐ๋ง โโ
|
| 620 |
+
if top_k > 0:
|
| 621 |
+
top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 622 |
+
min_top_k = top_k_values[:, -1].unsqueeze(-1)
|
| 623 |
+
next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
|
| 624 |
+
|
| 625 |
+
# โโ Top-P (Nucleus) ํํฐ๋ง โโ
|
| 626 |
+
if top_p < 1.0:
|
| 627 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 628 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 629 |
+
# ๋์ ํ๋ฅ ์ด top_p๋ฅผ ์ด๊ณผํ๋ ํ ํฐ ์ ๊ฑฐ
|
| 630 |
+
remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 631 |
+
sorted_logits[remove_mask] = float("-inf")
|
| 632 |
+
# ์๋ ์์๋ก ๋ณต์
|
| 633 |
+
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 634 |
+
|
| 635 |
+
# ํ๋ฅ ๋ถํฌ์์ ์ํ๋ง
|
| 636 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 637 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 638 |
+
|
| 639 |
+
# ์์ฑ๋ ํ ํฐ ์ด์ด๋ถ์ด๊ธฐ
|
| 640 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 641 |
+
|
| 642 |
+
return generated
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
# ============================================================================
|
| 646 |
+
# 8. ์ ํธ๋ฆฌํฐ ํจ์
|
| 647 |
+
# ============================================================================
|
| 648 |
+
|
| 649 |
+
def count_parameters_detailed(model: LLMModel) -> dict:
|
| 650 |
+
"""๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ปดํฌ๋ํธ๋ณ๋ก ์์ธ ์ถ๋ ฅํฉ๋๋ค."""
|
| 651 |
+
total = 0
|
| 652 |
+
breakdown = {}
|
| 653 |
+
|
| 654 |
+
# Embedding
|
| 655 |
+
emb_params = model.token_embedding.weight.numel()
|
| 656 |
+
breakdown["token_embedding"] = emb_params
|
| 657 |
+
total += emb_params
|
| 658 |
+
|
| 659 |
+
# ๊ฐ ๋ ์ด์ด
|
| 660 |
+
layer_total = 0
|
| 661 |
+
layer_detail = {}
|
| 662 |
+
layer = model.layers[0]
|
| 663 |
+
|
| 664 |
+
for name, param in layer.named_parameters():
|
| 665 |
+
layer_detail[name] = param.numel()
|
| 666 |
+
layer_total += param.numel()
|
| 667 |
+
|
| 668 |
+
breakdown["per_layer"] = layer_detail
|
| 669 |
+
breakdown["per_layer_total"] = layer_total
|
| 670 |
+
breakdown["all_layers_total"] = layer_total * len(model.layers)
|
| 671 |
+
total += layer_total * len(model.layers)
|
| 672 |
+
|
| 673 |
+
# Final norm
|
| 674 |
+
norm_params = model.final_norm.weight.numel()
|
| 675 |
+
breakdown["final_norm"] = norm_params
|
| 676 |
+
total += norm_params
|
| 677 |
+
|
| 678 |
+
# LM head (weight tying์ด๋ฏ๋ก ์ค์ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ 0)
|
| 679 |
+
breakdown["lm_head"] = "weight tying (0 additional)"
|
| 680 |
+
breakdown["total"] = total
|
| 681 |
+
|
| 682 |
+
return breakdown
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
|
| 686 |
+
"""๋ชจ๋ธ์ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ถ์ ํฉ๋๋ค.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
dtype_bytes: 2 (bf16/fp16) ๋๋ 4 (fp32)
|
| 690 |
+
"""
|
| 691 |
+
# ๋๋ต์ ์ธ ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ
|
| 692 |
+
emb = config.vocab_size * config.hidden_dim
|
| 693 |
+
per_layer = (
|
| 694 |
+
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
|
| 695 |
+
+ config.num_heads * config.head_dim * config.hidden_dim # O proj
|
| 696 |
+
+ 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
|
| 697 |
+
+ 2 * config.hidden_dim # 2 ร RMSNorm
|
| 698 |
+
)
|
| 699 |
+
total_params = emb + per_layer * config.num_layers + config.hidden_dim
|
| 700 |
+
|
| 701 |
+
model_gb = total_params * dtype_bytes / 1e9
|
| 702 |
+
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states ร fp32
|
| 703 |
+
gradient_gb = total_params * dtype_bytes / 1e9
|
| 704 |
+
|
| 705 |
+
# ํ์ฑํ ๋ฉ๋ชจ๋ฆฌ (activation checkpointing ์ ์ฉ ๊ฐ์ )
|
| 706 |
+
# ๋๋ต์ ์ถ์ : batch_size ร seq_len ร hidden_dim ร num_layers ร factor
|
| 707 |
+
activation_gb = (
|
| 708 |
+
batch_size * config.max_seq_len * config.hidden_dim * 4 # ๋ฐ์ดํธ
|
| 709 |
+
* math.sqrt(config.num_layers) # checkpointing ํจ๊ณผ
|
| 710 |
+
/ 1e9
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
return {
|
| 714 |
+
"total_parameters": total_params,
|
| 715 |
+
"model_weights_gb": round(model_gb, 2),
|
| 716 |
+
"optimizer_states_gb": round(optimizer_gb, 2),
|
| 717 |
+
"gradients_gb": round(gradient_gb, 2),
|
| 718 |
+
"activations_estimated_gb": round(activation_gb, 2),
|
| 719 |
+
"total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
# ============================================================================
|
| 724 |
+
# 9. ๊ฒ์ฆ ์คํฌ๋ฆฝํธ (์คํ ์)
|
| 725 |
+
# ============================================================================
|
| 726 |
+
|
| 727 |
+
if __name__ == "__main__":
|
| 728 |
+
print("=" * 70)
|
| 729 |
+
print("LLM-1B-Lab: ๋ชจ๋ธ ์ํคํ
์ฒ ๊ฒ์ฆ")
|
| 730 |
+
print("=" * 70)
|
| 731 |
+
|
| 732 |
+
# โโ ๋๋ฒ๊ทธ ๋ชจ๋ธ (10M) ํ
์คํธ โโ
|
| 733 |
+
print("\n[1] Debug Model (~10M params)")
|
| 734 |
+
cfg_debug = ModelConfig.debug_10m()
|
| 735 |
+
model_debug = LLMModel(cfg_debug)
|
| 736 |
+
n_params = model_debug.count_parameters()
|
| 737 |
+
print(f" ํ๋ผ๋ฏธํฐ ์: {n_params:,} ({n_params / 1e6:.1f}M)")
|
| 738 |
+
|
| 739 |
+
# Forward pass ํ
์คํธ
|
| 740 |
+
dummy_input = torch.randint(0, cfg_debug.vocab_size, (2, 64))
|
| 741 |
+
dummy_target = torch.randint(0, cfg_debug.vocab_size, (2, 64))
|
| 742 |
+
logits, loss = model_debug(dummy_input, dummy_target)
|
| 743 |
+
print(f" Input shape: {dummy_input.shape}")
|
| 744 |
+
print(f" Logits shape: {logits.shape}")
|
| 745 |
+
print(f" Loss: {loss.item():.4f}")
|
| 746 |
+
# ์ด๊ธฐ loss โ ln(vocab_size) โ ln(32000) โ 10.37 ์ด๋ฉด ์ ์
|
| 747 |
+
expected_loss = math.log(cfg_debug.vocab_size)
|
| 748 |
+
print(f" Expected initial loss โ ln({cfg_debug.vocab_size}) = {expected_loss:.2f}")
|
| 749 |
+
|
| 750 |
+
# โโ 1B ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ ํ์ธ โโ
|
| 751 |
+
print("\n[2] Base Model (~1B params) โ ํ๋ผ๋ฏธํฐ ์๋ง ํ์ธ")
|
| 752 |
+
cfg_1b = ModelConfig.base_1b()
|
| 753 |
+
|
| 754 |
+
# ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ ์ ์์ผ๋ฏ๋ก meta device์์ ์์ฑ
|
| 755 |
+
with torch.device("meta"):
|
| 756 |
+
model_1b = LLMModel(cfg_1b)
|
| 757 |
+
n_params_1b = model_1b.count_parameters()
|
| 758 |
+
print(f" ํ๋ผ๋ฏธํฐ ์: {n_params_1b:,} ({n_params_1b / 1e6:.1f}M โ {n_params_1b / 1e9:.2f}B)")
|
| 759 |
+
|
| 760 |
+
# ์์ธ ํ๋ผ๋ฏธํฐ ๋ถํด
|
| 761 |
+
print("\n[3] ํ๋ผ๋ฏธํฐ ์์ธ ๋ถํด (1B)")
|
| 762 |
+
detail = count_parameters_detailed(model_1b)
|
| 763 |
+
print(f" Token Embedding: {detail['token_embedding']:,}")
|
| 764 |
+
print(f" Per Layer Total: {detail['per_layer_total']:,}")
|
| 765 |
+
print(f" All Layers ({cfg_1b.num_layers}): {detail['all_layers_total']:,}")
|
| 766 |
+
print(f" Final Norm: {detail['final_norm']:,}")
|
| 767 |
+
print(f" LM Head: {detail['lm_head']}")
|
| 768 |
+
print(f" โโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 769 |
+
print(f" TOTAL: {detail['total']:,}")
|
| 770 |
+
|
| 771 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ถ์
|
| 772 |
+
print("\n[4] GPU ๋ฉ๋ชจ๋ฆฌ ์ถ์ (A100 40GB, bf16, batch_size=4)")
|
| 773 |
+
mem = estimate_memory_gb(cfg_1b, batch_size=4, dtype_bytes=2)
|
| 774 |
+
print(f" ๋ชจ๋ธ ๊ฐ์ค์น: {mem['model_weights_gb']} GB")
|
| 775 |
+
print(f" ์ตํฐ๋ง์ด์ : {mem['optimizer_states_gb']} GB")
|
| 776 |
+
print(f" ๊ธฐ์ธ๊ธฐ: {mem['gradients_gb']} GB")
|
| 777 |
+
print(f" ํ์ฑํ (์ถ์ ): {mem['activations_estimated_gb']} GB")
|
| 778 |
+
print(f" โโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 779 |
+
print(f" ์ด ์ถ์ : {mem['total_estimated_gb']} GB")
|
| 780 |
+
|
| 781 |
+
# ํ
์คํธ ์์ฑ ํ
์คํธ (๋๋ฒ๊ทธ ๋ชจ๋ธ)
|
| 782 |
+
print("\n[5] ํ
์คํธ ์์ฑ ํ
์คํธ (10M debug model, ๋๋ค ๊ฐ์ค์น)")
|
| 783 |
+
prompt = torch.randint(0, cfg_debug.vocab_size, (1, 10))
|
| 784 |
+
generated = model_debug.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=50)
|
| 785 |
+
print(f" Prompt length: {prompt.shape[1]}")
|
| 786 |
+
print(f" Generated length: {generated.shape[1]}")
|
| 787 |
+
print(f" Generated token IDs: {generated[0].tolist()}")
|
| 788 |
+
|
| 789 |
+
print("\n" + "=" * 70)
|
| 790 |
+
print("โ
๋ชจ๋ ๊ฒ์ฆ ํต๊ณผ!")
|
| 791 |
+
print("=" * 70)
|
_archive/llm-1b-trainer.py
ADDED
|
@@ -0,0 +1,1108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-1B-Lab: ํ์ต ๋ฃจํ (Training Loop)
|
| 3 |
+
========================================
|
| 4 |
+
Gradient Accumulation, Mixed Precision, LR Scheduling,
|
| 5 |
+
์ฒดํฌํฌ์ธํธ ์ ์ฅ/๋ณต์, wandb ๋ก๊น
์ ํฌํจํ ์์ ํ ํ์ต ํ์ดํ๋ผ์ธ.
|
| 6 |
+
|
| 7 |
+
์ ์ฒด ํ๋ฆ:
|
| 8 |
+
๋ฐฐ์น ๊ฐ์ ธ์ค๊ธฐ
|
| 9 |
+
โ Forward (bf16 autocast)
|
| 10 |
+
โ Loss / accumulation_steps (๋ฏธ๋๋ฐฐ์น ํ๊ท )
|
| 11 |
+
โ Backward (gradient ๋์ )
|
| 12 |
+
โ [accumulation_steps๋ง๋ค]
|
| 13 |
+
โ Gradient Clipping
|
| 14 |
+
โ Optimizer Step
|
| 15 |
+
โ LR Scheduler Step
|
| 16 |
+
โ Logging
|
| 17 |
+
โ [checkpoint_interval๋ง๋ค]
|
| 18 |
+
โ ์ฒดํฌํฌ์ธํธ ์ ์ฅ (Google Drive)
|
| 19 |
+
โ [eval_interval๋ง๋ค]
|
| 20 |
+
โ ๊ฒ์ฆ Loss/Perplexity ์ธก์
|
| 21 |
+
|
| 22 |
+
์ค์น ํ์:
|
| 23 |
+
pip install wandb torch
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import math
|
| 28 |
+
import time
|
| 29 |
+
import json
|
| 30 |
+
import shutil
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from typing import Optional, Dict, Any, Tuple
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
from torch.utils.data import DataLoader
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ============================================================================
|
| 41 |
+
# 1. ํ์ต ์ค์
|
| 42 |
+
# ============================================================================
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class TrainConfig:
|
| 46 |
+
"""ํ์ต ํ์ดํผํ๋ผ๋ฏธํฐ + ์ธํ๋ผ ์ค์ .
|
| 47 |
+
|
| 48 |
+
Colab Pro+ (A100 40GB) ๊ธฐ์ค ์ต์ ํ๋ ๊ธฐ๋ณธ๊ฐ.
|
| 49 |
+
๋ชจ๋ ๊ฐ์ '์ ์ด ๊ฐ์ธ์ง' ์ค๋ช
์ ํฌํจํฉ๋๋ค.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
# โโ ์ต์ ํ โโ
|
| 53 |
+
learning_rate: float = 3e-4
|
| 54 |
+
"""Peak LR. 1B ๋ชจ๋ธ ๊ธฐ์ค 3e-4๊ฐ ํ์ค.
|
| 55 |
+
GPT-3 ๋
ผ๋ฌธ์์ ๋ชจ๋ธ ํฌ๊ธฐ๋ณ ์ต์ LR์ ์ ์:
|
| 56 |
+
125M โ 6e-4, 350M โ 3e-4, 1.3B โ 2e-4
|
| 57 |
+
์ฐ๋ฆฌ ๋ชจ๋ธ(1.1B)์ 3e-4์์ ์์, ๋ถ์์ ํ๋ฉด 2e-4๋ก ํํฅ."""
|
| 58 |
+
|
| 59 |
+
min_learning_rate: float = 3e-5
|
| 60 |
+
"""Cosine decay ์ต์ ์ . ๋ณดํต peak์ 10%.
|
| 61 |
+
๋๋ฌด ๋ฎ์ผ๋ฉด ํ์ต ํ๋ฐ ์ ์ฒด, ๋๋ฌด ๋์ผ๋ฉด ์๋ ด ๋ถ์์ ."""
|
| 62 |
+
|
| 63 |
+
weight_decay: float = 0.1
|
| 64 |
+
"""AdamW์ L2 ์ ๊ทํ. 0.1์ด LLM ํ์ค.
|
| 65 |
+
Embedding๊ณผ Bias์๋ ์ ์ฉํ์ง ์์ (๊ด๋ก)."""
|
| 66 |
+
|
| 67 |
+
beta1: float = 0.9
|
| 68 |
+
beta2: float = 0.95
|
| 69 |
+
"""Adam ๋ชจ๋ฉํ
๊ณ์. ฮฒ2=0.95๋ LLM ํ์ต์์ ฮฒ2=0.999๋ณด๋ค ์์ ์ .
|
| 70 |
+
ํฐ ๋ฐฐ์น + ๊ธด ํ์ต์์ ฮฒ2๊ฐ ๋๋ฌด ํฌ๋ฉด ์ ์ ์๋๊ฐ ๋๋ฆผ."""
|
| 71 |
+
|
| 72 |
+
adam_eps: float = 1e-8
|
| 73 |
+
grad_clip: float = 1.0
|
| 74 |
+
"""Gradient Clipping: gradient norm์ด 1.0์ ์ด๊ณผํ๋ฉด ์ค์ผ์ผ๋ง.
|
| 75 |
+
ํ์ต ์ด๋ฐ์ด๋ ๋
ธ์ด์ฆ ๋ฐ์ดํฐ์์ ๋ฐ์ํ๋ gradient spike ๋ฐฉ์ง."""
|
| 76 |
+
|
| 77 |
+
# โโ ์ค์ผ์ค๋ง โโ
|
| 78 |
+
warmup_steps: int = 2000
|
| 79 |
+
"""Warmup: ์ฒ์ 2000 ์คํ
๋์ LR์ 0 โ peak๋ก ์ ํ ์ฆ๊ฐ.
|
| 80 |
+
์ ํ์ํ๊ฐ?
|
| 81 |
+
- ์ด๊ธฐ ๊ฐ์ค์น๊ฐ ๋๋ค โ ํฐ LR์ ๋ถ์์ ํ ์
๋ฐ์ดํธ ์ ๋ฐ
|
| 82 |
+
- ์์ LR๋ก ์์ํด ๋ชจ๋ธ์ด '๋ฐฉํฅ'์ ์ก๊ฒ ํ ํ ๋ณธ๊ฒฉ ํ์ต
|
| 83 |
+
- 2000์ ์ ์ฒด ํ์ต์ ~10%๊ฐ ์ ๋น (๊ฒฝํ์ ๊ท์น)."""
|
| 84 |
+
|
| 85 |
+
total_steps: int = 20_000
|
| 86 |
+
"""์ด ํ์ต ์คํ
์.
|
| 87 |
+
10B tokens / (128 batch ร 2048 seq_len) โ 38,000 ์ด์ง๋ง,
|
| 88 |
+
gradient accumulation ํฌํจ effective step ๊ธฐ์ค ~20,000."""
|
| 89 |
+
|
| 90 |
+
# โโ ๋ฐฐ์น โโ
|
| 91 |
+
micro_batch_size: int = 4
|
| 92 |
+
"""GPU์ ํ ๋ฒ์ ์ฌ๋ฆฌ๋ ๋ฐฐ์น ํฌ๊ธฐ.
|
| 93 |
+
A100 40GB์์ 1B ๋ชจ๋ธ bf16 ๊ธฐ์ค 4๊ฐ ์์ ํ ์ํ."""
|
| 94 |
+
|
| 95 |
+
gradient_accumulation_steps: int = 32
|
| 96 |
+
"""Gradient ๋์ ํ์. Effective batch = 4 ร 32 = 128.
|
| 97 |
+
์ ํฐ ๋ฐฐ์น๊ฐ ์ข์๊ฐ?
|
| 98 |
+
- gradient ์ถ์ ์ด ์์ ์ (๋
ธ์ด์ฆ ๊ฐ์)
|
| 99 |
+
- LLM ํ์ต์ ๋ณดํต effective batch 128~512
|
| 100 |
+
- ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์ ์ด ๊ฐ์ ๋๋ฆฌ๊ณ micro_batch๋ฅผ ์ค์."""
|
| 101 |
+
|
| 102 |
+
# โโ Mixed Precision โโ
|
| 103 |
+
dtype: str = "bfloat16"
|
| 104 |
+
"""bfloat16: A100์์ ์ง์, fp16๋ณด๋ค ์์น ์์ ์ฑ ์ฐ์.
|
| 105 |
+
exponent ๋นํธ๊ฐ fp32์ ๋์ผ โ overflow/underflow ์ํ ์ ์.
|
| 106 |
+
T4/V100 ํด๋ฐฑ ์ 'float16'์ผ๋ก ๋ณ๊ฒฝ."""
|
| 107 |
+
|
| 108 |
+
# โโ ์ฒดํฌํฌ์ธํธ โโ
|
| 109 |
+
checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
|
| 110 |
+
"""Google Drive ๊ฒฝ๋ก. Colab ์ธ์
๋ง๋ฃ ์์๋ ๋ณด์กด๋จ."""
|
| 111 |
+
|
| 112 |
+
checkpoint_interval: int = 500
|
| 113 |
+
"""500 ์คํ
๋ง๋ค ์ฒดํฌํฌ์ธํธ ์ ์ฅ.
|
| 114 |
+
A100 ๊ธฐ์ค ~30๋ถ ๊ฐ๊ฒฉ. ๋๋ฌด ์ฆ์ผ๋ฉด I/O ์ค๋ฒํค๋,
|
| 115 |
+
๋๋ฌด ๋๋ฌผ๋ฉด ์ธ์
๋ง๋ฃ ์ ์์ค ํผ."""
|
| 116 |
+
|
| 117 |
+
max_checkpoints: int = 3
|
| 118 |
+
"""๋กค๋ง ๋ณด๊ด ์. ์ค๋๋ ๊ฒ๋ถํฐ ์ญ์ .
|
| 119 |
+
์ฒดํฌํฌ์ธํธ 1๊ฐ โ 8-10GB โ 3๊ฐ๋ฉด ~30GB."""
|
| 120 |
+
|
| 121 |
+
# โโ ๋ก๊น
โโ
|
| 122 |
+
log_interval: int = 10
|
| 123 |
+
"""10 ์คํ
๋ง๋ค ์ฝ์ + wandb ๋ก๊น
."""
|
| 124 |
+
|
| 125 |
+
eval_interval: int = 500
|
| 126 |
+
"""500 ์คํ
๋ง๋ค ๊ฒ์ฆ Loss ์ธก์ ."""
|
| 127 |
+
|
| 128 |
+
eval_steps: int = 20
|
| 129 |
+
"""๊ฒ์ฆ ์ ์ฌ์ฉํ ๋ฐฐ์น ์. 20 ร 4 ร 2048 โ 160K ํ ํฐ."""
|
| 130 |
+
|
| 131 |
+
# โโ wandb โโ
|
| 132 |
+
wandb_project: str = "llm-1b-lab"
|
| 133 |
+
wandb_run_name: Optional[str] = None
|
| 134 |
+
use_wandb: bool = True
|
| 135 |
+
|
| 136 |
+
# โโ ์ฌํ์ฑ โโ
|
| 137 |
+
seed: int = 42
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def effective_batch_size(self) -> int:
|
| 141 |
+
return self.micro_batch_size * self.gradient_accumulation_steps
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def tokens_per_step(self) -> int:
|
| 145 |
+
"""ํ optimizer step๋น ์ฒ๋ฆฌ ํ ํฐ ์."""
|
| 146 |
+
# max_seq_len์ ์ธ๋ถ์์ ์ฃผ์
(ModelConfig ์ฐธ์กฐ)
|
| 147 |
+
return self.effective_batch_size * 2048
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def torch_dtype(self) -> torch.dtype:
|
| 151 |
+
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[self.dtype]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ============================================================================
|
| 155 |
+
# 2. ํ์ต๋ฅ ์ค์ผ์ค๋ฌ (Cosine with Warmup)
|
| 156 |
+
# ============================================================================
|
| 157 |
+
|
| 158 |
+
class CosineWarmupScheduler:
|
| 159 |
+
"""Cosine Annealing with Linear Warmup.
|
| 160 |
+
|
| 161 |
+
LR ๊ณก์ :
|
| 162 |
+
โโโโ peak_lr โโโโโโโโฒ
|
| 163 |
+
โ โฒ cosine decay
|
| 164 |
+
โ warmup (linear) โฒ
|
| 165 |
+
โ/ โฒ_______ min_lr
|
| 166 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ steps
|
| 167 |
+
|
| 168 |
+
์ Cosine Decay์ธ๊ฐ?
|
| 169 |
+
- Step decay: ๊ฐ์์ค๋ฌ์ด LR ํ๋ฝ โ Loss ๋ถ์์
|
| 170 |
+
- Linear decay: ํ๋ฐ๋ถ LR์ด ๋๋ฌด ๋นจ๋ฆฌ ๊ฐ์
|
| 171 |
+
- Cosine: ๋ถ๋๋ฌ์ด ๊ฐ์, ํ์ต ํ๋ฐ์๋ ์ ์ ํ LR ์ ์ง
|
| 172 |
+
- GPT-3, LLaMA, Chinchilla ๋ฑ ๋๋ถ๋ถ์ LLM์ด ์ฌ์ฉ
|
| 173 |
+
|
| 174 |
+
๊ตฌํ ์ฐธ๊ณ :
|
| 175 |
+
PyTorch ๋ด์ฅ ์ค์ผ์ค๋ฌ(CosineAnnealingLR ๋ฑ)๋ ์์ง๋ง,
|
| 176 |
+
warmup + min_lr + ์ฒดํฌํฌ์ธํธ ๋ณต์์ ์ํด ์ง์ ๊ตฌํ์ด ๋ ์ ์ฐํฉ๋๋ค.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def __init__(self, config: TrainConfig):
|
| 180 |
+
self.peak_lr = config.learning_rate
|
| 181 |
+
self.min_lr = config.min_learning_rate
|
| 182 |
+
self.warmup_steps = config.warmup_steps
|
| 183 |
+
self.total_steps = config.total_steps
|
| 184 |
+
|
| 185 |
+
def get_lr(self, step: int) -> float:
|
| 186 |
+
"""ํ์ฌ step์ ํด๋นํ๋ ํ์ต๋ฅ ์ ๋ฐํํฉ๋๋ค.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
step: ํ์ฌ optimizer step (0-indexed)
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
ํ์ต๋ฅ (float)
|
| 193 |
+
"""
|
| 194 |
+
# Phase 1: Linear Warmup
|
| 195 |
+
if step < self.warmup_steps:
|
| 196 |
+
# 0 โ peak_lr ์ ํ ์ฆ๊ฐ
|
| 197 |
+
return self.peak_lr * (step / self.warmup_steps)
|
| 198 |
+
|
| 199 |
+
# Phase 2: Cosine Decay
|
| 200 |
+
# warmup ์ดํ ๋จ์ ์งํ๋ฅ (0.0 โ 1.0)
|
| 201 |
+
decay_steps = self.total_steps - self.warmup_steps
|
| 202 |
+
progress = (step - self.warmup_steps) / max(decay_steps, 1)
|
| 203 |
+
progress = min(progress, 1.0) # ์์ ์ฅ์น
|
| 204 |
+
|
| 205 |
+
# Cosine ๊ณต์: min_lr + 0.5 ร (peak - min) ร (1 + cos(ฯ ร progress))
|
| 206 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 207 |
+
lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 208 |
+
|
| 209 |
+
return lr
|
| 210 |
+
|
| 211 |
+
def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
|
| 212 |
+
"""Optimizer์ ํ์ต๋ฅ ์ ์
๋ฐ์ดํธํฉ๋๋ค."""
|
| 213 |
+
lr = self.get_lr(step)
|
| 214 |
+
for param_group in optimizer.param_groups:
|
| 215 |
+
param_group["lr"] = lr
|
| 216 |
+
return lr
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ============================================================================
|
| 220 |
+
# 3. ์ฒดํฌํฌ์ธํธ ๊ด๋ฆฌ
|
| 221 |
+
# ============================================================================
|
| 222 |
+
|
| 223 |
+
class CheckpointManager:
|
| 224 |
+
"""ํ์ต ์ํ ์ ์ฅ/๋ณต์ ๊ด๋ฆฌ์.
|
| 225 |
+
|
| 226 |
+
Colab์์ ์ฒดํฌํฌ์ธํธ๊ฐ ์ค์ํ ์ด์ :
|
| 227 |
+
- ์ธ์
๋ง๋ฃ (์ต๋ ~24์๊ฐ) ์ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ์ํ ์๋ฉธ
|
| 228 |
+
- Google Drive์ ์ ์ฅํ๋ฉด ์ธ์
๊ฐ ์ฐ์ ํ์ต ๊ฐ๋ฅ
|
| 229 |
+
- ์ตํฐ๋ง์ด์ ์ํ๊น์ง ์ ์ฅํด์ผ AdamW ๋ชจ๋ฉํ
์ด ์ ์ง๋จ
|
| 230 |
+
|
| 231 |
+
์ ์ฅ ๋ด์ฉ:
|
| 232 |
+
- model_state_dict: ๋ชจ๋ธ ๊ฐ์ค์น
|
| 233 |
+
- optimizer_state_dict: ์ตํฐ๋ง์ด์ ์ํ (m, v ๋ชจ๋ฉํ
)
|
| 234 |
+
- step: ํ์ฌ ํ์ต ์คํ
|
| 235 |
+
- best_val_loss: ์ต์ ๊ฒ์ฆ Loss
|
| 236 |
+
- config: ํ์ต ์ค์ (์ฌํ์ฑ)
|
| 237 |
+
- rng_states: ๋๋ค ์๋ ์ํ (์์ ์ฌํ)
|
| 238 |
+
- metrics_history: ํ์ต ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 239 |
+
- wandb_run_id: wandb ์คํ ID (๋ก๊น
์ฐ์์ฑ)
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(self, config: TrainConfig):
|
| 243 |
+
self.config = config
|
| 244 |
+
self.checkpoint_dir = Path(config.checkpoint_dir)
|
| 245 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 246 |
+
self.max_checkpoints = config.max_checkpoints
|
| 247 |
+
|
| 248 |
+
def save(
|
| 249 |
+
self,
|
| 250 |
+
model: nn.Module,
|
| 251 |
+
optimizer: torch.optim.Optimizer,
|
| 252 |
+
step: int,
|
| 253 |
+
best_val_loss: float,
|
| 254 |
+
metrics_history: Dict[str, list],
|
| 255 |
+
wandb_run_id: Optional[str] = None,
|
| 256 |
+
):
|
| 257 |
+
"""์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค."""
|
| 258 |
+
ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
|
| 259 |
+
ckpt_path.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
print(f"\n๐พ ์ฒดํฌํฌ์ธํธ ์ ์ฅ: {ckpt_path}")
|
| 262 |
+
start = time.time()
|
| 263 |
+
|
| 264 |
+
# 1) ๋ชจ๋ธ ๊ฐ์ค์น (bf16 ์ํ ๊ทธ๋๋ก)
|
| 265 |
+
torch.save(model.state_dict(), ckpt_path / "model.pt")
|
| 266 |
+
|
| 267 |
+
# 2) ์ตํฐ๋ง์ด์ ์ํ (fp32 ๋ชจ๋ฉํ
ํฌํจ, ํฌ๊ธฐ ํผ)
|
| 268 |
+
torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
|
| 269 |
+
|
| 270 |
+
# 3) ํ์ต ๋ฉํ ์ ๋ณด
|
| 271 |
+
meta = {
|
| 272 |
+
"step": step,
|
| 273 |
+
"best_val_loss": best_val_loss,
|
| 274 |
+
"wandb_run_id": wandb_run_id,
|
| 275 |
+
"config": self.config.__dict__,
|
| 276 |
+
}
|
| 277 |
+
with open(ckpt_path / "meta.json", "w") as f:
|
| 278 |
+
json.dump(meta, f, indent=2)
|
| 279 |
+
|
| 280 |
+
# 4) ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 281 |
+
torch.save(metrics_history, ckpt_path / "metrics.pt")
|
| 282 |
+
|
| 283 |
+
# 5) ๋๋ค ์ํ (์์ ์ฌํ์ ์ํด)
|
| 284 |
+
rng_states = {
|
| 285 |
+
"python": torch.random.get_rng_state(),
|
| 286 |
+
"cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 287 |
+
}
|
| 288 |
+
torch.save(rng_states, ckpt_path / "rng_states.pt")
|
| 289 |
+
|
| 290 |
+
elapsed = time.time() - start
|
| 291 |
+
ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
|
| 292 |
+
print(f" ์ ์ฅ ์๋ฃ: {ckpt_size:.2f} GB, {elapsed:.1f}์ด")
|
| 293 |
+
|
| 294 |
+
# ์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์ (๋กค๋ง)
|
| 295 |
+
self._cleanup_old_checkpoints()
|
| 296 |
+
|
| 297 |
+
def load_latest(
|
| 298 |
+
self,
|
| 299 |
+
model: nn.Module,
|
| 300 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 301 |
+
device: torch.device = torch.device("cpu"),
|
| 302 |
+
) -> Dict[str, Any]:
|
| 303 |
+
"""๊ฐ์ฅ ์ต๊ทผ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
{"step", "best_val_loss", "wandb_run_id", "metrics_history"}
|
| 307 |
+
๋๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฉด None
|
| 308 |
+
"""
|
| 309 |
+
ckpt_path = self._find_latest()
|
| 310 |
+
if ckpt_path is None:
|
| 311 |
+
print("[Checkpoint] ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ ์์. ์ฒ์๋ถํฐ ์์ํฉ๋๋ค.")
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
print(f"\n๐ ์ฒดํฌํฌ์ธํธ ๋ก๋: {ckpt_path}")
|
| 315 |
+
start = time.time()
|
| 316 |
+
|
| 317 |
+
# 1) ๋ชจ๋ธ ๊ฐ์ค์น
|
| 318 |
+
model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
|
| 319 |
+
model.load_state_dict(model_state)
|
| 320 |
+
del model_state # ๋ฉ๋ชจ๋ฆฌ ํด์
|
| 321 |
+
|
| 322 |
+
# 2) ์ตํฐ๋ง์ด์ ์ํ
|
| 323 |
+
if optimizer is not None:
|
| 324 |
+
optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
|
| 325 |
+
optimizer.load_state_dict(optim_state)
|
| 326 |
+
del optim_state
|
| 327 |
+
|
| 328 |
+
# 3) ๋ฉํ ์ ๋ณด
|
| 329 |
+
with open(ckpt_path / "meta.json", "r") as f:
|
| 330 |
+
meta = json.load(f)
|
| 331 |
+
|
| 332 |
+
# 4) ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 333 |
+
metrics_history = {}
|
| 334 |
+
metrics_path = ckpt_path / "metrics.pt"
|
| 335 |
+
if metrics_path.exists():
|
| 336 |
+
metrics_history = torch.load(metrics_path, weights_only=False)
|
| 337 |
+
|
| 338 |
+
# 5) ๋๋ค ์ํ ๋ณต์
|
| 339 |
+
rng_path = ckpt_path / "rng_states.pt"
|
| 340 |
+
if rng_path.exists():
|
| 341 |
+
rng_states = torch.load(rng_path, weights_only=False)
|
| 342 |
+
torch.random.set_rng_state(rng_states["python"])
|
| 343 |
+
if rng_states["cuda"] is not None and torch.cuda.is_available():
|
| 344 |
+
torch.cuda.set_rng_state(rng_states["cuda"])
|
| 345 |
+
|
| 346 |
+
elapsed = time.time() - start
|
| 347 |
+
print(f" ๋ก๋ ์๋ฃ: step={meta['step']}, {elapsed:.1f}์ด")
|
| 348 |
+
|
| 349 |
+
return {
|
| 350 |
+
"step": meta["step"],
|
| 351 |
+
"best_val_loss": meta["best_val_loss"],
|
| 352 |
+
"wandb_run_id": meta.get("wandb_run_id"),
|
| 353 |
+
"metrics_history": metrics_history,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
def _find_latest(self) -> Optional[Path]:
|
| 357 |
+
"""๊ฐ์ฅ ์ต๊ทผ ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ต๋๋ค."""
|
| 358 |
+
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 359 |
+
return ckpts[-1] if ckpts else None
|
| 360 |
+
|
| 361 |
+
def _cleanup_old_checkpoints(self):
|
| 362 |
+
"""์ค๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ญ์ ํฉ๋๋ค (๋กค๋ง)."""
|
| 363 |
+
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 364 |
+
while len(ckpts) > self.max_checkpoints:
|
| 365 |
+
old = ckpts.pop(0)
|
| 366 |
+
print(f" ๐๏ธ ์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์ : {old.name}")
|
| 367 |
+
shutil.rmtree(old)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# ============================================================================
|
| 371 |
+
# 4. ๋ฉํธ๋ฆญ ์ถ์ ๊ธฐ
|
| 372 |
+
# ============================================================================
|
| 373 |
+
|
| 374 |
+
class MetricsTracker:
|
| 375 |
+
"""ํ์ต ๋ฉํธ๋ฆญ์ ์ถ์ ํ๊ณ ๋ก๊น
ํฉ๋๋ค.
|
| 376 |
+
|
| 377 |
+
์ถ์ ํญ๋ชฉ:
|
| 378 |
+
- train/loss: ํ์ต Loss (Cross-Entropy)
|
| 379 |
+
- train/lr: ํ์ฌ ํ์ต๋ฅ
|
| 380 |
+
- train/grad_norm: Gradient L2 Norm
|
| 381 |
+
- train/tokens_per_sec: ์ฒ๋ฆฌ๋
|
| 382 |
+
- train/gpu_mem_gb: GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋
|
| 383 |
+
- val/loss: ๊ฒ์ฆ Loss
|
| 384 |
+
- val/perplexity: ๊ฒ์ฆ Perplexity (= exp(loss))
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
def __init__(self, config: TrainConfig):
|
| 388 |
+
self.config = config
|
| 389 |
+
self.history: Dict[str, list] = {
|
| 390 |
+
"step": [],
|
| 391 |
+
"train_loss": [],
|
| 392 |
+
"learning_rate": [],
|
| 393 |
+
"grad_norm": [],
|
| 394 |
+
"tokens_per_sec": [],
|
| 395 |
+
"gpu_mem_gb": [],
|
| 396 |
+
"val_loss": [],
|
| 397 |
+
"val_ppl": [],
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
# wandb ์ด๊ธฐํ
|
| 401 |
+
self.wandb_run = None
|
| 402 |
+
if config.use_wandb:
|
| 403 |
+
self._init_wandb()
|
| 404 |
+
|
| 405 |
+
def _init_wandb(self, resume_id: Optional[str] = None):
|
| 406 |
+
"""wandb ์ด๊ธฐํ (์ธ์
๊ฐ ์ฐ์ ๋ก๊น
์ง์)."""
|
| 407 |
+
try:
|
| 408 |
+
import wandb
|
| 409 |
+
|
| 410 |
+
run_id = resume_id or wandb.util.generate_id()
|
| 411 |
+
self.wandb_run = wandb.init(
|
| 412 |
+
project=self.config.wandb_project,
|
| 413 |
+
name=self.config.wandb_run_name or f"1b-run-{run_id[:6]}",
|
| 414 |
+
id=run_id,
|
| 415 |
+
resume="allow",
|
| 416 |
+
config=self.config.__dict__,
|
| 417 |
+
)
|
| 418 |
+
print(f"[wandb] ์ด๊ธฐํ ์๋ฃ: {self.wandb_run.url}")
|
| 419 |
+
except ImportError:
|
| 420 |
+
print("[wandb] ์ค์น๋์ง ์์. ์ฝ์ ๋ก๊น
๋ง ์ฌ์ฉํฉ๋๋ค.")
|
| 421 |
+
self.config.use_wandb = False
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"[wandb] ์ด๊ธฐํ ์คํจ: {e}. ์ฝ์ ๋ก๊น
๋ง ์ฌ์ฉํฉ๋๋ค.")
|
| 424 |
+
self.config.use_wandb = False
|
| 425 |
+
|
| 426 |
+
def resume_wandb(self, run_id: str):
|
| 427 |
+
"""์ด์ wandb ์คํ์ ์ด์ด์ ๋ก๊น
ํฉ๋๋ค."""
|
| 428 |
+
if self.config.use_wandb:
|
| 429 |
+
self._init_wandb(resume_id=run_id)
|
| 430 |
+
|
| 431 |
+
def log_train_step(
|
| 432 |
+
self,
|
| 433 |
+
step: int,
|
| 434 |
+
loss: float,
|
| 435 |
+
lr: float,
|
| 436 |
+
grad_norm: float,
|
| 437 |
+
tokens_per_sec: float,
|
| 438 |
+
gpu_mem_gb: float,
|
| 439 |
+
):
|
| 440 |
+
"""ํ์ต ์คํ
๋ฉํธ๋ฆญ์ ๊ธฐ๋กํฉ๋๋ค."""
|
| 441 |
+
self.history["step"].append(step)
|
| 442 |
+
self.history["train_loss"].append(loss)
|
| 443 |
+
self.history["learning_rate"].append(lr)
|
| 444 |
+
self.history["grad_norm"].append(grad_norm)
|
| 445 |
+
self.history["tokens_per_sec"].append(tokens_per_sec)
|
| 446 |
+
self.history["gpu_mem_gb"].append(gpu_mem_gb)
|
| 447 |
+
|
| 448 |
+
if self.config.use_wandb and self.wandb_run:
|
| 449 |
+
import wandb
|
| 450 |
+
|
| 451 |
+
wandb.log({
|
| 452 |
+
"train/loss": loss,
|
| 453 |
+
"train/lr": lr,
|
| 454 |
+
"train/grad_norm": grad_norm,
|
| 455 |
+
"train/tokens_per_sec": tokens_per_sec,
|
| 456 |
+
"train/gpu_mem_gb": gpu_mem_gb,
|
| 457 |
+
}, step=step)
|
| 458 |
+
|
| 459 |
+
def log_eval(self, step: int, val_loss: float, val_ppl: float):
|
| 460 |
+
"""๊ฒ์ฆ ๋ฉํธ๋ฆญ์ ๊ธฐ๋กํฉ๋๋ค."""
|
| 461 |
+
self.history["val_loss"].append(val_loss)
|
| 462 |
+
self.history["val_ppl"].append(val_ppl)
|
| 463 |
+
|
| 464 |
+
if self.config.use_wandb and self.wandb_run:
|
| 465 |
+
import wandb
|
| 466 |
+
|
| 467 |
+
wandb.log({
|
| 468 |
+
"val/loss": val_loss,
|
| 469 |
+
"val/perplexity": val_ppl,
|
| 470 |
+
}, step=step)
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def wandb_run_id(self) -> Optional[str]:
|
| 474 |
+
if self.wandb_run:
|
| 475 |
+
return self.wandb_run.id
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
# ============================================================================
|
| 480 |
+
# 5. Optimizer ์์ฑ (AdamW with weight decay ๋ถ๋ฆฌ)
|
| 481 |
+
# ============================================================================
|
| 482 |
+
|
| 483 |
+
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
|
| 484 |
+
"""AdamW ์ตํฐ๋ง์ด์ ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 485 |
+
|
| 486 |
+
Weight Decay ๋ถ๋ฆฌ ๊ท์น:
|
| 487 |
+
- Decay ์ ์ฉ: Linear ๊ฐ์ค์น (attention proj, FFN ๋ฑ)
|
| 488 |
+
- Decay ๋ฏธ์ ์ฉ: Embedding, LayerNorm/RMSNorm, Bias
|
| 489 |
+
|
| 490 |
+
์ ๋ถ๋ฆฌํ๋๊ฐ?
|
| 491 |
+
- Weight Decay๋ ํฐ ๊ฐ์ค์น์ ํจ๋ํฐ๋ฅผ ์ฃผ์ด ๊ณผ์ ํฉ ๋ฐฉ์ง
|
| 492 |
+
- ํ์ง๋ง Norm์ scale ํ๋ผ๋ฏธํฐ์ ์ ์ฉํ๋ฉด ์ ๊ทํ ํจ๊ณผ๋ฅผ ๋ฐฉํด
|
| 493 |
+
- Embedding์ ์ ์ฉํ๋ฉด ํฌ๊ท ํ ํฐ์ ํํ์ด 0์ผ๋ก ์์ถ
|
| 494 |
+
- 1D ํ๋ผ๋ฏธํฐ(bias, norm weight)๋ decay์์ ์ ์ธํ๋ ๊ฒ์ด ๊ด๋ก
|
| 495 |
+
"""
|
| 496 |
+
# ํ๋ผ๋ฏธํฐ๋ฅผ decay/no-decay ๊ทธ๋ฃน์ผ๋ก ๋ถ๋ฆฌ
|
| 497 |
+
decay_params = []
|
| 498 |
+
no_decay_params = []
|
| 499 |
+
|
| 500 |
+
for name, param in model.named_parameters():
|
| 501 |
+
if not param.requires_grad:
|
| 502 |
+
continue
|
| 503 |
+
|
| 504 |
+
# 1D ํ
์(bias, norm weight) ๋๋ embedding โ no decay
|
| 505 |
+
if param.dim() <= 1 or "embedding" in name:
|
| 506 |
+
no_decay_params.append(param)
|
| 507 |
+
else:
|
| 508 |
+
decay_params.append(param)
|
| 509 |
+
|
| 510 |
+
param_groups = [
|
| 511 |
+
{"params": decay_params, "weight_decay": config.weight_decay},
|
| 512 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
n_decay = sum(p.numel() for p in decay_params)
|
| 516 |
+
n_no_decay = sum(p.numel() for p in no_decay_params)
|
| 517 |
+
print(f"[Optimizer] Decay ํ๋ผ๋ฏธํฐ: {n_decay:,} ({n_decay/1e6:.1f}M)")
|
| 518 |
+
print(f"[Optimizer] No-decay ํ๋ผ๋ฏธํฐ: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
|
| 519 |
+
|
| 520 |
+
optimizer = torch.optim.AdamW(
|
| 521 |
+
param_groups,
|
| 522 |
+
lr=config.learning_rate,
|
| 523 |
+
betas=(config.beta1, config.beta2),
|
| 524 |
+
eps=config.adam_eps,
|
| 525 |
+
fused=torch.cuda.is_available(), # CUDA fused AdamW (๋ ๋น ๋ฆ)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
return optimizer
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# ============================================================================
|
| 532 |
+
# 6. Trainer (ํต์ฌ ํ์ต ๋ฃจํ)
|
| 533 |
+
# ============================================================================
|
| 534 |
+
|
| 535 |
+
class Trainer:
|
| 536 |
+
"""LLM ์ฌ์ ํ์ต ํธ๋ ์ด๋.
|
| 537 |
+
|
| 538 |
+
ํ์ต ๋ฃจํ์ ํต์ฌ ๊ตฌ์กฐ:
|
| 539 |
+
```
|
| 540 |
+
for step in range(total_steps):
|
| 541 |
+
# โโ Gradient Accumulation Loop โโ
|
| 542 |
+
for micro_step in range(accumulation_steps):
|
| 543 |
+
batch = next(dataloader)
|
| 544 |
+
with autocast(bf16):
|
| 545 |
+
logits, loss = model(input_ids, targets)
|
| 546 |
+
scaled_loss = loss / accumulation_steps
|
| 547 |
+
scaled_loss.backward() # gradient ๋์
|
| 548 |
+
|
| 549 |
+
# โโ Optimizer Step (accumulation ์๋ฃ ํ) โโ
|
| 550 |
+
clip_grad_norm(model, max_norm=1.0)
|
| 551 |
+
optimizer.step()
|
| 552 |
+
optimizer.zero_grad()
|
| 553 |
+
scheduler.set_lr(optimizer, step)
|
| 554 |
+
```
|
| 555 |
+
|
| 556 |
+
Gradient Accumulation์ด๋?
|
| 557 |
+
- GPU ๋ฉ๋ชจ๋ฆฌ์ ํฐ ๋ฐฐ์น๋ฅผ ํ ๋ฒ์ ์ฌ๋ฆด ์ ์์ ๋
|
| 558 |
+
- ์์ micro_batch๋ก ์ฌ๋ฌ ๋ฒ forward/backward โ gradient๋ฅผ ๋์
|
| 559 |
+
- ๋์ ํ ํ ๋ฒ์ optimizer step
|
| 560 |
+
- ๊ฒฐ๊ณผ์ ์ผ๋ก ํฐ effective_batch์ ๋์ผํ ํจ๊ณผ
|
| 561 |
+
- Loss๋ฅผ accumulation_steps๋ก ๋๋๋ ์ด์ :
|
| 562 |
+
gradient์ ํ๊ท ์ ๊ตฌํ๊ธฐ ์ํด (ํฉ์ด ์๋ ํ๊ท )
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
def __init__(
|
| 566 |
+
self,
|
| 567 |
+
model: nn.Module,
|
| 568 |
+
train_dataloader: DataLoader,
|
| 569 |
+
val_dataloader: Optional[DataLoader],
|
| 570 |
+
config: TrainConfig,
|
| 571 |
+
seq_len: int = 2048,
|
| 572 |
+
):
|
| 573 |
+
self.config = config
|
| 574 |
+
self.seq_len = seq_len
|
| 575 |
+
|
| 576 |
+
# โโ ๋๋ฐ์ด์ค ์ค์ โโ
|
| 577 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 578 |
+
print(f"[Trainer] ๋๋ฐ์ด์ค: {self.device}")
|
| 579 |
+
if torch.cuda.is_available():
|
| 580 |
+
print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
|
| 581 |
+
print(f"[Trainer] GPU ๋ฉ๋ชจ๋ฆฌ: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 582 |
+
|
| 583 |
+
# โโ ๋ชจ๋ธ โโ
|
| 584 |
+
self.model = model.to(self.device)
|
| 585 |
+
# torch.compile: PyTorch 2.0+ ๊ทธ๋ํ ์ต์ ํ (์๋ 10-30% ํฅ์)
|
| 586 |
+
if torch.cuda.is_available() and hasattr(torch, "compile"):
|
| 587 |
+
print("[Trainer] torch.compile ์ ์ฉ ์ค...")
|
| 588 |
+
self.model = torch.compile(self.model)
|
| 589 |
+
|
| 590 |
+
# โโ ๋ฐ์ดํฐ โโ
|
| 591 |
+
self.train_dataloader = train_dataloader
|
| 592 |
+
self.val_dataloader = val_dataloader
|
| 593 |
+
self.train_iter = iter(train_dataloader)
|
| 594 |
+
|
| 595 |
+
# โโ ์ตํฐ๋ง์ด์ โโ
|
| 596 |
+
self.optimizer = create_optimizer(self.model, config)
|
| 597 |
+
|
| 598 |
+
# โโ ์ค์ผ์ค๋ฌ โโ
|
| 599 |
+
self.scheduler = CosineWarmupScheduler(config)
|
| 600 |
+
|
| 601 |
+
# โโ ์ฒดํฌํฌ์ธํธ โโ
|
| 602 |
+
self.ckpt_manager = CheckpointManager(config)
|
| 603 |
+
|
| 604 |
+
# โโ ๋ฉํธ๋ฆญ โโ
|
| 605 |
+
self.metrics = MetricsTracker(config)
|
| 606 |
+
|
| 607 |
+
# โโ ํ์ต ์ํ โโ
|
| 608 |
+
self.global_step = 0
|
| 609 |
+
self.best_val_loss = float("inf")
|
| 610 |
+
self.tokens_seen = 0
|
| 611 |
+
|
| 612 |
+
# โโ Mixed Precision โโ
|
| 613 |
+
# bf16์ GradScaler๊ฐ ๋ถํ์ (fp16์ผ ๋๋ง ํ์)
|
| 614 |
+
self.use_amp = config.dtype != "float32"
|
| 615 |
+
self.amp_dtype = config.torch_dtype
|
| 616 |
+
|
| 617 |
+
# โโ ์๋ ๋ณต์ ์๋ โโ
|
| 618 |
+
self._try_resume()
|
| 619 |
+
|
| 620 |
+
def _try_resume(self):
|
| 621 |
+
"""์ด์ ์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฉด ์๋์ผ๋ก ๋ณต์ํฉ๋๋ค."""
|
| 622 |
+
result = self.ckpt_manager.load_latest(
|
| 623 |
+
self.model, self.optimizer, self.device
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if result is not None:
|
| 627 |
+
self.global_step = result["step"]
|
| 628 |
+
self.best_val_loss = result["best_val_loss"]
|
| 629 |
+
self.metrics.history = result.get("metrics_history", self.metrics.history)
|
| 630 |
+
|
| 631 |
+
# wandb ์ฐ์ ๋ก๊น
|
| 632 |
+
if result.get("wandb_run_id"):
|
| 633 |
+
self.metrics.resume_wandb(result["wandb_run_id"])
|
| 634 |
+
|
| 635 |
+
self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
|
| 636 |
+
print(f"[Trainer] ํ์ต ์ฌ๊ฐ: step={self.global_step}, "
|
| 637 |
+
f"tokens={self.tokens_seen/1e9:.2f}B, "
|
| 638 |
+
f"best_val_loss={self.best_val_loss:.4f}")
|
| 639 |
+
|
| 640 |
+
def _get_next_batch(self) -> Dict[str, torch.Tensor]:
|
| 641 |
+
"""๋ค์ ํ์ต ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
|
| 642 |
+
|
| 643 |
+
Streaming DataLoader๋ ์ํญ ๊ฐ๋
์ด ์์ผ๋ฏ๋ก,
|
| 644 |
+
StopIteration ์ ์ ์ดํฐ๋ ์ดํฐ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 645 |
+
"""
|
| 646 |
+
try:
|
| 647 |
+
batch = next(self.train_iter)
|
| 648 |
+
except StopIteration:
|
| 649 |
+
self.train_iter = iter(self.train_dataloader)
|
| 650 |
+
batch = next(self.train_iter)
|
| 651 |
+
|
| 652 |
+
return {
|
| 653 |
+
"input_ids": batch["input_ids"].to(self.device, non_blocking=True),
|
| 654 |
+
"targets": batch["targets"].to(self.device, non_blocking=True),
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
def _train_step(self) -> Tuple[float, float]:
|
| 658 |
+
"""ํ๋์ optimizer step์ ์ํํฉ๋๋ค.
|
| 659 |
+
|
| 660 |
+
Returns:
|
| 661 |
+
(loss, grad_norm)
|
| 662 |
+
"""
|
| 663 |
+
self.model.train()
|
| 664 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 665 |
+
# set_to_none=True: gradient๋ฅผ None์ผ๋ก ์ค์ โ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 666 |
+
|
| 667 |
+
total_loss = 0.0
|
| 668 |
+
|
| 669 |
+
# โโ Gradient Accumulation Loop โโ
|
| 670 |
+
for micro_step in range(self.config.gradient_accumulation_steps):
|
| 671 |
+
batch = self._get_next_batch()
|
| 672 |
+
|
| 673 |
+
# Mixed Precision Forward
|
| 674 |
+
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 675 |
+
logits, loss = self.model(batch["input_ids"], batch["targets"])
|
| 676 |
+
|
| 677 |
+
# Loss ์ค์ผ์ผ๋ง: effective batch์ ํ๊ท ์ ์ํด
|
| 678 |
+
scaled_loss = loss / self.config.gradient_accumulation_steps
|
| 679 |
+
total_loss += loss.item()
|
| 680 |
+
|
| 681 |
+
# Backward (gradient ๋์ )
|
| 682 |
+
scaled_loss.backward()
|
| 683 |
+
|
| 684 |
+
# โโ Gradient Clipping โโ
|
| 685 |
+
# ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ gradient๋ฅผ ํ๋์ ๋ฒกํฐ๋ก ๋ณด๊ณ L2 norm ๊ณ์ฐ
|
| 686 |
+
# norm์ด max_norm์ ์ด๊ณผํ๋ฉด ๋น๋ก์ ์ผ๋ก ์ค์ผ์ผ ๋ค์ด
|
| 687 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 688 |
+
self.model.parameters(),
|
| 689 |
+
max_norm=self.config.grad_clip,
|
| 690 |
+
).item()
|
| 691 |
+
|
| 692 |
+
# โโ Optimizer Step โโ
|
| 693 |
+
self.optimizer.step()
|
| 694 |
+
|
| 695 |
+
# โโ LR ์
๋ฐ์ดํธ โโ
|
| 696 |
+
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 697 |
+
|
| 698 |
+
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
| 699 |
+
return avg_loss, grad_norm
|
| 700 |
+
|
| 701 |
+
@torch.no_grad()
|
| 702 |
+
def _evaluate(self) -> Tuple[float, float]:
|
| 703 |
+
"""๊ฒ์ฆ ๋ฐ์ดํฐ์์ Loss์ Perplexity๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 704 |
+
|
| 705 |
+
Perplexity = exp(loss)
|
| 706 |
+
- ์ง๊ด: "๋ชจ๋ธ์ด ๋ค์ ํ ํฐ์ ํ๊ท ๋ช ๊ฐ์ ํ๋ณด ์ค์์ ๊ณ ๋ฅด๋๊ฐ"
|
| 707 |
+
- PPL 100 โ 100๊ฐ ์ค 1๊ฐ๋ฅผ ๊ท ์ผํ๊ฒ ๊ณ ๋ฅด๋ ์์ค
|
| 708 |
+
- PPL 20 โ 20๊ฐ ์ค 1๊ฐ ์์ค (๊ฝค ์ข์)
|
| 709 |
+
- PPL 10 โ ๋งค์ฐ ์์ ์๊ฒ ์์ธก
|
| 710 |
+
"""
|
| 711 |
+
if self.val_dataloader is None:
|
| 712 |
+
return float("inf"), float("inf")
|
| 713 |
+
|
| 714 |
+
self.model.eval()
|
| 715 |
+
total_loss = 0.0
|
| 716 |
+
num_batches = 0
|
| 717 |
+
|
| 718 |
+
for i, batch in enumerate(self.val_dataloader):
|
| 719 |
+
if i >= self.config.eval_steps:
|
| 720 |
+
break
|
| 721 |
+
|
| 722 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 723 |
+
targets = batch["targets"].to(self.device)
|
| 724 |
+
|
| 725 |
+
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 726 |
+
_, loss = self.model(input_ids, targets)
|
| 727 |
+
|
| 728 |
+
total_loss += loss.item()
|
| 729 |
+
num_batches += 1
|
| 730 |
+
|
| 731 |
+
avg_loss = total_loss / max(num_batches, 1)
|
| 732 |
+
perplexity = math.exp(min(avg_loss, 20)) # overflow ๋ฐฉ์ง (exp(20) โ 5์ต)
|
| 733 |
+
|
| 734 |
+
return avg_loss, perplexity
|
| 735 |
+
|
| 736 |
+
def train(self):
|
| 737 |
+
"""๋ฉ์ธ ํ์ต ๋ฃจํ.
|
| 738 |
+
|
| 739 |
+
์ด ๋ฉ์๋๊ฐ ์ ์ฒด ํ์ต์ ์คํํฉ๋๋ค.
|
| 740 |
+
Colab ์ธ์
๋ง๋ฃ ์ ์ค๋จ๋์ด๋ ์ฒดํฌํฌ์ธํธ์์ ์๋ ์ฌ๊ฐ๋ฉ๋๋ค.
|
| 741 |
+
"""
|
| 742 |
+
config = self.config
|
| 743 |
+
|
| 744 |
+
print("\n" + "=" * 70)
|
| 745 |
+
print("๐ ํ์ต ์์")
|
| 746 |
+
print("=" * 70)
|
| 747 |
+
print(f" ์ด ์คํ
: {config.total_steps:,}")
|
| 748 |
+
print(f" ์์ ์คํ
: {self.global_step}")
|
| 749 |
+
print(f" Effective batch size: {config.effective_batch_size}")
|
| 750 |
+
print(f" ํ ํฐ/์คํ
: {config.effective_batch_size * self.seq_len:,}")
|
| 751 |
+
print(f" ์ด ํ์ต ํ ํฐ (์์): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
|
| 752 |
+
print(f" Mixed Precision: {config.dtype}")
|
| 753 |
+
print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
|
| 754 |
+
print(f" ์ฒดํฌํฌ์ธํธ: {config.checkpoint_dir}")
|
| 755 |
+
print("=" * 70 + "\n")
|
| 756 |
+
|
| 757 |
+
step_start_time = time.time()
|
| 758 |
+
tokens_at_log_start = self.tokens_seen
|
| 759 |
+
|
| 760 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 761 |
+
# ๋ฉ์ธ ๋ฃจํ
|
| 762 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 763 |
+
|
| 764 |
+
while self.global_step < config.total_steps:
|
| 765 |
+
|
| 766 |
+
# โโ Train Step โโ
|
| 767 |
+
loss, grad_norm = self._train_step()
|
| 768 |
+
self.global_step += 1
|
| 769 |
+
self.tokens_seen += config.effective_batch_size * self.seq_len
|
| 770 |
+
|
| 771 |
+
# โโ Logging โโ
|
| 772 |
+
if self.global_step % config.log_interval == 0:
|
| 773 |
+
elapsed = time.time() - step_start_time
|
| 774 |
+
tokens_delta = self.tokens_seen - tokens_at_log_start
|
| 775 |
+
tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
|
| 776 |
+
|
| 777 |
+
# GPU ๋ฉ๋ชจ๋ฆฌ
|
| 778 |
+
gpu_mem_gb = 0.0
|
| 779 |
+
if torch.cuda.is_available():
|
| 780 |
+
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
|
| 781 |
+
|
| 782 |
+
# ํ์ฌ LR
|
| 783 |
+
current_lr = self.scheduler.get_lr(self.global_step)
|
| 784 |
+
|
| 785 |
+
# ๋จ์ ์๊ฐ ์ถ์
|
| 786 |
+
remaining_steps = config.total_steps - self.global_step
|
| 787 |
+
steps_per_sec = config.log_interval / max(elapsed, 1e-6)
|
| 788 |
+
eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
|
| 789 |
+
eta_hours = eta_seconds / 3600
|
| 790 |
+
|
| 791 |
+
# ์ฝ์ ์ถ๋ ฅ
|
| 792 |
+
print(
|
| 793 |
+
f" Step {self.global_step:>6d}/{config.total_steps} โ "
|
| 794 |
+
f"Loss {loss:.4f} โ "
|
| 795 |
+
f"LR {current_lr:.2e} โ "
|
| 796 |
+
f"Grad {grad_norm:.2f} โ "
|
| 797 |
+
f"{tokens_per_sec:,.0f} tok/s โ "
|
| 798 |
+
f"GPU {gpu_mem_gb:.1f}GB โ "
|
| 799 |
+
f"ETA {eta_hours:.1f}h โ "
|
| 800 |
+
f"Tokens {self.tokens_seen/1e9:.2f}B"
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# wandb ๋ก๊น
|
| 804 |
+
self.metrics.log_train_step(
|
| 805 |
+
step=self.global_step,
|
| 806 |
+
loss=loss,
|
| 807 |
+
lr=current_lr,
|
| 808 |
+
grad_norm=grad_norm,
|
| 809 |
+
tokens_per_sec=tokens_per_sec,
|
| 810 |
+
gpu_mem_gb=gpu_mem_gb,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
step_start_time = time.time()
|
| 814 |
+
tokens_at_log_start = self.tokens_seen
|
| 815 |
+
|
| 816 |
+
# โโ Evaluation โโ
|
| 817 |
+
if self.global_step % config.eval_interval == 0:
|
| 818 |
+
val_loss, val_ppl = self._evaluate()
|
| 819 |
+
|
| 820 |
+
print(f"\n ๐ Eval @ Step {self.global_step}: "
|
| 821 |
+
f"Val Loss = {val_loss:.4f}, "
|
| 822 |
+
f"Val PPL = {val_ppl:.2f}")
|
| 823 |
+
|
| 824 |
+
self.metrics.log_eval(self.global_step, val_loss, val_ppl)
|
| 825 |
+
|
| 826 |
+
if val_loss < self.best_val_loss:
|
| 827 |
+
self.best_val_loss = val_loss
|
| 828 |
+
print(f" ๐ New best val loss: {val_loss:.4f}")
|
| 829 |
+
|
| 830 |
+
print()
|
| 831 |
+
|
| 832 |
+
# โโ Checkpoint โโ
|
| 833 |
+
if self.global_step % config.checkpoint_interval == 0:
|
| 834 |
+
self.ckpt_manager.save(
|
| 835 |
+
model=self.model,
|
| 836 |
+
optimizer=self.optimizer,
|
| 837 |
+
step=self.global_step,
|
| 838 |
+
best_val_loss=self.best_val_loss,
|
| 839 |
+
metrics_history=self.metrics.history,
|
| 840 |
+
wandb_run_id=self.metrics.wandb_run_id,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 844 |
+
# ํ์ต ์๋ฃ
|
| 845 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 846 |
+
|
| 847 |
+
print("\n" + "=" * 70)
|
| 848 |
+
print("๐ ํ์ต ์๋ฃ!")
|
| 849 |
+
print("=" * 70)
|
| 850 |
+
print(f" ์ด ์คํ
: {self.global_step:,}")
|
| 851 |
+
print(f" ์ด ํ ํฐ: {self.tokens_seen/1e9:.2f}B")
|
| 852 |
+
print(f" ์ต์ Val Loss: {self.best_val_loss:.4f}")
|
| 853 |
+
print(f" ์ต์ Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
|
| 854 |
+
print("=" * 70)
|
| 855 |
+
|
| 856 |
+
# ์ต์ข
์ฒดํฌํฌ์ธํธ ์ ์ฅ
|
| 857 |
+
self.ckpt_manager.save(
|
| 858 |
+
model=self.model,
|
| 859 |
+
optimizer=self.optimizer,
|
| 860 |
+
step=self.global_step,
|
| 861 |
+
best_val_loss=self.best_val_loss,
|
| 862 |
+
metrics_history=self.metrics.history,
|
| 863 |
+
wandb_run_id=self.metrics.wandb_run_id,
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
if self.config.use_wandb and self.metrics.wandb_run:
|
| 867 |
+
import wandb
|
| 868 |
+
wandb.finish()
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
# ============================================================================
|
| 872 |
+
# 7. GPU ํ๊ฒฝ ์๋ ๊ฐ์ง ๋ฐ ์ค์ ์กฐ์
|
| 873 |
+
# ============================================================================
|
| 874 |
+
|
| 875 |
+
def auto_configure(config: TrainConfig) -> TrainConfig:
|
| 876 |
+
"""GPU ์ข
๋ฅ์ ๋ฐ๋ผ ์ค์ ์ ์๋ ์กฐ์ ํฉ๋๋ค.
|
| 877 |
+
|
| 878 |
+
Colab Pro+์์ A100์ด ํญ์ ๋ฐฐ์ ๋์ง๋ ์์ต๋๋ค.
|
| 879 |
+
T4๋ V100์ด ๋ฐฐ์ ๋ ๊ฒฝ์ฐ ์๋์ผ๋ก ์ค์ ์ ์กฐ์ ํฉ๋๋ค.
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
์กฐ์ ๋ TrainConfig
|
| 883 |
+
"""
|
| 884 |
+
if not torch.cuda.is_available():
|
| 885 |
+
print("โ ๏ธ GPU ์์! CPU ๋ชจ๋ (๋งค์ฐ ๋๋ฆผ)")
|
| 886 |
+
config.dtype = "float32"
|
| 887 |
+
config.micro_batch_size = 1
|
| 888 |
+
config.gradient_accumulation_steps = 4
|
| 889 |
+
return config
|
| 890 |
+
|
| 891 |
+
gpu_name = torch.cuda.get_device_name().lower()
|
| 892 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 893 |
+
|
| 894 |
+
print(f"\n๐ GPU ๊ฐ์ง: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
|
| 895 |
+
|
| 896 |
+
if "a100" in gpu_name:
|
| 897 |
+
# A100 40GB: ๊ธฐ๋ณธ ์ค์ ๊ทธ๋๋ก (์ต์ )
|
| 898 |
+
print(" โ A100 ๊ฐ์ง: ๊ธฐ๋ณธ ์ค์ ์ฌ์ฉ (bf16, batch=4)")
|
| 899 |
+
config.dtype = "bfloat16"
|
| 900 |
+
config.micro_batch_size = 4
|
| 901 |
+
|
| 902 |
+
elif "v100" in gpu_name:
|
| 903 |
+
# V100 16GB: bf16 ๋ฏธ์ง์, ๋ฐฐ์น ์ถ์
|
| 904 |
+
print(" โ V100 ๊ฐ์ง: fp16 ๋ชจ๋, ๋ฐฐ์น ์ถ์")
|
| 905 |
+
config.dtype = "float16"
|
| 906 |
+
config.micro_batch_size = 2
|
| 907 |
+
config.gradient_accumulation_steps = 64 # effective batch ์ ์ง
|
| 908 |
+
|
| 909 |
+
elif "t4" in gpu_name:
|
| 910 |
+
# T4 16GB: bf16 ๋ฏธ์ง์, ๋ ์์ ๋ฐฐ์น
|
| 911 |
+
print(" โ T4 ๊ฐ์ง: fp16 ๋ชจ๋, ์ต์ ๋ฐฐ์น")
|
| 912 |
+
config.dtype = "float16"
|
| 913 |
+
config.micro_batch_size = 1
|
| 914 |
+
config.gradient_accumulation_steps = 128
|
| 915 |
+
|
| 916 |
+
elif "l4" in gpu_name:
|
| 917 |
+
# L4 24GB: bf16 ์ง์
|
| 918 |
+
print(" โ L4 ๊ฐ์ง: bf16 ๋ชจ๋, ๋ฐฐ์น ์กฐ์ ")
|
| 919 |
+
config.dtype = "bfloat16"
|
| 920 |
+
config.micro_batch_size = 2
|
| 921 |
+
config.gradient_accumulation_steps = 64
|
| 922 |
+
|
| 923 |
+
else:
|
| 924 |
+
print(f" โ ์ ์ ์๋ GPU. ๋ฉ๋ชจ๋ฆฌ ๊ธฐ์ค์ผ๋ก ์ค์ ์กฐ์ ")
|
| 925 |
+
if gpu_mem >= 30:
|
| 926 |
+
config.micro_batch_size = 4
|
| 927 |
+
elif gpu_mem >= 16:
|
| 928 |
+
config.micro_batch_size = 2
|
| 929 |
+
else:
|
| 930 |
+
config.micro_batch_size = 1
|
| 931 |
+
config.gradient_accumulation_steps = 128
|
| 932 |
+
|
| 933 |
+
print(f" โ dtype: {config.dtype}")
|
| 934 |
+
print(f" โ micro_batch: {config.micro_batch_size}")
|
| 935 |
+
print(f" โ grad_accum: {config.gradient_accumulation_steps}")
|
| 936 |
+
print(f" โ effective_batch: {config.effective_batch_size}")
|
| 937 |
+
|
| 938 |
+
return config
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
# ============================================================================
|
| 942 |
+
# 8. Quick Start (Colab ์คํ์ฉ)
|
| 943 |
+
# ============================================================================
|
| 944 |
+
|
| 945 |
+
def start_training(
|
| 946 |
+
model: nn.Module,
|
| 947 |
+
train_dataloader: DataLoader,
|
| 948 |
+
val_dataloader: Optional[DataLoader] = None,
|
| 949 |
+
config: Optional[TrainConfig] = None,
|
| 950 |
+
seq_len: int = 2048,
|
| 951 |
+
auto_config: bool = True,
|
| 952 |
+
) -> Trainer:
|
| 953 |
+
"""ํ์ต์ ์์ํฉ๋๋ค (ํ ์ค ์คํ).
|
| 954 |
+
|
| 955 |
+
์ฌ์ฉ๋ฒ (Colab):
|
| 956 |
+
```python
|
| 957 |
+
from model import LLMModel, ModelConfig
|
| 958 |
+
from data_pipeline import setup_data_pipeline, DataConfig
|
| 959 |
+
from trainer import start_training, TrainConfig
|
| 960 |
+
|
| 961 |
+
# 1. ๋ชจ๋ธ ์์ฑ
|
| 962 |
+
model_config = ModelConfig.base_1b()
|
| 963 |
+
model = LLMModel(model_config)
|
| 964 |
+
|
| 965 |
+
# 2. ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ
|
| 966 |
+
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 967 |
+
|
| 968 |
+
# 3. ํ์ต ์์ (์ฒดํฌํฌ์ธํธ ์๋ ๋ณต์)
|
| 969 |
+
trainer = start_training(model, train_dl, val_dl)
|
| 970 |
+
```
|
| 971 |
+
"""
|
| 972 |
+
config = config or TrainConfig()
|
| 973 |
+
|
| 974 |
+
# GPU ์๋ ๊ฐ์ง ๋ฐ ์ค์ ์กฐ์
|
| 975 |
+
if auto_config:
|
| 976 |
+
config = auto_configure(config)
|
| 977 |
+
|
| 978 |
+
# Google Drive ๋ง์ดํธ ํ์ธ (Colab)
|
| 979 |
+
if "/content/drive" in config.checkpoint_dir:
|
| 980 |
+
drive_path = Path("/content/drive/MyDrive")
|
| 981 |
+
if not drive_path.exists():
|
| 982 |
+
print("\nโ ๏ธ Google Drive๊ฐ ๋ง์ดํธ๋์ง ์์์ต๋๋ค!")
|
| 983 |
+
print(" Colab์์ ์คํ: from google.colab import drive; drive.mount('/content/drive')")
|
| 984 |
+
print(" ๋ก์ปฌ ๊ฒฝ๋ก๋ก ๋ณ๊ฒฝํฉ๋๋ค.")
|
| 985 |
+
config.checkpoint_dir = "./checkpoints"
|
| 986 |
+
|
| 987 |
+
# ์ฌํ์ฑ ์๋ ์ค์
|
| 988 |
+
torch.manual_seed(config.seed)
|
| 989 |
+
if torch.cuda.is_available():
|
| 990 |
+
torch.cuda.manual_seed(config.seed)
|
| 991 |
+
|
| 992 |
+
# Trainer ์์ฑ (์ฒดํฌํฌ์ธํธ ์๋ ๋ณต์ ํฌํจ)
|
| 993 |
+
trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
|
| 994 |
+
|
| 995 |
+
# ํ์ต ์คํ
|
| 996 |
+
trainer.train()
|
| 997 |
+
|
| 998 |
+
return trainer
|
| 999 |
+
|
| 1000 |
+
|
| 1001 |
+
# ============================================================================
|
| 1002 |
+
# 9. ๊ฒ์ฆ ์คํฌ๋ฆฝํธ
|
| 1003 |
+
# ============================================================================
|
| 1004 |
+
|
| 1005 |
+
if __name__ == "__main__":
|
| 1006 |
+
print("=" * 70)
|
| 1007 |
+
print("LLM-1B-Lab: Trainer ๊ฒ์ฆ")
|
| 1008 |
+
print("=" * 70)
|
| 1009 |
+
|
| 1010 |
+
# โโ ๋ฏธ๋ ๋ชจ๋ธ๋ก ํ์ต ๋ฃจํ ๊ฒ์ฆ โโ
|
| 1011 |
+
print("\n[ํ
์คํธ 1] ๋ฏธ๋ ๋ชจ๋ธ ํ์ต ๋ฃจํ ๊ฒ์ฆ")
|
| 1012 |
+
|
| 1013 |
+
# ๊ฐ๋จํ ๋๋ฏธ ๋ชจ๋ธ
|
| 1014 |
+
class TinyModel(nn.Module):
|
| 1015 |
+
def __init__(self, vocab_size=100, dim=64):
|
| 1016 |
+
super().__init__()
|
| 1017 |
+
self.emb = nn.Embedding(vocab_size, dim)
|
| 1018 |
+
self.linear = nn.Linear(dim, vocab_size)
|
| 1019 |
+
self.linear.weight = self.emb.weight # weight tying
|
| 1020 |
+
|
| 1021 |
+
def forward(self, input_ids, targets=None):
|
| 1022 |
+
import torch.nn.functional as F
|
| 1023 |
+
|
| 1024 |
+
h = self.emb(input_ids)
|
| 1025 |
+
logits = self.linear(h)
|
| 1026 |
+
loss = None
|
| 1027 |
+
if targets is not None:
|
| 1028 |
+
loss = F.cross_entropy(logits.view(-1, 100), targets.view(-1))
|
| 1029 |
+
return logits, loss
|
| 1030 |
+
|
| 1031 |
+
def count_parameters(self, trainable_only=True):
|
| 1032 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 1033 |
+
|
| 1034 |
+
model = TinyModel()
|
| 1035 |
+
print(f" ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ: {model.count_parameters():,}")
|
| 1036 |
+
|
| 1037 |
+
# ๋๋ฏธ ๋ฐ์ดํฐ ์์ฑ
|
| 1038 |
+
def dummy_dataloader(num_batches=100, batch_size=4, seq_len=32, vocab=100):
|
| 1039 |
+
for _ in range(num_batches):
|
| 1040 |
+
ids = torch.randint(0, vocab, (batch_size, seq_len + 1))
|
| 1041 |
+
yield {
|
| 1042 |
+
"input_ids": ids[:, :-1],
|
| 1043 |
+
"targets": ids[:, 1:],
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
# ์ค์ (๋งค์ฐ ์งง์ ํ์ต)
|
| 1047 |
+
config = TrainConfig(
|
| 1048 |
+
total_steps=20,
|
| 1049 |
+
warmup_steps=5,
|
| 1050 |
+
micro_batch_size=4,
|
| 1051 |
+
gradient_accumulation_steps=2,
|
| 1052 |
+
log_interval=5,
|
| 1053 |
+
eval_interval=10,
|
| 1054 |
+
checkpoint_interval=10,
|
| 1055 |
+
checkpoint_dir="./test_checkpoints",
|
| 1056 |
+
use_wandb=False,
|
| 1057 |
+
dtype="float32", # CPU ํ
์คํธ
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
# ์ค์ผ์ค๋ฌ ํ
์คํธ
|
| 1061 |
+
print("\n[ํ
์คํธ 2] LR ์ค์ผ์ค๋ฌ ๊ฒ์ฆ")
|
| 1062 |
+
scheduler = CosineWarmupScheduler(config)
|
| 1063 |
+
test_steps = [0, 2, 5, 10, 15, 20]
|
| 1064 |
+
for s in test_steps:
|
| 1065 |
+
lr = scheduler.get_lr(s)
|
| 1066 |
+
phase = "warmup" if s < config.warmup_steps else "cosine"
|
| 1067 |
+
print(f" Step {s:3d}: LR = {lr:.6f} ({phase})")
|
| 1068 |
+
|
| 1069 |
+
# Optimizer ํ
์คํธ
|
| 1070 |
+
print("\n[ํ
์คํธ 3] Optimizer ์์ฑ ๊ฒ์ฆ")
|
| 1071 |
+
optimizer = create_optimizer(model, config)
|
| 1072 |
+
print(f" ํ๋ผ๋ฏธํฐ ๊ทธ๋ฃน ์: {len(optimizer.param_groups)}")
|
| 1073 |
+
for i, pg in enumerate(optimizer.param_groups):
|
| 1074 |
+
n_params = sum(p.numel() for p in pg["params"])
|
| 1075 |
+
print(f" ๊ทธ๋ฃน {i}: {n_params:,} params, weight_decay={pg['weight_decay']}")
|
| 1076 |
+
|
| 1077 |
+
# ํ์ต ๋ฃจํ ๏ฟฝ๏ฟฝ์คํธ (์งง์ ๋ฒ์ )
|
| 1078 |
+
print("\n[ํ
์คํธ 4] ํ์ต ๋ฃจํ ์คํ (20 steps)")
|
| 1079 |
+
train_dl = list(dummy_dataloader(num_batches=200))
|
| 1080 |
+
|
| 1081 |
+
# DataLoader ์๋ฎฌ๋ ์ด์
|
| 1082 |
+
class SimpleLoader:
|
| 1083 |
+
def __init__(self, data):
|
| 1084 |
+
self.data = data
|
| 1085 |
+
|
| 1086 |
+
def __iter__(self):
|
| 1087 |
+
return iter(self.data)
|
| 1088 |
+
|
| 1089 |
+
trainer = Trainer(
|
| 1090 |
+
model=model,
|
| 1091 |
+
train_dataloader=SimpleLoader(train_dl),
|
| 1092 |
+
val_dataloader=SimpleLoader(train_dl[:20]),
|
| 1093 |
+
config=config,
|
| 1094 |
+
seq_len=32,
|
| 1095 |
+
)
|
| 1096 |
+
trainer.train()
|
| 1097 |
+
|
| 1098 |
+
# ์ ๋ฆฌ
|
| 1099 |
+
import shutil
|
| 1100 |
+
if os.path.exists("./test_checkpoints"):
|
| 1101 |
+
shutil.rmtree("./test_checkpoints")
|
| 1102 |
+
|
| 1103 |
+
print("\n" + "=" * 70)
|
| 1104 |
+
print("โ
Trainer ๊ฒ์ฆ ์๋ฃ!")
|
| 1105 |
+
print()
|
| 1106 |
+
print("์ค์ ํ์ต ์คํ ๋ฐฉ๋ฒ:")
|
| 1107 |
+
print(" trainer = start_training(model, train_dl, val_dl)")
|
| 1108 |
+
print("=" * 70)
|
llm_lab/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
|
| 3 |
+
================================================================
|
| 4 |
+
๋ฅ๋ฌ๋ ์ด๋ณด์๋ฅผ ์ํ ํ์ต์ฉ ๊ตฌํ.
|
| 5 |
+
๊ฐ ์ปดํฌ๋ํธ์ ์์ธ ์ฃผ์์ ๋ฌ์ "์ ์ด๋ ๊ฒ ํ๋์ง"๋ฅผ ์ค๋ช
ํฉ๋๋ค.
|
| 6 |
+
|
| 7 |
+
๋ชจ๋ ๊ตฌ์กฐ:
|
| 8 |
+
llm_lab.config โ ๋ชจ๋ ์ค์ (ModelConfig, DataConfig, TrainConfig, EvalConfig)
|
| 9 |
+
llm_lab.model โ ๋ชจ๋ธ ์ํคํ
์ฒ (RMSNorm, RoPE, GQA, SwiGLU, Transformer)
|
| 10 |
+
llm_lab.data โ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ (ํ ํฌ๋์ด์ , ์คํธ๋ฆฌ๋ฐ, ํจํน)
|
| 11 |
+
llm_lab.training โ ํ์ต ๋ฃจํ (Trainer, ์ค์ผ์ค๋ฌ, ์ฒดํฌํฌ์ธํธ)
|
| 12 |
+
llm_lab.evaluation โ ํ๊ฐ (Perplexity, ์์ฑ, Scaling Law, Attention)
|
| 13 |
+
llm_lab.utils โ ๊ณตํต ์ ํธ๋ฆฌํฐ (๋๋ฐ์ด์ค ๊ฐ์ง, ์๋)
|
| 14 |
+
|
| 15 |
+
Quick Start:
|
| 16 |
+
from llm_lab.config import ModelConfig, DataConfig, TrainConfig
|
| 17 |
+
from llm_lab.model import LLMModel
|
| 18 |
+
from llm_lab.data import setup_data_pipeline
|
| 19 |
+
from llm_lab.training import start_training
|
| 20 |
+
from llm_lab.evaluation import run_evaluation
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
__version__ = "0.1.0"
|
| 24 |
+
|
| 25 |
+
from .config import ModelConfig, DataConfig, TrainConfig, EvalConfig
|
| 26 |
+
from .model import LLMModel
|
| 27 |
+
from .data import setup_data_pipeline
|
| 28 |
+
from .training import start_training
|
| 29 |
+
from .evaluation import run_evaluation
|
| 30 |
+
from .utils import get_device, auto_configure
|
llm_lab/config/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""์ค์ (Config) ๋ชจ๋ โ ๋ชจ๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ ๊ณณ์์ ๊ด๋ฆฌํฉ๋๋ค."""
|
| 2 |
+
from .model_config import ModelConfig
|
| 3 |
+
from .data_config import DataConfig
|
| 4 |
+
from .train_config import TrainConfig
|
| 5 |
+
from .eval_config import EvalConfig
|
| 6 |
+
|
| 7 |
+
__all__ = ["ModelConfig", "DataConfig", "TrainConfig", "EvalConfig"]
|
llm_lab/config/data_config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class DataConfig:
|
| 7 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ .
|
| 8 |
+
|
| 9 |
+
Colab Pro+ ํ๊ฒฝ ์ ์ฝ์ ๊ณ ๋ คํ ๊ธฐ๋ณธ๊ฐ:
|
| 10 |
+
- Streaming ๋ชจ๋๋ก ๋์คํฌ ์ฌ์ฉ ์ต์ํ
|
| 11 |
+
- ์ํ์ค ํจํน์ผ๋ก ํจ๋ฉ ์์ด GPU ํ์ฉ๋ฅ ๊ทน๋ํ
|
| 12 |
+
- ์ ์ฒ๋ฆฌ๋ฅผ on-the-fly๋ก ์ํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 13 |
+
"""
|
| 14 |
+
# โโ ๋ฐ์ดํฐ์
โโ
|
| 15 |
+
dataset_name: str = "HuggingFaceFW/fineweb-edu"
|
| 16 |
+
dataset_subset: str = "sample-10BT" # 10B ํ ํฐ ์ํ
|
| 17 |
+
dataset_split: str = "train"
|
| 18 |
+
text_column: str = "text" # ํ
์คํธ๊ฐ ๋ด๊ธด ์ปฌ๋ผ๋ช
|
| 19 |
+
|
| 20 |
+
# โโ ํ ํฌ๋์ด์ โโ
|
| 21 |
+
tokenizer_type: str = "sentencepiece" # "sentencepiece" ๋๋ "hf"
|
| 22 |
+
# ์ฌ์ ํ์ต๋ ํ ํฌ๋์ด์ ๊ฒฝ๋ก (์์ผ๋ฉด ์๋ก ํ์ต)
|
| 23 |
+
tokenizer_path: Optional[str] = None
|
| 24 |
+
vocab_size: int = 32_000
|
| 25 |
+
|
| 26 |
+
# โโ ์ํ์ค โโ
|
| 27 |
+
max_seq_len: int = 2048
|
| 28 |
+
# ๋ฌธ์ ๊ตฌ๋ถ ํ ํฐ ์ฌ์ฉ ์ฌ๋ถ (ํจํน ์ ๋ฌธ์ ๊ฒฝ๊ณ ํ์)
|
| 29 |
+
use_eos_separator: bool = True
|
| 30 |
+
|
| 31 |
+
# โโ ๋ฐฐ์น โโ
|
| 32 |
+
batch_size: int = 4 # micro batch (GPU๋น)
|
| 33 |
+
num_workers: int = 2 # DataLoader ์์ปค ์
|
| 34 |
+
prefetch_factor: int = 4 # ๋ฏธ๋ฆฌ ์ค๋นํ ๋ฐฐ์น ์
|
| 35 |
+
|
| 36 |
+
# โโ ํ ํฌ๋์ด์ ํ์ต ์ค์ (์๋ก ํ์ต ์) โโ
|
| 37 |
+
tokenizer_train_samples: int = 50_000 # ํ์ต์ ์ฌ์ฉํ ๋ฌธ์ ์
|
| 38 |
+
tokenizer_save_dir: str = "./tokenizer"
|
| 39 |
+
|
| 40 |
+
# โโ ๊ฒ์ฆ ๋ฐ์ดํฐ โโ
|
| 41 |
+
val_ratio: float = 0.001 # ์ ์ฒด์ 0.1%๋ฅผ ๊ฒ์ฆ์ฉ์ผ๋ก
|
llm_lab/config/eval_config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class EvalConfig:
|
| 6 |
+
"""ํ๊ฐ ํ๋ผ๋ฏธํฐ."""
|
| 7 |
+
# โโ Perplexity โโ
|
| 8 |
+
eval_batch_size: int = 4
|
| 9 |
+
max_eval_batches: int = 100 # ์ต๋ ํ๊ฐ ๋ฐฐ์น ์
|
| 10 |
+
|
| 11 |
+
# โโ ์์ฑ โโ
|
| 12 |
+
max_new_tokens: int = 200
|
| 13 |
+
temperature: float = 0.8
|
| 14 |
+
top_k: int = 50
|
| 15 |
+
top_p: float = 0.9
|
| 16 |
+
num_samples: int = 3 # ํ๋กฌํํธ๋น ์์ฑ ํ์
|
| 17 |
+
|
| 18 |
+
# โโ ์ถ๋ ฅ โโ
|
| 19 |
+
save_dir: str = "./eval_results"
|
| 20 |
+
plot_dpi: int = 150
|
llm_lab/config/model_config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class ModelConfig:
|
| 6 |
+
"""๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ๋์ ๋ฐ์ดํฐํด๋์ค๋ก ๊ด๋ฆฌํฉ๋๋ค.
|
| 7 |
+
|
| 8 |
+
๊ท๋ชจ๋ณ ํ๋ฆฌ์
:
|
| 9 |
+
- debug: ~10M (ํ์ดํ๋ผ์ธ ๊ฒ์ฆ์ฉ)
|
| 10 |
+
- small: ~100M (์ค๊ฐ ๊ฒ์ฆ์ฉ)
|
| 11 |
+
- base: ~1.1B (์ต์ข
๋ชฉํ)
|
| 12 |
+
"""
|
| 13 |
+
vocab_size: int = 32_000
|
| 14 |
+
hidden_dim: int = 2048 # d_model: ๋ชจ๋ธ์ ๊ธฐ๋ณธ ์ฐจ์
|
| 15 |
+
num_layers: int = 22 # Transformer ๋ธ๋ก ์
|
| 16 |
+
num_heads: int = 16 # Query ํค๋ ์
|
| 17 |
+
num_kv_heads: int = 4 # Key/Value ํค๋ ์ (GQA)
|
| 18 |
+
intermediate_dim: int = 5632 # FFN ์ค๊ฐ ์ฐจ์ (โ 2.75 ร hidden_dim)
|
| 19 |
+
max_seq_len: int = 2048 # ์ต๋ ์ํ์ค ๊ธธ์ด
|
| 20 |
+
dropout: float = 0.0 # Pretraining์์๋ ๋ณดํต 0 ์ฌ์ฉ
|
| 21 |
+
rope_theta: float = 10000.0 # RoPE ์ฃผํ์ ๋ฒ ์ด์ค
|
| 22 |
+
norm_eps: float = 1e-6 # RMSNorm epsilon
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def head_dim(self) -> int:
|
| 26 |
+
"""๊ฐ ์ดํ
์
ํค๋์ ์ฐจ์."""
|
| 27 |
+
return self.hidden_dim // self.num_heads
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def num_kv_groups(self) -> int:
|
| 31 |
+
"""GQA์์ ํ๋์ KV ํค๋๊ฐ ๋ด๋นํ๋ Q ํค๋ ์."""
|
| 32 |
+
return self.num_heads // self.num_kv_heads
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def debug_10m(cls) -> "ModelConfig":
|
| 36 |
+
"""~10M ํ๋ผ๋ฏธํฐ - ๋น ๋ฅธ ๋๋ฒ๊น
์ฉ."""
|
| 37 |
+
return cls(
|
| 38 |
+
hidden_dim=256, num_layers=6, num_heads=8,
|
| 39 |
+
num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def small_100m(cls) -> "ModelConfig":
|
| 44 |
+
"""~100M ํ๋ผ๋ฏธํฐ - ์ค๊ฐ ๊ฒ์ฆ์ฉ."""
|
| 45 |
+
return cls(
|
| 46 |
+
hidden_dim=768, num_layers=12, num_heads=12,
|
| 47 |
+
num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def base_1b(cls) -> "ModelConfig":
|
| 52 |
+
"""~1.1B ํ๋ผ๋ฏธํฐ - ์ต์ข
ํ์ต ๋ชฉํ."""
|
| 53 |
+
return cls() # ๊ธฐ๋ณธ๊ฐ์ด 1B ์ค์
|
llm_lab/config/train_config.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class TrainConfig:
|
| 9 |
+
"""ํ์ต ํ์ดํผํ๋ผ๋ฏธํฐ + ์ธํ๋ผ ์ค์ .
|
| 10 |
+
|
| 11 |
+
Colab Pro+ (A100 40GB) ๊ธฐ์ค ์ต์ ํ๋ ๊ธฐ๋ณธ๊ฐ.
|
| 12 |
+
๋ชจ๋ ๊ฐ์ '์ ์ด ๊ฐ์ธ์ง' ์ค๋ช
์ ํฌํจํฉ๋๋ค.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# โโ ์ต์ ํ โโ
|
| 16 |
+
learning_rate: float = 3e-4
|
| 17 |
+
"""Peak LR. 1B ๋ชจ๋ธ ๊ธฐ์ค 3e-4๊ฐ ํ์ค.
|
| 18 |
+
GPT-3 ๋
ผ๋ฌธ์์ ๋ชจ๋ธ ํฌ๊ธฐ๋ณ ์ต์ LR์ ์ ์:
|
| 19 |
+
125M โ 6e-4, 350M โ 3e-4, 1.3B โ 2e-4
|
| 20 |
+
์ฐ๋ฆฌ ๋ชจ๋ธ(1.1B)์ 3e-4์์ ์์, ๋ถ์์ ํ๋ฉด 2e-4๋ก ํํฅ."""
|
| 21 |
+
|
| 22 |
+
min_learning_rate: float = 3e-5
|
| 23 |
+
"""Cosine decay ์ต์ ์ . ๋ณดํต peak์ 10%.
|
| 24 |
+
๋๋ฌด ๋ฎ์ผ๋ฉด ํ์ต ํ๋ฐ ์ ์ฒด, ๋๋ฌด ๋์ผ๋ฉด ์๋ ด ๋ถ์์ ."""
|
| 25 |
+
|
| 26 |
+
weight_decay: float = 0.1
|
| 27 |
+
"""AdamW์ L2 ์ ๊ทํ. 0.1์ด LLM ํ์ค.
|
| 28 |
+
Embedding๊ณผ Bias์๋ ์ ์ฉํ์ง ์์ (๊ด๋ก)."""
|
| 29 |
+
|
| 30 |
+
beta1: float = 0.9
|
| 31 |
+
beta2: float = 0.95
|
| 32 |
+
"""Adam ๋ชจ๋ฉํ
๊ณ์. ฮฒ2=0.95๋ LLM ํ์ต์์ ฮฒ2=0.999๋ณด๋ค ์์ ์ .
|
| 33 |
+
ํฐ ๋ฐฐ์น + ๊ธด ํ์ต์์ ฮฒ2๊ฐ ๋๋ฌด ํฌ๋ฉด ์ ์ ์๋๊ฐ ๋๋ฆผ."""
|
| 34 |
+
|
| 35 |
+
adam_eps: float = 1e-8
|
| 36 |
+
grad_clip: float = 1.0
|
| 37 |
+
"""Gradient Clipping: gradient norm์ด 1.0์ ์ด๊ณผํ๋ฉด ์ค์ผ์ผ๋ง.
|
| 38 |
+
ํ์ต ์ด๋ฐ์ด๋ ๋
ธ์ด์ฆ ๋ฐ์ดํฐ์์ ๋ฐ์ํ๋ gradient spike ๋ฐฉ์ง."""
|
| 39 |
+
|
| 40 |
+
# โโ ์ค์ผ์ค๋ง โโ
|
| 41 |
+
warmup_steps: int = 2000
|
| 42 |
+
"""Warmup: ์ฒ์ 2000 ์คํ
๋์ LR์ 0 โ peak๋ก ์ ํ ์ฆ๊ฐ.
|
| 43 |
+
์ ํ์ํ๊ฐ?
|
| 44 |
+
- ์ด๊ธฐ ๊ฐ์ค์น๊ฐ ๋๋ค โ ํฐ LR์ ๋ถ์์ ํ ์
๋ฐ์ดํธ ์ ๋ฐ
|
| 45 |
+
- ์์ LR๋ก ์์ํด ๋ชจ๋ธ์ด '๋ฐฉํฅ'์ ์ก๊ฒ ํ ํ ๋ณธ๊ฒฉ ํ์ต
|
| 46 |
+
- 2000์ ์ ์ฒด ํ์ต์ ~10%๊ฐ ์ ๋น (๊ฒฝํ์ ๊ท์น)."""
|
| 47 |
+
|
| 48 |
+
total_steps: int = 20_000
|
| 49 |
+
"""์ด ํ์ต ์คํ
์.
|
| 50 |
+
10B tokens / (128 batch ร 2048 seq_len) โ 38,000 ์ด์ง๋ง,
|
| 51 |
+
gradient accumulation ํฌํจ effective step ๊ธฐ์ค ~20,000."""
|
| 52 |
+
|
| 53 |
+
# โโ ๋ฐฐ์น โโ
|
| 54 |
+
micro_batch_size: int = 4
|
| 55 |
+
"""GPU์ ํ ๋ฒ์ ์ฌ๋ฆฌ๋ ๋ฐฐ์น ํฌ๊ธฐ.
|
| 56 |
+
A100 40GB์์ 1B ๋ชจ๋ธ bf16 ๊ธฐ์ค 4๊ฐ ์์ ํ ์ํ."""
|
| 57 |
+
|
| 58 |
+
gradient_accumulation_steps: int = 32
|
| 59 |
+
"""Gradient ๋์ ํ์. Effective batch = 4 ร 32 = 128.
|
| 60 |
+
์ ํฐ ๋ฐฐ์น๊ฐ ์ข์๊ฐ?
|
| 61 |
+
- gradient ์ถ์ ์ด ์์ ์ (๋
ธ์ด์ฆ ๊ฐ์)
|
| 62 |
+
- LLM ํ์ต์ ๋ณดํต effective batch 128~512
|
| 63 |
+
- ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ ์ ์ด ๊ฐ์ ๋๋ฆฌ๊ณ micro_batch๋ฅผ ์ค์."""
|
| 64 |
+
|
| 65 |
+
# โโ Mixed Precision โโ
|
| 66 |
+
dtype: str = "bfloat16"
|
| 67 |
+
"""bfloat16: A100์์ ์ง์, fp16๋ณด๋ค ์์น ์์ ์ฑ ์ฐ์.
|
| 68 |
+
exponent ๋นํธ๊ฐ fp32์ ๋์ผ โ overflow/underflow ์ํ ์ ์.
|
| 69 |
+
T4/V100 ํด๋ฐฑ ์ 'float16'์ผ๋ก ๋ณ๊ฒฝ."""
|
| 70 |
+
|
| 71 |
+
# โโ ์ฒดํฌํฌ์ธํธ โโ
|
| 72 |
+
checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
|
| 73 |
+
"""Google Drive ๊ฒฝ๋ก. Colab ์ธ์
๋ง๋ฃ ์์๋ ๋ณด์กด๋จ."""
|
| 74 |
+
|
| 75 |
+
checkpoint_interval: int = 500
|
| 76 |
+
"""500 ์คํ
๋ง๋ค ์ฒดํฌํฌ์ธํธ ์ ์ฅ.
|
| 77 |
+
A100 ๊ธฐ์ค ~30๋ถ ๊ฐ๊ฒฉ. ๋๋ฌด ์ฆ์ผ๋ฉด I/O ์ค๋ฒํค๋,
|
| 78 |
+
๋๋ฌด ๋๋ฌผ๋ฉด ์ธ์
๋ง๋ฃ ์ ์์ค ํผ."""
|
| 79 |
+
|
| 80 |
+
max_checkpoints: int = 3
|
| 81 |
+
"""๋กค๋ง ๋ณด๊ด ์. ์ค๋๋ ๊ฒ๋ถํฐ ์ญ์ .
|
| 82 |
+
์ฒดํฌํฌ์ธํธ 1๊ฐ โ 8-10GB โ 3๊ฐ๋ฉด ~30GB."""
|
| 83 |
+
|
| 84 |
+
# โโ ๋ก๊น
โโ
|
| 85 |
+
log_interval: int = 10
|
| 86 |
+
"""10 ์คํ
๋ง๋ค ์ฝ์ + wandb ๋ก๊น
."""
|
| 87 |
+
|
| 88 |
+
eval_interval: int = 500
|
| 89 |
+
"""500 ์คํ
๋ง๋ค ๊ฒ์ฆ Loss ์ธก์ ."""
|
| 90 |
+
|
| 91 |
+
eval_steps: int = 20
|
| 92 |
+
"""๊ฒ์ฆ ์ ์ฌ์ฉํ ๋ฐฐ์น ์. 20 ร 4 ร 2048 โ 160K ํ ํฐ."""
|
| 93 |
+
|
| 94 |
+
# โโ wandb โโ
|
| 95 |
+
wandb_project: str = "llm-1b-lab"
|
| 96 |
+
wandb_run_name: Optional[str] = None
|
| 97 |
+
use_wandb: bool = True
|
| 98 |
+
|
| 99 |
+
# โโ ์ฌํ์ฑ โโ
|
| 100 |
+
seed: int = 42
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def effective_batch_size(self) -> int:
|
| 104 |
+
return self.micro_batch_size * self.gradient_accumulation_steps
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def tokens_per_step(self) -> int:
|
| 108 |
+
"""ํ optimizer step๋น ์ฒ๋ฆฌ ํ ํฐ ์."""
|
| 109 |
+
# max_seq_len์ ์ธ๋ถ์์ ์ฃผ์
(ModelConfig ์ฐธ์กฐ)
|
| 110 |
+
return self.effective_batch_size * 2048
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def torch_dtype(self) -> torch.dtype:
|
| 114 |
+
return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[self.dtype]
|
llm_lab/data/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ๋ชจ๋ โ ํ ํฌ๋์ด์ , ์คํธ๋ฆฌ๋ฐ, ์ํ์ค ํจํน."""
|
| 2 |
+
from .tokenizer import Tokenizer
|
| 3 |
+
from .dataset import PackedStreamingDataset, ValidationDataset
|
| 4 |
+
from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
|
| 5 |
+
from .diagnostics import DataPipelineDiagnostics
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"Tokenizer", "PackedStreamingDataset", "ValidationDataset",
|
| 9 |
+
"create_train_dataloader", "train_tokenizer_from_dataset",
|
| 10 |
+
"setup_data_pipeline", "DataPipelineDiagnostics",
|
| 11 |
+
]
|
llm_lab/data/dataset.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""์คํธ๋ฆฌ๋ฐ ๋ฐ์ดํฐ์
โ ์ํ์ค ํจํน, ๊ฒ์ฆ ๋ฐ์ดํฐ์
."""
|
| 2 |
+
|
| 3 |
+
from typing import Iterator, List, Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 7 |
+
|
| 8 |
+
from llm_lab.config import DataConfig
|
| 9 |
+
from .tokenizer import Tokenizer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PackedStreamingDataset(IterableDataset):
|
| 13 |
+
"""Streaming + ์ํ์ค ํจํน ๋ฐ์ดํฐ์
.
|
| 14 |
+
|
| 15 |
+
์ ์ํ์ค ํจํน์ธ๊ฐ?
|
| 16 |
+
- ์ผ๋ฐ์ ๋ฐฉ๋ฒ: ๊ฐ ๋ฌธ์๋ฅผ max_seq_len์ผ๋ก ์๋ผ ํจ๋ฉ โ GPU ๋ญ๋น
|
| 17 |
+
- ์ํ์ค ํจํน: ์ฌ๋ฌ ๋ฌธ์๋ฅผ ์ด์ด๋ถ์ฌ max_seq_len์ ๊ฝ ์ฑ์ โ 100% ํ์ฉ
|
| 18 |
+
|
| 19 |
+
๋์ ๋ฐฉ์:
|
| 20 |
+
๋ฌธ์1 (300 ํ ํฐ) + ๋ฌธ์2 (1500 ํ ํฐ) + ๋ฌธ์3 (248 ํ ํฐ) = 2048 ํ ํฐ
|
| 21 |
+
โ [๋ฌธ์1][EOS][๋ฌธ์2][EOS][๋ฌธ์3][EOS][...ํจ๋ฉ ์์ด ๋ฑ ๋ง์ถค]
|
| 22 |
+
|
| 23 |
+
์ Streaming์ธ๊ฐ?
|
| 24 |
+
- FineWeb-Edu 10B ์ํ: ์์ถ ์ํ์์๋ ์์ญ GB
|
| 25 |
+
- Colab ๋์คํฌ ํ๊ณ (~200GB)์์ ์ ์ฒด ๋ค์ด๋ก๋ ๋ถ๊ฐ
|
| 26 |
+
- Streaming: ํ์ํ ๋งํผ๋ง ๋คํธ์ํฌ์์ ์ฝ์ด์ด
|
| 27 |
+
|
| 28 |
+
ํ์ต ์ ์ฃผ์์ฌํญ:
|
| 29 |
+
- ์ํ์ค ๋ด ๋ฌธ์ ๊ฒฝ๊ณ์ EOS ํ ํฐ ์ฝ์
์ผ๋ก ๋ชจ๋ธ์ด ๋ฌธ์ ๋์ ์ธ์
|
| 30 |
+
- Cross-Attention ๋ง์คํฌ ์์ด๋ EOS๊ฐ ์์ฐ์ค๋ฌ์ด ๊ฒฝ๊ณ ์ญํ
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
tokenizer: Tokenizer,
|
| 36 |
+
config: DataConfig,
|
| 37 |
+
split: str = "train",
|
| 38 |
+
seed: int = 42,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.tokenizer = tokenizer
|
| 42 |
+
self.config = config
|
| 43 |
+
self.split = split
|
| 44 |
+
self.seed = seed
|
| 45 |
+
self.max_seq_len = config.max_seq_len
|
| 46 |
+
|
| 47 |
+
def _load_dataset(self):
|
| 48 |
+
"""HuggingFace ๋ฐ์ดํฐ์
์ ์คํธ๋ฆฌ๋ฐ ๋ชจ๋๋ก ๋ก๋ํฉ๋๋ค."""
|
| 49 |
+
from datasets import load_dataset
|
| 50 |
+
|
| 51 |
+
ds = load_dataset(
|
| 52 |
+
self.config.dataset_name,
|
| 53 |
+
name=self.config.dataset_subset,
|
| 54 |
+
split=self.config.dataset_split,
|
| 55 |
+
streaming=True, # ํต์ฌ: ์คํธ๋ฆฌ๋ฐ ๋ชจ๋
|
| 56 |
+
trust_remote_code=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# ์
ํ (์คํธ๋ฆฌ๋ฐ์์๋ ๋ฒํผ ๊ธฐ๋ฐ ๊ทผ์ฌ ์
ํ)
|
| 60 |
+
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
| 61 |
+
|
| 62 |
+
return ds
|
| 63 |
+
|
| 64 |
+
def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
|
| 65 |
+
"""๋ฌธ์๋ฅผ ํ ํฌ๋์ด์ฆํ๊ณ ์ํ์ค ํจํนํฉ๋๋ค.
|
| 66 |
+
|
| 67 |
+
Yields:
|
| 68 |
+
{"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
|
| 69 |
+
|
| 70 |
+
targets = input_ids๋ฅผ ํ ์นธ shift:
|
| 71 |
+
input_ids: [A, B, C, D, E]
|
| 72 |
+
targets: [B, C, D, E, F]
|
| 73 |
+
โ ๋ชจ๋ธ์ A๋ฅผ ๋ณด๊ณ B๋ฅผ ์์ธก, B๋ฅผ ๋ณด๊ณ C๋ฅผ ์์ธก, ...
|
| 74 |
+
"""
|
| 75 |
+
buffer: List[int] = [] # ํ ํฐ ๋ฒํผ
|
| 76 |
+
|
| 77 |
+
for example in dataset:
|
| 78 |
+
text = example[self.config.text_column]
|
| 79 |
+
if not text or not text.strip():
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# ํ ํฌ๋์ด์ฆ (ํน์ ํ ํฐ ์์ด)
|
| 83 |
+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 84 |
+
|
| 85 |
+
if not token_ids:
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
# EOS ํ ํฐ ์ถ๊ฐ (๋ฌธ์ ๊ฒฝ๊ณ ํ์)
|
| 89 |
+
if self.config.use_eos_separator:
|
| 90 |
+
token_ids.append(self.tokenizer.eos_id)
|
| 91 |
+
|
| 92 |
+
# ๋ฒํผ์ ์ถ๊ฐ
|
| 93 |
+
buffer.extend(token_ids)
|
| 94 |
+
|
| 95 |
+
# ๋ฒํผ๊ฐ ์ถฉ๋ถํ ์ฐจ๋ฉด ์ํ์ค ์์ฑ
|
| 96 |
+
# +1์ targets ์์ฑ์ ์ํด (input + ๋ค์ ํ ํฐ)
|
| 97 |
+
while len(buffer) >= self.max_seq_len + 1:
|
| 98 |
+
# max_seq_len + 1 ๋งํผ ๊บผ๋
|
| 99 |
+
chunk = buffer[: self.max_seq_len + 1]
|
| 100 |
+
buffer = buffer[self.max_seq_len + 1 :]
|
| 101 |
+
|
| 102 |
+
# input_ids: ์ฒ์ ~ ๋์์ ๋ ๋ฒ์งธ
|
| 103 |
+
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
|
| 104 |
+
# targets: ๋ ๋ฒ์งธ ~ ๋ (ํ ์นธ shift)
|
| 105 |
+
targets = torch.tensor(chunk[1:], dtype=torch.long)
|
| 106 |
+
|
| 107 |
+
yield {"input_ids": input_ids, "targets": targets}
|
| 108 |
+
|
| 109 |
+
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
| 110 |
+
"""DataLoader๊ฐ ํธ์ถํ๋ ์ดํฐ๋ ์ดํฐ.
|
| 111 |
+
|
| 112 |
+
๋ฉํฐ ์์ปค ์ง์:
|
| 113 |
+
- ๊ฐ ์์ปค๊ฐ ์๋ก ๋ค๋ฅธ ์๋๋ก ์
ํ๋ ์คํธ๋ฆผ์ ์ฒ๋ฆฌ
|
| 114 |
+
- ์์ปค ๊ฐ ๋ฐ์ดํฐ ์ค๋ณต์ ์ต์ํ
|
| 115 |
+
"""
|
| 116 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 117 |
+
|
| 118 |
+
if worker_info is not None:
|
| 119 |
+
# ๋ฉํฐ ์์ปค: ๊ฐ ์์ปค์ ๋ค๋ฅธ ์๋
|
| 120 |
+
worker_seed = self.seed + worker_info.id
|
| 121 |
+
else:
|
| 122 |
+
worker_seed = self.seed
|
| 123 |
+
|
| 124 |
+
# ์์ปค๋ณ ์๋๋ก ๋ฐ์ดํฐ์
๋ก๋
|
| 125 |
+
self.seed = worker_seed
|
| 126 |
+
dataset = self._load_dataset()
|
| 127 |
+
|
| 128 |
+
return self._tokenize_and_pack(dataset)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ValidationDataset:
|
| 132 |
+
"""๊ฒ์ฆ์ฉ ๋ฐ์ดํฐ์
.
|
| 133 |
+
|
| 134 |
+
Streaming ๋ฐ์ดํฐ์
์์ ์ผ์ ๋์ ๋ฏธ๋ฆฌ ๊ฐ์ ธ์ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํฉ๋๋ค.
|
| 135 |
+
๋งค ์ํญ ๋์ผํ ๋ฐ์ดํฐ๋ก ํ๊ฐํด์ผ ๋น๊ต๊ฐ ์๋ฏธ ์๊ธฐ ๋๋ฌธ์
๋๋ค.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
tokenizer: Tokenizer,
|
| 141 |
+
config: DataConfig,
|
| 142 |
+
num_samples: int = 100,
|
| 143 |
+
seed: int = 9999,
|
| 144 |
+
):
|
| 145 |
+
self.tokenizer = tokenizer
|
| 146 |
+
self.config = config
|
| 147 |
+
self.num_samples = num_samples
|
| 148 |
+
self.samples: List[Dict[str, torch.Tensor]] = []
|
| 149 |
+
|
| 150 |
+
self._prepare(seed)
|
| 151 |
+
|
| 152 |
+
def _prepare(self, seed: int):
|
| 153 |
+
"""๋ฐ์ดํฐ์
์์ ๊ฒ์ฆ ์ํ์ ๋ฏธ๋ฆฌ ์ถ์ถํฉ๋๋ค."""
|
| 154 |
+
from datasets import load_dataset
|
| 155 |
+
|
| 156 |
+
print(f"[Validation] {self.num_samples}๊ฐ ๊ฒ์ฆ ์ํ ์ค๋น ์ค...")
|
| 157 |
+
|
| 158 |
+
ds = load_dataset(
|
| 159 |
+
self.config.dataset_name,
|
| 160 |
+
name=self.config.dataset_subset,
|
| 161 |
+
split=self.config.dataset_split,
|
| 162 |
+
streaming=True,
|
| 163 |
+
trust_remote_code=True,
|
| 164 |
+
)
|
| 165 |
+
# ํ์ต ๋ฐ์ดํฐ์ ๊ฒน์น์ง ์๋๋ก ๋ค๋ฅธ ์๋, ์๋ถ๋ถ ๊ฑด๋๋ฐ๊ธฐ
|
| 166 |
+
ds = ds.shuffle(seed=seed, buffer_size=5_000)
|
| 167 |
+
|
| 168 |
+
buffer: List[int] = []
|
| 169 |
+
count = 0
|
| 170 |
+
|
| 171 |
+
for example in ds:
|
| 172 |
+
if count >= self.num_samples:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
text = example[self.config.text_column]
|
| 176 |
+
if not text or not text.strip():
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
| 180 |
+
if not token_ids:
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
token_ids.append(self.tokenizer.eos_id)
|
| 184 |
+
buffer.extend(token_ids)
|
| 185 |
+
|
| 186 |
+
while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples:
|
| 187 |
+
chunk = buffer[: self.config.max_seq_len + 1]
|
| 188 |
+
buffer = buffer[self.config.max_seq_len + 1 :]
|
| 189 |
+
|
| 190 |
+
self.samples.append({
|
| 191 |
+
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
|
| 192 |
+
"targets": torch.tensor(chunk[1:], dtype=torch.long),
|
| 193 |
+
})
|
| 194 |
+
count += 1
|
| 195 |
+
|
| 196 |
+
print(f"[Validation] {len(self.samples)}๊ฐ ์ํ ์ค๋น ์๋ฃ")
|
| 197 |
+
|
| 198 |
+
def get_dataloader(self, batch_size: int) -> DataLoader:
|
| 199 |
+
"""๊ฒ์ฆ DataLoader๋ฅผ ๋ฐํํฉ๋๋ค."""
|
| 200 |
+
return DataLoader(
|
| 201 |
+
self.samples,
|
| 202 |
+
batch_size=batch_size,
|
| 203 |
+
shuffle=False,
|
| 204 |
+
num_workers=0,
|
| 205 |
+
collate_fn=_collate_fn,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 210 |
+
"""๋ฐฐ์น ๋ด ์ํ๋ค์ ํ๋์ ํ
์๋ก ํฉ์นฉ๋๋ค.
|
| 211 |
+
|
| 212 |
+
์ํ์ค ํจํน ๋๋ถ์ ๋ชจ๋ ์ํ์ด ๋์ผํ ๊ธธ์ด(max_seq_len)์ด๋ฏ๋ก
|
| 213 |
+
์ถ๊ฐ ํจ๋ฉ์ด ํ์ ์์ต๋๋ค.
|
| 214 |
+
"""
|
| 215 |
+
return {
|
| 216 |
+
"input_ids": torch.stack([s["input_ids"] for s in batch]),
|
| 217 |
+
"targets": torch.stack([s["targets"] for s in batch]),
|
| 218 |
+
}
|
llm_lab/data/diagnostics.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ง๋จ ๋๊ตฌ."""
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from llm_lab.config import DataConfig
|
| 10 |
+
from .tokenizer import Tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DataPipelineDiagnostics:
|
| 14 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ์ฑ๋ฅ๊ณผ ํ์ง์ ์ง๋จํฉ๋๋ค.
|
| 15 |
+
|
| 16 |
+
ํ์ต ์ ๋ฐ๋์ ํ์ธํด์ผ ํ ํญ๋ชฉ:
|
| 17 |
+
1) ํ ํฌ๋์ด์ ํ์ง: ํ๊ท ํ ํฐ/๋ฌธ์, ์ ์ ์๋ ํ ํฐ ๋น์จ
|
| 18 |
+
2) ํจํน ํจ์จ: ์ค์ ํ ํฐ ๋น์จ vs ํจ๋ฉ ๋น์จ
|
| 19 |
+
3) ์ฒ๋ฆฌ ์๋: tokens/sec (๋ฐ์ดํฐ ๋ก๋ฉ ๋ณ๋ชฉ ํ์ธ)
|
| 20 |
+
4) ๋ฐฐ์น ํํ: shape, dtype ์ ํ์ฑ
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
def check_tokenizer_quality(
|
| 25 |
+
tokenizer: Tokenizer,
|
| 26 |
+
config: DataConfig,
|
| 27 |
+
num_samples: int = 1000,
|
| 28 |
+
):
|
| 29 |
+
"""ํ ํฌ๋์ด์ ํ์ง์ ์ง๋จํฉ๋๋ค."""
|
| 30 |
+
from datasets import load_dataset
|
| 31 |
+
|
| 32 |
+
print("\n" + "=" * 60)
|
| 33 |
+
print("๐ ํ ํฌ๋์ด์ ํ์ง ์ง๋จ")
|
| 34 |
+
print("=" * 60)
|
| 35 |
+
|
| 36 |
+
ds = load_dataset(
|
| 37 |
+
config.dataset_name,
|
| 38 |
+
name=config.dataset_subset,
|
| 39 |
+
split=config.dataset_split,
|
| 40 |
+
streaming=True,
|
| 41 |
+
trust_remote_code=True,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
token_counts = []
|
| 45 |
+
char_counts = []
|
| 46 |
+
sample_count = 0
|
| 47 |
+
|
| 48 |
+
for example in ds:
|
| 49 |
+
if sample_count >= num_samples:
|
| 50 |
+
break
|
| 51 |
+
text = example[config.text_column]
|
| 52 |
+
if not text or not text.strip():
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
tokens = tokenizer.encode(text)
|
| 56 |
+
token_counts.append(len(tokens))
|
| 57 |
+
char_counts.append(len(text))
|
| 58 |
+
sample_count += 1
|
| 59 |
+
|
| 60 |
+
avg_tokens = sum(token_counts) / len(token_counts)
|
| 61 |
+
avg_chars = sum(char_counts) / len(char_counts)
|
| 62 |
+
compression_ratio = avg_chars / avg_tokens # ๋ฌธ์/ํ ํฐ ๋น์จ
|
| 63 |
+
|
| 64 |
+
print(f" ๋ถ์ ๋ฌธ์ ์: {len(token_counts):,}")
|
| 65 |
+
print(f" ํ๊ท ํ ํฐ/๋ฌธ์: {avg_tokens:.1f}")
|
| 66 |
+
print(f" ํ๊ท ๋ฌธ์/๋ฌธ์: {avg_chars:.1f}")
|
| 67 |
+
print(f" ์์ถ ๋น์จ (๋ฌธ์/ํ ํฐ): {compression_ratio:.2f}")
|
| 68 |
+
print(f" โ ์์ด ๊ธฐ์ค 3.5~4.5๊ฐ ์ ์")
|
| 69 |
+
print(f" ์ต์ ํ ํฐ: {min(token_counts)}, ์ต๋: {max(token_counts)}")
|
| 70 |
+
|
| 71 |
+
# ๋์ฝ๋ ์๋ณต ํ
์คํธ
|
| 72 |
+
test_text = "The quick brown fox jumps over the lazy dog."
|
| 73 |
+
encoded = tokenizer.encode(test_text)
|
| 74 |
+
decoded = tokenizer.decode(encoded)
|
| 75 |
+
roundtrip_ok = test_text.strip() in decoded.strip()
|
| 76 |
+
print(f"\n ์๋ณต ํ
์คํธ: {'โ
ํต๊ณผ' if roundtrip_ok else 'โ ์คํจ'}")
|
| 77 |
+
print(f" ์๋ณธ: {test_text}")
|
| 78 |
+
print(f" ์ธ์ฝ๋ฉ: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
|
| 79 |
+
print(f" ๋์ฝ๋ฉ: {decoded}")
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def benchmark_throughput(
|
| 83 |
+
dataloader: DataLoader,
|
| 84 |
+
num_batches: int = 50,
|
| 85 |
+
seq_len: int = 2048,
|
| 86 |
+
):
|
| 87 |
+
"""๋ฐ์ดํฐ ๋ก๋ฉ ์ฒ๋ฆฌ๋์ ์ธก์ ํฉ๋๋ค.
|
| 88 |
+
|
| 89 |
+
GPU ํ์ต ์๋์ ๋ณ๋ชฉ์ด ๋ฐ์ดํฐ ๋ก๋ฉ์ธ์ง ํ์ธํ๋ ํต์ฌ ์ง๋จ.
|
| 90 |
+
๋ชฉํ: ๋ฐ์ดํฐ ๋ก๋ฉ์ด GPU ์ฐ์ฐ๋ณด๋ค ๋นจ๋ผ์ผ ํจ (data loading โ bottleneck).
|
| 91 |
+
"""
|
| 92 |
+
print("\n" + "=" * 60)
|
| 93 |
+
print("โก ๋ฐ์ดํฐ ๋ก๋ฉ ์ฒ๋ฆฌ๋ ๋ฒค์น๋งํฌ")
|
| 94 |
+
print("=" * 60)
|
| 95 |
+
|
| 96 |
+
total_tokens = 0
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
|
| 99 |
+
for i, batch in enumerate(dataloader):
|
| 100 |
+
if i >= num_batches:
|
| 101 |
+
break
|
| 102 |
+
batch_tokens = batch["input_ids"].numel()
|
| 103 |
+
total_tokens += batch_tokens
|
| 104 |
+
|
| 105 |
+
if (i + 1) % 10 == 0:
|
| 106 |
+
elapsed = time.time() - start_time
|
| 107 |
+
tps = total_tokens / elapsed
|
| 108 |
+
print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
|
| 109 |
+
|
| 110 |
+
elapsed = time.time() - start_time
|
| 111 |
+
tps = total_tokens / elapsed
|
| 112 |
+
|
| 113 |
+
print(f"\n ์ด ๋ฐฐ์น ์: {num_batches}")
|
| 114 |
+
print(f" ์ด ํ ํฐ ์: {total_tokens:,}")
|
| 115 |
+
print(f" ์์ ์๊ฐ: {elapsed:.2f}์ด")
|
| 116 |
+
print(f" ํ๊ท ์ฒ๋ฆฌ๋: {tps:,.0f} tokens/sec")
|
| 117 |
+
print(f"\n ๐ก A100 ํ์ต ์ฒ๋ฆฌ๋ ~50-80K tokens/sec ๊ธฐ์ค:")
|
| 118 |
+
if tps > 80_000:
|
| 119 |
+
print(f" โ
๋ฐ์ดํฐ ๋ก๋ฉ์ด ๋ณ๋ชฉ์ด ์๋๋๋ค")
|
| 120 |
+
elif tps > 30_000:
|
| 121 |
+
print(f" โ ๏ธ ๊ฒฝ๊ณ์ - num_workers ์ฆ๊ฐ๋ฅผ ๊ณ ๋ คํ์ธ์")
|
| 122 |
+
else:
|
| 123 |
+
print(f" โ ๋ฐ์ดํฐ ๋ก๋ฉ์ด ๋ณ๋ชฉ! num_workers/prefetch ์กฐ์ ํ์")
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
|
| 127 |
+
"""๋ฐฐ์น ํ๋๋ฅผ ์์ธ ๊ฒ์ฌํฉ๋๋ค."""
|
| 128 |
+
print("\n" + "=" * 60)
|
| 129 |
+
print("๐ ๋ฐฐ์น ์์ธ ๊ฒ์ฌ")
|
| 130 |
+
print("=" * 60)
|
| 131 |
+
|
| 132 |
+
input_ids = batch["input_ids"]
|
| 133 |
+
targets = batch["targets"]
|
| 134 |
+
|
| 135 |
+
print(f" input_ids shape: {input_ids.shape}")
|
| 136 |
+
print(f" targets shape: {targets.shape}")
|
| 137 |
+
print(f" dtype: {input_ids.dtype}")
|
| 138 |
+
print(f" ๊ฐ ๋ฒ์: [{input_ids.min().item()}, {input_ids.max().item()}]")
|
| 139 |
+
|
| 140 |
+
# Shift ๊ด๊ณ ํ์ธ: targets[i] == input_ids[i+1]
|
| 141 |
+
shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
|
| 142 |
+
print(f" Shift ์ ํฉ์ฑ: {shift_correct*100:.1f}% (100%์ฌ์ผ ์ ์)")
|
| 143 |
+
|
| 144 |
+
# EOS ํ ํฐ ๋ถํฌ (๋ฌธ์ ๊ฒฝ๊ณ)
|
| 145 |
+
eos_count = (input_ids == tokenizer.eos_id).sum().item()
|
| 146 |
+
total_tokens = input_ids.numel()
|
| 147 |
+
print(f" EOS ํ ํฐ ์: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
|
| 148 |
+
|
| 149 |
+
# ์ฒซ ๋ฒ์งธ ์ํ ๋์ฝ๋ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
|
| 150 |
+
first_sample = input_ids[0][:100].tolist()
|
| 151 |
+
decoded_preview = tokenizer.decode(first_sample)
|
| 152 |
+
print(f"\n ์ฒซ ์ํ ๋์ฝ๋ฉ (์ฒ์ 100 ํ ํฐ):")
|
| 153 |
+
print(f" {decoded_preview[:300]}...")
|
llm_lab/data/pipeline.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ํตํฉ โ DataLoader ์์ฑ, ํ ํฌ๋์ด์ ํ์ต, Quick Start."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
|
| 8 |
+
from llm_lab.config import DataConfig
|
| 9 |
+
from .tokenizer import Tokenizer
|
| 10 |
+
from .dataset import PackedStreamingDataset, ValidationDataset, _collate_fn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_train_dataloader(
|
| 14 |
+
tokenizer: Tokenizer,
|
| 15 |
+
config: DataConfig,
|
| 16 |
+
seed: int = 42,
|
| 17 |
+
) -> DataLoader:
|
| 18 |
+
"""ํ์ต์ฉ DataLoader๋ฅผ ์์ฑํฉ๋๋ค.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
๋ฌดํํ ๋ฐ๋ณต๋๋ ์คํธ๋ฆฌ๋ฐ DataLoader
|
| 22 |
+
|
| 23 |
+
์ฌ์ฉ๋ฒ:
|
| 24 |
+
dataloader = create_train_dataloader(tokenizer, config)
|
| 25 |
+
for step, batch in enumerate(dataloader):
|
| 26 |
+
input_ids = batch["input_ids"].to(device) # (B, seq_len)
|
| 27 |
+
targets = batch["targets"].to(device) # (B, seq_len)
|
| 28 |
+
logits, loss = model(input_ids, targets)
|
| 29 |
+
...
|
| 30 |
+
"""
|
| 31 |
+
dataset = PackedStreamingDataset(
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
config=config,
|
| 34 |
+
split="train",
|
| 35 |
+
seed=seed,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
dataloader = DataLoader(
|
| 39 |
+
dataset,
|
| 40 |
+
batch_size=config.batch_size,
|
| 41 |
+
num_workers=config.num_workers,
|
| 42 |
+
prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
|
| 43 |
+
pin_memory=True, # GPU ์ ์ก ์๋ ํฅ์
|
| 44 |
+
collate_fn=_collate_fn,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return dataloader
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
|
| 51 |
+
"""๋ฐ์ดํฐ์
์์ BPE ํ ํฌ๋์ด์ ๋ฅผ ํ์ตํฉ๋๋ค.
|
| 52 |
+
|
| 53 |
+
์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ค ์ฌ์ฉํ ํ์ ์์ด, 50K ๋ฌธ์๋ฉด ์ถฉ๋ถํฉ๋๋ค.
|
| 54 |
+
ํ ํฌ๋์ด์ vocab์ ์ ์ฒด ๋ฐ์ดํฐ์ ํต๊ณ๋ฅผ ๋ฐ์ํ๋ฉด ๋๋ฏ๋ก.
|
| 55 |
+
"""
|
| 56 |
+
from datasets import load_dataset
|
| 57 |
+
|
| 58 |
+
print(f"[Train Tokenizer] {config.dataset_name}์์ ํ ํฌ๋์ด์ ํ์ต")
|
| 59 |
+
print(f"[Train Tokenizer] ํ์ต ๋ฌธ์ ์: {config.tokenizer_train_samples:,}")
|
| 60 |
+
|
| 61 |
+
# ํ
์คํธ ์ดํฐ๋ ์ดํฐ ์์ฑ
|
| 62 |
+
ds = load_dataset(
|
| 63 |
+
config.dataset_name,
|
| 64 |
+
name=config.dataset_subset,
|
| 65 |
+
split=config.dataset_split,
|
| 66 |
+
streaming=True,
|
| 67 |
+
trust_remote_code=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def text_iterator():
|
| 71 |
+
count = 0
|
| 72 |
+
for example in ds:
|
| 73 |
+
if count >= config.tokenizer_train_samples:
|
| 74 |
+
break
|
| 75 |
+
text = example[config.text_column]
|
| 76 |
+
if text and text.strip():
|
| 77 |
+
yield text
|
| 78 |
+
count += 1
|
| 79 |
+
if count % 10_000 == 0:
|
| 80 |
+
print(f" ... {count:,} ๋ฌธ์ ์ฒ๋ฆฌ")
|
| 81 |
+
|
| 82 |
+
# ํ ํฌ๋์ด์ ํ์ต
|
| 83 |
+
tokenizer = Tokenizer(config)
|
| 84 |
+
tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
|
| 85 |
+
|
| 86 |
+
return tokenizer
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def setup_data_pipeline(
|
| 90 |
+
tokenizer_mode: str = "train_new",
|
| 91 |
+
tokenizer_path: Optional[str] = None,
|
| 92 |
+
config: Optional[DataConfig] = None,
|
| 93 |
+
) -> tuple:
|
| 94 |
+
"""๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ํ ๋ฒ์ ์ค์ ํฉ๋๋ค.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
tokenizer_mode:
|
| 98 |
+
"train_new" - BPE ํ ํฌ๋์ด์ ์๋ก ํ์ต
|
| 99 |
+
"load_trained" - ์ด์ ์ ํ์ตํ ํ ํฌ๋์ด์ ๋ก๋
|
| 100 |
+
"pretrained" - HuggingFace ์ฌ์ ํ์ต ํ ํฌ๋์ด์ ์ฌ์ฉ
|
| 101 |
+
tokenizer_path:
|
| 102 |
+
"train_new" โ ์ ์ฅ ๊ฒฝ๋ก (๊ธฐ๋ณธ: ./tokenizer)
|
| 103 |
+
"load_trained" โ ์ ์ฅ๋ ํ ํฌ๋์ด์ ๊ฒฝ๋ก
|
| 104 |
+
"pretrained" โ HF ๋ชจ๋ธ๋ช
(๊ธฐ๋ณธ: mistralai/Mistral-7B-v0.1)
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
(tokenizer, train_dataloader, val_dataloader)
|
| 108 |
+
|
| 109 |
+
์ฌ์ฉ ์์ (Colab):
|
| 110 |
+
# ๋ฐฉ๋ฒ 1: ํ ํฌ๋์ด์ ์๋ก ํ์ต
|
| 111 |
+
tok, train_dl, val_dl = setup_data_pipeline("train_new")
|
| 112 |
+
|
| 113 |
+
# ๋ฐฉ๋ฒ 2: ๊ธฐ์กด ํ ํฌ๋์ด์ ๋ก๋
|
| 114 |
+
tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
|
| 115 |
+
|
| 116 |
+
# ๋ฐฉ๋ฒ 3: ์ฌ์ ํ์ต ํ ํฌ๋์ด์ (๊ฐ์ฅ ๊ฐํธ)
|
| 117 |
+
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 118 |
+
"""
|
| 119 |
+
config = config or DataConfig()
|
| 120 |
+
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
print("๐ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ ")
|
| 123 |
+
print("=" * 60)
|
| 124 |
+
|
| 125 |
+
# โโ Step 1: ํ ํฌ๋์ด์ โโ
|
| 126 |
+
tokenizer = Tokenizer(config)
|
| 127 |
+
|
| 128 |
+
if tokenizer_mode == "train_new":
|
| 129 |
+
tokenizer = train_tokenizer_from_dataset(config)
|
| 130 |
+
elif tokenizer_mode == "load_trained":
|
| 131 |
+
path = tokenizer_path or config.tokenizer_save_dir
|
| 132 |
+
tokenizer.load_trained_hf(path)
|
| 133 |
+
elif tokenizer_mode == "pretrained":
|
| 134 |
+
name = tokenizer_path or "mistralai/Mistral-7B-v0.1"
|
| 135 |
+
tokenizer.load_pretrained_hf(name)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
|
| 138 |
+
|
| 139 |
+
# โโ Step 2: ํ์ต DataLoader โโ
|
| 140 |
+
print("\n[DataLoader] ํ์ต DataLoader ์์ฑ...")
|
| 141 |
+
train_dataloader = create_train_dataloader(tokenizer, config)
|
| 142 |
+
|
| 143 |
+
# โโ Step 3: ๊ฒ์ฆ DataLoader โโ
|
| 144 |
+
print("\n[DataLoader] ๊ฒ์ฆ DataLoader ์์ฑ...")
|
| 145 |
+
val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
|
| 146 |
+
val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
|
| 147 |
+
|
| 148 |
+
print("\n" + "=" * 60)
|
| 149 |
+
print("โ
๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ์ค์ ์๋ฃ!")
|
| 150 |
+
print(f" ํ ํฌ๋์ด์ vocab: {tokenizer.vocab_size:,}")
|
| 151 |
+
print(f" ์ํ์ค ๊ธธ์ด: {config.max_seq_len}")
|
| 152 |
+
print(f" ๋ฐฐ์น ํฌ๊ธฐ: {config.batch_size}")
|
| 153 |
+
print(f" ํ ํฐ/๋ฐฐ์น: {config.batch_size * config.max_seq_len:,}")
|
| 154 |
+
print("=" * 60)
|
| 155 |
+
|
| 156 |
+
return tokenizer, train_dataloader, val_dataloader
|
llm_lab/data/tokenizer.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ ํฌ๋์ด์ ๋ํผ โ SentencePiece / HuggingFace BPE ํตํฉ."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from typing import Optional, Iterator, List
|
| 6 |
+
|
| 7 |
+
from llm_lab.config import DataConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Tokenizer:
|
| 11 |
+
"""ํ ํฌ๋์ด์ ํตํฉ ๋ํผ.
|
| 12 |
+
|
| 13 |
+
์ธ ๊ฐ์ง ๋ฐฉ๋ฒ ์ง์:
|
| 14 |
+
1) ๊ธฐ์กด SentencePiece ๋ชจ๋ธ ๋ก๋
|
| 15 |
+
2) HuggingFace tokenizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ์๋ก ํ์ต
|
| 16 |
+
3) ์ฌ์ ํ์ต๋ HF ํ ํฌ๋์ด์ ๋ก๋ (์: LLaMA tokenizer)
|
| 17 |
+
|
| 18 |
+
์ ์ง์ ๊ตฌํํ์ง ์๋๊ฐ?
|
| 19 |
+
- BPE ํ ํฌ๋์ด์ ํ์ต์ ๋๊ท๋ชจ ํ
์คํธ ํต๊ณ ์ฒ๋ฆฌ์ด๋ฉฐ,
|
| 20 |
+
๋ชจ๋ธ ์ํคํ
์ฒ ์ดํด์ ์ง์ ์ ๊ด๋ จ์ด ์ ์ต๋๋ค.
|
| 21 |
+
- ๋ค๋ง ํ ํฌ๋์ด์ ์ ๋์ ์๋ฆฌ(BPE ๋ณํฉ ๊ท์น)๋ ์ดํดํด์ผ ํฉ๋๋ค.
|
| 22 |
+
|
| 23 |
+
BPE(Byte Pair Encoding) ํต์ฌ ์๋ฆฌ:
|
| 24 |
+
1) ํ
์คํธ๋ฅผ ๋ฐ์ดํธ/๋ฌธ์ ๋จ์๋ก ๋ถ๋ฆฌ
|
| 25 |
+
2) ๊ฐ์ฅ ๋น๋ฒํ ์ธ์ ์์ ๋ฐ๋ณต์ ์ผ๋ก ๋ณํฉ
|
| 26 |
+
3) vocab_size์ ๋๋ฌํ ๋๊น์ง ๋ฐ๋ณต
|
| 27 |
+
โ ์์ฃผ ๋ฑ์ฅํ๋ ๋จ์ด๋ ํ๋์ ํ ํฐ, ํฌ๊ท ๋จ์ด๋ ์ฌ๋ฌ ํ ํฐ์ผ๋ก ๋ถ๋ฆฌ
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: DataConfig):
|
| 31 |
+
self.config = config
|
| 32 |
+
self._tokenizer = None
|
| 33 |
+
self.vocab_size = config.vocab_size
|
| 34 |
+
|
| 35 |
+
# ํน์ ํ ํฐ ID (์ด๊ธฐํ ํ ์ค์ ๋จ)
|
| 36 |
+
self.bos_id: int = 1 # Beginning of Sequence
|
| 37 |
+
self.eos_id: int = 2 # End of Sequence
|
| 38 |
+
self.pad_id: int = 0 # Padding
|
| 39 |
+
|
| 40 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 41 |
+
# ๋ฐฉ๋ฒ 1: SentencePiece ๋ชจ๋ธ ๋ก๋
|
| 42 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 43 |
+
|
| 44 |
+
def load_sentencepiece(self, model_path: str):
|
| 45 |
+
"""๊ธฐ์กด SentencePiece ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค."""
|
| 46 |
+
import sentencepiece as spm
|
| 47 |
+
|
| 48 |
+
self._tokenizer = spm.SentencePieceProcessor()
|
| 49 |
+
self._tokenizer.Load(model_path)
|
| 50 |
+
|
| 51 |
+
self.vocab_size = self._tokenizer.GetPieceSize()
|
| 52 |
+
self.bos_id = self._tokenizer.bos_id()
|
| 53 |
+
self.eos_id = self._tokenizer.eos_id()
|
| 54 |
+
self.pad_id = self._tokenizer.pad_id()
|
| 55 |
+
self._encode_fn = self._tokenizer.Encode
|
| 56 |
+
self._decode_fn = self._tokenizer.Decode
|
| 57 |
+
|
| 58 |
+
print(f"[Tokenizer] SentencePiece ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 59 |
+
|
| 60 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 61 |
+
# ๋ฐฉ๋ฒ 2: HuggingFace tokenizers๋ก BPE ํ์ต
|
| 62 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 63 |
+
|
| 64 |
+
def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
|
| 65 |
+
"""BPE ํ ํฌ๋์ด์ ๋ฅผ ์ฒ์๋ถํฐ ํ์ตํฉ๋๋ค.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
text_iterator: ํ์ต ํ
์คํธ๋ฅผ yieldํ๋ ์ดํฐ๋ ์ดํฐ
|
| 69 |
+
save_dir: ์ ์ฅ ๊ฒฝ๋ก
|
| 70 |
+
|
| 71 |
+
ํ์ต ํฌ์ธํธ:
|
| 72 |
+
- vocab_size๊ฐ ํด์๋ก: ์์ฃผ ์ฐ๋ ํํ์ด 1ํ ํฐ โ ์ํ์ค ์งง์์ง
|
| 73 |
+
- vocab_size๊ฐ ์์์๋ก: Embedding ํ๋ผ๋ฏธํฐ ์ ์ฝ, ํ์ง๋ง ์ํ์ค ๊ธธ์ด์ง
|
| 74 |
+
- 32K๋ ์์ด ๊ธฐ์ค ์ข์ ๊ท ํ์
|
| 75 |
+
"""
|
| 76 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 77 |
+
from tokenizers.models import BPE
|
| 78 |
+
from tokenizers.trainers import BpeTrainer
|
| 79 |
+
from tokenizers.pre_tokenizers import ByteLevel
|
| 80 |
+
from tokenizers.processors import TemplateProcessing
|
| 81 |
+
|
| 82 |
+
print("[Tokenizer] BPE ํ ํฌ๋์ด์ ํ์ต ์์...")
|
| 83 |
+
|
| 84 |
+
# BPE ๋ชจ๋ธ ์์ฑ
|
| 85 |
+
tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
|
| 86 |
+
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
|
| 87 |
+
|
| 88 |
+
# ํน์ ํ ํฐ ์ ์
|
| 89 |
+
special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
|
| 90 |
+
|
| 91 |
+
# ํธ๋ ์ด๋ ์ค์
|
| 92 |
+
trainer = BpeTrainer(
|
| 93 |
+
vocab_size=self.config.vocab_size,
|
| 94 |
+
special_tokens=special_tokens,
|
| 95 |
+
min_frequency=2, # ์ต์ 2๋ฒ ๋ฑ์ฅํ ์๋ง ๋ณํฉ
|
| 96 |
+
show_progress=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# ํ์ต ์คํ
|
| 100 |
+
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
|
| 101 |
+
|
| 102 |
+
# ํ์ฒ๋ฆฌ: BOS/EOS ์๋ ์ถ๊ฐ
|
| 103 |
+
tokenizer.post_processor = TemplateProcessing(
|
| 104 |
+
single="<s> $A </s>",
|
| 105 |
+
special_tokens=[("<s>", 1), ("</s>", 2)],
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self._tokenizer = tokenizer
|
| 109 |
+
self.vocab_size = tokenizer.get_vocab_size()
|
| 110 |
+
self.pad_id = 0
|
| 111 |
+
self.bos_id = 1
|
| 112 |
+
self.eos_id = 2
|
| 113 |
+
|
| 114 |
+
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 115 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 116 |
+
|
| 117 |
+
# ์ ์ฅ
|
| 118 |
+
save_dir = save_dir or self.config.tokenizer_save_dir
|
| 119 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 120 |
+
tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
|
| 121 |
+
# ๋ฉํ ์ ๋ณด ์ ์ฅ
|
| 122 |
+
meta = {
|
| 123 |
+
"vocab_size": self.vocab_size,
|
| 124 |
+
"bos_id": self.bos_id,
|
| 125 |
+
"eos_id": self.eos_id,
|
| 126 |
+
"pad_id": self.pad_id,
|
| 127 |
+
}
|
| 128 |
+
with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
|
| 129 |
+
json.dump(meta, f, indent=2)
|
| 130 |
+
|
| 131 |
+
print(f"[Tokenizer] ํ์ต ์๋ฃ: vocab_size={self.vocab_size}")
|
| 132 |
+
print(f"[Tokenizer] ์ ์ฅ ์์น: {save_dir}")
|
| 133 |
+
|
| 134 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 135 |
+
# ๋ฐฉ๋ฒ 3: ์ฌ์ ํ์ต๋ HF ํ ํฌ๋์ด์ ๋ก๋
|
| 136 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 137 |
+
|
| 138 |
+
def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
|
| 139 |
+
"""HuggingFace์์ ์ฌ์ ํ์ต๋ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 140 |
+
|
| 141 |
+
๊ฐ์ฅ ๊ฐํธํ ๋ฐฉ๋ฒ. LLaMA ํ ํฌ๋์ด์ ๋ 32K vocab, BPE ๊ธฐ๋ฐ.
|
| 142 |
+
์ฃผ์: meta-llama ๋ชจ๋ธ์ HF ์น์ธ์ด ํ์ํ ์ ์์.
|
| 143 |
+
๋์: mistralai/Mistral-7B-v0.1 (์น์ธ ๋ถํ์)
|
| 144 |
+
"""
|
| 145 |
+
from transformers import AutoTokenizer
|
| 146 |
+
|
| 147 |
+
print(f"[Tokenizer] HF ํ ํฌ๋์ด์ ๋ก๋: {name_or_path}")
|
| 148 |
+
tokenizer = AutoTokenizer.from_pretrained(name_or_path)
|
| 149 |
+
|
| 150 |
+
self._tokenizer = tokenizer
|
| 151 |
+
self.vocab_size = tokenizer.vocab_size
|
| 152 |
+
self.bos_id = tokenizer.bos_token_id or 1
|
| 153 |
+
self.eos_id = tokenizer.eos_token_id or 2
|
| 154 |
+
self.pad_id = tokenizer.pad_token_id or 0
|
| 155 |
+
|
| 156 |
+
self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
| 157 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 158 |
+
|
| 159 |
+
print(f"[Tokenizer] ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 160 |
+
|
| 161 |
+
def load_trained_hf(self, path: str):
|
| 162 |
+
"""train_bpe()๋ก ํ์ตํ ํ ํฌ๋์ด์ ๋ฅผ ๋ค์ ๋ก๋ํฉ๋๋ค."""
|
| 163 |
+
from tokenizers import Tokenizer as HFTokenizer
|
| 164 |
+
|
| 165 |
+
tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
|
| 166 |
+
with open(os.path.join(path, "tokenizer_meta.json"), "r") as f:
|
| 167 |
+
meta = json.load(f)
|
| 168 |
+
|
| 169 |
+
self._tokenizer = tokenizer
|
| 170 |
+
self.vocab_size = meta["vocab_size"]
|
| 171 |
+
self.bos_id = meta["bos_id"]
|
| 172 |
+
self.eos_id = meta["eos_id"]
|
| 173 |
+
self.pad_id = meta["pad_id"]
|
| 174 |
+
|
| 175 |
+
self._encode_fn = lambda text: tokenizer.encode(text).ids
|
| 176 |
+
self._decode_fn = lambda ids: tokenizer.decode(ids)
|
| 177 |
+
|
| 178 |
+
print(f"[Tokenizer] ๋ก๋ ์๋ฃ: vocab_size={self.vocab_size}")
|
| 179 |
+
|
| 180 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 181 |
+
# ๊ณตํต ์ธํฐํ์ด์ค
|
| 182 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 183 |
+
|
| 184 |
+
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
|
| 185 |
+
"""ํ
์คํธ โ ํ ํฐ ID ๋ฆฌ์คํธ."""
|
| 186 |
+
ids = self._encode_fn(text)
|
| 187 |
+
if add_special_tokens:
|
| 188 |
+
ids = [self.bos_id] + ids + [self.eos_id]
|
| 189 |
+
return ids
|
| 190 |
+
|
| 191 |
+
def decode(self, ids: List[int]) -> str:
|
| 192 |
+
"""ํ ํฐ ID ๋ฆฌ์คํธ โ ํ
์คํธ."""
|
| 193 |
+
return self._decode_fn(ids)
|
| 194 |
+
|
| 195 |
+
def __len__(self) -> int:
|
| 196 |
+
return self.vocab_size
|
llm_lab/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ๊ฐ ๋ชจ๋ โ Perplexity, ํ
์คํธ ์์ฑ, Scaling Law, Attention ์๊ฐํ."""
|
| 2 |
+
|
| 3 |
+
from .perplexity import PerplexityEvaluator
|
| 4 |
+
from .generation import GenerationEvaluator
|
| 5 |
+
from .scaling import ScalingAnalyzer
|
| 6 |
+
from .dynamics import TrainingDynamicsAnalyzer
|
| 7 |
+
from .attention_viz import AttentionVisualizer
|
| 8 |
+
from .full_evaluator import FullEvaluator
|
| 9 |
+
from .checklist import InsightChecklist
|
| 10 |
+
from .runner import run_evaluation
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"PerplexityEvaluator",
|
| 14 |
+
"GenerationEvaluator",
|
| 15 |
+
"ScalingAnalyzer",
|
| 16 |
+
"TrainingDynamicsAnalyzer",
|
| 17 |
+
"AttentionVisualizer",
|
| 18 |
+
"FullEvaluator",
|
| 19 |
+
"InsightChecklist",
|
| 20 |
+
"run_evaluation",
|
| 21 |
+
]
|
llm_lab/evaluation/attention_viz.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Attention ํจํด ์๊ฐํ."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use("Agg")
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
HAS_MATPLOTLIB = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
HAS_MATPLOTLIB = False
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AttentionVisualizer:
|
| 21 |
+
"""Attention ํจํด์ ์๊ฐํํฉ๋๋ค.
|
| 22 |
+
|
| 23 |
+
ํ์ต ํฌ์ธํธ:
|
| 24 |
+
- Causal Mask: ํ์ผ๊ฐ ํจํด (๋ฏธ๋ ํ ํฐ์ ๋ณผ ์ ์์)
|
| 25 |
+
- ํค๋๋ณ ์ญํ ๋ถํ: ์ผ๋ถ๋ ๋ก์ปฌ(์ธ์ ), ์ผ๋ถ๋ ๊ธ๋ก๋ฒ(๋จผ ํ ํฐ) ์ฃผ๋ชฉ
|
| 26 |
+
- ๊ตฌ๋ฌธ๋ก ์ ํจํด: ๋์ฌโ์ฃผ์ด, ๋๋ช
์ฌโ์ ํ์ฌ ๋ฑ์ ๋์ attention
|
| 27 |
+
|
| 28 |
+
์ฃผ์: 1B ๋ชจ๋ธ์ ์ ์ฒด attention์ ์ ์ฅํ๋ฉด ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ!
|
| 29 |
+
โ ํน์ ๋ ์ด์ด/ํค๋๋ง ์ ํ์ ์ผ๋ก ์๊ฐํํฉ๋๋ค.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 33 |
+
self.save_dir = Path(save_dir)
|
| 34 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def extract_attention(
|
| 38 |
+
self,
|
| 39 |
+
model: nn.Module,
|
| 40 |
+
input_ids: torch.Tensor,
|
| 41 |
+
layer_idx: int = 0,
|
| 42 |
+
device: torch.device = torch.device("cpu"),
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""ํน์ ๋ ์ด์ด์ attention weight๋ฅผ ์ถ์ถํฉ๋๋ค.
|
| 45 |
+
|
| 46 |
+
๋ชจ๋ธ์ attention ๋ชจ๋์ ์ผ์์ ์ผ๋ก ์์ ํ์ฌ
|
| 47 |
+
attention weight๋ฅผ ์บก์ฒํฉ๋๋ค.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
attention_weights: (num_heads, seq_len, seq_len)
|
| 51 |
+
"""
|
| 52 |
+
model.eval()
|
| 53 |
+
captured_attn = {}
|
| 54 |
+
|
| 55 |
+
# Hook์ผ๋ก attention weight ์บก์ฒ
|
| 56 |
+
target_layer = model.layers[layer_idx].attention
|
| 57 |
+
|
| 58 |
+
# scaled_dot_product_attention์ ์๋ ๊ตฌํ์ผ๋ก ๋์ฒด
|
| 59 |
+
original_forward = target_layer.forward
|
| 60 |
+
|
| 61 |
+
def hooked_forward(x, mask=None, position_offset=0):
|
| 62 |
+
B, S, _ = x.shape
|
| 63 |
+
hd = target_layer.head_dim
|
| 64 |
+
|
| 65 |
+
q = target_layer.q_proj(x).view(B, S, target_layer.num_heads, hd).transpose(1, 2)
|
| 66 |
+
k = target_layer.k_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
|
| 67 |
+
v = target_layer.v_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
q, k = target_layer.rope(q, k, position_offset)
|
| 70 |
+
|
| 71 |
+
if target_layer.num_kv_groups > 1:
|
| 72 |
+
k = target_layer._repeat_kv(k)
|
| 73 |
+
v = target_layer._repeat_kv(v)
|
| 74 |
+
|
| 75 |
+
# ์๋ attention ๊ณ์ฐ (weight ์ถ์ถ์ฉ)
|
| 76 |
+
scale = 1.0 / math.sqrt(hd)
|
| 77 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
|
| 78 |
+
|
| 79 |
+
# Causal mask
|
| 80 |
+
causal = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
|
| 81 |
+
scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
|
| 82 |
+
|
| 83 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 84 |
+
captured_attn["weights"] = attn_weights[0].cpu() # ์ฒซ ๋ฐฐ์น๋ง
|
| 85 |
+
|
| 86 |
+
out = torch.matmul(attn_weights, v)
|
| 87 |
+
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 88 |
+
return target_layer.o_proj(out)
|
| 89 |
+
|
| 90 |
+
# Hook ์ ์ฉ
|
| 91 |
+
target_layer.forward = hooked_forward
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
model(input_ids.to(device))
|
| 95 |
+
finally:
|
| 96 |
+
target_layer.forward = original_forward
|
| 97 |
+
|
| 98 |
+
return captured_attn.get("weights") # (num_heads, S, S)
|
| 99 |
+
|
| 100 |
+
def plot_attention_heatmap(
|
| 101 |
+
self,
|
| 102 |
+
attn_weights: torch.Tensor,
|
| 103 |
+
tokens: List[str],
|
| 104 |
+
head_idx: int = 0,
|
| 105 |
+
save_path: Optional[str] = None,
|
| 106 |
+
title: str = "Attention Weights",
|
| 107 |
+
):
|
| 108 |
+
"""Attention heatmap์ ๊ทธ๋ฆฝ๋๋ค."""
|
| 109 |
+
if not HAS_MATPLOTLIB:
|
| 110 |
+
print("โ ๏ธ matplotlib๊ฐ ํ์ํฉ๋๋ค")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
weights = attn_weights[head_idx].numpy()
|
| 114 |
+
max_len = min(len(tokens), 50) # ์ต๋ 50 ํ ํฐ๋ง ํ์
|
| 115 |
+
weights = weights[:max_len, :max_len]
|
| 116 |
+
display_tokens = tokens[:max_len]
|
| 117 |
+
|
| 118 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 119 |
+
im = ax.imshow(weights, cmap="Blues", aspect="auto")
|
| 120 |
+
|
| 121 |
+
ax.set_xticks(range(max_len))
|
| 122 |
+
ax.set_yticks(range(max_len))
|
| 123 |
+
ax.set_xticklabels(display_tokens, rotation=90, fontsize=7)
|
| 124 |
+
ax.set_yticklabels(display_tokens, fontsize=7)
|
| 125 |
+
|
| 126 |
+
ax.set_xlabel("Key (attended to)", fontsize=11)
|
| 127 |
+
ax.set_ylabel("Query (attending from)", fontsize=11)
|
| 128 |
+
ax.set_title(f"{title} โ Head {head_idx}", fontsize=13, fontweight="bold")
|
| 129 |
+
|
| 130 |
+
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 131 |
+
plt.tight_layout()
|
| 132 |
+
|
| 133 |
+
save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
|
| 134 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 135 |
+
print(f" ๐ Attention ์๊ฐํ ์ ์ฅ: {save_path}")
|
| 136 |
+
plt.close(fig)
|
| 137 |
+
|
| 138 |
+
def plot_multi_head_summary(
|
| 139 |
+
self,
|
| 140 |
+
attn_weights: torch.Tensor,
|
| 141 |
+
num_heads_to_show: int = 8,
|
| 142 |
+
save_path: Optional[str] = None,
|
| 143 |
+
):
|
| 144 |
+
"""์ฌ๋ฌ ํค๋์ attention ํจํด์ ์์ฝ ๋น๊ตํฉ๋๋ค."""
|
| 145 |
+
if not HAS_MATPLOTLIB:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
n_heads = min(attn_weights.shape[0], num_heads_to_show)
|
| 149 |
+
cols = 4
|
| 150 |
+
rows = math.ceil(n_heads / cols)
|
| 151 |
+
|
| 152 |
+
fig, axes = plt.subplots(rows, cols, figsize=(16, 4 * rows))
|
| 153 |
+
if rows == 1:
|
| 154 |
+
axes = axes.reshape(1, -1)
|
| 155 |
+
|
| 156 |
+
for idx in range(n_heads):
|
| 157 |
+
r, c = idx // cols, idx % cols
|
| 158 |
+
ax = axes[r, c]
|
| 159 |
+
w = attn_weights[idx].numpy()
|
| 160 |
+
ax.imshow(w, cmap="Blues", aspect="auto")
|
| 161 |
+
ax.set_title(f"Head {idx}", fontsize=10)
|
| 162 |
+
ax.set_xticks([])
|
| 163 |
+
ax.set_yticks([])
|
| 164 |
+
|
| 165 |
+
# ๋น subplot ์จ๊ธฐ๊ธฐ
|
| 166 |
+
for idx in range(n_heads, rows * cols):
|
| 167 |
+
r, c = idx // cols, idx % cols
|
| 168 |
+
axes[r, c].axis("off")
|
| 169 |
+
|
| 170 |
+
fig.suptitle("Attention Patterns by Head", fontsize=14, fontweight="bold")
|
| 171 |
+
plt.tight_layout()
|
| 172 |
+
|
| 173 |
+
save_path = save_path or str(self.save_dir / "attention_multi_head.png")
|
| 174 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 175 |
+
print(f" ๐ ๋ฉํฐ ํค๋ ์์ฝ ์ ์ฅ: {save_path}")
|
| 176 |
+
plt.close(fig)
|
llm_lab/evaluation/checklist.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ ๊ฒ์ฆ๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class InsightChecklist:
|
| 7 |
+
"""PRD์ ์ ์๋ ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ๋ฅผ ์๋/์๋์ผ๋ก ๊ฒ์ฆํฉ๋๋ค.
|
| 8 |
+
|
| 9 |
+
์๋ ๊ฒ์ฆ ๊ฐ๋ฅ ํญ๋ชฉ์ ๋ฉํธ๋ฆญ ๊ธฐ๋ฐ์ผ๋ก ํ์ ํ๊ณ ,
|
| 10 |
+
์๋ ํญ๋ชฉ์ ์ง๋ฌธ์ผ๋ก ์ ์ํฉ๋๋ค.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def run_checklist(
|
| 15 |
+
report: Dict[str, Any],
|
| 16 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 17 |
+
):
|
| 18 |
+
"""์ฒดํฌ๋ฆฌ์คํธ๋ฅผ ์คํํฉ๋๋ค."""
|
| 19 |
+
print("\n" + "=" * 70)
|
| 20 |
+
print("โ
ํ์ต ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ")
|
| 21 |
+
print("=" * 70)
|
| 22 |
+
|
| 23 |
+
checks = {
|
| 24 |
+
"passed": [],
|
| 25 |
+
"failed": [],
|
| 26 |
+
"manual": [],
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# โโ ์๋ ๊ฒ์ฆ โโ
|
| 30 |
+
|
| 31 |
+
# 1. Loss ์๋ ด
|
| 32 |
+
if report.get("perplexity", {}).get("loss", 99) < 4.0:
|
| 33 |
+
checks["passed"].append("๋ชจ๋ธ Loss๊ฐ 4.0 ์ดํ๋ก ์๋ ด")
|
| 34 |
+
else:
|
| 35 |
+
checks["failed"].append("๋ชจ๋ธ Loss๊ฐ 4.0 ์ดํ๋ก ๋ฏธ์๋ ด")
|
| 36 |
+
|
| 37 |
+
# 2. Loss ์คํ์ดํฌ
|
| 38 |
+
spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
|
| 39 |
+
if len(spikes) < 5:
|
| 40 |
+
checks["passed"].append(f"Loss ์คํ์ดํฌ {len(spikes)}ํ (< 5ํ)")
|
| 41 |
+
else:
|
| 42 |
+
checks["failed"].append(f"Loss ์คํ์ดํฌ {len(spikes)}ํ (โฅ 5ํ, ์์ ์ฑ ๊ฐ์ ํ์)")
|
| 43 |
+
|
| 44 |
+
# 3. ์์น๋ณ Loss ํจํด
|
| 45 |
+
if report.get("position_losses"):
|
| 46 |
+
early = report["position_losses"]["early_avg"]
|
| 47 |
+
late = report["position_losses"]["late_avg"]
|
| 48 |
+
if early > late:
|
| 49 |
+
checks["passed"].append("์์น๋ณ Loss ๊ฐ์ ํจํด ํ์ธ (์ปจํ
์คํธ ํ์ฉ)")
|
| 50 |
+
else:
|
| 51 |
+
checks["failed"].append("์์น๋ณ Loss ํจํด ์ด์ (์ปจํ
์คํธ ๋ฏธํ์ฉ?)")
|
| 52 |
+
|
| 53 |
+
# 4. ์์ฑ ๋ฐ๋ณต๋ฅ
|
| 54 |
+
rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
|
| 55 |
+
if rep < 0.3:
|
| 56 |
+
checks["passed"].append(f"์์ฑ ๋ฐ๋ณต๋ฅ {rep:.1%} (< 30%)")
|
| 57 |
+
else:
|
| 58 |
+
checks["failed"].append(f"์์ฑ ๋ฐ๋ณต๋ฅ {rep:.1%} (โฅ 30%, temperature/top_p ์กฐ์ )")
|
| 59 |
+
|
| 60 |
+
# 5. Gradient ํด๋ฆฌํ ๋น์จ
|
| 61 |
+
if metrics_history and metrics_history.get("grad_norm"):
|
| 62 |
+
gnorms = metrics_history["grad_norm"]
|
| 63 |
+
clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
|
| 64 |
+
if clip_rate < 0.3:
|
| 65 |
+
checks["passed"].append(f"Gradient ํด๋ฆฌํ ๋น์จ {clip_rate:.1%} (๊ฑด๊ฐ)")
|
| 66 |
+
else:
|
| 67 |
+
checks["failed"].append(f"Gradient ํด๋ฆฌํ ๋น์จ {clip_rate:.1%} (๋๋ฌด ์ฆ์)")
|
| 68 |
+
|
| 69 |
+
# โโ ์๋ ํ์ธ ํญ๋ชฉ โโ
|
| 70 |
+
manual_items = [
|
| 71 |
+
"Self-Attention์์ Q, K, V ๊ฐ๊ฐ์ ์ญํ ์ ์ค๋ช
ํ ์ ์๋๊ฐ?",
|
| 72 |
+
"RoPE๊ฐ ์์น ์ ๋ณด๋ฅผ ์ธ์ฝ๋ฉํ๋ ์ํ์ ์๋ฆฌ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 73 |
+
"GQA๊ฐ MHA ๋๋น ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ๋ ๋ฉ์ปค๋์ฆ์ ์ค๋ช
ํ ์ ์๋๊ฐ?",
|
| 74 |
+
"SwiGLU์ ๊ฒ์ดํ
๋ฉ์ปค๋์ฆ์ด ReLU FFN๊ณผ ์ด๋ป๊ฒ ๋ค๋ฅธ์ง ์ดํดํ๋๊ฐ?",
|
| 75 |
+
"Learning Rate Warmup์ด ์ ํ์ํ์ง ์ฒด๊ฐํ๋๊ฐ?",
|
| 76 |
+
"Gradient Accumulation์ด ํฐ ๋ฐฐ์น๋ฅผ ์๋ฎฌ๋ ์ด์
ํ๋ ์๋ฆฌ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 77 |
+
"Mixed Precision(bf16)์ ๋ฉ๋ชจ๋ฆฌ-์๋ ํจ๊ณผ๋ฅผ ์ธก์ ํ๋๊ฐ?",
|
| 78 |
+
"Activation Checkpointing์ ๋ฉ๋ชจ๋ฆฌ-์ฐ์ฐ ํธ๋ ์ด๋์คํ๋ฅผ ์ดํดํ๋๊ฐ?",
|
| 79 |
+
]
|
| 80 |
+
checks["manual"] = manual_items
|
| 81 |
+
|
| 82 |
+
# โโ ์ถ๋ ฅ โโ
|
| 83 |
+
total_auto = len(checks["passed"]) + len(checks["failed"])
|
| 84 |
+
passed_auto = len(checks["passed"])
|
| 85 |
+
|
| 86 |
+
print(f"\n ์๋ ๊ฒ์ฆ: {passed_auto}/{total_auto} ํต๊ณผ")
|
| 87 |
+
for item in checks["passed"]:
|
| 88 |
+
print(f" โ
{item}")
|
| 89 |
+
for item in checks["failed"]:
|
| 90 |
+
print(f" โ {item}")
|
| 91 |
+
|
| 92 |
+
print(f"\n ์๋ ํ์ธ ({len(manual_items)} ํญ๋ชฉ):")
|
| 93 |
+
for i, item in enumerate(manual_items, 1):
|
| 94 |
+
print(f" {i}. [ ] {item}")
|
| 95 |
+
|
| 96 |
+
print(f"\n ์ด ์งํ๋ฅ : {passed_auto}/{total_auto + len(manual_items)} "
|
| 97 |
+
f"(์๋ ํญ๋ชฉ ํฌํจ ์)")
|
| 98 |
+
|
| 99 |
+
return checks
|
llm_lab/evaluation/dynamics.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ์ญํ ๋ถ์๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import matplotlib
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
HAS_MATPLOTLIB = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_MATPLOTLIB = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TrainingDynamicsAnalyzer:
|
| 17 |
+
"""ํ์ต ๊ณผ์ ์ ๋ฉํธ๋ฆญ์ ๋ถ์ํ๊ณ ์๊ฐํํฉ๋๋ค.
|
| 18 |
+
|
| 19 |
+
๋ถ์ ํญ๋ชฉ:
|
| 20 |
+
- Loss ๊ณก์ : ์๋ ด ํจํด, ์คํ์ดํฌ ๊ฐ์ง
|
| 21 |
+
- LR ์ค์ผ์ค: Warmup + Cosine decay ํ์ธ
|
| 22 |
+
- Gradient Norm: ํ์ต ์์ ์ฑ, ํญ๋ฐ/์๋ฉธ ๊ฐ์ง
|
| 23 |
+
- ์ฒ๋ฆฌ๋: tokens/sec ์์ ์ฑ, ๋ณ๋ชฉ ๊ฐ์ง
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 27 |
+
self.save_dir = Path(save_dir)
|
| 28 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
|
| 31 |
+
"""ํ์ต ๋ฉํธ๋ฆญ์ ๋ถ์ํฉ๋๋ค.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
metrics_history: Trainer.metrics.history ๋์
๋๋ฆฌ
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
๋ถ์ ๊ฒฐ๊ณผ
|
| 38 |
+
"""
|
| 39 |
+
print("\n" + "=" * 70)
|
| 40 |
+
print("๐ฌ ํ์ต ์ญํ ๋ถ์")
|
| 41 |
+
print("=" * 70)
|
| 42 |
+
|
| 43 |
+
analysis = {}
|
| 44 |
+
|
| 45 |
+
# โโ Loss ๋ถ์ โโ
|
| 46 |
+
if metrics_history.get("train_loss"):
|
| 47 |
+
losses = metrics_history["train_loss"]
|
| 48 |
+
analysis["loss"] = {
|
| 49 |
+
"initial": round(losses[0], 4),
|
| 50 |
+
"final": round(losses[-1], 4),
|
| 51 |
+
"minimum": round(min(losses), 4),
|
| 52 |
+
"total_reduction": round(losses[0] - losses[-1], 4),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# ์คํ์ดํฌ ๊ฐ์ง (์ด์ ๊ฐ ๋๋น 50% ์ด์ ๊ธ์ฆ)
|
| 56 |
+
spikes = []
|
| 57 |
+
for i in range(1, len(losses)):
|
| 58 |
+
if losses[i] > losses[i-1] * 1.5:
|
| 59 |
+
step = metrics_history["step"][i] if "step" in metrics_history else i
|
| 60 |
+
spikes.append({"step": step, "loss": round(losses[i], 4)})
|
| 61 |
+
|
| 62 |
+
analysis["loss"]["spikes"] = spikes
|
| 63 |
+
|
| 64 |
+
print(f"\n ๐ Loss ๋ถ์:")
|
| 65 |
+
print(f" ์ด๊ธฐ: {analysis['loss']['initial']:.4f}")
|
| 66 |
+
print(f" ์ต์ข
: {analysis['loss']['final']:.4f}")
|
| 67 |
+
print(f" ์ต์: {analysis['loss']['minimum']:.4f}")
|
| 68 |
+
print(f" ๊ฐ์: {analysis['loss']['total_reduction']:.4f}")
|
| 69 |
+
print(f" ์คํ์ดํฌ: {len(spikes)}ํ")
|
| 70 |
+
if spikes:
|
| 71 |
+
for s in spikes[:5]:
|
| 72 |
+
print(f" Step {s['step']}: Loss = {s['loss']}")
|
| 73 |
+
|
| 74 |
+
# โโ Gradient Norm ๋ถ์ โโ
|
| 75 |
+
if metrics_history.get("grad_norm"):
|
| 76 |
+
gnorms = metrics_history["grad_norm"]
|
| 77 |
+
analysis["grad_norm"] = {
|
| 78 |
+
"mean": round(sum(gnorms) / len(gnorms), 4),
|
| 79 |
+
"max": round(max(gnorms), 4),
|
| 80 |
+
"min": round(min(gnorms), 4),
|
| 81 |
+
"clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
print(f"\n ๐ Gradient Norm ๋ถ์:")
|
| 85 |
+
print(f" ํ๊ท : {analysis['grad_norm']['mean']:.4f}")
|
| 86 |
+
print(f" ์ต๋: {analysis['grad_norm']['max']:.4f}")
|
| 87 |
+
print(f" ํด๋ฆฌํ ๋น์จ: {analysis['grad_norm']['clipped_pct']:.1f}%")
|
| 88 |
+
if analysis["grad_norm"]["clipped_pct"] > 30:
|
| 89 |
+
print(f" โ ๏ธ ํด๋ฆฌํ์ด ์ฆ์ โ LR ํํฅ ๋๋ warmup ์ฐ์ฅ ๊ณ ๋ ค")
|
| 90 |
+
|
| 91 |
+
# โโ ์ฒ๋ฆฌ๋ ๋ถ์ โโ
|
| 92 |
+
if metrics_history.get("tokens_per_sec"):
|
| 93 |
+
tps = metrics_history["tokens_per_sec"]
|
| 94 |
+
tps_valid = [t for t in tps if t > 0]
|
| 95 |
+
if tps_valid:
|
| 96 |
+
analysis["throughput"] = {
|
| 97 |
+
"mean": round(sum(tps_valid) / len(tps_valid)),
|
| 98 |
+
"std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5),
|
| 99 |
+
"min": round(min(tps_valid)),
|
| 100 |
+
"max": round(max(tps_valid)),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
print(f"\n โก ์ฒ๋ฆฌ๋ ๋ถ์:")
|
| 104 |
+
print(f" ํ๊ท : {analysis['throughput']['mean']:,} tokens/sec")
|
| 105 |
+
print(f" ํ์คํธ์ฐจ: {analysis['throughput']['std']:,}")
|
| 106 |
+
print(f" ๋ฒ์: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
|
| 107 |
+
|
| 108 |
+
return analysis
|
| 109 |
+
|
| 110 |
+
def plot_training_curves(
|
| 111 |
+
self,
|
| 112 |
+
metrics_history: Dict[str, list],
|
| 113 |
+
save_path: Optional[str] = None,
|
| 114 |
+
):
|
| 115 |
+
"""ํ์ต ๊ณก์ ์ 4-panel ์ฐจํธ๋ก ์๊ฐํํฉ๋๋ค."""
|
| 116 |
+
if not HAS_MATPLOTLIB:
|
| 117 |
+
print("โ ๏ธ matplotlib๊ฐ ํ์ํฉ๋๋ค: pip install matplotlib")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
fig, axes = plt.subplots(2, 2, figsize=(16, 10))
|
| 121 |
+
fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold")
|
| 122 |
+
|
| 123 |
+
steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", [])))))
|
| 124 |
+
|
| 125 |
+
# โโ (1) Loss โโ
|
| 126 |
+
ax = axes[0, 0]
|
| 127 |
+
if metrics_history.get("train_loss"):
|
| 128 |
+
ax.plot(steps[:len(metrics_history["train_loss"])],
|
| 129 |
+
metrics_history["train_loss"],
|
| 130 |
+
color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
|
| 131 |
+
|
| 132 |
+
# ์ด๋ ํ๊ท (์ค๋ฌด๋ฉ)
|
| 133 |
+
if len(metrics_history["train_loss"]) > 20:
|
| 134 |
+
window = min(50, len(metrics_history["train_loss"]) // 5)
|
| 135 |
+
smoothed = self._moving_average(metrics_history["train_loss"], window)
|
| 136 |
+
ax.plot(steps[window-1:len(smoothed)+window-1],
|
| 137 |
+
smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})")
|
| 138 |
+
|
| 139 |
+
if metrics_history.get("val_loss"):
|
| 140 |
+
val_steps = [steps[i] for i in range(0, len(steps),
|
| 141 |
+
max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])]
|
| 142 |
+
ax.plot(val_steps, metrics_history["val_loss"],
|
| 143 |
+
"o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss")
|
| 144 |
+
|
| 145 |
+
ax.set_xlabel("Step")
|
| 146 |
+
ax.set_ylabel("Loss")
|
| 147 |
+
ax.set_title("Training & Validation Loss")
|
| 148 |
+
ax.legend()
|
| 149 |
+
ax.grid(True, alpha=0.3)
|
| 150 |
+
|
| 151 |
+
# โโ (2) Learning Rate โโ
|
| 152 |
+
ax = axes[0, 1]
|
| 153 |
+
if metrics_history.get("learning_rate"):
|
| 154 |
+
ax.plot(steps[:len(metrics_history["learning_rate"])],
|
| 155 |
+
metrics_history["learning_rate"],
|
| 156 |
+
color="#059669", linewidth=2)
|
| 157 |
+
ax.set_xlabel("Step")
|
| 158 |
+
ax.set_ylabel("Learning Rate")
|
| 159 |
+
ax.set_title("Learning Rate Schedule")
|
| 160 |
+
ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0))
|
| 161 |
+
ax.grid(True, alpha=0.3)
|
| 162 |
+
|
| 163 |
+
# โโ (3) Gradient Norm โโ
|
| 164 |
+
ax = axes[1, 0]
|
| 165 |
+
if metrics_history.get("grad_norm"):
|
| 166 |
+
ax.plot(steps[:len(metrics_history["grad_norm"])],
|
| 167 |
+
metrics_history["grad_norm"],
|
| 168 |
+
color="#d97706", alpha=0.6, linewidth=0.8)
|
| 169 |
+
ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold")
|
| 170 |
+
ax.legend()
|
| 171 |
+
ax.set_xlabel("Step")
|
| 172 |
+
ax.set_ylabel("Gradient Norm")
|
| 173 |
+
ax.set_title("Gradient Norm (clipped at 1.0)")
|
| 174 |
+
ax.grid(True, alpha=0.3)
|
| 175 |
+
|
| 176 |
+
# โโ (4) Throughput โโ
|
| 177 |
+
ax = axes[1, 1]
|
| 178 |
+
if metrics_history.get("tokens_per_sec"):
|
| 179 |
+
tps = metrics_history["tokens_per_sec"]
|
| 180 |
+
ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8)
|
| 181 |
+
if tps:
|
| 182 |
+
avg_tps = sum(tps) / len(tps)
|
| 183 |
+
ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5,
|
| 184 |
+
label=f"Avg: {avg_tps:,.0f}")
|
| 185 |
+
ax.legend()
|
| 186 |
+
ax.set_xlabel("Step")
|
| 187 |
+
ax.set_ylabel("Tokens/sec")
|
| 188 |
+
ax.set_title("Training Throughput")
|
| 189 |
+
ax.grid(True, alpha=0.3)
|
| 190 |
+
|
| 191 |
+
plt.tight_layout()
|
| 192 |
+
|
| 193 |
+
save_path = save_path or str(self.save_dir / "training_curves.png")
|
| 194 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 195 |
+
print(f"\n ๐ ํ์ต ๊ณก์ ์ ์ฅ: {save_path}")
|
| 196 |
+
plt.close(fig)
|
| 197 |
+
|
| 198 |
+
def plot_position_loss(
|
| 199 |
+
self,
|
| 200 |
+
position_losses: List[float],
|
| 201 |
+
save_path: Optional[str] = None,
|
| 202 |
+
):
|
| 203 |
+
"""์์น๋ณ Loss ๋ถํฌ๋ฅผ ์๊ฐํํฉ๋๋ค."""
|
| 204 |
+
if not HAS_MATPLOTLIB:
|
| 205 |
+
return
|
| 206 |
+
|
| 207 |
+
fig, ax = plt.subplots(figsize=(12, 5))
|
| 208 |
+
|
| 209 |
+
positions = list(range(len(position_losses)))
|
| 210 |
+
ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5)
|
| 211 |
+
ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb")
|
| 212 |
+
|
| 213 |
+
ax.set_xlabel("Position in Sequence", fontsize=12)
|
| 214 |
+
ax.set_ylabel("Cross-Entropy Loss", fontsize=12)
|
| 215 |
+
ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
|
| 216 |
+
ax.grid(True, alpha=0.3)
|
| 217 |
+
|
| 218 |
+
# ์ฃผ์ ๊ตฌ๊ฐ ํ์
|
| 219 |
+
if len(position_losses) > 100:
|
| 220 |
+
early_avg = sum(position_losses[:50]) / 50
|
| 221 |
+
late_avg = sum(position_losses[-200:]) / 200
|
| 222 |
+
ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4,
|
| 223 |
+
label=f"Early avg (0-50): {early_avg:.2f}")
|
| 224 |
+
ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4,
|
| 225 |
+
label=f"Late avg (-200): {late_avg:.2f}")
|
| 226 |
+
ax.legend()
|
| 227 |
+
|
| 228 |
+
plt.tight_layout()
|
| 229 |
+
|
| 230 |
+
save_path = save_path or str(self.save_dir / "position_loss.png")
|
| 231 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 232 |
+
print(f" ๐ ์์น๋ณ Loss ์ ์ฅ: {save_path}")
|
| 233 |
+
plt.close(fig)
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def _moving_average(data: list, window: int) -> list:
|
| 237 |
+
"""์ด๋ ํ๊ท ๊ณ์ฐ."""
|
| 238 |
+
result = []
|
| 239 |
+
for i in range(window - 1, len(data)):
|
| 240 |
+
avg = sum(data[i - window + 1 : i + 1]) / window
|
| 241 |
+
result.append(avg)
|
| 242 |
+
return result
|
llm_lab/evaluation/full_evaluator.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""์ข
ํฉ ํ๊ฐ ์คํ๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from llm_lab.config import EvalConfig
|
| 13 |
+
from .perplexity import PerplexityEvaluator
|
| 14 |
+
from .generation import GenerationEvaluator
|
| 15 |
+
from .dynamics import TrainingDynamicsAnalyzer
|
| 16 |
+
from .attention_viz import AttentionVisualizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FullEvaluator:
|
| 20 |
+
"""๋ชจ๋ ํ๊ฐ๋ฅผ ํ ๋ฒ์ ์คํํ๊ณ ๋ฆฌํฌํธ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 21 |
+
|
| 22 |
+
์ฌ์ฉ๋ฒ:
|
| 23 |
+
```python
|
| 24 |
+
evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
|
| 25 |
+
report = evaluator.run_full_evaluation()
|
| 26 |
+
```
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
model: nn.Module,
|
| 32 |
+
tokenizer: Any,
|
| 33 |
+
val_dataloader: DataLoader,
|
| 34 |
+
device: torch.device,
|
| 35 |
+
config: Optional[EvalConfig] = None,
|
| 36 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 37 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 38 |
+
):
|
| 39 |
+
self.model = model
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.val_dataloader = val_dataloader
|
| 42 |
+
self.device = device
|
| 43 |
+
self.config = config or EvalConfig()
|
| 44 |
+
self.dtype = dtype
|
| 45 |
+
self.metrics_history = metrics_history
|
| 46 |
+
|
| 47 |
+
self.save_dir = Path(self.config.save_dir)
|
| 48 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
def run_full_evaluation(self) -> Dict[str, Any]:
|
| 51 |
+
"""์ ์ฒด ํ๊ฐ๋ฅผ ์คํํฉ๋๋ค."""
|
| 52 |
+
report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
|
| 53 |
+
|
| 54 |
+
print("\n" + "=" * 70)
|
| 55 |
+
print("๐ ์ข
ํฉ ํ๊ฐ ์์")
|
| 56 |
+
print("=" * 70)
|
| 57 |
+
|
| 58 |
+
# โโ 1. Perplexity โโ
|
| 59 |
+
print("\n" + "โ" * 40)
|
| 60 |
+
print("Phase 1/4: Perplexity ์ธก์ ")
|
| 61 |
+
print("โ" * 40)
|
| 62 |
+
ppl_evaluator = PerplexityEvaluator(self.config)
|
| 63 |
+
report["perplexity"] = ppl_evaluator.evaluate(
|
| 64 |
+
self.model, self.val_dataloader, self.device, self.dtype
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# ์์น๋ณ Loss
|
| 68 |
+
print("\n ์์น๋ณ Loss ์ธก์ ์ค...")
|
| 69 |
+
position_losses = ppl_evaluator.evaluate_per_position(
|
| 70 |
+
self.model, self.val_dataloader, self.device, self.dtype
|
| 71 |
+
)
|
| 72 |
+
report["position_losses"] = {
|
| 73 |
+
"early_avg": round(sum(position_losses[:50]) / max(len(position_losses[:50]), 1), 4),
|
| 74 |
+
"late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
# ์์น๋ณ Loss ์๊ฐํ
|
| 78 |
+
dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
|
| 79 |
+
dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
|
| 80 |
+
|
| 81 |
+
# โโ 2. ํ
์คํธ ์์ฑ โโ
|
| 82 |
+
print("\n" + "โ" * 40)
|
| 83 |
+
print("Phase 2/4: ํ
์คํธ ์์ฑ")
|
| 84 |
+
print("โ" * 40)
|
| 85 |
+
gen_evaluator = GenerationEvaluator(self.config)
|
| 86 |
+
gen_results = gen_evaluator.generate_samples(
|
| 87 |
+
self.model, self.tokenizer, self.device
|
| 88 |
+
)
|
| 89 |
+
report["generation"] = {
|
| 90 |
+
"num_prompts": len(gen_results),
|
| 91 |
+
"avg_metrics": self._average_gen_metrics(gen_results),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# โโ 3. ํ์ต ์ญํ ๋ถ์ โโ
|
| 95 |
+
if self.metrics_history:
|
| 96 |
+
print("\n" + "โ" * 40)
|
| 97 |
+
print("Phase 3/4: ํ์ต ์ญํ ๋ถ์")
|
| 98 |
+
print("โ" * 40)
|
| 99 |
+
report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
|
| 100 |
+
dynamics.plot_training_curves(self.metrics_history,
|
| 101 |
+
str(self.save_dir / "training_curves.png"))
|
| 102 |
+
else:
|
| 103 |
+
print("\n Phase 3/4: ๊ฑด๋๋ (metrics_history ์์)")
|
| 104 |
+
|
| 105 |
+
# โโ 4. Attention ์๊ฐํ (์ํ) โโ
|
| 106 |
+
print("\n" + "โ" * 40)
|
| 107 |
+
print("Phase 4/4: Attention ์๊ฐํ")
|
| 108 |
+
print("โ" * 40)
|
| 109 |
+
try:
|
| 110 |
+
self._visualize_attention_sample()
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f" โ ๏ธ Attention ์๊ฐํ ์คํจ: {e}")
|
| 113 |
+
|
| 114 |
+
# โโ ๋ฆฌํฌํธ ์ ์ฅ โโ
|
| 115 |
+
report_path = self.save_dir / "eval_report.json"
|
| 116 |
+
with open(report_path, "w") as f:
|
| 117 |
+
json.dump(report, f, indent=2, default=str)
|
| 118 |
+
print(f"\n๐ ๋ฆฌํฌํธ ์ ์ฅ: {report_path}")
|
| 119 |
+
|
| 120 |
+
# โโ ์์ฝ ์ถ๋ ฅ โโ
|
| 121 |
+
self._print_summary(report)
|
| 122 |
+
|
| 123 |
+
return report
|
| 124 |
+
|
| 125 |
+
def _visualize_attention_sample(self):
|
| 126 |
+
"""์ํ ํ
์คํธ๋ก attention์ ์๊ฐํํฉ๋๋ค."""
|
| 127 |
+
viz = AttentionVisualizer(str(self.save_dir))
|
| 128 |
+
|
| 129 |
+
sample_text = "The cat sat on the mat and looked at the bird."
|
| 130 |
+
token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
|
| 131 |
+
input_tensor = torch.tensor([token_ids], dtype=torch.long)
|
| 132 |
+
|
| 133 |
+
# ํ ํฐ ๋ฌธ์์ด (์๊ฐํ ๋ผ๋ฒจ์ฉ)
|
| 134 |
+
tokens_str = []
|
| 135 |
+
for tid in token_ids:
|
| 136 |
+
decoded = self.tokenizer.decode([tid])
|
| 137 |
+
tokens_str.append(decoded.replace("\n", "\\n"))
|
| 138 |
+
|
| 139 |
+
# Layer 0 attention ์ถ์ถ
|
| 140 |
+
attn_weights = viz.extract_attention(
|
| 141 |
+
self.model, input_tensor, layer_idx=0, device=self.device
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if attn_weights is not None:
|
| 145 |
+
viz.plot_attention_heatmap(
|
| 146 |
+
attn_weights, tokens_str, head_idx=0,
|
| 147 |
+
title="Layer 0 Attention"
|
| 148 |
+
)
|
| 149 |
+
viz.plot_multi_head_summary(attn_weights)
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
|
| 153 |
+
"""๋ชจ๋ ํ๋กฌํํธ์ ์์ฑ ๋ฉํธ๋ฆญ ํ๊ท ."""
|
| 154 |
+
if not gen_results:
|
| 155 |
+
return {}
|
| 156 |
+
|
| 157 |
+
all_metrics = [r["metrics"] for r in gen_results if r.get("metrics")]
|
| 158 |
+
if not all_metrics:
|
| 159 |
+
return {}
|
| 160 |
+
|
| 161 |
+
keys = all_metrics[0].keys()
|
| 162 |
+
return {
|
| 163 |
+
k: round(sum(m.get(k, 0) for m in all_metrics) / len(all_metrics), 3)
|
| 164 |
+
for k in keys
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def _print_summary(self, report: Dict[str, Any]):
|
| 168 |
+
"""์ต์ข
์์ฝ์ ์ถ๋ ฅํฉ๋๋ค."""
|
| 169 |
+
print("\n" + "=" * 70)
|
| 170 |
+
print("๐ ํ๊ฐ ์์ฝ ๋ฆฌํฌํธ")
|
| 171 |
+
print("=" * 70)
|
| 172 |
+
|
| 173 |
+
# Perplexity
|
| 174 |
+
if "perplexity" in report:
|
| 175 |
+
ppl = report["perplexity"]
|
| 176 |
+
print(f"\n ๐ฏ Perplexity:")
|
| 177 |
+
print(f" Loss: {ppl['loss']:.4f}")
|
| 178 |
+
print(f" PPL: {ppl['perplexity']:.2f}")
|
| 179 |
+
|
| 180 |
+
# ๋ฑ๊ธ ํ์
|
| 181 |
+
ppl_val = ppl["perplexity"]
|
| 182 |
+
if ppl_val < 20:
|
| 183 |
+
grade = "๐ ์ฐ์ (Strong)"
|
| 184 |
+
elif ppl_val < 35:
|
| 185 |
+
grade = "โ
์ํธ (Good)"
|
| 186 |
+
elif ppl_val < 60:
|
| 187 |
+
grade = "โ ๏ธ ๋ณดํต (Fair)"
|
| 188 |
+
else:
|
| 189 |
+
grade = "โ ๋ฏธํก (ํ์ต ์ถ๊ฐ ํ์)"
|
| 190 |
+
print(f" ๋ฑ๊ธ: {grade}")
|
| 191 |
+
|
| 192 |
+
# ์์น๋ณ Loss
|
| 193 |
+
if "position_losses" in report:
|
| 194 |
+
pl = report["position_losses"]
|
| 195 |
+
print(f"\n ๐ ์์น๋ณ Loss:")
|
| 196 |
+
print(f" ์ด๋ฐ (0-50): {pl['early_avg']:.4f}")
|
| 197 |
+
print(f" ํ๋ฐ (-200): {pl['late_avg']:.4f}")
|
| 198 |
+
print(f" ์ปจํ
์คํธ ํจ๊ณผ: {pl['early_avg'] - pl['late_avg']:.4f} ๊ฐ์")
|
| 199 |
+
|
| 200 |
+
# ์์ฑ ํ์ง
|
| 201 |
+
if "generation" in report and report["generation"].get("avg_metrics"):
|
| 202 |
+
gm = report["generation"]["avg_metrics"]
|
| 203 |
+
print(f"\n โ๏ธ ์์ฑ ํ์ง:")
|
| 204 |
+
print(f" ํ๊ท ๊ธธ์ด: {gm.get('avg_length', 0):.0f} ์")
|
| 205 |
+
print(f" ๋ฐ๋ณต๋ฅ : {gm.get('repetition_rate', 0):.1%}")
|
| 206 |
+
print(f" ์ดํ ๋ค์์ฑ: {gm.get('lexical_diversity', 0):.3f}")
|
| 207 |
+
|
| 208 |
+
# ํ์ต ์ญํ
|
| 209 |
+
if "training_dynamics" in report:
|
| 210 |
+
td = report["training_dynamics"]
|
| 211 |
+
if "loss" in td:
|
| 212 |
+
print(f"\n ๐ ํ์ต ์ญํ:")
|
| 213 |
+
print(f" Loss ๊ฐ์: {td['loss']['initial']:.4f} โ {td['loss']['final']:.4f}")
|
| 214 |
+
print(f" ์คํ์ดํฌ: {len(td['loss']['spikes'])}ํ")
|
| 215 |
+
|
| 216 |
+
# ์์ฑ๋ ํ์ผ
|
| 217 |
+
print(f"\n ๐ ๊ฒฐ๊ณผ ํ์ผ:")
|
| 218 |
+
for f in sorted(self.save_dir.glob("*")):
|
| 219 |
+
size = f.stat().st_size / 1024
|
| 220 |
+
print(f" {f.name} ({size:.1f} KB)")
|
| 221 |
+
|
| 222 |
+
print("\n" + "=" * 70)
|
llm_lab/evaluation/generation.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ
์คํธ ์์ฑ ํ๊ฐ๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from llm_lab.config import EvalConfig
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GenerationEvaluator:
|
| 12 |
+
"""๋ค์ํ ํ๋กฌํํธ๋ก ํ
์คํธ๋ฅผ ์์ฑํ์ฌ ํ์ง์ ํ๊ฐํฉ๋๋ค.
|
| 13 |
+
|
| 14 |
+
ํ๊ฐ ๊ด์ :
|
| 15 |
+
1) ๋ฌธ๋ฒ์ ์ ํ์ฑ: ์์ด ๋ฌธ๋ฒ์ ๋ง๋ ๋ฌธ์ฅ์ ์์ฑํ๋๊ฐ?
|
| 16 |
+
2) ์ผ๊ด์ฑ: ๋ฌธ๋งฅ์ ์ ์งํ๋ฉฐ ์ด์ด๊ฐ๋๊ฐ?
|
| 17 |
+
3) ๋ค์์ฑ: ๊ฐ์ ํ๋กฌํํธ์ ๋ค๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์์ฑํ๋๊ฐ?
|
| 18 |
+
4) ๋ฐ๋ณต ํํผ: ๊ฐ์ ๊ตฌ์ ์ ๋ฐ๋ณตํ์ง ์๋๊ฐ?
|
| 19 |
+
5) ์ง์ ํํ: ํ์ต ๋ฐ์ดํฐ์ ์ง์์ด ๋ฐ์๋๋๊ฐ?
|
| 20 |
+
|
| 21 |
+
1B ๋ชจ๋ธ์ ํ์ค์ ๊ธฐ๋์น:
|
| 22 |
+
- ๋ฌธ๋ฒ์ ์ผ๋ก ์ฌ๋ฐ๋ฅธ ์์ด ๋ฌธ์ฅ ์์ฑ โ
|
| 23 |
+
- ์งง์ ๋ฌธ๋จ ๋ด ์ผ๊ด์ฑ ์ ์ง โ
|
| 24 |
+
- ๋ณต์กํ ์ถ๋ก ์ด๋ ๊ธด ๋
ผ๋ฆฌ ์ ๊ฐ โ (๋ ํฐ ๋ชจ๋ธ ํ์)
|
| 25 |
+
- ์ฌ์ค์ ์ ํ์ฑ์ ๋ณด์ฅ ์ ๋จ โ ๏ธ
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# ๋ค์ํ ๋๋ฉ์ธ์ ํ
์คํธ ํ๋กฌํํธ
|
| 29 |
+
DEFAULT_PROMPTS = [
|
| 30 |
+
# โโ ์ผ๋ฐ ์ง์ โโ
|
| 31 |
+
"The theory of relativity states that",
|
| 32 |
+
"In the history of computer science,",
|
| 33 |
+
"The human brain is remarkable because",
|
| 34 |
+
|
| 35 |
+
# โโ ์ค๋ช
/๊ต์ก โโ
|
| 36 |
+
"To understand machine learning, one must first",
|
| 37 |
+
"The water cycle begins when",
|
| 38 |
+
"Photosynthesis is the process by which",
|
| 39 |
+
|
| 40 |
+
# โโ ์์ฌ/์คํ ๋ฆฌ โโ
|
| 41 |
+
"Once upon a time, in a small village near the mountains,",
|
| 42 |
+
"The detective looked at the evidence and realized that",
|
| 43 |
+
|
| 44 |
+
# โโ ์ฝ๋/๊ธฐ์ โโ
|
| 45 |
+
"def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
|
| 46 |
+
"The most important data structures in programming are",
|
| 47 |
+
|
| 48 |
+
# โโ ์งง์ ์์ฑ โโ
|
| 49 |
+
"The capital of France is",
|
| 50 |
+
"Water boils at a temperature of",
|
| 51 |
+
|
| 52 |
+
# โโ ๊ธด ๋ฌธ๋งฅ โโ
|
| 53 |
+
("Artificial intelligence has transformed many industries. "
|
| 54 |
+
"In healthcare, AI is used for diagnosis and drug discovery. "
|
| 55 |
+
"In finance, it powers algorithmic trading and fraud detection. "
|
| 56 |
+
"Looking ahead, the most promising application of AI is"),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
def __init__(self, config: EvalConfig):
|
| 60 |
+
self.config = config
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def generate_samples(
|
| 64 |
+
self,
|
| 65 |
+
model: nn.Module,
|
| 66 |
+
tokenizer: Any,
|
| 67 |
+
device: torch.device,
|
| 68 |
+
prompts: Optional[List[str]] = None,
|
| 69 |
+
verbose: bool = True,
|
| 70 |
+
) -> List[Dict[str, Any]]:
|
| 71 |
+
"""ํ๋กฌํํธ๋ณ๋ก ํ
์คํธ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
[{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
|
| 75 |
+
"""
|
| 76 |
+
model.eval()
|
| 77 |
+
prompts = prompts or self.DEFAULT_PROMPTS
|
| 78 |
+
results = []
|
| 79 |
+
|
| 80 |
+
if verbose:
|
| 81 |
+
print("\n" + "=" * 70)
|
| 82 |
+
print("๐ ํ
์คํธ ์์ฑ ํ๊ฐ")
|
| 83 |
+
print("=" * 70)
|
| 84 |
+
|
| 85 |
+
for idx, prompt in enumerate(prompts):
|
| 86 |
+
prompt_results = {
|
| 87 |
+
"prompt": prompt,
|
| 88 |
+
"generations": [],
|
| 89 |
+
"metrics": {},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if verbose:
|
| 93 |
+
print(f"\n{'โ'*60}")
|
| 94 |
+
print(f"ํ๋กฌํํธ [{idx+1}/{len(prompts)}]:")
|
| 95 |
+
print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
|
| 96 |
+
print(f"{'โ'*60}")
|
| 97 |
+
|
| 98 |
+
# ํ๋กฌํํธ ์ธ์ฝ๋ฉ
|
| 99 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 100 |
+
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 101 |
+
|
| 102 |
+
all_texts = []
|
| 103 |
+
for sample_idx in range(self.config.num_samples):
|
| 104 |
+
# ์์ฑ
|
| 105 |
+
generated_ids = model.generate(
|
| 106 |
+
input_tensor,
|
| 107 |
+
max_new_tokens=self.config.max_new_tokens,
|
| 108 |
+
temperature=self.config.temperature,
|
| 109 |
+
top_k=self.config.top_k,
|
| 110 |
+
top_p=self.config.top_p,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# ๋์ฝ๋ฉ (ํ๋กฌํํธ ์ดํ ๋ถ๋ถ๋ง)
|
| 114 |
+
new_ids = generated_ids[0][len(prompt_ids):].tolist()
|
| 115 |
+
generated_text = tokenizer.decode(new_ids)
|
| 116 |
+
all_texts.append(generated_text)
|
| 117 |
+
|
| 118 |
+
prompt_results["generations"].append(generated_text)
|
| 119 |
+
|
| 120 |
+
if verbose:
|
| 121 |
+
print(f"\n โ๏ธ ์์ฑ #{sample_idx+1}:")
|
| 122 |
+
# ๊น๋ํ ์ถ๋ ฅ (์ค๋ฐ๊ฟ ํฌํจ)
|
| 123 |
+
display_text = generated_text[:500]
|
| 124 |
+
for line in display_text.split("\n"):
|
| 125 |
+
print(f" {line}")
|
| 126 |
+
if len(generated_text) > 500:
|
| 127 |
+
print(f" ... (์ด {len(generated_text)} ๋ฌธ์)")
|
| 128 |
+
|
| 129 |
+
# ์์ฑ ํ์ง ๋ฉํธ๋ฆญ
|
| 130 |
+
prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
|
| 131 |
+
|
| 132 |
+
if verbose and prompt_results["metrics"]:
|
| 133 |
+
m = prompt_results["metrics"]
|
| 134 |
+
print(f"\n ๐ ๋ฉํธ๋ฆญ: "
|
| 135 |
+
f"ํ๊ท ๊ธธ์ด={m['avg_length']:.0f}์, "
|
| 136 |
+
f"๋ฐ๋ณต๋ฅ ={m['repetition_rate']:.1%}, "
|
| 137 |
+
f"์ดํ ๋ค์์ฑ={m['lexical_diversity']:.2f}")
|
| 138 |
+
|
| 139 |
+
results.append(prompt_results)
|
| 140 |
+
|
| 141 |
+
return results
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
|
| 145 |
+
"""์์ฑ ํ
์คํธ์ ํ์ง ๋ฉํธ๋ฆญ์ ๊ณ์ฐํฉ๋๋ค.
|
| 146 |
+
|
| 147 |
+
๋ฉํธ๋ฆญ:
|
| 148 |
+
- avg_length: ํ๊ท ์์ฑ ๊ธธ์ด (๋ฌธ์)
|
| 149 |
+
- avg_word_count: ํ๊ท ๋จ์ด ์
|
| 150 |
+
- repetition_rate: n-gram ๋ฐ๋ณต๋ฅ (๋ฎ์์๋ก ์ข์)
|
| 151 |
+
- lexical_diversity: ๊ณ ์ ๋จ์ด ๋น์จ (๋์์๋ก ๋ค์)
|
| 152 |
+
- sample_diversity: ์ํ ๊ฐ ๋ค์์ฑ (๋ค๋ฅธ ์์ฑ๋ผ๋ฆฌ ์ผ๋ง๋ ๋ค๋ฅธ๊ฐ)
|
| 153 |
+
"""
|
| 154 |
+
if not texts:
|
| 155 |
+
return {}
|
| 156 |
+
|
| 157 |
+
# ๊ธธ์ด
|
| 158 |
+
lengths = [len(t) for t in texts]
|
| 159 |
+
word_counts = [len(t.split()) for t in texts]
|
| 160 |
+
|
| 161 |
+
# ๋ฐ๋ณต๋ฅ (4-gram ๊ธฐ์ค)
|
| 162 |
+
rep_rates = []
|
| 163 |
+
for text in texts:
|
| 164 |
+
words = text.lower().split()
|
| 165 |
+
if len(words) < 4:
|
| 166 |
+
rep_rates.append(0.0)
|
| 167 |
+
continue
|
| 168 |
+
ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
|
| 169 |
+
unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
|
| 170 |
+
rep_rates.append(1.0 - unique_ratio) # ๋ฐ๋ณต๋ฅ = 1 - ๊ณ ์ ๋น์จ
|
| 171 |
+
|
| 172 |
+
# ์ดํ ๋ค์์ฑ (Type-Token Ratio)
|
| 173 |
+
diversities = []
|
| 174 |
+
for text in texts:
|
| 175 |
+
words = text.lower().split()
|
| 176 |
+
if words:
|
| 177 |
+
diversities.append(len(set(words)) / len(words))
|
| 178 |
+
else:
|
| 179 |
+
diversities.append(0.0)
|
| 180 |
+
|
| 181 |
+
# ์ํ ๊ฐ ๋ค์์ฑ (์์นด๋ ์ ์ฌ๋์ ์ญ)
|
| 182 |
+
sample_div = 0.0
|
| 183 |
+
if len(texts) > 1:
|
| 184 |
+
word_sets = [set(t.lower().split()) for t in texts]
|
| 185 |
+
similarities = []
|
| 186 |
+
for i in range(len(word_sets)):
|
| 187 |
+
for j in range(i+1, len(word_sets)):
|
| 188 |
+
inter = len(word_sets[i] & word_sets[j])
|
| 189 |
+
union = len(word_sets[i] | word_sets[j])
|
| 190 |
+
if union > 0:
|
| 191 |
+
similarities.append(inter / union)
|
| 192 |
+
sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
|
| 193 |
+
|
| 194 |
+
return {
|
| 195 |
+
"avg_length": sum(lengths) / len(lengths),
|
| 196 |
+
"avg_word_count": sum(word_counts) / len(word_counts),
|
| 197 |
+
"repetition_rate": sum(rep_rates) / len(rep_rates),
|
| 198 |
+
"lexical_diversity": sum(diversities) / len(diversities),
|
| 199 |
+
"sample_diversity": round(sample_div, 3),
|
| 200 |
+
}
|
llm_lab/evaluation/perplexity.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Perplexity(PPL) ํ๊ฐ๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import time
|
| 5 |
+
from typing import Dict, List
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from llm_lab.config import EvalConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PerplexityEvaluator:
|
| 16 |
+
"""Perplexity(PPL)๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 17 |
+
|
| 18 |
+
Perplexity๋?
|
| 19 |
+
PPL = exp(average cross-entropy loss)
|
| 20 |
+
|
| 21 |
+
์ง๊ด์ ์๋ฏธ:
|
| 22 |
+
- PPL = 1: ์๋ฒฝํ ์์ธก (๋ถ๊ฐ๋ฅ)
|
| 23 |
+
- PPL = 10: ๋งค๋ฒ 10๊ฐ ํ๋ณด ์ค ๊ณ ๋ฅด๋ ์์ค
|
| 24 |
+
- PPL = 100: 100๊ฐ ํ๋ณด ์ค ๊ณ ๋ฅด๋ ์์ค (๋ฌด์์์ ๊ฐ๊น์)
|
| 25 |
+
- PPL = 32000: vocab ์ ์ฒด์์ ๋๋ค ์ ํ (์ด๊ธฐ ๋๋ค ๋ชจ๋ธ)
|
| 26 |
+
|
| 27 |
+
์ข์ 1B ๋ชจ๋ธ ๊ธฐ์ค (์์ด ์น ํ
์คํธ):
|
| 28 |
+
- 5B ํ ํฐ ํ์ต: PPL ~30-40
|
| 29 |
+
- 10B ํ ํฐ ํ์ต: PPL ~20-30
|
| 30 |
+
- 20B ํ ํฐ ํ์ต: PPL ~15-25
|
| 31 |
+
|
| 32 |
+
์ธก์ ๋ฐฉ๋ฒ:
|
| 33 |
+
- ๊ฒ์ฆ ๋ฐ์ดํฐ์
์ ๋ชจ๋ ํ ํฐ์ ๋ํด cross-entropy ๊ณ์ฐ
|
| 34 |
+
- ํ ํฐ ๋จ์ ํ๊ท ํ exp() ์ ์ฉ
|
| 35 |
+
- ํจ๋ฉ ํ ํฐ์ ์ ์ธ (ignore_index=-100)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, config: EvalConfig):
|
| 39 |
+
self.config = config
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def evaluate(
|
| 43 |
+
self,
|
| 44 |
+
model: nn.Module,
|
| 45 |
+
dataloader: DataLoader,
|
| 46 |
+
device: torch.device,
|
| 47 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 48 |
+
desc: str = "Evaluation",
|
| 49 |
+
) -> Dict[str, float]:
|
| 50 |
+
"""Perplexity๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
{
|
| 54 |
+
"loss": ํ๊ท cross-entropy loss,
|
| 55 |
+
"perplexity": exp(loss),
|
| 56 |
+
"num_tokens": ํ๊ฐ์ ์ฌ์ฉ๋ ์ด ํ ํฐ ์,
|
| 57 |
+
"num_batches": ํ๊ฐ์ ์ฌ์ฉ๋ ๋ฐฐ์น ์,
|
| 58 |
+
}
|
| 59 |
+
"""
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
total_loss = 0.0
|
| 63 |
+
total_tokens = 0
|
| 64 |
+
num_batches = 0
|
| 65 |
+
|
| 66 |
+
print(f"\n๐ {desc}")
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
|
| 69 |
+
for i, batch in enumerate(dataloader):
|
| 70 |
+
if i >= self.config.max_eval_batches:
|
| 71 |
+
break
|
| 72 |
+
|
| 73 |
+
input_ids = batch["input_ids"].to(device)
|
| 74 |
+
targets = batch["targets"].to(device)
|
| 75 |
+
|
| 76 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 77 |
+
logits, _ = model(input_ids)
|
| 78 |
+
|
| 79 |
+
# ํ ํฐ๋ณ cross-entropy (reduction='none')
|
| 80 |
+
# logits: (B, S, V) โ (B*S, V)
|
| 81 |
+
# targets: (B, S) โ (B*S,)
|
| 82 |
+
loss_per_token = F.cross_entropy(
|
| 83 |
+
logits.view(-1, logits.size(-1)),
|
| 84 |
+
targets.view(-1),
|
| 85 |
+
ignore_index=-100,
|
| 86 |
+
reduction="none",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# -100์ด ์๋ ์ ํจ ํ ํฐ๋ง ์นด์ดํธ
|
| 90 |
+
valid_mask = (targets.view(-1) != -100)
|
| 91 |
+
valid_tokens = valid_mask.sum().item()
|
| 92 |
+
|
| 93 |
+
total_loss += loss_per_token[valid_mask].sum().item()
|
| 94 |
+
total_tokens += valid_tokens
|
| 95 |
+
num_batches += 1
|
| 96 |
+
|
| 97 |
+
if (i + 1) % 20 == 0:
|
| 98 |
+
running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
|
| 99 |
+
print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
|
| 100 |
+
|
| 101 |
+
elapsed = time.time() - start_time
|
| 102 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 103 |
+
perplexity = math.exp(min(avg_loss, 100)) # overflow ๋ฐฉ์ง
|
| 104 |
+
|
| 105 |
+
results = {
|
| 106 |
+
"loss": round(avg_loss, 4),
|
| 107 |
+
"perplexity": round(perplexity, 2),
|
| 108 |
+
"num_tokens": total_tokens,
|
| 109 |
+
"num_batches": num_batches,
|
| 110 |
+
"eval_time_sec": round(elapsed, 1),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
print(f" โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 114 |
+
print(f" Loss: {results['loss']:.4f}")
|
| 115 |
+
print(f" Perplexity: {results['perplexity']:.2f}")
|
| 116 |
+
print(f" ํ๊ฐ ํ ํฐ: {total_tokens:,}")
|
| 117 |
+
print(f" ์์ ์๊ฐ: {elapsed:.1f}์ด")
|
| 118 |
+
|
| 119 |
+
return results
|
| 120 |
+
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def evaluate_per_position(
|
| 123 |
+
self,
|
| 124 |
+
model: nn.Module,
|
| 125 |
+
dataloader: DataLoader,
|
| 126 |
+
device: torch.device,
|
| 127 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 128 |
+
max_batches: int = 50,
|
| 129 |
+
) -> List[float]:
|
| 130 |
+
"""์ํ์ค ๋ด ์์น๋ณ Loss๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 131 |
+
|
| 132 |
+
ํ์ต ํฌ์ธํธ:
|
| 133 |
+
- ์์น 0~10: Loss๊ฐ ๋์ (๋ฌธ๋งฅ์ด ๋ถ์กฑ)
|
| 134 |
+
- ์์น 100+: Loss๊ฐ ์์ ์ ์ผ๋ก ๋ฎ์์ง (๋ฌธ๋งฅ ํ์ฉ)
|
| 135 |
+
- ์ด ํจํด์ด Transformer์ in-context learning ๋ฅ๋ ฅ์ ๋ณด์ฌ์ค
|
| 136 |
+
"""
|
| 137 |
+
model.eval()
|
| 138 |
+
seq_len = None
|
| 139 |
+
position_loss_sum = None
|
| 140 |
+
position_count = None
|
| 141 |
+
|
| 142 |
+
for i, batch in enumerate(dataloader):
|
| 143 |
+
if i >= max_batches:
|
| 144 |
+
break
|
| 145 |
+
|
| 146 |
+
input_ids = batch["input_ids"].to(device)
|
| 147 |
+
targets = batch["targets"].to(device)
|
| 148 |
+
B, S = targets.shape
|
| 149 |
+
|
| 150 |
+
if seq_len is None:
|
| 151 |
+
seq_len = S
|
| 152 |
+
position_loss_sum = torch.zeros(S, device=device)
|
| 153 |
+
position_count = torch.zeros(S, device=device)
|
| 154 |
+
|
| 155 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 156 |
+
logits, _ = model(input_ids)
|
| 157 |
+
|
| 158 |
+
# (B, S) ํํ์ ํ ํฐ๋ณ loss
|
| 159 |
+
loss_per_token = F.cross_entropy(
|
| 160 |
+
logits.view(-1, logits.size(-1)),
|
| 161 |
+
targets.view(-1),
|
| 162 |
+
ignore_index=-100,
|
| 163 |
+
reduction="none",
|
| 164 |
+
).view(B, S)
|
| 165 |
+
|
| 166 |
+
valid_mask = (targets != -100).float()
|
| 167 |
+
position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
|
| 168 |
+
position_count += valid_mask.sum(dim=0)
|
| 169 |
+
|
| 170 |
+
# ์์น๋ณ ํ๊ท loss
|
| 171 |
+
position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
|
| 172 |
+
return position_avg_loss
|
llm_lab/evaluation/runner.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ๊ฐ ์คํ ํฌํผ (Quick Start)."""
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from llm_lab.config import EvalConfig
|
| 10 |
+
from .full_evaluator import FullEvaluator
|
| 11 |
+
from .checklist import InsightChecklist
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_evaluation(
|
| 15 |
+
model: nn.Module,
|
| 16 |
+
tokenizer: Any,
|
| 17 |
+
val_dataloader: DataLoader,
|
| 18 |
+
device: torch.device = None,
|
| 19 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 20 |
+
metrics_history: Optional[Dict[str, list]] = None,
|
| 21 |
+
config: Optional[EvalConfig] = None,
|
| 22 |
+
) -> Dict[str, Any]:
|
| 23 |
+
"""ํ๊ฐ๋ฅผ ํ ๋ฒ์ ์คํํฉ๋๋ค.
|
| 24 |
+
|
| 25 |
+
์ฌ์ฉ๋ฒ (Colab):
|
| 26 |
+
```python
|
| 27 |
+
from llm_lab.evaluation import run_evaluation
|
| 28 |
+
|
| 29 |
+
# ํ์ต ์๋ฃ ํ
|
| 30 |
+
report = run_evaluation(
|
| 31 |
+
model=trainer.model,
|
| 32 |
+
tokenizer=tokenizer,
|
| 33 |
+
val_dataloader=val_dl,
|
| 34 |
+
metrics_history=trainer.metrics.history,
|
| 35 |
+
)
|
| 36 |
+
```
|
| 37 |
+
"""
|
| 38 |
+
if device is None:
|
| 39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
|
| 41 |
+
evaluator = FullEvaluator(
|
| 42 |
+
model=model,
|
| 43 |
+
tokenizer=tokenizer,
|
| 44 |
+
val_dataloader=val_dataloader,
|
| 45 |
+
device=device,
|
| 46 |
+
config=config,
|
| 47 |
+
dtype=dtype,
|
| 48 |
+
metrics_history=metrics_history,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
report = evaluator.run_full_evaluation()
|
| 52 |
+
|
| 53 |
+
# ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ
|
| 54 |
+
InsightChecklist.run_checklist(report, metrics_history)
|
| 55 |
+
|
| 56 |
+
return report
|
llm_lab/evaluation/scaling.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scaling Law ๋ถ์๊ธฐ."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
HAS_MATPLOTLIB = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HAS_MATPLOTLIB = False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import numpy as np
|
| 16 |
+
HAS_NUMPY = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
HAS_NUMPY = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ScalingAnalyzer:
|
| 22 |
+
"""10M โ 100M โ 1B ๋ชจ๋ธ์ Scaling Law๋ฅผ ๋ถ์ํฉ๋๋ค.
|
| 23 |
+
|
| 24 |
+
Chinchilla Scaling Law (2022):
|
| 25 |
+
- ์ต์ ํ์ต: ํ ํฐ ์ โ 20 ร ํ๋ผ๋ฏธํฐ ์
|
| 26 |
+
- Loss โ N^(-ฮฑ) ร D^(-ฮฒ) (N=ํ๋ผ๋ฏธํฐ, D=๋ฐ์ดํฐ)
|
| 27 |
+
- ฮฑ โ 0.076, ฮฒ โ 0.095 (๋
ผ๋ฌธ ๊ธฐ์ค)
|
| 28 |
+
|
| 29 |
+
์ด ๋ถ์์ ๋ชฉ์ :
|
| 30 |
+
- ์ฐ๋ฆฌ ๋ชจ๋ธ์ด Scaling Law๋ฅผ ๋ฐ๋ฅด๋์ง ํ์ธ
|
| 31 |
+
- ๋ ํฐ ๋ชจ๋ธ/๋ ๋ง์ ๋ฐ์ดํฐ์ ํจ๊ณผ๋ฅผ ์์ธก
|
| 32 |
+
- ์ปดํจํ
์์ ๋ฐฐ๋ถ์ ์ต์ ์ ์ดํด
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, save_dir: str = "./eval_results"):
|
| 36 |
+
self.save_dir = Path(save_dir)
|
| 37 |
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
def analyze(
|
| 40 |
+
self,
|
| 41 |
+
model_results: List[Dict[str, Any]],
|
| 42 |
+
) -> Dict[str, Any]:
|
| 43 |
+
"""์ฌ๋ฌ ๋ชจ๋ธ ํฌ๊ธฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ต ๋ถ์ํฉ๋๋ค.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
model_results: [
|
| 47 |
+
{"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
|
| 48 |
+
{"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
|
| 49 |
+
{"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1},
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
๋ถ์ ๊ฒฐ๊ณผ ๋์
๋๋ฆฌ
|
| 54 |
+
"""
|
| 55 |
+
if len(model_results) < 2:
|
| 56 |
+
print("โ ๏ธ Scaling ๋ถ์์๋ ์ต์ 2๊ฐ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ํ์ํฉ๋๋ค.")
|
| 57 |
+
return {}
|
| 58 |
+
|
| 59 |
+
print("\n" + "=" * 70)
|
| 60 |
+
print("๐ Scaling Law ๋ถ์")
|
| 61 |
+
print("=" * 70)
|
| 62 |
+
|
| 63 |
+
# โโ ๊ฒฐ๊ณผ ํ
์ด๋ธ โโ
|
| 64 |
+
print(f"\n {'๋ชจ๋ธ':<8} {'ํ๋ผ๋ฏธํฐ':>12} {'ํ ํฐ':>10} {'Loss':>8} {'PPL':>8}")
|
| 65 |
+
print(f" {'โ'*52}")
|
| 66 |
+
for r in model_results:
|
| 67 |
+
params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
|
| 68 |
+
tokens_str = f"{r['tokens']/1e9:.1f}B"
|
| 69 |
+
print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
|
| 70 |
+
|
| 71 |
+
# โโ Scaling ํจ์จ ๊ณ์ฐ โโ
|
| 72 |
+
analysis = {"models": model_results, "scaling_efficiency": []}
|
| 73 |
+
|
| 74 |
+
for i in range(1, len(model_results)):
|
| 75 |
+
prev = model_results[i-1]
|
| 76 |
+
curr = model_results[i]
|
| 77 |
+
|
| 78 |
+
param_ratio = curr["params"] / prev["params"]
|
| 79 |
+
loss_reduction = prev["loss"] - curr["loss"]
|
| 80 |
+
ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"]
|
| 81 |
+
|
| 82 |
+
efficiency = {
|
| 83 |
+
"from": prev["name"],
|
| 84 |
+
"to": curr["name"],
|
| 85 |
+
"param_multiplier": round(param_ratio, 1),
|
| 86 |
+
"loss_reduction": round(loss_reduction, 4),
|
| 87 |
+
"ppl_reduction_pct": round(ppl_reduction * 100, 1),
|
| 88 |
+
}
|
| 89 |
+
analysis["scaling_efficiency"].append(efficiency)
|
| 90 |
+
|
| 91 |
+
print(f"\n {prev['name']} โ {curr['name']}:")
|
| 92 |
+
print(f" ํ๋ผ๋ฏธํฐ ร{param_ratio:.1f}")
|
| 93 |
+
print(f" Loss ๊ฐ์: {loss_reduction:.4f}")
|
| 94 |
+
print(f" PPL ๊ฐ์: {ppl_reduction*100:.1f}%")
|
| 95 |
+
|
| 96 |
+
# โโ Chinchilla ์ต์ ์ฑ ์ฒดํฌ โโ
|
| 97 |
+
print(f"\n Chinchilla ์ต์ ์ฑ ์ฒดํฌ (ํ ํฐ โ 20 ร ํ๋ผ๋ฏธํฐ):")
|
| 98 |
+
for r in model_results:
|
| 99 |
+
actual_ratio = r["tokens"] / r["params"]
|
| 100 |
+
status = "โ
์ต์ ๋ฒ์" if 15 <= actual_ratio <= 25 else "โ ๏ธ ๋ฒ์ ๋ฐ"
|
| 101 |
+
print(f" {r['name']}: ํ ํฐ/ํ๋ผ๋ฏธํฐ = {actual_ratio:.1f}x "
|
| 102 |
+
f"(์ต์ : 20x) {status}")
|
| 103 |
+
|
| 104 |
+
analysis["chinchilla_ratios"] = [
|
| 105 |
+
{"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
|
| 106 |
+
for r in model_results
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
return analysis
|
| 110 |
+
|
| 111 |
+
def plot_scaling_curves(
|
| 112 |
+
self,
|
| 113 |
+
model_results: List[Dict[str, Any]],
|
| 114 |
+
save_path: Optional[str] = None,
|
| 115 |
+
):
|
| 116 |
+
"""Scaling ๊ณก์ ์ ์๊ฐํํฉ๋๋ค."""
|
| 117 |
+
if not HAS_MATPLOTLIB or not HAS_NUMPY:
|
| 118 |
+
print("โ ๏ธ matplotlib/numpy๊ฐ ํ์ํฉ๋๋ค: pip install matplotlib numpy")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 122 |
+
|
| 123 |
+
params = [r["params"] for r in model_results]
|
| 124 |
+
losses = [r["loss"] for r in model_results]
|
| 125 |
+
ppls = [r["ppl"] for r in model_results]
|
| 126 |
+
names = [r["name"] for r in model_results]
|
| 127 |
+
|
| 128 |
+
# โโ Loss vs Parameters (log-log) โโ
|
| 129 |
+
ax = axes[0]
|
| 130 |
+
ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10)
|
| 131 |
+
for p, l, n in zip(params, losses, names):
|
| 132 |
+
ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9)
|
| 133 |
+
ax.set_xlabel("Parameters", fontsize=12)
|
| 134 |
+
ax.set_ylabel("Validation Loss", fontsize=12)
|
| 135 |
+
ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold")
|
| 136 |
+
ax.grid(True, alpha=0.3)
|
| 137 |
+
|
| 138 |
+
# โโ PPL vs Parameters (log-log) โโ
|
| 139 |
+
ax = axes[1]
|
| 140 |
+
ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10)
|
| 141 |
+
for p, pp, n in zip(params, ppls, names):
|
| 142 |
+
ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9)
|
| 143 |
+
ax.set_xlabel("Parameters", fontsize=12)
|
| 144 |
+
ax.set_ylabel("Perplexity", fontsize=12)
|
| 145 |
+
ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold")
|
| 146 |
+
ax.grid(True, alpha=0.3)
|
| 147 |
+
|
| 148 |
+
plt.tight_layout()
|
| 149 |
+
|
| 150 |
+
save_path = save_path or str(self.save_dir / "scaling_curves.png")
|
| 151 |
+
fig.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 152 |
+
print(f"\n ๐ Scaling ๊ณก์ ์ ์ฅ: {save_path}")
|
| 153 |
+
plt.close(fig)
|
llm_lab/model/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋ชจ๋ธ ์ํคํ
์ฒ ๋ชจ๋ โ LLaMA-style Decoder-Only Transformer."""
|
| 2 |
+
from .norm import RMSNorm
|
| 3 |
+
from .rope import RotaryPositionalEmbedding
|
| 4 |
+
from .attention import GroupedQueryAttention
|
| 5 |
+
from .feedforward import SwiGLUFeedForward
|
| 6 |
+
from .transformer_block import TransformerBlock
|
| 7 |
+
from .llm_model import LLMModel
|
| 8 |
+
from .utils import count_parameters_detailed, estimate_memory_gb
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"RMSNorm", "RotaryPositionalEmbedding", "GroupedQueryAttention",
|
| 12 |
+
"SwiGLUFeedForward", "TransformerBlock", "LLMModel",
|
| 13 |
+
"count_parameters_detailed", "estimate_memory_gb",
|
| 14 |
+
]
|
llm_lab/model/attention.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grouped Query Attention (GQA)."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from llm_lab.config import ModelConfig
|
| 10 |
+
from .rope import RotaryPositionalEmbedding
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GroupedQueryAttention(nn.Module):
|
| 14 |
+
"""GQA: Multi-Head Attention์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ๋ณํ.
|
| 15 |
+
|
| 16 |
+
MHA vs GQA vs MQA:
|
| 17 |
+
- MHA (Multi-Head Attention): Q, K, V ๋ชจ๋ num_heads๊ฐ โ ๋ฉ๋ชจ๋ฆฌ ํผ
|
| 18 |
+
- MQA (Multi-Query Attention): K, V๋ 1๊ฐ ํค๋ ๊ณต์ โ ํ์ง ์ ํ ์ฐ๋ ค
|
| 19 |
+
- GQA (Grouped Query Attention): K, V๋ฅผ num_kv_heads๊ฐ๋ก ๊ทธ๋ฃนํ
|
| 20 |
+
โ MHA์ MQA์ ์ค๊ฐ, ์ข์ ํ์ง-ํจ์จ ๊ท ํ
|
| 21 |
+
|
| 22 |
+
์์ (num_heads=16, num_kv_heads=4):
|
| 23 |
+
Q ํค๋: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
|
| 24 |
+
K/V ๊ทธ๋ฃน: [ 0 , 1 , 2 , 3 ]
|
| 25 |
+
โ Q ํค๋ 4๊ฐ๊ฐ K/V ํค๋ 1๊ฐ๋ฅผ ๊ณต์
|
| 26 |
+
|
| 27 |
+
Attention ์์:
|
| 28 |
+
Attention(Q, K, V) = softmax(QยทK^T / โd_k) ยท V
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.config = config
|
| 34 |
+
self.head_dim = config.head_dim
|
| 35 |
+
self.num_heads = config.num_heads
|
| 36 |
+
self.num_kv_heads = config.num_kv_heads
|
| 37 |
+
self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
|
| 38 |
+
|
| 39 |
+
# Q/K/V ํ๋ก์ ์
|
| 40 |
+
# Q: hidden_dim โ num_heads ร head_dim
|
| 41 |
+
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
|
| 42 |
+
# K, V: hidden_dim โ num_kv_heads ร head_dim (Q๋ณด๋ค ์์!)
|
| 43 |
+
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 44 |
+
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
|
| 45 |
+
|
| 46 |
+
# ์ถ๋ ฅ ํ๋ก์ ์
: ๋ชจ๋ ํค๋์ ์ถ๋ ฅ์ ๋ค์ hidden_dim์ผ๋ก
|
| 47 |
+
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
| 48 |
+
|
| 49 |
+
# RoPE
|
| 50 |
+
self.rope = RotaryPositionalEmbedding(
|
| 51 |
+
dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Attention dropout (pretraining์์๋ ๋ณดํต 0)
|
| 55 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 56 |
+
|
| 57 |
+
def forward(
|
| 58 |
+
self,
|
| 59 |
+
x: torch.Tensor,
|
| 60 |
+
mask: Optional[torch.Tensor] = None,
|
| 61 |
+
position_offset: int = 0,
|
| 62 |
+
) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
x: (batch_size, seq_len, hidden_dim)
|
| 66 |
+
mask: (seq_len, seq_len) causal mask
|
| 67 |
+
position_offset: ์์น ์คํ์
(์ถ๋ก ์ ์ฌ์ฉ)
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
(batch_size, seq_len, hidden_dim)
|
| 71 |
+
"""
|
| 72 |
+
B, S, _ = x.shape
|
| 73 |
+
|
| 74 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 75 |
+
# Step 1: Q, K, V ํ๋ก์ ์
|
| 76 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 77 |
+
q = self.q_proj(x) # (B, S, num_heads ร head_dim)
|
| 78 |
+
k = self.k_proj(x) # (B, S, num_kv_heads ร head_dim)
|
| 79 |
+
v = self.v_proj(x) # (B, S, num_kv_heads ร head_dim)
|
| 80 |
+
|
| 81 |
+
# ๋ฉํฐํค๋ ํํ๋ก reshape
|
| 82 |
+
q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
| 83 |
+
# โ (B, num_heads, S, head_dim)
|
| 84 |
+
k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 85 |
+
# โ (B, num_kv_heads, S, head_dim)
|
| 86 |
+
v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 87 |
+
|
| 88 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 89 |
+
# Step 2: RoPE ์ ์ฉ (Q, K์๋ง! V์๋ ์ ์ฉํ์ง ์์)
|
| 90 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 91 |
+
# ์์น ์ ๋ณด๋ "์ด๋๋ฅผ ๋ณผ์ง"(QยทK)์๋ง ์ํฅ์ ์ค์ผ ํ๊ณ ,
|
| 92 |
+
# "๋ฌด์์ ๊ฐ์ ธ์ฌ์ง"(V)์๋ ์ํฅ์ ์ฃผ๋ฉด ์ ๋ฉ๋๋ค.
|
| 93 |
+
q, k = self.rope(q, k, position_offset)
|
| 94 |
+
|
| 95 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 96 |
+
# Step 3: GQA - KV ํค๋ ํ์ฅ (repeat)
|
| 97 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 98 |
+
# num_kv_heads=4 โ num_heads=16: ๊ฐ KV๋ฅผ 4๋ฒ ๋ฐ๋ณต
|
| 99 |
+
if self.num_kv_groups > 1:
|
| 100 |
+
k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
|
| 101 |
+
v = self._repeat_kv(v)
|
| 102 |
+
|
| 103 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 104 |
+
# Step 4: Scaled Dot-Product Attention
|
| 105 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 106 |
+
# PyTorch >= 2.0์ ์ต์ ํ๋ ๊ตฌํ ์ฌ์ฉ (Flash Attention ์๋ ์ ์ฉ)
|
| 107 |
+
attn_out = F.scaled_dot_product_attention(
|
| 108 |
+
q, k, v,
|
| 109 |
+
attn_mask=mask,
|
| 110 |
+
dropout_p=self.config.dropout if self.training else 0.0,
|
| 111 |
+
is_causal=(mask is None), # mask๊ฐ ์์ผ๋ฉด ์๋ causal masking
|
| 112 |
+
)
|
| 113 |
+
# โ (B, num_heads, S, head_dim)
|
| 114 |
+
|
| 115 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 116 |
+
# Step 5: ํค๋ ํฉ์น๊ธฐ + ์ถ๋ ฅ ํ๋ก์ ์
|
| 117 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 118 |
+
attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
|
| 119 |
+
# โ (B, S, num_heads ร head_dim)
|
| 120 |
+
|
| 121 |
+
return self.o_proj(attn_out) # โ (B, S, hidden_dim)
|
| 122 |
+
|
| 123 |
+
def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 124 |
+
"""KV ํค๋๋ฅผ Q ํค๋ ์์ ๋ง๊ฒ ๋ฐ๋ณตํฉ๋๋ค.
|
| 125 |
+
|
| 126 |
+
(B, num_kv_heads, S, head_dim) โ (B, num_heads, S, head_dim)
|
| 127 |
+
|
| 128 |
+
์: num_kv_heads=4, num_kv_groups=4
|
| 129 |
+
[kv0, kv1, kv2, kv3] โ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
|
| 130 |
+
"""
|
| 131 |
+
B, H_kv, S, D = x.shape
|
| 132 |
+
x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
|
| 133 |
+
x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
|
| 134 |
+
return x.reshape(B, self.num_heads, S, D)
|
llm_lab/model/feedforward.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SwiGLU Feed-Forward Network."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from llm_lab.config import ModelConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SwiGLUFeedForward(nn.Module):
|
| 11 |
+
"""SwiGLU: Gated Linear Unit with Swish ํ์ฑํ ํจ์.
|
| 12 |
+
|
| 13 |
+
๊ธฐ์กด FFN:
|
| 14 |
+
FFN(x) = ReLU(xยทW1 + b1)ยทW2 + b2
|
| 15 |
+
โ ๋จ์ํ ๋น์ ํ ๋ณํ
|
| 16 |
+
|
| 17 |
+
SwiGLU FFN:
|
| 18 |
+
SwiGLU(x) = (Swish(xยทW_gate) โ (xยทW_up)) ยท W_down
|
| 19 |
+
โ ๊ฒ์ดํ
๋ฉ์ปค๋์ฆ์ผ๋ก ์ ๋ณด ํ๋ฆ์ ์ ์ด
|
| 20 |
+
|
| 21 |
+
์ SwiGLU๊ฐ ๋ ์ข์๊ฐ?
|
| 22 |
+
- Swish(x) = x ยท sigmoid(x): ๋ถ๋๋ฌ์ด ํ์ฑํ, ์์ ์์ญ ์ผ๋ถ ํ์ฉ
|
| 23 |
+
- Gate ๋ฒกํฐ๊ฐ "์ด๋ค ์ ๋ณด๋ฅผ ํต๊ณผ์ํฌ์ง" ํ์ต
|
| 24 |
+
- PaLM, LLaMA ๋ฑ์์ ReLU FFN ๋๋น ์ผ๊ด๋ ์ฑ๋ฅ ํฅ์ ๋ณด๊ณ
|
| 25 |
+
|
| 26 |
+
์ฐธ๊ณ : W_gate์ W_up ๋ ๊ฐ์ up-projection์ด ์์ด์
|
| 27 |
+
ํ๋ผ๋ฏธํฐ ์๊ฐ ๊ธฐ์กด FFN ๋๋น 1.5๋ฐฐ์ด์ง๋ง, intermediate_dim์
|
| 28 |
+
์กฐ์ ํ์ฌ ์ด ํ๋ผ๋ฏธํฐ ์๋ฅผ ๋ง์ถฅ๋๋ค.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig):
|
| 32 |
+
super().__init__()
|
| 33 |
+
# ๊ฒ์ดํธ ํ๋ก์ ์
: hidden_dim โ intermediate_dim
|
| 34 |
+
self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 35 |
+
# ์
ํ๋ก์ ์
: hidden_dim โ intermediate_dim
|
| 36 |
+
self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
| 37 |
+
# ๋ค์ด ํ๋ก์ ์
: intermediate_dim โ hidden_dim
|
| 38 |
+
self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
# SwiGLU(x) = (Swish(gate(x)) โ up(x)) ยท down
|
| 42 |
+
#
|
| 43 |
+
# 1) gate: ์ด๋ค ์ ๋ณด๋ฅผ ํต๊ณผ์ํฌ์ง ๊ฒฐ์ (Swish ํ์ฑํ)
|
| 44 |
+
gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
|
| 45 |
+
# 2) up: ์ ๋ณด๋ฅผ ๊ณ ์ฐจ์์ผ๋ก ์ฌ์
|
| 46 |
+
up = self.up_proj(x)
|
| 47 |
+
# 3) element-wise ๊ณฑ (๊ฒ์ดํ
) โ ๋ค์ ์๋ ์ฐจ์์ผ๋ก
|
| 48 |
+
return self.down_proj(gate * up)
|
llm_lab/model/llm_model.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Full Transformer Model (LLaMA-style)."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from llm_lab.config import ModelConfig
|
| 11 |
+
from .norm import RMSNorm
|
| 12 |
+
from .transformer_block import TransformerBlock
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMModel(nn.Module):
|
| 16 |
+
"""1B ํ๋ผ๋ฏธํฐ LLaMA-style Decoder-Only Transformer.
|
| 17 |
+
|
| 18 |
+
์ ์ฒด ๊ตฌ์กฐ:
|
| 19 |
+
Input Token IDs
|
| 20 |
+
โ Token Embedding
|
| 21 |
+
โ [TransformerBlock] ร num_layers (+ Activation Checkpointing)
|
| 22 |
+
โ RMSNorm (์ต์ข
)
|
| 23 |
+
โ Linear Head (โ vocab logits)
|
| 24 |
+
|
| 25 |
+
Weight Tying:
|
| 26 |
+
- ์
๋ ฅ Embedding๊ณผ ์ถ๋ ฅ Linear Head์ ๊ฐ์ค์น๋ฅผ ๊ณต์
|
| 27 |
+
- ํ๋ผ๋ฏธํฐ ์ ์ ์ฝ (~65M) + ์ฑ๋ฅ ์ ์ง/ํฅ์
|
| 28 |
+
- ์ง๊ด: "๋จ์ด์ ์๋ฏธ ํํ"๊ณผ "๋จ์ด ์์ธก"์ด ๊ฐ์ ๊ณต๊ฐ์ ์ฌ์ฉ
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.config = config
|
| 34 |
+
|
| 35 |
+
# โโ Token Embedding โโ
|
| 36 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
|
| 37 |
+
|
| 38 |
+
# โโ Transformer Blocks โโ
|
| 39 |
+
self.layers = nn.ModuleList([
|
| 40 |
+
TransformerBlock(config, layer_idx=i)
|
| 41 |
+
for i in range(config.num_layers)
|
| 42 |
+
])
|
| 43 |
+
|
| 44 |
+
# โโ ์ต์ข
์ ๊ทํ โโ
|
| 45 |
+
self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 46 |
+
|
| 47 |
+
# โโ ์ถ๋ ฅ ํค๋ (Weight Tying) โโ
|
| 48 |
+
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
| 49 |
+
# Weight Tying: lm_head์ ๊ฐ์ค์น = token_embedding์ ๊ฐ์ค์น
|
| 50 |
+
self.lm_head.weight = self.token_embedding.weight
|
| 51 |
+
|
| 52 |
+
# ๊ฐ์ค์น ์ด๊ธฐํ
|
| 53 |
+
self._init_weights()
|
| 54 |
+
|
| 55 |
+
def _init_weights(self):
|
| 56 |
+
"""๊ฐ์ค์น ์ด๊ธฐํ ์ ๋ต.
|
| 57 |
+
|
| 58 |
+
์ ์ด๊ธฐํ๊ฐ ์ค์ํ๊ฐ?
|
| 59 |
+
- ๋๋ฌด ํฌ๋ฉด: ํ์ฑํ ํญ๋ฐ โ NaN
|
| 60 |
+
- ๋๋ฌด ์์ผ๋ฉด: gradient ์๋ฉธ โ ํ์ต ์ ์ฒด
|
| 61 |
+
- ์ ์ ํ ์ด๊ธฐํ: ๊ฐ ๋ ์ด์ด์ ์ถ๋ ฅ ๋ถ์ฐ์ ์ผ์ ํ๊ฒ ์ ์ง
|
| 62 |
+
|
| 63 |
+
GPT-2 ์คํ์ผ ์ด๊ธฐํ:
|
| 64 |
+
- ์ผ๋ฐ Linear: N(0, 0.02)
|
| 65 |
+
- Residual projection: N(0, 0.02 / โ(2 ร num_layers))
|
| 66 |
+
โ ๋ ์ด์ด๊ฐ ๊น์ด์ง์๋ก residual ๊ธฐ์ฌ๋ฅผ ์ค์ฌ ์์ ํ
|
| 67 |
+
"""
|
| 68 |
+
std = 0.02
|
| 69 |
+
residual_std = std / math.sqrt(2 * self.config.num_layers)
|
| 70 |
+
|
| 71 |
+
for module in self.modules():
|
| 72 |
+
if isinstance(module, nn.Linear):
|
| 73 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 74 |
+
if module.bias is not None:
|
| 75 |
+
nn.init.zeros_(module.bias)
|
| 76 |
+
elif isinstance(module, nn.Embedding):
|
| 77 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 78 |
+
|
| 79 |
+
# Residual projection ๋ ์ด์ด์ ์ถ์๋ ์ด๊ธฐํ ์ ์ฉ
|
| 80 |
+
for layer in self.layers:
|
| 81 |
+
nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
|
| 82 |
+
nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
input_ids: torch.Tensor,
|
| 87 |
+
targets: Optional[torch.Tensor] = None,
|
| 88 |
+
position_offset: int = 0,
|
| 89 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 90 |
+
"""
|
| 91 |
+
Args:
|
| 92 |
+
input_ids: (batch_size, seq_len) - ํ ํฐ ID
|
| 93 |
+
targets: (batch_size, seq_len) - ์ ๋ต ํ ํฐ ID (ํ์ต ์)
|
| 94 |
+
position_offset: ์์น ์คํ์
(์ถ๋ก ์)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
logits: (batch_size, seq_len, vocab_size)
|
| 98 |
+
loss: ์ค์นผ๋ผ (targets ์ ๊ณต ์) ๋๋ None
|
| 99 |
+
"""
|
| 100 |
+
B, S = input_ids.shape
|
| 101 |
+
|
| 102 |
+
# โโ Step 1: Token Embedding โโ
|
| 103 |
+
# ๊ฐ ํ ํฐ ID๋ฅผ hidden_dim ์ฐจ์์ ๋ฒกํฐ๋ก ๋ณํ
|
| 104 |
+
h = self.token_embedding(input_ids) # (B, S, hidden_dim)
|
| 105 |
+
|
| 106 |
+
# โโ Step 2: Transformer Blocks โโ
|
| 107 |
+
# Activation Checkpointing: ํ์ต ์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 108 |
+
# (์ค๊ฐ ํ์ฑํ๋ฅผ ์ ์ฅํ์ง ์๊ณ , backward ์ ์ฌ๊ณ์ฐ)
|
| 109 |
+
for layer in self.layers:
|
| 110 |
+
if self.training and torch.is_grad_enabled():
|
| 111 |
+
# Activation Checkpointing ์ ์ฉ
|
| 112 |
+
h = torch.utils.checkpoint.checkpoint(
|
| 113 |
+
layer, h, None, position_offset,
|
| 114 |
+
use_reentrant=False, # PyTorch >= 2.0 ๊ถ์ฅ
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
h = layer(h, mask=None, position_offset=position_offset)
|
| 118 |
+
|
| 119 |
+
# โโ Step 3: ์ต์ข
์ ๊ทํ โโ
|
| 120 |
+
h = self.final_norm(h)
|
| 121 |
+
|
| 122 |
+
# โโ Step 4: ์ถ๋ ฅ ๋ก์ง ๊ณ์ฐ โโ
|
| 123 |
+
logits = self.lm_head(h) # (B, S, vocab_size)
|
| 124 |
+
|
| 125 |
+
# โโ Step 5: Loss ๊ณ์ฐ (ํ์ต ์) โโ
|
| 126 |
+
loss = None
|
| 127 |
+
if targets is not None:
|
| 128 |
+
# Cross-Entropy Loss: ๋ค์ ํ ํฐ ์์ธก
|
| 129 |
+
# logits: (B, S, V) โ (B*S, V)
|
| 130 |
+
# targets: (B, S) โ (B*S,)
|
| 131 |
+
loss = F.cross_entropy(
|
| 132 |
+
logits.view(-1, self.config.vocab_size),
|
| 133 |
+
targets.view(-1),
|
| 134 |
+
ignore_index=-100, # ํจ๋ฉ ํ ํฐ ๋ฌด์
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return logits, loss
|
| 138 |
+
|
| 139 |
+
def count_parameters(self, trainable_only: bool = True) -> int:
|
| 140 |
+
"""๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ."""
|
| 141 |
+
if trainable_only:
|
| 142 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 143 |
+
return sum(p.numel() for p in self.parameters())
|
| 144 |
+
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def generate(
|
| 147 |
+
self,
|
| 148 |
+
input_ids: torch.Tensor,
|
| 149 |
+
max_new_tokens: int = 100,
|
| 150 |
+
temperature: float = 1.0,
|
| 151 |
+
top_k: int = 50,
|
| 152 |
+
top_p: float = 0.9,
|
| 153 |
+
) -> torch.Tensor:
|
| 154 |
+
"""ํ
์คํธ ์์ฑ (์ถ๋ก ).
|
| 155 |
+
|
| 156 |
+
Autoregressive ์์ฑ: ํ ํ ํฐ์ฉ ์์ธกํ์ฌ ์ด์ด๋ถ์ด๊ธฐ.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
input_ids: (1, prompt_len) - ์ด๊ธฐ ํ๋กฌํํธ
|
| 160 |
+
max_new_tokens: ์์ฑํ ์ต๋ ํ ํฐ ์
|
| 161 |
+
temperature: ํ๋ฅ ๋ถํฌ ๋ ์นด๋ก์ ์กฐ์ (๋ฎ์์๋ก ๋ณด์์ )
|
| 162 |
+
top_k: ํ๋ฅ ์์ k๊ฐ๋ง ๊ณ ๋ ค
|
| 163 |
+
top_p: ๋์ ํ๋ฅ p๊น์ง๋ง ๊ณ ๋ ค (nucleus sampling)
|
| 164 |
+
"""
|
| 165 |
+
self.eval()
|
| 166 |
+
generated = input_ids
|
| 167 |
+
|
| 168 |
+
for _ in range(max_new_tokens):
|
| 169 |
+
# ํ์ฌ ์ํ์ค๊ฐ max_seq_len์ ์ด๊ณผํ๋ฉด ์๋ผ๋ด๊ธฐ
|
| 170 |
+
ctx = generated[:, -self.config.max_seq_len:]
|
| 171 |
+
|
| 172 |
+
# Forward pass
|
| 173 |
+
logits, _ = self(ctx)
|
| 174 |
+
# ๋ง์ง๋ง ํ ํฐ์ logits๋ง ์ฌ์ฉ (๋ค์ ํ ํฐ ์์ธก)
|
| 175 |
+
next_logits = logits[:, -1, :] / temperature
|
| 176 |
+
|
| 177 |
+
# โโ Top-K ํํฐ๋ง โโ
|
| 178 |
+
if top_k > 0:
|
| 179 |
+
top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
|
| 180 |
+
min_top_k = top_k_values[:, -1].unsqueeze(-1)
|
| 181 |
+
next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
|
| 182 |
+
|
| 183 |
+
# โโ Top-P (Nucleus) ํํฐ๋ง โโ
|
| 184 |
+
if top_p < 1.0:
|
| 185 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 186 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 187 |
+
# ๋์ ํ๋ฅ ์ด top_p๋ฅผ ์ด๊ณผํ๋ ํ ํฐ ์ ๊ฑฐ
|
| 188 |
+
remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
| 189 |
+
sorted_logits[remove_mask] = float("-inf")
|
| 190 |
+
# ์๋ ์์๋ก ๋ณต์
|
| 191 |
+
next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 192 |
+
|
| 193 |
+
# ํ๋ฅ ๋ถํฌ์์ ์ํ๋ง
|
| 194 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 195 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
|
| 196 |
+
|
| 197 |
+
# ์์ฑ๋ ํ ํฐ ์ด์ด๋ถ์ด๊ธฐ
|
| 198 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 199 |
+
|
| 200 |
+
return generated
|
llm_lab/model/norm.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RMSNorm (Root Mean Square Layer Normalization)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RMSNorm(nn.Module):
|
| 8 |
+
"""RMSNorm: LayerNorm์ ๊ฒฝ๋ํ ๋ฒ์ .
|
| 9 |
+
|
| 10 |
+
์ผ๋ฐ LayerNorm๊ณผ์ ์ฐจ์ด:
|
| 11 |
+
- ํ๊ท (mean)์ ๋นผ์ง ์์ โ ์ฐ์ฐ ์ ์ฝ
|
| 12 |
+
- ๋ถ์ฐ ๋์ RMS(Root Mean Square)๋ก ์ ๊ทํ
|
| 13 |
+
- bias ํ๋ผ๋ฏธํฐ ์์
|
| 14 |
+
|
| 15 |
+
์์:
|
| 16 |
+
RMSNorm(x) = (x / RMS(x)) * ฮณ
|
| 17 |
+
RMS(x) = sqrt(mean(xยฒ) + ฮต)
|
| 18 |
+
|
| 19 |
+
์ ์ ๊ทํ๊ฐ ํ์ํ๊ฐ?
|
| 20 |
+
โ ๋ ์ด์ด๋ฅผ ๊น๊ฒ ์์ผ๋ฉด ํ์ฑํ ๊ฐ์ ์ค์ผ์ผ์ด ํญ๋ฐํ๊ฑฐ๋ ์๋ฉธํฉ๋๋ค.
|
| 21 |
+
โ ์ ๊ทํ๋ก ๊ฐ ๋ ์ด์ด์ ์
๋ ฅ์ ์์ ์ ์ธ ๋ฒ์๋ก ์ ์งํฉ๋๋ค.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.eps = eps
|
| 27 |
+
# ฮณ (gamma): ํ์ต ๊ฐ๋ฅํ ์ค์ผ์ผ ํ๋ผ๋ฏธํฐ, 1๋ก ์ด๊ธฐํ
|
| 28 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
# 1) ์
๋ ฅ์ float32๋ก ๋ณํ (์์น ์์ ์ฑ)
|
| 32 |
+
# bf16/fp16 ์ํ์์ ์ ๊ณฑํฉ์ ๊ตฌํ๋ฉด ์ค๋ฒํ๋ก์ฐ ์ํ
|
| 33 |
+
x_float = x.float()
|
| 34 |
+
|
| 35 |
+
# 2) RMS ๊ณ์ฐ: sqrt(mean(xยฒ) + ฮต)
|
| 36 |
+
rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 37 |
+
# rsqrt = 1/sqrt(x) โ ๋๋์
๋์ ๊ณฑ์
์ผ๋ก ๋์ฒด (๋ ๋น ๋ฆ)
|
| 38 |
+
|
| 39 |
+
# 3) ์ ๊ทํ ํ ์๋ dtype์ผ๋ก ๋ณต์, ์ค์ผ์ผ ์ ์ฉ
|
| 40 |
+
return (x_float * rms).to(x.dtype) * self.weight
|
llm_lab/model/rope.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rotary Positional Embedding (RoPE)."""
|
| 2 |
+
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RotaryPositionalEmbedding(nn.Module):
|
| 10 |
+
"""RoPE: ํ์ ํ๋ ฌ์ ์ด์ฉํ ์๋ ์์น ์ธ์ฝ๋ฉ.
|
| 11 |
+
|
| 12 |
+
ํต์ฌ ์์ด๋์ด:
|
| 13 |
+
- ๊ฐ ์ฐจ์ ์(2i, 2i+1)์ 2D ํ๋ฉด์ ์ขํ๋ก ๋ณด๊ณ ,
|
| 14 |
+
์์น(position)์ ๋น๋กํ ๊ฐ๋๋งํผ ํ์ ์ํต๋๋ค.
|
| 15 |
+
- ๋ ํ ํฐ์ ์ดํ
์
์ค์ฝ์ด(QยทK)๋ ์๋ ๊ฑฐ๋ฆฌ์๋ง ์์กดํ๊ฒ ๋ฉ๋๋ค.
|
| 16 |
+
|
| 17 |
+
์ RoPE์ธ๊ฐ?
|
| 18 |
+
- ์ ๋ ์์น ์๋ฒ ๋ฉ: ๊ฐ ์์น์ ๊ณ ์ ๋ฒกํฐ๋ฅผ ๋ํจ โ ๊ธธ์ด ์ผ๋ฐํ ์ด๋ ค์
|
| 19 |
+
- ์๋ ์์น ์๋ฒ ๋ฉ: ๊ตฌํ ๋ณต์ก, ์ถ๊ฐ ํ๋ผ๋ฏธํฐ ํ์
|
| 20 |
+
- RoPE: ํ๋ผ๋ฏธํฐ ์์ด, ์์ฐ์ค๋ฝ๊ฒ ์๋ ์์น ์ ๋ณด ์ธ์ฝ๋ฉ
|
| 21 |
+
|
| 22 |
+
์์:
|
| 23 |
+
ฮธ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
|
| 24 |
+
RoPE(x, pos) = x๋ฅผ ๊ฐ ์ฐจ์ ์์์ pos ร ฮธ_i ๋งํผ ํ์
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.dim = dim
|
| 30 |
+
self.max_seq_len = max_seq_len
|
| 31 |
+
self.theta = theta
|
| 32 |
+
|
| 33 |
+
# ์ฃผํ์ ๋ฒกํฐ ๋ฏธ๋ฆฌ ๊ณ์ฐ (ํ์ต ๋ถํ์ โ buffer๋ก ๋ฑ๋ก)
|
| 34 |
+
# freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
|
| 35 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 36 |
+
self.register_buffer("freqs", freqs, persistent=False)
|
| 37 |
+
|
| 38 |
+
# (max_seq_len, dim/2) ํฌ๊ธฐ์ cos/sin ํ
์ด๋ธ ๋ฏธ๋ฆฌ ๊ณ์ฐ
|
| 39 |
+
self._build_cache(max_seq_len)
|
| 40 |
+
|
| 41 |
+
def _build_cache(self, seq_len: int):
|
| 42 |
+
"""cos/sin ๊ฐ์ ๋ฏธ๋ฆฌ ๊ณ์ฐํ์ฌ ์บ์ฑํฉ๋๋ค."""
|
| 43 |
+
t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
|
| 44 |
+
# outer product: (seq_len,) ร (dim/2,) โ (seq_len, dim/2)
|
| 45 |
+
angles = torch.outer(t, self.freqs)
|
| 46 |
+
self.register_buffer("cos_cached", angles.cos(), persistent=False)
|
| 47 |
+
self.register_buffer("sin_cached", angles.sin(), persistent=False)
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
|
| 51 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
+
"""Q, K์ ํ์ ๋ณํ์ ์ ์ฉํฉ๋๋ค.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
q: (batch, num_heads, seq_len, head_dim)
|
| 56 |
+
k: (batch, num_kv_heads, seq_len, head_dim)
|
| 57 |
+
position_offset: ์ํ์ค ์์ ์์น ์คํ์
(์ถ๋ก ์ KV ์บ์ ์ฌ์ฉ ์)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
ํ์ ๋ณํ์ด ์ ์ฉ๋ (q_rotated, k_rotated)
|
| 61 |
+
"""
|
| 62 |
+
seq_len = q.shape[2]
|
| 63 |
+
|
| 64 |
+
# ํ์ ์ ์บ์ ํ์ฅ
|
| 65 |
+
if position_offset + seq_len > self.cos_cached.shape[0]:
|
| 66 |
+
self._build_cache(position_offset + seq_len)
|
| 67 |
+
|
| 68 |
+
# ํ์ฌ ์์น์ ํด๋นํ๋ cos/sin ์ฌ๋ผ์ด์ค
|
| 69 |
+
cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
|
| 70 |
+
sin = self.sin_cached[position_offset : position_offset + seq_len]
|
| 71 |
+
|
| 72 |
+
q_rotated = self._apply_rotation(q, cos, sin)
|
| 73 |
+
k_rotated = self._apply_rotation(k, cos, sin)
|
| 74 |
+
return q_rotated, k_rotated
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def _apply_rotation(
|
| 78 |
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""ํ์ ๋ณํ ์ ์ฉ.
|
| 81 |
+
|
| 82 |
+
2D ํ์ ํ๋ ฌ:
|
| 83 |
+
[cos ฮธ, -sin ฮธ] [x1] [x1ยทcos ฮธ - x2ยทsin ฮธ]
|
| 84 |
+
[sin ฮธ, cos ฮธ] [x2] = [x1ยทsin ฮธ + x2ยทcos ฮธ]
|
| 85 |
+
|
| 86 |
+
์ด๋ฅผ ๋ฒกํฐ ์ฐ์ฐ์ผ๋ก ํจ์จ์ ์ผ๋ก ๊ตฌํํฉ๋๋ค.
|
| 87 |
+
"""
|
| 88 |
+
# x: (batch, heads, seq_len, head_dim)
|
| 89 |
+
# ์ง์/ํ์ ์ธ๋ฑ์ค๋ฅผ ๋ถ๋ฆฌ: (x0, x1, x2, x3, ...) โ (x0, x2, ...), (x1, x3, ...)
|
| 90 |
+
x_even = x[..., 0::2] # ์ง์ ์ธ๋ฑ์ค
|
| 91 |
+
x_odd = x[..., 1::2] # ํ์ ์ธ๋ฑ์ค
|
| 92 |
+
|
| 93 |
+
# ๋ธ๋ก๋์บ์คํ
์ ์ํด ์ฐจ์ ๋ง์ถค: (seq_len, dim/2) โ (1, 1, seq_len, dim/2)
|
| 94 |
+
cos = cos.unsqueeze(0).unsqueeze(0)
|
| 95 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 96 |
+
|
| 97 |
+
# ํ์ ์ ์ฉ
|
| 98 |
+
rotated_even = x_even * cos - x_odd * sin
|
| 99 |
+
rotated_odd = x_even * sin + x_odd * cos
|
| 100 |
+
|
| 101 |
+
# ๋ค์ ์ธํฐ๋ฆฌ๋น: (even0, odd0, even1, odd1, ...)
|
| 102 |
+
out = torch.stack([rotated_even, rotated_odd], dim=-1)
|
| 103 |
+
return out.flatten(-2) # ๋ง์ง๋ง ๋ ์ฐจ์์ ํฉ์ณ ์๋ shape ๋ณต์
|
llm_lab/model/transformer_block.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer Block (ํ๋์ ๋ ์ด์ด)."""
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from llm_lab.config import ModelConfig
|
| 9 |
+
from .norm import RMSNorm
|
| 10 |
+
from .attention import GroupedQueryAttention
|
| 11 |
+
from .feedforward import SwiGLUFeedForward
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TransformerBlock(nn.Module):
|
| 15 |
+
"""ํ๋์ Transformer ๋์ฝ๋ ๋ธ๋ก.
|
| 16 |
+
|
| 17 |
+
๊ตฌ์กฐ (Pre-Norm ๋ฐฉ์):
|
| 18 |
+
x โ RMSNorm โ Attention โ + (residual) โ RMSNorm โ FFN โ + (residual) โ out
|
| 19 |
+
|
| 20 |
+
Pre-Norm vs Post-Norm:
|
| 21 |
+
- Post-Norm (์๋ Transformer): LayerNorm์ด residual ์ดํ
|
| 22 |
+
โ ๊น์ ๋ชจ๋ธ์์ ํ์ต ๋ถ์์
|
| 23 |
+
- Pre-Norm (GPT-2 ์ดํ ํ์ค): LayerNorm์ด sublayer ์ด์
|
| 24 |
+
โ gradient ํ๋ฆ์ด ์ํ, ํ์ต์ด ์์ ์
|
| 25 |
+
|
| 26 |
+
Residual Connection์ ์ญํ :
|
| 27 |
+
- ์
๋ ฅ์ ์ถ๋ ฅ์ ๋ํจ โ gradient๊ฐ ๋ ์ด์ด๋ฅผ ๊ฑด๋๋ธ ์ ์๋ "๊ณ ์๋๋ก"
|
| 28 |
+
- 22๊ฐ ๋ ์ด์ด๋ฅผ ์์๋ ํ์ต์ด ๊ฐ๋ฅํ ํต์ฌ ์ด์
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: ModelConfig, layer_idx: int):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.layer_idx = layer_idx
|
| 34 |
+
|
| 35 |
+
# Pre-Norm: Attention ์ ์ ๊ทํ
|
| 36 |
+
self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 37 |
+
# Self-Attention
|
| 38 |
+
self.attention = GroupedQueryAttention(config)
|
| 39 |
+
|
| 40 |
+
# Pre-Norm: FFN ์ ์ ๊ทํ
|
| 41 |
+
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
|
| 42 |
+
# Feed-Forward Network
|
| 43 |
+
self.feed_forward = SwiGLUFeedForward(config)
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
mask: Optional[torch.Tensor] = None,
|
| 49 |
+
position_offset: int = 0,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
"""
|
| 52 |
+
Args:
|
| 53 |
+
x: (batch_size, seq_len, hidden_dim)
|
| 54 |
+
Returns:
|
| 55 |
+
(batch_size, seq_len, hidden_dim)
|
| 56 |
+
"""
|
| 57 |
+
# โโ Attention sublayer with residual โโ
|
| 58 |
+
# h = x + Attention(RMSNorm(x))
|
| 59 |
+
h = x + self.attention(self.attn_norm(x), mask, position_offset)
|
| 60 |
+
|
| 61 |
+
# โโ FFN sublayer with residual โโ
|
| 62 |
+
# out = h + FFN(RMSNorm(h))
|
| 63 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 64 |
+
|
| 65 |
+
return out
|
llm_lab/model/utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋ชจ๋ธ ์ ํธ๋ฆฌํฐ ํจ์."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
from llm_lab.config import ModelConfig
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from .llm_model import LLMModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def count_parameters_detailed(model: "LLMModel") -> dict:
|
| 15 |
+
"""๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ปดํฌ๋ํธ๋ณ๋ก ์์ธ ์ถ๋ ฅํฉ๋๋ค."""
|
| 16 |
+
total = 0
|
| 17 |
+
breakdown = {}
|
| 18 |
+
|
| 19 |
+
# Embedding
|
| 20 |
+
emb_params = model.token_embedding.weight.numel()
|
| 21 |
+
breakdown["token_embedding"] = emb_params
|
| 22 |
+
total += emb_params
|
| 23 |
+
|
| 24 |
+
# ๊ฐ ๋ ์ด์ด
|
| 25 |
+
layer_total = 0
|
| 26 |
+
layer_detail = {}
|
| 27 |
+
layer = model.layers[0]
|
| 28 |
+
|
| 29 |
+
for name, param in layer.named_parameters():
|
| 30 |
+
layer_detail[name] = param.numel()
|
| 31 |
+
layer_total += param.numel()
|
| 32 |
+
|
| 33 |
+
breakdown["per_layer"] = layer_detail
|
| 34 |
+
breakdown["per_layer_total"] = layer_total
|
| 35 |
+
breakdown["all_layers_total"] = layer_total * len(model.layers)
|
| 36 |
+
total += layer_total * len(model.layers)
|
| 37 |
+
|
| 38 |
+
# Final norm
|
| 39 |
+
norm_params = model.final_norm.weight.numel()
|
| 40 |
+
breakdown["final_norm"] = norm_params
|
| 41 |
+
total += norm_params
|
| 42 |
+
|
| 43 |
+
# LM head (weight tying์ด๋ฏ๋ก ์ค์ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ 0)
|
| 44 |
+
breakdown["lm_head"] = "weight tying (0 additional)"
|
| 45 |
+
breakdown["total"] = total
|
| 46 |
+
|
| 47 |
+
return breakdown
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
|
| 51 |
+
"""๋ชจ๋ธ์ GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ถ์ ํฉ๋๋ค.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
dtype_bytes: 2 (bf16/fp16) ๋๋ 4 (fp32)
|
| 55 |
+
"""
|
| 56 |
+
# ๋๋ต์ ์ธ ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ
|
| 57 |
+
emb = config.vocab_size * config.hidden_dim
|
| 58 |
+
per_layer = (
|
| 59 |
+
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
|
| 60 |
+
+ config.num_heads * config.head_dim * config.hidden_dim # O proj
|
| 61 |
+
+ 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
|
| 62 |
+
+ 2 * config.hidden_dim # 2 ร RMSNorm
|
| 63 |
+
)
|
| 64 |
+
total_params = emb + per_layer * config.num_layers + config.hidden_dim
|
| 65 |
+
|
| 66 |
+
model_gb = total_params * dtype_bytes / 1e9
|
| 67 |
+
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states ร fp32
|
| 68 |
+
gradient_gb = total_params * dtype_bytes / 1e9
|
| 69 |
+
|
| 70 |
+
# ํ์ฑํ ๋ฉ๋ชจ๋ฆฌ (activation checkpointing ์ ์ฉ ๊ฐ์ )
|
| 71 |
+
# ๋๋ต์ ์ถ์ : batch_size ร seq_len ร hidden_dim ร num_layers ร factor
|
| 72 |
+
activation_gb = (
|
| 73 |
+
batch_size * config.max_seq_len * config.hidden_dim * 4 # ๋ฐ์ดํธ
|
| 74 |
+
* math.sqrt(config.num_layers) # checkpointing ํจ๊ณผ
|
| 75 |
+
/ 1e9
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"total_parameters": total_params,
|
| 80 |
+
"model_weights_gb": round(model_gb, 2),
|
| 81 |
+
"optimizer_states_gb": round(optimizer_gb, 2),
|
| 82 |
+
"gradients_gb": round(gradient_gb, 2),
|
| 83 |
+
"activations_estimated_gb": round(activation_gb, 2),
|
| 84 |
+
"total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
|
| 85 |
+
}
|
llm_lab/training/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ๋ชจ๋ โ Gradient Accumulation, Mixed Precision, ์ฒดํฌํฌ์ธํธ, wandb ๋ก๊น
."""
|
| 2 |
+
from .scheduler import CosineWarmupScheduler
|
| 3 |
+
from .checkpoint import CheckpointManager
|
| 4 |
+
from .metrics import MetricsTracker
|
| 5 |
+
from .optimizer import create_optimizer
|
| 6 |
+
from .trainer import Trainer
|
| 7 |
+
from .runner import start_training
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
|
| 11 |
+
"create_optimizer", "Trainer", "start_training",
|
| 12 |
+
]
|
llm_lab/training/checkpoint.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ์ํ ์ ์ฅ/๋ณต์ ๊ด๋ฆฌ์."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from llm_lab.config import TrainConfig
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CheckpointManager:
|
| 16 |
+
"""ํ์ต ์ํ ์ ์ฅ/๋ณต์ ๊ด๋ฆฌ์.
|
| 17 |
+
|
| 18 |
+
Colab์์ ์ฒดํฌํฌ์ธํธ๊ฐ ์ค์ํ ์ด์ :
|
| 19 |
+
- ์ธ์
๋ง๋ฃ (์ต๋ ~24์๊ฐ) ์ ๋ชจ๋ ๋ฉ๋ชจ๋ฆฌ ์ํ ์๋ฉธ
|
| 20 |
+
- Google Drive์ ์ ์ฅํ๋ฉด ์ธ์
๊ฐ ์ฐ์ ํ์ต ๊ฐ๋ฅ
|
| 21 |
+
- ์ตํฐ๋ง์ด์ ์ํ๊น์ง ์ ์ฅํด์ผ AdamW ๋ชจ๋ฉํ
์ด ์ ์ง๋จ
|
| 22 |
+
|
| 23 |
+
์ ์ฅ ๋ด์ฉ:
|
| 24 |
+
- model_state_dict: ๋ชจ๋ธ ๊ฐ์ค์น
|
| 25 |
+
- optimizer_state_dict: ์ตํฐ๋ง์ด์ ์ํ (m, v ๋ชจ๋ฉํ
)
|
| 26 |
+
- step: ํ์ฌ ํ์ต ์คํ
|
| 27 |
+
- best_val_loss: ์ต์ ๊ฒ์ฆ Loss
|
| 28 |
+
- config: ํ์ต ์ค์ (์ฌํ์ฑ)
|
| 29 |
+
- rng_states: ๋๋ค ์๋ ์ํ (์์ ์ฌํ)
|
| 30 |
+
- metrics_history: ํ์ต ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 31 |
+
- wandb_run_id: wandb ์คํ ID (๋ก๊น
์ฐ์์ฑ)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, config: TrainConfig):
|
| 35 |
+
self.config = config
|
| 36 |
+
self.checkpoint_dir = Path(config.checkpoint_dir)
|
| 37 |
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
self.max_checkpoints = config.max_checkpoints
|
| 39 |
+
|
| 40 |
+
def save(
|
| 41 |
+
self,
|
| 42 |
+
model: nn.Module,
|
| 43 |
+
optimizer: torch.optim.Optimizer,
|
| 44 |
+
step: int,
|
| 45 |
+
best_val_loss: float,
|
| 46 |
+
metrics_history: Dict[str, list],
|
| 47 |
+
wandb_run_id: Optional[str] = None,
|
| 48 |
+
):
|
| 49 |
+
"""์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํฉ๋๋ค."""
|
| 50 |
+
ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
|
| 51 |
+
ckpt_path.mkdir(parents=True, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
print(f"\n๐พ ์ฒดํฌํฌ์ธํธ ์ ์ฅ: {ckpt_path}")
|
| 54 |
+
start = time.time()
|
| 55 |
+
|
| 56 |
+
# 1) ๋ชจ๋ธ ๊ฐ์ค์น (bf16 ์ํ ๊ทธ๋๋ก)
|
| 57 |
+
torch.save(model.state_dict(), ckpt_path / "model.pt")
|
| 58 |
+
|
| 59 |
+
# 2) ์ตํฐ๋ง์ด์ ์ํ (fp32 ๋ชจ๋ฉํ
ํฌํจ, ํฌ๊ธฐ ํผ)
|
| 60 |
+
torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
|
| 61 |
+
|
| 62 |
+
# 3) ํ์ต ๋ฉํ ์ ๋ณด
|
| 63 |
+
meta = {
|
| 64 |
+
"step": step,
|
| 65 |
+
"best_val_loss": best_val_loss,
|
| 66 |
+
"wandb_run_id": wandb_run_id,
|
| 67 |
+
"config": self.config.__dict__,
|
| 68 |
+
}
|
| 69 |
+
with open(ckpt_path / "meta.json", "w") as f:
|
| 70 |
+
json.dump(meta, f, indent=2)
|
| 71 |
+
|
| 72 |
+
# 4) ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 73 |
+
torch.save(metrics_history, ckpt_path / "metrics.pt")
|
| 74 |
+
|
| 75 |
+
# 5) ๋๋ค ์ํ (์์ ์ฌํ์ ์ํด)
|
| 76 |
+
rng_states = {
|
| 77 |
+
"python": torch.random.get_rng_state(),
|
| 78 |
+
"cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
| 79 |
+
}
|
| 80 |
+
torch.save(rng_states, ckpt_path / "rng_states.pt")
|
| 81 |
+
|
| 82 |
+
elapsed = time.time() - start
|
| 83 |
+
ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
|
| 84 |
+
print(f" ์ ์ฅ ์๋ฃ: {ckpt_size:.2f} GB, {elapsed:.1f}์ด")
|
| 85 |
+
|
| 86 |
+
# ์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์ (๋กค๋ง)
|
| 87 |
+
self._cleanup_old_checkpoints()
|
| 88 |
+
|
| 89 |
+
def load_latest(
|
| 90 |
+
self,
|
| 91 |
+
model: nn.Module,
|
| 92 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 93 |
+
device: torch.device = torch.device("cpu"),
|
| 94 |
+
) -> Dict[str, Any]:
|
| 95 |
+
"""๊ฐ์ฅ ์ต๊ทผ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก๋ํฉ๋๋ค.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
{"step", "best_val_loss", "wandb_run_id", "metrics_history"}
|
| 99 |
+
๋๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฉด None
|
| 100 |
+
"""
|
| 101 |
+
ckpt_path = self._find_latest()
|
| 102 |
+
if ckpt_path is None:
|
| 103 |
+
print("[Checkpoint] ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ ์์. ์ฒ์๋ถํฐ ์์ํฉ๋๋ค.")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
print(f"\n๐ ์ฒดํฌํฌ์ธํธ ๋ก๋: {ckpt_path}")
|
| 107 |
+
start = time.time()
|
| 108 |
+
|
| 109 |
+
# 1) ๋ชจ๋ธ ๊ฐ์ค์น
|
| 110 |
+
model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
|
| 111 |
+
model.load_state_dict(model_state)
|
| 112 |
+
del model_state # ๋ฉ๋ชจ๋ฆฌ ํด์
|
| 113 |
+
|
| 114 |
+
# 2) ์ตํฐ๋ง์ด์ ์ํ
|
| 115 |
+
if optimizer is not None:
|
| 116 |
+
optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
|
| 117 |
+
optimizer.load_state_dict(optim_state)
|
| 118 |
+
del optim_state
|
| 119 |
+
|
| 120 |
+
# 3) ๋ฉํ ์ ๋ณด
|
| 121 |
+
with open(ckpt_path / "meta.json", "r") as f:
|
| 122 |
+
meta = json.load(f)
|
| 123 |
+
|
| 124 |
+
# 4) ๋ฉํธ๋ฆญ ๊ธฐ๋ก
|
| 125 |
+
metrics_history = {}
|
| 126 |
+
metrics_path = ckpt_path / "metrics.pt"
|
| 127 |
+
if metrics_path.exists():
|
| 128 |
+
metrics_history = torch.load(metrics_path, weights_only=False)
|
| 129 |
+
|
| 130 |
+
# 5) ๋๋ค ์ํ ๋ณต์
|
| 131 |
+
rng_path = ckpt_path / "rng_states.pt"
|
| 132 |
+
if rng_path.exists():
|
| 133 |
+
rng_states = torch.load(rng_path, weights_only=False)
|
| 134 |
+
torch.random.set_rng_state(rng_states["python"])
|
| 135 |
+
if rng_states["cuda"] is not None and torch.cuda.is_available():
|
| 136 |
+
torch.cuda.set_rng_state(rng_states["cuda"])
|
| 137 |
+
|
| 138 |
+
elapsed = time.time() - start
|
| 139 |
+
print(f" ๋ก๋ ์๋ฃ: step={meta['step']}, {elapsed:.1f}์ด")
|
| 140 |
+
|
| 141 |
+
return {
|
| 142 |
+
"step": meta["step"],
|
| 143 |
+
"best_val_loss": meta["best_val_loss"],
|
| 144 |
+
"wandb_run_id": meta.get("wandb_run_id"),
|
| 145 |
+
"metrics_history": metrics_history,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def _find_latest(self) -> Optional[Path]:
|
| 149 |
+
"""๊ฐ์ฅ ์ต๊ทผ ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ต๋๋ค."""
|
| 150 |
+
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 151 |
+
return ckpts[-1] if ckpts else None
|
| 152 |
+
|
| 153 |
+
def _cleanup_old_checkpoints(self):
|
| 154 |
+
"""์ค๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ญ์ ํฉ๋๋ค (๋กค๋ง)."""
|
| 155 |
+
ckpts = sorted(self.checkpoint_dir.glob("step_*"))
|
| 156 |
+
while len(ckpts) > self.max_checkpoints:
|
| 157 |
+
old = ckpts.pop(0)
|
| 158 |
+
print(f" ๐๏ธ ์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์ : {old.name}")
|
| 159 |
+
shutil.rmtree(old)
|
llm_lab/training/metrics.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ๋ฉํธ๋ฆญ ์ถ์ ๋ฐ ๋ก๊น
."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from llm_lab.config import TrainConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MetricsTracker:
|
| 11 |
+
"""ํ์ต ๋ฉํธ๋ฆญ์ ์ถ์ ํ๊ณ ๋ก๊น
ํฉ๋๋ค.
|
| 12 |
+
|
| 13 |
+
์ถ์ ํญ๋ชฉ:
|
| 14 |
+
- train/loss: ํ์ต Loss (Cross-Entropy)
|
| 15 |
+
- train/lr: ํ์ฌ ํ์ต๋ฅ
|
| 16 |
+
- train/grad_norm: Gradient L2 Norm
|
| 17 |
+
- train/tokens_per_sec: ์ฒ๋ฆฌ๋
|
| 18 |
+
- train/gpu_mem_gb: GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋
|
| 19 |
+
- val/loss: ๊ฒ์ฆ Loss
|
| 20 |
+
- val/perplexity: ๊ฒ์ฆ Perplexity (= exp(loss))
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, config: TrainConfig):
|
| 24 |
+
self.config = config
|
| 25 |
+
self.history: Dict[str, list] = {
|
| 26 |
+
"step": [],
|
| 27 |
+
"train_loss": [],
|
| 28 |
+
"learning_rate": [],
|
| 29 |
+
"grad_norm": [],
|
| 30 |
+
"tokens_per_sec": [],
|
| 31 |
+
"gpu_mem_gb": [],
|
| 32 |
+
"val_loss": [],
|
| 33 |
+
"val_ppl": [],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# wandb ์ด๊ธฐํ
|
| 37 |
+
self.wandb_run = None
|
| 38 |
+
if config.use_wandb:
|
| 39 |
+
self._init_wandb()
|
| 40 |
+
|
| 41 |
+
def _init_wandb(self, resume_id: Optional[str] = None):
|
| 42 |
+
"""wandb ์ด๊ธฐํ (์ธ์
๊ฐ ์ฐ์ ๋ก๊น
์ง์)."""
|
| 43 |
+
try:
|
| 44 |
+
import wandb
|
| 45 |
+
|
| 46 |
+
run_id = resume_id or wandb.util.generate_id()
|
| 47 |
+
self.wandb_run = wandb.init(
|
| 48 |
+
project=self.config.wandb_project,
|
| 49 |
+
name=self.config.wandb_run_name or f"1b-run-{run_id[:6]}",
|
| 50 |
+
id=run_id,
|
| 51 |
+
resume="allow",
|
| 52 |
+
config=self.config.__dict__,
|
| 53 |
+
)
|
| 54 |
+
print(f"[wandb] ์ด๊ธฐํ ์๋ฃ: {self.wandb_run.url}")
|
| 55 |
+
except ImportError:
|
| 56 |
+
print("[wandb] ์ค์น๋์ง ์์. ์ฝ์ ๋ก๊น
๋ง ์ฌ์ฉํฉ๋๋ค.")
|
| 57 |
+
self.config.use_wandb = False
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"[wandb] ์ด๊ธฐํ ์คํจ: {e}. ์ฝ์ ๋ก๊น
๋ง ์ฌ์ฉํฉ๋๋ค.")
|
| 60 |
+
self.config.use_wandb = False
|
| 61 |
+
|
| 62 |
+
def resume_wandb(self, run_id: str):
|
| 63 |
+
"""์ด์ wandb ์คํ์ ์ด์ด์ ๋ก๊น
ํฉ๋๋ค."""
|
| 64 |
+
if self.config.use_wandb:
|
| 65 |
+
self._init_wandb(resume_id=run_id)
|
| 66 |
+
|
| 67 |
+
def log_train_step(
|
| 68 |
+
self,
|
| 69 |
+
step: int,
|
| 70 |
+
loss: float,
|
| 71 |
+
lr: float,
|
| 72 |
+
grad_norm: float,
|
| 73 |
+
tokens_per_sec: float,
|
| 74 |
+
gpu_mem_gb: float,
|
| 75 |
+
):
|
| 76 |
+
"""ํ์ต ์คํ
๋ฉํธ๋ฆญ์ ๊ธฐ๋กํฉ๋๋ค."""
|
| 77 |
+
self.history["step"].append(step)
|
| 78 |
+
self.history["train_loss"].append(loss)
|
| 79 |
+
self.history["learning_rate"].append(lr)
|
| 80 |
+
self.history["grad_norm"].append(grad_norm)
|
| 81 |
+
self.history["tokens_per_sec"].append(tokens_per_sec)
|
| 82 |
+
self.history["gpu_mem_gb"].append(gpu_mem_gb)
|
| 83 |
+
|
| 84 |
+
if self.config.use_wandb and self.wandb_run:
|
| 85 |
+
import wandb
|
| 86 |
+
|
| 87 |
+
wandb.log({
|
| 88 |
+
"train/loss": loss,
|
| 89 |
+
"train/lr": lr,
|
| 90 |
+
"train/grad_norm": grad_norm,
|
| 91 |
+
"train/tokens_per_sec": tokens_per_sec,
|
| 92 |
+
"train/gpu_mem_gb": gpu_mem_gb,
|
| 93 |
+
}, step=step)
|
| 94 |
+
|
| 95 |
+
def log_eval(self, step: int, val_loss: float, val_ppl: float):
|
| 96 |
+
"""๊ฒ์ฆ ๋ฉํธ๋ฆญ์ ๊ธฐ๋กํฉ๋๋ค."""
|
| 97 |
+
self.history["val_loss"].append(val_loss)
|
| 98 |
+
self.history["val_ppl"].append(val_ppl)
|
| 99 |
+
|
| 100 |
+
if self.config.use_wandb and self.wandb_run:
|
| 101 |
+
import wandb
|
| 102 |
+
|
| 103 |
+
wandb.log({
|
| 104 |
+
"val/loss": val_loss,
|
| 105 |
+
"val/perplexity": val_ppl,
|
| 106 |
+
}, step=step)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def wandb_run_id(self) -> Optional[str]:
|
| 110 |
+
if self.wandb_run:
|
| 111 |
+
return self.wandb_run.id
|
| 112 |
+
return None
|
llm_lab/training/optimizer.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AdamW ์ตํฐ๋ง์ด์ ์์ฑ (Weight Decay ๋ถ๋ฆฌ)."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from llm_lab.config import TrainConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
|
| 10 |
+
"""AdamW ์ตํฐ๋ง์ด์ ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 11 |
+
|
| 12 |
+
Weight Decay ๋ถ๋ฆฌ ๊ท์น:
|
| 13 |
+
- Decay ์ ์ฉ: Linear ๊ฐ์ค์น (attention proj, FFN ๋ฑ)
|
| 14 |
+
- Decay ๋ฏธ์ ์ฉ: Embedding, LayerNorm/RMSNorm, Bias
|
| 15 |
+
|
| 16 |
+
์ ๋ถ๋ฆฌํ๋๊ฐ?
|
| 17 |
+
- Weight Decay๋ ํฐ ๊ฐ์ค์น์ ํจ๋ํฐ๋ฅผ ์ฃผ์ด ๊ณผ์ ํฉ ๋ฐฉ์ง
|
| 18 |
+
- ํ์ง๋ง Norm์ scale ํ๋ผ๋ฏธํฐ์ ์ ์ฉํ๋ฉด ์ ๊ทํ ํจ๊ณผ๋ฅผ ๋ฐฉํด
|
| 19 |
+
- Embedding์ ์ ์ฉํ๋ฉด ํฌ๊ท ํ ํฐ์ ํํ์ด 0์ผ๋ก ์์ถ
|
| 20 |
+
- 1D ํ๋ผ๋ฏธํฐ(bias, norm weight)๋ decay์์ ์ ์ธํ๋ ๊ฒ์ด ๊ด๋ก
|
| 21 |
+
"""
|
| 22 |
+
# ํ๋ผ๋ฏธํฐ๋ฅผ decay/no-decay ๊ทธ๋ฃน์ผ๋ก ๋ถ๋ฆฌ
|
| 23 |
+
decay_params = []
|
| 24 |
+
no_decay_params = []
|
| 25 |
+
|
| 26 |
+
for name, param in model.named_parameters():
|
| 27 |
+
if not param.requires_grad:
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
# 1D ํ
์(bias, norm weight) ๋๋ embedding โ no decay
|
| 31 |
+
if param.dim() <= 1 or "embedding" in name:
|
| 32 |
+
no_decay_params.append(param)
|
| 33 |
+
else:
|
| 34 |
+
decay_params.append(param)
|
| 35 |
+
|
| 36 |
+
param_groups = [
|
| 37 |
+
{"params": decay_params, "weight_decay": config.weight_decay},
|
| 38 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
n_decay = sum(p.numel() for p in decay_params)
|
| 42 |
+
n_no_decay = sum(p.numel() for p in no_decay_params)
|
| 43 |
+
print(f"[Optimizer] Decay ํ๋ผ๋ฏธํฐ: {n_decay:,} ({n_decay/1e6:.1f}M)")
|
| 44 |
+
print(f"[Optimizer] No-decay ํ๋ผ๋ฏธํฐ: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
|
| 45 |
+
|
| 46 |
+
optimizer = torch.optim.AdamW(
|
| 47 |
+
param_groups,
|
| 48 |
+
lr=config.learning_rate,
|
| 49 |
+
betas=(config.beta1, config.beta2),
|
| 50 |
+
eps=config.adam_eps,
|
| 51 |
+
fused=torch.cuda.is_available(), # CUDA fused AdamW (๋ ๋น ๋ฆ)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return optimizer
|
llm_lab/training/runner.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ํ์ต ์คํ ํฌํผ (Quick Start)."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
from llm_lab.config import TrainConfig
|
| 11 |
+
from .trainer import Trainer
|
| 12 |
+
from llm_lab.utils import auto_configure
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def start_training(
|
| 16 |
+
model: nn.Module,
|
| 17 |
+
train_dataloader: DataLoader,
|
| 18 |
+
val_dataloader: Optional[DataLoader] = None,
|
| 19 |
+
config: Optional[TrainConfig] = None,
|
| 20 |
+
seq_len: int = 2048,
|
| 21 |
+
auto_config: bool = True,
|
| 22 |
+
) -> Trainer:
|
| 23 |
+
"""ํ์ต์ ์์ํฉ๋๋ค (ํ ์ค ์คํ).
|
| 24 |
+
|
| 25 |
+
์ฌ์ฉ๋ฒ (Colab):
|
| 26 |
+
```python
|
| 27 |
+
from model import LLMModel, ModelConfig
|
| 28 |
+
from data_pipeline import setup_data_pipeline, DataConfig
|
| 29 |
+
from trainer import start_training, TrainConfig
|
| 30 |
+
|
| 31 |
+
# 1. ๋ชจ๋ธ ์์ฑ
|
| 32 |
+
model_config = ModelConfig.base_1b()
|
| 33 |
+
model = LLMModel(model_config)
|
| 34 |
+
|
| 35 |
+
# 2. ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ
|
| 36 |
+
tok, train_dl, val_dl = setup_data_pipeline("pretrained")
|
| 37 |
+
|
| 38 |
+
# 3. ํ์ต ์์ (์ฒดํฌํฌ์ธํธ ์๋ ๋ณต์)
|
| 39 |
+
trainer = start_training(model, train_dl, val_dl)
|
| 40 |
+
```
|
| 41 |
+
"""
|
| 42 |
+
config = config or TrainConfig()
|
| 43 |
+
|
| 44 |
+
# GPU ์๋ ๊ฐ์ง ๋ฐ ์ค์ ์กฐ์
|
| 45 |
+
if auto_config:
|
| 46 |
+
config = auto_configure(config)
|
| 47 |
+
|
| 48 |
+
# Google Drive ๋ง์ดํธ ํ์ธ (Colab)
|
| 49 |
+
if "/content/drive" in config.checkpoint_dir:
|
| 50 |
+
drive_path = Path("/content/drive/MyDrive")
|
| 51 |
+
if not drive_path.exists():
|
| 52 |
+
print("\nโ ๏ธ Google Drive๊ฐ ๋ง์ดํธ๋์ง ์์์ต๋๋ค!")
|
| 53 |
+
print(" Colab์์ ์คํ: from google.colab import drive; drive.mount('/content/drive')")
|
| 54 |
+
print(" ๋ก์ปฌ ๊ฒฝ๋ก๋ก ๋ณ๊ฒฝํฉ๋๋ค.")
|
| 55 |
+
config.checkpoint_dir = "./checkpoints"
|
| 56 |
+
|
| 57 |
+
# ์ฌํ์ฑ ์๋ ์ค์
|
| 58 |
+
torch.manual_seed(config.seed)
|
| 59 |
+
if torch.cuda.is_available():
|
| 60 |
+
torch.cuda.manual_seed(config.seed)
|
| 61 |
+
|
| 62 |
+
# Trainer ์์ฑ (์ฒดํฌํฌ์ธํธ ์๋ ๋ณต์ ํฌํจ)
|
| 63 |
+
trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
|
| 64 |
+
|
| 65 |
+
# ํ์ต ์คํ
|
| 66 |
+
trainer.train()
|
| 67 |
+
|
| 68 |
+
return trainer
|
llm_lab/training/scheduler.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cosine Annealing with Linear Warmup ์ค์ผ์ค๋ฌ."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from llm_lab.config import TrainConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CosineWarmupScheduler:
|
| 11 |
+
"""Cosine Annealing with Linear Warmup.
|
| 12 |
+
|
| 13 |
+
LR ๊ณก์ :
|
| 14 |
+
โโโโ peak_lr โโโโโโโโฒ
|
| 15 |
+
โ โฒ cosine decay
|
| 16 |
+
โ warmup (linear) โฒ
|
| 17 |
+
โ/ โฒ_______ min_lr
|
| 18 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ steps
|
| 19 |
+
|
| 20 |
+
์ Cosine Decay์ธ๊ฐ?
|
| 21 |
+
- Step decay: ๊ฐ์์ค๋ฌ์ด LR ํ๋ฝ โ Loss ๋ถ์์
|
| 22 |
+
- Linear decay: ํ๋ฐ๋ถ LR์ด ๋๋ฌด ๋นจ๋ฆฌ ๊ฐ์
|
| 23 |
+
- Cosine: ๋ถ๋๋ฌ์ด ๊ฐ์, ํ์ต ํ๋ฐ์๋ ์ ์ ํ LR ์ ์ง
|
| 24 |
+
- GPT-3, LLaMA, Chinchilla ๋ฑ ๋๋ถ๋ถ์ LLM์ด ์ฌ์ฉ
|
| 25 |
+
|
| 26 |
+
๊ตฌํ ์ฐธ๊ณ :
|
| 27 |
+
PyTorch ๋ด์ฅ ์ค์ผ์ค๋ฌ(CosineAnnealingLR ๋ฑ)๋ ์์ง๋ง,
|
| 28 |
+
warmup + min_lr + ์ฒดํฌํฌ์ธํธ ๋ณต์์ ์ํด ์ง์ ๊ตฌํ์ด ๋ ์ ์ฐํฉ๋๋ค.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: TrainConfig):
|
| 32 |
+
self.peak_lr = config.learning_rate
|
| 33 |
+
self.min_lr = config.min_learning_rate
|
| 34 |
+
self.warmup_steps = config.warmup_steps
|
| 35 |
+
self.total_steps = config.total_steps
|
| 36 |
+
|
| 37 |
+
def get_lr(self, step: int) -> float:
|
| 38 |
+
"""ํ์ฌ step์ ํด๋นํ๋ ํ์ต๋ฅ ์ ๋ฐํํฉ๋๋ค.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
step: ํ์ฌ optimizer step (0-indexed)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
ํ์ต๋ฅ (float)
|
| 45 |
+
"""
|
| 46 |
+
# Phase 1: Linear Warmup
|
| 47 |
+
if step < self.warmup_steps:
|
| 48 |
+
# 0 โ peak_lr ์ ํ ์ฆ๊ฐ
|
| 49 |
+
return self.peak_lr * (step / self.warmup_steps)
|
| 50 |
+
|
| 51 |
+
# Phase 2: Cosine Decay
|
| 52 |
+
# warmup ์ดํ ๋จ์ ์งํ๋ฅ (0.0 โ 1.0)
|
| 53 |
+
decay_steps = self.total_steps - self.warmup_steps
|
| 54 |
+
progress = (step - self.warmup_steps) / max(decay_steps, 1)
|
| 55 |
+
progress = min(progress, 1.0) # ์์ ์ฅ์น
|
| 56 |
+
|
| 57 |
+
# Cosine ๊ณต์: min_lr + 0.5 ร (peak - min) ร (1 + cos(ฯ ร progress))
|
| 58 |
+
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 59 |
+
lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
|
| 60 |
+
|
| 61 |
+
return lr
|
| 62 |
+
|
| 63 |
+
def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
|
| 64 |
+
"""Optimizer์ ํ์ต๋ฅ ์ ์
๋ฐ์ดํธํฉ๋๋ค."""
|
| 65 |
+
lr = self.get_lr(step)
|
| 66 |
+
for param_group in optimizer.param_groups:
|
| 67 |
+
param_group["lr"] = lr
|
| 68 |
+
return lr
|
llm_lab/training/trainer.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM ์ฌ์ ํ์ต ํธ๋ ์ด๋."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import time
|
| 5 |
+
from typing import Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from llm_lab.config import TrainConfig
|
| 12 |
+
from .scheduler import CosineWarmupScheduler
|
| 13 |
+
from .checkpoint import CheckpointManager
|
| 14 |
+
from .metrics import MetricsTracker
|
| 15 |
+
from .optimizer import create_optimizer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Trainer:
|
| 19 |
+
"""LLM ์ฌ์ ํ์ต ํธ๋ ์ด๋.
|
| 20 |
+
|
| 21 |
+
ํ์ต ๋ฃจํ์ ํต์ฌ ๊ตฌ์กฐ:
|
| 22 |
+
```
|
| 23 |
+
for step in range(total_steps):
|
| 24 |
+
# โโ Gradient Accumulation Loop โโ
|
| 25 |
+
for micro_step in range(accumulation_steps):
|
| 26 |
+
batch = next(dataloader)
|
| 27 |
+
with autocast(bf16):
|
| 28 |
+
logits, loss = model(input_ids, targets)
|
| 29 |
+
scaled_loss = loss / accumulation_steps
|
| 30 |
+
scaled_loss.backward() # gradient ๋์
|
| 31 |
+
|
| 32 |
+
# โโ Optimizer Step (accumulation ์๋ฃ ํ) โโ
|
| 33 |
+
clip_grad_norm(model, max_norm=1.0)
|
| 34 |
+
optimizer.step()
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
scheduler.set_lr(optimizer, step)
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Gradient Accumulation์ด๋?
|
| 40 |
+
- GPU ๋ฉ๋ชจ๋ฆฌ์ ํฐ ๋ฐฐ์น๋ฅผ ํ ๋ฒ์ ์ฌ๋ฆด ์ ์์ ๋
|
| 41 |
+
- ์์ micro_batch๋ก ์ฌ๋ฌ ๋ฒ forward/backward โ gradient๋ฅผ ๋์
|
| 42 |
+
- ๋์ ํ ํ ๋ฒ์ optimizer step
|
| 43 |
+
- ๊ฒฐ๊ณผ์ ์ผ๋ก ํฐ effective_batch์ ๋์ผํ ํจ๊ณผ
|
| 44 |
+
- Loss๋ฅผ accumulation_steps๋ก ๋๋๋ ์ด์ :
|
| 45 |
+
gradient์ ํ๊ท ์ ๊ตฌํ๊ธฐ ์ํด (ํฉ์ด ์๋ ํ๊ท )
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
model: nn.Module,
|
| 51 |
+
train_dataloader: DataLoader,
|
| 52 |
+
val_dataloader: Optional[DataLoader],
|
| 53 |
+
config: TrainConfig,
|
| 54 |
+
seq_len: int = 2048,
|
| 55 |
+
):
|
| 56 |
+
self.config = config
|
| 57 |
+
self.seq_len = seq_len
|
| 58 |
+
|
| 59 |
+
# โโ ๋๋ฐ์ด์ค ์ค์ โโ
|
| 60 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
+
print(f"[Trainer] ๋๋ฐ์ด์ค: {self.device}")
|
| 62 |
+
if torch.cuda.is_available():
|
| 63 |
+
print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
|
| 64 |
+
print(f"[Trainer] GPU ๋ฉ๋ชจ๋ฆฌ: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 65 |
+
|
| 66 |
+
# โโ ๋ชจ๋ธ โโ
|
| 67 |
+
self.model = model.to(self.device)
|
| 68 |
+
# torch.compile: PyTorch 2.0+ ๊ทธ๋ํ ์ต์ ํ (์๋ 10-30% ํฅ์)
|
| 69 |
+
if torch.cuda.is_available() and hasattr(torch, "compile"):
|
| 70 |
+
print("[Trainer] torch.compile ์ ์ฉ ์ค...")
|
| 71 |
+
self.model = torch.compile(self.model)
|
| 72 |
+
|
| 73 |
+
# โโ ๋ฐ์ดํฐ โโ
|
| 74 |
+
self.train_dataloader = train_dataloader
|
| 75 |
+
self.val_dataloader = val_dataloader
|
| 76 |
+
self.train_iter = iter(train_dataloader)
|
| 77 |
+
|
| 78 |
+
# โโ ์ตํฐ๋ง์ด์ โโ
|
| 79 |
+
self.optimizer = create_optimizer(self.model, config)
|
| 80 |
+
|
| 81 |
+
# โโ ์ค์ผ์ค๋ฌ โโ
|
| 82 |
+
self.scheduler = CosineWarmupScheduler(config)
|
| 83 |
+
|
| 84 |
+
# โโ ์ฒดํฌํฌ์ธํธ โโ
|
| 85 |
+
self.ckpt_manager = CheckpointManager(config)
|
| 86 |
+
|
| 87 |
+
# โโ ๋ฉํธ๋ฆญ โโ
|
| 88 |
+
self.metrics = MetricsTracker(config)
|
| 89 |
+
|
| 90 |
+
# โโ ํ์ต ์ํ โโ
|
| 91 |
+
self.global_step = 0
|
| 92 |
+
self.best_val_loss = float("inf")
|
| 93 |
+
self.tokens_seen = 0
|
| 94 |
+
|
| 95 |
+
# โโ Mixed Precision โโ
|
| 96 |
+
# bf16์ GradScaler๊ฐ ๋ถํ์ (fp16์ผ ๋๋ง ํ์)
|
| 97 |
+
self.use_amp = config.dtype != "float32"
|
| 98 |
+
self.amp_dtype = config.torch_dtype
|
| 99 |
+
|
| 100 |
+
# โโ ์๋ ๋ณต์ ์๋ โโ
|
| 101 |
+
self._try_resume()
|
| 102 |
+
|
| 103 |
+
def _try_resume(self):
|
| 104 |
+
"""์ด์ ์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฉด ์๋์ผ๋ก ๋ณต์ํฉ๋๋ค."""
|
| 105 |
+
result = self.ckpt_manager.load_latest(
|
| 106 |
+
self.model, self.optimizer, self.device
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if result is not None:
|
| 110 |
+
self.global_step = result["step"]
|
| 111 |
+
self.best_val_loss = result["best_val_loss"]
|
| 112 |
+
self.metrics.history = result.get("metrics_history", self.metrics.history)
|
| 113 |
+
|
| 114 |
+
# wandb ์ฐ์ ๋ก๊น
|
| 115 |
+
if result.get("wandb_run_id"):
|
| 116 |
+
self.metrics.resume_wandb(result["wandb_run_id"])
|
| 117 |
+
|
| 118 |
+
self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
|
| 119 |
+
print(f"[Trainer] ํ์ต ์ฌ๊ฐ: step={self.global_step}, "
|
| 120 |
+
f"tokens={self.tokens_seen/1e9:.2f}B, "
|
| 121 |
+
f"best_val_loss={self.best_val_loss:.4f}")
|
| 122 |
+
|
| 123 |
+
def _get_next_batch(self) -> Dict[str, torch.Tensor]:
|
| 124 |
+
"""๋ค์ ํ์ต ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
|
| 125 |
+
|
| 126 |
+
Streaming DataLoader๋ ์ํญ ๊ฐ๋
์ด ์์ผ๋ฏ๋ก,
|
| 127 |
+
StopIteration ์ ์ ์ดํฐ๋ ์ดํฐ๋ฅผ ์์ฑํฉ๋๋ค.
|
| 128 |
+
"""
|
| 129 |
+
try:
|
| 130 |
+
batch = next(self.train_iter)
|
| 131 |
+
except StopIteration:
|
| 132 |
+
self.train_iter = iter(self.train_dataloader)
|
| 133 |
+
batch = next(self.train_iter)
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"input_ids": batch["input_ids"].to(self.device, non_blocking=True),
|
| 137 |
+
"targets": batch["targets"].to(self.device, non_blocking=True),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def _train_step(self) -> Tuple[float, float]:
|
| 141 |
+
"""ํ๋์ optimizer step์ ์ํํฉ๋๋ค.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
(loss, grad_norm)
|
| 145 |
+
"""
|
| 146 |
+
self.model.train()
|
| 147 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 148 |
+
# set_to_none=True: gradient๋ฅผ None์ผ๋ก ์ค์ โ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ
|
| 149 |
+
|
| 150 |
+
total_loss = 0.0
|
| 151 |
+
|
| 152 |
+
# โโ Gradient Accumulation Loop โโ
|
| 153 |
+
for micro_step in range(self.config.gradient_accumulation_steps):
|
| 154 |
+
batch = self._get_next_batch()
|
| 155 |
+
|
| 156 |
+
# Mixed Precision Forward
|
| 157 |
+
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 158 |
+
logits, loss = self.model(batch["input_ids"], batch["targets"])
|
| 159 |
+
|
| 160 |
+
# Loss ์ค์ผ์ผ๋ง: effective batch์ ํ๊ท ์ ์ํด
|
| 161 |
+
scaled_loss = loss / self.config.gradient_accumulation_steps
|
| 162 |
+
total_loss += loss.item()
|
| 163 |
+
|
| 164 |
+
# Backward (gradient ๋์ )
|
| 165 |
+
scaled_loss.backward()
|
| 166 |
+
|
| 167 |
+
# โโ Gradient Clipping โโ
|
| 168 |
+
# ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ gradient๋ฅผ ํ๋์ ๋ฒกํฐ๋ก ๋ณด๊ณ L2 norm ๊ณ์ฐ
|
| 169 |
+
# norm์ด max_norm์ ์ด๊ณผํ๋ฉด ๋น๋ก์ ์ผ๋ก ์ค์ผ์ผ ๋ค์ด
|
| 170 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 171 |
+
self.model.parameters(),
|
| 172 |
+
max_norm=self.config.grad_clip,
|
| 173 |
+
).item()
|
| 174 |
+
|
| 175 |
+
# โโ Optimizer Step โโ
|
| 176 |
+
self.optimizer.step()
|
| 177 |
+
|
| 178 |
+
# โโ LR ์
๋ฐ์ดํธ โโ
|
| 179 |
+
self.scheduler.set_lr(self.optimizer, self.global_step)
|
| 180 |
+
|
| 181 |
+
avg_loss = total_loss / self.config.gradient_accumulation_steps
|
| 182 |
+
return avg_loss, grad_norm
|
| 183 |
+
|
| 184 |
+
@torch.no_grad()
|
| 185 |
+
def _evaluate(self) -> Tuple[float, float]:
|
| 186 |
+
"""๊ฒ์ฆ ๋ฐ์ดํฐ์์ Loss์ Perplexity๋ฅผ ์ธก์ ํฉ๋๋ค.
|
| 187 |
+
|
| 188 |
+
Perplexity = exp(loss)
|
| 189 |
+
- ์ง๊ด: "๋ชจ๋ธ์ด ๋ค์ ํ ํฐ์ ํ๊ท ๋ช ๊ฐ์ ํ๋ณด ์ค์์ ๊ณ ๋ฅด๋๊ฐ"
|
| 190 |
+
- PPL 100 โ 100๊ฐ ์ค 1๊ฐ๋ฅผ ๊ท ์ผํ๊ฒ ๊ณ ๋ฅด๋ ์์ค
|
| 191 |
+
- PPL 20 โ 20๊ฐ ์ค 1๊ฐ ์์ค (๊ฝค ์ข์)
|
| 192 |
+
- PPL 10 โ ๋งค์ฐ ์์ ์๊ฒ ์์ธก
|
| 193 |
+
"""
|
| 194 |
+
if self.val_dataloader is None:
|
| 195 |
+
return float("inf"), float("inf")
|
| 196 |
+
|
| 197 |
+
self.model.eval()
|
| 198 |
+
total_loss = 0.0
|
| 199 |
+
num_batches = 0
|
| 200 |
+
|
| 201 |
+
for i, batch in enumerate(self.val_dataloader):
|
| 202 |
+
if i >= self.config.eval_steps:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 206 |
+
targets = batch["targets"].to(self.device)
|
| 207 |
+
|
| 208 |
+
with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
|
| 209 |
+
_, loss = self.model(input_ids, targets)
|
| 210 |
+
|
| 211 |
+
total_loss += loss.item()
|
| 212 |
+
num_batches += 1
|
| 213 |
+
|
| 214 |
+
avg_loss = total_loss / max(num_batches, 1)
|
| 215 |
+
perplexity = math.exp(min(avg_loss, 20)) # overflow ๋ฐฉ์ง (exp(20) โ 5์ต)
|
| 216 |
+
|
| 217 |
+
return avg_loss, perplexity
|
| 218 |
+
|
| 219 |
+
def train(self):
|
| 220 |
+
"""๋ฉ์ธ ํ์ต ๋ฃจํ.
|
| 221 |
+
|
| 222 |
+
์ด ๋ฉ์๋๊ฐ ์ ์ฒด ํ์ต์ ์คํํฉ๋๋ค.
|
| 223 |
+
Colab ์ธ์
๋ง๋ฃ ์ ์ค๋จ๋์ด๋ ์ฒดํฌํฌ์ธํธ์์ ์๋ ์ฌ๊ฐ๋ฉ๋๋ค.
|
| 224 |
+
"""
|
| 225 |
+
config = self.config
|
| 226 |
+
|
| 227 |
+
print("\n" + "=" * 70)
|
| 228 |
+
print("๐ ํ์ต ์์")
|
| 229 |
+
print("=" * 70)
|
| 230 |
+
print(f" ์ด ์คํ
: {config.total_steps:,}")
|
| 231 |
+
print(f" ์์ ์คํ
: {self.global_step}")
|
| 232 |
+
print(f" Effective batch size: {config.effective_batch_size}")
|
| 233 |
+
print(f" ํ ํฐ/์คํ
: {config.effective_batch_size * self.seq_len:,}")
|
| 234 |
+
print(f" ์ด ํ์ต ํ ํฐ (์์): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
|
| 235 |
+
print(f" Mixed Precision: {config.dtype}")
|
| 236 |
+
print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
|
| 237 |
+
print(f" ์ฒดํฌํฌ์ธํธ: {config.checkpoint_dir}")
|
| 238 |
+
print("=" * 70 + "\n")
|
| 239 |
+
|
| 240 |
+
step_start_time = time.time()
|
| 241 |
+
tokens_at_log_start = self.tokens_seen
|
| 242 |
+
|
| 243 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 244 |
+
# ๋ฉ์ธ ๋ฃจํ
|
| 245 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 246 |
+
|
| 247 |
+
while self.global_step < config.total_steps:
|
| 248 |
+
|
| 249 |
+
# โโ Train Step โโ
|
| 250 |
+
loss, grad_norm = self._train_step()
|
| 251 |
+
self.global_step += 1
|
| 252 |
+
self.tokens_seen += config.effective_batch_size * self.seq_len
|
| 253 |
+
|
| 254 |
+
# โโ Logging โโ
|
| 255 |
+
if self.global_step % config.log_interval == 0:
|
| 256 |
+
elapsed = time.time() - step_start_time
|
| 257 |
+
tokens_delta = self.tokens_seen - tokens_at_log_start
|
| 258 |
+
tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
|
| 259 |
+
|
| 260 |
+
# GPU ๋ฉ๋ชจ๋ฆฌ
|
| 261 |
+
gpu_mem_gb = 0.0
|
| 262 |
+
if torch.cuda.is_available():
|
| 263 |
+
gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
|
| 264 |
+
|
| 265 |
+
# ํ์ฌ LR
|
| 266 |
+
current_lr = self.scheduler.get_lr(self.global_step)
|
| 267 |
+
|
| 268 |
+
# ๋จ์ ์๊ฐ ์ถ์
|
| 269 |
+
remaining_steps = config.total_steps - self.global_step
|
| 270 |
+
steps_per_sec = config.log_interval / max(elapsed, 1e-6)
|
| 271 |
+
eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
|
| 272 |
+
eta_hours = eta_seconds / 3600
|
| 273 |
+
|
| 274 |
+
# ์ฝ์ ์ถ๋ ฅ
|
| 275 |
+
print(
|
| 276 |
+
f" Step {self.global_step:>6d}/{config.total_steps} โ "
|
| 277 |
+
f"Loss {loss:.4f} โ "
|
| 278 |
+
f"LR {current_lr:.2e} โ "
|
| 279 |
+
f"Grad {grad_norm:.2f} โ "
|
| 280 |
+
f"{tokens_per_sec:,.0f} tok/s โ "
|
| 281 |
+
f"GPU {gpu_mem_gb:.1f}GB โ "
|
| 282 |
+
f"ETA {eta_hours:.1f}h โ "
|
| 283 |
+
f"Tokens {self.tokens_seen/1e9:.2f}B"
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# wandb ๋ก๊น
|
| 287 |
+
self.metrics.log_train_step(
|
| 288 |
+
step=self.global_step,
|
| 289 |
+
loss=loss,
|
| 290 |
+
lr=current_lr,
|
| 291 |
+
grad_norm=grad_norm,
|
| 292 |
+
tokens_per_sec=tokens_per_sec,
|
| 293 |
+
gpu_mem_gb=gpu_mem_gb,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
step_start_time = time.time()
|
| 297 |
+
tokens_at_log_start = self.tokens_seen
|
| 298 |
+
|
| 299 |
+
# โโ Evaluation โโ
|
| 300 |
+
if self.global_step % config.eval_interval == 0:
|
| 301 |
+
val_loss, val_ppl = self._evaluate()
|
| 302 |
+
|
| 303 |
+
print(f"\n ๐ Eval @ Step {self.global_step}: "
|
| 304 |
+
f"Val Loss = {val_loss:.4f}, "
|
| 305 |
+
f"Val PPL = {val_ppl:.2f}")
|
| 306 |
+
|
| 307 |
+
self.metrics.log_eval(self.global_step, val_loss, val_ppl)
|
| 308 |
+
|
| 309 |
+
if val_loss < self.best_val_loss:
|
| 310 |
+
self.best_val_loss = val_loss
|
| 311 |
+
print(f" ๐ New best val loss: {val_loss:.4f}")
|
| 312 |
+
|
| 313 |
+
print()
|
| 314 |
+
|
| 315 |
+
# โโ Checkpoint โโ
|
| 316 |
+
if self.global_step % config.checkpoint_interval == 0:
|
| 317 |
+
self.ckpt_manager.save(
|
| 318 |
+
model=self.model,
|
| 319 |
+
optimizer=self.optimizer,
|
| 320 |
+
step=self.global_step,
|
| 321 |
+
best_val_loss=self.best_val_loss,
|
| 322 |
+
metrics_history=self.metrics.history,
|
| 323 |
+
wandb_run_id=self.metrics.wandb_run_id,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 327 |
+
# ํ์ต ์๋ฃ
|
| 328 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 329 |
+
|
| 330 |
+
print("\n" + "=" * 70)
|
| 331 |
+
print("๐ ํ์ต ์๋ฃ!")
|
| 332 |
+
print("=" * 70)
|
| 333 |
+
print(f" ์ด ์คํ
: {self.global_step:,}")
|
| 334 |
+
print(f" ์ด ํ ํฐ: {self.tokens_seen/1e9:.2f}B")
|
| 335 |
+
print(f" ์ต์ Val Loss: {self.best_val_loss:.4f}")
|
| 336 |
+
print(f" ์ต์ Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
|
| 337 |
+
print("=" * 70)
|
| 338 |
+
|
| 339 |
+
# ์ต์ข
์ฒดํฌํฌ์ธํธ ์ ์ฅ
|
| 340 |
+
self.ckpt_manager.save(
|
| 341 |
+
model=self.model,
|
| 342 |
+
optimizer=self.optimizer,
|
| 343 |
+
step=self.global_step,
|
| 344 |
+
best_val_loss=self.best_val_loss,
|
| 345 |
+
metrics_history=self.metrics.history,
|
| 346 |
+
wandb_run_id=self.metrics.wandb_run_id,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if self.config.use_wandb and self.metrics.wandb_run:
|
| 350 |
+
import wandb
|
| 351 |
+
wandb.finish()
|
llm_lab/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๊ณตํต ์ ํธ๋ฆฌํฐ โ ๋๋ฐ์ด์ค ๊ฐ์ง, ์๋ ์ค์ ."""
|
| 2 |
+
from .device import get_device, detect_gpu_info, auto_configure
|
| 3 |
+
from .seed import set_seed
|
| 4 |
+
|
| 5 |
+
__all__ = ["get_device", "detect_gpu_info", "auto_configure", "set_seed"]
|
llm_lab/utils/device.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""๋๋ฐ์ด์ค ๊ฐ์ง ๋ฐ ์๋ ์ค์ ์ ํธ๋ฆฌํฐ."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from llm_lab.config import TrainConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_device() -> torch.device:
|
| 13 |
+
"""์ฌ์ฉ ๊ฐ๋ฅํ ๋๋ฐ์ด์ค(cuda ๋๋ cpu)๋ฅผ ๋ฐํํฉ๋๋ค."""
|
| 14 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def detect_gpu_info() -> dict:
|
| 18 |
+
"""GPU ์ด๋ฆ๊ณผ ๋ฉ๋ชจ๋ฆฌ ์ ๋ณด๋ฅผ ๋ฐํํฉ๋๋ค.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
{"name": str, "memory_gb": float} ๋๋ GPU๊ฐ ์์ผ๋ฉด ๋น dict
|
| 22 |
+
"""
|
| 23 |
+
if not torch.cuda.is_available():
|
| 24 |
+
return {}
|
| 25 |
+
return {
|
| 26 |
+
"name": torch.cuda.get_device_name(),
|
| 27 |
+
"memory_gb": round(torch.cuda.get_device_properties(0).total_mem / 1e9, 1),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def auto_configure(config: "TrainConfig") -> "TrainConfig":
|
| 32 |
+
"""GPU ์ข
๋ฅ์ ๋ฐ๋ผ ์ค์ ์ ์๋ ์กฐ์ ํฉ๋๋ค.
|
| 33 |
+
|
| 34 |
+
Colab Pro+์์ A100์ด ํญ์ ๋ฐฐ์ ๋์ง๋ ์์ต๋๋ค.
|
| 35 |
+
T4๋ V100์ด ๋ฐฐ์ ๋ ๊ฒฝ์ฐ ์๋์ผ๋ก ์ค์ ์ ์กฐ์ ํฉ๋๋ค.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
์กฐ์ ๋ TrainConfig
|
| 39 |
+
"""
|
| 40 |
+
if not torch.cuda.is_available():
|
| 41 |
+
print("โ ๏ธ GPU ์์! CPU ๋ชจ๋ (๋งค์ฐ ๋๋ฆผ)")
|
| 42 |
+
config.dtype = "float32"
|
| 43 |
+
config.micro_batch_size = 1
|
| 44 |
+
config.gradient_accumulation_steps = 4
|
| 45 |
+
return config
|
| 46 |
+
|
| 47 |
+
gpu_name = torch.cuda.get_device_name().lower()
|
| 48 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 49 |
+
|
| 50 |
+
print(f"\n๐ GPU ๊ฐ์ง: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
|
| 51 |
+
|
| 52 |
+
if "a100" in gpu_name:
|
| 53 |
+
# A100 40GB: ๊ธฐ๋ณธ ์ค์ ๊ทธ๋๋ก (์ต์ )
|
| 54 |
+
print(" โ A100 ๊ฐ์ง: ๊ธฐ๋ณธ ์ค์ ์ฌ์ฉ (bf16, batch=4)")
|
| 55 |
+
config.dtype = "bfloat16"
|
| 56 |
+
config.micro_batch_size = 4
|
| 57 |
+
|
| 58 |
+
elif "v100" in gpu_name:
|
| 59 |
+
# V100 16GB: bf16 ๋ฏธ์ง์, ๋ฐฐ์น ์ถ์
|
| 60 |
+
print(" โ V100 ๊ฐ์ง: fp16 ๋ชจ๋, ๋ฐฐ์น ์ถ์")
|
| 61 |
+
config.dtype = "float16"
|
| 62 |
+
config.micro_batch_size = 2
|
| 63 |
+
config.gradient_accumulation_steps = 64 # effective batch ์ ์ง
|
| 64 |
+
|
| 65 |
+
elif "t4" in gpu_name:
|
| 66 |
+
# T4 16GB: bf16 ๋ฏธ์ง์, ๋ ์์ ๋ฐฐ์น
|
| 67 |
+
print(" โ T4 ๊ฐ์ง: fp16 ๋ชจ๋, ์ต์ ๋ฐฐ์น")
|
| 68 |
+
config.dtype = "float16"
|
| 69 |
+
config.micro_batch_size = 1
|
| 70 |
+
config.gradient_accumulation_steps = 128
|
| 71 |
+
|
| 72 |
+
elif "l4" in gpu_name:
|
| 73 |
+
# L4 24GB: bf16 ์ง์
|
| 74 |
+
print(" โ L4 ๊ฐ์ง: bf16 ๋ชจ๋, ๋ฐฐ์น ์กฐ์ ")
|
| 75 |
+
config.dtype = "bfloat16"
|
| 76 |
+
config.micro_batch_size = 2
|
| 77 |
+
config.gradient_accumulation_steps = 64
|
| 78 |
+
|
| 79 |
+
else:
|
| 80 |
+
print(f" โ ์ ์ ์๋ GPU. ๋ฉ๋ชจ๋ฆฌ ๊ธฐ์ค์ผ๋ก ์ค์ ์กฐ์ ")
|
| 81 |
+
if gpu_mem >= 30:
|
| 82 |
+
config.micro_batch_size = 4
|
| 83 |
+
elif gpu_mem >= 16:
|
| 84 |
+
config.micro_batch_size = 2
|
| 85 |
+
else:
|
| 86 |
+
config.micro_batch_size = 1
|
| 87 |
+
config.gradient_accumulation_steps = 128
|
| 88 |
+
|
| 89 |
+
print(f" โ dtype: {config.dtype}")
|
| 90 |
+
print(f" โ micro_batch: {config.micro_batch_size}")
|
| 91 |
+
print(f" โ grad_accum: {config.gradient_accumulation_steps}")
|
| 92 |
+
print(f" โ effective_batch: {config.effective_batch_size}")
|
| 93 |
+
|
| 94 |
+
return config
|
llm_lab/utils/seed.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""์ฌํ์ฑ์ ์ํ ์๋ ์ ํธ๋ฆฌํฐ."""
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def set_seed(seed: int = 42):
|
| 6 |
+
"""์ฌํ์ฑ์ ์ํ ์๋ ์ค์ ."""
|
| 7 |
+
torch.manual_seed(seed)
|
| 8 |
+
if torch.cuda.is_available():
|
| 9 |
+
torch.cuda.manual_seed(seed)
|
notebooks/01_data_pipeline.ipynb
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 01. ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"ํ ํฌ๋์ด์ ์ค๋น โ ๋ฐ์ดํฐ ์คํธ๋ฆฌ๋ฐ โ ์ํ์ค ํจํน โ ๋ฐฐ์น ๊ตฌ์ฑ\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**ํ์ดํ๋ผ์ธ ํ๋ฆ:**\n",
|
| 12 |
+
"```\n",
|
| 13 |
+
"FineWeb-Edu (HuggingFace)\n",
|
| 14 |
+
" โ Streaming์ผ๋ก ๋ก๋ (๋์คํฌ ์ ์ฅ ์์)\n",
|
| 15 |
+
" โ ํ ํฌ๋์ด์ง (BPE, vocab=32K)\n",
|
| 16 |
+
" โ ์ํ์ค ํจํน (์ฌ๋ฌ ๋ฌธ์๋ฅผ max_seq_len์ผ๋ก ์ฐ๊ฒฐ)\n",
|
| 17 |
+
" โ ๋ฐฐ์น ๊ตฌ์ฑ (input_ids, targets)\n",
|
| 18 |
+
" โ GPU ์ ์ก\n",
|
| 19 |
+
"```"
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "code",
|
| 24 |
+
"execution_count": null,
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"# ํ์ ํจํค์ง ์ค์น\n",
|
| 29 |
+
"!pip install datasets tokenizers sentencepiece transformers -q"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": null,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"import sys\n",
|
| 39 |
+
"sys.path.insert(0, '..')\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"from llm_lab.config import DataConfig\n",
|
| 42 |
+
"from llm_lab.data import (\n",
|
| 43 |
+
" Tokenizer, setup_data_pipeline,\n",
|
| 44 |
+
" DataPipelineDiagnostics\n",
|
| 45 |
+
")"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "markdown",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"source": [
|
| 52 |
+
"## 1. ๋ฐ์ดํฐ ์ค์ (Config)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"์๋ ๊ฐ๋ค์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํ์ธ์."
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"data_config = DataConfig(\n",
|
| 64 |
+
" dataset_name=\"HuggingFaceFW/fineweb-edu\",\n",
|
| 65 |
+
" dataset_subset=\"sample-10BT\",\n",
|
| 66 |
+
" vocab_size=32_000,\n",
|
| 67 |
+
" max_seq_len=2048,\n",
|
| 68 |
+
" batch_size=4,\n",
|
| 69 |
+
" num_workers=2,\n",
|
| 70 |
+
")\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"print(f\"๋ฐ์ดํฐ์
: {data_config.dataset_name} ({data_config.dataset_subset})\")\n",
|
| 73 |
+
"print(f\"์ํ์ค ๊ธธ์ด: {data_config.max_seq_len}\")\n",
|
| 74 |
+
"print(f\"๋ฐฐ์น ํฌ๊ธฐ: {data_config.batch_size}\")\n",
|
| 75 |
+
"print(f\"ํ ํฐ/๋ฐฐ์น: {data_config.batch_size * data_config.max_seq_len:,}\")"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "markdown",
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"source": [
|
| 82 |
+
"## 2. ํ ํฌ๋์ด์ ์ค์ \n",
|
| 83 |
+
"\n",
|
| 84 |
+
"์ธ ๊ฐ์ง ๋ฐฉ๋ฒ ์ค ์ ํ:\n",
|
| 85 |
+
"- `\"pretrained\"` โ HuggingFace ์ฌ์ ํ์ต ํ ํฌ๋์ด์ (๊ฐ์ฅ ๊ฐํธ)\n",
|
| 86 |
+
"- `\"train_new\"` โ BPE ํ ํฌ๋์ด์ ์๋ก ํ์ต\n",
|
| 87 |
+
"- `\"load_trained\"` โ ์ด์ ์ ํ์ตํ ํ ํฌ๋์ด์ ๋ก๋"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
|
| 97 |
+
" tokenizer_mode=\"pretrained\", # \"train_new\" ๋๋ \"load_trained\"๋ก ๋ณ๊ฒฝ ๊ฐ๋ฅ\n",
|
| 98 |
+
" config=data_config,\n",
|
| 99 |
+
")"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "markdown",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"source": [
|
| 106 |
+
"## 3. ํ์ดํ๋ผ์ธ ์ง๋จ"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": [
|
| 115 |
+
"# ํ ํฌ๋์ด์ ํ์ง ์ง๋จ\n",
|
| 116 |
+
"DataPipelineDiagnostics.check_tokenizer_quality(tokenizer, data_config)"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"cell_type": "code",
|
| 121 |
+
"execution_count": null,
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"outputs": [],
|
| 124 |
+
"source": [
|
| 125 |
+
"# ๋ฐ์ดํฐ ๋ก๋ฉ ์ฒ๋ฆฌ๋ ๋ฒค์น๋งํฌ\n",
|
| 126 |
+
"DataPipelineDiagnostics.benchmark_throughput(train_dl, num_batches=50)"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"cell_type": "markdown",
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"source": [
|
| 133 |
+
"## 4. ๋ฐฐ์น ๊ฒ์ฌ"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"# ์ฒซ ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์์ ์์ธ ๊ฒ์ฌ\n",
|
| 143 |
+
"batch = next(iter(train_dl))\n",
|
| 144 |
+
"DataPipelineDiagnostics.inspect_batch(batch, tokenizer)"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "markdown",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"---\n",
|
| 152 |
+
"**๋ค์ ๋จ๊ณ:** `02_model.ipynb`์์ ๋ชจ๋ธ ์ํคํ
์ฒ๋ฅผ ์์ฑํฉ๋๋ค."
|
| 153 |
+
]
|
| 154 |
+
}
|
| 155 |
+
],
|
| 156 |
+
"metadata": {
|
| 157 |
+
"kernelspec": {
|
| 158 |
+
"display_name": "Python 3",
|
| 159 |
+
"language": "python",
|
| 160 |
+
"name": "python3"
|
| 161 |
+
},
|
| 162 |
+
"language_info": {
|
| 163 |
+
"name": "python",
|
| 164 |
+
"version": "3.10.0"
|
| 165 |
+
}
|
| 166 |
+
},
|
| 167 |
+
"nbformat": 4,
|
| 168 |
+
"nbformat_minor": 4
|
| 169 |
+
}
|
notebooks/02_model.ipynb
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 02. ๋ชจ๋ธ ์ํคํ
์ฒ\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"1.1B ํ๋ผ๋ฏธํฐ LLaMA-style Decoder-Only Transformer ์์ฑ ๋ฐ ๊ฒ์ฆ.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**๋ชจ๋ธ ๊ตฌ์กฐ:**\n",
|
| 12 |
+
"```\n",
|
| 13 |
+
"Input Token IDs\n",
|
| 14 |
+
" โ Token Embedding\n",
|
| 15 |
+
" โ [TransformerBlock] ร num_layers\n",
|
| 16 |
+
" โ โโโ RMSNorm โ GroupedQueryAttention (+ RoPE) โ Residual\n",
|
| 17 |
+
" โ โโโ RMSNorm โ SwiGLU FFN โ Residual\n",
|
| 18 |
+
" โ RMSNorm (์ต์ข
)\n",
|
| 19 |
+
" โ Linear Head (Weight Tying)\n",
|
| 20 |
+
" โ Vocab Logits\n",
|
| 21 |
+
"```"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"import sys\n",
|
| 31 |
+
"sys.path.insert(0, '..')\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"import torch\n",
|
| 34 |
+
"import math\n",
|
| 35 |
+
"from llm_lab.config import ModelConfig\n",
|
| 36 |
+
"from llm_lab.model import LLMModel, count_parameters_detailed, estimate_memory_gb"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"source": [
|
| 43 |
+
"## 1. ๋ชจ๋ธ ์ค์ ์ ํ\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"| ํ๋ฆฌ์
| ํ๋ผ๋ฏธํฐ | ์ฉ๋ |\n",
|
| 46 |
+
"|--------|----------|------|\n",
|
| 47 |
+
"| `debug_10m()` | ~10M | ํ์ดํ๋ผ์ธ ๊ฒ์ฆ |\n",
|
| 48 |
+
"| `small_100m()` | ~100M | ์ค๊ฐ ๊ฒ์ฆ |\n",
|
| 49 |
+
"| `base_1b()` | ~1.1B | ์ต์ข
ํ์ต |"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"# --- ๋ชจ๋ธ ์ค์ผ์ผ ์ ํ ---\n",
|
| 59 |
+
"# model_config = ModelConfig.debug_10m() # ~10M (๋น ๋ฅธ ๊ฒ์ฆ)\n",
|
| 60 |
+
"# model_config = ModelConfig.small_100m() # ~100M (์ค๊ฐ ๊ฒ์ฆ)\n",
|
| 61 |
+
"model_config = ModelConfig.base_1b() # ~1.1B (์ต์ข
๋ชฉํ)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"print(f\"hidden_dim: {model_config.hidden_dim}\")\n",
|
| 64 |
+
"print(f\"num_layers: {model_config.num_layers}\")\n",
|
| 65 |
+
"print(f\"num_heads: {model_config.num_heads}\")\n",
|
| 66 |
+
"print(f\"num_kv_heads: {model_config.num_kv_heads} (GQA ๊ทธ๋ฃน: {model_config.num_kv_groups})\")\n",
|
| 67 |
+
"print(f\"intermediate_dim: {model_config.intermediate_dim}\")\n",
|
| 68 |
+
"print(f\"max_seq_len: {model_config.max_seq_len}\")"
|
| 69 |
+
]
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"cell_type": "markdown",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"source": [
|
| 75 |
+
"## 2. ๋ชจ๋ธ ์์ฑ ๋ฐ ํ๋ผ๋ฏธํฐ ํ์ธ"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": null,
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": [
|
| 84 |
+
"# Debug ๋ชจ๋ธ ์ค์ ์์ฑ (๋ฉ๋ชจ๋ฆฌ ํ์ธ ์ฉ๋)\n",
|
| 85 |
+
"debug_config = ModelConfig.debug_10m()\n",
|
| 86 |
+
"model = LLMModel(debug_config)\n",
|
| 87 |
+
"print(f\"Debug ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์: {model.count_parameters():,}\")\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"# 1B ๋ชจ๋ธ์ meta device์์ ํ๋ผ๋ฏธํฐ ์๋ง ํ์ธ\n",
|
| 90 |
+
"with torch.device(\"meta\"):\n",
|
| 91 |
+
" model_1b = LLMModel(ModelConfig.base_1b())\n",
|
| 92 |
+
"n_params_1b = model_1b.count_parameters()\n",
|
| 93 |
+
"print(f\"1B ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์: {n_params_1b:,} ({n_params_1b/1e9:.2f}B)\")"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "markdown",
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"source": [
|
| 100 |
+
"## 3. ์์ธ ํ๋ผ๋ฏธํฐ ๋ถํด"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "code",
|
| 105 |
+
"execution_count": null,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"detail = count_parameters_detailed(model_1b)\n",
|
| 110 |
+
"cfg_1b = ModelConfig.base_1b()\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"print(f\"Token Embedding: {detail['token_embedding']:,}\")\n",
|
| 113 |
+
"print(f\"Per Layer Total: {detail['per_layer_total']:,}\")\n",
|
| 114 |
+
"print(f\"All Layers ({cfg_1b.num_layers}): {detail['all_layers_total']:,}\")\n",
|
| 115 |
+
"print(f\"Final Norm: {detail['final_norm']:,}\")\n",
|
| 116 |
+
"print(f\"LM Head: {detail['lm_head']}\")\n",
|
| 117 |
+
"print(f\"{'โ' * 30}\")\n",
|
| 118 |
+
"print(f\"TOTAL: {detail['total']:,}\")"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "markdown",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"source": [
|
| 125 |
+
"## 4. GPU ๋ฉ๋ชจ๋ฆฌ ์ถ์ "
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"mem = estimate_memory_gb(ModelConfig.base_1b(), batch_size=4, dtype_bytes=2)\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"print(f\"๋ชจ๋ธ ๊ฐ์ค์น: {mem['model_weights_gb']} GB\")\n",
|
| 137 |
+
"print(f\"์ตํฐ๋ง์ด์ : {mem['optimizer_states_gb']} GB\")\n",
|
| 138 |
+
"print(f\"๊ธฐ์ธ๊ธฐ: {mem['gradients_gb']} GB\")\n",
|
| 139 |
+
"print(f\"ํ์ฑํ (์ถ์ ): {mem['activations_estimated_gb']} GB\")\n",
|
| 140 |
+
"print(f\"{'โ' * 30}\")\n",
|
| 141 |
+
"print(f\"์ด ์ถ์ : {mem['total_estimated_gb']} GB\")"
|
| 142 |
+
]
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"cell_type": "markdown",
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"source": [
|
| 148 |
+
"## 5. Forward Pass ๊ฒ์ฆ"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "code",
|
| 153 |
+
"execution_count": null,
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"outputs": [],
|
| 156 |
+
"source": [
|
| 157 |
+
"# Debug ๋ชจ๋ธ๋ก forward/backward ๊ฒ์ฆ\n",
|
| 158 |
+
"dummy_input = torch.randint(0, debug_config.vocab_size, (2, 64))\n",
|
| 159 |
+
"dummy_target = torch.randint(0, debug_config.vocab_size, (2, 64))\n",
|
| 160 |
+
"logits, loss = model(dummy_input, dummy_target)\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"print(f\"Input shape: {dummy_input.shape}\")\n",
|
| 163 |
+
"print(f\"Logits shape: {logits.shape}\")\n",
|
| 164 |
+
"print(f\"Loss: {loss.item():.4f}\")\n",
|
| 165 |
+
"expected_loss = math.log(debug_config.vocab_size)\n",
|
| 166 |
+
"print(f\"Expected initial loss (ln({debug_config.vocab_size})): {expected_loss:.2f}\")"
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"cell_type": "markdown",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"source": [
|
| 173 |
+
"## 6. ํ
์คํธ ์์ฑ ํ
์คํธ (๋๋ค ๊ฐ์ค์น)"
|
| 174 |
+
]
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"cell_type": "code",
|
| 178 |
+
"execution_count": null,
|
| 179 |
+
"metadata": {},
|
| 180 |
+
"outputs": [],
|
| 181 |
+
"source": [
|
| 182 |
+
"prompt = torch.randint(0, debug_config.vocab_size, (1, 10))\n",
|
| 183 |
+
"generated = model.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=50)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"print(f\"Prompt length: {prompt.shape[1]}\")\n",
|
| 186 |
+
"print(f\"Generated length: {generated.shape[1]}\")\n",
|
| 187 |
+
"print(f\"Token IDs: {generated[0].tolist()}\")"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "markdown",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"source": [
|
| 194 |
+
"---\n",
|
| 195 |
+
"**๋ค์ ๋จ๊ณ:** `03_training.ipynb`์์ ํ์ต์ ์คํํฉ๋๋ค."
|
| 196 |
+
]
|
| 197 |
+
}
|
| 198 |
+
],
|
| 199 |
+
"metadata": {
|
| 200 |
+
"kernelspec": {
|
| 201 |
+
"display_name": "Python 3",
|
| 202 |
+
"language": "python",
|
| 203 |
+
"name": "python3"
|
| 204 |
+
},
|
| 205 |
+
"language_info": {
|
| 206 |
+
"name": "python",
|
| 207 |
+
"version": "3.10.0"
|
| 208 |
+
}
|
| 209 |
+
},
|
| 210 |
+
"nbformat": 4,
|
| 211 |
+
"nbformat_minor": 4
|
| 212 |
+
}
|
notebooks/03_training.ipynb
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 03. ํ์ต (Training)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Gradient Accumulation, Mixed Precision, Cosine LR Scheduling,\n",
|
| 10 |
+
"์ฒดํฌํฌ์ธํธ ์ ์ฅ/๋ณต์, wandb ๋ก๊น
์ ํฌํจํ ํ์ต ํ์ดํ๋ผ์ธ.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"**ํ์ต ํ๋ฆ:**\n",
|
| 13 |
+
"```\n",
|
| 14 |
+
"๋ฐฐ์น ๊ฐ์ ธ์ค๊ธฐ\n",
|
| 15 |
+
" โ Forward (bf16 autocast)\n",
|
| 16 |
+
" โ Loss / accumulation_steps\n",
|
| 17 |
+
" โ Backward (gradient ๋์ )\n",
|
| 18 |
+
" โ [accumulation_steps๋ง๋ค] Gradient Clipping โ Optimizer Step โ LR Update\n",
|
| 19 |
+
" โ [checkpoint_interval๋ง๋ค] ์ฒดํฌํฌ์ธํธ ์ ์ฅ (Google Drive)\n",
|
| 20 |
+
" โ [eval_interval๋ง๋ค] ๊ฒ์ฆ Loss/Perplexity ์ธก์ \n",
|
| 21 |
+
"```"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"!pip install wandb -q"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"import sys\n",
|
| 40 |
+
"sys.path.insert(0, '..')\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n",
|
| 43 |
+
"from llm_lab.model import LLMModel\n",
|
| 44 |
+
"from llm_lab.data import setup_data_pipeline\n",
|
| 45 |
+
"from llm_lab.training import start_training, Trainer\n",
|
| 46 |
+
"from llm_lab.utils import auto_configure, get_device"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"source": [
|
| 53 |
+
"## 0. Google Drive ๋ง์ดํธ (Colab)\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"์ฒดํฌํฌ์ธํธ๋ฅผ Google Drive์ ์ ์ฅํ์ฌ ์ธ์
๋ง๋ฃ ์์๋ ๋ณด์กดํฉ๋๋ค."
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"# Colab์์๋ง ์คํ\n",
|
| 65 |
+
"# from google.colab import drive\n",
|
| 66 |
+
"# drive.mount('/content/drive')"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "markdown",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"## 1. ์ค์ "
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"execution_count": null,
|
| 79 |
+
"metadata": {},
|
| 80 |
+
"outputs": [],
|
| 81 |
+
"source": [
|
| 82 |
+
"# --- ๋ชจ๋ธ ์ค์ ---\n",
|
| 83 |
+
"model_config = ModelConfig.debug_10m() # ๊ฒ์ฆ ์ debug, ์ค์ ํ์ต ์ base_1b()\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# --- ๋ฐ์ดํฐ ์ค์ ---\n",
|
| 86 |
+
"data_config = DataConfig(\n",
|
| 87 |
+
" max_seq_len=model_config.max_seq_len,\n",
|
| 88 |
+
" batch_size=4,\n",
|
| 89 |
+
")\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# --- ํ์ต ์ค์ ---\n",
|
| 92 |
+
"train_config = TrainConfig(\n",
|
| 93 |
+
" total_steps=20_000,\n",
|
| 94 |
+
" warmup_steps=2_000,\n",
|
| 95 |
+
" learning_rate=3e-4,\n",
|
| 96 |
+
" min_learning_rate=3e-5,\n",
|
| 97 |
+
" weight_decay=0.1,\n",
|
| 98 |
+
" grad_clip=1.0,\n",
|
| 99 |
+
" micro_batch_size=4,\n",
|
| 100 |
+
" gradient_accumulation_steps=32,\n",
|
| 101 |
+
" checkpoint_dir=\"/content/drive/MyDrive/llm-1b-lab/checkpoints\",\n",
|
| 102 |
+
" checkpoint_interval=500,\n",
|
| 103 |
+
" eval_interval=500,\n",
|
| 104 |
+
" log_interval=10,\n",
|
| 105 |
+
" use_wandb=True,\n",
|
| 106 |
+
")\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"print(f\"Effective batch size: {train_config.effective_batch_size}\")\n",
|
| 109 |
+
"print(f\"Total steps: {train_config.total_steps:,}\")"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "markdown",
|
| 114 |
+
"metadata": {},
|
| 115 |
+
"source": [
|
| 116 |
+
"## 2. GPU ์๋ ๊ฐ์ง\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"GPU ์ข
๋ฅ(A100/V100/T4/L4)์ ๋ฐ๋ผ dtype, batch_size, gradient_accumulation์ ์๋ ์กฐ์ ํฉ๋๋ค."
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": [
|
| 127 |
+
"train_config = auto_configure(train_config)"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "markdown",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"source": [
|
| 134 |
+
"## 3. ๋ชจ๋ธ + ๋ฐ์ดํฐ ์ด๊ธฐํ"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"# ๋ชจ๋ธ ์์ฑ\n",
|
| 144 |
+
"model = LLMModel(model_config)\n",
|
| 145 |
+
"print(f\"๋ชจ๋ธ ํ๋ผ๋ฏธํฐ: {model.count_parameters():,}\")\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"# ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ\n",
|
| 148 |
+
"tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
|
| 149 |
+
" tokenizer_mode=\"pretrained\",\n",
|
| 150 |
+
" config=data_config,\n",
|
| 151 |
+
")"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"source": [
|
| 158 |
+
"## 4. ํ์ต ์์\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"์ฒดํฌํฌ์ธํธ๊ฐ ์์ผ๋ฉด ์๋์ผ๋ก ๋ณต์ํ์ฌ ์ด์ด์ ํ์ตํฉ๋๋ค."
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"trainer = start_training(\n",
|
| 170 |
+
" model=model,\n",
|
| 171 |
+
" train_dataloader=train_dl,\n",
|
| 172 |
+
" val_dataloader=val_dl,\n",
|
| 173 |
+
" config=train_config,\n",
|
| 174 |
+
" seq_len=model_config.max_seq_len,\n",
|
| 175 |
+
")"
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "markdown",
|
| 180 |
+
"metadata": {},
|
| 181 |
+
"source": [
|
| 182 |
+
"## 5. ํ์ต ์ฌ๊ฐ (์ธ์
๋ง๋ฃ ํ)\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"Colab ์ธ์
์ด ๋ง๋ฃ๋ ํ ๋ค์ ์คํํ๋ฉด CheckpointManager๊ฐ ์๋์ผ๋ก ์ต์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ๋ณต์ํฉ๋๋ค.\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"์์ ์
๋ค์ ์์๋๋ก ๋ค์ ์คํํ๋ฉด ๋ฉ๋๋ค."
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "markdown",
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"source": [
|
| 193 |
+
"---\n",
|
| 194 |
+
"**๋ค์ ๋จ๊ณ:** `04_evaluation.ipynb`์์ ํ์ต๋ ๋ชจ๋ธ์ ํ๊ฐํฉ๋๋ค."
|
| 195 |
+
]
|
| 196 |
+
}
|
| 197 |
+
],
|
| 198 |
+
"metadata": {
|
| 199 |
+
"kernelspec": {
|
| 200 |
+
"display_name": "Python 3",
|
| 201 |
+
"language": "python",
|
| 202 |
+
"name": "python3"
|
| 203 |
+
},
|
| 204 |
+
"language_info": {
|
| 205 |
+
"name": "python",
|
| 206 |
+
"version": "3.10.0"
|
| 207 |
+
}
|
| 208 |
+
},
|
| 209 |
+
"nbformat": 4,
|
| 210 |
+
"nbformat_minor": 4
|
| 211 |
+
}
|
notebooks/04_evaluation.ipynb
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 04. ํ๊ฐ (Evaluation)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"ํ์ต๋ ๋ชจ๋ธ์ ํ์ง์ ๋ค๊ฐ๋๋ก ํ๊ฐํฉ๋๋ค.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**ํ๊ฐ ์์ญ:**\n",
|
| 12 |
+
"1. Perplexity ์ธก์ โ ์ธ์ด ๋ชจ๋ธ์ ํ์ค ์ ๋ ์งํ\n",
|
| 13 |
+
"2. ํ
์คํธ ์์ฑ ํ์ง โ ๋ค์ํ ํ๋กฌํํธ๋ก ์ ์ฑ์ ํ๊ฐ\n",
|
| 14 |
+
"3. Scaling Law ๋ถ์ โ 10M โ 100M โ 1B ๋น๊ต\n",
|
| 15 |
+
"4. Attention ์๊ฐํ โ ๋ชจ๋ธ์ด \"์ด๋๋ฅผ ๋ณด๋์ง\" ๋ถ์\n",
|
| 16 |
+
"5. ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ โ ํ์ต ๋ชฉํ ๋ฌ์ฑ ํ์ธ"
|
| 17 |
+
]
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"execution_count": null,
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"outputs": [],
|
| 24 |
+
"source": [
|
| 25 |
+
"!pip install matplotlib numpy -q"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"import sys\n",
|
| 35 |
+
"sys.path.insert(0, '..')\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"import torch\n",
|
| 38 |
+
"from llm_lab.config import ModelConfig, EvalConfig\n",
|
| 39 |
+
"from llm_lab.model import LLMModel\n",
|
| 40 |
+
"from llm_lab.evaluation import (\n",
|
| 41 |
+
" run_evaluation, PerplexityEvaluator, GenerationEvaluator,\n",
|
| 42 |
+
" ScalingAnalyzer, AttentionVisualizer, InsightChecklist\n",
|
| 43 |
+
")\n",
|
| 44 |
+
"from llm_lab.utils import get_device"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"source": [
|
| 51 |
+
"## 1. ๋ชจ๋ธ ๋ก๋\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"ํ์ต๋ ์ฒดํฌํฌ์ธํธ์์ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ๋ก๋ํฉ๋๋ค."
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"device = get_device()\n",
|
| 63 |
+
"model_config = ModelConfig.base_1b()\n",
|
| 64 |
+
"model = LLMModel(model_config).to(device)\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"# ์ฒดํฌํฌ์ธํธ ๋ก๋ (๊ฒฝ๋ก๋ฅผ ์ค์ ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก๋ก ๋ณ๊ฒฝ)\n",
|
| 67 |
+
"# ckpt = torch.load(\"path/to/step_XXXXXX/model.pt\", map_location=device)\n",
|
| 68 |
+
"# model.load_state_dict(ckpt)\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"print(f\"๋ชจ๋ธ ํ๋ผ๋ฏธํฐ: {model.count_parameters():,}\")\n",
|
| 71 |
+
"print(f\"๋๋ฐ์ด์ค: {device}\")"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "markdown",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"source": [
|
| 78 |
+
"## 2. ์ข
ํฉ ํ๊ฐ (ํ ์ค ์คํ)\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"Perplexity, ํ
์คํธ ์์ฑ, ํ์ต ์ญํ, Attention ์๊ฐํ๋ฅผ ํ ๋ฒ์ ์คํํฉ๋๋ค."
|
| 81 |
+
]
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"cell_type": "code",
|
| 85 |
+
"execution_count": null,
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"outputs": [],
|
| 88 |
+
"source": [
|
| 89 |
+
"# ํ์ต ์ ์ฌ์ฉํ tokenizer, val_dl, metrics_history๊ฐ ํ์ํฉ๋๋ค\n",
|
| 90 |
+
"# report = run_evaluation(\n",
|
| 91 |
+
"# model=model,\n",
|
| 92 |
+
"# tokenizer=tokenizer,\n",
|
| 93 |
+
"# val_dataloader=val_dl,\n",
|
| 94 |
+
"# metrics_history=trainer.metrics.history,\n",
|
| 95 |
+
"# )"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "markdown",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## 3. Scaling Law ๋ถ์\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"10M โ 100M โ 1B ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋น๊ตํ์ฌ Scaling Law๋ฅผ ํ์ธํฉ๋๋ค.\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"Chinchilla Scaling Law: ์ต์ ํ์ต ํ ํฐ ์ โ 20 ร ํ๋ผ๋ฏธํฐ ์"
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": [
|
| 115 |
+
"analyzer = ScalingAnalyzer()\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"# ๊ฐ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๋ฅผ ์
๋ ฅ (์ค์ ํ์ต ๊ฒฐ๊ณผ๋ก ๋์ฒด)\n",
|
| 118 |
+
"scaling_results = [\n",
|
| 119 |
+
" {\"name\": \"10M\", \"params\": 10e6, \"tokens\": 1e9, \"loss\": 4.2, \"ppl\": 66.7},\n",
|
| 120 |
+
" {\"name\": \"100M\", \"params\": 100e6, \"tokens\": 5e9, \"loss\": 3.5, \"ppl\": 33.1},\n",
|
| 121 |
+
" {\"name\": \"1B\", \"params\": 1.1e9, \"tokens\": 10e9, \"loss\": 3.0, \"ppl\": 20.1},\n",
|
| 122 |
+
"]\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"analysis = analyzer.analyze(scaling_results)\n",
|
| 125 |
+
"analyzer.plot_scaling_curves(scaling_results)"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": [
|
| 132 |
+
"## 4. Attention ์๊ฐํ\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"๋ชจ๋ธ์ด ๊ฐ ํ ํฐ์ ๋ํด \"์ด๋๋ฅผ ๋ณด๋์ง\" ์๊ฐํํฉ๋๋ค."
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"# viz = AttentionVisualizer()\n",
|
| 144 |
+
"# sample_text = \"The cat sat on the mat and looked at the bird.\"\n",
|
| 145 |
+
"# token_ids = tokenizer.encode(sample_text)\n",
|
| 146 |
+
"# input_tensor = torch.tensor([token_ids], dtype=torch.long)\n",
|
| 147 |
+
"# \n",
|
| 148 |
+
"# attn_weights = viz.extract_attention(model, input_tensor, layer_idx=0, device=device)\n",
|
| 149 |
+
"# if attn_weights is not None:\n",
|
| 150 |
+
"# tokens_str = [tokenizer.decode([tid]) for tid in token_ids]\n",
|
| 151 |
+
"# viz.plot_attention_heatmap(attn_weights, tokens_str, head_idx=0)\n",
|
| 152 |
+
"# viz.plot_multi_head_summary(attn_weights)"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "markdown",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"source": [
|
| 159 |
+
"## 5. ์ธ์ฌ์ดํธ ์ฒดํฌ๋ฆฌ์คํธ\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"ํ์ต ๋ชฉํ ๋ฌ์ฑ ์ฌ๋ถ๋ฅผ ์๋/์๋์ผ๋ก ํ์ธํฉ๋๋ค."
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": null,
|
| 167 |
+
"metadata": {},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"# report๊ฐ ์๋ ๊ฒฝ์ฐ ์ฒดํฌ๋ฆฌ์คํธ ์คํ\n",
|
| 171 |
+
"# InsightChecklist.run_checklist(report, metrics_history)"
|
| 172 |
+
]
|
| 173 |
+
}
|
| 174 |
+
],
|
| 175 |
+
"metadata": {
|
| 176 |
+
"kernelspec": {
|
| 177 |
+
"display_name": "Python 3",
|
| 178 |
+
"language": "python",
|
| 179 |
+
"name": "python3"
|
| 180 |
+
},
|
| 181 |
+
"language_info": {
|
| 182 |
+
"name": "python",
|
| 183 |
+
"version": "3.10.0"
|
| 184 |
+
}
|
| 185 |
+
},
|
| 186 |
+
"nbformat": 4,
|
| 187 |
+
"nbformat_minor": 4
|
| 188 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
datasets
|
| 3 |
+
tokenizers
|
| 4 |
+
sentencepiece
|
| 5 |
+
transformers
|
| 6 |
+
wandb
|
| 7 |
+
matplotlib
|
| 8 |
+
numpy
|