Add NaN detection to diagnose_status classification chain
Browse filesdiagnose_status filtered NaN values but never classified them as
problematic, causing NaN-containing histories (e.g. mock_history_b)
to be reported as NORMAL. Add STATUS_NAN_DETECTED check before
unstable/overfitting/plateau checks.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- llm_lab/training/debugger.py +18 -3
llm_lab/training/debugger.py
CHANGED
|
@@ -41,6 +41,7 @@ STATUS_DIVERGING = "DIVERGING"
|
|
| 41 |
STATUS_PLATEAU = "PLATEAU"
|
| 42 |
STATUS_OVERFITTING = "OVERFITTING"
|
| 43 |
STATUS_UNSTABLE = "UNSTABLE"
|
|
|
|
| 44 |
|
| 45 |
# GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
|
| 46 |
# (param_count, recommended_lr, batch_tokens_str)
|
|
@@ -199,7 +200,21 @@ class LossDebugger:
|
|
| 199 |
)
|
| 200 |
recommended_levels = [1, 2, 3]
|
| 201 |
|
| 202 |
-
# Check 3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
elif recent_std > 0.5 * recent_mean:
|
| 204 |
status = STATUS_UNSTABLE
|
| 205 |
severity = "yellow"
|
|
@@ -209,7 +224,7 @@ class LossDebugger:
|
|
| 209 |
)
|
| 210 |
recommended_levels = [3, 2]
|
| 211 |
|
| 212 |
-
# Check
|
| 213 |
elif val_trend == "increasing" and second_half_avg < first_half_avg:
|
| 214 |
status = STATUS_OVERFITTING
|
| 215 |
severity = "yellow"
|
|
@@ -220,7 +235,7 @@ class LossDebugger:
|
|
| 220 |
)
|
| 221 |
recommended_levels = [4]
|
| 222 |
|
| 223 |
-
# Check
|
| 224 |
elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
|
| 225 |
status = STATUS_PLATEAU
|
| 226 |
severity = "yellow"
|
|
|
|
| 41 |
STATUS_PLATEAU = "PLATEAU"
|
| 42 |
STATUS_OVERFITTING = "OVERFITTING"
|
| 43 |
STATUS_UNSTABLE = "UNSTABLE"
|
| 44 |
+
STATUS_NAN_DETECTED = "NAN_DETECTED"
|
| 45 |
|
| 46 |
# GPT-3 LR reference by model size (Brown et al. 2020, Table 2.1)
|
| 47 |
# (param_count, recommended_lr, batch_tokens_str)
|
|
|
|
| 200 |
)
|
| 201 |
recommended_levels = [1, 2, 3]
|
| 202 |
|
| 203 |
+
# Check 3: NaN detected in training loss
|
| 204 |
+
elif has_nan:
|
| 205 |
+
nan_count = len(raw_train_losses) - len(train_losses)
|
| 206 |
+
nan_idx = next(i for i, l in enumerate(raw_train_losses) if math.isnan(l))
|
| 207 |
+
status = STATUS_NAN_DETECTED
|
| 208 |
+
severity = "red"
|
| 209 |
+
details = (
|
| 210 |
+
f"NaN detected in train_loss: {nan_count} NaN values "
|
| 211 |
+
f"(first at step ~{nan_idx}). "
|
| 212 |
+
f"Before NaN: {first_loss:.4f} -> {last_loss:.4f}. "
|
| 213 |
+
f"Check gradient norms, LR schedule, and numerical precision."
|
| 214 |
+
)
|
| 215 |
+
recommended_levels = [2, 3]
|
| 216 |
+
|
| 217 |
+
# Check 4: Unstable (large spikes)
|
| 218 |
elif recent_std > 0.5 * recent_mean:
|
| 219 |
status = STATUS_UNSTABLE
|
| 220 |
severity = "yellow"
|
|
|
|
| 224 |
)
|
| 225 |
recommended_levels = [3, 2]
|
| 226 |
|
| 227 |
+
# Check 5: Overfitting
|
| 228 |
elif val_trend == "increasing" and second_half_avg < first_half_avg:
|
| 229 |
status = STATUS_OVERFITTING
|
| 230 |
severity = "yellow"
|
|
|
|
| 235 |
)
|
| 236 |
recommended_levels = [4]
|
| 237 |
|
| 238 |
+
# Check 6: Plateau
|
| 239 |
elif abs(second_half_avg - first_half_avg) < 0.05 and last_loss > _EXPECTED_TRAIN_LOSS[1]:
|
| 240 |
status = STATUS_PLATEAU
|
| 241 |
severity = "yellow"
|