Scale overfit test LR and steps by model size in LossDebugger
Browse filesAdjust Check 4 (single-batch overfit) parameters based on model size:
- Large (>500M): lr=1e-4, 400 steps; medium (>50M): lr=3e-4, 300 steps; small: lr=1e-3, 200 steps
- Add gradient clipping (norm=1.0) to prevent instability in large models
- Relax pass threshold from loss < 0.1 to min_loss < 0.5 to reduce false negatives
- Track min loss across all steps for a more robust pass/fail signal
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- llm_lab/training/debugger.py +17 -6
llm_lab/training/debugger.py
CHANGED
|
@@ -400,30 +400,41 @@ class LossDebugger:
|
|
| 400 |
torch.cuda.empty_cache()
|
| 401 |
|
| 402 |
# ββ Check 4: Single-batch overfit test ββ
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
overfit_model = copy.deepcopy(model)
|
| 405 |
overfit_model.to(device)
|
| 406 |
overfit_model.train()
|
| 407 |
-
overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=
|
| 408 |
single_input = input_ids[:1].to(device) # single sample
|
| 409 |
single_target = targets[:1].to(device)
|
|
|
|
| 410 |
|
| 411 |
overfit_losses = []
|
| 412 |
-
for step in range(
|
| 413 |
overfit_optimizer.zero_grad()
|
| 414 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 415 |
_, loss = overfit_model(single_input, single_target)
|
| 416 |
loss.backward()
|
|
|
|
| 417 |
overfit_optimizer.step()
|
| 418 |
overfit_losses.append(loss.item())
|
| 419 |
-
if (step + 1) %
|
| 420 |
print(f" Step {step + 1}: Loss = {loss.item():.4f}")
|
| 421 |
|
| 422 |
final_overfit_loss = overfit_losses[-1]
|
| 423 |
-
|
|
|
|
| 424 |
detail = (
|
| 425 |
f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
|
| 426 |
-
f"(target < 0.
|
| 427 |
)
|
| 428 |
results.append(_check_result("Single-batch overfit", overfit_ok, detail))
|
| 429 |
icon = "β
" if overfit_ok else "β"
|
|
|
|
| 400 |
torch.cuda.empty_cache()
|
| 401 |
|
| 402 |
# ββ Check 4: Single-batch overfit test ββ
|
| 403 |
+
# Scale LR and steps based on model size to avoid instability
|
| 404 |
+
num_params = sum(p.numel() for p in model.parameters())
|
| 405 |
+
if num_params > 500e6:
|
| 406 |
+
overfit_lr, overfit_steps = 1e-4, 400
|
| 407 |
+
elif num_params > 50e6:
|
| 408 |
+
overfit_lr, overfit_steps = 3e-4, 300
|
| 409 |
+
else:
|
| 410 |
+
overfit_lr, overfit_steps = 1e-3, 200
|
| 411 |
+
print(f"\n β³ Check 4: Single-batch overfit test ({overfit_steps} steps, lr={overfit_lr:.0e})...")
|
| 412 |
overfit_model = copy.deepcopy(model)
|
| 413 |
overfit_model.to(device)
|
| 414 |
overfit_model.train()
|
| 415 |
+
overfit_optimizer = torch.optim.AdamW(overfit_model.parameters(), lr=overfit_lr)
|
| 416 |
single_input = input_ids[:1].to(device) # single sample
|
| 417 |
single_target = targets[:1].to(device)
|
| 418 |
+
log_interval = max(overfit_steps // 4, 1)
|
| 419 |
|
| 420 |
overfit_losses = []
|
| 421 |
+
for step in range(overfit_steps):
|
| 422 |
overfit_optimizer.zero_grad()
|
| 423 |
with torch.amp.autocast(device_type="cuda", dtype=dtype, enabled=(dtype != torch.float32)):
|
| 424 |
_, loss = overfit_model(single_input, single_target)
|
| 425 |
loss.backward()
|
| 426 |
+
torch.nn.utils.clip_grad_norm_(overfit_model.parameters(), 1.0)
|
| 427 |
overfit_optimizer.step()
|
| 428 |
overfit_losses.append(loss.item())
|
| 429 |
+
if (step + 1) % log_interval == 0:
|
| 430 |
print(f" Step {step + 1}: Loss = {loss.item():.4f}")
|
| 431 |
|
| 432 |
final_overfit_loss = overfit_losses[-1]
|
| 433 |
+
min_overfit_loss = min(overfit_losses)
|
| 434 |
+
overfit_ok = min_overfit_loss < 0.5
|
| 435 |
detail = (
|
| 436 |
f"Single-batch overfit: {overfit_losses[0]:.4f} -> {final_overfit_loss:.4f} "
|
| 437 |
+
f"(min={min_overfit_loss:.4f}, target < 0.5)"
|
| 438 |
)
|
| 439 |
results.append(_check_result("Single-batch overfit", overfit_ok, detail))
|
| 440 |
icon = "β
" if overfit_ok else "β"
|