Vjeong Claude Opus 4.6 commited on
Commit
8313ca8
·
1 Parent(s): d789de8

Add NaN detection to diagnose_status classification chain

Browse files

diagnose_status filtered NaN values but never classified them as
problematic, causing NaN-containing histories (e.g. mock_history_b)
to be reported as NORMAL. Add STATUS_NAN_DETECTED check before
unstable/overfitting/plateau checks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. llm_lab/training/debugger.py +18 -3
llm_lab/training/debugger.py CHANGED
@@ -41,6 +41,7 @@ STATUS_DIVERGING = "DIVERGING"
41
  STATUS_PLATEAU = "PLATEAU"
42
  STATUS_OVERFITTING = "OVERFITTING"
43
  STATUS_UNSTABLE = "UNSTABLE"
 
44
 
45
  # GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
46
  # (param_count, recommended_lr, batch_tokens_str)
@@ -199,7 +200,21 @@ class LossDebugger:
199
  )
200
  recommended_levels = [1, 2, 3]
201
 
202
- # Check 3: Unstable (large spikes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  elif recent_std > 0.5 * recent_mean:
204
  status = STATUS_UNSTABLE
205
  severity = "yellow"
@@ -209,7 +224,7 @@ class LossDebugger:
209
  )
210
  recommended_levels = [3, 2]
211
 
212
- # Check 4: Overfitting
213
  elif val_trend == "increasing" and second_half_avg < first_half_avg:
214
  status = STATUS_OVERFITTING
215
  severity = "yellow"
@@ -220,7 +235,7 @@ class LossDebugger:
220
  )
221
  recommended_levels = [4]
222
 
223
- # Check 5: Plateau
224
  elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
225
  status = STATUS_PLATEAU
226
  severity = "yellow"
 
41
  STATUS_PLATEAU = "PLATEAU"
42
  STATUS_OVERFITTING = "OVERFITTING"
43
  STATUS_UNSTABLE = "UNSTABLE"
44
+ STATUS_NAN_DETECTED = "NAN_DETECTED"
45
 
46
  # GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
47
  # (param_count, recommended_lr, batch_tokens_str)
 
200
  )
201
  recommended_levels = [1, 2, 3]
202
 
203
+ # Check 3: NaN detected in training loss
204
+ elif has_nan:
205
+ nan_count = len(raw_train_losses) - len(train_losses)
206
+ nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l))
207
+ status = STATUS_NAN_DETECTED
208
+ severity = "red"
209
+ details = (
210
+ f"NaN detected in train_loss: {nan_count} NaN values "
211
+ f"(first at step ~{nan_idx}). "
212
+ f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. "
213
+ f"Check gradient norms, LR schedule, and numerical precision."
214
+ )
215
+ recommended_levels = [2, 3]
216
+
217
+ # Check 4: Unstable (large spikes)
218
  elif recent_std > 0.5 * recent_mean:
219
  status = STATUS_UNSTABLE
220
  severity = "yellow"
 
224
  )
225
  recommended_levels = [3, 2]
226
 
227
+ # Check 5: Overfitting
228
  elif val_trend == "increasing" and second_half_avg < first_half_avg:
229
  status = STATUS_OVERFITTING
230
  severity = "yellow"
 
235
  )
236
  recommended_levels = [4]
237
 
238
+ # Check 6: Plateau
239
  elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
240
  status = STATUS_PLATEAU
241
  severity = "yellow"