| """LLM Loss Debugging & Optimization Framework. |
| |
| A systematic 5-level debugging framework for diagnosing training issues. |
| Always start from Level 1 β fixing lower-level bugs before tuning |
| hyperparameters saves time. |
| |
| Levels: |
| 0. Status Diagnosis β classify current training health |
| 1. Data/Implementation β most common cause (70% of issues) |
| 2. Numerical Stability β dtype, normalization, gradient health |
| 3. Hyperparameters β LR, batch size, warmup |
| 4. Fitting Diagnosis β overfitting vs underfitting |
| 5. Architecture β initialization, component checks |
| """ |
|
|
| import copy |
| import math |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from llm_lab.config import TrainConfig |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| _EXPECTED_TRAIN_LOSS = (2.5, 3.3) |
| _EXPECTED_VAL_LOSS = (2.7, 3.6) |
| _EXPECTED_VAL_PPL = (15, 37) |
|
|
| |
| STATUS_NORMAL = "NORMAL" |
| STATUS_NO_DECREASE = "NO_DECREASE" |
| STATUS_DIVERGING = "DIVERGING" |
| STATUS_PLATEAU = "PLATEAU" |
| STATUS_OVERFITTING = "OVERFITTING" |
| STATUS_UNSTABLE = "UNSTABLE" |
| STATUS_NAN_DETECTED = "NAN_DETECTED" |
| STATUS_LOSS_BOUNCE = "LOSS_BOUNCE" |
|
|
| |
| |
| _GPT3_LR_REFERENCE = [ |
| (125e6, 6e-4, "0.5M"), |
| (350e6, 3e-4, "0.5M"), |
| (760e6, 2.5e-4, "0.5M"), |
| (1.3e9, 2e-4, "1M"), |
| (2.7e9, 1.6e-4, "1M"), |
| (6.7e9, 1.2e-4, "2M"), |
| (13e9, 1e-4, "2M"), |
| (175e9, 6e-5, "3.2M"), |
| ] |
|
|
| |
| _LLM_TRAINING_REFS = { |
| "TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000}, |
| "LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000}, |
| "Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.01}, |
| "OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1}, |
| } |
|
|
| |
| _RECOMMENDED_BETA2 = 0.95 |
| _DEFAULT_PYTORCH_BETA2 = 0.999 |
|
|
|
|
| def _header(title: str) -> str: |
| return f"\n{'=' * 60}\n{title}\n{'=' * 60}" |
|
|
|
|
| def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]: |
| return {"name": name, "passed": passed, "detail": detail} |
|
|
|
|
| |
| |
| |
|
|
|
|
| class LossDebugger: |
| """5-level loss debugging framework for LLM training. |
| |
| Usage:: |
| |
| from llm_lab.training.debugger import LossDebugger |
| |
| # Quick status check |
| status = LossDebugger.diagnose_status(vocab_size=32000, |
| metrics_history=trainer.metrics.history) |
| |
| # Full diagnostics |
| report = LossDebugger.run_diagnostics( |
| model=model, dataloader=train_dl, tokenizer=tok, |
| train_config=train_cfg, metrics_history=trainer.metrics.history, |
| device=device, dtype=torch.bfloat16, |
| ) |
| """ |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def diagnose_status( |
| vocab_size: int, |
| metrics_history: Dict[str, list], |
| ) -> Dict[str, Any]: |
| """Classify current training health from metrics history. |
| |
| Args: |
| vocab_size: model vocabulary size (e.g. 32000) |
| metrics_history: dict with keys 'train_loss', 'val_loss', etc. |
| |
| Returns: |
| dict with 'status', 'severity', 'details', 'recommended_levels' |
| """ |
| print(_header("Level 0: Training Status Diagnosis")) |
|
|
| expected_initial = math.log(vocab_size) |
| print(f" Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}") |
| print(f" Normal convergence range (1B, 10B tokens):") |
| print(f" Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}") |
| print(f" Val Loss: {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}") |
| print(f" Val PPL: {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}") |
|
|
| raw_train_losses = metrics_history.get("train_loss", []) |
| train_losses = [l for l in raw_train_losses if not math.isnan(l)] |
| val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None] |
|
|
| if len(train_losses) < 2: |
| print("\n [!] Not enough training data to diagnose. Run more steps first.") |
| return { |
| "status": "INSUFFICIENT_DATA", |
| "severity": "unknown", |
| "details": "Need at least 2 logged train loss values.", |
| "recommended_levels": [1], |
| } |
|
|
| |
| has_nan = len(train_losses) < len(raw_train_losses) |
| if has_nan: |
| nan_count = len(raw_train_losses) - len(train_losses) |
| print(f"\n β {nan_count} NaN values detected in train_loss β filtered for analysis") |
|
|
| first_loss = train_losses[0] |
| last_loss = train_losses[-1] |
| loss_change = first_loss - last_loss |
|
|
| |
| mid = len(train_losses) // 2 |
| first_half_avg = sum(train_losses[:mid]) / mid |
| second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid) |
|
|
| |
| recent_n = min(50, len(train_losses)) |
| recent = train_losses[-recent_n:] |
| recent_mean = sum(recent) / len(recent) |
| recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent) |
| recent_std = recent_var ** 0.5 |
|
|
| |
| val_trend = "unknown" |
| if len(val_losses) >= 2: |
| val_mid = len(val_losses) // 2 |
| val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1) |
| val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1) |
| if val_second_avg < val_first_avg - 0.05: |
| val_trend = "decreasing" |
| elif val_second_avg > val_first_avg + 0.1: |
| val_trend = "increasing" |
| else: |
| val_trend = "flat" |
|
|
| |
| |
| _ma_window = max(1, len(train_losses) // 20) |
| _ma_losses = [ |
| sum(train_losses[max(0, i - _ma_window + 1):i + 1]) |
| / (i - max(0, i - _ma_window + 1) + 1) |
| for i in range(len(train_losses)) |
| ] |
| _min_ma_loss = min(_ma_losses) |
| _min_ma_idx = _ma_losses.index(_min_ma_loss) |
| _last_ma_loss = _ma_losses[-1] |
| _bounce_amount = _last_ma_loss - _min_ma_loss |
| _has_bounce = ( |
| loss_change > 0.1 |
| and _min_ma_idx < len(train_losses) * 0.85 |
| and _bounce_amount > _min_ma_loss * 0.05 |
| ) |
| |
| _val_improving = ( |
| val_trend == "decreasing" |
| or (len(val_losses) >= 4 |
| and val_losses[-1] <= min(val_losses[:len(val_losses) // 2])) |
| ) |
|
|
| |
| status = STATUS_NORMAL |
| severity = "green" |
| details = "" |
| recommended_levels: List[int] = [] |
|
|
| |
| if loss_change < 0.1 and first_loss > expected_initial - 2.0: |
| status = STATUS_NO_DECREASE |
| severity = "red" |
| details = ( |
| f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} " |
| f"(delta={loss_change:.4f}). Likely a data or implementation bug." |
| ) |
| recommended_levels = [1, 2] |
|
|
| |
| elif last_loss > expected_initial + 1.0: |
| status = STATUS_DIVERGING |
| severity = "red" |
| details = ( |
| f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). " |
| f"Training is diverging β check LR, data, or numerical issues." |
| ) |
| recommended_levels = [1, 2, 3] |
|
|
| |
| elif has_nan: |
| nan_count = len(raw_train_losses) - len(train_losses) |
| nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l)) |
| status = STATUS_NAN_DETECTED |
| severity = "red" |
| details = ( |
| f"NaN detected in train_loss: {nan_count} NaN values " |
| f"(first at step ~{nan_idx}). " |
| f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. " |
| f"Check gradient norms, LR schedule, and numerical precision." |
| ) |
| recommended_levels = [2, 3] |
|
|
| |
| elif recent_std > 0.5 * recent_mean: |
| status = STATUS_UNSTABLE |
| severity = "yellow" |
| details = ( |
| f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. " |
| f"Training is unstable β likely LR too high or batch too small." |
| ) |
| recommended_levels = [3, 2] |
|
|
| |
| elif _has_bounce: |
| status = STATUS_LOSS_BOUNCE |
| if _val_improving: |
| severity = "green" |
| details = ( |
| f"Train loss bounced (moving-avg): " |
| f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} " |
| f"(bounce={_bounce_amount:.4f}), but val loss is still improving " |
| f"({val_losses[0]:.4f} -> {val_losses[-1]:.4f}). " |
| f"Likely data distribution variation, not a real issue." |
| ) |
| recommended_levels = [] |
| else: |
| severity = "yellow" |
| details = ( |
| f"Train loss bounced (moving-avg): " |
| f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} " |
| f"(bounce={_bounce_amount:.4f}). " |
| f"Possible LR too high, data issue, or overfitting." |
| ) |
| recommended_levels = [3, 4] |
|
|
| |
| elif val_trend == "increasing" and second_half_avg < first_half_avg: |
| status = STATUS_OVERFITTING |
| severity = "yellow" |
| details = ( |
| f"Train loss decreasing but val loss increasing. " |
| f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, " |
| f"Val trend: {val_trend}." |
| ) |
| recommended_levels = [4] |
|
|
| |
| elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]: |
| status = STATUS_PLATEAU |
| severity = "yellow" |
| details = ( |
| f"Loss has plateaued: first half avg={first_half_avg:.4f}, " |
| f"second half avg={second_half_avg:.4f}. " |
| f"Current loss ({last_loss:.4f}) is above expected range." |
| ) |
| recommended_levels = [3, 4, 5] |
|
|
| |
| else: |
| status = STATUS_NORMAL |
| severity = "green" |
| details = ( |
| f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} " |
| f"(delta={loss_change:.4f}). Val trend: {val_trend}." |
| ) |
| recommended_levels = [] |
|
|
| |
| icons = {"red": "π΄", "yellow": "π‘", "green": "π’"} |
| icon = icons.get(severity, "βͺ") |
| print(f"\n {icon} Status: {status}") |
| print(f" {details}") |
| if recommended_levels: |
| print(f" Recommended: check Level(s) {recommended_levels}") |
|
|
| return { |
| "status": status, |
| "severity": severity, |
| "details": details, |
| "recommended_levels": recommended_levels, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def check_data_pipeline( |
| model: nn.Module, |
| dataloader: DataLoader, |
| tokenizer: Any, |
| vocab_size: int, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> Dict[str, Any]: |
| """Run 6 data/implementation checks (Level 1). |
| |
| This is the most important level β 70% of loss issues are data bugs. |
| |
| Checks: |
| 1. Shift relationship (targets[t] == input_ids[t+1]) |
| 2. Token range (0 <= ids < vocab_size) |
| 3. Initial loss (β ln(vocab_size) for random weights) |
| 4. Single-batch overfit (loss β ~0 in 200 steps) |
| 5. Tokenizer roundtrip (encodeβdecode preserves text) |
| 6. Data quality sampling (visual inspection) |
| """ |
| print(_header("Level 1: Data / Implementation Bug Checks")) |
| print(" (70% of loss issues come from data pipeline bugs)\n") |
|
|
| results: List[Dict[str, Any]] = [] |
| batch = next(iter(dataloader)) |
| input_ids = batch["input_ids"] |
| targets = batch["targets"] |
|
|
| |
| shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item() |
| passed = shift_match > 0.99 |
| detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)" |
| results.append(_check_result("Shift relationship", passed, detail)) |
| icon = "β
" if passed else "β" |
| print(f" {icon} Check 1: {detail}") |
|
|
| |
| min_id = input_ids.min().item() |
| max_id = input_ids.max().item() |
| range_ok = min_id >= 0 and max_id < vocab_size |
| detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}" |
| results.append(_check_result("Token range", range_ok, detail)) |
| icon = "β
" if range_ok else "β" |
| print(f" {icon} Check 2: {detail}") |
|
|
| |
| expected_loss = math.log(vocab_size) |
| model_copy = copy.deepcopy(model) |
| model_copy._init_weights() |
| model_copy.to(device) |
| model_copy.eval() |
| with torch.no_grad(): |
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| _, initial_loss = model_copy( |
| input_ids.to(device), |
| targets.to(device), |
| ) |
| initial_loss_val = initial_loss.item() |
| loss_diff = abs(initial_loss_val - expected_loss) |
| loss_ok = loss_diff < 1.0 |
| detail = ( |
| f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} " |
| f"(diff={loss_diff:.4f})" |
| ) |
| results.append(_check_result("Initial loss", loss_ok, detail)) |
| icon = "β
" if loss_ok else "β" |
| print(f" {icon} Check 3: {detail}") |
| if initial_loss_val > expected_loss + 1.0: |
| print(f" Hint: loss >> ln(V) suggests label mismatch or loss function bug") |
| elif initial_loss_val < expected_loss - 2.0: |
| print(f" Hint: loss << ln(V) suggests data leakage") |
| del model_copy |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| |
| num_params = sum(p.numel() for p in model.parameters()) |
| if num_params > 500e6: |
| overfit_lr, overfit_steps = 1e-4, 400 |
| elif num_params > 50e6: |
| overfit_lr, overfit_steps = 3e-4, 300 |
| else: |
| overfit_lr, overfit_steps = 1e-3, 200 |
| print(f"\n β³ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...") |
| overfit_model = copy.deepcopy(model) |
| overfit_model.to(device) |
| overfit_model.train() |
| overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr) |
| single_input = input_ids[:1].to(device) |
| single_target = targets[:1].to(device) |
| log_interval = max(overfit_steps // 4, 1) |
|
|
| overfit_losses = [] |
| for step in range(overfit_steps): |
| overfit_optimizer.zero_grad() |
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| _, loss = overfit_model(single_input, single_target) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0) |
| overfit_optimizer.step() |
| overfit_losses.append(loss.item()) |
| if (step + 1) % log_interval == 0: |
| print(f" Step {step + 1}: Loss = {loss.item():.4f}") |
|
|
| final_overfit_loss = overfit_losses[-1] |
| min_overfit_loss = min(overfit_losses) |
| overfit_ok = min_overfit_loss < 0.5 |
| detail = ( |
| f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} " |
| f"(min={min_overfit_loss:.4f}, target < 0.5)" |
| ) |
| results.append(_check_result("Single-batch overfit", overfit_ok, detail)) |
| icon = "β
" if overfit_ok else "β" |
| print(f" {icon} Check 4: {detail}") |
| if not overfit_ok: |
| print(f" CRITICAL: Model cannot memorize a single batch!") |
| print(f" This means the model or loss function has a bug.") |
| del overfit_model, overfit_optimizer |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| test_text = "The quick brown fox jumps over the lazy dog." |
| encoded = tokenizer.encode(test_text) |
| decoded = tokenizer.decode(encoded) |
| roundtrip_ok = test_text.strip() in decoded.strip() |
| detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'" |
| results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail)) |
| icon = "β
" if roundtrip_ok else "β" |
| print(f" {icon} Check 5: {detail}") |
|
|
| |
| print(f"\n π Check 6: Data quality sampling (visual inspection)") |
| for i in range(min(3, input_ids.shape[0])): |
| sample_tokens = input_ids[i][:100].tolist() |
| decoded_text = tokenizer.decode(sample_tokens) |
| preview = decoded_text[:200].replace("\n", "\\n") |
| print(f" Sample {i}: {preview}...") |
|
|
| passed_count = sum(1 for r in results if r["passed"]) |
| total_count = len(results) |
| print(f"\n Result: {passed_count}/{total_count} checks passed") |
|
|
| return { |
| "level": 1, |
| "checks": results, |
| "passed": [r for r in results if r["passed"]], |
| "failed": [r for r in results if not r["passed"]], |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def check_numerical_stability( |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> Dict[str, Any]: |
| """Check for NaN/Inf in gradients, activations, and logits (Level 2). |
| |
| Checks: |
| - Mixed precision config (RMSNorm fp32 upcast, loss dtype) |
| - NaN/Inf gradients β softmax overflow, bad data |
| - Inf gradients β log(0) in loss, missing ignore_index |
| - Large activations growing per layer β initialization or norm bug |
| - Logit scale β should be < 1000 |
| """ |
| print(_header("Level 2: Numerical Stability Checks")) |
|
|
| batch = next(iter(dataloader)) |
| input_ids = batch["input_ids"].to(device) |
| targets = batch["targets"].to(device) |
|
|
| results: List[Dict[str, Any]] = [] |
| activation_stats: List[Dict[str, Any]] = [] |
|
|
| |
| print("\n Mixed Precision Config:") |
| print(f" Training dtype: {dtype}") |
|
|
| |
| norm_fp32_ok = True |
| checked_norm_classes: set = set() |
| for name, module in model.named_modules(): |
| cls_name = module.__class__.__name__ |
| if "Norm" in cls_name and cls_name not in checked_norm_classes: |
| checked_norm_classes.add(cls_name) |
| import inspect |
| try: |
| src = inspect.getsource(type(module).forward) |
| has_upcast = ".float()" in src or "float32" in src |
| except (TypeError, OSError): |
| has_upcast = True |
| if not has_upcast: |
| norm_fp32_ok = False |
| print(f" π΄ {cls_name}: no fp32 upcast detected!") |
| if norm_fp32_ok: |
| print(f" β
Norm layers use fp32 upcast (safe)") |
|
|
| results.append(_check_result( |
| "Norm fp32 upcast", norm_fp32_ok, |
| "Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype", |
| )) |
|
|
| |
| if dtype in (torch.bfloat16, torch.float16): |
| print(f" βΉοΈ Best practice: compute loss in fp32 when using {dtype}") |
| print(f" logits_fp32 = logits.float()") |
| print(f" loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))") |
|
|
| |
| print("\n Common Numerical Issues Reference:") |
| print(" ββββββββββββββββββββββββ¬βββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββ") |
| print(" β Symptom β Likely Cause β Solution β") |
| print(" ββββββββββββββββββββββββΌβββββββββββββββββββββββββββΌββββββββββββββββββββββββββ€") |
| print(" β Loss β NaN β Large logits β softmax β Check init, logit scale β") |
| print(" β Loss β Inf β log(0) in CE loss β Add eps, ignore_index β") |
| print(" β Loss oscillation β fp16 gradient underflow β Switch to bf16 / scaler β") |
| print(" β Late-training NaN β Activation growth β Check RMSNorm, wd β") |
| print(" ββββββββββββββββββββββββ΄βββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββ") |
|
|
| |
| hooks = [] |
|
|
| def make_hook(name: str): |
| def hook_fn(module, input, output): |
| if isinstance(output, torch.Tensor): |
| out_f = output.float() |
| stats = { |
| "name": name, |
| "mean": out_f.mean().item(), |
| "std": out_f.std().item(), |
| "max": out_f.abs().max().item(), |
| "has_nan": bool(torch.isnan(output).any()), |
| "has_inf": bool(torch.isinf(output).any()), |
| } |
| activation_stats.append(stats) |
| return hook_fn |
|
|
| |
| for i, layer in enumerate(model.layers): |
| h = layer.register_forward_hook(make_hook(f"layer_{i}")) |
| hooks.append(h) |
|
|
| |
| model.train() |
| model.zero_grad(set_to_none=True) |
| use_scaler = dtype == torch.float16 and torch.cuda.is_available() |
| scaler = torch.amp.GradScaler() if use_scaler else None |
|
|
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| logits, loss = model(input_ids, targets) |
|
|
| loss_val = loss.item() |
| loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val)) |
| results.append(_check_result( |
| "Loss value", |
| loss_ok, |
| f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)" |
| )) |
|
|
| if scaler is not None: |
| scaler.scale(loss).backward() |
| _temp_opt = torch.optim.SGD(model.parameters(), lr=0) |
| scaler.unscale_(_temp_opt) |
| else: |
| loss.backward() |
|
|
| |
| for h in hooks: |
| h.remove() |
|
|
| |
| print("\n Gradient Health:") |
| grad_issues = [] |
| for name, param in model.named_parameters(): |
| if param.grad is None: |
| continue |
| grad = param.grad |
| if torch.isnan(grad).any(): |
| grad_issues.append(f"π΄ NaN gradient: {name}") |
| if torch.isinf(grad).any(): |
| grad_issues.append(f"π΄ Inf gradient: {name}") |
| if grad.abs().max().item() > 100: |
| grad_issues.append( |
| f"π‘ Large gradient: {name} max={grad.abs().max().item():.1f}" |
| ) |
|
|
| grad_ok = len(grad_issues) == 0 |
| if grad_ok: |
| print(" β
All gradients are healthy (no NaN/Inf/large values)") |
| else: |
| for issue in grad_issues[:10]: |
| print(f" {issue}") |
| if len(grad_issues) > 10: |
| print(f" ... and {len(grad_issues) - 10} more issues") |
|
|
| results.append(_check_result( |
| "Gradient health", |
| grad_ok, |
| f"{len(grad_issues)} issues found" if not grad_ok else "All healthy", |
| )) |
|
|
| |
| print("\n Activation Stats (per transformer layer):") |
| act_nan_count = 0 |
| for stats in activation_stats: |
| icon = "π΄" if stats["has_nan"] or stats["has_inf"] else " " |
| if stats["has_nan"] or stats["has_inf"]: |
| act_nan_count += 1 |
| print( |
| f" {icon} {stats['name']}: " |
| f"mean={stats['mean']:.4f}, " |
| f"std={stats['std']:.4f}, " |
| f"max={stats['max']:.4f}" |
| + (" [NaN!]" if stats["has_nan"] else "") |
| + (" [Inf!]" if stats["has_inf"] else "") |
| ) |
|
|
| act_ok = act_nan_count == 0 |
| results.append(_check_result( |
| "Activation health", |
| act_ok, |
| f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy", |
| )) |
|
|
| |
| if len(activation_stats) >= 2: |
| stds = [s["std"] for s in activation_stats] |
| if stds[0] > 1e-8: |
| growth_ratio = stds[-1] / stds[0] |
| growth_ok = growth_ratio < 10 |
| detail = ( |
| f"Activation std ratio (last/first): {growth_ratio:.1f}x " |
| f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})" |
| ) |
| results.append(_check_result("Activation growth", growth_ok, detail)) |
| icon = "β
" if growth_ok else "π‘" |
| print(f" {icon} {detail}") |
| if not growth_ok: |
| print(f" Possible initialization or normalization issue") |
|
|
| |
| logit_max = logits.float().abs().max().item() |
| logit_ok = logit_max < 1000 |
| detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)" |
| results.append(_check_result("Logit scale", logit_ok, detail)) |
| icon = "β
" if logit_ok else "π΄" |
| print(f"\n {icon} Logit scale: {detail}") |
|
|
| model.zero_grad(set_to_none=True) |
|
|
| passed_count = sum(1 for r in results if r["passed"]) |
| print(f"\n Result: {passed_count}/{len(results)} checks passed") |
|
|
| return { |
| "level": 2, |
| "checks": results, |
| "activation_stats": activation_stats, |
| "grad_issues": grad_issues, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def diagnose_hyperparameters( |
| metrics_history: Dict[str, list], |
| config: TrainConfig, |
| ) -> Dict[str, Any]: |
| """Analyze hyperparameter health from training metrics (Level 3). |
| |
| Checks: |
| - LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny) |
| - Batch size: loss variance indicates batch too small |
| - Warmup: spikes in early steps indicate warmup too short |
| """ |
| print(_header("Level 3: Hyperparameter Diagnosis")) |
|
|
| findings: List[Dict[str, str]] = [] |
| grad_norms = metrics_history.get("grad_norm", []) |
| train_losses = metrics_history.get("train_loss", []) |
|
|
| |
| print("\n Learning Rate Analysis:") |
| print(f" Peak LR: {config.learning_rate:.2e}") |
| print(f" Min LR: {config.min_learning_rate:.2e}") |
|
|
| if grad_norms: |
| avg_grad = sum(grad_norms) / len(grad_norms) |
| |
| clip_count = sum(1 for g in grad_norms if g >= config.grad_clip) |
| clip_rate = clip_count / len(grad_norms) |
| |
| tiny_threshold = config.grad_clip * 0.01 |
| tiny_count = sum(1 for g in grad_norms if g < tiny_threshold) |
| tiny_rate = tiny_count / len(grad_norms) |
|
|
| print(f" Avg grad norm: {avg_grad:.4f}") |
| print(f" Clip rate: {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})") |
| print(f" Tiny grad rate: {tiny_rate * 100:.1f}% (< {tiny_threshold:.4f})") |
|
|
| |
| |
| |
| if clip_rate > 0.5: |
| findings.append({ |
| "issue": "LR may be too high", |
| "evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time", |
| "action": f"Try LR = {config.learning_rate / 2:.2e} (Γ·2)", |
| }) |
| print(f" π‘ Grad clipping frequent ({clip_rate * 100:.0f}%) β LR may be too high") |
| elif tiny_rate > 0.5: |
| findings.append({ |
| "issue": "Possible vanishing gradients", |
| "evidence": f"Grad norm < {tiny_threshold:.4f} in {tiny_rate * 100:.0f}% of steps", |
| "action": "Check weight initialization, layer norms, and model depth", |
| }) |
| print(f" π‘ Grad norm too small ({tiny_rate * 100:.0f}% < {tiny_threshold:.4f}) β possible vanishing gradients") |
| else: |
| print(f" β
LR looks appropriate") |
|
|
| |
| print("\n Batch Size Analysis:") |
| print(f" Effective batch: {config.effective_batch_size}") |
|
|
| if len(train_losses) >= 50: |
| recent_losses = train_losses[-50:] |
| loss_mean = sum(recent_losses) / len(recent_losses) |
| loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses) |
| loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8) |
|
|
| print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation, last 50 steps)") |
|
|
| if loss_cv > 0.1: |
| findings.append({ |
| "issue": "Training loss has high variance", |
| "evidence": f"Loss CV = {loss_cv:.4f} over last 50 steps", |
| "action": "Check: (1) LR may be too high, (2) increase gradient_accumulation_steps, (3) inspect data quality", |
| }) |
| print(f" π‘ High loss variance β check LR, batch size, or data quality") |
| else: |
| print(f" β
Loss variance is acceptable") |
|
|
| |
| print("\n Ξ²β (Adam second momentum) Analysis:") |
| print(f" Current Ξ²β: {config.beta2}") |
| if config.beta2 >= _DEFAULT_PYTORCH_BETA2: |
| findings.append({ |
| "issue": "Ξ²β may be too high for LLM training", |
| "evidence": ( |
| f"Ξ²β={config.beta2} (PyTorch default). " |
| f"LLM standard is {_RECOMMENDED_BETA2}" |
| ), |
| "action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)", |
| }) |
| print(f" π‘ Ξ²β={config.beta2} is PyTorch default β " |
| f"LLM training standard is {_RECOMMENDED_BETA2}") |
| print(f" Why: Ξ²β=0.999 averages ~1000 steps of gradient stats,") |
| print(f" Ξ²β=0.95 averages ~20 steps β faster adaptation to changing data") |
| print(f" (Cattaneo & Shigida 2025, 'Tuning Adam(W)')") |
| else: |
| print(f" β
Ξ²β={config.beta2} is within LLM standard range") |
|
|
| |
| print("\n Weight Decay Analysis:") |
| print(f" Current weight_decay: {config.weight_decay}") |
| if config.weight_decay == 0: |
| findings.append({ |
| "issue": "Weight decay is disabled", |
| "evidence": "weight_decay=0 increases overfitting risk", |
| "action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)", |
| }) |
| print(f" π‘ weight_decay=0 β overfitting risk. Standard is 0.1") |
| elif config.weight_decay > 0.3: |
| findings.append({ |
| "issue": "Weight decay may be too high", |
| "evidence": f"weight_decay={config.weight_decay} (unusually high)", |
| "action": "Try weight_decay=0.1 (standard value)", |
| }) |
| print(f" π‘ weight_decay={config.weight_decay} is unusually high (standard: 0.1)") |
| else: |
| print(f" β
weight_decay={config.weight_decay} is within normal range") |
|
|
| |
| print("\n GPT-3 LR Reference (Brown et al. 2020):") |
| print(" ββββββββββββ¬ββββββββββββ¬βββββββββββββββ") |
| print(" β Model β Peak LR β Batch Tokens β") |
| print(" ββββββββββββΌββββββββββββΌβββββββββββββββ€") |
| for params, lr, batch_tok in _GPT3_LR_REFERENCE: |
| label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M" |
| marker = " β" if abs(params - 1.1e9) < 0.5e9 else "" |
| print(f" β {label:<8} β {lr:.1e} β {batch_tok:<12} β{marker}") |
| print(" ββββββββββββ΄ββββββββββββ΄βββββββββββββββ") |
| print(" β Larger models need lower LR and larger batch") |
|
|
| |
| print("\n Batch-LR Scaling Rules:") |
| print(" β’ Batch Γ2 β LR Γβ2 (square root scaling, recommended for Adam)") |
| print(" (Malladi et al. NeurIPS 2022, 'On the SDEs and Scaling Rules for Adaptive Gradient Algorithms')") |
| print(" β’ Batch Γ2 β LR Γ2 (linear scaling, Goyal et al. 2017, mainly SGD)") |
| print(" β’ 1B model: ~1K-2K sequences (~2-4M tokens) is typical") |
| print(" (Pythia-1B: ~2M tokens, TinyLlama: ~2M, OLMo-1B: ~4M)") |
|
|
| |
| print("\n Warmup Analysis:") |
| print(f" Warmup steps: {config.warmup_steps} " |
| f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)") |
|
|
| if len(train_losses) >= 10: |
| early_losses = train_losses[:min(50, len(train_losses))] |
| |
| spike_count = 0 |
| for i in range(1, len(early_losses)): |
| if early_losses[i] > early_losses[i - 1] * 1.5: |
| spike_count += 1 |
|
|
| if spike_count > 3: |
| findings.append({ |
| "issue": "Warmup may be too short", |
| "evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps", |
| "action": f"Try warmup_steps = {config.warmup_steps * 2}", |
| }) |
| print(f" π‘ {spike_count} spikes in early training β warmup may be too short") |
| else: |
| print(f" β
Early training is stable") |
|
|
| |
| if not findings: |
| print("\n β
No hyperparameter issues detected") |
| else: |
| print(f"\n Found {len(findings)} potential issue(s):") |
| for f in findings: |
| print(f" β’ {f['issue']}: {f['action']}") |
|
|
| |
| print("\n Warmup Reference (real projects):") |
| print(" β’ TinyLlama 1.1B (3T tokens): 2,000 steps β 0.1% of total") |
| print(" β’ GPT-3 175B: 375M warmup tokens β 117 steps") |
| print(" β’ General range: 0.1% ~ 5% of total steps") |
| print(" β’ Smaller experiments: 5~10% is also reasonable") |
|
|
| print("\n Tuning priority (high β low):") |
| print(" 1. Learning Rate β tune first (10x impact)") |
| print(" 2. Batch Size β adjust with LR") |
| print(" 3. Warmup Steps β early stability") |
| print(" 4. Weight Decay β if overfitting (typically 0.1)") |
| print(" 5. Ξ²β, Ξ²β (Adam) β see Ξ²β analysis above") |
| print(" 6. Gradient Clip β usually keep at 1.0") |
|
|
| return { |
| "level": 3, |
| "findings": findings, |
| "config_summary": { |
| "learning_rate": config.learning_rate, |
| "effective_batch": config.effective_batch_size, |
| "warmup_steps": config.warmup_steps, |
| "total_steps": config.total_steps, |
| "grad_clip": config.grad_clip, |
| }, |
| } |
|
|
| @staticmethod |
| def lr_range_test( |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| lr_start: float = 1e-7, |
| lr_end: float = 1e-1, |
| steps: int = 300, |
| ) -> Dict[str, Any]: |
| """Run an LR range test to find the optimal learning rate (Level 3 bonus). |
| |
| Sweeps LR from lr_start to lr_end exponentially, recording loss. |
| The optimal LR is where loss decreases fastest (steepest slope), |
| divided by 3~10 for stability. |
| |
| WARNING: This modifies a copy of the model. The original is untouched. |
| """ |
| print(_header("Level 3 Bonus: LR Range Test")) |
| print(f" Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n") |
|
|
| test_model = copy.deepcopy(model) |
| test_model.to(device) |
| test_model.train() |
| optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start) |
|
|
| lr_mult = (lr_end / lr_start) ** (1 / steps) |
| lr = lr_start |
|
|
| lrs: List[float] = [] |
| losses: List[float] = [] |
| data_iter = iter(dataloader) |
|
|
| for step in range(steps): |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| batch = next(data_iter) |
|
|
| input_ids = batch["input_ids"].to(device) |
| targets_t = batch["targets"].to(device) |
|
|
| optimizer.zero_grad() |
| with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): |
| _, loss = test_model(input_ids, targets_t) |
| loss.backward() |
| optimizer.step() |
|
|
| loss_val = loss.item() |
| lrs.append(lr) |
| losses.append(loss_val) |
|
|
| if (step + 1) % 50 == 0: |
| print(f" Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}") |
|
|
| |
| if len(losses) > 1 and loss_val > losses[0] * 4: |
| print(f" Loss exploded at LR = {lr:.2e}, stopping.") |
| break |
|
|
| lr *= lr_mult |
|
|
| del test_model, optimizer |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| best_lr = lr_start |
| if len(losses) > 10: |
| |
| window = 5 |
| smoothed = [] |
| for i in range(len(losses) - window): |
| smoothed.append(sum(losses[i:i + window]) / window) |
|
|
| min_slope = 0 |
| min_idx = 0 |
| for i in range(1, len(smoothed)): |
| slope = smoothed[i] - smoothed[i - 1] |
| if slope < min_slope: |
| min_slope = slope |
| min_idx = i |
|
|
| best_lr = lrs[min_idx] |
| suggested_lr = best_lr / 3 |
|
|
| print(f"\n Steepest descent at LR = {best_lr:.2e}") |
| print(f" Suggested peak LR: {suggested_lr:.2e} (Γ·3 for stability)") |
| print(f" Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]") |
| else: |
| suggested_lr = 3e-4 |
| print(f"\n Not enough data points. Using default LR = {suggested_lr:.2e}") |
|
|
| return { |
| "lrs": lrs, |
| "losses": losses, |
| "best_lr": best_lr, |
| "suggested_lr": suggested_lr, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def diagnose_fitting( |
| metrics_history: Dict[str, list], |
| model_params: Optional[int] = None, |
| total_tokens: Optional[int] = None, |
| ) -> Dict[str, Any]: |
| """Diagnose overfitting vs underfitting from metrics (Level 4). |
| |
| Cases: |
| 1. Both high, decreasing β Normal (still training) |
| 2. Both high, plateau β Underfitting |
| 3. Trainβ Valβ or Valβ β Overfitting |
| 4. Both low, plateau β Converged (or at limit) |
| """ |
| print(_header("Level 4: Overfitting vs Underfitting Diagnosis")) |
|
|
| train_losses = metrics_history.get("train_loss", []) |
| val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None] |
|
|
| if len(train_losses) < 10 or len(val_losses) < 2: |
| print(" [!] Not enough data. Need more training steps with eval.") |
| return {"level": 4, "case": "insufficient_data", "recommendations": []} |
|
|
| |
| recent_n = min(50, len(train_losses)) |
| train_recent = train_losses[-recent_n:] |
| train_mid = len(train_recent) // 2 |
| train_first = sum(train_recent[:train_mid]) / max(train_mid, 1) |
| train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1) |
| train_decreasing = train_second < train_first - 0.02 |
|
|
| |
| val_mid = len(val_losses) // 2 |
| val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1) |
| val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1) |
| val_decreasing = val_second < val_first - 0.02 |
| val_increasing = val_second > val_first + 0.05 |
|
|
| |
| last_train = train_losses[-1] |
| last_val = val_losses[-1] |
| gap = last_train - last_val |
|
|
| print(f" Train loss (recent): {train_first:.4f} β {train_second:.4f} " |
| f"({'β' if train_decreasing else 'β'})") |
| print(f" Val loss: {val_first:.4f} β {val_second:.4f} " |
| f"({'β' if val_decreasing else 'β' if val_increasing else 'β'})") |
| print(f" Train-Val gap: {abs(gap):.4f}") |
|
|
| |
| case = "" |
| recommendations: List[str] = [] |
|
|
| if train_decreasing and val_decreasing: |
| case = "Case 1: Normal β both decreasing" |
| recommendations.append("Training is progressing normally. Continue.") |
| if model_params and total_tokens: |
| ratio = total_tokens / model_params |
| chinchilla = 20 |
| if ratio < chinchilla: |
| recommendations.append( |
| f"Token/param ratio = {ratio:.1f}x " |
| f"(Chinchilla optimal β {chinchilla}x). " |
| f"Model may benefit from more data." |
| ) |
| print(f"\n π’ {case}") |
|
|
| elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]: |
| case = "Case 2: Underfitting β both plateaued at high loss" |
| recommendations = [ |
| "Diagnosis priority (check in order):", |
| "1) Training insufficient? β check if loss curve still has downward slope", |
| " - Chinchilla: 1B model needs ~20B tokens minimum", |
| " - TinyLlama trains 1.1B on 3T tokens (inference-optimal)", |
| "2) LR too low? β try LR Γ2, see if loss drops faster", |
| "3) Model capacity too small? β train 2x larger model on same data", |
| " - If larger model gets lower loss β capacity was the limit", |
| "4) Data quality? β sample and read training data manually", |
| " - Noisy/low-quality data raises the achievable loss floor", |
| ] |
| if model_params and total_tokens: |
| ratio = total_tokens / model_params |
| if ratio < 10: |
| recommendations.insert(0, |
| f"β Token/param ratio = {ratio:.1f}x β " |
| f"very likely undertrained. Chinchilla recommends β₯20x." |
| ) |
| elif ratio < 20: |
| recommendations.insert(0, |
| f"βΉ Token/param ratio = {ratio:.1f}x β " |
| f"below Chinchilla optimal (20x). More tokens may help." |
| ) |
| print(f"\n π‘ {case}") |
|
|
| elif train_decreasing and (val_increasing or not val_decreasing): |
| case = "Case 3: Overfitting β trainβ but valβ/β" |
| recommendations = [ |
| "Diagnosis priority (check in order):", |
| "1) Data repetition? (most common cause in pretraining)", |
| " - Check: total tokens vs unique tokens", |
| " - Epoch > 1 dramatically increases overfitting risk", |
| " - Solution: add more data, stay within 1 epoch", |
| "2) Weight decay too low?", |
| " - Check: weight_decay value (standard: 0.1)", |
| " - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1", |
| " - Experiment: 0.01 / 0.05 / 0.1 / 0.3", |
| "3) Data diversity?", |
| " - Single-domain data overfits faster", |
| " - Mix: web, books, code, wiki, etc.", |
| "", |
| "Note on Dropout in LLM pretraining:", |
| " - Modern LLMs do NOT use dropout in pretraining", |
| " (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)", |
| " - Sufficient data is the best regularization", |
| " - Dropout is useful for fine-tuning on small datasets", |
| ] |
| print(f"\n π‘ {case}") |
|
|
| else: |
| case = "Case 4: Converged β loss is low and stable" |
| recommendations = [ |
| "Training has converged (or reached the data/model limit).", |
| "To push further: add more data or increase model size.", |
| ] |
| print(f"\n π’ {case}") |
|
|
| for rec in recommendations: |
| print(f" {rec}") |
|
|
| return { |
| "level": 4, |
| "case": case, |
| "train_trend": "decreasing" if train_decreasing else "flat", |
| "val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"), |
| "gap": abs(gap), |
| "recommendations": recommendations, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def check_architecture( |
| model: nn.Module, |
| dataloader: DataLoader, |
| device: torch.device, |
| ) -> Dict[str, Any]: |
| """Check weight initialization and per-layer activation health (Level 5). |
| |
| Healthy initialization: |
| - All layers: std β 1.0, mean β 0.0 |
| Problems: |
| - std increasing per layer β activation explosion (init scale too large) |
| - std decreasing per layer β activation vanishing (init scale too small) |
| - Sudden change at specific layer β implementation bug in that layer |
| """ |
| print(_header("Level 5: Architecture / Initialization Check")) |
|
|
| batch = next(iter(dataloader)) |
| sample_input = batch["input_ids"][:1].to(device) |
|
|
| model.eval() |
| layer_stats: List[Dict[str, Any]] = [] |
|
|
| with torch.no_grad(): |
| h = model.token_embedding(sample_input) |
| emb_std = h.float().std().item() |
| print(f"\n Embedding: std={emb_std:.4f}") |
|
|
| for i, layer in enumerate(model.layers): |
| h = layer(h, mask=None, position_offset=0) |
| h_f = h.float() |
| stats = { |
| "layer": i, |
| "mean": h_f.mean().item(), |
| "std": h_f.std().item(), |
| "max": h_f.abs().max().item(), |
| } |
| layer_stats.append(stats) |
|
|
| |
| print(f"\n Layer-by-layer activation statistics:") |
| print(f" {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}") |
| print(f" {'-' * 38}") |
| for s in layer_stats: |
| print(f" {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}") |
|
|
| |
| print(f"\n Weight Initialization Distribution:") |
| print(f" {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}") |
| print(f" {'-' * 75}") |
| weight_issues = [] |
| for name, param in model.named_parameters(): |
| if param.ndim < 2: |
| continue |
| p_f = param.float() |
| p_mean = p_f.mean().item() |
| p_std = p_f.std().item() |
| |
| shape_str = str(list(param.shape)) |
| is_residual = "o_proj" in name or "down_proj" in name |
| expected_std = 0.02 |
| if p_std > expected_std * 5: |
| weight_issues.append(f"Large std: {name} (std={p_std:.4f})") |
| print(f" π‘ {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}") |
| elif p_std < expected_std * 0.1: |
| weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})") |
| print(f" π‘ {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}") |
| else: |
| print(f" {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}") |
|
|
| if weight_issues: |
| print(f"\n β {len(weight_issues)} weight distribution issue(s) found") |
| for issue in weight_issues[:5]: |
| print(f" β’ {issue}") |
| else: |
| print(f"\n β
All weight distributions look normal (std β 0.02)") |
|
|
| print(f"\n Expected init pattern:") |
| print(f" β’ General Linear: N(0, 0.02)") |
| print(f" β’ Residual proj (o_proj, down_proj): N(0, 0.02/β(2Γlayers))") |
| print(f" β’ Embedding: N(0, 0.02)") |
|
|
| |
| print(f"\n Component Ablation Reference:") |
| print(" ββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ") |
| print(" β Experiment β Expected Outcome β") |
| print(" ββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββ€") |
| print(" β RMSNorm β LayerNorm β Minimal loss diff β OK β") |
| print(" β RoPE β Absolute PE β Similar on short seq (<512) β") |
| print(" β SwiGLU β ReLU FFN β Loss +0.05~0.15 β SwiGLU working β") |
| print(" β GQA β MHA β Same loss, less memory β OK β") |
| print(" ββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββ") |
| print(" If any replacement shows unexpected results, check that component.") |
|
|
| |
| stds = [s["std"] for s in layer_stats] |
| diagnosis = "healthy" |
| detail = "" |
|
|
| if len(stds) >= 3: |
| |
| first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3) |
| last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3) |
| ratio = last_third / max(first_third, 1e-8) |
|
|
| if ratio > 5: |
| diagnosis = "exploding" |
| detail = ( |
| f"Activation std grows {ratio:.1f}x from early to late layers. " |
| f"Init scale may be too large." |
| ) |
| elif ratio < 0.2: |
| diagnosis = "vanishing" |
| detail = ( |
| f"Activation std shrinks to {ratio:.1f}x from early to late layers. " |
| f"Init scale may be too small." |
| ) |
| else: |
| detail = f"Std ratio (last/first third) = {ratio:.2f} β within normal range." |
|
|
| |
| for i in range(1, len(stds)): |
| jump = stds[i] / max(stds[i - 1], 1e-8) |
| if jump > 10 or jump < 0.1: |
| diagnosis = "anomaly" |
| detail = ( |
| f"Sudden activation change at layer {i}: " |
| f"std {stds[i - 1]:.4f} β {stds[i]:.4f}. " |
| f"Possible implementation bug in that layer." |
| ) |
| break |
|
|
| icon = {"healthy": "β
", "exploding": "π΄", "vanishing": "π‘", "anomaly": "π΄"} |
| print(f"\n {icon.get(diagnosis, 'βͺ')} Diagnosis: {diagnosis}") |
| print(f" {detail}") |
|
|
| return { |
| "level": 5, |
| "diagnosis": diagnosis, |
| "detail": detail, |
| "layer_stats": layer_stats, |
| "weight_issues": weight_issues, |
| } |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def run_diagnostics( |
| model: nn.Module, |
| dataloader: DataLoader, |
| tokenizer: Any, |
| train_config: TrainConfig, |
| metrics_history: Dict[str, list], |
| device: torch.device, |
| dtype: torch.dtype = torch.bfloat16, |
| vocab_size: int = 32000, |
| levels: Optional[List[int]] = None, |
| ) -> Dict[str, Any]: |
| """Run the full 5-level debugging framework. |
| |
| Args: |
| model: the LLM model |
| dataloader: training dataloader |
| tokenizer: tokenizer with encode/decode methods |
| train_config: TrainConfig instance |
| metrics_history: dict from MetricsTracker.history |
| device: torch device |
| dtype: mixed precision dtype |
| vocab_size: model vocabulary size |
| levels: which levels to run (default: all [0,1,2,3,4,5]) |
| |
| Returns: |
| Full diagnostic report dict. |
| """ |
| if levels is None: |
| levels = [0, 1, 2, 3, 4, 5] |
|
|
| print("\n" + "β" * 60) |
| print(" LLM Loss Debugging Framework") |
| print(" Levels to run: " + ", ".join(str(l) for l in levels)) |
| print("β" * 60) |
|
|
| report: Dict[str, Any] = {} |
|
|
| if 0 in levels: |
| report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history) |
| |
| if ( |
| report["level_0"]["status"] == STATUS_NORMAL |
| and levels == [0] |
| ): |
| print("\n Training is healthy β no further debugging needed.") |
| return report |
|
|
| if 1 in levels: |
| report["level_1"] = LossDebugger.check_data_pipeline( |
| model, dataloader, tokenizer, vocab_size, device, dtype, |
| ) |
|
|
| if 2 in levels: |
| report["level_2"] = LossDebugger.check_numerical_stability( |
| model, dataloader, device, dtype, |
| ) |
|
|
| if 3 in levels: |
| report["level_3"] = LossDebugger.diagnose_hyperparameters( |
| metrics_history, train_config, |
| ) |
|
|
| if 4 in levels: |
| model_params = sum(p.numel() for p in model.parameters()) |
| total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step |
| report["level_4"] = LossDebugger.diagnose_fitting( |
| metrics_history, model_params, total_tokens, |
| ) |
|
|
| if 5 in levels: |
| report["level_5"] = LossDebugger.check_architecture( |
| model, dataloader, device, |
| ) |
|
|
| |
| print("\n" + "β" * 60) |
| print(" Diagnostics Complete") |
| print("β" * 60) |
|
|
| return report |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def print_study_roadmap() -> None: |
| """Print the recommended study roadmap for LLM training optimization.""" |
| print(_header("Study Roadmap β LLM Training Optimization")) |
|
|
| print(""" |
| βββ Top Priority: Optimization Fundamentals |
| βββββββββββββββββββββββββββββββββββββββββββββ |
| 1. SGD β Momentum β Adam β AdamW progression |
| - Why Adam > SGD? Why decouple weight decay in AdamW? |
| - Ξ²β, Ξ²β intuition (1st / 2nd momentum) |
| - Ref: Loshchilov & Hutter 2019 (AdamW) |
| - Ref: Karpathy "A Recipe for Training Neural Networks" |
| |
| 2. Loss Landscape |
| - Why large LR diverges, small LR stalls |
| - Batch size effect on landscape exploration |
| - Ref: Li et al. 2018 "Visualizing the Loss Landscape" |
| - Ref: McCandlish et al. 2018 "Large-Batch Training" |
| |
| 3. Chinchilla Scaling Law |
| - Loss = f(N, D) relationship |
| - Compute-optimal model size vs data allocation |
| - Ref: Hoffmann et al. 2022 (original) |
| - Ref: Kaplan et al. 2020 (predecessor) |
| - Ref: Besiroglu et al. 2024 (replication/verification) |
| |
| ββ Important: Training Stability |
| ββββββββββββββββββββββββββββββββββ |
| 4. Gradient Flow: vanishing/exploding, residual as gradient highway |
| 5. Weight Init: Xavier / Kaiming / GPT-2 style |
| 6. Normalization: BatchNorm β LayerNorm β RMSNorm |
| 7. Weight Decay: L2 vs decoupled, why exclude embed/norm |
| |
| β Advanced: Optimization Techniques |
| βββββββββββββββββββββββββββββββββββββ |
| 8. LR Schedules: cosine vs linear vs step, warmup/cooldown |
| 9. Gradient Accumulation & Large Batch Training |
| 10. ΞΌP (Maximal Update Parameterization): transfer HP across scales |
| |
| Recommended Experiments (in order): |
| βββββββββββββββββββββββββββββββββββ |
| 1. Single-batch overfit (30 min) β basic sanity |
| 2. LR Range Test (1 hour) β optimal LR range |
| 3. 10M model quick train (2-3 hrs) β pipeline validation |
| 4. Ablation (remove components) (1 day) β component contribution |
| 5. 100M model + 5B tokens (1-2 days)β mid-scale dynamics |
| 6. 1B model full training (2-3 days)β scaling law verification |
| 7. LR / batch size comparison (1 day) β HP sensitivity |
| |
| Key References: |
| βββββββββββββββ |
| βββ Karpathy "Recipe for Training NNs" β debugging mindset |
| βββ Hoffmann et al. 2022 (Chinchilla) β scaling law |
| ββ Touvron et al. 2023 (LLaMA) β 1B+ training details |
| ββ Biderman et al. 2023 (Pythia) β open training logs |
| ββ Zhang et al. 2024 (TinyLlama) β 1.1B on 3T tokens |
| ββ Groeneveld et al. 2024 (OLMo) β fully open LLM |
| ββ Li et al. 2018 (Loss Landscape) β loss terrain intuition |
| ββ Loshchilov & Hutter 2019 (AdamW) β optimizer basics |
| β Yang et al. 2022 (ΞΌP) β HP transfer |
| β McCandlish et al. 2018 (Batch size) β critical batch size |
| """) |
|
|