Vjeong Claude Opus 4.6 commited on
Commit
5b7ea5e
ยท
1 Parent(s): c1a8df8

feat(training): add LossDebugger 5-level diagnostic framework

Browse files

Systematic debugging tool for diagnosing LLM training loss issues:
- Level 0: auto-classify training health status (6 categories)
- Level 1: data/implementation bug checks (shift, token range, overfit test)
- Level 2: numerical stability (mixed precision config, gradient/activation
NaN/Inf detection, common issues reference table)
- Level 3: hyperparameter analysis (LR, ฮฒโ‚‚, weight decay, batch-LR scaling,
GPT-3 LR reference table, warmup reference, LR range test)
- Level 4: overfitting vs underfitting diagnosis (detailed sub-cause analysis,
Chinchilla token ratio, dropout note for pretraining)
- Level 5: architecture checks (per-layer activation stats, weight init
distribution analysis, ablation study reference)
- Scenario auto-detection (A/B/C/D) with step-by-step recommendations
- Study roadmap with recommended experiments and key references

Also adds 05_debugging.ipynb with mock scenarios A/B/C/D for all
diagnostic levels and educational content from the optimization guide.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

CLAUDE.md CHANGED
@@ -35,7 +35,8 @@ LLM_Foundation_Model/
35
  โ”‚ โ”‚ โ”œโ”€โ”€ metrics.py # MetricsTracker (wandb integration)
36
  โ”‚ โ”‚ โ”œโ”€โ”€ optimizer.py # create_optimizer (weight decay separation)
37
  โ”‚ โ”‚ โ”œโ”€โ”€ trainer.py # Trainer (gradient accumulation, mixed precision)
38
- โ”‚ โ”‚ โ””โ”€โ”€ runner.py # start_training (one-line helper)
 
39
  โ”‚ โ”œโ”€โ”€ evaluation/ # Evaluation & analysis
40
  โ”‚ โ”‚ โ”œโ”€โ”€ perplexity.py # PerplexityEvaluator (including per-position loss)
41
  โ”‚ โ”‚ โ”œโ”€โ”€ generation.py # GenerationEvaluator (various prompts)
@@ -52,7 +53,8 @@ LLM_Foundation_Model/
52
  โ”‚ โ”œโ”€โ”€ 01_data_pipeline.ipynb
53
  โ”‚ โ”œโ”€โ”€ 02_model.ipynb
54
  โ”‚ โ”œโ”€โ”€ 03_training.ipynb
55
- โ”‚ โ””โ”€โ”€ 04_evaluation.ipynb
 
56
  โ””โ”€โ”€ _archive/ # Original single-file backups
57
  โ”œโ”€โ”€ llm-1b-model.py
58
  โ”œโ”€โ”€ llm-1b-data-pipeline.py
 
35
  โ”‚ โ”‚ โ”œโ”€โ”€ metrics.py # MetricsTracker (wandb integration)
36
  โ”‚ โ”‚ โ”œโ”€โ”€ optimizer.py # create_optimizer (weight decay separation)
37
  โ”‚ โ”‚ โ”œโ”€โ”€ trainer.py # Trainer (gradient accumulation, mixed precision)
38
+ โ”‚ โ”‚ โ”œโ”€โ”€ runner.py # start_training (one-line helper)
39
+ โ”‚ โ”‚ โ””โ”€โ”€ debugger.py # LossDebugger (5-level diagnostic framework)
40
  โ”‚ โ”œโ”€โ”€ evaluation/ # Evaluation & analysis
41
  โ”‚ โ”‚ โ”œโ”€โ”€ perplexity.py # PerplexityEvaluator (including per-position loss)
42
  โ”‚ โ”‚ โ”œโ”€โ”€ generation.py # GenerationEvaluator (various prompts)
 
53
  โ”‚ โ”œโ”€โ”€ 01_data_pipeline.ipynb
54
  โ”‚ โ”œโ”€โ”€ 02_model.ipynb
55
  โ”‚ โ”œโ”€โ”€ 03_training.ipynb
56
+ โ”‚ โ”œโ”€โ”€ 04_evaluation.ipynb
57
+ โ”‚ โ””โ”€โ”€ 05_debugging.ipynb
58
  โ””โ”€โ”€ _archive/ # Original single-file backups
59
  โ”œโ”€โ”€ llm-1b-model.py
60
  โ”œโ”€โ”€ llm-1b-data-pipeline.py
llm_lab/training/__init__.py CHANGED
@@ -5,8 +5,9 @@ from .metrics import MetricsTracker
5
  from .optimizer import create_optimizer
6
  from .trainer import Trainer
7
  from .runner import start_training
 
8
 
9
  __all__ = [
10
  "CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
11
- "create_optimizer", "Trainer", "start_training",
12
  ]
 
5
  from .optimizer import create_optimizer
6
  from .trainer import Trainer
7
  from .runner import start_training
8
+ from .debugger import LossDebugger
9
 
10
  __all__ = [
11
  "CosineWarmupScheduler", "CheckpointManager", "MetricsTracker",
12
+ "create_optimizer", "Trainer", "start_training", "LossDebugger",
13
  ]
llm_lab/training/debugger.py ADDED
@@ -0,0 +1,1442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM Loss Debugging & Optimization Framework.
2
+
3
+ A systematic 5-level debugging framework for diagnosing training issues.
4
+ Always start from Level 1 โ€” fixing lower-level bugs before tuning
5
+ hyperparameters saves time.
6
+
7
+ Levels:
8
+ 0. Status Diagnosis โ€” classify current training health
9
+ 1. Data/Implementation โ€” most common cause (70% of issues)
10
+ 2. Numerical Stability โ€” dtype, normalization, gradient health
11
+ 3. Hyperparameters โ€” LR, batch size, warmup
12
+ 4. Fitting Diagnosis โ€” overfitting vs underfitting
13
+ 5. Architecture โ€” initialization, component checks
14
+ """
15
+
16
+ import copy
17
+ import math
18
+ from typing import Any, Dict, List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+
25
+ from llm_lab.config import TrainConfig
26
+
27
+
28
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
29
+ # Constants
30
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
31
+
32
+ # Normal convergence ranges for a 1B model trained on ~10B tokens
33
+ _EXPECTED_TRAIN_LOSS = (2.8, 3.5)
34
+ _EXPECTED_VAL_LOSS = (3.0, 3.8)
35
+ _EXPECTED_VAL_PPL = (20, 45)
36
+
37
+ # Status labels
38
+ STATUS_NORMAL = "NORMAL"
39
+ STATUS_NO_DECREASE = "NO_DECREASE"
40
+ STATUS_DIVERGING = "DIVERGING"
41
+ STATUS_PLATEAU = "PLATEAU"
42
+ STATUS_OVERFITTING = "OVERFITTING"
43
+ STATUS_UNSTABLE = "UNSTABLE"
44
+
45
+ # GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
46
+ # (param_count, recommended_lr, batch_tokens_str)
47
+ _GPT3_LR_REFERENCE = [
48
+ (125e6, 6e-4, "0.5M"),
49
+ (350e6, 3e-4, "0.5M"),
50
+ (1.3e9, 2e-4, "1M"),
51
+ (2.7e9, 1.6e-4, "1M"),
52
+ (6.7e9, 1.2e-4, "2M"),
53
+ (175e9, 6e-5, "3.2M"),
54
+ ]
55
+
56
+ # Known LLM training references
57
+ _LLM_TRAINING_REFS = {
58
+ "TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
59
+ "LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
60
+ "Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1},
61
+ "OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1},
62
+ }
63
+
64
+ # Recommended ฮฒโ‚‚ for LLM training
65
+ _RECOMMENDED_BETA2 = 0.95
66
+ _DEFAULT_PYTORCH_BETA2 = 0.999
67
+
68
+
69
+ def _header(title: str) -> str:
70
+ return f"\n{'=' * 60}\n{title}\n{'=' * 60}"
71
+
72
+
73
+ def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]:
74
+ return {"name": name, "passed": passed, "detail": detail}
75
+
76
+
77
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
78
+ # LossDebugger
79
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
80
+
81
+
82
+ class LossDebugger:
83
+ """5-level loss debugging framework for LLM training.
84
+
85
+ Usage::
86
+
87
+ from llm_lab.training.debugger import LossDebugger
88
+
89
+ # Quick status check
90
+ status = LossDebugger.diagnose_status(vocab_size=32000,
91
+ metrics_history=trainer.metrics.history)
92
+
93
+ # Full diagnostics
94
+ report = LossDebugger.run_diagnostics(
95
+ model=model, dataloader=train_dl, tokenizer=tok,
96
+ train_config=train_cfg, metrics_history=trainer.metrics.history,
97
+ device=device, dtype=torch.bfloat16,
98
+ )
99
+ """
100
+
101
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
102
+ # Level 0: Status Diagnosis
103
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
104
+
105
+ @staticmethod
106
+ def diagnose_status(
107
+ vocab_size: int,
108
+ metrics_history: Dict[str, list],
109
+ ) -> Dict[str, Any]:
110
+ """Classify current training health from metrics history.
111
+
112
+ Args:
113
+ vocab_size: model vocabulary size (e.g. 32000)
114
+ metrics_history: dict with keys 'train_loss', 'val_loss', etc.
115
+
116
+ Returns:
117
+ dict with 'status', 'severity', 'details', 'recommended_levels'
118
+ """
119
+ print(_header("Level 0: Training Status Diagnosis"))
120
+
121
+ expected_initial = math.log(vocab_size)
122
+ print(f" Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}")
123
+ print(f" Normal convergence range (1B, 10B tokens):")
124
+ print(f" Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}")
125
+ print(f" Val Loss: {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}")
126
+ print(f" Val PPL: {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}")
127
+
128
+ train_losses = metrics_history.get("train_loss", [])
129
+ val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
130
+
131
+ if len(train_losses) < 2:
132
+ print("\n [!] Not enough training data to diagnose. Run more steps first.")
133
+ return {
134
+ "status": "INSUFFICIENT_DATA",
135
+ "severity": "unknown",
136
+ "details": "Need at least 2 logged train loss values.",
137
+ "recommended_levels": [1],
138
+ }
139
+
140
+ first_loss = train_losses[0]
141
+ last_loss = train_losses[-1]
142
+ loss_change = first_loss - last_loss
143
+
144
+ # Split into halves for trend analysis
145
+ mid = len(train_losses) // 2
146
+ first_half_avg = sum(train_losses[:mid]) / mid
147
+ second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid)
148
+
149
+ # Recent window for spike detection
150
+ recent_n = min(50, len(train_losses))
151
+ recent = train_losses[-recent_n:]
152
+ recent_mean = sum(recent) / len(recent)
153
+ recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent)
154
+ recent_std = recent_var ** 0.5
155
+
156
+ # Val trend
157
+ val_trend = "unknown"
158
+ if len(val_losses) >= 2:
159
+ val_mid = len(val_losses) // 2
160
+ val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
161
+ val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
162
+ if val_second_avg < val_first_avg - 0.05:
163
+ val_trend = "decreasing"
164
+ elif val_second_avg > val_first_avg + 0.1:
165
+ val_trend = "increasing"
166
+ else:
167
+ val_trend = "flat"
168
+
169
+ # โ”€โ”€ Classify โ”€โ”€
170
+ status = STATUS_NORMAL
171
+ severity = "green"
172
+ details = ""
173
+ recommended_levels: List[int] = []
174
+
175
+ # Check 1: No decrease at all
176
+ if loss_change < 0.1 and first_loss > expected_initial - 2.0:
177
+ status = STATUS_NO_DECREASE
178
+ severity = "red"
179
+ details = (
180
+ f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} "
181
+ f"(delta={loss_change:.4f}). Likely a data or implementation bug."
182
+ )
183
+ recommended_levels = [1, 2]
184
+
185
+ # Check 2: Diverging
186
+ elif last_loss > expected_initial + 1.0:
187
+ status = STATUS_DIVERGING
188
+ severity = "red"
189
+ details = (
190
+ f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). "
191
+ f"Training is diverging โ€” check LR, data, or numerical issues."
192
+ )
193
+ recommended_levels = [1, 2, 3]
194
+
195
+ # Check 3: Unstable (large spikes)
196
+ elif recent_std > 0.5 * recent_mean:
197
+ status = STATUS_UNSTABLE
198
+ severity = "yellow"
199
+ details = (
200
+ f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. "
201
+ f"Training is unstable โ€” likely LR too high or batch too small."
202
+ )
203
+ recommended_levels = [3, 2]
204
+
205
+ # Check 4: Overfitting
206
+ elif val_trend == "increasing" and second_half_avg < first_half_avg:
207
+ status = STATUS_OVERFITTING
208
+ severity = "yellow"
209
+ details = (
210
+ f"Train loss decreasing but val loss increasing. "
211
+ f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, "
212
+ f"Val trend: {val_trend}."
213
+ )
214
+ recommended_levels = [4]
215
+
216
+ # Check 5: Plateau
217
+ elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
218
+ status = STATUS_PLATEAU
219
+ severity = "yellow"
220
+ details = (
221
+ f"Loss has plateaued: first half avg={first_half_avg:.4f}, "
222
+ f"second half avg={second_half_avg:.4f}. "
223
+ f"Current loss ({last_loss:.4f}) is above expected range."
224
+ )
225
+ recommended_levels = [3, 4, 5]
226
+
227
+ # Normal
228
+ else:
229
+ status = STATUS_NORMAL
230
+ severity = "green"
231
+ details = (
232
+ f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} "
233
+ f"(delta={loss_change:.4f}). Val trend: {val_trend}."
234
+ )
235
+ recommended_levels = []
236
+
237
+ # โ”€โ”€ Print โ”€โ”€
238
+ icons = {"red": "๐Ÿ”ด", "yellow": "๐ŸŸก", "green": "๐ŸŸข"}
239
+ icon = icons.get(severity, "โšช")
240
+ print(f"\n {icon} Status: {status}")
241
+ print(f" {details}")
242
+ if recommended_levels:
243
+ print(f" Recommended: check Level(s) {recommended_levels}")
244
+
245
+ return {
246
+ "status": status,
247
+ "severity": severity,
248
+ "details": details,
249
+ "recommended_levels": recommended_levels,
250
+ }
251
+
252
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€๏ฟฝ๏ฟฝโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
253
+ # Level 1: Data / Implementation Bug Checks
254
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
255
+
256
+ @staticmethod
257
+ def check_data_pipeline(
258
+ model: nn.Module,
259
+ dataloader: DataLoader,
260
+ tokenizer: Any,
261
+ vocab_size: int,
262
+ device: torch.device,
263
+ dtype: torch.dtype = torch.bfloat16,
264
+ ) -> Dict[str, Any]:
265
+ """Run 6 data/implementation checks (Level 1).
266
+
267
+ This is the most important level โ€” 70% of loss issues are data bugs.
268
+
269
+ Checks:
270
+ 1. Shift relationship (targets[t] == input_ids[t+1])
271
+ 2. Token range (0 <= ids < vocab_size)
272
+ 3. Initial loss (โ‰ˆ ln(vocab_size) for random weights)
273
+ 4. Single-batch overfit (loss โ†’ ~0 in 200 steps)
274
+ 5. Tokenizer roundtrip (encodeโ†’decode preserves text)
275
+ 6. Data quality sampling (visual inspection)
276
+ """
277
+ print(_header("Level 1: Data / Implementation Bug Checks"))
278
+ print(" (70% of loss issues come from data pipeline bugs)\n")
279
+
280
+ results: List[Dict[str, Any]] = []
281
+ batch = next(iter(dataloader))
282
+ input_ids = batch["input_ids"]
283
+ targets = batch["targets"]
284
+
285
+ # โ”€โ”€ Check 1: Shift relationship โ”€โ”€
286
+ shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
287
+ passed = shift_match > 0.99
288
+ detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)"
289
+ results.append(_check_result("Shift relationship", passed, detail))
290
+ icon = "โœ…" if passed else "โŒ"
291
+ print(f" {icon} Check 1: {detail}")
292
+
293
+ # โ”€โ”€ Check 2: Token range โ”€โ”€
294
+ min_id = input_ids.min().item()
295
+ max_id = input_ids.max().item()
296
+ range_ok = min_id >= 0 and max_id < vocab_size
297
+ detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}"
298
+ results.append(_check_result("Token range", range_ok, detail))
299
+ icon = "โœ…" if range_ok else "โŒ"
300
+ print(f" {icon} Check 2: {detail}")
301
+
302
+ # โ”€โ”€ Check 3: Initial loss โ”€โ”€
303
+ expected_loss = math.log(vocab_size)
304
+ model_copy = copy.deepcopy(model)
305
+ model_copy._init_weights() # re-initialize to random
306
+ model_copy.to(device)
307
+ model_copy.eval()
308
+ with torch.no_grad():
309
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
310
+ _, initial_loss = model_copy(
311
+ input_ids.to(device),
312
+ targets.to(device),
313
+ )
314
+ initial_loss_val = initial_loss.item()
315
+ loss_diff = abs(initial_loss_val - expected_loss)
316
+ loss_ok = loss_diff < 1.0
317
+ detail = (
318
+ f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} "
319
+ f"(diff={loss_diff:.4f})"
320
+ )
321
+ results.append(_check_result("Initial loss", loss_ok, detail))
322
+ icon = "โœ…" if loss_ok else "โŒ"
323
+ print(f" {icon} Check 3: {detail}")
324
+ if initial_loss_val > expected_loss + 1.0:
325
+ print(f" Hint: loss >> ln(V) suggests label mismatch or loss function bug")
326
+ elif initial_loss_val < expected_loss - 2.0:
327
+ print(f" Hint: loss << ln(V) suggests data leakage")
328
+ del model_copy
329
+ if torch.cuda.is_available():
330
+ torch.cuda.empty_cache()
331
+
332
+ # โ”€โ”€ Check 4: Single-batch overfit test โ”€โ”€
333
+ print(f"\n โณ Check 4: Single-batch overfit test (200 steps)...")
334
+ overfit_model = copy.deepcopy(model)
335
+ overfit_model.to(device)
336
+ overfit_model.train()
337
+ overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=1e-3)
338
+ single_input = input_ids[:1].to(device) # single sample
339
+ single_target = targets[:1].to(device)
340
+
341
+ overfit_losses = []
342
+ for step in range(200):
343
+ overfit_optimizer.zero_grad()
344
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
345
+ _, loss = overfit_model(single_input, single_target)
346
+ loss.backward()
347
+ overfit_optimizer.step()
348
+ overfit_losses.append(loss.item())
349
+ if (step + 1) % 50 == 0:
350
+ print(f" Step {step + 1}: Loss = {loss.item():.4f}")
351
+
352
+ final_overfit_loss = overfit_losses[-1]
353
+ overfit_ok = final_overfit_loss < 0.1
354
+ detail = (
355
+ f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
356
+ f"(target < 0.1)"
357
+ )
358
+ results.append(_check_result("Single-batch overfit", overfit_ok, detail))
359
+ icon = "โœ…" if overfit_ok else "โŒ"
360
+ print(f" {icon} Check 4: {detail}")
361
+ if not overfit_ok:
362
+ print(f" CRITICAL: Model cannot memorize a single batch!")
363
+ print(f" This means the model or loss function has a bug.")
364
+ del overfit_model, overfit_optimizer
365
+ if torch.cuda.is_available():
366
+ torch.cuda.empty_cache()
367
+
368
+ # โ”€โ”€ Check 5: Tokenizer roundtrip โ”€โ”€
369
+ test_text = "The quick brown fox jumps over the lazy dog."
370
+ encoded = tokenizer.encode(test_text)
371
+ decoded = tokenizer.decode(encoded)
372
+ roundtrip_ok = test_text.strip() in decoded.strip()
373
+ detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'"
374
+ results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail))
375
+ icon = "โœ…" if roundtrip_ok else "โŒ"
376
+ print(f" {icon} Check 5: {detail}")
377
+
378
+ # โ”€โ”€ Check 6: Data quality sampling โ”€โ”€
379
+ print(f"\n ๐Ÿ“‹ Check 6: Data quality sampling (visual inspection)")
380
+ for i in range(min(3, input_ids.shape[0])):
381
+ sample_tokens = input_ids[i][:100].tolist()
382
+ decoded_text = tokenizer.decode(sample_tokens)
383
+ preview = decoded_text[:200].replace("\n", "\\n")
384
+ print(f" Sample {i}: {preview}...")
385
+
386
+ passed_count = sum(1 for r in results if r["passed"])
387
+ total_count = len(results)
388
+ print(f"\n Result: {passed_count}/{total_count} checks passed")
389
+
390
+ return {
391
+ "level": 1,
392
+ "checks": results,
393
+ "passed": [r for r in results if r["passed"]],
394
+ "failed": [r for r in results if not r["passed"]],
395
+ }
396
+
397
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
398
+ # Level 2: Numerical Stability
399
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
400
+
401
+ @staticmethod
402
+ def check_numerical_stability(
403
+ model: nn.Module,
404
+ dataloader: DataLoader,
405
+ device: torch.device,
406
+ dtype: torch.dtype = torch.bfloat16,
407
+ ) -> Dict[str, Any]:
408
+ """Check for NaN/Inf in gradients, activations, and logits (Level 2).
409
+
410
+ Checks:
411
+ - Mixed precision config (RMSNorm fp32 upcast, loss dtype)
412
+ - NaN/Inf gradients โ†’ softmax overflow, bad data
413
+ - Inf gradients โ†’ log(0) in loss, missing ignore_index
414
+ - Large activations growing per layer โ†’ initialization or norm bug
415
+ - Logit scale โ†’ should be < 1000
416
+ """
417
+ print(_header("Level 2: Numerical Stability Checks"))
418
+
419
+ batch = next(iter(dataloader))
420
+ input_ids = batch["input_ids"].to(device)
421
+ targets = batch["targets"].to(device)
422
+
423
+ results: List[Dict[str, Any]] = []
424
+ activation_stats: List[Dict[str, Any]] = []
425
+
426
+ # โ”€โ”€ Mixed Precision Configuration Check โ”€โ”€
427
+ print("\n Mixed Precision Config:")
428
+ print(f" Training dtype: {dtype}")
429
+
430
+ # Check RMSNorm fp32 upcast
431
+ norm_fp32_ok = True
432
+ for name, module in model.named_modules():
433
+ cls_name = module.__class__.__name__
434
+ if "Norm" in cls_name:
435
+ # Inspect forward source for .float() call
436
+ import inspect
437
+ try:
438
+ src = inspect.getsource(module.forward)
439
+ has_upcast = ".float()" in src or "float32" in src
440
+ except (TypeError, OSError):
441
+ has_upcast = True # assume ok if can't inspect
442
+ if not has_upcast:
443
+ norm_fp32_ok = False
444
+ print(f" ๐Ÿ”ด {name} ({cls_name}): no fp32 upcast detected!")
445
+ break # check one norm layer is enough
446
+ if norm_fp32_ok:
447
+ print(f" โœ… Norm layers use fp32 upcast (safe)")
448
+
449
+ results.append(_check_result(
450
+ "Norm fp32 upcast", norm_fp32_ok,
451
+ "Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype",
452
+ ))
453
+
454
+ # Check loss computation dtype
455
+ loss_fp32_note = (
456
+ dtype in (torch.bfloat16, torch.float16)
457
+ and "cross_entropy" in str(type(model))
458
+ )
459
+ if dtype in (torch.bfloat16, torch.float16):
460
+ print(f" โ„น๏ธ Best practice: compute loss in fp32 when using {dtype}")
461
+ print(f" logits_fp32 = logits.float()")
462
+ print(f" loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))")
463
+
464
+ # Common numerical issues reference
465
+ print("\n Common Numerical Issues Reference:")
466
+ print(" โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”")
467
+ print(" โ”‚ Symptom โ”‚ Likely Cause โ”‚ Solution โ”‚")
468
+ print(" โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค")
469
+ print(" โ”‚ Loss โ†’ NaN โ”‚ Large logits โ†’ softmax โ”‚ Check init, logit scale โ”‚")
470
+ print(" โ”‚ Loss โ†’ Inf โ”‚ log(0) in CE loss โ”‚ Add eps, ignore_index โ”‚")
471
+ print(" โ”‚ Loss oscillation โ”‚ fp16 gradient underflow โ”‚ Switch to bf16 / scaler โ”‚")
472
+ print(" โ”‚ Late-training NaN โ”‚ Activation growth โ”‚ Check RMSNorm, wd โ”‚")
473
+ print(" โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜")
474
+
475
+ # โ”€โ”€ Activation monitoring via hooks โ”€โ”€
476
+ hooks = []
477
+
478
+ def make_hook(name: str):
479
+ def hook_fn(module, input, output):
480
+ if isinstance(output, torch.Tensor):
481
+ out_f = output.float()
482
+ stats = {
483
+ "name": name,
484
+ "mean": out_f.mean().item(),
485
+ "std": out_f.std().item(),
486
+ "max": out_f.abs().max().item(),
487
+ "has_nan": bool(torch.isnan(output).any()),
488
+ "has_inf": bool(torch.isinf(output).any()),
489
+ }
490
+ activation_stats.append(stats)
491
+ return hook_fn
492
+
493
+ # Register hooks on transformer layers
494
+ for i, layer in enumerate(model.layers):
495
+ h = layer.register_forward_hook(make_hook(f"layer_{i}"))
496
+ hooks.append(h)
497
+
498
+ # โ”€โ”€ Forward + Backward โ”€โ”€
499
+ model.train()
500
+ model.zero_grad(set_to_none=True)
501
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
502
+ logits, loss = model(input_ids, targets)
503
+
504
+ loss_val = loss.item()
505
+ loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val))
506
+ results.append(_check_result(
507
+ "Loss value",
508
+ loss_ok,
509
+ f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
510
+ ))
511
+
512
+ loss.backward()
513
+
514
+ # Remove hooks
515
+ for h in hooks:
516
+ h.remove()
517
+
518
+ # โ”€โ”€ Gradient checks โ”€โ”€
519
+ print("\n Gradient Health:")
520
+ grad_issues = []
521
+ for name, param in model.named_parameters():
522
+ if param.grad is None:
523
+ continue
524
+ grad = param.grad
525
+ if torch.isnan(grad).any():
526
+ grad_issues.append(f"๐Ÿ”ด NaN gradient: {name}")
527
+ if torch.isinf(grad).any():
528
+ grad_issues.append(f"๐Ÿ”ด Inf gradient: {name}")
529
+ if grad.abs().max().item() > 100:
530
+ grad_issues.append(
531
+ f"๐ŸŸก Large gradient: {name} max={grad.abs().max().item():.1f}"
532
+ )
533
+
534
+ grad_ok = len(grad_issues) == 0
535
+ if grad_ok:
536
+ print(" โœ… All gradients are healthy (no NaN/Inf/large values)")
537
+ else:
538
+ for issue in grad_issues[:10]: # limit output
539
+ print(f" {issue}")
540
+ if len(grad_issues) > 10:
541
+ print(f" ... and {len(grad_issues) - 10} more issues")
542
+
543
+ results.append(_check_result(
544
+ "Gradient health",
545
+ grad_ok,
546
+ f"{len(grad_issues)} issues found" if not grad_ok else "All healthy",
547
+ ))
548
+
549
+ # โ”€โ”€ Activation checks โ”€โ”€
550
+ print("\n Activation Stats (per transformer layer):")
551
+ act_nan_count = 0
552
+ for stats in activation_stats:
553
+ icon = "๐Ÿ”ด" if stats["has_nan"] or stats["has_inf"] else " "
554
+ if stats["has_nan"] or stats["has_inf"]:
555
+ act_nan_count += 1
556
+ print(
557
+ f" {icon} {stats['name']}: "
558
+ f"mean={stats['mean']:.4f}, "
559
+ f"std={stats['std']:.4f}, "
560
+ f"max={stats['max']:.4f}"
561
+ + (" [NaN!]" if stats["has_nan"] else "")
562
+ + (" [Inf!]" if stats["has_inf"] else "")
563
+ )
564
+
565
+ act_ok = act_nan_count == 0
566
+ results.append(_check_result(
567
+ "Activation health",
568
+ act_ok,
569
+ f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
570
+ ))
571
+
572
+ # โ”€โ”€ Logit scale check โ”€โ”€
573
+ logit_max = logits.float().abs().max().item()
574
+ logit_ok = logit_max < 1000
575
+ detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)"
576
+ results.append(_check_result("Logit scale", logit_ok, detail))
577
+ icon = "โœ…" if logit_ok else "๐Ÿ”ด"
578
+ print(f"\n {icon} Logit scale: {detail}")
579
+
580
+ model.zero_grad(set_to_none=True)
581
+
582
+ passed_count = sum(1 for r in results if r["passed"])
583
+ print(f"\n Result: {passed_count}/{len(results)} checks passed")
584
+
585
+ return {
586
+ "level": 2,
587
+ "checks": results,
588
+ "activation_stats": activation_stats,
589
+ "grad_issues": grad_issues,
590
+ }
591
+
592
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
593
+ # Level 3: Hyperparameter Diagnosis
594
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
595
+
596
+ @staticmethod
597
+ def diagnose_hyperparameters(
598
+ metrics_history: Dict[str, list],
599
+ config: TrainConfig,
600
+ ) -> Dict[str, Any]:
601
+ """Analyze hyperparameter health from training metrics (Level 3).
602
+
603
+ Checks:
604
+ - LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny)
605
+ - Batch size: loss variance indicates batch too small
606
+ - Warmup: spikes in early steps indicate warmup too short
607
+ """
608
+ print(_header("Level 3: Hyperparameter Diagnosis"))
609
+
610
+ findings: List[Dict[str, str]] = []
611
+ grad_norms = metrics_history.get("grad_norm", [])
612
+ train_losses = metrics_history.get("train_loss", [])
613
+
614
+ # โ”€โ”€ LR diagnosis โ”€โ”€
615
+ print("\n Learning Rate Analysis:")
616
+ print(f" Peak LR: {config.learning_rate:.2e}")
617
+ print(f" Min LR: {config.min_learning_rate:.2e}")
618
+
619
+ if grad_norms:
620
+ avg_grad = sum(grad_norms) / len(grad_norms)
621
+ clip_count = sum(1 for g in grad_norms if g >= config.grad_clip * 0.95)
622
+ clip_rate = clip_count / len(grad_norms)
623
+ tiny_count = sum(1 for g in grad_norms if g < 0.01)
624
+ tiny_rate = tiny_count / len(grad_norms)
625
+
626
+ print(f" Avg grad norm: {avg_grad:.4f}")
627
+ print(f" Clip rate: {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})")
628
+ print(f" Tiny grad rate: {tiny_rate * 100:.1f}% (< 0.01)")
629
+
630
+ if clip_rate > 0.3:
631
+ findings.append({
632
+ "issue": "LR may be too high",
633
+ "evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time",
634
+ "action": f"Try LR = {config.learning_rate / 2:.2e} (รท2)",
635
+ })
636
+ print(f" ๐ŸŸก Grad clipping frequent ({clip_rate * 100:.0f}%) โ†’ LR may be too high")
637
+ elif tiny_rate > 0.5:
638
+ findings.append({
639
+ "issue": "LR may be too low",
640
+ "evidence": f"Grad norm < 0.01 in {tiny_rate * 100:.0f}% of steps",
641
+ "action": f"Try LR = {config.learning_rate * 2:.2e} (ร—2)",
642
+ })
643
+ print(f" ๐ŸŸก Grad norm too small ({tiny_rate * 100:.0f}% < 0.01) โ†’ LR may be too low")
644
+ else:
645
+ print(f" โœ… LR looks appropriate")
646
+
647
+ # โ”€โ”€ Batch size diagnosis โ”€โ”€
648
+ print("\n Batch Size Analysis:")
649
+ print(f" Effective batch: {config.effective_batch_size}")
650
+
651
+ if len(train_losses) >= 20:
652
+ recent_losses = train_losses[-20:]
653
+ loss_mean = sum(recent_losses) / len(recent_losses)
654
+ loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
655
+ loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)
656
+
657
+ print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation)")
658
+
659
+ if loss_cv > 0.1:
660
+ findings.append({
661
+ "issue": "Batch size may be too small",
662
+ "evidence": f"Loss CV = {loss_cv:.4f} (high variance)",
663
+ "action": "Increase gradient_accumulation_steps",
664
+ })
665
+ print(f" ๐ŸŸก High loss variance โ†’ batch may be too small")
666
+ else:
667
+ print(f" โœ… Loss variance is acceptable")
668
+
669
+ # โ”€โ”€ ฮฒโ‚‚ diagnosis โ”€โ”€
670
+ print("\n ฮฒโ‚‚ (Adam second momentum) Analysis:")
671
+ print(f" Current ฮฒโ‚‚: {config.beta2}")
672
+ if config.beta2 >= _DEFAULT_PYTORCH_BETA2:
673
+ findings.append({
674
+ "issue": "ฮฒโ‚‚ may be too high for LLM training",
675
+ "evidence": (
676
+ f"ฮฒโ‚‚={config.beta2} (PyTorch default). "
677
+ f"LLM standard is {_RECOMMENDED_BETA2}"
678
+ ),
679
+ "action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)",
680
+ })
681
+ print(f" ๐ŸŸก ฮฒโ‚‚={config.beta2} is PyTorch default โ†’ "
682
+ f"LLM training standard is {_RECOMMENDED_BETA2}")
683
+ print(f" Why: ฮฒโ‚‚=0.999 averages ~1000 steps of gradient stats,")
684
+ print(f" ฮฒโ‚‚=0.95 averages ~20 steps โ†’ faster adaptation to changing data")
685
+ print(f" (Cattaneo & Shigida 2025, Compagnoni et al. 2025)")
686
+ else:
687
+ print(f" โœ… ฮฒโ‚‚={config.beta2} is within LLM standard range")
688
+
689
+ # โ”€โ”€ Weight Decay diagnosis โ”€โ”€
690
+ print("\n Weight Decay Analysis:")
691
+ print(f" Current weight_decay: {config.weight_decay}")
692
+ if config.weight_decay == 0:
693
+ findings.append({
694
+ "issue": "Weight decay is disabled",
695
+ "evidence": "weight_decay=0 increases overfitting risk",
696
+ "action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)",
697
+ })
698
+ print(f" ๐ŸŸก weight_decay=0 โ†’ overfitting risk. Standard is 0.1")
699
+ elif config.weight_decay > 0.3:
700
+ findings.append({
701
+ "issue": "Weight decay may be too high",
702
+ "evidence": f"weight_decay={config.weight_decay} (unusually high)",
703
+ "action": "Try weight_decay=0.1 (standard value)",
704
+ })
705
+ print(f" ๐ŸŸก weight_decay={config.weight_decay} is unusually high (standard: 0.1)")
706
+ else:
707
+ print(f" โœ… weight_decay={config.weight_decay} is within normal range")
708
+
709
+ # โ”€โ”€ Model-size LR reference โ”€โ”€
710
+ print("\n GPT-3 LR Reference (Brown et al. 2020):")
711
+ print(" โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”")
712
+ print(" โ”‚ Model โ”‚ Peak LR โ”‚ Batch Tokens โ”‚")
713
+ print(" โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค")
714
+ for params, lr, batch_tok in _GPT3_LR_REFERENCE:
715
+ label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M"
716
+ marker = " โ†" if abs(params - 1.1e9) < 0.5e9 else ""
717
+ print(f" โ”‚ {label:<8} โ”‚ {lr:.1e} โ”‚ {batch_tok:<12} โ”‚{marker}")
718
+ print(" โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜")
719
+ print(" โ†’ Larger models need lower LR and larger batch")
720
+
721
+ # โ”€โ”€ Batch-LR scaling guidance โ”€โ”€
722
+ print("\n Batch-LR Scaling Rules:")
723
+ print(" โ€ข Batch ร—2 โ†’ LR ร—โˆš2 (square root scaling, safer)")
724
+ print(" โ€ข Batch ร—2 โ†’ LR ร—2 (linear scaling, used by GPT-3)")
725
+ print(" โ€ข 1B model: effective batch 64~512 is typical range")
726
+
727
+ # โ”€โ”€ Warmup diagnosis โ”€โ”€
728
+ print("\n Warmup Analysis:")
729
+ print(f" Warmup steps: {config.warmup_steps} "
730
+ f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)")
731
+
732
+ if len(train_losses) >= 10:
733
+ early_losses = train_losses[:min(50, len(train_losses))]
734
+ # Detect spikes in early training
735
+ spike_count = 0
736
+ for i in range(1, len(early_losses)):
737
+ if early_losses[i] > early_losses[i - 1] * 1.5:
738
+ spike_count += 1
739
+
740
+ if spike_count > 3:
741
+ findings.append({
742
+ "issue": "Warmup may be too short",
743
+ "evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps",
744
+ "action": f"Try warmup_steps = {config.warmup_steps * 2}",
745
+ })
746
+ print(f" ๐ŸŸก {spike_count} spikes in early training โ†’ warmup may be too short")
747
+ else:
748
+ print(f" โœ… Early training is stable")
749
+
750
+ # โ”€โ”€ Summary โ”€โ”€
751
+ if not findings:
752
+ print("\n โœ… No hyperparameter issues detected")
753
+ else:
754
+ print(f"\n Found {len(findings)} potential issue(s):")
755
+ for f in findings:
756
+ print(f" โ€ข {f['issue']}: {f['action']}")
757
+
758
+ # โ”€โ”€ Warmup reference from real projects โ”€โ”€
759
+ print("\n Warmup Reference (real projects):")
760
+ print(" โ€ข TinyLlama 1.1B (3T tokens): 2,000 steps โ‰ˆ 0.1% of total")
761
+ print(" โ€ข GPT-3 175B: 375 steps โ‰ˆ 0.2% of total")
762
+ print(" โ€ข General range: 0.1% ~ 5% of total steps")
763
+ print(" โ€ข Smaller experiments: 5~10% is also reasonable")
764
+
765
+ print("\n Tuning priority (high โ†’ low):")
766
+ print(" 1. Learning Rate โ† tune first (10x impact)")
767
+ print(" 2. Batch Size โ† adjust with LR")
768
+ print(" 3. Warmup Steps โ† early stability")
769
+ print(" 4. Weight Decay โ† if overfitting (typically 0.1)")
770
+ print(" 5. ฮฒโ‚, ฮฒโ‚‚ (Adam) โ† see ฮฒโ‚‚ analysis above")
771
+ print(" 6. Gradient Clip โ† usually keep at 1.0")
772
+
773
+ return {
774
+ "level": 3,
775
+ "findings": findings,
776
+ "config_summary": {
777
+ "learning_rate": config.learning_rate,
778
+ "effective_batch": config.effective_batch_size,
779
+ "warmup_steps": config.warmup_steps,
780
+ "total_steps": config.total_steps,
781
+ "grad_clip": config.grad_clip,
782
+ },
783
+ }
784
+
785
+ @staticmethod
786
+ def lr_range_test(
787
+ model: nn.Module,
788
+ dataloader: DataLoader,
789
+ device: torch.device,
790
+ dtype: torch.dtype = torch.bfloat16,
791
+ lr_start: float = 1e-7,
792
+ lr_end: float = 1e-1,
793
+ steps: int = 300,
794
+ ) -> Dict[str, Any]:
795
+ """Run an LR range test to find the optimal learning rate (Level 3 bonus).
796
+
797
+ Sweeps LR from lr_start to lr_end exponentially, recording loss.
798
+ The optimal LR is where loss decreases fastest (steepest slope),
799
+ divided by 3~10 for stability.
800
+
801
+ WARNING: This modifies a copy of the model. The original is untouched.
802
+ """
803
+ print(_header("Level 3 Bonus: LR Range Test"))
804
+ print(f" Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n")
805
+
806
+ test_model = copy.deepcopy(model)
807
+ test_model.to(device)
808
+ test_model.train()
809
+ optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start)
810
+
811
+ lr_mult = (lr_end / lr_start) ** (1 / steps)
812
+ lr = lr_start
813
+
814
+ lrs: List[float] = []
815
+ losses: List[float] = []
816
+ data_iter = iter(dataloader)
817
+
818
+ for step in range(steps):
819
+ for pg in optimizer.param_groups:
820
+ pg["lr"] = lr
821
+
822
+ try:
823
+ batch = next(data_iter)
824
+ except StopIteration:
825
+ data_iter = iter(dataloader)
826
+ batch = next(data_iter)
827
+
828
+ input_ids = batch["input_ids"].to(device)
829
+ targets_t = batch["targets"].to(device)
830
+
831
+ optimizer.zero_grad()
832
+ with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
833
+ _, loss = test_model(input_ids, targets_t)
834
+ loss.backward()
835
+ optimizer.step()
836
+
837
+ loss_val = loss.item()
838
+ lrs.append(lr)
839
+ losses.append(loss_val)
840
+
841
+ if (step + 1) % 50 == 0:
842
+ print(f" Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}")
843
+
844
+ # Stop if loss explodes
845
+ if len(losses) > 1 and loss_val > losses[0] * 4:
846
+ print(f" Loss exploded at LR = {lr:.2e}, stopping.")
847
+ break
848
+
849
+ lr *= lr_mult
850
+
851
+ del test_model, optimizer
852
+ if torch.cuda.is_available():
853
+ torch.cuda.empty_cache()
854
+
855
+ # Find steepest descent
856
+ best_lr = lr_start
857
+ if len(losses) > 10:
858
+ # Smooth losses and find steepest negative slope
859
+ window = 5
860
+ smoothed = []
861
+ for i in range(len(losses) - window):
862
+ smoothed.append(sum(losses[i:i + window]) / window)
863
+
864
+ min_slope = 0
865
+ min_idx = 0
866
+ for i in range(1, len(smoothed)):
867
+ slope = smoothed[i] - smoothed[i - 1]
868
+ if slope < min_slope:
869
+ min_slope = slope
870
+ min_idx = i
871
+
872
+ best_lr = lrs[min_idx]
873
+ suggested_lr = best_lr / 3 # conservative choice
874
+
875
+ print(f"\n Steepest descent at LR = {best_lr:.2e}")
876
+ print(f" Suggested peak LR: {suggested_lr:.2e} (รท3 for stability)")
877
+ print(f" Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]")
878
+ else:
879
+ suggested_lr = 3e-4
880
+ print(f"\n Not enough data points. Using default LR = {suggested_lr:.2e}")
881
+
882
+ return {
883
+ "lrs": lrs,
884
+ "losses": losses,
885
+ "best_lr": best_lr,
886
+ "suggested_lr": suggested_lr,
887
+ }
888
+
889
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
890
+ # Level 4: Overfitting vs Underfitting Diagnosis
891
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
892
+
893
+ @staticmethod
894
+ def diagnose_fitting(
895
+ metrics_history: Dict[str, list],
896
+ model_params: Optional[int] = None,
897
+ total_tokens: Optional[int] = None,
898
+ ) -> Dict[str, Any]:
899
+ """Diagnose overfitting vs underfitting from metrics (Level 4).
900
+
901
+ Cases:
902
+ 1. Both high, decreasing โ†’ Normal (still training)
903
+ 2. Both high, plateau โ†’ Underfitting
904
+ 3. Trainโ†“ Valโ†’ or Valโ†‘ โ†’ Overfitting
905
+ 4. Both low, plateau โ†’ Converged (or at limit)
906
+ """
907
+ print(_header("Level 4: Overfitting vs Underfitting Diagnosis"))
908
+
909
+ train_losses = metrics_history.get("train_loss", [])
910
+ val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
911
+
912
+ if len(train_losses) < 10 or len(val_losses) < 2:
913
+ print(" [!] Not enough data. Need more training steps with eval.")
914
+ return {"level": 4, "case": "insufficient_data", "recommendations": []}
915
+
916
+ # Recent train trend
917
+ recent_n = min(50, len(train_losses))
918
+ train_recent = train_losses[-recent_n:]
919
+ train_mid = len(train_recent) // 2
920
+ train_first = sum(train_recent[:train_mid]) / max(train_mid, 1)
921
+ train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1)
922
+ train_decreasing = train_second < train_first - 0.02
923
+
924
+ # Val trend
925
+ val_mid = len(val_losses) // 2
926
+ val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
927
+ val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
928
+ val_decreasing = val_second < val_first - 0.02
929
+ val_increasing = val_second > val_first + 0.05
930
+
931
+ # Train-Val gap
932
+ last_train = train_losses[-1]
933
+ last_val = val_losses[-1]
934
+ gap = last_train - last_val # negative means val > train (typical)
935
+
936
+ print(f" Train loss (recent): {train_first:.4f} โ†’ {train_second:.4f} "
937
+ f"({'โ†“' if train_decreasing else 'โ†’'})")
938
+ print(f" Val loss: {val_first:.4f} โ†’ {val_second:.4f} "
939
+ f"({'โ†“' if val_decreasing else 'โ†‘' if val_increasing else 'โ†’'})")
940
+ print(f" Train-Val gap: {abs(gap):.4f}")
941
+
942
+ # โ”€โ”€ Classify โ”€โ”€
943
+ case = ""
944
+ recommendations: List[str] = []
945
+
946
+ if train_decreasing and val_decreasing:
947
+ case = "Case 1: Normal โ€” both decreasing"
948
+ recommendations.append("Training is progressing normally. Continue.")
949
+ if model_params and total_tokens:
950
+ ratio = total_tokens / model_params
951
+ chinchilla = 20 # Chinchilla optimal: 20 tokens per param
952
+ if ratio < chinchilla:
953
+ recommendations.append(
954
+ f"Token/param ratio = {ratio:.1f}x "
955
+ f"(Chinchilla optimal โ‰ˆ {chinchilla}x). "
956
+ f"Model may benefit from more data."
957
+ )
958
+ print(f"\n ๐ŸŸข {case}")
959
+
960
+ elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]:
961
+ case = "Case 2: Underfitting โ€” both plateaued at high loss"
962
+ recommendations = [
963
+ "Diagnosis priority (check in order):",
964
+ "1) Training insufficient? โ†’ check if loss curve still has downward slope",
965
+ " - Chinchilla: 1B model needs ~20B tokens minimum",
966
+ " - TinyLlama trains 1.1B on 3T tokens (inference-optimal)",
967
+ "2) LR too low? โ†’ try LR ร—2, see if loss drops faster",
968
+ "3) Model capacity too small? โ†’ train 2x larger model on same data",
969
+ " - If larger model gets lower loss โ†’ capacity was the limit",
970
+ "4) Data quality? โ†’ sample and read training data manually",
971
+ " - Noisy/low-quality data raises the achievable loss floor",
972
+ ]
973
+ if model_params and total_tokens:
974
+ ratio = total_tokens / model_params
975
+ if ratio < 10:
976
+ recommendations.insert(0,
977
+ f"โš  Token/param ratio = {ratio:.1f}x โ€” "
978
+ f"very likely undertrained. Chinchilla recommends โ‰ฅ20x."
979
+ )
980
+ elif ratio < 20:
981
+ recommendations.insert(0,
982
+ f"โ„น Token/param ratio = {ratio:.1f}x โ€” "
983
+ f"below Chinchilla optimal (20x). More tokens may help."
984
+ )
985
+ print(f"\n ๐ŸŸก {case}")
986
+
987
+ elif train_decreasing and (val_increasing or not val_decreasing):
988
+ case = "Case 3: Overfitting โ€” trainโ†“ but valโ†’/โ†‘"
989
+ recommendations = [
990
+ "Diagnosis priority (check in order):",
991
+ "1) Data repetition? (most common cause in pretraining)",
992
+ " - Check: total tokens vs unique tokens",
993
+ " - Epoch > 1 dramatically increases overfitting risk",
994
+ " - Solution: add more data, stay within 1 epoch",
995
+ "2) Weight decay too low?",
996
+ " - Check: weight_decay value (standard: 0.1)",
997
+ " - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1",
998
+ " - Experiment: 0.01 / 0.05 / 0.1 / 0.3",
999
+ "3) Data diversity?",
1000
+ " - Single-domain data overfits faster",
1001
+ " - Mix: web, books, code, wiki, etc.",
1002
+ "",
1003
+ "Note on Dropout in LLM pretraining:",
1004
+ " - Modern LLMs do NOT use dropout in pretraining",
1005
+ " (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)",
1006
+ " - Sufficient data is the best regularization",
1007
+ " - Dropout is useful for fine-tuning on small datasets",
1008
+ ]
1009
+ print(f"\n ๐ŸŸก {case}")
1010
+
1011
+ else:
1012
+ case = "Case 4: Converged โ€” loss is low and stable"
1013
+ recommendations = [
1014
+ "Training has converged (or reached the data/model limit).",
1015
+ "To push further: add more data or increase model size.",
1016
+ ]
1017
+ print(f"\n ๐ŸŸข {case}")
1018
+
1019
+ for rec in recommendations:
1020
+ print(f" {rec}")
1021
+
1022
+ return {
1023
+ "level": 4,
1024
+ "case": case,
1025
+ "train_trend": "decreasing" if train_decreasing else "flat",
1026
+ "val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"),
1027
+ "gap": abs(gap),
1028
+ "recommendations": recommendations,
1029
+ }
1030
+
1031
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1032
+ # Level 5: Architecture Checks
1033
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1034
+
1035
+ @staticmethod
1036
+ def check_architecture(
1037
+ model: nn.Module,
1038
+ dataloader: DataLoader,
1039
+ device: torch.device,
1040
+ ) -> Dict[str, Any]:
1041
+ """Check weight initialization and per-layer activation health (Level 5).
1042
+
1043
+ Healthy initialization:
1044
+ - All layers: std โ‰ˆ 1.0, mean โ‰ˆ 0.0
1045
+ Problems:
1046
+ - std increasing per layer โ†’ activation explosion (init scale too large)
1047
+ - std decreasing per layer โ†’ activation vanishing (init scale too small)
1048
+ - Sudden change at specific layer โ†’ implementation bug in that layer
1049
+ """
1050
+ print(_header("Level 5: Architecture / Initialization Check"))
1051
+
1052
+ batch = next(iter(dataloader))
1053
+ sample_input = batch["input_ids"][:1].to(device)
1054
+
1055
+ model.eval()
1056
+ layer_stats: List[Dict[str, Any]] = []
1057
+
1058
+ with torch.no_grad():
1059
+ h = model.token_embedding(sample_input)
1060
+ emb_std = h.float().std().item()
1061
+ print(f"\n Embedding: std={emb_std:.4f}")
1062
+
1063
+ for i, layer in enumerate(model.layers):
1064
+ h = layer(h, mask=None, position_offset=0)
1065
+ h_f = h.float()
1066
+ stats = {
1067
+ "layer": i,
1068
+ "mean": h_f.mean().item(),
1069
+ "std": h_f.std().item(),
1070
+ "max": h_f.abs().max().item(),
1071
+ }
1072
+ layer_stats.append(stats)
1073
+
1074
+ # Print stats
1075
+ print(f"\n Layer-by-layer activation statistics:")
1076
+ print(f" {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}")
1077
+ print(f" {'-' * 38}")
1078
+ for s in layer_stats:
1079
+ print(f" {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}")
1080
+
1081
+ # โ”€โ”€ Weight initialization distribution check โ”€โ”€
1082
+ print(f"\n Weight Initialization Distribution:")
1083
+ print(f" {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}")
1084
+ print(f" {'-' * 75}")
1085
+ weight_issues = []
1086
+ for name, param in model.named_parameters():
1087
+ if param.ndim < 2:
1088
+ continue # skip biases, norm weights
1089
+ p_f = param.float()
1090
+ p_mean = p_f.mean().item()
1091
+ p_std = p_f.std().item()
1092
+ # Expected: std โ‰ˆ 0.02 for most layers, smaller for residual projections
1093
+ shape_str = str(list(param.shape))
1094
+ is_residual = "o_proj" in name or "down_proj" in name
1095
+ expected_std = 0.02 # GPT-2 style
1096
+ if p_std > expected_std * 5:
1097
+ weight_issues.append(f"Large std: {name} (std={p_std:.4f})")
1098
+ print(f" ๐ŸŸก {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
1099
+ elif p_std < expected_std * 0.1:
1100
+ weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})")
1101
+ print(f" ๐ŸŸก {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}")
1102
+ else:
1103
+ print(f" {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
1104
+
1105
+ if weight_issues:
1106
+ print(f"\n โš  {len(weight_issues)} weight distribution issue(s) found")
1107
+ for issue in weight_issues[:5]:
1108
+ print(f" โ€ข {issue}")
1109
+ else:
1110
+ print(f"\n โœ… All weight distributions look normal (std โ‰ˆ 0.02)")
1111
+
1112
+ print(f"\n Expected init pattern:")
1113
+ print(f" โ€ข General Linear: N(0, 0.02)")
1114
+ print(f" โ€ข Residual proj (o_proj, down_proj): N(0, 0.02/โˆš(2ร—layers))")
1115
+ print(f" โ€ข Embedding: N(0, 0.02)")
1116
+
1117
+ # โ”€โ”€ Ablation study guidance โ”€โ”€
1118
+ print(f"\n Component Ablation Reference:")
1119
+ print(" โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”")
1120
+ print(" โ”‚ Experiment โ”‚ Expected Outcome โ”‚")
1121
+ print(" โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค")
1122
+ print(" โ”‚ RMSNorm โ†’ LayerNorm โ”‚ Minimal loss diff โ†’ OK โ”‚")
1123
+ print(" โ”‚ RoPE โ†’ Absolute PE โ”‚ Similar on short seq (<512) โ”‚")
1124
+ print(" โ”‚ SwiGLU โ†’ ReLU FFN โ”‚ Loss +0.05~0.15 โ†’ SwiGLU working โ”‚")
1125
+ print(" โ”‚ GQA โ†’ MHA โ”‚ Same loss, less memory โ†’ OK โ”‚")
1126
+ print(" โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜")
1127
+ print(" If any replacement shows unexpected results, check that component.")
1128
+
1129
+ # Analyze trends
1130
+ stds = [s["std"] for s in layer_stats]
1131
+ diagnosis = "healthy"
1132
+ detail = ""
1133
+
1134
+ if len(stds) >= 3:
1135
+ # Check for monotonic increase/decrease
1136
+ first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3)
1137
+ last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3)
1138
+ ratio = last_third / max(first_third, 1e-8)
1139
+
1140
+ if ratio > 5:
1141
+ diagnosis = "exploding"
1142
+ detail = (
1143
+ f"Activation std grows {ratio:.1f}x from early to late layers. "
1144
+ f"Init scale may be too large."
1145
+ )
1146
+ elif ratio < 0.2:
1147
+ diagnosis = "vanishing"
1148
+ detail = (
1149
+ f"Activation std shrinks to {ratio:.1f}x from early to late layers. "
1150
+ f"Init scale may be too small."
1151
+ )
1152
+ else:
1153
+ detail = f"Std ratio (last/first third) = {ratio:.2f} โ€” within normal range."
1154
+
1155
+ # Check for sudden jumps
1156
+ for i in range(1, len(stds)):
1157
+ jump = stds[i] / max(stds[i - 1], 1e-8)
1158
+ if jump > 10 or jump < 0.1:
1159
+ diagnosis = "anomaly"
1160
+ detail = (
1161
+ f"Sudden activation change at layer {i}: "
1162
+ f"std {stds[i - 1]:.4f} โ†’ {stds[i]:.4f}. "
1163
+ f"Possible implementation bug in that layer."
1164
+ )
1165
+ break
1166
+
1167
+ icon = {"healthy": "โœ…", "exploding": "๐Ÿ”ด", "vanishing": "๐ŸŸก", "anomaly": "๐Ÿ”ด"}
1168
+ print(f"\n {icon.get(diagnosis, 'โšช')} Diagnosis: {diagnosis}")
1169
+ print(f" {detail}")
1170
+
1171
+ return {
1172
+ "level": 5,
1173
+ "diagnosis": diagnosis,
1174
+ "detail": detail,
1175
+ "layer_stats": layer_stats,
1176
+ "weight_issues": weight_issues,
1177
+ }
1178
+
1179
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1180
+ # Scenario Auto-Detection
1181
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1182
+
1183
+ @staticmethod
1184
+ def detect_scenario(
1185
+ metrics_history: Dict[str, list],
1186
+ vocab_size: int = 32000,
1187
+ ) -> Dict[str, Any]:
1188
+ """Auto-detect which debugging scenario applies.
1189
+
1190
+ Scenarios (from the guide):
1191
+ A: Loss stuck at ~10.37 (doesn't decrease at all)
1192
+ B: Loss was decreasing then suddenly NaN
1193
+ C: Loss decreased to X then started increasing
1194
+ D: Loss stuck at high value (e.g. 4.0 for 1B model)
1195
+ """
1196
+ print(_header("Scenario Auto-Detection"))
1197
+
1198
+ train_losses = metrics_history.get("train_loss", [])
1199
+ val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
1200
+ expected_initial = math.log(vocab_size)
1201
+
1202
+ if len(train_losses) < 5:
1203
+ print(" [!] Not enough data to detect scenario.")
1204
+ return {"scenario": "unknown", "steps": []}
1205
+
1206
+ first_loss = train_losses[0]
1207
+ last_loss = train_losses[-1]
1208
+ has_nan = any(math.isnan(l) for l in train_losses)
1209
+ min_loss = min(l for l in train_losses if not math.isnan(l))
1210
+ min_loss_idx = next(i for i, l in enumerate(train_losses) if l == min_loss)
1211
+ loss_recovered = last_loss > min_loss + 0.3 and min_loss_idx < len(train_losses) * 0.8
1212
+
1213
+ scenario = "unknown"
1214
+ steps: List[str] = []
1215
+
1216
+ # Scenario A: Loss stuck near initial value
1217
+ if abs(last_loss - expected_initial) < 1.5 and abs(first_loss - last_loss) < 0.5:
1218
+ scenario = "A"
1219
+ steps = [
1220
+ "1. Run single-batch overfit test โ†’ if it fails, model/loss has a bug",
1221
+ "2. Check if gradients are zero โ†’ optimizer.step() may be missing",
1222
+ "3. Verify input_ids/targets shift โ†’ data pipeline bug",
1223
+ "4. Check LR โ†’ is it set to 0?",
1224
+ "5. Check model.train() โ†’ eval mode changes norm/dropout behavior",
1225
+ ]
1226
+
1227
+ # Scenario B: NaN appeared
1228
+ elif has_nan:
1229
+ nan_idx = next(i for i, l in enumerate(train_losses) if math.isnan(l))
1230
+ scenario = "B"
1231
+ steps = [
1232
+ f"1. NaN appeared at step ~{nan_idx}. Check that batch's data for bad tokens",
1233
+ "2. Check gradient norm just before NaN โ†’ was there a spike?",
1234
+ "3. Check LR schedule โ†’ does NaN coincide with warmup end?",
1235
+ "4. Check specific layer weights for Inf values",
1236
+ "5. Try switching to fp32 to see if it's a mixed precision issue",
1237
+ " (Pythia-1B had irrecoverable fp16 loss spikes โ†’ switched to bf16,",
1238
+ " Biderman et al. 2023)",
1239
+ ]
1240
+
1241
+ # Scenario C: Loss decreased then increased
1242
+ elif loss_recovered:
1243
+ scenario = "C"
1244
+ steps = [
1245
+ "1. Check Train and Val loss simultaneously:",
1246
+ " - Both increasing โ†’ LR too high (check cosine decay)",
1247
+ " - Only train increasing โ†’ data quality changed (streaming order)",
1248
+ " - Only val increasing โ†’ overfitting started",
1249
+ "2. Verify LR schedule is decaying as intended",
1250
+ "3. Check data shuffling โ†’ same data repeating?",
1251
+ ]
1252
+
1253
+ # Scenario D: Loss stuck at high value
1254
+ elif last_loss > _EXPECTED_TRAIN_LOSS[1] and abs(last_loss - min_loss) < 0.3:
1255
+ scenario = "D"
1256
+ total_tokens = len(train_losses) * 262144 # approximate
1257
+ steps = [
1258
+ f"1. Check total tokens trained: ~{total_tokens / 1e9:.1f}B "
1259
+ f"(need 5-10B for 1B model)",
1260
+ "2. Compare with smaller model (100M) at same step โ†’ "
1261
+ "if 100M is lower, 1B may have a bug",
1262
+ "3. Run LR range test โ†’ current LR may not be optimal",
1263
+ "4. Sample training data โ†’ check for noise, duplicates, low quality",
1264
+ "5. Try different effective batch size (64 vs 128 vs 256)",
1265
+ ]
1266
+
1267
+ else:
1268
+ scenario = "none"
1269
+ steps = ["Training appears normal. No specific scenario detected."]
1270
+
1271
+ label = {
1272
+ "A": "Loss stuck at initial value (~10.37)",
1273
+ "B": "Loss was decreasing, then NaN",
1274
+ "C": "Loss decreased then started increasing",
1275
+ "D": f"Loss stuck at high value (>{_EXPECTED_TRAIN_LOSS[1]})",
1276
+ "none": "No problematic scenario detected",
1277
+ "unknown": "Cannot determine",
1278
+ }
1279
+
1280
+ print(f"\n Detected: Scenario {scenario} โ€” {label.get(scenario, 'Unknown')}")
1281
+ print(f"\n Recommended debugging steps:")
1282
+ for step in steps:
1283
+ print(f" {step}")
1284
+
1285
+ return {"scenario": scenario, "label": label.get(scenario, ""), "steps": steps}
1286
+
1287
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1288
+ # Main Entry Point
1289
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1290
+
1291
+ @staticmethod
1292
+ def run_diagnostics(
1293
+ model: nn.Module,
1294
+ dataloader: DataLoader,
1295
+ tokenizer: Any,
1296
+ train_config: TrainConfig,
1297
+ metrics_history: Dict[str, list],
1298
+ device: torch.device,
1299
+ dtype: torch.dtype = torch.bfloat16,
1300
+ vocab_size: int = 32000,
1301
+ levels: Optional[List[int]] = None,
1302
+ ) -> Dict[str, Any]:
1303
+ """Run the full 5-level debugging framework.
1304
+
1305
+ Args:
1306
+ model: the LLM model
1307
+ dataloader: training dataloader
1308
+ tokenizer: tokenizer with encode/decode methods
1309
+ train_config: TrainConfig instance
1310
+ metrics_history: dict from MetricsTracker.history
1311
+ device: torch device
1312
+ dtype: mixed precision dtype
1313
+ vocab_size: model vocabulary size
1314
+ levels: which levels to run (default: all [0,1,2,3,4,5])
1315
+
1316
+ Returns:
1317
+ Full diagnostic report dict.
1318
+ """
1319
+ if levels is None:
1320
+ levels = [0, 1, 2, 3, 4, 5]
1321
+
1322
+ print("\n" + "โ•" * 60)
1323
+ print(" LLM Loss Debugging Framework")
1324
+ print(" Levels to run: " + ", ".join(str(l) for l in levels))
1325
+ print("โ•" * 60)
1326
+
1327
+ report: Dict[str, Any] = {}
1328
+
1329
+ if 0 in levels:
1330
+ report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history)
1331
+ # If status is normal and only level 0 was explicitly requested, skip rest
1332
+ if (
1333
+ report["level_0"]["status"] == STATUS_NORMAL
1334
+ and levels == [0]
1335
+ ):
1336
+ print("\n Training is healthy โ€” no further debugging needed.")
1337
+ return report
1338
+
1339
+ if 1 in levels:
1340
+ report["level_1"] = LossDebugger.check_data_pipeline(
1341
+ model, dataloader, tokenizer, vocab_size, device, dtype,
1342
+ )
1343
+
1344
+ if 2 in levels:
1345
+ report["level_2"] = LossDebugger.check_numerical_stability(
1346
+ model, dataloader, device, dtype,
1347
+ )
1348
+
1349
+ if 3 in levels:
1350
+ report["level_3"] = LossDebugger.diagnose_hyperparameters(
1351
+ metrics_history, train_config,
1352
+ )
1353
+
1354
+ if 4 in levels:
1355
+ model_params = sum(p.numel() for p in model.parameters())
1356
+ total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step
1357
+ report["level_4"] = LossDebugger.diagnose_fitting(
1358
+ metrics_history, model_params, total_tokens,
1359
+ )
1360
+
1361
+ if 5 in levels:
1362
+ report["level_5"] = LossDebugger.check_architecture(
1363
+ model, dataloader, device,
1364
+ )
1365
+
1366
+ # Auto-detect scenario
1367
+ report["scenario"] = LossDebugger.detect_scenario(metrics_history, vocab_size)
1368
+
1369
+ # Final summary
1370
+ print("\n" + "โ•" * 60)
1371
+ print(" Diagnostics Complete")
1372
+ print("โ•" * 60)
1373
+
1374
+ return report
1375
+
1376
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1377
+ # Study Roadmap
1378
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1379
+
1380
+ @staticmethod
1381
+ def print_study_roadmap() -> None:
1382
+ """Print the recommended study roadmap for LLM training optimization."""
1383
+ print(_header("Study Roadmap โ€” LLM Training Optimization"))
1384
+
1385
+ print("""
1386
+ โญโญโญ Top Priority: Optimization Fundamentals
1387
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1388
+ 1. SGD โ†’ Momentum โ†’ Adam โ†’ AdamW progression
1389
+ - Why Adam > SGD? Why decouple weight decay in AdamW?
1390
+ - ฮฒโ‚, ฮฒโ‚‚ intuition (1st / 2nd momentum)
1391
+ - Ref: Loshchilov & Hutter 2019 (AdamW)
1392
+ - Ref: Karpathy "A Recipe for Training Neural Networks"
1393
+
1394
+ 2. Loss Landscape
1395
+ - Why large LR diverges, small LR stalls
1396
+ - Batch size effect on landscape exploration
1397
+ - Ref: Li et al. 2018 "Visualizing the Loss Landscape"
1398
+ - Ref: McCandlish et al. 2018 "Large-Batch Training"
1399
+
1400
+ 3. Chinchilla Scaling Law
1401
+ - Loss = f(N, D) relationship
1402
+ - Compute-optimal model size vs data allocation
1403
+ - Ref: Hoffmann et al. 2022 (original)
1404
+ - Ref: Kaplan et al. 2020 (predecessor)
1405
+ - Ref: Besiroglu et al. 2024 (replication/verification)
1406
+
1407
+ โญโญ Important: Training Stability
1408
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1409
+ 4. Gradient Flow: vanishing/exploding, residual as gradient highway
1410
+ 5. Weight Init: Xavier / Kaiming / GPT-2 style
1411
+ 6. Normalization: BatchNorm โ†’ LayerNorm โ†’ RMSNorm
1412
+ 7. Weight Decay: L2 vs decoupled, why exclude embed/norm
1413
+
1414
+ โญ Advanced: Optimization Techniques
1415
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1416
+ 8. LR Schedules: cosine vs linear vs step, warmup/cooldown
1417
+ 9. Gradient Accumulation & Large Batch Training
1418
+ 10. ฮผP (Maximal Update Parameterization): transfer HP across scales
1419
+
1420
+ Recommended Experiments (in order):
1421
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1422
+ 1. Single-batch overfit (30 min) โ†’ basic sanity
1423
+ 2. LR Range Test (1 hour) โ†’ optimal LR range
1424
+ 3. 10M model quick train (2-3 hrs) โ†’ pipeline validation
1425
+ 4. Ablation (remove components) (1 day) โ†’ component contribution
1426
+ 5. 100M model + 5B tokens (1-2 days)โ†’ mid-scale dynamics
1427
+ 6. 1B model full training (2-3 days)โ†’ scaling law verification
1428
+ 7. LR / batch size comparison (1 day) โ†’ HP sensitivity
1429
+
1430
+ Key References:
1431
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
1432
+ โญโญโญ Karpathy "Recipe for Training NNs" โ€” debugging mindset
1433
+ โญโญโญ Hoffmann et al. 2022 (Chinchilla) โ€” scaling law
1434
+ โญโญ Touvron et al. 2023 (LLaMA) โ€” 1B+ training details
1435
+ โญโญ Biderman et al. 2023 (Pythia) โ€” open training logs
1436
+ โญโญ Zhang et al. 2024 (TinyLlama) โ€” 1.1B on 3T tokens
1437
+ โญโญ Groeneveld et al. 2024 (OLMo) โ€” fully open LLM
1438
+ โญโญ Li et al. 2018 (Loss Landscape) โ€” loss terrain intuition
1439
+ โญโญ Loshchilov & Hutter 2019 (AdamW) โ€” optimizer basics
1440
+ โญ Yang et al. 2022 (ฮผP) โ€” HP transfer
1441
+ โญ McCandlish et al. 2018 (Batch size) โ€” critical batch size
1442
+ """)
notebooks/05_debugging.ipynb ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 4,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10.0"
13
+ }
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "# 05. Loss Debugging (5-Level Diagnostic Framework)\n",
21
+ "\n",
22
+ "ํ•™์Šต ์ค‘ loss๊ฐ€ ์˜ˆ์ƒ๋Œ€๋กœ ์ค„์–ด๋“ค์ง€ ์•Š์„ ๋•Œ, ์ฒด๊ณ„์ ์œผ๋กœ ์›์ธ์„ ์ง„๋‹จํ•˜๋Š” ํ”„๋ ˆ์ž„์›Œํฌ์ž…๋‹ˆ๋‹ค.\n",
23
+ "\n",
24
+ "**ํ•ญ์ƒ ๋‚ฎ์€ ๋ ˆ๋ฒจ๋ถ€ํ„ฐ ์ ๊ฒ€ํ•˜์„ธ์š”** โ€” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํŠœ๋‹ํ•˜๊ธฐ ์ „์— ๋ฐ์ดํ„ฐ ๋ฒ„๊ทธ๋ฅผ ๋จผ์ € ์žก์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.\n",
25
+ "\n",
26
+ "```\n",
27
+ "Level 0: Status Diagnosis โ† ํ˜„์žฌ ํ•™์Šต ์ƒํƒœ ๋ถ„๋ฅ˜ (6๊ฐ€์ง€)\n",
28
+ "Level 1: Data / Implementation โ† ๊ฐ€์žฅ ํ”ํ•œ ์›์ธ (70%)\n",
29
+ "Level 2: Numerical Stability โ† NaN/Inf, activation ํญ๋ฐœ\n",
30
+ "Level 3: Hyperparameters โ† LR, batch size, warmup\n",
31
+ "Level 4: Fitting Diagnosis โ† overfitting vs underfitting\n",
32
+ "Level 5: Architecture โ† ์ดˆ๊ธฐํ™”, ๋ ˆ์ด์–ด๋ณ„ activation\n",
33
+ "```"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "# ์ถ”๊ฐ€ ํŒจํ‚ค์ง€ ์„ค์น˜ ๋ถˆํ•„์š”\n",
43
+ "# LossDebugger๋Š” torch์™€ llm_lab ๋‚ด์žฅ ๋ชจ๋“ˆ๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import sys\n",
53
+ "sys.path.insert(0, '..')\n",
54
+ "\n",
55
+ "import math\n",
56
+ "import torch\n",
57
+ "\n",
58
+ "from llm_lab.config import ModelConfig, DataConfig, TrainConfig\n",
59
+ "from llm_lab.model import LLMModel\n",
60
+ "from llm_lab.data import setup_data_pipeline\n",
61
+ "from llm_lab.training import LossDebugger\n",
62
+ "from llm_lab.utils import auto_configure, get_device"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "## 0. ์„ค์ •\n",
70
+ "\n",
71
+ "`debug_10m` ํ”„๋ฆฌ์…‹์„ ์‚ฌ์šฉํ•˜์—ฌ CPU์—์„œ๋„ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค."
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# --- Config ---\n",
81
+ "model_config = ModelConfig.debug_10m()\n",
82
+ "data_config = DataConfig(\n",
83
+ " max_seq_len=model_config.max_seq_len,\n",
84
+ " batch_size=4,\n",
85
+ ")\n",
86
+ "train_config = TrainConfig.debug_10m()\n",
87
+ "\n",
88
+ "# --- Device / dtype ---\n",
89
+ "train_config = auto_configure(train_config)\n",
90
+ "device = get_device()\n",
91
+ "dtype = train_config.torch_dtype\n",
92
+ "\n",
93
+ "vocab_size = data_config.vocab_size\n",
94
+ "print(f\"Device: {device}, dtype: {dtype}\")\n",
95
+ "print(f\"Vocab size: {vocab_size:,}\")\n",
96
+ "print(f\"Expected initial loss: ln({vocab_size}) = {math.log(vocab_size):.2f}\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "# --- Model ---\n",
106
+ "model = LLMModel(model_config).to(device)\n",
107
+ "print(f\"Model parameters: {model.count_parameters():,}\")\n",
108
+ "\n",
109
+ "# --- Data pipeline ---\n",
110
+ "tokenizer, train_dl, val_dl = setup_data_pipeline(\n",
111
+ " tokenizer_mode=\"pretrained\",\n",
112
+ " config=data_config,\n",
113
+ ")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "### 0.1 ํ•™์Šต ์ด๋ ฅ (Mock Metrics History)\n",
121
+ "\n",
122
+ "์‹ค์ œ ํ•™์Šต ์—†์ด๋„ Level 0 / 3 / 4 / ์‹œ๋‚˜๋ฆฌ์˜ค ๊ฐ์ง€๋ฅผ ํ…Œ์ŠคํŠธํ•  ์ˆ˜ ์žˆ๋„๋ก\n",
123
+ "mock `metrics_history`๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.\n",
124
+ "\n",
125
+ "์‹ค์ œ ํ•™์Šต ํ›„์—๋Š” `trainer.metrics.history`๋ฅผ ๋Œ€์‹  ์‚ฌ์šฉํ•˜์„ธ์š”:\n",
126
+ "```python\n",
127
+ "# metrics_history = trainer.metrics.history\n",
128
+ "```"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": "import random\nrandom.seed(42)\n\nexpected_initial = math.log(vocab_size) # ~10.37\n\n# --- Scenario A: loss stuck near ln(vocab_size) ---\n# loss๊ฐ€ ๊ฑฐ์˜ ์ค„์–ด๋“ค์ง€ ์•Š๋Š” ์ƒํ™ฉ (๋ฐ์ดํ„ฐ/๊ตฌํ˜„ ๋ฒ„๊ทธ ์˜์‹ฌ)\nn_steps_a = 200\nmock_history_a = {\n \"step\": list(range(1, n_steps_a + 1)),\n \"train_loss\": [expected_initial - 0.01 * i + random.uniform(-0.05, 0.05)\n for i in range(n_steps_a)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_a)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_a)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_a)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_a)],\n \"val_loss\": [expected_initial + random.uniform(-0.1, 0.1)\n for _ in range(0, n_steps_a, 50)],\n \"val_ppl\": [math.exp(expected_initial) + random.uniform(-50, 50)\n for _ in range(0, n_steps_a, 50)],\n}\nprint(f\"Mock A โ€” train_loss: {mock_history_a['train_loss'][0]:.2f} -> {mock_history_a['train_loss'][-1]:.2f}\")\nprint(f\" Expected: NO_DECREASE (loss barely changes)\")\n\n# --- Scenario B: loss decreasing then NaN ---\nn_steps_b = 200\nmock_history_b = {\n \"step\": list(range(1, n_steps_b + 1)),\n \"train_loss\": [expected_initial - 0.03 * i + random.uniform(-0.05, 0.05)\n for i in range(150)]\n + [float('nan')] * 50,\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_b)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(145)]\n + [random.uniform(5.0, 50.0) for _ in range(5)]\n + [float('nan')] * 50,\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_b)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_b)],\n \"val_loss\": [expected_initial - 0.03 * i + random.uniform(-0.1, 0.1)\n for i in range(0, n_steps_b, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.03 * i)\n for i in range(0, n_steps_b, 50)],\n}\nprint(f\"\\nMock B โ€” train_loss starts normal, then NaN at step ~150\")\nprint(f\" Expected: Scenario B (NaN detected)\")\n\n# --- Scenario C: loss decreased then increased ---\nn_steps_c = 200\nmock_history_c = {\n \"step\": list(range(1, n_steps_c + 1)),\n \"train_loss\": [expected_initial - 0.04 * i + random.uniform(-0.05, 0.05)\n for i in range(120)]\n + [expected_initial - 0.04 * 120 + 0.02 * (i - 120) + random.uniform(-0.05, 0.05)\n for i in range(120, n_steps_c)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_c)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_c)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_c)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_c)],\n \"val_loss\": [expected_initial - 0.02 * i + random.uniform(-0.1, 0.1)\n for i in range(0, 120, 50)]\n + [expected_initial - 0.02 * 120 + 0.03 * (i - 120) + random.uniform(-0.1, 0.1)\n for i in range(120, n_steps_c, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.02 * i)\n for i in range(0, n_steps_c, 50)],\n}\nprint(f\"\\nMock C โ€” train_loss: decrease โ†’ increase (bounce)\")\nprint(f\" Expected: Scenario C (loss recovery)\")\n\n# --- Scenario D: loss stuck at high value (4.0) ---\nn_steps_d = 200\nmock_history_d = {\n \"step\": list(range(1, n_steps_d + 1)),\n \"train_loss\": [4.0 + random.uniform(-0.1, 0.1) for _ in range(n_steps_d)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_d)],\n \"grad_norm\": [random.uniform(0.1, 0.3) for _ in range(n_steps_d)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_d)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_d)],\n \"val_loss\": [4.2 + random.uniform(-0.1, 0.1)\n for _ in range(0, n_steps_d, 50)],\n \"val_ppl\": [math.exp(4.2) + random.uniform(-5, 5)\n for _ in range(0, n_steps_d, 50)],\n}\nprint(f\"\\nMock D โ€” train_loss stuck at ~4.0\")\nprint(f\" Expected: Scenario D (plateau at high value)\")\n\n# --- Normal: loss decreasing normally ---\nn_steps_n = 200\nmock_history_normal = {\n \"step\": list(range(1, n_steps_n + 1)),\n \"train_loss\": [expected_initial - 0.03 * i + random.uniform(-0.05, 0.05)\n for i in range(n_steps_n)],\n \"learning_rate\": [1e-3 * min(i / 200, 1.0) for i in range(n_steps_n)],\n \"grad_norm\": [random.uniform(0.3, 0.8) for _ in range(n_steps_n)],\n \"tokens_per_sec\":[random.uniform(8000, 12000) for _ in range(n_steps_n)],\n \"gpu_mem_gb\": [random.uniform(1.0, 2.0) for _ in range(n_steps_n)],\n \"val_loss\": [expected_initial - 0.03 * i + random.uniform(-0.1, 0.1)\n for i in range(0, n_steps_n, 50)],\n \"val_ppl\": [math.exp(expected_initial - 0.03 * i)\n for i in range(0, n_steps_n, 50)],\n}\nprint(f\"\\nMock Normal โ€” train_loss: {mock_history_normal['train_loss'][0]:.2f} -> {mock_history_normal['train_loss'][-1]:.2f}\")\nprint(f\" Expected: NORMAL (loss decreasing steadily)\")"
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "metadata": {},
141
+ "source": [
142
+ "## 1. Level 0 โ€” Status Diagnosis (์ƒํƒœ ์ง„๋‹จ)\n",
143
+ "\n",
144
+ "`metrics_history`๋งŒ์œผ๋กœ ํ˜„์žฌ ํ•™์Šต ์ƒํƒœ๋ฅผ 6๊ฐ€์ง€ ์ค‘ ํ•˜๋‚˜๋กœ ๋ถ„๋ฅ˜ํ•ฉ๋‹ˆ๋‹ค:\n",
145
+ "\n",
146
+ "| Status | ์˜๋ฏธ | ์‹ฌ๊ฐ๋„ |\n",
147
+ "|--------|------|--------|\n",
148
+ "| `NORMAL` | ์ •์ƒ ํ•™์Šต ์ค‘ | green |\n",
149
+ "| `NO_DECREASE` | loss๊ฐ€ ์ค„์ง€ ์•Š์Œ | red |\n",
150
+ "| `DIVERGING` | loss๊ฐ€ ๋ฐœ์‚ฐ | red |\n",
151
+ "| `PLATEAU` | loss๊ฐ€ ๋†’์€ ๊ฐ’์—์„œ ์ •์ฒด | yellow |\n",
152
+ "| `OVERFITTING` | train๏ฟฝ๏ฟฝ valโ†‘ | yellow |\n",
153
+ "| `UNSTABLE` | loss ๋ณ€๋™์ด ํผ | yellow |"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "# Scenario A (๋ฌธ์ œ ์ƒํ™ฉ)\n",
163
+ "status_a = LossDebugger.diagnose_status(vocab_size, mock_history_a)\n",
164
+ "print(f\"\\n>>> Result: {status_a['status']}\")\n",
165
+ "\n",
166
+ "print(\"\\n\" + \"-\" * 40)\n",
167
+ "\n",
168
+ "# Normal (์ •์ƒ ์ƒํ™ฉ)\n",
169
+ "status_n = LossDebugger.diagnose_status(vocab_size, mock_history_normal)\n",
170
+ "print(f\"\\n>>> Result: {status_n['status']}\")"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "## 2. Level 1 โ€” Data / Implementation Checks (๋ฐ์ดํ„ฐยท๊ตฌํ˜„ ์ ๊ฒ€)\n",
178
+ "\n",
179
+ "Loss ๋ฌธ์ œ์˜ **70%๋Š” ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ๋ฒ„๊ทธ**์—์„œ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค. 6๊ฐ€์ง€๋ฅผ ์ฒดํฌํ•ฉ๋‹ˆ๋‹ค:\n",
180
+ "\n",
181
+ "1. **Shift ๊ด€๊ณ„** โ€” `targets[t] == input_ids[t+1]` ์ธ์ง€ ํ™•์ธ\n",
182
+ "2. **ํ† ํฐ ๋ฒ”์œ„** โ€” `0 <= ids < vocab_size` ์ธ์ง€ ํ™•์ธ\n",
183
+ "3. **์ดˆ๊ธฐ loss** โ€” ๋žœ๋ค ๊ฐ€์ค‘์น˜์—์„œ `โ‰ˆ ln(vocab_size)` ์ธ์ง€ ํ™•์ธ\n",
184
+ "4. **๋‹จ์ผ ๋ฐฐ์น˜ ์˜ค๋ฒ„ํ”ผํŒ…** โ€” ํ•œ ๋ฐฐ์น˜๋ฅผ 200์Šคํ… ๋ฐ˜๋ณต โ†’ loss โ‰ˆ 0 ๋„๋‹ฌ ํ™•์ธ\n",
185
+ "5. **ํ† ํฌ๋‚˜์ด์ € ์™•๋ณต** โ€” encode โ†’ decode ์‹œ ํ…์ŠคํŠธ ๋ณด์กด ํ™•์ธ\n",
186
+ "6. **๋ฐ์ดํ„ฐ ํ’ˆ์งˆ** โ€” ์ƒ˜ํ”Œ ํ…์ŠคํŠธ ์œก์•ˆ ํ™•์ธ"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "level1 = LossDebugger.check_data_pipeline(\n",
196
+ " model=model,\n",
197
+ " dataloader=train_dl,\n",
198
+ " tokenizer=tokenizer,\n",
199
+ " vocab_size=vocab_size,\n",
200
+ " device=device,\n",
201
+ " dtype=dtype,\n",
202
+ ")\n",
203
+ "\n",
204
+ "print(f\"\\nPassed: {len(level1['passed'])}, Failed: {len(level1['failed'])}\")\n",
205
+ "for f in level1['failed']:\n",
206
+ " print(f\" FAILED: {f['name']} โ€” {f['detail']}\")"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "markdown",
211
+ "metadata": {},
212
+ "source": "## 3. Level 2 โ€” Numerical Stability (์ˆ˜์น˜ ์•ˆ์ •์„ฑ)\n\nForward + Backward 1ํšŒ๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๋‹ค์Œ์„ ์ ๊ฒ€ํ•ฉ๋‹ˆ๋‹ค:\n\n- **Mixed Precision ์„ค์ •** โ€” RMSNorm์ด fp32๋กœ ์—ฐ์‚ฐํ•˜๋Š”์ง€, loss dtype ํ™•์ธ\n- **Gradient** โ€” NaN/Inf/large gradient ์—ฌ๋ถ€\n- **Activation** โ€” ๊ฐ transformer ๋ ˆ์ด์–ด์˜ mean/std/max\n- **Logit scale** โ€” ์ถœ๋ ฅ logit์ด ํ•ฉ๋ฆฌ์  ๋ฒ”์œ„์ธ์ง€ (< 1000)\n\n### Mixed Precision Best Practices\n\n```python\n# โŒ ์œ„ํ—˜: bf16์—์„œ ํฐ ๊ฐ’ + ์ž‘์€ ๊ฐ’ ๋ง์…ˆ โ†’ ์ž‘์€ ๊ฐ’ ์†Œ์‹ค\n# โœ… ํ•ด๊ฒฐ: RMSNorm ๋‚ด๋ถ€๋Š” float32๋กœ ๊ณ„์‚ฐ\nx_float = x.float() # bf16 โ†’ fp32\nrms = torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + eps)\nreturn (x_float * rms).to(x.dtype) * self.weight\n\n# โœ… Loss ๊ณ„์‚ฐ๋„ float32๋กœ:\nlogits_fp32 = logits.float()\nloss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))\n```\n\n### ํ”ํ•œ ์ˆ˜์น˜ ๋ฌธ์ œ\n\n| ์ฆ์ƒ | ์›์ธ | ํ•ด๊ฒฐ |\n|------|------|------|\n| Loss โ†’ NaN | Softmax์— ๋งค์šฐ ํฐ ๊ฐ’ ์ž…๋ ฅ | Logit ์Šค์ผ€์ผ ํ™•์ธ, ์ดˆ๊ธฐํ™” ๊ฒ€ํ†  |\n| Loss โ†’ Inf | 0์— ๋Œ€ํ•œ log ์—ฐ์‚ฐ | eps ์ถ”๊ฐ€, ignore_index ์„ค์ • |\n| Loss ์ง„๋™ ์‹ฌํ•จ | fp16์—์„œ gradient underflow | bf16 ์ „ํ™˜ ๋˜๋Š” GradScaler ์‚ฌ์šฉ |\n| ํ•™์Šต ํ›„๋ฐ˜ NaN | ํ™œ์„ฑํ™” ๊ฐ’ ์ ์ง„์  ์ฆ๊ฐ€ | RMSNorm ์ž‘๋™ ํ™•์ธ, weight decay ์ ์šฉ ํ™•์ธ |"
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "level2 = LossDebugger.check_numerical_stability(\n",
221
+ " model=model,\n",
222
+ " dataloader=train_dl,\n",
223
+ " device=device,\n",
224
+ " dtype=dtype,\n",
225
+ ")"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": "## 4. Level 3 โ€” Hyperparameter Diagnosis (ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ง„๋‹จ)\n\n`metrics_history`์™€ `TrainConfig`๋ฅผ ๋ถ„์„ํ•˜์—ฌ ๋‹ค์Œ์„ ์ ๊ฒ€ํ•ฉ๋‹ˆ๋‹ค:\n\n- **LR**: grad norm์ด clip limit์— ์ž์ฃผ ๋„๋‹ฌํ•˜๋ฉด โ†’ LR์ด ๋„ˆ๋ฌด ๋†’์Œ\n- **Batch size**: loss ๋ณ€๋™์ด ํฌ๋ฉด โ†’ batch๊ฐ€ ๋„ˆ๋ฌด ์ž‘์Œ\n- **Warmup**: ์ดˆ๊ธฐ loss spike๊ฐ€ ๋งŽ์œผ๋ฉด โ†’ warmup์ด ๋„ˆ๋ฌด ์งง์Œ\n- **ฮฒโ‚‚**: PyTorch ๊ธฐ๋ณธ๊ฐ’(0.999) ๋Œ€์‹  LLM ํ‘œ์ค€(0.95) ์‚ฌ์šฉ ์—ฌ๋ถ€\n- **Weight Decay**: 0์ด๋ฉด ์˜ค๋ฒ„ํ”ผํŒ… ์œ„ํ—˜, ํ‘œ์ค€์€ 0.1\n\n### ํŠœ๋‹ ์šฐ์„ ์ˆœ์œ„ (์˜ํ–ฅ๋„ ์ˆœ)\n\n| ์ˆœ์œ„ | ํŒŒ๋ผ๋ฏธํ„ฐ | ์˜ํ–ฅ๋ ฅ | ๋น„๊ณ  |\n|------|----------|--------|------|\n| 1 | **Learning Rate** | 10x | ๋ฐ˜๋“œ์‹œ ๋จผ์ € ํŠœ๋‹ |\n| 2 | **Batch Size** | ๋†’์Œ | LR๊ณผ ํ•จ๊ป˜ ์กฐ์ • |\n| 3 | **Warmup Steps** | ์ค‘๊ฐ„ | ์ดˆ๋ฐ˜ ์•ˆ์ •์„ฑ |\n| 4 | **Weight Decay** | ์ค‘๊ฐ„ | ์˜ค๋ฒ„ํ”ผํŒ… ์‹œ ์กฐ์ • (๋ณดํ†ต 0.1 ๊ณ ์ •) |\n| 5 | **ฮฒโ‚, ฮฒโ‚‚** | ๋‚ฎ์Œ | ฮฒโ‚‚=0.95 ๊ถŒ์žฅ (LLaMA, TinyLlama, OLMo) |\n| 6 | **Gradient Clip** | ๋‚ฎ์Œ | ๋ณดํ†ต 1.0 ๊ณ ์ • |\n\n### ฮฒโ‚‚ ์„ ํƒ ๊ฐ€์ด๋“œ\n\n- **PyTorch ๊ธฐ๋ณธ๊ฐ’** ฮฒโ‚‚=0.999 โ†’ ์ตœ๊ทผ ~1000 step์˜ gradient ํ†ต๊ณ„ ๋ฐ˜์˜\n- **LLM ํ‘œ์ค€** ฮฒโ‚‚=0.95 โ†’ ์ตœ๊ทผ ~20 step์˜ gradient ํ†ต๊ณ„ ๋ฐ˜์˜\n- LLM ํ•™์Šต์€ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ๊ฐ€ ๊ณ„์† ๋ฐ”๋€Œ๋ฏ€๋กœ ๋น ๋ฅธ ์ ์‘(๋‚ฎ์€ ฮฒโ‚‚)๏ฟฝ๏ฟฝ๏ฟฝ ์œ ๋ฆฌ\n- ๋‚ฎ์€ ฮฒโ‚‚๊ฐ€ loss spike ์™„ํ™”์— ๋„์›€ (Cattaneo & Shigida 2025)\n\n### Batch-LR ์Šค์ผ€์ผ๋ง\n\n- Batch ร—2 โ†’ LR ร—โˆš2 (square root scaling, ๋” ์•ˆ์ „)\n- Batch ร—2 โ†’ LR ร—2 (linear scaling, GPT-3 ๋“ฑ์—์„œ ์‚ฌ์šฉ)\n- 1B ๋ชจ๋ธ ๊ธฐ์ค€ effective batch 64~512์ด ์ผ๋ฐ˜์ "
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "level3 = LossDebugger.diagnose_hyperparameters(\n",
240
+ " metrics_history=mock_history_a,\n",
241
+ " config=train_config,\n",
242
+ ")"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {},
248
+ "source": [
249
+ "### 4.1 LR Range Test\n",
250
+ "\n",
251
+ "LR์„ `1e-7` โ†’ `1e-1`๋กœ ์ง€์ˆ˜์ ์œผ๋กœ ์ฆ๊ฐ€์‹œํ‚ค๋ฉฐ loss๋ฅผ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.\n",
252
+ "loss๊ฐ€ ๊ฐ€์žฅ ๋น ๋ฅด๊ฒŒ ์ค„์–ด๋“œ๋Š” ์ง€์ ์˜ LR รท 3์ด ๊ถŒ์žฅ peak LR์ž…๋‹ˆ๋‹ค.\n",
253
+ "\n",
254
+ "> ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค (debug_10m ๊ธฐ์ค€ ~1๋ถ„).\n",
255
+ "> ์‹ค์ œ ํ•™์Šต ์ „ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค."
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "lr_result = LossDebugger.lr_range_test(\n",
265
+ " model=model,\n",
266
+ " dataloader=train_dl,\n",
267
+ " device=device,\n",
268
+ " dtype=dtype,\n",
269
+ " steps=100, # debug_10m: 100 steps for speed\n",
270
+ ")\n",
271
+ "\n",
272
+ "print(f\"\\nSuggested peak LR: {lr_result['suggested_lr']:.2e}\")"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "markdown",
277
+ "metadata": {},
278
+ "source": "## 5. Level 4 โ€” Fitting Diagnosis (ํ”ผํŒ… ์ง„๋‹จ)\n\nTrain/Val loss ์ถ”์„ธ๋ฅผ ๋น„๊ตํ•˜์—ฌ 4๊ฐ€์ง€ ์ผ€์ด์Šค๋ฅผ ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค:\n\n| Case | Train | Val | ์ง„๋‹จ |\n|------|-------|-----|------|\n| 1 | โ†“ | โ†“ | Normal (ํ•™์Šต ์ง„ํ–‰ ์ค‘) |\n| 2 | โ†’ | โ†’ | Underfitting (๋ชจ๋ธ/๋ฐ์ดํ„ฐ ๋ถ€์กฑ) |\n| 3 | โ†“ | โ†‘ | Overfitting (๋ฐ์ดํ„ฐ ๋ฐ˜๋ณต ์˜์‹ฌ) |\n| 4 | โ†’ | โ†’ (๋‚ฎ์Œ) | Converged (์ˆ˜๋ ด ์™„๋ฃŒ) |\n\n### ์–ธ๋”ํ”ผํŒ… ์›์ธ ๋ถ„์„ ์ˆœ์„œ\n\n1. **๋ชจ๋ธ ์šฉ๋Ÿ‰ ๋ถ€์กฑ?** โ†’ 2x ํฐ ๋ชจ๋ธ๋กœ ๊ฐ™์€ ๋ฐ์ดํ„ฐ ํ•™์Šต โ†’ loss ๋‚ฎ์•„์ง€๋ฉด ์šฉ๋Ÿ‰ ๋ฌธ์ œ\n2. **ํ•™์Šต ๋ถˆ์ถฉ๋ถ„?** โ†’ loss ๊ณก์„ ์ด ์•„์ง ๊ฐ์†Œ ์ถ”์„ธ์ธ์ง€ ํ™•์ธ (Chinchilla: 1B โ†’ ~20B ํ† ํฐ)\n3. **LR์ด ๋„ˆ๋ฌด ์ž‘์•„์„œ ์ˆ˜๋ ด์ด ๋А๋ฆฐ ๊ฒƒ?** โ†’ LR ร—2๋กœ ์‹คํ—˜\n4. **๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ๋ฌธ์ œ?** โ†’ ์ง์ ‘ ์ƒ˜ํ”Œ๋งํ•ด์„œ ์ฝ์–ด๋ณด๊ธฐ\n\n### ์˜ค๋ฒ„ํ”ผํŒ… ์›์ธ ๋ถ„์„ ์ˆœ์„œ\n\n1. **๋ฐ์ดํ„ฐ ๋ถ€์กฑ/๋ฐ˜๋ณต?** โ†’ epoch > 1์ด๋ฉด ์˜ค๋ฒ„ํ”ผํŒ… ์œ„ํ—˜ ๊ธ‰์ฆ\n2. **Weight Decay ๋ถ€์กฑ?** โ†’ 0.1์ด ํ‘œ์ค€ (LLaMA, TinyLlama, GPT-3, OLMo)\n3. **๋ฐ์ดํ„ฐ ๋‹ค์–‘์„ฑ ๋ถ€์กฑ?** โ†’ ๋‹ค์–‘ํ•œ ๋„๋ฉ”์ธ ํ˜ผํ•ฉ ํ•„์š”\n\n### LLM Pretraining์—์„œ ์˜ค๋ฒ„ํ”ผํŒ…์— ๋Œ€ํ•œ ์ค‘์š” ์‚ฌ์‹ค\n\n- **1 ์—ํญ ์ด๋‚ด** ํ•™์Šต ์‹œ ์˜ค๋ฒ„ํ”ผํŒ…์€ ๋งค์šฐ ๋“œ๋ฌพ\n- ์˜ค๋ฒ„ํ”ผํŒ…์ด ๋ณด์ด๋ฉด ๋Œ€๋ถ€๋ถ„ **๋ฐ์ดํ„ฐ ๋ฐ˜๋ณต**(์—ํญ > 1)์ด ์›์ธ\n- **Dropout**์€ ํ˜„๋Œ€ LLM pretraining์—์„œ **๊ฑฐ์˜ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ**\n - Pythia, TinyLlama, OLMo, LLaMA ๋ชจ๋‘ dropout=0\n - ์ถฉ๋ถ„ํ•œ ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ์œผ๋ฉด ๋ฐ์ดํ„ฐ ์ž์ฒด๊ฐ€ ์ตœ๊ณ ์˜ ์ •๊ทœํ™”\n - Dropout์€ fine-tuning์—์„œ ์†Œ๋Ÿ‰ ๋ฐ์ดํ„ฐ ํ•™์Šต ์‹œ ์œ ํšจ"
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "model_params = sum(p.numel() for p in model.parameters())\n",
287
+ "total_tokens = len(mock_history_normal[\"train_loss\"]) * train_config.tokens_per_step\n",
288
+ "\n",
289
+ "level4 = LossDebugger.diagnose_fitting(\n",
290
+ " metrics_history=mock_history_normal,\n",
291
+ " model_params=model_params,\n",
292
+ " total_tokens=total_tokens,\n",
293
+ ")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "markdown",
298
+ "metadata": {},
299
+ "source": "## 6. Level 5 โ€” Architecture / Initialization Check (์•„ํ‚คํ…์ฒ˜ ์ ๊ฒ€)\n\n๋ชจ๋ธ์— ์ž…๋ ฅ์„ ํ•œ ๋ฒˆ ํ†ต๊ณผ์‹œํ‚ค๋ฉฐ **๋ ˆ์ด์–ด๋ณ„ activation ํ†ต๊ณ„**์™€ **๊ฐ€์ค‘์น˜ ๋ถ„ํฌ**๋ฅผ ์ˆ˜์ง‘ํ•ฉ๋‹ˆ๋‹ค.\n\n### Activation ์ง„๋‹จ\n\n- **healthy**: std๊ฐ€ ๋ ˆ์ด์–ด ์ „๋ฐ˜์— ๊ฑธ์ณ ์•ˆ์ •์ \n- **exploding**: std๊ฐ€ ๋’ค์ชฝ ๋ ˆ์ด์–ด๋กœ ๊ฐˆ์ˆ˜๋ก ๊ธ‰๊ฒฉํžˆ ์ฆ๊ฐ€ โ†’ ์ดˆ๊ธฐํ™” ์Šค์ผ€์ผ์ด ๋„ˆ๋ฌด ํผ\n- **vanishing**: std๊ฐ€ ๋’ค์ชฝ ๋ ˆ์ด์–ด๋กœ ๊ฐˆ์ˆ˜๋ก ๊ธ‰๊ฒฉํžˆ ๊ฐ์†Œ โ†’ ์ดˆ๊ธฐํ™” ์Šค์ผ€์ผ์ด ๋„ˆ๋ฌด ์ž‘์Œ\n- **anomaly**: ํŠน์ • ๋ ˆ์ด์–ด์—์„œ ๊ฐ‘์ž‘์Šค๋Ÿฌ์šด ๋ณ€ํ™” โ†’ ํ•ด๋‹น ๋ ˆ์ด์–ด ๊ตฌํ˜„ ๋ฒ„๊ทธ\n\n### Weight Initialization ์ง„๋‹จ\n\nGPT-2 ์Šคํƒ€์ผ ์ดˆ๊ธฐํ™”:\n- **์ผ๋ฐ˜ Linear**: `N(0, 0.02)`\n- **Residual projection** (o_proj, down_proj): `N(0, 0.02/โˆš(2ร—layers))`\n โ†’ ๊นŠ์€ ๋ชจ๋ธ์—์„œ ์ž”์ฐจ ๊ธฐ์—ฌ๋ฅผ ์ค„์—ฌ ์•ˆ์ •์„ฑ ํ™•๋ณด\n\n### Ablation Study (์ปดํฌ๋„ŒํŠธ๋ณ„ ์˜ํ–ฅ ํ™•์ธ)\n\n| ์‹คํ—˜ | ์˜ˆ์ƒ ๊ฒฐ๊ณผ | ์ด์ƒ ์‹œ |\n|------|----------|---------|\n| RMSNorm โ†’ LayerNorm | Loss ์ฐจ์ด ๋ฏธ๋ฏธ | ์ •๊ทœํ™” ๊ตฌํ˜„ ๋ฒ„๊ทธ |\n| RoPE โ†’ ์ ˆ๋Œ€ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ | ์งง์€ ์‹œํ€€์Šค์—์„œ ์ฐจ์ด ์ž‘์Œ | RoPE ๊ตฌํ˜„ ํ™•์ธ |\n| SwiGLU โ†’ ReLU FFN | Loss +0.05~0.15 | SwiGLU ๊ตฌํ˜„ ํ™•์ธ |\n| GQA โ†’ MHA | Loss ๊ฑฐ์˜ ๋™์ผ (๋ฉ”๋ชจ๋ฆฌ๋งŒ ์ฐจ์ด) | KV repeat ๋ฒ„๊ทธ |"
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "level5 = LossDebugger.check_architecture(\n",
308
+ " model=model,\n",
309
+ " dataloader=train_dl,\n",
310
+ " device=device,\n",
311
+ ")\n",
312
+ "\n",
313
+ "print(f\"\\n>>> Diagnosis: {level5['diagnosis']}\")"
314
+ ]
315
+ },
316
+ {
317
+ "cell_type": "markdown",
318
+ "metadata": {},
319
+ "source": "## 7. Scenario Auto-Detection (์‹œ๋‚˜๋ฆฌ์˜ค ์ž๋™ ๊ฐ์ง€)\n\n`metrics_history`๋ฅผ ๋ถ„์„ํ•˜์—ฌ ๋‹ค์Œ 4๊ฐ€์ง€ ์‹œ๋‚˜๋ฆฌ์˜ค ์ค‘ ์–ด๋””์— ํ•ด๋‹นํ•˜๋Š”์ง€ ์ž๋™์œผ๋กœ ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค:\n\n| Scenario | ์ฆ์ƒ | ์ฃผ์š” ์›์ธ | ํ•ต์‹ฌ ์ง„๋‹จ |\n|----------|------|----------|----------|\n| **A** | Loss๊ฐ€ ~10.37์—์„œ ์•ˆ ์ค„์–ด๋“ฆ | ๋ฐ์ดํ„ฐ/๊ตฌํ˜„ ๋ฒ„๊ทธ | ๋‹จ์ผ ๋ฐฐ์น˜ ์˜ค๋ฒ„ํ”ผํŒ… ํ…Œ์ŠคํŠธ |\n| **B** | Loss ๊ฐ์†Œ ์ค‘ ๊ฐ‘์ž๊ธฐ NaN | ์ˆ˜์น˜ ๋ถˆ์•ˆ์ • | NaN ์ง์ „ grad norm ํ™•์ธ |\n| **C** | Loss ๊ฐ์†Œ ํ›„ ๋‹ค์‹œ ์ฆ๊ฐ€ | LR/๋ฐ์ดํ„ฐ ๋ฌธ์ œ | Train/Val ๋™์‹œ ํ™•์ธ |\n| **D** | Loss๊ฐ€ ๋†’์€ ๊ฐ’์—์„œ ์ •์ฒด | ํ•™์Šต ๋ถ€์กฑ/LR ๋ฌธ์ œ | ํ† ํฐ ์ˆ˜ ํ™•์ธ, LR Range Test |\n\n### ์‹œ๋‚˜๋ฆฌ์˜ค๋ณ„ ์ง„๋‹จ ํฌ์ธํŠธ\n\n**์‹œ๋‚˜๋ฆฌ์˜ค A**: \"Loss๊ฐ€ 10.37์—์„œ ์ „ํ˜€ ์•ˆ ์ค„์–ด์š”\"\n1. ๋‹จ์ผ ๋ฐฐ์น˜ ์˜ค๋ฒ„ํ”ผํŒ… โ†’ ์‹คํŒจํ•˜๋ฉด ๋ชจ๋ธ/Loss ๋ฒ„๊ทธ\n2. gradient๊ฐ€ 0์ธ์ง€ ํ™•์ธ โ†’ `optimizer.step()` ๋ˆ„๋ฝ?\n3. `input_ids/targets` shift ํ™•์ธ โ†’ ๋ฐ์ดํ„ฐ ํŒŒ์ดํ”„๋ผ์ธ ๋ฒ„๊ทธ\n4. LR ํ™•์ธ โ†’ 0์œผ๋กœ ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์€์ง€\n5. `model.train()` ํ˜ธ์ถœ ํ™•์ธ\n\n**์‹œ๋‚˜๋ฆฌ์˜ค B**: \"Loss๊ฐ€ ์ค„๋‹ค๊ฐ€ ๊ฐ‘์ž๊ธฐ NaN\"\n1. ํ•ด๋‹น step์˜ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ ํ™•์ธ\n2. NaN ์ง์ „์˜ gradient norm spike ํ™•์ธ\n3. LR ์Šค์ผ€์ค„๊ณผ NaN ์‹œ์  ๋น„๊ต\n4. mixed precision ๋ฌธ์ œ โ†’ fp32๋กœ ์ „ํ™˜ํ•˜์—ฌ ์žฌํ˜„ ์‹œ๋„\n - (Pythia-1B: fp16 โ†’ bf16 ์ „ํ™˜ ์‚ฌ๋ก€, Biderman et al. 2023)\n\n**์‹œ๋‚˜๋ฆฌ์˜ค C**: \"Loss๊ฐ€ 3.5๊นŒ์ง€ ์ค„์—ˆ๋‹ค๊ฐ€ ๋‹ค์‹œ ์˜ฌ๋ผ๊ฐ\"\n1. Train/Val ๋™์‹œ ํ™•์ธ:\n - ๋‘˜ ๋‹ค ์˜ค๋ฅด๋ฉด โ†’ LR์ด ๋„ˆ๋ฌด ํผ\n - Train๋งŒ ์˜ค๋ฅด๋ฉด โ†’ ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ๋ณ€ํ™” (streaming ์ˆœ์„œ)\n - Val๋งŒ ์˜ค๋ฅด๋ฉด โ†’ ์˜ค๋ฒ„ํ”ผํŒ… ์‹œ์ž‘\n2. LR ์Šค์ผ€์ค„ ํ™•์ธ, ๋ฐ์ดํ„ฐ ์…”ํ”Œ๋ง ํ™•์ธ\n\n**์‹œ๋‚˜๋ฆฌ์˜ค D**: \"Loss๊ฐ€ 4.0์—์„œ ๋” ์•ˆ ์ค„์–ด์š”\"\n1. ํ•™์Šต ํ† ํฐ ์ˆ˜ ํ™•์ธ (5B ๋ฏธ๋งŒ์ด๋ฉด ํ•™์Šต ๋ถ€์กฑ)\n2. 100M ๋ชจ๋ธ๊ณผ ๋น„๊ต\n3. LR Range Test ์‹คํ–‰\n4. ๋ฐ์ดํ„ฐ ํ’ˆ์งˆ ์ƒ˜ํ”Œ๋ง"
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": "# --- Scenario A ---\nprint(\"=\" * 50)\nprint(\"Testing Scenario A (loss stuck at initial value)\")\nprint(\"=\" * 50)\nscenario_a = LossDebugger.detect_scenario(\n metrics_history=mock_history_a,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_a['scenario']}\")\n\n# --- Scenario B ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario B (NaN appeared)\")\nprint(\"=\" * 50)\nscenario_b = LossDebugger.detect_scenario(\n metrics_history=mock_history_b,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_b['scenario']}\")\n\n# --- Scenario C ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario C (loss bounce)\")\nprint(\"=\" * 50)\nscenario_c = LossDebugger.detect_scenario(\n metrics_history=mock_history_c,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_c['scenario']}\")\n\n# --- Scenario D ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Scenario D (loss plateau)\")\nprint(\"=\" * 50)\nscenario_d = LossDebugger.detect_scenario(\n metrics_history=mock_history_d,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_d['scenario']}\")\n\n# --- Normal ---\nprint(\"\\n\" + \"=\" * 50)\nprint(\"Testing Normal scenario\")\nprint(\"=\" * 50)\nscenario_n = LossDebugger.detect_scenario(\n metrics_history=mock_history_normal,\n vocab_size=vocab_size,\n)\nprint(f\"\\n>>> Detected: Scenario {scenario_n['scenario']}\")"
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "metadata": {},
331
+ "source": [
332
+ "## 8. ์ „์ฒด ์ง„๋‹จ (run_diagnostics)\n",
333
+ "\n",
334
+ "์œ„์˜ ๋ชจ๋“  ๋ ˆ๋ฒจ์„ ํ•œ ๋ฒˆ์— ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. `levels` ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์‹คํ–‰ํ•  ๋ ˆ๋ฒจ์„ ์„ ํƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.\n",
335
+ "\n",
336
+ "```python\n",
337
+ "# ์‹ค์ œ ํ•™์Šต ํ›„ ์‚ฌ์šฉ ์˜ˆ์‹œ:\n",
338
+ "# report = LossDebugger.run_diagnostics(\n",
339
+ "# model=model, dataloader=train_dl, tokenizer=tokenizer,\n",
340
+ "# train_config=train_config,\n",
341
+ "# metrics_history=trainer.metrics.history,\n",
342
+ "# device=device, dtype=dtype,\n",
343
+ "# )\n",
344
+ "```"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "metadata": {},
351
+ "outputs": [],
352
+ "source": [
353
+ "report = LossDebugger.run_diagnostics(\n",
354
+ " model=model,\n",
355
+ " dataloader=train_dl,\n",
356
+ " tokenizer=tokenizer,\n",
357
+ " train_config=train_config,\n",
358
+ " metrics_history=mock_history_a,\n",
359
+ " device=device,\n",
360
+ " dtype=dtype,\n",
361
+ " vocab_size=vocab_size,\n",
362
+ " levels=[0, 1, 2, 3, 4, 5],\n",
363
+ ")"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "source": "## 9. Study Roadmap (์ง‘์ค‘ ๊ณต๋ถ€ ๋กœ๋“œ๋งต)\n\nLLM ํ•™์Šต ์ตœ์ ๏ฟฝ๏ฟฝ๏ฟฝ์— ๋Œ€ํ•œ ์ฒด๊ณ„์  ํ•™์Šต ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค.",
369
+ "metadata": {}
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "source": "LossDebugger.print_study_roadmap()",
374
+ "metadata": {},
375
+ "execution_count": null,
376
+ "outputs": []
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {},
381
+ "source": "## 10. ๋””๋ฒ„๊น… ํŒ\n\n**๊ถŒ์žฅ ์ˆœ์„œ:**\n1. Level 0์œผ๋กœ ์ƒํƒœ ํ™•์ธ โ†’ ์–ด๋–ค ๋ ˆ๋ฒจ์„ ์ ๊ฒ€ํ•ด์•ผ ํ•˜๋Š”์ง€ ์•Œ๋ ค์คŒ\n2. Level 1 (๋ฐ์ดํ„ฐ) โ†’ ๊ฐ€์žฅ ๋จผ์ €! 70%์˜ ๋ฌธ์ œ๊ฐ€ ์—ฌ๊ธฐ์„œ ๋ฐœ๊ฒฌ๋จ\n3. Level 2 (์ˆ˜์น˜ ์•ˆ์ •์„ฑ) โ†’ NaN/Inf ๋ฌธ์ œ ํ•ด๊ฒฐ\n4. Level 3 (ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ) โ†’ LR์„ ๋จผ์ € ํŠœ๋‹ (์˜ํ–ฅ๋ ฅ 10๋ฐฐ)\n5. Level 4 (ํ”ผํŒ… ์ง„๋‹จ) โ†’ ์ถฉ๋ถ„ํžˆ ํ•™์Šตํ•œ ํ›„์— ํ™•์ธ\n6. Level 5 (์•„ํ‚คํ…์ฒ˜) โ†’ ์œ„ ๋ ˆ๋ฒจ์—์„œ ์›์ธ์„ ๋ชป ์ฐพ์„ ๋•Œ\n\n**Scenario Detection** โ†’ ์ฆ์ƒ ๊ธฐ๋ฐ˜์œผ๋กœ ์–ด๋–ค ์‹œ๋‚˜๋ฆฌ์˜ค์ธ์ง€ ์ž๋™ ํŒ๋ณ„\n\n### ํ•ต์‹ฌ ์ฐธ๊ณ  ์ž๋ฃŒ\n\n| ์ž๋ฃŒ | ๋‚ด์šฉ | ์šฐ์„ ์ˆœ์œ„ |\n|------|------|----------|\n| Karpathy \"Recipe for Training NNs\" | ์‹ค์ „ ๋””๋ฒ„๊น… ๋งˆ์ธ๋“œ์…‹ | โญโญโญ |\n| Hoffmann et al. 2022 (Chinchilla) | Scaling Law ํ•ต์‹ฌ | โญโญโญ |\n| Kaplan et al. 2020 (Scaling Laws) | Loss ์˜ˆ์ธก ๊ณต์‹ | โญโญโญ |\n| Touvron et al. 2023 (LLaMA) | 1B๊ธ‰ ๋ชจ๋ธ ํ•™์Šต ์„ธ๋ถ€์‚ฌํ•ญ | โญโญ |\n| Biderman et al. 2023 (Pythia) | ๊ณต๊ฐœ ํ•™์Šต ๋กœ๊ทธ, ์žฌํ˜„์„ฑ | โญโญ |\n| Zhang et al. 2024 (TinyLlama) | 1.1B ๋ชจ๋ธ 3T ํ† ํฐ ํ•™์Šต | โญโญ |\n| Groeneveld et al. 2024 (OLMo) | ์™„์ „ ๊ณต๊ฐœ LLM ํ”„๋ ˆ์ž„์›Œํฌ | โญโญ |\n| Li et al. 2018 (Loss Landscape) | Loss ์ง€ํ˜• ์ง๊ด€ | โญโญ |\n| Loshchilov & Hutter 2019 (AdamW) | ์˜ตํ‹ฐ๋งˆ์ด์ € ๊ธฐ์ดˆ | โญโญ |\n| Yang et al. 2022 (ฮผP) | ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ „์ด | โญ |\n\n---\n**์ด์ „ ๋‹จ๊ณ„:** `03_training.ipynb`์—์„œ ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. \n**๋‹ค์Œ ๋‹จ๊ณ„:** `04_evaluation.ipynb`์—์„œ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค."
382
+ }
383
+ ]
384
+ }