sad / src /utils /debug_utils.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
37.2 kB
"""
Debug utilities for systematic NaN/Inf detection and root-cause analysis.
Usage in train.py:
from src.utils.debug_utils import TrainingDebugger
debugger = TrainingDebugger(
model, ancestor_table, optimizer,
config={"debug_mode": True, "raise_on_nan": True, "use_hooks": True}
)
with debugger:
for step in range(...):
debugger.check_batch(batch, step)
loss, metrics, output = forward_step(...)
debugger.check_forward_output(output, step)
debugger.check_loss(loss, metrics, step)
loss.backward()
gnorm = debugger.check_gradients(step)
debugger.clip_grads(step)
optimizer.step()
debugger.check_params_after_step(step)
Design principles:
- Fail-fast: detect the FIRST place where NaN/Inf appears.
- Informative: print tensor name, shape, dtype, stats, module name, param name.
- Minimally intrusive: wrap existing logic, don't replace it.
- Configurable: debug_mode=False disables almost all overhead.
"""
from __future__ import annotations
import sys
import math
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
# --------------------------------------------------------------------------- #
# 1. Low-level tensor checks
# --------------------------------------------------------------------------- #
def check_tensor_stats(
name: str,
tensor: Optional[torch.Tensor],
step: Optional[int] = None,
stats: bool = True,
raise_on_nan: bool = False,
max_elements_for_stats: int = 50_000_000,
) -> Dict[str, Any]:
"""
Check a single tensor for NaN/Inf and optionally print stats.
CRITICAL: this function avoids creating a copy of the tensor via boolean
indexing (``tensor[tensor.isfinite()]``) which caused OOM on large logits.
For tensors larger than ``max_elements_for_stats`` we skip detailed stats.
For all-finite tensors we use native reduce ops (no copies).
For tensors that contain NaN we use ``torch.nanmean`` when available.
Returns:
dict with keys: has_nan, has_inf, is_finite, msg
"""
prefix = f"[step {step}] " if step is not None else ""
if tensor is None:
return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] None"}
if not isinstance(tensor, torch.Tensor):
return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] non-tensor {type(tensor)}"}
if tensor.numel() == 0:
return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": f"{prefix}[{name}] empty tensor"}
# ---- Fast path: one sync instead of three ----
is_finite = tensor.isfinite().all().item()
if not is_finite:
has_nan = tensor.isnan().any().item()
has_inf = tensor.isinf().any().item()
status = "NAN" if has_nan else "INF"
nan_count = tensor.isnan().sum().item()
inf_count = tensor.isinf().sum().item()
msg = (
f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} "
f"device={tensor.device} status={status} "
f"nan_count={nan_count} inf_count={inf_count}"
)
print(msg, flush=True)
if raise_on_nan:
raise RuntimeError(f"{prefix}NaN/Inf detected in {name}")
return {"has_nan": has_nan, "has_inf": has_inf, "is_finite": False, "msg": msg}
# ---- OK path ----
msg = (
f"{prefix}[{name}] shape={list(tensor.shape)} dtype={tensor.dtype} "
f"device={tensor.device} status=OK"
)
if stats and tensor.numel() > 0:
if tensor.numel() > max_elements_for_stats:
msg += " | stats_skipped_large_tensor"
elif tensor.isnan().all().item():
msg += " | all_nan"
else:
# All finite → native reduce ops, zero extra memory
tmin = tensor.min().item()
tmax = tensor.max().item()
msg += f" | min={tmin} max={tmax}"
if tensor.dtype.is_floating_point or tensor.dtype.is_complex:
msg += f" mean={tensor.mean().item():.4e}"
if tensor.numel() > 1:
msg += f" std={tensor.std().item():.4e}"
else:
tf = tensor.float()
msg += f" mean={tf.mean().item():.4e}"
if tensor.numel() > 1:
msg += f" std={tf.std().item():.4e}"
return {"has_nan": False, "has_inf": False, "is_finite": True, "msg": msg}
def check_nested_tensors(
name: str,
obj: Any,
step: Optional[int] = None,
raise_on_nan: bool = False,
stats: bool = True,
max_elements_for_stats: int = 50_000_000,
) -> List[Dict[str, Any]]:
"""
Recursively check dict/list/tuple for tensors.
"""
results: List[Dict[str, Any]] = []
if isinstance(obj, torch.Tensor):
results.append(check_tensor_stats(
name, obj, step, stats=stats,
raise_on_nan=raise_on_nan,
max_elements_for_stats=max_elements_for_stats,
))
elif isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
results.extend(check_nested_tensors(
f"{name}[{i}]", item, step, raise_on_nan, stats, max_elements_for_stats
))
elif isinstance(obj, dict):
for k, v in obj.items():
results.extend(check_nested_tensors(
f"{name}.{k}", v, step, raise_on_nan, stats, max_elements_for_stats
))
return results
# --------------------------------------------------------------------------- #
# 2. Gradient & parameter checks
# --------------------------------------------------------------------------- #
def compute_grad_norm(
model: nn.Module,
ancestor_table: Optional[nn.Module] = None,
) -> Tuple[float, Optional[Tuple[str, nn.Parameter]], float]:
"""
Compute total grad norm and identify the param with the largest individual grad norm.
Uses ``torch._foreach_norm`` (C++ batched op) instead of a Python loop
to avoid ~100 ms overhead on large models.
Returns:
(total_norm, (name, param) or None, max_single_norm)
"""
params: List[Tuple[str, nn.Parameter]] = []
params += [(n, p) for n, p in model.named_parameters() if p.grad is not None]
if ancestor_table is not None:
params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None]
if not params:
return 0.0, None, 0.0
# Fast batched norm computation via _foreach_norm
grad_tensors = [p.grad for _, p in params]
norms = torch._foreach_norm(grad_tensors, 2.0)
norms_stacked = torch.stack(norms)
total_norm = torch.norm(norms_stacked, 2.0).item()
max_idx = int(norms_stacked.argmax().item())
max_val = norms_stacked[max_idx].item()
max_param = params[max_idx]
return total_norm, max_param, max_val
def check_gradients(
model: nn.Module,
ancestor_table: Optional[nn.Module] = None,
step: Optional[int] = None,
raise_on_nan: bool = True,
print_all: bool = False,
) -> List[Tuple[str, nn.Parameter, str]]:
"""
Check grads for NaN/Inf.
Fast path: ``torch._foreach_norm`` batched computation (~1 ms for a 600M model).
Slow path (only when a NaN/Inf is found): iterate individual params to print names.
"""
prefix = f"[step {step}] " if step is not None else ""
bad: List[Tuple[str, nn.Parameter, str]] = []
params: List[Tuple[str, nn.Parameter]] = []
params += [(n, p) for n, p in model.named_parameters() if p.grad is not None]
if ancestor_table is not None:
params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters() if p.grad is not None]
if not params:
return bad
grad_tensors = [p.grad for _, p in params]
norms = torch._foreach_norm(grad_tensors, 2.0)
total_norm = torch.norm(torch.stack(norms), 2.0)
if not torch.isfinite(total_norm):
# Slow path: find the culprit(s)
for (name, p), norm in zip(params, norms):
if not torch.isfinite(norm):
has_nan = p.grad.isnan().any().item()
status = "NAN" if has_nan else "INF"
print(
f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}",
flush=True,
)
bad.append((name, p, status))
if bad and raise_on_nan:
names = [n for n, _, _ in bad]
raise RuntimeError(
f"{prefix}Non-finite grad in {len(bad)} params: {names[:10]}..."
)
elif print_all:
for (name, p), norm in zip(params, norms):
print(
f"{prefix}[GRAD OK] {name} grad_norm={norm.item():.4e}",
flush=True,
)
return bad
def check_model_params(
model: nn.Module,
ancestor_table: Optional[nn.Module] = None,
step: Optional[int] = None,
raise_on_nan: bool = True,
) -> List[Tuple[str, nn.Parameter, str]]:
"""
Check params for NaN/Inf (typically after optimizer.step()).
Fast path via ``torch._foreach_norm``; slow path only when needed.
"""
prefix = f"[step {step}] " if step is not None else ""
bad: List[Tuple[str, nn.Parameter, str]] = []
params: List[Tuple[str, nn.Parameter]] = []
params += [(n, p) for n, p in model.named_parameters()]
if ancestor_table is not None:
params += [(f"ancestor_table.{n}", p) for n, p in ancestor_table.named_parameters()]
if not params:
return bad
param_tensors = [p.data for _, p in params]
norms = torch._foreach_norm(param_tensors, 2.0)
total_norm = torch.norm(torch.stack(norms), 2.0)
if not torch.isfinite(total_norm):
for (name, p), norm in zip(params, norms):
if not torch.isfinite(norm):
has_nan = p.data.isnan().any().item()
status = "NAN" if has_nan else "INF"
print(
f"{prefix}[PARAM {status}] {name} shape={list(p.shape)}",
flush=True,
)
bad.append((name, p, status))
if bad and raise_on_nan:
names = [n for n, _, _ in bad]
raise RuntimeError(
f"{prefix}Non-finite param in {len(bad)} params: {names[:10]}..."
)
return bad
# --------------------------------------------------------------------------- #
# 3. Forward hooks for intermediate-layer NaN detection
# --------------------------------------------------------------------------- #
def register_nan_hooks(
model: nn.Module,
step: Optional[int] = None,
verbose: bool = False,
module_filter: Optional[str] = "blocks",
) -> Tuple[List[torch.utils.hooks.RemovableHandle], Dict[str, Any]]:
"""
Register forward hooks to catch the FIRST layer that produces non-finite
output (or receives non-finite input).
Args:
module_filter: Controls which modules get hooks.
- "blocks" (default): hooks on each block inside ``model.blocks``.
Typically ~12 hooks for a transformer, negligible overhead.
- "children": hooks on all direct children of ``model``.
- "all": hooks on *every* named submodule (old behaviour).
Very slow (~10x forward slowdown) and increases memory.
- None / "": no hooks.
Returns:
(list_of_handles, first_bad_dict)
"""
handles: List[torch.utils.hooks.RemovableHandle] = []
first_bad: Dict[str, Any] = {}
def _targets():
"""Yield (name, module) pairs to hook."""
if not module_filter:
return
mf = str(module_filter).lower()
if mf == "all":
for name, module in model.named_modules():
if name:
yield name, module
elif mf == "children":
for name, module in model.named_children():
yield name, module
elif mf == "blocks":
# Default: hook each block in model.blocks (e.g. transformer blocks)
if hasattr(model, "blocks") and isinstance(model.blocks, nn.ModuleList):
for i, block in enumerate(model.blocks):
yield f"blocks.{i}", block
else:
# Fallback to direct children if no .blocks attr
for name, module in model.named_children():
yield name, module
else:
# Treat as a comma-separated list of module names
allowed = {s.strip() for s in mf.split(",")}
for name, module in model.named_modules():
if name in allowed:
yield name, module
def make_hook(module_name: str, module_type: str):
def hook(module, inputs, output):
if first_bad:
return
prefix = f"[step {step}] " if step is not None else ""
# Check inputs
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor) and inp.numel() > 0 and not inp.isfinite().all():
first_bad.update({
"stage": "input",
"module_name": module_name,
"module_type": module_type,
"tensor_idx": i,
"shape": list(inp.shape),
"has_nan": inp.isnan().any().item(),
"has_inf": inp.isinf().any().item(),
})
print(
f"{prefix}[HOOK INPUT] {module_name} ({module_type}) "
f"input[{i}] shape={list(inp.shape)} "
f"nan={inp.isnan().sum().item()} inf={inp.isinf().sum().item()}",
flush=True,
)
return
# Check outputs (recursive over tuple/list/dict)
tensors_to_check = []
if isinstance(output, torch.Tensor):
tensors_to_check = [("output", output)]
elif isinstance(output, (list, tuple)):
for i, o in enumerate(output):
if isinstance(o, torch.Tensor):
tensors_to_check.append((f"output[{i}]", o))
elif isinstance(output, dict):
for k, v in output.items():
if isinstance(v, torch.Tensor):
tensors_to_check.append((f"output['{k}']", v))
for out_name, out_tensor in tensors_to_check:
if out_tensor.numel() > 0 and not out_tensor.isfinite().all():
first_bad.update({
"stage": "output",
"module_name": module_name,
"module_type": module_type,
"tensor_name": out_name,
"shape": list(out_tensor.shape),
"has_nan": out_tensor.isnan().any().item(),
"has_inf": out_tensor.isinf().any().item(),
})
print(
f"{prefix}[HOOK OUTPUT] {module_name} ({module_type}) "
f"{out_name} shape={list(out_tensor.shape)} "
f"nan={out_tensor.isnan().sum().item()} inf={out_tensor.isinf().sum().item()}",
flush=True,
)
return
if verbose:
for out_name, out_tensor in tensors_to_check:
print(
f"{prefix}[HOOK OK] {module_name} ({module_type}) "
f"{out_name} shape={list(out_tensor.shape)}",
flush=True,
)
return hook
for name, module in _targets():
h = module.register_forward_hook(make_hook(name, module.__class__.__name__))
handles.append(h)
return handles, first_bad
def remove_nan_hooks(handles: List[torch.utils.hooks.RemovableHandle]) -> None:
for h in handles:
h.remove()
# --------------------------------------------------------------------------- #
# 4. Snapshot saving on anomaly
# --------------------------------------------------------------------------- #
def save_debug_state(
batch: Optional[Dict[str, Any]],
model: nn.Module,
ancestor_table: Optional[nn.Module],
optimizer: torch.optim.Optimizer,
loss: Optional[torch.Tensor],
metrics: Optional[Dict[str, Any]],
step: int,
save_dir: Union[str, Path],
extra: Optional[Dict[str, Any]] = None,
) -> Path:
"""
Save current batch, model, optimizer, and metrics for post-mortem analysis.
"""
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
prefix = f"debug_step{step:06d}_{ts}"
state: Dict[str, Any] = {"step": step}
if batch is not None:
state["batch"] = {
k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in batch.items()
}
if loss is not None:
state["loss"] = loss.item() if torch.isfinite(loss) else float("nan")
if metrics is not None:
state["metrics"] = {
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in metrics.items()
}
if extra is not None:
state["extra"] = extra
torch.save(state, save_dir / f"{prefix}_state.pt")
# Unwrap to get raw state_dict
m = model
while hasattr(m, "_orig_mod"):
m = m._orig_mod
while hasattr(m, "module"):
m = m.module
torch.save(m.state_dict(), save_dir / f"{prefix}_model.pt")
if ancestor_table is not None:
torch.save(ancestor_table.state_dict(), save_dir / f"{prefix}_ancestor.pt")
torch.save(optimizer.state_dict(), save_dir / f"{prefix}_optimizer.pt")
path = save_dir / f"{prefix}_state.pt"
print(f"[DEBUG] Saved debug snapshot to {path}", flush=True)
return path
# --------------------------------------------------------------------------- #
# 5. TrainingDebugger — high-level orchestrator
# --------------------------------------------------------------------------- #
class TrainingDebugger:
"""
Systematic NaN/Inf debugger for training loops.
Typical usage:
debugger = TrainingDebugger(model, ancestor_table, optimizer, config)
with debugger: # registers hooks on enter, removes on exit
for step in range(num_steps):
...
"""
def __init__(
self,
model: nn.Module,
ancestor_table: Optional[nn.Module],
optimizer: torch.optim.Optimizer,
config: Optional[Dict[str, Any]] = None,
):
self.model = model
self.ancestor_table = ancestor_table
self.optimizer = optimizer
self.cfg = config or {}
# Top-level switches
self.debug_mode = self.cfg.get("debug_mode", True)
self.raise_on_nan = self.cfg.get("raise_on_nan", False)
# Stage toggles (prefixed with _ to avoid shadowing methods)
# Defaults are conservative: only loss + grad NaN guards are on every
# step; input/output checks default OFF to avoid GPU sync overhead.
self._check_inputs = self.cfg.get("check_inputs", False)
self._check_outputs = self.cfg.get("check_outputs", False)
self._check_loss = self.cfg.get("check_loss", False)
self._check_grads = self.cfg.get("check_grads", True)
self._check_params = self.cfg.get("check_params", False)
# Hooks default to OFF because they add significant forward overhead
# (5-10x slower) and increase memory. Only enable when you need to
# pinpoint which *intermediate* layer first produces NaN.
self._use_hooks = self.cfg.get("use_hooks", False)
# Behaviour
self.save_on_nan = self.cfg.get("save_on_nan", True)
self.save_dir = Path(self.cfg.get("save_dir", "outputs/debug"))
self.grad_clip = float(self.cfg.get("grad_clip", 1.0))
self.log_interval = int(self.cfg.get("log_interval", 100))
self.print_stats_every = int(self.cfg.get("print_stats_every", 100))
self.grad_norm_warn_threshold = float(self.cfg.get("grad_norm_warn_threshold", 1e5))
self.check_params_every = int(self.cfg.get("check_params_every", 100))
self.check_outputs_every = int(self.cfg.get("check_outputs_every", 1))
self.check_batch_every = int(self.cfg.get("check_batch_every", 1))
# AMP / scaler tracking
self.scaler: Optional[torch.cuda.amp.GradScaler] = self.cfg.get("scaler", None)
# Internal state
self._hook_handles: List[torch.utils.hooks.RemovableHandle] = []
self._hook_first_bad: Dict[str, Any] = {}
self._last_clean_step = -1
def _unwrap_model(self) -> nn.Module:
m = self.model
while hasattr(m, "_orig_mod"):
m = m._orig_mod
while hasattr(m, "module"):
m = m.module
return m
# -- context manager --
def __enter__(self) -> "TrainingDebugger":
if self.debug_mode and self._use_hooks:
hook_filter = self.cfg.get("hook_modules", "blocks")
if hasattr(self.model, "blocks"):
n_blocks = len(self.model.blocks)
else:
n_blocks = "?"
print(
f"[DEBUG] Forward hooks ENABLED ({hook_filter}, ~{n_blocks} hooks). "
f"This will slow down forward pass significantly. "
f"Set use_hooks=false in debug config to disable.",
flush=True,
)
self._hook_handles, self._hook_first_bad = register_nan_hooks(
self._unwrap_model(), step=None, verbose=False, module_filter=hook_filter
)
return self
def __exit__(self, *args: Any) -> None:
remove_nan_hooks(self._hook_handles)
self._hook_handles = []
# -- helpers --
def _prefix(self, step: Optional[int] = None) -> str:
return f"[step {step}] " if step is not None else ""
def _maybe_save(
self,
batch: Optional[Dict[str, Any]],
loss: Optional[torch.Tensor],
metrics: Optional[Dict[str, Any]],
step: int,
extra: Optional[Dict[str, Any]] = None,
) -> None:
if not self.save_on_nan:
return
save_debug_state(
batch, self.model, self.ancestor_table, self.optimizer,
loss, metrics, step, self.save_dir, extra,
)
# -- stage A: input data --
def check_batch(self, batch: Dict[str, Any], step: int) -> None:
if not self.debug_mode or not self._check_inputs:
return
if self.check_batch_every > 1 and step % self.check_batch_every != 0:
return
prefix = self._prefix(step)
for key, val in batch.items():
if isinstance(val, torch.Tensor):
# stats=False: batch tensor stats (min/max/mean) are rarely useful
# and the extra .item() syncs add ~3-4 ms per tensor.
r = check_tensor_stats(
f"batch.{key}", val, step, stats=False, raise_on_nan=False
)
if not r["is_finite"]:
print(
f"{prefix}[INPUT ERROR] Non-finite tensor in batch['{key}']\n {r['msg']}",
flush=True,
)
self._maybe_save(batch, None, None, step, extra={"bad_key": key})
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite input in batch['{key}']")
# -- stage B: model output --
def check_forward_output(self, output: Any, step: int) -> None:
if not self.debug_mode or not self._check_outputs:
return
if self.check_outputs_every > 1 and step % self.check_outputs_every != 0:
return
# For forward output we only check NaN/Inf, NOT full stats.
# Model logits [B, S, V] can be >1B elements; computing min/max/mean/std
# would temporarily allocate several GB and cause OOM.
results = check_nested_tensors(
"model.output", output, step,
raise_on_nan=False, stats=False, max_elements_for_stats=50_000_000,
)
bad = [r for r in results if not r["is_finite"]]
if bad:
prefix = self._prefix(step)
print(f"{prefix}[FORWARD ERROR] Non-finite model output:", flush=True)
for r in bad:
print(f" {r['msg']}", flush=True)
if self._hook_first_bad:
print(f" First bad layer from hooks: {self._hook_first_bad}", flush=True)
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite model output")
# -- stage C: loss --
def check_loss(
self,
loss: torch.Tensor,
metrics: Dict[str, Any],
step: int,
) -> None:
if not self.debug_mode or not self._check_loss:
return
prefix = self._prefix(step)
# Check sub-items
for key, val in metrics.items():
if isinstance(val, torch.Tensor) and not val.isfinite().all():
has_nan = val.isnan().any().item()
has_inf = val.isinf().any().item()
status = "NAN" if has_nan else "INF"
print(
f"{prefix}[LOSS SUBITEM {status}] {key} = "
f"{val.item() if val.numel() == 1 else val.detach()}",
flush=True,
)
self._maybe_save(None, loss, metrics, step, extra={"bad_metric": key})
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite loss subitem: {key}")
# Check total loss
if not loss.isfinite().all():
nan_count = loss.isnan().sum().item()
inf_count = loss.isinf().sum().item()
print(
f"{prefix}[TOTAL LOSS ERROR] loss is non-finite "
f"nan_count={nan_count} inf_count={inf_count}",
flush=True,
)
self._maybe_save(None, loss, metrics, step)
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite total loss")
else:
self._last_clean_step = step
# -- stage D: gradients --
def check_gradients(self, step: int) -> torch.Tensor:
"""
Check every param's grad and return total norm (before clipping).
Optimised: single ``torch._foreach_norm`` call serves both NaN detection
and norm statistics (avoids duplicate GPU work).
"""
model = self._unwrap_model()
device = next(model.parameters()).device
# Collect grad-bearing params
params: List[Tuple[str, nn.Parameter]] = []
params += [(n, p) for n, p in model.named_parameters() if p.grad is not None]
if self.ancestor_table is not None:
params += [
(f"ancestor_table.{n}", p)
for n, p in self.ancestor_table.named_parameters()
if p.grad is not None
]
if not params:
return torch.tensor(0.0, device=device)
grad_tensors = [p.grad for _, p in params]
norms = torch._foreach_norm(grad_tensors, 2.0)
norms_stacked = torch.stack(norms)
total_norm_t = torch.norm(norms_stacked, 2.0)
prefix = self._prefix(step)
total_norm = total_norm_t.item()
if self.debug_mode and self._check_grads:
if not total_norm_t.isfinite():
bad = []
for (name, p), norm in zip(params, norms):
if not norm.isfinite():
has_nan = p.grad.isnan().any().item()
status = "NAN" if has_nan else "INF"
print(
f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}",
flush=True,
)
bad.append((name, p, status))
if bad:
print(
f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. "
f"Total norm={total_norm:.4e}",
flush=True,
)
for name, _, status in bad[:5]:
print(f" - {name}: {status}", flush=True)
self._maybe_save(
None, None, None, step,
extra={"bad_grads": [n for n, _, _ in bad]},
)
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite gradients")
# Only compute per-param max when we actually need to print it.
need_max = (
total_norm > self.grad_norm_warn_threshold
or step % self.print_stats_every == 0
)
if need_max:
max_idx = int(norms_stacked.argmax().item())
max_val = norms_stacked[max_idx].item()
max_name = params[max_idx][0]
if total_norm > self.grad_norm_warn_threshold:
print(
f"{prefix}[GRAD WARN] grad_norm={total_norm:.4e} (very large). "
f"Largest single-param grad={max_val:.4e} in {max_name}",
flush=True,
)
if step % self.print_stats_every == 0:
print(
f"{prefix}[GRAD STATS] total_norm={total_norm:.4e} "
f"max_single={max_val:.4e} ({max_name})",
flush=True,
)
return torch.tensor(total_norm, device=device)
def clip_grads(self, step: int) -> torch.Tensor:
"""
Clip gradients and return total norm (PyTorch returns pre-clip norm).
When ``debug_mode`` and ``check_grads`` are enabled, this method also
performs NaN/Inf detection, large-gradient warnings, and periodic stats
logging. This merges the work of ``check_gradients()`` so that callers
don't need to pay for a duplicate ``_foreach_norm`` pass.
"""
model = self._unwrap_model()
device = next(model.parameters()).device
params = list(model.parameters())
if self.ancestor_table is not None:
params += list(self.ancestor_table.parameters())
if self.grad_clip > 0:
total_norm = torch.nn.utils.clip_grad_norm_(params, self.grad_clip)
else:
total_norm = torch.tensor(0.0, device=device)
prefix = self._prefix(step)
# Single isfinite evaluation shared by debug and clip-error paths.
# Calling it once avoids duplicate GPU sync.
total_norm_finite = torch.isfinite(total_norm).item()
# ---- NaN / Inf detection (merged from check_gradients) ----
if self.debug_mode and self._check_grads and not total_norm_finite:
# Slow path: identify the culprit(s)
bad = []
for name, p in model.named_parameters():
if p.grad is not None and not torch.isfinite(p.grad).all():
has_nan = p.grad.isnan().any().item()
status = "NAN" if has_nan else "INF"
print(
f"{prefix}[GRAD {status}] {name} shape={list(p.grad.shape)}",
flush=True,
)
bad.append((name, p, status))
if bad:
print(
f"{prefix}[GRAD ERROR] {len(bad)} params with non-finite grad. "
f"Total norm={total_norm.item():.4e}",
flush=True,
)
for name, _, status in bad[:5]:
print(f" - {name}: {status}", flush=True)
self._maybe_save(
None, None, None, step,
extra={"bad_grads": [n for n, _, _ in bad]},
)
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite gradients")
# Stats / warnings (only when total norm is finite)
if self.debug_mode and self._check_grads and total_norm_finite:
total_norm_val = total_norm.item()
need_max = (
total_norm_val > self.grad_norm_warn_threshold
or step % self.print_stats_every == 0
)
if need_max:
grad_params = [
(n, p) for n, p in model.named_parameters() if p.grad is not None
]
if self.ancestor_table is not None:
grad_params += [
(f"ancestor_table.{n}", p)
for n, p in self.ancestor_table.named_parameters()
if p.grad is not None
]
if grad_params:
norms = torch._foreach_norm([p.grad for _, p in grad_params], 2.0)
norms_stacked = torch.stack(norms)
max_idx = int(norms_stacked.argmax().item())
max_val = norms_stacked[max_idx].item()
max_name = grad_params[max_idx][0]
if total_norm_val > self.grad_norm_warn_threshold:
print(
f"{prefix}[GRAD WARN] grad_norm={total_norm_val:.4e} (very large). "
f"Largest single-param grad={max_val:.4e} in {max_name}",
flush=True,
)
if step % self.print_stats_every == 0:
print(
f"{prefix}[GRAD STATS] total_norm={total_norm_val:.4e} "
f"max_single={max_val:.4e} ({max_name})",
flush=True,
)
# After clip, double-check that clipping didn't produce NaN
# (clip_grad_norm_ with inf norm can produce nan via inf*0)
if not total_norm_finite:
print(
f"{prefix}[CLIP ERROR] total_norm non-finite after clip_grad_norm_: {total_norm}",
flush=True,
)
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite grad norm after clipping")
return total_norm
# -- stage E: parameters after step --
def check_params_after_step(self, step: int) -> None:
if not self.debug_mode or not self._check_params:
return
if self.check_params_every > 1 and step % self.check_params_every != 0:
return
model = self._unwrap_model()
bad = check_model_params(model, self.ancestor_table, step, raise_on_nan=False)
if bad:
prefix = self._prefix(step)
print(f"{prefix}[PARAM ERROR] {len(bad)} params non-finite after optimizer.step():", flush=True)
for name, _, status in bad[:5]:
print(f" - {name}: {status}", flush=True)
self._maybe_save(None, None, None, step, extra={"bad_params": [n for n, _, _ in bad]})
if self.raise_on_nan:
raise RuntimeError(f"{prefix}Non-finite params after step")
# -- periodic summary log --
def log_step(
self,
step: int,
loss: torch.Tensor,
metrics: Dict[str, Any],
lr: float,
grad_norm: float,
elapsed: float,
) -> None:
"""
Structured training log line. Only called on main rank.
"""
prefix = self._prefix(step)
parts = [f"step={step:6d}"]
# Loss items
l_total = metrics.get("loss_total", loss)
l_leaf = metrics.get("loss_leaf", torch.tensor(0.0))
l_anc = metrics.get("loss_ancestor", torch.tensor(0.0))
for label, t in [("total", l_total), ("leaf", l_leaf), ("ancestor", l_anc)]:
v = t.item() if hasattr(t, "item") else float(t)
parts.append(f"{label}={v:.4f}")
# LR & grad norm
parts.append(f"lr={lr:.2e}")
parts.append(f"gnorm={grad_norm:.4e}")
# AMP scaler
if self.scaler is not None:
parts.append(f"scale={self.scaler.get_scale():.1f}")
# Hooks first-bad (if any)
if self._hook_first_bad:
fb = self._hook_first_bad
parts.append(f"FIRST_BAD={fb.get('module_name','?')}({fb.get('stage','?')})")
parts.append(f"t={elapsed:.1f}s")
print(f"{prefix}{' | '.join(parts)}", flush=True)
def reset_hooks_step(self, step: int) -> None:
"""
Reset hook first-bad dict for the new step. Call at start of each step.
"""
self._hook_first_bad.clear()