feat(training): add LossDebugger 5-level diagnostic framework
Browse filesSystematic debugging tool for diagnosing LLM training loss issues:
- Level 0: auto-classify training health status (6 categories)
- Level 1: data/implementation bug checks (shift, token range, overfit test)
- Level 2: numerical stability (mixed precision config, gradient/activation
NaN/Inf detection, common issues reference table)
- Level 3: hyperparameter analysis (LR, ฮฒโ, weight decay, batch-LR scaling,
GPT-3 LR reference table, warmup reference, LR range test)
- Level 4: overfitting vs underfitting diagnosis (detailed sub-cause analysis,
Chinchilla token ratio, dropout note for pretraining)
- Level 5: architecture checks (per-layer activation stats, weight init
distribution analysis, ablation study reference)
- Scenario auto-detection (A/B/C/D) with step-by-step recommendations
- Study roadmap with recommended experiments and key references
Also adds 05_debugging.ipynb with mock scenarios A/B/C/D for all
diagnostic levels and educational content from the optimization guide.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- CLAUDE.md +4 -2
- llm_lab/training/__init__.py +2 -1
- llm_lab/training/debugger.py +1442 -0
- notebooks/05_debugging.ipynb +384 -0
|
@@ -35,7 +35,8 @@ LLM_Foundation_Model/
|
|
| 35 |
โ โ โโโ metrics.py # MetricsTracker (wandb integration)
|
| 36 |
โ โ โโโ optimizer.py # create_optimizer (weight decay separation)
|
| 37 |
โ โ โโโ trainer.py # Trainer (gradient accumulation, mixed precision)
|
| 38 |
-
โ โ
|
|
|
|
| 39 |
โ โโโ evaluation/ # Evaluation & analysis
|
| 40 |
โ โ โโโ perplexity.py # PerplexityEvaluator (including per-position loss)
|
| 41 |
โ โ โโโ generation.py # GenerationEvaluator (various prompts)
|
|
@@ -52,7 +53,8 @@ LLM_Foundation_Model/
|
|
| 52 |
โ โโโ 01_data_pipeline.ipynb
|
| 53 |
โ โโโ 02_model.ipynb
|
| 54 |
โ โโโ 03_training.ipynb
|
| 55 |
-
โ
|
|
|
|
| 56 |
โโโ _archive/ # Original single-file backups
|
| 57 |
โโโ llm-1b-model.py
|
| 58 |
โโโ llm-1b-data-pipeline.py
|
|
|
|
| 35 |
โ โ โโโ metrics.py # MetricsTracker (wandb integration)
|
| 36 |
โ โ โโโ optimizer.py # create_optimizer (weight decay separation)
|
| 37 |
โ โ โโโ trainer.py # Trainer (gradient accumulation, mixed precision)
|
| 38 |
+
โ โ โโโ runner.py # start_training (one-line helper)
|
| 39 |
+
โ โ โโโ debugger.py # LossDebugger (5-level diagnostic framework)
|
| 40 |
โ โโโ evaluation/ # Evaluation & analysis
|
| 41 |
โ โ โโโ perplexity.py # PerplexityEvaluator (including per-position loss)
|
| 42 |
โ โ โโโ generation.py # GenerationEvaluator (various prompts)
|
|
|
|
| 53 |
โ โโโ 01_data_pipeline.ipynb
|
| 54 |
โ โโโ 02_model.ipynb
|
| 55 |
โ โโโ 03_training.ipynb
|
| 56 |
+
โ โโโ 04_evaluation.ipynb
|
| 57 |
+
โ โโโ 05_debugging.ipynb
|
| 58 |
โโโ _archive/ # Original single-file backups
|
| 59 |
โโโ llm-1b-model.py
|
| 60 |
โโโ llm-1b-data-pipeline.py
|
|
@@ -5,8 +5,9 @@ from .metrics import MetricsTracker
|
|
| 5 |
from .optimizer import create_optimizer
|
| 6 |
from .trainer import Trainer
|
| 7 |
from .runner import start_training
|
|
|
|
| 8 |
|
| 9 |
__all__ = [
|
| 10 |
"CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
|
| 11 |
-
"create_optimizer", "Trainer", "start_training",
|
| 12 |
]
|
|
|
|
| 5 |
from .optimizer import create_optimizer
|
| 6 |
from .trainer import Trainer
|
| 7 |
from .runner import start_training
|
| 8 |
+
from .debugger import LossDebugger
|
| 9 |
|
| 10 |
__all__ = [
|
| 11 |
"CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
|
| 12 |
+
"create_optimizer", "Trainer", "start_training", "LossDebugger",
|
| 13 |
]
|
|
@@ -0,0 +1,1442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM Loss Debugging & Optimization Framework.
|
| 2 |
+
|
| 3 |
+
A systematic 5-level debugging framework for diagnosing training issues.
|
| 4 |
+
Always start from Level 1 โ fixing lower-level bugs before tuning
|
| 5 |
+
hyperparameters saves time.
|
| 6 |
+
|
| 7 |
+
Levels:
|
| 8 |
+
0. Status Diagnosis โ classify current training health
|
| 9 |
+
1. Data/Implementation โ most common cause (70% of issues)
|
| 10 |
+
2. Numerical Stability โ dtype, normalization, gradient health
|
| 11 |
+
3. Hyperparameters โ LR, batch size, warmup
|
| 12 |
+
4. Fitting Diagnosis โ overfitting vs underfitting
|
| 13 |
+
5. Architecture โ initialization, component checks
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.utils.data import DataLoader
|
| 24 |
+
|
| 25 |
+
from llm_lab.config import TrainConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 29 |
+
# Constants
|
| 30 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 31 |
+
|
| 32 |
+
# Normal convergence ranges for a 1B model trained on ~10B tokens
|
| 33 |
+
_EXPECTED_TRAIN_LOSS = (2.8, 3.5)
|
| 34 |
+
_EXPECTED_VAL_LOSS = (3.0, 3.8)
|
| 35 |
+
_EXPECTED_VAL_PPL = (20, 45)
|
| 36 |
+
|
| 37 |
+
# Status labels
|
| 38 |
+
STATUS_NORMAL = "NORMAL"
|
| 39 |
+
STATUS_NO_DECREASE = "NO_DECREASE"
|
| 40 |
+
STATUS_DIVERGING = "DIVERGING"
|
| 41 |
+
STATUS_PLATEAU = "PLATEAU"
|
| 42 |
+
STATUS_OVERFITTING = "OVERFITTING"
|
| 43 |
+
STATUS_UNSTABLE = "UNSTABLE"
|
| 44 |
+
|
| 45 |
+
# GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
|
| 46 |
+
# (param_count, recommended_lr, batch_tokens_str)
|
| 47 |
+
_GPT3_LR_REFERENCE = [
|
| 48 |
+
(125e6, 6e-4, "0.5M"),
|
| 49 |
+
(350e6, 3e-4, "0.5M"),
|
| 50 |
+
(1.3e9, 2e-4, "1M"),
|
| 51 |
+
(2.7e9, 1.6e-4, "1M"),
|
| 52 |
+
(6.7e9, 1.2e-4, "2M"),
|
| 53 |
+
(175e9, 6e-5, "3.2M"),
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# Known LLM training references
|
| 57 |
+
_LLM_TRAINING_REFS = {
|
| 58 |
+
"TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
|
| 59 |
+
"LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
|
| 60 |
+
"Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1},
|
| 61 |
+
"OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Recommended ฮฒโ for LLM training
|
| 65 |
+
_RECOMMENDED_BETA2 = 0.95
|
| 66 |
+
_DEFAULT_PYTORCH_BETA2 = 0.999
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _header(title: str) -> str:
|
| 70 |
+
return f"\n{'=' * 60}\n{title}\n{'=' * 60}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]:
|
| 74 |
+
return {"name": name, "passed": passed, "detail": detail}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 78 |
+
# LossDebugger
|
| 79 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LossDebugger:
|
| 83 |
+
"""5-level loss debugging framework for LLM training.
|
| 84 |
+
|
| 85 |
+
Usage::
|
| 86 |
+
|
| 87 |
+
from llm_lab.training.debugger import LossDebugger
|
| 88 |
+
|
| 89 |
+
# Quick status check
|
| 90 |
+
status = LossDebugger.diagnose_status(vocab_size=32000,
|
| 91 |
+
metrics_history=trainer.metrics.history)
|
| 92 |
+
|
| 93 |
+
# Full diagnostics
|
| 94 |
+
report = LossDebugger.run_diagnostics(
|
| 95 |
+
model=model, dataloader=train_dl, tokenizer=tok,
|
| 96 |
+
train_config=train_cfg, metrics_history=trainer.metrics.history,
|
| 97 |
+
device=device, dtype=torch.bfloat16,
|
| 98 |
+
)
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 102 |
+
# Level 0: Status Diagnosis
|
| 103 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def diagnose_status(
|
| 107 |
+
vocab_size: int,
|
| 108 |
+
metrics_history: Dict[str, list],
|
| 109 |
+
) -> Dict[str, Any]:
|
| 110 |
+
"""Classify current training health from metrics history.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
vocab_size: model vocabulary size (e.g. 32000)
|
| 114 |
+
metrics_history: dict with keys 'train_loss', 'val_loss', etc.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
dict with 'status', 'severity', 'details', 'recommended_levels'
|
| 118 |
+
"""
|
| 119 |
+
print(_header("Level 0: Training Status Diagnosis"))
|
| 120 |
+
|
| 121 |
+
expected_initial = math.log(vocab_size)
|
| 122 |
+
print(f" Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}")
|
| 123 |
+
print(f" Normal convergence range (1B, 10B tokens):")
|
| 124 |
+
print(f" Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}")
|
| 125 |
+
print(f" Val Loss: {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}")
|
| 126 |
+
print(f" Val PPL: {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}")
|
| 127 |
+
|
| 128 |
+
train_losses = metrics_history.get("train_loss", [])
|
| 129 |
+
val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
|
| 130 |
+
|
| 131 |
+
if len(train_losses) < 2:
|
| 132 |
+
print("\n [!] Not enough training data to diagnose. Run more steps first.")
|
| 133 |
+
return {
|
| 134 |
+
"status": "INSUFFICIENT_DATA",
|
| 135 |
+
"severity": "unknown",
|
| 136 |
+
"details": "Need at least 2 logged train loss values.",
|
| 137 |
+
"recommended_levels": [1],
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
first_loss = train_losses[0]
|
| 141 |
+
last_loss = train_losses[-1]
|
| 142 |
+
loss_change = first_loss - last_loss
|
| 143 |
+
|
| 144 |
+
# Split into halves for trend analysis
|
| 145 |
+
mid = len(train_losses) // 2
|
| 146 |
+
first_half_avg = sum(train_losses[:mid]) / mid
|
| 147 |
+
second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid)
|
| 148 |
+
|
| 149 |
+
# Recent window for spike detection
|
| 150 |
+
recent_n = min(50, len(train_losses))
|
| 151 |
+
recent = train_losses[-recent_n:]
|
| 152 |
+
recent_mean = sum(recent) / len(recent)
|
| 153 |
+
recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent)
|
| 154 |
+
recent_std = recent_var ** 0.5
|
| 155 |
+
|
| 156 |
+
# Val trend
|
| 157 |
+
val_trend = "unknown"
|
| 158 |
+
if len(val_losses) >= 2:
|
| 159 |
+
val_mid = len(val_losses) // 2
|
| 160 |
+
val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
|
| 161 |
+
val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
|
| 162 |
+
if val_second_avg < val_first_avg - 0.05:
|
| 163 |
+
val_trend = "decreasing"
|
| 164 |
+
elif val_second_avg > val_first_avg + 0.1:
|
| 165 |
+
val_trend = "increasing"
|
| 166 |
+
else:
|
| 167 |
+
val_trend = "flat"
|
| 168 |
+
|
| 169 |
+
# โโ Classify โโ
|
| 170 |
+
status = STATUS_NORMAL
|
| 171 |
+
severity = "green"
|
| 172 |
+
details = ""
|
| 173 |
+
recommended_levels: List[int] = []
|
| 174 |
+
|
| 175 |
+
# Check 1: No decrease at all
|
| 176 |
+
if loss_change < 0.1 and first_loss > expected_initial - 2.0:
|
| 177 |
+
status = STATUS_NO_DECREASE
|
| 178 |
+
severity = "red"
|
| 179 |
+
details = (
|
| 180 |
+
f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} "
|
| 181 |
+
f"(delta={loss_change:.4f}). Likely a data or implementation bug."
|
| 182 |
+
)
|
| 183 |
+
recommended_levels = [1, 2]
|
| 184 |
+
|
| 185 |
+
# Check 2: Diverging
|
| 186 |
+
elif last_loss > expected_initial + 1.0:
|
| 187 |
+
status = STATUS_DIVERGING
|
| 188 |
+
severity = "red"
|
| 189 |
+
details = (
|
| 190 |
+
f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). "
|
| 191 |
+
f"Training is diverging โ check LR, data, or numerical issues."
|
| 192 |
+
)
|
| 193 |
+
recommended_levels = [1, 2, 3]
|
| 194 |
+
|
| 195 |
+
# Check 3: Unstable (large spikes)
|
| 196 |
+
elif recent_std > 0.5 * recent_mean:
|
| 197 |
+
status = STATUS_UNSTABLE
|
| 198 |
+
severity = "yellow"
|
| 199 |
+
details = (
|
| 200 |
+
f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. "
|
| 201 |
+
f"Training is unstable โ likely LR too high or batch too small."
|
| 202 |
+
)
|
| 203 |
+
recommended_levels = [3, 2]
|
| 204 |
+
|
| 205 |
+
# Check 4: Overfitting
|
| 206 |
+
elif val_trend == "increasing" and second_half_avg < first_half_avg:
|
| 207 |
+
status = STATUS_OVERFITTING
|
| 208 |
+
severity = "yellow"
|
| 209 |
+
details = (
|
| 210 |
+
f"Train loss decreasing but val loss increasing. "
|
| 211 |
+
f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, "
|
| 212 |
+
f"Val trend: {val_trend}."
|
| 213 |
+
)
|
| 214 |
+
recommended_levels = [4]
|
| 215 |
+
|
| 216 |
+
# Check 5: Plateau
|
| 217 |
+
elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
|
| 218 |
+
status = STATUS_PLATEAU
|
| 219 |
+
severity = "yellow"
|
| 220 |
+
details = (
|
| 221 |
+
f"Loss has plateaued: first half avg={first_half_avg:.4f}, "
|
| 222 |
+
f"second half avg={second_half_avg:.4f}. "
|
| 223 |
+
f"Current loss ({last_loss:.4f}) is above expected range."
|
| 224 |
+
)
|
| 225 |
+
recommended_levels = [3, 4, 5]
|
| 226 |
+
|
| 227 |
+
# Normal
|
| 228 |
+
else:
|
| 229 |
+
status = STATUS_NORMAL
|
| 230 |
+
severity = "green"
|
| 231 |
+
details = (
|
| 232 |
+
f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} "
|
| 233 |
+
f"(delta={loss_change:.4f}). Val trend: {val_trend}."
|
| 234 |
+
)
|
| 235 |
+
recommended_levels = []
|
| 236 |
+
|
| 237 |
+
# โโ Print โโ
|
| 238 |
+
icons = {"red": "๐ด", "yellow": "๐ก", "green": "๐ข"}
|
| 239 |
+
icon = icons.get(severity, "โช")
|
| 240 |
+
print(f"\n {icon} Status: {status}")
|
| 241 |
+
print(f" {details}")
|
| 242 |
+
if recommended_levels:
|
| 243 |
+
print(f" Recommended: check Level(s) {recommended_levels}")
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
"status": status,
|
| 247 |
+
"severity": severity,
|
| 248 |
+
"details": details,
|
| 249 |
+
"recommended_levels": recommended_levels,
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ๏ฟฝ๏ฟฝโโโโโโโโโโโโโ
|
| 253 |
+
# Level 1: Data / Implementation Bug Checks
|
| 254 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def check_data_pipeline(
|
| 258 |
+
model: nn.Module,
|
| 259 |
+
dataloader: DataLoader,
|
| 260 |
+
tokenizer: Any,
|
| 261 |
+
vocab_size: int,
|
| 262 |
+
device: torch.device,
|
| 263 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 264 |
+
) -> Dict[str, Any]:
|
| 265 |
+
"""Run 6 data/implementation checks (Level 1).
|
| 266 |
+
|
| 267 |
+
This is the most important level โ 70% of loss issues are data bugs.
|
| 268 |
+
|
| 269 |
+
Checks:
|
| 270 |
+
1. Shift relationship (targets[t] == input_ids[t+1])
|
| 271 |
+
2. Token range (0 <= ids < vocab_size)
|
| 272 |
+
3. Initial loss (โ ln(vocab_size) for random weights)
|
| 273 |
+
4. Single-batch overfit (loss โ ~0 in 200 steps)
|
| 274 |
+
5. Tokenizer roundtrip (encodeโdecode preserves text)
|
| 275 |
+
6. Data quality sampling (visual inspection)
|
| 276 |
+
"""
|
| 277 |
+
print(_header("Level 1: Data / Implementation Bug Checks"))
|
| 278 |
+
print(" (70% of loss issues come from data pipeline bugs)\n")
|
| 279 |
+
|
| 280 |
+
results: List[Dict[str, Any]] = []
|
| 281 |
+
batch = next(iter(dataloader))
|
| 282 |
+
input_ids = batch["input_ids"]
|
| 283 |
+
targets = batch["targets"]
|
| 284 |
+
|
| 285 |
+
# โโ Check 1: Shift relationship โโ
|
| 286 |
+
shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
|
| 287 |
+
passed = shift_match > 0.99
|
| 288 |
+
detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)"
|
| 289 |
+
results.append(_check_result("Shift relationship", passed, detail))
|
| 290 |
+
icon = "โ
" if passed else "โ"
|
| 291 |
+
print(f" {icon} Check 1: {detail}")
|
| 292 |
+
|
| 293 |
+
# โโ Check 2: Token range โโ
|
| 294 |
+
min_id = input_ids.min().item()
|
| 295 |
+
max_id = input_ids.max().item()
|
| 296 |
+
range_ok = min_id >= 0 and max_id < vocab_size
|
| 297 |
+
detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}"
|
| 298 |
+
results.append(_check_result("Token range", range_ok, detail))
|
| 299 |
+
icon = "โ
" if range_ok else "โ"
|
| 300 |
+
print(f" {icon} Check 2: {detail}")
|
| 301 |
+
|
| 302 |
+
# โโ Check 3: Initial loss โโ
|
| 303 |
+
expected_loss = math.log(vocab_size)
|
| 304 |
+
model_copy = copy.deepcopy(model)
|
| 305 |
+
model_copy._init_weights() # re-initialize to random
|
| 306 |
+
model_copy.to(device)
|
| 307 |
+
model_copy.eval()
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 310 |
+
_, initial_loss = model_copy(
|
| 311 |
+
input_ids.to(device),
|
| 312 |
+
targets.to(device),
|
| 313 |
+
)
|
| 314 |
+
initial_loss_val = initial_loss.item()
|
| 315 |
+
loss_diff = abs(initial_loss_val - expected_loss)
|
| 316 |
+
loss_ok = loss_diff < 1.0
|
| 317 |
+
detail = (
|
| 318 |
+
f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} "
|
| 319 |
+
f"(diff={loss_diff:.4f})"
|
| 320 |
+
)
|
| 321 |
+
results.append(_check_result("Initial loss", loss_ok, detail))
|
| 322 |
+
icon = "โ
" if loss_ok else "โ"
|
| 323 |
+
print(f" {icon} Check 3: {detail}")
|
| 324 |
+
if initial_loss_val > expected_loss + 1.0:
|
| 325 |
+
print(f" Hint: loss >> ln(V) suggests label mismatch or loss function bug")
|
| 326 |
+
elif initial_loss_val < expected_loss - 2.0:
|
| 327 |
+
print(f" Hint: loss << ln(V) suggests data leakage")
|
| 328 |
+
del model_copy
|
| 329 |
+
if torch.cuda.is_available():
|
| 330 |
+
torch.cuda.empty_cache()
|
| 331 |
+
|
| 332 |
+
# โโ Check 4: Single-batch overfit test โโ
|
| 333 |
+
print(f"\n โณ Check 4: Single-batch overfit test (200 steps)...")
|
| 334 |
+
overfit_model = copy.deepcopy(model)
|
| 335 |
+
overfit_model.to(device)
|
| 336 |
+
overfit_model.train()
|
| 337 |
+
overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=1e-3)
|
| 338 |
+
single_input = input_ids[:1].to(device) # single sample
|
| 339 |
+
single_target = targets[:1].to(device)
|
| 340 |
+
|
| 341 |
+
overfit_losses = []
|
| 342 |
+
for step in range(200):
|
| 343 |
+
overfit_optimizer.zero_grad()
|
| 344 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 345 |
+
_, loss = overfit_model(single_input, single_target)
|
| 346 |
+
loss.backward()
|
| 347 |
+
overfit_optimizer.step()
|
| 348 |
+
overfit_losses.append(loss.item())
|
| 349 |
+
if (step + 1) % 50 == 0:
|
| 350 |
+
print(f" Step {step + 1}: Loss = {loss.item():.4f}")
|
| 351 |
+
|
| 352 |
+
final_overfit_loss = overfit_losses[-1]
|
| 353 |
+
overfit_ok = final_overfit_loss < 0.1
|
| 354 |
+
detail = (
|
| 355 |
+
f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
|
| 356 |
+
f"(target < 0.1)"
|
| 357 |
+
)
|
| 358 |
+
results.append(_check_result("Single-batch overfit", overfit_ok, detail))
|
| 359 |
+
icon = "โ
" if overfit_ok else "โ"
|
| 360 |
+
print(f" {icon} Check 4: {detail}")
|
| 361 |
+
if not overfit_ok:
|
| 362 |
+
print(f" CRITICAL: Model cannot memorize a single batch!")
|
| 363 |
+
print(f" This means the model or loss function has a bug.")
|
| 364 |
+
del overfit_model, overfit_optimizer
|
| 365 |
+
if torch.cuda.is_available():
|
| 366 |
+
torch.cuda.empty_cache()
|
| 367 |
+
|
| 368 |
+
# โโ Check 5: Tokenizer roundtrip โโ
|
| 369 |
+
test_text = "The quick brown fox jumps over the lazy dog."
|
| 370 |
+
encoded = tokenizer.encode(test_text)
|
| 371 |
+
decoded = tokenizer.decode(encoded)
|
| 372 |
+
roundtrip_ok = test_text.strip() in decoded.strip()
|
| 373 |
+
detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'"
|
| 374 |
+
results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail))
|
| 375 |
+
icon = "โ
" if roundtrip_ok else "โ"
|
| 376 |
+
print(f" {icon} Check 5: {detail}")
|
| 377 |
+
|
| 378 |
+
# โโ Check 6: Data quality sampling โโ
|
| 379 |
+
print(f"\n ๐ Check 6: Data quality sampling (visual inspection)")
|
| 380 |
+
for i in range(min(3, input_ids.shape[0])):
|
| 381 |
+
sample_tokens = input_ids[i][:100].tolist()
|
| 382 |
+
decoded_text = tokenizer.decode(sample_tokens)
|
| 383 |
+
preview = decoded_text[:200].replace("\n", "\\n")
|
| 384 |
+
print(f" Sample {i}: {preview}...")
|
| 385 |
+
|
| 386 |
+
passed_count = sum(1 for r in results if r["passed"])
|
| 387 |
+
total_count = len(results)
|
| 388 |
+
print(f"\n Result: {passed_count}/{total_count} checks passed")
|
| 389 |
+
|
| 390 |
+
return {
|
| 391 |
+
"level": 1,
|
| 392 |
+
"checks": results,
|
| 393 |
+
"passed": [r for r in results if r["passed"]],
|
| 394 |
+
"failed": [r for r in results if not r["passed"]],
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 398 |
+
# Level 2: Numerical Stability
|
| 399 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 400 |
+
|
| 401 |
+
@staticmethod
|
| 402 |
+
def check_numerical_stability(
|
| 403 |
+
model: nn.Module,
|
| 404 |
+
dataloader: DataLoader,
|
| 405 |
+
device: torch.device,
|
| 406 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 407 |
+
) -> Dict[str, Any]:
|
| 408 |
+
"""Check for NaN/Inf in gradients, activations, and logits (Level 2).
|
| 409 |
+
|
| 410 |
+
Checks:
|
| 411 |
+
- Mixed precision config (RMSNorm fp32 upcast, loss dtype)
|
| 412 |
+
- NaN/Inf gradients โ softmax overflow, bad data
|
| 413 |
+
- Inf gradients โ log(0) in loss, missing ignore_index
|
| 414 |
+
- Large activations growing per layer โ initialization or norm bug
|
| 415 |
+
- Logit scale โ should be < 1000
|
| 416 |
+
"""
|
| 417 |
+
print(_header("Level 2: Numerical Stability Checks"))
|
| 418 |
+
|
| 419 |
+
batch = next(iter(dataloader))
|
| 420 |
+
input_ids = batch["input_ids"].to(device)
|
| 421 |
+
targets = batch["targets"].to(device)
|
| 422 |
+
|
| 423 |
+
results: List[Dict[str, Any]] = []
|
| 424 |
+
activation_stats: List[Dict[str, Any]] = []
|
| 425 |
+
|
| 426 |
+
# โโ Mixed Precision Configuration Check โโ
|
| 427 |
+
print("\n Mixed Precision Config:")
|
| 428 |
+
print(f" Training dtype: {dtype}")
|
| 429 |
+
|
| 430 |
+
# Check RMSNorm fp32 upcast
|
| 431 |
+
norm_fp32_ok = True
|
| 432 |
+
for name, module in model.named_modules():
|
| 433 |
+
cls_name = module.__class__.__name__
|
| 434 |
+
if "Norm" in cls_name:
|
| 435 |
+
# Inspect forward source for .float() call
|
| 436 |
+
import inspect
|
| 437 |
+
try:
|
| 438 |
+
src = inspect.getsource(module.forward)
|
| 439 |
+
has_upcast = ".float()" in src or "float32" in src
|
| 440 |
+
except (TypeError, OSError):
|
| 441 |
+
has_upcast = True # assume ok if can't inspect
|
| 442 |
+
if not has_upcast:
|
| 443 |
+
norm_fp32_ok = False
|
| 444 |
+
print(f" ๐ด {name} ({cls_name}): no fp32 upcast detected!")
|
| 445 |
+
break # check one norm layer is enough
|
| 446 |
+
if norm_fp32_ok:
|
| 447 |
+
print(f" โ
Norm layers use fp32 upcast (safe)")
|
| 448 |
+
|
| 449 |
+
results.append(_check_result(
|
| 450 |
+
"Norm fp32 upcast", norm_fp32_ok,
|
| 451 |
+
"Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype",
|
| 452 |
+
))
|
| 453 |
+
|
| 454 |
+
# Check loss computation dtype
|
| 455 |
+
loss_fp32_note = (
|
| 456 |
+
dtype in (torch.bfloat16, torch.float16)
|
| 457 |
+
and "cross_entropy" in str(type(model))
|
| 458 |
+
)
|
| 459 |
+
if dtype in (torch.bfloat16, torch.float16):
|
| 460 |
+
print(f" โน๏ธ Best practice: compute loss in fp32 when using {dtype}")
|
| 461 |
+
print(f" logits_fp32 = logits.float()")
|
| 462 |
+
print(f" loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))")
|
| 463 |
+
|
| 464 |
+
# Common numerical issues reference
|
| 465 |
+
print("\n Common Numerical Issues Reference:")
|
| 466 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 467 |
+
print(" โ Symptom โ Likely Cause โ Solution โ")
|
| 468 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโค")
|
| 469 |
+
print(" โ Loss โ NaN โ Large logits โ softmax โ Check init, logit scale โ")
|
| 470 |
+
print(" โ Loss โ Inf โ log(0) in CE loss โ Add eps, ignore_index โ")
|
| 471 |
+
print(" โ Loss oscillation โ fp16 gradient underflow โ Switch to bf16 / scaler โ")
|
| 472 |
+
print(" โ Late-training NaN โ Activation growth โ Check RMSNorm, wd โ")
|
| 473 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 474 |
+
|
| 475 |
+
# โโ Activation monitoring via hooks โโ
|
| 476 |
+
hooks = []
|
| 477 |
+
|
| 478 |
+
def make_hook(name: str):
|
| 479 |
+
def hook_fn(module, input, output):
|
| 480 |
+
if isinstance(output, torch.Tensor):
|
| 481 |
+
out_f = output.float()
|
| 482 |
+
stats = {
|
| 483 |
+
"name": name,
|
| 484 |
+
"mean": out_f.mean().item(),
|
| 485 |
+
"std": out_f.std().item(),
|
| 486 |
+
"max": out_f.abs().max().item(),
|
| 487 |
+
"has_nan": bool(torch.isnan(output).any()),
|
| 488 |
+
"has_inf": bool(torch.isinf(output).any()),
|
| 489 |
+
}
|
| 490 |
+
activation_stats.append(stats)
|
| 491 |
+
return hook_fn
|
| 492 |
+
|
| 493 |
+
# Register hooks on transformer layers
|
| 494 |
+
for i, layer in enumerate(model.layers):
|
| 495 |
+
h = layer.register_forward_hook(make_hook(f"layer_{i}"))
|
| 496 |
+
hooks.append(h)
|
| 497 |
+
|
| 498 |
+
# โโ Forward + Backward โโ
|
| 499 |
+
model.train()
|
| 500 |
+
model.zero_grad(set_to_none=True)
|
| 501 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 502 |
+
logits, loss = model(input_ids, targets)
|
| 503 |
+
|
| 504 |
+
loss_val = loss.item()
|
| 505 |
+
loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val))
|
| 506 |
+
results.append(_check_result(
|
| 507 |
+
"Loss value",
|
| 508 |
+
loss_ok,
|
| 509 |
+
f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
|
| 510 |
+
))
|
| 511 |
+
|
| 512 |
+
loss.backward()
|
| 513 |
+
|
| 514 |
+
# Remove hooks
|
| 515 |
+
for h in hooks:
|
| 516 |
+
h.remove()
|
| 517 |
+
|
| 518 |
+
# โโ Gradient checks โโ
|
| 519 |
+
print("\n Gradient Health:")
|
| 520 |
+
grad_issues = []
|
| 521 |
+
for name, param in model.named_parameters():
|
| 522 |
+
if param.grad is None:
|
| 523 |
+
continue
|
| 524 |
+
grad = param.grad
|
| 525 |
+
if torch.isnan(grad).any():
|
| 526 |
+
grad_issues.append(f"๐ด NaN gradient: {name}")
|
| 527 |
+
if torch.isinf(grad).any():
|
| 528 |
+
grad_issues.append(f"๐ด Inf gradient: {name}")
|
| 529 |
+
if grad.abs().max().item() > 100:
|
| 530 |
+
grad_issues.append(
|
| 531 |
+
f"๐ก Large gradient: {name} max={grad.abs().max().item():.1f}"
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
grad_ok = len(grad_issues) == 0
|
| 535 |
+
if grad_ok:
|
| 536 |
+
print(" โ
All gradients are healthy (no NaN/Inf/large values)")
|
| 537 |
+
else:
|
| 538 |
+
for issue in grad_issues[:10]: # limit output
|
| 539 |
+
print(f" {issue}")
|
| 540 |
+
if len(grad_issues) > 10:
|
| 541 |
+
print(f" ... and {len(grad_issues) - 10} more issues")
|
| 542 |
+
|
| 543 |
+
results.append(_check_result(
|
| 544 |
+
"Gradient health",
|
| 545 |
+
grad_ok,
|
| 546 |
+
f"{len(grad_issues)} issues found" if not grad_ok else "All healthy",
|
| 547 |
+
))
|
| 548 |
+
|
| 549 |
+
# โโ Activation checks โโ
|
| 550 |
+
print("\n Activation Stats (per transformer layer):")
|
| 551 |
+
act_nan_count = 0
|
| 552 |
+
for stats in activation_stats:
|
| 553 |
+
icon = "๐ด" if stats["has_nan"] or stats["has_inf"] else " "
|
| 554 |
+
if stats["has_nan"] or stats["has_inf"]:
|
| 555 |
+
act_nan_count += 1
|
| 556 |
+
print(
|
| 557 |
+
f" {icon} {stats['name']}: "
|
| 558 |
+
f"mean={stats['mean']:.4f}, "
|
| 559 |
+
f"std={stats['std']:.4f}, "
|
| 560 |
+
f"max={stats['max']:.4f}"
|
| 561 |
+
+ (" [NaN!]" if stats["has_nan"] else "")
|
| 562 |
+
+ (" [Inf!]" if stats["has_inf"] else "")
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
act_ok = act_nan_count == 0
|
| 566 |
+
results.append(_check_result(
|
| 567 |
+
"Activation health",
|
| 568 |
+
act_ok,
|
| 569 |
+
f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
|
| 570 |
+
))
|
| 571 |
+
|
| 572 |
+
# โโ Logit scale check โโ
|
| 573 |
+
logit_max = logits.float().abs().max().item()
|
| 574 |
+
logit_ok = logit_max < 1000
|
| 575 |
+
detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)"
|
| 576 |
+
results.append(_check_result("Logit scale", logit_ok, detail))
|
| 577 |
+
icon = "โ
" if logit_ok else "๐ด"
|
| 578 |
+
print(f"\n {icon} Logit scale: {detail}")
|
| 579 |
+
|
| 580 |
+
model.zero_grad(set_to_none=True)
|
| 581 |
+
|
| 582 |
+
passed_count = sum(1 for r in results if r["passed"])
|
| 583 |
+
print(f"\n Result: {passed_count}/{len(results)} checks passed")
|
| 584 |
+
|
| 585 |
+
return {
|
| 586 |
+
"level": 2,
|
| 587 |
+
"checks": results,
|
| 588 |
+
"activation_stats": activation_stats,
|
| 589 |
+
"grad_issues": grad_issues,
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 593 |
+
# Level 3: Hyperparameter Diagnosis
|
| 594 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 595 |
+
|
| 596 |
+
@staticmethod
|
| 597 |
+
def diagnose_hyperparameters(
|
| 598 |
+
metrics_history: Dict[str, list],
|
| 599 |
+
config: TrainConfig,
|
| 600 |
+
) -> Dict[str, Any]:
|
| 601 |
+
"""Analyze hyperparameter health from training metrics (Level 3).
|
| 602 |
+
|
| 603 |
+
Checks:
|
| 604 |
+
- LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny)
|
| 605 |
+
- Batch size: loss variance indicates batch too small
|
| 606 |
+
- Warmup: spikes in early steps indicate warmup too short
|
| 607 |
+
"""
|
| 608 |
+
print(_header("Level 3: Hyperparameter Diagnosis"))
|
| 609 |
+
|
| 610 |
+
findings: List[Dict[str, str]] = []
|
| 611 |
+
grad_norms = metrics_history.get("grad_norm", [])
|
| 612 |
+
train_losses = metrics_history.get("train_loss", [])
|
| 613 |
+
|
| 614 |
+
# โโ LR diagnosis โโ
|
| 615 |
+
print("\n Learning Rate Analysis:")
|
| 616 |
+
print(f" Peak LR: {config.learning_rate:.2e}")
|
| 617 |
+
print(f" Min LR: {config.min_learning_rate:.2e}")
|
| 618 |
+
|
| 619 |
+
if grad_norms:
|
| 620 |
+
avg_grad = sum(grad_norms) / len(grad_norms)
|
| 621 |
+
clip_count = sum(1 for g in grad_norms if g >= config.grad_clip * 0.95)
|
| 622 |
+
clip_rate = clip_count / len(grad_norms)
|
| 623 |
+
tiny_count = sum(1 for g in grad_norms if g < 0.01)
|
| 624 |
+
tiny_rate = tiny_count / len(grad_norms)
|
| 625 |
+
|
| 626 |
+
print(f" Avg grad norm: {avg_grad:.4f}")
|
| 627 |
+
print(f" Clip rate: {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})")
|
| 628 |
+
print(f" Tiny grad rate: {tiny_rate * 100:.1f}% (< 0.01)")
|
| 629 |
+
|
| 630 |
+
if clip_rate > 0.3:
|
| 631 |
+
findings.append({
|
| 632 |
+
"issue": "LR may be too high",
|
| 633 |
+
"evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time",
|
| 634 |
+
"action": f"Try LR = {config.learning_rate / 2:.2e} (รท2)",
|
| 635 |
+
})
|
| 636 |
+
print(f" ๐ก Grad clipping frequent ({clip_rate * 100:.0f}%) โ LR may be too high")
|
| 637 |
+
elif tiny_rate > 0.5:
|
| 638 |
+
findings.append({
|
| 639 |
+
"issue": "LR may be too low",
|
| 640 |
+
"evidence": f"Grad norm < 0.01 in {tiny_rate * 100:.0f}% of steps",
|
| 641 |
+
"action": f"Try LR = {config.learning_rate * 2:.2e} (ร2)",
|
| 642 |
+
})
|
| 643 |
+
print(f" ๐ก Grad norm too small ({tiny_rate * 100:.0f}% < 0.01) โ LR may be too low")
|
| 644 |
+
else:
|
| 645 |
+
print(f" โ
LR looks appropriate")
|
| 646 |
+
|
| 647 |
+
# โโ Batch size diagnosis โโ
|
| 648 |
+
print("\n Batch Size Analysis:")
|
| 649 |
+
print(f" Effective batch: {config.effective_batch_size}")
|
| 650 |
+
|
| 651 |
+
if len(train_losses) >= 20:
|
| 652 |
+
recent_losses = train_losses[-20:]
|
| 653 |
+
loss_mean = sum(recent_losses) / len(recent_losses)
|
| 654 |
+
loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
|
| 655 |
+
loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)
|
| 656 |
+
|
| 657 |
+
print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation)")
|
| 658 |
+
|
| 659 |
+
if loss_cv > 0.1:
|
| 660 |
+
findings.append({
|
| 661 |
+
"issue": "Batch size may be too small",
|
| 662 |
+
"evidence": f"Loss CV = {loss_cv:.4f} (high variance)",
|
| 663 |
+
"action": "Increase gradient_accumulation_steps",
|
| 664 |
+
})
|
| 665 |
+
print(f" ๐ก High loss variance โ batch may be too small")
|
| 666 |
+
else:
|
| 667 |
+
print(f" โ
Loss variance is acceptable")
|
| 668 |
+
|
| 669 |
+
# โโ ฮฒโ diagnosis โโ
|
| 670 |
+
print("\n ฮฒโ (Adam second momentum) Analysis:")
|
| 671 |
+
print(f" Current ฮฒโ: {config.beta2}")
|
| 672 |
+
if config.beta2 >= _DEFAULT_PYTORCH_BETA2:
|
| 673 |
+
findings.append({
|
| 674 |
+
"issue": "ฮฒโ may be too high for LLM training",
|
| 675 |
+
"evidence": (
|
| 676 |
+
f"ฮฒโ={config.beta2} (PyTorch default). "
|
| 677 |
+
f"LLM standard is {_RECOMMENDED_BETA2}"
|
| 678 |
+
),
|
| 679 |
+
"action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)",
|
| 680 |
+
})
|
| 681 |
+
print(f" ๐ก ฮฒโ={config.beta2} is PyTorch default โ "
|
| 682 |
+
f"LLM training standard is {_RECOMMENDED_BETA2}")
|
| 683 |
+
print(f" Why: ฮฒโ=0.999 averages ~1000 steps of gradient stats,")
|
| 684 |
+
print(f" ฮฒโ=0.95 averages ~20 steps โ faster adaptation to changing data")
|
| 685 |
+
print(f" (Cattaneo & Shigida 2025, Compagnoni et al. 2025)")
|
| 686 |
+
else:
|
| 687 |
+
print(f" โ
ฮฒโ={config.beta2} is within LLM standard range")
|
| 688 |
+
|
| 689 |
+
# โโ Weight Decay diagnosis โโ
|
| 690 |
+
print("\n Weight Decay Analysis:")
|
| 691 |
+
print(f" Current weight_decay: {config.weight_decay}")
|
| 692 |
+
if config.weight_decay == 0:
|
| 693 |
+
findings.append({
|
| 694 |
+
"issue": "Weight decay is disabled",
|
| 695 |
+
"evidence": "weight_decay=0 increases overfitting risk",
|
| 696 |
+
"action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)",
|
| 697 |
+
})
|
| 698 |
+
print(f" ๐ก weight_decay=0 โ overfitting risk. Standard is 0.1")
|
| 699 |
+
elif config.weight_decay > 0.3:
|
| 700 |
+
findings.append({
|
| 701 |
+
"issue": "Weight decay may be too high",
|
| 702 |
+
"evidence": f"weight_decay={config.weight_decay} (unusually high)",
|
| 703 |
+
"action": "Try weight_decay=0.1 (standard value)",
|
| 704 |
+
})
|
| 705 |
+
print(f" ๐ก weight_decay={config.weight_decay} is unusually high (standard: 0.1)")
|
| 706 |
+
else:
|
| 707 |
+
print(f" โ
weight_decay={config.weight_decay} is within normal range")
|
| 708 |
+
|
| 709 |
+
# โโ Model-size LR reference โโ
|
| 710 |
+
print("\n GPT-3 LR Reference (Brown et al. 2020):")
|
| 711 |
+
print(" โโโโโโโโโโโโฌโโโโโโโโโโโโฌโโโโโโโโโโโโโโโ")
|
| 712 |
+
print(" โ Model โ Peak LR โ Batch Tokens โ")
|
| 713 |
+
print(" โโโโโโโโโโโโผโโโโโโโโโโโโผโโโโโโโโโโโโโโโค")
|
| 714 |
+
for params, lr, batch_tok in _GPT3_LR_REFERENCE:
|
| 715 |
+
label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M"
|
| 716 |
+
marker = " โ" if abs(params - 1.1e9) < 0.5e9 else ""
|
| 717 |
+
print(f" โ {label:<8} โ {lr:.1e} โ {batch_tok:<12} โ{marker}")
|
| 718 |
+
print(" โโโโโโโโโโโโดโโโโโโโโโโโโดโโโโโโโโโโโโโโโ")
|
| 719 |
+
print(" โ Larger models need lower LR and larger batch")
|
| 720 |
+
|
| 721 |
+
# โโ Batch-LR scaling guidance โโ
|
| 722 |
+
print("\n Batch-LR Scaling Rules:")
|
| 723 |
+
print(" โข Batch ร2 โ LR รโ2 (square root scaling, safer)")
|
| 724 |
+
print(" โข Batch ร2 โ LR ร2 (linear scaling, used by GPT-3)")
|
| 725 |
+
print(" โข 1B model: effective batch 64~512 is typical range")
|
| 726 |
+
|
| 727 |
+
# โโ Warmup diagnosis โโ
|
| 728 |
+
print("\n Warmup Analysis:")
|
| 729 |
+
print(f" Warmup steps: {config.warmup_steps} "
|
| 730 |
+
f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)")
|
| 731 |
+
|
| 732 |
+
if len(train_losses) >= 10:
|
| 733 |
+
early_losses = train_losses[:min(50, len(train_losses))]
|
| 734 |
+
# Detect spikes in early training
|
| 735 |
+
spike_count = 0
|
| 736 |
+
for i in range(1, len(early_losses)):
|
| 737 |
+
if early_losses[i] > early_losses[i - 1] * 1.5:
|
| 738 |
+
spike_count += 1
|
| 739 |
+
|
| 740 |
+
if spike_count > 3:
|
| 741 |
+
findings.append({
|
| 742 |
+
"issue": "Warmup may be too short",
|
| 743 |
+
"evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps",
|
| 744 |
+
"action": f"Try warmup_steps = {config.warmup_steps * 2}",
|
| 745 |
+
})
|
| 746 |
+
print(f" ๐ก {spike_count} spikes in early training โ warmup may be too short")
|
| 747 |
+
else:
|
| 748 |
+
print(f" โ
Early training is stable")
|
| 749 |
+
|
| 750 |
+
# โโ Summary โโ
|
| 751 |
+
if not findings:
|
| 752 |
+
print("\n โ
No hyperparameter issues detected")
|
| 753 |
+
else:
|
| 754 |
+
print(f"\n Found {len(findings)} potential issue(s):")
|
| 755 |
+
for f in findings:
|
| 756 |
+
print(f" โข {f['issue']}: {f['action']}")
|
| 757 |
+
|
| 758 |
+
# โโ Warmup reference from real projects โโ
|
| 759 |
+
print("\n Warmup Reference (real projects):")
|
| 760 |
+
print(" โข TinyLlama 1.1B (3T tokens): 2,000 steps โ 0.1% of total")
|
| 761 |
+
print(" โข GPT-3 175B: 375 steps โ 0.2% of total")
|
| 762 |
+
print(" โข General range: 0.1% ~ 5% of total steps")
|
| 763 |
+
print(" โข Smaller experiments: 5~10% is also reasonable")
|
| 764 |
+
|
| 765 |
+
print("\n Tuning priority (high โ low):")
|
| 766 |
+
print(" 1. Learning Rate โ tune first (10x impact)")
|
| 767 |
+
print(" 2. Batch Size โ adjust with LR")
|
| 768 |
+
print(" 3. Warmup Steps โ early stability")
|
| 769 |
+
print(" 4. Weight Decay โ if overfitting (typically 0.1)")
|
| 770 |
+
print(" 5. ฮฒโ, ฮฒโ (Adam) โ see ฮฒโ analysis above")
|
| 771 |
+
print(" 6. Gradient Clip โ usually keep at 1.0")
|
| 772 |
+
|
| 773 |
+
return {
|
| 774 |
+
"level": 3,
|
| 775 |
+
"findings": findings,
|
| 776 |
+
"config_summary": {
|
| 777 |
+
"learning_rate": config.learning_rate,
|
| 778 |
+
"effective_batch": config.effective_batch_size,
|
| 779 |
+
"warmup_steps": config.warmup_steps,
|
| 780 |
+
"total_steps": config.total_steps,
|
| 781 |
+
"grad_clip": config.grad_clip,
|
| 782 |
+
},
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
@staticmethod
|
| 786 |
+
def lr_range_test(
|
| 787 |
+
model: nn.Module,
|
| 788 |
+
dataloader: DataLoader,
|
| 789 |
+
device: torch.device,
|
| 790 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 791 |
+
lr_start: float = 1e-7,
|
| 792 |
+
lr_end: float = 1e-1,
|
| 793 |
+
steps: int = 300,
|
| 794 |
+
) -> Dict[str, Any]:
|
| 795 |
+
"""Run an LR range test to find the optimal learning rate (Level 3 bonus).
|
| 796 |
+
|
| 797 |
+
Sweeps LR from lr_start to lr_end exponentially, recording loss.
|
| 798 |
+
The optimal LR is where loss decreases fastest (steepest slope),
|
| 799 |
+
divided by 3~10 for stability.
|
| 800 |
+
|
| 801 |
+
WARNING: This modifies a copy of the model. The original is untouched.
|
| 802 |
+
"""
|
| 803 |
+
print(_header("Level 3 Bonus: LR Range Test"))
|
| 804 |
+
print(f" Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n")
|
| 805 |
+
|
| 806 |
+
test_model = copy.deepcopy(model)
|
| 807 |
+
test_model.to(device)
|
| 808 |
+
test_model.train()
|
| 809 |
+
optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start)
|
| 810 |
+
|
| 811 |
+
lr_mult = (lr_end / lr_start) ** (1 / steps)
|
| 812 |
+
lr = lr_start
|
| 813 |
+
|
| 814 |
+
lrs: List[float] = []
|
| 815 |
+
losses: List[float] = []
|
| 816 |
+
data_iter = iter(dataloader)
|
| 817 |
+
|
| 818 |
+
for step in range(steps):
|
| 819 |
+
for pg in optimizer.param_groups:
|
| 820 |
+
pg["lr"] = lr
|
| 821 |
+
|
| 822 |
+
try:
|
| 823 |
+
batch = next(data_iter)
|
| 824 |
+
except StopIteration:
|
| 825 |
+
data_iter = iter(dataloader)
|
| 826 |
+
batch = next(data_iter)
|
| 827 |
+
|
| 828 |
+
input_ids = batch["input_ids"].to(device)
|
| 829 |
+
targets_t = batch["targets"].to(device)
|
| 830 |
+
|
| 831 |
+
optimizer.zero_grad()
|
| 832 |
+
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 833 |
+
_, loss = test_model(input_ids, targets_t)
|
| 834 |
+
loss.backward()
|
| 835 |
+
optimizer.step()
|
| 836 |
+
|
| 837 |
+
loss_val = loss.item()
|
| 838 |
+
lrs.append(lr)
|
| 839 |
+
losses.append(loss_val)
|
| 840 |
+
|
| 841 |
+
if (step + 1) % 50 == 0:
|
| 842 |
+
print(f" Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}")
|
| 843 |
+
|
| 844 |
+
# Stop if loss explodes
|
| 845 |
+
if len(losses) > 1 and loss_val > losses[0] * 4:
|
| 846 |
+
print(f" Loss exploded at LR = {lr:.2e}, stopping.")
|
| 847 |
+
break
|
| 848 |
+
|
| 849 |
+
lr *= lr_mult
|
| 850 |
+
|
| 851 |
+
del test_model, optimizer
|
| 852 |
+
if torch.cuda.is_available():
|
| 853 |
+
torch.cuda.empty_cache()
|
| 854 |
+
|
| 855 |
+
# Find steepest descent
|
| 856 |
+
best_lr = lr_start
|
| 857 |
+
if len(losses) > 10:
|
| 858 |
+
# Smooth losses and find steepest negative slope
|
| 859 |
+
window = 5
|
| 860 |
+
smoothed = []
|
| 861 |
+
for i in range(len(losses) - window):
|
| 862 |
+
smoothed.append(sum(losses[i:i + window]) / window)
|
| 863 |
+
|
| 864 |
+
min_slope = 0
|
| 865 |
+
min_idx = 0
|
| 866 |
+
for i in range(1, len(smoothed)):
|
| 867 |
+
slope = smoothed[i] - smoothed[i - 1]
|
| 868 |
+
if slope < min_slope:
|
| 869 |
+
min_slope = slope
|
| 870 |
+
min_idx = i
|
| 871 |
+
|
| 872 |
+
best_lr = lrs[min_idx]
|
| 873 |
+
suggested_lr = best_lr / 3 # conservative choice
|
| 874 |
+
|
| 875 |
+
print(f"\n Steepest descent at LR = {best_lr:.2e}")
|
| 876 |
+
print(f" Suggested peak LR: {suggested_lr:.2e} (รท3 for stability)")
|
| 877 |
+
print(f" Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]")
|
| 878 |
+
else:
|
| 879 |
+
suggested_lr = 3e-4
|
| 880 |
+
print(f"\n Not enough data points. Using default LR = {suggested_lr:.2e}")
|
| 881 |
+
|
| 882 |
+
return {
|
| 883 |
+
"lrs": lrs,
|
| 884 |
+
"losses": losses,
|
| 885 |
+
"best_lr": best_lr,
|
| 886 |
+
"suggested_lr": suggested_lr,
|
| 887 |
+
}
|
| 888 |
+
|
| 889 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 890 |
+
# Level 4: Overfitting vs Underfitting Diagnosis
|
| 891 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 892 |
+
|
| 893 |
+
@staticmethod
|
| 894 |
+
def diagnose_fitting(
|
| 895 |
+
metrics_history: Dict[str, list],
|
| 896 |
+
model_params: Optional[int] = None,
|
| 897 |
+
total_tokens: Optional[int] = None,
|
| 898 |
+
) -> Dict[str, Any]:
|
| 899 |
+
"""Diagnose overfitting vs underfitting from metrics (Level 4).
|
| 900 |
+
|
| 901 |
+
Cases:
|
| 902 |
+
1. Both high, decreasing โ Normal (still training)
|
| 903 |
+
2. Both high, plateau โ Underfitting
|
| 904 |
+
3. Trainโ Valโ or Valโ โ Overfitting
|
| 905 |
+
4. Both low, plateau โ Converged (or at limit)
|
| 906 |
+
"""
|
| 907 |
+
print(_header("Level 4: Overfitting vs Underfitting Diagnosis"))
|
| 908 |
+
|
| 909 |
+
train_losses = metrics_history.get("train_loss", [])
|
| 910 |
+
val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
|
| 911 |
+
|
| 912 |
+
if len(train_losses) < 10 or len(val_losses) < 2:
|
| 913 |
+
print(" [!] Not enough data. Need more training steps with eval.")
|
| 914 |
+
return {"level": 4, "case": "insufficient_data", "recommendations": []}
|
| 915 |
+
|
| 916 |
+
# Recent train trend
|
| 917 |
+
recent_n = min(50, len(train_losses))
|
| 918 |
+
train_recent = train_losses[-recent_n:]
|
| 919 |
+
train_mid = len(train_recent) // 2
|
| 920 |
+
train_first = sum(train_recent[:train_mid]) / max(train_mid, 1)
|
| 921 |
+
train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1)
|
| 922 |
+
train_decreasing = train_second < train_first - 0.02
|
| 923 |
+
|
| 924 |
+
# Val trend
|
| 925 |
+
val_mid = len(val_losses) // 2
|
| 926 |
+
val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
|
| 927 |
+
val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
|
| 928 |
+
val_decreasing = val_second < val_first - 0.02
|
| 929 |
+
val_increasing = val_second > val_first + 0.05
|
| 930 |
+
|
| 931 |
+
# Train-Val gap
|
| 932 |
+
last_train = train_losses[-1]
|
| 933 |
+
last_val = val_losses[-1]
|
| 934 |
+
gap = last_train - last_val # negative means val > train (typical)
|
| 935 |
+
|
| 936 |
+
print(f" Train loss (recent): {train_first:.4f} โ {train_second:.4f} "
|
| 937 |
+
f"({'โ' if train_decreasing else 'โ'})")
|
| 938 |
+
print(f" Val loss: {val_first:.4f} โ {val_second:.4f} "
|
| 939 |
+
f"({'โ' if val_decreasing else 'โ' if val_increasing else 'โ'})")
|
| 940 |
+
print(f" Train-Val gap: {abs(gap):.4f}")
|
| 941 |
+
|
| 942 |
+
# โโ Classify โโ
|
| 943 |
+
case = ""
|
| 944 |
+
recommendations: List[str] = []
|
| 945 |
+
|
| 946 |
+
if train_decreasing and val_decreasing:
|
| 947 |
+
case = "Case 1: Normal โ both decreasing"
|
| 948 |
+
recommendations.append("Training is progressing normally. Continue.")
|
| 949 |
+
if model_params and total_tokens:
|
| 950 |
+
ratio = total_tokens / model_params
|
| 951 |
+
chinchilla = 20 # Chinchilla optimal: 20 tokens per param
|
| 952 |
+
if ratio < chinchilla:
|
| 953 |
+
recommendations.append(
|
| 954 |
+
f"Token/param ratio = {ratio:.1f}x "
|
| 955 |
+
f"(Chinchilla optimal โ {chinchilla}x). "
|
| 956 |
+
f"Model may benefit from more data."
|
| 957 |
+
)
|
| 958 |
+
print(f"\n ๐ข {case}")
|
| 959 |
+
|
| 960 |
+
elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]:
|
| 961 |
+
case = "Case 2: Underfitting โ both plateaued at high loss"
|
| 962 |
+
recommendations = [
|
| 963 |
+
"Diagnosis priority (check in order):",
|
| 964 |
+
"1) Training insufficient? โ check if loss curve still has downward slope",
|
| 965 |
+
" - Chinchilla: 1B model needs ~20B tokens minimum",
|
| 966 |
+
" - TinyLlama trains 1.1B on 3T tokens (inference-optimal)",
|
| 967 |
+
"2) LR too low? โ try LR ร2, see if loss drops faster",
|
| 968 |
+
"3) Model capacity too small? โ train 2x larger model on same data",
|
| 969 |
+
" - If larger model gets lower loss โ capacity was the limit",
|
| 970 |
+
"4) Data quality? โ sample and read training data manually",
|
| 971 |
+
" - Noisy/low-quality data raises the achievable loss floor",
|
| 972 |
+
]
|
| 973 |
+
if model_params and total_tokens:
|
| 974 |
+
ratio = total_tokens / model_params
|
| 975 |
+
if ratio < 10:
|
| 976 |
+
recommendations.insert(0,
|
| 977 |
+
f"โ Token/param ratio = {ratio:.1f}x โ "
|
| 978 |
+
f"very likely undertrained. Chinchilla recommends โฅ20x."
|
| 979 |
+
)
|
| 980 |
+
elif ratio < 20:
|
| 981 |
+
recommendations.insert(0,
|
| 982 |
+
f"โน Token/param ratio = {ratio:.1f}x โ "
|
| 983 |
+
f"below Chinchilla optimal (20x). More tokens may help."
|
| 984 |
+
)
|
| 985 |
+
print(f"\n ๐ก {case}")
|
| 986 |
+
|
| 987 |
+
elif train_decreasing and (val_increasing or not val_decreasing):
|
| 988 |
+
case = "Case 3: Overfitting โ trainโ but valโ/โ"
|
| 989 |
+
recommendations = [
|
| 990 |
+
"Diagnosis priority (check in order):",
|
| 991 |
+
"1) Data repetition? (most common cause in pretraining)",
|
| 992 |
+
" - Check: total tokens vs unique tokens",
|
| 993 |
+
" - Epoch > 1 dramatically increases overfitting risk",
|
| 994 |
+
" - Solution: add more data, stay within 1 epoch",
|
| 995 |
+
"2) Weight decay too low?",
|
| 996 |
+
" - Check: weight_decay value (standard: 0.1)",
|
| 997 |
+
" - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1",
|
| 998 |
+
" - Experiment: 0.01 / 0.05 / 0.1 / 0.3",
|
| 999 |
+
"3) Data diversity?",
|
| 1000 |
+
" - Single-domain data overfits faster",
|
| 1001 |
+
" - Mix: web, books, code, wiki, etc.",
|
| 1002 |
+
"",
|
| 1003 |
+
"Note on Dropout in LLM pretraining:",
|
| 1004 |
+
" - Modern LLMs do NOT use dropout in pretraining",
|
| 1005 |
+
" (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)",
|
| 1006 |
+
" - Sufficient data is the best regularization",
|
| 1007 |
+
" - Dropout is useful for fine-tuning on small datasets",
|
| 1008 |
+
]
|
| 1009 |
+
print(f"\n ๐ก {case}")
|
| 1010 |
+
|
| 1011 |
+
else:
|
| 1012 |
+
case = "Case 4: Converged โ loss is low and stable"
|
| 1013 |
+
recommendations = [
|
| 1014 |
+
"Training has converged (or reached the data/model limit).",
|
| 1015 |
+
"To push further: add more data or increase model size.",
|
| 1016 |
+
]
|
| 1017 |
+
print(f"\n ๐ข {case}")
|
| 1018 |
+
|
| 1019 |
+
for rec in recommendations:
|
| 1020 |
+
print(f" {rec}")
|
| 1021 |
+
|
| 1022 |
+
return {
|
| 1023 |
+
"level": 4,
|
| 1024 |
+
"case": case,
|
| 1025 |
+
"train_trend": "decreasing" if train_decreasing else "flat",
|
| 1026 |
+
"val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"),
|
| 1027 |
+
"gap": abs(gap),
|
| 1028 |
+
"recommendations": recommendations,
|
| 1029 |
+
}
|
| 1030 |
+
|
| 1031 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1032 |
+
# Level 5: Architecture Checks
|
| 1033 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1034 |
+
|
| 1035 |
+
@staticmethod
|
| 1036 |
+
def check_architecture(
|
| 1037 |
+
model: nn.Module,
|
| 1038 |
+
dataloader: DataLoader,
|
| 1039 |
+
device: torch.device,
|
| 1040 |
+
) -> Dict[str, Any]:
|
| 1041 |
+
"""Check weight initialization and per-layer activation health (Level 5).
|
| 1042 |
+
|
| 1043 |
+
Healthy initialization:
|
| 1044 |
+
- All layers: std โ 1.0, mean โ 0.0
|
| 1045 |
+
Problems:
|
| 1046 |
+
- std increasing per layer โ activation explosion (init scale too large)
|
| 1047 |
+
- std decreasing per layer โ activation vanishing (init scale too small)
|
| 1048 |
+
- Sudden change at specific layer โ implementation bug in that layer
|
| 1049 |
+
"""
|
| 1050 |
+
print(_header("Level 5: Architecture / Initialization Check"))
|
| 1051 |
+
|
| 1052 |
+
batch = next(iter(dataloader))
|
| 1053 |
+
sample_input = batch["input_ids"][:1].to(device)
|
| 1054 |
+
|
| 1055 |
+
model.eval()
|
| 1056 |
+
layer_stats: List[Dict[str, Any]] = []
|
| 1057 |
+
|
| 1058 |
+
with torch.no_grad():
|
| 1059 |
+
h = model.token_embedding(sample_input)
|
| 1060 |
+
emb_std = h.float().std().item()
|
| 1061 |
+
print(f"\n Embedding: std={emb_std:.4f}")
|
| 1062 |
+
|
| 1063 |
+
for i, layer in enumerate(model.layers):
|
| 1064 |
+
h = layer(h, mask=None, position_offset=0)
|
| 1065 |
+
h_f = h.float()
|
| 1066 |
+
stats = {
|
| 1067 |
+
"layer": i,
|
| 1068 |
+
"mean": h_f.mean().item(),
|
| 1069 |
+
"std": h_f.std().item(),
|
| 1070 |
+
"max": h_f.abs().max().item(),
|
| 1071 |
+
}
|
| 1072 |
+
layer_stats.append(stats)
|
| 1073 |
+
|
| 1074 |
+
# Print stats
|
| 1075 |
+
print(f"\n Layer-by-layer activation statistics:")
|
| 1076 |
+
print(f" {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}")
|
| 1077 |
+
print(f" {'-' * 38}")
|
| 1078 |
+
for s in layer_stats:
|
| 1079 |
+
print(f" {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}")
|
| 1080 |
+
|
| 1081 |
+
# โโ Weight initialization distribution check โโ
|
| 1082 |
+
print(f"\n Weight Initialization Distribution:")
|
| 1083 |
+
print(f" {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}")
|
| 1084 |
+
print(f" {'-' * 75}")
|
| 1085 |
+
weight_issues = []
|
| 1086 |
+
for name, param in model.named_parameters():
|
| 1087 |
+
if param.ndim < 2:
|
| 1088 |
+
continue # skip biases, norm weights
|
| 1089 |
+
p_f = param.float()
|
| 1090 |
+
p_mean = p_f.mean().item()
|
| 1091 |
+
p_std = p_f.std().item()
|
| 1092 |
+
# Expected: std โ 0.02 for most layers, smaller for residual projections
|
| 1093 |
+
shape_str = str(list(param.shape))
|
| 1094 |
+
is_residual = "o_proj" in name or "down_proj" in name
|
| 1095 |
+
expected_std = 0.02 # GPT-2 style
|
| 1096 |
+
if p_std > expected_std * 5:
|
| 1097 |
+
weight_issues.append(f"Large std: {name} (std={p_std:.4f})")
|
| 1098 |
+
print(f" ๐ก {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
|
| 1099 |
+
elif p_std < expected_std * 0.1:
|
| 1100 |
+
weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})")
|
| 1101 |
+
print(f" ๐ก {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}")
|
| 1102 |
+
else:
|
| 1103 |
+
print(f" {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
|
| 1104 |
+
|
| 1105 |
+
if weight_issues:
|
| 1106 |
+
print(f"\n โ {len(weight_issues)} weight distribution issue(s) found")
|
| 1107 |
+
for issue in weight_issues[:5]:
|
| 1108 |
+
print(f" โข {issue}")
|
| 1109 |
+
else:
|
| 1110 |
+
print(f"\n โ
All weight distributions look normal (std โ 0.02)")
|
| 1111 |
+
|
| 1112 |
+
print(f"\n Expected init pattern:")
|
| 1113 |
+
print(f" โข General Linear: N(0, 0.02)")
|
| 1114 |
+
print(f" โข Residual proj (o_proj, down_proj): N(0, 0.02/โ(2รlayers))")
|
| 1115 |
+
print(f" โข Embedding: N(0, 0.02)")
|
| 1116 |
+
|
| 1117 |
+
# โโ Ablation study guidance โโ
|
| 1118 |
+
print(f"\n Component Ablation Reference:")
|
| 1119 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 1120 |
+
print(" โ Experiment โ Expected Outcome โ")
|
| 1121 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค")
|
| 1122 |
+
print(" โ RMSNorm โ LayerNorm โ Minimal loss diff โ OK โ")
|
| 1123 |
+
print(" โ RoPE โ Absolute PE โ Similar on short seq (<512) โ")
|
| 1124 |
+
print(" โ SwiGLU โ ReLU FFN โ Loss +0.05~0.15 โ SwiGLU working โ")
|
| 1125 |
+
print(" โ GQA โ MHA โ Same loss, less memory โ OK โ")
|
| 1126 |
+
print(" โโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ")
|
| 1127 |
+
print(" If any replacement shows unexpected results, check that component.")
|
| 1128 |
+
|
| 1129 |
+
# Analyze trends
|
| 1130 |
+
stds = [s["std"] for s in layer_stats]
|
| 1131 |
+
diagnosis = "healthy"
|
| 1132 |
+
detail = ""
|
| 1133 |
+
|
| 1134 |
+
if len(stds) >= 3:
|
| 1135 |
+
# Check for monotonic increase/decrease
|
| 1136 |
+
first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3)
|
| 1137 |
+
last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3)
|
| 1138 |
+
ratio = last_third / max(first_third, 1e-8)
|
| 1139 |
+
|
| 1140 |
+
if ratio > 5:
|
| 1141 |
+
diagnosis = "exploding"
|
| 1142 |
+
detail = (
|
| 1143 |
+
f"Activation std grows {ratio:.1f}x from early to late layers. "
|
| 1144 |
+
f"Init scale may be too large."
|
| 1145 |
+
)
|
| 1146 |
+
elif ratio < 0.2:
|
| 1147 |
+
diagnosis = "vanishing"
|
| 1148 |
+
detail = (
|
| 1149 |
+
f"Activation std shrinks to {ratio:.1f}x from early to late layers. "
|
| 1150 |
+
f"Init scale may be too small."
|
| 1151 |
+
)
|
| 1152 |
+
else:
|
| 1153 |
+
detail = f"Std ratio (last/first third) = {ratio:.2f} โ within normal range."
|
| 1154 |
+
|
| 1155 |
+
# Check for sudden jumps
|
| 1156 |
+
for i in range(1, len(stds)):
|
| 1157 |
+
jump = stds[i] / max(stds[i - 1], 1e-8)
|
| 1158 |
+
if jump > 10 or jump < 0.1:
|
| 1159 |
+
diagnosis = "anomaly"
|
| 1160 |
+
detail = (
|
| 1161 |
+
f"Sudden activation change at layer {i}: "
|
| 1162 |
+
f"std {stds[i - 1]:.4f} โ {stds[i]:.4f}. "
|
| 1163 |
+
f"Possible implementation bug in that layer."
|
| 1164 |
+
)
|
| 1165 |
+
break
|
| 1166 |
+
|
| 1167 |
+
icon = {"healthy": "โ
", "exploding": "๐ด", "vanishing": "๐ก", "anomaly": "๐ด"}
|
| 1168 |
+
print(f"\n {icon.get(diagnosis, 'โช')} Diagnosis: {diagnosis}")
|
| 1169 |
+
print(f" {detail}")
|
| 1170 |
+
|
| 1171 |
+
return {
|
| 1172 |
+
"level": 5,
|
| 1173 |
+
"diagnosis": diagnosis,
|
| 1174 |
+
"detail": detail,
|
| 1175 |
+
"layer_stats": layer_stats,
|
| 1176 |
+
"weight_issues": weight_issues,
|
| 1177 |
+
}
|
| 1178 |
+
|
| 1179 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1180 |
+
# Scenario Auto-Detection
|
| 1181 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1182 |
+
|
| 1183 |
+
@staticmethod
|
| 1184 |
+
def detect_scenario(
|
| 1185 |
+
metrics_history: Dict[str, list],
|
| 1186 |
+
vocab_size: int = 32000,
|
| 1187 |
+
) -> Dict[str, Any]:
|
| 1188 |
+
"""Auto-detect which debugging scenario applies.
|
| 1189 |
+
|
| 1190 |
+
Scenarios (from the guide):
|
| 1191 |
+
A: Loss stuck at ~10.37 (doesn't decrease at all)
|
| 1192 |
+
B: Loss was decreasing then suddenly NaN
|
| 1193 |
+
C: Loss decreased to X then started increasing
|
| 1194 |
+
D: Loss stuck at high value (e.g. 4.0 for 1B model)
|
| 1195 |
+
"""
|
| 1196 |
+
print(_header("Scenario Auto-Detection"))
|
| 1197 |
+
|
| 1198 |
+
train_losses = metrics_history.get("train_loss", [])
|
| 1199 |
+
val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
|
| 1200 |
+
expected_initial = math.log(vocab_size)
|
| 1201 |
+
|
| 1202 |
+
if len(train_losses) < 5:
|
| 1203 |
+
print(" [!] Not enough data to detect scenario.")
|
| 1204 |
+
return {"scenario": "unknown", "steps": []}
|
| 1205 |
+
|
| 1206 |
+
first_loss = train_losses[0]
|
| 1207 |
+
last_loss = train_losses[-1]
|
| 1208 |
+
has_nan = any(math.isnan(l) for l in train_losses)
|
| 1209 |
+
min_loss = min(l for l in train_losses if not math.isnan(l))
|
| 1210 |
+
min_loss_idx = next(i for i, l in enumerate(train_losses) if l == min_loss)
|
| 1211 |
+
loss_recovered = last_loss > min_loss + 0.3 and min_loss_idx < len(train_losses) * 0.8
|
| 1212 |
+
|
| 1213 |
+
scenario = "unknown"
|
| 1214 |
+
steps: List[str] = []
|
| 1215 |
+
|
| 1216 |
+
# Scenario A: Loss stuck near initial value
|
| 1217 |
+
if abs(last_loss - expected_initial) < 1.5 and abs(first_loss - last_loss) < 0.5:
|
| 1218 |
+
scenario = "A"
|
| 1219 |
+
steps = [
|
| 1220 |
+
"1. Run single-batch overfit test โ if it fails, model/loss has a bug",
|
| 1221 |
+
"2. Check if gradients are zero โ optimizer.step() may be missing",
|
| 1222 |
+
"3. Verify input_ids/targets shift โ data pipeline bug",
|
| 1223 |
+
"4. Check LR โ is it set to 0?",
|
| 1224 |
+
"5. Check model.train() โ eval mode changes norm/dropout behavior",
|
| 1225 |
+
]
|
| 1226 |
+
|
| 1227 |
+
# Scenario B: NaN appeared
|
| 1228 |
+
elif has_nan:
|
| 1229 |
+
nan_idx = next(i for i, l in enumerate(train_losses) if math.isnan(l))
|
| 1230 |
+
scenario = "B"
|
| 1231 |
+
steps = [
|
| 1232 |
+
f"1. NaN appeared at step ~{nan_idx}. Check that batch's data for bad tokens",
|
| 1233 |
+
"2. Check gradient norm just before NaN โ was there a spike?",
|
| 1234 |
+
"3. Check LR schedule โ does NaN coincide with warmup end?",
|
| 1235 |
+
"4. Check specific layer weights for Inf values",
|
| 1236 |
+
"5. Try switching to fp32 to see if it's a mixed precision issue",
|
| 1237 |
+
" (Pythia-1B had irrecoverable fp16 loss spikes โ switched to bf16,",
|
| 1238 |
+
" Biderman et al. 2023)",
|
| 1239 |
+
]
|
| 1240 |
+
|
| 1241 |
+
# Scenario C: Loss decreased then increased
|
| 1242 |
+
elif loss_recovered:
|
| 1243 |
+
scenario = "C"
|
| 1244 |
+
steps = [
|
| 1245 |
+
"1. Check Train and Val loss simultaneously:",
|
| 1246 |
+
" - Both increasing โ LR too high (check cosine decay)",
|
| 1247 |
+
" - Only train increasing โ data quality changed (streaming order)",
|
| 1248 |
+
" - Only val increasing โ overfitting started",
|
| 1249 |
+
"2. Verify LR schedule is decaying as intended",
|
| 1250 |
+
"3. Check data shuffling โ same data repeating?",
|
| 1251 |
+
]
|
| 1252 |
+
|
| 1253 |
+
# Scenario D: Loss stuck at high value
|
| 1254 |
+
elif last_loss > _EXPECTED_TRAIN_LOSS[1] and abs(last_loss - min_loss) < 0.3:
|
| 1255 |
+
scenario = "D"
|
| 1256 |
+
total_tokens = len(train_losses) * 262144 # approximate
|
| 1257 |
+
steps = [
|
| 1258 |
+
f"1. Check total tokens trained: ~{total_tokens / 1e9:.1f}B "
|
| 1259 |
+
f"(need 5-10B for 1B model)",
|
| 1260 |
+
"2. Compare with smaller model (100M) at same step โ "
|
| 1261 |
+
"if 100M is lower, 1B may have a bug",
|
| 1262 |
+
"3. Run LR range test โ current LR may not be optimal",
|
| 1263 |
+
"4. Sample training data โ check for noise, duplicates, low quality",
|
| 1264 |
+
"5. Try different effective batch size (64 vs 128 vs 256)",
|
| 1265 |
+
]
|
| 1266 |
+
|
| 1267 |
+
else:
|
| 1268 |
+
scenario = "none"
|
| 1269 |
+
steps = ["Training appears normal. No specific scenario detected."]
|
| 1270 |
+
|
| 1271 |
+
label = {
|
| 1272 |
+
"A": "Loss stuck at initial value (~10.37)",
|
| 1273 |
+
"B": "Loss was decreasing, then NaN",
|
| 1274 |
+
"C": "Loss decreased then started increasing",
|
| 1275 |
+
"D": f"Loss stuck at high value (>{_EXPECTED_TRAIN_LOSS[1]})",
|
| 1276 |
+
"none": "No problematic scenario detected",
|
| 1277 |
+
"unknown": "Cannot determine",
|
| 1278 |
+
}
|
| 1279 |
+
|
| 1280 |
+
print(f"\n Detected: Scenario {scenario} โ {label.get(scenario, 'Unknown')}")
|
| 1281 |
+
print(f"\n Recommended debugging steps:")
|
| 1282 |
+
for step in steps:
|
| 1283 |
+
print(f" {step}")
|
| 1284 |
+
|
| 1285 |
+
return {"scenario": scenario, "label": label.get(scenario, ""), "steps": steps}
|
| 1286 |
+
|
| 1287 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1288 |
+
# Main Entry Point
|
| 1289 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1290 |
+
|
| 1291 |
+
@staticmethod
|
| 1292 |
+
def run_diagnostics(
|
| 1293 |
+
model: nn.Module,
|
| 1294 |
+
dataloader: DataLoader,
|
| 1295 |
+
tokenizer: Any,
|
| 1296 |
+
train_config: TrainConfig,
|
| 1297 |
+
metrics_history: Dict[str, list],
|
| 1298 |
+
device: torch.device,
|
| 1299 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 1300 |
+
vocab_size: int = 32000,
|
| 1301 |
+
levels: Optional[List[int]] = None,
|
| 1302 |
+
) -> Dict[str, Any]:
|
| 1303 |
+
"""Run the full 5-level debugging framework.
|
| 1304 |
+
|
| 1305 |
+
Args:
|
| 1306 |
+
model: the LLM model
|
| 1307 |
+
dataloader: training dataloader
|
| 1308 |
+
tokenizer: tokenizer with encode/decode methods
|
| 1309 |
+
train_config: TrainConfig instance
|
| 1310 |
+
metrics_history: dict from MetricsTracker.history
|
| 1311 |
+
device: torch device
|
| 1312 |
+
dtype: mixed precision dtype
|
| 1313 |
+
vocab_size: model vocabulary size
|
| 1314 |
+
levels: which levels to run (default: all [0,1,2,3,4,5])
|
| 1315 |
+
|
| 1316 |
+
Returns:
|
| 1317 |
+
Full diagnostic report dict.
|
| 1318 |
+
"""
|
| 1319 |
+
if levels is None:
|
| 1320 |
+
levels = [0, 1, 2, 3, 4, 5]
|
| 1321 |
+
|
| 1322 |
+
print("\n" + "โ" * 60)
|
| 1323 |
+
print(" LLM Loss Debugging Framework")
|
| 1324 |
+
print(" Levels to run: " + ", ".join(str(l) for l in levels))
|
| 1325 |
+
print("โ" * 60)
|
| 1326 |
+
|
| 1327 |
+
report: Dict[str, Any] = {}
|
| 1328 |
+
|
| 1329 |
+
if 0 in levels:
|
| 1330 |
+
report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history)
|
| 1331 |
+
# If status is normal and only level 0 was explicitly requested, skip rest
|
| 1332 |
+
if (
|
| 1333 |
+
report["level_0"]["status"] == STATUS_NORMAL
|
| 1334 |
+
and levels == [0]
|
| 1335 |
+
):
|
| 1336 |
+
print("\n Training is healthy โ no further debugging needed.")
|
| 1337 |
+
return report
|
| 1338 |
+
|
| 1339 |
+
if 1 in levels:
|
| 1340 |
+
report["level_1"] = LossDebugger.check_data_pipeline(
|
| 1341 |
+
model, dataloader, tokenizer, vocab_size, device, dtype,
|
| 1342 |
+
)
|
| 1343 |
+
|
| 1344 |
+
if 2 in levels:
|
| 1345 |
+
report["level_2"] = LossDebugger.check_numerical_stability(
|
| 1346 |
+
model, dataloader, device, dtype,
|
| 1347 |
+
)
|
| 1348 |
+
|
| 1349 |
+
if 3 in levels:
|
| 1350 |
+
report["level_3"] = LossDebugger.diagnose_hyperparameters(
|
| 1351 |
+
metrics_history, train_config,
|
| 1352 |
+
)
|
| 1353 |
+
|
| 1354 |
+
if 4 in levels:
|
| 1355 |
+
model_params = sum(p.numel() for p in model.parameters())
|
| 1356 |
+
total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step
|
| 1357 |
+
report["level_4"] = LossDebugger.diagnose_fitting(
|
| 1358 |
+
metrics_history, model_params, total_tokens,
|
| 1359 |
+
)
|
| 1360 |
+
|
| 1361 |
+
if 5 in levels:
|
| 1362 |
+
report["level_5"] = LossDebugger.check_architecture(
|
| 1363 |
+
model, dataloader, device,
|
| 1364 |
+
)
|
| 1365 |
+
|
| 1366 |
+
# Auto-detect scenario
|
| 1367 |
+
report["scenario"] = LossDebugger.detect_scenario(metrics_history, vocab_size)
|
| 1368 |
+
|
| 1369 |
+
# Final summary
|
| 1370 |
+
print("\n" + "โ" * 60)
|
| 1371 |
+
print(" Diagnostics Complete")
|
| 1372 |
+
print("โ" * 60)
|
| 1373 |
+
|
| 1374 |
+
return report
|
| 1375 |
+
|
| 1376 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1377 |
+
# Study Roadmap
|
| 1378 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1379 |
+
|
| 1380 |
+
@staticmethod
|
| 1381 |
+
def print_study_roadmap() -> None:
|
| 1382 |
+
"""Print the recommended study roadmap for LLM training optimization."""
|
| 1383 |
+
print(_header("Study Roadmap โ LLM Training Optimization"))
|
| 1384 |
+
|
| 1385 |
+
print("""
|
| 1386 |
+
โญโญโญ Top Priority: Optimization Fundamentals
|
| 1387 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1388 |
+
1. SGD โ Momentum โ Adam โ AdamW progression
|
| 1389 |
+
- Why Adam > SGD? Why decouple weight decay in AdamW?
|
| 1390 |
+
- ฮฒโ, ฮฒโ intuition (1st / 2nd momentum)
|
| 1391 |
+
- Ref: Loshchilov & Hutter 2019 (AdamW)
|
| 1392 |
+
- Ref: Karpathy "A Recipe for Training Neural Networks"
|
| 1393 |
+
|
| 1394 |
+
2. Loss Landscape
|
| 1395 |
+
- Why large LR diverges, small LR stalls
|
| 1396 |
+
- Batch size effect on landscape exploration
|
| 1397 |
+
- Ref: Li et al. 2018 "Visualizing the Loss Landscape"
|
| 1398 |
+
- Ref: McCandlish et al. 2018 "Large-Batch Training"
|
| 1399 |
+
|
| 1400 |
+
3. Chinchilla Scaling Law
|
| 1401 |
+
- Loss = f(N, D) relationship
|
| 1402 |
+
- Compute-optimal model size vs data allocation
|
| 1403 |
+
- Ref: Hoffmann et al. 2022 (original)
|
| 1404 |
+
- Ref: Kaplan et al. 2020 (predecessor)
|
| 1405 |
+
- Ref: Besiroglu et al. 2024 (replication/verification)
|
| 1406 |
+
|
| 1407 |
+
โญโญ Important: Training Stability
|
| 1408 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1409 |
+
4. Gradient Flow: vanishing/exploding, residual as gradient highway
|
| 1410 |
+
5. Weight Init: Xavier / Kaiming / GPT-2 style
|
| 1411 |
+
6. Normalization: BatchNorm โ LayerNorm โ RMSNorm
|
| 1412 |
+
7. Weight Decay: L2 vs decoupled, why exclude embed/norm
|
| 1413 |
+
|
| 1414 |
+
โญ Advanced: Optimization Techniques
|
| 1415 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1416 |
+
8. LR Schedules: cosine vs linear vs step, warmup/cooldown
|
| 1417 |
+
9. Gradient Accumulation & Large Batch Training
|
| 1418 |
+
10. ฮผP (Maximal Update Parameterization): transfer HP across scales
|
| 1419 |
+
|
| 1420 |
+
Recommended Experiments (in order):
|
| 1421 |
+
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 1422 |
+
1. Single-batch overfit (30 min) โ basic sanity
|
| 1423 |
+
2. LR Range Test (1 hour) โ optimal LR range
|
| 1424 |
+
3. 10M model quick train (2-3 hrs) โ pipeline validation
|
| 1425 |
+
4. Ablation (remove components) (1 day) โ component contribution
|
| 1426 |
+
5. 100M model + 5B tokens (1-2 days)โ mid-scale dynamics
|
| 1427 |
+
6. 1B model full training (2-3 days)โ scaling law verification
|
| 1428 |
+
7. LR / batch size comparison (1 day) โ HP sensitivity
|
| 1429 |
+
|
| 1430 |
+
Key References:
|
| 1431 |
+
โโโโโโโโโโโโโโโ
|
| 1432 |
+
โญโญโญ Karpathy "Recipe for Training NNs" โ debugging mindset
|
| 1433 |
+
โญโญโญ Hoffmann et al. 2022 (Chinchilla) โ scaling law
|
| 1434 |
+
โญโญ Touvron et al. 2023 (LLaMA) โ 1B+ training details
|
| 1435 |
+
โญโญ Biderman et al. 2023 (Pythia) โ open training logs
|
| 1436 |
+
โญโญ Zhang et al. 2024 (TinyLlama) โ 1.1B on 3T tokens
|
| 1437 |
+
โญโญ Groeneveld et al. 2024 (OLMo) โ fully open LLM
|
| 1438 |
+
โญโญ Li et al. 2018 (Loss Landscape) โ loss terrain intuition
|
| 1439 |
+
โญโญ Loshchilov & Hutter 2019 (AdamW) โ optimizer basics
|
| 1440 |
+
โญ Yang et al. 2022 (ฮผP) โ HP transfer
|
| 1441 |
+
โญ McCandlish et al. 2018 (Batch size) โ critical batch size
|
| 1442 |
+
""")
|
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 4,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"kernelspec": {
|
| 6 |
+
"display_name": "Python 3",
|
| 7 |
+
"language": "python",
|
| 8 |
+
"name": "python3"
|
| 9 |
+
},
|
| 10 |
+
"language_info": {
|
| 11 |
+
"name": "python",
|
| 12 |
+
"version": "3.10.0"
|
| 13 |
+
}
|
| 14 |
+
},
|
| 15 |
+
"cells": [
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"# 05. Loss Debugging (5-Level Diagnostic Framework)\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"ํ์ต ์ค loss๊ฐ ์์๋๋ก ์ค์ด๋ค์ง ์์ ๋, ์ฒด๊ณ์ ์ผ๋ก ์์ธ์ ์ง๋จํ๋ ํ๋ ์์ํฌ์
๋๋ค.\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"**ํญ์ ๋ฎ์ ๋ ๋ฒจ๋ถํฐ ์ ๊ฒํ์ธ์** โ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ๋ํ๊ธฐ ์ ์ ๋ฐ์ดํฐ ๋ฒ๊ทธ๋ฅผ ๋จผ์ ์ก์์ผ ํฉ๋๋ค.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"```\n",
|
| 27 |
+
"Level 0: Status Diagnosis โ ํ์ฌ ํ์ต ์ํ ๋ถ๋ฅ (6๊ฐ์ง)\n",
|
| 28 |
+
"Level 1: Data / Implementation โ ๊ฐ์ฅ ํํ ์์ธ (70%)\n",
|
| 29 |
+
"Level 2: Numerical Stability โ NaN/Inf, activation ํญ๋ฐ\n",
|
| 30 |
+
"Level 3: Hyperparameters โ LR, batch size, warmup\n",
|
| 31 |
+
"Level 4: Fitting Diagnosis โ overfitting vs underfitting\n",
|
| 32 |
+
"Level 5: Architecture โ ์ด๊ธฐํ, ๋ ์ด์ด๋ณ activation\n",
|
| 33 |
+
"```"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"# ์ถ๊ฐ ํจํค์ง ์ค์น ๋ถํ์\n",
|
| 43 |
+
"# LossDebugger๋ torch์ llm_lab ๋ด์ฅ ๋ชจ๋๋ง ์ฌ์ฉํฉ๋๋ค"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"import sys\n",
|
| 53 |
+
"sys.path.insert(0, '..')\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"import math\n",
|
| 56 |
+
"import torch\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n",
|
| 59 |
+
"from llm_lab.model import LLMModel\n",
|
| 60 |
+
"from llm_lab.data import setup_data_pipeline\n",
|
| 61 |
+
"from llm_lab.training import LossDebugger\n",
|
| 62 |
+
"from llm_lab.utils import auto_configure, get_device"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "markdown",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"source": [
|
| 69 |
+
"## 0. ์ค์ \n",
|
| 70 |
+
"\n",
|
| 71 |
+
"`debug_10m` ํ๋ฆฌ์
์ ์ฌ์ฉํ์ฌ CPU์์๋ ๋น ๋ฅด๊ฒ ์คํํ ์ ์๋๋ก ํฉ๋๋ค."
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"# --- Config ---\n",
|
| 81 |
+
"model_config = ModelConfig.debug_10m()\n",
|
| 82 |
+
"data_config = DataConfig(\n",
|
| 83 |
+
" max_seq_len=model_config.max_seq_len,\n",
|
| 84 |
+
" batch_size=4,\n",
|
| 85 |
+
")\n",
|
| 86 |
+
"train_config = TrainConfig.debug_10m()\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"# --- Device / dtype ---\n",
|
| 89 |
+
"train_config = auto_configure(train_config)\n",
|
| 90 |
+
"device = get_device()\n",
|
| 91 |
+
"dtype = train_config.torch_dtype\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"vocab_size = data_config.vocab_size\n",
|
| 94 |
+
"print(f\"Device: {device}, dtype: {dtype}\")\n",
|
| 95 |
+
"print(f\"Vocab size: {vocab_size:,}\")\n",
|
| 96 |
+
"print(f\"Expected initial loss: ln({vocab_size}) = {math.log(vocab_size):.2f}\")"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
{
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"execution_count": null,
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"outputs": [],
|
| 104 |
+
"source": [
|
| 105 |
+
"# --- Model ---\n",
|
| 106 |
+
"model = LLMModel(model_config).to(device)\n",
|
| 107 |
+
"print(f\"Model parameters: {model.count_parameters():,}\")\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# --- Data pipeline ---\n",
|
| 110 |
+
"tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
|
| 111 |
+
" tokenizer_mode=\"pretrained\",\n",
|
| 112 |
+
" config=data_config,\n",
|
| 113 |
+
")"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"source": [
|
| 120 |
+
"### 0.1 ํ์ต ์ด๋ ฅ (Mock Metrics History)\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"์ค์ ํ์ต ์์ด๋ Level 0 / 3 / 4 / ์๋๋ฆฌ์ค ๊ฐ์ง๋ฅผ ํ
์คํธํ ์ ์๋๋ก\n",
|
| 123 |
+
"mock `metrics_history`๋ฅผ ์ ๊ณตํฉ๋๋ค.\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"์ค์ ํ์ต ํ์๋ `trainer.metrics.history`๋ฅผ ๋์ ์ฌ์ฉํ์ธ์:\n",
|
| 126 |
+
"```python\n",
|
| 127 |
+
"# metrics_history = trainer.metrics.history\n",
|
| 128 |
+
"```"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"outputs": [],
|
| 136 |
+
"source": "import random\nrandom.seed(42)\n\nexpected_initial = math.log(vocab_size) # ~10.37\n\n# --- Scenario A: loss stuck near ln(vocab_size) ---\n# loss๊ฐ ๊ฑฐ์ ์ค์ด๋ค์ง ์๋ ์ํฉ (๋ฐ์ดํฐ/๊ตฌํ ๋ฒ๊ทธ ์์ฌ)\nn_steps_a = 200\nmock_history_a = {\n \"step\": list(range(1, n_steps_a + 1)),\n \"train_loss\": [expected_initial - 0.01 * i + random.uniform(-0.05, 0.05)\n for i in range(n_steps_a)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_a)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_a)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_a)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_a)],\n \"val_loss\": [expected_initial + random.uniform(-0.1, 0.1)\n for _ in range(0, n_steps_a, 50)],\n \"val_ppl\": [math.exp(expected_initial) + random.uniform(-50, 50)\n for _ in range(0, n_steps_a, 50)],\n}\nprint(f\"Mock A โ train_loss: {mock_history_a['train_loss'][0]:.2f} -> {mock_history_a['train_loss'][-1]:.2f}\")\nprint(f\" Expected: NO_DECREASE (loss barely changes)\")\n\n# --- Scenario B: loss decreasing then NaN ---\nn_steps_b = 200\nmock_history_b = {\n \"step\": list(range(1, n_steps_b + 1)),\n \"train_loss\": [expected_initial - 0.03 * i + random.uniform(-0.05, 0.05)\n for i in range(150)]\n + [float('nan')] * 50,\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_b)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(145)]\n + [random.uniform(5.0, 50.0) for _ in range(5)]\n + [float('nan')] * 50,\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_b)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_b)],\n \"val_loss\": [expected_initial - 0.03 * i + random.uniform(-0.1, 0.1)\n for i in range(0, n_steps_b, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.03 * i)\n for i in range(0, n_steps_b, 50)],\n}\nprint(f\"\\nMock B โ train_loss starts normal, then NaN at step ~150\")\nprint(f\" Expected: Scenario B (NaN detected)\")\n\n# --- Scenario C: loss decreased then increased ---\nn_steps_c = 200\nmock_history_c = {\n \"step\": list(range(1, n_steps_c + 1)),\n \"train_loss\": [expected_initial - 0.04 * i + random.uniform(-0.05, 0.05)\n for i in range(120)]\n + [expected_initial - 0.04 * 120 + 0.02 * (i - 120) + random.uniform(-0.05, 0.05)\n for i in range(120, n_steps_c)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_c)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_c)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_c)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_c)],\n \"val_loss\": [expected_initial - 0.02 * i + random.uniform(-0.1, 0.1)\n for i in range(0, 120, 50)]\n + [expected_initial - 0.02 * 120 + 0.03 * (i - 120) + random.uniform(-0.1, 0.1)\n for i in range(120, n_steps_c, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.02 * i)\n for i in range(0, n_steps_c, 50)],\n}\nprint(f\"\\nMock C โ train_loss: decrease โ increase (bounce)\")\nprint(f\" Expected: Scenario C (loss recovery)\")\n\n# --- Scenario D: loss stuck at high value (4.0) ---\nn_steps_d = 200\nmock_history_d = {\n \"step\": list(range(1, n_steps_d + 1)),\n \"train_loss\": [4.0 + random.uniform(-0.1, 0.1) for _ in range(n_steps_d)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_d)],\n \"grad_norm\": [random.uniform(0.1, 0.3) for _ in range(n_steps_d)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_d)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_d)],\n \"val_loss\": [4.2 + random.uniform(-0.1, 0.1)\n for _ in range(0, n_steps_d, 50)],\n \"val_ppl\": [math.exp(4.2) + random.uniform(-5, 5)\n for _ in range(0, n_steps_d, 50)],\n}\nprint(f\"\\nMock D โ train_loss stuck at ~4.0\")\nprint(f\" Expected: Scenario D (plateau at high value)\")\n\n# --- Normal: loss decreasing normally ---\nn_steps_n = 200\nmock_history_normal = {\n \"step\": list(range(1, n_steps_n + 1)),\n \"train_loss\": [expected_initial - 0.03 * i + random.uniform(-0.05, 0.05)\n for i in range(n_steps_n)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_n)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_n)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_n)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_n)],\n \"val_loss\": [expected_initial - 0.03 * i + random.uniform(-0.1, 0.1)\n for i in range(0, n_steps_n, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.03 * i)\n for i in range(0, n_steps_n, 50)],\n}\nprint(f\"\\nMock Normal โ train_loss: {mock_history_normal['train_loss'][0]:.2f} -> {mock_history_normal['train_loss'][-1]:.2f}\")\nprint(f\" Expected: NORMAL (loss decreasing steadily)\")"
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "markdown",
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"source": [
|
| 142 |
+
"## 1. Level 0 โ Status Diagnosis (์ํ ์ง๋จ)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"`metrics_history`๋ง์ผ๋ก ํ์ฌ ํ์ต ์ํ๋ฅผ 6๊ฐ์ง ์ค ํ๋๋ก ๋ถ๋ฅํฉ๋๋ค:\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"| Status | ์๋ฏธ | ์ฌ๊ฐ๋ |\n",
|
| 147 |
+
"|--------|------|--------|\n",
|
| 148 |
+
"| `NORMAL` | ์ ์ ํ์ต ์ค | green |\n",
|
| 149 |
+
"| `NO_DECREASE` | loss๊ฐ ์ค์ง ์์ | red |\n",
|
| 150 |
+
"| `DIVERGING` | loss๊ฐ ๋ฐ์ฐ | red |\n",
|
| 151 |
+
"| `PLATEAU` | loss๊ฐ ๋์ ๊ฐ์์ ์ ์ฒด | yellow |\n",
|
| 152 |
+
"| `OVERFITTING` | train๏ฟฝ๏ฟฝ valโ | yellow |\n",
|
| 153 |
+
"| `UNSTABLE` | loss ๋ณ๋์ด ํผ | yellow |"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": null,
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"# Scenario A (๋ฌธ์ ์ํฉ)\n",
|
| 163 |
+
"status_a = LossDebugger.diagnose_status(vocab_size, mock_history_a)\n",
|
| 164 |
+
"print(f\"\\n>>> Result: {status_a['status']}\")\n",
|
| 165 |
+
"\n",
|
| 166 |
+
"print(\"\\n\" + \"-\" * 40)\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"# Normal (์ ์ ์ํฉ)\n",
|
| 169 |
+
"status_n = LossDebugger.diagnose_status(vocab_size, mock_history_normal)\n",
|
| 170 |
+
"print(f\"\\n>>> Result: {status_n['status']}\")"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "markdown",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"source": [
|
| 177 |
+
"## 2. Level 1 โ Data / Implementation Checks (๋ฐ์ดํฐยท๊ตฌํ ์ ๊ฒ)\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"Loss ๋ฌธ์ ์ **70%๋ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ๋ฒ๊ทธ**์์ ๋ฐ์ํฉ๋๋ค. 6๊ฐ์ง๋ฅผ ์ฒดํฌํฉ๋๋ค:\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"1. **Shift ๊ด๊ณ** โ `targets[t] == input_ids[t+1]` ์ธ์ง ํ์ธ\n",
|
| 182 |
+
"2. **ํ ํฐ ๋ฒ์** โ `0 <= ids < vocab_size` ์ธ์ง ํ์ธ\n",
|
| 183 |
+
"3. **์ด๊ธฐ loss** โ ๋๋ค ๊ฐ์ค์น์์ `โ ln(vocab_size)` ์ธ์ง ํ์ธ\n",
|
| 184 |
+
"4. **๋จ์ผ ๋ฐฐ์น ์ค๋ฒํผํ
** โ ํ ๋ฐฐ์น๋ฅผ 200์คํ
๋ฐ๋ณต โ loss โ 0 ๋๋ฌ ํ์ธ\n",
|
| 185 |
+
"5. **ํ ํฌ๋์ด์ ์๋ณต** โ encode โ decode ์ ํ
์คํธ ๋ณด์กด ํ์ธ\n",
|
| 186 |
+
"6. **๋ฐ์ดํฐ ํ์ง** โ ์ํ ํ
์คํธ ์ก์ ํ์ธ"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"outputs": [],
|
| 194 |
+
"source": [
|
| 195 |
+
"level1 = LossDebugger.check_data_pipeline(\n",
|
| 196 |
+
" model=model,\n",
|
| 197 |
+
" dataloader=train_dl,\n",
|
| 198 |
+
" tokenizer=tokenizer,\n",
|
| 199 |
+
" vocab_size=vocab_size,\n",
|
| 200 |
+
" device=device,\n",
|
| 201 |
+
" dtype=dtype,\n",
|
| 202 |
+
")\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"print(f\"\\nPassed: {len(level1['passed'])}, Failed: {len(level1['failed'])}\")\n",
|
| 205 |
+
"for f in level1['failed']:\n",
|
| 206 |
+
" print(f\" FAILED: {f['name']} โ {f['detail']}\")"
|
| 207 |
+
]
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"cell_type": "markdown",
|
| 211 |
+
"metadata": {},
|
| 212 |
+
"source": "## 3. Level 2 โ Numerical Stability (์์น ์์ ์ฑ)\n\nForward + Backward 1ํ๋ฅผ ์คํํ์ฌ ๋ค์์ ์ ๊ฒํฉ๋๋ค:\n\n- **Mixed Precision ์ค์ ** โ RMSNorm์ด fp32๋ก ์ฐ์ฐํ๋์ง, loss dtype ํ์ธ\n- **Gradient** โ NaN/Inf/large gradient ์ฌ๋ถ\n- **Activation** โ ๊ฐ transformer ๋ ์ด์ด์ mean/std/max\n- **Logit scale** โ ์ถ๋ ฅ logit์ด ํฉ๋ฆฌ์ ๋ฒ์์ธ์ง (< 1000)\n\n### Mixed Precision Best Practices\n\n```python\n# โ ์ํ: bf16์์ ํฐ ๊ฐ + ์์ ๊ฐ ๋ง์
โ ์์ ๊ฐ ์์ค\n# โ
ํด๊ฒฐ: RMSNorm ๋ด๋ถ๋ float32๋ก ๊ณ์ฐ\nx_float = x.float() # bf16 โ fp32\nrms = torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + eps)\nreturn (x_float * rms).to(x.dtype) * self.weight\n\n# โ
Loss ๊ณ์ฐ๋ float32๋ก:\nlogits_fp32 = logits.float()\nloss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))\n```\n\n### ํํ ์์น ๋ฌธ์ \n\n| ์ฆ์ | ์์ธ | ํด๊ฒฐ |\n|------|------|------|\n| Loss โ NaN | Softmax์ ๋งค์ฐ ํฐ ๊ฐ ์
๋ ฅ | Logit ์ค์ผ์ผ ํ์ธ, ์ด๊ธฐํ ๊ฒํ |\n| Loss โ Inf | 0์ ๋ํ log ์ฐ์ฐ | eps ์ถ๊ฐ, ignore_index ์ค์ |\n| Loss ์ง๋ ์ฌํจ | fp16์์ gradient underflow | bf16 ์ ํ ๋๋ GradScaler ์ฌ์ฉ |\n| ํ์ต ํ๋ฐ NaN | ํ์ฑํ ๊ฐ ์ ์ง์ ์ฆ๊ฐ | RMSNorm ์๋ ํ์ธ, weight decay ์ ์ฉ ํ์ธ |"
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": null,
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"outputs": [],
|
| 219 |
+
"source": [
|
| 220 |
+
"level2 = LossDebugger.check_numerical_stability(\n",
|
| 221 |
+
" model=model,\n",
|
| 222 |
+
" dataloader=train_dl,\n",
|
| 223 |
+
" device=device,\n",
|
| 224 |
+
" dtype=dtype,\n",
|
| 225 |
+
")"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"cell_type": "markdown",
|
| 230 |
+
"metadata": {},
|
| 231 |
+
"source": "## 4. Level 3 โ Hyperparameter Diagnosis (ํ์ดํผํ๋ผ๋ฏธํฐ ์ง๋จ)\n\n`metrics_history`์ `TrainConfig`๋ฅผ ๋ถ์ํ์ฌ ๋ค์์ ์ ๊ฒํฉ๋๋ค:\n\n- **LR**: grad norm์ด clip limit์ ์์ฃผ ๋๋ฌํ๋ฉด โ LR์ด ๋๋ฌด ๋์\n- **Batch size**: loss ๋ณ๋์ด ํฌ๋ฉด โ batch๊ฐ ๋๋ฌด ์์\n- **Warmup**: ์ด๊ธฐ loss spike๊ฐ ๋ง์ผ๋ฉด โ warmup์ด ๋๋ฌด ์งง์\n- **ฮฒโ**: PyTorch ๊ธฐ๋ณธ๊ฐ(0.999) ๋์ LLM ํ์ค(0.95) ์ฌ์ฉ ์ฌ๋ถ\n- **Weight Decay**: 0์ด๋ฉด ์ค๋ฒํผํ
์ํ, ํ์ค์ 0.1\n\n### ํ๋ ์ฐ์ ์์ (์ํฅ๋ ์)\n\n| ์์ | ํ๋ผ๋ฏธํฐ | ์ํฅ๋ ฅ | ๋น๊ณ |\n|------|----------|--------|------|\n| 1 | **Learning Rate** | 10x | ๋ฐ๋์ ๋จผ์ ํ๋ |\n| 2 | **Batch Size** | ๋์ | LR๊ณผ ํจ๊ป ์กฐ์ |\n| 3 | **Warmup Steps** | ์ค๊ฐ | ์ด๋ฐ ์์ ์ฑ |\n| 4 | **Weight Decay** | ์ค๊ฐ | ์ค๋ฒํผํ
์ ์กฐ์ (๋ณดํต 0.1 ๊ณ ์ ) |\n| 5 | **ฮฒโ, ฮฒโ** | ๋ฎ์ | ฮฒโ=0.95 ๊ถ์ฅ (LLaMA, TinyLlama, OLMo) |\n| 6 | **Gradient Clip** | ๋ฎ์ | ๋ณดํต 1.0 ๊ณ ์ |\n\n### ฮฒโ ์ ํ ๊ฐ์ด๋\n\n- **PyTorch ๊ธฐ๋ณธ๊ฐ** ฮฒโ=0.999 โ ์ต๊ทผ ~1000 step์ gradient ํต๊ณ ๋ฐ์\n- **LLM ํ์ค** ฮฒโ=0.95 โ ์ต๊ทผ ~20 step์ gradient ํต๊ณ ๋ฐ์\n- LLM ํ์ต์ ๋ฐ์ดํฐ ๋ถํฌ๊ฐ ๊ณ์ ๋ฐ๋๋ฏ๋ก ๋น ๋ฅธ ์ ์(๋ฎ์ ฮฒโ)๏ฟฝ๏ฟฝ๏ฟฝ ์ ๋ฆฌ\n- ๋ฎ์ ฮฒโ๊ฐ loss spike ์ํ์ ๋์ (Cattaneo & Shigida 2025)\n\n### Batch-LR ์ค์ผ์ผ๋ง\n\n- Batch ร2 โ LR รโ2 (square root scaling, ๋ ์์ )\n- Batch ร2 โ LR ร2 (linear scaling, GPT-3 ๋ฑ์์ ์ฌ์ฉ)\n- 1B ๋ชจ๋ธ ๊ธฐ์ค effective batch 64~512์ด ์ผ๋ฐ์ "
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"cell_type": "code",
|
| 235 |
+
"execution_count": null,
|
| 236 |
+
"metadata": {},
|
| 237 |
+
"outputs": [],
|
| 238 |
+
"source": [
|
| 239 |
+
"level3 = LossDebugger.diagnose_hyperparameters(\n",
|
| 240 |
+
" metrics_history=mock_history_a,\n",
|
| 241 |
+
" config=train_config,\n",
|
| 242 |
+
")"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "markdown",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"source": [
|
| 249 |
+
"### 4.1 LR Range Test\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"LR์ `1e-7` โ `1e-1`๋ก ์ง์์ ์ผ๋ก ์ฆ๊ฐ์ํค๋ฉฐ loss๋ฅผ ๊ธฐ๋กํฉ๋๋ค.\n",
|
| 252 |
+
"loss๊ฐ ๊ฐ์ฅ ๋น ๋ฅด๊ฒ ์ค์ด๋๋ ์ง์ ์ LR รท 3์ด ๊ถ์ฅ peak LR์
๋๋ค.\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"> ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆด ์ ์์ต๋๋ค (debug_10m ๊ธฐ์ค ~1๋ถ).\n",
|
| 255 |
+
"> ์ค์ ํ์ต ์ ํ ๋ฒ๋ง ์คํํ๋ฉด ๋ฉ๋๋ค."
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": null,
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"outputs": [],
|
| 263 |
+
"source": [
|
| 264 |
+
"lr_result = LossDebugger.lr_range_test(\n",
|
| 265 |
+
" model=model,\n",
|
| 266 |
+
" dataloader=train_dl,\n",
|
| 267 |
+
" device=device,\n",
|
| 268 |
+
" dtype=dtype,\n",
|
| 269 |
+
" steps=100, # debug_10m: 100 steps for speed\n",
|
| 270 |
+
")\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"print(f\"\\nSuggested peak LR: {lr_result['suggested_lr']:.2e}\")"
|
| 273 |
+
]
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
"cell_type": "markdown",
|
| 277 |
+
"metadata": {},
|
| 278 |
+
"source": "## 5. Level 4 โ Fitting Diagnosis (ํผํ
์ง๋จ)\n\nTrain/Val loss ์ถ์ธ๋ฅผ ๋น๊ตํ์ฌ 4๊ฐ์ง ์ผ์ด์ค๋ฅผ ํ๋ณํฉ๋๋ค:\n\n| Case | Train | Val | ์ง๋จ |\n|------|-------|-----|------|\n| 1 | โ | โ | Normal (ํ์ต ์งํ ์ค) |\n| 2 | โ | โ | Underfitting (๋ชจ๋ธ/๋ฐ์ดํฐ ๋ถ์กฑ) |\n| 3 | โ | โ | Overfitting (๋ฐ์ดํฐ ๋ฐ๋ณต ์์ฌ) |\n| 4 | โ | โ (๋ฎ์) | Converged (์๋ ด ์๋ฃ) |\n\n### ์ธ๋ํผํ
์์ธ ๋ถ์ ์์\n\n1. **๋ชจ๋ธ ์ฉ๋ ๋ถ์กฑ?** โ 2x ํฐ ๋ชจ๋ธ๋ก ๊ฐ์ ๋ฐ์ดํฐ ํ์ต โ loss ๋ฎ์์ง๋ฉด ์ฉ๋ ๋ฌธ์ \n2. **ํ์ต ๋ถ์ถฉ๋ถ?** โ loss ๊ณก์ ์ด ์์ง ๊ฐ์ ์ถ์ธ์ธ์ง ํ์ธ (Chinchilla: 1B โ ~20B ํ ํฐ)\n3. **LR์ด ๋๋ฌด ์์์ ์๋ ด์ด ๋๋ฆฐ ๊ฒ?** โ LR ร2๋ก ์คํ\n4. **๋ฐ์ดํฐ ํ์ง ๋ฌธ์ ?** โ ์ง์ ์ํ๋งํด์ ์ฝ์ด๋ณด๊ธฐ\n\n### ์ค๋ฒํผํ
์์ธ ๋ถ์ ์์\n\n1. **๋ฐ์ดํฐ ๋ถ์กฑ/๋ฐ๋ณต?** โ epoch > 1์ด๋ฉด ์ค๋ฒํผํ
์ํ ๊ธ์ฆ\n2. **Weight Decay ๋ถ์กฑ?** โ 0.1์ด ํ์ค (LLaMA, TinyLlama, GPT-3, OLMo)\n3. **๋ฐ์ดํฐ ๋ค์์ฑ ๋ถ์กฑ?** โ ๋ค์ํ ๋๋ฉ์ธ ํผํฉ ํ์\n\n### LLM Pretraining์์ ์ค๋ฒํผํ
์ ๋ํ ์ค์ ์ฌ์ค\n\n- **1 ์ํญ ์ด๋ด** ํ์ต ์ ์ค๋ฒํผํ
์ ๋งค์ฐ ๋๋ฌพ\n- ์ค๋ฒํผํ
์ด ๋ณด์ด๋ฉด ๋๋ถ๋ถ **๋ฐ์ดํฐ ๋ฐ๋ณต**(์ํญ > 1)์ด ์์ธ\n- **Dropout**์ ํ๋ LLM pretraining์์ **๊ฑฐ์ ์ฌ์ฉํ์ง ์์**\n - Pythia, TinyLlama, OLMo, LLaMA ๋ชจ๋ dropout=0\n - ์ถฉ๋ถํ ๋ฐ์ดํฐ๊ฐ ์์ผ๋ฉด ๋ฐ์ดํฐ ์์ฒด๊ฐ ์ต๊ณ ์ ์ ๊ทํ\n - Dropout์ fine-tuning์์ ์๋ ๋ฐ์ดํฐ ํ์ต ์ ์ ํจ"
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"cell_type": "code",
|
| 282 |
+
"execution_count": null,
|
| 283 |
+
"metadata": {},
|
| 284 |
+
"outputs": [],
|
| 285 |
+
"source": [
|
| 286 |
+
"model_params = sum(p.numel() for p in model.parameters())\n",
|
| 287 |
+
"total_tokens = len(mock_history_normal[\"train_loss\"]) * train_config.tokens_per_step\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"level4 = LossDebugger.diagnose_fitting(\n",
|
| 290 |
+
" metrics_history=mock_history_normal,\n",
|
| 291 |
+
" model_params=model_params,\n",
|
| 292 |
+
" total_tokens=total_tokens,\n",
|
| 293 |
+
")"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "markdown",
|
| 298 |
+
"metadata": {},
|
| 299 |
+
"source": "## 6. Level 5 โ Architecture / Initialization Check (์ํคํ
์ฒ ์ ๊ฒ)\n\n๋ชจ๋ธ์ ์
๋ ฅ์ ํ ๋ฒ ํต๊ณผ์ํค๋ฉฐ **๋ ์ด์ด๋ณ activation ํต๊ณ**์ **๊ฐ์ค์น ๋ถํฌ**๋ฅผ ์์งํฉ๋๋ค.\n\n### Activation ์ง๋จ\n\n- **healthy**: std๊ฐ ๋ ์ด์ด ์ ๋ฐ์ ๊ฑธ์ณ ์์ ์ \n- **exploding**: std๊ฐ ๋ค์ชฝ ๋ ์ด์ด๋ก ๊ฐ์๋ก ๊ธ๊ฒฉํ ์ฆ๊ฐ โ ์ด๊ธฐํ ์ค์ผ์ผ์ด ๋๋ฌด ํผ\n- **vanishing**: std๊ฐ ๋ค์ชฝ ๋ ์ด์ด๋ก ๊ฐ์๋ก ๊ธ๊ฒฉํ ๊ฐ์ โ ์ด๊ธฐํ ์ค์ผ์ผ์ด ๋๋ฌด ์์\n- **anomaly**: ํน์ ๋ ์ด์ด์์ ๊ฐ์์ค๋ฌ์ด ๋ณํ โ ํด๋น ๋ ์ด์ด ๊ตฌํ ๋ฒ๊ทธ\n\n### Weight Initialization ์ง๋จ\n\nGPT-2 ์คํ์ผ ์ด๊ธฐํ:\n- **์ผ๋ฐ Linear**: `N(0, 0.02)`\n- **Residual projection** (o_proj, down_proj): `N(0, 0.02/โ(2รlayers))`\n โ ๊น์ ๋ชจ๋ธ์์ ์์ฐจ ๊ธฐ์ฌ๋ฅผ ์ค์ฌ ์์ ์ฑ ํ๋ณด\n\n### Ablation Study (์ปดํฌ๋ํธ๋ณ ์ํฅ ํ์ธ)\n\n| ์คํ | ์์ ๊ฒฐ๊ณผ | ์ด์ ์ |\n|------|----------|---------|\n| RMSNorm โ LayerNorm | Loss ์ฐจ์ด ๋ฏธ๋ฏธ | ์ ๊ทํ ๊ตฌํ ๋ฒ๊ทธ |\n| RoPE โ ์ ๋ ์์น ์๋ฒ ๋ฉ | ์งง์ ์ํ์ค์์ ์ฐจ์ด ์์ | RoPE ๊ตฌํ ํ์ธ |\n| SwiGLU โ ReLU FFN | Loss +0.05~0.15 | SwiGLU ๊ตฌํ ํ์ธ |\n| GQA โ MHA | Loss ๊ฑฐ์ ๋์ผ (๋ฉ๋ชจ๋ฆฌ๋ง ์ฐจ์ด) | KV repeat ๋ฒ๊ทธ |"
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "code",
|
| 303 |
+
"execution_count": null,
|
| 304 |
+
"metadata": {},
|
| 305 |
+
"outputs": [],
|
| 306 |
+
"source": [
|
| 307 |
+
"level5 = LossDebugger.check_architecture(\n",
|
| 308 |
+
" model=model,\n",
|
| 309 |
+
" dataloader=train_dl,\n",
|
| 310 |
+
" device=device,\n",
|
| 311 |
+
")\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"print(f\"\\n>>> Diagnosis: {level5['diagnosis']}\")"
|
| 314 |
+
]
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"cell_type": "markdown",
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"source": "## 7. Scenario Auto-Detection (์๋๋ฆฌ์ค ์๋ ๊ฐ์ง)\n\n`metrics_history`๋ฅผ ๋ถ์ํ์ฌ ๋ค์ 4๊ฐ์ง ์๋๋ฆฌ์ค ์ค ์ด๋์ ํด๋นํ๋์ง ์๋์ผ๋ก ํ๋ณํฉ๋๋ค:\n\n| Scenario | ์ฆ์ | ์ฃผ์ ์์ธ | ํต์ฌ ์ง๋จ |\n|----------|------|----------|----------|\n| **A** | Loss๊ฐ ~10.37์์ ์ ์ค์ด๋ฆ | ๋ฐ์ดํฐ/๊ตฌํ ๋ฒ๊ทธ | ๋จ์ผ ๋ฐฐ์น ์ค๋ฒํผํ
ํ
์คํธ |\n| **B** | Loss ๊ฐ์ ์ค ๊ฐ์๊ธฐ NaN | ์์น ๋ถ์์ | NaN ์ง์ grad norm ํ์ธ |\n| **C** | Loss ๊ฐ์ ํ ๋ค์ ์ฆ๊ฐ | LR/๋ฐ์ดํฐ ๋ฌธ์ | Train/Val ๋์ ํ์ธ |\n| **D** | Loss๊ฐ ๋์ ๊ฐ์์ ์ ์ฒด | ํ์ต ๋ถ์กฑ/LR ๋ฌธ์ | ํ ํฐ ์ ํ์ธ, LR Range Test |\n\n### ์๋๋ฆฌ์ค๋ณ ์ง๋จ ํฌ์ธํธ\n\n**์๋๋ฆฌ์ค A**: \"Loss๊ฐ 10.37์์ ์ ํ ์ ์ค์ด์\"\n1. ๋จ์ผ ๋ฐฐ์น ์ค๋ฒํผํ
โ ์คํจํ๋ฉด ๋ชจ๋ธ/Loss ๋ฒ๊ทธ\n2. gradient๊ฐ 0์ธ์ง ํ์ธ โ `optimizer.step()` ๋๋ฝ?\n3. `input_ids/targets` shift ํ์ธ โ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ ๋ฒ๊ทธ\n4. LR ํ์ธ โ 0์ผ๋ก ์ค์ ๋์ด ์์ง ์์์ง\n5. `model.train()` ํธ์ถ ํ์ธ\n\n**์๋๋ฆฌ์ค B**: \"Loss๊ฐ ์ค๋ค๊ฐ ๊ฐ์๊ธฐ NaN\"\n1. ํด๋น step์ ๋ฐฐ์น ๋ฐ์ดํฐ ํ์ธ\n2. NaN ์ง์ ์ gradient norm spike ํ์ธ\n3. LR ์ค์ผ์ค๊ณผ NaN ์์ ๋น๊ต\n4. mixed precision ๋ฌธ์ โ fp32๋ก ์ ํํ์ฌ ์ฌํ ์๋\n - (Pythia-1B: fp16 โ bf16 ์ ํ ์ฌ๋ก, Biderman et al. 2023)\n\n**์๋๋ฆฌ์ค C**: \"Loss๊ฐ 3.5๊น์ง ์ค์๋ค๊ฐ ๋ค์ ์ฌ๋ผ๊ฐ\"\n1. Train/Val ๋์ ํ์ธ:\n - ๋ ๋ค ์ค๋ฅด๋ฉด โ LR์ด ๋๋ฌด ํผ\n - Train๋ง ์ค๋ฅด๋ฉด โ ๋ฐ์ดํฐ ํ์ง ๋ณํ (streaming ์์)\n - Val๋ง ์ค๋ฅด๋ฉด โ ์ค๋ฒํผํ
์์\n2. LR ์ค์ผ์ค ํ์ธ, ๋ฐ์ดํฐ ์
ํ๋ง ํ์ธ\n\n**์๋๋ฆฌ์ค D**: \"Loss๊ฐ 4.0์์ ๋ ์ ์ค์ด์\"\n1. ํ์ต ํ ํฐ ์ ํ์ธ (5B ๋ฏธ๋ง์ด๋ฉด ํ์ต ๋ถ์กฑ)\n2. 100M ๋ชจ๋ธ๊ณผ ๋น๊ต\n3. LR Range Test ์คํ\n4. ๋ฐ์ดํฐ ํ์ง ์ํ๋ง"
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"cell_type": "code",
|
| 323 |
+
"execution_count": null,
|
| 324 |
+
"metadata": {},
|
| 325 |
+
"outputs": [],
|
| 326 |
+
"source": "# --- Scenario A ---\nprint(\"=\" * 50)\nprint(\"Testing Scenario A (loss stuck at initial value)\")\nprint(\"=\" * 50)\nscenario_a = LossDebugger.detect_scenario(\n metrics_history=mock_history_a,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_a['scenario']}\")\n\n# --- Scenario B ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario B (NaN appeared)\")\nprint(\"=\" * 50)\nscenario_b = LossDebugger.detect_scenario(\n metrics_history=mock_history_b,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_b['scenario']}\")\n\n# --- Scenario C ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario C (loss bounce)\")\nprint(\"=\" * 50)\nscenario_c = LossDebugger.detect_scenario(\n metrics_history=mock_history_c,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_c['scenario']}\")\n\n# --- Scenario D ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario D (loss plateau)\")\nprint(\"=\" * 50)\nscenario_d = LossDebugger.detect_scenario(\n metrics_history=mock_history_d,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_d['scenario']}\")\n\n# --- Normal ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Normal scenario\")\nprint(\"=\" * 50)\nscenario_n = LossDebugger.detect_scenario(\n metrics_history=mock_history_normal,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_n['scenario']}\")"
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"cell_type": "markdown",
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"source": [
|
| 332 |
+
"## 8. ์ ์ฒด ์ง๋จ (run_diagnostics)\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"์์ ๋ชจ๋ ๋ ๋ฒจ์ ํ ๋ฒ์ ์คํํฉ๋๋ค. `levels` ํ๋ผ๋ฏธํฐ๋ก ์คํํ ๋ ๋ฒจ์ ์ ํํ ์ ์์ต๋๋ค.\n",
|
| 335 |
+
"\n",
|
| 336 |
+
"```python\n",
|
| 337 |
+
"# ์ค์ ํ์ต ํ ์ฌ์ฉ ์์:\n",
|
| 338 |
+
"# report = LossDebugger.run_diagnostics(\n",
|
| 339 |
+
"# model=model, dataloader=train_dl, tokenizer=tokenizer,\n",
|
| 340 |
+
"# train_config=train_config,\n",
|
| 341 |
+
"# metrics_history=trainer.metrics.history,\n",
|
| 342 |
+
"# device=device, dtype=dtype,\n",
|
| 343 |
+
"# )\n",
|
| 344 |
+
"```"
|
| 345 |
+
]
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"cell_type": "code",
|
| 349 |
+
"execution_count": null,
|
| 350 |
+
"metadata": {},
|
| 351 |
+
"outputs": [],
|
| 352 |
+
"source": [
|
| 353 |
+
"report = LossDebugger.run_diagnostics(\n",
|
| 354 |
+
" model=model,\n",
|
| 355 |
+
" dataloader=train_dl,\n",
|
| 356 |
+
" tokenizer=tokenizer,\n",
|
| 357 |
+
" train_config=train_config,\n",
|
| 358 |
+
" metrics_history=mock_history_a,\n",
|
| 359 |
+
" device=device,\n",
|
| 360 |
+
" dtype=dtype,\n",
|
| 361 |
+
" vocab_size=vocab_size,\n",
|
| 362 |
+
" levels=[0, 1, 2, 3, 4, 5],\n",
|
| 363 |
+
")"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"cell_type": "markdown",
|
| 368 |
+
"source": "## 9. Study Roadmap (์ง์ค ๊ณต๋ถ ๋ก๋๋งต)\n\nLLM ํ์ต ์ต์ ๏ฟฝ๏ฟฝ๏ฟฝ์ ๋ํ ์ฒด๊ณ์ ํ์ต ๊ฒฝ๋ก์
๋๋ค.",
|
| 369 |
+
"metadata": {}
|
| 370 |
+
},
|
| 371 |
+
{
|
| 372 |
+
"cell_type": "code",
|
| 373 |
+
"source": "LossDebugger.print_study_roadmap()",
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"execution_count": null,
|
| 376 |
+
"outputs": []
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "markdown",
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"source": "## 10. ๋๋ฒ๊น
ํ\n\n**๊ถ์ฅ ์์:**\n1. Level 0์ผ๋ก ์ํ ํ์ธ โ ์ด๋ค ๋ ๋ฒจ์ ์ ๊ฒํด์ผ ํ๋์ง ์๋ ค์ค\n2. Level 1 (๋ฐ์ดํฐ) โ ๊ฐ์ฅ ๋จผ์ ! 70%์ ๋ฌธ์ ๊ฐ ์ฌ๊ธฐ์ ๋ฐ๊ฒฌ๋จ\n3. Level 2 (์์น ์์ ์ฑ) โ NaN/Inf ๋ฌธ์ ํด๊ฒฐ\n4. Level 3 (ํ์ดํผํ๋ผ๋ฏธํฐ) โ LR์ ๋จผ์ ํ๋ (์ํฅ๋ ฅ 10๋ฐฐ)\n5. Level 4 (ํผํ
์ง๋จ) โ ์ถฉ๋ถํ ํ์ตํ ํ์ ํ์ธ\n6. Level 5 (์ํคํ
์ฒ) โ ์ ๋ ๋ฒจ์์ ์์ธ์ ๋ชป ์ฐพ์ ๋\n\n**Scenario Detection** โ ์ฆ์ ๊ธฐ๋ฐ์ผ๋ก ์ด๋ค ์๋๋ฆฌ์ค์ธ์ง ์๋ ํ๋ณ\n\n### ํต์ฌ ์ฐธ๊ณ ์๋ฃ\n\n| ์๋ฃ | ๋ด์ฉ | ์ฐ์ ์์ |\n|------|------|----------|\n| Karpathy \"Recipe for Training NNs\" | ์ค์ ๋๋ฒ๊น
๋ง์ธ๋์
| โญโญโญ |\n| Hoffmann et al. 2022 (Chinchilla) | Scaling Law ํต์ฌ | โญโญโญ |\n| Kaplan et al. 2020 (Scaling Laws) | Loss ์์ธก ๊ณต์ | โญโญโญ |\n| Touvron et al. 2023 (LLaMA) | 1B๊ธ ๋ชจ๋ธ ํ์ต ์ธ๋ถ์ฌํญ | โญโญ |\n| Biderman et al. 2023 (Pythia) | ๊ณต๊ฐ ํ์ต ๋ก๊ทธ, ์ฌํ์ฑ | โญโญ |\n| Zhang et al. 2024 (TinyLlama) | 1.1B ๋ชจ๋ธ 3T ํ ํฐ ํ์ต | โญโญ |\n| Groeneveld et al. 2024 (OLMo) | ์์ ๊ณต๊ฐ LLM ํ๋ ์์ํฌ | โญโญ |\n| Li et al. 2018 (Loss Landscape) | Loss ์งํ ์ง๊ด | โญโญ |\n| Loshchilov & Hutter 2019 (AdamW) | ์ตํฐ๋ง์ด์ ๊ธฐ์ด | โญโญ |\n| Yang et al. 2022 (ฮผP) | ํ์ดํผํ๋ผ๋ฏธํฐ ์ ์ด | โญ |\n\n---\n**์ด์ ๋จ๊ณ:** `03_training.ipynb`์์ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค. \n**๋ค์ ๋จ๊ณ:** `04_evaluation.ipynb`์์ ํ์ต๋ ๋ชจ๋ธ์ ํ๊ฐํฉ๋๋ค."
|
| 382 |
+
}
|
| 383 |
+
]
|
| 384 |
+
}
|