LLM-1B-Lab / llm_lab /training /debugger.py
Vjeong's picture
Remove redundant detect_scenario from LossDebugger
2a50172
"""LLM Loss Debugging & Optimization Framework.
A systematic 5-level debugging framework for diagnosing training issues.
Always start from Level 1 β€” fixing lower-level bugs before tuning
hyperparameters saves time.
Levels:
0. Status Diagnosis β€” classify current training health
1. Data/Implementation β€” most common cause (70% of issues)
2. Numerical Stability β€” dtype, normalization, gradient health
3. Hyperparameters β€” LR, batch size, warmup
4. Fitting Diagnosis β€” overfitting vs underfitting
5. Architecture β€” initialization, component checks
"""
import copy
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from llm_lab.config import TrainConfig
# ═══════════════════════════════════════════════════════════════════
# Constants
# ═══════════════════════════════════════════════════════════════════
# Approximate convergence ranges for a 1B model trained on ~10B tokens.
# Estimated from GPT-2 scaling benchmarks (Radford et al. 2019) and
# Chinchilla scaling laws (Hoffmann et al. 2022). Not dataset-specific.
_EXPECTED_TRAIN_LOSS = (2.5, 3.3)
_EXPECTED_VAL_LOSS = (2.7, 3.6)
_EXPECTED_VAL_PPL = (15, 37)
# Status labels
STATUS_NORMAL = "NORMAL"
STATUS_NO_DECREASE = "NO_DECREASE"
STATUS_DIVERGING = "DIVERGING"
STATUS_PLATEAU = "PLATEAU"
STATUS_OVERFITTING = "OVERFITTING"
STATUS_UNSTABLE = "UNSTABLE"
STATUS_NAN_DETECTED = "NAN_DETECTED"
STATUS_LOSS_BOUNCE = "LOSS_BOUNCE"
# GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
# (param_count, recommended_lr, batch_tokens_str)
_GPT3_LR_REFERENCE = [
(125e6, 6e-4, "0.5M"),
(350e6, 3e-4, "0.5M"),
(760e6, 2.5e-4, "0.5M"),
(1.3e9, 2e-4, "1M"),
(2.7e9, 1.6e-4, "1M"),
(6.7e9, 1.2e-4, "2M"),
(13e9, 1e-4, "2M"),
(175e9, 6e-5, "3.2M"),
]
# Known LLM training references
_LLM_TRAINING_REFS = {
"TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
"LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000},
"Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.01},
"OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1},
}
# Recommended Ξ²β‚‚ for LLM training
_RECOMMENDED_BETA2 = 0.95
_DEFAULT_PYTORCH_BETA2 = 0.999
def _header(title: str) -> str:
return f"\n{'=' * 60}\n{title}\n{'=' * 60}"
def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]:
return {"name": name, "passed": passed, "detail": detail}
# ═══════════════════════════════════════════════════════════════════
# LossDebugger
# ═══════════════════════════════════════════════════════════════════
class LossDebugger:
"""5-level loss debugging framework for LLM training.
Usage::
from llm_lab.training.debugger import LossDebugger
# Quick status check
status = LossDebugger.diagnose_status(vocab_size=32000,
metrics_history=trainer.metrics.history)
# Full diagnostics
report = LossDebugger.run_diagnostics(
model=model, dataloader=train_dl, tokenizer=tok,
train_config=train_cfg, metrics_history=trainer.metrics.history,
device=device, dtype=torch.bfloat16,
)
"""
# ───────────────────────────────────────────────────────────────
# Level 0: Status Diagnosis
# ───────────────────────────────────────────────────────────────
@staticmethod
def diagnose_status(
vocab_size: int,
metrics_history: Dict[str, list],
) -> Dict[str, Any]:
"""Classify current training health from metrics history.
Args:
vocab_size: model vocabulary size (e.g. 32000)
metrics_history: dict with keys 'train_loss', 'val_loss', etc.
Returns:
dict with 'status', 'severity', 'details', 'recommended_levels'
"""
print(_header("Level 0: Training Status Diagnosis"))
expected_initial = math.log(vocab_size)
print(f" Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}")
print(f" Normal convergence range (1B, 10B tokens):")
print(f" Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}")
print(f" Val Loss: {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}")
print(f" Val PPL: {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}")
raw_train_losses = metrics_history.get("train_loss", [])
train_losses = [l for l in raw_train_losses if not math.isnan(l)]
val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
if len(train_losses) < 2:
print("\n [!] Not enough training data to diagnose. Run more steps first.")
return {
"status": "INSUFFICIENT_DATA",
"severity": "unknown",
"details": "Need at least 2 logged train loss values.",
"recommended_levels": [1],
}
# Detect NaN presence before filtering
has_nan = len(train_losses) < len(raw_train_losses)
if has_nan:
nan_count = len(raw_train_losses) - len(train_losses)
print(f"\n ⚠ {nan_count} NaN values detected in train_loss β€” filtered for analysis")
first_loss = train_losses[0]
last_loss = train_losses[-1]
loss_change = first_loss - last_loss
# Split into halves for trend analysis
mid = len(train_losses) // 2
first_half_avg = sum(train_losses[:mid]) / mid
second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid)
# Recent window for spike detection
recent_n = min(50, len(train_losses))
recent = train_losses[-recent_n:]
recent_mean = sum(recent) / len(recent)
recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent)
recent_std = recent_var ** 0.5
# Val trend
val_trend = "unknown"
if len(val_losses) >= 2:
val_mid = len(val_losses) // 2
val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
if val_second_avg < val_first_avg - 0.05:
val_trend = "decreasing"
elif val_second_avg > val_first_avg + 0.1:
val_trend = "increasing"
else:
val_trend = "flat"
# Pre-compute bounce detection using moving-average minimum
# to avoid false positives from single noisy data points
_ma_window = max(1, len(train_losses) // 20) # 5% window
_ma_losses = [
sum(train_losses[max(0, i - _ma_window + 1):i + 1])
/ (i - max(0, i - _ma_window + 1) + 1)
for i in range(len(train_losses))
]
_min_ma_loss = min(_ma_losses)
_min_ma_idx = _ma_losses.index(_min_ma_loss)
_last_ma_loss = _ma_losses[-1]
_bounce_amount = _last_ma_loss - _min_ma_loss
_has_bounce = (
loss_change > 0.1
and _min_ma_idx < len(train_losses) * 0.85
and _bounce_amount > _min_ma_loss * 0.05
)
# Downgrade bounce severity when val loss is still improving
_val_improving = (
val_trend == "decreasing"
or (len(val_losses) >= 4
and val_losses[-1] <= min(val_losses[:len(val_losses) // 2]))
)
# ── Classify ──
status = STATUS_NORMAL
severity = "green"
details = ""
recommended_levels: List[int] = []
# Check 1: No decrease at all
if loss_change < 0.1 and first_loss > expected_initial - 2.0:
status = STATUS_NO_DECREASE
severity = "red"
details = (
f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} "
f"(delta={loss_change:.4f}). Likely a data or implementation bug."
)
recommended_levels = [1, 2]
# Check 2: Diverging
elif last_loss > expected_initial + 1.0:
status = STATUS_DIVERGING
severity = "red"
details = (
f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). "
f"Training is diverging β€” check LR, data, or numerical issues."
)
recommended_levels = [1, 2, 3]
# Check 3: NaN detected in training loss
elif has_nan:
nan_count = len(raw_train_losses) - len(train_losses)
nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l))
status = STATUS_NAN_DETECTED
severity = "red"
details = (
f"NaN detected in train_loss: {nan_count} NaN values "
f"(first at step ~{nan_idx}). "
f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. "
f"Check gradient norms, LR schedule, and numerical precision."
)
recommended_levels = [2, 3]
# Check 4: Unstable (large spikes)
elif recent_std > 0.5 * recent_mean:
status = STATUS_UNSTABLE
severity = "yellow"
details = (
f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. "
f"Training is unstable β€” likely LR too high or batch too small."
)
recommended_levels = [3, 2]
# Check 5: Loss bounce (decreased then increased again)
elif _has_bounce:
status = STATUS_LOSS_BOUNCE
if _val_improving:
severity = "green"
details = (
f"Train loss bounced (moving-avg): "
f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} "
f"(bounce={_bounce_amount:.4f}), but val loss is still improving "
f"({val_losses[0]:.4f} -> {val_losses[-1]:.4f}). "
f"Likely data distribution variation, not a real issue."
)
recommended_levels = []
else:
severity = "yellow"
details = (
f"Train loss bounced (moving-avg): "
f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} "
f"(bounce={_bounce_amount:.4f}). "
f"Possible LR too high, data issue, or overfitting."
)
recommended_levels = [3, 4]
# Check 6: Overfitting
elif val_trend == "increasing" and second_half_avg < first_half_avg:
status = STATUS_OVERFITTING
severity = "yellow"
details = (
f"Train loss decreasing but val loss increasing. "
f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, "
f"Val trend: {val_trend}."
)
recommended_levels = [4]
# Check 7: Plateau
elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
status = STATUS_PLATEAU
severity = "yellow"
details = (
f"Loss has plateaued: first half avg={first_half_avg:.4f}, "
f"second half avg={second_half_avg:.4f}. "
f"Current loss ({last_loss:.4f}) is above expected range."
)
recommended_levels = [3, 4, 5]
# Normal
else:
status = STATUS_NORMAL
severity = "green"
details = (
f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} "
f"(delta={loss_change:.4f}). Val trend: {val_trend}."
)
recommended_levels = []
# ── Print ──
icons = {"red": "πŸ”΄", "yellow": "🟑", "green": "🟒"}
icon = icons.get(severity, "βšͺ")
print(f"\n {icon} Status: {status}")
print(f" {details}")
if recommended_levels:
print(f" Recommended: check Level(s) {recommended_levels}")
return {
"status": status,
"severity": severity,
"details": details,
"recommended_levels": recommended_levels,
}
# ───────────────────────────────────────────────────────────────
# Level 1: Data / Implementation Bug Checks
# ───────────────────────────────────────────────────────────────
@staticmethod
def check_data_pipeline(
model: nn.Module,
dataloader: DataLoader,
tokenizer: Any,
vocab_size: int,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
"""Run 6 data/implementation checks (Level 1).
This is the most important level β€” 70% of loss issues are data bugs.
Checks:
1. Shift relationship (targets[t] == input_ids[t+1])
2. Token range (0 <= ids < vocab_size)
3. Initial loss (β‰ˆ ln(vocab_size) for random weights)
4. Single-batch overfit (loss β†’ ~0 in 200 steps)
5. Tokenizer roundtrip (encode→decode preserves text)
6. Data quality sampling (visual inspection)
"""
print(_header("Level 1: Data / Implementation Bug Checks"))
print(" (70% of loss issues come from data pipeline bugs)\n")
results: List[Dict[str, Any]] = []
batch = next(iter(dataloader))
input_ids = batch["input_ids"]
targets = batch["targets"]
# ── Check 1: Shift relationship ──
shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item()
passed = shift_match > 0.99
detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)"
results.append(_check_result("Shift relationship", passed, detail))
icon = "βœ…" if passed else "❌"
print(f" {icon} Check 1: {detail}")
# ── Check 2: Token range ──
min_id = input_ids.min().item()
max_id = input_ids.max().item()
range_ok = min_id >= 0 and max_id < vocab_size
detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}"
results.append(_check_result("Token range", range_ok, detail))
icon = "βœ…" if range_ok else "❌"
print(f" {icon} Check 2: {detail}")
# ── Check 3: Initial loss ──
expected_loss = math.log(vocab_size)
model_copy = copy.deepcopy(model)
model_copy._init_weights() # re-initialize to random
model_copy.to(device)
model_copy.eval()
with torch.no_grad():
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
_, initial_loss = model_copy(
input_ids.to(device),
targets.to(device),
)
initial_loss_val = initial_loss.item()
loss_diff = abs(initial_loss_val - expected_loss)
loss_ok = loss_diff < 1.0
detail = (
f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} "
f"(diff={loss_diff:.4f})"
)
results.append(_check_result("Initial loss", loss_ok, detail))
icon = "βœ…" if loss_ok else "❌"
print(f" {icon} Check 3: {detail}")
if initial_loss_val > expected_loss + 1.0:
print(f" Hint: loss >> ln(V) suggests label mismatch or loss function bug")
elif initial_loss_val < expected_loss - 2.0:
print(f" Hint: loss << ln(V) suggests data leakage")
del model_copy
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ── Check 4: Single-batch overfit test ──
# Scale LR and steps based on model size to avoid instability
num_params = sum(p.numel() for p in model.parameters())
if num_params > 500e6:
overfit_lr, overfit_steps = 1e-4, 400
elif num_params > 50e6:
overfit_lr, overfit_steps = 3e-4, 300
else:
overfit_lr, overfit_steps = 1e-3, 200
print(f"\n ⏳ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...")
overfit_model = copy.deepcopy(model)
overfit_model.to(device)
overfit_model.train()
overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr)
single_input = input_ids[:1].to(device) # single sample
single_target = targets[:1].to(device)
log_interval = max(overfit_steps // 4, 1)
overfit_losses = []
for step in range(overfit_steps):
overfit_optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
_, loss = overfit_model(single_input, single_target)
loss.backward()
torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0)
overfit_optimizer.step()
overfit_losses.append(loss.item())
if (step + 1) % log_interval == 0:
print(f" Step {step + 1}: Loss = {loss.item():.4f}")
final_overfit_loss = overfit_losses[-1]
min_overfit_loss = min(overfit_losses)
overfit_ok = min_overfit_loss < 0.5
detail = (
f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
f"(min={min_overfit_loss:.4f}, target < 0.5)"
)
results.append(_check_result("Single-batch overfit", overfit_ok, detail))
icon = "βœ…" if overfit_ok else "❌"
print(f" {icon} Check 4: {detail}")
if not overfit_ok:
print(f" CRITICAL: Model cannot memorize a single batch!")
print(f" This means the model or loss function has a bug.")
del overfit_model, overfit_optimizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ── Check 5: Tokenizer roundtrip ──
test_text = "The quick brown fox jumps over the lazy dog."
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
roundtrip_ok = test_text.strip() in decoded.strip()
detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'"
results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail))
icon = "βœ…" if roundtrip_ok else "❌"
print(f" {icon} Check 5: {detail}")
# ── Check 6: Data quality sampling ──
print(f"\n πŸ“‹ Check 6: Data quality sampling (visual inspection)")
for i in range(min(3, input_ids.shape[0])):
sample_tokens = input_ids[i][:100].tolist()
decoded_text = tokenizer.decode(sample_tokens)
preview = decoded_text[:200].replace("\n", "\\n")
print(f" Sample {i}: {preview}...")
passed_count = sum(1 for r in results if r["passed"])
total_count = len(results)
print(f"\n Result: {passed_count}/{total_count} checks passed")
return {
"level": 1,
"checks": results,
"passed": [r for r in results if r["passed"]],
"failed": [r for r in results if not r["passed"]],
}
# ───────────────────────────────────────────────────────────────
# Level 2: Numerical Stability
# ───────────────────────────────────────────────────────────────
@staticmethod
def check_numerical_stability(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
"""Check for NaN/Inf in gradients, activations, and logits (Level 2).
Checks:
- Mixed precision config (RMSNorm fp32 upcast, loss dtype)
- NaN/Inf gradients β†’ softmax overflow, bad data
- Inf gradients β†’ log(0) in loss, missing ignore_index
- Large activations growing per layer β†’ initialization or norm bug
- Logit scale β†’ should be < 1000
"""
print(_header("Level 2: Numerical Stability Checks"))
batch = next(iter(dataloader))
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
results: List[Dict[str, Any]] = []
activation_stats: List[Dict[str, Any]] = []
# ── Mixed Precision Configuration Check ──
print("\n Mixed Precision Config:")
print(f" Training dtype: {dtype}")
# Check RMSNorm fp32 upcast
norm_fp32_ok = True
checked_norm_classes: set = set()
for name, module in model.named_modules():
cls_name = module.__class__.__name__
if "Norm" in cls_name and cls_name not in checked_norm_classes:
checked_norm_classes.add(cls_name)
import inspect
try:
src = inspect.getsource(type(module).forward)
has_upcast = ".float()" in src or "float32" in src
except (TypeError, OSError):
has_upcast = True # assume ok if can't inspect
if not has_upcast:
norm_fp32_ok = False
print(f" πŸ”΄ {cls_name}: no fp32 upcast detected!")
if norm_fp32_ok:
print(f" βœ… Norm layers use fp32 upcast (safe)")
results.append(_check_result(
"Norm fp32 upcast", norm_fp32_ok,
"Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype",
))
# Check loss computation dtype
if dtype in (torch.bfloat16, torch.float16):
print(f" ℹ️ Best practice: compute loss in fp32 when using {dtype}")
print(f" logits_fp32 = logits.float()")
print(f" loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))")
# Common numerical issues reference
print("\n Common Numerical Issues Reference:")
print(" β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ Symptom β”‚ Likely Cause β”‚ Solution β”‚")
print(" β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print(" β”‚ Loss β†’ NaN β”‚ Large logits β†’ softmax β”‚ Check init, logit scale β”‚")
print(" β”‚ Loss β†’ Inf β”‚ log(0) in CE loss β”‚ Add eps, ignore_index β”‚")
print(" β”‚ Loss oscillation β”‚ fp16 gradient underflow β”‚ Switch to bf16 / scaler β”‚")
print(" β”‚ Late-training NaN β”‚ Activation growth β”‚ Check RMSNorm, wd β”‚")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
# ── Activation monitoring via hooks ──
hooks = []
def make_hook(name: str):
def hook_fn(module, input, output):
if isinstance(output, torch.Tensor):
out_f = output.float()
stats = {
"name": name,
"mean": out_f.mean().item(),
"std": out_f.std().item(),
"max": out_f.abs().max().item(),
"has_nan": bool(torch.isnan(output).any()),
"has_inf": bool(torch.isinf(output).any()),
}
activation_stats.append(stats)
return hook_fn
# Register hooks on transformer layers
for i, layer in enumerate(model.layers):
h = layer.register_forward_hook(make_hook(f"layer_{i}"))
hooks.append(h)
# ── Forward + Backward ──
model.train()
model.zero_grad(set_to_none=True)
use_scaler = dtype == torch.float16 and torch.cuda.is_available()
scaler = torch.amp.GradScaler() if use_scaler else None
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
logits, loss = model(input_ids, targets)
loss_val = loss.item()
loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val))
results.append(_check_result(
"Loss value",
loss_ok,
f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
))
if scaler is not None:
scaler.scale(loss).backward()
_temp_opt = torch.optim.SGD(model.parameters(), lr=0)
scaler.unscale_(_temp_opt)
else:
loss.backward()
# Remove hooks
for h in hooks:
h.remove()
# ── Gradient checks ──
print("\n Gradient Health:")
grad_issues = []
for name, param in model.named_parameters():
if param.grad is None:
continue
grad = param.grad
if torch.isnan(grad).any():
grad_issues.append(f"πŸ”΄ NaN gradient: {name}")
if torch.isinf(grad).any():
grad_issues.append(f"πŸ”΄ Inf gradient: {name}")
if grad.abs().max().item() > 100:
grad_issues.append(
f"🟑 Large gradient: {name} max={grad.abs().max().item():.1f}"
)
grad_ok = len(grad_issues) == 0
if grad_ok:
print(" βœ… All gradients are healthy (no NaN/Inf/large values)")
else:
for issue in grad_issues[:10]: # limit output
print(f" {issue}")
if len(grad_issues) > 10:
print(f" ... and {len(grad_issues) - 10} more issues")
results.append(_check_result(
"Gradient health",
grad_ok,
f"{len(grad_issues)} issues found" if not grad_ok else "All healthy",
))
# ── Activation checks ──
print("\n Activation Stats (per transformer layer):")
act_nan_count = 0
for stats in activation_stats:
icon = "πŸ”΄" if stats["has_nan"] or stats["has_inf"] else " "
if stats["has_nan"] or stats["has_inf"]:
act_nan_count += 1
print(
f" {icon} {stats['name']}: "
f"mean={stats['mean']:.4f}, "
f"std={stats['std']:.4f}, "
f"max={stats['max']:.4f}"
+ (" [NaN!]" if stats["has_nan"] else "")
+ (" [Inf!]" if stats["has_inf"] else "")
)
act_ok = act_nan_count == 0
results.append(_check_result(
"Activation health",
act_ok,
f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
))
# ── Activation growth trend ──
if len(activation_stats) >= 2:
stds = [s["std"] for s in activation_stats]
if stds[0] > 1e-8:
growth_ratio = stds[-1] / stds[0]
growth_ok = growth_ratio < 10
detail = (
f"Activation std ratio (last/first): {growth_ratio:.1f}x "
f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})"
)
results.append(_check_result("Activation growth", growth_ok, detail))
icon = "βœ…" if growth_ok else "🟑"
print(f" {icon} {detail}")
if not growth_ok:
print(f" Possible initialization or normalization issue")
# ── Logit scale check ──
logit_max = logits.float().abs().max().item()
logit_ok = logit_max < 1000
detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)"
results.append(_check_result("Logit scale", logit_ok, detail))
icon = "βœ…" if logit_ok else "πŸ”΄"
print(f"\n {icon} Logit scale: {detail}")
model.zero_grad(set_to_none=True)
passed_count = sum(1 for r in results if r["passed"])
print(f"\n Result: {passed_count}/{len(results)} checks passed")
return {
"level": 2,
"checks": results,
"activation_stats": activation_stats,
"grad_issues": grad_issues,
}
# ───────────────────────────────────────────────────────────────
# Level 3: Hyperparameter Diagnosis
# ───────────────────────────────────────────────────────────────
@staticmethod
def diagnose_hyperparameters(
metrics_history: Dict[str, list],
config: TrainConfig,
) -> Dict[str, Any]:
"""Analyze hyperparameter health from training metrics (Level 3).
Checks:
- LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny)
- Batch size: loss variance indicates batch too small
- Warmup: spikes in early steps indicate warmup too short
"""
print(_header("Level 3: Hyperparameter Diagnosis"))
findings: List[Dict[str, str]] = []
grad_norms = metrics_history.get("grad_norm", [])
train_losses = metrics_history.get("train_loss", [])
# ── LR diagnosis ──
print("\n Learning Rate Analysis:")
print(f" Peak LR: {config.learning_rate:.2e}")
print(f" Min LR: {config.min_learning_rate:.2e}")
if grad_norms:
avg_grad = sum(grad_norms) / len(grad_norms)
# Ref: PyTorch clip_grad_norm_ clips when total_norm > max_norm
clip_count = sum(1 for g in grad_norms if g >= config.grad_clip)
clip_rate = clip_count / len(grad_norms)
# Relative threshold: < 1% of clip limit (model-size independent)
tiny_threshold = config.grad_clip * 0.01
tiny_count = sum(1 for g in grad_norms if g < tiny_threshold)
tiny_rate = tiny_count / len(grad_norms)
print(f" Avg grad norm: {avg_grad:.4f}")
print(f" Clip rate: {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})")
print(f" Tiny grad rate: {tiny_rate * 100:.1f}% (< {tiny_threshold:.4f})")
# Heuristic: >50% clipping means most steps are capped, so the
# effective LR is lower than configured. Practitioners generally
# treat this as a sign that peak LR is too high.
if clip_rate > 0.5:
findings.append({
"issue": "LR may be too high",
"evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time",
"action": f"Try LR = {config.learning_rate / 2:.2e} (Γ·2)",
})
print(f" 🟑 Grad clipping frequent ({clip_rate * 100:.0f}%) β†’ LR may be too high")
elif tiny_rate > 0.5:
findings.append({
"issue": "Possible vanishing gradients",
"evidence": f"Grad norm < {tiny_threshold:.4f} in {tiny_rate * 100:.0f}% of steps",
"action": "Check weight initialization, layer norms, and model depth",
})
print(f" 🟑 Grad norm too small ({tiny_rate * 100:.0f}% < {tiny_threshold:.4f}) β†’ possible vanishing gradients")
else:
print(f" βœ… LR looks appropriate")
# ── Batch size diagnosis ──
print("\n Batch Size Analysis:")
print(f" Effective batch: {config.effective_batch_size}")
if len(train_losses) >= 50:
recent_losses = train_losses[-50:]
loss_mean = sum(recent_losses) / len(recent_losses)
loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)
print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation, last 50 steps)")
if loss_cv > 0.1:
findings.append({
"issue": "Training loss has high variance",
"evidence": f"Loss CV = {loss_cv:.4f} over last 50 steps",
"action": "Check: (1) LR may be too high, (2) increase gradient_accumulation_steps, (3) inspect data quality",
})
print(f" 🟑 High loss variance β†’ check LR, batch size, or data quality")
else:
print(f" βœ… Loss variance is acceptable")
# ── Ξ²β‚‚ diagnosis ──
print("\n Ξ²β‚‚ (Adam second momentum) Analysis:")
print(f" Current Ξ²β‚‚: {config.beta2}")
if config.beta2 >= _DEFAULT_PYTORCH_BETA2:
findings.append({
"issue": "Ξ²β‚‚ may be too high for LLM training",
"evidence": (
f"Ξ²β‚‚={config.beta2} (PyTorch default). "
f"LLM standard is {_RECOMMENDED_BETA2}"
),
"action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)",
})
print(f" 🟑 Ξ²β‚‚={config.beta2} is PyTorch default β†’ "
f"LLM training standard is {_RECOMMENDED_BETA2}")
print(f" Why: Ξ²β‚‚=0.999 averages ~1000 steps of gradient stats,")
print(f" Ξ²β‚‚=0.95 averages ~20 steps β†’ faster adaptation to changing data")
print(f" (Cattaneo & Shigida 2025, 'Tuning Adam(W)')")
else:
print(f" βœ… Ξ²β‚‚={config.beta2} is within LLM standard range")
# ── Weight Decay diagnosis ──
print("\n Weight Decay Analysis:")
print(f" Current weight_decay: {config.weight_decay}")
if config.weight_decay == 0:
findings.append({
"issue": "Weight decay is disabled",
"evidence": "weight_decay=0 increases overfitting risk",
"action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)",
})
print(f" 🟑 weight_decay=0 β†’ overfitting risk. Standard is 0.1")
elif config.weight_decay > 0.3:
findings.append({
"issue": "Weight decay may be too high",
"evidence": f"weight_decay={config.weight_decay} (unusually high)",
"action": "Try weight_decay=0.1 (standard value)",
})
print(f" 🟑 weight_decay={config.weight_decay} is unusually high (standard: 0.1)")
else:
print(f" βœ… weight_decay={config.weight_decay} is within normal range")
# ── Model-size LR reference ──
print("\n GPT-3 LR Reference (Brown et al. 2020):")
print(" β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ Model β”‚ Peak LR β”‚ Batch Tokens β”‚")
print(" β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
for params, lr, batch_tok in _GPT3_LR_REFERENCE:
label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M"
marker = " ←" if abs(params - 1.1e9) < 0.5e9 else ""
print(f" β”‚ {label:<8} β”‚ {lr:.1e} β”‚ {batch_tok:<12} β”‚{marker}")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print(" β†’ Larger models need lower LR and larger batch")
# ── Batch-LR scaling guidance ──
print("\n Batch-LR Scaling Rules:")
print(" β€’ Batch Γ—2 β†’ LR Γ—βˆš2 (square root scaling, recommended for Adam)")
print(" (Malladi et al. NeurIPS 2022, 'On the SDEs and Scaling Rules for Adaptive Gradient Algorithms')")
print(" β€’ Batch Γ—2 β†’ LR Γ—2 (linear scaling, Goyal et al. 2017, mainly SGD)")
print(" β€’ 1B model: ~1K-2K sequences (~2-4M tokens) is typical")
print(" (Pythia-1B: ~2M tokens, TinyLlama: ~2M, OLMo-1B: ~4M)")
# ── Warmup diagnosis ──
print("\n Warmup Analysis:")
print(f" Warmup steps: {config.warmup_steps} "
f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)")
if len(train_losses) >= 10:
early_losses = train_losses[:min(50, len(train_losses))]
# Detect spikes in early training
spike_count = 0
for i in range(1, len(early_losses)):
if early_losses[i] > early_losses[i - 1] * 1.5:
spike_count += 1
if spike_count > 3:
findings.append({
"issue": "Warmup may be too short",
"evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps",
"action": f"Try warmup_steps = {config.warmup_steps * 2}",
})
print(f" 🟑 {spike_count} spikes in early training β†’ warmup may be too short")
else:
print(f" βœ… Early training is stable")
# ── Summary ──
if not findings:
print("\n βœ… No hyperparameter issues detected")
else:
print(f"\n Found {len(findings)} potential issue(s):")
for f in findings:
print(f" β€’ {f['issue']}: {f['action']}")
# ── Warmup reference from real projects ──
print("\n Warmup Reference (real projects):")
print(" β€’ TinyLlama 1.1B (3T tokens): 2,000 steps β‰ˆ 0.1% of total")
print(" β€’ GPT-3 175B: 375M warmup tokens β‰ˆ 117 steps")
print(" β€’ General range: 0.1% ~ 5% of total steps")
print(" β€’ Smaller experiments: 5~10% is also reasonable")
print("\n Tuning priority (high β†’ low):")
print(" 1. Learning Rate ← tune first (10x impact)")
print(" 2. Batch Size ← adjust with LR")
print(" 3. Warmup Steps ← early stability")
print(" 4. Weight Decay ← if overfitting (typically 0.1)")
print(" 5. β₁, Ξ²β‚‚ (Adam) ← see Ξ²β‚‚ analysis above")
print(" 6. Gradient Clip ← usually keep at 1.0")
return {
"level": 3,
"findings": findings,
"config_summary": {
"learning_rate": config.learning_rate,
"effective_batch": config.effective_batch_size,
"warmup_steps": config.warmup_steps,
"total_steps": config.total_steps,
"grad_clip": config.grad_clip,
},
}
@staticmethod
def lr_range_test(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
lr_start: float = 1e-7,
lr_end: float = 1e-1,
steps: int = 300,
) -> Dict[str, Any]:
"""Run an LR range test to find the optimal learning rate (Level 3 bonus).
Sweeps LR from lr_start to lr_end exponentially, recording loss.
The optimal LR is where loss decreases fastest (steepest slope),
divided by 3~10 for stability.
WARNING: This modifies a copy of the model. The original is untouched.
"""
print(_header("Level 3 Bonus: LR Range Test"))
print(f" Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n")
test_model = copy.deepcopy(model)
test_model.to(device)
test_model.train()
optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start)
lr_mult = (lr_end / lr_start) ** (1 / steps)
lr = lr_start
lrs: List[float] = []
losses: List[float] = []
data_iter = iter(dataloader)
for step in range(steps):
for pg in optimizer.param_groups:
pg["lr"] = lr
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
batch = next(data_iter)
input_ids = batch["input_ids"].to(device)
targets_t = batch["targets"].to(device)
optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
_, loss = test_model(input_ids, targets_t)
loss.backward()
optimizer.step()
loss_val = loss.item()
lrs.append(lr)
losses.append(loss_val)
if (step + 1) % 50 == 0:
print(f" Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}")
# Stop if loss explodes
if len(losses) > 1 and loss_val > losses[0] * 4:
print(f" Loss exploded at LR = {lr:.2e}, stopping.")
break
lr *= lr_mult
del test_model, optimizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Find steepest descent
best_lr = lr_start
if len(losses) > 10:
# Smooth losses and find steepest negative slope
window = 5
smoothed = []
for i in range(len(losses) - window):
smoothed.append(sum(losses[i:i + window]) / window)
min_slope = 0
min_idx = 0
for i in range(1, len(smoothed)):
slope = smoothed[i] - smoothed[i - 1]
if slope < min_slope:
min_slope = slope
min_idx = i
best_lr = lrs[min_idx]
suggested_lr = best_lr / 3 # conservative choice
print(f"\n Steepest descent at LR = {best_lr:.2e}")
print(f" Suggested peak LR: {suggested_lr:.2e} (Γ·3 for stability)")
print(f" Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]")
else:
suggested_lr = 3e-4
print(f"\n Not enough data points. Using default LR = {suggested_lr:.2e}")
return {
"lrs": lrs,
"losses": losses,
"best_lr": best_lr,
"suggested_lr": suggested_lr,
}
# ───────────────────────────────────────────────────────────────
# Level 4: Overfitting vs Underfitting Diagnosis
# ───────────────────────────────────────────────────────────────
@staticmethod
def diagnose_fitting(
metrics_history: Dict[str, list],
model_params: Optional[int] = None,
total_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""Diagnose overfitting vs underfitting from metrics (Level 4).
Cases:
1. Both high, decreasing β†’ Normal (still training)
2. Both high, plateau β†’ Underfitting
3. Train↓ Valβ†’ or Val↑ β†’ Overfitting
4. Both low, plateau β†’ Converged (or at limit)
"""
print(_header("Level 4: Overfitting vs Underfitting Diagnosis"))
train_losses = metrics_history.get("train_loss", [])
val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None]
if len(train_losses) < 10 or len(val_losses) < 2:
print(" [!] Not enough data. Need more training steps with eval.")
return {"level": 4, "case": "insufficient_data", "recommendations": []}
# Recent train trend
recent_n = min(50, len(train_losses))
train_recent = train_losses[-recent_n:]
train_mid = len(train_recent) // 2
train_first = sum(train_recent[:train_mid]) / max(train_mid, 1)
train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1)
train_decreasing = train_second < train_first - 0.02
# Val trend
val_mid = len(val_losses) // 2
val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1)
val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1)
val_decreasing = val_second < val_first - 0.02
val_increasing = val_second > val_first + 0.05
# Train-Val gap
last_train = train_losses[-1]
last_val = val_losses[-1]
gap = last_train - last_val # negative means val > train (typical)
print(f" Train loss (recent): {train_first:.4f} β†’ {train_second:.4f} "
f"({'↓' if train_decreasing else 'β†’'})")
print(f" Val loss: {val_first:.4f} β†’ {val_second:.4f} "
f"({'↓' if val_decreasing else '↑' if val_increasing else 'β†’'})")
print(f" Train-Val gap: {abs(gap):.4f}")
# ── Classify ──
case = ""
recommendations: List[str] = []
if train_decreasing and val_decreasing:
case = "Case 1: Normal β€” both decreasing"
recommendations.append("Training is progressing normally. Continue.")
if model_params and total_tokens:
ratio = total_tokens / model_params
chinchilla = 20 # Chinchilla optimal: 20 tokens per param
if ratio < chinchilla:
recommendations.append(
f"Token/param ratio = {ratio:.1f}x "
f"(Chinchilla optimal β‰ˆ {chinchilla}x). "
f"Model may benefit from more data."
)
print(f"\n 🟒 {case}")
elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]:
case = "Case 2: Underfitting β€” both plateaued at high loss"
recommendations = [
"Diagnosis priority (check in order):",
"1) Training insufficient? β†’ check if loss curve still has downward slope",
" - Chinchilla: 1B model needs ~20B tokens minimum",
" - TinyLlama trains 1.1B on 3T tokens (inference-optimal)",
"2) LR too low? β†’ try LR Γ—2, see if loss drops faster",
"3) Model capacity too small? β†’ train 2x larger model on same data",
" - If larger model gets lower loss β†’ capacity was the limit",
"4) Data quality? β†’ sample and read training data manually",
" - Noisy/low-quality data raises the achievable loss floor",
]
if model_params and total_tokens:
ratio = total_tokens / model_params
if ratio < 10:
recommendations.insert(0,
f"⚠ Token/param ratio = {ratio:.1f}x β€” "
f"very likely undertrained. Chinchilla recommends β‰₯20x."
)
elif ratio < 20:
recommendations.insert(0,
f"β„Ή Token/param ratio = {ratio:.1f}x β€” "
f"below Chinchilla optimal (20x). More tokens may help."
)
print(f"\n 🟑 {case}")
elif train_decreasing and (val_increasing or not val_decreasing):
case = "Case 3: Overfitting β€” train↓ but valβ†’/↑"
recommendations = [
"Diagnosis priority (check in order):",
"1) Data repetition? (most common cause in pretraining)",
" - Check: total tokens vs unique tokens",
" - Epoch > 1 dramatically increases overfitting risk",
" - Solution: add more data, stay within 1 epoch",
"2) Weight decay too low?",
" - Check: weight_decay value (standard: 0.1)",
" - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1",
" - Experiment: 0.01 / 0.05 / 0.1 / 0.3",
"3) Data diversity?",
" - Single-domain data overfits faster",
" - Mix: web, books, code, wiki, etc.",
"",
"Note on Dropout in LLM pretraining:",
" - Modern LLMs do NOT use dropout in pretraining",
" (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)",
" - Sufficient data is the best regularization",
" - Dropout is useful for fine-tuning on small datasets",
]
print(f"\n 🟑 {case}")
else:
case = "Case 4: Converged β€” loss is low and stable"
recommendations = [
"Training has converged (or reached the data/model limit).",
"To push further: add more data or increase model size.",
]
print(f"\n 🟒 {case}")
for rec in recommendations:
print(f" {rec}")
return {
"level": 4,
"case": case,
"train_trend": "decreasing" if train_decreasing else "flat",
"val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"),
"gap": abs(gap),
"recommendations": recommendations,
}
# ───────────────────────────────────────────────────────────────
# Level 5: Architecture Checks
# ───────────────────────────────────────────────────────────────
@staticmethod
def check_architecture(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
) -> Dict[str, Any]:
"""Check weight initialization and per-layer activation health (Level 5).
Healthy initialization:
- All layers: std β‰ˆ 1.0, mean β‰ˆ 0.0
Problems:
- std increasing per layer β†’ activation explosion (init scale too large)
- std decreasing per layer β†’ activation vanishing (init scale too small)
- Sudden change at specific layer β†’ implementation bug in that layer
"""
print(_header("Level 5: Architecture / Initialization Check"))
batch = next(iter(dataloader))
sample_input = batch["input_ids"][:1].to(device)
model.eval()
layer_stats: List[Dict[str, Any]] = []
with torch.no_grad():
h = model.token_embedding(sample_input)
emb_std = h.float().std().item()
print(f"\n Embedding: std={emb_std:.4f}")
for i, layer in enumerate(model.layers):
h = layer(h, mask=None, position_offset=0)
h_f = h.float()
stats = {
"layer": i,
"mean": h_f.mean().item(),
"std": h_f.std().item(),
"max": h_f.abs().max().item(),
}
layer_stats.append(stats)
# Print stats
print(f"\n Layer-by-layer activation statistics:")
print(f" {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}")
print(f" {'-' * 38}")
for s in layer_stats:
print(f" {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}")
# ── Weight initialization distribution check ──
print(f"\n Weight Initialization Distribution:")
print(f" {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}")
print(f" {'-' * 75}")
weight_issues = []
for name, param in model.named_parameters():
if param.ndim < 2:
continue # skip biases, norm weights
p_f = param.float()
p_mean = p_f.mean().item()
p_std = p_f.std().item()
# Expected: std β‰ˆ 0.02 for most layers, smaller for residual projections
shape_str = str(list(param.shape))
is_residual = "o_proj" in name or "down_proj" in name
expected_std = 0.02 # GPT-2 style
if p_std > expected_std * 5:
weight_issues.append(f"Large std: {name} (std={p_std:.4f})")
print(f" 🟑 {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
elif p_std < expected_std * 0.1:
weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})")
print(f" 🟑 {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}")
else:
print(f" {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}")
if weight_issues:
print(f"\n ⚠ {len(weight_issues)} weight distribution issue(s) found")
for issue in weight_issues[:5]:
print(f" β€’ {issue}")
else:
print(f"\n βœ… All weight distributions look normal (std β‰ˆ 0.02)")
print(f"\n Expected init pattern:")
print(f" β€’ General Linear: N(0, 0.02)")
print(f" β€’ Residual proj (o_proj, down_proj): N(0, 0.02/√(2Γ—layers))")
print(f" β€’ Embedding: N(0, 0.02)")
# ── Ablation study guidance ──
print(f"\n Component Ablation Reference:")
print(" β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ Experiment β”‚ Expected Outcome β”‚")
print(" β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print(" β”‚ RMSNorm β†’ LayerNorm β”‚ Minimal loss diff β†’ OK β”‚")
print(" β”‚ RoPE β†’ Absolute PE β”‚ Similar on short seq (<512) β”‚")
print(" β”‚ SwiGLU β†’ ReLU FFN β”‚ Loss +0.05~0.15 β†’ SwiGLU working β”‚")
print(" β”‚ GQA β†’ MHA β”‚ Same loss, less memory β†’ OK β”‚")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
print(" If any replacement shows unexpected results, check that component.")
# Analyze trends
stds = [s["std"] for s in layer_stats]
diagnosis = "healthy"
detail = ""
if len(stds) >= 3:
# Check for monotonic increase/decrease
first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3)
last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3)
ratio = last_third / max(first_third, 1e-8)
if ratio > 5:
diagnosis = "exploding"
detail = (
f"Activation std grows {ratio:.1f}x from early to late layers. "
f"Init scale may be too large."
)
elif ratio < 0.2:
diagnosis = "vanishing"
detail = (
f"Activation std shrinks to {ratio:.1f}x from early to late layers. "
f"Init scale may be too small."
)
else:
detail = f"Std ratio (last/first third) = {ratio:.2f} β€” within normal range."
# Check for sudden jumps
for i in range(1, len(stds)):
jump = stds[i] / max(stds[i - 1], 1e-8)
if jump > 10 or jump < 0.1:
diagnosis = "anomaly"
detail = (
f"Sudden activation change at layer {i}: "
f"std {stds[i - 1]:.4f} β†’ {stds[i]:.4f}. "
f"Possible implementation bug in that layer."
)
break
icon = {"healthy": "βœ…", "exploding": "πŸ”΄", "vanishing": "🟑", "anomaly": "πŸ”΄"}
print(f"\n {icon.get(diagnosis, 'βšͺ')} Diagnosis: {diagnosis}")
print(f" {detail}")
return {
"level": 5,
"diagnosis": diagnosis,
"detail": detail,
"layer_stats": layer_stats,
"weight_issues": weight_issues,
}
# ───────────────────────────────────────────────────────────────
# Main Entry Point
# ───────────────────────────────────────────────────────────────
@staticmethod
def run_diagnostics(
model: nn.Module,
dataloader: DataLoader,
tokenizer: Any,
train_config: TrainConfig,
metrics_history: Dict[str, list],
device: torch.device,
dtype: torch.dtype = torch.bfloat16,
vocab_size: int = 32000,
levels: Optional[List[int]] = None,
) -> Dict[str, Any]:
"""Run the full 5-level debugging framework.
Args:
model: the LLM model
dataloader: training dataloader
tokenizer: tokenizer with encode/decode methods
train_config: TrainConfig instance
metrics_history: dict from MetricsTracker.history
device: torch device
dtype: mixed precision dtype
vocab_size: model vocabulary size
levels: which levels to run (default: all [0,1,2,3,4,5])
Returns:
Full diagnostic report dict.
"""
if levels is None:
levels = [0, 1, 2, 3, 4, 5]
print("\n" + "═" * 60)
print(" LLM Loss Debugging Framework")
print(" Levels to run: " + ", ".join(str(l) for l in levels))
print("═" * 60)
report: Dict[str, Any] = {}
if 0 in levels:
report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history)
# If status is normal and only level 0 was explicitly requested, skip rest
if (
report["level_0"]["status"] == STATUS_NORMAL
and levels == [0]
):
print("\n Training is healthy β€” no further debugging needed.")
return report
if 1 in levels:
report["level_1"] = LossDebugger.check_data_pipeline(
model, dataloader, tokenizer, vocab_size, device, dtype,
)
if 2 in levels:
report["level_2"] = LossDebugger.check_numerical_stability(
model, dataloader, device, dtype,
)
if 3 in levels:
report["level_3"] = LossDebugger.diagnose_hyperparameters(
metrics_history, train_config,
)
if 4 in levels:
model_params = sum(p.numel() for p in model.parameters())
total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step
report["level_4"] = LossDebugger.diagnose_fitting(
metrics_history, model_params, total_tokens,
)
if 5 in levels:
report["level_5"] = LossDebugger.check_architecture(
model, dataloader, device,
)
# Final summary
print("\n" + "═" * 60)
print(" Diagnostics Complete")
print("═" * 60)
return report
# ───────────────────────────────────────────────────────────────
# Study Roadmap
# ───────────────────────────────────────────────────────────────
@staticmethod
def print_study_roadmap() -> None:
"""Print the recommended study roadmap for LLM training optimization."""
print(_header("Study Roadmap β€” LLM Training Optimization"))
print("""
⭐⭐⭐ Top Priority: Optimization Fundamentals
─────────────────────────────────────────────
1. SGD β†’ Momentum β†’ Adam β†’ AdamW progression
- Why Adam > SGD? Why decouple weight decay in AdamW?
- β₁, Ξ²β‚‚ intuition (1st / 2nd momentum)
- Ref: Loshchilov & Hutter 2019 (AdamW)
- Ref: Karpathy "A Recipe for Training Neural Networks"
2. Loss Landscape
- Why large LR diverges, small LR stalls
- Batch size effect on landscape exploration
- Ref: Li et al. 2018 "Visualizing the Loss Landscape"
- Ref: McCandlish et al. 2018 "Large-Batch Training"
3. Chinchilla Scaling Law
- Loss = f(N, D) relationship
- Compute-optimal model size vs data allocation
- Ref: Hoffmann et al. 2022 (original)
- Ref: Kaplan et al. 2020 (predecessor)
- Ref: Besiroglu et al. 2024 (replication/verification)
⭐⭐ Important: Training Stability
──────────────────────────────────
4. Gradient Flow: vanishing/exploding, residual as gradient highway
5. Weight Init: Xavier / Kaiming / GPT-2 style
6. Normalization: BatchNorm β†’ LayerNorm β†’ RMSNorm
7. Weight Decay: L2 vs decoupled, why exclude embed/norm
⭐ Advanced: Optimization Techniques
─────────────────────────────────────
8. LR Schedules: cosine vs linear vs step, warmup/cooldown
9. Gradient Accumulation & Large Batch Training
10. ΞΌP (Maximal Update Parameterization): transfer HP across scales
Recommended Experiments (in order):
───────────────────────────────────
1. Single-batch overfit (30 min) β†’ basic sanity
2. LR Range Test (1 hour) β†’ optimal LR range
3. 10M model quick train (2-3 hrs) β†’ pipeline validation
4. Ablation (remove components) (1 day) β†’ component contribution
5. 100M model + 5B tokens (1-2 days)β†’ mid-scale dynamics
6. 1B model full training (2-3 days)β†’ scaling law verification
7. LR / batch size comparison (1 day) β†’ HP sensitivity
Key References:
───────────────
⭐⭐⭐ Karpathy "Recipe for Training NNs" β€” debugging mindset
⭐⭐⭐ Hoffmann et al. 2022 (Chinchilla) β€” scaling law
⭐⭐ Touvron et al. 2023 (LLaMA) β€” 1B+ training details
⭐⭐ Biderman et al. 2023 (Pythia) β€” open training logs
⭐⭐ Zhang et al. 2024 (TinyLlama) β€” 1.1B on 3T tokens
⭐⭐ Groeneveld et al. 2024 (OLMo) β€” fully open LLM
⭐⭐ Li et al. 2018 (Loss Landscape) β€” loss terrain intuition
⭐⭐ Loshchilov & Hutter 2019 (AdamW) β€” optimizer basics
⭐ Yang et al. 2022 (ΞΌP) β€” HP transfer
⭐ McCandlish et al. 2018 (Batch size) β€” critical batch size
""")