"""LLM Loss Debugging & Optimization Framework. A systematic 5-level debugging framework for diagnosing training issues. Always start from Level 1 — fixing lower-level bugs before tuning hyperparameters saves time. Levels: 0. Status Diagnosis — classify current training health 1. Data/Implementation — most common cause (70% of issues) 2. Numerical Stability — dtype, normalization, gradient health 3. Hyperparameters — LR, batch size, warmup 4. Fitting Diagnosis — overfitting vs underfitting 5. Architecture — initialization, component checks """ import copy import math from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from llm_lab.config import TrainConfig # ═══════════════════════════════════════════════════════════════════ # Constants # ═══════════════════════════════════════════════════════════════════ # Approximate convergence ranges for a 1B model trained on ~10B tokens. # Estimated from GPT-2 scaling benchmarks (Radford et al. 2019) and # Chinchilla scaling laws (Hoffmann et al. 2022). Not dataset-specific. _EXPECTED_TRAIN_LOSS = (2.5, 3.3) _EXPECTED_VAL_LOSS = (2.7, 3.6) _EXPECTED_VAL_PPL = (15, 37) # Status labels STATUS_NORMAL = "NORMAL" STATUS_NO_DECREASE = "NO_DECREASE" STATUS_DIVERGING = "DIVERGING" STATUS_PLATEAU = "PLATEAU" STATUS_OVERFITTING = "OVERFITTING" STATUS_UNSTABLE = "UNSTABLE" STATUS_NAN_DETECTED = "NAN_DETECTED" STATUS_LOSS_BOUNCE = "LOSS_BOUNCE" # GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1) # (param_count, recommended_lr, batch_tokens_str) _GPT3_LR_REFERENCE = [ (125e6, 6e-4, "0.5M"), (350e6, 3e-4, "0.5M"), (760e6, 2.5e-4, "0.5M"), (1.3e9, 2e-4, "1M"), (2.7e9, 1.6e-4, "1M"), (6.7e9, 1.2e-4, "2M"), (13e9, 1e-4, "2M"), (175e9, 6e-5, "3.2M"), ] # Known LLM training references _LLM_TRAINING_REFS = { "TinyLlama-1.1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000}, "LLaMA-7B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.1, "warmup": 2000}, "Pythia-1B": {"lr": 3e-4, "beta2": 0.95, "wd": 0.01}, "OLMo-1B": {"lr": 4e-4, "beta2": 0.95, "wd": 0.1}, } # Recommended β₂ for LLM training _RECOMMENDED_BETA2 = 0.95 _DEFAULT_PYTORCH_BETA2 = 0.999 def _header(title: str) -> str: return f"\n{'=' * 60}\n{title}\n{'=' * 60}" def _check_result(name: str, passed: bool, detail: str = "") -> Dict[str, Any]: return {"name": name, "passed": passed, "detail": detail} # ═══════════════════════════════════════════════════════════════════ # LossDebugger # ═══════════════════════════════════════════════════════════════════ class LossDebugger: """5-level loss debugging framework for LLM training. Usage:: from llm_lab.training.debugger import LossDebugger # Quick status check status = LossDebugger.diagnose_status(vocab_size=32000, metrics_history=trainer.metrics.history) # Full diagnostics report = LossDebugger.run_diagnostics( model=model, dataloader=train_dl, tokenizer=tok, train_config=train_cfg, metrics_history=trainer.metrics.history, device=device, dtype=torch.bfloat16, ) """ # ─────────────────────────────────────────────────────────────── # Level 0: Status Diagnosis # ─────────────────────────────────────────────────────────────── @staticmethod def diagnose_status( vocab_size: int, metrics_history: Dict[str, list], ) -> Dict[str, Any]: """Classify current training health from metrics history. Args: vocab_size: model vocabulary size (e.g. 32000) metrics_history: dict with keys 'train_loss', 'val_loss', etc. Returns: dict with 'status', 'severity', 'details', 'recommended_levels' """ print(_header("Level 0: Training Status Diagnosis")) expected_initial = math.log(vocab_size) print(f" Expected initial loss (random weights): ln({vocab_size}) = {expected_initial:.2f}") print(f" Normal convergence range (1B, 10B tokens):") print(f" Train Loss: {_EXPECTED_TRAIN_LOSS[0]} ~ {_EXPECTED_TRAIN_LOSS[1]}") print(f" Val Loss: {_EXPECTED_VAL_LOSS[0]} ~ {_EXPECTED_VAL_LOSS[1]}") print(f" Val PPL: {_EXPECTED_VAL_PPL[0]} ~ {_EXPECTED_VAL_PPL[1]}") raw_train_losses = metrics_history.get("train_loss", []) train_losses = [l for l in raw_train_losses if not math.isnan(l)] val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None] if len(train_losses) < 2: print("\n [!] Not enough training data to diagnose. Run more steps first.") return { "status": "INSUFFICIENT_DATA", "severity": "unknown", "details": "Need at least 2 logged train loss values.", "recommended_levels": [1], } # Detect NaN presence before filtering has_nan = len(train_losses) < len(raw_train_losses) if has_nan: nan_count = len(raw_train_losses) - len(train_losses) print(f"\n ⚠ {nan_count} NaN values detected in train_loss — filtered for analysis") first_loss = train_losses[0] last_loss = train_losses[-1] loss_change = first_loss - last_loss # Split into halves for trend analysis mid = len(train_losses) // 2 first_half_avg = sum(train_losses[:mid]) / mid second_half_avg = sum(train_losses[mid:]) / (len(train_losses) - mid) # Recent window for spike detection recent_n = min(50, len(train_losses)) recent = train_losses[-recent_n:] recent_mean = sum(recent) / len(recent) recent_var = sum((x - recent_mean) ** 2 for x in recent) / len(recent) recent_std = recent_var ** 0.5 # Val trend val_trend = "unknown" if len(val_losses) >= 2: val_mid = len(val_losses) // 2 val_first_avg = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1) val_second_avg = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1) if val_second_avg < val_first_avg - 0.05: val_trend = "decreasing" elif val_second_avg > val_first_avg + 0.1: val_trend = "increasing" else: val_trend = "flat" # Pre-compute bounce detection using moving-average minimum # to avoid false positives from single noisy data points _ma_window = max(1, len(train_losses) // 20) # 5% window _ma_losses = [ sum(train_losses[max(0, i - _ma_window + 1):i + 1]) / (i - max(0, i - _ma_window + 1) + 1) for i in range(len(train_losses)) ] _min_ma_loss = min(_ma_losses) _min_ma_idx = _ma_losses.index(_min_ma_loss) _last_ma_loss = _ma_losses[-1] _bounce_amount = _last_ma_loss - _min_ma_loss _has_bounce = ( loss_change > 0.1 and _min_ma_idx < len(train_losses) * 0.85 and _bounce_amount > _min_ma_loss * 0.05 ) # Downgrade bounce severity when val loss is still improving _val_improving = ( val_trend == "decreasing" or (len(val_losses) >= 4 and val_losses[-1] <= min(val_losses[:len(val_losses) // 2])) ) # ── Classify ── status = STATUS_NORMAL severity = "green" details = "" recommended_levels: List[int] = [] # Check 1: No decrease at all if loss_change < 0.1 and first_loss > expected_initial - 2.0: status = STATUS_NO_DECREASE severity = "red" details = ( f"Loss barely changed: {first_loss:.4f} -> {last_loss:.4f} " f"(delta={loss_change:.4f}). Likely a data or implementation bug." ) recommended_levels = [1, 2] # Check 2: Diverging elif last_loss > expected_initial + 1.0: status = STATUS_DIVERGING severity = "red" details = ( f"Loss ({last_loss:.4f}) exceeds initial value ({expected_initial:.2f}). " f"Training is diverging — check LR, data, or numerical issues." ) recommended_levels = [1, 2, 3] # Check 3: NaN detected in training loss elif has_nan: nan_count = len(raw_train_losses) - len(train_losses) nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l)) status = STATUS_NAN_DETECTED severity = "red" details = ( f"NaN detected in train_loss: {nan_count} NaN values " f"(first at step ~{nan_idx}). " f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. " f"Check gradient norms, LR schedule, and numerical precision." ) recommended_levels = [2, 3] # Check 4: Unstable (large spikes) elif recent_std > 0.5 * recent_mean: status = STATUS_UNSTABLE severity = "yellow" details = ( f"High loss variance: std={recent_std:.4f}, mean={recent_mean:.4f}. " f"Training is unstable — likely LR too high or batch too small." ) recommended_levels = [3, 2] # Check 5: Loss bounce (decreased then increased again) elif _has_bounce: status = STATUS_LOSS_BOUNCE if _val_improving: severity = "green" details = ( f"Train loss bounced (moving-avg): " f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} " f"(bounce={_bounce_amount:.4f}), but val loss is still improving " f"({val_losses[0]:.4f} -> {val_losses[-1]:.4f}). " f"Likely data distribution variation, not a real issue." ) recommended_levels = [] else: severity = "yellow" details = ( f"Train loss bounced (moving-avg): " f"{first_loss:.4f} -> min {_min_ma_loss:.4f} -> {_last_ma_loss:.4f} " f"(bounce={_bounce_amount:.4f}). " f"Possible LR too high, data issue, or overfitting." ) recommended_levels = [3, 4] # Check 6: Overfitting elif val_trend == "increasing" and second_half_avg < first_half_avg: status = STATUS_OVERFITTING severity = "yellow" details = ( f"Train loss decreasing but val loss increasing. " f"Train trend: {first_half_avg:.4f} -> {second_half_avg:.4f}, " f"Val trend: {val_trend}." ) recommended_levels = [4] # Check 7: Plateau elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]: status = STATUS_PLATEAU severity = "yellow" details = ( f"Loss has plateaued: first half avg={first_half_avg:.4f}, " f"second half avg={second_half_avg:.4f}. " f"Current loss ({last_loss:.4f}) is above expected range." ) recommended_levels = [3, 4, 5] # Normal else: status = STATUS_NORMAL severity = "green" details = ( f"Training looks healthy: {first_loss:.4f} -> {last_loss:.4f} " f"(delta={loss_change:.4f}). Val trend: {val_trend}." ) recommended_levels = [] # ── Print ── icons = {"red": "🔴", "yellow": "🟡", "green": "🟢"} icon = icons.get(severity, "⚪") print(f"\n {icon} Status: {status}") print(f" {details}") if recommended_levels: print(f" Recommended: check Level(s) {recommended_levels}") return { "status": status, "severity": severity, "details": details, "recommended_levels": recommended_levels, } # ─────────────────────────────────────────────────────────────── # Level 1: Data / Implementation Bug Checks # ─────────────────────────────────────────────────────────────── @staticmethod def check_data_pipeline( model: nn.Module, dataloader: DataLoader, tokenizer: Any, vocab_size: int, device: torch.device, dtype: torch.dtype = torch.bfloat16, ) -> Dict[str, Any]: """Run 6 data/implementation checks (Level 1). This is the most important level — 70% of loss issues are data bugs. Checks: 1. Shift relationship (targets[t] == input_ids[t+1]) 2. Token range (0 <= ids < vocab_size) 3. Initial loss (≈ ln(vocab_size) for random weights) 4. Single-batch overfit (loss → ~0 in 200 steps) 5. Tokenizer roundtrip (encode→decode preserves text) 6. Data quality sampling (visual inspection) """ print(_header("Level 1: Data / Implementation Bug Checks")) print(" (70% of loss issues come from data pipeline bugs)\n") results: List[Dict[str, Any]] = [] batch = next(iter(dataloader)) input_ids = batch["input_ids"] targets = batch["targets"] # ── Check 1: Shift relationship ── shift_match = (input_ids[:, 1:] == targets[:, :-1]).float().mean().item() passed = shift_match > 0.99 detail = f"Shift consistency: {shift_match * 100:.1f}% (should be ~100%)" results.append(_check_result("Shift relationship", passed, detail)) icon = "✅" if passed else "❌" print(f" {icon} Check 1: {detail}") # ── Check 2: Token range ── min_id = input_ids.min().item() max_id = input_ids.max().item() range_ok = min_id >= 0 and max_id < vocab_size detail = f"Token range: [{min_id}, {max_id}], vocab_size={vocab_size}" results.append(_check_result("Token range", range_ok, detail)) icon = "✅" if range_ok else "❌" print(f" {icon} Check 2: {detail}") # ── Check 3: Initial loss ── expected_loss = math.log(vocab_size) model_copy = copy.deepcopy(model) model_copy._init_weights() # re-initialize to random model_copy.to(device) model_copy.eval() with torch.no_grad(): with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): _, initial_loss = model_copy( input_ids.to(device), targets.to(device), ) initial_loss_val = initial_loss.item() loss_diff = abs(initial_loss_val - expected_loss) loss_ok = loss_diff < 1.0 detail = ( f"Initial loss: {initial_loss_val:.4f} vs expected {expected_loss:.2f} " f"(diff={loss_diff:.4f})" ) results.append(_check_result("Initial loss", loss_ok, detail)) icon = "✅" if loss_ok else "❌" print(f" {icon} Check 3: {detail}") if initial_loss_val > expected_loss + 1.0: print(f" Hint: loss >> ln(V) suggests label mismatch or loss function bug") elif initial_loss_val < expected_loss - 2.0: print(f" Hint: loss << ln(V) suggests data leakage") del model_copy if torch.cuda.is_available(): torch.cuda.empty_cache() # ── Check 4: Single-batch overfit test ── # Scale LR and steps based on model size to avoid instability num_params = sum(p.numel() for p in model.parameters()) if num_params > 500e6: overfit_lr, overfit_steps = 1e-4, 400 elif num_params > 50e6: overfit_lr, overfit_steps = 3e-4, 300 else: overfit_lr, overfit_steps = 1e-3, 200 print(f"\n ⏳ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...") overfit_model = copy.deepcopy(model) overfit_model.to(device) overfit_model.train() overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr) single_input = input_ids[:1].to(device) # single sample single_target = targets[:1].to(device) log_interval = max(overfit_steps // 4, 1) overfit_losses = [] for step in range(overfit_steps): overfit_optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): _, loss = overfit_model(single_input, single_target) loss.backward() torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0) overfit_optimizer.step() overfit_losses.append(loss.item()) if (step + 1) % log_interval == 0: print(f" Step {step + 1}: Loss = {loss.item():.4f}") final_overfit_loss = overfit_losses[-1] min_overfit_loss = min(overfit_losses) overfit_ok = min_overfit_loss < 0.5 detail = ( f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} " f"(min={min_overfit_loss:.4f}, target < 0.5)" ) results.append(_check_result("Single-batch overfit", overfit_ok, detail)) icon = "✅" if overfit_ok else "❌" print(f" {icon} Check 4: {detail}") if not overfit_ok: print(f" CRITICAL: Model cannot memorize a single batch!") print(f" This means the model or loss function has a bug.") del overfit_model, overfit_optimizer if torch.cuda.is_available(): torch.cuda.empty_cache() # ── Check 5: Tokenizer roundtrip ── test_text = "The quick brown fox jumps over the lazy dog." encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) roundtrip_ok = test_text.strip() in decoded.strip() detail = f"Roundtrip: '{test_text}' -> '{decoded.strip()}'" results.append(_check_result("Tokenizer roundtrip", roundtrip_ok, detail)) icon = "✅" if roundtrip_ok else "❌" print(f" {icon} Check 5: {detail}") # ── Check 6: Data quality sampling ── print(f"\n 📋 Check 6: Data quality sampling (visual inspection)") for i in range(min(3, input_ids.shape[0])): sample_tokens = input_ids[i][:100].tolist() decoded_text = tokenizer.decode(sample_tokens) preview = decoded_text[:200].replace("\n", "\\n") print(f" Sample {i}: {preview}...") passed_count = sum(1 for r in results if r["passed"]) total_count = len(results) print(f"\n Result: {passed_count}/{total_count} checks passed") return { "level": 1, "checks": results, "passed": [r for r in results if r["passed"]], "failed": [r for r in results if not r["passed"]], } # ─────────────────────────────────────────────────────────────── # Level 2: Numerical Stability # ─────────────────────────────────────────────────────────────── @staticmethod def check_numerical_stability( model: nn.Module, dataloader: DataLoader, device: torch.device, dtype: torch.dtype = torch.bfloat16, ) -> Dict[str, Any]: """Check for NaN/Inf in gradients, activations, and logits (Level 2). Checks: - Mixed precision config (RMSNorm fp32 upcast, loss dtype) - NaN/Inf gradients → softmax overflow, bad data - Inf gradients → log(0) in loss, missing ignore_index - Large activations growing per layer → initialization or norm bug - Logit scale → should be < 1000 """ print(_header("Level 2: Numerical Stability Checks")) batch = next(iter(dataloader)) input_ids = batch["input_ids"].to(device) targets = batch["targets"].to(device) results: List[Dict[str, Any]] = [] activation_stats: List[Dict[str, Any]] = [] # ── Mixed Precision Configuration Check ── print("\n Mixed Precision Config:") print(f" Training dtype: {dtype}") # Check RMSNorm fp32 upcast norm_fp32_ok = True checked_norm_classes: set = set() for name, module in model.named_modules(): cls_name = module.__class__.__name__ if "Norm" in cls_name and cls_name not in checked_norm_classes: checked_norm_classes.add(cls_name) import inspect try: src = inspect.getsource(type(module).forward) has_upcast = ".float()" in src or "float32" in src except (TypeError, OSError): has_upcast = True # assume ok if can't inspect if not has_upcast: norm_fp32_ok = False print(f" 🔴 {cls_name}: no fp32 upcast detected!") if norm_fp32_ok: print(f" ✅ Norm layers use fp32 upcast (safe)") results.append(_check_result( "Norm fp32 upcast", norm_fp32_ok, "Norm computes in fp32" if norm_fp32_ok else "Norm may lose precision in half dtype", )) # Check loss computation dtype if dtype in (torch.bfloat16, torch.float16): print(f" ℹ️ Best practice: compute loss in fp32 when using {dtype}") print(f" logits_fp32 = logits.float()") print(f" loss = F.cross_entropy(logits_fp32.view(-1, V), targets.view(-1))") # Common numerical issues reference print("\n Common Numerical Issues Reference:") print(" ┌──────────────────────┬──────────────────────────┬─────────────────────────┐") print(" │ Symptom │ Likely Cause │ Solution │") print(" ├──────────────────────┼──────────────────────────┼─────────────────────────┤") print(" │ Loss → NaN │ Large logits → softmax │ Check init, logit scale │") print(" │ Loss → Inf │ log(0) in CE loss │ Add eps, ignore_index │") print(" │ Loss oscillation │ fp16 gradient underflow │ Switch to bf16 / scaler │") print(" │ Late-training NaN │ Activation growth │ Check RMSNorm, wd │") print(" └──────────────────────┴──────────────────────────┴─────────────────────────┘") # ── Activation monitoring via hooks ── hooks = [] def make_hook(name: str): def hook_fn(module, input, output): if isinstance(output, torch.Tensor): out_f = output.float() stats = { "name": name, "mean": out_f.mean().item(), "std": out_f.std().item(), "max": out_f.abs().max().item(), "has_nan": bool(torch.isnan(output).any()), "has_inf": bool(torch.isinf(output).any()), } activation_stats.append(stats) return hook_fn # Register hooks on transformer layers for i, layer in enumerate(model.layers): h = layer.register_forward_hook(make_hook(f"layer_{i}")) hooks.append(h) # ── Forward + Backward ── model.train() model.zero_grad(set_to_none=True) use_scaler = dtype == torch.float16 and torch.cuda.is_available() scaler = torch.amp.GradScaler() if use_scaler else None with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): logits, loss = model(input_ids, targets) loss_val = loss.item() loss_ok = not (math.isnan(loss_val) or math.isinf(loss_val)) results.append(_check_result( "Loss value", loss_ok, f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)" )) if scaler is not None: scaler.scale(loss).backward() _temp_opt = torch.optim.SGD(model.parameters(), lr=0) scaler.unscale_(_temp_opt) else: loss.backward() # Remove hooks for h in hooks: h.remove() # ── Gradient checks ── print("\n Gradient Health:") grad_issues = [] for name, param in model.named_parameters(): if param.grad is None: continue grad = param.grad if torch.isnan(grad).any(): grad_issues.append(f"🔴 NaN gradient: {name}") if torch.isinf(grad).any(): grad_issues.append(f"🔴 Inf gradient: {name}") if grad.abs().max().item() > 100: grad_issues.append( f"🟡 Large gradient: {name} max={grad.abs().max().item():.1f}" ) grad_ok = len(grad_issues) == 0 if grad_ok: print(" ✅ All gradients are healthy (no NaN/Inf/large values)") else: for issue in grad_issues[:10]: # limit output print(f" {issue}") if len(grad_issues) > 10: print(f" ... and {len(grad_issues) - 10} more issues") results.append(_check_result( "Gradient health", grad_ok, f"{len(grad_issues)} issues found" if not grad_ok else "All healthy", )) # ── Activation checks ── print("\n Activation Stats (per transformer layer):") act_nan_count = 0 for stats in activation_stats: icon = "🔴" if stats["has_nan"] or stats["has_inf"] else " " if stats["has_nan"] or stats["has_inf"]: act_nan_count += 1 print( f" {icon} {stats['name']}: " f"mean={stats['mean']:.4f}, " f"std={stats['std']:.4f}, " f"max={stats['max']:.4f}" + (" [NaN!]" if stats["has_nan"] else "") + (" [Inf!]" if stats["has_inf"] else "") ) act_ok = act_nan_count == 0 results.append(_check_result( "Activation health", act_ok, f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy", )) # ── Activation growth trend ── if len(activation_stats) >= 2: stds = [s["std"] for s in activation_stats] if stds[0] > 1e-8: growth_ratio = stds[-1] / stds[0] growth_ok = growth_ratio < 10 detail = ( f"Activation std ratio (last/first): {growth_ratio:.1f}x " f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})" ) results.append(_check_result("Activation growth", growth_ok, detail)) icon = "✅" if growth_ok else "🟡" print(f" {icon} {detail}") if not growth_ok: print(f" Possible initialization or normalization issue") # ── Logit scale check ── logit_max = logits.float().abs().max().item() logit_ok = logit_max < 1000 detail = f"Logit max abs value: {logit_max:.1f} (should be < 1000)" results.append(_check_result("Logit scale", logit_ok, detail)) icon = "✅" if logit_ok else "🔴" print(f"\n {icon} Logit scale: {detail}") model.zero_grad(set_to_none=True) passed_count = sum(1 for r in results if r["passed"]) print(f"\n Result: {passed_count}/{len(results)} checks passed") return { "level": 2, "checks": results, "activation_stats": activation_stats, "grad_issues": grad_issues, } # ─────────────────────────────────────────────────────────────── # Level 3: Hyperparameter Diagnosis # ─────────────────────────────────────────────────────────────── @staticmethod def diagnose_hyperparameters( metrics_history: Dict[str, list], config: TrainConfig, ) -> Dict[str, Any]: """Analyze hyperparameter health from training metrics (Level 3). Checks: - LR: too high (grad_norm hitting clip limit) or too low (grad_norm tiny) - Batch size: loss variance indicates batch too small - Warmup: spikes in early steps indicate warmup too short """ print(_header("Level 3: Hyperparameter Diagnosis")) findings: List[Dict[str, str]] = [] grad_norms = metrics_history.get("grad_norm", []) train_losses = metrics_history.get("train_loss", []) # ── LR diagnosis ── print("\n Learning Rate Analysis:") print(f" Peak LR: {config.learning_rate:.2e}") print(f" Min LR: {config.min_learning_rate:.2e}") if grad_norms: avg_grad = sum(grad_norms) / len(grad_norms) # Ref: PyTorch clip_grad_norm_ clips when total_norm > max_norm clip_count = sum(1 for g in grad_norms if g >= config.grad_clip) clip_rate = clip_count / len(grad_norms) # Relative threshold: < 1% of clip limit (model-size independent) tiny_threshold = config.grad_clip * 0.01 tiny_count = sum(1 for g in grad_norms if g < tiny_threshold) tiny_rate = tiny_count / len(grad_norms) print(f" Avg grad norm: {avg_grad:.4f}") print(f" Clip rate: {clip_rate * 100:.1f}% (hitting max_norm={config.grad_clip})") print(f" Tiny grad rate: {tiny_rate * 100:.1f}% (< {tiny_threshold:.4f})") # Heuristic: >50% clipping means most steps are capped, so the # effective LR is lower than configured. Practitioners generally # treat this as a sign that peak LR is too high. if clip_rate > 0.5: findings.append({ "issue": "LR may be too high", "evidence": f"Grad norm hits clip limit {clip_rate * 100:.0f}% of the time", "action": f"Try LR = {config.learning_rate / 2:.2e} (÷2)", }) print(f" 🟡 Grad clipping frequent ({clip_rate * 100:.0f}%) → LR may be too high") elif tiny_rate > 0.5: findings.append({ "issue": "Possible vanishing gradients", "evidence": f"Grad norm < {tiny_threshold:.4f} in {tiny_rate * 100:.0f}% of steps", "action": "Check weight initialization, layer norms, and model depth", }) print(f" 🟡 Grad norm too small ({tiny_rate * 100:.0f}% < {tiny_threshold:.4f}) → possible vanishing gradients") else: print(f" ✅ LR looks appropriate") # ── Batch size diagnosis ── print("\n Batch Size Analysis:") print(f" Effective batch: {config.effective_batch_size}") if len(train_losses) >= 50: recent_losses = train_losses[-50:] loss_mean = sum(recent_losses) / len(recent_losses) loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses) loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8) print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation, last 50 steps)") if loss_cv > 0.1: findings.append({ "issue": "Training loss has high variance", "evidence": f"Loss CV = {loss_cv:.4f} over last 50 steps", "action": "Check: (1) LR may be too high, (2) increase gradient_accumulation_steps, (3) inspect data quality", }) print(f" 🟡 High loss variance → check LR, batch size, or data quality") else: print(f" ✅ Loss variance is acceptable") # ── β₂ diagnosis ── print("\n β₂ (Adam second momentum) Analysis:") print(f" Current β₂: {config.beta2}") if config.beta2 >= _DEFAULT_PYTORCH_BETA2: findings.append({ "issue": "β₂ may be too high for LLM training", "evidence": ( f"β₂={config.beta2} (PyTorch default). " f"LLM standard is {_RECOMMENDED_BETA2}" ), "action": f"Set beta2={_RECOMMENDED_BETA2} (used by LLaMA, TinyLlama, OLMo)", }) print(f" 🟡 β₂={config.beta2} is PyTorch default → " f"LLM training standard is {_RECOMMENDED_BETA2}") print(f" Why: β₂=0.999 averages ~1000 steps of gradient stats,") print(f" β₂=0.95 averages ~20 steps → faster adaptation to changing data") print(f" (Cattaneo & Shigida 2025, 'Tuning Adam(W)')") else: print(f" ✅ β₂={config.beta2} is within LLM standard range") # ── Weight Decay diagnosis ── print("\n Weight Decay Analysis:") print(f" Current weight_decay: {config.weight_decay}") if config.weight_decay == 0: findings.append({ "issue": "Weight decay is disabled", "evidence": "weight_decay=0 increases overfitting risk", "action": "Set weight_decay=0.1 (standard for LLaMA, TinyLlama, GPT-3, OLMo)", }) print(f" 🟡 weight_decay=0 → overfitting risk. Standard is 0.1") elif config.weight_decay > 0.3: findings.append({ "issue": "Weight decay may be too high", "evidence": f"weight_decay={config.weight_decay} (unusually high)", "action": "Try weight_decay=0.1 (standard value)", }) print(f" 🟡 weight_decay={config.weight_decay} is unusually high (standard: 0.1)") else: print(f" ✅ weight_decay={config.weight_decay} is within normal range") # ── Model-size LR reference ── print("\n GPT-3 LR Reference (Brown et al. 2020):") print(" ┌──────────┬───────────┬──────────────┐") print(" │ Model │ Peak LR │ Batch Tokens │") print(" ├──────────┼───────────┼──────────────┤") for params, lr, batch_tok in _GPT3_LR_REFERENCE: label = f"{params / 1e9:.1f}B" if params >= 1e9 else f"{params / 1e6:.0f}M" marker = " ←" if abs(params - 1.1e9) < 0.5e9 else "" print(f" │ {label:<8} │ {lr:.1e} │ {batch_tok:<12} │{marker}") print(" └──────────┴───────────┴──────────────┘") print(" → Larger models need lower LR and larger batch") # ── Batch-LR scaling guidance ── print("\n Batch-LR Scaling Rules:") print(" • Batch ×2 → LR ×√2 (square root scaling, recommended for Adam)") print(" (Malladi et al. NeurIPS 2022, 'On the SDEs and Scaling Rules for Adaptive Gradient Algorithms')") print(" • Batch ×2 → LR ×2 (linear scaling, Goyal et al. 2017, mainly SGD)") print(" • 1B model: ~1K-2K sequences (~2-4M tokens) is typical") print(" (Pythia-1B: ~2M tokens, TinyLlama: ~2M, OLMo-1B: ~4M)") # ── Warmup diagnosis ── print("\n Warmup Analysis:") print(f" Warmup steps: {config.warmup_steps} " f"({config.warmup_steps / config.total_steps * 100:.1f}% of total)") if len(train_losses) >= 10: early_losses = train_losses[:min(50, len(train_losses))] # Detect spikes in early training spike_count = 0 for i in range(1, len(early_losses)): if early_losses[i] > early_losses[i - 1] * 1.5: spike_count += 1 if spike_count > 3: findings.append({ "issue": "Warmup may be too short", "evidence": f"{spike_count} loss spikes in first {len(early_losses)} steps", "action": f"Try warmup_steps = {config.warmup_steps * 2}", }) print(f" 🟡 {spike_count} spikes in early training → warmup may be too short") else: print(f" ✅ Early training is stable") # ── Summary ── if not findings: print("\n ✅ No hyperparameter issues detected") else: print(f"\n Found {len(findings)} potential issue(s):") for f in findings: print(f" • {f['issue']}: {f['action']}") # ── Warmup reference from real projects ── print("\n Warmup Reference (real projects):") print(" • TinyLlama 1.1B (3T tokens): 2,000 steps ≈ 0.1% of total") print(" • GPT-3 175B: 375M warmup tokens ≈ 117 steps") print(" • General range: 0.1% ~ 5% of total steps") print(" • Smaller experiments: 5~10% is also reasonable") print("\n Tuning priority (high → low):") print(" 1. Learning Rate ← tune first (10x impact)") print(" 2. Batch Size ← adjust with LR") print(" 3. Warmup Steps ← early stability") print(" 4. Weight Decay ← if overfitting (typically 0.1)") print(" 5. β₁, β₂ (Adam) ← see β₂ analysis above") print(" 6. Gradient Clip ← usually keep at 1.0") return { "level": 3, "findings": findings, "config_summary": { "learning_rate": config.learning_rate, "effective_batch": config.effective_batch_size, "warmup_steps": config.warmup_steps, "total_steps": config.total_steps, "grad_clip": config.grad_clip, }, } @staticmethod def lr_range_test( model: nn.Module, dataloader: DataLoader, device: torch.device, dtype: torch.dtype = torch.bfloat16, lr_start: float = 1e-7, lr_end: float = 1e-1, steps: int = 300, ) -> Dict[str, Any]: """Run an LR range test to find the optimal learning rate (Level 3 bonus). Sweeps LR from lr_start to lr_end exponentially, recording loss. The optimal LR is where loss decreases fastest (steepest slope), divided by 3~10 for stability. WARNING: This modifies a copy of the model. The original is untouched. """ print(_header("Level 3 Bonus: LR Range Test")) print(f" Sweeping LR from {lr_start:.1e} to {lr_end:.1e} over {steps} steps...\n") test_model = copy.deepcopy(model) test_model.to(device) test_model.train() optimizer = torch.optim.AdamW(test_model.parameters(), lr=lr_start) lr_mult = (lr_end / lr_start) ** (1 / steps) lr = lr_start lrs: List[float] = [] losses: List[float] = [] data_iter = iter(dataloader) for step in range(steps): for pg in optimizer.param_groups: pg["lr"] = lr try: batch = next(data_iter) except StopIteration: data_iter = iter(dataloader) batch = next(data_iter) input_ids = batch["input_ids"].to(device) targets_t = batch["targets"].to(device) optimizer.zero_grad() with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)): _, loss = test_model(input_ids, targets_t) loss.backward() optimizer.step() loss_val = loss.item() lrs.append(lr) losses.append(loss_val) if (step + 1) % 50 == 0: print(f" Step {step + 1}: LR = {lr:.2e}, Loss = {loss_val:.4f}") # Stop if loss explodes if len(losses) > 1 and loss_val > losses[0] * 4: print(f" Loss exploded at LR = {lr:.2e}, stopping.") break lr *= lr_mult del test_model, optimizer if torch.cuda.is_available(): torch.cuda.empty_cache() # Find steepest descent best_lr = lr_start if len(losses) > 10: # Smooth losses and find steepest negative slope window = 5 smoothed = [] for i in range(len(losses) - window): smoothed.append(sum(losses[i:i + window]) / window) min_slope = 0 min_idx = 0 for i in range(1, len(smoothed)): slope = smoothed[i] - smoothed[i - 1] if slope < min_slope: min_slope = slope min_idx = i best_lr = lrs[min_idx] suggested_lr = best_lr / 3 # conservative choice print(f"\n Steepest descent at LR = {best_lr:.2e}") print(f" Suggested peak LR: {suggested_lr:.2e} (÷3 for stability)") print(f" Conservative range: [{best_lr / 10:.2e}, {best_lr / 3:.2e}]") else: suggested_lr = 3e-4 print(f"\n Not enough data points. Using default LR = {suggested_lr:.2e}") return { "lrs": lrs, "losses": losses, "best_lr": best_lr, "suggested_lr": suggested_lr, } # ─────────────────────────────────────────────────────────────── # Level 4: Overfitting vs Underfitting Diagnosis # ─────────────────────────────────────────────────────────────── @staticmethod def diagnose_fitting( metrics_history: Dict[str, list], model_params: Optional[int] = None, total_tokens: Optional[int] = None, ) -> Dict[str, Any]: """Diagnose overfitting vs underfitting from metrics (Level 4). Cases: 1. Both high, decreasing → Normal (still training) 2. Both high, plateau → Underfitting 3. Train↓ Val→ or Val↑ → Overfitting 4. Both low, plateau → Converged (or at limit) """ print(_header("Level 4: Overfitting vs Underfitting Diagnosis")) train_losses = metrics_history.get("train_loss", []) val_losses = [v for v in metrics_history.get("val_loss", []) if v is not None] if len(train_losses) < 10 or len(val_losses) < 2: print(" [!] Not enough data. Need more training steps with eval.") return {"level": 4, "case": "insufficient_data", "recommendations": []} # Recent train trend recent_n = min(50, len(train_losses)) train_recent = train_losses[-recent_n:] train_mid = len(train_recent) // 2 train_first = sum(train_recent[:train_mid]) / max(train_mid, 1) train_second = sum(train_recent[train_mid:]) / max(len(train_recent) - train_mid, 1) train_decreasing = train_second < train_first - 0.02 # Val trend val_mid = len(val_losses) // 2 val_first = sum(val_losses[:max(val_mid, 1)]) / max(val_mid, 1) val_second = sum(val_losses[val_mid:]) / max(len(val_losses) - val_mid, 1) val_decreasing = val_second < val_first - 0.02 val_increasing = val_second > val_first + 0.05 # Train-Val gap last_train = train_losses[-1] last_val = val_losses[-1] gap = last_train - last_val # negative means val > train (typical) print(f" Train loss (recent): {train_first:.4f} → {train_second:.4f} " f"({'↓' if train_decreasing else '→'})") print(f" Val loss: {val_first:.4f} → {val_second:.4f} " f"({'↓' if val_decreasing else '↑' if val_increasing else '→'})") print(f" Train-Val gap: {abs(gap):.4f}") # ── Classify ── case = "" recommendations: List[str] = [] if train_decreasing and val_decreasing: case = "Case 1: Normal — both decreasing" recommendations.append("Training is progressing normally. Continue.") if model_params and total_tokens: ratio = total_tokens / model_params chinchilla = 20 # Chinchilla optimal: 20 tokens per param if ratio < chinchilla: recommendations.append( f"Token/param ratio = {ratio:.1f}x " f"(Chinchilla optimal ≈ {chinchilla}x). " f"Model may benefit from more data." ) print(f"\n 🟢 {case}") elif not train_decreasing and not val_decreasing and last_train > _EXPECTED_TRAIN_LOSS[1]: case = "Case 2: Underfitting — both plateaued at high loss" recommendations = [ "Diagnosis priority (check in order):", "1) Training insufficient? → check if loss curve still has downward slope", " - Chinchilla: 1B model needs ~20B tokens minimum", " - TinyLlama trains 1.1B on 3T tokens (inference-optimal)", "2) LR too low? → try LR ×2, see if loss drops faster", "3) Model capacity too small? → train 2x larger model on same data", " - If larger model gets lower loss → capacity was the limit", "4) Data quality? → sample and read training data manually", " - Noisy/low-quality data raises the achievable loss floor", ] if model_params and total_tokens: ratio = total_tokens / model_params if ratio < 10: recommendations.insert(0, f"⚠ Token/param ratio = {ratio:.1f}x — " f"very likely undertrained. Chinchilla recommends ≥20x." ) elif ratio < 20: recommendations.insert(0, f"ℹ Token/param ratio = {ratio:.1f}x — " f"below Chinchilla optimal (20x). More tokens may help." ) print(f"\n 🟡 {case}") elif train_decreasing and (val_increasing or not val_decreasing): case = "Case 3: Overfitting — train↓ but val→/↑" recommendations = [ "Diagnosis priority (check in order):", "1) Data repetition? (most common cause in pretraining)", " - Check: total tokens vs unique tokens", " - Epoch > 1 dramatically increases overfitting risk", " - Solution: add more data, stay within 1 epoch", "2) Weight decay too low?", " - Check: weight_decay value (standard: 0.1)", " - LLaMA, TinyLlama, OLMo, GPT-3 all use 0.1", " - Experiment: 0.01 / 0.05 / 0.1 / 0.3", "3) Data diversity?", " - Single-domain data overfits faster", " - Mix: web, books, code, wiki, etc.", "", "Note on Dropout in LLM pretraining:", " - Modern LLMs do NOT use dropout in pretraining", " (Pythia, TinyLlama, OLMo, LLaMA all use dropout=0)", " - Sufficient data is the best regularization", " - Dropout is useful for fine-tuning on small datasets", ] print(f"\n 🟡 {case}") else: case = "Case 4: Converged — loss is low and stable" recommendations = [ "Training has converged (or reached the data/model limit).", "To push further: add more data or increase model size.", ] print(f"\n 🟢 {case}") for rec in recommendations: print(f" {rec}") return { "level": 4, "case": case, "train_trend": "decreasing" if train_decreasing else "flat", "val_trend": "decreasing" if val_decreasing else ("increasing" if val_increasing else "flat"), "gap": abs(gap), "recommendations": recommendations, } # ─────────────────────────────────────────────────────────────── # Level 5: Architecture Checks # ─────────────────────────────────────────────────────────────── @staticmethod def check_architecture( model: nn.Module, dataloader: DataLoader, device: torch.device, ) -> Dict[str, Any]: """Check weight initialization and per-layer activation health (Level 5). Healthy initialization: - All layers: std ≈ 1.0, mean ≈ 0.0 Problems: - std increasing per layer → activation explosion (init scale too large) - std decreasing per layer → activation vanishing (init scale too small) - Sudden change at specific layer → implementation bug in that layer """ print(_header("Level 5: Architecture / Initialization Check")) batch = next(iter(dataloader)) sample_input = batch["input_ids"][:1].to(device) model.eval() layer_stats: List[Dict[str, Any]] = [] with torch.no_grad(): h = model.token_embedding(sample_input) emb_std = h.float().std().item() print(f"\n Embedding: std={emb_std:.4f}") for i, layer in enumerate(model.layers): h = layer(h, mask=None, position_offset=0) h_f = h.float() stats = { "layer": i, "mean": h_f.mean().item(), "std": h_f.std().item(), "max": h_f.abs().max().item(), } layer_stats.append(stats) # Print stats print(f"\n Layer-by-layer activation statistics:") print(f" {'Layer':<8} {'Mean':>10} {'Std':>10} {'Max':>10}") print(f" {'-' * 38}") for s in layer_stats: print(f" {s['layer']:<8} {s['mean']:>10.4f} {s['std']:>10.4f} {s['max']:>10.4f}") # ── Weight initialization distribution check ── print(f"\n Weight Initialization Distribution:") print(f" {'Parameter':<40} {'Mean':>10} {'Std':>10} {'Shape'}") print(f" {'-' * 75}") weight_issues = [] for name, param in model.named_parameters(): if param.ndim < 2: continue # skip biases, norm weights p_f = param.float() p_mean = p_f.mean().item() p_std = p_f.std().item() # Expected: std ≈ 0.02 for most layers, smaller for residual projections shape_str = str(list(param.shape)) is_residual = "o_proj" in name or "down_proj" in name expected_std = 0.02 # GPT-2 style if p_std > expected_std * 5: weight_issues.append(f"Large std: {name} (std={p_std:.4f})") print(f" 🟡 {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}") elif p_std < expected_std * 0.1: weight_issues.append(f"Tiny std: {name} (std={p_std:.6f})") print(f" 🟡 {name:<38} {p_mean:>10.4f} {p_std:>10.6f} {shape_str}") else: print(f" {name:<38} {p_mean:>10.4f} {p_std:>10.4f} {shape_str}") if weight_issues: print(f"\n ⚠ {len(weight_issues)} weight distribution issue(s) found") for issue in weight_issues[:5]: print(f" • {issue}") else: print(f"\n ✅ All weight distributions look normal (std ≈ 0.02)") print(f"\n Expected init pattern:") print(f" • General Linear: N(0, 0.02)") print(f" • Residual proj (o_proj, down_proj): N(0, 0.02/√(2×layers))") print(f" • Embedding: N(0, 0.02)") # ── Ablation study guidance ── print(f"\n Component Ablation Reference:") print(" ┌──────────────────────┬────────────────────────────────────┐") print(" │ Experiment │ Expected Outcome │") print(" ├──────────────────────┼────────────────────────────────────┤") print(" │ RMSNorm → LayerNorm │ Minimal loss diff → OK │") print(" │ RoPE → Absolute PE │ Similar on short seq (<512) │") print(" │ SwiGLU → ReLU FFN │ Loss +0.05~0.15 → SwiGLU working │") print(" │ GQA → MHA │ Same loss, less memory → OK │") print(" └──────────────────────┴────────────────────────────────────┘") print(" If any replacement shows unexpected results, check that component.") # Analyze trends stds = [s["std"] for s in layer_stats] diagnosis = "healthy" detail = "" if len(stds) >= 3: # Check for monotonic increase/decrease first_third = sum(stds[:len(stds) // 3]) / (len(stds) // 3) last_third = sum(stds[-(len(stds) // 3):]) / (len(stds) // 3) ratio = last_third / max(first_third, 1e-8) if ratio > 5: diagnosis = "exploding" detail = ( f"Activation std grows {ratio:.1f}x from early to late layers. " f"Init scale may be too large." ) elif ratio < 0.2: diagnosis = "vanishing" detail = ( f"Activation std shrinks to {ratio:.1f}x from early to late layers. " f"Init scale may be too small." ) else: detail = f"Std ratio (last/first third) = {ratio:.2f} — within normal range." # Check for sudden jumps for i in range(1, len(stds)): jump = stds[i] / max(stds[i - 1], 1e-8) if jump > 10 or jump < 0.1: diagnosis = "anomaly" detail = ( f"Sudden activation change at layer {i}: " f"std {stds[i - 1]:.4f} → {stds[i]:.4f}. " f"Possible implementation bug in that layer." ) break icon = {"healthy": "✅", "exploding": "🔴", "vanishing": "🟡", "anomaly": "🔴"} print(f"\n {icon.get(diagnosis, '⚪')} Diagnosis: {diagnosis}") print(f" {detail}") return { "level": 5, "diagnosis": diagnosis, "detail": detail, "layer_stats": layer_stats, "weight_issues": weight_issues, } # ─────────────────────────────────────────────────────────────── # Main Entry Point # ─────────────────────────────────────────────────────────────── @staticmethod def run_diagnostics( model: nn.Module, dataloader: DataLoader, tokenizer: Any, train_config: TrainConfig, metrics_history: Dict[str, list], device: torch.device, dtype: torch.dtype = torch.bfloat16, vocab_size: int = 32000, levels: Optional[List[int]] = None, ) -> Dict[str, Any]: """Run the full 5-level debugging framework. Args: model: the LLM model dataloader: training dataloader tokenizer: tokenizer with encode/decode methods train_config: TrainConfig instance metrics_history: dict from MetricsTracker.history device: torch device dtype: mixed precision dtype vocab_size: model vocabulary size levels: which levels to run (default: all [0,1,2,3,4,5]) Returns: Full diagnostic report dict. """ if levels is None: levels = [0, 1, 2, 3, 4, 5] print("\n" + "═" * 60) print(" LLM Loss Debugging Framework") print(" Levels to run: " + ", ".join(str(l) for l in levels)) print("═" * 60) report: Dict[str, Any] = {} if 0 in levels: report["level_0"] = LossDebugger.diagnose_status(vocab_size, metrics_history) # If status is normal and only level 0 was explicitly requested, skip rest if ( report["level_0"]["status"] == STATUS_NORMAL and levels == [0] ): print("\n Training is healthy — no further debugging needed.") return report if 1 in levels: report["level_1"] = LossDebugger.check_data_pipeline( model, dataloader, tokenizer, vocab_size, device, dtype, ) if 2 in levels: report["level_2"] = LossDebugger.check_numerical_stability( model, dataloader, device, dtype, ) if 3 in levels: report["level_3"] = LossDebugger.diagnose_hyperparameters( metrics_history, train_config, ) if 4 in levels: model_params = sum(p.numel() for p in model.parameters()) total_tokens = len(metrics_history.get("train_loss", [])) * train_config.tokens_per_step report["level_4"] = LossDebugger.diagnose_fitting( metrics_history, model_params, total_tokens, ) if 5 in levels: report["level_5"] = LossDebugger.check_architecture( model, dataloader, device, ) # Final summary print("\n" + "═" * 60) print(" Diagnostics Complete") print("═" * 60) return report # ─────────────────────────────────────────────────────────────── # Study Roadmap # ─────────────────────────────────────────────────────────────── @staticmethod def print_study_roadmap() -> None: """Print the recommended study roadmap for LLM training optimization.""" print(_header("Study Roadmap — LLM Training Optimization")) print(""" ⭐⭐⭐ Top Priority: Optimization Fundamentals ───────────────────────────────────────────── 1. SGD → Momentum → Adam → AdamW progression - Why Adam > SGD? Why decouple weight decay in AdamW? - β₁, β₂ intuition (1st / 2nd momentum) - Ref: Loshchilov & Hutter 2019 (AdamW) - Ref: Karpathy "A Recipe for Training Neural Networks" 2. Loss Landscape - Why large LR diverges, small LR stalls - Batch size effect on landscape exploration - Ref: Li et al. 2018 "Visualizing the Loss Landscape" - Ref: McCandlish et al. 2018 "Large-Batch Training" 3. Chinchilla Scaling Law - Loss = f(N, D) relationship - Compute-optimal model size vs data allocation - Ref: Hoffmann et al. 2022 (original) - Ref: Kaplan et al. 2020 (predecessor) - Ref: Besiroglu et al. 2024 (replication/verification) ⭐⭐ Important: Training Stability ────────────────────────────────── 4. Gradient Flow: vanishing/exploding, residual as gradient highway 5. Weight Init: Xavier / Kaiming / GPT-2 style 6. Normalization: BatchNorm → LayerNorm → RMSNorm 7. Weight Decay: L2 vs decoupled, why exclude embed/norm ⭐ Advanced: Optimization Techniques ───────────────────────────────────── 8. LR Schedules: cosine vs linear vs step, warmup/cooldown 9. Gradient Accumulation & Large Batch Training 10. μP (Maximal Update Parameterization): transfer HP across scales Recommended Experiments (in order): ─────────────────────────────────── 1. Single-batch overfit (30 min) → basic sanity 2. LR Range Test (1 hour) → optimal LR range 3. 10M model quick train (2-3 hrs) → pipeline validation 4. Ablation (remove components) (1 day) → component contribution 5. 100M model + 5B tokens (1-2 days)→ mid-scale dynamics 6. 1B model full training (2-3 days)→ scaling law verification 7. LR / batch size comparison (1 day) → HP sensitivity Key References: ─────────────── ⭐⭐⭐ Karpathy "Recipe for Training NNs" — debugging mindset ⭐⭐⭐ Hoffmann et al. 2022 (Chinchilla) — scaling law ⭐⭐ Touvron et al. 2023 (LLaMA) — 1B+ training details ⭐⭐ Biderman et al. 2023 (Pythia) — open training logs ⭐⭐ Zhang et al. 2024 (TinyLlama) — 1.1B on 3T tokens ⭐⭐ Groeneveld et al. 2024 (OLMo) — fully open LLM ⭐⭐ Li et al. 2018 (Loss Landscape) — loss terrain intuition ⭐⭐ Loshchilov & Hutter 2019 (AdamW) — optimizer basics ⭐ Yang et al. 2022 (μP) — HP transfer ⭐ McCandlish et al. 2018 (Batch size) — critical batch size """)