| """ |
| Debug utilities for systematic NaN/Inf detection and root-cause analysis. |
| |
| Usage in train.py: |
| from src.utils.debug_utils import TrainingDebugger |
| |
| debugger = TrainingDebugger( |
| model, ancestor_table, optimizer, |
| config={"debug_mode": True, "raise_on_nan": True, "use_hooks": True} |
| ) |
| with debugger: |
| for step in range(...): |
| debugger.check_batch(batch, step) |
| loss, metrics, output = forward_step(...) |
| debugger.check_forward_output(output, step) |
| debugger.check_loss(loss, metrics, step) |
| loss.backward() |
| gnorm = debugger.check_gradients(step) |
| debugger.clip_grads(step) |
| optimizer.step() |
| debugger.check_params_after_step(step) |
| |
| Design principles: |
| - Fail-fast: detect the FIRST place where NaN/Inf appears. |
| - Informative: print tensor name, shape, dtype, stats, module name, param name. |
| - Minimally intrusive: wrap existing logic, don't replace it. |
| - Configurable: debug_mode=False disables almost all overhead. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| import math |
| from pathlib import Path |
| from datetime import datetime |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| |
| |
| |
|
|
| def check_tensor_stats( |
| name: str, |
| tensor: Optional[torch.Tensor], |
| step: Optional[int] = None, |
| stats: bool = True, |
| raise_on_nan: bool = False, |
| max_elements_for_stats: int = 50_000_000, |
| ) -> Dict[str, Any]: |
| """ |
| Check a single tensor for NaN/Inf and optionally print stats. |
| |
| CRITICAL: this function avoids creating a copy of the tensor via boolean |
| indexing (``tensor[tensor.isfinite()]``) which caused OOM on large logits. |
| For tensors larger than ``max_elements_for_stats`` we skip detailed stats. |
| For all-finite tensors we use native reduce ops (no copies). |
| For tensors that contain NaN we use ``torch.nanmean`` when available. |
| |
| Returns: |
| dict with keys: has_nan, has_inf, is_finite, msg |
| """ |
| prefix = f"[step {step}] " if step is not None else "" |
|
|
| if tensor is None: |
| return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] None"} |
|
|
| if not isinstance(tensor, torch.Tensor): |
| return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] non-tensor {type(tensor)}"} |
|
|
| if tensor.numel() == 0: |
| return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] empty tensor"} |
|
|
| |
| is_finite = tensor.isfinite().all().item() |
|
|
| if not is_finite: |
| has_nan = tensor.isnan().any().item() |
| has_inf = tensor.isinf().any().item() |
| status = "NAN" if has_nan else "INF" |
| nan_count = tensor.isnan().sum().item() |
| inf_count = tensor.isinf().sum().item() |
| msg = ( |
| f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} " |
| f"device={tensor.device} status={status} " |
| f"nan_count={nan_count} inf_count={inf_count}" |
| ) |
| print(msg, flush=True) |
| if raise_on_nan: |
| raise RuntimeError(f"{prefix}NaN/Inf detected in {name}") |
| return {"has_nan": has_nan, "has_inf": has_inf, "is_finite": False, "msg": msg} |
|
|
| |
| msg = ( |
| f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} " |
| f"device={tensor.device} status=OK" |
| ) |
|
|
| if stats and tensor.numel() > 0: |
| if tensor.numel() > max_elements_for_stats: |
| msg += " | stats_skipped_large_tensor" |
| elif tensor.isnan().all().item(): |
| msg += " | all_nan" |
| else: |
| |
| tmin = tensor.min().item() |
| tmax = tensor.max().item() |
| msg += f" | min={tmin} max={tmax}" |
| if tensor.dtype.is_floating_point or tensor.dtype.is_complex: |
| msg += f" mean={tensor.mean().item():.4e}" |
| if tensor.numel() > 1: |
| msg += f" std={tensor.std().item():.4e}" |
| else: |
| tf = tensor.float() |
| msg += f" mean={tf.mean().item():.4e}" |
| if tensor.numel() > 1: |
| msg += f" std={tf.std().item():.4e}" |
|
|
| return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": msg} |
|
|
|
|
| def check_nested_tensors( |
| name: str, |
| obj: Any, |
| step: Optional[int] = None, |
| raise_on_nan: bool = False, |
| stats: bool = True, |
| max_elements_for_stats: int = 50_000_000, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Recursively check dict/list/tuple for tensors. |
| """ |
| results: List[Dict[str, Any]] = [] |
| if isinstance(obj, torch.Tensor): |
| results.append(check_tensor_stats( |
| name, obj, step, stats=stats, |
| raise_on_nan=raise_on_nan, |
| max_elements_for_stats=max_elements_for_stats, |
| )) |
| elif isinstance(obj, (list, tuple)): |
| for i, item in enumerate(obj): |
| results.extend(check_nested_tensors( |
| f"{name}[{i}]", item, step, raise_on_nan, stats, max_elements_for_stats |
| )) |
| elif isinstance(obj, dict): |
| for k, v in obj.items(): |
| results.extend(check_nested_tensors( |
| f"{name}.{k}", v, step, raise_on_nan, stats, max_elements_for_stats |
| )) |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def compute_grad_norm( |
| model: nn.Module, |
| ancestor_table: Optional[nn.Module] = None, |
| ) -> Tuple[float, Optional[Tuple[str, nn.Parameter]], float]: |
| """ |
| Compute total grad norm and identify the param with the largest individual grad norm. |
| |
| Uses ``torch._foreach_norm`` (C++ batched op) instead of a Python loop |
| to avoid ~100 ms overhead on large models. |
| |
| Returns: |
| (total_norm, (name, param) or None, max_single_norm) |
| """ |
| params: List[Tuple[str, nn.Parameter]] = [] |
| params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] |
| if ancestor_table is not None: |
| params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None] |
|
|
| if not params: |
| return 0.0, None, 0.0 |
|
|
| |
| grad_tensors = [p.grad for _, p in params] |
| norms = torch._foreach_norm(grad_tensors, 2.0) |
| norms_stacked = torch.stack(norms) |
| total_norm = torch.norm(norms_stacked, 2.0).item() |
|
|
| max_idx = int(norms_stacked.argmax().item()) |
| max_val = norms_stacked[max_idx].item() |
| max_param = params[max_idx] |
|
|
| return total_norm, max_param, max_val |
|
|
|
|
| def check_gradients( |
| model: nn.Module, |
| ancestor_table: Optional[nn.Module] = None, |
| step: Optional[int] = None, |
| raise_on_nan: bool = True, |
| print_all: bool = False, |
| ) -> List[Tuple[str, nn.Parameter, str]]: |
| """ |
| Check grads for NaN/Inf. |
| |
| Fast path: ``torch._foreach_norm`` batched computation (~1 ms for a 600M model). |
| Slow path (only when a NaN/Inf is found): iterate individual params to print names. |
| """ |
| prefix = f"[step {step}] " if step is not None else "" |
| bad: List[Tuple[str, nn.Parameter, str]] = [] |
|
|
| params: List[Tuple[str, nn.Parameter]] = [] |
| params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] |
| if ancestor_table is not None: |
| params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None] |
|
|
| if not params: |
| return bad |
|
|
| grad_tensors = [p.grad for _, p in params] |
| norms = torch._foreach_norm(grad_tensors, 2.0) |
| total_norm = torch.norm(torch.stack(norms), 2.0) |
|
|
| if not torch.isfinite(total_norm): |
| |
| for (name, p), norm in zip(params, norms): |
| if not torch.isfinite(norm): |
| has_nan = p.grad.isnan().any().item() |
| status = "NAN" if has_nan else "INF" |
| print( |
| f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", |
| flush=True, |
| ) |
| bad.append((name, p, status)) |
| if bad and raise_on_nan: |
| names = [n for n, _, _ in bad] |
| raise RuntimeError( |
| f"{prefix}Non-finite grad in {len(bad)} params: {names[:10]}..." |
| ) |
| elif print_all: |
| for (name, p), norm in zip(params, norms): |
| print( |
| f"{prefix}[GRAD OK] {name} grad_norm={norm.item():.4e}", |
| flush=True, |
| ) |
|
|
| return bad |
|
|
|
|
| def check_model_params( |
| model: nn.Module, |
| ancestor_table: Optional[nn.Module] = None, |
| step: Optional[int] = None, |
| raise_on_nan: bool = True, |
| ) -> List[Tuple[str, nn.Parameter, str]]: |
| """ |
| Check params for NaN/Inf (typically after optimizer.step()). |
| |
| Fast path via ``torch._foreach_norm``; slow path only when needed. |
| """ |
| prefix = f"[step {step}] " if step is not None else "" |
| bad: List[Tuple[str, nn.Parameter, str]] = [] |
|
|
| params: List[Tuple[str, nn.Parameter]] = [] |
| params += [(n, p) for n, p in model.named_parameters()] |
| if ancestor_table is not None: |
| params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters()] |
|
|
| if not params: |
| return bad |
|
|
| param_tensors = [p.data for _, p in params] |
| norms = torch._foreach_norm(param_tensors, 2.0) |
| total_norm = torch.norm(torch.stack(norms), 2.0) |
|
|
| if not torch.isfinite(total_norm): |
| for (name, p), norm in zip(params, norms): |
| if not torch.isfinite(norm): |
| has_nan = p.data.isnan().any().item() |
| status = "NAN" if has_nan else "INF" |
| print( |
| f"{prefix}[PARAM {status}] {name} shape={list(p.shape)}", |
| flush=True, |
| ) |
| bad.append((name, p, status)) |
| if bad and raise_on_nan: |
| names = [n for n, _, _ in bad] |
| raise RuntimeError( |
| f"{prefix}Non-finite param in {len(bad)} params: {names[:10]}..." |
| ) |
|
|
| return bad |
|
|
|
|
| |
| |
| |
|
|
| def register_nan_hooks( |
| model: nn.Module, |
| step: Optional[int] = None, |
| verbose: bool = False, |
| module_filter: Optional[str] = "blocks", |
| ) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, Any]]: |
| """ |
| Register forward hooks to catch the FIRST layer that produces non-finite |
| output (or receives non-finite input). |
| |
| Args: |
| module_filter: Controls which modules get hooks. |
| - "blocks" (default): hooks on each block inside ``model.blocks``. |
| Typically ~12 hooks for a transformer, negligible overhead. |
| - "children": hooks on all direct children of ``model``. |
| - "all": hooks on *every* named submodule (old behaviour). |
| Very slow (~10x forward slowdown) and increases memory. |
| - None / "": no hooks. |
| |
| Returns: |
| (list_of_handles, first_bad_dict) |
| """ |
| handles: List[torch.utils.hooks.RemovableHandle] = [] |
| first_bad: Dict[str, Any] = {} |
|
|
| def _targets(): |
| """Yield (name, module) pairs to hook.""" |
| if not module_filter: |
| return |
| mf = str(module_filter).lower() |
| if mf == "all": |
| for name, module in model.named_modules(): |
| if name: |
| yield name, module |
| elif mf == "children": |
| for name, module in model.named_children(): |
| yield name, module |
| elif mf == "blocks": |
| |
| if hasattr(model, "blocks") and isinstance(model.blocks, nn.ModuleList): |
| for i, block in enumerate(model.blocks): |
| yield f"blocks.{i}", block |
| else: |
| |
| for name, module in model.named_children(): |
| yield name, module |
| else: |
| |
| allowed = {s.strip() for s in mf.split(",")} |
| for name, module in model.named_modules(): |
| if name in allowed: |
| yield name, module |
|
|
| def make_hook(module_name: str, module_type: str): |
| def hook(module, inputs, output): |
| if first_bad: |
| return |
| prefix = f"[step {step}] " if step is not None else "" |
|
|
| |
| for i, inp in enumerate(inputs): |
| if isinstance(inp, torch.Tensor) and inp.numel() > 0 and not inp.isfinite().all(): |
| first_bad.update({ |
| "stage": "input", |
| "module_name": module_name, |
| "module_type": module_type, |
| "tensor_idx": i, |
| "shape": list(inp.shape), |
| "has_nan": inp.isnan().any().item(), |
| "has_inf": inp.isinf().any().item(), |
| }) |
| print( |
| f"{prefix}[HOOK INPUT] {module_name} ({module_type}) " |
| f"input[{i}] shape={list(inp.shape)} " |
| f"nan={inp.isnan().sum().item()} inf={inp.isinf().sum().item()}", |
| flush=True, |
| ) |
| return |
|
|
| |
| tensors_to_check = [] |
| if isinstance(output, torch.Tensor): |
| tensors_to_check = [("output", output)] |
| elif isinstance(output, (list, tuple)): |
| for i, o in enumerate(output): |
| if isinstance(o, torch.Tensor): |
| tensors_to_check.append((f"output[{i}]", o)) |
| elif isinstance(output, dict): |
| for k, v in output.items(): |
| if isinstance(v, torch.Tensor): |
| tensors_to_check.append((f"output['{k}']", v)) |
|
|
| for out_name, out_tensor in tensors_to_check: |
| if out_tensor.numel() > 0 and not out_tensor.isfinite().all(): |
| first_bad.update({ |
| "stage": "output", |
| "module_name": module_name, |
| "module_type": module_type, |
| "tensor_name": out_name, |
| "shape": list(out_tensor.shape), |
| "has_nan": out_tensor.isnan().any().item(), |
| "has_inf": out_tensor.isinf().any().item(), |
| }) |
| print( |
| f"{prefix}[HOOK OUTPUT] {module_name} ({module_type}) " |
| f"{out_name} shape={list(out_tensor.shape)} " |
| f"nan={out_tensor.isnan().sum().item()} inf={out_tensor.isinf().sum().item()}", |
| flush=True, |
| ) |
| return |
|
|
| if verbose: |
| for out_name, out_tensor in tensors_to_check: |
| print( |
| f"{prefix}[HOOK OK] {module_name} ({module_type}) " |
| f"{out_name} shape={list(out_tensor.shape)}", |
| flush=True, |
| ) |
|
|
| return hook |
|
|
| for name, module in _targets(): |
| h = module.register_forward_hook(make_hook(name, module.__class__.__name__)) |
| handles.append(h) |
|
|
| return handles, first_bad |
|
|
|
|
| def remove_nan_hooks(handles: List[torch.utils.hooks.RemovableHandle]) -> None: |
| for h in handles: |
| h.remove() |
|
|
|
|
| |
| |
| |
|
|
| def save_debug_state( |
| batch: Optional[Dict[str, Any]], |
| model: nn.Module, |
| ancestor_table: Optional[nn.Module], |
| optimizer: torch.optim.Optimizer, |
| loss: Optional[torch.Tensor], |
| metrics: Optional[Dict[str, Any]], |
| step: int, |
| save_dir: Union[str, Path], |
| extra: Optional[Dict[str, Any]] = None, |
| ) -> Path: |
| """ |
| Save current batch, model, optimizer, and metrics for post-mortem analysis. |
| """ |
| save_dir = Path(save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| prefix = f"debug_step{step:06d}_{ts}" |
|
|
| state: Dict[str, Any] = {"step": step} |
| if batch is not None: |
| state["batch"] = { |
| k: v.cpu() if isinstance(v, torch.Tensor) else v |
| for k, v in batch.items() |
| } |
| if loss is not None: |
| state["loss"] = loss.item() if torch.isfinite(loss) else float("nan") |
| if metrics is not None: |
| state["metrics"] = { |
| k: (v.item() if isinstance(v, torch.Tensor) else v) |
| for k, v in metrics.items() |
| } |
| if extra is not None: |
| state["extra"] = extra |
|
|
| torch.save(state, save_dir / f"{prefix}_state.pt") |
|
|
| |
| m = model |
| while hasattr(m, "_orig_mod"): |
| m = m._orig_mod |
| while hasattr(m, "module"): |
| m = m.module |
|
|
| torch.save(m.state_dict(), save_dir / f"{prefix}_model.pt") |
| if ancestor_table is not None: |
| torch.save(ancestor_table.state_dict(), save_dir / f"{prefix}_ancestor.pt") |
| torch.save(optimizer.state_dict(), save_dir / f"{prefix}_optimizer.pt") |
|
|
| path = save_dir / f"{prefix}_state.pt" |
| print(f"[DEBUG] Saved debug snapshot to {path}", flush=True) |
| return path |
|
|
|
|
| |
| |
| |
|
|
| class TrainingDebugger: |
| """ |
| Systematic NaN/Inf debugger for training loops. |
| |
| Typical usage: |
| debugger = TrainingDebugger(model, ancestor_table, optimizer, config) |
| with debugger: # registers hooks on enter, removes on exit |
| for step in range(num_steps): |
| ... |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| ancestor_table: Optional[nn.Module], |
| optimizer: torch.optim.Optimizer, |
| config: Optional[Dict[str, Any]] = None, |
| ): |
| self.model = model |
| self.ancestor_table = ancestor_table |
| self.optimizer = optimizer |
| self.cfg = config or {} |
|
|
| |
| self.debug_mode = self.cfg.get("debug_mode", True) |
| self.raise_on_nan = self.cfg.get("raise_on_nan", False) |
|
|
| |
| |
| |
| self._check_inputs = self.cfg.get("check_inputs", False) |
| self._check_outputs = self.cfg.get("check_outputs", False) |
| self._check_loss = self.cfg.get("check_loss", False) |
| self._check_grads = self.cfg.get("check_grads", True) |
| self._check_params = self.cfg.get("check_params", False) |
| |
| |
| |
| self._use_hooks = self.cfg.get("use_hooks", False) |
|
|
| |
| self.save_on_nan = self.cfg.get("save_on_nan", True) |
| self.save_dir = Path(self.cfg.get("save_dir", "outputs/debug")) |
| self.grad_clip = float(self.cfg.get("grad_clip", 1.0)) |
| self.log_interval = int(self.cfg.get("log_interval", 100)) |
| self.print_stats_every = int(self.cfg.get("print_stats_every", 100)) |
| self.grad_norm_warn_threshold = float(self.cfg.get("grad_norm_warn_threshold", 1e5)) |
| self.check_params_every = int(self.cfg.get("check_params_every", 100)) |
| self.check_outputs_every = int(self.cfg.get("check_outputs_every", 1)) |
| self.check_batch_every = int(self.cfg.get("check_batch_every", 1)) |
|
|
| |
| self.scaler: Optional[torch.cuda.amp.GradScaler] = self.cfg.get("scaler", None) |
|
|
| |
| self._hook_handles: List[torch.utils.hooks.RemovableHandle] = [] |
| self._hook_first_bad: Dict[str, Any] = {} |
| self._last_clean_step = -1 |
|
|
| def _unwrap_model(self) -> nn.Module: |
| m = self.model |
| while hasattr(m, "_orig_mod"): |
| m = m._orig_mod |
| while hasattr(m, "module"): |
| m = m.module |
| return m |
|
|
| |
| def __enter__(self) -> "TrainingDebugger": |
| if self.debug_mode and self._use_hooks: |
| hook_filter = self.cfg.get("hook_modules", "blocks") |
| if hasattr(self.model, "blocks"): |
| n_blocks = len(self.model.blocks) |
| else: |
| n_blocks = "?" |
| print( |
| f"[DEBUG] Forward hooks ENABLED ({hook_filter}, ~{n_blocks} hooks). " |
| f"This will slow down forward pass significantly. " |
| f"Set use_hooks=false in debug config to disable.", |
| flush=True, |
| ) |
| self._hook_handles, self._hook_first_bad = register_nan_hooks( |
| self._unwrap_model(), step=None, verbose=False, module_filter=hook_filter |
| ) |
| return self |
|
|
| def __exit__(self, *args: Any) -> None: |
| remove_nan_hooks(self._hook_handles) |
| self._hook_handles = [] |
|
|
| |
| def _prefix(self, step: Optional[int] = None) -> str: |
| return f"[step {step}] " if step is not None else "" |
|
|
| def _maybe_save( |
| self, |
| batch: Optional[Dict[str, Any]], |
| loss: Optional[torch.Tensor], |
| metrics: Optional[Dict[str, Any]], |
| step: int, |
| extra: Optional[Dict[str, Any]] = None, |
| ) -> None: |
| if not self.save_on_nan: |
| return |
| save_debug_state( |
| batch, self.model, self.ancestor_table, self.optimizer, |
| loss, metrics, step, self.save_dir, extra, |
| ) |
|
|
| |
| def check_batch(self, batch: Dict[str, Any], step: int) -> None: |
| if not self.debug_mode or not self._check_inputs: |
| return |
| if self.check_batch_every > 1 and step % self.check_batch_every != 0: |
| return |
| prefix = self._prefix(step) |
| for key, val in batch.items(): |
| if isinstance(val, torch.Tensor): |
| |
| |
| r = check_tensor_stats( |
| f"batch.{key}", val, step, stats=False, raise_on_nan=False |
| ) |
| if not r["is_finite"]: |
| print( |
| f"{prefix}[INPUT ERROR] Non-finite tensor in batch['{key}']\n {r['msg']}", |
| flush=True, |
| ) |
| self._maybe_save(batch, None, None, step, extra={"bad_key": key}) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite input in batch['{key}']") |
|
|
| |
| def check_forward_output(self, output: Any, step: int) -> None: |
| if not self.debug_mode or not self._check_outputs: |
| return |
| if self.check_outputs_every > 1 and step % self.check_outputs_every != 0: |
| return |
| |
| |
| |
| results = check_nested_tensors( |
| "model.output", output, step, |
| raise_on_nan=False, stats=False, max_elements_for_stats=50_000_000, |
| ) |
| bad = [r for r in results if not r["is_finite"]] |
| if bad: |
| prefix = self._prefix(step) |
| print(f"{prefix}[FORWARD ERROR] Non-finite model output:", flush=True) |
| for r in bad: |
| print(f" {r['msg']}", flush=True) |
| if self._hook_first_bad: |
| print(f" First bad layer from hooks: {self._hook_first_bad}", flush=True) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite model output") |
|
|
| |
| def check_loss( |
| self, |
| loss: torch.Tensor, |
| metrics: Dict[str, Any], |
| step: int, |
| ) -> None: |
| if not self.debug_mode or not self._check_loss: |
| return |
| prefix = self._prefix(step) |
|
|
| |
| for key, val in metrics.items(): |
| if isinstance(val, torch.Tensor) and not val.isfinite().all(): |
| has_nan = val.isnan().any().item() |
| has_inf = val.isinf().any().item() |
| status = "NAN" if has_nan else "INF" |
| print( |
| f"{prefix}[LOSS SUBITEM {status}] {key} = " |
| f"{val.item() if val.numel() == 1 else val.detach()}", |
| flush=True, |
| ) |
| self._maybe_save(None, loss, metrics, step, extra={"bad_metric": key}) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite loss subitem: {key}") |
|
|
| |
| if not loss.isfinite().all(): |
| nan_count = loss.isnan().sum().item() |
| inf_count = loss.isinf().sum().item() |
| print( |
| f"{prefix}[TOTAL LOSS ERROR] loss is non-finite " |
| f"nan_count={nan_count} inf_count={inf_count}", |
| flush=True, |
| ) |
| self._maybe_save(None, loss, metrics, step) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite total loss") |
| else: |
| self._last_clean_step = step |
|
|
| |
| def check_gradients(self, step: int) -> torch.Tensor: |
| """ |
| Check every param's grad and return total norm (before clipping). |
| |
| Optimised: single ``torch._foreach_norm`` call serves both NaN detection |
| and norm statistics (avoids duplicate GPU work). |
| """ |
| model = self._unwrap_model() |
| device = next(model.parameters()).device |
|
|
| |
| params: List[Tuple[str, nn.Parameter]] = [] |
| params += [(n, p) for n, p in model.named_parameters() if p.grad is not None] |
| if self.ancestor_table is not None: |
| params += [ |
| (f"ancestor_table.{n}", p) |
| for n, p in self.ancestor_table.named_parameters() |
| if p.grad is not None |
| ] |
|
|
| if not params: |
| return torch.tensor(0.0, device=device) |
|
|
| grad_tensors = [p.grad for _, p in params] |
| norms = torch._foreach_norm(grad_tensors, 2.0) |
| norms_stacked = torch.stack(norms) |
| total_norm_t = torch.norm(norms_stacked, 2.0) |
|
|
| prefix = self._prefix(step) |
| total_norm = total_norm_t.item() |
|
|
| if self.debug_mode and self._check_grads: |
| if not total_norm_t.isfinite(): |
| bad = [] |
| for (name, p), norm in zip(params, norms): |
| if not norm.isfinite(): |
| has_nan = p.grad.isnan().any().item() |
| status = "NAN" if has_nan else "INF" |
| print( |
| f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", |
| flush=True, |
| ) |
| bad.append((name, p, status)) |
| if bad: |
| print( |
| f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. " |
| f"Total norm={total_norm:.4e}", |
| flush=True, |
| ) |
| for name, _, status in bad[:5]: |
| print(f" - {name}: {status}", flush=True) |
| self._maybe_save( |
| None, None, None, step, |
| extra={"bad_grads": [n for n, _, _ in bad]}, |
| ) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite gradients") |
|
|
| |
| need_max = ( |
| total_norm > self.grad_norm_warn_threshold |
| or step % self.print_stats_every == 0 |
| ) |
| if need_max: |
| max_idx = int(norms_stacked.argmax().item()) |
| max_val = norms_stacked[max_idx].item() |
| max_name = params[max_idx][0] |
|
|
| if total_norm > self.grad_norm_warn_threshold: |
| print( |
| f"{prefix}[GRAD WARN] grad_norm={total_norm:.4e} (very large). " |
| f"Largest single-param grad={max_val:.4e} in {max_name}", |
| flush=True, |
| ) |
|
|
| if step % self.print_stats_every == 0: |
| print( |
| f"{prefix}[GRAD STATS] total_norm={total_norm:.4e} " |
| f"max_single={max_val:.4e} ({max_name})", |
| flush=True, |
| ) |
|
|
| return torch.tensor(total_norm, device=device) |
|
|
| def clip_grads(self, step: int) -> torch.Tensor: |
| """ |
| Clip gradients and return total norm (PyTorch returns pre-clip norm). |
| |
| When ``debug_mode`` and ``check_grads`` are enabled, this method also |
| performs NaN/Inf detection, large-gradient warnings, and periodic stats |
| logging. This merges the work of ``check_gradients()`` so that callers |
| don't need to pay for a duplicate ``_foreach_norm`` pass. |
| """ |
| model = self._unwrap_model() |
| device = next(model.parameters()).device |
| params = list(model.parameters()) |
| if self.ancestor_table is not None: |
| params += list(self.ancestor_table.parameters()) |
|
|
| if self.grad_clip > 0: |
| total_norm = torch.nn.utils.clip_grad_norm_(params, self.grad_clip) |
| else: |
| total_norm = torch.tensor(0.0, device=device) |
|
|
| prefix = self._prefix(step) |
|
|
| |
| |
| total_norm_finite = torch.isfinite(total_norm).item() |
|
|
| |
| if self.debug_mode and self._check_grads and not total_norm_finite: |
| |
| bad = [] |
| for name, p in model.named_parameters(): |
| if p.grad is not None and not torch.isfinite(p.grad).all(): |
| has_nan = p.grad.isnan().any().item() |
| status = "NAN" if has_nan else "INF" |
| print( |
| f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}", |
| flush=True, |
| ) |
| bad.append((name, p, status)) |
| if bad: |
| print( |
| f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. " |
| f"Total norm={total_norm.item():.4e}", |
| flush=True, |
| ) |
| for name, _, status in bad[:5]: |
| print(f" - {name}: {status}", flush=True) |
| self._maybe_save( |
| None, None, None, step, |
| extra={"bad_grads": [n for n, _, _ in bad]}, |
| ) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite gradients") |
|
|
| |
| if self.debug_mode and self._check_grads and total_norm_finite: |
| total_norm_val = total_norm.item() |
| need_max = ( |
| total_norm_val > self.grad_norm_warn_threshold |
| or step % self.print_stats_every == 0 |
| ) |
| if need_max: |
| grad_params = [ |
| (n, p) for n, p in model.named_parameters() if p.grad is not None |
| ] |
| if self.ancestor_table is not None: |
| grad_params += [ |
| (f"ancestor_table.{n}", p) |
| for n, p in self.ancestor_table.named_parameters() |
| if p.grad is not None |
| ] |
| if grad_params: |
| norms = torch._foreach_norm([p.grad for _, p in grad_params], 2.0) |
| norms_stacked = torch.stack(norms) |
| max_idx = int(norms_stacked.argmax().item()) |
| max_val = norms_stacked[max_idx].item() |
| max_name = grad_params[max_idx][0] |
|
|
| if total_norm_val > self.grad_norm_warn_threshold: |
| print( |
| f"{prefix}[GRAD WARN] grad_norm={total_norm_val:.4e} (very large). " |
| f"Largest single-param grad={max_val:.4e} in {max_name}", |
| flush=True, |
| ) |
| if step % self.print_stats_every == 0: |
| print( |
| f"{prefix}[GRAD STATS] total_norm={total_norm_val:.4e} " |
| f"max_single={max_val:.4e} ({max_name})", |
| flush=True, |
| ) |
|
|
| |
| |
| if not total_norm_finite: |
| print( |
| f"{prefix}[CLIP ERROR] total_norm non-finite after clip_grad_norm_: {total_norm}", |
| flush=True, |
| ) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite grad norm after clipping") |
| return total_norm |
|
|
| |
| def check_params_after_step(self, step: int) -> None: |
| if not self.debug_mode or not self._check_params: |
| return |
| if self.check_params_every > 1 and step % self.check_params_every != 0: |
| return |
| model = self._unwrap_model() |
| bad = check_model_params(model, self.ancestor_table, step, raise_on_nan=False) |
| if bad: |
| prefix = self._prefix(step) |
| print(f"{prefix}[PARAM ERROR] {len(bad)} params non-finite after optimizer.step():", flush=True) |
| for name, _, status in bad[:5]: |
| print(f" - {name}: {status}", flush=True) |
| self._maybe_save(None, None, None, step, extra={"bad_params": [n for n, _, _ in bad]}) |
| if self.raise_on_nan: |
| raise RuntimeError(f"{prefix}Non-finite params after step") |
|
|
| |
| def log_step( |
| self, |
| step: int, |
| loss: torch.Tensor, |
| metrics: Dict[str, Any], |
| lr: float, |
| grad_norm: float, |
| elapsed: float, |
| ) -> None: |
| """ |
| Structured training log line. Only called on main rank. |
| """ |
| prefix = self._prefix(step) |
| parts = [f"step={step:6d}"] |
|
|
| |
| l_total = metrics.get("loss_total", loss) |
| l_leaf = metrics.get("loss_leaf", torch.tensor(0.0)) |
| l_anc = metrics.get("loss_ancestor", torch.tensor(0.0)) |
| for label, t in [("total", l_total), ("leaf", l_leaf), ("ancestor", l_anc)]: |
| v = t.item() if hasattr(t, "item") else float(t) |
| parts.append(f"{label}={v:.4f}") |
|
|
| |
| parts.append(f"lr={lr:.2e}") |
| parts.append(f"gnorm={grad_norm:.4e}") |
|
|
| |
| if self.scaler is not None: |
| parts.append(f"scale={self.scaler.get_scale():.1f}") |
|
|
| |
| if self._hook_first_bad: |
| fb = self._hook_first_bad |
| parts.append(f"FIRST_BAD={fb.get('module_name','?')}({fb.get('stage','?')})") |
|
|
| parts.append(f"t={elapsed:.1f}s") |
| print(f"{prefix}{' | '.join(parts)}", flush=True) |
|
|
| def reset_hooks_step(self, step: int) -> None: |
| """ |
| Reset hook first-bad dict for the new step. Call at start of each step. |
| """ |
| self._hook_first_bad.clear() |
|
|