Vjeong Claude Opus 4.6 commited on
Commit
8a58ffe
ยท
1 Parent(s): f494c9e

Initial commit: LLM-1B-Lab project setup

Browse files

LLaMA-style 1.1B parameter Decoder-Only Transformer for educational purposes.
Includes modularized llm_lab package, notebooks, and configuration files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (50) hide show
  1. .gitignore +47 -0
  2. CLAUDE.md +131 -0
  3. LLM_Foundation_Model.code-workspace +8 -0
  4. _archive/llm-1b-data-pipeline.py +906 -0
  5. _archive/llm-1b-evaluation.py +1455 -0
  6. _archive/llm-1b-model.py +791 -0
  7. _archive/llm-1b-trainer.py +1108 -0
  8. llm_lab/__init__.py +30 -0
  9. llm_lab/config/__init__.py +7 -0
  10. llm_lab/config/data_config.py +41 -0
  11. llm_lab/config/eval_config.py +20 -0
  12. llm_lab/config/model_config.py +53 -0
  13. llm_lab/config/train_config.py +114 -0
  14. llm_lab/data/__init__.py +11 -0
  15. llm_lab/data/dataset.py +218 -0
  16. llm_lab/data/diagnostics.py +153 -0
  17. llm_lab/data/pipeline.py +156 -0
  18. llm_lab/data/tokenizer.py +196 -0
  19. llm_lab/evaluation/__init__.py +21 -0
  20. llm_lab/evaluation/attention_viz.py +176 -0
  21. llm_lab/evaluation/checklist.py +99 -0
  22. llm_lab/evaluation/dynamics.py +242 -0
  23. llm_lab/evaluation/full_evaluator.py +222 -0
  24. llm_lab/evaluation/generation.py +200 -0
  25. llm_lab/evaluation/perplexity.py +172 -0
  26. llm_lab/evaluation/runner.py +56 -0
  27. llm_lab/evaluation/scaling.py +153 -0
  28. llm_lab/model/__init__.py +14 -0
  29. llm_lab/model/attention.py +134 -0
  30. llm_lab/model/feedforward.py +48 -0
  31. llm_lab/model/llm_model.py +200 -0
  32. llm_lab/model/norm.py +40 -0
  33. llm_lab/model/rope.py +103 -0
  34. llm_lab/model/transformer_block.py +65 -0
  35. llm_lab/model/utils.py +85 -0
  36. llm_lab/training/__init__.py +12 -0
  37. llm_lab/training/checkpoint.py +159 -0
  38. llm_lab/training/metrics.py +112 -0
  39. llm_lab/training/optimizer.py +54 -0
  40. llm_lab/training/runner.py +68 -0
  41. llm_lab/training/scheduler.py +68 -0
  42. llm_lab/training/trainer.py +351 -0
  43. llm_lab/utils/__init__.py +5 -0
  44. llm_lab/utils/device.py +94 -0
  45. llm_lab/utils/seed.py +9 -0
  46. notebooks/01_data_pipeline.ipynb +169 -0
  47. notebooks/02_model.ipynb +212 -0
  48. notebooks/03_training.ipynb +211 -0
  49. notebooks/04_evaluation.ipynb +188 -0
  50. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.egg-info/
6
+ *.egg
7
+ dist/
8
+ build/
9
+ *.so
10
+
11
+ # Virtual environments
12
+ venv/
13
+ .venv/
14
+ env/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
21
+ *~
22
+
23
+ # Jupyter Notebook
24
+ .ipynb_checkpoints/
25
+
26
+ # OS
27
+ .DS_Store
28
+ Thumbs.db
29
+
30
+ # ML / Training artifacts
31
+ *.pt
32
+ *.pth
33
+ *.bin
34
+ *.ckpt
35
+ checkpoints/
36
+ wandb/
37
+ runs/
38
+
39
+ # Data
40
+ *.log
41
+ *.csv
42
+ *.tsv
43
+ data/
44
+
45
+ # Secrets
46
+ .env
47
+ *.key
CLAUDE.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM-1B-Lab
2
+
3
+ 1.1B parameter LLaMA-style Decoder-Only Transformer ๊ต์œก์šฉ ๊ตฌํ˜„.
4
+ ๋”ฅ๋Ÿฌ๋‹ ์ดˆ๋ณด์ž๊ฐ€ ์ฒ˜์Œ๋ถ€ํ„ฐ ๋๊นŒ์ง€ LLM์„ ํ•™์Šตํ•˜๊ณ  ํ‰๊ฐ€ํ•˜๋Š” ๊ณผ์ •์„ ๊ฒฝํ—˜ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋จ.
5
+
6
+ ## ํ”„๋กœ์ ํŠธ ๊ตฌ์กฐ
7
+
8
+ ```
9
+ LLM_Foundation_Model/
10
+ โ”œโ”€โ”€ CLAUDE.md
11
+ โ”œโ”€โ”€ requirements.txt
12
+ โ”œโ”€โ”€ llm_lab/ # Python ํŒจํ‚ค์ง€ (ํ•ต์‹ฌ ์ฝ”๋“œ)
13
+ โ”‚ โ”œโ”€โ”€ __init__.py
14
+ โ”‚ โ”œโ”€โ”€ config/ # ์„ค์ • ๋ฐ์ดํ„ฐํด๋ž˜์Šค
15
+ โ”‚ โ”‚ โ”œโ”€โ”€ model_config.py # ModelConfig (debug_10m / small_100m / base_1b ํ”„๋ฆฌ์…‹)
16
+ โ”‚ โ”‚ โ”œโ”€โ”€ data_config.py # DataConfig (๋ฐ์ดํ„ฐ์…‹, ํ† ํฌ๋‚˜์ด์ €, ๋ฐฐ์น˜ ์„ค์ •)
17
+ โ”‚ โ”‚ โ”œโ”€โ”€ train_config.py # TrainConfig (LR, ์Šค์ผ€์ค„๋Ÿฌ, ์ฒดํฌํฌ์ธํŠธ, wandb)
18
+ โ”‚ โ”‚ โ””โ”€โ”€ eval_config.py # EvalConfig (ํ‰๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ)
19
+ โ”‚ โ”œโ”€โ”€ model/ # ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜
20
+ โ”‚ โ”‚ โ”œโ”€โ”€ norm.py # RMSNorm
21
+ โ”‚ โ”‚ โ”œโ”€โ”€ rope.py # RotaryPositionalEmbedding (RoPE)
22
+ โ”‚ โ”‚ โ”œโ”€โ”€ attention.py # GroupedQueryAttention (GQA)
23
+ โ”‚ โ”‚ โ”œโ”€โ”€ feedforward.py # SwiGLUFeedForward
24
+ โ”‚ โ”‚ โ”œโ”€โ”€ transformer_block.py # TransformerBlock (Pre-LN)
25
+ โ”‚ โ”‚ โ”œโ”€โ”€ llm_model.py # LLMModel (์ „์ฒด ๋ชจ๋ธ + generate)
26
+ โ”‚ โ”‚ โ””โ”€โ”€ utils.py # count_parameters_detailed, estimate_memory_gb
27
+ โ”‚ โ”œโ”€โ”€ data/ # ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ
28
+ โ”‚ โ”‚ โ”œโ”€โ”€ tokenizer.py # Tokenizer (SentencePiece / BPE / HuggingFace)
29
+ โ”‚ โ”‚ โ”œโ”€โ”€ dataset.py # PackedStreamingDataset, ValidationDataset, _collate_fn
30
+ โ”‚ โ”‚ โ”œโ”€โ”€ pipeline.py # create_train_dataloader, setup_data_pipeline
31
+ โ”‚ โ”‚ โ””โ”€โ”€ diagnostics.py # DataPipelineDiagnostics
32
+ โ”‚ โ”œโ”€โ”€ training/ # ํ•™์Šต ๋ฃจํ”„
33
+ โ”‚ โ”‚ โ”œโ”€โ”€ scheduler.py # CosineWarmupScheduler
34
+ โ”‚ โ”‚ โ”œโ”€โ”€ checkpoint.py # CheckpointManager (Google Drive ์ง€์›)
35
+ โ”‚ โ”‚ โ”œโ”€โ”€ metrics.py # MetricsTracker (wandb ์—ฐ๋™)
36
+ โ”‚ โ”‚ โ”œโ”€โ”€ optimizer.py # create_optimizer (weight decay ๋ถ„๋ฆฌ)
37
+ โ”‚ โ”‚ โ”œโ”€โ”€ trainer.py # Trainer (gradient accumulation, mixed precision)
38
+ โ”‚ โ”‚ โ””โ”€โ”€ runner.py # start_training (ํ•œ ์ค„ ์‹คํ–‰ ํ—ฌํผ)
39
+ โ”‚ โ”œโ”€โ”€ evaluation/ # ํ‰๊ฐ€ & ๋ถ„์„
40
+ โ”‚ โ”‚ โ”œโ”€โ”€ perplexity.py # PerplexityEvaluator (์œ„์น˜๋ณ„ Loss ํฌํ•จ)
41
+ โ”‚ โ”‚ โ”œโ”€โ”€ generation.py # GenerationEvaluator (๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ)
42
+ โ”‚ โ”‚ โ”œโ”€โ”€ scaling.py # ScalingAnalyzer (Chinchilla Scaling Law)
43
+ โ”‚ โ”‚ โ”œโ”€โ”€ dynamics.py # TrainingDynamicsAnalyzer (Loss/LR/Grad ์‹œ๊ฐํ™”)
44
+ โ”‚ โ”‚ โ”œโ”€โ”€ attention_viz.py # AttentionVisualizer (ํ—ค๋“œ๋ณ„ heatmap)
45
+ โ”‚ โ”‚ โ”œโ”€โ”€ full_evaluator.py # FullEvaluator (์ข…ํ•ฉ ํ‰๊ฐ€ + ๋ฆฌํฌํŠธ)
46
+ โ”‚ โ”‚ โ”œโ”€โ”€ checklist.py # InsightChecklist (ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ)
47
+ โ”‚ โ”‚ โ””โ”€โ”€ runner.py # run_evaluation (ํ•œ ์ค„ ์‹คํ–‰ ํ—ฌํผ)
48
+ โ”‚ โ””โ”€โ”€ utils/ # ๊ณตํ†ต ์œ ํ‹ธ๋ฆฌํ‹ฐ
49
+ โ”‚ โ”œโ”€โ”€ device.py # auto_configure, get_device, detect_gpu_info
50
+ โ”‚ โ””โ”€โ”€ seed.py # set_seed
51
+ โ”œโ”€โ”€ notebooks/ # Jupyter ๋…ธํŠธ๋ถ (์„ค์ • + ์‹คํ–‰)
52
+ โ”‚ โ”œโ”€โ”€ 01_data_pipeline.ipynb
53
+ โ”‚ โ”œโ”€โ”€ 02_model.ipynb
54
+ โ”‚ โ”œโ”€โ”€ 03_training.ipynb
55
+ โ”‚ โ””โ”€โ”€ 04_evaluation.ipynb
56
+ โ””โ”€โ”€ _archive/ # ์›๋ณธ ๋‹จ์ผํŒŒ์ผ ๋ฐฑ์—…
57
+ โ”œโ”€โ”€ llm-1b-model.py
58
+ โ”œโ”€โ”€ llm-1b-data-pipeline.py
59
+ โ”œโ”€โ”€ llm-1b-trainer.py
60
+ โ””โ”€โ”€ llm-1b-evaluation.py
61
+ ```
62
+
63
+ ## ๊ธฐ์ˆ  ์Šคํƒ
64
+
65
+ - **๋ชจ๋ธ**: LLaMA-style Decoder-Only Transformer (RMSNorm, RoPE, GQA, SwiGLU, Weight Tying)
66
+ - **ํ•™์Šต**: Gradient Accumulation, Mixed Precision (bf16/fp16), Cosine LR + Warmup, Activation Checkpointing
67
+ - **๋ฐ์ดํ„ฐ**: HuggingFace Streaming (FineWeb-Edu), BPE ํ† ํฌ๋‚˜์ด์ €, ์‹œํ€€์Šค ํŒจํ‚น
68
+ - **์ฒดํฌํฌ์ธํŠธ**: Google Drive ์ž๋™ ์ €์žฅ/๋ณต์› (Colab Pro+ ํ™˜๊ฒฝ)
69
+ - **ํ‰๊ฐ€**: Perplexity, ํ…์ŠคํŠธ ์ƒ์„ฑ, Scaling Law, Attention ์‹œ๊ฐํ™”
70
+ - **ํƒ€๊ฒŸ ํ™˜๊ฒฝ**: Google Colab Pro+ (A100 40GB)
71
+
72
+ ## ์˜์กด์„ฑ ๊ทธ๋ž˜ํ”„ (์ˆœํ™˜ ์—†์Œ)
73
+
74
+ ```
75
+ config (์˜์กด์„ฑ ์—†์Œ)
76
+ โ†“
77
+ utils โ†’ config
78
+ โ†“
79
+ model โ†’ config
80
+ โ†“
81
+ data โ†’ config
82
+ โ†“
83
+ training โ†’ config, utils
84
+ โ†“
85
+ evaluation โ†’ config
86
+ ```
87
+
88
+ ## ๋ชจ๋ธ ํ”„๋ฆฌ์…‹
89
+
90
+ | ํ”„๋ฆฌ์…‹ | ํŒŒ๋ผ๋ฏธํ„ฐ | dim | layers | heads | kv_heads | ์šฉ๋„ |
91
+ |--------|---------|-----|--------|-------|----------|------|
92
+ | `debug_10m` | ~10M | 256 | 6 | 8 | 4 | ๋น ๋ฅธ ๊ฒ€์ฆ/๋””๋ฒ„๊ทธ |
93
+ | `small_100m` | ~100M | 768 | 12 | 12 | 4 | ์ค‘๊ฐ„ ์‹คํ—˜ |
94
+ | `base_1b` | ~1.1B | 2048 | 22 | 32 | 8 | ๋ณธ๊ฒฉ ํ•™์Šต |
95
+
96
+ ## Quick Start
97
+
98
+ ```python
99
+ from llm_lab.config import ModelConfig, DataConfig, TrainConfig
100
+ from llm_lab.model import LLMModel
101
+ from llm_lab.data import setup_data_pipeline
102
+ from llm_lab.training import start_training
103
+ from llm_lab.evaluation import run_evaluation
104
+
105
+ # 1. ๋ชจ๋ธ
106
+ model = LLMModel(ModelConfig.base_1b())
107
+
108
+ # 2. ๋ฐ์ดํ„ฐ
109
+ tok, train_dl, val_dl = setup_data_pipeline("pretrained")
110
+
111
+ # 3. ํ•™์Šต
112
+ trainer = start_training(model, train_dl, val_dl)
113
+
114
+ # 4. ํ‰๊ฐ€
115
+ report = run_evaluation(model, tok, val_dl,
116
+ metrics_history=trainer.metrics.history)
117
+ ```
118
+
119
+ ## ์ฝ”๋“œ ์ปจ๋ฒค์…˜
120
+
121
+ - **์–ธ์–ด**: ์ฝ”๋“œ๋Š” ์˜์–ด, ์ฃผ์„/๋…์ŠคํŠธ๋ง์€ ํ•œ๊ตญ์–ด (๊ต์œก์  ์„ค๋ช… ํฌํ•จ)
122
+ - **ํƒ€์ž… ํžŒํŠธ**: ๋ชจ๋“  ํ•จ์ˆ˜์— typing ์–ด๋…ธํ…Œ์ด์…˜ ์‚ฌ์šฉ
123
+ - **import ์ˆœ์„œ**: stdlib โ†’ torch โ†’ llm_lab (์ ˆ๋Œ€ ๊ฒฝ๋กœ) โ†’ ๋กœ์ปฌ (์ƒ๋Œ€ ๊ฒฝ๋กœ)
124
+ - **๋ฐ์ดํ„ฐํด๋ž˜์Šค**: ๋ชจ๋“  ์„ค์ •์€ `@dataclass`๋กœ ์ •์˜, ๊ธฐ๋ณธ๊ฐ’ ํฌํ•จ
125
+ - **์—๋Ÿฌ ์ฒ˜๋ฆฌ**: ์™ธ๋ถ€ ์˜์กด์„ฑ(matplotlib, wandb ๋“ฑ)์€ `try/except ImportError`๋กœ ์„ ํƒ์  ์‚ฌ์šฉ
126
+
127
+ ## ์ฃผ์˜์‚ฌํ•ญ
128
+
129
+ - `torch`๋Š” ๋กœ์ปฌ ํ™˜๊ฒฝ์— ์„ค์น˜๋˜์–ด ์žˆ์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Œ (Colab Pro+์—์„œ ์‹คํ–‰ ์ „์ œ)
130
+ - `pip install torch datasets tokenizers sentencepiece transformers wandb matplotlib numpy`
131
+ - ์›๋ณธ 4๊ฐœ ํŒŒ์ผ(`_archive/`)๊ณผ ๋ชจ๋“ˆํ™”๋œ `llm_lab/` ํŒจํ‚ค์ง€์˜ ๋กœ์ง์€ ๋™์ผ (import ๊ฒฝ๋กœ๋งŒ ๋ณ€๊ฒฝ)
LLM_Foundation_Model.code-workspace ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "."
5
+ }
6
+ ],
7
+ "settings": {}
8
+ }
_archive/llm-1b-data-pipeline.py ADDED
@@ -0,0 +1,906 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-1B-Lab: ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ
3
+ ==============================
4
+ ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„ โ†’ ๋ฐ์ดํ„ฐ ์ŠคํŠธ๋ฆฌ๋ฐ โ†’ ์‹œํ€€์Šค ํŒจํ‚น โ†’ ๋ฐฐ์น˜ ๊ตฌ์„ฑ
5
+
6
+ ์ „์ฒด ํ๋ฆ„:
7
+ FineWeb-Edu (HuggingFace)
8
+ โ†’ Streaming์œผ๋กœ ๋กœ๋“œ (๋””์Šคํฌ ์ €์žฅ ์—†์Œ)
9
+ โ†’ ํ† ํฌ๋‚˜์ด์ง• (BPE, vocab=32K)
10
+ โ†’ ์‹œํ€€์Šค ํŒจํ‚น (์—ฌ๋Ÿฌ ๋ฌธ์„œ๋ฅผ max_seq_len์œผ๋กœ ์—ฐ๊ฒฐ)
11
+ โ†’ ๋ฐฐ์น˜ ๊ตฌ์„ฑ (input_ids, targets)
12
+ โ†’ GPU ์ „์†ก
13
+
14
+ ์„ค์น˜ ํ•„์š” ํŒจํ‚ค์ง€:
15
+ pip install datasets tokenizers sentencepiece wandb
16
+ """
17
+
18
+ import os
19
+ import time
20
+ import json
21
+ from pathlib import Path
22
+ from dataclasses import dataclass, field
23
+ from typing import Optional, Iterator, List, Dict, Any
24
+
25
+ import torch
26
+ from torch.utils.data import IterableDataset, DataLoader
27
+
28
+ # ============================================================================
29
+ # 1. ๋ฐ์ดํ„ฐ ์„ค์ •
30
+ # ============================================================================
31
+
32
+ @dataclass
33
+ class DataConfig:
34
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •.
35
+
36
+ Colab Pro+ ํ™˜๊ฒฝ ์ œ์•ฝ์„ ๊ณ ๋ คํ•œ ๊ธฐ๋ณธ๊ฐ’:
37
+ - Streaming ๋ชจ๋“œ๋กœ ๋””์Šคํฌ ์‚ฌ์šฉ ์ตœ์†Œํ™”
38
+ - ์‹œํ€€์Šค ํŒจํ‚น์œผ๋กœ ํŒจ๋”ฉ ์—†์ด GPU ํ™œ์šฉ๋ฅ  ๊ทน๋Œ€ํ™”
39
+ - ์ „์ฒ˜๋ฆฌ๋ฅผ on-the-fly๋กœ ์ˆ˜ํ–‰ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
40
+ """
41
+ # โ”€โ”€ ๋ฐ์ดํ„ฐ์…‹ โ”€โ”€
42
+ dataset_name: str = "HuggingFaceFW/fineweb-edu"
43
+ dataset_subset: str = "sample-10BT" # 10B ํ† ํฐ ์ƒ˜ํ”Œ
44
+ dataset_split: str = "train"
45
+ text_column: str = "text" # ํ…์ŠคํŠธ๊ฐ€ ๋‹ด๊ธด ์ปฌ๋Ÿผ๋ช…
46
+
47
+ # โ”€โ”€ ํ† ํฌ๋‚˜์ด์ € โ”€โ”€
48
+ tokenizer_type: str = "sentencepiece" # "sentencepiece" ๋˜๋Š” "hf"
49
+ # ์‚ฌ์ „ ํ•™์Šต๋œ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ (์—†์œผ๋ฉด ์ƒˆ๋กœ ํ•™์Šต)
50
+ tokenizer_path: Optional[str] = None
51
+ vocab_size: int = 32_000
52
+
53
+ # โ”€โ”€ ์‹œํ€€์Šค โ”€โ”€
54
+ max_seq_len: int = 2048
55
+ # ๋ฌธ์„œ ๊ตฌ๋ถ„ ํ† ํฐ ์‚ฌ์šฉ ์—ฌ๋ถ€ (ํŒจํ‚น ์‹œ ๋ฌธ์„œ ๊ฒฝ๊ณ„ ํ‘œ์‹œ)
56
+ use_eos_separator: bool = True
57
+
58
+ # โ”€โ”€ ๋ฐฐ์น˜ โ”€โ”€
59
+ batch_size: int = 4 # micro batch (GPU๋‹น)
60
+ num_workers: int = 2 # DataLoader ์›Œ์ปค ์ˆ˜
61
+ prefetch_factor: int = 4 # ๋ฏธ๋ฆฌ ์ค€๋น„ํ•  ๋ฐฐ์น˜ ์ˆ˜
62
+
63
+ # โ”€โ”€ ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต ์„ค์ • (์ƒˆ๋กœ ํ•™์Šต ์‹œ) โ”€โ”€
64
+ tokenizer_train_samples: int = 50_000 # ํ•™์Šต์— ์‚ฌ์šฉํ•  ๋ฌธ์„œ ์ˆ˜
65
+ tokenizer_save_dir: str = "./tokenizer"
66
+
67
+ # โ”€โ”€ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ โ”€โ”€
68
+ val_ratio: float = 0.001 # ์ „์ฒด์˜ 0.1%๋ฅผ ๊ฒ€์ฆ์šฉ์œผ๋กœ
69
+
70
+
71
+ # ============================================================================
72
+ # 2. ํ† ํฌ๋‚˜์ด์ € ๋ž˜ํผ
73
+ # ============================================================================
74
+
75
+ class Tokenizer:
76
+ """ํ† ํฌ๋‚˜์ด์ € ํ†ตํ•ฉ ๋ž˜ํผ.
77
+
78
+ ์„ธ ๊ฐ€์ง€ ๋ฐฉ๋ฒ• ์ง€์›:
79
+ 1) ๊ธฐ์กด SentencePiece ๋ชจ๋ธ ๋กœ๋“œ
80
+ 2) HuggingFace tokenizers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ์ƒˆ๋กœ ํ•™์Šต
81
+ 3) ์‚ฌ์ „ ํ•™์Šต๋œ HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (์˜ˆ: LLaMA tokenizer)
82
+
83
+ ์™œ ์ง์ ‘ ๊ตฌํ˜„ํ•˜์ง€ ์•Š๋Š”๊ฐ€?
84
+ - BPE ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต์€ ๋Œ€๊ทœ๋ชจ ํ…์ŠคํŠธ ํ†ต๊ณ„ ์ฒ˜๋ฆฌ์ด๋ฉฐ,
85
+ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ดํ•ด์™€ ์ง์ ‘์  ๊ด€๋ จ์ด ์ ์Šต๋‹ˆ๋‹ค.
86
+ - ๋‹ค๋งŒ ํ† ํฌ๋‚˜์ด์ €์˜ ๋™์ž‘ ์›๋ฆฌ(BPE ๋ณ‘ํ•ฉ ๊ทœ์น™)๋Š” ์ดํ•ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
87
+
88
+ BPE(Byte Pair Encoding) ํ•ต์‹ฌ ์›๋ฆฌ:
89
+ 1) ํ…์ŠคํŠธ๋ฅผ ๋ฐ”์ดํŠธ/๋ฌธ์ž ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌ
90
+ 2) ๊ฐ€์žฅ ๋นˆ๋ฒˆํ•œ ์ธ์ ‘ ์Œ์„ ๋ฐ˜๋ณต์ ์œผ๋กœ ๋ณ‘ํ•ฉ
91
+ 3) vocab_size์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต
92
+ โ†’ ์ž์ฃผ ๋“ฑ์žฅํ•˜๋Š” ๋‹จ์–ด๋Š” ํ•˜๋‚˜์˜ ํ† ํฐ, ํฌ๊ท€ ๋‹จ์–ด๋Š” ์—ฌ๋Ÿฌ ํ† ํฐ์œผ๋กœ ๋ถ„๋ฆฌ
93
+ """
94
+
95
+ def __init__(self, config: DataConfig):
96
+ self.config = config
97
+ self._tokenizer = None
98
+ self.vocab_size = config.vocab_size
99
+
100
+ # ํŠน์ˆ˜ ํ† ํฐ ID (์ดˆ๊ธฐํ™” ํ›„ ์„ค์ •๋จ)
101
+ self.bos_id: int = 1 # Beginning of Sequence
102
+ self.eos_id: int = 2 # End of Sequence
103
+ self.pad_id: int = 0 # Padding
104
+
105
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
106
+ # ๋ฐฉ๋ฒ• 1: SentencePiece ๋ชจ๋ธ ๋กœ๋“œ
107
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
108
+
109
+ def load_sentencepiece(self, model_path: str):
110
+ """๊ธฐ์กด SentencePiece ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
111
+ import sentencepiece as spm
112
+
113
+ self._tokenizer = spm.SentencePieceProcessor()
114
+ self._tokenizer.Load(model_path)
115
+
116
+ self.vocab_size = self._tokenizer.GetPieceSize()
117
+ self.bos_id = self._tokenizer.bos_id()
118
+ self.eos_id = self._tokenizer.eos_id()
119
+ self.pad_id = self._tokenizer.pad_id()
120
+ self._encode_fn = self._tokenizer.Encode
121
+ self._decode_fn = self._tokenizer.Decode
122
+
123
+ print(f"[Tokenizer] SentencePiece ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
124
+
125
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
126
+ # ๋ฐฉ๋ฒ• 2: HuggingFace tokenizers๋กœ BPE ํ•™์Šต
127
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
128
+
129
+ def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
130
+ """BPE ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
131
+
132
+ Args:
133
+ text_iterator: ํ•™์Šต ํ…์ŠคํŠธ๋ฅผ yieldํ•˜๋Š” ์ดํ„ฐ๋ ˆ์ดํ„ฐ
134
+ save_dir: ์ €์žฅ ๊ฒฝ๋กœ
135
+
136
+ ํ•™์Šต ํฌ์ธํŠธ:
137
+ - vocab_size๊ฐ€ ํด์ˆ˜๋ก: ์ž์ฃผ ์“ฐ๋Š” ํ‘œํ˜„์ด 1ํ† ํฐ โ†’ ์‹œํ€€์Šค ์งง์•„์ง
138
+ - vocab_size๊ฐ€ ์ž‘์„์ˆ˜๋ก: Embedding ํŒŒ๋ผ๋ฏธํ„ฐ ์ ˆ์•ฝ, ํ•˜์ง€๋งŒ ์‹œํ€€์Šค ๊ธธ์–ด์ง
139
+ - 32K๋Š” ์˜์–ด ๊ธฐ์ค€ ์ข‹์€ ๊ท ํ˜•์ 
140
+ """
141
+ from tokenizers import Tokenizer as HFTokenizer
142
+ from tokenizers.models import BPE
143
+ from tokenizers.trainers import BpeTrainer
144
+ from tokenizers.pre_tokenizers import ByteLevel
145
+ from tokenizers.processors import TemplateProcessing
146
+
147
+ print("[Tokenizer] BPE ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต ์‹œ์ž‘...")
148
+
149
+ # BPE ๋ชจ๋ธ ์ƒ์„ฑ
150
+ tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
151
+ tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
152
+
153
+ # ํŠน์ˆ˜ ํ† ํฐ ์ •์˜
154
+ special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
155
+
156
+ # ํŠธ๋ ˆ์ด๋„ˆ ์„ค์ •
157
+ trainer = BpeTrainer(
158
+ vocab_size=self.config.vocab_size,
159
+ special_tokens=special_tokens,
160
+ min_frequency=2, # ์ตœ์†Œ 2๋ฒˆ ๋“ฑ์žฅํ•œ ์Œ๋งŒ ๋ณ‘ํ•ฉ
161
+ show_progress=True,
162
+ )
163
+
164
+ # ํ•™์Šต ์‹คํ–‰
165
+ tokenizer.train_from_iterator(text_iterator, trainer=trainer)
166
+
167
+ # ํ›„์ฒ˜๋ฆฌ: BOS/EOS ์ž๋™ ์ถ”๊ฐ€
168
+ tokenizer.post_processor = TemplateProcessing(
169
+ single="<s> $A </s>",
170
+ special_tokens=[("<s>", 1), ("</s>", 2)],
171
+ )
172
+
173
+ self._tokenizer = tokenizer
174
+ self.vocab_size = tokenizer.get_vocab_size()
175
+ self.pad_id = 0
176
+ self.bos_id = 1
177
+ self.eos_id = 2
178
+
179
+ self._encode_fn = lambda text: tokenizer.encode(text).ids
180
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
181
+
182
+ # ์ €์žฅ
183
+ save_dir = save_dir or self.config.tokenizer_save_dir
184
+ os.makedirs(save_dir, exist_ok=True)
185
+ tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
186
+ # ๋ฉ”ํƒ€ ์ •๋ณด ์ €์žฅ
187
+ meta = {
188
+ "vocab_size": self.vocab_size,
189
+ "bos_id": self.bos_id,
190
+ "eos_id": self.eos_id,
191
+ "pad_id": self.pad_id,
192
+ }
193
+ with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
194
+ json.dump(meta, f, indent=2)
195
+
196
+ print(f"[Tokenizer] ํ•™์Šต ์™„๋ฃŒ: vocab_size={self.vocab_size}")
197
+ print(f"[Tokenizer] ์ €์žฅ ์œ„์น˜: {save_dir}")
198
+
199
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
200
+ # ๋ฐฉ๋ฒ• 3: ์‚ฌ์ „ ํ•™์Šต๋œ HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
201
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
202
+
203
+ def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
204
+ """HuggingFace์—์„œ ์‚ฌ์ „ ํ•™์Šต๋œ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
205
+
206
+ ๊ฐ€์žฅ ๊ฐ„ํŽธํ•œ ๋ฐฉ๋ฒ•. LLaMA ํ† ํฌ๋‚˜์ด์ €๋Š” 32K vocab, BPE ๊ธฐ๋ฐ˜.
207
+ ์ฃผ์˜: meta-llama ๋ชจ๋ธ์€ HF ์Šน์ธ์ด ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Œ.
208
+ ๋Œ€์•ˆ: mistralai/Mistral-7B-v0.1 (์Šน์ธ ๋ถˆํ•„์š”)
209
+ """
210
+ from transformers import AutoTokenizer
211
+
212
+ print(f"[Tokenizer] HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ: {name_or_path}")
213
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
214
+
215
+ self._tokenizer = tokenizer
216
+ self.vocab_size = tokenizer.vocab_size
217
+ self.bos_id = tokenizer.bos_token_id or 1
218
+ self.eos_id = tokenizer.eos_token_id or 2
219
+ self.pad_id = tokenizer.pad_token_id or 0
220
+
221
+ self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
222
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
223
+
224
+ print(f"[Tokenizer] ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
225
+
226
+ def load_trained_hf(self, path: str):
227
+ """train_bpe()๋กœ ํ•™์Šตํ•œ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋‹ค์‹œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
228
+ from tokenizers import Tokenizer as HFTokenizer
229
+
230
+ tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
231
+ with open(os.path.join(path, "tokenizer_meta.json"), "r") as f:
232
+ meta = json.load(f)
233
+
234
+ self._tokenizer = tokenizer
235
+ self.vocab_size = meta["vocab_size"]
236
+ self.bos_id = meta["bos_id"]
237
+ self.eos_id = meta["eos_id"]
238
+ self.pad_id = meta["pad_id"]
239
+
240
+ self._encode_fn = lambda text: tokenizer.encode(text).ids
241
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
242
+
243
+ print(f"[Tokenizer] ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
244
+
245
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
246
+ # ๊ณตํ†ต ์ธํ„ฐํŽ˜์ด์Šค
247
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
248
+
249
+ def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
250
+ """ํ…์ŠคํŠธ โ†’ ํ† ํฐ ID ๋ฆฌ์ŠคํŠธ."""
251
+ ids = self._encode_fn(text)
252
+ if add_special_tokens:
253
+ ids = [self.bos_id] + ids + [self.eos_id]
254
+ return ids
255
+
256
+ def decode(self, ids: List[int]) -> str:
257
+ """ํ† ํฐ ID ๋ฆฌ์ŠคํŠธ โ†’ ํ…์ŠคํŠธ."""
258
+ return self._decode_fn(ids)
259
+
260
+ def __len__(self) -> int:
261
+ return self.vocab_size
262
+
263
+
264
+ # ============================================================================
265
+ # 3. ์‹œํ€€์Šค ํŒจํ‚น ์ŠคํŠธ๋ฆฌ๋ฐ ๋ฐ์ดํ„ฐ์…‹
266
+ # ============================================================================
267
+
268
+ class PackedStreamingDataset(IterableDataset):
269
+ """Streaming + ์‹œํ€€์Šค ํŒจํ‚น ๋ฐ์ดํ„ฐ์…‹.
270
+
271
+ ์™œ ์‹œํ€€์Šค ํŒจํ‚น์ธ๊ฐ€?
272
+ - ์ผ๋ฐ˜์  ๋ฐฉ๋ฒ•: ๊ฐ ๋ฌธ์„œ๋ฅผ max_seq_len์œผ๋กœ ์ž˜๋ผ ํŒจ๋”ฉ โ†’ GPU ๋‚ญ๋น„
273
+ - ์‹œํ€€์Šค ํŒจํ‚น: ์—ฌ๋Ÿฌ ๋ฌธ์„œ๋ฅผ ์ด์–ด๋ถ™์—ฌ max_seq_len์„ ๊ฝ‰ ์ฑ„์›€ โ†’ 100% ํ™œ์šฉ
274
+
275
+ ๋™์ž‘ ๋ฐฉ์‹:
276
+ ๋ฌธ์„œ1 (300 ํ† ํฐ) + ๋ฌธ์„œ2 (1500 ํ† ํฐ) + ๋ฌธ์„œ3 (248 ํ† ํฐ) = 2048 ํ† ํฐ
277
+ โ†’ [๋ฌธ์„œ1][EOS][๋ฌธ์„œ2][EOS][๋ฌธ์„œ3][EOS][...ํŒจ๋”ฉ ์—†์ด ๋”ฑ ๋งž์ถค]
278
+
279
+ ์™œ Streaming์ธ๊ฐ€?
280
+ - FineWeb-Edu 10B ์ƒ˜ํ”Œ: ์••์ถ• ์ƒํƒœ์—์„œ๋„ ์ˆ˜์‹ญ GB
281
+ - Colab ๋””์Šคํฌ ํ•œ๊ณ„ (~200GB)์—์„œ ์ „์ฒด ๋‹ค์šด๋กœ๋“œ ๋ถˆ๊ฐ€
282
+ - Streaming: ํ•„์š”ํ•œ ๋งŒํผ๋งŒ ๋„คํŠธ์›Œํฌ์—์„œ ์ฝ์–ด์˜ด
283
+
284
+ ํ•™์Šต ์‹œ ์ฃผ์˜์‚ฌํ•ญ:
285
+ - ์‹œํ€€์Šค ๋‚ด ๋ฌธ์„œ ๊ฒฝ๊ณ„์— EOS ํ† ํฐ ์‚ฝ์ž…์œผ๋กœ ๋ชจ๋ธ์ด ๋ฌธ์„œ ๋์„ ์ธ์‹
286
+ - Cross-Attention ๋งˆ์Šคํฌ ์—†์ด๋„ EOS๊ฐ€ ์ž์—ฐ์Šค๋Ÿฌ์šด ๊ฒฝ๊ณ„ ์—ญํ• 
287
+ """
288
+
289
+ def __init__(
290
+ self,
291
+ tokenizer: Tokenizer,
292
+ config: DataConfig,
293
+ split: str = "train",
294
+ seed: int = 42,
295
+ ):
296
+ super().__init__()
297
+ self.tokenizer = tokenizer
298
+ self.config = config
299
+ self.split = split
300
+ self.seed = seed
301
+ self.max_seq_len = config.max_seq_len
302
+
303
+ def _load_dataset(self):
304
+ """HuggingFace ๋ฐ์ดํ„ฐ์…‹์„ ์ŠคํŠธ๋ฆฌ๋ฐ ๋ชจ๋“œ๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
305
+ from datasets import load_dataset
306
+
307
+ ds = load_dataset(
308
+ self.config.dataset_name,
309
+ name=self.config.dataset_subset,
310
+ split=self.config.dataset_split,
311
+ streaming=True, # ํ•ต์‹ฌ: ์ŠคํŠธ๋ฆฌ๋ฐ ๋ชจ๋“œ
312
+ trust_remote_code=True,
313
+ )
314
+
315
+ # ์…”ํ”Œ (์ŠคํŠธ๋ฆฌ๋ฐ์—์„œ๋Š” ๋ฒ„ํผ ๊ธฐ๋ฐ˜ ๊ทผ์‚ฌ ์…”ํ”Œ)
316
+ ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
317
+
318
+ return ds
319
+
320
+ def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
321
+ """๋ฌธ์„œ๋ฅผ ํ† ํฌ๋‚˜์ด์ฆˆํ•˜๊ณ  ์‹œํ€€์Šค ํŒจํ‚นํ•ฉ๋‹ˆ๋‹ค.
322
+
323
+ Yields:
324
+ {"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
325
+
326
+ targets = input_ids๋ฅผ ํ•œ ์นธ shift:
327
+ input_ids: [A, B, C, D, E]
328
+ targets: [B, C, D, E, F]
329
+ โ†’ ๋ชจ๋ธ์€ A๋ฅผ ๋ณด๊ณ  B๋ฅผ ์˜ˆ์ธก, B๋ฅผ ๋ณด๊ณ  C๋ฅผ ์˜ˆ์ธก, ...
330
+ """
331
+ buffer: List[int] = [] # ํ† ํฐ ๋ฒ„ํผ
332
+
333
+ for example in dataset:
334
+ text = example[self.config.text_column]
335
+ if not text or not text.strip():
336
+ continue
337
+
338
+ # ํ† ํฌ๋‚˜์ด์ฆˆ (ํŠน์ˆ˜ ํ† ํฐ ์—†์ด)
339
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
340
+
341
+ if not token_ids:
342
+ continue
343
+
344
+ # EOS ํ† ํฐ ์ถ”๊ฐ€ (๋ฌธ์„œ ๊ฒฝ๊ณ„ ํ‘œ์‹œ)
345
+ if self.config.use_eos_separator:
346
+ token_ids.append(self.tokenizer.eos_id)
347
+
348
+ # ๋ฒ„ํผ์— ์ถ”๊ฐ€
349
+ buffer.extend(token_ids)
350
+
351
+ # ๋ฒ„ํผ๊ฐ€ ์ถฉ๋ถ„ํžˆ ์ฐจ๋ฉด ์‹œํ€€์Šค ์ƒ์„ฑ
352
+ # +1์€ targets ์ƒ์„ฑ์„ ์œ„ํ•ด (input + ๋‹ค์Œ ํ† ํฐ)
353
+ while len(buffer) >= self.max_seq_len + 1:
354
+ # max_seq_len + 1 ๋งŒํผ ๊บผ๋ƒ„
355
+ chunk = buffer[: self.max_seq_len + 1]
356
+ buffer = buffer[self.max_seq_len + 1 :]
357
+
358
+ # input_ids: ์ฒ˜์Œ ~ ๋์—์„œ ๋‘ ๋ฒˆ์งธ
359
+ input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
360
+ # targets: ๋‘ ๋ฒˆ์งธ ~ ๋ (ํ•œ ์นธ shift)
361
+ targets = torch.tensor(chunk[1:], dtype=torch.long)
362
+
363
+ yield {"input_ids": input_ids, "targets": targets}
364
+
365
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
366
+ """DataLoader๊ฐ€ ํ˜ธ์ถœํ•˜๋Š” ์ดํ„ฐ๋ ˆ์ดํ„ฐ.
367
+
368
+ ๋ฉ€ํ‹ฐ ์›Œ์ปค ์ง€์›:
369
+ - ๊ฐ ์›Œ์ปค๊ฐ€ ์„œ๋กœ ๋‹ค๋ฅธ ์‹œ๋“œ๋กœ ์…”ํ”Œ๋œ ์ŠคํŠธ๋ฆผ์„ ์ฒ˜๋ฆฌ
370
+ - ์›Œ์ปค ๊ฐ„ ๋ฐ์ดํ„ฐ ์ค‘๋ณต์„ ์ตœ์†Œํ™”
371
+ """
372
+ worker_info = torch.utils.data.get_worker_info()
373
+
374
+ if worker_info is not None:
375
+ # ๋ฉ€ํ‹ฐ ์›Œ์ปค: ๊ฐ ์›Œ์ปค์— ๋‹ค๋ฅธ ์‹œ๋“œ
376
+ worker_seed = self.seed + worker_info.id
377
+ else:
378
+ worker_seed = self.seed
379
+
380
+ # ์›Œ์ปค๋ณ„ ์‹œ๋“œ๋กœ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
381
+ self.seed = worker_seed
382
+ dataset = self._load_dataset()
383
+
384
+ return self._tokenize_and_pack(dataset)
385
+
386
+
387
+ # ============================================================================
388
+ # 4. ๊ฒ€์ฆ์šฉ ๋ฐ์ดํ„ฐ์…‹ (๊ณ ์ • ํฌ๊ธฐ)
389
+ # ============================================================================
390
+
391
+ class ValidationDataset:
392
+ """๊ฒ€์ฆ์šฉ ๋ฐ์ดํ„ฐ์…‹.
393
+
394
+ Streaming ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ผ์ •๋Ÿ‰์„ ๋ฏธ๋ฆฌ ๊ฐ€์ ธ์™€ ๋ฉ”๋ชจ๋ฆฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
395
+ ๋งค ์—ํญ ๋™์ผํ•œ ๋ฐ์ดํ„ฐ๋กœ ํ‰๊ฐ€ํ•ด์•ผ ๋น„๊ต๊ฐ€ ์˜๋ฏธ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ tokenizer: Tokenizer,
401
+ config: DataConfig,
402
+ num_samples: int = 100,
403
+ seed: int = 9999,
404
+ ):
405
+ self.tokenizer = tokenizer
406
+ self.config = config
407
+ self.num_samples = num_samples
408
+ self.samples: List[Dict[str, torch.Tensor]] = []
409
+
410
+ self._prepare(seed)
411
+
412
+ def _prepare(self, seed: int):
413
+ """๋ฐ์ดํ„ฐ์…‹์—์„œ ๊ฒ€์ฆ ์ƒ˜ํ”Œ์„ ๋ฏธ๋ฆฌ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค."""
414
+ from datasets import load_dataset
415
+
416
+ print(f"[Validation] {self.num_samples}๊ฐœ ๊ฒ€์ฆ ์ƒ˜ํ”Œ ์ค€๋น„ ์ค‘...")
417
+
418
+ ds = load_dataset(
419
+ self.config.dataset_name,
420
+ name=self.config.dataset_subset,
421
+ split=self.config.dataset_split,
422
+ streaming=True,
423
+ trust_remote_code=True,
424
+ )
425
+ # ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ๊ฒน์น˜์ง€ ์•Š๋„๋ก ๋‹ค๋ฅธ ์‹œ๋“œ, ์•ž๋ถ€๋ถ„ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
426
+ ds = ds.shuffle(seed=seed, buffer_size=5_000)
427
+
428
+ buffer: List[int] = []
429
+ count = 0
430
+
431
+ for example in ds:
432
+ if count >= self.num_samples:
433
+ break
434
+
435
+ text = example[self.config.text_column]
436
+ if not text or not text.strip():
437
+ continue
438
+
439
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
440
+ if not token_ids:
441
+ continue
442
+
443
+ token_ids.append(self.tokenizer.eos_id)
444
+ buffer.extend(token_ids)
445
+
446
+ while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples:
447
+ chunk = buffer[: self.config.max_seq_len + 1]
448
+ buffer = buffer[self.config.max_seq_len + 1 :]
449
+
450
+ self.samples.append({
451
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
452
+ "targets": torch.tensor(chunk[1:], dtype=torch.long),
453
+ })
454
+ count += 1
455
+
456
+ print(f"[Validation] {len(self.samples)}๊ฐœ ์ƒ˜ํ”Œ ์ค€๋น„ ์™„๋ฃŒ")
457
+
458
+ def get_dataloader(self, batch_size: int) -> DataLoader:
459
+ """๊ฒ€์ฆ DataLoader๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
460
+ return DataLoader(
461
+ self.samples,
462
+ batch_size=batch_size,
463
+ shuffle=False,
464
+ num_workers=0,
465
+ collate_fn=_collate_fn,
466
+ )
467
+
468
+
469
+ # ============================================================================
470
+ # 5. DataLoader ์ƒ์„ฑ ์œ ํ‹ธ๋ฆฌํ‹ฐ
471
+ # ============================================================================
472
+
473
+ def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
474
+ """๋ฐฐ์น˜ ๋‚ด ์ƒ˜ํ”Œ๋“ค์„ ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ํ•ฉ์นฉ๋‹ˆ๋‹ค.
475
+
476
+ ์‹œํ€€์Šค ํŒจํ‚น ๋•๋ถ„์— ๋ชจ๋“  ์ƒ˜ํ”Œ์ด ๋™์ผํ•œ ๊ธธ์ด(max_seq_len)์ด๋ฏ€๋กœ
477
+ ์ถ”๊ฐ€ ํŒจ๋”ฉ์ด ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.
478
+ """
479
+ return {
480
+ "input_ids": torch.stack([s["input_ids"] for s in batch]),
481
+ "targets": torch.stack([s["targets"] for s in batch]),
482
+ }
483
+
484
+
485
+ def create_train_dataloader(
486
+ tokenizer: Tokenizer,
487
+ config: DataConfig,
488
+ seed: int = 42,
489
+ ) -> DataLoader:
490
+ """ํ•™์Šต์šฉ DataLoader๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
491
+
492
+ Returns:
493
+ ๋ฌดํ•œํžˆ ๋ฐ˜๋ณต๋˜๋Š” ์ŠคํŠธ๋ฆฌ๋ฐ DataLoader
494
+
495
+ ์‚ฌ์šฉ๋ฒ•:
496
+ dataloader = create_train_dataloader(tokenizer, config)
497
+ for step, batch in enumerate(dataloader):
498
+ input_ids = batch["input_ids"].to(device) # (B, seq_len)
499
+ targets = batch["targets"].to(device) # (B, seq_len)
500
+ logits, loss = model(input_ids, targets)
501
+ ...
502
+ """
503
+ dataset = PackedStreamingDataset(
504
+ tokenizer=tokenizer,
505
+ config=config,
506
+ split="train",
507
+ seed=seed,
508
+ )
509
+
510
+ dataloader = DataLoader(
511
+ dataset,
512
+ batch_size=config.batch_size,
513
+ num_workers=config.num_workers,
514
+ prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
515
+ pin_memory=True, # GPU ์ „์†ก ์†๋„ ํ–ฅ์ƒ
516
+ collate_fn=_collate_fn,
517
+ )
518
+
519
+ return dataloader
520
+
521
+
522
+ # ============================================================================
523
+ # 6. ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต ํ—ฌํผ
524
+ # ============================================================================
525
+
526
+ def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
527
+ """๋ฐ์ดํ„ฐ์…‹์—์„œ BPE ํ† ํฌ๋‚˜์ด์ €๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
528
+
529
+ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค ์‚ฌ์šฉํ•  ํ•„์š” ์—†์ด, 50K ๋ฌธ์„œ๋ฉด ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค.
530
+ ํ† ํฌ๋‚˜์ด์ € vocab์€ ์ „์ฒด ๋ฐ์ดํ„ฐ์˜ ํ†ต๊ณ„๋ฅผ ๋ฐ˜์˜ํ•˜๋ฉด ๋˜๋ฏ€๋กœ.
531
+ """
532
+ from datasets import load_dataset
533
+
534
+ print(f"[Train Tokenizer] {config.dataset_name}์—์„œ ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต")
535
+ print(f"[Train Tokenizer] ํ•™์Šต ๋ฌธ์„œ ์ˆ˜: {config.tokenizer_train_samples:,}")
536
+
537
+ # ํ…์ŠคํŠธ ์ดํ„ฐ๋ ˆ์ดํ„ฐ ์ƒ์„ฑ
538
+ ds = load_dataset(
539
+ config.dataset_name,
540
+ name=config.dataset_subset,
541
+ split=config.dataset_split,
542
+ streaming=True,
543
+ trust_remote_code=True,
544
+ )
545
+
546
+ def text_iterator():
547
+ count = 0
548
+ for example in ds:
549
+ if count >= config.tokenizer_train_samples:
550
+ break
551
+ text = example[config.text_column]
552
+ if text and text.strip():
553
+ yield text
554
+ count += 1
555
+ if count % 10_000 == 0:
556
+ print(f" ... {count:,} ๋ฌธ์„œ ์ฒ˜๋ฆฌ")
557
+
558
+ # ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต
559
+ tokenizer = Tokenizer(config)
560
+ tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
561
+
562
+ return tokenizer
563
+
564
+
565
+ # ============================================================================
566
+ # 7. ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ํ†ต๊ณ„/์ง„๋‹จ ๋„๊ตฌ
567
+ # ============================================================================
568
+
569
+ class DataPipelineDiagnostics:
570
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ์˜ ์„ฑ๋Šฅ๊ณผ ํ’ˆ์งˆ์„ ์ง„๋‹จํ•ฉ๋‹ˆ๋‹ค.
571
+
572
+ ํ•™์Šต ์ „ ๋ฐ˜๋“œ์‹œ ํ™•์ธํ•ด์•ผ ํ•  ํ•ญ๋ชฉ:
573
+ 1) ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ: ํ‰๊ท  ํ† ํฐ/๋ฌธ์„œ, ์•Œ ์ˆ˜ ์—†๋Š” ํ† ํฐ ๋น„์œจ
574
+ 2) ํŒจํ‚น ํšจ์œจ: ์‹ค์ œ ํ† ํฐ ๋น„์œจ vs ํŒจ๋”ฉ ๋น„์œจ
575
+ 3) ์ฒ˜๋ฆฌ ์†๋„: tokens/sec (๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ๋ณ‘๋ชฉ ํ™•์ธ)
576
+ 4) ๋ฐฐ์น˜ ํ˜•ํƒœ: shape, dtype ์ •ํ™•์„ฑ
577
+ """
578
+
579
+ @staticmethod
580
+ def check_tokenizer_quality(
581
+ tokenizer: Tokenizer,
582
+ config: DataConfig,
583
+ num_samples: int = 1000,
584
+ ):
585
+ """ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ์„ ์ง„๋‹จํ•ฉ๋‹ˆ๋‹ค."""
586
+ from datasets import load_dataset
587
+
588
+ print("\n" + "=" * 60)
589
+ print("๐Ÿ“Š ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ ์ง„๋‹จ")
590
+ print("=" * 60)
591
+
592
+ ds = load_dataset(
593
+ config.dataset_name,
594
+ name=config.dataset_subset,
595
+ split=config.dataset_split,
596
+ streaming=True,
597
+ trust_remote_code=True,
598
+ )
599
+
600
+ token_counts = []
601
+ char_counts = []
602
+ sample_count = 0
603
+
604
+ for example in ds:
605
+ if sample_count >= num_samples:
606
+ break
607
+ text = example[config.text_column]
608
+ if not text or not text.strip():
609
+ continue
610
+
611
+ tokens = tokenizer.encode(text)
612
+ token_counts.append(len(tokens))
613
+ char_counts.append(len(text))
614
+ sample_count += 1
615
+
616
+ avg_tokens = sum(token_counts) / len(token_counts)
617
+ avg_chars = sum(char_counts) / len(char_counts)
618
+ compression_ratio = avg_chars / avg_tokens # ๋ฌธ์ž/ํ† ํฐ ๋น„์œจ
619
+
620
+ print(f" ๋ถ„์„ ๋ฌธ์„œ ์ˆ˜: {len(token_counts):,}")
621
+ print(f" ํ‰๊ท  ํ† ํฐ/๋ฌธ์„œ: {avg_tokens:.1f}")
622
+ print(f" ํ‰๊ท  ๋ฌธ์ž/๋ฌธ์„œ: {avg_chars:.1f}")
623
+ print(f" ์••์ถ• ๋น„์œจ (๋ฌธ์ž/ํ† ํฐ): {compression_ratio:.2f}")
624
+ print(f" โ†’ ์˜์–ด ๊ธฐ์ค€ 3.5~4.5๊ฐ€ ์ •์ƒ")
625
+ print(f" ์ตœ์†Œ ํ† ํฐ: {min(token_counts)}, ์ตœ๋Œ€: {max(token_counts)}")
626
+
627
+ # ๋””์ฝ”๋“œ ์™•๋ณต ํ…Œ์ŠคํŠธ
628
+ test_text = "The quick brown fox jumps over the lazy dog."
629
+ encoded = tokenizer.encode(test_text)
630
+ decoded = tokenizer.decode(encoded)
631
+ roundtrip_ok = test_text.strip() in decoded.strip()
632
+ print(f"\n ์™•๋ณต ํ…Œ์ŠคํŠธ: {'โœ… ํ†ต๊ณผ' if roundtrip_ok else 'โŒ ์‹คํŒจ'}")
633
+ print(f" ์›๋ณธ: {test_text}")
634
+ print(f" ์ธ์ฝ”๋”ฉ: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
635
+ print(f" ๋””์ฝ”๋”ฉ: {decoded}")
636
+
637
+ @staticmethod
638
+ def benchmark_throughput(
639
+ dataloader: DataLoader,
640
+ num_batches: int = 50,
641
+ seq_len: int = 2048,
642
+ ):
643
+ """๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ฒ˜๋ฆฌ๋Ÿ‰์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
644
+
645
+ GPU ํ•™์Šต ์†๋„์˜ ๋ณ‘๋ชฉ์ด ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ธ์ง€ ํ™•์ธํ•˜๋Š” ํ•ต์‹ฌ ์ง„๋‹จ.
646
+ ๋ชฉํ‘œ: ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด GPU ์—ฐ์‚ฐ๋ณด๋‹ค ๋นจ๋ผ์•ผ ํ•จ (data loading โ‰  bottleneck).
647
+ """
648
+ print("\n" + "=" * 60)
649
+ print("โšก ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ฒค์น˜๋งˆํฌ")
650
+ print("=" * 60)
651
+
652
+ total_tokens = 0
653
+ start_time = time.time()
654
+
655
+ for i, batch in enumerate(dataloader):
656
+ if i >= num_batches:
657
+ break
658
+ batch_tokens = batch["input_ids"].numel()
659
+ total_tokens += batch_tokens
660
+
661
+ if (i + 1) % 10 == 0:
662
+ elapsed = time.time() - start_time
663
+ tps = total_tokens / elapsed
664
+ print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
665
+
666
+ elapsed = time.time() - start_time
667
+ tps = total_tokens / elapsed
668
+
669
+ print(f"\n ์ด ๋ฐฐ์น˜ ์ˆ˜: {num_batches}")
670
+ print(f" ์ด ํ† ํฐ ์ˆ˜: {total_tokens:,}")
671
+ print(f" ์†Œ์š” ์‹œ๊ฐ„: {elapsed:.2f}์ดˆ")
672
+ print(f" ํ‰๊ท  ์ฒ˜๋ฆฌ๋Ÿ‰: {tps:,.0f} tokens/sec")
673
+ print(f"\n ๐Ÿ’ก A100 ํ•™์Šต ์ฒ˜๋ฆฌ๋Ÿ‰ ~50-80K tokens/sec ๊ธฐ์ค€:")
674
+ if tps > 80_000:
675
+ print(f" โœ… ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด ๋ณ‘๋ชฉ์ด ์•„๋‹™๋‹ˆ๋‹ค")
676
+ elif tps > 30_000:
677
+ print(f" โš ๏ธ ๊ฒฝ๊ณ„์„  - num_workers ์ฆ๊ฐ€๋ฅผ ๊ณ ๋ คํ•˜์„ธ์š”")
678
+ else:
679
+ print(f" โŒ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด ๋ณ‘๋ชฉ! num_workers/prefetch ์กฐ์ • ํ•„์š”")
680
+
681
+ @staticmethod
682
+ def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
683
+ """๋ฐฐ์น˜ ํ•˜๋‚˜๋ฅผ ์ƒ์„ธ ๊ฒ€์‚ฌํ•ฉ๋‹ˆ๋‹ค."""
684
+ print("\n" + "=" * 60)
685
+ print("๐Ÿ” ๋ฐฐ์น˜ ์ƒ์„ธ ๊ฒ€์‚ฌ")
686
+ print("=" * 60)
687
+
688
+ input_ids = batch["input_ids"]
689
+ targets = batch["targets"]
690
+
691
+ print(f" input_ids shape: {input_ids.shape}")
692
+ print(f" targets shape: {targets.shape}")
693
+ print(f" dtype: {input_ids.dtype}")
694
+ print(f" ๊ฐ’ ๋ฒ”์œ„: [{input_ids.min().item()}, {input_ids.max().item()}]")
695
+
696
+ # Shift ๊ด€๊ณ„ ํ™•์ธ: targets[i] == input_ids[i+1]
697
+ shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
698
+ print(f" Shift ์ •ํ•ฉ์„ฑ: {shift_correct*100:.1f}% (100%์—ฌ์•ผ ์ •์ƒ)")
699
+
700
+ # EOS ํ† ํฐ ๋ถ„ํฌ (๋ฌธ์„œ ๊ฒฝ๊ณ„)
701
+ eos_count = (input_ids == tokenizer.eos_id).sum().item()
702
+ total_tokens = input_ids.numel()
703
+ print(f" EOS ํ† ํฐ ์ˆ˜: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
704
+
705
+ # ์ฒซ ๋ฒˆ์งธ ์ƒ˜ํ”Œ ๋””์ฝ”๋”ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
706
+ first_sample = input_ids[0][:100].tolist()
707
+ decoded_preview = tokenizer.decode(first_sample)
708
+ print(f"\n ์ฒซ ์ƒ˜ํ”Œ ๋””์ฝ”๋”ฉ (์ฒ˜์Œ 100 ํ† ํฐ):")
709
+ print(f" {decoded_preview[:300]}...")
710
+
711
+
712
+ # ============================================================================
713
+ # 8. ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ํ†ตํ•ฉ (Quick Start)
714
+ # ============================================================================
715
+
716
+ def setup_data_pipeline(
717
+ tokenizer_mode: str = "train_new",
718
+ tokenizer_path: Optional[str] = None,
719
+ config: Optional[DataConfig] = None,
720
+ ) -> tuple:
721
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ํ•œ ๋ฒˆ์— ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
722
+
723
+ Args:
724
+ tokenizer_mode:
725
+ "train_new" - BPE ํ† ํฌ๋‚˜์ด์ € ์ƒˆ๋กœ ํ•™์Šต
726
+ "load_trained" - ์ด์ „์— ํ•™์Šตํ•œ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
727
+ "pretrained" - HuggingFace ์‚ฌ์ „ํ•™์Šต ํ† ํฌ๋‚˜์ด์ € ์‚ฌ์šฉ
728
+ tokenizer_path:
729
+ "train_new" โ†’ ์ €์žฅ ๊ฒฝ๋กœ (๊ธฐ๋ณธ: ./tokenizer)
730
+ "load_trained" โ†’ ์ €์žฅ๋œ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ
731
+ "pretrained" โ†’ HF ๋ชจ๋ธ๋ช… (๊ธฐ๋ณธ: mistralai/Mistral-7B-v0.1)
732
+
733
+ Returns:
734
+ (tokenizer, train_dataloader, val_dataloader)
735
+
736
+ ์‚ฌ์šฉ ์˜ˆ์‹œ (Colab):
737
+ # ๋ฐฉ๋ฒ• 1: ํ† ํฌ๋‚˜์ด์ € ์ƒˆ๋กœ ํ•™์Šต
738
+ tok, train_dl, val_dl = setup_data_pipeline("train_new")
739
+
740
+ # ๋ฐฉ๋ฒ• 2: ๊ธฐ์กด ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
741
+ tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
742
+
743
+ # ๋ฐฉ๋ฒ• 3: ์‚ฌ์ „ํ•™์Šต ํ† ํฌ๋‚˜์ด์ € (๊ฐ€์žฅ ๊ฐ„ํŽธ)
744
+ tok, train_dl, val_dl = setup_data_pipeline("pretrained")
745
+ """
746
+ config = config or DataConfig()
747
+
748
+ print("=" * 60)
749
+ print("๐Ÿš€ ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •")
750
+ print("=" * 60)
751
+
752
+ # โ”€โ”€ Step 1: ํ† ํฌ๋‚˜์ด์ € โ”€โ”€
753
+ tokenizer = Tokenizer(config)
754
+
755
+ if tokenizer_mode == "train_new":
756
+ tokenizer = train_tokenizer_from_dataset(config)
757
+ elif tokenizer_mode == "load_trained":
758
+ path = tokenizer_path or config.tokenizer_save_dir
759
+ tokenizer.load_trained_hf(path)
760
+ elif tokenizer_mode == "pretrained":
761
+ name = tokenizer_path or "mistralai/Mistral-7B-v0.1"
762
+ tokenizer.load_pretrained_hf(name)
763
+ else:
764
+ raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
765
+
766
+ # โ”€โ”€ Step 2: ํ•™์Šต DataLoader โ”€โ”€
767
+ print("\n[DataLoader] ํ•™์Šต DataLoader ์ƒ์„ฑ...")
768
+ train_dataloader = create_train_dataloader(tokenizer, config)
769
+
770
+ # โ”€โ”€ Step 3: ๊ฒ€์ฆ DataLoader โ”€โ”€
771
+ print("\n[DataLoader] ๊ฒ€์ฆ DataLoader ์ƒ์„ฑ...")
772
+ val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
773
+ val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
774
+
775
+ print("\n" + "=" * 60)
776
+ print("โœ… ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ • ์™„๋ฃŒ!")
777
+ print(f" ํ† ํฌ๋‚˜์ด์ € vocab: {tokenizer.vocab_size:,}")
778
+ print(f" ์‹œํ€€์Šค ๊ธธ์ด: {config.max_seq_len}")
779
+ print(f" ๋ฐฐ์น˜ ํฌ๊ธฐ: {config.batch_size}")
780
+ print(f" ํ† ํฐ/๋ฐฐ์น˜: {config.batch_size * config.max_seq_len:,}")
781
+ print("=" * 60)
782
+
783
+ return tokenizer, train_dataloader, val_dataloader
784
+
785
+
786
+ # ============================================================================
787
+ # 9. ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ
788
+ # ============================================================================
789
+
790
+ if __name__ == "__main__":
791
+ """
792
+ ๋กœ์ปฌ/Colab์—์„œ ์‹คํ–‰ํ•˜์—ฌ ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.
793
+
794
+ ์‹คํ–‰ ๋ฐฉ๋ฒ•:
795
+ python data_pipeline.py
796
+
797
+ ๋˜๋Š” Colab์—์„œ:
798
+ !pip install datasets tokenizers sentencepiece
799
+ %run data_pipeline.py
800
+ """
801
+ print("=" * 70)
802
+ print("LLM-1B-Lab: ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ๊ฒ€์ฆ")
803
+ print("=" * 70)
804
+
805
+ # โ”€โ”€ ๊ฐ„๋‹จํ•œ ๊ฒ€์ฆ: ๋”๋ฏธ ํ† ํฌ๋‚˜์ด์ €๋กœ ํŒŒ์ดํ”„๋ผ์ธ ํ…Œ์ŠคํŠธ โ”€โ”€
806
+ print("\n[ํ…Œ์ŠคํŠธ 1] ๋”๋ฏธ ํ† ํฌ๋‚˜์ด์ €๋กœ ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์กฐ ๊ฒ€์ฆ")
807
+
808
+ # ๋”๋ฏธ ํ† ํฌ๋‚˜์ด์ € (์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹ ์—†์ด ํ…Œ์ŠคํŠธ)
809
+ class DummyTokenizer:
810
+ """ํ…Œ์ŠคํŠธ์šฉ ๊ฐ„๋‹จํ•œ ๋ฌธ์ž ๋‹จ์œ„ ํ† ํฌ๋‚˜์ด์ €."""
811
+ def __init__(self, vocab_size=256):
812
+ self.vocab_size = vocab_size
813
+ self.eos_id = 2
814
+ self.bos_id = 1
815
+ self.pad_id = 0
816
+
817
+ def encode(self, text, add_special_tokens=False):
818
+ # ๊ฐ ๋ฌธ์ž๋ฅผ ASCII ๊ฐ’์œผ๋กœ ๋ณ€ํ™˜ (๊ฐ„๋‹จํ•œ ํ…Œ์ŠคํŠธ์šฉ)
819
+ ids = [min(ord(c), self.vocab_size - 1) for c in text]
820
+ if add_special_tokens:
821
+ ids = [self.bos_id] + ids + [self.eos_id]
822
+ return ids
823
+
824
+ def decode(self, ids):
825
+ return "".join(chr(min(i, 127)) for i in ids if i > 2)
826
+
827
+ def __len__(self):
828
+ return self.vocab_size
829
+
830
+ config = DataConfig(max_seq_len=64, batch_size=2) # ์ž‘์€ ์„ค์ •
831
+ dummy_tok = DummyTokenizer()
832
+
833
+ # ๋”๋ฏธ ๋ฐ์ดํ„ฐ๋กœ ํŒจํ‚น ํ…Œ์ŠคํŠธ
834
+ print("\n[ํ…Œ์ŠคํŠธ 2] ์‹œํ€€์Šค ํŒจํ‚น ๋กœ์ง ๊ฒ€์ฆ")
835
+
836
+ buffer = []
837
+ test_docs = [
838
+ "Hello world! This is document one. " * 5,
839
+ "Second document here with different content. " * 8,
840
+ "Third doc. " * 20,
841
+ "A " * 200,
842
+ ]
843
+
844
+ for doc in test_docs:
845
+ tokens = dummy_tok.encode(doc)
846
+ tokens.append(dummy_tok.eos_id)
847
+ buffer.extend(tokens)
848
+
849
+ seq_len = config.max_seq_len
850
+ packed_count = 0
851
+ while len(buffer) >= seq_len + 1:
852
+ chunk = buffer[: seq_len + 1]
853
+ buffer = buffer[seq_len + 1 :]
854
+ input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
855
+ targets = torch.tensor(chunk[1:], dtype=torch.long)
856
+
857
+ # Shift ๊ด€๊ณ„ ํ™•์ธ
858
+ assert (input_ids[1:] == targets[:-1]).all(), "Shift ๊ด€๊ณ„ ์˜ค๋ฅ˜!"
859
+ packed_count += 1
860
+
861
+ print(f" ๋ฌธ์„œ ์ˆ˜: {len(test_docs)}")
862
+ print(f" ์ด ํ† ํฐ ์ˆ˜: {sum(len(dummy_tok.encode(d)) + 1 for d in test_docs)}")
863
+ print(f" ํŒจํ‚น๋œ ์‹œํ€€์Šค ์ˆ˜: {packed_count}")
864
+ print(f" ์‹œํ€€์Šค ๊ธธ์ด: {seq_len}")
865
+ print(f" ๋‚จ์€ ๋ฒ„ํผ: {len(buffer)} ํ† ํฐ")
866
+ print(f" โœ… Shift ๊ด€๊ณ„ ๊ฒ€์ฆ ํ†ต๊ณผ")
867
+
868
+ # ๋ฐฐ์น˜ ๊ตฌ์„ฑ ํ…Œ์ŠคํŠธ
869
+ print("\n[ํ…Œ์ŠคํŠธ 3] ๋ฐฐ์น˜ ๊ตฌ์„ฑ ๊ฒ€์ฆ")
870
+
871
+ samples = []
872
+ buffer2 = []
873
+ for doc in test_docs * 10: # ์ถฉ๋ถ„ํ•œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
874
+ tokens = dummy_tok.encode(doc)
875
+ tokens.append(dummy_tok.eos_id)
876
+ buffer2.extend(tokens)
877
+
878
+ while len(buffer2) >= seq_len + 1 and len(samples) < 10:
879
+ chunk = buffer2[: seq_len + 1]
880
+ buffer2 = buffer2[seq_len + 1 :]
881
+ samples.append({
882
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
883
+ "targets": torch.tensor(chunk[1:], dtype=torch.long),
884
+ })
885
+
886
+ batch = _collate_fn(samples[:config.batch_size])
887
+ print(f" input_ids shape: {batch['input_ids'].shape}")
888
+ print(f" targets shape: {batch['targets'].shape}")
889
+ print(f" dtype: {batch['input_ids'].dtype}")
890
+
891
+ expected_shape = (config.batch_size, seq_len)
892
+ assert batch["input_ids"].shape == expected_shape, f"Shape ๋ถˆ์ผ์น˜: {batch['input_ids'].shape} != {expected_shape}"
893
+ print(f" โœ… ๋ฐฐ์น˜ shape ๊ฒ€์ฆ ํ†ต๊ณผ: {expected_shape}")
894
+
895
+ # EOS ํ† ํฐ ์กด์žฌ ํ™•์ธ
896
+ eos_found = (batch["input_ids"] == dummy_tok.eos_id).any().item()
897
+ print(f" โœ… EOS ํ† ํฐ ์กด์žฌ: {eos_found}")
898
+
899
+ print("\n" + "=" * 70)
900
+ print("โœ… ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์กฐ ๊ฒ€์ฆ ์™„๋ฃŒ!")
901
+ print()
902
+ print("๋‹ค์Œ ๋‹จ๊ณ„: ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ…Œ์ŠคํŠธ")
903
+ print(" tokenizer, train_dl, val_dl = setup_data_pipeline('pretrained')")
904
+ print(" DataPipelineDiagnostics.check_tokenizer_quality(tokenizer, DataConfig())")
905
+ print(" DataPipelineDiagnostics.benchmark_throughput(train_dl)")
906
+ print("=" * 70)
_archive/llm-1b-evaluation.py ADDED
@@ -0,0 +1,1455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-1B-Lab: ํ‰๊ฐ€ ๋ชจ๋“ˆ (Evaluation)
3
+ =====================================
4
+ ํ•™์Šต๋œ ๋ชจ๋ธ์˜ ํ’ˆ์งˆ์„ ๋‹ค๊ฐ๋„๋กœ ํ‰๊ฐ€ํ•˜๊ณ ,
5
+ ํ•™์Šต ๊ณผ์ •์—์„œ ์–ป์€ ํ†ต์ฐฐ์„ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
6
+
7
+ ํ‰๊ฐ€ ์˜์—ญ:
8
+ 1. Perplexity ์ธก์ • โ€” ์–ธ์–ด ๋ชจ๋ธ์˜ ํ‘œ์ค€ ์ •๋Ÿ‰ ์ง€ํ‘œ
9
+ 2. ํ…์ŠคํŠธ ์ƒ์„ฑ ํ’ˆ์งˆ โ€” ์ •์„ฑ์  ํ‰๊ฐ€ (๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ)
10
+ 3. Scaling Law ๋ถ„์„ โ€” 10M โ†’ 100M โ†’ 1B ๋น„๊ต
11
+ 4. ํ•™์Šต ์—ญํ•™ ๋ถ„์„ โ€” Loss ๊ณก์„ , LR, Gradient ํŒจํ„ด
12
+ 5. Attention ์‹œ๊ฐํ™” โ€” ๋ชจ๋ธ์ด "์–ด๋””๋ฅผ ๋ณด๋Š”์ง€" ๋ถ„์„
13
+ 6. ์ข…ํ•ฉ ๋ฆฌํฌํŠธ ์ƒ์„ฑ โ€” ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ •๋ฆฌ
14
+
15
+ ์„ค์น˜ ํ•„์š”:
16
+ pip install matplotlib numpy
17
+ """
18
+
19
+ import math
20
+ import time
21
+ import json
22
+ from pathlib import Path
23
+ from dataclasses import dataclass, field
24
+ from typing import Optional, List, Dict, Any, Tuple
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch.utils.data import DataLoader
30
+
31
+ try:
32
+ import matplotlib
33
+ matplotlib.use("Agg") # Colab/์„œ๋ฒ„ ํ˜ธํ™˜
34
+ import matplotlib.pyplot as plt
35
+ import matplotlib.ticker as ticker
36
+ HAS_MATPLOTLIB = True
37
+ except ImportError:
38
+ HAS_MATPLOTLIB = False
39
+
40
+ try:
41
+ import numpy as np
42
+ HAS_NUMPY = True
43
+ except ImportError:
44
+ HAS_NUMPY = False
45
+
46
+
47
+ # ============================================================================
48
+ # 1. ํ‰๊ฐ€ ์„ค์ •
49
+ # ============================================================================
50
+
51
+ @dataclass
52
+ class EvalConfig:
53
+ """ํ‰๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ."""
54
+ # โ”€โ”€ Perplexity โ”€โ”€
55
+ eval_batch_size: int = 4
56
+ max_eval_batches: int = 100 # ์ตœ๋Œ€ ํ‰๊ฐ€ ๋ฐฐ์น˜ ์ˆ˜
57
+
58
+ # โ”€โ”€ ์ƒ์„ฑ โ”€โ”€
59
+ max_new_tokens: int = 200
60
+ temperature: float = 0.8
61
+ top_k: int = 50
62
+ top_p: float = 0.9
63
+ num_samples: int = 3 # ํ”„๋กฌํ”„ํŠธ๋‹น ์ƒ์„ฑ ํšŸ์ˆ˜
64
+
65
+ # โ”€โ”€ ์ถœ๋ ฅ โ”€โ”€
66
+ save_dir: str = "./eval_results"
67
+ plot_dpi: int = 150
68
+
69
+
70
+ # ============================================================================
71
+ # 2. Perplexity ํ‰๊ฐ€๊ธฐ
72
+ # ============================================================================
73
+
74
+ class PerplexityEvaluator:
75
+ """Perplexity(PPL)๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
76
+
77
+ Perplexity๋ž€?
78
+ PPL = exp(average cross-entropy loss)
79
+
80
+ ์ง๊ด€์  ์˜๋ฏธ:
81
+ - PPL = 1: ์™„๋ฒฝํ•œ ์˜ˆ์ธก (๋ถˆ๊ฐ€๋Šฅ)
82
+ - PPL = 10: ๋งค๋ฒˆ 10๊ฐœ ํ›„๋ณด ์ค‘ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€
83
+ - PPL = 100: 100๊ฐœ ํ›„๋ณด ์ค‘ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€ (๋ฌด์ž‘์œ„์— ๊ฐ€๊นŒ์›€)
84
+ - PPL = 32000: vocab ์ „์ฒด์—์„œ ๋žœ๋ค ์„ ํƒ (์ดˆ๊ธฐ ๋žœ๋ค ๋ชจ๋ธ)
85
+
86
+ ์ข‹์€ 1B ๋ชจ๋ธ ๊ธฐ์ค€ (์˜์–ด ์›น ํ…์ŠคํŠธ):
87
+ - 5B ํ† ํฐ ํ•™์Šต: PPL ~30-40
88
+ - 10B ํ† ํฐ ํ•™์Šต: PPL ~20-30
89
+ - 20B ํ† ํฐ ํ•™์Šต: PPL ~15-25
90
+
91
+ ์ธก์ • ๋ฐฉ๋ฒ•:
92
+ - ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ชจ๋“  ํ† ํฐ์— ๋Œ€ํ•ด cross-entropy ๊ณ„์‚ฐ
93
+ - ํ† ํฐ ๋‹จ์œ„ ํ‰๊ท  ํ›„ exp() ์ ์šฉ
94
+ - ํŒจ๋”ฉ ํ† ํฐ์€ ์ œ์™ธ (ignore_index=-100)
95
+ """
96
+
97
+ def __init__(self, config: EvalConfig):
98
+ self.config = config
99
+
100
+ @torch.no_grad()
101
+ def evaluate(
102
+ self,
103
+ model: nn.Module,
104
+ dataloader: DataLoader,
105
+ device: torch.device,
106
+ dtype: torch.dtype = torch.bfloat16,
107
+ desc: str = "Evaluation",
108
+ ) -> Dict[str, float]:
109
+ """Perplexity๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
110
+
111
+ Returns:
112
+ {
113
+ "loss": ํ‰๊ท  cross-entropy loss,
114
+ "perplexity": exp(loss),
115
+ "num_tokens": ํ‰๊ฐ€์— ์‚ฌ์šฉ๋œ ์ด ํ† ํฐ ์ˆ˜,
116
+ "num_batches": ํ‰๊ฐ€์— ์‚ฌ์šฉ๋œ ๋ฐฐ์น˜ ์ˆ˜,
117
+ }
118
+ """
119
+ model.eval()
120
+
121
+ total_loss = 0.0
122
+ total_tokens = 0
123
+ num_batches = 0
124
+
125
+ print(f"\n๐Ÿ“Š {desc}")
126
+ start_time = time.time()
127
+
128
+ for i, batch in enumerate(dataloader):
129
+ if i >= self.config.max_eval_batches:
130
+ break
131
+
132
+ input_ids = batch["input_ids"].to(device)
133
+ targets = batch["targets"].to(device)
134
+
135
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
136
+ logits, _ = model(input_ids)
137
+
138
+ # ํ† ํฐ๋ณ„ cross-entropy (reduction='none')
139
+ # logits: (B, S, V) โ†’ (B*S, V)
140
+ # targets: (B, S) โ†’ (B*S,)
141
+ loss_per_token = F.cross_entropy(
142
+ logits.view(-1, logits.size(-1)),
143
+ targets.view(-1),
144
+ ignore_index=-100,
145
+ reduction="none",
146
+ )
147
+
148
+ # -100์ด ์•„๋‹Œ ์œ ํšจ ํ† ํฐ๋งŒ ์นด์šดํŠธ
149
+ valid_mask = (targets.view(-1) != -100)
150
+ valid_tokens = valid_mask.sum().item()
151
+
152
+ total_loss += loss_per_token[valid_mask].sum().item()
153
+ total_tokens += valid_tokens
154
+ num_batches += 1
155
+
156
+ if (i + 1) % 20 == 0:
157
+ running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
158
+ print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
159
+
160
+ elapsed = time.time() - start_time
161
+ avg_loss = total_loss / max(total_tokens, 1)
162
+ perplexity = math.exp(min(avg_loss, 100)) # overflow ๋ฐฉ์ง€
163
+
164
+ results = {
165
+ "loss": round(avg_loss, 4),
166
+ "perplexity": round(perplexity, 2),
167
+ "num_tokens": total_tokens,
168
+ "num_batches": num_batches,
169
+ "eval_time_sec": round(elapsed, 1),
170
+ }
171
+
172
+ print(f" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€")
173
+ print(f" Loss: {results['loss']:.4f}")
174
+ print(f" Perplexity: {results['perplexity']:.2f}")
175
+ print(f" ํ‰๊ฐ€ ํ† ํฐ: {total_tokens:,}")
176
+ print(f" ์†Œ์š” ์‹œ๊ฐ„: {elapsed:.1f}์ดˆ")
177
+
178
+ return results
179
+
180
+ @torch.no_grad()
181
+ def evaluate_per_position(
182
+ self,
183
+ model: nn.Module,
184
+ dataloader: DataLoader,
185
+ device: torch.device,
186
+ dtype: torch.dtype = torch.bfloat16,
187
+ max_batches: int = 50,
188
+ ) -> List[float]:
189
+ """์‹œํ€€์Šค ๋‚ด ์œ„์น˜๋ณ„ Loss๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
190
+
191
+ ํ•™์Šต ํฌ์ธํŠธ:
192
+ - ์œ„์น˜ 0~10: Loss๊ฐ€ ๋†’์Œ (๋ฌธ๋งฅ์ด ๋ถ€์กฑ)
193
+ - ์œ„์น˜ 100+: Loss๊ฐ€ ์•ˆ์ •์ ์œผ๋กœ ๋‚ฎ์•„์ง (๋ฌธ๋งฅ ํ™œ์šฉ)
194
+ - ์ด ํŒจํ„ด์ด Transformer์˜ in-context learning ๋Šฅ๋ ฅ์„ ๋ณด์—ฌ์คŒ
195
+ """
196
+ model.eval()
197
+ seq_len = None
198
+ position_loss_sum = None
199
+ position_count = None
200
+
201
+ for i, batch in enumerate(dataloader):
202
+ if i >= max_batches:
203
+ break
204
+
205
+ input_ids = batch["input_ids"].to(device)
206
+ targets = batch["targets"].to(device)
207
+ B, S = targets.shape
208
+
209
+ if seq_len is None:
210
+ seq_len = S
211
+ position_loss_sum = torch.zeros(S, device=device)
212
+ position_count = torch.zeros(S, device=device)
213
+
214
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
215
+ logits, _ = model(input_ids)
216
+
217
+ # (B, S) ํ˜•ํƒœ์˜ ํ† ํฐ๋ณ„ loss
218
+ loss_per_token = F.cross_entropy(
219
+ logits.view(-1, logits.size(-1)),
220
+ targets.view(-1),
221
+ ignore_index=-100,
222
+ reduction="none",
223
+ ).view(B, S)
224
+
225
+ valid_mask = (targets != -100).float()
226
+ position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
227
+ position_count += valid_mask.sum(dim=0)
228
+
229
+ # ์œ„์น˜๋ณ„ ํ‰๊ท  loss
230
+ position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
231
+ return position_avg_loss
232
+
233
+
234
+ # ============================================================================
235
+ # 3. ํ…์ŠคํŠธ ์ƒ์„ฑ ํ‰๊ฐ€
236
+ # ============================================================================
237
+
238
+ class GenerationEvaluator:
239
+ """๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ํ’ˆ์งˆ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
240
+
241
+ ํ‰๊ฐ€ ๊ด€์ :
242
+ 1) ๋ฌธ๋ฒ•์  ์ •ํ™•์„ฑ: ์˜์–ด ๋ฌธ๋ฒ•์— ๋งž๋Š” ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๋Š”๊ฐ€?
243
+ 2) ์ผ๊ด€์„ฑ: ๋ฌธ๋งฅ์„ ์œ ์ง€ํ•˜๋ฉฐ ์ด์–ด๊ฐ€๋Š”๊ฐ€?
244
+ 3) ๋‹ค์–‘์„ฑ: ๊ฐ™์€ ํ”„๋กฌํ”„ํŠธ์— ๋‹ค๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜๋Š”๊ฐ€?
245
+ 4) ๋ฐ˜๋ณต ํšŒํ”ผ: ๊ฐ™์€ ๊ตฌ์ ˆ์„ ๋ฐ˜๋ณตํ•˜์ง€ ์•Š๋Š”๊ฐ€?
246
+ 5) ์ง€์‹ ํ‘œํ˜„: ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ์ง€์‹์ด ๋ฐ˜์˜๋˜๋Š”๊ฐ€?
247
+
248
+ 1B ๋ชจ๋ธ์˜ ํ˜„์‹ค์  ๊ธฐ๋Œ€์น˜:
249
+ - ๋ฌธ๋ฒ•์ ์œผ๋กœ ์˜ฌ๋ฐ”๋ฅธ ์˜์–ด ๋ฌธ์žฅ ์ƒ์„ฑ โœ…
250
+ - ์งง์€ ๋ฌธ๋‹จ ๋‚ด ์ผ๊ด€์„ฑ ์œ ์ง€ โœ…
251
+ - ๋ณต์žกํ•œ ์ถ”๋ก ์ด๋‚˜ ๊ธด ๋…ผ๋ฆฌ ์ „๊ฐœ โŒ (๋” ํฐ ๋ชจ๋ธ ํ•„์š”)
252
+ - ์‚ฌ์‹ค์  ์ •ํ™•์„ฑ์€ ๋ณด์žฅ ์•ˆ ๋จ โš ๏ธ
253
+ """
254
+
255
+ # ๋‹ค์–‘ํ•œ ๋„๋ฉ”์ธ์˜ ํ…Œ์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ
256
+ DEFAULT_PROMPTS = [
257
+ # โ”€โ”€ ์ผ๋ฐ˜ ์ง€์‹ โ”€โ”€
258
+ "The theory of relativity states that",
259
+ "In the history of computer science,",
260
+ "The human brain is remarkable because",
261
+
262
+ # โ”€โ”€ ์„ค๋ช…/๊ต์œก โ”€โ”€
263
+ "To understand machine learning, one must first",
264
+ "The water cycle begins when",
265
+ "Photosynthesis is the process by which",
266
+
267
+ # โ”€โ”€ ์„œ์‚ฌ/์Šคํ† ๋ฆฌ โ”€โ”€
268
+ "Once upon a time, in a small village near the mountains,",
269
+ "The detective looked at the evidence and realized that",
270
+
271
+ # โ”€โ”€ ์ฝ”๋“œ/๊ธฐ์ˆ  โ”€โ”€
272
+ "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
273
+ "The most important data structures in programming are",
274
+
275
+ # โ”€โ”€ ์งง์€ ์™„์„ฑ โ”€โ”€
276
+ "The capital of France is",
277
+ "Water boils at a temperature of",
278
+
279
+ # โ”€โ”€ ๊ธด ๋ฌธ๋งฅ โ”€โ”€
280
+ ("Artificial intelligence has transformed many industries. "
281
+ "In healthcare, AI is used for diagnosis and drug discovery. "
282
+ "In finance, it powers algorithmic trading and fraud detection. "
283
+ "Looking ahead, the most promising application of AI is"),
284
+ ]
285
+
286
+ def __init__(self, config: EvalConfig):
287
+ self.config = config
288
+
289
+ @torch.no_grad()
290
+ def generate_samples(
291
+ self,
292
+ model: nn.Module,
293
+ tokenizer: Any,
294
+ device: torch.device,
295
+ prompts: Optional[List[str]] = None,
296
+ verbose: bool = True,
297
+ ) -> List[Dict[str, Any]]:
298
+ """ํ”„๋กฌํ”„ํŠธ๋ณ„๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
299
+
300
+ Returns:
301
+ [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
302
+ """
303
+ model.eval()
304
+ prompts = prompts or self.DEFAULT_PROMPTS
305
+ results = []
306
+
307
+ if verbose:
308
+ print("\n" + "=" * 70)
309
+ print("๐Ÿ“ ํ…์ŠคํŠธ ์ƒ์„ฑ ํ‰๊ฐ€")
310
+ print("=" * 70)
311
+
312
+ for idx, prompt in enumerate(prompts):
313
+ prompt_results = {
314
+ "prompt": prompt,
315
+ "generations": [],
316
+ "metrics": {},
317
+ }
318
+
319
+ if verbose:
320
+ print(f"\n{'โ”€'*60}")
321
+ print(f"ํ”„๋กฌํ”„ํŠธ [{idx+1}/{len(prompts)}]:")
322
+ print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
323
+ print(f"{'โ”€'*60}")
324
+
325
+ # ํ”„๋กฌํ”„ํŠธ ์ธ์ฝ”๋”ฉ
326
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
327
+ input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
328
+
329
+ all_texts = []
330
+ for sample_idx in range(self.config.num_samples):
331
+ # ์ƒ์„ฑ
332
+ generated_ids = model.generate(
333
+ input_tensor,
334
+ max_new_tokens=self.config.max_new_tokens,
335
+ temperature=self.config.temperature,
336
+ top_k=self.config.top_k,
337
+ top_p=self.config.top_p,
338
+ )
339
+
340
+ # ๋””์ฝ”๋”ฉ (ํ”„๋กฌํ”„ํŠธ ์ดํ›„ ๋ถ€๋ถ„๋งŒ)
341
+ new_ids = generated_ids[0][len(prompt_ids):].tolist()
342
+ generated_text = tokenizer.decode(new_ids)
343
+ all_texts.append(generated_text)
344
+
345
+ prompt_results["generations"].append(generated_text)
346
+
347
+ if verbose:
348
+ print(f"\n โœ๏ธ ์ƒ์„ฑ #{sample_idx+1}:")
349
+ # ๊น”๋”ํ•œ ์ถœ๋ ฅ (์ค„๋ฐ”๊ฟˆ ํฌํ•จ)
350
+ display_text = generated_text[:500]
351
+ for line in display_text.split("\n"):
352
+ print(f" {line}")
353
+ if len(generated_text) > 500:
354
+ print(f" ... (์ด {len(generated_text)} ๋ฌธ์ž)")
355
+
356
+ # ์ƒ์„ฑ ํ’ˆ์งˆ ๋ฉ”ํŠธ๋ฆญ
357
+ prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
358
+
359
+ if verbose and prompt_results["metrics"]:
360
+ m = prompt_results["metrics"]
361
+ print(f"\n ๐Ÿ“Š ๋ฉ”ํŠธ๋ฆญ: "
362
+ f"ํ‰๊ท  ๊ธธ์ด={m['avg_length']:.0f}์ž, "
363
+ f"๋ฐ˜๋ณต๋ฅ ={m['repetition_rate']:.1%}, "
364
+ f"์–ดํœ˜ ๋‹ค์–‘์„ฑ={m['lexical_diversity']:.2f}")
365
+
366
+ results.append(prompt_results)
367
+
368
+ return results
369
+
370
+ @staticmethod
371
+ def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
372
+ """์ƒ์„ฑ ํ…์ŠคํŠธ์˜ ํ’ˆ์งˆ ๋ฉ”ํŠธ๋ฆญ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
373
+
374
+ ๋ฉ”ํŠธ๋ฆญ:
375
+ - avg_length: ํ‰๊ท  ์ƒ์„ฑ ๊ธธ์ด (๋ฌธ์ž)
376
+ - avg_word_count: ํ‰๊ท  ๋‹จ์–ด ์ˆ˜
377
+ - repetition_rate: n-gram ๋ฐ˜๋ณต๋ฅ  (๋‚ฎ์„์ˆ˜๋ก ์ข‹์Œ)
378
+ - lexical_diversity: ๊ณ ์œ  ๋‹จ์–ด ๋น„์œจ (๋†’์„์ˆ˜๋ก ๋‹ค์–‘)
379
+ - sample_diversity: ์ƒ˜ํ”Œ ๊ฐ„ ๋‹ค์–‘์„ฑ (๋‹ค๋ฅธ ์ƒ์„ฑ๋ผ๋ฆฌ ์–ผ๋งˆ๋‚˜ ๋‹ค๋ฅธ๊ฐ€)
380
+ """
381
+ if not texts:
382
+ return {}
383
+
384
+ # ๊ธธ์ด
385
+ lengths = [len(t) for t in texts]
386
+ word_counts = [len(t.split()) for t in texts]
387
+
388
+ # ๋ฐ˜๋ณต๋ฅ  (4-gram ๊ธฐ์ค€)
389
+ rep_rates = []
390
+ for text in texts:
391
+ words = text.lower().split()
392
+ if len(words) < 4:
393
+ rep_rates.append(0.0)
394
+ continue
395
+ ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
396
+ unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
397
+ rep_rates.append(1.0 - unique_ratio) # ๋ฐ˜๋ณต๋ฅ  = 1 - ๊ณ ์œ ๋น„์œจ
398
+
399
+ # ์–ดํœ˜ ๋‹ค์–‘์„ฑ (Type-Token Ratio)
400
+ diversities = []
401
+ for text in texts:
402
+ words = text.lower().split()
403
+ if words:
404
+ diversities.append(len(set(words)) / len(words))
405
+ else:
406
+ diversities.append(0.0)
407
+
408
+ # ์ƒ˜ํ”Œ ๊ฐ„ ๋‹ค์–‘์„ฑ (์ž์นด๋“œ ์œ ์‚ฌ๋„์˜ ์—ญ)
409
+ sample_div = 0.0
410
+ if len(texts) > 1:
411
+ word_sets = [set(t.lower().split()) for t in texts]
412
+ similarities = []
413
+ for i in range(len(word_sets)):
414
+ for j in range(i+1, len(word_sets)):
415
+ inter = len(word_sets[i] & word_sets[j])
416
+ union = len(word_sets[i] | word_sets[j])
417
+ if union > 0:
418
+ similarities.append(inter / union)
419
+ sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
420
+
421
+ return {
422
+ "avg_length": sum(lengths) / len(lengths),
423
+ "avg_word_count": sum(word_counts) / len(word_counts),
424
+ "repetition_rate": sum(rep_rates) / len(rep_rates),
425
+ "lexical_diversity": sum(diversities) / len(diversities),
426
+ "sample_diversity": round(sample_div, 3),
427
+ }
428
+
429
+
430
+ # ============================================================================
431
+ # 4. Scaling Law ๋ถ„์„๊ธฐ
432
+ # ============================================================================
433
+
434
+ class ScalingAnalyzer:
435
+ """10M โ†’ 100M โ†’ 1B ๋ชจ๋ธ์˜ Scaling Law๋ฅผ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
436
+
437
+ Chinchilla Scaling Law (2022):
438
+ - ์ตœ์  ํ•™์Šต: ํ† ํฐ ์ˆ˜ โ‰ˆ 20 ร— ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
439
+ - Loss โˆ N^(-ฮฑ) ร— D^(-ฮฒ) (N=ํŒŒ๋ผ๋ฏธํ„ฐ, D=๋ฐ์ดํ„ฐ)
440
+ - ฮฑ โ‰ˆ 0.076, ฮฒ โ‰ˆ 0.095 (๋…ผ๋ฌธ ๊ธฐ์ค€)
441
+
442
+ ์ด ๋ถ„์„์˜ ๋ชฉ์ :
443
+ - ์šฐ๋ฆฌ ๋ชจ๋ธ์ด Scaling Law๋ฅผ ๋”ฐ๋ฅด๋Š”์ง€ ํ™•์ธ
444
+ - ๋” ํฐ ๋ชจ๋ธ/๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ์˜ ํšจ๊ณผ๋ฅผ ์˜ˆ์ธก
445
+ - ์ปดํ“จํŒ… ์ž์› ๋ฐฐ๋ถ„์˜ ์ตœ์ ์  ์ดํ•ด
446
+ """
447
+
448
+ def __init__(self, save_dir: str = "./eval_results"):
449
+ self.save_dir = Path(save_dir)
450
+ self.save_dir.mkdir(parents=True, exist_ok=True)
451
+
452
+ def analyze(
453
+ self,
454
+ model_results: List[Dict[str, Any]],
455
+ ) -> Dict[str, Any]:
456
+ """์—ฌ๋Ÿฌ ๋ชจ๋ธ ํฌ๊ธฐ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ต ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
457
+
458
+ Args:
459
+ model_results: [
460
+ {"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
461
+ {"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
462
+ {"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1},
463
+ ]
464
+
465
+ Returns:
466
+ ๋ถ„์„ ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ
467
+ """
468
+ if len(model_results) < 2:
469
+ print("โš ๏ธ Scaling ๋ถ„์„์—๋Š” ์ตœ์†Œ 2๊ฐœ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
470
+ return {}
471
+
472
+ print("\n" + "=" * 70)
473
+ print("๐Ÿ“ˆ Scaling Law ๋ถ„์„")
474
+ print("=" * 70)
475
+
476
+ # โ”€โ”€ ๊ฒฐ๊ณผ ํ…Œ์ด๋ธ” โ”€โ”€
477
+ print(f"\n {'๋ชจ๋ธ':<8} {'ํŒŒ๋ผ๋ฏธํ„ฐ':>12} {'ํ† ํฐ':>10} {'Loss':>8} {'PPL':>8}")
478
+ print(f" {'โ”€'*52}")
479
+ for r in model_results:
480
+ params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
481
+ tokens_str = f"{r['tokens']/1e9:.1f}B"
482
+ print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
483
+
484
+ # โ”€โ”€ Scaling ํšจ์œจ ๊ณ„์‚ฐ โ”€โ”€
485
+ analysis = {"models": model_results, "scaling_efficiency": []}
486
+
487
+ for i in range(1, len(model_results)):
488
+ prev = model_results[i-1]
489
+ curr = model_results[i]
490
+
491
+ param_ratio = curr["params"] / prev["params"]
492
+ loss_reduction = prev["loss"] - curr["loss"]
493
+ ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"]
494
+
495
+ efficiency = {
496
+ "from": prev["name"],
497
+ "to": curr["name"],
498
+ "param_multiplier": round(param_ratio, 1),
499
+ "loss_reduction": round(loss_reduction, 4),
500
+ "ppl_reduction_pct": round(ppl_reduction * 100, 1),
501
+ }
502
+ analysis["scaling_efficiency"].append(efficiency)
503
+
504
+ print(f"\n {prev['name']} โ†’ {curr['name']}:")
505
+ print(f" ํŒŒ๋ผ๋ฏธํ„ฐ ร—{param_ratio:.1f}")
506
+ print(f" Loss ๊ฐ์†Œ: {loss_reduction:.4f}")
507
+ print(f" PPL ๊ฐ์†Œ: {ppl_reduction*100:.1f}%")
508
+
509
+ # โ”€โ”€ Chinchilla ์ตœ์ ์„ฑ ์ฒดํฌ โ”€โ”€
510
+ print(f"\n Chinchilla ์ตœ์ ์„ฑ ์ฒดํฌ (ํ† ํฐ โ‰ˆ 20 ร— ํŒŒ๋ผ๋ฏธํ„ฐ):")
511
+ for r in model_results:
512
+ optimal_tokens = r["params"] * 20
513
+ actual_ratio = r["tokens"] / r["params"]
514
+ status = "โœ… ์ตœ์  ๋ฒ”์œ„" if 15 <= actual_ratio <= 25 else "โš ๏ธ ๋ฒ”์œ„ ๋ฐ–"
515
+ print(f" {r['name']}: ํ† ํฐ/ํŒŒ๋ผ๋ฏธํ„ฐ = {actual_ratio:.1f}x "
516
+ f"(์ตœ์ : 20x) {status}")
517
+
518
+ analysis["chinchilla_ratios"] = [
519
+ {"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
520
+ for r in model_results
521
+ ]
522
+
523
+ return analysis
524
+
525
+ def plot_scaling_curves(
526
+ self,
527
+ model_results: List[Dict[str, Any]],
528
+ save_path: Optional[str] = None,
529
+ ):
530
+ """Scaling ๊ณก์„ ์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
531
+ if not HAS_MATPLOTLIB or not HAS_NUMPY:
532
+ print("โš ๏ธ matplotlib/numpy๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค: pip install matplotlib numpy")
533
+ return
534
+
535
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
536
+
537
+ params = [r["params"] for r in model_results]
538
+ losses = [r["loss"] for r in model_results]
539
+ ppls = [r["ppl"] for r in model_results]
540
+ names = [r["name"] for r in model_results]
541
+
542
+ # โ”€โ”€ Loss vs Parameters (log-log) โ”€โ”€
543
+ ax = axes[0]
544
+ ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10)
545
+ for p, l, n in zip(params, losses, names):
546
+ ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9)
547
+ ax.set_xlabel("Parameters", fontsize=12)
548
+ ax.set_ylabel("Validation Loss", fontsize=12)
549
+ ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold")
550
+ ax.grid(True, alpha=0.3)
551
+
552
+ # โ”€โ”€ PPL vs Parameters (log-log) โ”€โ”€
553
+ ax = axes[1]
554
+ ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10)
555
+ for p, pp, n in zip(params, ppls, names):
556
+ ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9)
557
+ ax.set_xlabel("Parameters", fontsize=12)
558
+ ax.set_ylabel("Perplexity", fontsize=12)
559
+ ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold")
560
+ ax.grid(True, alpha=0.3)
561
+
562
+ plt.tight_layout()
563
+
564
+ save_path = save_path or str(self.save_dir / "scaling_curves.png")
565
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
566
+ print(f"\n ๐Ÿ“Š Scaling ๊ณก์„  ์ €์žฅ: {save_path}")
567
+ plt.close(fig)
568
+
569
+
570
+ # ============================================================================
571
+ # 5. ํ•™์Šต ์—ญํ•™ ๋ถ„์„๊ธฐ
572
+ # ============================================================================
573
+
574
+ class TrainingDynamicsAnalyzer:
575
+ """ํ•™์Šต ๊ณผ์ •์˜ ๋ฉ”ํŠธ๋ฆญ์„ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
576
+
577
+ ๋ถ„์„ ํ•ญ๋ชฉ:
578
+ - Loss ๊ณก์„ : ์ˆ˜๋ ด ํŒจํ„ด, ์ŠคํŒŒ์ดํฌ ๊ฐ์ง€
579
+ - LR ์Šค์ผ€์ค„: Warmup + Cosine decay ํ™•์ธ
580
+ - Gradient Norm: ํ•™์Šต ์•ˆ์ •์„ฑ, ํญ๋ฐœ/์†Œ๋ฉธ ๊ฐ์ง€
581
+ - ์ฒ˜๋ฆฌ๋Ÿ‰: tokens/sec ์•ˆ์ •์„ฑ, ๋ณ‘๋ชฉ ๊ฐ์ง€
582
+ """
583
+
584
+ def __init__(self, save_dir: str = "./eval_results"):
585
+ self.save_dir = Path(save_dir)
586
+ self.save_dir.mkdir(parents=True, exist_ok=True)
587
+
588
+ def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
589
+ """ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ์„ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
590
+
591
+ Args:
592
+ metrics_history: Trainer.metrics.history ๋”•์…”๋„ˆ๋ฆฌ
593
+
594
+ Returns:
595
+ ๋ถ„์„ ๊ฒฐ๊ณผ
596
+ """
597
+ print("\n" + "=" * 70)
598
+ print("๐Ÿ”ฌ ํ•™์Šต ์—ญํ•™ ๋ถ„์„")
599
+ print("=" * 70)
600
+
601
+ analysis = {}
602
+
603
+ # โ”€โ”€ Loss ๋ถ„์„ โ”€โ”€
604
+ if metrics_history.get("train_loss"):
605
+ losses = metrics_history["train_loss"]
606
+ analysis["loss"] = {
607
+ "initial": round(losses[0], 4),
608
+ "final": round(losses[-1], 4),
609
+ "minimum": round(min(losses), 4),
610
+ "total_reduction": round(losses[0] - losses[-1], 4),
611
+ }
612
+
613
+ # ์ŠคํŒŒ์ดํฌ ๊ฐ์ง€ (์ด์ „ ๊ฐ’ ๋Œ€๋น„ 50% ์ด์ƒ ๊ธ‰์ฆ)
614
+ spikes = []
615
+ for i in range(1, len(losses)):
616
+ if losses[i] > losses[i-1] * 1.5:
617
+ step = metrics_history["step"][i] if "step" in metrics_history else i
618
+ spikes.append({"step": step, "loss": round(losses[i], 4)})
619
+
620
+ analysis["loss"]["spikes"] = spikes
621
+
622
+ print(f"\n ๐Ÿ“‰ Loss ๋ถ„์„:")
623
+ print(f" ์ดˆ๊ธฐ: {analysis['loss']['initial']:.4f}")
624
+ print(f" ์ตœ์ข…: {analysis['loss']['final']:.4f}")
625
+ print(f" ์ตœ์†Œ: {analysis['loss']['minimum']:.4f}")
626
+ print(f" ๊ฐ์†Œ: {analysis['loss']['total_reduction']:.4f}")
627
+ print(f" ์ŠคํŒŒ์ดํฌ: {len(spikes)}ํšŒ")
628
+ if spikes:
629
+ for s in spikes[:5]:
630
+ print(f" Step {s['step']}: Loss = {s['loss']}")
631
+
632
+ # โ”€โ”€ Gradient Norm ๋ถ„์„ โ”€โ”€
633
+ if metrics_history.get("grad_norm"):
634
+ gnorms = metrics_history["grad_norm"]
635
+ analysis["grad_norm"] = {
636
+ "mean": round(sum(gnorms) / len(gnorms), 4),
637
+ "max": round(max(gnorms), 4),
638
+ "min": round(min(gnorms), 4),
639
+ "clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
640
+ }
641
+
642
+ print(f"\n ๐Ÿ“ Gradient Norm ๋ถ„์„:")
643
+ print(f" ํ‰๊ท : {analysis['grad_norm']['mean']:.4f}")
644
+ print(f" ์ตœ๋Œ€: {analysis['grad_norm']['max']:.4f}")
645
+ print(f" ํด๋ฆฌํ•‘ ๋น„์œจ: {analysis['grad_norm']['clipped_pct']:.1f}%")
646
+ if analysis["grad_norm"]["clipped_pct"] > 30:
647
+ print(f" โš ๏ธ ํด๋ฆฌํ•‘์ด ์žฆ์Œ โ†’ LR ํ•˜ํ–ฅ ๋˜๋Š” warmup ์—ฐ์žฅ ๊ณ ๋ ค")
648
+
649
+ # โ”€โ”€ ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ถ„์„ โ”€โ”€
650
+ if metrics_history.get("tokens_per_sec"):
651
+ tps = metrics_history["tokens_per_sec"]
652
+ tps_valid = [t for t in tps if t > 0]
653
+ if tps_valid:
654
+ analysis["throughput"] = {
655
+ "mean": round(sum(tps_valid) / len(tps_valid)),
656
+ "std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5),
657
+ "min": round(min(tps_valid)),
658
+ "max": round(max(tps_valid)),
659
+ }
660
+
661
+ print(f"\n โšก ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ถ„์„:")
662
+ print(f" ํ‰๊ท : {analysis['throughput']['mean']:,} tokens/sec")
663
+ print(f" ํ‘œ์ค€ํŽธ์ฐจ: {analysis['throughput']['std']:,}")
664
+ print(f" ๋ฒ”์œ„: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
665
+
666
+ return analysis
667
+
668
+ def plot_training_curves(
669
+ self,
670
+ metrics_history: Dict[str, list],
671
+ save_path: Optional[str] = None,
672
+ ):
673
+ """ํ•™์Šต ๊ณก์„ ์„ 4-panel ์ฐจํŠธ๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
674
+ if not HAS_MATPLOTLIB:
675
+ print("โš ๏ธ matplotlib๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค: pip install matplotlib")
676
+ return
677
+
678
+ fig, axes = plt.subplots(2, 2, figsize=(16, 10))
679
+ fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold")
680
+
681
+ steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", [])))))
682
+
683
+ # โ”€โ”€ (1) Loss โ”€โ”€
684
+ ax = axes[0, 0]
685
+ if metrics_history.get("train_loss"):
686
+ ax.plot(steps[:len(metrics_history["train_loss"])],
687
+ metrics_history["train_loss"],
688
+ color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
689
+
690
+ # ์ด๋™ ํ‰๊ท  (์Šค๋ฌด๋”ฉ)
691
+ if len(metrics_history["train_loss"]) > 20:
692
+ window = min(50, len(metrics_history["train_loss"]) // 5)
693
+ smoothed = self._moving_average(metrics_history["train_loss"], window)
694
+ ax.plot(steps[window-1:len(smoothed)+window-1],
695
+ smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})")
696
+
697
+ if metrics_history.get("val_loss"):
698
+ val_steps = [steps[i] for i in range(0, len(steps),
699
+ max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])]
700
+ ax.plot(val_steps, metrics_history["val_loss"],
701
+ "o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss")
702
+
703
+ ax.set_xlabel("Step")
704
+ ax.set_ylabel("Loss")
705
+ ax.set_title("Training & Validation Loss")
706
+ ax.legend()
707
+ ax.grid(True, alpha=0.3)
708
+
709
+ # โ”€โ”€ (2) Learning Rate โ”€โ”€
710
+ ax = axes[0, 1]
711
+ if metrics_history.get("learning_rate"):
712
+ ax.plot(steps[:len(metrics_history["learning_rate"])],
713
+ metrics_history["learning_rate"],
714
+ color="#059669", linewidth=2)
715
+ ax.set_xlabel("Step")
716
+ ax.set_ylabel("Learning Rate")
717
+ ax.set_title("Learning Rate Schedule")
718
+ ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0))
719
+ ax.grid(True, alpha=0.3)
720
+
721
+ # โ”€โ”€ (3) Gradient Norm โ”€โ”€
722
+ ax = axes[1, 0]
723
+ if metrics_history.get("grad_norm"):
724
+ ax.plot(steps[:len(metrics_history["grad_norm"])],
725
+ metrics_history["grad_norm"],
726
+ color="#d97706", alpha=0.6, linewidth=0.8)
727
+ ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold")
728
+ ax.legend()
729
+ ax.set_xlabel("Step")
730
+ ax.set_ylabel("Gradient Norm")
731
+ ax.set_title("Gradient Norm (clipped at 1.0)")
732
+ ax.grid(True, alpha=0.3)
733
+
734
+ # โ”€โ”€ (4) Throughput โ”€โ”€
735
+ ax = axes[1, 1]
736
+ if metrics_history.get("tokens_per_sec"):
737
+ tps = metrics_history["tokens_per_sec"]
738
+ ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8)
739
+ if tps:
740
+ avg_tps = sum(tps) / len(tps)
741
+ ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5,
742
+ label=f"Avg: {avg_tps:,.0f}")
743
+ ax.legend()
744
+ ax.set_xlabel("Step")
745
+ ax.set_ylabel("Tokens/sec")
746
+ ax.set_title("Training Throughput")
747
+ ax.grid(True, alpha=0.3)
748
+
749
+ plt.tight_layout()
750
+
751
+ save_path = save_path or str(self.save_dir / "training_curves.png")
752
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
753
+ print(f"\n ๐Ÿ“Š ํ•™์Šต ๊ณก์„  ์ €์žฅ: {save_path}")
754
+ plt.close(fig)
755
+
756
+ def plot_position_loss(
757
+ self,
758
+ position_losses: List[float],
759
+ save_path: Optional[str] = None,
760
+ ):
761
+ """์œ„์น˜๋ณ„ Loss ๋ถ„ํฌ๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
762
+ if not HAS_MATPLOTLIB:
763
+ return
764
+
765
+ fig, ax = plt.subplots(figsize=(12, 5))
766
+
767
+ positions = list(range(len(position_losses)))
768
+ ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5)
769
+ ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb")
770
+
771
+ ax.set_xlabel("Position in Sequence", fontsize=12)
772
+ ax.set_ylabel("Cross-Entropy Loss", fontsize=12)
773
+ ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
774
+ ax.grid(True, alpha=0.3)
775
+
776
+ # ์ฃผ์š” ๊ตฌ๊ฐ„ ํ‘œ์‹œ
777
+ if len(position_losses) > 100:
778
+ early_avg = sum(position_losses[:50]) / 50
779
+ late_avg = sum(position_losses[-200:]) / 200
780
+ ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4,
781
+ label=f"Early avg (0-50): {early_avg:.2f}")
782
+ ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4,
783
+ label=f"Late avg (-200): {late_avg:.2f}")
784
+ ax.legend()
785
+
786
+ plt.tight_layout()
787
+
788
+ save_path = save_path or str(self.save_dir / "position_loss.png")
789
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
790
+ print(f" ๐Ÿ“Š ์œ„์น˜๋ณ„ Loss ์ €์žฅ: {save_path}")
791
+ plt.close(fig)
792
+
793
+ @staticmethod
794
+ def _moving_average(data: list, window: int) -> list:
795
+ """์ด๋™ ํ‰๊ท  ๊ณ„์‚ฐ."""
796
+ result = []
797
+ for i in range(window - 1, len(data)):
798
+ avg = sum(data[i - window + 1 : i + 1]) / window
799
+ result.append(avg)
800
+ return result
801
+
802
+
803
+ # ============================================================================
804
+ # 6. Attention ์‹œ๊ฐํ™”
805
+ # ============================================================================
806
+
807
+ class AttentionVisualizer:
808
+ """Attention ํŒจํ„ด์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
809
+
810
+ ํ•™์Šต ํฌ์ธํŠธ:
811
+ - Causal Mask: ํ•˜์‚ผ๊ฐ ํŒจํ„ด (๋ฏธ๋ž˜ ํ† ํฐ์€ ๋ณผ ์ˆ˜ ์—†์Œ)
812
+ - ํ—ค๋“œ๋ณ„ ์—ญํ•  ๋ถ„ํ™”: ์ผ๋ถ€๋Š” ๋กœ์ปฌ(์ธ์ ‘), ์ผ๋ถ€๋Š” ๊ธ€๋กœ๋ฒŒ(๋จผ ํ† ํฐ) ์ฃผ๋ชฉ
813
+ - ๊ตฌ๋ฌธ๋ก ์  ํŒจํ„ด: ๋™์‚ฌโ†’์ฃผ์–ด, ๋Œ€๋ช…์‚ฌโ†’์„ ํ–‰์‚ฌ ๋“ฑ์— ๋†’์€ attention
814
+
815
+ ์ฃผ์˜: 1B ๋ชจ๋ธ์˜ ์ „์ฒด attention์„ ์ €์žฅํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ!
816
+ โ†’ ํŠน์ • ๋ ˆ์ด์–ด/ํ—ค๋“œ๋งŒ ์„ ํƒ์ ์œผ๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
817
+ """
818
+
819
+ def __init__(self, save_dir: str = "./eval_results"):
820
+ self.save_dir = Path(save_dir)
821
+ self.save_dir.mkdir(parents=True, exist_ok=True)
822
+
823
+ @torch.no_grad()
824
+ def extract_attention(
825
+ self,
826
+ model: nn.Module,
827
+ input_ids: torch.Tensor,
828
+ layer_idx: int = 0,
829
+ device: torch.device = torch.device("cpu"),
830
+ ) -> torch.Tensor:
831
+ """ํŠน์ • ๋ ˆ์ด์–ด์˜ attention weight๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
832
+
833
+ ๋ชจ๋ธ์˜ attention ๋ชจ๋“ˆ์„ ์ผ์‹œ์ ์œผ๋กœ ์ˆ˜์ •ํ•˜์—ฌ
834
+ attention weight๋ฅผ ์บก์ฒ˜ํ•ฉ๋‹ˆ๋‹ค.
835
+
836
+ Returns:
837
+ attention_weights: (num_heads, seq_len, seq_len)
838
+ """
839
+ model.eval()
840
+ captured_attn = {}
841
+
842
+ # Hook์œผ๋กœ attention weight ์บก์ฒ˜
843
+ target_layer = model.layers[layer_idx].attention
844
+
845
+ # scaled_dot_product_attention์„ ์ˆ˜๋™ ๊ตฌํ˜„์œผ๋กœ ๋Œ€์ฒด
846
+ original_forward = target_layer.forward
847
+
848
+ def hooked_forward(x, mask=None, position_offset=0):
849
+ B, S, _ = x.shape
850
+ hd = target_layer.head_dim
851
+
852
+ q = target_layer.q_proj(x).view(B, S, target_layer.num_heads, hd).transpose(1, 2)
853
+ k = target_layer.k_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
854
+ v = target_layer.v_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
855
+
856
+ q, k = target_layer.rope(q, k, position_offset)
857
+
858
+ if target_layer.num_kv_groups > 1:
859
+ k = target_layer._repeat_kv(k)
860
+ v = target_layer._repeat_kv(v)
861
+
862
+ # ์ˆ˜๋™ attention ๊ณ„์‚ฐ (weight ์ถ”์ถœ์šฉ)
863
+ scale = 1.0 / math.sqrt(hd)
864
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
865
+
866
+ # Causal mask
867
+ causal = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
868
+ scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
869
+
870
+ attn_weights = F.softmax(scores, dim=-1)
871
+ captured_attn["weights"] = attn_weights[0].cpu() # ์ฒซ ๋ฐฐ์น˜๋งŒ
872
+
873
+ out = torch.matmul(attn_weights, v)
874
+ out = out.transpose(1, 2).contiguous().view(B, S, -1)
875
+ return target_layer.o_proj(out)
876
+
877
+ # Hook ์ ์šฉ
878
+ target_layer.forward = hooked_forward
879
+
880
+ try:
881
+ model(input_ids.to(device))
882
+ finally:
883
+ target_layer.forward = original_forward
884
+
885
+ return captured_attn.get("weights") # (num_heads, S, S)
886
+
887
+ def plot_attention_heatmap(
888
+ self,
889
+ attn_weights: torch.Tensor,
890
+ tokens: List[str],
891
+ head_idx: int = 0,
892
+ save_path: Optional[str] = None,
893
+ title: str = "Attention Weights",
894
+ ):
895
+ """Attention heatmap์„ ๊ทธ๋ฆฝ๋‹ˆ๋‹ค."""
896
+ if not HAS_MATPLOTLIB:
897
+ print("โš ๏ธ matplotlib๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
898
+ return
899
+
900
+ weights = attn_weights[head_idx].numpy()
901
+ max_len = min(len(tokens), 50) # ์ตœ๋Œ€ 50 ํ† ํฐ๋งŒ ํ‘œ์‹œ
902
+ weights = weights[:max_len, :max_len]
903
+ display_tokens = tokens[:max_len]
904
+
905
+ fig, ax = plt.subplots(figsize=(12, 10))
906
+ im = ax.imshow(weights, cmap="Blues", aspect="auto")
907
+
908
+ ax.set_xticks(range(max_len))
909
+ ax.set_yticks(range(max_len))
910
+ ax.set_xticklabels(display_tokens, rotation=90, fontsize=7)
911
+ ax.set_yticklabels(display_tokens, fontsize=7)
912
+
913
+ ax.set_xlabel("Key (attended to)", fontsize=11)
914
+ ax.set_ylabel("Query (attending from)", fontsize=11)
915
+ ax.set_title(f"{title} โ€” Head {head_idx}", fontsize=13, fontweight="bold")
916
+
917
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
918
+ plt.tight_layout()
919
+
920
+ save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
921
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
922
+ print(f" ๐Ÿ“Š Attention ์‹œ๊ฐํ™” ์ €์žฅ: {save_path}")
923
+ plt.close(fig)
924
+
925
+ def plot_multi_head_summary(
926
+ self,
927
+ attn_weights: torch.Tensor,
928
+ num_heads_to_show: int = 8,
929
+ save_path: Optional[str] = None,
930
+ ):
931
+ """์—ฌ๋Ÿฌ ํ—ค๋“œ์˜ attention ํŒจํ„ด์„ ์š”์•ฝ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค."""
932
+ if not HAS_MATPLOTLIB:
933
+ return
934
+
935
+ n_heads = min(attn_weights.shape[0], num_heads_to_show)
936
+ cols = 4
937
+ rows = math.ceil(n_heads / cols)
938
+
939
+ fig, axes = plt.subplots(rows, cols, figsize=(16, 4 * rows))
940
+ if rows == 1:
941
+ axes = axes.reshape(1, -1)
942
+
943
+ for idx in range(n_heads):
944
+ r, c = idx // cols, idx % cols
945
+ ax = axes[r, c]
946
+ w = attn_weights[idx].numpy()
947
+ ax.imshow(w, cmap="Blues", aspect="auto")
948
+ ax.set_title(f"Head {idx}", fontsize=10)
949
+ ax.set_xticks([])
950
+ ax.set_yticks([])
951
+
952
+ # ๋นˆ subplot ์ˆจ๊ธฐ๊ธฐ
953
+ for idx in range(n_heads, rows * cols):
954
+ r, c = idx // cols, idx % cols
955
+ axes[r, c].axis("off")
956
+
957
+ fig.suptitle("Attention Patterns by Head", fontsize=14, fontweight="bold")
958
+ plt.tight_layout()
959
+
960
+ save_path = save_path or str(self.save_dir / "attention_multi_head.png")
961
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
962
+ print(f" ๐Ÿ“Š ๋ฉ€ํ‹ฐ ํ—ค๋“œ ์š”์•ฝ ์ €์žฅ: {save_path}")
963
+ plt.close(fig)
964
+
965
+
966
+ # ============================================================================
967
+ # 7. ์ข…ํ•ฉ ํ‰๊ฐ€ ์‹คํ–‰๊ธฐ
968
+ # ============================================================================
969
+
970
+ class FullEvaluator:
971
+ """๋ชจ๋“  ํ‰๊ฐ€๋ฅผ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•˜๊ณ  ๋ฆฌํฌํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
972
+
973
+ ์‚ฌ์šฉ๋ฒ•:
974
+ ```python
975
+ evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
976
+ report = evaluator.run_full_evaluation()
977
+ ```
978
+ """
979
+
980
+ def __init__(
981
+ self,
982
+ model: nn.Module,
983
+ tokenizer: Any,
984
+ val_dataloader: DataLoader,
985
+ device: torch.device,
986
+ config: Optional[EvalConfig] = None,
987
+ dtype: torch.dtype = torch.bfloat16,
988
+ metrics_history: Optional[Dict[str, list]] = None,
989
+ ):
990
+ self.model = model
991
+ self.tokenizer = tokenizer
992
+ self.val_dataloader = val_dataloader
993
+ self.device = device
994
+ self.config = config or EvalConfig()
995
+ self.dtype = dtype
996
+ self.metrics_history = metrics_history
997
+
998
+ self.save_dir = Path(self.config.save_dir)
999
+ self.save_dir.mkdir(parents=True, exist_ok=True)
1000
+
1001
+ def run_full_evaluation(self) -> Dict[str, Any]:
1002
+ """์ „์ฒด ํ‰๊ฐ€๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."""
1003
+ report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
1004
+
1005
+ print("\n" + "=" * 70)
1006
+ print("๐Ÿ” ์ข…ํ•ฉ ํ‰๊ฐ€ ์‹œ์ž‘")
1007
+ print("=" * 70)
1008
+
1009
+ # โ”€โ”€ 1. Perplexity โ”€โ”€
1010
+ print("\n" + "โ”" * 40)
1011
+ print("Phase 1/4: Perplexity ์ธก์ •")
1012
+ print("โ”" * 40)
1013
+ ppl_evaluator = PerplexityEvaluator(self.config)
1014
+ report["perplexity"] = ppl_evaluator.evaluate(
1015
+ self.model, self.val_dataloader, self.device, self.dtype
1016
+ )
1017
+
1018
+ # ์œ„์น˜๋ณ„ Loss
1019
+ print("\n ์œ„์น˜๋ณ„ Loss ์ธก์ • ์ค‘...")
1020
+ position_losses = ppl_evaluator.evaluate_per_position(
1021
+ self.model, self.val_dataloader, self.device, self.dtype
1022
+ )
1023
+ report["position_losses"] = {
1024
+ "early_avg": round(sum(position_losses[:50]) / max(len(position_losses[:50]), 1), 4),
1025
+ "late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
1026
+ }
1027
+
1028
+ # ์œ„์น˜๋ณ„ Loss ์‹œ๊ฐํ™”
1029
+ dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
1030
+ dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
1031
+
1032
+ # โ”€โ”€ 2. ํ…์ŠคํŠธ ์ƒ์„ฑ โ”€โ”€
1033
+ print("\n" + "โ”" * 40)
1034
+ print("Phase 2/4: ํ…์ŠคํŠธ ์ƒ์„ฑ")
1035
+ print("โ”" * 40)
1036
+ gen_evaluator = GenerationEvaluator(self.config)
1037
+ gen_results = gen_evaluator.generate_samples(
1038
+ self.model, self.tokenizer, self.device
1039
+ )
1040
+ report["generation"] = {
1041
+ "num_prompts": len(gen_results),
1042
+ "avg_metrics": self._average_gen_metrics(gen_results),
1043
+ }
1044
+
1045
+ # โ”€โ”€ 3. ํ•™์Šต ์—ญํ•™ ๋ถ„์„ โ”€โ”€
1046
+ if self.metrics_history:
1047
+ print("\n" + "โ”" * 40)
1048
+ print("Phase 3/4: ํ•™์Šต ์—ญํ•™ ๋ถ„์„")
1049
+ print("โ”" * 40)
1050
+ report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
1051
+ dynamics.plot_training_curves(self.metrics_history,
1052
+ str(self.save_dir / "training_curves.png"))
1053
+ else:
1054
+ print("\n Phase 3/4: ๊ฑด๋„ˆ๋œ€ (metrics_history ์—†์Œ)")
1055
+
1056
+ # โ”€โ”€ 4. Attention ์‹œ๊ฐํ™” (์ƒ˜ํ”Œ) โ”€โ”€
1057
+ print("\n" + "โ”" * 40)
1058
+ print("Phase 4/4: Attention ์‹œ๊ฐํ™”")
1059
+ print("โ”" * 40)
1060
+ try:
1061
+ self._visualize_attention_sample()
1062
+ except Exception as e:
1063
+ print(f" โš ๏ธ Attention ์‹œ๊ฐํ™” ์‹คํŒจ: {e}")
1064
+
1065
+ # โ”€โ”€ ๋ฆฌํฌํŠธ ์ €์žฅ โ”€โ”€
1066
+ report_path = self.save_dir / "eval_report.json"
1067
+ with open(report_path, "w") as f:
1068
+ json.dump(report, f, indent=2, default=str)
1069
+ print(f"\n๐Ÿ“‹ ๋ฆฌํฌํŠธ ์ €์žฅ: {report_path}")
1070
+
1071
+ # โ”€โ”€ ์š”์•ฝ ์ถœ๋ ฅ โ”€โ”€
1072
+ self._print_summary(report)
1073
+
1074
+ return report
1075
+
1076
+ def _visualize_attention_sample(self):
1077
+ """์ƒ˜ํ”Œ ํ…์ŠคํŠธ๋กœ attention์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
1078
+ viz = AttentionVisualizer(str(self.save_dir))
1079
+
1080
+ sample_text = "The cat sat on the mat and looked at the bird."
1081
+ token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
1082
+ input_tensor = torch.tensor([token_ids], dtype=torch.long)
1083
+
1084
+ # ํ† ํฐ ๋ฌธ์ž์—ด (์‹œ๊ฐํ™” ๋ผ๋ฒจ์šฉ)
1085
+ tokens_str = []
1086
+ for tid in token_ids:
1087
+ decoded = self.tokenizer.decode([tid])
1088
+ tokens_str.append(decoded.replace("\n", "\\n"))
1089
+
1090
+ # Layer 0 attention ์ถ”์ถœ
1091
+ attn_weights = viz.extract_attention(
1092
+ self.model, input_tensor, layer_idx=0, device=self.device
1093
+ )
1094
+
1095
+ if attn_weights is not None:
1096
+ viz.plot_attention_heatmap(
1097
+ attn_weights, tokens_str, head_idx=0,
1098
+ title="Layer 0 Attention"
1099
+ )
1100
+ viz.plot_multi_head_summary(attn_weights)
1101
+
1102
+ @staticmethod
1103
+ def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
1104
+ """๋ชจ๋“  ํ”„๋กฌํ”„ํŠธ์˜ ์ƒ์„ฑ ๋ฉ”ํŠธ๋ฆญ ํ‰๊ท ."""
1105
+ if not gen_results:
1106
+ return {}
1107
+
1108
+ all_metrics = [r["metrics"] for r in gen_results if r.get("metrics")]
1109
+ if not all_metrics:
1110
+ return {}
1111
+
1112
+ keys = all_metrics[0].keys()
1113
+ return {
1114
+ k: round(sum(m.get(k, 0) for m in all_metrics) / len(all_metrics), 3)
1115
+ for k in keys
1116
+ }
1117
+
1118
+ def _print_summary(self, report: Dict[str, Any]):
1119
+ """์ตœ์ข… ์š”์•ฝ์„ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค."""
1120
+ print("\n" + "=" * 70)
1121
+ print("๐Ÿ“‹ ํ‰๊ฐ€ ์š”์•ฝ ๋ฆฌํฌํŠธ")
1122
+ print("=" * 70)
1123
+
1124
+ # Perplexity
1125
+ if "perplexity" in report:
1126
+ ppl = report["perplexity"]
1127
+ print(f"\n ๐ŸŽฏ Perplexity:")
1128
+ print(f" Loss: {ppl['loss']:.4f}")
1129
+ print(f" PPL: {ppl['perplexity']:.2f}")
1130
+
1131
+ # ๋“ฑ๊ธ‰ ํŒ์ •
1132
+ ppl_val = ppl["perplexity"]
1133
+ if ppl_val < 20:
1134
+ grade = "๐ŸŒŸ ์šฐ์ˆ˜ (Strong)"
1135
+ elif ppl_val < 35:
1136
+ grade = "โœ… ์–‘ํ˜ธ (Good)"
1137
+ elif ppl_val < 60:
1138
+ grade = "โš ๏ธ ๋ณดํ†ต (Fair)"
1139
+ else:
1140
+ grade = "โŒ ๋ฏธํก (ํ•™์Šต ์ถ”๊ฐ€ ํ•„์š”)"
1141
+ print(f" ๋“ฑ๊ธ‰: {grade}")
1142
+
1143
+ # ์œ„์น˜๋ณ„ Loss
1144
+ if "position_losses" in report:
1145
+ pl = report["position_losses"]
1146
+ print(f"\n ๐Ÿ“ ์œ„์น˜๋ณ„ Loss:")
1147
+ print(f" ์ดˆ๋ฐ˜ (0-50): {pl['early_avg']:.4f}")
1148
+ print(f" ํ›„๋ฐ˜ (-200): {pl['late_avg']:.4f}")
1149
+ print(f" ์ปจํ…์ŠคํŠธ ํšจ๊ณผ: {pl['early_avg'] - pl['late_avg']:.4f} ๊ฐ์†Œ")
1150
+
1151
+ # ์ƒ์„ฑ ํ’ˆ์งˆ
1152
+ if "generation" in report and report["generation"].get("avg_metrics"):
1153
+ gm = report["generation"]["avg_metrics"]
1154
+ print(f"\n โœ๏ธ ์ƒ์„ฑ ํ’ˆ์งˆ:")
1155
+ print(f" ํ‰๊ท  ๊ธธ์ด: {gm.get('avg_length', 0):.0f} ์ž")
1156
+ print(f" ๋ฐ˜๋ณต๋ฅ : {gm.get('repetition_rate', 0):.1%}")
1157
+ print(f" ์–ดํœ˜ ๋‹ค์–‘์„ฑ: {gm.get('lexical_diversity', 0):.3f}")
1158
+
1159
+ # ํ•™์Šต ์—ญํ•™
1160
+ if "training_dynamics" in report:
1161
+ td = report["training_dynamics"]
1162
+ if "loss" in td:
1163
+ print(f"\n ๐Ÿ“‰ ํ•™์Šต ์—ญํ•™:")
1164
+ print(f" Loss ๊ฐ์†Œ: {td['loss']['initial']:.4f} โ†’ {td['loss']['final']:.4f}")
1165
+ print(f" ์ŠคํŒŒ์ดํฌ: {len(td['loss']['spikes'])}ํšŒ")
1166
+
1167
+ # ์ƒ์„ฑ๋œ ํŒŒ์ผ
1168
+ print(f"\n ๐Ÿ“‚ ๊ฒฐ๊ณผ ํŒŒ์ผ:")
1169
+ for f in sorted(self.save_dir.glob("*")):
1170
+ size = f.stat().st_size / 1024
1171
+ print(f" {f.name} ({size:.1f} KB)")
1172
+
1173
+ print("\n" + "=" * 70)
1174
+
1175
+
1176
+ # ============================================================================
1177
+ # 8. ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ ๊ฒ€์ฆ๊ธฐ
1178
+ # ============================================================================
1179
+
1180
+ class InsightChecklist:
1181
+ """PRD์— ์ •์˜๋œ ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž๋™/์ˆ˜๋™์œผ๋กœ ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.
1182
+
1183
+ ์ž๋™ ๊ฒ€์ฆ ๊ฐ€๋Šฅ ํ•ญ๋ชฉ์€ ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ฐ˜์œผ๋กœ ํŒ์ •ํ•˜๊ณ ,
1184
+ ์ˆ˜๋™ ํ•ญ๋ชฉ์€ ์งˆ๋ฌธ์œผ๋กœ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.
1185
+ """
1186
+
1187
+ @staticmethod
1188
+ def run_checklist(
1189
+ report: Dict[str, Any],
1190
+ metrics_history: Optional[Dict[str, list]] = None,
1191
+ ):
1192
+ """์ฒดํฌ๋ฆฌ์ŠคํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."""
1193
+ print("\n" + "=" * 70)
1194
+ print("โœ… ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ")
1195
+ print("=" * 70)
1196
+
1197
+ checks = {
1198
+ "passed": [],
1199
+ "failed": [],
1200
+ "manual": [],
1201
+ }
1202
+
1203
+ # โ”€โ”€ ์ž๋™ ๊ฒ€์ฆ โ”€โ”€
1204
+
1205
+ # 1. Loss ์ˆ˜๋ ด
1206
+ if report.get("perplexity", {}).get("loss", 99) < 4.0:
1207
+ checks["passed"].append("๋ชจ๋ธ Loss๊ฐ€ 4.0 ์ดํ•˜๋กœ ์ˆ˜๋ ด")
1208
+ else:
1209
+ checks["failed"].append("๋ชจ๋ธ Loss๊ฐ€ 4.0 ์ดํ•˜๋กœ ๋ฏธ์ˆ˜๋ ด")
1210
+
1211
+ # 2. Loss ์ŠคํŒŒ์ดํฌ
1212
+ spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
1213
+ if len(spikes) < 5:
1214
+ checks["passed"].append(f"Loss ์ŠคํŒŒ์ดํฌ {len(spikes)}ํšŒ (< 5ํšŒ)")
1215
+ else:
1216
+ checks["failed"].append(f"Loss ์ŠคํŒŒ์ดํฌ {len(spikes)}ํšŒ (โ‰ฅ 5ํšŒ, ์•ˆ์ •์„ฑ ๊ฐœ์„  ํ•„์š”)")
1217
+
1218
+ # 3. ์œ„์น˜๋ณ„ Loss ํŒจํ„ด
1219
+ if report.get("position_losses"):
1220
+ early = report["position_losses"]["early_avg"]
1221
+ late = report["position_losses"]["late_avg"]
1222
+ if early > late:
1223
+ checks["passed"].append("์œ„์น˜๋ณ„ Loss ๊ฐ์†Œ ํŒจํ„ด ํ™•์ธ (์ปจํ…์ŠคํŠธ ํ™œ์šฉ)")
1224
+ else:
1225
+ checks["failed"].append("์œ„์น˜๋ณ„ Loss ํŒจํ„ด ์ด์ƒ (์ปจํ…์ŠคํŠธ ๋ฏธํ™œ์šฉ?)")
1226
+
1227
+ # 4. ์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ 
1228
+ rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
1229
+ if rep < 0.3:
1230
+ checks["passed"].append(f"์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ  {rep:.1%} (< 30%)")
1231
+ else:
1232
+ checks["failed"].append(f"์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ  {rep:.1%} (โ‰ฅ 30%, temperature/top_p ์กฐ์ •)")
1233
+
1234
+ # 5. Gradient ํด๋ฆฌํ•‘ ๋น„์œจ
1235
+ if metrics_history and metrics_history.get("grad_norm"):
1236
+ gnorms = metrics_history["grad_norm"]
1237
+ clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
1238
+ if clip_rate < 0.3:
1239
+ checks["passed"].append(f"Gradient ํด๋ฆฌํ•‘ ๋น„์œจ {clip_rate:.1%} (๊ฑด๊ฐ•)")
1240
+ else:
1241
+ checks["failed"].append(f"Gradient ํด๋ฆฌํ•‘ ๋น„์œจ {clip_rate:.1%} (๋„ˆ๋ฌด ์žฆ์Œ)")
1242
+
1243
+ # โ”€โ”€ ์ˆ˜๋™ ํ™•์ธ ํ•ญ๋ชฉ โ”€โ”€
1244
+ manual_items = [
1245
+ "Self-Attention์—์„œ Q, K, V ๊ฐ๊ฐ์˜ ์—ญํ• ์„ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋Š”๊ฐ€?",
1246
+ "RoPE๊ฐ€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ์ˆ˜ํ•™์  ์›๋ฆฌ๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
1247
+ "GQA๊ฐ€ MHA ๋Œ€๋น„ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋Š”๊ฐ€?",
1248
+ "SwiGLU์˜ ๊ฒŒ์ดํŒ… ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ReLU FFN๊ณผ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฅธ์ง€ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
1249
+ "Learning Rate Warmup์ด ์™œ ํ•„์š”ํ•œ์ง€ ์ฒด๊ฐํ–ˆ๋Š”๊ฐ€?",
1250
+ "Gradient Accumulation์ด ํฐ ๋ฐฐ์น˜๋ฅผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๋Š” ์›๋ฆฌ๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
1251
+ "Mixed Precision(bf16)์˜ ๋ฉ”๋ชจ๋ฆฌ-์†๋„ ํšจ๊ณผ๋ฅผ ์ธก์ •ํ–ˆ๋Š”๊ฐ€?",
1252
+ "Activation Checkpointing์˜ ๋ฉ”๋ชจ๋ฆฌ-์—ฐ์‚ฐ ํŠธ๋ ˆ์ด๋“œ์˜คํ”„๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
1253
+ ]
1254
+ checks["manual"] = manual_items
1255
+
1256
+ # โ”€โ”€ ์ถœ๋ ฅ โ”€โ”€
1257
+ total_auto = len(checks["passed"]) + len(checks["failed"])
1258
+ passed_auto = len(checks["passed"])
1259
+
1260
+ print(f"\n ์ž๋™ ๊ฒ€์ฆ: {passed_auto}/{total_auto} ํ†ต๊ณผ")
1261
+ for item in checks["passed"]:
1262
+ print(f" โœ… {item}")
1263
+ for item in checks["failed"]:
1264
+ print(f" โŒ {item}")
1265
+
1266
+ print(f"\n ์ˆ˜๋™ ํ™•์ธ ({len(manual_items)} ํ•ญ๋ชฉ):")
1267
+ for i, item in enumerate(manual_items, 1):
1268
+ print(f" {i}. [ ] {item}")
1269
+
1270
+ print(f"\n ์ด ์ง„ํ–‰๋ฅ : {passed_auto}/{total_auto + len(manual_items)} "
1271
+ f"(์ˆ˜๋™ ํ•ญ๋ชฉ ํฌํ•จ ์‹œ)")
1272
+
1273
+ return checks
1274
+
1275
+
1276
+ # ============================================================================
1277
+ # 9. Quick Start
1278
+ # ============================================================================
1279
+
1280
+ def run_evaluation(
1281
+ model: nn.Module,
1282
+ tokenizer: Any,
1283
+ val_dataloader: DataLoader,
1284
+ device: torch.device = None,
1285
+ dtype: torch.dtype = torch.bfloat16,
1286
+ metrics_history: Optional[Dict[str, list]] = None,
1287
+ config: Optional[EvalConfig] = None,
1288
+ ) -> Dict[str, Any]:
1289
+ """ํ‰๊ฐ€๋ฅผ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
1290
+
1291
+ ์‚ฌ์šฉ๋ฒ• (Colab):
1292
+ ```python
1293
+ from evaluation import run_evaluation
1294
+
1295
+ # ํ•™์Šต ์™„๋ฃŒ ํ›„
1296
+ report = run_evaluation(
1297
+ model=trainer.model,
1298
+ tokenizer=tokenizer,
1299
+ val_dataloader=val_dl,
1300
+ metrics_history=trainer.metrics.history,
1301
+ )
1302
+ ```
1303
+ """
1304
+ if device is None:
1305
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1306
+
1307
+ evaluator = FullEvaluator(
1308
+ model=model,
1309
+ tokenizer=tokenizer,
1310
+ val_dataloader=val_dataloader,
1311
+ device=device,
1312
+ config=config,
1313
+ dtype=dtype,
1314
+ metrics_history=metrics_history,
1315
+ )
1316
+
1317
+ report = evaluator.run_full_evaluation()
1318
+
1319
+ # ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ
1320
+ InsightChecklist.run_checklist(report, metrics_history)
1321
+
1322
+ return report
1323
+
1324
+
1325
+ # ============================================================================
1326
+ # 10. ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ
1327
+ # ============================================================================
1328
+
1329
+ if __name__ == "__main__":
1330
+ print("=" * 70)
1331
+ print("LLM-1B-Lab: ํ‰๊ฐ€ ๋ชจ๋“ˆ ๊ฒ€์ฆ")
1332
+ print("=" * 70)
1333
+
1334
+ # โ”€โ”€ ๋”๋ฏธ ๋ชจ๋ธ๋กœ ๊ตฌ์กฐ ๊ฒ€์ฆ โ”€โ”€
1335
+ class TinyModel(nn.Module):
1336
+ def __init__(self, vocab=100, dim=64):
1337
+ super().__init__()
1338
+ self.emb = nn.Embedding(vocab, dim)
1339
+ self.linear = nn.Linear(dim, vocab)
1340
+ self.linear.weight = self.emb.weight
1341
+ self.layers = nn.ModuleList() # attention ์‹œ๊ฐํ™” ํ˜ธํ™˜์šฉ
1342
+
1343
+ def forward(self, input_ids, targets=None):
1344
+ h = self.emb(input_ids)
1345
+ logits = self.linear(h)
1346
+ loss = None
1347
+ if targets is not None:
1348
+ loss = F.cross_entropy(logits.view(-1, 100), targets.view(-1))
1349
+ return logits, loss
1350
+
1351
+ def generate(self, input_ids, max_new_tokens=20, temperature=1.0, top_k=50, top_p=0.9):
1352
+ generated = input_ids
1353
+ for _ in range(max_new_tokens):
1354
+ logits, _ = self(generated[:, -64:])
1355
+ next_logits = logits[:, -1, :] / temperature
1356
+ probs = F.softmax(next_logits, dim=-1)
1357
+ nxt = torch.multinomial(probs, 1)
1358
+ generated = torch.cat([generated, nxt], dim=1)
1359
+ return generated
1360
+
1361
+ model = TinyModel()
1362
+ device = torch.device("cpu")
1363
+
1364
+ # ๋”๋ฏธ ํ† ํฌ๋‚˜์ด์ €
1365
+ class DummyTok:
1366
+ eos_id = 2
1367
+ vocab_size = 100
1368
+ def encode(self, t, add_special_tokens=False):
1369
+ return [min(ord(c), 99) for c in t]
1370
+ def decode(self, ids):
1371
+ return "".join(chr(max(min(i, 122), 32)) for i in ids if i > 2)
1372
+
1373
+ tok = DummyTok()
1374
+
1375
+ # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
1376
+ val_data = []
1377
+ for _ in range(30):
1378
+ ids = torch.randint(3, 100, (65,))
1379
+ val_data.append({"input_ids": ids[:64], "targets": ids[1:65]})
1380
+
1381
+ def collate(batch):
1382
+ return {
1383
+ "input_ids": torch.stack([b["input_ids"] for b in batch]),
1384
+ "targets": torch.stack([b["targets"] for b in batch]),
1385
+ }
1386
+
1387
+ val_dl = DataLoader(val_data, batch_size=4, collate_fn=collate)
1388
+
1389
+ # โ”€โ”€ 1. Perplexity ํ…Œ์ŠคํŠธ โ”€โ”€
1390
+ print("\n[ํ…Œ์ŠคํŠธ 1] Perplexity ์ธก์ •")
1391
+ ppl_eval = PerplexityEvaluator(EvalConfig(max_eval_batches=5))
1392
+ result = ppl_eval.evaluate(model, val_dl, device, torch.float32, desc="Test Eval")
1393
+ print(f" โ†’ Loss={result['loss']:.4f}, PPL={result['perplexity']:.2f}")
1394
+ expected_ppl = math.exp(math.log(100)) # vocab=100 โ†’ ์ดˆ๊ธฐ PPL โ‰ˆ 100
1395
+ print(f" โ†’ ์˜ˆ์ƒ ์ดˆ๊ธฐ PPL โ‰ˆ {expected_ppl:.0f} (vocab=100 ๋žœ๋ค)")
1396
+
1397
+ # โ”€โ”€ 2. ์ƒ์„ฑ ํ…Œ์ŠคํŠธ โ”€โ”€
1398
+ print("\n[ํ…Œ์ŠคํŠธ 2] ํ…์ŠคํŠธ ์ƒ์„ฑ")
1399
+ gen_eval = GenerationEvaluator(EvalConfig(max_new_tokens=30, num_samples=1))
1400
+ gen_results = gen_eval.generate_samples(
1401
+ model, tok, device, prompts=["Hello world"], verbose=True
1402
+ )
1403
+
1404
+ # โ”€โ”€ 3. Scaling ๋ถ„์„ ํ…Œ์ŠคํŠธ โ”€โ”€
1405
+ print("\n[ํ…Œ์ŠคํŠธ 3] Scaling Law ๋ถ„์„")
1406
+ analyzer = ScalingAnalyzer("./test_eval")
1407
+ dummy_scaling = [
1408
+ {"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
1409
+ {"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
1410
+ {"name": "1B", "params": 1.1e9, "tokens": 10e9, "loss": 3.0, "ppl": 20.1},
1411
+ ]
1412
+ scaling_result = analyzer.analyze(dummy_scaling)
1413
+
1414
+ # โ”€โ”€ 4. ํ•™์Šต ์—ญํ•™ ๋ถ„์„ ํ…Œ์ŠคํŠธ โ”€โ”€
1415
+ print("\n[ํ…Œ์ŠคํŠธ 4] ํ•™์Šต ์—ญํ•™ ๋ถ„์„")
1416
+ import random
1417
+ random.seed(42)
1418
+
1419
+ dummy_history = {
1420
+ "step": list(range(0, 1000, 10)),
1421
+ "train_loss": [10.0 * (0.995 ** i) + random.gauss(0, 0.1) for i in range(100)],
1422
+ "learning_rate": [min(3e-4 * i / 20, 3e-4) * (0.5 + 0.5 * math.cos(math.pi * max(0, i-20)/80))
1423
+ for i in range(100)],
1424
+ "grad_norm": [min(random.gauss(0.5, 0.3), 1.0) for _ in range(100)],
1425
+ "tokens_per_sec": [50000 + random.gauss(0, 3000) for _ in range(100)],
1426
+ "val_loss": [8.0, 6.0, 4.5, 3.8, 3.5],
1427
+ "val_ppl": [2981, 403, 90, 44, 33],
1428
+ }
1429
+
1430
+ dynamics = TrainingDynamicsAnalyzer("./test_eval")
1431
+ dynamics.analyze_metrics(dummy_history)
1432
+
1433
+ # โ”€โ”€ 5. ์ฒดํฌ๋ฆฌ์ŠคํŠธ ํ…Œ์ŠคํŠธ โ”€โ”€
1434
+ print("\n[ํ…Œ์ŠคํŠธ 5] ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ")
1435
+ dummy_report = {
1436
+ "perplexity": {"loss": 3.5, "perplexity": 33.1},
1437
+ "position_losses": {"early_avg": 4.5, "late_avg": 3.2},
1438
+ "generation": {"avg_metrics": {"repetition_rate": 0.15}},
1439
+ "training_dynamics": {"loss": {"initial": 10.0, "final": 3.5, "spikes": []}},
1440
+ }
1441
+ InsightChecklist.run_checklist(dummy_report, dummy_history)
1442
+
1443
+ # ์ •๋ฆฌ
1444
+ import shutil
1445
+ if os.path.exists("./test_eval"):
1446
+ shutil.rmtree("./test_eval")
1447
+
1448
+ print("\n" + "=" * 70)
1449
+ print("โœ… ํ‰๊ฐ€ ๋ชจ๋“ˆ ๊ฒ€์ฆ ์™„๋ฃŒ!")
1450
+ print()
1451
+ print("์‹ค์ œ ์‚ฌ์šฉ๋ฒ•:")
1452
+ print(" from evaluation import run_evaluation")
1453
+ print(" report = run_evaluation(model, tokenizer, val_dl,")
1454
+ print(" metrics_history=trainer.metrics.history)")
1455
+ print("=" * 70)
_archive/llm-1b-model.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
3
+ ================================================================
4
+ ๋”ฅ๋Ÿฌ๋‹ ์ดˆ๋ณด์ž๋ฅผ ์œ„ํ•œ ํ•™์Šต์šฉ ๊ตฌํ˜„.
5
+ ๊ฐ ์ปดํฌ๋„ŒํŠธ์— ์ƒ์„ธ ์ฃผ์„์„ ๋‹ฌ์•„ "์™œ ์ด๋ ‡๊ฒŒ ํ•˜๋Š”์ง€"๋ฅผ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.
6
+
7
+ ์•„ํ‚คํ…์ฒ˜ ์š”์•ฝ:
8
+ - Decoder-Only Transformer (Causal LM)
9
+ - RMSNorm (Pre-Normalization)
10
+ - Rotary Positional Embedding (RoPE)
11
+ - Grouped Query Attention (GQA)
12
+ - SwiGLU Feed-Forward Network
13
+ - Weight Tying (Embedding โ†” Output Head)
14
+ """
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+
25
+ # ============================================================================
26
+ # 1. ๋ชจ๋ธ ์„ค์ • (Config)
27
+ # ============================================================================
28
+
29
+ @dataclass
30
+ class ModelConfig:
31
+ """๋ชจ๋ธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฐ์ดํ„ฐํด๋ž˜์Šค๋กœ ๊ด€๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
32
+
33
+ ๊ทœ๋ชจ๋ณ„ ํ”„๋ฆฌ์…‹:
34
+ - debug: ~10M (ํŒŒ์ดํ”„๋ผ์ธ ๊ฒ€์ฆ์šฉ)
35
+ - small: ~100M (์ค‘๊ฐ„ ๊ฒ€์ฆ์šฉ)
36
+ - base: ~1.1B (์ตœ์ข… ๋ชฉํ‘œ)
37
+ """
38
+ vocab_size: int = 32_000
39
+ hidden_dim: int = 2048 # d_model: ๋ชจ๋ธ์˜ ๊ธฐ๋ณธ ์ฐจ์›
40
+ num_layers: int = 22 # Transformer ๋ธ”๋ก ์ˆ˜
41
+ num_heads: int = 16 # Query ํ—ค๋“œ ์ˆ˜
42
+ num_kv_heads: int = 4 # Key/Value ํ—ค๋“œ ์ˆ˜ (GQA)
43
+ intermediate_dim: int = 5632 # FFN ์ค‘๊ฐ„ ์ฐจ์› (โ‰ˆ 2.75 ร— hidden_dim)
44
+ max_seq_len: int = 2048 # ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด
45
+ dropout: float = 0.0 # Pretraining์—์„œ๋Š” ๋ณดํ†ต 0 ์‚ฌ์šฉ
46
+ rope_theta: float = 10000.0 # RoPE ์ฃผํŒŒ์ˆ˜ ๋ฒ ์ด์Šค
47
+ norm_eps: float = 1e-6 # RMSNorm epsilon
48
+
49
+ @property
50
+ def head_dim(self) -> int:
51
+ """๊ฐ ์–ดํ…์…˜ ํ—ค๋“œ์˜ ์ฐจ์›."""
52
+ return self.hidden_dim // self.num_heads
53
+
54
+ @property
55
+ def num_kv_groups(self) -> int:
56
+ """GQA์—์„œ ํ•˜๋‚˜์˜ KV ํ—ค๋“œ๊ฐ€ ๋‹ด๋‹นํ•˜๋Š” Q ํ—ค๋“œ ์ˆ˜."""
57
+ return self.num_heads // self.num_kv_heads
58
+
59
+ @classmethod
60
+ def debug_10m(cls) -> "ModelConfig":
61
+ """~10M ํŒŒ๋ผ๋ฏธํ„ฐ - ๋น ๋ฅธ ๋””๋ฒ„๊น…์šฉ."""
62
+ return cls(
63
+ hidden_dim=256, num_layers=6, num_heads=8,
64
+ num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
65
+ )
66
+
67
+ @classmethod
68
+ def small_100m(cls) -> "ModelConfig":
69
+ """~100M ํŒŒ๋ผ๋ฏธํ„ฐ - ์ค‘๊ฐ„ ๊ฒ€์ฆ์šฉ."""
70
+ return cls(
71
+ hidden_dim=768, num_layers=12, num_heads=12,
72
+ num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
73
+ )
74
+
75
+ @classmethod
76
+ def base_1b(cls) -> "ModelConfig":
77
+ """~1.1B ํŒŒ๋ผ๋ฏธํ„ฐ - ์ตœ์ข… ํ•™์Šต ๋ชฉํ‘œ."""
78
+ return cls() # ๊ธฐ๋ณธ๊ฐ’์ด 1B ์„ค์ •
79
+
80
+
81
+ # ============================================================================
82
+ # 2. RMSNorm (Root Mean Square Layer Normalization)
83
+ # ============================================================================
84
+
85
+ class RMSNorm(nn.Module):
86
+ """RMSNorm: LayerNorm์˜ ๊ฒฝ๋Ÿ‰ํ™” ๋ฒ„์ „.
87
+
88
+ ์ผ๋ฐ˜ LayerNorm๊ณผ์˜ ์ฐจ์ด:
89
+ - ํ‰๊ท (mean)์„ ๋นผ์ง€ ์•Š์Œ โ†’ ์—ฐ์‚ฐ ์ ˆ์•ฝ
90
+ - ๋ถ„์‚ฐ ๋Œ€์‹  RMS(Root Mean Square)๋กœ ์ •๊ทœํ™”
91
+ - bias ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์Œ
92
+
93
+ ์ˆ˜์‹:
94
+ RMSNorm(x) = (x / RMS(x)) * ฮณ
95
+ RMS(x) = sqrt(mean(xยฒ) + ฮต)
96
+
97
+ ์™œ ์ •๊ทœํ™”๊ฐ€ ํ•„์š”ํ•œ๊ฐ€?
98
+ โ†’ ๋ ˆ์ด์–ด๋ฅผ ๊นŠ๊ฒŒ ์Œ“์œผ๋ฉด ํ™œ์„ฑํ™” ๊ฐ’์˜ ์Šค์ผ€์ผ์ด ํญ๋ฐœํ•˜๊ฑฐ๋‚˜ ์†Œ๋ฉธํ•ฉ๋‹ˆ๋‹ค.
99
+ โ†’ ์ •๊ทœํ™”๋กœ ๊ฐ ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ์„ ์•ˆ์ •์ ์ธ ๋ฒ”์œ„๋กœ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
100
+ """
101
+
102
+ def __init__(self, dim: int, eps: float = 1e-6):
103
+ super().__init__()
104
+ self.eps = eps
105
+ # ฮณ (gamma): ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์Šค์ผ€์ผ ํŒŒ๋ผ๋ฏธํ„ฐ, 1๋กœ ์ดˆ๊ธฐํ™”
106
+ self.weight = nn.Parameter(torch.ones(dim))
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ # 1) ์ž…๋ ฅ์„ float32๋กœ ๋ณ€ํ™˜ (์ˆ˜์น˜ ์•ˆ์ •์„ฑ)
110
+ # bf16/fp16 ์ƒํƒœ์—์„œ ์ œ๊ณฑํ•ฉ์„ ๊ตฌํ•˜๋ฉด ์˜ค๋ฒ„ํ”Œ๋กœ์šฐ ์œ„ํ—˜
111
+ x_float = x.float()
112
+
113
+ # 2) RMS ๊ณ„์‚ฐ: sqrt(mean(xยฒ) + ฮต)
114
+ rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
115
+ # rsqrt = 1/sqrt(x) โ†’ ๋‚˜๋ˆ—์…ˆ ๋Œ€์‹  ๊ณฑ์…ˆ์œผ๋กœ ๋Œ€์ฒด (๋” ๋น ๋ฆ„)
116
+
117
+ # 3) ์ •๊ทœํ™” ํ›„ ์›๋ž˜ dtype์œผ๋กœ ๋ณต์›, ์Šค์ผ€์ผ ์ ์šฉ
118
+ return (x_float * rms).to(x.dtype) * self.weight
119
+
120
+
121
+ # ============================================================================
122
+ # 3. Rotary Positional Embedding (RoPE)
123
+ # ============================================================================
124
+
125
+ class RotaryPositionalEmbedding(nn.Module):
126
+ """RoPE: ํšŒ์ „ ํ–‰๋ ฌ์„ ์ด์šฉํ•œ ์ƒ๋Œ€ ์œ„์น˜ ์ธ์ฝ”๋”ฉ.
127
+
128
+ ํ•ต์‹ฌ ์•„์ด๋””์–ด:
129
+ - ๊ฐ ์ฐจ์› ์Œ(2i, 2i+1)์„ 2D ํ‰๋ฉด์˜ ์ขŒํ‘œ๋กœ ๋ณด๊ณ ,
130
+ ์œ„์น˜(position)์— ๋น„๋ก€ํ•œ ๊ฐ๋„๋งŒํผ ํšŒ์ „์‹œํ‚ต๋‹ˆ๋‹ค.
131
+ - ๋‘ ํ† ํฐ์˜ ์–ดํ…์…˜ ์Šค์ฝ”์–ด(QยทK)๋Š” ์ƒ๋Œ€ ๊ฑฐ๋ฆฌ์—๋งŒ ์˜์กดํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
132
+
133
+ ์™œ RoPE์ธ๊ฐ€?
134
+ - ์ ˆ๏ฟฝ๏ฟฝ๏ฟฝ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ: ๊ฐ ์œ„์น˜์— ๊ณ ์ • ๋ฒกํ„ฐ๋ฅผ ๋”ํ•จ โ†’ ๊ธธ์ด ์ผ๋ฐ˜ํ™” ์–ด๋ ค์›€
135
+ - ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ: ๊ตฌํ˜„ ๋ณต์žก, ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ ํ•„์š”
136
+ - RoPE: ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์ด, ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด ์ธ์ฝ”๋”ฉ
137
+
138
+ ์ˆ˜์‹:
139
+ ฮธ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
140
+ RoPE(x, pos) = x๋ฅผ ๊ฐ ์ฐจ์› ์Œ์—์„œ pos ร— ฮธ_i ๋งŒํผ ํšŒ์ „
141
+ """
142
+
143
+ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
144
+ super().__init__()
145
+ self.dim = dim
146
+ self.max_seq_len = max_seq_len
147
+ self.theta = theta
148
+
149
+ # ์ฃผํŒŒ์ˆ˜ ๋ฒกํ„ฐ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ (ํ•™์Šต ๋ถˆํ•„์š” โ†’ buffer๋กœ ๋“ฑ๋ก)
150
+ # freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
151
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
152
+ self.register_buffer("freqs", freqs, persistent=False)
153
+
154
+ # (max_seq_len, dim/2) ํฌ๊ธฐ์˜ cos/sin ํ…Œ์ด๋ธ” ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
155
+ self._build_cache(max_seq_len)
156
+
157
+ def _build_cache(self, seq_len: int):
158
+ """cos/sin ๊ฐ’์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐํ•˜์—ฌ ์บ์‹ฑํ•ฉ๋‹ˆ๋‹ค."""
159
+ t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
160
+ # outer product: (seq_len,) ร— (dim/2,) โ†’ (seq_len, dim/2)
161
+ angles = torch.outer(t, self.freqs)
162
+ self.register_buffer("cos_cached", angles.cos(), persistent=False)
163
+ self.register_buffer("sin_cached", angles.sin(), persistent=False)
164
+
165
+ def forward(
166
+ self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
167
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
168
+ """Q, K์— ํšŒ์ „ ๋ณ€ํ™˜์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
169
+
170
+ Args:
171
+ q: (batch, num_heads, seq_len, head_dim)
172
+ k: (batch, num_kv_heads, seq_len, head_dim)
173
+ position_offset: ์‹œํ€€์Šค ์‹œ์ž‘ ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ KV ์บ์‹œ ์‚ฌ์šฉ ์‹œ)
174
+
175
+ Returns:
176
+ ํšŒ์ „ ๋ณ€ํ™˜์ด ์ ์šฉ๋œ (q_rotated, k_rotated)
177
+ """
178
+ seq_len = q.shape[2]
179
+
180
+ # ํ•„์š” ์‹œ ์บ์‹œ ํ™•์žฅ
181
+ if position_offset + seq_len > self.cos_cached.shape[0]:
182
+ self._build_cache(position_offset + seq_len)
183
+
184
+ # ํ˜„์žฌ ์œ„์น˜์— ํ•ด๋‹นํ•˜๋Š” cos/sin ์Šฌ๋ผ์ด์Šค
185
+ cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
186
+ sin = self.sin_cached[position_offset : position_offset + seq_len]
187
+
188
+ q_rotated = self._apply_rotation(q, cos, sin)
189
+ k_rotated = self._apply_rotation(k, cos, sin)
190
+ return q_rotated, k_rotated
191
+
192
+ @staticmethod
193
+ def _apply_rotation(
194
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
195
+ ) -> torch.Tensor:
196
+ """ํšŒ์ „ ๋ณ€ํ™˜ ์ ์šฉ.
197
+
198
+ 2D ํšŒ์ „ ํ–‰๋ ฌ:
199
+ [cos ฮธ, -sin ฮธ] [x1] [x1ยทcos ฮธ - x2ยทsin ฮธ]
200
+ [sin ฮธ, cos ฮธ] [x2] = [x1ยทsin ฮธ + x2ยทcos ฮธ]
201
+
202
+ ์ด๋ฅผ ๋ฒกํ„ฐ ์—ฐ์‚ฐ์œผ๋กœ ํšจ์œจ์ ์œผ๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
203
+ """
204
+ # x: (batch, heads, seq_len, head_dim)
205
+ # ์ง์ˆ˜/ํ™€์ˆ˜ ์ธ๋ฑ์Šค๋ฅผ ๋ถ„๋ฆฌ: (x0, x1, x2, x3, ...) โ†’ (x0, x2, ...), (x1, x3, ...)
206
+ x_even = x[..., 0::2] # ์ง์ˆ˜ ์ธ๋ฑ์Šค
207
+ x_odd = x[..., 1::2] # ํ™€์ˆ˜ ์ธ๋ฑ์Šค
208
+
209
+ # ๋ธŒ๋กœ๋“œ์บ์ŠคํŒ…์„ ์œ„ํ•ด ์ฐจ์› ๋งž์ถค: (seq_len, dim/2) โ†’ (1, 1, seq_len, dim/2)
210
+ cos = cos.unsqueeze(0).unsqueeze(0)
211
+ sin = sin.unsqueeze(0).unsqueeze(0)
212
+
213
+ # ํšŒ์ „ ์ ์šฉ
214
+ rotated_even = x_even * cos - x_odd * sin
215
+ rotated_odd = x_even * sin + x_odd * cos
216
+
217
+ # ๋‹ค์‹œ ์ธํ„ฐ๋ฆฌ๋น™: (even0, odd0, even1, odd1, ...)
218
+ out = torch.stack([rotated_even, rotated_odd], dim=-1)
219
+ return out.flatten(-2) # ๋งˆ์ง€๋ง‰ ๋‘ ์ฐจ์›์„ ํ•ฉ์ณ ์›๋ž˜ shape ๋ณต์›
220
+
221
+
222
+ # ============================================================================
223
+ # 4. Grouped Query Attention (GQA)
224
+ # ============================================================================
225
+
226
+ class GroupedQueryAttention(nn.Module):
227
+ """GQA: Multi-Head Attention์˜ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์  ๋ณ€ํ˜•.
228
+
229
+ MHA vs GQA vs MQA:
230
+ - MHA (Multi-Head Attention): Q, K, V ๋ชจ๋‘ num_heads๊ฐœ โ†’ ๋ฉ”๋ชจ๋ฆฌ ํผ
231
+ - MQA (Multi-Query Attention): K, V๋Š” 1๊ฐœ ํ—ค๋“œ ๊ณต์œ  โ†’ ํ’ˆ์งˆ ์ €ํ•˜ ์šฐ๋ ค
232
+ - GQA (Grouped Query Attention): K, V๋ฅผ num_kv_heads๊ฐœ๋กœ ๊ทธ๋ฃนํ™”
233
+ โ†’ MHA์™€ MQA์˜ ์ค‘๊ฐ„, ์ข‹์€ ํ’ˆ์งˆ-ํšจ์œจ ๊ท ํ˜•
234
+
235
+ ์˜ˆ์‹œ (num_heads=16, num_kv_heads=4):
236
+ Q ํ—ค๋“œ: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
237
+ K/V ๊ทธ๋ฃน: [ 0 , 1 , 2 , 3 ]
238
+ โ†’ Q ํ—ค๋“œ 4๊ฐœ๊ฐ€ K/V ํ—ค๋“œ 1๊ฐœ๋ฅผ ๊ณต์œ 
239
+
240
+ Attention ์ˆ˜์‹:
241
+ Attention(Q, K, V) = softmax(QยทK^T / โˆšd_k) ยท V
242
+ """
243
+
244
+ def __init__(self, config: ModelConfig):
245
+ super().__init__()
246
+ self.config = config
247
+ self.head_dim = config.head_dim
248
+ self.num_heads = config.num_heads
249
+ self.num_kv_heads = config.num_kv_heads
250
+ self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
251
+
252
+ # Q/K/V ํ”„๋กœ์ ์…˜
253
+ # Q: hidden_dim โ†’ num_heads ร— head_dim
254
+ self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
255
+ # K, V: hidden_dim โ†’ num_kv_heads ร— head_dim (Q๋ณด๋‹ค ์ž‘์Œ!)
256
+ self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
257
+ self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
258
+
259
+ # ์ถœ๋ ฅ ํ”„๋กœ์ ์…˜: ๋ชจ๋“  ํ—ค๋“œ์˜ ์ถœ๋ ฅ์„ ๋‹ค์‹œ hidden_dim์œผ๋กœ
260
+ self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
261
+
262
+ # RoPE
263
+ self.rope = RotaryPositionalEmbedding(
264
+ dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
265
+ )
266
+
267
+ # Attention dropout (pretraining์—์„œ๋Š” ๋ณดํ†ต 0)
268
+ self.attn_dropout = nn.Dropout(config.dropout)
269
+
270
+ def forward(
271
+ self,
272
+ x: torch.Tensor,
273
+ mask: Optional[torch.Tensor] = None,
274
+ position_offset: int = 0,
275
+ ) -> torch.Tensor:
276
+ """
277
+ Args:
278
+ x: (batch_size, seq_len, hidden_dim)
279
+ mask: (seq_len, seq_len) causal mask
280
+ position_offset: ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ ์‚ฌ์šฉ)
281
+
282
+ Returns:
283
+ (batch_size, seq_len, hidden_dim)
284
+ """
285
+ B, S, _ = x.shape
286
+
287
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
288
+ # Step 1: Q, K, V ํ”„๋กœ์ ์…˜
289
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
290
+ q = self.q_proj(x) # (B, S, num_heads ร— head_dim)
291
+ k = self.k_proj(x) # (B, S, num_kv_heads ร— head_dim)
292
+ v = self.v_proj(x) # (B, S, num_kv_heads ร— head_dim)
293
+
294
+ # ๋ฉ€ํ‹ฐํ—ค๋“œ ํ˜•ํƒœ๋กœ reshape
295
+ q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
296
+ # โ†’ (B, num_heads, S, head_dim)
297
+ k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
298
+ # โ†’ (B, num_kv_heads, S, head_dim)
299
+ v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
300
+
301
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
302
+ # Step 2: RoPE ์ ์šฉ (Q, K์—๋งŒ! V์—๋Š” ์ ์šฉํ•˜์ง€ ์•Š์Œ)
303
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
304
+ # ์œ„์น˜ ์ •๋ณด๋Š” "์–ด๋””๋ฅผ ๋ณผ์ง€"(QยทK)์—๋งŒ ์˜ํ–ฅ์„ ์ค˜์•ผ ํ•˜๊ณ ,
305
+ # "๋ฌด์—‡์„ ๊ฐ€์ ธ์˜ฌ์ง€"(V)์—๋Š” ์˜ํ–ฅ์„ ์ฃผ๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.
306
+ q, k = self.rope(q, k, position_offset)
307
+
308
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
309
+ # Step 3: GQA - KV ํ—ค๋“œ ํ™•์žฅ (repeat)
310
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
311
+ # num_kv_heads=4 โ†’ num_heads=16: ๊ฐ KV๋ฅผ 4๋ฒˆ ๋ฐ˜๋ณต
312
+ if self.num_kv_groups > 1:
313
+ k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
314
+ v = self._repeat_kv(v)
315
+
316
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
317
+ # Step 4: Scaled Dot-Product Attention
318
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
319
+ # PyTorch >= 2.0์˜ ์ตœ์ ํ™”๋œ ๊ตฌํ˜„ ์‚ฌ์šฉ (Flash Attention ์ž๋™ ์ ์šฉ)
320
+ attn_out = F.scaled_dot_product_attention(
321
+ q, k, v,
322
+ attn_mask=mask,
323
+ dropout_p=self.config.dropout if self.training else 0.0,
324
+ is_causal=(mask is None), # mask๊ฐ€ ์—†์œผ๋ฉด ์ž๋™ causal masking
325
+ )
326
+ # โ†’ (B, num_heads, S, head_dim)
327
+
328
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
329
+ # Step 5: ํ—ค๋“œ ํ•ฉ์น˜๊ธฐ + ์ถœ๋ ฅ ํ”„๋กœ์ ์…˜
330
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
331
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
332
+ # โ†’ (B, S, num_heads ร— head_dim)
333
+
334
+ return self.o_proj(attn_out) # โ†’ (B, S, hidden_dim)
335
+
336
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
337
+ """KV ํ—ค๋“œ๋ฅผ Q ํ—ค๋“œ ์ˆ˜์— ๋งž๊ฒŒ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
338
+
339
+ (B, num_kv_heads, S, head_dim) โ†’ (B, num_heads, S, head_dim)
340
+
341
+ ์˜ˆ: num_kv_heads=4, num_kv_groups=4
342
+ [kv0, kv1, kv2, kv3] โ†’ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
343
+ """
344
+ B, H_kv, S, D = x.shape
345
+ x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
346
+ x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
347
+ return x.reshape(B, self.num_heads, S, D)
348
+
349
+
350
+ # ============================================================================
351
+ # 5. SwiGLU Feed-Forward Network
352
+ # ============================================================================
353
+
354
+ class SwiGLUFeedForward(nn.Module):
355
+ """SwiGLU: Gated Linear Unit with Swish ํ™œ์„ฑํ™” ํ•จ์ˆ˜.
356
+
357
+ ๊ธฐ์กด FFN:
358
+ FFN(x) = ReLU(xยทW1 + b1)ยทW2 + b2
359
+ โ†’ ๋‹จ์ˆœํ•œ ๋น„์„ ํ˜• ๋ณ€ํ™˜
360
+
361
+ SwiGLU FFN:
362
+ SwiGLU(x) = (Swish(xยทW_gate) โŠ™ (xยทW_up)) ยท W_down
363
+ โ†’ ๊ฒŒ์ดํŒ… ๋ฉ”์ปค๋‹ˆ์ฆ˜์œผ๋กœ ์ •๋ณด ํ๋ฆ„์„ ์ œ์–ด
364
+
365
+ ์™œ SwiGLU๊ฐ€ ๋” ์ข‹์€๊ฐ€?
366
+ - Swish(x) = x ยท sigmoid(x): ๋ถ€๋“œ๋Ÿฌ์šด ํ™œ์„ฑํ™”, ์Œ์ˆ˜ ์˜์—ญ ์ผ๋ถ€ ํ—ˆ์šฉ
367
+ - Gate ๋ฒกํ„ฐ๊ฐ€ "์–ด๋–ค ์ •๋ณด๋ฅผ ํ†ต๊ณผ์‹œํ‚ฌ์ง€" ํ•™์Šต
368
+ - PaLM, LLaMA ๋“ฑ์—์„œ ReLU FFN ๋Œ€๋น„ ์ผ๊ด€๋œ ์„ฑ๋Šฅ ํ–ฅ์ƒ ๋ณด๊ณ 
369
+
370
+ ์ฐธ๊ณ : W_gate์™€ W_up ๋‘ ๊ฐœ์˜ up-projection์ด ์žˆ์–ด์„œ
371
+ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ๊ธฐ์กด FFN ๋Œ€๋น„ 1.5๋ฐฐ์ด์ง€๋งŒ, intermediate_dim์„
372
+ ์กฐ์ •ํ•˜์—ฌ ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ๋งž์ถฅ๋‹ˆ๋‹ค.
373
+ """
374
+
375
+ def __init__(self, config: ModelConfig):
376
+ super().__init__()
377
+ # ๊ฒŒ์ดํŠธ ํ”„๋กœ์ ์…˜: hidden_dim โ†’ intermediate_dim
378
+ self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
379
+ # ์—… ํ”„๋กœ์ ์…˜: hidden_dim โ†’ intermediate_dim
380
+ self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
381
+ # ๋‹ค์šด ํ”„๋กœ์ ์…˜: intermediate_dim โ†’ hidden_dim
382
+ self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
383
+
384
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
385
+ # SwiGLU(x) = (Swish(gate(x)) โŠ™ up(x)) ยท down
386
+ #
387
+ # 1) gate: ์–ด๋–ค ์ •๋ณด๋ฅผ ํ†ต๊ณผ์‹œํ‚ฌ์ง€ ๊ฒฐ์ • (Swish ํ™œ์„ฑํ™”)
388
+ gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
389
+ # 2) up: ์ •๋ณด๋ฅผ ๊ณ ์ฐจ์›์œผ๋กœ ์‚ฌ์˜
390
+ up = self.up_proj(x)
391
+ # 3) element-wise ๊ณฑ (๊ฒŒ์ดํŒ…) โ†’ ๋‹ค์‹œ ์›๋ž˜ ์ฐจ์›์œผ๋กœ
392
+ return self.down_proj(gate * up)
393
+
394
+
395
+ # ============================================================================
396
+ # 6. Transformer Block (ํ•˜๋‚˜์˜ ๋ ˆ์ด์–ด)
397
+ # ============================================================================
398
+
399
+ class TransformerBlock(nn.Module):
400
+ """ํ•˜๋‚˜์˜ Transformer ๋””์ฝ”๋” ๋ธ”๋ก.
401
+
402
+ ๊ตฌ์กฐ (Pre-Norm ๋ฐฉ์‹):
403
+ x โ†’ RMSNorm โ†’ Attention โ†’ + (residual) โ†’ RMSNorm โ†’ FFN โ†’ + (residual) โ†’ out
404
+
405
+ Pre-Norm vs Post-Norm:
406
+ - Post-Norm (์›๋ž˜ Transformer): LayerNorm์ด residual ์ดํ›„
407
+ โ†’ ๊นŠ์€ ๋ชจ๋ธ์—์„œ ํ•™์Šต ๋ถˆ์•ˆ์ •
408
+ - Pre-Norm (GPT-2 ์ดํ›„ ํ‘œ์ค€): LayerNorm์ด sublayer ์ด์ „
409
+ โ†’ gradient ํ๋ฆ„์ด ์›ํ™œ, ํ•™์Šต์ด ์•ˆ์ •์ 
410
+
411
+ Residual Connection์˜ ์—ญํ• :
412
+ - ์ž…๋ ฅ์„ ์ถœ๋ ฅ์— ๋”ํ•จ โ†’ gradient๊ฐ€ ๋ ˆ์ด์–ด๋ฅผ ๊ฑด๋„ˆ๋›ธ ์ˆ˜ ์žˆ๋Š” "๊ณ ์†๋„๋กœ"
413
+ - 22๊ฐœ ๋ ˆ์ด์–ด๋ฅผ ์Œ“์•„๋„ ํ•™์Šต์ด ๊ฐ€๋Šฅํ•œ ํ•ต์‹ฌ ์ด์œ 
414
+ """
415
+
416
+ def __init__(self, config: ModelConfig, layer_idx: int):
417
+ super().__init__()
418
+ self.layer_idx = layer_idx
419
+
420
+ # Pre-Norm: Attention ์ „ ์ •๊ทœํ™”
421
+ self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
422
+ # Self-Attention
423
+ self.attention = GroupedQueryAttention(config)
424
+
425
+ # Pre-Norm: FFN ์ „ ์ •๊ทœํ™”
426
+ self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
427
+ # Feed-Forward Network
428
+ self.feed_forward = SwiGLUFeedForward(config)
429
+
430
+ def forward(
431
+ self,
432
+ x: torch.Tensor,
433
+ mask: Optional[torch.Tensor] = None,
434
+ position_offset: int = 0,
435
+ ) -> torch.Tensor:
436
+ """
437
+ Args:
438
+ x: (batch_size, seq_len, hidden_dim)
439
+ Returns:
440
+ (batch_size, seq_len, hidden_dim)
441
+ """
442
+ # โ”€โ”€ Attention sublayer with residual โ”€โ”€
443
+ # h = x + Attention(RMSNorm(x))
444
+ h = x + self.attention(self.attn_norm(x), mask, position_offset)
445
+
446
+ # โ”€โ”€ FFN sublayer with residual โ”€โ”€
447
+ # out = h + FFN(RMSNorm(h))
448
+ out = h + self.feed_forward(self.ffn_norm(h))
449
+
450
+ return out
451
+
452
+
453
+ # ============================================================================
454
+ # 7. Full Transformer Model (LLaMA-style)
455
+ # ============================================================================
456
+
457
+ class LLMModel(nn.Module):
458
+ """1B ํŒŒ๋ผ๋ฏธํ„ฐ LLaMA-style Decoder-Only Transformer.
459
+
460
+ ์ „์ฒด ๊ตฌ์กฐ:
461
+ Input Token IDs
462
+ โ†’ Token Embedding
463
+ โ†’ [TransformerBlock] ร— num_layers (+ Activation Checkpointing)
464
+ โ†’ RMSNorm (์ตœ์ข…)
465
+ โ†’ Linear Head (โ†’ vocab logits)
466
+
467
+ Weight Tying:
468
+ - ์ž…๋ ฅ Embedding๊ณผ ์ถœ๋ ฅ Linear Head์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณต์œ 
469
+ - ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์ ˆ์•ฝ (~65M) + ์„ฑ๋Šฅ ์œ ์ง€/ํ–ฅ์ƒ
470
+ - ์ง๊ด€: "๋‹จ์–ด์˜ ์˜๋ฏธ ํ‘œํ˜„"๊ณผ "๋‹จ์–ด ์˜ˆ์ธก"์ด ๊ฐ™์€ ๊ณต๊ฐ„๏ฟฝ๏ฟฝ ์‚ฌ์šฉ
471
+ """
472
+
473
+ def __init__(self, config: ModelConfig):
474
+ super().__init__()
475
+ self.config = config
476
+
477
+ # โ”€โ”€ Token Embedding โ”€โ”€
478
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
479
+
480
+ # โ”€โ”€ Transformer Blocks โ”€โ”€
481
+ self.layers = nn.ModuleList([
482
+ TransformerBlock(config, layer_idx=i)
483
+ for i in range(config.num_layers)
484
+ ])
485
+
486
+ # โ”€โ”€ ์ตœ์ข… ์ •๊ทœํ™” โ”€โ”€
487
+ self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
488
+
489
+ # โ”€โ”€ ์ถœ๋ ฅ ํ—ค๋“œ (Weight Tying) โ”€โ”€
490
+ self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
491
+ # Weight Tying: lm_head์˜ ๊ฐ€์ค‘์น˜ = token_embedding์˜ ๊ฐ€์ค‘์น˜
492
+ self.lm_head.weight = self.token_embedding.weight
493
+
494
+ # ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”
495
+ self._init_weights()
496
+
497
+ def _init_weights(self):
498
+ """๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™” ์ „๋žต.
499
+
500
+ ์™œ ์ดˆ๊ธฐํ™”๊ฐ€ ์ค‘์š”ํ•œ๊ฐ€?
501
+ - ๋„ˆ๋ฌด ํฌ๋ฉด: ํ™œ์„ฑํ™” ํญ๋ฐœ โ†’ NaN
502
+ - ๋„ˆ๋ฌด ์ž‘์œผ๋ฉด: gradient ์†Œ๋ฉธ โ†’ ํ•™์Šต ์ •์ฒด
503
+ - ์ ์ ˆํ•œ ์ดˆ๊ธฐํ™”: ๊ฐ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ ๋ถ„์‚ฐ์„ ์ผ์ •ํ•˜๊ฒŒ ์œ ์ง€
504
+
505
+ GPT-2 ์Šคํƒ€์ผ ์ดˆ๊ธฐํ™”:
506
+ - ์ผ๋ฐ˜ Linear: N(0, 0.02)
507
+ - Residual projection: N(0, 0.02 / โˆš(2 ร— num_layers))
508
+ โ†’ ๋ ˆ์ด์–ด๊ฐ€ ๊นŠ์–ด์งˆ์ˆ˜๋ก residual ๊ธฐ์—ฌ๋ฅผ ์ค„์—ฌ ์•ˆ์ •ํ™”
509
+ """
510
+ std = 0.02
511
+ residual_std = std / math.sqrt(2 * self.config.num_layers)
512
+
513
+ for module in self.modules():
514
+ if isinstance(module, nn.Linear):
515
+ nn.init.normal_(module.weight, mean=0.0, std=std)
516
+ if module.bias is not None:
517
+ nn.init.zeros_(module.bias)
518
+ elif isinstance(module, nn.Embedding):
519
+ nn.init.normal_(module.weight, mean=0.0, std=std)
520
+
521
+ # Residual projection ๋ ˆ์ด์–ด์— ์ถ•์†Œ๋œ ์ดˆ๊ธฐํ™” ์ ์šฉ
522
+ for layer in self.layers:
523
+ nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
524
+ nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
525
+
526
+ def forward(
527
+ self,
528
+ input_ids: torch.Tensor,
529
+ targets: Optional[torch.Tensor] = None,
530
+ position_offset: int = 0,
531
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
532
+ """
533
+ Args:
534
+ input_ids: (batch_size, seq_len) - ํ† ํฐ ID
535
+ targets: (batch_size, seq_len) - ์ •๋‹ต ํ† ํฐ ID (ํ•™์Šต ์‹œ)
536
+ position_offset: ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ)
537
+
538
+ Returns:
539
+ logits: (batch_size, seq_len, vocab_size)
540
+ loss: ์Šค์นผ๋ผ (targets ์ œ๊ณต ์‹œ) ๋˜๋Š” None
541
+ """
542
+ B, S = input_ids.shape
543
+
544
+ # โ”€โ”€ Step 1: Token Embedding โ”€โ”€
545
+ # ๊ฐ ํ† ํฐ ID๋ฅผ hidden_dim ์ฐจ์›์˜ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
546
+ h = self.token_embedding(input_ids) # (B, S, hidden_dim)
547
+
548
+ # โ”€โ”€ Step 2: Transformer Blocks โ”€โ”€
549
+ # Activation Checkpointing: ํ•™์Šต ์‹œ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
550
+ # (์ค‘๊ฐ„ ํ™œ์„ฑํ™”๋ฅผ ์ €์žฅํ•˜์ง€ ์•Š๊ณ , backward ์‹œ ์žฌ๊ณ„์‚ฐ)
551
+ for layer in self.layers:
552
+ if self.training and torch.is_grad_enabled():
553
+ # Activation Checkpointing ์ ์šฉ
554
+ h = torch.utils.checkpoint.checkpoint(
555
+ layer, h, None, position_offset,
556
+ use_reentrant=False, # PyTorch >= 2.0 ๊ถŒ์žฅ
557
+ )
558
+ else:
559
+ h = layer(h, mask=None, position_offset=position_offset)
560
+
561
+ # โ”€โ”€ Step 3: ์ตœ์ข… ์ •๊ทœํ™” โ”€โ”€
562
+ h = self.final_norm(h)
563
+
564
+ # โ”€โ”€ Step 4: ์ถœ๋ ฅ ๋กœ์ง“ ๊ณ„์‚ฐ โ”€โ”€
565
+ logits = self.lm_head(h) # (B, S, vocab_size)
566
+
567
+ # โ”€โ”€ Step 5: Loss ๊ณ„์‚ฐ (ํ•™์Šต ์‹œ) โ”€โ”€
568
+ loss = None
569
+ if targets is not None:
570
+ # Cross-Entropy Loss: ๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก
571
+ # logits: (B, S, V) โ†’ (B*S, V)
572
+ # targets: (B, S) โ†’ (B*S,)
573
+ loss = F.cross_entropy(
574
+ logits.view(-1, self.config.vocab_size),
575
+ targets.view(-1),
576
+ ignore_index=-100, # ํŒจ๋”ฉ ํ† ํฐ ๋ฌด์‹œ
577
+ )
578
+
579
+ return logits, loss
580
+
581
+ def count_parameters(self, trainable_only: bool = True) -> int:
582
+ """๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ."""
583
+ if trainable_only:
584
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
585
+ return sum(p.numel() for p in self.parameters())
586
+
587
+ @torch.no_grad()
588
+ def generate(
589
+ self,
590
+ input_ids: torch.Tensor,
591
+ max_new_tokens: int = 100,
592
+ temperature: float = 1.0,
593
+ top_k: int = 50,
594
+ top_p: float = 0.9,
595
+ ) -> torch.Tensor:
596
+ """ํ…์ŠคํŠธ ์ƒ์„ฑ (์ถ”๋ก ).
597
+
598
+ Autoregressive ์ƒ์„ฑ: ํ•œ ํ† ํฐ์”ฉ ์˜ˆ์ธกํ•˜์—ฌ ์ด์–ด๋ถ™์ด๊ธฐ.
599
+
600
+ Args:
601
+ input_ids: (1, prompt_len) - ์ดˆ๊ธฐ ํ”„๋กฌํ”„ํŠธ
602
+ max_new_tokens: ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
603
+ temperature: ํ™•๋ฅ  ๋ถ„ํฌ ๋‚ ์นด๋กœ์›€ ์กฐ์ ˆ (๋‚ฎ์„์ˆ˜๋ก ๋ณด์ˆ˜์ )
604
+ top_k: ํ™•๋ฅ  ์ƒ์œ„ k๊ฐœ๋งŒ ๊ณ ๋ ค
605
+ top_p: ๋ˆ„์  ํ™•๋ฅ  p๊นŒ์ง€๋งŒ ๊ณ ๋ ค (nucleus sampling)
606
+ """
607
+ self.eval()
608
+ generated = input_ids
609
+
610
+ for _ in range(max_new_tokens):
611
+ # ํ˜„์žฌ ์‹œํ€€์Šค๊ฐ€ max_seq_len์„ ์ดˆ๊ณผํ•˜๋ฉด ์ž˜๋ผ๋‚ด๊ธฐ
612
+ ctx = generated[:, -self.config.max_seq_len:]
613
+
614
+ # Forward pass
615
+ logits, _ = self(ctx)
616
+ # ๋งˆ์ง€๋ง‰ ํ† ํฐ์˜ logits๋งŒ ์‚ฌ์šฉ (๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก)
617
+ next_logits = logits[:, -1, :] / temperature
618
+
619
+ # โ”€โ”€ Top-K ํ•„ํ„ฐ๋ง โ”€โ”€
620
+ if top_k > 0:
621
+ top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
622
+ min_top_k = top_k_values[:, -1].unsqueeze(-1)
623
+ next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
624
+
625
+ # โ”€โ”€ Top-P (Nucleus) ํ•„ํ„ฐ๋ง โ”€โ”€
626
+ if top_p < 1.0:
627
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
628
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
629
+ # ๋ˆ„์  ํ™•๋ฅ ์ด top_p๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ํ† ํฐ ์ œ๊ฑฐ
630
+ remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
631
+ sorted_logits[remove_mask] = float("-inf")
632
+ # ์›๋ž˜ ์ˆœ์„œ๋กœ ๋ณต์›
633
+ next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
634
+
635
+ # ํ™•๋ฅ  ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋ง
636
+ probs = F.softmax(next_logits, dim=-1)
637
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
638
+
639
+ # ์ƒ์„ฑ๋œ ํ† ํฐ ์ด์–ด๋ถ™์ด๊ธฐ
640
+ generated = torch.cat([generated, next_token], dim=1)
641
+
642
+ return generated
643
+
644
+
645
+ # ============================================================================
646
+ # 8. ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜
647
+ # ============================================================================
648
+
649
+ def count_parameters_detailed(model: LLMModel) -> dict:
650
+ """๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ์ปดํฌ๋„ŒํŠธ๋ณ„๋กœ ์ƒ์„ธ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค."""
651
+ total = 0
652
+ breakdown = {}
653
+
654
+ # Embedding
655
+ emb_params = model.token_embedding.weight.numel()
656
+ breakdown["token_embedding"] = emb_params
657
+ total += emb_params
658
+
659
+ # ๊ฐ ๋ ˆ์ด์–ด
660
+ layer_total = 0
661
+ layer_detail = {}
662
+ layer = model.layers[0]
663
+
664
+ for name, param in layer.named_parameters():
665
+ layer_detail[name] = param.numel()
666
+ layer_total += param.numel()
667
+
668
+ breakdown["per_layer"] = layer_detail
669
+ breakdown["per_layer_total"] = layer_total
670
+ breakdown["all_layers_total"] = layer_total * len(model.layers)
671
+ total += layer_total * len(model.layers)
672
+
673
+ # Final norm
674
+ norm_params = model.final_norm.weight.numel()
675
+ breakdown["final_norm"] = norm_params
676
+ total += norm_params
677
+
678
+ # LM head (weight tying์ด๋ฏ€๋กœ ์‹ค์ œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ 0)
679
+ breakdown["lm_head"] = "weight tying (0 additional)"
680
+ breakdown["total"] = total
681
+
682
+ return breakdown
683
+
684
+
685
+ def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
686
+ """๋ชจ๋ธ์˜ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ถ”์ •ํ•ฉ๋‹ˆ๋‹ค.
687
+
688
+ Args:
689
+ dtype_bytes: 2 (bf16/fp16) ๋˜๋Š” 4 (fp32)
690
+ """
691
+ # ๋Œ€๋žต์ ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ
692
+ emb = config.vocab_size * config.hidden_dim
693
+ per_layer = (
694
+ config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
695
+ + config.num_heads * config.head_dim * config.hidden_dim # O proj
696
+ + 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
697
+ + 2 * config.hidden_dim # 2 ร— RMSNorm
698
+ )
699
+ total_params = emb + per_layer * config.num_layers + config.hidden_dim
700
+
701
+ model_gb = total_params * dtype_bytes / 1e9
702
+ optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states ร— fp32
703
+ gradient_gb = total_params * dtype_bytes / 1e9
704
+
705
+ # ํ™œ์„ฑํ™” ๋ฉ”๋ชจ๋ฆฌ (activation checkpointing ์ ์šฉ ๊ฐ€์ •)
706
+ # ๋Œ€๋žต์  ์ถ”์ •: batch_size ร— seq_len ร— hidden_dim ร— num_layers ร— factor
707
+ activation_gb = (
708
+ batch_size * config.max_seq_len * config.hidden_dim * 4 # ๋ฐ”์ดํŠธ
709
+ * math.sqrt(config.num_layers) # checkpointing ํšจ๊ณผ
710
+ / 1e9
711
+ )
712
+
713
+ return {
714
+ "total_parameters": total_params,
715
+ "model_weights_gb": round(model_gb, 2),
716
+ "optimizer_states_gb": round(optimizer_gb, 2),
717
+ "gradients_gb": round(gradient_gb, 2),
718
+ "activations_estimated_gb": round(activation_gb, 2),
719
+ "total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
720
+ }
721
+
722
+
723
+ # ============================================================================
724
+ # 9. ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ (์‹คํ–‰ ์‹œ)
725
+ # ============================================================================
726
+
727
+ if __name__ == "__main__":
728
+ print("=" * 70)
729
+ print("LLM-1B-Lab: ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ๊ฒ€์ฆ")
730
+ print("=" * 70)
731
+
732
+ # โ”€โ”€ ๋””๋ฒ„๊ทธ ๋ชจ๋ธ (10M) ํ…Œ์ŠคํŠธ โ”€โ”€
733
+ print("\n[1] Debug Model (~10M params)")
734
+ cfg_debug = ModelConfig.debug_10m()
735
+ model_debug = LLMModel(cfg_debug)
736
+ n_params = model_debug.count_parameters()
737
+ print(f" ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {n_params:,} ({n_params / 1e6:.1f}M)")
738
+
739
+ # Forward pass ํ…Œ์ŠคํŠธ
740
+ dummy_input = torch.randint(0, cfg_debug.vocab_size, (2, 64))
741
+ dummy_target = torch.randint(0, cfg_debug.vocab_size, (2, 64))
742
+ logits, loss = model_debug(dummy_input, dummy_target)
743
+ print(f" Input shape: {dummy_input.shape}")
744
+ print(f" Logits shape: {logits.shape}")
745
+ print(f" Loss: {loss.item():.4f}")
746
+ # ์ดˆ๊ธฐ loss โ‰ˆ ln(vocab_size) โ‰ˆ ln(32000) โ‰ˆ 10.37 ์ด๋ฉด ์ •์ƒ
747
+ expected_loss = math.log(cfg_debug.vocab_size)
748
+ print(f" Expected initial loss โ‰ˆ ln({cfg_debug.vocab_size}) = {expected_loss:.2f}")
749
+
750
+ # โ”€โ”€ 1B ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ํ™•์ธ โ”€โ”€
751
+ print("\n[2] Base Model (~1B params) โ€” ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋งŒ ํ™•์ธ")
752
+ cfg_1b = ModelConfig.base_1b()
753
+
754
+ # ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ meta device์—์„œ ์ƒ์„ฑ
755
+ with torch.device("meta"):
756
+ model_1b = LLMModel(cfg_1b)
757
+ n_params_1b = model_1b.count_parameters()
758
+ print(f" ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {n_params_1b:,} ({n_params_1b / 1e6:.1f}M โ‰ˆ {n_params_1b / 1e9:.2f}B)")
759
+
760
+ # ์ƒ์„ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ถ„ํ•ด
761
+ print("\n[3] ํŒŒ๋ผ๋ฏธํ„ฐ ์ƒ์„ธ ๋ถ„ํ•ด (1B)")
762
+ detail = count_parameters_detailed(model_1b)
763
+ print(f" Token Embedding: {detail['token_embedding']:,}")
764
+ print(f" Per Layer Total: {detail['per_layer_total']:,}")
765
+ print(f" All Layers ({cfg_1b.num_layers}): {detail['all_layers_total']:,}")
766
+ print(f" Final Norm: {detail['final_norm']:,}")
767
+ print(f" LM Head: {detail['lm_head']}")
768
+ print(f" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€")
769
+ print(f" TOTAL: {detail['total']:,}")
770
+
771
+ # ๋ฉ”๋ชจ๋ฆฌ ์ถ”์ •
772
+ print("\n[4] GPU ๋ฉ”๋ชจ๋ฆฌ ์ถ”์ • (A100 40GB, bf16, batch_size=4)")
773
+ mem = estimate_memory_gb(cfg_1b, batch_size=4, dtype_bytes=2)
774
+ print(f" ๋ชจ๋ธ ๊ฐ€์ค‘์น˜: {mem['model_weights_gb']} GB")
775
+ print(f" ์˜ตํ‹ฐ๋งˆ์ด์ €: {mem['optimizer_states_gb']} GB")
776
+ print(f" ๊ธฐ์šธ๊ธฐ: {mem['gradients_gb']} GB")
777
+ print(f" ํ™œ์„ฑํ™” (์ถ”์ •): {mem['activations_estimated_gb']} GB")
778
+ print(f" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€")
779
+ print(f" ์ด ์ถ”์ •: {mem['total_estimated_gb']} GB")
780
+
781
+ # ํ…์ŠคํŠธ ์ƒ์„ฑ ํ…Œ์ŠคํŠธ (๋””๋ฒ„๊ทธ ๋ชจ๋ธ)
782
+ print("\n[5] ํ…์ŠคํŠธ ์ƒ์„ฑ ํ…Œ์ŠคํŠธ (10M debug model, ๋žœ๋ค ๊ฐ€์ค‘์น˜)")
783
+ prompt = torch.randint(0, cfg_debug.vocab_size, (1, 10))
784
+ generated = model_debug.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=50)
785
+ print(f" Prompt length: {prompt.shape[1]}")
786
+ print(f" Generated length: {generated.shape[1]}")
787
+ print(f" Generated token IDs: {generated[0].tolist()}")
788
+
789
+ print("\n" + "=" * 70)
790
+ print("โœ… ๋ชจ๋“  ๊ฒ€์ฆ ํ†ต๊ณผ!")
791
+ print("=" * 70)
_archive/llm-1b-trainer.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-1B-Lab: ํ•™์Šต ๋ฃจํ”„ (Training Loop)
3
+ ========================================
4
+ Gradient Accumulation, Mixed Precision, LR Scheduling,
5
+ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ/๋ณต์›, wandb ๋กœ๊น…์„ ํฌํ•จํ•œ ์™„์ „ํ•œ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ.
6
+
7
+ ์ „์ฒด ํ๋ฆ„:
8
+ ๋ฐฐ์น˜ ๊ฐ€์ ธ์˜ค๊ธฐ
9
+ โ†’ Forward (bf16 autocast)
10
+ โ†’ Loss / accumulation_steps (๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํ‰๊ท )
11
+ โ†’ Backward (gradient ๋ˆ„์ )
12
+ โ†’ [accumulation_steps๋งˆ๋‹ค]
13
+ โ†’ Gradient Clipping
14
+ โ†’ Optimizer Step
15
+ โ†’ LR Scheduler Step
16
+ โ†’ Logging
17
+ โ†’ [checkpoint_interval๋งˆ๋‹ค]
18
+ โ†’ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (Google Drive)
19
+ โ†’ [eval_interval๋งˆ๋‹ค]
20
+ โ†’ ๊ฒ€์ฆ Loss/Perplexity ์ธก์ •
21
+
22
+ ์„ค์น˜ ํ•„์š”:
23
+ pip install wandb torch
24
+ """
25
+
26
+ import os
27
+ import math
28
+ import time
29
+ import json
30
+ import shutil
31
+ from pathlib import Path
32
+ from dataclasses import dataclass, field
33
+ from typing import Optional, Dict, Any, Tuple
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ from torch.utils.data import DataLoader
38
+
39
+
40
+ # ============================================================================
41
+ # 1. ํ•™์Šต ์„ค์ •
42
+ # ============================================================================
43
+
44
+ @dataclass
45
+ class TrainConfig:
46
+ """ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ + ์ธํ”„๋ผ ์„ค์ •.
47
+
48
+ Colab Pro+ (A100 40GB) ๊ธฐ์ค€ ์ตœ์ ํ™”๋œ ๊ธฐ๋ณธ๊ฐ’.
49
+ ๋ชจ๋“  ๊ฐ’์— '์™œ ์ด ๊ฐ’์ธ์ง€' ์„ค๋ช…์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
50
+ """
51
+
52
+ # โ”€โ”€ ์ตœ์ ํ™” โ”€โ”€
53
+ learning_rate: float = 3e-4
54
+ """Peak LR. 1B ๋ชจ๋ธ ๊ธฐ์ค€ 3e-4๊ฐ€ ํ‘œ์ค€.
55
+ GPT-3 ๋…ผ๋ฌธ์—์„œ ๋ชจ๋ธ ํฌ๊ธฐ๋ณ„ ์ตœ์  LR์„ ์ œ์‹œ:
56
+ 125M โ†’ 6e-4, 350M โ†’ 3e-4, 1.3B โ†’ 2e-4
57
+ ์šฐ๋ฆฌ ๋ชจ๋ธ(1.1B)์€ 3e-4์—์„œ ์‹œ์ž‘, ๋ถˆ์•ˆ์ •ํ•˜๋ฉด 2e-4๋กœ ํ•˜ํ–ฅ."""
58
+
59
+ min_learning_rate: float = 3e-5
60
+ """Cosine decay ์ตœ์ €์ . ๋ณดํ†ต peak์˜ 10%.
61
+ ๋„ˆ๋ฌด ๋‚ฎ์œผ๋ฉด ํ•™์Šต ํ›„๋ฐ˜ ์ •์ฒด, ๋„ˆ๋ฌด ๋†’์œผ๋ฉด ์ˆ˜๋ ด ๋ถˆ์•ˆ์ •."""
62
+
63
+ weight_decay: float = 0.1
64
+ """AdamW์˜ L2 ์ •๊ทœํ™”. 0.1์ด LLM ํ‘œ์ค€.
65
+ Embedding๊ณผ Bias์—๋Š” ์ ์šฉํ•˜์ง€ ์•Š์Œ (๊ด€๋ก€)."""
66
+
67
+ beta1: float = 0.9
68
+ beta2: float = 0.95
69
+ """Adam ๋ชจ๋ฉ˜ํ…€ ๊ณ„์ˆ˜. ฮฒ2=0.95๋Š” LLM ํ•™์Šต์—์„œ ฮฒ2=0.999๋ณด๋‹ค ์•ˆ์ •์ .
70
+ ํฐ ๋ฐฐ์น˜ + ๊ธด ํ•™์Šต์—์„œ ฮฒ2๊ฐ€ ๋„ˆ๋ฌด ํฌ๋ฉด ์ ์‘ ์†๋„๊ฐ€ ๋А๋ฆผ."""
71
+
72
+ adam_eps: float = 1e-8
73
+ grad_clip: float = 1.0
74
+ """Gradient Clipping: gradient norm์ด 1.0์„ ์ดˆ๊ณผํ•˜๋ฉด ์Šค์ผ€์ผ๋ง.
75
+ ํ•™์Šต ์ดˆ๋ฐ˜์ด๋‚˜ ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ์—์„œ ๋ฐœ์ƒํ•˜๋Š” gradient spike ๋ฐฉ์ง€."""
76
+
77
+ # โ”€โ”€ ์Šค์ผ€์ค„๋ง โ”€โ”€
78
+ warmup_steps: int = 2000
79
+ """Warmup: ์ฒ˜์Œ 2000 ์Šคํ… ๋™์•ˆ LR์„ 0 โ†’ peak๋กœ ์„ ํ˜• ์ฆ๊ฐ€.
80
+ ์™œ ํ•„์š”ํ•œ๊ฐ€?
81
+ - ์ดˆ๊ธฐ ๊ฐ€์ค‘์น˜๊ฐ€ ๋žœ๋ค โ†’ ํฐ LR์€ ๋ถˆ์•ˆ์ •ํ•œ ์—…๋ฐ์ดํŠธ ์œ ๋ฐœ
82
+ - ์ž‘์€ LR๋กœ ์‹œ์ž‘ํ•ด ๋ชจ๋ธ์ด '๋ฐฉํ–ฅ'์„ ์žก๊ฒŒ ํ•œ ํ›„ ๋ณธ๊ฒฉ ํ•™์Šต
83
+ - 2000์€ ์ „์ฒด ํ•™์Šต์˜ ~10%๊ฐ€ ์ ๋‹น (๊ฒฝํ—˜์  ๊ทœ์น™)."""
84
+
85
+ total_steps: int = 20_000
86
+ """์ด ํ•™์Šต ์Šคํ… ์ˆ˜.
87
+ 10B tokens / (128 batch ร— 2048 seq_len) โ‰ˆ 38,000 ์ด์ง€๋งŒ,
88
+ gradient accumulation ํฌํ•จ effective step ๊ธฐ์ค€ ~20,000."""
89
+
90
+ # โ”€โ”€ ๋ฐฐ์น˜ โ”€โ”€
91
+ micro_batch_size: int = 4
92
+ """GPU์— ํ•œ ๋ฒˆ์— ์˜ฌ๋ฆฌ๋Š” ๋ฐฐ์น˜ ํฌ๊ธฐ.
93
+ A100 40GB์—์„œ 1B ๋ชจ๋ธ bf16 ๊ธฐ์ค€ 4๊ฐ€ ์•ˆ์ „ํ•œ ์ƒํ•œ."""
94
+
95
+ gradient_accumulation_steps: int = 32
96
+ """Gradient ๋ˆ„์  ํšŸ์ˆ˜. Effective batch = 4 ร— 32 = 128.
97
+ ์™œ ํฐ ๋ฐฐ์น˜๊ฐ€ ์ข‹์€๊ฐ€?
98
+ - gradient ์ถ”์ •์ด ์•ˆ์ •์  (๋…ธ์ด์ฆˆ ๊ฐ์†Œ)
99
+ - LLM ํ•™์Šต์€ ๋ณดํ†ต effective batch 128~512
100
+ - ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์‹œ ์ด ๊ฐ’์„ ๋Š˜๋ฆฌ๊ณ  micro_batch๋ฅผ ์ค„์ž„."""
101
+
102
+ # โ”€โ”€ Mixed Precision โ”€โ”€
103
+ dtype: str = "bfloat16"
104
+ """bfloat16: A100์—์„œ ์ง€์›, fp16๋ณด๋‹ค ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ์šฐ์ˆ˜.
105
+ exponent ๋น„ํŠธ๊ฐ€ fp32์™€ ๋™์ผ โ†’ overflow/underflow ์œ„ํ—˜ ์ ์Œ.
106
+ T4/V100 ํด๋ฐฑ ์‹œ 'float16'์œผ๋กœ ๋ณ€๊ฒฝ."""
107
+
108
+ # โ”€โ”€ ์ฒดํฌํฌ์ธํŠธ โ”€โ”€
109
+ checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
110
+ """Google Drive ๊ฒฝ๋กœ. Colab ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ์—๋„ ๋ณด์กด๋จ."""
111
+
112
+ checkpoint_interval: int = 500
113
+ """500 ์Šคํ…๋งˆ๋‹ค ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ.
114
+ A100 ๊ธฐ์ค€ ~30๋ถ„ ๊ฐ„๊ฒฉ. ๋„ˆ๋ฌด ์žฆ์œผ๋ฉด I/O ์˜ค๋ฒ„ํ—ค๋“œ,
115
+ ๋„ˆ๋ฌด ๋“œ๋ฌผ๋ฉด ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ ์†์‹ค ํผ."""
116
+
117
+ max_checkpoints: int = 3
118
+ """๋กค๋ง ๋ณด๊ด€ ์ˆ˜. ์˜ค๋ž˜๋œ ๊ฒƒ๋ถ€ํ„ฐ ์‚ญ์ œ.
119
+ ์ฒดํฌํฌ์ธํŠธ 1๊ฐœ โ‰ˆ 8-10GB โ†’ 3๊ฐœ๋ฉด ~30GB."""
120
+
121
+ # โ”€โ”€ ๋กœ๊น… โ”€โ”€
122
+ log_interval: int = 10
123
+ """10 ์Šคํ…๋งˆ๋‹ค ์ฝ˜์†” + wandb ๋กœ๊น…."""
124
+
125
+ eval_interval: int = 500
126
+ """500 ์Šคํ…๋งˆ๋‹ค ๊ฒ€์ฆ Loss ์ธก์ •."""
127
+
128
+ eval_steps: int = 20
129
+ """๊ฒ€์ฆ ์‹œ ์‚ฌ์šฉํ•  ๋ฐฐ์น˜ ์ˆ˜. 20 ร— 4 ร— 2048 โ‰ˆ 160K ํ† ํฐ."""
130
+
131
+ # โ”€โ”€ wandb โ”€โ”€
132
+ wandb_project: str = "llm-1b-lab"
133
+ wandb_run_name: Optional[str] = None
134
+ use_wandb: bool = True
135
+
136
+ # โ”€โ”€ ์žฌํ˜„์„ฑ โ”€โ”€
137
+ seed: int = 42
138
+
139
+ @property
140
+ def effective_batch_size(self) -> int:
141
+ return self.micro_batch_size * self.gradient_accumulation_steps
142
+
143
+ @property
144
+ def tokens_per_step(self) -> int:
145
+ """ํ•œ optimizer step๋‹น ์ฒ˜๋ฆฌ ํ† ํฐ ์ˆ˜."""
146
+ # max_seq_len์€ ์™ธ๋ถ€์—์„œ ์ฃผ์ž… (ModelConfig ์ฐธ์กฐ)
147
+ return self.effective_batch_size * 2048
148
+
149
+ @property
150
+ def torch_dtype(self) -> torch.dtype:
151
+ return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[self.dtype]
152
+
153
+
154
+ # ============================================================================
155
+ # 2. ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋Ÿฌ (Cosine with Warmup)
156
+ # ============================================================================
157
+
158
+ class CosineWarmupScheduler:
159
+ """Cosine Annealing with Linear Warmup.
160
+
161
+ LR ๊ณก์„ :
162
+ โ”Œโ”€โ”€โ”€ peak_lr โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฒ
163
+ โ”‚ โ•ฒ cosine decay
164
+ โ”‚ warmup (linear) โ•ฒ
165
+ โ”‚/ โ•ฒ_______ min_lr
166
+ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ steps
167
+
168
+ ์™œ Cosine Decay์ธ๊ฐ€?
169
+ - Step decay: ๊ฐ‘์ž‘์Šค๋Ÿฌ์šด LR ํ•˜๋ฝ โ†’ Loss ๋ถˆ์•ˆ์ •
170
+ - Linear decay: ํ›„๋ฐ˜๋ถ€ LR์ด ๋„ˆ๋ฌด ๋นจ๋ฆฌ ๊ฐ์†Œ
171
+ - Cosine: ๋ถ€๋“œ๋Ÿฌ์šด ๊ฐ์†Œ, ํ•™์Šต ํ›„๋ฐ˜์—๋„ ์ ์ ˆํ•œ LR ์œ ์ง€
172
+ - GPT-3, LLaMA, Chinchilla ๋“ฑ ๋Œ€๋ถ€๋ถ„์˜ LLM์ด ์‚ฌ์šฉ
173
+
174
+ ๊ตฌํ˜„ ์ฐธ๊ณ :
175
+ PyTorch ๋‚ด์žฅ ์Šค์ผ€์ค„๋Ÿฌ(CosineAnnealingLR ๋“ฑ)๋„ ์žˆ์ง€๋งŒ,
176
+ warmup + min_lr + ์ฒดํฌํฌ์ธํŠธ ๋ณต์›์„ ์œ„ํ•ด ์ง์ ‘ ๊ตฌํ˜„์ด ๋” ์œ ์—ฐํ•ฉ๋‹ˆ๋‹ค.
177
+ """
178
+
179
+ def __init__(self, config: TrainConfig):
180
+ self.peak_lr = config.learning_rate
181
+ self.min_lr = config.min_learning_rate
182
+ self.warmup_steps = config.warmup_steps
183
+ self.total_steps = config.total_steps
184
+
185
+ def get_lr(self, step: int) -> float:
186
+ """ํ˜„์žฌ step์— ํ•ด๋‹นํ•˜๋Š” ํ•™์Šต๋ฅ ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
187
+
188
+ Args:
189
+ step: ํ˜„์žฌ optimizer step (0-indexed)
190
+
191
+ Returns:
192
+ ํ•™์Šต๋ฅ  (float)
193
+ """
194
+ # Phase 1: Linear Warmup
195
+ if step < self.warmup_steps:
196
+ # 0 โ†’ peak_lr ์„ ํ˜• ์ฆ๊ฐ€
197
+ return self.peak_lr * (step / self.warmup_steps)
198
+
199
+ # Phase 2: Cosine Decay
200
+ # warmup ์ดํ›„ ๋‚จ์€ ์ง„ํ–‰๋ฅ  (0.0 โ†’ 1.0)
201
+ decay_steps = self.total_steps - self.warmup_steps
202
+ progress = (step - self.warmup_steps) / max(decay_steps, 1)
203
+ progress = min(progress, 1.0) # ์•ˆ์ „์žฅ์น˜
204
+
205
+ # Cosine ๊ณต์‹: min_lr + 0.5 ร— (peak - min) ร— (1 + cos(ฯ€ ร— progress))
206
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
207
+ lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
208
+
209
+ return lr
210
+
211
+ def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
212
+ """Optimizer์˜ ํ•™์Šต๋ฅ ์„ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค."""
213
+ lr = self.get_lr(step)
214
+ for param_group in optimizer.param_groups:
215
+ param_group["lr"] = lr
216
+ return lr
217
+
218
+
219
+ # ============================================================================
220
+ # 3. ์ฒดํฌํฌ์ธํŠธ ๊ด€๋ฆฌ
221
+ # ============================================================================
222
+
223
+ class CheckpointManager:
224
+ """ํ•™์Šต ์ƒํƒœ ์ €์žฅ/๋ณต์› ๊ด€๋ฆฌ์ž.
225
+
226
+ Colab์—์„œ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ค‘์š”ํ•œ ์ด์œ :
227
+ - ์„ธ์…˜ ๋งŒ๋ฃŒ (์ตœ๋Œ€ ~24์‹œ๊ฐ„) ์‹œ ๋ชจ๋“  ๋ฉ”๋ชจ๋ฆฌ ์ƒํƒœ ์†Œ๋ฉธ
228
+ - Google Drive์— ์ €์žฅํ•˜๋ฉด ์„ธ์…˜ ๊ฐ„ ์—ฐ์† ํ•™์Šต ๊ฐ€๋Šฅ
229
+ - ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๊นŒ์ง€ ์ €์žฅํ•ด์•ผ AdamW ๋ชจ๋ฉ˜ํ…€์ด ์œ ์ง€๋จ
230
+
231
+ ์ €์žฅ ๋‚ด์šฉ:
232
+ - model_state_dict: ๋ชจ๋ธ ๊ฐ€์ค‘์น˜
233
+ - optimizer_state_dict: ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ (m, v ๋ชจ๋ฉ˜ํ…€)
234
+ - step: ํ˜„์žฌ ํ•™์Šต ์Šคํ…
235
+ - best_val_loss: ์ตœ์ € ๊ฒ€์ฆ Loss
236
+ - config: ํ•™์Šต ์„ค์ • (์žฌํ˜„์„ฑ)
237
+ - rng_states: ๋žœ๋ค ์‹œ๋“œ ์ƒํƒœ (์™„์ „ ์žฌํ˜„)
238
+ - metrics_history: ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
239
+ - wandb_run_id: wandb ์‹คํ–‰ ID (๋กœ๊น… ์—ฐ์†์„ฑ)
240
+ """
241
+
242
+ def __init__(self, config: TrainConfig):
243
+ self.config = config
244
+ self.checkpoint_dir = Path(config.checkpoint_dir)
245
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
246
+ self.max_checkpoints = config.max_checkpoints
247
+
248
+ def save(
249
+ self,
250
+ model: nn.Module,
251
+ optimizer: torch.optim.Optimizer,
252
+ step: int,
253
+ best_val_loss: float,
254
+ metrics_history: Dict[str, list],
255
+ wandb_run_id: Optional[str] = None,
256
+ ):
257
+ """์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค."""
258
+ ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
259
+ ckpt_path.mkdir(parents=True, exist_ok=True)
260
+
261
+ print(f"\n๐Ÿ’พ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ: {ckpt_path}")
262
+ start = time.time()
263
+
264
+ # 1) ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ (bf16 ์ƒํƒœ ๊ทธ๋Œ€๋กœ)
265
+ torch.save(model.state_dict(), ckpt_path / "model.pt")
266
+
267
+ # 2) ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ (fp32 ๋ชจ๋ฉ˜ํ…€ ํฌํ•จ, ํฌ๊ธฐ ํผ)
268
+ torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
269
+
270
+ # 3) ํ•™์Šต ๋ฉ”ํƒ€ ์ •๋ณด
271
+ meta = {
272
+ "step": step,
273
+ "best_val_loss": best_val_loss,
274
+ "wandb_run_id": wandb_run_id,
275
+ "config": self.config.__dict__,
276
+ }
277
+ with open(ckpt_path / "meta.json", "w") as f:
278
+ json.dump(meta, f, indent=2)
279
+
280
+ # 4) ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
281
+ torch.save(metrics_history, ckpt_path / "metrics.pt")
282
+
283
+ # 5) ๋žœ๋ค ์ƒํƒœ (์™„์ „ ์žฌํ˜„์„ ์œ„ํ•ด)
284
+ rng_states = {
285
+ "python": torch.random.get_rng_state(),
286
+ "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
287
+ }
288
+ torch.save(rng_states, ckpt_path / "rng_states.pt")
289
+
290
+ elapsed = time.time() - start
291
+ ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
292
+ print(f" ์ €์žฅ ์™„๋ฃŒ: {ckpt_size:.2f} GB, {elapsed:.1f}์ดˆ")
293
+
294
+ # ์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ (๋กค๋ง)
295
+ self._cleanup_old_checkpoints()
296
+
297
+ def load_latest(
298
+ self,
299
+ model: nn.Module,
300
+ optimizer: Optional[torch.optim.Optimizer] = None,
301
+ device: torch.device = torch.device("cpu"),
302
+ ) -> Dict[str, Any]:
303
+ """๊ฐ€์žฅ ์ตœ๊ทผ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
304
+
305
+ Returns:
306
+ {"step", "best_val_loss", "wandb_run_id", "metrics_history"}
307
+ ๋˜๋Š” ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์—†์œผ๋ฉด None
308
+ """
309
+ ckpt_path = self._find_latest()
310
+ if ckpt_path is None:
311
+ print("[Checkpoint] ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ ์—†์Œ. ์ฒ˜์Œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.")
312
+ return None
313
+
314
+ print(f"\n๐Ÿ“‚ ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ: {ckpt_path}")
315
+ start = time.time()
316
+
317
+ # 1) ๋ชจ๋ธ ๊ฐ€์ค‘์น˜
318
+ model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
319
+ model.load_state_dict(model_state)
320
+ del model_state # ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
321
+
322
+ # 2) ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ
323
+ if optimizer is not None:
324
+ optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
325
+ optimizer.load_state_dict(optim_state)
326
+ del optim_state
327
+
328
+ # 3) ๋ฉ”ํƒ€ ์ •๋ณด
329
+ with open(ckpt_path / "meta.json", "r") as f:
330
+ meta = json.load(f)
331
+
332
+ # 4) ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
333
+ metrics_history = {}
334
+ metrics_path = ckpt_path / "metrics.pt"
335
+ if metrics_path.exists():
336
+ metrics_history = torch.load(metrics_path, weights_only=False)
337
+
338
+ # 5) ๋žœ๋ค ์ƒํƒœ ๋ณต์›
339
+ rng_path = ckpt_path / "rng_states.pt"
340
+ if rng_path.exists():
341
+ rng_states = torch.load(rng_path, weights_only=False)
342
+ torch.random.set_rng_state(rng_states["python"])
343
+ if rng_states["cuda"] is not None and torch.cuda.is_available():
344
+ torch.cuda.set_rng_state(rng_states["cuda"])
345
+
346
+ elapsed = time.time() - start
347
+ print(f" ๋กœ๋“œ ์™„๋ฃŒ: step={meta['step']}, {elapsed:.1f}์ดˆ")
348
+
349
+ return {
350
+ "step": meta["step"],
351
+ "best_val_loss": meta["best_val_loss"],
352
+ "wandb_run_id": meta.get("wandb_run_id"),
353
+ "metrics_history": metrics_history,
354
+ }
355
+
356
+ def _find_latest(self) -> Optional[Path]:
357
+ """๊ฐ€์žฅ ์ตœ๊ทผ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค."""
358
+ ckpts = sorted(self.checkpoint_dir.glob("step_*"))
359
+ return ckpts[-1] if ckpts else None
360
+
361
+ def _cleanup_old_checkpoints(self):
362
+ """์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค (๋กค๋ง)."""
363
+ ckpts = sorted(self.checkpoint_dir.glob("step_*"))
364
+ while len(ckpts) > self.max_checkpoints:
365
+ old = ckpts.pop(0)
366
+ print(f" ๐Ÿ—‘๏ธ ์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ: {old.name}")
367
+ shutil.rmtree(old)
368
+
369
+
370
+ # ============================================================================
371
+ # 4. ๋ฉ”ํŠธ๋ฆญ ์ถ”์ ๊ธฐ
372
+ # ============================================================================
373
+
374
+ class MetricsTracker:
375
+ """ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ์„ ์ถ”์ ํ•˜๊ณ  ๋กœ๊น…ํ•ฉ๋‹ˆ๋‹ค.
376
+
377
+ ์ถ”์  ํ•ญ๋ชฉ:
378
+ - train/loss: ํ•™์Šต Loss (Cross-Entropy)
379
+ - train/lr: ํ˜„์žฌ ํ•™์Šต๋ฅ 
380
+ - train/grad_norm: Gradient L2 Norm
381
+ - train/tokens_per_sec: ์ฒ˜๋ฆฌ๋Ÿ‰
382
+ - train/gpu_mem_gb: GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰
383
+ - val/loss: ๊ฒ€์ฆ Loss
384
+ - val/perplexity: ๊ฒ€์ฆ Perplexity (= exp(loss))
385
+ """
386
+
387
+ def __init__(self, config: TrainConfig):
388
+ self.config = config
389
+ self.history: Dict[str, list] = {
390
+ "step": [],
391
+ "train_loss": [],
392
+ "learning_rate": [],
393
+ "grad_norm": [],
394
+ "tokens_per_sec": [],
395
+ "gpu_mem_gb": [],
396
+ "val_loss": [],
397
+ "val_ppl": [],
398
+ }
399
+
400
+ # wandb ์ดˆ๊ธฐํ™”
401
+ self.wandb_run = None
402
+ if config.use_wandb:
403
+ self._init_wandb()
404
+
405
+ def _init_wandb(self, resume_id: Optional[str] = None):
406
+ """wandb ์ดˆ๊ธฐํ™” (์„ธ์…˜ ๊ฐ„ ์—ฐ์† ๋กœ๊น… ์ง€์›)."""
407
+ try:
408
+ import wandb
409
+
410
+ run_id = resume_id or wandb.util.generate_id()
411
+ self.wandb_run = wandb.init(
412
+ project=self.config.wandb_project,
413
+ name=self.config.wandb_run_name or f"1b-run-{run_id[:6]}",
414
+ id=run_id,
415
+ resume="allow",
416
+ config=self.config.__dict__,
417
+ )
418
+ print(f"[wandb] ์ดˆ๊ธฐํ™” ์™„๋ฃŒ: {self.wandb_run.url}")
419
+ except ImportError:
420
+ print("[wandb] ์„ค์น˜๋˜์ง€ ์•Š์Œ. ์ฝ˜์†” ๋กœ๊น…๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
421
+ self.config.use_wandb = False
422
+ except Exception as e:
423
+ print(f"[wandb] ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}. ์ฝ˜์†” ๋กœ๊น…๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
424
+ self.config.use_wandb = False
425
+
426
+ def resume_wandb(self, run_id: str):
427
+ """์ด์ „ wandb ์‹คํ–‰์„ ์ด์–ด์„œ ๋กœ๊น…ํ•ฉ๋‹ˆ๋‹ค."""
428
+ if self.config.use_wandb:
429
+ self._init_wandb(resume_id=run_id)
430
+
431
+ def log_train_step(
432
+ self,
433
+ step: int,
434
+ loss: float,
435
+ lr: float,
436
+ grad_norm: float,
437
+ tokens_per_sec: float,
438
+ gpu_mem_gb: float,
439
+ ):
440
+ """ํ•™์Šต ์Šคํ… ๋ฉ”ํŠธ๋ฆญ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค."""
441
+ self.history["step"].append(step)
442
+ self.history["train_loss"].append(loss)
443
+ self.history["learning_rate"].append(lr)
444
+ self.history["grad_norm"].append(grad_norm)
445
+ self.history["tokens_per_sec"].append(tokens_per_sec)
446
+ self.history["gpu_mem_gb"].append(gpu_mem_gb)
447
+
448
+ if self.config.use_wandb and self.wandb_run:
449
+ import wandb
450
+
451
+ wandb.log({
452
+ "train/loss": loss,
453
+ "train/lr": lr,
454
+ "train/grad_norm": grad_norm,
455
+ "train/tokens_per_sec": tokens_per_sec,
456
+ "train/gpu_mem_gb": gpu_mem_gb,
457
+ }, step=step)
458
+
459
+ def log_eval(self, step: int, val_loss: float, val_ppl: float):
460
+ """๊ฒ€์ฆ ๋ฉ”ํŠธ๋ฆญ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค."""
461
+ self.history["val_loss"].append(val_loss)
462
+ self.history["val_ppl"].append(val_ppl)
463
+
464
+ if self.config.use_wandb and self.wandb_run:
465
+ import wandb
466
+
467
+ wandb.log({
468
+ "val/loss": val_loss,
469
+ "val/perplexity": val_ppl,
470
+ }, step=step)
471
+
472
+ @property
473
+ def wandb_run_id(self) -> Optional[str]:
474
+ if self.wandb_run:
475
+ return self.wandb_run.id
476
+ return None
477
+
478
+
479
+ # ============================================================================
480
+ # 5. Optimizer ์ƒ์„ฑ (AdamW with weight decay ๋ถ„๋ฆฌ)
481
+ # ============================================================================
482
+
483
+ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
484
+ """AdamW ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
485
+
486
+ Weight Decay ๋ถ„๋ฆฌ ๊ทœ์น™:
487
+ - Decay ์ ์šฉ: Linear ๊ฐ€์ค‘์น˜ (attention proj, FFN ๋“ฑ)
488
+ - Decay ๋ฏธ์ ์šฉ: Embedding, LayerNorm/RMSNorm, Bias
489
+
490
+ ์™œ ๋ถ„๋ฆฌํ•˜๋Š”๊ฐ€?
491
+ - Weight Decay๋Š” ํฐ ๊ฐ€์ค‘์น˜์— ํŒจ๋„ํ‹ฐ๋ฅผ ์ฃผ์–ด ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
492
+ - ํ•˜์ง€๋งŒ Norm์˜ scale ํŒŒ๋ผ๋ฏธํ„ฐ์— ์ ์šฉํ•˜๋ฉด ์ •๊ทœํ™” ํšจ๊ณผ๋ฅผ ๋ฐฉํ•ด
493
+ - Embedding์— ์ ์šฉํ•˜๋ฉด ํฌ๊ท€ ํ† ํฐ์˜ ํ‘œํ˜„์ด 0์œผ๋กœ ์ˆ˜์ถ•
494
+ - 1D ํŒŒ๋ผ๋ฏธํ„ฐ(bias, norm weight)๋Š” decay์—์„œ ์ œ์™ธํ•˜๋Š” ๊ฒƒ์ด ๊ด€๋ก€
495
+ """
496
+ # ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ decay/no-decay ๊ทธ๋ฃน์œผ๋กœ ๋ถ„๋ฆฌ
497
+ decay_params = []
498
+ no_decay_params = []
499
+
500
+ for name, param in model.named_parameters():
501
+ if not param.requires_grad:
502
+ continue
503
+
504
+ # 1D ํ…์„œ(bias, norm weight) ๋˜๋Š” embedding โ†’ no decay
505
+ if param.dim() <= 1 or "embedding" in name:
506
+ no_decay_params.append(param)
507
+ else:
508
+ decay_params.append(param)
509
+
510
+ param_groups = [
511
+ {"params": decay_params, "weight_decay": config.weight_decay},
512
+ {"params": no_decay_params, "weight_decay": 0.0},
513
+ ]
514
+
515
+ n_decay = sum(p.numel() for p in decay_params)
516
+ n_no_decay = sum(p.numel() for p in no_decay_params)
517
+ print(f"[Optimizer] Decay ํŒŒ๋ผ๋ฏธํ„ฐ: {n_decay:,} ({n_decay/1e6:.1f}M)")
518
+ print(f"[Optimizer] No-decay ํŒŒ๋ผ๋ฏธํ„ฐ: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
519
+
520
+ optimizer = torch.optim.AdamW(
521
+ param_groups,
522
+ lr=config.learning_rate,
523
+ betas=(config.beta1, config.beta2),
524
+ eps=config.adam_eps,
525
+ fused=torch.cuda.is_available(), # CUDA fused AdamW (๋” ๋น ๋ฆ„)
526
+ )
527
+
528
+ return optimizer
529
+
530
+
531
+ # ============================================================================
532
+ # 6. Trainer (ํ•ต์‹ฌ ํ•™์Šต ๋ฃจํ”„)
533
+ # ============================================================================
534
+
535
+ class Trainer:
536
+ """LLM ์‚ฌ์ „ํ•™์Šต ํŠธ๋ ˆ์ด๋„ˆ.
537
+
538
+ ํ•™์Šต ๋ฃจํ”„์˜ ํ•ต์‹ฌ ๊ตฌ์กฐ:
539
+ ```
540
+ for step in range(total_steps):
541
+ # โ”€โ”€ Gradient Accumulation Loop โ”€โ”€
542
+ for micro_step in range(accumulation_steps):
543
+ batch = next(dataloader)
544
+ with autocast(bf16):
545
+ logits, loss = model(input_ids, targets)
546
+ scaled_loss = loss / accumulation_steps
547
+ scaled_loss.backward() # gradient ๋ˆ„์ 
548
+
549
+ # โ”€โ”€ Optimizer Step (accumulation ์™„๋ฃŒ ํ›„) โ”€โ”€
550
+ clip_grad_norm(model, max_norm=1.0)
551
+ optimizer.step()
552
+ optimizer.zero_grad()
553
+ scheduler.set_lr(optimizer, step)
554
+ ```
555
+
556
+ Gradient Accumulation์ด๋ž€?
557
+ - GPU ๋ฉ”๋ชจ๋ฆฌ์— ํฐ ๋ฐฐ์น˜๋ฅผ ํ•œ ๋ฒˆ์— ์˜ฌ๋ฆด ์ˆ˜ ์—†์„ ๋•Œ
558
+ - ์ž‘์€ micro_batch๋กœ ์—ฌ๋Ÿฌ ๋ฒˆ forward/backward โ†’ gradient๋ฅผ ๋ˆ„์ 
559
+ - ๋ˆ„์  ํ›„ ํ•œ ๋ฒˆ์— optimizer step
560
+ - ๊ฒฐ๊ณผ์ ์œผ๋กœ ํฐ effective_batch์™€ ๋™์ผํ•œ ํšจ๊ณผ
561
+ - Loss๋ฅผ accumulation_steps๋กœ ๋‚˜๋ˆ„๋Š” ์ด์œ :
562
+ gradient์˜ ํ‰๊ท ์„ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด (ํ•ฉ์ด ์•„๋‹Œ ํ‰๊ท )
563
+ """
564
+
565
+ def __init__(
566
+ self,
567
+ model: nn.Module,
568
+ train_dataloader: DataLoader,
569
+ val_dataloader: Optional[DataLoader],
570
+ config: TrainConfig,
571
+ seq_len: int = 2048,
572
+ ):
573
+ self.config = config
574
+ self.seq_len = seq_len
575
+
576
+ # โ”€โ”€ ๋””๋ฐ”์ด์Šค ์„ค์ • โ”€โ”€
577
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
578
+ print(f"[Trainer] ๋””๋ฐ”์ด์Šค: {self.device}")
579
+ if torch.cuda.is_available():
580
+ print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
581
+ print(f"[Trainer] GPU ๋ฉ”๋ชจ๋ฆฌ: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
582
+
583
+ # โ”€โ”€ ๋ชจ๋ธ โ”€โ”€
584
+ self.model = model.to(self.device)
585
+ # torch.compile: PyTorch 2.0+ ๊ทธ๋ž˜ํ”„ ์ตœ์ ํ™” (์†๋„ 10-30% ํ–ฅ์ƒ)
586
+ if torch.cuda.is_available() and hasattr(torch, "compile"):
587
+ print("[Trainer] torch.compile ์ ์šฉ ์ค‘...")
588
+ self.model = torch.compile(self.model)
589
+
590
+ # โ”€โ”€ ๋ฐ์ดํ„ฐ โ”€โ”€
591
+ self.train_dataloader = train_dataloader
592
+ self.val_dataloader = val_dataloader
593
+ self.train_iter = iter(train_dataloader)
594
+
595
+ # โ”€โ”€ ์˜ตํ‹ฐ๋งˆ์ด์ € โ”€โ”€
596
+ self.optimizer = create_optimizer(self.model, config)
597
+
598
+ # โ”€โ”€ ์Šค์ผ€์ค„๋Ÿฌ โ”€โ”€
599
+ self.scheduler = CosineWarmupScheduler(config)
600
+
601
+ # โ”€โ”€ ์ฒดํฌํฌ์ธํŠธ โ”€โ”€
602
+ self.ckpt_manager = CheckpointManager(config)
603
+
604
+ # โ”€โ”€ ๋ฉ”ํŠธ๋ฆญ โ”€โ”€
605
+ self.metrics = MetricsTracker(config)
606
+
607
+ # โ”€โ”€ ํ•™์Šต ์ƒํƒœ โ”€โ”€
608
+ self.global_step = 0
609
+ self.best_val_loss = float("inf")
610
+ self.tokens_seen = 0
611
+
612
+ # โ”€โ”€ Mixed Precision โ”€โ”€
613
+ # bf16์€ GradScaler๊ฐ€ ๋ถˆํ•„์š” (fp16์ผ ๋•Œ๋งŒ ํ•„์š”)
614
+ self.use_amp = config.dtype != "float32"
615
+ self.amp_dtype = config.torch_dtype
616
+
617
+ # โ”€โ”€ ์ž๋™ ๋ณต์› ์‹œ๋„ โ”€โ”€
618
+ self._try_resume()
619
+
620
+ def _try_resume(self):
621
+ """์ด์ „ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์žˆ์œผ๋ฉด ์ž๋™์œผ๋กœ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค."""
622
+ result = self.ckpt_manager.load_latest(
623
+ self.model, self.optimizer, self.device
624
+ )
625
+
626
+ if result is not None:
627
+ self.global_step = result["step"]
628
+ self.best_val_loss = result["best_val_loss"]
629
+ self.metrics.history = result.get("metrics_history", self.metrics.history)
630
+
631
+ # wandb ์—ฐ์† ๋กœ๊น…
632
+ if result.get("wandb_run_id"):
633
+ self.metrics.resume_wandb(result["wandb_run_id"])
634
+
635
+ self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
636
+ print(f"[Trainer] ํ•™์Šต ์žฌ๊ฐœ: step={self.global_step}, "
637
+ f"tokens={self.tokens_seen/1e9:.2f}B, "
638
+ f"best_val_loss={self.best_val_loss:.4f}")
639
+
640
+ def _get_next_batch(self) -> Dict[str, torch.Tensor]:
641
+ """๋‹ค์Œ ํ•™์Šต ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
642
+
643
+ Streaming DataLoader๋Š” ์—ํญ ๊ฐœ๋…์ด ์—†์œผ๋ฏ€๋กœ,
644
+ StopIteration ์‹œ ์ƒˆ ์ดํ„ฐ๋ ˆ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
645
+ """
646
+ try:
647
+ batch = next(self.train_iter)
648
+ except StopIteration:
649
+ self.train_iter = iter(self.train_dataloader)
650
+ batch = next(self.train_iter)
651
+
652
+ return {
653
+ "input_ids": batch["input_ids"].to(self.device, non_blocking=True),
654
+ "targets": batch["targets"].to(self.device, non_blocking=True),
655
+ }
656
+
657
+ def _train_step(self) -> Tuple[float, float]:
658
+ """ํ•˜๋‚˜์˜ optimizer step์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
659
+
660
+ Returns:
661
+ (loss, grad_norm)
662
+ """
663
+ self.model.train()
664
+ self.optimizer.zero_grad(set_to_none=True)
665
+ # set_to_none=True: gradient๋ฅผ None์œผ๋กœ ์„ค์ • โ†’ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
666
+
667
+ total_loss = 0.0
668
+
669
+ # โ”€โ”€ Gradient Accumulation Loop โ”€โ”€
670
+ for micro_step in range(self.config.gradient_accumulation_steps):
671
+ batch = self._get_next_batch()
672
+
673
+ # Mixed Precision Forward
674
+ with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
675
+ logits, loss = self.model(batch["input_ids"], batch["targets"])
676
+
677
+ # Loss ์Šค์ผ€์ผ๋ง: effective batch์˜ ํ‰๊ท ์„ ์œ„ํ•ด
678
+ scaled_loss = loss / self.config.gradient_accumulation_steps
679
+ total_loss += loss.item()
680
+
681
+ # Backward (gradient ๋ˆ„์ )
682
+ scaled_loss.backward()
683
+
684
+ # โ”€โ”€ Gradient Clipping โ”€โ”€
685
+ # ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ gradient๋ฅผ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ๋ณด๊ณ  L2 norm ๊ณ„์‚ฐ
686
+ # norm์ด max_norm์„ ์ดˆ๊ณผํ•˜๋ฉด ๋น„๋ก€์ ์œผ๋กœ ์Šค์ผ€์ผ ๋‹ค์šด
687
+ grad_norm = torch.nn.utils.clip_grad_norm_(
688
+ self.model.parameters(),
689
+ max_norm=self.config.grad_clip,
690
+ ).item()
691
+
692
+ # โ”€โ”€ Optimizer Step โ”€โ”€
693
+ self.optimizer.step()
694
+
695
+ # โ”€โ”€ LR ์—…๋ฐ์ดํŠธ โ”€โ”€
696
+ self.scheduler.set_lr(self.optimizer, self.global_step)
697
+
698
+ avg_loss = total_loss / self.config.gradient_accumulation_steps
699
+ return avg_loss, grad_norm
700
+
701
+ @torch.no_grad()
702
+ def _evaluate(self) -> Tuple[float, float]:
703
+ """๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์—์„œ Loss์™€ Perplexity๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
704
+
705
+ Perplexity = exp(loss)
706
+ - ์ง๊ด€: "๋ชจ๋ธ์ด ๋‹ค์Œ ํ† ํฐ์„ ํ‰๊ท  ๋ช‡ ๊ฐœ์˜ ํ›„๋ณด ์ค‘์—์„œ ๊ณ ๋ฅด๋Š”๊ฐ€"
707
+ - PPL 100 โ†’ 100๊ฐœ ์ค‘ 1๊ฐœ๋ฅผ ๊ท ์ผํ•˜๊ฒŒ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€
708
+ - PPL 20 โ†’ 20๊ฐœ ์ค‘ 1๊ฐœ ์ˆ˜์ค€ (๊ฝค ์ข‹์Œ)
709
+ - PPL 10 โ†’ ๋งค์šฐ ์ž์‹ ์žˆ๊ฒŒ ์˜ˆ์ธก
710
+ """
711
+ if self.val_dataloader is None:
712
+ return float("inf"), float("inf")
713
+
714
+ self.model.eval()
715
+ total_loss = 0.0
716
+ num_batches = 0
717
+
718
+ for i, batch in enumerate(self.val_dataloader):
719
+ if i >= self.config.eval_steps:
720
+ break
721
+
722
+ input_ids = batch["input_ids"].to(self.device)
723
+ targets = batch["targets"].to(self.device)
724
+
725
+ with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
726
+ _, loss = self.model(input_ids, targets)
727
+
728
+ total_loss += loss.item()
729
+ num_batches += 1
730
+
731
+ avg_loss = total_loss / max(num_batches, 1)
732
+ perplexity = math.exp(min(avg_loss, 20)) # overflow ๋ฐฉ์ง€ (exp(20) โ‰ˆ 5์–ต)
733
+
734
+ return avg_loss, perplexity
735
+
736
+ def train(self):
737
+ """๋ฉ”์ธ ํ•™์Šต ๋ฃจํ”„.
738
+
739
+ ์ด ๋ฉ”์„œ๋“œ๊ฐ€ ์ „์ฒด ํ•™์Šต์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
740
+ Colab ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ ์ค‘๋‹จ๋˜์–ด๋„ ์ฒดํฌํฌ์ธํŠธ์—์„œ ์ž๋™ ์žฌ๊ฐœ๋ฉ๋‹ˆ๋‹ค.
741
+ """
742
+ config = self.config
743
+
744
+ print("\n" + "=" * 70)
745
+ print("๐Ÿš€ ํ•™์Šต ์‹œ์ž‘")
746
+ print("=" * 70)
747
+ print(f" ์ด ์Šคํ…: {config.total_steps:,}")
748
+ print(f" ์‹œ์ž‘ ์Šคํ…: {self.global_step}")
749
+ print(f" Effective batch size: {config.effective_batch_size}")
750
+ print(f" ํ† ํฐ/์Šคํ…: {config.effective_batch_size * self.seq_len:,}")
751
+ print(f" ์ด ํ•™์Šต ํ† ํฐ (์˜ˆ์ƒ): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
752
+ print(f" Mixed Precision: {config.dtype}")
753
+ print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
754
+ print(f" ์ฒดํฌํฌ์ธํŠธ: {config.checkpoint_dir}")
755
+ print("=" * 70 + "\n")
756
+
757
+ step_start_time = time.time()
758
+ tokens_at_log_start = self.tokens_seen
759
+
760
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
761
+ # ๋ฉ”์ธ ๋ฃจํ”„
762
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
763
+
764
+ while self.global_step < config.total_steps:
765
+
766
+ # โ”€โ”€ Train Step โ”€โ”€
767
+ loss, grad_norm = self._train_step()
768
+ self.global_step += 1
769
+ self.tokens_seen += config.effective_batch_size * self.seq_len
770
+
771
+ # โ”€โ”€ Logging โ”€โ”€
772
+ if self.global_step % config.log_interval == 0:
773
+ elapsed = time.time() - step_start_time
774
+ tokens_delta = self.tokens_seen - tokens_at_log_start
775
+ tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
776
+
777
+ # GPU ๋ฉ”๋ชจ๋ฆฌ
778
+ gpu_mem_gb = 0.0
779
+ if torch.cuda.is_available():
780
+ gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
781
+
782
+ # ํ˜„์žฌ LR
783
+ current_lr = self.scheduler.get_lr(self.global_step)
784
+
785
+ # ๋‚จ์€ ์‹œ๊ฐ„ ์ถ”์ •
786
+ remaining_steps = config.total_steps - self.global_step
787
+ steps_per_sec = config.log_interval / max(elapsed, 1e-6)
788
+ eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
789
+ eta_hours = eta_seconds / 3600
790
+
791
+ # ์ฝ˜์†” ์ถœ๋ ฅ
792
+ print(
793
+ f" Step {self.global_step:>6d}/{config.total_steps} โ”‚ "
794
+ f"Loss {loss:.4f} โ”‚ "
795
+ f"LR {current_lr:.2e} โ”‚ "
796
+ f"Grad {grad_norm:.2f} โ”‚ "
797
+ f"{tokens_per_sec:,.0f} tok/s โ”‚ "
798
+ f"GPU {gpu_mem_gb:.1f}GB โ”‚ "
799
+ f"ETA {eta_hours:.1f}h โ”‚ "
800
+ f"Tokens {self.tokens_seen/1e9:.2f}B"
801
+ )
802
+
803
+ # wandb ๋กœ๊น…
804
+ self.metrics.log_train_step(
805
+ step=self.global_step,
806
+ loss=loss,
807
+ lr=current_lr,
808
+ grad_norm=grad_norm,
809
+ tokens_per_sec=tokens_per_sec,
810
+ gpu_mem_gb=gpu_mem_gb,
811
+ )
812
+
813
+ step_start_time = time.time()
814
+ tokens_at_log_start = self.tokens_seen
815
+
816
+ # โ”€โ”€ Evaluation โ”€โ”€
817
+ if self.global_step % config.eval_interval == 0:
818
+ val_loss, val_ppl = self._evaluate()
819
+
820
+ print(f"\n ๐Ÿ“Š Eval @ Step {self.global_step}: "
821
+ f"Val Loss = {val_loss:.4f}, "
822
+ f"Val PPL = {val_ppl:.2f}")
823
+
824
+ self.metrics.log_eval(self.global_step, val_loss, val_ppl)
825
+
826
+ if val_loss < self.best_val_loss:
827
+ self.best_val_loss = val_loss
828
+ print(f" ๐Ÿ† New best val loss: {val_loss:.4f}")
829
+
830
+ print()
831
+
832
+ # โ”€โ”€ Checkpoint โ”€โ”€
833
+ if self.global_step % config.checkpoint_interval == 0:
834
+ self.ckpt_manager.save(
835
+ model=self.model,
836
+ optimizer=self.optimizer,
837
+ step=self.global_step,
838
+ best_val_loss=self.best_val_loss,
839
+ metrics_history=self.metrics.history,
840
+ wandb_run_id=self.metrics.wandb_run_id,
841
+ )
842
+
843
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
844
+ # ํ•™์Šต ์™„๋ฃŒ
845
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
846
+
847
+ print("\n" + "=" * 70)
848
+ print("๐ŸŽ‰ ํ•™์Šต ์™„๋ฃŒ!")
849
+ print("=" * 70)
850
+ print(f" ์ด ์Šคํ…: {self.global_step:,}")
851
+ print(f" ์ด ํ† ํฐ: {self.tokens_seen/1e9:.2f}B")
852
+ print(f" ์ตœ์ € Val Loss: {self.best_val_loss:.4f}")
853
+ print(f" ์ตœ์ € Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
854
+ print("=" * 70)
855
+
856
+ # ์ตœ์ข… ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
857
+ self.ckpt_manager.save(
858
+ model=self.model,
859
+ optimizer=self.optimizer,
860
+ step=self.global_step,
861
+ best_val_loss=self.best_val_loss,
862
+ metrics_history=self.metrics.history,
863
+ wandb_run_id=self.metrics.wandb_run_id,
864
+ )
865
+
866
+ if self.config.use_wandb and self.metrics.wandb_run:
867
+ import wandb
868
+ wandb.finish()
869
+
870
+
871
+ # ============================================================================
872
+ # 7. GPU ํ™˜๊ฒฝ ์ž๋™ ๊ฐ์ง€ ๋ฐ ์„ค์ • ์กฐ์ •
873
+ # ============================================================================
874
+
875
+ def auto_configure(config: TrainConfig) -> TrainConfig:
876
+ """GPU ์ข…๋ฅ˜์— ๋”ฐ๋ผ ์„ค์ •์„ ์ž๋™ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
877
+
878
+ Colab Pro+์—์„œ A100์ด ํ•ญ์ƒ ๋ฐฐ์ •๋˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.
879
+ T4๋‚˜ V100์ด ๋ฐฐ์ •๋  ๊ฒฝ์šฐ ์ž๋™์œผ๋กœ ์„ค์ •์„ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
880
+
881
+ Returns:
882
+ ์กฐ์ •๋œ TrainConfig
883
+ """
884
+ if not torch.cuda.is_available():
885
+ print("โš ๏ธ GPU ์—†์Œ! CPU ๋ชจ๋“œ (๋งค์šฐ ๋А๋ฆผ)")
886
+ config.dtype = "float32"
887
+ config.micro_batch_size = 1
888
+ config.gradient_accumulation_steps = 4
889
+ return config
890
+
891
+ gpu_name = torch.cuda.get_device_name().lower()
892
+ gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
893
+
894
+ print(f"\n๐Ÿ” GPU ๊ฐ์ง€: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
895
+
896
+ if "a100" in gpu_name:
897
+ # A100 40GB: ๊ธฐ๋ณธ ์„ค์ • ๊ทธ๋Œ€๋กœ (์ตœ์ )
898
+ print(" โ†’ A100 ๊ฐ์ง€: ๊ธฐ๋ณธ ์„ค์ • ์‚ฌ์šฉ (bf16, batch=4)")
899
+ config.dtype = "bfloat16"
900
+ config.micro_batch_size = 4
901
+
902
+ elif "v100" in gpu_name:
903
+ # V100 16GB: bf16 ๋ฏธ์ง€์›, ๋ฐฐ์น˜ ์ถ•์†Œ
904
+ print(" โ†’ V100 ๊ฐ์ง€: fp16 ๋ชจ๋“œ, ๋ฐฐ์น˜ ์ถ•์†Œ")
905
+ config.dtype = "float16"
906
+ config.micro_batch_size = 2
907
+ config.gradient_accumulation_steps = 64 # effective batch ์œ ์ง€
908
+
909
+ elif "t4" in gpu_name:
910
+ # T4 16GB: bf16 ๋ฏธ์ง€์›, ๋” ์ž‘์€ ๋ฐฐ์น˜
911
+ print(" โ†’ T4 ๊ฐ์ง€: fp16 ๋ชจ๋“œ, ์ตœ์†Œ ๋ฐฐ์น˜")
912
+ config.dtype = "float16"
913
+ config.micro_batch_size = 1
914
+ config.gradient_accumulation_steps = 128
915
+
916
+ elif "l4" in gpu_name:
917
+ # L4 24GB: bf16 ์ง€์›
918
+ print(" โ†’ L4 ๊ฐ์ง€: bf16 ๋ชจ๋“œ, ๋ฐฐ์น˜ ์กฐ์ •")
919
+ config.dtype = "bfloat16"
920
+ config.micro_batch_size = 2
921
+ config.gradient_accumulation_steps = 64
922
+
923
+ else:
924
+ print(f" โ†’ ์•Œ ์ˆ˜ ์—†๋Š” GPU. ๋ฉ”๋ชจ๋ฆฌ ๊ธฐ์ค€์œผ๋กœ ์„ค์ • ์กฐ์ •")
925
+ if gpu_mem >= 30:
926
+ config.micro_batch_size = 4
927
+ elif gpu_mem >= 16:
928
+ config.micro_batch_size = 2
929
+ else:
930
+ config.micro_batch_size = 1
931
+ config.gradient_accumulation_steps = 128
932
+
933
+ print(f" โ†’ dtype: {config.dtype}")
934
+ print(f" โ†’ micro_batch: {config.micro_batch_size}")
935
+ print(f" โ†’ grad_accum: {config.gradient_accumulation_steps}")
936
+ print(f" โ†’ effective_batch: {config.effective_batch_size}")
937
+
938
+ return config
939
+
940
+
941
+ # ============================================================================
942
+ # 8. Quick Start (Colab ์‹คํ–‰์šฉ)
943
+ # ============================================================================
944
+
945
+ def start_training(
946
+ model: nn.Module,
947
+ train_dataloader: DataLoader,
948
+ val_dataloader: Optional[DataLoader] = None,
949
+ config: Optional[TrainConfig] = None,
950
+ seq_len: int = 2048,
951
+ auto_config: bool = True,
952
+ ) -> Trainer:
953
+ """ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค (ํ•œ ์ค„ ์‹คํ–‰).
954
+
955
+ ์‚ฌ์šฉ๋ฒ• (Colab):
956
+ ```python
957
+ from model import LLMModel, ModelConfig
958
+ from data_pipeline import setup_data_pipeline, DataConfig
959
+ from trainer import start_training, TrainConfig
960
+
961
+ # 1. ๋ชจ๋ธ ์ƒ์„ฑ
962
+ model_config = ModelConfig.base_1b()
963
+ model = LLMModel(model_config)
964
+
965
+ # 2. ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ
966
+ tok, train_dl, val_dl = setup_data_pipeline("pretrained")
967
+
968
+ # 3. ํ•™์Šต ์‹œ์ž‘ (์ฒดํฌํฌ์ธํŠธ ์ž๋™ ๋ณต์›)
969
+ trainer = start_training(model, train_dl, val_dl)
970
+ ```
971
+ """
972
+ config = config or TrainConfig()
973
+
974
+ # GPU ์ž๋™ ๊ฐ์ง€ ๋ฐ ์„ค์ • ์กฐ์ •
975
+ if auto_config:
976
+ config = auto_configure(config)
977
+
978
+ # Google Drive ๋งˆ์šดํŠธ ํ™•์ธ (Colab)
979
+ if "/content/drive" in config.checkpoint_dir:
980
+ drive_path = Path("/content/drive/MyDrive")
981
+ if not drive_path.exists():
982
+ print("\nโš ๏ธ Google Drive๊ฐ€ ๋งˆ์šดํŠธ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค!")
983
+ print(" Colab์—์„œ ์‹คํ–‰: from google.colab import drive; drive.mount('/content/drive')")
984
+ print(" ๋กœ์ปฌ ๊ฒฝ๋กœ๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค.")
985
+ config.checkpoint_dir = "./checkpoints"
986
+
987
+ # ์žฌํ˜„์„ฑ ์‹œ๋“œ ์„ค์ •
988
+ torch.manual_seed(config.seed)
989
+ if torch.cuda.is_available():
990
+ torch.cuda.manual_seed(config.seed)
991
+
992
+ # Trainer ์ƒ์„ฑ (์ฒดํฌํฌ์ธํŠธ ์ž๋™ ๋ณต์› ํฌํ•จ)
993
+ trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
994
+
995
+ # ํ•™์Šต ์‹คํ–‰
996
+ trainer.train()
997
+
998
+ return trainer
999
+
1000
+
1001
+ # ============================================================================
1002
+ # 9. ๊ฒ€์ฆ ์Šคํฌ๋ฆฝํŠธ
1003
+ # ============================================================================
1004
+
1005
+ if __name__ == "__main__":
1006
+ print("=" * 70)
1007
+ print("LLM-1B-Lab: Trainer ๊ฒ€์ฆ")
1008
+ print("=" * 70)
1009
+
1010
+ # โ”€โ”€ ๋ฏธ๋‹ˆ ๋ชจ๋ธ๋กœ ํ•™์Šต ๋ฃจํ”„ ๊ฒ€์ฆ โ”€โ”€
1011
+ print("\n[ํ…Œ์ŠคํŠธ 1] ๋ฏธ๋‹ˆ ๋ชจ๋ธ ํ•™์Šต ๋ฃจํ”„ ๊ฒ€์ฆ")
1012
+
1013
+ # ๊ฐ„๋‹จํ•œ ๋”๋ฏธ ๋ชจ๋ธ
1014
+ class TinyModel(nn.Module):
1015
+ def __init__(self, vocab_size=100, dim=64):
1016
+ super().__init__()
1017
+ self.emb = nn.Embedding(vocab_size, dim)
1018
+ self.linear = nn.Linear(dim, vocab_size)
1019
+ self.linear.weight = self.emb.weight # weight tying
1020
+
1021
+ def forward(self, input_ids, targets=None):
1022
+ import torch.nn.functional as F
1023
+
1024
+ h = self.emb(input_ids)
1025
+ logits = self.linear(h)
1026
+ loss = None
1027
+ if targets is not None:
1028
+ loss = F.cross_entropy(logits.view(-1, 100), targets.view(-1))
1029
+ return logits, loss
1030
+
1031
+ def count_parameters(self, trainable_only=True):
1032
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
1033
+
1034
+ model = TinyModel()
1035
+ print(f" ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ: {model.count_parameters():,}")
1036
+
1037
+ # ๋”๋ฏธ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
1038
+ def dummy_dataloader(num_batches=100, batch_size=4, seq_len=32, vocab=100):
1039
+ for _ in range(num_batches):
1040
+ ids = torch.randint(0, vocab, (batch_size, seq_len + 1))
1041
+ yield {
1042
+ "input_ids": ids[:, :-1],
1043
+ "targets": ids[:, 1:],
1044
+ }
1045
+
1046
+ # ์„ค์ • (๋งค์šฐ ์งง์€ ํ•™์Šต)
1047
+ config = TrainConfig(
1048
+ total_steps=20,
1049
+ warmup_steps=5,
1050
+ micro_batch_size=4,
1051
+ gradient_accumulation_steps=2,
1052
+ log_interval=5,
1053
+ eval_interval=10,
1054
+ checkpoint_interval=10,
1055
+ checkpoint_dir="./test_checkpoints",
1056
+ use_wandb=False,
1057
+ dtype="float32", # CPU ํ…Œ์ŠคํŠธ
1058
+ )
1059
+
1060
+ # ์Šค์ผ€์ค„๋Ÿฌ ํ…Œ์ŠคํŠธ
1061
+ print("\n[ํ…Œ์ŠคํŠธ 2] LR ์Šค์ผ€์ค„๋Ÿฌ ๊ฒ€์ฆ")
1062
+ scheduler = CosineWarmupScheduler(config)
1063
+ test_steps = [0, 2, 5, 10, 15, 20]
1064
+ for s in test_steps:
1065
+ lr = scheduler.get_lr(s)
1066
+ phase = "warmup" if s < config.warmup_steps else "cosine"
1067
+ print(f" Step {s:3d}: LR = {lr:.6f} ({phase})")
1068
+
1069
+ # Optimizer ํ…Œ์ŠคํŠธ
1070
+ print("\n[ํ…Œ์ŠคํŠธ 3] Optimizer ์ƒ์„ฑ ๊ฒ€์ฆ")
1071
+ optimizer = create_optimizer(model, config)
1072
+ print(f" ํŒŒ๋ผ๋ฏธํ„ฐ ๊ทธ๋ฃน ์ˆ˜: {len(optimizer.param_groups)}")
1073
+ for i, pg in enumerate(optimizer.param_groups):
1074
+ n_params = sum(p.numel() for p in pg["params"])
1075
+ print(f" ๊ทธ๋ฃน {i}: {n_params:,} params, weight_decay={pg['weight_decay']}")
1076
+
1077
+ # ํ•™์Šต ๋ฃจํ”„ ๏ฟฝ๏ฟฝ์ŠคํŠธ (์งง์€ ๋ฒ„์ „)
1078
+ print("\n[ํ…Œ์ŠคํŠธ 4] ํ•™์Šต ๋ฃจํ”„ ์‹คํ–‰ (20 steps)")
1079
+ train_dl = list(dummy_dataloader(num_batches=200))
1080
+
1081
+ # DataLoader ์‹œ๋ฎฌ๋ ˆ์ด์…˜
1082
+ class SimpleLoader:
1083
+ def __init__(self, data):
1084
+ self.data = data
1085
+
1086
+ def __iter__(self):
1087
+ return iter(self.data)
1088
+
1089
+ trainer = Trainer(
1090
+ model=model,
1091
+ train_dataloader=SimpleLoader(train_dl),
1092
+ val_dataloader=SimpleLoader(train_dl[:20]),
1093
+ config=config,
1094
+ seq_len=32,
1095
+ )
1096
+ trainer.train()
1097
+
1098
+ # ์ •๋ฆฌ
1099
+ import shutil
1100
+ if os.path.exists("./test_checkpoints"):
1101
+ shutil.rmtree("./test_checkpoints")
1102
+
1103
+ print("\n" + "=" * 70)
1104
+ print("โœ… Trainer ๊ฒ€์ฆ ์™„๋ฃŒ!")
1105
+ print()
1106
+ print("์‹ค์ œ ํ•™์Šต ์‹คํ–‰ ๋ฐฉ๋ฒ•:")
1107
+ print(" trainer = start_training(model, train_dl, val_dl)")
1108
+ print("=" * 70)
llm_lab/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM-1B-Lab: 1B Parameter LLaMA-style Transformer (from scratch)
3
+ ================================================================
4
+ ๋”ฅ๋Ÿฌ๋‹ ์ดˆ๋ณด์ž๋ฅผ ์œ„ํ•œ ํ•™์Šต์šฉ ๊ตฌํ˜„.
5
+ ๊ฐ ์ปดํฌ๋„ŒํŠธ์— ์ƒ์„ธ ์ฃผ์„์„ ๋‹ฌ์•„ "์™œ ์ด๋ ‡๊ฒŒ ํ•˜๋Š”์ง€"๋ฅผ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.
6
+
7
+ ๋ชจ๋“ˆ ๊ตฌ์กฐ:
8
+ llm_lab.config โ€” ๋ชจ๋“  ์„ค์ • (ModelConfig, DataConfig, TrainConfig, EvalConfig)
9
+ llm_lab.model โ€” ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ (RMSNorm, RoPE, GQA, SwiGLU, Transformer)
10
+ llm_lab.data โ€” ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ (ํ† ํฌ๋‚˜์ด์ €, ์ŠคํŠธ๋ฆฌ๋ฐ, ํŒจํ‚น)
11
+ llm_lab.training โ€” ํ•™์Šต ๋ฃจํ”„ (Trainer, ์Šค์ผ€์ค„๋Ÿฌ, ์ฒดํฌํฌ์ธํŠธ)
12
+ llm_lab.evaluation โ€” ํ‰๊ฐ€ (Perplexity, ์ƒ์„ฑ, Scaling Law, Attention)
13
+ llm_lab.utils โ€” ๊ณตํ†ต ์œ ํ‹ธ๋ฆฌํ‹ฐ (๋””๋ฐ”์ด์Šค ๊ฐ์ง€, ์‹œ๋“œ)
14
+
15
+ Quick Start:
16
+ from llm_lab.config import ModelConfig, DataConfig, TrainConfig
17
+ from llm_lab.model import LLMModel
18
+ from llm_lab.data import setup_data_pipeline
19
+ from llm_lab.training import start_training
20
+ from llm_lab.evaluation import run_evaluation
21
+ """
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .config import ModelConfig, DataConfig, TrainConfig, EvalConfig
26
+ from .model import LLMModel
27
+ from .data import setup_data_pipeline
28
+ from .training import start_training
29
+ from .evaluation import run_evaluation
30
+ from .utils import get_device, auto_configure
llm_lab/config/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """์„ค์ •(Config) ๋ชจ๋“ˆ โ€” ๋ชจ๋“  ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ•œ ๊ณณ์—์„œ ๊ด€๋ฆฌํ•ฉ๋‹ˆ๋‹ค."""
2
+ from .model_config import ModelConfig
3
+ from .data_config import DataConfig
4
+ from .train_config import TrainConfig
5
+ from .eval_config import EvalConfig
6
+
7
+ __all__ = ["ModelConfig", "DataConfig", "TrainConfig", "EvalConfig"]
llm_lab/config/data_config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class DataConfig:
7
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •.
8
+
9
+ Colab Pro+ ํ™˜๊ฒฝ ์ œ์•ฝ์„ ๊ณ ๋ คํ•œ ๊ธฐ๋ณธ๊ฐ’:
10
+ - Streaming ๋ชจ๋“œ๋กœ ๋””์Šคํฌ ์‚ฌ์šฉ ์ตœ์†Œํ™”
11
+ - ์‹œํ€€์Šค ํŒจํ‚น์œผ๋กœ ํŒจ๋”ฉ ์—†์ด GPU ํ™œ์šฉ๋ฅ  ๊ทน๋Œ€ํ™”
12
+ - ์ „์ฒ˜๋ฆฌ๋ฅผ on-the-fly๋กœ ์ˆ˜ํ–‰ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
13
+ """
14
+ # โ”€โ”€ ๋ฐ์ดํ„ฐ์…‹ โ”€โ”€
15
+ dataset_name: str = "HuggingFaceFW/fineweb-edu"
16
+ dataset_subset: str = "sample-10BT" # 10B ํ† ํฐ ์ƒ˜ํ”Œ
17
+ dataset_split: str = "train"
18
+ text_column: str = "text" # ํ…์ŠคํŠธ๊ฐ€ ๋‹ด๊ธด ์ปฌ๋Ÿผ๋ช…
19
+
20
+ # โ”€โ”€ ํ† ํฌ๋‚˜์ด์ € โ”€โ”€
21
+ tokenizer_type: str = "sentencepiece" # "sentencepiece" ๋˜๋Š” "hf"
22
+ # ์‚ฌ์ „ ํ•™์Šต๋œ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ (์—†์œผ๋ฉด ์ƒˆ๋กœ ํ•™์Šต)
23
+ tokenizer_path: Optional[str] = None
24
+ vocab_size: int = 32_000
25
+
26
+ # โ”€โ”€ ์‹œํ€€์Šค โ”€โ”€
27
+ max_seq_len: int = 2048
28
+ # ๋ฌธ์„œ ๊ตฌ๋ถ„ ํ† ํฐ ์‚ฌ์šฉ ์—ฌ๋ถ€ (ํŒจํ‚น ์‹œ ๋ฌธ์„œ ๊ฒฝ๊ณ„ ํ‘œ์‹œ)
29
+ use_eos_separator: bool = True
30
+
31
+ # โ”€โ”€ ๋ฐฐ์น˜ โ”€โ”€
32
+ batch_size: int = 4 # micro batch (GPU๋‹น)
33
+ num_workers: int = 2 # DataLoader ์›Œ์ปค ์ˆ˜
34
+ prefetch_factor: int = 4 # ๋ฏธ๋ฆฌ ์ค€๋น„ํ•  ๋ฐฐ์น˜ ์ˆ˜
35
+
36
+ # โ”€โ”€ ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต ์„ค์ • (์ƒˆ๋กœ ํ•™์Šต ์‹œ) โ”€โ”€
37
+ tokenizer_train_samples: int = 50_000 # ํ•™์Šต์— ์‚ฌ์šฉํ•  ๋ฌธ์„œ ์ˆ˜
38
+ tokenizer_save_dir: str = "./tokenizer"
39
+
40
+ # โ”€โ”€ ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ โ”€โ”€
41
+ val_ratio: float = 0.001 # ์ „์ฒด์˜ 0.1%๋ฅผ ๊ฒ€์ฆ์šฉ์œผ๋กœ
llm_lab/config/eval_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class EvalConfig:
6
+ """ํ‰๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ."""
7
+ # โ”€โ”€ Perplexity โ”€โ”€
8
+ eval_batch_size: int = 4
9
+ max_eval_batches: int = 100 # ์ตœ๋Œ€ ํ‰๊ฐ€ ๋ฐฐ์น˜ ์ˆ˜
10
+
11
+ # โ”€โ”€ ์ƒ์„ฑ โ”€โ”€
12
+ max_new_tokens: int = 200
13
+ temperature: float = 0.8
14
+ top_k: int = 50
15
+ top_p: float = 0.9
16
+ num_samples: int = 3 # ํ”„๋กฌํ”„ํŠธ๋‹น ์ƒ์„ฑ ํšŸ์ˆ˜
17
+
18
+ # โ”€โ”€ ์ถœ๋ ฅ โ”€โ”€
19
+ save_dir: str = "./eval_results"
20
+ plot_dpi: int = 150
llm_lab/config/model_config.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class ModelConfig:
6
+ """๋ชจ๋ธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ•˜๋‚˜์˜ ๋ฐ์ดํ„ฐํด๋ž˜์Šค๋กœ ๊ด€๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
7
+
8
+ ๊ทœ๋ชจ๋ณ„ ํ”„๋ฆฌ์…‹:
9
+ - debug: ~10M (ํŒŒ์ดํ”„๋ผ์ธ ๊ฒ€์ฆ์šฉ)
10
+ - small: ~100M (์ค‘๊ฐ„ ๊ฒ€์ฆ์šฉ)
11
+ - base: ~1.1B (์ตœ์ข… ๋ชฉํ‘œ)
12
+ """
13
+ vocab_size: int = 32_000
14
+ hidden_dim: int = 2048 # d_model: ๋ชจ๋ธ์˜ ๊ธฐ๋ณธ ์ฐจ์›
15
+ num_layers: int = 22 # Transformer ๋ธ”๋ก ์ˆ˜
16
+ num_heads: int = 16 # Query ํ—ค๋“œ ์ˆ˜
17
+ num_kv_heads: int = 4 # Key/Value ํ—ค๋“œ ์ˆ˜ (GQA)
18
+ intermediate_dim: int = 5632 # FFN ์ค‘๊ฐ„ ์ฐจ์› (โ‰ˆ 2.75 ร— hidden_dim)
19
+ max_seq_len: int = 2048 # ์ตœ๋Œ€ ์‹œํ€€์Šค ๊ธธ์ด
20
+ dropout: float = 0.0 # Pretraining์—์„œ๋Š” ๋ณดํ†ต 0 ์‚ฌ์šฉ
21
+ rope_theta: float = 10000.0 # RoPE ์ฃผํŒŒ์ˆ˜ ๋ฒ ์ด์Šค
22
+ norm_eps: float = 1e-6 # RMSNorm epsilon
23
+
24
+ @property
25
+ def head_dim(self) -> int:
26
+ """๊ฐ ์–ดํ…์…˜ ํ—ค๋“œ์˜ ์ฐจ์›."""
27
+ return self.hidden_dim // self.num_heads
28
+
29
+ @property
30
+ def num_kv_groups(self) -> int:
31
+ """GQA์—์„œ ํ•˜๋‚˜์˜ KV ํ—ค๋“œ๊ฐ€ ๋‹ด๋‹นํ•˜๋Š” Q ํ—ค๋“œ ์ˆ˜."""
32
+ return self.num_heads // self.num_kv_heads
33
+
34
+ @classmethod
35
+ def debug_10m(cls) -> "ModelConfig":
36
+ """~10M ํŒŒ๋ผ๋ฏธํ„ฐ - ๋น ๋ฅธ ๋””๋ฒ„๊น…์šฉ."""
37
+ return cls(
38
+ hidden_dim=256, num_layers=6, num_heads=8,
39
+ num_kv_heads=4, intermediate_dim=704, max_seq_len=512,
40
+ )
41
+
42
+ @classmethod
43
+ def small_100m(cls) -> "ModelConfig":
44
+ """~100M ํŒŒ๋ผ๋ฏธํ„ฐ - ์ค‘๊ฐ„ ๊ฒ€์ฆ์šฉ."""
45
+ return cls(
46
+ hidden_dim=768, num_layers=12, num_heads=12,
47
+ num_kv_heads=4, intermediate_dim=2048, max_seq_len=1024,
48
+ )
49
+
50
+ @classmethod
51
+ def base_1b(cls) -> "ModelConfig":
52
+ """~1.1B ํŒŒ๋ผ๋ฏธํ„ฐ - ์ตœ์ข… ํ•™์Šต ๋ชฉํ‘œ."""
53
+ return cls() # ๊ธฐ๋ณธ๊ฐ’์ด 1B ์„ค์ •
llm_lab/config/train_config.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class TrainConfig:
9
+ """ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ + ์ธํ”„๋ผ ์„ค์ •.
10
+
11
+ Colab Pro+ (A100 40GB) ๊ธฐ์ค€ ์ตœ์ ํ™”๋œ ๊ธฐ๋ณธ๊ฐ’.
12
+ ๋ชจ๋“  ๊ฐ’์— '์™œ ์ด ๊ฐ’์ธ์ง€' ์„ค๋ช…์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
13
+ """
14
+
15
+ # โ”€โ”€ ์ตœ์ ํ™” โ”€โ”€
16
+ learning_rate: float = 3e-4
17
+ """Peak LR. 1B ๋ชจ๋ธ ๊ธฐ์ค€ 3e-4๊ฐ€ ํ‘œ์ค€.
18
+ GPT-3 ๋…ผ๋ฌธ์—์„œ ๋ชจ๋ธ ํฌ๊ธฐ๋ณ„ ์ตœ์  LR์„ ์ œ์‹œ:
19
+ 125M โ†’ 6e-4, 350M โ†’ 3e-4, 1.3B โ†’ 2e-4
20
+ ์šฐ๋ฆฌ ๋ชจ๋ธ(1.1B)์€ 3e-4์—์„œ ์‹œ์ž‘, ๋ถˆ์•ˆ์ •ํ•˜๋ฉด 2e-4๋กœ ํ•˜ํ–ฅ."""
21
+
22
+ min_learning_rate: float = 3e-5
23
+ """Cosine decay ์ตœ์ €์ . ๋ณดํ†ต peak์˜ 10%.
24
+ ๋„ˆ๋ฌด ๋‚ฎ์œผ๋ฉด ํ•™์Šต ํ›„๋ฐ˜ ์ •์ฒด, ๋„ˆ๋ฌด ๋†’์œผ๋ฉด ์ˆ˜๋ ด ๋ถˆ์•ˆ์ •."""
25
+
26
+ weight_decay: float = 0.1
27
+ """AdamW์˜ L2 ์ •๊ทœํ™”. 0.1์ด LLM ํ‘œ์ค€.
28
+ Embedding๊ณผ Bias์—๋Š” ์ ์šฉํ•˜์ง€ ์•Š์Œ (๊ด€๋ก€)."""
29
+
30
+ beta1: float = 0.9
31
+ beta2: float = 0.95
32
+ """Adam ๋ชจ๋ฉ˜ํ…€ ๊ณ„์ˆ˜. ฮฒ2=0.95๋Š” LLM ํ•™์Šต์—์„œ ฮฒ2=0.999๋ณด๋‹ค ์•ˆ์ •์ .
33
+ ํฐ ๋ฐฐ์น˜ + ๊ธด ํ•™์Šต์—์„œ ฮฒ2๊ฐ€ ๋„ˆ๋ฌด ํฌ๋ฉด ์ ์‘ ์†๋„๊ฐ€ ๋А๋ฆผ."""
34
+
35
+ adam_eps: float = 1e-8
36
+ grad_clip: float = 1.0
37
+ """Gradient Clipping: gradient norm์ด 1.0์„ ์ดˆ๊ณผํ•˜๋ฉด ์Šค์ผ€์ผ๋ง.
38
+ ํ•™์Šต ์ดˆ๋ฐ˜์ด๋‚˜ ๋…ธ์ด์ฆˆ ๋ฐ์ดํ„ฐ์—์„œ ๋ฐœ์ƒํ•˜๋Š” gradient spike ๋ฐฉ์ง€."""
39
+
40
+ # โ”€โ”€ ์Šค์ผ€์ค„๋ง โ”€โ”€
41
+ warmup_steps: int = 2000
42
+ """Warmup: ์ฒ˜์Œ 2000 ์Šคํ… ๋™์•ˆ LR์„ 0 โ†’ peak๋กœ ์„ ํ˜• ์ฆ๊ฐ€.
43
+ ์™œ ํ•„์š”ํ•œ๊ฐ€?
44
+ - ์ดˆ๊ธฐ ๊ฐ€์ค‘์น˜๊ฐ€ ๋žœ๋ค โ†’ ํฐ LR์€ ๋ถˆ์•ˆ์ •ํ•œ ์—…๋ฐ์ดํŠธ ์œ ๋ฐœ
45
+ - ์ž‘์€ LR๋กœ ์‹œ์ž‘ํ•ด ๋ชจ๋ธ์ด '๋ฐฉํ–ฅ'์„ ์žก๊ฒŒ ํ•œ ํ›„ ๋ณธ๊ฒฉ ํ•™์Šต
46
+ - 2000์€ ์ „์ฒด ํ•™์Šต์˜ ~10%๊ฐ€ ์ ๋‹น (๊ฒฝํ—˜์  ๊ทœ์น™)."""
47
+
48
+ total_steps: int = 20_000
49
+ """์ด ํ•™์Šต ์Šคํ… ์ˆ˜.
50
+ 10B tokens / (128 batch ร— 2048 seq_len) โ‰ˆ 38,000 ์ด์ง€๋งŒ,
51
+ gradient accumulation ํฌํ•จ effective step ๊ธฐ์ค€ ~20,000."""
52
+
53
+ # โ”€โ”€ ๋ฐฐ์น˜ โ”€โ”€
54
+ micro_batch_size: int = 4
55
+ """GPU์— ํ•œ ๋ฒˆ์— ์˜ฌ๋ฆฌ๋Š” ๋ฐฐ์น˜ ํฌ๊ธฐ.
56
+ A100 40GB์—์„œ 1B ๋ชจ๋ธ bf16 ๊ธฐ์ค€ 4๊ฐ€ ์•ˆ์ „ํ•œ ์ƒํ•œ."""
57
+
58
+ gradient_accumulation_steps: int = 32
59
+ """Gradient ๋ˆ„์  ํšŸ์ˆ˜. Effective batch = 4 ร— 32 = 128.
60
+ ์™œ ํฐ ๋ฐฐ์น˜๊ฐ€ ์ข‹์€๊ฐ€?
61
+ - gradient ์ถ”์ •์ด ์•ˆ์ •์  (๋…ธ์ด์ฆˆ ๊ฐ์†Œ)
62
+ - LLM ํ•™์Šต์€ ๋ณดํ†ต effective batch 128~512
63
+ - ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ ์‹œ ์ด ๊ฐ’์„ ๋Š˜๋ฆฌ๊ณ  micro_batch๋ฅผ ์ค„์ž„."""
64
+
65
+ # โ”€โ”€ Mixed Precision โ”€โ”€
66
+ dtype: str = "bfloat16"
67
+ """bfloat16: A100์—์„œ ์ง€์›, fp16๋ณด๋‹ค ์ˆ˜์น˜ ์•ˆ์ •์„ฑ ์šฐ์ˆ˜.
68
+ exponent ๋น„ํŠธ๊ฐ€ fp32์™€ ๋™์ผ โ†’ overflow/underflow ์œ„ํ—˜ ์ ์Œ.
69
+ T4/V100 ํด๋ฐฑ ์‹œ 'float16'์œผ๋กœ ๋ณ€๊ฒฝ."""
70
+
71
+ # โ”€โ”€ ์ฒดํฌํฌ์ธํŠธ โ”€โ”€
72
+ checkpoint_dir: str = "/content/drive/MyDrive/llm-1b-lab/checkpoints"
73
+ """Google Drive ๊ฒฝ๋กœ. Colab ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ์—๋„ ๋ณด์กด๋จ."""
74
+
75
+ checkpoint_interval: int = 500
76
+ """500 ์Šคํ…๋งˆ๋‹ค ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ.
77
+ A100 ๊ธฐ์ค€ ~30๋ถ„ ๊ฐ„๊ฒฉ. ๋„ˆ๋ฌด ์žฆ์œผ๋ฉด I/O ์˜ค๋ฒ„ํ—ค๋“œ,
78
+ ๋„ˆ๋ฌด ๋“œ๋ฌผ๋ฉด ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ ์†์‹ค ํผ."""
79
+
80
+ max_checkpoints: int = 3
81
+ """๋กค๋ง ๋ณด๊ด€ ์ˆ˜. ์˜ค๋ž˜๋œ ๊ฒƒ๋ถ€ํ„ฐ ์‚ญ์ œ.
82
+ ์ฒดํฌํฌ์ธํŠธ 1๊ฐœ โ‰ˆ 8-10GB โ†’ 3๊ฐœ๋ฉด ~30GB."""
83
+
84
+ # โ”€โ”€ ๋กœ๊น… โ”€โ”€
85
+ log_interval: int = 10
86
+ """10 ์Šคํ…๋งˆ๋‹ค ์ฝ˜์†” + wandb ๋กœ๊น…."""
87
+
88
+ eval_interval: int = 500
89
+ """500 ์Šคํ…๋งˆ๋‹ค ๊ฒ€์ฆ Loss ์ธก์ •."""
90
+
91
+ eval_steps: int = 20
92
+ """๊ฒ€์ฆ ์‹œ ์‚ฌ์šฉํ•  ๋ฐฐ์น˜ ์ˆ˜. 20 ร— 4 ร— 2048 โ‰ˆ 160K ํ† ํฐ."""
93
+
94
+ # โ”€โ”€ wandb โ”€โ”€
95
+ wandb_project: str = "llm-1b-lab"
96
+ wandb_run_name: Optional[str] = None
97
+ use_wandb: bool = True
98
+
99
+ # โ”€โ”€ ์žฌํ˜„์„ฑ โ”€โ”€
100
+ seed: int = 42
101
+
102
+ @property
103
+ def effective_batch_size(self) -> int:
104
+ return self.micro_batch_size * self.gradient_accumulation_steps
105
+
106
+ @property
107
+ def tokens_per_step(self) -> int:
108
+ """ํ•œ optimizer step๋‹น ์ฒ˜๋ฆฌ ํ† ํฐ ์ˆ˜."""
109
+ # max_seq_len์€ ์™ธ๋ถ€์—์„œ ์ฃผ์ž… (ModelConfig ์ฐธ์กฐ)
110
+ return self.effective_batch_size * 2048
111
+
112
+ @property
113
+ def torch_dtype(self) -> torch.dtype:
114
+ return {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[self.dtype]
llm_lab/data/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ๋ชจ๋“ˆ โ€” ํ† ํฌ๋‚˜์ด์ €, ์ŠคํŠธ๋ฆฌ๋ฐ, ์‹œํ€€์Šค ํŒจํ‚น."""
2
+ from .tokenizer import Tokenizer
3
+ from .dataset import PackedStreamingDataset, ValidationDataset
4
+ from .pipeline import create_train_dataloader, train_tokenizer_from_dataset, setup_data_pipeline
5
+ from .diagnostics import DataPipelineDiagnostics
6
+
7
+ __all__ = [
8
+ "Tokenizer", "PackedStreamingDataset", "ValidationDataset",
9
+ "create_train_dataloader", "train_tokenizer_from_dataset",
10
+ "setup_data_pipeline", "DataPipelineDiagnostics",
11
+ ]
llm_lab/data/dataset.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """์ŠคํŠธ๋ฆฌ๋ฐ ๋ฐ์ดํ„ฐ์…‹ โ€” ์‹œํ€€์Šค ํŒจํ‚น, ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹."""
2
+
3
+ from typing import Iterator, List, Dict, Optional
4
+
5
+ import torch
6
+ from torch.utils.data import IterableDataset, DataLoader
7
+
8
+ from llm_lab.config import DataConfig
9
+ from .tokenizer import Tokenizer
10
+
11
+
12
+ class PackedStreamingDataset(IterableDataset):
13
+ """Streaming + ์‹œํ€€์Šค ํŒจํ‚น ๋ฐ์ดํ„ฐ์…‹.
14
+
15
+ ์™œ ์‹œํ€€์Šค ํŒจํ‚น์ธ๊ฐ€?
16
+ - ์ผ๋ฐ˜์  ๋ฐฉ๋ฒ•: ๊ฐ ๋ฌธ์„œ๋ฅผ max_seq_len์œผ๋กœ ์ž˜๋ผ ํŒจ๋”ฉ โ†’ GPU ๋‚ญ๋น„
17
+ - ์‹œํ€€์Šค ํŒจํ‚น: ์—ฌ๋Ÿฌ ๋ฌธ์„œ๋ฅผ ์ด์–ด๋ถ™์—ฌ max_seq_len์„ ๊ฝ‰ ์ฑ„์›€ โ†’ 100% ํ™œ์šฉ
18
+
19
+ ๋™์ž‘ ๋ฐฉ์‹:
20
+ ๋ฌธ์„œ1 (300 ํ† ํฐ) + ๋ฌธ์„œ2 (1500 ํ† ํฐ) + ๋ฌธ์„œ3 (248 ํ† ํฐ) = 2048 ํ† ํฐ
21
+ โ†’ [๋ฌธ์„œ1][EOS][๋ฌธ์„œ2][EOS][๋ฌธ์„œ3][EOS][...ํŒจ๋”ฉ ์—†์ด ๋”ฑ ๋งž์ถค]
22
+
23
+ ์™œ Streaming์ธ๊ฐ€?
24
+ - FineWeb-Edu 10B ์ƒ˜ํ”Œ: ์••์ถ• ์ƒํƒœ์—์„œ๋„ ์ˆ˜์‹ญ GB
25
+ - Colab ๋””์Šคํฌ ํ•œ๊ณ„ (~200GB)์—์„œ ์ „์ฒด ๋‹ค์šด๋กœ๋“œ ๋ถˆ๊ฐ€
26
+ - Streaming: ํ•„์š”ํ•œ ๋งŒํผ๋งŒ ๋„คํŠธ์›Œํฌ์—์„œ ์ฝ์–ด์˜ด
27
+
28
+ ํ•™์Šต ์‹œ ์ฃผ์˜์‚ฌํ•ญ:
29
+ - ์‹œํ€€์Šค ๋‚ด ๋ฌธ์„œ ๊ฒฝ๊ณ„์— EOS ํ† ํฐ ์‚ฝ์ž…์œผ๋กœ ๋ชจ๋ธ์ด ๋ฌธ์„œ ๋์„ ์ธ์‹
30
+ - Cross-Attention ๋งˆ์Šคํฌ ์—†์ด๋„ EOS๊ฐ€ ์ž์—ฐ์Šค๋Ÿฌ์šด ๊ฒฝ๊ณ„ ์—ญํ• 
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ tokenizer: Tokenizer,
36
+ config: DataConfig,
37
+ split: str = "train",
38
+ seed: int = 42,
39
+ ):
40
+ super().__init__()
41
+ self.tokenizer = tokenizer
42
+ self.config = config
43
+ self.split = split
44
+ self.seed = seed
45
+ self.max_seq_len = config.max_seq_len
46
+
47
+ def _load_dataset(self):
48
+ """HuggingFace ๋ฐ์ดํ„ฐ์…‹์„ ์ŠคํŠธ๋ฆฌ๋ฐ ๋ชจ๋“œ๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
49
+ from datasets import load_dataset
50
+
51
+ ds = load_dataset(
52
+ self.config.dataset_name,
53
+ name=self.config.dataset_subset,
54
+ split=self.config.dataset_split,
55
+ streaming=True, # ํ•ต์‹ฌ: ์ŠคํŠธ๋ฆฌ๋ฐ ๋ชจ๋“œ
56
+ trust_remote_code=True,
57
+ )
58
+
59
+ # ์…”ํ”Œ (์ŠคํŠธ๋ฆฌ๋ฐ์—์„œ๋Š” ๋ฒ„ํผ ๊ธฐ๋ฐ˜ ๊ทผ์‚ฌ ์…”ํ”Œ)
60
+ ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
61
+
62
+ return ds
63
+
64
+ def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
65
+ """๋ฌธ์„œ๋ฅผ ํ† ํฌ๋‚˜์ด์ฆˆํ•˜๊ณ  ์‹œํ€€์Šค ํŒจํ‚นํ•ฉ๋‹ˆ๋‹ค.
66
+
67
+ Yields:
68
+ {"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
69
+
70
+ targets = input_ids๋ฅผ ํ•œ ์นธ shift:
71
+ input_ids: [A, B, C, D, E]
72
+ targets: [B, C, D, E, F]
73
+ โ†’ ๋ชจ๋ธ์€ A๋ฅผ ๋ณด๊ณ  B๋ฅผ ์˜ˆ์ธก, B๋ฅผ ๋ณด๊ณ  C๋ฅผ ์˜ˆ์ธก, ...
74
+ """
75
+ buffer: List[int] = [] # ํ† ํฐ ๋ฒ„ํผ
76
+
77
+ for example in dataset:
78
+ text = example[self.config.text_column]
79
+ if not text or not text.strip():
80
+ continue
81
+
82
+ # ํ† ํฌ๋‚˜์ด์ฆˆ (ํŠน์ˆ˜ ํ† ํฐ ์—†์ด)
83
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
84
+
85
+ if not token_ids:
86
+ continue
87
+
88
+ # EOS ํ† ํฐ ์ถ”๊ฐ€ (๋ฌธ์„œ ๊ฒฝ๊ณ„ ํ‘œ์‹œ)
89
+ if self.config.use_eos_separator:
90
+ token_ids.append(self.tokenizer.eos_id)
91
+
92
+ # ๋ฒ„ํผ์— ์ถ”๊ฐ€
93
+ buffer.extend(token_ids)
94
+
95
+ # ๋ฒ„ํผ๊ฐ€ ์ถฉ๋ถ„ํžˆ ์ฐจ๋ฉด ์‹œํ€€์Šค ์ƒ์„ฑ
96
+ # +1์€ targets ์ƒ์„ฑ์„ ์œ„ํ•ด (input + ๋‹ค์Œ ํ† ํฐ)
97
+ while len(buffer) >= self.max_seq_len + 1:
98
+ # max_seq_len + 1 ๋งŒํผ ๊บผ๋ƒ„
99
+ chunk = buffer[: self.max_seq_len + 1]
100
+ buffer = buffer[self.max_seq_len + 1 :]
101
+
102
+ # input_ids: ์ฒ˜์Œ ~ ๋์—์„œ ๋‘ ๋ฒˆ์งธ
103
+ input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
104
+ # targets: ๋‘ ๋ฒˆ์งธ ~ ๋ (ํ•œ ์นธ shift)
105
+ targets = torch.tensor(chunk[1:], dtype=torch.long)
106
+
107
+ yield {"input_ids": input_ids, "targets": targets}
108
+
109
+ def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
110
+ """DataLoader๊ฐ€ ํ˜ธ์ถœํ•˜๋Š” ์ดํ„ฐ๋ ˆ์ดํ„ฐ.
111
+
112
+ ๋ฉ€ํ‹ฐ ์›Œ์ปค ์ง€์›:
113
+ - ๊ฐ ์›Œ์ปค๊ฐ€ ์„œ๋กœ ๋‹ค๋ฅธ ์‹œ๋“œ๋กœ ์…”ํ”Œ๋œ ์ŠคํŠธ๋ฆผ์„ ์ฒ˜๋ฆฌ
114
+ - ์›Œ์ปค ๊ฐ„ ๋ฐ์ดํ„ฐ ์ค‘๋ณต์„ ์ตœ์†Œํ™”
115
+ """
116
+ worker_info = torch.utils.data.get_worker_info()
117
+
118
+ if worker_info is not None:
119
+ # ๋ฉ€ํ‹ฐ ์›Œ์ปค: ๊ฐ ์›Œ์ปค์— ๋‹ค๋ฅธ ์‹œ๋“œ
120
+ worker_seed = self.seed + worker_info.id
121
+ else:
122
+ worker_seed = self.seed
123
+
124
+ # ์›Œ์ปค๋ณ„ ์‹œ๋“œ๋กœ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
125
+ self.seed = worker_seed
126
+ dataset = self._load_dataset()
127
+
128
+ return self._tokenize_and_pack(dataset)
129
+
130
+
131
+ class ValidationDataset:
132
+ """๊ฒ€์ฆ์šฉ ๋ฐ์ดํ„ฐ์…‹.
133
+
134
+ Streaming ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ผ์ •๋Ÿ‰์„ ๋ฏธ๋ฆฌ ๊ฐ€์ ธ์™€ ๋ฉ”๋ชจ๋ฆฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
135
+ ๋งค ์—ํญ ๋™์ผํ•œ ๋ฐ์ดํ„ฐ๋กœ ํ‰๊ฐ€ํ•ด์•ผ ๋น„๊ต๊ฐ€ ์˜๋ฏธ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ tokenizer: Tokenizer,
141
+ config: DataConfig,
142
+ num_samples: int = 100,
143
+ seed: int = 9999,
144
+ ):
145
+ self.tokenizer = tokenizer
146
+ self.config = config
147
+ self.num_samples = num_samples
148
+ self.samples: List[Dict[str, torch.Tensor]] = []
149
+
150
+ self._prepare(seed)
151
+
152
+ def _prepare(self, seed: int):
153
+ """๋ฐ์ดํ„ฐ์…‹์—์„œ ๊ฒ€์ฆ ์ƒ˜ํ”Œ์„ ๋ฏธ๋ฆฌ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค."""
154
+ from datasets import load_dataset
155
+
156
+ print(f"[Validation] {self.num_samples}๊ฐœ ๊ฒ€์ฆ ์ƒ˜ํ”Œ ์ค€๋น„ ์ค‘...")
157
+
158
+ ds = load_dataset(
159
+ self.config.dataset_name,
160
+ name=self.config.dataset_subset,
161
+ split=self.config.dataset_split,
162
+ streaming=True,
163
+ trust_remote_code=True,
164
+ )
165
+ # ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ๊ฒน์น˜์ง€ ์•Š๋„๋ก ๋‹ค๋ฅธ ์‹œ๋“œ, ์•ž๋ถ€๋ถ„ ๊ฑด๋„ˆ๋›ฐ๊ธฐ
166
+ ds = ds.shuffle(seed=seed, buffer_size=5_000)
167
+
168
+ buffer: List[int] = []
169
+ count = 0
170
+
171
+ for example in ds:
172
+ if count >= self.num_samples:
173
+ break
174
+
175
+ text = example[self.config.text_column]
176
+ if not text or not text.strip():
177
+ continue
178
+
179
+ token_ids = self.tokenizer.encode(text, add_special_tokens=False)
180
+ if not token_ids:
181
+ continue
182
+
183
+ token_ids.append(self.tokenizer.eos_id)
184
+ buffer.extend(token_ids)
185
+
186
+ while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples:
187
+ chunk = buffer[: self.config.max_seq_len + 1]
188
+ buffer = buffer[self.config.max_seq_len + 1 :]
189
+
190
+ self.samples.append({
191
+ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
192
+ "targets": torch.tensor(chunk[1:], dtype=torch.long),
193
+ })
194
+ count += 1
195
+
196
+ print(f"[Validation] {len(self.samples)}๊ฐœ ์ƒ˜ํ”Œ ์ค€๋น„ ์™„๋ฃŒ")
197
+
198
+ def get_dataloader(self, batch_size: int) -> DataLoader:
199
+ """๊ฒ€์ฆ DataLoader๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
200
+ return DataLoader(
201
+ self.samples,
202
+ batch_size=batch_size,
203
+ shuffle=False,
204
+ num_workers=0,
205
+ collate_fn=_collate_fn,
206
+ )
207
+
208
+
209
+ def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
210
+ """๋ฐฐ์น˜ ๋‚ด ์ƒ˜ํ”Œ๋“ค์„ ํ•˜๋‚˜์˜ ํ…์„œ๋กœ ํ•ฉ์นฉ๋‹ˆ๋‹ค.
211
+
212
+ ์‹œํ€€์Šค ํŒจํ‚น ๋•๋ถ„์— ๋ชจ๋“  ์ƒ˜ํ”Œ์ด ๋™์ผํ•œ ๊ธธ์ด(max_seq_len)์ด๋ฏ€๋กœ
213
+ ์ถ”๊ฐ€ ํŒจ๋”ฉ์ด ํ•„์š” ์—†์Šต๋‹ˆ๋‹ค.
214
+ """
215
+ return {
216
+ "input_ids": torch.stack([s["input_ids"] for s in batch]),
217
+ "targets": torch.stack([s["targets"] for s in batch]),
218
+ }
llm_lab/data/diagnostics.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์ง„๋‹จ ๋„๊ตฌ."""
2
+
3
+ import time
4
+ from typing import Dict
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ from llm_lab.config import DataConfig
10
+ from .tokenizer import Tokenizer
11
+
12
+
13
+ class DataPipelineDiagnostics:
14
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ์˜ ์„ฑ๋Šฅ๊ณผ ํ’ˆ์งˆ์„ ์ง„๋‹จํ•ฉ๋‹ˆ๋‹ค.
15
+
16
+ ํ•™์Šต ์ „ ๋ฐ˜๋“œ์‹œ ํ™•์ธํ•ด์•ผ ํ•  ํ•ญ๋ชฉ:
17
+ 1) ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ: ํ‰๊ท  ํ† ํฐ/๋ฌธ์„œ, ์•Œ ์ˆ˜ ์—†๋Š” ํ† ํฐ ๋น„์œจ
18
+ 2) ํŒจํ‚น ํšจ์œจ: ์‹ค์ œ ํ† ํฐ ๋น„์œจ vs ํŒจ๋”ฉ ๋น„์œจ
19
+ 3) ์ฒ˜๋ฆฌ ์†๋„: tokens/sec (๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ๋ณ‘๋ชฉ ํ™•์ธ)
20
+ 4) ๋ฐฐ์น˜ ํ˜•ํƒœ: shape, dtype ์ •ํ™•์„ฑ
21
+ """
22
+
23
+ @staticmethod
24
+ def check_tokenizer_quality(
25
+ tokenizer: Tokenizer,
26
+ config: DataConfig,
27
+ num_samples: int = 1000,
28
+ ):
29
+ """ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ์„ ์ง„๋‹จํ•ฉ๋‹ˆ๋‹ค."""
30
+ from datasets import load_dataset
31
+
32
+ print("\n" + "=" * 60)
33
+ print("๐Ÿ“Š ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ ์ง„๋‹จ")
34
+ print("=" * 60)
35
+
36
+ ds = load_dataset(
37
+ config.dataset_name,
38
+ name=config.dataset_subset,
39
+ split=config.dataset_split,
40
+ streaming=True,
41
+ trust_remote_code=True,
42
+ )
43
+
44
+ token_counts = []
45
+ char_counts = []
46
+ sample_count = 0
47
+
48
+ for example in ds:
49
+ if sample_count >= num_samples:
50
+ break
51
+ text = example[config.text_column]
52
+ if not text or not text.strip():
53
+ continue
54
+
55
+ tokens = tokenizer.encode(text)
56
+ token_counts.append(len(tokens))
57
+ char_counts.append(len(text))
58
+ sample_count += 1
59
+
60
+ avg_tokens = sum(token_counts) / len(token_counts)
61
+ avg_chars = sum(char_counts) / len(char_counts)
62
+ compression_ratio = avg_chars / avg_tokens # ๋ฌธ์ž/ํ† ํฐ ๋น„์œจ
63
+
64
+ print(f" ๋ถ„์„ ๋ฌธ์„œ ์ˆ˜: {len(token_counts):,}")
65
+ print(f" ํ‰๊ท  ํ† ํฐ/๋ฌธ์„œ: {avg_tokens:.1f}")
66
+ print(f" ํ‰๊ท  ๋ฌธ์ž/๋ฌธ์„œ: {avg_chars:.1f}")
67
+ print(f" ์••์ถ• ๋น„์œจ (๋ฌธ์ž/ํ† ํฐ): {compression_ratio:.2f}")
68
+ print(f" โ†’ ์˜์–ด ๊ธฐ์ค€ 3.5~4.5๊ฐ€ ์ •์ƒ")
69
+ print(f" ์ตœ์†Œ ํ† ํฐ: {min(token_counts)}, ์ตœ๋Œ€: {max(token_counts)}")
70
+
71
+ # ๋””์ฝ”๋“œ ์™•๋ณต ํ…Œ์ŠคํŠธ
72
+ test_text = "The quick brown fox jumps over the lazy dog."
73
+ encoded = tokenizer.encode(test_text)
74
+ decoded = tokenizer.decode(encoded)
75
+ roundtrip_ok = test_text.strip() in decoded.strip()
76
+ print(f"\n ์™•๋ณต ํ…Œ์ŠคํŠธ: {'โœ… ํ†ต๊ณผ' if roundtrip_ok else 'โŒ ์‹คํŒจ'}")
77
+ print(f" ์›๋ณธ: {test_text}")
78
+ print(f" ์ธ์ฝ”๋”ฉ: {encoded[:20]}{'...' if len(encoded) > 20 else ''}")
79
+ print(f" ๋””์ฝ”๋”ฉ: {decoded}")
80
+
81
+ @staticmethod
82
+ def benchmark_throughput(
83
+ dataloader: DataLoader,
84
+ num_batches: int = 50,
85
+ seq_len: int = 2048,
86
+ ):
87
+ """๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ฒ˜๋ฆฌ๋Ÿ‰์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
88
+
89
+ GPU ํ•™์Šต ์†๋„์˜ ๋ณ‘๋ชฉ์ด ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ธ์ง€ ํ™•์ธํ•˜๋Š” ํ•ต์‹ฌ ์ง„๋‹จ.
90
+ ๋ชฉํ‘œ: ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด GPU ์—ฐ์‚ฐ๋ณด๋‹ค ๋นจ๋ผ์•ผ ํ•จ (data loading โ‰  bottleneck).
91
+ """
92
+ print("\n" + "=" * 60)
93
+ print("โšก ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ฒค์น˜๋งˆํฌ")
94
+ print("=" * 60)
95
+
96
+ total_tokens = 0
97
+ start_time = time.time()
98
+
99
+ for i, batch in enumerate(dataloader):
100
+ if i >= num_batches:
101
+ break
102
+ batch_tokens = batch["input_ids"].numel()
103
+ total_tokens += batch_tokens
104
+
105
+ if (i + 1) % 10 == 0:
106
+ elapsed = time.time() - start_time
107
+ tps = total_tokens / elapsed
108
+ print(f" Batch {i+1}: {tps:,.0f} tokens/sec")
109
+
110
+ elapsed = time.time() - start_time
111
+ tps = total_tokens / elapsed
112
+
113
+ print(f"\n ์ด ๋ฐฐ์น˜ ์ˆ˜: {num_batches}")
114
+ print(f" ์ด ํ† ํฐ ์ˆ˜: {total_tokens:,}")
115
+ print(f" ์†Œ์š” ์‹œ๊ฐ„: {elapsed:.2f}์ดˆ")
116
+ print(f" ํ‰๊ท  ์ฒ˜๋ฆฌ๋Ÿ‰: {tps:,.0f} tokens/sec")
117
+ print(f"\n ๐Ÿ’ก A100 ํ•™์Šต ์ฒ˜๋ฆฌ๋Ÿ‰ ~50-80K tokens/sec ๊ธฐ์ค€:")
118
+ if tps > 80_000:
119
+ print(f" โœ… ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด ๋ณ‘๋ชฉ์ด ์•„๋‹™๋‹ˆ๋‹ค")
120
+ elif tps > 30_000:
121
+ print(f" โš ๏ธ ๊ฒฝ๊ณ„์„  - num_workers ์ฆ๊ฐ€๋ฅผ ๊ณ ๋ คํ•˜์„ธ์š”")
122
+ else:
123
+ print(f" โŒ ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ์ด ๋ณ‘๋ชฉ! num_workers/prefetch ์กฐ์ • ํ•„์š”")
124
+
125
+ @staticmethod
126
+ def inspect_batch(batch: Dict[str, torch.Tensor], tokenizer: Tokenizer):
127
+ """๋ฐฐ์น˜ ํ•˜๋‚˜๋ฅผ ์ƒ์„ธ ๊ฒ€์‚ฌํ•ฉ๋‹ˆ๋‹ค."""
128
+ print("\n" + "=" * 60)
129
+ print("๐Ÿ” ๋ฐฐ์น˜ ์ƒ์„ธ ๊ฒ€์‚ฌ")
130
+ print("=" * 60)
131
+
132
+ input_ids = batch["input_ids"]
133
+ targets = batch["targets"]
134
+
135
+ print(f" input_ids shape: {input_ids.shape}")
136
+ print(f" targets shape: {targets.shape}")
137
+ print(f" dtype: {input_ids.dtype}")
138
+ print(f" ๊ฐ’ ๋ฒ”์œ„: [{input_ids.min().item()}, {input_ids.max().item()}]")
139
+
140
+ # Shift ๊ด€๊ณ„ ํ™•์ธ: targets[i] == input_ids[i+1]
141
+ shift_correct = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
142
+ print(f" Shift ์ •ํ•ฉ์„ฑ: {shift_correct*100:.1f}% (100%์—ฌ์•ผ ์ •์ƒ)")
143
+
144
+ # EOS ํ† ํฐ ๋ถ„ํฌ (๋ฌธ์„œ ๊ฒฝ๊ณ„)
145
+ eos_count = (input_ids == tokenizer.eos_id).sum().item()
146
+ total_tokens = input_ids.numel()
147
+ print(f" EOS ํ† ํฐ ์ˆ˜: {eos_count} / {total_tokens} ({eos_count/total_tokens*100:.2f}%)")
148
+
149
+ # ์ฒซ ๋ฒˆ์งธ ์ƒ˜ํ”Œ ๋””์ฝ”๋”ฉ ๋ฏธ๋ฆฌ๋ณด๊ธฐ
150
+ first_sample = input_ids[0][:100].tolist()
151
+ decoded_preview = tokenizer.decode(first_sample)
152
+ print(f"\n ์ฒซ ์ƒ˜ํ”Œ ๋””์ฝ”๋”ฉ (์ฒ˜์Œ 100 ํ† ํฐ):")
153
+ print(f" {decoded_preview[:300]}...")
llm_lab/data/pipeline.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ํ†ตํ•ฉ โ€” DataLoader ์ƒ์„ฑ, ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต, Quick Start."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+
8
+ from llm_lab.config import DataConfig
9
+ from .tokenizer import Tokenizer
10
+ from .dataset import PackedStreamingDataset, ValidationDataset, _collate_fn
11
+
12
+
13
+ def create_train_dataloader(
14
+ tokenizer: Tokenizer,
15
+ config: DataConfig,
16
+ seed: int = 42,
17
+ ) -> DataLoader:
18
+ """ํ•™์Šต์šฉ DataLoader๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
19
+
20
+ Returns:
21
+ ๋ฌดํ•œํžˆ ๋ฐ˜๋ณต๋˜๋Š” ์ŠคํŠธ๋ฆฌ๋ฐ DataLoader
22
+
23
+ ์‚ฌ์šฉ๋ฒ•:
24
+ dataloader = create_train_dataloader(tokenizer, config)
25
+ for step, batch in enumerate(dataloader):
26
+ input_ids = batch["input_ids"].to(device) # (B, seq_len)
27
+ targets = batch["targets"].to(device) # (B, seq_len)
28
+ logits, loss = model(input_ids, targets)
29
+ ...
30
+ """
31
+ dataset = PackedStreamingDataset(
32
+ tokenizer=tokenizer,
33
+ config=config,
34
+ split="train",
35
+ seed=seed,
36
+ )
37
+
38
+ dataloader = DataLoader(
39
+ dataset,
40
+ batch_size=config.batch_size,
41
+ num_workers=config.num_workers,
42
+ prefetch_factor=config.prefetch_factor if config.num_workers > 0 else None,
43
+ pin_memory=True, # GPU ์ „์†ก ์†๋„ ํ–ฅ์ƒ
44
+ collate_fn=_collate_fn,
45
+ )
46
+
47
+ return dataloader
48
+
49
+
50
+ def train_tokenizer_from_dataset(config: DataConfig) -> Tokenizer:
51
+ """๋ฐ์ดํ„ฐ์…‹์—์„œ BPE ํ† ํฌ๋‚˜์ด์ €๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
52
+
53
+ ์ „์ฒด ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค ์‚ฌ์šฉํ•  ํ•„์š” ์—†์ด, 50K ๋ฌธ์„œ๋ฉด ์ถฉ๋ถ„ํ•ฉ๋‹ˆ๋‹ค.
54
+ ํ† ํฌ๋‚˜์ด์ € vocab์€ ์ „์ฒด ๋ฐ์ดํ„ฐ์˜ ํ†ต๊ณ„๋ฅผ ๋ฐ˜์˜ํ•˜๋ฉด ๋˜๋ฏ€๋กœ.
55
+ """
56
+ from datasets import load_dataset
57
+
58
+ print(f"[Train Tokenizer] {config.dataset_name}์—์„œ ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต")
59
+ print(f"[Train Tokenizer] ํ•™์Šต ๋ฌธ์„œ ์ˆ˜: {config.tokenizer_train_samples:,}")
60
+
61
+ # ํ…์ŠคํŠธ ์ดํ„ฐ๋ ˆ์ดํ„ฐ ์ƒ์„ฑ
62
+ ds = load_dataset(
63
+ config.dataset_name,
64
+ name=config.dataset_subset,
65
+ split=config.dataset_split,
66
+ streaming=True,
67
+ trust_remote_code=True,
68
+ )
69
+
70
+ def text_iterator():
71
+ count = 0
72
+ for example in ds:
73
+ if count >= config.tokenizer_train_samples:
74
+ break
75
+ text = example[config.text_column]
76
+ if text and text.strip():
77
+ yield text
78
+ count += 1
79
+ if count % 10_000 == 0:
80
+ print(f" ... {count:,} ๋ฌธ์„œ ์ฒ˜๋ฆฌ")
81
+
82
+ # ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต
83
+ tokenizer = Tokenizer(config)
84
+ tokenizer.train_bpe(text_iterator(), save_dir=config.tokenizer_save_dir)
85
+
86
+ return tokenizer
87
+
88
+
89
+ def setup_data_pipeline(
90
+ tokenizer_mode: str = "train_new",
91
+ tokenizer_path: Optional[str] = None,
92
+ config: Optional[DataConfig] = None,
93
+ ) -> tuple:
94
+ """๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ํ•œ ๋ฒˆ์— ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
95
+
96
+ Args:
97
+ tokenizer_mode:
98
+ "train_new" - BPE ํ† ํฌ๋‚˜์ด์ € ์ƒˆ๋กœ ํ•™์Šต
99
+ "load_trained" - ์ด์ „์— ํ•™์Šตํ•œ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
100
+ "pretrained" - HuggingFace ์‚ฌ์ „ํ•™์Šต ํ† ํฌ๋‚˜์ด์ € ์‚ฌ์šฉ
101
+ tokenizer_path:
102
+ "train_new" โ†’ ์ €์žฅ ๊ฒฝ๋กœ (๊ธฐ๋ณธ: ./tokenizer)
103
+ "load_trained" โ†’ ์ €์žฅ๋œ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ
104
+ "pretrained" โ†’ HF ๋ชจ๋ธ๋ช… (๊ธฐ๋ณธ: mistralai/Mistral-7B-v0.1)
105
+
106
+ Returns:
107
+ (tokenizer, train_dataloader, val_dataloader)
108
+
109
+ ์‚ฌ์šฉ ์˜ˆ์‹œ (Colab):
110
+ # ๋ฐฉ๋ฒ• 1: ํ† ํฌ๋‚˜์ด์ € ์ƒˆ๋กœ ํ•™์Šต
111
+ tok, train_dl, val_dl = setup_data_pipeline("train_new")
112
+
113
+ # ๋ฐฉ๋ฒ• 2: ๊ธฐ์กด ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
114
+ tok, train_dl, val_dl = setup_data_pipeline("load_trained", "./tokenizer")
115
+
116
+ # ๋ฐฉ๋ฒ• 3: ์‚ฌ์ „ํ•™์Šต ํ† ํฌ๋‚˜์ด์ € (๊ฐ€์žฅ ๊ฐ„ํŽธ)
117
+ tok, train_dl, val_dl = setup_data_pipeline("pretrained")
118
+ """
119
+ config = config or DataConfig()
120
+
121
+ print("=" * 60)
122
+ print("๐Ÿš€ ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •")
123
+ print("=" * 60)
124
+
125
+ # โ”€โ”€ Step 1: ํ† ํฌ๋‚˜์ด์ € โ”€โ”€
126
+ tokenizer = Tokenizer(config)
127
+
128
+ if tokenizer_mode == "train_new":
129
+ tokenizer = train_tokenizer_from_dataset(config)
130
+ elif tokenizer_mode == "load_trained":
131
+ path = tokenizer_path or config.tokenizer_save_dir
132
+ tokenizer.load_trained_hf(path)
133
+ elif tokenizer_mode == "pretrained":
134
+ name = tokenizer_path or "mistralai/Mistral-7B-v0.1"
135
+ tokenizer.load_pretrained_hf(name)
136
+ else:
137
+ raise ValueError(f"Unknown tokenizer_mode: {tokenizer_mode}")
138
+
139
+ # โ”€โ”€ Step 2: ํ•™์Šต DataLoader โ”€โ”€
140
+ print("\n[DataLoader] ํ•™์Šต DataLoader ์ƒ์„ฑ...")
141
+ train_dataloader = create_train_dataloader(tokenizer, config)
142
+
143
+ # โ”€โ”€ Step 3: ๊ฒ€์ฆ DataLoader โ”€โ”€
144
+ print("\n[DataLoader] ๊ฒ€์ฆ DataLoader ์ƒ์„ฑ...")
145
+ val_dataset = ValidationDataset(tokenizer, config, num_samples=100)
146
+ val_dataloader = val_dataset.get_dataloader(batch_size=config.batch_size)
147
+
148
+ print("\n" + "=" * 60)
149
+ print("โœ… ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ • ์™„๋ฃŒ!")
150
+ print(f" ํ† ํฌ๋‚˜์ด์ € vocab: {tokenizer.vocab_size:,}")
151
+ print(f" ์‹œํ€€์Šค ๊ธธ์ด: {config.max_seq_len}")
152
+ print(f" ๋ฐฐ์น˜ ํฌ๊ธฐ: {config.batch_size}")
153
+ print(f" ํ† ํฐ/๋ฐฐ์น˜: {config.batch_size * config.max_seq_len:,}")
154
+ print("=" * 60)
155
+
156
+ return tokenizer, train_dataloader, val_dataloader
llm_lab/data/tokenizer.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ† ํฌ๋‚˜์ด์ € ๋ž˜ํผ โ€” SentencePiece / HuggingFace BPE ํ†ตํ•ฉ."""
2
+
3
+ import os
4
+ import json
5
+ from typing import Optional, Iterator, List
6
+
7
+ from llm_lab.config import DataConfig
8
+
9
+
10
+ class Tokenizer:
11
+ """ํ† ํฌ๋‚˜์ด์ € ํ†ตํ•ฉ ๋ž˜ํผ.
12
+
13
+ ์„ธ ๊ฐ€์ง€ ๋ฐฉ๋ฒ• ์ง€์›:
14
+ 1) ๊ธฐ์กด SentencePiece ๋ชจ๋ธ ๋กœ๋“œ
15
+ 2) HuggingFace tokenizers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ์ƒˆ๋กœ ํ•™์Šต
16
+ 3) ์‚ฌ์ „ ํ•™์Šต๋œ HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (์˜ˆ: LLaMA tokenizer)
17
+
18
+ ์™œ ์ง์ ‘ ๊ตฌํ˜„ํ•˜์ง€ ์•Š๋Š”๊ฐ€?
19
+ - BPE ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต์€ ๋Œ€๊ทœ๋ชจ ํ…์ŠคํŠธ ํ†ต๊ณ„ ์ฒ˜๋ฆฌ์ด๋ฉฐ,
20
+ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ดํ•ด์™€ ์ง์ ‘์  ๊ด€๋ จ์ด ์ ์Šต๋‹ˆ๋‹ค.
21
+ - ๋‹ค๋งŒ ํ† ํฌ๋‚˜์ด์ €์˜ ๋™์ž‘ ์›๋ฆฌ(BPE ๋ณ‘ํ•ฉ ๊ทœ์น™)๋Š” ์ดํ•ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
22
+
23
+ BPE(Byte Pair Encoding) ํ•ต์‹ฌ ์›๋ฆฌ:
24
+ 1) ํ…์ŠคํŠธ๋ฅผ ๋ฐ”์ดํŠธ/๋ฌธ์ž ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌ
25
+ 2) ๊ฐ€์žฅ ๋นˆ๋ฒˆํ•œ ์ธ์ ‘ ์Œ์„ ๋ฐ˜๋ณต์ ์œผ๋กœ ๋ณ‘ํ•ฉ
26
+ 3) vocab_size์— ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต
27
+ โ†’ ์ž์ฃผ ๋“ฑ์žฅํ•˜๋Š” ๋‹จ์–ด๋Š” ํ•˜๋‚˜์˜ ํ† ํฐ, ํฌ๊ท€ ๋‹จ์–ด๋Š” ์—ฌ๋Ÿฌ ํ† ํฐ์œผ๋กœ ๋ถ„๋ฆฌ
28
+ """
29
+
30
+ def __init__(self, config: DataConfig):
31
+ self.config = config
32
+ self._tokenizer = None
33
+ self.vocab_size = config.vocab_size
34
+
35
+ # ํŠน์ˆ˜ ํ† ํฐ ID (์ดˆ๊ธฐํ™” ํ›„ ์„ค์ •๋จ)
36
+ self.bos_id: int = 1 # Beginning of Sequence
37
+ self.eos_id: int = 2 # End of Sequence
38
+ self.pad_id: int = 0 # Padding
39
+
40
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
41
+ # ๋ฐฉ๋ฒ• 1: SentencePiece ๋ชจ๋ธ ๋กœ๋“œ
42
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
43
+
44
+ def load_sentencepiece(self, model_path: str):
45
+ """๊ธฐ์กด SentencePiece ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
46
+ import sentencepiece as spm
47
+
48
+ self._tokenizer = spm.SentencePieceProcessor()
49
+ self._tokenizer.Load(model_path)
50
+
51
+ self.vocab_size = self._tokenizer.GetPieceSize()
52
+ self.bos_id = self._tokenizer.bos_id()
53
+ self.eos_id = self._tokenizer.eos_id()
54
+ self.pad_id = self._tokenizer.pad_id()
55
+ self._encode_fn = self._tokenizer.Encode
56
+ self._decode_fn = self._tokenizer.Decode
57
+
58
+ print(f"[Tokenizer] SentencePiece ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
59
+
60
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
61
+ # ๋ฐฉ๋ฒ• 2: HuggingFace tokenizers๋กœ BPE ํ•™์Šต
62
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
63
+
64
+ def train_bpe(self, text_iterator: Iterator[str], save_dir: Optional[str] = None):
65
+ """BPE ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
66
+
67
+ Args:
68
+ text_iterator: ํ•™์Šต ํ…์ŠคํŠธ๋ฅผ yieldํ•˜๋Š” ์ดํ„ฐ๋ ˆ์ดํ„ฐ
69
+ save_dir: ์ €์žฅ ๊ฒฝ๋กœ
70
+
71
+ ํ•™์Šต ํฌ์ธํŠธ:
72
+ - vocab_size๊ฐ€ ํด์ˆ˜๋ก: ์ž์ฃผ ์“ฐ๋Š” ํ‘œํ˜„์ด 1ํ† ํฐ โ†’ ์‹œํ€€์Šค ์งง์•„์ง
73
+ - vocab_size๊ฐ€ ์ž‘์„์ˆ˜๋ก: Embedding ํŒŒ๋ผ๋ฏธํ„ฐ ์ ˆ์•ฝ, ํ•˜์ง€๋งŒ ์‹œํ€€์Šค ๊ธธ์–ด์ง
74
+ - 32K๋Š” ์˜์–ด ๊ธฐ์ค€ ์ข‹์€ ๊ท ํ˜•์ 
75
+ """
76
+ from tokenizers import Tokenizer as HFTokenizer
77
+ from tokenizers.models import BPE
78
+ from tokenizers.trainers import BpeTrainer
79
+ from tokenizers.pre_tokenizers import ByteLevel
80
+ from tokenizers.processors import TemplateProcessing
81
+
82
+ print("[Tokenizer] BPE ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต ์‹œ์ž‘...")
83
+
84
+ # BPE ๋ชจ๋ธ ์ƒ์„ฑ
85
+ tokenizer = HFTokenizer(BPE(unk_token="<unk>"))
86
+ tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
87
+
88
+ # ํŠน์ˆ˜ ํ† ํฐ ์ •์˜
89
+ special_tokens = ["<pad>", "<s>", "</s>", "<unk>"]
90
+
91
+ # ํŠธ๋ ˆ์ด๋„ˆ ์„ค์ •
92
+ trainer = BpeTrainer(
93
+ vocab_size=self.config.vocab_size,
94
+ special_tokens=special_tokens,
95
+ min_frequency=2, # ์ตœ์†Œ 2๋ฒˆ ๋“ฑ์žฅํ•œ ์Œ๋งŒ ๋ณ‘ํ•ฉ
96
+ show_progress=True,
97
+ )
98
+
99
+ # ํ•™์Šต ์‹คํ–‰
100
+ tokenizer.train_from_iterator(text_iterator, trainer=trainer)
101
+
102
+ # ํ›„์ฒ˜๋ฆฌ: BOS/EOS ์ž๋™ ์ถ”๊ฐ€
103
+ tokenizer.post_processor = TemplateProcessing(
104
+ single="<s> $A </s>",
105
+ special_tokens=[("<s>", 1), ("</s>", 2)],
106
+ )
107
+
108
+ self._tokenizer = tokenizer
109
+ self.vocab_size = tokenizer.get_vocab_size()
110
+ self.pad_id = 0
111
+ self.bos_id = 1
112
+ self.eos_id = 2
113
+
114
+ self._encode_fn = lambda text: tokenizer.encode(text).ids
115
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
116
+
117
+ # ์ €์žฅ
118
+ save_dir = save_dir or self.config.tokenizer_save_dir
119
+ os.makedirs(save_dir, exist_ok=True)
120
+ tokenizer.save(os.path.join(save_dir, "tokenizer.json"))
121
+ # ๋ฉ”ํƒ€ ์ •๋ณด ์ €์žฅ
122
+ meta = {
123
+ "vocab_size": self.vocab_size,
124
+ "bos_id": self.bos_id,
125
+ "eos_id": self.eos_id,
126
+ "pad_id": self.pad_id,
127
+ }
128
+ with open(os.path.join(save_dir, "tokenizer_meta.json"), "w") as f:
129
+ json.dump(meta, f, indent=2)
130
+
131
+ print(f"[Tokenizer] ํ•™์Šต ์™„๋ฃŒ: vocab_size={self.vocab_size}")
132
+ print(f"[Tokenizer] ์ €์žฅ ์œ„์น˜: {save_dir}")
133
+
134
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
135
+ # ๋ฐฉ๋ฒ• 3: ์‚ฌ์ „ ํ•™์Šต๋œ HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
136
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
137
+
138
+ def load_pretrained_hf(self, name_or_path: str = "meta-llama/Llama-2-7b-hf"):
139
+ """HuggingFace์—์„œ ์‚ฌ์ „ ํ•™์Šต๋œ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
140
+
141
+ ๊ฐ€์žฅ ๊ฐ„ํŽธํ•œ ๋ฐฉ๋ฒ•. LLaMA ํ† ํฌ๋‚˜์ด์ €๋Š” 32K vocab, BPE ๊ธฐ๋ฐ˜.
142
+ ์ฃผ์˜: meta-llama ๋ชจ๋ธ์€ HF ์Šน์ธ์ด ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Œ.
143
+ ๋Œ€์•ˆ: mistralai/Mistral-7B-v0.1 (์Šน์ธ ๋ถˆํ•„์š”)
144
+ """
145
+ from transformers import AutoTokenizer
146
+
147
+ print(f"[Tokenizer] HF ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ: {name_or_path}")
148
+ tokenizer = AutoTokenizer.from_pretrained(name_or_path)
149
+
150
+ self._tokenizer = tokenizer
151
+ self.vocab_size = tokenizer.vocab_size
152
+ self.bos_id = tokenizer.bos_token_id or 1
153
+ self.eos_id = tokenizer.eos_token_id or 2
154
+ self.pad_id = tokenizer.pad_token_id or 0
155
+
156
+ self._encode_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
157
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
158
+
159
+ print(f"[Tokenizer] ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
160
+
161
+ def load_trained_hf(self, path: str):
162
+ """train_bpe()๋กœ ํ•™์Šตํ•œ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ๋‹ค์‹œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."""
163
+ from tokenizers import Tokenizer as HFTokenizer
164
+
165
+ tokenizer = HFTokenizer.from_file(os.path.join(path, "tokenizer.json"))
166
+ with open(os.path.join(path, "tokenizer_meta.json"), "r") as f:
167
+ meta = json.load(f)
168
+
169
+ self._tokenizer = tokenizer
170
+ self.vocab_size = meta["vocab_size"]
171
+ self.bos_id = meta["bos_id"]
172
+ self.eos_id = meta["eos_id"]
173
+ self.pad_id = meta["pad_id"]
174
+
175
+ self._encode_fn = lambda text: tokenizer.encode(text).ids
176
+ self._decode_fn = lambda ids: tokenizer.decode(ids)
177
+
178
+ print(f"[Tokenizer] ๋กœ๋“œ ์™„๋ฃŒ: vocab_size={self.vocab_size}")
179
+
180
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
181
+ # ๊ณตํ†ต ์ธํ„ฐํŽ˜์ด์Šค
182
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
183
+
184
+ def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
185
+ """ํ…์ŠคํŠธ โ†’ ํ† ํฐ ID ๋ฆฌ์ŠคํŠธ."""
186
+ ids = self._encode_fn(text)
187
+ if add_special_tokens:
188
+ ids = [self.bos_id] + ids + [self.eos_id]
189
+ return ids
190
+
191
+ def decode(self, ids: List[int]) -> str:
192
+ """ํ† ํฐ ID ๋ฆฌ์ŠคํŠธ โ†’ ํ…์ŠคํŠธ."""
193
+ return self._decode_fn(ids)
194
+
195
+ def __len__(self) -> int:
196
+ return self.vocab_size
llm_lab/evaluation/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ‰๊ฐ€ ๋ชจ๋“ˆ โ€” Perplexity, ํ…์ŠคํŠธ ์ƒ์„ฑ, Scaling Law, Attention ์‹œ๊ฐํ™”."""
2
+
3
+ from .perplexity import PerplexityEvaluator
4
+ from .generation import GenerationEvaluator
5
+ from .scaling import ScalingAnalyzer
6
+ from .dynamics import TrainingDynamicsAnalyzer
7
+ from .attention_viz import AttentionVisualizer
8
+ from .full_evaluator import FullEvaluator
9
+ from .checklist import InsightChecklist
10
+ from .runner import run_evaluation
11
+
12
+ __all__ = [
13
+ "PerplexityEvaluator",
14
+ "GenerationEvaluator",
15
+ "ScalingAnalyzer",
16
+ "TrainingDynamicsAnalyzer",
17
+ "AttentionVisualizer",
18
+ "FullEvaluator",
19
+ "InsightChecklist",
20
+ "run_evaluation",
21
+ ]
llm_lab/evaluation/attention_viz.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Attention ํŒจํ„ด ์‹œ๊ฐํ™”."""
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ try:
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ HAS_MATPLOTLIB = True
16
+ except ImportError:
17
+ HAS_MATPLOTLIB = False
18
+
19
+
20
+ class AttentionVisualizer:
21
+ """Attention ํŒจํ„ด์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
22
+
23
+ ํ•™์Šต ํฌ์ธํŠธ:
24
+ - Causal Mask: ํ•˜์‚ผ๊ฐ ํŒจํ„ด (๋ฏธ๋ž˜ ํ† ํฐ์€ ๋ณผ ์ˆ˜ ์—†์Œ)
25
+ - ํ—ค๋“œ๋ณ„ ์—ญํ•  ๋ถ„ํ™”: ์ผ๋ถ€๋Š” ๋กœ์ปฌ(์ธ์ ‘), ์ผ๋ถ€๋Š” ๊ธ€๋กœ๋ฒŒ(๋จผ ํ† ํฐ) ์ฃผ๋ชฉ
26
+ - ๊ตฌ๋ฌธ๋ก ์  ํŒจํ„ด: ๋™์‚ฌโ†’์ฃผ์–ด, ๋Œ€๋ช…์‚ฌโ†’์„ ํ–‰์‚ฌ ๋“ฑ์— ๋†’์€ attention
27
+
28
+ ์ฃผ์˜: 1B ๋ชจ๋ธ์˜ ์ „์ฒด attention์„ ์ €์žฅํ•˜๋ฉด ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ!
29
+ โ†’ ํŠน์ • ๋ ˆ์ด์–ด/ํ—ค๋“œ๋งŒ ์„ ํƒ์ ์œผ๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
30
+ """
31
+
32
+ def __init__(self, save_dir: str = "./eval_results"):
33
+ self.save_dir = Path(save_dir)
34
+ self.save_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ @torch.no_grad()
37
+ def extract_attention(
38
+ self,
39
+ model: nn.Module,
40
+ input_ids: torch.Tensor,
41
+ layer_idx: int = 0,
42
+ device: torch.device = torch.device("cpu"),
43
+ ) -> torch.Tensor:
44
+ """ํŠน์ • ๋ ˆ์ด์–ด์˜ attention weight๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
45
+
46
+ ๋ชจ๋ธ์˜ attention ๋ชจ๋“ˆ์„ ์ผ์‹œ์ ์œผ๋กœ ์ˆ˜์ •ํ•˜์—ฌ
47
+ attention weight๋ฅผ ์บก์ฒ˜ํ•ฉ๋‹ˆ๋‹ค.
48
+
49
+ Returns:
50
+ attention_weights: (num_heads, seq_len, seq_len)
51
+ """
52
+ model.eval()
53
+ captured_attn = {}
54
+
55
+ # Hook์œผ๋กœ attention weight ์บก์ฒ˜
56
+ target_layer = model.layers[layer_idx].attention
57
+
58
+ # scaled_dot_product_attention์„ ์ˆ˜๋™ ๊ตฌํ˜„์œผ๋กœ ๋Œ€์ฒด
59
+ original_forward = target_layer.forward
60
+
61
+ def hooked_forward(x, mask=None, position_offset=0):
62
+ B, S, _ = x.shape
63
+ hd = target_layer.head_dim
64
+
65
+ q = target_layer.q_proj(x).view(B, S, target_layer.num_heads, hd).transpose(1, 2)
66
+ k = target_layer.k_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
67
+ v = target_layer.v_proj(x).view(B, S, target_layer.num_kv_heads, hd).transpose(1, 2)
68
+
69
+ q, k = target_layer.rope(q, k, position_offset)
70
+
71
+ if target_layer.num_kv_groups > 1:
72
+ k = target_layer._repeat_kv(k)
73
+ v = target_layer._repeat_kv(v)
74
+
75
+ # ์ˆ˜๋™ attention ๊ณ„์‚ฐ (weight ์ถ”์ถœ์šฉ)
76
+ scale = 1.0 / math.sqrt(hd)
77
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
78
+
79
+ # Causal mask
80
+ causal = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
81
+ scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float("-inf"))
82
+
83
+ attn_weights = F.softmax(scores, dim=-1)
84
+ captured_attn["weights"] = attn_weights[0].cpu() # ์ฒซ ๋ฐฐ์น˜๋งŒ
85
+
86
+ out = torch.matmul(attn_weights, v)
87
+ out = out.transpose(1, 2).contiguous().view(B, S, -1)
88
+ return target_layer.o_proj(out)
89
+
90
+ # Hook ์ ์šฉ
91
+ target_layer.forward = hooked_forward
92
+
93
+ try:
94
+ model(input_ids.to(device))
95
+ finally:
96
+ target_layer.forward = original_forward
97
+
98
+ return captured_attn.get("weights") # (num_heads, S, S)
99
+
100
+ def plot_attention_heatmap(
101
+ self,
102
+ attn_weights: torch.Tensor,
103
+ tokens: List[str],
104
+ head_idx: int = 0,
105
+ save_path: Optional[str] = None,
106
+ title: str = "Attention Weights",
107
+ ):
108
+ """Attention heatmap์„ ๊ทธ๋ฆฝ๋‹ˆ๋‹ค."""
109
+ if not HAS_MATPLOTLIB:
110
+ print("โš ๏ธ matplotlib๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค")
111
+ return
112
+
113
+ weights = attn_weights[head_idx].numpy()
114
+ max_len = min(len(tokens), 50) # ์ตœ๋Œ€ 50 ํ† ํฐ๋งŒ ํ‘œ์‹œ
115
+ weights = weights[:max_len, :max_len]
116
+ display_tokens = tokens[:max_len]
117
+
118
+ fig, ax = plt.subplots(figsize=(12, 10))
119
+ im = ax.imshow(weights, cmap="Blues", aspect="auto")
120
+
121
+ ax.set_xticks(range(max_len))
122
+ ax.set_yticks(range(max_len))
123
+ ax.set_xticklabels(display_tokens, rotation=90, fontsize=7)
124
+ ax.set_yticklabels(display_tokens, fontsize=7)
125
+
126
+ ax.set_xlabel("Key (attended to)", fontsize=11)
127
+ ax.set_ylabel("Query (attending from)", fontsize=11)
128
+ ax.set_title(f"{title} โ€” Head {head_idx}", fontsize=13, fontweight="bold")
129
+
130
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
131
+ plt.tight_layout()
132
+
133
+ save_path = save_path or str(self.save_dir / f"attention_head{head_idx}.png")
134
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
135
+ print(f" ๐Ÿ“Š Attention ์‹œ๊ฐํ™” ์ €์žฅ: {save_path}")
136
+ plt.close(fig)
137
+
138
+ def plot_multi_head_summary(
139
+ self,
140
+ attn_weights: torch.Tensor,
141
+ num_heads_to_show: int = 8,
142
+ save_path: Optional[str] = None,
143
+ ):
144
+ """์—ฌ๋Ÿฌ ํ—ค๋“œ์˜ attention ํŒจํ„ด์„ ์š”์•ฝ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค."""
145
+ if not HAS_MATPLOTLIB:
146
+ return
147
+
148
+ n_heads = min(attn_weights.shape[0], num_heads_to_show)
149
+ cols = 4
150
+ rows = math.ceil(n_heads / cols)
151
+
152
+ fig, axes = plt.subplots(rows, cols, figsize=(16, 4 * rows))
153
+ if rows == 1:
154
+ axes = axes.reshape(1, -1)
155
+
156
+ for idx in range(n_heads):
157
+ r, c = idx // cols, idx % cols
158
+ ax = axes[r, c]
159
+ w = attn_weights[idx].numpy()
160
+ ax.imshow(w, cmap="Blues", aspect="auto")
161
+ ax.set_title(f"Head {idx}", fontsize=10)
162
+ ax.set_xticks([])
163
+ ax.set_yticks([])
164
+
165
+ # ๋นˆ subplot ์ˆจ๊ธฐ๊ธฐ
166
+ for idx in range(n_heads, rows * cols):
167
+ r, c = idx // cols, idx % cols
168
+ axes[r, c].axis("off")
169
+
170
+ fig.suptitle("Attention Patterns by Head", fontsize=14, fontweight="bold")
171
+ plt.tight_layout()
172
+
173
+ save_path = save_path or str(self.save_dir / "attention_multi_head.png")
174
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
175
+ print(f" ๐Ÿ“Š ๋ฉ€ํ‹ฐ ํ—ค๋“œ ์š”์•ฝ ์ €์žฅ: {save_path}")
176
+ plt.close(fig)
llm_lab/evaluation/checklist.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ ๊ฒ€์ฆ๊ธฐ."""
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+
6
+ class InsightChecklist:
7
+ """PRD์— ์ •์˜๋œ ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ๋ฅผ ์ž๋™/์ˆ˜๋™์œผ๋กœ ๊ฒ€์ฆํ•ฉ๋‹ˆ๋‹ค.
8
+
9
+ ์ž๋™ ๊ฒ€์ฆ ๊ฐ€๋Šฅ ํ•ญ๋ชฉ์€ ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ฐ˜์œผ๋กœ ํŒ์ •ํ•˜๊ณ ,
10
+ ์ˆ˜๋™ ํ•ญ๋ชฉ์€ ์งˆ๋ฌธ์œผ๋กœ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.
11
+ """
12
+
13
+ @staticmethod
14
+ def run_checklist(
15
+ report: Dict[str, Any],
16
+ metrics_history: Optional[Dict[str, list]] = None,
17
+ ):
18
+ """์ฒดํฌ๋ฆฌ์ŠคํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."""
19
+ print("\n" + "=" * 70)
20
+ print("โœ… ํ•™์Šต ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ")
21
+ print("=" * 70)
22
+
23
+ checks = {
24
+ "passed": [],
25
+ "failed": [],
26
+ "manual": [],
27
+ }
28
+
29
+ # โ”€โ”€ ์ž๋™ ๊ฒ€์ฆ โ”€โ”€
30
+
31
+ # 1. Loss ์ˆ˜๋ ด
32
+ if report.get("perplexity", {}).get("loss", 99) < 4.0:
33
+ checks["passed"].append("๋ชจ๋ธ Loss๊ฐ€ 4.0 ์ดํ•˜๋กœ ์ˆ˜๋ ด")
34
+ else:
35
+ checks["failed"].append("๋ชจ๋ธ Loss๊ฐ€ 4.0 ์ดํ•˜๋กœ ๋ฏธ์ˆ˜๋ ด")
36
+
37
+ # 2. Loss ์ŠคํŒŒ์ดํฌ
38
+ spikes = report.get("training_dynamics", {}).get("loss", {}).get("spikes", [])
39
+ if len(spikes) < 5:
40
+ checks["passed"].append(f"Loss ์ŠคํŒŒ์ดํฌ {len(spikes)}ํšŒ (< 5ํšŒ)")
41
+ else:
42
+ checks["failed"].append(f"Loss ์ŠคํŒŒ์ดํฌ {len(spikes)}ํšŒ (โ‰ฅ 5ํšŒ, ์•ˆ์ •์„ฑ ๊ฐœ์„  ํ•„์š”)")
43
+
44
+ # 3. ์œ„์น˜๋ณ„ Loss ํŒจํ„ด
45
+ if report.get("position_losses"):
46
+ early = report["position_losses"]["early_avg"]
47
+ late = report["position_losses"]["late_avg"]
48
+ if early > late:
49
+ checks["passed"].append("์œ„์น˜๋ณ„ Loss ๊ฐ์†Œ ํŒจํ„ด ํ™•์ธ (์ปจํ…์ŠคํŠธ ํ™œ์šฉ)")
50
+ else:
51
+ checks["failed"].append("์œ„์น˜๋ณ„ Loss ํŒจํ„ด ์ด์ƒ (์ปจํ…์ŠคํŠธ ๋ฏธํ™œ์šฉ?)")
52
+
53
+ # 4. ์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ 
54
+ rep = report.get("generation", {}).get("avg_metrics", {}).get("repetition_rate", 1.0)
55
+ if rep < 0.3:
56
+ checks["passed"].append(f"์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ  {rep:.1%} (< 30%)")
57
+ else:
58
+ checks["failed"].append(f"์ƒ์„ฑ ๋ฐ˜๋ณต๋ฅ  {rep:.1%} (โ‰ฅ 30%, temperature/top_p ์กฐ์ •)")
59
+
60
+ # 5. Gradient ํด๋ฆฌํ•‘ ๋น„์œจ
61
+ if metrics_history and metrics_history.get("grad_norm"):
62
+ gnorms = metrics_history["grad_norm"]
63
+ clip_rate = sum(1 for g in gnorms if g >= 0.99) / max(len(gnorms), 1)
64
+ if clip_rate < 0.3:
65
+ checks["passed"].append(f"Gradient ํด๋ฆฌํ•‘ ๋น„์œจ {clip_rate:.1%} (๊ฑด๊ฐ•)")
66
+ else:
67
+ checks["failed"].append(f"Gradient ํด๋ฆฌํ•‘ ๋น„์œจ {clip_rate:.1%} (๋„ˆ๋ฌด ์žฆ์Œ)")
68
+
69
+ # โ”€โ”€ ์ˆ˜๋™ ํ™•์ธ ํ•ญ๋ชฉ โ”€โ”€
70
+ manual_items = [
71
+ "Self-Attention์—์„œ Q, K, V ๊ฐ๊ฐ์˜ ์—ญํ• ์„ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋Š”๊ฐ€?",
72
+ "RoPE๊ฐ€ ์œ„์น˜ ์ •๋ณด๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๋Š” ์ˆ˜ํ•™์  ์›๋ฆฌ๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
73
+ "GQA๊ฐ€ MHA ๋Œ€๋น„ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋Š”๊ฐ€?",
74
+ "SwiGLU์˜ ๊ฒŒ์ดํŒ… ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ReLU FFN๊ณผ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฅธ์ง€ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
75
+ "Learning Rate Warmup์ด ์™œ ํ•„์š”ํ•œ์ง€ ์ฒด๊ฐํ–ˆ๋Š”๊ฐ€?",
76
+ "Gradient Accumulation์ด ํฐ ๋ฐฐ์น˜๋ฅผ ์‹œ๋ฎฌ๋ ˆ์ด์…˜ํ•˜๋Š” ์›๋ฆฌ๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
77
+ "Mixed Precision(bf16)์˜ ๋ฉ”๋ชจ๋ฆฌ-์†๋„ ํšจ๊ณผ๋ฅผ ์ธก์ •ํ–ˆ๋Š”๊ฐ€?",
78
+ "Activation Checkpointing์˜ ๋ฉ”๋ชจ๋ฆฌ-์—ฐ์‚ฐ ํŠธ๋ ˆ์ด๋“œ์˜คํ”„๋ฅผ ์ดํ•ดํ•˜๋Š”๊ฐ€?",
79
+ ]
80
+ checks["manual"] = manual_items
81
+
82
+ # โ”€โ”€ ์ถœ๋ ฅ โ”€โ”€
83
+ total_auto = len(checks["passed"]) + len(checks["failed"])
84
+ passed_auto = len(checks["passed"])
85
+
86
+ print(f"\n ์ž๋™ ๊ฒ€์ฆ: {passed_auto}/{total_auto} ํ†ต๊ณผ")
87
+ for item in checks["passed"]:
88
+ print(f" โœ… {item}")
89
+ for item in checks["failed"]:
90
+ print(f" โŒ {item}")
91
+
92
+ print(f"\n ์ˆ˜๋™ ํ™•์ธ ({len(manual_items)} ํ•ญ๋ชฉ):")
93
+ for i, item in enumerate(manual_items, 1):
94
+ print(f" {i}. [ ] {item}")
95
+
96
+ print(f"\n ์ด ์ง„ํ–‰๋ฅ : {passed_auto}/{total_auto + len(manual_items)} "
97
+ f"(์ˆ˜๋™ ํ•ญ๋ชฉ ํฌํ•จ ์‹œ)")
98
+
99
+ return checks
llm_lab/evaluation/dynamics.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ์—ญํ•™ ๋ถ„์„๊ธฐ."""
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ try:
8
+ import matplotlib
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ HAS_MATPLOTLIB = True
12
+ except ImportError:
13
+ HAS_MATPLOTLIB = False
14
+
15
+
16
+ class TrainingDynamicsAnalyzer:
17
+ """ํ•™์Šต ๊ณผ์ •์˜ ๋ฉ”ํŠธ๋ฆญ์„ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
18
+
19
+ ๋ถ„์„ ํ•ญ๋ชฉ:
20
+ - Loss ๊ณก์„ : ์ˆ˜๋ ด ํŒจํ„ด, ์ŠคํŒŒ์ดํฌ ๊ฐ์ง€
21
+ - LR ์Šค์ผ€์ค„: Warmup + Cosine decay ํ™•์ธ
22
+ - Gradient Norm: ํ•™์Šต ์•ˆ์ •์„ฑ, ํญ๋ฐœ/์†Œ๋ฉธ ๊ฐ์ง€
23
+ - ์ฒ˜๋ฆฌ๋Ÿ‰: tokens/sec ์•ˆ์ •์„ฑ, ๋ณ‘๋ชฉ ๊ฐ์ง€
24
+ """
25
+
26
+ def __init__(self, save_dir: str = "./eval_results"):
27
+ self.save_dir = Path(save_dir)
28
+ self.save_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ def analyze_metrics(self, metrics_history: Dict[str, list]) -> Dict[str, Any]:
31
+ """ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ์„ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
32
+
33
+ Args:
34
+ metrics_history: Trainer.metrics.history ๋”•์…”๋„ˆ๋ฆฌ
35
+
36
+ Returns:
37
+ ๋ถ„์„ ๊ฒฐ๊ณผ
38
+ """
39
+ print("\n" + "=" * 70)
40
+ print("๐Ÿ”ฌ ํ•™์Šต ์—ญํ•™ ๋ถ„์„")
41
+ print("=" * 70)
42
+
43
+ analysis = {}
44
+
45
+ # โ”€โ”€ Loss ๋ถ„์„ โ”€โ”€
46
+ if metrics_history.get("train_loss"):
47
+ losses = metrics_history["train_loss"]
48
+ analysis["loss"] = {
49
+ "initial": round(losses[0], 4),
50
+ "final": round(losses[-1], 4),
51
+ "minimum": round(min(losses), 4),
52
+ "total_reduction": round(losses[0] - losses[-1], 4),
53
+ }
54
+
55
+ # ์ŠคํŒŒ์ดํฌ ๊ฐ์ง€ (์ด์ „ ๊ฐ’ ๋Œ€๋น„ 50% ์ด์ƒ ๊ธ‰์ฆ)
56
+ spikes = []
57
+ for i in range(1, len(losses)):
58
+ if losses[i] > losses[i-1] * 1.5:
59
+ step = metrics_history["step"][i] if "step" in metrics_history else i
60
+ spikes.append({"step": step, "loss": round(losses[i], 4)})
61
+
62
+ analysis["loss"]["spikes"] = spikes
63
+
64
+ print(f"\n ๐Ÿ“‰ Loss ๋ถ„์„:")
65
+ print(f" ์ดˆ๊ธฐ: {analysis['loss']['initial']:.4f}")
66
+ print(f" ์ตœ์ข…: {analysis['loss']['final']:.4f}")
67
+ print(f" ์ตœ์†Œ: {analysis['loss']['minimum']:.4f}")
68
+ print(f" ๊ฐ์†Œ: {analysis['loss']['total_reduction']:.4f}")
69
+ print(f" ์ŠคํŒŒ์ดํฌ: {len(spikes)}ํšŒ")
70
+ if spikes:
71
+ for s in spikes[:5]:
72
+ print(f" Step {s['step']}: Loss = {s['loss']}")
73
+
74
+ # โ”€โ”€ Gradient Norm ๋ถ„์„ โ”€โ”€
75
+ if metrics_history.get("grad_norm"):
76
+ gnorms = metrics_history["grad_norm"]
77
+ analysis["grad_norm"] = {
78
+ "mean": round(sum(gnorms) / len(gnorms), 4),
79
+ "max": round(max(gnorms), 4),
80
+ "min": round(min(gnorms), 4),
81
+ "clipped_pct": round(sum(1 for g in gnorms if g >= 0.99) / len(gnorms) * 100, 1),
82
+ }
83
+
84
+ print(f"\n ๐Ÿ“ Gradient Norm ๋ถ„์„:")
85
+ print(f" ํ‰๊ท : {analysis['grad_norm']['mean']:.4f}")
86
+ print(f" ์ตœ๋Œ€: {analysis['grad_norm']['max']:.4f}")
87
+ print(f" ํด๋ฆฌํ•‘ ๋น„์œจ: {analysis['grad_norm']['clipped_pct']:.1f}%")
88
+ if analysis["grad_norm"]["clipped_pct"] > 30:
89
+ print(f" โš ๏ธ ํด๋ฆฌํ•‘์ด ์žฆ์Œ โ†’ LR ํ•˜ํ–ฅ ๋˜๋Š” warmup ์—ฐ์žฅ ๊ณ ๋ ค")
90
+
91
+ # โ”€โ”€ ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ถ„์„ โ”€โ”€
92
+ if metrics_history.get("tokens_per_sec"):
93
+ tps = metrics_history["tokens_per_sec"]
94
+ tps_valid = [t for t in tps if t > 0]
95
+ if tps_valid:
96
+ analysis["throughput"] = {
97
+ "mean": round(sum(tps_valid) / len(tps_valid)),
98
+ "std": round((sum((t - sum(tps_valid)/len(tps_valid))**2 for t in tps_valid) / len(tps_valid))**0.5),
99
+ "min": round(min(tps_valid)),
100
+ "max": round(max(tps_valid)),
101
+ }
102
+
103
+ print(f"\n โšก ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ถ„์„:")
104
+ print(f" ํ‰๊ท : {analysis['throughput']['mean']:,} tokens/sec")
105
+ print(f" ํ‘œ์ค€ํŽธ์ฐจ: {analysis['throughput']['std']:,}")
106
+ print(f" ๋ฒ”์œ„: [{analysis['throughput']['min']:,}, {analysis['throughput']['max']:,}]")
107
+
108
+ return analysis
109
+
110
+ def plot_training_curves(
111
+ self,
112
+ metrics_history: Dict[str, list],
113
+ save_path: Optional[str] = None,
114
+ ):
115
+ """ํ•™์Šต ๊ณก์„ ์„ 4-panel ์ฐจํŠธ๋กœ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
116
+ if not HAS_MATPLOTLIB:
117
+ print("โš ๏ธ matplotlib๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค: pip install matplotlib")
118
+ return
119
+
120
+ fig, axes = plt.subplots(2, 2, figsize=(16, 10))
121
+ fig.suptitle("Training Dynamics", fontsize=16, fontweight="bold")
122
+
123
+ steps = metrics_history.get("step", list(range(len(metrics_history.get("train_loss", [])))))
124
+
125
+ # โ”€โ”€ (1) Loss โ”€โ”€
126
+ ax = axes[0, 0]
127
+ if metrics_history.get("train_loss"):
128
+ ax.plot(steps[:len(metrics_history["train_loss"])],
129
+ metrics_history["train_loss"],
130
+ color="#2563eb", alpha=0.6, linewidth=0.8, label="Train Loss")
131
+
132
+ # ์ด๋™ ํ‰๊ท  (์Šค๋ฌด๋”ฉ)
133
+ if len(metrics_history["train_loss"]) > 20:
134
+ window = min(50, len(metrics_history["train_loss"]) // 5)
135
+ smoothed = self._moving_average(metrics_history["train_loss"], window)
136
+ ax.plot(steps[window-1:len(smoothed)+window-1],
137
+ smoothed, color="#1d4ed8", linewidth=2, label=f"Smoothed (window={window})")
138
+
139
+ if metrics_history.get("val_loss"):
140
+ val_steps = [steps[i] for i in range(0, len(steps),
141
+ max(1, len(steps)//len(metrics_history["val_loss"])))][:len(metrics_history["val_loss"])]
142
+ ax.plot(val_steps, metrics_history["val_loss"],
143
+ "o-", color="#dc2626", linewidth=2, markersize=5, label="Val Loss")
144
+
145
+ ax.set_xlabel("Step")
146
+ ax.set_ylabel("Loss")
147
+ ax.set_title("Training & Validation Loss")
148
+ ax.legend()
149
+ ax.grid(True, alpha=0.3)
150
+
151
+ # โ”€โ”€ (2) Learning Rate โ”€โ”€
152
+ ax = axes[0, 1]
153
+ if metrics_history.get("learning_rate"):
154
+ ax.plot(steps[:len(metrics_history["learning_rate"])],
155
+ metrics_history["learning_rate"],
156
+ color="#059669", linewidth=2)
157
+ ax.set_xlabel("Step")
158
+ ax.set_ylabel("Learning Rate")
159
+ ax.set_title("Learning Rate Schedule")
160
+ ax.ticklabel_format(style="scientific", axis="y", scilimits=(0,0))
161
+ ax.grid(True, alpha=0.3)
162
+
163
+ # โ”€โ”€ (3) Gradient Norm โ”€โ”€
164
+ ax = axes[1, 0]
165
+ if metrics_history.get("grad_norm"):
166
+ ax.plot(steps[:len(metrics_history["grad_norm"])],
167
+ metrics_history["grad_norm"],
168
+ color="#d97706", alpha=0.6, linewidth=0.8)
169
+ ax.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="Clip threshold")
170
+ ax.legend()
171
+ ax.set_xlabel("Step")
172
+ ax.set_ylabel("Gradient Norm")
173
+ ax.set_title("Gradient Norm (clipped at 1.0)")
174
+ ax.grid(True, alpha=0.3)
175
+
176
+ # โ”€โ”€ (4) Throughput โ”€โ”€
177
+ ax = axes[1, 1]
178
+ if metrics_history.get("tokens_per_sec"):
179
+ tps = metrics_history["tokens_per_sec"]
180
+ ax.plot(steps[:len(tps)], tps, color="#7c3aed", alpha=0.6, linewidth=0.8)
181
+ if tps:
182
+ avg_tps = sum(tps) / len(tps)
183
+ ax.axhline(y=avg_tps, color="#7c3aed", linestyle="--", alpha=0.5,
184
+ label=f"Avg: {avg_tps:,.0f}")
185
+ ax.legend()
186
+ ax.set_xlabel("Step")
187
+ ax.set_ylabel("Tokens/sec")
188
+ ax.set_title("Training Throughput")
189
+ ax.grid(True, alpha=0.3)
190
+
191
+ plt.tight_layout()
192
+
193
+ save_path = save_path or str(self.save_dir / "training_curves.png")
194
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
195
+ print(f"\n ๐Ÿ“Š ํ•™์Šต ๊ณก์„  ์ €์žฅ: {save_path}")
196
+ plt.close(fig)
197
+
198
+ def plot_position_loss(
199
+ self,
200
+ position_losses: List[float],
201
+ save_path: Optional[str] = None,
202
+ ):
203
+ """์œ„์น˜๋ณ„ Loss ๋ถ„ํฌ๋ฅผ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
204
+ if not HAS_MATPLOTLIB:
205
+ return
206
+
207
+ fig, ax = plt.subplots(figsize=(12, 5))
208
+
209
+ positions = list(range(len(position_losses)))
210
+ ax.plot(positions, position_losses, color="#2563eb", linewidth=1.5)
211
+ ax.fill_between(positions, position_losses, alpha=0.1, color="#2563eb")
212
+
213
+ ax.set_xlabel("Position in Sequence", fontsize=12)
214
+ ax.set_ylabel("Cross-Entropy Loss", fontsize=12)
215
+ ax.set_title("Loss by Position (earlier positions have less context)", fontsize=13, fontweight="bold")
216
+ ax.grid(True, alpha=0.3)
217
+
218
+ # ์ฃผ์š” ๊ตฌ๊ฐ„ ํ‘œ์‹œ
219
+ if len(position_losses) > 100:
220
+ early_avg = sum(position_losses[:50]) / 50
221
+ late_avg = sum(position_losses[-200:]) / 200
222
+ ax.axhline(y=early_avg, color="red", linestyle="--", alpha=0.4,
223
+ label=f"Early avg (0-50): {early_avg:.2f}")
224
+ ax.axhline(y=late_avg, color="green", linestyle="--", alpha=0.4,
225
+ label=f"Late avg (-200): {late_avg:.2f}")
226
+ ax.legend()
227
+
228
+ plt.tight_layout()
229
+
230
+ save_path = save_path or str(self.save_dir / "position_loss.png")
231
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
232
+ print(f" ๐Ÿ“Š ์œ„์น˜๋ณ„ Loss ์ €์žฅ: {save_path}")
233
+ plt.close(fig)
234
+
235
+ @staticmethod
236
+ def _moving_average(data: list, window: int) -> list:
237
+ """์ด๋™ ํ‰๊ท  ๊ณ„์‚ฐ."""
238
+ result = []
239
+ for i in range(window - 1, len(data)):
240
+ avg = sum(data[i - window + 1 : i + 1]) / window
241
+ result.append(avg)
242
+ return result
llm_lab/evaluation/full_evaluator.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """์ข…ํ•ฉ ํ‰๊ฐ€ ์‹คํ–‰๊ธฐ."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader
11
+
12
+ from llm_lab.config import EvalConfig
13
+ from .perplexity import PerplexityEvaluator
14
+ from .generation import GenerationEvaluator
15
+ from .dynamics import TrainingDynamicsAnalyzer
16
+ from .attention_viz import AttentionVisualizer
17
+
18
+
19
+ class FullEvaluator:
20
+ """๋ชจ๋“  ํ‰๊ฐ€๋ฅผ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•˜๊ณ  ๋ฆฌํฌํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
21
+
22
+ ์‚ฌ์šฉ๋ฒ•:
23
+ ```python
24
+ evaluator = FullEvaluator(model, tokenizer, val_dataloader, device)
25
+ report = evaluator.run_full_evaluation()
26
+ ```
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model: nn.Module,
32
+ tokenizer: Any,
33
+ val_dataloader: DataLoader,
34
+ device: torch.device,
35
+ config: Optional[EvalConfig] = None,
36
+ dtype: torch.dtype = torch.bfloat16,
37
+ metrics_history: Optional[Dict[str, list]] = None,
38
+ ):
39
+ self.model = model
40
+ self.tokenizer = tokenizer
41
+ self.val_dataloader = val_dataloader
42
+ self.device = device
43
+ self.config = config or EvalConfig()
44
+ self.dtype = dtype
45
+ self.metrics_history = metrics_history
46
+
47
+ self.save_dir = Path(self.config.save_dir)
48
+ self.save_dir.mkdir(parents=True, exist_ok=True)
49
+
50
+ def run_full_evaluation(self) -> Dict[str, Any]:
51
+ """์ „์ฒด ํ‰๊ฐ€๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."""
52
+ report = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
53
+
54
+ print("\n" + "=" * 70)
55
+ print("๐Ÿ” ์ข…ํ•ฉ ํ‰๊ฐ€ ์‹œ์ž‘")
56
+ print("=" * 70)
57
+
58
+ # โ”€โ”€ 1. Perplexity โ”€โ”€
59
+ print("\n" + "โ”" * 40)
60
+ print("Phase 1/4: Perplexity ์ธก์ •")
61
+ print("โ”" * 40)
62
+ ppl_evaluator = PerplexityEvaluator(self.config)
63
+ report["perplexity"] = ppl_evaluator.evaluate(
64
+ self.model, self.val_dataloader, self.device, self.dtype
65
+ )
66
+
67
+ # ์œ„์น˜๋ณ„ Loss
68
+ print("\n ์œ„์น˜๋ณ„ Loss ์ธก์ • ์ค‘...")
69
+ position_losses = ppl_evaluator.evaluate_per_position(
70
+ self.model, self.val_dataloader, self.device, self.dtype
71
+ )
72
+ report["position_losses"] = {
73
+ "early_avg": round(sum(position_losses[:50]) / max(len(position_losses[:50]), 1), 4),
74
+ "late_avg": round(sum(position_losses[-200:]) / max(len(position_losses[-200:]), 1), 4),
75
+ }
76
+
77
+ # ์œ„์น˜๋ณ„ Loss ์‹œ๊ฐํ™”
78
+ dynamics = TrainingDynamicsAnalyzer(str(self.save_dir))
79
+ dynamics.plot_position_loss(position_losses, str(self.save_dir / "position_loss.png"))
80
+
81
+ # โ”€โ”€ 2. ํ…์ŠคํŠธ ์ƒ์„ฑ โ”€โ”€
82
+ print("\n" + "โ”" * 40)
83
+ print("Phase 2/4: ํ…์ŠคํŠธ ์ƒ์„ฑ")
84
+ print("โ”" * 40)
85
+ gen_evaluator = GenerationEvaluator(self.config)
86
+ gen_results = gen_evaluator.generate_samples(
87
+ self.model, self.tokenizer, self.device
88
+ )
89
+ report["generation"] = {
90
+ "num_prompts": len(gen_results),
91
+ "avg_metrics": self._average_gen_metrics(gen_results),
92
+ }
93
+
94
+ # โ”€โ”€ 3. ํ•™์Šต ์—ญํ•™ ๋ถ„์„ โ”€โ”€
95
+ if self.metrics_history:
96
+ print("\n" + "โ”" * 40)
97
+ print("Phase 3/4: ํ•™์Šต ์—ญํ•™ ๋ถ„์„")
98
+ print("โ”" * 40)
99
+ report["training_dynamics"] = dynamics.analyze_metrics(self.metrics_history)
100
+ dynamics.plot_training_curves(self.metrics_history,
101
+ str(self.save_dir / "training_curves.png"))
102
+ else:
103
+ print("\n Phase 3/4: ๊ฑด๋„ˆ๋œ€ (metrics_history ์—†์Œ)")
104
+
105
+ # โ”€โ”€ 4. Attention ์‹œ๊ฐํ™” (์ƒ˜ํ”Œ) โ”€โ”€
106
+ print("\n" + "โ”" * 40)
107
+ print("Phase 4/4: Attention ์‹œ๊ฐํ™”")
108
+ print("โ”" * 40)
109
+ try:
110
+ self._visualize_attention_sample()
111
+ except Exception as e:
112
+ print(f" โš ๏ธ Attention ์‹œ๊ฐํ™” ์‹คํŒจ: {e}")
113
+
114
+ # โ”€โ”€ ๋ฆฌํฌํŠธ ์ €์žฅ โ”€โ”€
115
+ report_path = self.save_dir / "eval_report.json"
116
+ with open(report_path, "w") as f:
117
+ json.dump(report, f, indent=2, default=str)
118
+ print(f"\n๐Ÿ“‹ ๋ฆฌํฌํŠธ ์ €์žฅ: {report_path}")
119
+
120
+ # โ”€โ”€ ์š”์•ฝ ์ถœ๋ ฅ โ”€โ”€
121
+ self._print_summary(report)
122
+
123
+ return report
124
+
125
+ def _visualize_attention_sample(self):
126
+ """์ƒ˜ํ”Œ ํ…์ŠคํŠธ๋กœ attention์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
127
+ viz = AttentionVisualizer(str(self.save_dir))
128
+
129
+ sample_text = "The cat sat on the mat and looked at the bird."
130
+ token_ids = self.tokenizer.encode(sample_text, add_special_tokens=False)
131
+ input_tensor = torch.tensor([token_ids], dtype=torch.long)
132
+
133
+ # ํ† ํฐ ๋ฌธ์ž์—ด (์‹œ๊ฐํ™” ๋ผ๋ฒจ์šฉ)
134
+ tokens_str = []
135
+ for tid in token_ids:
136
+ decoded = self.tokenizer.decode([tid])
137
+ tokens_str.append(decoded.replace("\n", "\\n"))
138
+
139
+ # Layer 0 attention ์ถ”์ถœ
140
+ attn_weights = viz.extract_attention(
141
+ self.model, input_tensor, layer_idx=0, device=self.device
142
+ )
143
+
144
+ if attn_weights is not None:
145
+ viz.plot_attention_heatmap(
146
+ attn_weights, tokens_str, head_idx=0,
147
+ title="Layer 0 Attention"
148
+ )
149
+ viz.plot_multi_head_summary(attn_weights)
150
+
151
+ @staticmethod
152
+ def _average_gen_metrics(gen_results: List[Dict]) -> Dict[str, float]:
153
+ """๋ชจ๋“  ํ”„๋กฌํ”„ํŠธ์˜ ์ƒ์„ฑ ๋ฉ”ํŠธ๋ฆญ ํ‰๊ท ."""
154
+ if not gen_results:
155
+ return {}
156
+
157
+ all_metrics = [r["metrics"] for r in gen_results if r.get("metrics")]
158
+ if not all_metrics:
159
+ return {}
160
+
161
+ keys = all_metrics[0].keys()
162
+ return {
163
+ k: round(sum(m.get(k, 0) for m in all_metrics) / len(all_metrics), 3)
164
+ for k in keys
165
+ }
166
+
167
+ def _print_summary(self, report: Dict[str, Any]):
168
+ """์ตœ์ข… ์š”์•ฝ์„ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค."""
169
+ print("\n" + "=" * 70)
170
+ print("๐Ÿ“‹ ํ‰๊ฐ€ ์š”์•ฝ ๋ฆฌํฌํŠธ")
171
+ print("=" * 70)
172
+
173
+ # Perplexity
174
+ if "perplexity" in report:
175
+ ppl = report["perplexity"]
176
+ print(f"\n ๐ŸŽฏ Perplexity:")
177
+ print(f" Loss: {ppl['loss']:.4f}")
178
+ print(f" PPL: {ppl['perplexity']:.2f}")
179
+
180
+ # ๋“ฑ๊ธ‰ ํŒ์ •
181
+ ppl_val = ppl["perplexity"]
182
+ if ppl_val < 20:
183
+ grade = "๐ŸŒŸ ์šฐ์ˆ˜ (Strong)"
184
+ elif ppl_val < 35:
185
+ grade = "โœ… ์–‘ํ˜ธ (Good)"
186
+ elif ppl_val < 60:
187
+ grade = "โš ๏ธ ๋ณดํ†ต (Fair)"
188
+ else:
189
+ grade = "โŒ ๋ฏธํก (ํ•™์Šต ์ถ”๊ฐ€ ํ•„์š”)"
190
+ print(f" ๋“ฑ๊ธ‰: {grade}")
191
+
192
+ # ์œ„์น˜๋ณ„ Loss
193
+ if "position_losses" in report:
194
+ pl = report["position_losses"]
195
+ print(f"\n ๐Ÿ“ ์œ„์น˜๋ณ„ Loss:")
196
+ print(f" ์ดˆ๋ฐ˜ (0-50): {pl['early_avg']:.4f}")
197
+ print(f" ํ›„๋ฐ˜ (-200): {pl['late_avg']:.4f}")
198
+ print(f" ์ปจํ…์ŠคํŠธ ํšจ๊ณผ: {pl['early_avg'] - pl['late_avg']:.4f} ๊ฐ์†Œ")
199
+
200
+ # ์ƒ์„ฑ ํ’ˆ์งˆ
201
+ if "generation" in report and report["generation"].get("avg_metrics"):
202
+ gm = report["generation"]["avg_metrics"]
203
+ print(f"\n โœ๏ธ ์ƒ์„ฑ ํ’ˆ์งˆ:")
204
+ print(f" ํ‰๊ท  ๊ธธ์ด: {gm.get('avg_length', 0):.0f} ์ž")
205
+ print(f" ๋ฐ˜๋ณต๋ฅ : {gm.get('repetition_rate', 0):.1%}")
206
+ print(f" ์–ดํœ˜ ๋‹ค์–‘์„ฑ: {gm.get('lexical_diversity', 0):.3f}")
207
+
208
+ # ํ•™์Šต ์—ญํ•™
209
+ if "training_dynamics" in report:
210
+ td = report["training_dynamics"]
211
+ if "loss" in td:
212
+ print(f"\n ๐Ÿ“‰ ํ•™์Šต ์—ญํ•™:")
213
+ print(f" Loss ๊ฐ์†Œ: {td['loss']['initial']:.4f} โ†’ {td['loss']['final']:.4f}")
214
+ print(f" ์ŠคํŒŒ์ดํฌ: {len(td['loss']['spikes'])}ํšŒ")
215
+
216
+ # ์ƒ์„ฑ๋œ ํŒŒ์ผ
217
+ print(f"\n ๐Ÿ“‚ ๊ฒฐ๊ณผ ํŒŒ์ผ:")
218
+ for f in sorted(self.save_dir.glob("*")):
219
+ size = f.stat().st_size / 1024
220
+ print(f" {f.name} ({size:.1f} KB)")
221
+
222
+ print("\n" + "=" * 70)
llm_lab/evaluation/generation.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ…์ŠคํŠธ ์ƒ์„ฑ ํ‰๊ฐ€๊ธฐ."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from llm_lab.config import EvalConfig
9
+
10
+
11
+ class GenerationEvaluator:
12
+ """๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ํ’ˆ์งˆ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
13
+
14
+ ํ‰๊ฐ€ ๊ด€์ :
15
+ 1) ๋ฌธ๋ฒ•์  ์ •ํ™•์„ฑ: ์˜์–ด ๋ฌธ๋ฒ•์— ๋งž๋Š” ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๋Š”๊ฐ€?
16
+ 2) ์ผ๊ด€์„ฑ: ๋ฌธ๋งฅ์„ ์œ ์ง€ํ•˜๋ฉฐ ์ด์–ด๊ฐ€๋Š”๊ฐ€?
17
+ 3) ๋‹ค์–‘์„ฑ: ๊ฐ™์€ ํ”„๋กฌํ”„ํŠธ์— ๋‹ค๋ฅธ ๊ฒฐ๊ณผ๋ฅผ ์ƒ์„ฑํ•˜๋Š”๊ฐ€?
18
+ 4) ๋ฐ˜๋ณต ํšŒํ”ผ: ๊ฐ™์€ ๊ตฌ์ ˆ์„ ๋ฐ˜๋ณตํ•˜์ง€ ์•Š๋Š”๊ฐ€?
19
+ 5) ์ง€์‹ ํ‘œํ˜„: ํ•™์Šต ๋ฐ์ดํ„ฐ์˜ ์ง€์‹์ด ๋ฐ˜์˜๋˜๋Š”๊ฐ€?
20
+
21
+ 1B ๋ชจ๋ธ์˜ ํ˜„์‹ค์  ๊ธฐ๋Œ€์น˜:
22
+ - ๋ฌธ๋ฒ•์ ์œผ๋กœ ์˜ฌ๋ฐ”๋ฅธ ์˜์–ด ๋ฌธ์žฅ ์ƒ์„ฑ โœ…
23
+ - ์งง์€ ๋ฌธ๋‹จ ๋‚ด ์ผ๊ด€์„ฑ ์œ ์ง€ โœ…
24
+ - ๋ณต์žกํ•œ ์ถ”๋ก ์ด๋‚˜ ๊ธด ๋…ผ๋ฆฌ ์ „๊ฐœ โŒ (๋” ํฐ ๋ชจ๋ธ ํ•„์š”)
25
+ - ์‚ฌ์‹ค์  ์ •ํ™•์„ฑ์€ ๋ณด์žฅ ์•ˆ ๋จ โš ๏ธ
26
+ """
27
+
28
+ # ๋‹ค์–‘ํ•œ ๋„๋ฉ”์ธ์˜ ํ…Œ์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ
29
+ DEFAULT_PROMPTS = [
30
+ # โ”€โ”€ ์ผ๋ฐ˜ ์ง€์‹ โ”€โ”€
31
+ "The theory of relativity states that",
32
+ "In the history of computer science,",
33
+ "The human brain is remarkable because",
34
+
35
+ # โ”€โ”€ ์„ค๋ช…/๊ต์œก โ”€โ”€
36
+ "To understand machine learning, one must first",
37
+ "The water cycle begins when",
38
+ "Photosynthesis is the process by which",
39
+
40
+ # โ”€โ”€ ์„œ์‚ฌ/์Šคํ† ๋ฆฌ โ”€โ”€
41
+ "Once upon a time, in a small village near the mountains,",
42
+ "The detective looked at the evidence and realized that",
43
+
44
+ # โ”€โ”€ ์ฝ”๋“œ/๊ธฐ์ˆ  โ”€โ”€
45
+ "def fibonacci(n):\n \"\"\"Calculate the nth Fibonacci number.\"\"\"\n",
46
+ "The most important data structures in programming are",
47
+
48
+ # โ”€โ”€ ์งง์€ ์™„์„ฑ โ”€โ”€
49
+ "The capital of France is",
50
+ "Water boils at a temperature of",
51
+
52
+ # โ”€โ”€ ๊ธด ๋ฌธ๋งฅ โ”€โ”€
53
+ ("Artificial intelligence has transformed many industries. "
54
+ "In healthcare, AI is used for diagnosis and drug discovery. "
55
+ "In finance, it powers algorithmic trading and fraud detection. "
56
+ "Looking ahead, the most promising application of AI is"),
57
+ ]
58
+
59
+ def __init__(self, config: EvalConfig):
60
+ self.config = config
61
+
62
+ @torch.no_grad()
63
+ def generate_samples(
64
+ self,
65
+ model: nn.Module,
66
+ tokenizer: Any,
67
+ device: torch.device,
68
+ prompts: Optional[List[str]] = None,
69
+ verbose: bool = True,
70
+ ) -> List[Dict[str, Any]]:
71
+ """ํ”„๋กฌํ”„ํŠธ๋ณ„๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
72
+
73
+ Returns:
74
+ [{"prompt": str, "generations": [str, ...], "metrics": {...}}, ...]
75
+ """
76
+ model.eval()
77
+ prompts = prompts or self.DEFAULT_PROMPTS
78
+ results = []
79
+
80
+ if verbose:
81
+ print("\n" + "=" * 70)
82
+ print("๐Ÿ“ ํ…์ŠคํŠธ ์ƒ์„ฑ ํ‰๊ฐ€")
83
+ print("=" * 70)
84
+
85
+ for idx, prompt in enumerate(prompts):
86
+ prompt_results = {
87
+ "prompt": prompt,
88
+ "generations": [],
89
+ "metrics": {},
90
+ }
91
+
92
+ if verbose:
93
+ print(f"\n{'โ”€'*60}")
94
+ print(f"ํ”„๋กฌํ”„ํŠธ [{idx+1}/{len(prompts)}]:")
95
+ print(f" \"{prompt[:80]}{'...' if len(prompt) > 80 else ''}\"")
96
+ print(f"{'โ”€'*60}")
97
+
98
+ # ํ”„๋กฌํ”„ํŠธ ์ธ์ฝ”๋”ฉ
99
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
100
+ input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
101
+
102
+ all_texts = []
103
+ for sample_idx in range(self.config.num_samples):
104
+ # ์ƒ์„ฑ
105
+ generated_ids = model.generate(
106
+ input_tensor,
107
+ max_new_tokens=self.config.max_new_tokens,
108
+ temperature=self.config.temperature,
109
+ top_k=self.config.top_k,
110
+ top_p=self.config.top_p,
111
+ )
112
+
113
+ # ๋””์ฝ”๋”ฉ (ํ”„๋กฌํ”„ํŠธ ์ดํ›„ ๋ถ€๋ถ„๋งŒ)
114
+ new_ids = generated_ids[0][len(prompt_ids):].tolist()
115
+ generated_text = tokenizer.decode(new_ids)
116
+ all_texts.append(generated_text)
117
+
118
+ prompt_results["generations"].append(generated_text)
119
+
120
+ if verbose:
121
+ print(f"\n โœ๏ธ ์ƒ์„ฑ #{sample_idx+1}:")
122
+ # ๊น”๋”ํ•œ ์ถœ๋ ฅ (์ค„๋ฐ”๊ฟˆ ํฌํ•จ)
123
+ display_text = generated_text[:500]
124
+ for line in display_text.split("\n"):
125
+ print(f" {line}")
126
+ if len(generated_text) > 500:
127
+ print(f" ... (์ด {len(generated_text)} ๋ฌธ์ž)")
128
+
129
+ # ์ƒ์„ฑ ํ’ˆ์งˆ ๋ฉ”ํŠธ๋ฆญ
130
+ prompt_results["metrics"] = self._compute_generation_metrics(all_texts)
131
+
132
+ if verbose and prompt_results["metrics"]:
133
+ m = prompt_results["metrics"]
134
+ print(f"\n ๐Ÿ“Š ๋ฉ”ํŠธ๋ฆญ: "
135
+ f"ํ‰๊ท  ๊ธธ์ด={m['avg_length']:.0f}์ž, "
136
+ f"๋ฐ˜๋ณต๋ฅ ={m['repetition_rate']:.1%}, "
137
+ f"์–ดํœ˜ ๋‹ค์–‘์„ฑ={m['lexical_diversity']:.2f}")
138
+
139
+ results.append(prompt_results)
140
+
141
+ return results
142
+
143
+ @staticmethod
144
+ def _compute_generation_metrics(texts: List[str]) -> Dict[str, float]:
145
+ """์ƒ์„ฑ ํ…์ŠคํŠธ์˜ ํ’ˆ์งˆ ๋ฉ”ํŠธ๋ฆญ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
146
+
147
+ ๋ฉ”ํŠธ๋ฆญ:
148
+ - avg_length: ํ‰๊ท  ์ƒ์„ฑ ๊ธธ์ด (๋ฌธ์ž)
149
+ - avg_word_count: ํ‰๊ท  ๋‹จ์–ด ์ˆ˜
150
+ - repetition_rate: n-gram ๋ฐ˜๋ณต๋ฅ  (๋‚ฎ์„์ˆ˜๋ก ์ข‹์Œ)
151
+ - lexical_diversity: ๊ณ ์œ  ๋‹จ์–ด ๋น„์œจ (๋†’์„์ˆ˜๋ก ๋‹ค์–‘)
152
+ - sample_diversity: ์ƒ˜ํ”Œ ๊ฐ„ ๋‹ค์–‘์„ฑ (๋‹ค๋ฅธ ์ƒ์„ฑ๋ผ๋ฆฌ ์–ผ๋งˆ๋‚˜ ๋‹ค๋ฅธ๊ฐ€)
153
+ """
154
+ if not texts:
155
+ return {}
156
+
157
+ # ๊ธธ์ด
158
+ lengths = [len(t) for t in texts]
159
+ word_counts = [len(t.split()) for t in texts]
160
+
161
+ # ๋ฐ˜๋ณต๋ฅ  (4-gram ๊ธฐ์ค€)
162
+ rep_rates = []
163
+ for text in texts:
164
+ words = text.lower().split()
165
+ if len(words) < 4:
166
+ rep_rates.append(0.0)
167
+ continue
168
+ ngrams = [tuple(words[i:i+4]) for i in range(len(words)-3)]
169
+ unique_ratio = len(set(ngrams)) / len(ngrams) if ngrams else 1.0
170
+ rep_rates.append(1.0 - unique_ratio) # ๋ฐ˜๋ณต๋ฅ  = 1 - ๊ณ ์œ ๋น„์œจ
171
+
172
+ # ์–ดํœ˜ ๋‹ค์–‘์„ฑ (Type-Token Ratio)
173
+ diversities = []
174
+ for text in texts:
175
+ words = text.lower().split()
176
+ if words:
177
+ diversities.append(len(set(words)) / len(words))
178
+ else:
179
+ diversities.append(0.0)
180
+
181
+ # ์ƒ˜ํ”Œ ๊ฐ„ ๋‹ค์–‘์„ฑ (์ž์นด๋“œ ์œ ์‚ฌ๋„์˜ ์—ญ)
182
+ sample_div = 0.0
183
+ if len(texts) > 1:
184
+ word_sets = [set(t.lower().split()) for t in texts]
185
+ similarities = []
186
+ for i in range(len(word_sets)):
187
+ for j in range(i+1, len(word_sets)):
188
+ inter = len(word_sets[i] & word_sets[j])
189
+ union = len(word_sets[i] | word_sets[j])
190
+ if union > 0:
191
+ similarities.append(inter / union)
192
+ sample_div = 1.0 - (sum(similarities) / max(len(similarities), 1))
193
+
194
+ return {
195
+ "avg_length": sum(lengths) / len(lengths),
196
+ "avg_word_count": sum(word_counts) / len(word_counts),
197
+ "repetition_rate": sum(rep_rates) / len(rep_rates),
198
+ "lexical_diversity": sum(diversities) / len(diversities),
199
+ "sample_diversity": round(sample_div, 3),
200
+ }
llm_lab/evaluation/perplexity.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Perplexity(PPL) ํ‰๊ฐ€๊ธฐ."""
2
+
3
+ import math
4
+ import time
5
+ from typing import Dict, List
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import DataLoader
11
+
12
+ from llm_lab.config import EvalConfig
13
+
14
+
15
+ class PerplexityEvaluator:
16
+ """Perplexity(PPL)๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
17
+
18
+ Perplexity๋ž€?
19
+ PPL = exp(average cross-entropy loss)
20
+
21
+ ์ง๊ด€์  ์˜๋ฏธ:
22
+ - PPL = 1: ์™„๋ฒฝํ•œ ์˜ˆ์ธก (๋ถˆ๊ฐ€๋Šฅ)
23
+ - PPL = 10: ๋งค๋ฒˆ 10๊ฐœ ํ›„๋ณด ์ค‘ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€
24
+ - PPL = 100: 100๊ฐœ ํ›„๋ณด ์ค‘ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€ (๋ฌด์ž‘์œ„์— ๊ฐ€๊นŒ์›€)
25
+ - PPL = 32000: vocab ์ „์ฒด์—์„œ ๋žœ๋ค ์„ ํƒ (์ดˆ๊ธฐ ๋žœ๋ค ๋ชจ๋ธ)
26
+
27
+ ์ข‹์€ 1B ๋ชจ๋ธ ๊ธฐ์ค€ (์˜์–ด ์›น ํ…์ŠคํŠธ):
28
+ - 5B ํ† ํฐ ํ•™์Šต: PPL ~30-40
29
+ - 10B ํ† ํฐ ํ•™์Šต: PPL ~20-30
30
+ - 20B ํ† ํฐ ํ•™์Šต: PPL ~15-25
31
+
32
+ ์ธก์ • ๋ฐฉ๋ฒ•:
33
+ - ๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์…‹์˜ ๋ชจ๋“  ํ† ํฐ์— ๋Œ€ํ•ด cross-entropy ๊ณ„์‚ฐ
34
+ - ํ† ํฐ ๋‹จ์œ„ ํ‰๊ท  ํ›„ exp() ์ ์šฉ
35
+ - ํŒจ๋”ฉ ํ† ํฐ์€ ์ œ์™ธ (ignore_index=-100)
36
+ """
37
+
38
+ def __init__(self, config: EvalConfig):
39
+ self.config = config
40
+
41
+ @torch.no_grad()
42
+ def evaluate(
43
+ self,
44
+ model: nn.Module,
45
+ dataloader: DataLoader,
46
+ device: torch.device,
47
+ dtype: torch.dtype = torch.bfloat16,
48
+ desc: str = "Evaluation",
49
+ ) -> Dict[str, float]:
50
+ """Perplexity๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
51
+
52
+ Returns:
53
+ {
54
+ "loss": ํ‰๊ท  cross-entropy loss,
55
+ "perplexity": exp(loss),
56
+ "num_tokens": ํ‰๊ฐ€์— ์‚ฌ์šฉ๋œ ์ด ํ† ํฐ ์ˆ˜,
57
+ "num_batches": ํ‰๊ฐ€์— ์‚ฌ์šฉ๋œ ๋ฐฐ์น˜ ์ˆ˜,
58
+ }
59
+ """
60
+ model.eval()
61
+
62
+ total_loss = 0.0
63
+ total_tokens = 0
64
+ num_batches = 0
65
+
66
+ print(f"\n๐Ÿ“Š {desc}")
67
+ start_time = time.time()
68
+
69
+ for i, batch in enumerate(dataloader):
70
+ if i >= self.config.max_eval_batches:
71
+ break
72
+
73
+ input_ids = batch["input_ids"].to(device)
74
+ targets = batch["targets"].to(device)
75
+
76
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
77
+ logits, _ = model(input_ids)
78
+
79
+ # ํ† ํฐ๋ณ„ cross-entropy (reduction='none')
80
+ # logits: (B, S, V) โ†’ (B*S, V)
81
+ # targets: (B, S) โ†’ (B*S,)
82
+ loss_per_token = F.cross_entropy(
83
+ logits.view(-1, logits.size(-1)),
84
+ targets.view(-1),
85
+ ignore_index=-100,
86
+ reduction="none",
87
+ )
88
+
89
+ # -100์ด ์•„๋‹Œ ์œ ํšจ ํ† ํฐ๋งŒ ์นด์šดํŠธ
90
+ valid_mask = (targets.view(-1) != -100)
91
+ valid_tokens = valid_mask.sum().item()
92
+
93
+ total_loss += loss_per_token[valid_mask].sum().item()
94
+ total_tokens += valid_tokens
95
+ num_batches += 1
96
+
97
+ if (i + 1) % 20 == 0:
98
+ running_ppl = math.exp(min(total_loss / max(total_tokens, 1), 20))
99
+ print(f" Batch {i+1}/{self.config.max_eval_batches}: running PPL = {running_ppl:.2f}")
100
+
101
+ elapsed = time.time() - start_time
102
+ avg_loss = total_loss / max(total_tokens, 1)
103
+ perplexity = math.exp(min(avg_loss, 100)) # overflow ๋ฐฉ์ง€
104
+
105
+ results = {
106
+ "loss": round(avg_loss, 4),
107
+ "perplexity": round(perplexity, 2),
108
+ "num_tokens": total_tokens,
109
+ "num_batches": num_batches,
110
+ "eval_time_sec": round(elapsed, 1),
111
+ }
112
+
113
+ print(f" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€")
114
+ print(f" Loss: {results['loss']:.4f}")
115
+ print(f" Perplexity: {results['perplexity']:.2f}")
116
+ print(f" ํ‰๊ฐ€ ํ† ํฐ: {total_tokens:,}")
117
+ print(f" ์†Œ์š” ์‹œ๊ฐ„: {elapsed:.1f}์ดˆ")
118
+
119
+ return results
120
+
121
+ @torch.no_grad()
122
+ def evaluate_per_position(
123
+ self,
124
+ model: nn.Module,
125
+ dataloader: DataLoader,
126
+ device: torch.device,
127
+ dtype: torch.dtype = torch.bfloat16,
128
+ max_batches: int = 50,
129
+ ) -> List[float]:
130
+ """์‹œํ€€์Šค ๋‚ด ์œ„์น˜๋ณ„ Loss๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
131
+
132
+ ํ•™์Šต ํฌ์ธํŠธ:
133
+ - ์œ„์น˜ 0~10: Loss๊ฐ€ ๋†’์Œ (๋ฌธ๋งฅ์ด ๋ถ€์กฑ)
134
+ - ์œ„์น˜ 100+: Loss๊ฐ€ ์•ˆ์ •์ ์œผ๋กœ ๋‚ฎ์•„์ง (๋ฌธ๋งฅ ํ™œ์šฉ)
135
+ - ์ด ํŒจํ„ด์ด Transformer์˜ in-context learning ๋Šฅ๋ ฅ์„ ๋ณด์—ฌ์คŒ
136
+ """
137
+ model.eval()
138
+ seq_len = None
139
+ position_loss_sum = None
140
+ position_count = None
141
+
142
+ for i, batch in enumerate(dataloader):
143
+ if i >= max_batches:
144
+ break
145
+
146
+ input_ids = batch["input_ids"].to(device)
147
+ targets = batch["targets"].to(device)
148
+ B, S = targets.shape
149
+
150
+ if seq_len is None:
151
+ seq_len = S
152
+ position_loss_sum = torch.zeros(S, device=device)
153
+ position_count = torch.zeros(S, device=device)
154
+
155
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
156
+ logits, _ = model(input_ids)
157
+
158
+ # (B, S) ํ˜•ํƒœ์˜ ํ† ํฐ๋ณ„ loss
159
+ loss_per_token = F.cross_entropy(
160
+ logits.view(-1, logits.size(-1)),
161
+ targets.view(-1),
162
+ ignore_index=-100,
163
+ reduction="none",
164
+ ).view(B, S)
165
+
166
+ valid_mask = (targets != -100).float()
167
+ position_loss_sum += (loss_per_token * valid_mask).sum(dim=0)
168
+ position_count += valid_mask.sum(dim=0)
169
+
170
+ # ์œ„์น˜๋ณ„ ํ‰๊ท  loss
171
+ position_avg_loss = (position_loss_sum / position_count.clamp(min=1)).cpu().tolist()
172
+ return position_avg_loss
llm_lab/evaluation/runner.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ‰๊ฐ€ ์‹คํ–‰ ํ—ฌํผ (Quick Start)."""
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+
9
+ from llm_lab.config import EvalConfig
10
+ from .full_evaluator import FullEvaluator
11
+ from .checklist import InsightChecklist
12
+
13
+
14
+ def run_evaluation(
15
+ model: nn.Module,
16
+ tokenizer: Any,
17
+ val_dataloader: DataLoader,
18
+ device: torch.device = None,
19
+ dtype: torch.dtype = torch.bfloat16,
20
+ metrics_history: Optional[Dict[str, list]] = None,
21
+ config: Optional[EvalConfig] = None,
22
+ ) -> Dict[str, Any]:
23
+ """ํ‰๊ฐ€๋ฅผ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
24
+
25
+ ์‚ฌ์šฉ๋ฒ• (Colab):
26
+ ```python
27
+ from llm_lab.evaluation import run_evaluation
28
+
29
+ # ํ•™์Šต ์™„๋ฃŒ ํ›„
30
+ report = run_evaluation(
31
+ model=trainer.model,
32
+ tokenizer=tokenizer,
33
+ val_dataloader=val_dl,
34
+ metrics_history=trainer.metrics.history,
35
+ )
36
+ ```
37
+ """
38
+ if device is None:
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ evaluator = FullEvaluator(
42
+ model=model,
43
+ tokenizer=tokenizer,
44
+ val_dataloader=val_dataloader,
45
+ device=device,
46
+ config=config,
47
+ dtype=dtype,
48
+ metrics_history=metrics_history,
49
+ )
50
+
51
+ report = evaluator.run_full_evaluation()
52
+
53
+ # ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ
54
+ InsightChecklist.run_checklist(report, metrics_history)
55
+
56
+ return report
llm_lab/evaluation/scaling.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scaling Law ๋ถ„์„๊ธฐ."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ try:
7
+ import matplotlib
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+ HAS_MATPLOTLIB = True
11
+ except ImportError:
12
+ HAS_MATPLOTLIB = False
13
+
14
+ try:
15
+ import numpy as np
16
+ HAS_NUMPY = True
17
+ except ImportError:
18
+ HAS_NUMPY = False
19
+
20
+
21
+ class ScalingAnalyzer:
22
+ """10M โ†’ 100M โ†’ 1B ๋ชจ๋ธ์˜ Scaling Law๋ฅผ ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
23
+
24
+ Chinchilla Scaling Law (2022):
25
+ - ์ตœ์  ํ•™์Šต: ํ† ํฐ ์ˆ˜ โ‰ˆ 20 ร— ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
26
+ - Loss โˆ N^(-ฮฑ) ร— D^(-ฮฒ) (N=ํŒŒ๋ผ๋ฏธํ„ฐ, D=๋ฐ์ดํ„ฐ)
27
+ - ฮฑ โ‰ˆ 0.076, ฮฒ โ‰ˆ 0.095 (๋…ผ๋ฌธ ๊ธฐ์ค€)
28
+
29
+ ์ด ๋ถ„์„์˜ ๋ชฉ์ :
30
+ - ์šฐ๋ฆฌ ๋ชจ๋ธ์ด Scaling Law๋ฅผ ๋”ฐ๋ฅด๋Š”์ง€ ํ™•์ธ
31
+ - ๋” ํฐ ๋ชจ๋ธ/๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ์˜ ํšจ๊ณผ๋ฅผ ์˜ˆ์ธก
32
+ - ์ปดํ“จํŒ… ์ž์› ๋ฐฐ๋ถ„์˜ ์ตœ์ ์  ์ดํ•ด
33
+ """
34
+
35
+ def __init__(self, save_dir: str = "./eval_results"):
36
+ self.save_dir = Path(save_dir)
37
+ self.save_dir.mkdir(parents=True, exist_ok=True)
38
+
39
+ def analyze(
40
+ self,
41
+ model_results: List[Dict[str, Any]],
42
+ ) -> Dict[str, Any]:
43
+ """์—ฌ๋Ÿฌ ๋ชจ๋ธ ํฌ๊ธฐ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋น„๊ต ๋ถ„์„ํ•ฉ๋‹ˆ๋‹ค.
44
+
45
+ Args:
46
+ model_results: [
47
+ {"name": "10M", "params": 10e6, "tokens": 1e9, "loss": 4.2, "ppl": 66.7},
48
+ {"name": "100M", "params": 100e6, "tokens": 5e9, "loss": 3.5, "ppl": 33.1},
49
+ {"name": "1B", "params": 1.1e9, "tokens": 10e9,"loss": 3.0, "ppl": 20.1},
50
+ ]
51
+
52
+ Returns:
53
+ ๋ถ„์„ ๊ฒฐ๊ณผ ๋”•์…”๋„ˆ๋ฆฌ
54
+ """
55
+ if len(model_results) < 2:
56
+ print("โš ๏ธ Scaling ๋ถ„์„์—๋Š” ์ตœ์†Œ 2๊ฐœ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.")
57
+ return {}
58
+
59
+ print("\n" + "=" * 70)
60
+ print("๐Ÿ“ˆ Scaling Law ๋ถ„์„")
61
+ print("=" * 70)
62
+
63
+ # โ”€โ”€ ๊ฒฐ๊ณผ ํ…Œ์ด๋ธ” โ”€โ”€
64
+ print(f"\n {'๋ชจ๋ธ':<8} {'ํŒŒ๋ผ๋ฏธํ„ฐ':>12} {'ํ† ํฐ':>10} {'Loss':>8} {'PPL':>8}")
65
+ print(f" {'โ”€'*52}")
66
+ for r in model_results:
67
+ params_str = f"{r['params']/1e6:.0f}M" if r["params"] < 1e9 else f"{r['params']/1e9:.1f}B"
68
+ tokens_str = f"{r['tokens']/1e9:.1f}B"
69
+ print(f" {r['name']:<8} {params_str:>12} {tokens_str:>10} {r['loss']:>8.4f} {r['ppl']:>8.2f}")
70
+
71
+ # โ”€โ”€ Scaling ํšจ์œจ ๊ณ„์‚ฐ โ”€โ”€
72
+ analysis = {"models": model_results, "scaling_efficiency": []}
73
+
74
+ for i in range(1, len(model_results)):
75
+ prev = model_results[i-1]
76
+ curr = model_results[i]
77
+
78
+ param_ratio = curr["params"] / prev["params"]
79
+ loss_reduction = prev["loss"] - curr["loss"]
80
+ ppl_reduction = (prev["ppl"] - curr["ppl"]) / prev["ppl"]
81
+
82
+ efficiency = {
83
+ "from": prev["name"],
84
+ "to": curr["name"],
85
+ "param_multiplier": round(param_ratio, 1),
86
+ "loss_reduction": round(loss_reduction, 4),
87
+ "ppl_reduction_pct": round(ppl_reduction * 100, 1),
88
+ }
89
+ analysis["scaling_efficiency"].append(efficiency)
90
+
91
+ print(f"\n {prev['name']} โ†’ {curr['name']}:")
92
+ print(f" ํŒŒ๋ผ๋ฏธํ„ฐ ร—{param_ratio:.1f}")
93
+ print(f" Loss ๊ฐ์†Œ: {loss_reduction:.4f}")
94
+ print(f" PPL ๊ฐ์†Œ: {ppl_reduction*100:.1f}%")
95
+
96
+ # โ”€โ”€ Chinchilla ์ตœ์ ์„ฑ ์ฒดํฌ โ”€โ”€
97
+ print(f"\n Chinchilla ์ตœ์ ์„ฑ ์ฒดํฌ (ํ† ํฐ โ‰ˆ 20 ร— ํŒŒ๋ผ๋ฏธํ„ฐ):")
98
+ for r in model_results:
99
+ actual_ratio = r["tokens"] / r["params"]
100
+ status = "โœ… ์ตœ์  ๋ฒ”์œ„" if 15 <= actual_ratio <= 25 else "โš ๏ธ ๋ฒ”์œ„ ๋ฐ–"
101
+ print(f" {r['name']}: ํ† ํฐ/ํŒŒ๋ผ๋ฏธํ„ฐ = {actual_ratio:.1f}x "
102
+ f"(์ตœ์ : 20x) {status}")
103
+
104
+ analysis["chinchilla_ratios"] = [
105
+ {"name": r["name"], "ratio": round(r["tokens"] / r["params"], 1)}
106
+ for r in model_results
107
+ ]
108
+
109
+ return analysis
110
+
111
+ def plot_scaling_curves(
112
+ self,
113
+ model_results: List[Dict[str, Any]],
114
+ save_path: Optional[str] = None,
115
+ ):
116
+ """Scaling ๊ณก์„ ์„ ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."""
117
+ if not HAS_MATPLOTLIB or not HAS_NUMPY:
118
+ print("โš ๏ธ matplotlib/numpy๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค: pip install matplotlib numpy")
119
+ return
120
+
121
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
122
+
123
+ params = [r["params"] for r in model_results]
124
+ losses = [r["loss"] for r in model_results]
125
+ ppls = [r["ppl"] for r in model_results]
126
+ names = [r["name"] for r in model_results]
127
+
128
+ # โ”€โ”€ Loss vs Parameters (log-log) โ”€โ”€
129
+ ax = axes[0]
130
+ ax.loglog(params, losses, "o-", color="#2563eb", linewidth=2, markersize=10)
131
+ for p, l, n in zip(params, losses, names):
132
+ ax.annotate(f" {n}\n Loss={l:.2f}", (p, l), fontsize=9)
133
+ ax.set_xlabel("Parameters", fontsize=12)
134
+ ax.set_ylabel("Validation Loss", fontsize=12)
135
+ ax.set_title("Loss vs Model Size (log-log)", fontsize=13, fontweight="bold")
136
+ ax.grid(True, alpha=0.3)
137
+
138
+ # โ”€โ”€ PPL vs Parameters (log-log) โ”€โ”€
139
+ ax = axes[1]
140
+ ax.loglog(params, ppls, "s-", color="#dc2626", linewidth=2, markersize=10)
141
+ for p, pp, n in zip(params, ppls, names):
142
+ ax.annotate(f" {n}\n PPL={pp:.1f}", (p, pp), fontsize=9)
143
+ ax.set_xlabel("Parameters", fontsize=12)
144
+ ax.set_ylabel("Perplexity", fontsize=12)
145
+ ax.set_title("Perplexity vs Model Size (log-log)", fontsize=13, fontweight="bold")
146
+ ax.grid(True, alpha=0.3)
147
+
148
+ plt.tight_layout()
149
+
150
+ save_path = save_path or str(self.save_dir / "scaling_curves.png")
151
+ fig.savefig(save_path, dpi=150, bbox_inches="tight")
152
+ print(f"\n ๐Ÿ“Š Scaling ๊ณก์„  ์ €์žฅ: {save_path}")
153
+ plt.close(fig)
llm_lab/model/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ๋ชจ๋“ˆ โ€” LLaMA-style Decoder-Only Transformer."""
2
+ from .norm import RMSNorm
3
+ from .rope import RotaryPositionalEmbedding
4
+ from .attention import GroupedQueryAttention
5
+ from .feedforward import SwiGLUFeedForward
6
+ from .transformer_block import TransformerBlock
7
+ from .llm_model import LLMModel
8
+ from .utils import count_parameters_detailed, estimate_memory_gb
9
+
10
+ __all__ = [
11
+ "RMSNorm", "RotaryPositionalEmbedding", "GroupedQueryAttention",
12
+ "SwiGLUFeedForward", "TransformerBlock", "LLMModel",
13
+ "count_parameters_detailed", "estimate_memory_gb",
14
+ ]
llm_lab/model/attention.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grouped Query Attention (GQA)."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from llm_lab.config import ModelConfig
10
+ from .rope import RotaryPositionalEmbedding
11
+
12
+
13
+ class GroupedQueryAttention(nn.Module):
14
+ """GQA: Multi-Head Attention์˜ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์  ๋ณ€ํ˜•.
15
+
16
+ MHA vs GQA vs MQA:
17
+ - MHA (Multi-Head Attention): Q, K, V ๋ชจ๋‘ num_heads๊ฐœ โ†’ ๋ฉ”๋ชจ๋ฆฌ ํผ
18
+ - MQA (Multi-Query Attention): K, V๋Š” 1๊ฐœ ํ—ค๋“œ ๊ณต์œ  โ†’ ํ’ˆ์งˆ ์ €ํ•˜ ์šฐ๋ ค
19
+ - GQA (Grouped Query Attention): K, V๋ฅผ num_kv_heads๊ฐœ๋กœ ๊ทธ๋ฃนํ™”
20
+ โ†’ MHA์™€ MQA์˜ ์ค‘๊ฐ„, ์ข‹์€ ํ’ˆ์งˆ-ํšจ์œจ ๊ท ํ˜•
21
+
22
+ ์˜ˆ์‹œ (num_heads=16, num_kv_heads=4):
23
+ Q ํ—ค๋“œ: [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15]
24
+ K/V ๊ทธ๋ฃน: [ 0 , 1 , 2 , 3 ]
25
+ โ†’ Q ํ—ค๋“œ 4๊ฐœ๊ฐ€ K/V ํ—ค๋“œ 1๊ฐœ๋ฅผ ๊ณต์œ 
26
+
27
+ Attention ์ˆ˜์‹:
28
+ Attention(Q, K, V) = softmax(QยทK^T / โˆšd_k) ยท V
29
+ """
30
+
31
+ def __init__(self, config: ModelConfig):
32
+ super().__init__()
33
+ self.config = config
34
+ self.head_dim = config.head_dim
35
+ self.num_heads = config.num_heads
36
+ self.num_kv_heads = config.num_kv_heads
37
+ self.num_kv_groups = config.num_kv_groups # num_heads // num_kv_heads
38
+
39
+ # Q/K/V ํ”„๋กœ์ ์…˜
40
+ # Q: hidden_dim โ†’ num_heads ร— head_dim
41
+ self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim, bias=False)
42
+ # K, V: hidden_dim โ†’ num_kv_heads ร— head_dim (Q๋ณด๋‹ค ์ž‘์Œ!)
43
+ self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
44
+ self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim, bias=False)
45
+
46
+ # ์ถœ๋ ฅ ํ”„๋กœ์ ์…˜: ๋ชจ๋“  ํ—ค๋“œ์˜ ์ถœ๋ ฅ์„ ๋‹ค์‹œ hidden_dim์œผ๋กœ
47
+ self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim, bias=False)
48
+
49
+ # RoPE
50
+ self.rope = RotaryPositionalEmbedding(
51
+ dim=self.head_dim, max_seq_len=config.max_seq_len, theta=config.rope_theta
52
+ )
53
+
54
+ # Attention dropout (pretraining์—์„œ๋Š” ๋ณดํ†ต 0)
55
+ self.attn_dropout = nn.Dropout(config.dropout)
56
+
57
+ def forward(
58
+ self,
59
+ x: torch.Tensor,
60
+ mask: Optional[torch.Tensor] = None,
61
+ position_offset: int = 0,
62
+ ) -> torch.Tensor:
63
+ """
64
+ Args:
65
+ x: (batch_size, seq_len, hidden_dim)
66
+ mask: (seq_len, seq_len) causal mask
67
+ position_offset: ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ ์‚ฌ์šฉ)
68
+
69
+ Returns:
70
+ (batch_size, seq_len, hidden_dim)
71
+ """
72
+ B, S, _ = x.shape
73
+
74
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
75
+ # Step 1: Q, K, V ํ”„๋กœ์ ์…˜
76
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
77
+ q = self.q_proj(x) # (B, S, num_heads ร— head_dim)
78
+ k = self.k_proj(x) # (B, S, num_kv_heads ร— head_dim)
79
+ v = self.v_proj(x) # (B, S, num_kv_heads ร— head_dim)
80
+
81
+ # ๋ฉ€ํ‹ฐํ—ค๋“œ ํ˜•ํƒœ๋กœ reshape
82
+ q = q.view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
83
+ # โ†’ (B, num_heads, S, head_dim)
84
+ k = k.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
85
+ # โ†’ (B, num_kv_heads, S, head_dim)
86
+ v = v.view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
87
+
88
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
89
+ # Step 2: RoPE ์ ์šฉ (Q, K์—๋งŒ! V์—๋Š” ์ ์šฉํ•˜์ง€ ์•Š์Œ)
90
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
91
+ # ์œ„์น˜ ์ •๋ณด๋Š” "์–ด๋””๋ฅผ ๋ณผ์ง€"(QยทK)์—๋งŒ ์˜ํ–ฅ์„ ์ค˜์•ผ ํ•˜๊ณ ,
92
+ # "๋ฌด์—‡์„ ๊ฐ€์ ธ์˜ฌ์ง€"(V)์—๋Š” ์˜ํ–ฅ์„ ์ฃผ๋ฉด ์•ˆ ๋ฉ๋‹ˆ๋‹ค.
93
+ q, k = self.rope(q, k, position_offset)
94
+
95
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
96
+ # Step 3: GQA - KV ํ—ค๋“œ ํ™•์žฅ (repeat)
97
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
98
+ # num_kv_heads=4 โ†’ num_heads=16: ๊ฐ KV๋ฅผ 4๋ฒˆ ๋ฐ˜๋ณต
99
+ if self.num_kv_groups > 1:
100
+ k = self._repeat_kv(k) # (B, num_heads, S, head_dim)
101
+ v = self._repeat_kv(v)
102
+
103
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
104
+ # Step 4: Scaled Dot-Product Attention
105
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
106
+ # PyTorch >= 2.0์˜ ์ตœ์ ํ™”๋œ ๊ตฌํ˜„ ์‚ฌ์šฉ (Flash Attention ์ž๋™ ์ ์šฉ)
107
+ attn_out = F.scaled_dot_product_attention(
108
+ q, k, v,
109
+ attn_mask=mask,
110
+ dropout_p=self.config.dropout if self.training else 0.0,
111
+ is_causal=(mask is None), # mask๊ฐ€ ์—†์œผ๋ฉด ์ž๋™ causal masking
112
+ )
113
+ # โ†’ (B, num_heads, S, head_dim)
114
+
115
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
116
+ # Step 5: ํ—ค๋“œ ํ•ฉ์น˜๊ธฐ + ์ถœ๋ ฅ ํ”„๋กœ์ ์…˜
117
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
118
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1)
119
+ # โ†’ (B, S, num_heads ร— head_dim)
120
+
121
+ return self.o_proj(attn_out) # โ†’ (B, S, hidden_dim)
122
+
123
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
124
+ """KV ํ—ค๋“œ๋ฅผ Q ํ—ค๋“œ ์ˆ˜์— ๋งž๊ฒŒ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
125
+
126
+ (B, num_kv_heads, S, head_dim) โ†’ (B, num_heads, S, head_dim)
127
+
128
+ ์˜ˆ: num_kv_heads=4, num_kv_groups=4
129
+ [kv0, kv1, kv2, kv3] โ†’ [kv0,kv0,kv0,kv0, kv1,kv1,kv1,kv1, ...]
130
+ """
131
+ B, H_kv, S, D = x.shape
132
+ x = x[:, :, None, :, :] # (B, H_kv, 1, S, D)
133
+ x = x.expand(B, H_kv, self.num_kv_groups, S, D) # (B, H_kv, groups, S, D)
134
+ return x.reshape(B, self.num_heads, S, D)
llm_lab/model/feedforward.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SwiGLU Feed-Forward Network."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from llm_lab.config import ModelConfig
8
+
9
+
10
+ class SwiGLUFeedForward(nn.Module):
11
+ """SwiGLU: Gated Linear Unit with Swish ํ™œ์„ฑํ™” ํ•จ์ˆ˜.
12
+
13
+ ๊ธฐ์กด FFN:
14
+ FFN(x) = ReLU(xยทW1 + b1)ยทW2 + b2
15
+ โ†’ ๋‹จ์ˆœํ•œ ๋น„์„ ํ˜• ๋ณ€ํ™˜
16
+
17
+ SwiGLU FFN:
18
+ SwiGLU(x) = (Swish(xยทW_gate) โŠ™ (xยทW_up)) ยท W_down
19
+ โ†’ ๊ฒŒ์ดํŒ… ๋ฉ”์ปค๋‹ˆ์ฆ˜์œผ๋กœ ์ •๋ณด ํ๋ฆ„์„ ์ œ์–ด
20
+
21
+ ์™œ SwiGLU๊ฐ€ ๋” ์ข‹์€๊ฐ€?
22
+ - Swish(x) = x ยท sigmoid(x): ๋ถ€๋“œ๋Ÿฌ์šด ํ™œ์„ฑํ™”, ์Œ์ˆ˜ ์˜์—ญ ์ผ๋ถ€ ํ—ˆ์šฉ
23
+ - Gate ๋ฒกํ„ฐ๊ฐ€ "์–ด๋–ค ์ •๋ณด๋ฅผ ํ†ต๊ณผ์‹œํ‚ฌ์ง€" ํ•™์Šต
24
+ - PaLM, LLaMA ๋“ฑ์—์„œ ReLU FFN ๋Œ€๋น„ ์ผ๊ด€๋œ ์„ฑ๋Šฅ ํ–ฅ์ƒ ๋ณด๊ณ 
25
+
26
+ ์ฐธ๊ณ : W_gate์™€ W_up ๋‘ ๊ฐœ์˜ up-projection์ด ์žˆ์–ด์„œ
27
+ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ๊ธฐ์กด FFN ๋Œ€๋น„ 1.5๋ฐฐ์ด์ง€๋งŒ, intermediate_dim์„
28
+ ์กฐ์ •ํ•˜์—ฌ ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ๋งž์ถฅ๋‹ˆ๋‹ค.
29
+ """
30
+
31
+ def __init__(self, config: ModelConfig):
32
+ super().__init__()
33
+ # ๊ฒŒ์ดํŠธ ํ”„๋กœ์ ์…˜: hidden_dim โ†’ intermediate_dim
34
+ self.gate_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
35
+ # ์—… ํ”„๋กœ์ ์…˜: hidden_dim โ†’ intermediate_dim
36
+ self.up_proj = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
37
+ # ๋‹ค์šด ํ”„๋กœ์ ์…˜: intermediate_dim โ†’ hidden_dim
38
+ self.down_proj = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ # SwiGLU(x) = (Swish(gate(x)) โŠ™ up(x)) ยท down
42
+ #
43
+ # 1) gate: ์–ด๋–ค ์ •๋ณด๋ฅผ ํ†ต๊ณผ์‹œํ‚ฌ์ง€ ๊ฒฐ์ • (Swish ํ™œ์„ฑํ™”)
44
+ gate = F.silu(self.gate_proj(x)) # silu = Swish = x * sigmoid(x)
45
+ # 2) up: ์ •๋ณด๋ฅผ ๊ณ ์ฐจ์›์œผ๋กœ ์‚ฌ์˜
46
+ up = self.up_proj(x)
47
+ # 3) element-wise ๊ณฑ (๊ฒŒ์ดํŒ…) โ†’ ๋‹ค์‹œ ์›๋ž˜ ์ฐจ์›์œผ๋กœ
48
+ return self.down_proj(gate * up)
llm_lab/model/llm_model.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full Transformer Model (LLaMA-style)."""
2
+
3
+ import math
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from llm_lab.config import ModelConfig
11
+ from .norm import RMSNorm
12
+ from .transformer_block import TransformerBlock
13
+
14
+
15
+ class LLMModel(nn.Module):
16
+ """1B ํŒŒ๋ผ๋ฏธํ„ฐ LLaMA-style Decoder-Only Transformer.
17
+
18
+ ์ „์ฒด ๊ตฌ์กฐ:
19
+ Input Token IDs
20
+ โ†’ Token Embedding
21
+ โ†’ [TransformerBlock] ร— num_layers (+ Activation Checkpointing)
22
+ โ†’ RMSNorm (์ตœ์ข…)
23
+ โ†’ Linear Head (โ†’ vocab logits)
24
+
25
+ Weight Tying:
26
+ - ์ž…๋ ฅ Embedding๊ณผ ์ถœ๋ ฅ Linear Head์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณต์œ 
27
+ - ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ์ ˆ์•ฝ (~65M) + ์„ฑ๋Šฅ ์œ ์ง€/ํ–ฅ์ƒ
28
+ - ์ง๊ด€: "๋‹จ์–ด์˜ ์˜๋ฏธ ํ‘œํ˜„"๊ณผ "๋‹จ์–ด ์˜ˆ์ธก"์ด ๊ฐ™์€ ๊ณต๊ฐ„์„ ์‚ฌ์šฉ
29
+ """
30
+
31
+ def __init__(self, config: ModelConfig):
32
+ super().__init__()
33
+ self.config = config
34
+
35
+ # โ”€โ”€ Token Embedding โ”€โ”€
36
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
37
+
38
+ # โ”€โ”€ Transformer Blocks โ”€โ”€
39
+ self.layers = nn.ModuleList([
40
+ TransformerBlock(config, layer_idx=i)
41
+ for i in range(config.num_layers)
42
+ ])
43
+
44
+ # โ”€โ”€ ์ตœ์ข… ์ •๊ทœํ™” โ”€โ”€
45
+ self.final_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
46
+
47
+ # โ”€โ”€ ์ถœ๋ ฅ ํ—ค๋“œ (Weight Tying) โ”€โ”€
48
+ self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
49
+ # Weight Tying: lm_head์˜ ๊ฐ€์ค‘์น˜ = token_embedding์˜ ๊ฐ€์ค‘์น˜
50
+ self.lm_head.weight = self.token_embedding.weight
51
+
52
+ # ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”
53
+ self._init_weights()
54
+
55
+ def _init_weights(self):
56
+ """๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™” ์ „๋žต.
57
+
58
+ ์™œ ์ดˆ๊ธฐํ™”๊ฐ€ ์ค‘์š”ํ•œ๊ฐ€?
59
+ - ๋„ˆ๋ฌด ํฌ๋ฉด: ํ™œ์„ฑํ™” ํญ๋ฐœ โ†’ NaN
60
+ - ๋„ˆ๋ฌด ์ž‘์œผ๋ฉด: gradient ์†Œ๋ฉธ โ†’ ํ•™์Šต ์ •์ฒด
61
+ - ์ ์ ˆํ•œ ์ดˆ๊ธฐํ™”: ๊ฐ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ ๋ถ„์‚ฐ์„ ์ผ์ •ํ•˜๊ฒŒ ์œ ์ง€
62
+
63
+ GPT-2 ์Šคํƒ€์ผ ์ดˆ๊ธฐํ™”:
64
+ - ์ผ๋ฐ˜ Linear: N(0, 0.02)
65
+ - Residual projection: N(0, 0.02 / โˆš(2 ร— num_layers))
66
+ โ†’ ๋ ˆ์ด์–ด๊ฐ€ ๊นŠ์–ด์งˆ์ˆ˜๋ก residual ๊ธฐ์—ฌ๋ฅผ ์ค„์—ฌ ์•ˆ์ •ํ™”
67
+ """
68
+ std = 0.02
69
+ residual_std = std / math.sqrt(2 * self.config.num_layers)
70
+
71
+ for module in self.modules():
72
+ if isinstance(module, nn.Linear):
73
+ nn.init.normal_(module.weight, mean=0.0, std=std)
74
+ if module.bias is not None:
75
+ nn.init.zeros_(module.bias)
76
+ elif isinstance(module, nn.Embedding):
77
+ nn.init.normal_(module.weight, mean=0.0, std=std)
78
+
79
+ # Residual projection ๋ ˆ์ด์–ด์— ์ถ•์†Œ๋œ ์ดˆ๊ธฐํ™” ์ ์šฉ
80
+ for layer in self.layers:
81
+ nn.init.normal_(layer.attention.o_proj.weight, mean=0.0, std=residual_std)
82
+ nn.init.normal_(layer.feed_forward.down_proj.weight, mean=0.0, std=residual_std)
83
+
84
+ def forward(
85
+ self,
86
+ input_ids: torch.Tensor,
87
+ targets: Optional[torch.Tensor] = None,
88
+ position_offset: int = 0,
89
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
90
+ """
91
+ Args:
92
+ input_ids: (batch_size, seq_len) - ํ† ํฐ ID
93
+ targets: (batch_size, seq_len) - ์ •๋‹ต ํ† ํฐ ID (ํ•™์Šต ์‹œ)
94
+ position_offset: ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ)
95
+
96
+ Returns:
97
+ logits: (batch_size, seq_len, vocab_size)
98
+ loss: ์Šค์นผ๋ผ (targets ์ œ๊ณต ์‹œ) ๋˜๋Š” None
99
+ """
100
+ B, S = input_ids.shape
101
+
102
+ # โ”€โ”€ Step 1: Token Embedding โ”€โ”€
103
+ # ๊ฐ ํ† ํฐ ID๋ฅผ hidden_dim ์ฐจ์›์˜ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
104
+ h = self.token_embedding(input_ids) # (B, S, hidden_dim)
105
+
106
+ # โ”€โ”€ Step 2: Transformer Blocks โ”€โ”€
107
+ # Activation Checkpointing: ํ•™์Šต ์‹œ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
108
+ # (์ค‘๊ฐ„ ํ™œ์„ฑํ™”๋ฅผ ์ €์žฅํ•˜์ง€ ์•Š๊ณ , backward ์‹œ ์žฌ๊ณ„์‚ฐ)
109
+ for layer in self.layers:
110
+ if self.training and torch.is_grad_enabled():
111
+ # Activation Checkpointing ์ ์šฉ
112
+ h = torch.utils.checkpoint.checkpoint(
113
+ layer, h, None, position_offset,
114
+ use_reentrant=False, # PyTorch >= 2.0 ๊ถŒ์žฅ
115
+ )
116
+ else:
117
+ h = layer(h, mask=None, position_offset=position_offset)
118
+
119
+ # โ”€โ”€ Step 3: ์ตœ์ข… ์ •๊ทœํ™” โ”€โ”€
120
+ h = self.final_norm(h)
121
+
122
+ # โ”€โ”€ Step 4: ์ถœ๋ ฅ ๋กœ์ง“ ๊ณ„์‚ฐ โ”€โ”€
123
+ logits = self.lm_head(h) # (B, S, vocab_size)
124
+
125
+ # โ”€โ”€ Step 5: Loss ๊ณ„์‚ฐ (ํ•™์Šต ์‹œ) โ”€โ”€
126
+ loss = None
127
+ if targets is not None:
128
+ # Cross-Entropy Loss: ๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก
129
+ # logits: (B, S, V) โ†’ (B*S, V)
130
+ # targets: (B, S) โ†’ (B*S,)
131
+ loss = F.cross_entropy(
132
+ logits.view(-1, self.config.vocab_size),
133
+ targets.view(-1),
134
+ ignore_index=-100, # ํŒจ๋”ฉ ํ† ํฐ ๋ฌด์‹œ
135
+ )
136
+
137
+ return logits, loss
138
+
139
+ def count_parameters(self, trainable_only: bool = True) -> int:
140
+ """๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ."""
141
+ if trainable_only:
142
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
143
+ return sum(p.numel() for p in self.parameters())
144
+
145
+ @torch.no_grad()
146
+ def generate(
147
+ self,
148
+ input_ids: torch.Tensor,
149
+ max_new_tokens: int = 100,
150
+ temperature: float = 1.0,
151
+ top_k: int = 50,
152
+ top_p: float = 0.9,
153
+ ) -> torch.Tensor:
154
+ """ํ…์ŠคํŠธ ์ƒ์„ฑ (์ถ”๋ก ).
155
+
156
+ Autoregressive ์ƒ์„ฑ: ํ•œ ํ† ํฐ์”ฉ ์˜ˆ์ธกํ•˜์—ฌ ์ด์–ด๋ถ™์ด๊ธฐ.
157
+
158
+ Args:
159
+ input_ids: (1, prompt_len) - ์ดˆ๊ธฐ ํ”„๋กฌํ”„ํŠธ
160
+ max_new_tokens: ์ƒ์„ฑํ•  ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜
161
+ temperature: ํ™•๋ฅ  ๋ถ„ํฌ ๋‚ ์นด๋กœ์›€ ์กฐ์ ˆ (๋‚ฎ์„์ˆ˜๋ก ๋ณด์ˆ˜์ )
162
+ top_k: ํ™•๋ฅ  ์ƒ์œ„ k๊ฐœ๋งŒ ๊ณ ๋ ค
163
+ top_p: ๋ˆ„์  ํ™•๋ฅ  p๊นŒ์ง€๋งŒ ๊ณ ๋ ค (nucleus sampling)
164
+ """
165
+ self.eval()
166
+ generated = input_ids
167
+
168
+ for _ in range(max_new_tokens):
169
+ # ํ˜„์žฌ ์‹œํ€€์Šค๊ฐ€ max_seq_len์„ ์ดˆ๊ณผํ•˜๋ฉด ์ž˜๋ผ๋‚ด๊ธฐ
170
+ ctx = generated[:, -self.config.max_seq_len:]
171
+
172
+ # Forward pass
173
+ logits, _ = self(ctx)
174
+ # ๋งˆ์ง€๋ง‰ ํ† ํฐ์˜ logits๋งŒ ์‚ฌ์šฉ (๋‹ค์Œ ํ† ํฐ ์˜ˆ์ธก)
175
+ next_logits = logits[:, -1, :] / temperature
176
+
177
+ # โ”€โ”€ Top-K ํ•„ํ„ฐ๋ง โ”€โ”€
178
+ if top_k > 0:
179
+ top_k_values, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
180
+ min_top_k = top_k_values[:, -1].unsqueeze(-1)
181
+ next_logits = next_logits.masked_fill(next_logits < min_top_k, float("-inf"))
182
+
183
+ # โ”€โ”€ Top-P (Nucleus) ํ•„ํ„ฐ๋ง โ”€โ”€
184
+ if top_p < 1.0:
185
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
186
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
187
+ # ๋ˆ„์  ํ™•๋ฅ ์ด top_p๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ํ† ํฐ ์ œ๊ฑฐ
188
+ remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= top_p
189
+ sorted_logits[remove_mask] = float("-inf")
190
+ # ์›๋ž˜ ์ˆœ์„œ๋กœ ๋ณต์›
191
+ next_logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
192
+
193
+ # ํ™•๋ฅ  ๋ถ„ํฌ์—์„œ ์ƒ˜ํ”Œ๋ง
194
+ probs = F.softmax(next_logits, dim=-1)
195
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
196
+
197
+ # ์ƒ์„ฑ๋œ ํ† ํฐ ์ด์–ด๋ถ™์ด๊ธฐ
198
+ generated = torch.cat([generated, next_token], dim=1)
199
+
200
+ return generated
llm_lab/model/norm.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RMSNorm (Root Mean Square Layer Normalization)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class RMSNorm(nn.Module):
8
+ """RMSNorm: LayerNorm์˜ ๊ฒฝ๋Ÿ‰ํ™” ๋ฒ„์ „.
9
+
10
+ ์ผ๋ฐ˜ LayerNorm๊ณผ์˜ ์ฐจ์ด:
11
+ - ํ‰๊ท (mean)์„ ๋นผ์ง€ ์•Š์Œ โ†’ ์—ฐ์‚ฐ ์ ˆ์•ฝ
12
+ - ๋ถ„์‚ฐ ๋Œ€์‹  RMS(Root Mean Square)๋กœ ์ •๊ทœํ™”
13
+ - bias ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์Œ
14
+
15
+ ์ˆ˜์‹:
16
+ RMSNorm(x) = (x / RMS(x)) * ฮณ
17
+ RMS(x) = sqrt(mean(xยฒ) + ฮต)
18
+
19
+ ์™œ ์ •๊ทœํ™”๊ฐ€ ํ•„์š”ํ•œ๊ฐ€?
20
+ โ†’ ๋ ˆ์ด์–ด๋ฅผ ๊นŠ๊ฒŒ ์Œ“์œผ๋ฉด ํ™œ์„ฑํ™” ๊ฐ’์˜ ์Šค์ผ€์ผ์ด ํญ๋ฐœํ•˜๊ฑฐ๋‚˜ ์†Œ๋ฉธํ•ฉ๋‹ˆ๋‹ค.
21
+ โ†’ ์ •๊ทœํ™”๋กœ ๊ฐ ๋ ˆ์ด์–ด์˜ ์ž…๋ ฅ์„ ์•ˆ์ •์ ์ธ ๋ฒ”์œ„๋กœ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
22
+ """
23
+
24
+ def __init__(self, dim: int, eps: float = 1e-6):
25
+ super().__init__()
26
+ self.eps = eps
27
+ # ฮณ (gamma): ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์Šค์ผ€์ผ ํŒŒ๋ผ๋ฏธํ„ฐ, 1๋กœ ์ดˆ๊ธฐํ™”
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ # 1) ์ž…๋ ฅ์„ float32๋กœ ๋ณ€ํ™˜ (์ˆ˜์น˜ ์•ˆ์ •์„ฑ)
32
+ # bf16/fp16 ์ƒํƒœ์—์„œ ์ œ๊ณฑํ•ฉ์„ ๊ตฌํ•˜๋ฉด ์˜ค๋ฒ„ํ”Œ๋กœ์šฐ ์œ„ํ—˜
33
+ x_float = x.float()
34
+
35
+ # 2) RMS ๊ณ„์‚ฐ: sqrt(mean(xยฒ) + ฮต)
36
+ rms = torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
37
+ # rsqrt = 1/sqrt(x) โ†’ ๋‚˜๋ˆ—์…ˆ ๋Œ€์‹  ๊ณฑ์…ˆ์œผ๋กœ ๋Œ€์ฒด (๋” ๋น ๋ฆ„)
38
+
39
+ # 3) ์ •๊ทœํ™” ํ›„ ์›๋ž˜ dtype์œผ๋กœ ๋ณต์›, ์Šค์ผ€์ผ ์ ์šฉ
40
+ return (x_float * rms).to(x.dtype) * self.weight
llm_lab/model/rope.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rotary Positional Embedding (RoPE)."""
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class RotaryPositionalEmbedding(nn.Module):
10
+ """RoPE: ํšŒ์ „ ํ–‰๋ ฌ์„ ์ด์šฉํ•œ ์ƒ๋Œ€ ์œ„์น˜ ์ธ์ฝ”๋”ฉ.
11
+
12
+ ํ•ต์‹ฌ ์•„์ด๋””์–ด:
13
+ - ๊ฐ ์ฐจ์› ์Œ(2i, 2i+1)์„ 2D ํ‰๋ฉด์˜ ์ขŒํ‘œ๋กœ ๋ณด๊ณ ,
14
+ ์œ„์น˜(position)์— ๋น„๋ก€ํ•œ ๊ฐ๋„๋งŒํผ ํšŒ์ „์‹œํ‚ต๋‹ˆ๋‹ค.
15
+ - ๋‘ ํ† ํฐ์˜ ์–ดํ…์…˜ ์Šค์ฝ”์–ด(QยทK)๋Š” ์ƒ๋Œ€ ๊ฑฐ๋ฆฌ์—๋งŒ ์˜์กดํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
16
+
17
+ ์™œ RoPE์ธ๊ฐ€?
18
+ - ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ: ๊ฐ ์œ„์น˜์— ๊ณ ์ • ๋ฒกํ„ฐ๋ฅผ ๋”ํ•จ โ†’ ๊ธธ์ด ์ผ๋ฐ˜ํ™” ์–ด๋ ค์›€
19
+ - ์ƒ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ: ๊ตฌํ˜„ ๋ณต์žก, ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ ํ•„์š”
20
+ - RoPE: ํŒŒ๋ผ๋ฏธํ„ฐ ์—†์ด, ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ƒ๋Œ€ ์œ„์น˜ ์ •๋ณด ์ธ์ฝ”๋”ฉ
21
+
22
+ ์ˆ˜์‹:
23
+ ฮธ_i = theta^(-2i/d) (i = 0, 1, ..., d/2-1)
24
+ RoPE(x, pos) = x๋ฅผ ๊ฐ ์ฐจ์› ์Œ์—์„œ pos ร— ฮธ_i ๋งŒํผ ํšŒ์ „
25
+ """
26
+
27
+ def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.max_seq_len = max_seq_len
31
+ self.theta = theta
32
+
33
+ # ์ฃผํŒŒ์ˆ˜ ๋ฒกํ„ฐ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ (ํ•™์Šต ๋ถˆํ•„์š” โ†’ buffer๋กœ ๋“ฑ๋ก)
34
+ # freqs[i] = 1 / (theta^(2i/dim)), i = 0, 1, ..., dim/2-1
35
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
36
+ self.register_buffer("freqs", freqs, persistent=False)
37
+
38
+ # (max_seq_len, dim/2) ํฌ๊ธฐ์˜ cos/sin ํ…Œ์ด๋ธ” ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
39
+ self._build_cache(max_seq_len)
40
+
41
+ def _build_cache(self, seq_len: int):
42
+ """cos/sin ๊ฐ’์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐํ•˜์—ฌ ์บ์‹ฑํ•ฉ๋‹ˆ๋‹ค."""
43
+ t = torch.arange(seq_len, device=self.freqs.device, dtype=torch.float32)
44
+ # outer product: (seq_len,) ร— (dim/2,) โ†’ (seq_len, dim/2)
45
+ angles = torch.outer(t, self.freqs)
46
+ self.register_buffer("cos_cached", angles.cos(), persistent=False)
47
+ self.register_buffer("sin_cached", angles.sin(), persistent=False)
48
+
49
+ def forward(
50
+ self, q: torch.Tensor, k: torch.Tensor, position_offset: int = 0
51
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
52
+ """Q, K์— ํšŒ์ „ ๋ณ€ํ™˜์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
53
+
54
+ Args:
55
+ q: (batch, num_heads, seq_len, head_dim)
56
+ k: (batch, num_kv_heads, seq_len, head_dim)
57
+ position_offset: ์‹œํ€€์Šค ์‹œ์ž‘ ์œ„์น˜ ์˜คํ”„์…‹ (์ถ”๋ก  ์‹œ KV ์บ์‹œ ์‚ฌ์šฉ ์‹œ)
58
+
59
+ Returns:
60
+ ํšŒ์ „ ๋ณ€ํ™˜์ด ์ ์šฉ๋œ (q_rotated, k_rotated)
61
+ """
62
+ seq_len = q.shape[2]
63
+
64
+ # ํ•„์š” ์‹œ ์บ์‹œ ํ™•์žฅ
65
+ if position_offset + seq_len > self.cos_cached.shape[0]:
66
+ self._build_cache(position_offset + seq_len)
67
+
68
+ # ํ˜„์žฌ ์œ„์น˜์— ํ•ด๋‹นํ•˜๋Š” cos/sin ์Šฌ๋ผ์ด์Šค
69
+ cos = self.cos_cached[position_offset : position_offset + seq_len] # (seq_len, dim/2)
70
+ sin = self.sin_cached[position_offset : position_offset + seq_len]
71
+
72
+ q_rotated = self._apply_rotation(q, cos, sin)
73
+ k_rotated = self._apply_rotation(k, cos, sin)
74
+ return q_rotated, k_rotated
75
+
76
+ @staticmethod
77
+ def _apply_rotation(
78
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
79
+ ) -> torch.Tensor:
80
+ """ํšŒ์ „ ๋ณ€ํ™˜ ์ ์šฉ.
81
+
82
+ 2D ํšŒ์ „ ํ–‰๋ ฌ:
83
+ [cos ฮธ, -sin ฮธ] [x1] [x1ยทcos ฮธ - x2ยทsin ฮธ]
84
+ [sin ฮธ, cos ฮธ] [x2] = [x1ยทsin ฮธ + x2ยทcos ฮธ]
85
+
86
+ ์ด๋ฅผ ๋ฒกํ„ฐ ์—ฐ์‚ฐ์œผ๋กœ ํšจ์œจ์ ์œผ๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
87
+ """
88
+ # x: (batch, heads, seq_len, head_dim)
89
+ # ์ง์ˆ˜/ํ™€์ˆ˜ ์ธ๋ฑ์Šค๋ฅผ ๋ถ„๋ฆฌ: (x0, x1, x2, x3, ...) โ†’ (x0, x2, ...), (x1, x3, ...)
90
+ x_even = x[..., 0::2] # ์ง์ˆ˜ ์ธ๋ฑ์Šค
91
+ x_odd = x[..., 1::2] # ํ™€์ˆ˜ ์ธ๋ฑ์Šค
92
+
93
+ # ๋ธŒ๋กœ๋“œ์บ์ŠคํŒ…์„ ์œ„ํ•ด ์ฐจ์› ๋งž์ถค: (seq_len, dim/2) โ†’ (1, 1, seq_len, dim/2)
94
+ cos = cos.unsqueeze(0).unsqueeze(0)
95
+ sin = sin.unsqueeze(0).unsqueeze(0)
96
+
97
+ # ํšŒ์ „ ์ ์šฉ
98
+ rotated_even = x_even * cos - x_odd * sin
99
+ rotated_odd = x_even * sin + x_odd * cos
100
+
101
+ # ๋‹ค์‹œ ์ธํ„ฐ๋ฆฌ๋น™: (even0, odd0, even1, odd1, ...)
102
+ out = torch.stack([rotated_even, rotated_odd], dim=-1)
103
+ return out.flatten(-2) # ๋งˆ์ง€๋ง‰ ๋‘ ์ฐจ์›์„ ํ•ฉ์ณ ์›๋ž˜ shape ๋ณต์›
llm_lab/model/transformer_block.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer Block (ํ•˜๋‚˜์˜ ๋ ˆ์ด์–ด)."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from llm_lab.config import ModelConfig
9
+ from .norm import RMSNorm
10
+ from .attention import GroupedQueryAttention
11
+ from .feedforward import SwiGLUFeedForward
12
+
13
+
14
+ class TransformerBlock(nn.Module):
15
+ """ํ•˜๋‚˜์˜ Transformer ๋””์ฝ”๋” ๋ธ”๋ก.
16
+
17
+ ๊ตฌ์กฐ (Pre-Norm ๋ฐฉ์‹):
18
+ x โ†’ RMSNorm โ†’ Attention โ†’ + (residual) โ†’ RMSNorm โ†’ FFN โ†’ + (residual) โ†’ out
19
+
20
+ Pre-Norm vs Post-Norm:
21
+ - Post-Norm (์›๋ž˜ Transformer): LayerNorm์ด residual ์ดํ›„
22
+ โ†’ ๊นŠ์€ ๋ชจ๋ธ์—์„œ ํ•™์Šต ๋ถˆ์•ˆ์ •
23
+ - Pre-Norm (GPT-2 ์ดํ›„ ํ‘œ์ค€): LayerNorm์ด sublayer ์ด์ „
24
+ โ†’ gradient ํ๋ฆ„์ด ์›ํ™œ, ํ•™์Šต์ด ์•ˆ์ •์ 
25
+
26
+ Residual Connection์˜ ์—ญํ• :
27
+ - ์ž…๋ ฅ์„ ์ถœ๋ ฅ์— ๋”ํ•จ โ†’ gradient๊ฐ€ ๋ ˆ์ด์–ด๋ฅผ ๊ฑด๋„ˆ๋›ธ ์ˆ˜ ์žˆ๋Š” "๊ณ ์†๋„๋กœ"
28
+ - 22๊ฐœ ๋ ˆ์ด์–ด๋ฅผ ์Œ“์•„๋„ ํ•™์Šต์ด ๊ฐ€๋Šฅํ•œ ํ•ต์‹ฌ ์ด์œ 
29
+ """
30
+
31
+ def __init__(self, config: ModelConfig, layer_idx: int):
32
+ super().__init__()
33
+ self.layer_idx = layer_idx
34
+
35
+ # Pre-Norm: Attention ์ „ ์ •๊ทœํ™”
36
+ self.attn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
37
+ # Self-Attention
38
+ self.attention = GroupedQueryAttention(config)
39
+
40
+ # Pre-Norm: FFN ์ „ ์ •๊ทœํ™”
41
+ self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
42
+ # Feed-Forward Network
43
+ self.feed_forward = SwiGLUFeedForward(config)
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ mask: Optional[torch.Tensor] = None,
49
+ position_offset: int = 0,
50
+ ) -> torch.Tensor:
51
+ """
52
+ Args:
53
+ x: (batch_size, seq_len, hidden_dim)
54
+ Returns:
55
+ (batch_size, seq_len, hidden_dim)
56
+ """
57
+ # โ”€โ”€ Attention sublayer with residual โ”€โ”€
58
+ # h = x + Attention(RMSNorm(x))
59
+ h = x + self.attention(self.attn_norm(x), mask, position_offset)
60
+
61
+ # โ”€โ”€ FFN sublayer with residual โ”€โ”€
62
+ # out = h + FFN(RMSNorm(h))
63
+ out = h + self.feed_forward(self.ffn_norm(h))
64
+
65
+ return out
llm_lab/model/utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋ชจ๋ธ ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import TYPE_CHECKING
7
+
8
+ from llm_lab.config import ModelConfig
9
+
10
+ if TYPE_CHECKING:
11
+ from .llm_model import LLMModel
12
+
13
+
14
+ def count_parameters_detailed(model: "LLMModel") -> dict:
15
+ """๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ์ปดํฌ๋„ŒํŠธ๋ณ„๋กœ ์ƒ์„ธ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค."""
16
+ total = 0
17
+ breakdown = {}
18
+
19
+ # Embedding
20
+ emb_params = model.token_embedding.weight.numel()
21
+ breakdown["token_embedding"] = emb_params
22
+ total += emb_params
23
+
24
+ # ๊ฐ ๋ ˆ์ด์–ด
25
+ layer_total = 0
26
+ layer_detail = {}
27
+ layer = model.layers[0]
28
+
29
+ for name, param in layer.named_parameters():
30
+ layer_detail[name] = param.numel()
31
+ layer_total += param.numel()
32
+
33
+ breakdown["per_layer"] = layer_detail
34
+ breakdown["per_layer_total"] = layer_total
35
+ breakdown["all_layers_total"] = layer_total * len(model.layers)
36
+ total += layer_total * len(model.layers)
37
+
38
+ # Final norm
39
+ norm_params = model.final_norm.weight.numel()
40
+ breakdown["final_norm"] = norm_params
41
+ total += norm_params
42
+
43
+ # LM head (weight tying์ด๋ฏ€๋กœ ์‹ค์ œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ 0)
44
+ breakdown["lm_head"] = "weight tying (0 additional)"
45
+ breakdown["total"] = total
46
+
47
+ return breakdown
48
+
49
+
50
+ def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
51
+ """๋ชจ๋ธ์˜ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ถ”์ •ํ•ฉ๋‹ˆ๋‹ค.
52
+
53
+ Args:
54
+ dtype_bytes: 2 (bf16/fp16) ๋˜๋Š” 4 (fp32)
55
+ """
56
+ # ๋Œ€๋žต์ ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ
57
+ emb = config.vocab_size * config.hidden_dim
58
+ per_layer = (
59
+ config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
60
+ + config.num_heads * config.head_dim * config.hidden_dim # O proj
61
+ + 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
62
+ + 2 * config.hidden_dim # 2 ร— RMSNorm
63
+ )
64
+ total_params = emb + per_layer * config.num_layers + config.hidden_dim
65
+
66
+ model_gb = total_params * dtype_bytes / 1e9
67
+ optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states ร— fp32
68
+ gradient_gb = total_params * dtype_bytes / 1e9
69
+
70
+ # ํ™œ์„ฑํ™” ๋ฉ”๋ชจ๋ฆฌ (activation checkpointing ์ ์šฉ ๊ฐ€์ •)
71
+ # ๋Œ€๋žต์  ์ถ”์ •: batch_size ร— seq_len ร— hidden_dim ร— num_layers ร— factor
72
+ activation_gb = (
73
+ batch_size * config.max_seq_len * config.hidden_dim * 4 # ๋ฐ”์ดํŠธ
74
+ * math.sqrt(config.num_layers) # checkpointing ํšจ๊ณผ
75
+ / 1e9
76
+ )
77
+
78
+ return {
79
+ "total_parameters": total_params,
80
+ "model_weights_gb": round(model_gb, 2),
81
+ "optimizer_states_gb": round(optimizer_gb, 2),
82
+ "gradients_gb": round(gradient_gb, 2),
83
+ "activations_estimated_gb": round(activation_gb, 2),
84
+ "total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
85
+ }
llm_lab/training/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ๋ชจ๋“ˆ โ€” Gradient Accumulation, Mixed Precision, ์ฒดํฌํฌ์ธํŠธ, wandb ๋กœ๊น…."""
2
+ from .scheduler import CosineWarmupScheduler
3
+ from .checkpoint import CheckpointManager
4
+ from .metrics import MetricsTracker
5
+ from .optimizer import create_optimizer
6
+ from .trainer import Trainer
7
+ from .runner import start_training
8
+
9
+ __all__ = [
10
+ "CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
11
+ "create_optimizer", "Trainer", "start_training",
12
+ ]
llm_lab/training/checkpoint.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ์ƒํƒœ ์ €์žฅ/๋ณต์› ๊ด€๋ฆฌ์ž."""
2
+
3
+ import json
4
+ import shutil
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Any, Dict, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from llm_lab.config import TrainConfig
13
+
14
+
15
+ class CheckpointManager:
16
+ """ํ•™์Šต ์ƒํƒœ ์ €์žฅ/๋ณต์› ๊ด€๋ฆฌ์ž.
17
+
18
+ Colab์—์„œ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ค‘์š”ํ•œ ์ด์œ :
19
+ - ์„ธ์…˜ ๋งŒ๋ฃŒ (์ตœ๋Œ€ ~24์‹œ๊ฐ„) ์‹œ ๋ชจ๋“  ๋ฉ”๋ชจ๋ฆฌ ์ƒํƒœ ์†Œ๋ฉธ
20
+ - Google Drive์— ์ €์žฅํ•˜๋ฉด ์„ธ์…˜ ๊ฐ„ ์—ฐ์† ํ•™์Šต ๊ฐ€๋Šฅ
21
+ - ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ๊นŒ์ง€ ์ €์žฅํ•ด์•ผ AdamW ๋ชจ๋ฉ˜ํ…€์ด ์œ ์ง€๋จ
22
+
23
+ ์ €์žฅ ๋‚ด์šฉ:
24
+ - model_state_dict: ๋ชจ๋ธ ๊ฐ€์ค‘์น˜
25
+ - optimizer_state_dict: ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ (m, v ๋ชจ๋ฉ˜ํ…€)
26
+ - step: ํ˜„์žฌ ํ•™์Šต ์Šคํ…
27
+ - best_val_loss: ์ตœ์ € ๊ฒ€์ฆ Loss
28
+ - config: ํ•™์Šต ์„ค์ • (์žฌํ˜„์„ฑ)
29
+ - rng_states: ๋žœ๋ค ์‹œ๋“œ ์ƒํƒœ (์™„์ „ ์žฌํ˜„)
30
+ - metrics_history: ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
31
+ - wandb_run_id: wandb ์‹คํ–‰ ID (๋กœ๊น… ์—ฐ์†์„ฑ)
32
+ """
33
+
34
+ def __init__(self, config: TrainConfig):
35
+ self.config = config
36
+ self.checkpoint_dir = Path(config.checkpoint_dir)
37
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
38
+ self.max_checkpoints = config.max_checkpoints
39
+
40
+ def save(
41
+ self,
42
+ model: nn.Module,
43
+ optimizer: torch.optim.Optimizer,
44
+ step: int,
45
+ best_val_loss: float,
46
+ metrics_history: Dict[str, list],
47
+ wandb_run_id: Optional[str] = None,
48
+ ):
49
+ """์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค."""
50
+ ckpt_path = self.checkpoint_dir / f"step_{step:06d}"
51
+ ckpt_path.mkdir(parents=True, exist_ok=True)
52
+
53
+ print(f"\n๐Ÿ’พ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ: {ckpt_path}")
54
+ start = time.time()
55
+
56
+ # 1) ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ (bf16 ์ƒํƒœ ๊ทธ๋Œ€๋กœ)
57
+ torch.save(model.state_dict(), ckpt_path / "model.pt")
58
+
59
+ # 2) ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ (fp32 ๋ชจ๋ฉ˜ํ…€ ํฌํ•จ, ํฌ๊ธฐ ํผ)
60
+ torch.save(optimizer.state_dict(), ckpt_path / "optimizer.pt")
61
+
62
+ # 3) ํ•™์Šต ๋ฉ”ํƒ€ ์ •๋ณด
63
+ meta = {
64
+ "step": step,
65
+ "best_val_loss": best_val_loss,
66
+ "wandb_run_id": wandb_run_id,
67
+ "config": self.config.__dict__,
68
+ }
69
+ with open(ckpt_path / "meta.json", "w") as f:
70
+ json.dump(meta, f, indent=2)
71
+
72
+ # 4) ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
73
+ torch.save(metrics_history, ckpt_path / "metrics.pt")
74
+
75
+ # 5) ๋žœ๋ค ์ƒํƒœ (์™„์ „ ์žฌํ˜„์„ ์œ„ํ•ด)
76
+ rng_states = {
77
+ "python": torch.random.get_rng_state(),
78
+ "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
79
+ }
80
+ torch.save(rng_states, ckpt_path / "rng_states.pt")
81
+
82
+ elapsed = time.time() - start
83
+ ckpt_size = sum(f.stat().st_size for f in ckpt_path.rglob("*")) / 1e9
84
+ print(f" ์ €์žฅ ์™„๋ฃŒ: {ckpt_size:.2f} GB, {elapsed:.1f}์ดˆ")
85
+
86
+ # ์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ (๋กค๋ง)
87
+ self._cleanup_old_checkpoints()
88
+
89
+ def load_latest(
90
+ self,
91
+ model: nn.Module,
92
+ optimizer: Optional[torch.optim.Optimizer] = None,
93
+ device: torch.device = torch.device("cpu"),
94
+ ) -> Dict[str, Any]:
95
+ """๊ฐ€์žฅ ์ตœ๊ทผ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
96
+
97
+ Returns:
98
+ {"step", "best_val_loss", "wandb_run_id", "metrics_history"}
99
+ ๋˜๋Š” ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์—†์œผ๋ฉด None
100
+ """
101
+ ckpt_path = self._find_latest()
102
+ if ckpt_path is None:
103
+ print("[Checkpoint] ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ ์—†์Œ. ์ฒ˜์Œ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.")
104
+ return None
105
+
106
+ print(f"\n๐Ÿ“‚ ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ: {ckpt_path}")
107
+ start = time.time()
108
+
109
+ # 1) ๋ชจ๋ธ ๊ฐ€์ค‘์น˜
110
+ model_state = torch.load(ckpt_path / "model.pt", map_location=device, weights_only=True)
111
+ model.load_state_dict(model_state)
112
+ del model_state # ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
113
+
114
+ # 2) ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒํƒœ
115
+ if optimizer is not None:
116
+ optim_state = torch.load(ckpt_path / "optimizer.pt", map_location=device, weights_only=True)
117
+ optimizer.load_state_dict(optim_state)
118
+ del optim_state
119
+
120
+ # 3) ๋ฉ”ํƒ€ ์ •๋ณด
121
+ with open(ckpt_path / "meta.json", "r") as f:
122
+ meta = json.load(f)
123
+
124
+ # 4) ๋ฉ”ํŠธ๋ฆญ ๊ธฐ๋ก
125
+ metrics_history = {}
126
+ metrics_path = ckpt_path / "metrics.pt"
127
+ if metrics_path.exists():
128
+ metrics_history = torch.load(metrics_path, weights_only=False)
129
+
130
+ # 5) ๋žœ๋ค ์ƒํƒœ ๋ณต์›
131
+ rng_path = ckpt_path / "rng_states.pt"
132
+ if rng_path.exists():
133
+ rng_states = torch.load(rng_path, weights_only=False)
134
+ torch.random.set_rng_state(rng_states["python"])
135
+ if rng_states["cuda"] is not None and torch.cuda.is_available():
136
+ torch.cuda.set_rng_state(rng_states["cuda"])
137
+
138
+ elapsed = time.time() - start
139
+ print(f" ๋กœ๋“œ ์™„๋ฃŒ: step={meta['step']}, {elapsed:.1f}์ดˆ")
140
+
141
+ return {
142
+ "step": meta["step"],
143
+ "best_val_loss": meta["best_val_loss"],
144
+ "wandb_run_id": meta.get("wandb_run_id"),
145
+ "metrics_history": metrics_history,
146
+ }
147
+
148
+ def _find_latest(self) -> Optional[Path]:
149
+ """๊ฐ€์žฅ ์ตœ๊ทผ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค."""
150
+ ckpts = sorted(self.checkpoint_dir.glob("step_*"))
151
+ return ckpts[-1] if ckpts else None
152
+
153
+ def _cleanup_old_checkpoints(self):
154
+ """์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค (๋กค๋ง)."""
155
+ ckpts = sorted(self.checkpoint_dir.glob("step_*"))
156
+ while len(ckpts) > self.max_checkpoints:
157
+ old = ckpts.pop(0)
158
+ print(f" ๐Ÿ—‘๏ธ ์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ: {old.name}")
159
+ shutil.rmtree(old)
llm_lab/training/metrics.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ ์ถ”์  ๋ฐ ๋กœ๊น…."""
2
+
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+
7
+ from llm_lab.config import TrainConfig
8
+
9
+
10
+ class MetricsTracker:
11
+ """ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ์„ ์ถ”์ ํ•˜๊ณ  ๋กœ๊น…ํ•ฉ๋‹ˆ๋‹ค.
12
+
13
+ ์ถ”์  ํ•ญ๋ชฉ:
14
+ - train/loss: ํ•™์Šต Loss (Cross-Entropy)
15
+ - train/lr: ํ˜„์žฌ ํ•™์Šต๋ฅ 
16
+ - train/grad_norm: Gradient L2 Norm
17
+ - train/tokens_per_sec: ์ฒ˜๋ฆฌ๋Ÿ‰
18
+ - train/gpu_mem_gb: GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰
19
+ - val/loss: ๊ฒ€์ฆ Loss
20
+ - val/perplexity: ๊ฒ€์ฆ Perplexity (= exp(loss))
21
+ """
22
+
23
+ def __init__(self, config: TrainConfig):
24
+ self.config = config
25
+ self.history: Dict[str, list] = {
26
+ "step": [],
27
+ "train_loss": [],
28
+ "learning_rate": [],
29
+ "grad_norm": [],
30
+ "tokens_per_sec": [],
31
+ "gpu_mem_gb": [],
32
+ "val_loss": [],
33
+ "val_ppl": [],
34
+ }
35
+
36
+ # wandb ์ดˆ๊ธฐํ™”
37
+ self.wandb_run = None
38
+ if config.use_wandb:
39
+ self._init_wandb()
40
+
41
+ def _init_wandb(self, resume_id: Optional[str] = None):
42
+ """wandb ์ดˆ๊ธฐํ™” (์„ธ์…˜ ๊ฐ„ ์—ฐ์† ๋กœ๊น… ์ง€์›)."""
43
+ try:
44
+ import wandb
45
+
46
+ run_id = resume_id or wandb.util.generate_id()
47
+ self.wandb_run = wandb.init(
48
+ project=self.config.wandb_project,
49
+ name=self.config.wandb_run_name or f"1b-run-{run_id[:6]}",
50
+ id=run_id,
51
+ resume="allow",
52
+ config=self.config.__dict__,
53
+ )
54
+ print(f"[wandb] ์ดˆ๊ธฐํ™” ์™„๋ฃŒ: {self.wandb_run.url}")
55
+ except ImportError:
56
+ print("[wandb] ์„ค์น˜๋˜์ง€ ์•Š์Œ. ์ฝ˜์†” ๋กœ๊น…๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
57
+ self.config.use_wandb = False
58
+ except Exception as e:
59
+ print(f"[wandb] ์ดˆ๊ธฐํ™” ์‹คํŒจ: {e}. ์ฝ˜์†” ๋กœ๊น…๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.")
60
+ self.config.use_wandb = False
61
+
62
+ def resume_wandb(self, run_id: str):
63
+ """์ด์ „ wandb ์‹คํ–‰์„ ์ด์–ด์„œ ๋กœ๊น…ํ•ฉ๋‹ˆ๋‹ค."""
64
+ if self.config.use_wandb:
65
+ self._init_wandb(resume_id=run_id)
66
+
67
+ def log_train_step(
68
+ self,
69
+ step: int,
70
+ loss: float,
71
+ lr: float,
72
+ grad_norm: float,
73
+ tokens_per_sec: float,
74
+ gpu_mem_gb: float,
75
+ ):
76
+ """ํ•™์Šต ์Šคํ… ๋ฉ”ํŠธ๋ฆญ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค."""
77
+ self.history["step"].append(step)
78
+ self.history["train_loss"].append(loss)
79
+ self.history["learning_rate"].append(lr)
80
+ self.history["grad_norm"].append(grad_norm)
81
+ self.history["tokens_per_sec"].append(tokens_per_sec)
82
+ self.history["gpu_mem_gb"].append(gpu_mem_gb)
83
+
84
+ if self.config.use_wandb and self.wandb_run:
85
+ import wandb
86
+
87
+ wandb.log({
88
+ "train/loss": loss,
89
+ "train/lr": lr,
90
+ "train/grad_norm": grad_norm,
91
+ "train/tokens_per_sec": tokens_per_sec,
92
+ "train/gpu_mem_gb": gpu_mem_gb,
93
+ }, step=step)
94
+
95
+ def log_eval(self, step: int, val_loss: float, val_ppl: float):
96
+ """๊ฒ€์ฆ ๋ฉ”ํŠธ๋ฆญ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค."""
97
+ self.history["val_loss"].append(val_loss)
98
+ self.history["val_ppl"].append(val_ppl)
99
+
100
+ if self.config.use_wandb and self.wandb_run:
101
+ import wandb
102
+
103
+ wandb.log({
104
+ "val/loss": val_loss,
105
+ "val/perplexity": val_ppl,
106
+ }, step=step)
107
+
108
+ @property
109
+ def wandb_run_id(self) -> Optional[str]:
110
+ if self.wandb_run:
111
+ return self.wandb_run.id
112
+ return None
llm_lab/training/optimizer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AdamW ์˜ตํ‹ฐ๋งˆ์ด์ € ์ƒ์„ฑ (Weight Decay ๋ถ„๋ฆฌ)."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from llm_lab.config import TrainConfig
7
+
8
+
9
+ def create_optimizer(model: nn.Module, config: TrainConfig) -> torch.optim.AdamW:
10
+ """AdamW ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
11
+
12
+ Weight Decay ๋ถ„๋ฆฌ ๊ทœ์น™:
13
+ - Decay ์ ์šฉ: Linear ๊ฐ€์ค‘์น˜ (attention proj, FFN ๋“ฑ)
14
+ - Decay ๋ฏธ์ ์šฉ: Embedding, LayerNorm/RMSNorm, Bias
15
+
16
+ ์™œ ๋ถ„๋ฆฌํ•˜๋Š”๊ฐ€?
17
+ - Weight Decay๋Š” ํฐ ๊ฐ€์ค‘์น˜์— ํŒจ๋„ํ‹ฐ๋ฅผ ์ฃผ์–ด ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
18
+ - ํ•˜์ง€๋งŒ Norm์˜ scale ํŒŒ๋ผ๋ฏธํ„ฐ์— ์ ์šฉํ•˜๋ฉด ์ •๊ทœํ™” ํšจ๊ณผ๋ฅผ ๋ฐฉํ•ด
19
+ - Embedding์— ์ ์šฉํ•˜๋ฉด ํฌ๊ท€ ํ† ํฐ์˜ ํ‘œํ˜„์ด 0์œผ๋กœ ์ˆ˜์ถ•
20
+ - 1D ํŒŒ๋ผ๋ฏธํ„ฐ(bias, norm weight)๋Š” decay์—์„œ ์ œ์™ธํ•˜๋Š” ๊ฒƒ์ด ๊ด€๋ก€
21
+ """
22
+ # ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ decay/no-decay ๊ทธ๋ฃน์œผ๋กœ ๋ถ„๋ฆฌ
23
+ decay_params = []
24
+ no_decay_params = []
25
+
26
+ for name, param in model.named_parameters():
27
+ if not param.requires_grad:
28
+ continue
29
+
30
+ # 1D ํ…์„œ(bias, norm weight) ๋˜๋Š” embedding โ†’ no decay
31
+ if param.dim() <= 1 or "embedding" in name:
32
+ no_decay_params.append(param)
33
+ else:
34
+ decay_params.append(param)
35
+
36
+ param_groups = [
37
+ {"params": decay_params, "weight_decay": config.weight_decay},
38
+ {"params": no_decay_params, "weight_decay": 0.0},
39
+ ]
40
+
41
+ n_decay = sum(p.numel() for p in decay_params)
42
+ n_no_decay = sum(p.numel() for p in no_decay_params)
43
+ print(f"[Optimizer] Decay ํŒŒ๋ผ๋ฏธํ„ฐ: {n_decay:,} ({n_decay/1e6:.1f}M)")
44
+ print(f"[Optimizer] No-decay ํŒŒ๋ผ๋ฏธํ„ฐ: {n_no_decay:,} ({n_no_decay/1e6:.1f}M)")
45
+
46
+ optimizer = torch.optim.AdamW(
47
+ param_groups,
48
+ lr=config.learning_rate,
49
+ betas=(config.beta1, config.beta2),
50
+ eps=config.adam_eps,
51
+ fused=torch.cuda.is_available(), # CUDA fused AdamW (๋” ๋น ๋ฆ„)
52
+ )
53
+
54
+ return optimizer
llm_lab/training/runner.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ํ•™์Šต ์‹คํ–‰ ํ—ฌํผ (Quick Start)."""
2
+
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+
10
+ from llm_lab.config import TrainConfig
11
+ from .trainer import Trainer
12
+ from llm_lab.utils import auto_configure
13
+
14
+
15
+ def start_training(
16
+ model: nn.Module,
17
+ train_dataloader: DataLoader,
18
+ val_dataloader: Optional[DataLoader] = None,
19
+ config: Optional[TrainConfig] = None,
20
+ seq_len: int = 2048,
21
+ auto_config: bool = True,
22
+ ) -> Trainer:
23
+ """ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค (ํ•œ ์ค„ ์‹คํ–‰).
24
+
25
+ ์‚ฌ์šฉ๋ฒ• (Colab):
26
+ ```python
27
+ from model import LLMModel, ModelConfig
28
+ from data_pipeline import setup_data_pipeline, DataConfig
29
+ from trainer import start_training, TrainConfig
30
+
31
+ # 1. ๋ชจ๋ธ ์ƒ์„ฑ
32
+ model_config = ModelConfig.base_1b()
33
+ model = LLMModel(model_config)
34
+
35
+ # 2. ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ
36
+ tok, train_dl, val_dl = setup_data_pipeline("pretrained")
37
+
38
+ # 3. ํ•™์Šต ์‹œ์ž‘ (์ฒดํฌํฌ์ธํŠธ ์ž๋™ ๋ณต์›)
39
+ trainer = start_training(model, train_dl, val_dl)
40
+ ```
41
+ """
42
+ config = config or TrainConfig()
43
+
44
+ # GPU ์ž๋™ ๊ฐ์ง€ ๋ฐ ์„ค์ • ์กฐ์ •
45
+ if auto_config:
46
+ config = auto_configure(config)
47
+
48
+ # Google Drive ๋งˆ์šดํŠธ ํ™•์ธ (Colab)
49
+ if "/content/drive" in config.checkpoint_dir:
50
+ drive_path = Path("/content/drive/MyDrive")
51
+ if not drive_path.exists():
52
+ print("\nโš ๏ธ Google Drive๊ฐ€ ๋งˆ์šดํŠธ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค!")
53
+ print(" Colab์—์„œ ์‹คํ–‰: from google.colab import drive; drive.mount('/content/drive')")
54
+ print(" ๋กœ์ปฌ ๊ฒฝ๋กœ๋กœ ๋ณ€๊ฒฝํ•ฉ๋‹ˆ๋‹ค.")
55
+ config.checkpoint_dir = "./checkpoints"
56
+
57
+ # ์žฌํ˜„์„ฑ ์‹œ๋“œ ์„ค์ •
58
+ torch.manual_seed(config.seed)
59
+ if torch.cuda.is_available():
60
+ torch.cuda.manual_seed(config.seed)
61
+
62
+ # Trainer ์ƒ์„ฑ (์ฒดํฌํฌ์ธํŠธ ์ž๋™ ๋ณต์› ํฌํ•จ)
63
+ trainer = Trainer(model, train_dataloader, val_dataloader, config, seq_len)
64
+
65
+ # ํ•™์Šต ์‹คํ–‰
66
+ trainer.train()
67
+
68
+ return trainer
llm_lab/training/scheduler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cosine Annealing with Linear Warmup ์Šค์ผ€์ค„๋Ÿฌ."""
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+ from llm_lab.config import TrainConfig
8
+
9
+
10
+ class CosineWarmupScheduler:
11
+ """Cosine Annealing with Linear Warmup.
12
+
13
+ LR ๊ณก์„ :
14
+ โ”Œโ”€โ”€โ”€ peak_lr โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฒ
15
+ โ”‚ โ•ฒ cosine decay
16
+ โ”‚ warmup (linear) โ•ฒ
17
+ โ”‚/ โ•ฒ_______ min_lr
18
+ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ steps
19
+
20
+ ์™œ Cosine Decay์ธ๊ฐ€?
21
+ - Step decay: ๊ฐ‘์ž‘์Šค๋Ÿฌ์šด LR ํ•˜๋ฝ โ†’ Loss ๋ถˆ์•ˆ์ •
22
+ - Linear decay: ํ›„๋ฐ˜๋ถ€ LR์ด ๋„ˆ๋ฌด ๋นจ๋ฆฌ ๊ฐ์†Œ
23
+ - Cosine: ๋ถ€๋“œ๋Ÿฌ์šด ๊ฐ์†Œ, ํ•™์Šต ํ›„๋ฐ˜์—๋„ ์ ์ ˆํ•œ LR ์œ ์ง€
24
+ - GPT-3, LLaMA, Chinchilla ๋“ฑ ๋Œ€๋ถ€๋ถ„์˜ LLM์ด ์‚ฌ์šฉ
25
+
26
+ ๊ตฌํ˜„ ์ฐธ๊ณ :
27
+ PyTorch ๋‚ด์žฅ ์Šค์ผ€์ค„๋Ÿฌ(CosineAnnealingLR ๋“ฑ)๋„ ์žˆ์ง€๋งŒ,
28
+ warmup + min_lr + ์ฒดํฌํฌ์ธํŠธ ๋ณต์›์„ ์œ„ํ•ด ์ง์ ‘ ๊ตฌํ˜„์ด ๋” ์œ ์—ฐํ•ฉ๋‹ˆ๋‹ค.
29
+ """
30
+
31
+ def __init__(self, config: TrainConfig):
32
+ self.peak_lr = config.learning_rate
33
+ self.min_lr = config.min_learning_rate
34
+ self.warmup_steps = config.warmup_steps
35
+ self.total_steps = config.total_steps
36
+
37
+ def get_lr(self, step: int) -> float:
38
+ """ํ˜„์žฌ step์— ํ•ด๋‹นํ•˜๋Š” ํ•™์Šต๋ฅ ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
39
+
40
+ Args:
41
+ step: ํ˜„์žฌ optimizer step (0-indexed)
42
+
43
+ Returns:
44
+ ํ•™์Šต๋ฅ  (float)
45
+ """
46
+ # Phase 1: Linear Warmup
47
+ if step < self.warmup_steps:
48
+ # 0 โ†’ peak_lr ์„ ํ˜• ์ฆ๊ฐ€
49
+ return self.peak_lr * (step / self.warmup_steps)
50
+
51
+ # Phase 2: Cosine Decay
52
+ # warmup ์ดํ›„ ๋‚จ์€ ์ง„ํ–‰๋ฅ  (0.0 โ†’ 1.0)
53
+ decay_steps = self.total_steps - self.warmup_steps
54
+ progress = (step - self.warmup_steps) / max(decay_steps, 1)
55
+ progress = min(progress, 1.0) # ์•ˆ์ „์žฅ์น˜
56
+
57
+ # Cosine ๊ณต์‹: min_lr + 0.5 ร— (peak - min) ร— (1 + cos(ฯ€ ร— progress))
58
+ cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
59
+ lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay
60
+
61
+ return lr
62
+
63
+ def set_lr(self, optimizer: torch.optim.Optimizer, step: int):
64
+ """Optimizer์˜ ํ•™์Šต๋ฅ ์„ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค."""
65
+ lr = self.get_lr(step)
66
+ for param_group in optimizer.param_groups:
67
+ param_group["lr"] = lr
68
+ return lr
llm_lab/training/trainer.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM ์‚ฌ์ „ํ•™์Šต ํŠธ๋ ˆ์ด๋„ˆ."""
2
+
3
+ import math
4
+ import time
5
+ from typing import Dict, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ from llm_lab.config import TrainConfig
12
+ from .scheduler import CosineWarmupScheduler
13
+ from .checkpoint import CheckpointManager
14
+ from .metrics import MetricsTracker
15
+ from .optimizer import create_optimizer
16
+
17
+
18
+ class Trainer:
19
+ """LLM ์‚ฌ์ „ํ•™์Šต ํŠธ๋ ˆ์ด๋„ˆ.
20
+
21
+ ํ•™์Šต ๋ฃจํ”„์˜ ํ•ต์‹ฌ ๊ตฌ์กฐ:
22
+ ```
23
+ for step in range(total_steps):
24
+ # โ”€โ”€ Gradient Accumulation Loop โ”€โ”€
25
+ for micro_step in range(accumulation_steps):
26
+ batch = next(dataloader)
27
+ with autocast(bf16):
28
+ logits, loss = model(input_ids, targets)
29
+ scaled_loss = loss / accumulation_steps
30
+ scaled_loss.backward() # gradient ๋ˆ„์ 
31
+
32
+ # โ”€โ”€ Optimizer Step (accumulation ์™„๋ฃŒ ํ›„) โ”€โ”€
33
+ clip_grad_norm(model, max_norm=1.0)
34
+ optimizer.step()
35
+ optimizer.zero_grad()
36
+ scheduler.set_lr(optimizer, step)
37
+ ```
38
+
39
+ Gradient Accumulation์ด๋ž€?
40
+ - GPU ๋ฉ”๋ชจ๋ฆฌ์— ํฐ ๋ฐฐ์น˜๋ฅผ ํ•œ ๋ฒˆ์— ์˜ฌ๋ฆด ์ˆ˜ ์—†์„ ๋•Œ
41
+ - ์ž‘์€ micro_batch๋กœ ์—ฌ๋Ÿฌ ๋ฒˆ forward/backward โ†’ gradient๋ฅผ ๋ˆ„์ 
42
+ - ๋ˆ„์  ํ›„ ํ•œ ๋ฒˆ์— optimizer step
43
+ - ๊ฒฐ๊ณผ์ ์œผ๋กœ ํฐ effective_batch์™€ ๋™์ผํ•œ ํšจ๊ณผ
44
+ - Loss๋ฅผ accumulation_steps๋กœ ๋‚˜๋ˆ„๋Š” ์ด์œ :
45
+ gradient์˜ ํ‰๊ท ์„ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด (ํ•ฉ์ด ์•„๋‹Œ ํ‰๊ท )
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model: nn.Module,
51
+ train_dataloader: DataLoader,
52
+ val_dataloader: Optional[DataLoader],
53
+ config: TrainConfig,
54
+ seq_len: int = 2048,
55
+ ):
56
+ self.config = config
57
+ self.seq_len = seq_len
58
+
59
+ # โ”€โ”€ ๋””๋ฐ”์ด์Šค ์„ค์ • โ”€โ”€
60
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ print(f"[Trainer] ๋””๋ฐ”์ด์Šค: {self.device}")
62
+ if torch.cuda.is_available():
63
+ print(f"[Trainer] GPU: {torch.cuda.get_device_name()}")
64
+ print(f"[Trainer] GPU ๋ฉ”๋ชจ๋ฆฌ: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
65
+
66
+ # โ”€โ”€ ๋ชจ๋ธ โ”€โ”€
67
+ self.model = model.to(self.device)
68
+ # torch.compile: PyTorch 2.0+ ๊ทธ๋ž˜ํ”„ ์ตœ์ ํ™” (์†๋„ 10-30% ํ–ฅ์ƒ)
69
+ if torch.cuda.is_available() and hasattr(torch, "compile"):
70
+ print("[Trainer] torch.compile ์ ์šฉ ์ค‘...")
71
+ self.model = torch.compile(self.model)
72
+
73
+ # โ”€โ”€ ๋ฐ์ดํ„ฐ โ”€โ”€
74
+ self.train_dataloader = train_dataloader
75
+ self.val_dataloader = val_dataloader
76
+ self.train_iter = iter(train_dataloader)
77
+
78
+ # โ”€โ”€ ์˜ตํ‹ฐ๋งˆ์ด์ € โ”€โ”€
79
+ self.optimizer = create_optimizer(self.model, config)
80
+
81
+ # โ”€โ”€ ์Šค์ผ€์ค„๋Ÿฌ โ”€โ”€
82
+ self.scheduler = CosineWarmupScheduler(config)
83
+
84
+ # โ”€โ”€ ์ฒดํฌํฌ์ธํŠธ โ”€โ”€
85
+ self.ckpt_manager = CheckpointManager(config)
86
+
87
+ # โ”€โ”€ ๋ฉ”ํŠธ๋ฆญ โ”€โ”€
88
+ self.metrics = MetricsTracker(config)
89
+
90
+ # โ”€โ”€ ํ•™์Šต ์ƒํƒœ โ”€โ”€
91
+ self.global_step = 0
92
+ self.best_val_loss = float("inf")
93
+ self.tokens_seen = 0
94
+
95
+ # โ”€โ”€ Mixed Precision โ”€โ”€
96
+ # bf16์€ GradScaler๊ฐ€ ๋ถˆํ•„์š” (fp16์ผ ๋•Œ๋งŒ ํ•„์š”)
97
+ self.use_amp = config.dtype != "float32"
98
+ self.amp_dtype = config.torch_dtype
99
+
100
+ # โ”€โ”€ ์ž๋™ ๋ณต์› ์‹œ๋„ โ”€โ”€
101
+ self._try_resume()
102
+
103
+ def _try_resume(self):
104
+ """์ด์ „ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์žˆ์œผ๋ฉด ์ž๋™์œผ๋กœ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค."""
105
+ result = self.ckpt_manager.load_latest(
106
+ self.model, self.optimizer, self.device
107
+ )
108
+
109
+ if result is not None:
110
+ self.global_step = result["step"]
111
+ self.best_val_loss = result["best_val_loss"]
112
+ self.metrics.history = result.get("metrics_history", self.metrics.history)
113
+
114
+ # wandb ์—ฐ์† ๋กœ๊น…
115
+ if result.get("wandb_run_id"):
116
+ self.metrics.resume_wandb(result["wandb_run_id"])
117
+
118
+ self.tokens_seen = self.global_step * self.config.effective_batch_size * self.seq_len
119
+ print(f"[Trainer] ํ•™์Šต ์žฌ๊ฐœ: step={self.global_step}, "
120
+ f"tokens={self.tokens_seen/1e9:.2f}B, "
121
+ f"best_val_loss={self.best_val_loss:.4f}")
122
+
123
+ def _get_next_batch(self) -> Dict[str, torch.Tensor]:
124
+ """๋‹ค์Œ ํ•™์Šต ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
125
+
126
+ Streaming DataLoader๋Š” ์—ํญ ๊ฐœ๋…์ด ์—†์œผ๋ฏ€๋กœ,
127
+ StopIteration ์‹œ ์ƒˆ ์ดํ„ฐ๋ ˆ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
128
+ """
129
+ try:
130
+ batch = next(self.train_iter)
131
+ except StopIteration:
132
+ self.train_iter = iter(self.train_dataloader)
133
+ batch = next(self.train_iter)
134
+
135
+ return {
136
+ "input_ids": batch["input_ids"].to(self.device, non_blocking=True),
137
+ "targets": batch["targets"].to(self.device, non_blocking=True),
138
+ }
139
+
140
+ def _train_step(self) -> Tuple[float, float]:
141
+ """ํ•˜๋‚˜์˜ optimizer step์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
142
+
143
+ Returns:
144
+ (loss, grad_norm)
145
+ """
146
+ self.model.train()
147
+ self.optimizer.zero_grad(set_to_none=True)
148
+ # set_to_none=True: gradient๋ฅผ None์œผ๋กœ ์„ค์ • โ†’ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
149
+
150
+ total_loss = 0.0
151
+
152
+ # โ”€โ”€ Gradient Accumulation Loop โ”€โ”€
153
+ for micro_step in range(self.config.gradient_accumulation_steps):
154
+ batch = self._get_next_batch()
155
+
156
+ # Mixed Precision Forward
157
+ with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
158
+ logits, loss = self.model(batch["input_ids"], batch["targets"])
159
+
160
+ # Loss ์Šค์ผ€์ผ๋ง: effective batch์˜ ํ‰๊ท ์„ ์œ„ํ•ด
161
+ scaled_loss = loss / self.config.gradient_accumulation_steps
162
+ total_loss += loss.item()
163
+
164
+ # Backward (gradient ๋ˆ„์ )
165
+ scaled_loss.backward()
166
+
167
+ # โ”€โ”€ Gradient Clipping โ”€โ”€
168
+ # ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ gradient๋ฅผ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๋กœ ๋ณด๊ณ  L2 norm ๊ณ„์‚ฐ
169
+ # norm์ด max_norm์„ ์ดˆ๊ณผํ•˜๋ฉด ๋น„๋ก€์ ์œผ๋กœ ์Šค์ผ€์ผ ๋‹ค์šด
170
+ grad_norm = torch.nn.utils.clip_grad_norm_(
171
+ self.model.parameters(),
172
+ max_norm=self.config.grad_clip,
173
+ ).item()
174
+
175
+ # โ”€โ”€ Optimizer Step โ”€โ”€
176
+ self.optimizer.step()
177
+
178
+ # โ”€โ”€ LR ์—…๋ฐ์ดํŠธ โ”€โ”€
179
+ self.scheduler.set_lr(self.optimizer, self.global_step)
180
+
181
+ avg_loss = total_loss / self.config.gradient_accumulation_steps
182
+ return avg_loss, grad_norm
183
+
184
+ @torch.no_grad()
185
+ def _evaluate(self) -> Tuple[float, float]:
186
+ """๊ฒ€์ฆ ๋ฐ์ดํ„ฐ์—์„œ Loss์™€ Perplexity๋ฅผ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค.
187
+
188
+ Perplexity = exp(loss)
189
+ - ์ง๊ด€: "๋ชจ๋ธ์ด ๋‹ค์Œ ํ† ํฐ์„ ํ‰๊ท  ๋ช‡ ๊ฐœ์˜ ํ›„๋ณด ์ค‘์—์„œ ๊ณ ๋ฅด๋Š”๊ฐ€"
190
+ - PPL 100 โ†’ 100๊ฐœ ์ค‘ 1๊ฐœ๋ฅผ ๊ท ์ผํ•˜๊ฒŒ ๊ณ ๋ฅด๋Š” ์ˆ˜์ค€
191
+ - PPL 20 โ†’ 20๊ฐœ ์ค‘ 1๊ฐœ ์ˆ˜์ค€ (๊ฝค ์ข‹์Œ)
192
+ - PPL 10 โ†’ ๋งค์šฐ ์ž์‹ ์žˆ๊ฒŒ ์˜ˆ์ธก
193
+ """
194
+ if self.val_dataloader is None:
195
+ return float("inf"), float("inf")
196
+
197
+ self.model.eval()
198
+ total_loss = 0.0
199
+ num_batches = 0
200
+
201
+ for i, batch in enumerate(self.val_dataloader):
202
+ if i >= self.config.eval_steps:
203
+ break
204
+
205
+ input_ids = batch["input_ids"].to(self.device)
206
+ targets = batch["targets"].to(self.device)
207
+
208
+ with torch.amp.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
209
+ _, loss = self.model(input_ids, targets)
210
+
211
+ total_loss += loss.item()
212
+ num_batches += 1
213
+
214
+ avg_loss = total_loss / max(num_batches, 1)
215
+ perplexity = math.exp(min(avg_loss, 20)) # overflow ๋ฐฉ์ง€ (exp(20) โ‰ˆ 5์–ต)
216
+
217
+ return avg_loss, perplexity
218
+
219
+ def train(self):
220
+ """๋ฉ”์ธ ํ•™์Šต ๋ฃจํ”„.
221
+
222
+ ์ด ๋ฉ”์„œ๋“œ๊ฐ€ ์ „์ฒด ํ•™์Šต์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
223
+ Colab ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ ์ค‘๋‹จ๋˜์–ด๋„ ์ฒดํฌํฌ์ธํŠธ์—์„œ ์ž๋™ ์žฌ๊ฐœ๋ฉ๋‹ˆ๋‹ค.
224
+ """
225
+ config = self.config
226
+
227
+ print("\n" + "=" * 70)
228
+ print("๐Ÿš€ ํ•™์Šต ์‹œ์ž‘")
229
+ print("=" * 70)
230
+ print(f" ์ด ์Šคํ…: {config.total_steps:,}")
231
+ print(f" ์‹œ์ž‘ ์Šคํ…: {self.global_step}")
232
+ print(f" Effective batch size: {config.effective_batch_size}")
233
+ print(f" ํ† ํฐ/์Šคํ…: {config.effective_batch_size * self.seq_len:,}")
234
+ print(f" ์ด ํ•™์Šต ํ† ํฐ (์˜ˆ์ƒ): {config.total_steps * config.effective_batch_size * self.seq_len / 1e9:.1f}B")
235
+ print(f" Mixed Precision: {config.dtype}")
236
+ print(f" Gradient Accumulation: {config.gradient_accumulation_steps}")
237
+ print(f" ์ฒดํฌํฌ์ธํŠธ: {config.checkpoint_dir}")
238
+ print("=" * 70 + "\n")
239
+
240
+ step_start_time = time.time()
241
+ tokens_at_log_start = self.tokens_seen
242
+
243
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
244
+ # ๋ฉ”์ธ ๋ฃจํ”„
245
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
246
+
247
+ while self.global_step < config.total_steps:
248
+
249
+ # โ”€โ”€ Train Step โ”€โ”€
250
+ loss, grad_norm = self._train_step()
251
+ self.global_step += 1
252
+ self.tokens_seen += config.effective_batch_size * self.seq_len
253
+
254
+ # โ”€โ”€ Logging โ”€โ”€
255
+ if self.global_step % config.log_interval == 0:
256
+ elapsed = time.time() - step_start_time
257
+ tokens_delta = self.tokens_seen - tokens_at_log_start
258
+ tokens_per_sec = tokens_delta / max(elapsed, 1e-6)
259
+
260
+ # GPU ๋ฉ”๋ชจ๋ฆฌ
261
+ gpu_mem_gb = 0.0
262
+ if torch.cuda.is_available():
263
+ gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9
264
+
265
+ # ํ˜„์žฌ LR
266
+ current_lr = self.scheduler.get_lr(self.global_step)
267
+
268
+ # ๋‚จ์€ ์‹œ๊ฐ„ ์ถ”์ •
269
+ remaining_steps = config.total_steps - self.global_step
270
+ steps_per_sec = config.log_interval / max(elapsed, 1e-6)
271
+ eta_seconds = remaining_steps / max(steps_per_sec, 1e-6)
272
+ eta_hours = eta_seconds / 3600
273
+
274
+ # ์ฝ˜์†” ์ถœ๋ ฅ
275
+ print(
276
+ f" Step {self.global_step:>6d}/{config.total_steps} โ”‚ "
277
+ f"Loss {loss:.4f} โ”‚ "
278
+ f"LR {current_lr:.2e} โ”‚ "
279
+ f"Grad {grad_norm:.2f} โ”‚ "
280
+ f"{tokens_per_sec:,.0f} tok/s โ”‚ "
281
+ f"GPU {gpu_mem_gb:.1f}GB โ”‚ "
282
+ f"ETA {eta_hours:.1f}h โ”‚ "
283
+ f"Tokens {self.tokens_seen/1e9:.2f}B"
284
+ )
285
+
286
+ # wandb ๋กœ๊น…
287
+ self.metrics.log_train_step(
288
+ step=self.global_step,
289
+ loss=loss,
290
+ lr=current_lr,
291
+ grad_norm=grad_norm,
292
+ tokens_per_sec=tokens_per_sec,
293
+ gpu_mem_gb=gpu_mem_gb,
294
+ )
295
+
296
+ step_start_time = time.time()
297
+ tokens_at_log_start = self.tokens_seen
298
+
299
+ # โ”€โ”€ Evaluation โ”€โ”€
300
+ if self.global_step % config.eval_interval == 0:
301
+ val_loss, val_ppl = self._evaluate()
302
+
303
+ print(f"\n ๐Ÿ“Š Eval @ Step {self.global_step}: "
304
+ f"Val Loss = {val_loss:.4f}, "
305
+ f"Val PPL = {val_ppl:.2f}")
306
+
307
+ self.metrics.log_eval(self.global_step, val_loss, val_ppl)
308
+
309
+ if val_loss < self.best_val_loss:
310
+ self.best_val_loss = val_loss
311
+ print(f" ๐Ÿ† New best val loss: {val_loss:.4f}")
312
+
313
+ print()
314
+
315
+ # โ”€โ”€ Checkpoint โ”€โ”€
316
+ if self.global_step % config.checkpoint_interval == 0:
317
+ self.ckpt_manager.save(
318
+ model=self.model,
319
+ optimizer=self.optimizer,
320
+ step=self.global_step,
321
+ best_val_loss=self.best_val_loss,
322
+ metrics_history=self.metrics.history,
323
+ wandb_run_id=self.metrics.wandb_run_id,
324
+ )
325
+
326
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
327
+ # ํ•™์Šต ์™„๋ฃŒ
328
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
329
+
330
+ print("\n" + "=" * 70)
331
+ print("๐ŸŽ‰ ํ•™์Šต ์™„๋ฃŒ!")
332
+ print("=" * 70)
333
+ print(f" ์ด ์Šคํ…: {self.global_step:,}")
334
+ print(f" ์ด ํ† ํฐ: {self.tokens_seen/1e9:.2f}B")
335
+ print(f" ์ตœ์ € Val Loss: {self.best_val_loss:.4f}")
336
+ print(f" ์ตœ์ € Val PPL: {math.exp(min(self.best_val_loss, 20)):.2f}")
337
+ print("=" * 70)
338
+
339
+ # ์ตœ์ข… ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ
340
+ self.ckpt_manager.save(
341
+ model=self.model,
342
+ optimizer=self.optimizer,
343
+ step=self.global_step,
344
+ best_val_loss=self.best_val_loss,
345
+ metrics_history=self.metrics.history,
346
+ wandb_run_id=self.metrics.wandb_run_id,
347
+ )
348
+
349
+ if self.config.use_wandb and self.metrics.wandb_run:
350
+ import wandb
351
+ wandb.finish()
llm_lab/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """๊ณตํ†ต ์œ ํ‹ธ๋ฆฌํ‹ฐ โ€” ๋””๋ฐ”์ด์Šค ๊ฐ์ง€, ์‹œ๋“œ ์„ค์ •."""
2
+ from .device import get_device, detect_gpu_info, auto_configure
3
+ from .seed import set_seed
4
+
5
+ __all__ = ["get_device", "detect_gpu_info", "auto_configure", "set_seed"]
llm_lab/utils/device.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """๋””๋ฐ”์ด์Šค ๊ฐ์ง€ ๋ฐ ์ž๋™ ์„ค์ • ์œ ํ‹ธ๋ฆฌํ‹ฐ."""
2
+ from __future__ import annotations
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+
8
+ if TYPE_CHECKING:
9
+ from llm_lab.config import TrainConfig
10
+
11
+
12
+ def get_device() -> torch.device:
13
+ """์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋””๋ฐ”์ด์Šค(cuda ๋˜๋Š” cpu)๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
14
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
+ def detect_gpu_info() -> dict:
18
+ """GPU ์ด๋ฆ„๊ณผ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ณด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
19
+
20
+ Returns:
21
+ {"name": str, "memory_gb": float} ๋˜๋Š” GPU๊ฐ€ ์—†์œผ๋ฉด ๋นˆ dict
22
+ """
23
+ if not torch.cuda.is_available():
24
+ return {}
25
+ return {
26
+ "name": torch.cuda.get_device_name(),
27
+ "memory_gb": round(torch.cuda.get_device_properties(0).total_mem / 1e9, 1),
28
+ }
29
+
30
+
31
+ def auto_configure(config: "TrainConfig") -> "TrainConfig":
32
+ """GPU ์ข…๋ฅ˜์— ๋”ฐ๋ผ ์„ค์ •์„ ์ž๋™ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
33
+
34
+ Colab Pro+์—์„œ A100์ด ํ•ญ์ƒ ๋ฐฐ์ •๋˜์ง€๋Š” ์•Š์Šต๋‹ˆ๋‹ค.
35
+ T4๋‚˜ V100์ด ๋ฐฐ์ •๋  ๊ฒฝ์šฐ ์ž๋™์œผ๋กœ ์„ค์ •์„ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
36
+
37
+ Returns:
38
+ ์กฐ์ •๋œ TrainConfig
39
+ """
40
+ if not torch.cuda.is_available():
41
+ print("โš ๏ธ GPU ์—†์Œ! CPU ๋ชจ๋“œ (๋งค์šฐ ๋А๋ฆผ)")
42
+ config.dtype = "float32"
43
+ config.micro_batch_size = 1
44
+ config.gradient_accumulation_steps = 4
45
+ return config
46
+
47
+ gpu_name = torch.cuda.get_device_name().lower()
48
+ gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
49
+
50
+ print(f"\n๐Ÿ” GPU ๊ฐ์ง€: {torch.cuda.get_device_name()} ({gpu_mem:.1f} GB)")
51
+
52
+ if "a100" in gpu_name:
53
+ # A100 40GB: ๊ธฐ๋ณธ ์„ค์ • ๊ทธ๋Œ€๋กœ (์ตœ์ )
54
+ print(" โ†’ A100 ๊ฐ์ง€: ๊ธฐ๋ณธ ์„ค์ • ์‚ฌ์šฉ (bf16, batch=4)")
55
+ config.dtype = "bfloat16"
56
+ config.micro_batch_size = 4
57
+
58
+ elif "v100" in gpu_name:
59
+ # V100 16GB: bf16 ๋ฏธ์ง€์›, ๋ฐฐ์น˜ ์ถ•์†Œ
60
+ print(" โ†’ V100 ๊ฐ์ง€: fp16 ๋ชจ๋“œ, ๋ฐฐ์น˜ ์ถ•์†Œ")
61
+ config.dtype = "float16"
62
+ config.micro_batch_size = 2
63
+ config.gradient_accumulation_steps = 64 # effective batch ์œ ์ง€
64
+
65
+ elif "t4" in gpu_name:
66
+ # T4 16GB: bf16 ๋ฏธ์ง€์›, ๋” ์ž‘์€ ๋ฐฐ์น˜
67
+ print(" โ†’ T4 ๊ฐ์ง€: fp16 ๋ชจ๋“œ, ์ตœ์†Œ ๋ฐฐ์น˜")
68
+ config.dtype = "float16"
69
+ config.micro_batch_size = 1
70
+ config.gradient_accumulation_steps = 128
71
+
72
+ elif "l4" in gpu_name:
73
+ # L4 24GB: bf16 ์ง€์›
74
+ print(" โ†’ L4 ๊ฐ์ง€: bf16 ๋ชจ๋“œ, ๋ฐฐ์น˜ ์กฐ์ •")
75
+ config.dtype = "bfloat16"
76
+ config.micro_batch_size = 2
77
+ config.gradient_accumulation_steps = 64
78
+
79
+ else:
80
+ print(f" โ†’ ์•Œ ์ˆ˜ ์—†๋Š” GPU. ๋ฉ”๋ชจ๋ฆฌ ๊ธฐ์ค€์œผ๋กœ ์„ค์ • ์กฐ์ •")
81
+ if gpu_mem >= 30:
82
+ config.micro_batch_size = 4
83
+ elif gpu_mem >= 16:
84
+ config.micro_batch_size = 2
85
+ else:
86
+ config.micro_batch_size = 1
87
+ config.gradient_accumulation_steps = 128
88
+
89
+ print(f" โ†’ dtype: {config.dtype}")
90
+ print(f" โ†’ micro_batch: {config.micro_batch_size}")
91
+ print(f" โ†’ grad_accum: {config.gradient_accumulation_steps}")
92
+ print(f" โ†’ effective_batch: {config.effective_batch_size}")
93
+
94
+ return config
llm_lab/utils/seed.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """์žฌํ˜„์„ฑ์„ ์œ„ํ•œ ์‹œ๋“œ ์œ ํ‹ธ๋ฆฌํ‹ฐ."""
2
+ import torch
3
+
4
+
5
+ def set_seed(seed: int = 42):
6
+ """์žฌํ˜„์„ฑ์„ ์œ„ํ•œ ์‹œ๋“œ ์„ค์ •."""
7
+ torch.manual_seed(seed)
8
+ if torch.cuda.is_available():
9
+ torch.cuda.manual_seed(seed)
notebooks/01_data_pipeline.ipynb ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 01. ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ\n",
8
+ "\n",
9
+ "ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„ โ†’ ๋ฐ์ดํ„ฐ ์ŠคํŠธ๋ฆฌ๋ฐ โ†’ ์‹œํ€€์Šค ํŒจํ‚น โ†’ ๋ฐฐ์น˜ ๊ตฌ์„ฑ\n",
10
+ "\n",
11
+ "**ํŒŒ์ดํ”„๋ผ์ธ ํ๋ฆ„:**\n",
12
+ "```\n",
13
+ "FineWeb-Edu (HuggingFace)\n",
14
+ " โ†’ Streaming์œผ๋กœ ๋กœ๋“œ (๋””์Šคํฌ ์ €์žฅ ์—†์Œ)\n",
15
+ " โ†’ ํ† ํฌ๋‚˜์ด์ง• (BPE, vocab=32K)\n",
16
+ " โ†’ ์‹œํ€€์Šค ํŒจํ‚น (์—ฌ๋Ÿฌ ๋ฌธ์„œ๋ฅผ max_seq_len์œผ๋กœ ์—ฐ๊ฒฐ)\n",
17
+ " โ†’ ๋ฐฐ์น˜ ๊ตฌ์„ฑ (input_ids, targets)\n",
18
+ " โ†’ GPU ์ „์†ก\n",
19
+ "```"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# ํ•„์š” ํŒจํ‚ค์ง€ ์„ค์น˜\n",
29
+ "!pip install datasets tokenizers sentencepiece transformers -q"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import sys\n",
39
+ "sys.path.insert(0, '..')\n",
40
+ "\n",
41
+ "from llm_lab.config import DataConfig\n",
42
+ "from llm_lab.data import (\n",
43
+ " Tokenizer, setup_data_pipeline,\n",
44
+ " DataPipelineDiagnostics\n",
45
+ ")"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "## 1. ๋ฐ์ดํ„ฐ ์„ค์ • (Config)\n",
53
+ "\n",
54
+ "์•„๋ž˜ ๊ฐ’๋“ค์„ ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •ํ•˜์„ธ์š”."
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "data_config = DataConfig(\n",
64
+ " dataset_name=\"HuggingFaceFW/fineweb-edu\",\n",
65
+ " dataset_subset=\"sample-10BT\",\n",
66
+ " vocab_size=32_000,\n",
67
+ " max_seq_len=2048,\n",
68
+ " batch_size=4,\n",
69
+ " num_workers=2,\n",
70
+ ")\n",
71
+ "\n",
72
+ "print(f\"๋ฐ์ดํ„ฐ์…‹: {data_config.dataset_name} ({data_config.dataset_subset})\")\n",
73
+ "print(f\"์‹œํ€€์Šค ๊ธธ์ด: {data_config.max_seq_len}\")\n",
74
+ "print(f\"๋ฐฐ์น˜ ํฌ๊ธฐ: {data_config.batch_size}\")\n",
75
+ "print(f\"ํ† ํฐ/๋ฐฐ์น˜: {data_config.batch_size * data_config.max_seq_len:,}\")"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {},
81
+ "source": [
82
+ "## 2. ํ† ํฌ๋‚˜์ด์ € ์„ค์ •\n",
83
+ "\n",
84
+ "์„ธ ๊ฐ€์ง€ ๋ฐฉ๋ฒ• ์ค‘ ์„ ํƒ:\n",
85
+ "- `\"pretrained\"` โ€” HuggingFace ์‚ฌ์ „ํ•™์Šต ํ† ํฌ๋‚˜์ด์ € (๊ฐ€์žฅ ๊ฐ„ํŽธ)\n",
86
+ "- `\"train_new\"` โ€” BPE ํ† ํฌ๋‚˜์ด์ € ์ƒˆ๋กœ ํ•™์Šต\n",
87
+ "- `\"load_trained\"` โ€” ์ด์ „์— ํ•™์Šตํ•œ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
97
+ " tokenizer_mode=\"pretrained\", # \"train_new\" ๋˜๋Š” \"load_trained\"๋กœ ๋ณ€๊ฒฝ ๊ฐ€๋Šฅ\n",
98
+ " config=data_config,\n",
99
+ ")"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "metadata": {},
105
+ "source": [
106
+ "## 3. ํŒŒ์ดํ”„๋ผ์ธ ์ง„๋‹จ"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# ํ† ํฌ๋‚˜์ด์ € ํ’ˆ์งˆ ์ง„๋‹จ\n",
116
+ "DataPipelineDiagnostics.check_tokenizer_quality(tokenizer, data_config)"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "# ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ฒ˜๋ฆฌ๋Ÿ‰ ๋ฒค์น˜๋งˆํฌ\n",
126
+ "DataPipelineDiagnostics.benchmark_throughput(train_dl, num_batches=50)"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ "## 4. ๋ฐฐ์น˜ ๊ฒ€์‚ฌ"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# ์ฒซ ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์™€์„œ ์ƒ์„ธ ๊ฒ€์‚ฌ\n",
143
+ "batch = next(iter(train_dl))\n",
144
+ "DataPipelineDiagnostics.inspect_batch(batch, tokenizer)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "metadata": {},
150
+ "source": [
151
+ "---\n",
152
+ "**๋‹ค์Œ ๋‹จ๊ณ„:** `02_model.ipynb`์—์„œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
153
+ ]
154
+ }
155
+ ],
156
+ "metadata": {
157
+ "kernelspec": {
158
+ "display_name": "Python 3",
159
+ "language": "python",
160
+ "name": "python3"
161
+ },
162
+ "language_info": {
163
+ "name": "python",
164
+ "version": "3.10.0"
165
+ }
166
+ },
167
+ "nbformat": 4,
168
+ "nbformat_minor": 4
169
+ }
notebooks/02_model.ipynb ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 02. ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜\n",
8
+ "\n",
9
+ "1.1B ํŒŒ๋ผ๋ฏธํ„ฐ LLaMA-style Decoder-Only Transformer ์ƒ์„ฑ ๋ฐ ๊ฒ€์ฆ.\n",
10
+ "\n",
11
+ "**๋ชจ๋ธ ๊ตฌ์กฐ:**\n",
12
+ "```\n",
13
+ "Input Token IDs\n",
14
+ " โ†’ Token Embedding\n",
15
+ " โ†’ [TransformerBlock] ร— num_layers\n",
16
+ " โ”‚ โ”œโ”€โ”€ RMSNorm โ†’ GroupedQueryAttention (+ RoPE) โ†’ Residual\n",
17
+ " โ”‚ โ””โ”€โ”€ RMSNorm โ†’ SwiGLU FFN โ†’ Residual\n",
18
+ " โ†’ RMSNorm (์ตœ์ข…)\n",
19
+ " โ†’ Linear Head (Weight Tying)\n",
20
+ " โ†’ Vocab Logits\n",
21
+ "```"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import sys\n",
31
+ "sys.path.insert(0, '..')\n",
32
+ "\n",
33
+ "import torch\n",
34
+ "import math\n",
35
+ "from llm_lab.config import ModelConfig\n",
36
+ "from llm_lab.model import LLMModel, count_parameters_detailed, estimate_memory_gb"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "## 1. ๋ชจ๋ธ ์„ค์ • ์„ ํƒ\n",
44
+ "\n",
45
+ "| ํ”„๋ฆฌ์…‹ | ํŒŒ๋ผ๋ฏธํ„ฐ | ์šฉ๋„ |\n",
46
+ "|--------|----------|------|\n",
47
+ "| `debug_10m()` | ~10M | ํŒŒ์ดํ”„๋ผ์ธ ๊ฒ€์ฆ |\n",
48
+ "| `small_100m()` | ~100M | ์ค‘๊ฐ„ ๊ฒ€์ฆ |\n",
49
+ "| `base_1b()` | ~1.1B | ์ตœ์ข… ํ•™์Šต |"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# --- ๋ชจ๋ธ ์Šค์ผ€์ผ ์„ ํƒ ---\n",
59
+ "# model_config = ModelConfig.debug_10m() # ~10M (๋น ๋ฅธ ๊ฒ€์ฆ)\n",
60
+ "# model_config = ModelConfig.small_100m() # ~100M (์ค‘๊ฐ„ ๊ฒ€์ฆ)\n",
61
+ "model_config = ModelConfig.base_1b() # ~1.1B (์ตœ์ข… ๋ชฉํ‘œ)\n",
62
+ "\n",
63
+ "print(f\"hidden_dim: {model_config.hidden_dim}\")\n",
64
+ "print(f\"num_layers: {model_config.num_layers}\")\n",
65
+ "print(f\"num_heads: {model_config.num_heads}\")\n",
66
+ "print(f\"num_kv_heads: {model_config.num_kv_heads} (GQA ๊ทธ๋ฃน: {model_config.num_kv_groups})\")\n",
67
+ "print(f\"intermediate_dim: {model_config.intermediate_dim}\")\n",
68
+ "print(f\"max_seq_len: {model_config.max_seq_len}\")"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "metadata": {},
74
+ "source": [
75
+ "## 2. ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ํ™•์ธ"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "# Debug ๋ชจ๋ธ ์‹ค์ œ ์ƒ์„ฑ (๋ฉ”๋ชจ๋ฆฌ ํ™•์ธ ์šฉ๋„)\n",
85
+ "debug_config = ModelConfig.debug_10m()\n",
86
+ "model = LLMModel(debug_config)\n",
87
+ "print(f\"Debug ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {model.count_parameters():,}\")\n",
88
+ "\n",
89
+ "# 1B ๋ชจ๋ธ์€ meta device์—์„œ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋งŒ ํ™•์ธ\n",
90
+ "with torch.device(\"meta\"):\n",
91
+ " model_1b = LLMModel(ModelConfig.base_1b())\n",
92
+ "n_params_1b = model_1b.count_parameters()\n",
93
+ "print(f\"1B ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: {n_params_1b:,} ({n_params_1b/1e9:.2f}B)\")"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "metadata": {},
99
+ "source": [
100
+ "## 3. ์ƒ์„ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ถ„ํ•ด"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "detail = count_parameters_detailed(model_1b)\n",
110
+ "cfg_1b = ModelConfig.base_1b()\n",
111
+ "\n",
112
+ "print(f\"Token Embedding: {detail['token_embedding']:,}\")\n",
113
+ "print(f\"Per Layer Total: {detail['per_layer_total']:,}\")\n",
114
+ "print(f\"All Layers ({cfg_1b.num_layers}): {detail['all_layers_total']:,}\")\n",
115
+ "print(f\"Final Norm: {detail['final_norm']:,}\")\n",
116
+ "print(f\"LM Head: {detail['lm_head']}\")\n",
117
+ "print(f\"{'โ”€' * 30}\")\n",
118
+ "print(f\"TOTAL: {detail['total']:,}\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "## 4. GPU ๋ฉ”๋ชจ๋ฆฌ ์ถ”์ •"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "mem = estimate_memory_gb(ModelConfig.base_1b(), batch_size=4, dtype_bytes=2)\n",
135
+ "\n",
136
+ "print(f\"๋ชจ๋ธ ๊ฐ€์ค‘์น˜: {mem['model_weights_gb']} GB\")\n",
137
+ "print(f\"์˜ตํ‹ฐ๋งˆ์ด์ €: {mem['optimizer_states_gb']} GB\")\n",
138
+ "print(f\"๊ธฐ์šธ๊ธฐ: {mem['gradients_gb']} GB\")\n",
139
+ "print(f\"ํ™œ์„ฑํ™” (์ถ”์ •): {mem['activations_estimated_gb']} GB\")\n",
140
+ "print(f\"{'โ”€' * 30}\")\n",
141
+ "print(f\"์ด ์ถ”์ •: {mem['total_estimated_gb']} GB\")"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "metadata": {},
147
+ "source": [
148
+ "## 5. Forward Pass ๊ฒ€์ฆ"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "# Debug ๋ชจ๋ธ๋กœ forward/backward ๊ฒ€์ฆ\n",
158
+ "dummy_input = torch.randint(0, debug_config.vocab_size, (2, 64))\n",
159
+ "dummy_target = torch.randint(0, debug_config.vocab_size, (2, 64))\n",
160
+ "logits, loss = model(dummy_input, dummy_target)\n",
161
+ "\n",
162
+ "print(f\"Input shape: {dummy_input.shape}\")\n",
163
+ "print(f\"Logits shape: {logits.shape}\")\n",
164
+ "print(f\"Loss: {loss.item():.4f}\")\n",
165
+ "expected_loss = math.log(debug_config.vocab_size)\n",
166
+ "print(f\"Expected initial loss (ln({debug_config.vocab_size})): {expected_loss:.2f}\")"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {},
172
+ "source": [
173
+ "## 6. ํ…์ŠคํŠธ ์ƒ์„ฑ ํ…Œ์ŠคํŠธ (๋žœ๋ค ๊ฐ€์ค‘์น˜)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "prompt = torch.randint(0, debug_config.vocab_size, (1, 10))\n",
183
+ "generated = model.generate(prompt, max_new_tokens=20, temperature=1.0, top_k=50)\n",
184
+ "\n",
185
+ "print(f\"Prompt length: {prompt.shape[1]}\")\n",
186
+ "print(f\"Generated length: {generated.shape[1]}\")\n",
187
+ "print(f\"Token IDs: {generated[0].tolist()}\")"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "---\n",
195
+ "**๋‹ค์Œ ๋‹จ๊ณ„:** `03_training.ipynb`์—์„œ ํ•™์Šต์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."
196
+ ]
197
+ }
198
+ ],
199
+ "metadata": {
200
+ "kernelspec": {
201
+ "display_name": "Python 3",
202
+ "language": "python",
203
+ "name": "python3"
204
+ },
205
+ "language_info": {
206
+ "name": "python",
207
+ "version": "3.10.0"
208
+ }
209
+ },
210
+ "nbformat": 4,
211
+ "nbformat_minor": 4
212
+ }
notebooks/03_training.ipynb ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 03. ํ•™์Šต (Training)\n",
8
+ "\n",
9
+ "Gradient Accumulation, Mixed Precision, Cosine LR Scheduling,\n",
10
+ "์ฒดํฌํฌ์ธํŠธ ์ €์žฅ/๋ณต์›, wandb ๋กœ๊น…์„ ํฌํ•จํ•œ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ.\n",
11
+ "\n",
12
+ "**ํ•™์Šต ํ๋ฆ„:**\n",
13
+ "```\n",
14
+ "๋ฐฐ์น˜ ๊ฐ€์ ธ์˜ค๊ธฐ\n",
15
+ " โ†’ Forward (bf16 autocast)\n",
16
+ " โ†’ Loss / accumulation_steps\n",
17
+ " โ†’ Backward (gradient ๋ˆ„์ )\n",
18
+ " โ†’ [accumulation_steps๋งˆ๋‹ค] Gradient Clipping โ†’ Optimizer Step โ†’ LR Update\n",
19
+ " โ†’ [checkpoint_interval๋งˆ๋‹ค] ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (Google Drive)\n",
20
+ " โ†’ [eval_interval๋งˆ๋‹ค] ๊ฒ€์ฆ Loss/Perplexity ์ธก์ •\n",
21
+ "```"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "!pip install wandb -q"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "import sys\n",
40
+ "sys.path.insert(0, '..')\n",
41
+ "\n",
42
+ "from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n",
43
+ "from llm_lab.model import LLMModel\n",
44
+ "from llm_lab.data import setup_data_pipeline\n",
45
+ "from llm_lab.training import start_training, Trainer\n",
46
+ "from llm_lab.utils import auto_configure, get_device"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "metadata": {},
52
+ "source": [
53
+ "## 0. Google Drive ๋งˆ์šดํŠธ (Colab)\n",
54
+ "\n",
55
+ "์ฒดํฌํฌ์ธํŠธ๋ฅผ Google Drive์— ์ €์žฅํ•˜์—ฌ ์„ธ์…˜ ๋งŒ๋ฃŒ ์‹œ์—๋„ ๋ณด์กดํ•ฉ๋‹ˆ๋‹ค."
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# Colab์—์„œ๋งŒ ์‹คํ–‰\n",
65
+ "# from google.colab import drive\n",
66
+ "# drive.mount('/content/drive')"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## 1. ์„ค์ •"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# --- ๋ชจ๋ธ ์„ค์ • ---\n",
83
+ "model_config = ModelConfig.debug_10m() # ๊ฒ€์ฆ ์‹œ debug, ์‹ค์ œ ํ•™์Šต ์‹œ base_1b()\n",
84
+ "\n",
85
+ "# --- ๋ฐ์ดํ„ฐ ์„ค์ • ---\n",
86
+ "data_config = DataConfig(\n",
87
+ " max_seq_len=model_config.max_seq_len,\n",
88
+ " batch_size=4,\n",
89
+ ")\n",
90
+ "\n",
91
+ "# --- ํ•™์Šต ์„ค์ • ---\n",
92
+ "train_config = TrainConfig(\n",
93
+ " total_steps=20_000,\n",
94
+ " warmup_steps=2_000,\n",
95
+ " learning_rate=3e-4,\n",
96
+ " min_learning_rate=3e-5,\n",
97
+ " weight_decay=0.1,\n",
98
+ " grad_clip=1.0,\n",
99
+ " micro_batch_size=4,\n",
100
+ " gradient_accumulation_steps=32,\n",
101
+ " checkpoint_dir=\"/content/drive/MyDrive/llm-1b-lab/checkpoints\",\n",
102
+ " checkpoint_interval=500,\n",
103
+ " eval_interval=500,\n",
104
+ " log_interval=10,\n",
105
+ " use_wandb=True,\n",
106
+ ")\n",
107
+ "\n",
108
+ "print(f\"Effective batch size: {train_config.effective_batch_size}\")\n",
109
+ "print(f\"Total steps: {train_config.total_steps:,}\")"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "metadata": {},
115
+ "source": [
116
+ "## 2. GPU ์ž๋™ ๊ฐ์ง€\n",
117
+ "\n",
118
+ "GPU ์ข…๋ฅ˜(A100/V100/T4/L4)์— ๋”ฐ๋ผ dtype, batch_size, gradient_accumulation์„ ์ž๋™ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค."
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "train_config = auto_configure(train_config)"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## 3. ๋ชจ๋ธ + ๋ฐ์ดํ„ฐ ์ดˆ๊ธฐํ™”"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "# ๋ชจ๋ธ ์ƒ์„ฑ\n",
144
+ "model = LLMModel(model_config)\n",
145
+ "print(f\"๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ: {model.count_parameters():,}\")\n",
146
+ "\n",
147
+ "# ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ\n",
148
+ "tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
149
+ " tokenizer_mode=\"pretrained\",\n",
150
+ " config=data_config,\n",
151
+ ")"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {},
157
+ "source": [
158
+ "## 4. ํ•™์Šต ์‹œ์ž‘\n",
159
+ "\n",
160
+ "์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์žˆ์œผ๋ฉด ์ž๋™์œผ๋กœ ๋ณต์›ํ•˜์—ฌ ์ด์–ด์„œ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค."
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "trainer = start_training(\n",
170
+ " model=model,\n",
171
+ " train_dataloader=train_dl,\n",
172
+ " val_dataloader=val_dl,\n",
173
+ " config=train_config,\n",
174
+ " seq_len=model_config.max_seq_len,\n",
175
+ ")"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {},
181
+ "source": [
182
+ "## 5. ํ•™์Šต ์žฌ๊ฐœ (์„ธ์…˜ ๋งŒ๋ฃŒ ํ›„)\n",
183
+ "\n",
184
+ "Colab ์„ธ์…˜์ด ๋งŒ๋ฃŒ๋œ ํ›„ ๋‹ค์‹œ ์‹คํ–‰ํ•˜๋ฉด CheckpointManager๊ฐ€ ์ž๋™์œผ๋กœ ์ตœ์‹  ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฐพ์•„ ๋ณต์›ํ•ฉ๋‹ˆ๋‹ค.\n",
185
+ "\n",
186
+ "์œ„์˜ ์…€๋“ค์„ ์ˆœ์„œ๋Œ€๋กœ ๋‹ค์‹œ ์‹คํ–‰ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค."
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "---\n",
194
+ "**๋‹ค์Œ ๋‹จ๊ณ„:** `04_evaluation.ipynb`์—์„œ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค."
195
+ ]
196
+ }
197
+ ],
198
+ "metadata": {
199
+ "kernelspec": {
200
+ "display_name": "Python 3",
201
+ "language": "python",
202
+ "name": "python3"
203
+ },
204
+ "language_info": {
205
+ "name": "python",
206
+ "version": "3.10.0"
207
+ }
208
+ },
209
+ "nbformat": 4,
210
+ "nbformat_minor": 4
211
+ }
notebooks/04_evaluation.ipynb ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 04. ํ‰๊ฐ€ (Evaluation)\n",
8
+ "\n",
9
+ "ํ•™์Šต๋œ ๋ชจ๋ธ์˜ ํ’ˆ์งˆ์„ ๋‹ค๊ฐ๋„๋กœ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.\n",
10
+ "\n",
11
+ "**ํ‰๊ฐ€ ์˜์—ญ:**\n",
12
+ "1. Perplexity ์ธก์ • โ€” ์–ธ์–ด ๋ชจ๋ธ์˜ ํ‘œ์ค€ ์ •๋Ÿ‰ ์ง€ํ‘œ\n",
13
+ "2. ํ…์ŠคํŠธ ์ƒ์„ฑ ํ’ˆ์งˆ โ€” ๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ๋กœ ์ •์„ฑ์  ํ‰๊ฐ€\n",
14
+ "3. Scaling Law ๋ถ„์„ โ€” 10M โ†’ 100M โ†’ 1B ๋น„๊ต\n",
15
+ "4. Attention ์‹œ๊ฐํ™” โ€” ๋ชจ๋ธ์ด \"์–ด๋””๋ฅผ ๋ณด๋Š”์ง€\" ๋ถ„์„\n",
16
+ "5. ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ โ€” ํ•™์Šต ๋ชฉํ‘œ ๋‹ฌ์„ฑ ํ™•์ธ"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "!pip install matplotlib numpy -q"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "import sys\n",
35
+ "sys.path.insert(0, '..')\n",
36
+ "\n",
37
+ "import torch\n",
38
+ "from llm_lab.config import ModelConfig, EvalConfig\n",
39
+ "from llm_lab.model import LLMModel\n",
40
+ "from llm_lab.evaluation import (\n",
41
+ " run_evaluation, PerplexityEvaluator, GenerationEvaluator,\n",
42
+ " ScalingAnalyzer, AttentionVisualizer, InsightChecklist\n",
43
+ ")\n",
44
+ "from llm_lab.utils import get_device"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {},
50
+ "source": [
51
+ "## 1. ๋ชจ๋ธ ๋กœ๋“œ\n",
52
+ "\n",
53
+ "ํ•™์Šต๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค."
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "device = get_device()\n",
63
+ "model_config = ModelConfig.base_1b()\n",
64
+ "model = LLMModel(model_config).to(device)\n",
65
+ "\n",
66
+ "# ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ (๊ฒฝ๋กœ๋ฅผ ์‹ค์ œ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ๋กœ ๋ณ€๊ฒฝ)\n",
67
+ "# ckpt = torch.load(\"path/to/step_XXXXXX/model.pt\", map_location=device)\n",
68
+ "# model.load_state_dict(ckpt)\n",
69
+ "\n",
70
+ "print(f\"๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ: {model.count_parameters():,}\")\n",
71
+ "print(f\"๋””๋ฐ”์ด์Šค: {device}\")"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {},
77
+ "source": [
78
+ "## 2. ์ข…ํ•ฉ ํ‰๊ฐ€ (ํ•œ ์ค„ ์‹คํ–‰)\n",
79
+ "\n",
80
+ "Perplexity, ํ…์ŠคํŠธ ์ƒ์„ฑ, ํ•™์Šต ์—ญํ•™, Attention ์‹œ๊ฐํ™”๋ฅผ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค."
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "# ํ•™์Šต ์‹œ ์‚ฌ์šฉํ•œ tokenizer, val_dl, metrics_history๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค\n",
90
+ "# report = run_evaluation(\n",
91
+ "# model=model,\n",
92
+ "# tokenizer=tokenizer,\n",
93
+ "# val_dataloader=val_dl,\n",
94
+ "# metrics_history=trainer.metrics.history,\n",
95
+ "# )"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 3. Scaling Law ๋ถ„์„\n",
103
+ "\n",
104
+ "10M โ†’ 100M โ†’ 1B ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•˜์—ฌ Scaling Law๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.\n",
105
+ "\n",
106
+ "Chinchilla Scaling Law: ์ตœ์  ํ•™์Šต ํ† ํฐ ์ˆ˜ โ‰ˆ 20 ร— ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "analyzer = ScalingAnalyzer()\n",
116
+ "\n",
117
+ "# ๊ฐ ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ๋ฅผ ์ž…๋ ฅ (์‹ค์ œ ํ•™์Šต ๊ฒฐ๊ณผ๋กœ ๋Œ€์ฒด)\n",
118
+ "scaling_results = [\n",
119
+ " {\"name\": \"10M\", \"params\": 10e6, \"tokens\": 1e9, \"loss\": 4.2, \"ppl\": 66.7},\n",
120
+ " {\"name\": \"100M\", \"params\": 100e6, \"tokens\": 5e9, \"loss\": 3.5, \"ppl\": 33.1},\n",
121
+ " {\"name\": \"1B\", \"params\": 1.1e9, \"tokens\": 10e9, \"loss\": 3.0, \"ppl\": 20.1},\n",
122
+ "]\n",
123
+ "\n",
124
+ "analysis = analyzer.analyze(scaling_results)\n",
125
+ "analyzer.plot_scaling_curves(scaling_results)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {},
131
+ "source": [
132
+ "## 4. Attention ์‹œ๊ฐํ™”\n",
133
+ "\n",
134
+ "๋ชจ๋ธ์ด ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด \"์–ด๋””๋ฅผ ๋ณด๋Š”์ง€\" ์‹œ๊ฐํ™”ํ•ฉ๋‹ˆ๋‹ค."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "# viz = AttentionVisualizer()\n",
144
+ "# sample_text = \"The cat sat on the mat and looked at the bird.\"\n",
145
+ "# token_ids = tokenizer.encode(sample_text)\n",
146
+ "# input_tensor = torch.tensor([token_ids], dtype=torch.long)\n",
147
+ "# \n",
148
+ "# attn_weights = viz.extract_attention(model, input_tensor, layer_idx=0, device=device)\n",
149
+ "# if attn_weights is not None:\n",
150
+ "# tokens_str = [tokenizer.decode([tid]) for tid in token_ids]\n",
151
+ "# viz.plot_attention_heatmap(attn_weights, tokens_str, head_idx=0)\n",
152
+ "# viz.plot_multi_head_summary(attn_weights)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "metadata": {},
158
+ "source": [
159
+ "## 5. ์ธ์‚ฌ์ดํŠธ ์ฒดํฌ๋ฆฌ์ŠคํŠธ\n",
160
+ "\n",
161
+ "ํ•™์Šต ๋ชฉํ‘œ ๋‹ฌ์„ฑ ์—ฌ๋ถ€๋ฅผ ์ž๋™/์ˆ˜๋™์œผ๋กœ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค."
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "# report๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ์ฒดํฌ๋ฆฌ์ŠคํŠธ ์‹คํ–‰\n",
171
+ "# InsightChecklist.run_checklist(report, metrics_history)"
172
+ ]
173
+ }
174
+ ],
175
+ "metadata": {
176
+ "kernelspec": {
177
+ "display_name": "Python 3",
178
+ "language": "python",
179
+ "name": "python3"
180
+ },
181
+ "language_info": {
182
+ "name": "python",
183
+ "version": "3.10.0"
184
+ }
185
+ },
186
+ "nbformat": 4,
187
+ "nbformat_minor": 4
188
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ datasets
3
+ tokenizers
4
+ sentencepiece
5
+ transformers
6
+ wandb
7
+ matplotlib
8
+ numpy