Vjeong Claude Opus 4.6 commited on
Commit
fb048e4
·
1 Parent(s): a671953

Fix batch size diagnostic: widen window and list multiple causes

Browse files

- Increase loss CV window from 20 to 50 steps for statistical stability
- Replace single-cause diagnosis ("batch too small") with multi-cause
guidance (LR, batch size, data quality)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/training/debugger.py +7 -7
llm_lab/training/debugger.py CHANGED
@@ -754,21 +754,21 @@ class LossDebugger:
754
  print("\n Batch Size Analysis:")
755
  print(f" Effective batch: {config.effective_batch_size}")
756
 
757
- if len(train_losses) >= 20:
758
- recent_losses = train_losses[-20:]
759
  loss_mean = sum(recent_losses) / len(recent_losses)
760
  loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
761
  loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)
762
 
763
- print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation)")
764
 
765
  if loss_cv > 0.1:
766
  findings.append({
767
- "issue": "Batch size may be too small",
768
- "evidence": f"Loss CV = {loss_cv:.4f} (high variance)",
769
- "action": "Increase gradient_accumulation_steps",
770
  })
771
- print(f" 🟡 High loss variance → batch may be too small")
772
  else:
773
  print(f" ✅ Loss variance is acceptable")
774
 
 
754
  print("\n Batch Size Analysis:")
755
  print(f" Effective batch: {config.effective_batch_size}")
756
 
757
+ if len(train_losses) >= 50:
758
+ recent_losses = train_losses[-50:]
759
  loss_mean = sum(recent_losses) / len(recent_losses)
760
  loss_var = sum((x - loss_mean) ** 2 for x in recent_losses) / len(recent_losses)
761
  loss_cv = (loss_var ** 0.5) / max(loss_mean, 1e-8)
762
 
763
+ print(f" Recent loss CV: {loss_cv:.4f} (coefficient of variation, last 50 steps)")
764
 
765
  if loss_cv > 0.1:
766
  findings.append({
767
+ "issue": "Training loss has high variance",
768
+ "evidence": f"Loss CV = {loss_cv:.4f} over last 50 steps",
769
+ "action": "Check: (1) LR may be too high, (2) increase gradient_accumulation_steps, (3) inspect data quality",
770
  })
771
+ print(f" 🟡 High loss variance → check LR, batch size, or data quality")
772
  else:
773
  print(f" ✅ Loss variance is acceptable")
774