""" Debug utilities for systematic NaN/Inf detection and root-cause analysis. Usage in train.py: from src.utils.debug_utils import TrainingDebugger debugger = TrainingDebugger( model, ancestor_table, optimizer, config={"debug_mode": True, "raise_on_nan": True, "use_hooks": True} ) with debugger: for step in range(...): debugger.check_batch(batch, step) loss, metrics, output = forward_step(...) debugger.check_forward_output(output, step) debugger.check_loss(loss, metrics, step) loss.backward() gnorm = debugger.check_gradients(step) debugger.clip_grads(step) optimizer.step() debugger.check_params_after_step(step) Design principles: - Fail-fast: detect the FIRST place where NaN/Inf appears. - Informative: print tensor name, shape, dtype, stats, module name, param name. - Minimally intrusive: wrap existing logic, don't replace it. - Configurable: debug_mode=False disables almost all overhead. """ from __future__ import annotations import sys import math from pathlib import Path from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn # --------------------------------------------------------------------------- # # 1. Low-level tensor checks # --------------------------------------------------------------------------- # def check_tensor_stats( name: str, tensor: Optional[torch.Tensor], step: Optional[int] = None, stats: bool = True, raise_on_nan: bool = False, max_elements_for_stats: int = 50_000_000, ) -> Dict[str, Any]: """ Check a single tensor for NaN/Inf and optionally print stats. CRITICAL: this function avoids creating a copy of the tensor via boolean indexing (``tensor[tensor.isfinite()]``) which caused OOM on large logits. For tensors larger than ``max_elements_for_stats`` we skip detailed stats. For all-finite tensors we use native reduce ops (no copies). For tensors that contain NaN we use ``torch.nanmean`` when available. Returns: dict with keys: has_nan, has_inf, is_finite, msg """ prefix = f"[step {step}] " if step is not None else "" if tensor is None: return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] None"} if not isinstance(tensor, torch.Tensor): return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] non-tensor {type(tensor)}"} if tensor.numel() == 0: return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] empty tensor"} # ---- Fast path: one sync instead of three ---- is_finite = tensor.isfinite().all().item() if not is_finite: has_nan = tensor.isnan().any().item() has_inf = tensor.isinf().any().item() status = "NAN" if has_nan else "INF" nan_count = tensor.isnan().sum().item() inf_count = tensor.isinf().sum().item() msg = ( f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} " f"device={tensor.device} status={status} " f"nan_count={nan_count} inf_count={inf_count}" ) print(msg, flush=True) if raise_on_nan: raise RuntimeError(f"{prefix}NaN/Inf detected in {name}") return {"has_nan": has_nan, "has_inf": has_inf, "is_finite": False, "msg": msg} # ---- OK path ---- msg = ( f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} " f"device={tensor.device} status=OK" ) if stats and tensor.numel() > 0: if tensor.numel() > max_elements_for_stats: msg += " | stats_skipped_large_tensor" elif tensor.isnan().all().item(): msg += " | all_nan" else: # All finite → native reduce ops, zero extra memory tmin = tensor.min().item() tmax = tensor.max().item() msg += f" | min={tmin} max={tmax}" if tensor.dtype.is_floating_point or tensor.dtype.is_complex: msg += f" mean={tensor.mean().item():.4e}" if tensor.numel() > 1: msg += f" std={tensor.std().item():.4e}" else: tf = tensor.float() msg += f" mean={tf.mean().item():.4e}" if tensor.numel() > 1: msg += f" std={tf.std().item():.4e}" return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": msg} def check_nested_tensors( name: str, obj: Any, step: Optional[int] = None, raise_on_nan: bool = False, stats: bool = True, max_elements_for_stats: int = 50_000_000, ) -> List[Dict[str, Any]]: """ Recursively check dict/list/tuple for tensors. """ results: List[Dict[str, Any]] = [] if isinstance(obj, torch.Tensor): results.append(check_tensor_stats( name, obj, step, stats=stats, raise_on_nan=raise_on_nan, max_elements_for_stats=max_elements_for_stats, )) elif isinstance(obj, (list, tuple)): for i, item in enumerate(obj): results.extend(check_nested_tensors( f"{name}[{i}]", item, step, raise_on_nan, stats, max_elements_for_stats )) elif isinstance(obj, dict): for k, v in obj.items(): results.extend(check_nested_tensors( f"{name}.{k}", v, step, raise_on_nan, stats, max_elements_for_stats )) return results # --------------------------------------------------------------------------- # # 2. Gradient & parameter checks # --------------------------------------------------------------------------- # def compute_grad_norm( model: nn.Module, ancestor_table: Optional[nn.Module] = None, ) -> Tuple[float, Optional[Tuple[str, nn.Parameter]], float]: """ Compute total grad norm and identify the param with the largest individual grad norm. Uses ``torch._foreach_norm`` (C++ batched op) instead of a Python loop to avoid ~100 ms overhead on large models. Returns: (total_norm, (name, param) or None, max_single_norm) """ params: List[Tuple[str, nn.Parameter]] = [] params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] if ancestor_table is not None: params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None] if not params: return 0.0, None, 0.0 # Fast batched norm computation via _foreach_norm grad_tensors = [p.grad for _, p in params] norms = torch._foreach_norm(grad_tensors, 2.0) norms_stacked = torch.stack(norms) total_norm = torch.norm(norms_stacked, 2.0).item() max_idx = int(norms_stacked.argmax().item()) max_val = norms_stacked[max_idx].item() max_param = params[max_idx] return total_norm, max_param, max_val def check_gradients( model: nn.Module, ancestor_table: Optional[nn.Module] = None, step: Optional[int] = None, raise_on_nan: bool = True, print_all: bool = False, ) -> List[Tuple[str, nn.Parameter, str]]: """ Check grads for NaN/Inf. Fast path: ``torch._foreach_norm`` batched computation (~1 ms for a 600M model). Slow path (only when a NaN/Inf is found): iterate individual params to print names. """ prefix = f"[step {step}] " if step is not None else "" bad: List[Tuple[str, nn.Parameter, str]] = [] params: List[Tuple[str, nn.Parameter]] = [] params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] if ancestor_table is not None: params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None] if not params: return bad grad_tensors = [p.grad for _, p in params] norms = torch._foreach_norm(grad_tensors, 2.0) total_norm = torch.norm(torch.stack(norms), 2.0) if not torch.isfinite(total_norm): # Slow path: find the culprit(s) for (name, p), norm in zip(params, norms): if not torch.isfinite(norm): has_nan = p.grad.isnan().any().item() status = "NAN" if has_nan else "INF" print( f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", flush=True, ) bad.append((name, p, status)) if bad and raise_on_nan: names = [n for n, _, _ in bad] raise RuntimeError( f"{prefix}Non-finite grad in {len(bad)} params: {names[:10]}..." ) elif print_all: for (name, p), norm in zip(params, norms): print( f"{prefix}[GRAD OK] {name} grad_norm={norm.item():.4e}", flush=True, ) return bad def check_model_params( model: nn.Module, ancestor_table: Optional[nn.Module] = None, step: Optional[int] = None, raise_on_nan: bool = True, ) -> List[Tuple[str, nn.Parameter, str]]: """ Check params for NaN/Inf (typically after optimizer.step()). Fast path via ``torch._foreach_norm``; slow path only when needed. """ prefix = f"[step {step}] " if step is not None else "" bad: List[Tuple[str, nn.Parameter, str]] = [] params: List[Tuple[str, nn.Parameter]] = [] params += [(n, p) for n, p in model.named_parameters()] if ancestor_table is not None: params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters()] if not params: return bad param_tensors = [p.data for _, p in params] norms = torch._foreach_norm(param_tensors, 2.0) total_norm = torch.norm(torch.stack(norms), 2.0) if not torch.isfinite(total_norm): for (name, p), norm in zip(params, norms): if not torch.isfinite(norm): has_nan = p.data.isnan().any().item() status = "NAN" if has_nan else "INF" print( f"{prefix}[PARAM {status}] {name} shape={list(p.shape)}", flush=True, ) bad.append((name, p, status)) if bad and raise_on_nan: names = [n for n, _, _ in bad] raise RuntimeError( f"{prefix}Non-finite param in {len(bad)} params: {names[:10]}..." ) return bad # --------------------------------------------------------------------------- # # 3. Forward hooks for intermediate-layer NaN detection # --------------------------------------------------------------------------- # def register_nan_hooks( model: nn.Module, step: Optional[int] = None, verbose: bool = False, module_filter: Optional[str] = "blocks", ) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, Any]]: """ Register forward hooks to catch the FIRST layer that produces non-finite output (or receives non-finite input). Args: module_filter: Controls which modules get hooks. - "blocks" (default): hooks on each block inside ``model.blocks``. Typically ~12 hooks for a transformer, negligible overhead. - "children": hooks on all direct children of ``model``. - "all": hooks on *every* named submodule (old behaviour). Very slow (~10x forward slowdown) and increases memory. - None / "": no hooks. Returns: (list_of_handles, first_bad_dict) """ handles: List[torch.utils.hooks.RemovableHandle] = [] first_bad: Dict[str, Any] = {} def _targets(): """Yield (name, module) pairs to hook.""" if not module_filter: return mf = str(module_filter).lower() if mf == "all": for name, module in model.named_modules(): if name: yield name, module elif mf == "children": for name, module in model.named_children(): yield name, module elif mf == "blocks": # Default: hook each block in model.blocks (e.g. transformer blocks) if hasattr(model, "blocks") and isinstance(model.blocks, nn.ModuleList): for i, block in enumerate(model.blocks): yield f"blocks.{i}", block else: # Fallback to direct children if no .blocks attr for name, module in model.named_children(): yield name, module else: # Treat as a comma-separated list of module names allowed = {s.strip() for s in mf.split(",")} for name, module in model.named_modules(): if name in allowed: yield name, module def make_hook(module_name: str, module_type: str): def hook(module, inputs, output): if first_bad: return prefix = f"[step {step}] " if step is not None else "" # Check inputs for i, inp in enumerate(inputs): if isinstance(inp, torch.Tensor) and inp.numel() > 0 and not inp.isfinite().all(): first_bad.update({ "stage": "input", "module_name": module_name, "module_type": module_type, "tensor_idx": i, "shape": list(inp.shape), "has_nan": inp.isnan().any().item(), "has_inf": inp.isinf().any().item(), }) print( f"{prefix}[HOOK INPUT] {module_name} ({module_type}) " f"input[{i}] shape={list(inp.shape)} " f"nan={inp.isnan().sum().item()} inf={inp.isinf().sum().item()}", flush=True, ) return # Check outputs (recursive over tuple/list/dict) tensors_to_check = [] if isinstance(output, torch.Tensor): tensors_to_check = [("output", output)] elif isinstance(output, (list, tuple)): for i, o in enumerate(output): if isinstance(o, torch.Tensor): tensors_to_check.append((f"output[{i}]", o)) elif isinstance(output, dict): for k, v in output.items(): if isinstance(v, torch.Tensor): tensors_to_check.append((f"output['{k}']", v)) for out_name, out_tensor in tensors_to_check: if out_tensor.numel() > 0 and not out_tensor.isfinite().all(): first_bad.update({ "stage": "output", "module_name": module_name, "module_type": module_type, "tensor_name": out_name, "shape": list(out_tensor.shape), "has_nan": out_tensor.isnan().any().item(), "has_inf": out_tensor.isinf().any().item(), }) print( f"{prefix}[HOOK OUTPUT] {module_name} ({module_type}) " f"{out_name} shape={list(out_tensor.shape)} " f"nan={out_tensor.isnan().sum().item()} inf={out_tensor.isinf().sum().item()}", flush=True, ) return if verbose: for out_name, out_tensor in tensors_to_check: print( f"{prefix}[HOOK OK] {module_name} ({module_type}) " f"{out_name} shape={list(out_tensor.shape)}", flush=True, ) return hook for name, module in _targets(): h = module.register_forward_hook(make_hook(name, module.__class__.__name__)) handles.append(h) return handles, first_bad def remove_nan_hooks(handles: List[torch.utils.hooks.RemovableHandle]) -> None: for h in handles: h.remove() # --------------------------------------------------------------------------- # # 4. Snapshot saving on anomaly # --------------------------------------------------------------------------- # def save_debug_state( batch: Optional[Dict[str, Any]], model: nn.Module, ancestor_table: Optional[nn.Module], optimizer: torch.optim.Optimizer, loss: Optional[torch.Tensor], metrics: Optional[Dict[str, Any]], step: int, save_dir: Union[str, Path], extra: Optional[Dict[str, Any]] = None, ) -> Path: """ Save current batch, model, optimizer, and metrics for post-mortem analysis. """ save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") prefix = f"debug_step{step:06d}_{ts}" state: Dict[str, Any] = {"step": step} if batch is not None: state["batch"] = { k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in batch.items() } if loss is not None: state["loss"] = loss.item() if torch.isfinite(loss) else float("nan") if metrics is not None: state["metrics"] = { k: (v.item() if isinstance(v, torch.Tensor) else v) for k, v in metrics.items() } if extra is not None: state["extra"] = extra torch.save(state, save_dir / f"{prefix}_state.pt") # Unwrap to get raw state_dict m = model while hasattr(m, "_orig_mod"): m = m._orig_mod while hasattr(m, "module"): m = m.module torch.save(m.state_dict(), save_dir / f"{prefix}_model.pt") if ancestor_table is not None: torch.save(ancestor_table.state_dict(), save_dir / f"{prefix}_ancestor.pt") torch.save(optimizer.state_dict(), save_dir / f"{prefix}_optimizer.pt") path = save_dir / f"{prefix}_state.pt" print(f"[DEBUG] Saved debug snapshot to {path}", flush=True) return path # --------------------------------------------------------------------------- # # 5. TrainingDebugger — high-level orchestrator # --------------------------------------------------------------------------- # class TrainingDebugger: """ Systematic NaN/Inf debugger for training loops. Typical usage: debugger = TrainingDebugger(model, ancestor_table, optimizer, config) with debugger: # registers hooks on enter, removes on exit for step in range(num_steps): ... """ def __init__( self, model: nn.Module, ancestor_table: Optional[nn.Module], optimizer: torch.optim.Optimizer, config: Optional[Dict[str, Any]] = None, ): self.model = model self.ancestor_table = ancestor_table self.optimizer = optimizer self.cfg = config or {} # Top-level switches self.debug_mode = self.cfg.get("debug_mode", True) self.raise_on_nan = self.cfg.get("raise_on_nan", False) # Stage toggles (prefixed with _ to avoid shadowing methods) # Defaults are conservative: only loss + grad NaN guards are on every # step; input/output checks default OFF to avoid GPU sync overhead. self._check_inputs = self.cfg.get("check_inputs", False) self._check_outputs = self.cfg.get("check_outputs", False) self._check_loss = self.cfg.get("check_loss", False) self._check_grads = self.cfg.get("check_grads", True) self._check_params = self.cfg.get("check_params", False) # Hooks default to OFF because they add significant forward overhead # (5-10x slower) and increase memory. Only enable when you need to # pinpoint which *intermediate* layer first produces NaN. self._use_hooks = self.cfg.get("use_hooks", False) # Behaviour self.save_on_nan = self.cfg.get("save_on_nan", True) self.save_dir = Path(self.cfg.get("save_dir", "outputs/debug")) self.grad_clip = float(self.cfg.get("grad_clip", 1.0)) self.log_interval = int(self.cfg.get("log_interval", 100)) self.print_stats_every = int(self.cfg.get("print_stats_every", 100)) self.grad_norm_warn_threshold = float(self.cfg.get("grad_norm_warn_threshold", 1e5)) self.check_params_every = int(self.cfg.get("check_params_every", 100)) self.check_outputs_every = int(self.cfg.get("check_outputs_every", 1)) self.check_batch_every = int(self.cfg.get("check_batch_every", 1)) # AMP / scaler tracking self.scaler: Optional[torch.cuda.amp.GradScaler] = self.cfg.get("scaler", None) # Internal state self._hook_handles: List[torch.utils.hooks.RemovableHandle] = [] self._hook_first_bad: Dict[str, Any] = {} self._last_clean_step = -1 def _unwrap_model(self) -> nn.Module: m = self.model while hasattr(m, "_orig_mod"): m = m._orig_mod while hasattr(m, "module"): m = m.module return m # -- context manager -- def __enter__(self) -> "TrainingDebugger": if self.debug_mode and self._use_hooks: hook_filter = self.cfg.get("hook_modules", "blocks") if hasattr(self.model, "blocks"): n_blocks = len(self.model.blocks) else: n_blocks = "?" print( f"[DEBUG] Forward hooks ENABLED ({hook_filter}, ~{n_blocks} hooks). " f"This will slow down forward pass significantly. " f"Set use_hooks=false in debug config to disable.", flush=True, ) self._hook_handles, self._hook_first_bad = register_nan_hooks( self._unwrap_model(), step=None, verbose=False, module_filter=hook_filter ) return self def __exit__(self, *args: Any) -> None: remove_nan_hooks(self._hook_handles) self._hook_handles = [] # -- helpers -- def _prefix(self, step: Optional[int] = None) -> str: return f"[step {step}] " if step is not None else "" def _maybe_save( self, batch: Optional[Dict[str, Any]], loss: Optional[torch.Tensor], metrics: Optional[Dict[str, Any]], step: int, extra: Optional[Dict[str, Any]] = None, ) -> None: if not self.save_on_nan: return save_debug_state( batch, self.model, self.ancestor_table, self.optimizer, loss, metrics, step, self.save_dir, extra, ) # -- stage A: input data -- def check_batch(self, batch: Dict[str, Any], step: int) -> None: if not self.debug_mode or not self._check_inputs: return if self.check_batch_every > 1 and step % self.check_batch_every != 0: return prefix = self._prefix(step) for key, val in batch.items(): if isinstance(val, torch.Tensor): # stats=False: batch tensor stats (min/max/mean) are rarely useful # and the extra .item() syncs add ~3-4 ms per tensor. r = check_tensor_stats( f"batch.{key}", val, step, stats=False, raise_on_nan=False ) if not r["is_finite"]: print( f"{prefix}[INPUT ERROR] Non-finite tensor in batch['{key}']\n {r['msg']}", flush=True, ) self._maybe_save(batch, None, None, step, extra={"bad_key": key}) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite input in batch['{key}']") # -- stage B: model output -- def check_forward_output(self, output: Any, step: int) -> None: if not self.debug_mode or not self._check_outputs: return if self.check_outputs_every > 1 and step % self.check_outputs_every != 0: return # For forward output we only check NaN/Inf, NOT full stats. # Model logits [B, S, V] can be >1B elements; computing min/max/mean/std # would temporarily allocate several GB and cause OOM. results = check_nested_tensors( "model.output", output, step, raise_on_nan=False, stats=False, max_elements_for_stats=50_000_000, ) bad = [r for r in results if not r["is_finite"]] if bad: prefix = self._prefix(step) print(f"{prefix}[FORWARD ERROR] Non-finite model output:", flush=True) for r in bad: print(f" {r['msg']}", flush=True) if self._hook_first_bad: print(f" First bad layer from hooks: {self._hook_first_bad}", flush=True) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite model output") # -- stage C: loss -- def check_loss( self, loss: torch.Tensor, metrics: Dict[str, Any], step: int, ) -> None: if not self.debug_mode or not self._check_loss: return prefix = self._prefix(step) # Check sub-items for key, val in metrics.items(): if isinstance(val, torch.Tensor) and not val.isfinite().all(): has_nan = val.isnan().any().item() has_inf = val.isinf().any().item() status = "NAN" if has_nan else "INF" print( f"{prefix}[LOSS SUBITEM {status}] {key} = " f"{val.item() if val.numel() == 1 else val.detach()}", flush=True, ) self._maybe_save(None, loss, metrics, step, extra={"bad_metric": key}) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite loss subitem: {key}") # Check total loss if not loss.isfinite().all(): nan_count = loss.isnan().sum().item() inf_count = loss.isinf().sum().item() print( f"{prefix}[TOTAL LOSS ERROR] loss is non-finite " f"nan_count={nan_count} inf_count={inf_count}", flush=True, ) self._maybe_save(None, loss, metrics, step) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite total loss") else: self._last_clean_step = step # -- stage D: gradients -- def check_gradients(self, step: int) -> torch.Tensor: """ Check every param's grad and return total norm (before clipping). Optimised: single ``torch._foreach_norm`` call serves both NaN detection and norm statistics (avoids duplicate GPU work). """ model = self._unwrap_model() device = next(model.parameters()).device # Collect grad-bearing params params: List[Tuple[str, nn.Parameter]] = [] params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] if self.ancestor_table is not None: params += [ (f"ancestor_table.{n}", p) for n, p in self.ancestor_table.named_parameters() if p.grad is not None ] if not params: return torch.tensor(0.0, device=device) grad_tensors = [p.grad for _, p in params] norms = torch._foreach_norm(grad_tensors, 2.0) norms_stacked = torch.stack(norms) total_norm_t = torch.norm(norms_stacked, 2.0) prefix = self._prefix(step) total_norm = total_norm_t.item() if self.debug_mode and self._check_grads: if not total_norm_t.isfinite(): bad = [] for (name, p), norm in zip(params, norms): if not norm.isfinite(): has_nan = p.grad.isnan().any().item() status = "NAN" if has_nan else "INF" print( f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", flush=True, ) bad.append((name, p, status)) if bad: print( f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. " f"Total norm={total_norm:.4e}", flush=True, ) for name, _, status in bad[:5]: print(f" - {name}: {status}", flush=True) self._maybe_save( None, None, None, step, extra={"bad_grads": [n for n, _, _ in bad]}, ) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite gradients") # Only compute per-param max when we actually need to print it. need_max = ( total_norm > self.grad_norm_warn_threshold or step % self.print_stats_every == 0 ) if need_max: max_idx = int(norms_stacked.argmax().item()) max_val = norms_stacked[max_idx].item() max_name = params[max_idx][0] if total_norm > self.grad_norm_warn_threshold: print( f"{prefix}[GRAD WARN] grad_norm={total_norm:.4e} (very large). " f"Largest single-param grad={max_val:.4e} in {max_name}", flush=True, ) if step % self.print_stats_every == 0: print( f"{prefix}[GRAD STATS] total_norm={total_norm:.4e} " f"max_single={max_val:.4e} ({max_name})", flush=True, ) return torch.tensor(total_norm, device=device) def clip_grads(self, step: int) -> torch.Tensor: """ Clip gradients and return total norm (PyTorch returns pre-clip norm). When ``debug_mode`` and ``check_grads`` are enabled, this method also performs NaN/Inf detection, large-gradient warnings, and periodic stats logging. This merges the work of ``check_gradients()`` so that callers don't need to pay for a duplicate ``_foreach_norm`` pass. """ model = self._unwrap_model() device = next(model.parameters()).device params = list(model.parameters()) if self.ancestor_table is not None: params += list(self.ancestor_table.parameters()) if self.grad_clip > 0: total_norm = torch.nn.utils.clip_grad_norm_(params, self.grad_clip) else: total_norm = torch.tensor(0.0, device=device) prefix = self._prefix(step) # Single isfinite evaluation shared by debug and clip-error paths. # Calling it once avoids duplicate GPU sync. total_norm_finite = torch.isfinite(total_norm).item() # ---- NaN / Inf detection (merged from check_gradients) ---- if self.debug_mode and self._check_grads and not total_norm_finite: # Slow path: identify the culprit(s) bad = [] for name, p in model.named_parameters(): if p.grad is not None and not torch.isfinite(p.grad).all(): has_nan = p.grad.isnan().any().item() status = "NAN" if has_nan else "INF" print( f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", flush=True, ) bad.append((name, p, status)) if bad: print( f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. " f"Total norm={total_norm.item():.4e}", flush=True, ) for name, _, status in bad[:5]: print(f" - {name}: {status}", flush=True) self._maybe_save( None, None, None, step, extra={"bad_grads": [n for n, _, _ in bad]}, ) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite gradients") # Stats / warnings (only when total norm is finite) if self.debug_mode and self._check_grads and total_norm_finite: total_norm_val = total_norm.item() need_max = ( total_norm_val > self.grad_norm_warn_threshold or step % self.print_stats_every == 0 ) if need_max: grad_params = [ (n, p) for n, p in model.named_parameters() if p.grad is not None ] if self.ancestor_table is not None: grad_params += [ (f"ancestor_table.{n}", p) for n, p in self.ancestor_table.named_parameters() if p.grad is not None ] if grad_params: norms = torch._foreach_norm([p.grad for _, p in grad_params], 2.0) norms_stacked = torch.stack(norms) max_idx = int(norms_stacked.argmax().item()) max_val = norms_stacked[max_idx].item() max_name = grad_params[max_idx][0] if total_norm_val > self.grad_norm_warn_threshold: print( f"{prefix}[GRAD WARN] grad_norm={total_norm_val:.4e} (very large). " f"Largest single-param grad={max_val:.4e} in {max_name}", flush=True, ) if step % self.print_stats_every == 0: print( f"{prefix}[GRAD STATS] total_norm={total_norm_val:.4e} " f"max_single={max_val:.4e} ({max_name})", flush=True, ) # After clip, double-check that clipping didn't produce NaN # (clip_grad_norm_ with inf norm can produce nan via inf*0) if not total_norm_finite: print( f"{prefix}[CLIP ERROR] total_norm non-finite after clip_grad_norm_: {total_norm}", flush=True, ) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite grad norm after clipping") return total_norm # -- stage E: parameters after step -- def check_params_after_step(self, step: int) -> None: if not self.debug_mode or not self._check_params: return if self.check_params_every > 1 and step % self.check_params_every != 0: return model = self._unwrap_model() bad = check_model_params(model, self.ancestor_table, step, raise_on_nan=False) if bad: prefix = self._prefix(step) print(f"{prefix}[PARAM ERROR] {len(bad)} params non-finite after optimizer.step():", flush=True) for name, _, status in bad[:5]: print(f" - {name}: {status}", flush=True) self._maybe_save(None, None, None, step, extra={"bad_params": [n for n, _, _ in bad]}) if self.raise_on_nan: raise RuntimeError(f"{prefix}Non-finite params after step") # -- periodic summary log -- def log_step( self, step: int, loss: torch.Tensor, metrics: Dict[str, Any], lr: float, grad_norm: float, elapsed: float, ) -> None: """ Structured training log line. Only called on main rank. """ prefix = self._prefix(step) parts = [f"step={step:6d}"] # Loss items l_total = metrics.get("loss_total", loss) l_leaf = metrics.get("loss_leaf", torch.tensor(0.0)) l_anc = metrics.get("loss_ancestor", torch.tensor(0.0)) for label, t in [("total", l_total), ("leaf", l_leaf), ("ancestor", l_anc)]: v = t.item() if hasattr(t, "item") else float(t) parts.append(f"{label}={v:.4f}") # LR & grad norm parts.append(f"lr={lr:.2e}") parts.append(f"gnorm={grad_norm:.4e}") # AMP scaler if self.scaler is not None: parts.append(f"scale={self.scaler.get_scale():.1f}") # Hooks first-bad (if any) if self._hook_first_bad: fb = self._hook_first_bad parts.append(f"FIRST_BAD={fb.get('module_name','?')}({fb.get('stage','?')})") parts.append(f"t={elapsed:.1f}s") print(f"{prefix}{' | '.join(parts)}", flush=True) def reset_hooks_step(self, step: int) -> None: """ Reset hook first-bad dict for the new step. Call at start of each step. """ self._hook_first_bad.clear()