Vjeong Claude Sonnet 4.6 commited on
Commit
6b7ca0e
Β·
1 Parent(s): 1451cc6

Scale overfit test LR and steps by model size in LossDebugger

Browse files

Adjust Check 4 (single-batch overfit) parameters based on model size:
- Large (>500M): lr=1e-4, 400 steps; medium (>50M): lr=3e-4, 300 steps; small: lr=1e-3, 200 steps
- Add gradient clipping (norm=1.0) to prevent instability in large models
- Relax pass threshold from loss < 0.1 to min_loss < 0.5 to reduce false negatives
- Track min loss across all steps for a more robust pass/fail signal

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/training/debugger.py +17 -6
llm_lab/training/debugger.py CHANGED
@@ -400,30 +400,41 @@ class LossDebugger:
400
  torch.cuda.empty_cache()
401
 
402
  # ── Check 4: Single-batch overfit test ──
403
- print(f"\n ⏳ Check 4: Single-batch overfit test (200 steps)...")
 
 
 
 
 
 
 
 
404
  overfit_model = copy.deepcopy(model)
405
  overfit_model.to(device)
406
  overfit_model.train()
407
- overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=1e-3)
408
  single_input = input_ids[:1].to(device) # single sample
409
  single_target = targets[:1].to(device)
 
410
 
411
  overfit_losses = []
412
- for step in range(200):
413
  overfit_optimizer.zero_grad()
414
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
415
  _, loss = overfit_model(single_input, single_target)
416
  loss.backward()
 
417
  overfit_optimizer.step()
418
  overfit_losses.append(loss.item())
419
- if (step + 1) % 50 == 0:
420
  print(f" Step {step + 1}: Loss = {loss.item():.4f}")
421
 
422
  final_overfit_loss = overfit_losses[-1]
423
- overfit_ok = final_overfit_loss < 0.1
 
424
  detail = (
425
  f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
426
- f"(target < 0.1)"
427
  )
428
  results.append(_check_result("Single-batch overfit", overfit_ok, detail))
429
  icon = "βœ…" if overfit_ok else "❌"
 
400
  torch.cuda.empty_cache()
401
 
402
  # ── Check 4: Single-batch overfit test ──
403
+ # Scale LR and steps based on model size to avoid instability
404
+ num_params = sum(p.numel() for p in model.parameters())
405
+ if num_params > 500e6:
406
+ overfit_lr, overfit_steps = 1e-4, 400
407
+ elif num_params > 50e6:
408
+ overfit_lr, overfit_steps = 3e-4, 300
409
+ else:
410
+ overfit_lr, overfit_steps = 1e-3, 200
411
+ print(f"\n ⏳ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...")
412
  overfit_model = copy.deepcopy(model)
413
  overfit_model.to(device)
414
  overfit_model.train()
415
+ overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr)
416
  single_input = input_ids[:1].to(device) # single sample
417
  single_target = targets[:1].to(device)
418
+ log_interval = max(overfit_steps // 4, 1)
419
 
420
  overfit_losses = []
421
+ for step in range(overfit_steps):
422
  overfit_optimizer.zero_grad()
423
  with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
424
  _, loss = overfit_model(single_input, single_target)
425
  loss.backward()
426
+ torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0)
427
  overfit_optimizer.step()
428
  overfit_losses.append(loss.item())
429
+ if (step + 1) % log_interval == 0:
430
  print(f" Step {step + 1}: Loss = {loss.item():.4f}")
431
 
432
  final_overfit_loss = overfit_losses[-1]
433
+ min_overfit_loss = min(overfit_losses)
434
+ overfit_ok = min_overfit_loss < 0.5
435
  detail = (
436
  f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
437
+ f"(min={min_overfit_loss:.4f}, target < 0.5)"
438
  )
439
  results.append(_check_result("Single-batch overfit", overfit_ok, detail))
440
  icon = "βœ…" if overfit_ok else "❌"