Fix check_numerical_stability accuracy and completeness
Browse files- Norm fp32 check: inspect all unique norm classes instead of breaking
after first; use type(module).forward to resolve inherited methods
- Remove dead loss_fp32_note variable (condition always False)
- Add GradScaler for fp16 backward to prevent gradient underflow
from producing false positives in gradient health checks
- Add activation growth trend detection (std ratio across layers)
to fulfill docstring promise of catching initialization/norm bugs
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- llm_lab/training/debugger.py +30 -10
llm_lab/training/debugger.py
CHANGED
|
@@ -510,20 +510,20 @@ class LossDebugger:
|
|
| 510 |
|
| 511 |
# Check RMSNorm fp32 upcast
|
| 512 |
norm_fp32_ok = True
|
|
|
|
| 513 |
for name, module in model.named_modules():
|
| 514 |
cls_name = module.__class__.__name__
|
| 515 |
-
if "Norm" in cls_name:
|
| 516 |
-
|
| 517 |
import inspect
|
| 518 |
try:
|
| 519 |
-
src = inspect.getsource(module.forward)
|
| 520 |
has_upcast = ".float()" in src or "float32" in src
|
| 521 |
except (TypeError, OSError):
|
| 522 |
has_upcast = True # assume ok if can't inspect
|
| 523 |
if not has_upcast:
|
| 524 |
norm_fp32_ok = False
|
| 525 |
-
print(f" π΄ {
|
| 526 |
-
break # check one norm layer is enough
|
| 527 |
if norm_fp32_ok:
|
| 528 |
print(f" β
Norm layers use fp32 upcast (safe)")
|
| 529 |
|
|
@@ -533,10 +533,6 @@ class LossDebugger:
|
|
| 533 |
))
|
| 534 |
|
| 535 |
# Check loss computation dtype
|
| 536 |
-
loss_fp32_note = (
|
| 537 |
-
dtype in (torch.bfloat16, torch.float16)
|
| 538 |
-
and "cross_entropy" in str(type(model))
|
| 539 |
-
)
|
| 540 |
if dtype in (torch.bfloat16, torch.float16):
|
| 541 |
print(f" βΉοΈ Best practice: compute loss in fp32 when using {dtype}")
|
| 542 |
print(f" logits_fp32 = logits.float()")
|
|
@@ -579,6 +575,9 @@ class LossDebugger:
|
|
| 579 |
# ββ Forward + Backward ββ
|
| 580 |
model.train()
|
| 581 |
model.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
|
|
| 582 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 583 |
logits, loss = model(input_ids, targets)
|
| 584 |
|
|
@@ -590,7 +589,12 @@ class LossDebugger:
|
|
| 590 |
f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
|
| 591 |
))
|
| 592 |
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
# Remove hooks
|
| 596 |
for h in hooks:
|
|
@@ -650,6 +654,22 @@ class LossDebugger:
|
|
| 650 |
f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
|
| 651 |
))
|
| 652 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
# ββ Logit scale check ββ
|
| 654 |
logit_max = logits.float().abs().max().item()
|
| 655 |
logit_ok = logit_max < 1000
|
|
|
|
| 510 |
|
| 511 |
# Check RMSNorm fp32 upcast
|
| 512 |
norm_fp32_ok = True
|
| 513 |
+
checked_norm_classes: set = set()
|
| 514 |
for name, module in model.named_modules():
|
| 515 |
cls_name = module.__class__.__name__
|
| 516 |
+
if "Norm" in cls_name and cls_name not in checked_norm_classes:
|
| 517 |
+
checked_norm_classes.add(cls_name)
|
| 518 |
import inspect
|
| 519 |
try:
|
| 520 |
+
src = inspect.getsource(type(module).forward)
|
| 521 |
has_upcast = ".float()" in src or "float32" in src
|
| 522 |
except (TypeError, OSError):
|
| 523 |
has_upcast = True # assume ok if can't inspect
|
| 524 |
if not has_upcast:
|
| 525 |
norm_fp32_ok = False
|
| 526 |
+
print(f" π΄ {cls_name}: no fp32 upcast detected!")
|
|
|
|
| 527 |
if norm_fp32_ok:
|
| 528 |
print(f" β
Norm layers use fp32 upcast (safe)")
|
| 529 |
|
|
|
|
| 533 |
))
|
| 534 |
|
| 535 |
# Check loss computation dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
if dtype in (torch.bfloat16, torch.float16):
|
| 537 |
print(f" βΉοΈ Best practice: compute loss in fp32 when using {dtype}")
|
| 538 |
print(f" logits_fp32 = logits.float()")
|
|
|
|
| 575 |
# ββ Forward + Backward ββ
|
| 576 |
model.train()
|
| 577 |
model.zero_grad(set_to_none=True)
|
| 578 |
+
use_scaler = dtype == torch.float16 and torch.cuda.is_available()
|
| 579 |
+
scaler = torch.amp.GradScaler() if use_scaler else None
|
| 580 |
+
|
| 581 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 582 |
logits, loss = model(input_ids, targets)
|
| 583 |
|
|
|
|
| 589 |
f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
|
| 590 |
))
|
| 591 |
|
| 592 |
+
if scaler is not None:
|
| 593 |
+
scaler.scale(loss).backward()
|
| 594 |
+
_temp_opt = torch.optim.SGD(model.parameters(), lr=0)
|
| 595 |
+
scaler.unscale_(_temp_opt)
|
| 596 |
+
else:
|
| 597 |
+
loss.backward()
|
| 598 |
|
| 599 |
# Remove hooks
|
| 600 |
for h in hooks:
|
|
|
|
| 654 |
f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
|
| 655 |
))
|
| 656 |
|
| 657 |
+
# ββ Activation growth trend ββ
|
| 658 |
+
if len(activation_stats) >= 2:
|
| 659 |
+
stds = [s["std"] for s in activation_stats]
|
| 660 |
+
if stds[0] > 1e-8:
|
| 661 |
+
growth_ratio = stds[-1] / stds[0]
|
| 662 |
+
growth_ok = growth_ratio < 10
|
| 663 |
+
detail = (
|
| 664 |
+
f"Activation std ratio (last/first): {growth_ratio:.1f}x "
|
| 665 |
+
f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})"
|
| 666 |
+
)
|
| 667 |
+
results.append(_check_result("Activation growth", growth_ok, detail))
|
| 668 |
+
icon = "β
" if growth_ok else "π‘"
|
| 669 |
+
print(f" {icon} {detail}")
|
| 670 |
+
if not growth_ok:
|
| 671 |
+
print(f" Possible initialization or normalization issue")
|
| 672 |
+
|
| 673 |
# ββ Logit scale check ββ
|
| 674 |
logit_max = logits.float().abs().max().item()
|
| 675 |
logit_ok = logit_max < 1000
|