Vjeong Claude Opus 4.6 commited on
Commit
2fb0306
Β·
1 Parent(s): 6b7ca0e

Fix check_numerical_stability accuracy and completeness

Browse files

- Norm fp32 check: inspect all unique norm classes instead of breaking
after first; use type(module).forward to resolve inherited methods
- Remove dead loss_fp32_note variable (condition always False)
- Add GradScaler for fp16 backward to prevent gradient underflow
from producing false positives in gradient health checks
- Add activation growth trend detection (std ratio across layers)
to fulfill docstring promise of catching initialization/norm bugs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/training/debugger.py +30 -10
llm_lab/training/debugger.py CHANGED
@@ -510,20 +510,20 @@ class LossDebugger:
510
 
511
  # Check RMSNorm fp32 upcast
512
  norm_fp32_ok = True
 
513
  for name, module in model.named_modules():
514
  cls_name = module.__class__.__name__
515
- if "Norm" in cls_name:
516
- # Inspect forward source for .float() call
517
  import inspect
518
  try:
519
- src = inspect.getsource(module.forward)
520
  has_upcast = ".float()" in src or "float32" in src
521
  except (TypeError, OSError):
522
  has_upcast = True # assume ok if can't inspect
523
  if not has_upcast:
524
  norm_fp32_ok = False
525
- print(f" πŸ”΄ {name} ({cls_name}): no fp32 upcast detected!")
526
- break # check one norm layer is enough
527
  if norm_fp32_ok:
528
  print(f" βœ… Norm layers use fp32 upcast (safe)")
529
 
@@ -533,10 +533,6 @@ class LossDebugger:
533
  ))
534
 
535
  # Check loss computation dtype
536
- loss_fp32_note = (
537
- dtype in (torch.bfloat16, torch.float16)
538
- and "cross_entropy" in str(type(model))
539
- )
540
  if dtype in (torch.bfloat16, torch.float16):
541
  print(f" ℹ️ Best practice: compute loss in fp32 when using {dtype}")
542
  print(f" logits_fp32 = logits.float()")
@@ -579,6 +575,9 @@ class LossDebugger:
579
  # ── Forward + Backward ──
580
  model.train()
581
  model.zero_grad(set_to_none=True)
 
 
 
582
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
583
  logits, loss = model(input_ids, targets)
584
 
@@ -590,7 +589,12 @@ class LossDebugger:
590
  f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
591
  ))
592
 
593
- loss.backward()
 
 
 
 
 
594
 
595
  # Remove hooks
596
  for h in hooks:
@@ -650,6 +654,22 @@ class LossDebugger:
650
  f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
651
  ))
652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  # ── Logit scale check ──
654
  logit_max = logits.float().abs().max().item()
655
  logit_ok = logit_max < 1000
 
510
 
511
  # Check RMSNorm fp32 upcast
512
  norm_fp32_ok = True
513
+ checked_norm_classes: set = set()
514
  for name, module in model.named_modules():
515
  cls_name = module.__class__.__name__
516
+ if "Norm" in cls_name and cls_name not in checked_norm_classes:
517
+ checked_norm_classes.add(cls_name)
518
  import inspect
519
  try:
520
+ src = inspect.getsource(type(module).forward)
521
  has_upcast = ".float()" in src or "float32" in src
522
  except (TypeError, OSError):
523
  has_upcast = True # assume ok if can't inspect
524
  if not has_upcast:
525
  norm_fp32_ok = False
526
+ print(f" πŸ”΄ {cls_name}: no fp32 upcast detected!")
 
527
  if norm_fp32_ok:
528
  print(f" βœ… Norm layers use fp32 upcast (safe)")
529
 
 
533
  ))
534
 
535
  # Check loss computation dtype
 
 
 
 
536
  if dtype in (torch.bfloat16, torch.float16):
537
  print(f" ℹ️ Best practice: compute loss in fp32 when using {dtype}")
538
  print(f" logits_fp32 = logits.float()")
 
575
  # ── Forward + Backward ──
576
  model.train()
577
  model.zero_grad(set_to_none=True)
578
+ use_scaler = dtype == torch.float16 and torch.cuda.is_available()
579
+ scaler = torch.amp.GradScaler() if use_scaler else None
580
+
581
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
582
  logits, loss = model(input_ids, targets)
583
 
 
589
  f"Loss = {loss_val:.4f}" if loss_ok else f"Loss = {loss_val} (NaN/Inf!)"
590
  ))
591
 
592
+ if scaler is not None:
593
+ scaler.scale(loss).backward()
594
+ _temp_opt = torch.optim.SGD(model.parameters(), lr=0)
595
+ scaler.unscale_(_temp_opt)
596
+ else:
597
+ loss.backward()
598
 
599
  # Remove hooks
600
  for h in hooks:
 
654
  f"{act_nan_count} layers with NaN/Inf" if not act_ok else "All layers healthy",
655
  ))
656
 
657
+ # ── Activation growth trend ──
658
+ if len(activation_stats) >= 2:
659
+ stds = [s["std"] for s in activation_stats]
660
+ if stds[0] > 1e-8:
661
+ growth_ratio = stds[-1] / stds[0]
662
+ growth_ok = growth_ratio < 10
663
+ detail = (
664
+ f"Activation std ratio (last/first): {growth_ratio:.1f}x "
665
+ f"(layer_0={stds[0]:.4f}, last={stds[-1]:.4f})"
666
+ )
667
+ results.append(_check_result("Activation growth", growth_ok, detail))
668
+ icon = "βœ…" if growth_ok else "🟑"
669
+ print(f" {icon} {detail}")
670
+ if not growth_ok:
671
+ print(f" Possible initialization or normalization issue")
672
+
673
  # ── Logit scale check ──
674
  logit_max = logits.float().abs().max().item()
675
  logit_ok = logit_max < 1000