UjjwalPardeshi commited on
Commit ·
8435256
1
Parent(s): eeb6913
improved huristic
Browse files- server/app.py +34 -26
- tests/test_endpoints.py +3 -2
server/app.py
CHANGED
|
@@ -260,34 +260,24 @@ def _run_heuristic_episode(
|
|
| 260 |
)
|
| 261 |
return _get_score(env)
|
| 262 |
|
| 263 |
-
#
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
)
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
| 271 |
if (
|
| 272 |
-
(
|
|
|
|
| 273 |
and obs.data_batch_stats
|
| 274 |
-
and obs.data_batch_stats.class_overlap_score < 0.
|
| 275 |
):
|
| 276 |
-
|
| 277 |
-
MLTrainingAction(
|
| 278 |
-
action_type="modify_config",
|
| 279 |
-
target="weight_decay",
|
| 280 |
-
value=0.01,
|
| 281 |
-
)
|
| 282 |
-
)
|
| 283 |
-
env.step(MLTrainingAction(action_type="restart_run"))
|
| 284 |
-
env.step(
|
| 285 |
-
MLTrainingAction(
|
| 286 |
-
action_type="mark_diagnosed",
|
| 287 |
-
diagnosis="overfitting",
|
| 288 |
-
)
|
| 289 |
-
)
|
| 290 |
-
return _get_score(env)
|
| 291 |
|
| 292 |
# Step 3: inspect_model_modes
|
| 293 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
|
@@ -361,7 +351,25 @@ def _run_heuristic_episode(
|
|
| 361 |
)
|
| 362 |
return _get_score(env)
|
| 363 |
|
| 364 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
env.step(
|
| 366 |
MLTrainingAction(
|
| 367 |
action_type="mark_diagnosed",
|
|
|
|
| 260 |
)
|
| 261 |
return _get_score(env)
|
| 262 |
|
| 263 |
+
# Detect overfitting pattern (used later, after ruling out code bugs)
|
| 264 |
+
_looks_like_overfitting = False
|
| 265 |
+
if obs.val_loss_history and obs.training_loss_history and len(obs.val_loss_history) >= 10:
|
| 266 |
+
early_train = sum(obs.training_loss_history[:5]) / 5
|
| 267 |
+
late_train = sum(obs.training_loss_history[-5:]) / 5
|
| 268 |
+
early_val = sum(obs.val_loss_history[:5]) / 5
|
| 269 |
+
late_val = sum(obs.val_loss_history[-5:]) / 5
|
| 270 |
+
train_dropped = late_train < early_train * 0.5
|
| 271 |
+
train_loss_low = late_train < 0.15
|
| 272 |
+
val_not_improving = late_val >= early_val * 0.95
|
| 273 |
+
gap_widening = (late_val - late_train) > (early_val - early_train)
|
| 274 |
if (
|
| 275 |
+
(train_dropped or train_loss_low)
|
| 276 |
+
and (val_not_improving or gap_widening)
|
| 277 |
and obs.data_batch_stats
|
| 278 |
+
and obs.data_batch_stats.class_overlap_score < 0.3
|
| 279 |
):
|
| 280 |
+
_looks_like_overfitting = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
# Step 3: inspect_model_modes
|
| 283 |
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
|
|
|
| 351 |
)
|
| 352 |
return _get_score(env)
|
| 353 |
|
| 354 |
+
# Overfitting fallback — only if code inspection didn't find a bug
|
| 355 |
+
if _looks_like_overfitting:
|
| 356 |
+
env.step(
|
| 357 |
+
MLTrainingAction(
|
| 358 |
+
action_type="modify_config",
|
| 359 |
+
target="weight_decay",
|
| 360 |
+
value=0.01,
|
| 361 |
+
)
|
| 362 |
+
)
|
| 363 |
+
env.step(MLTrainingAction(action_type="restart_run"))
|
| 364 |
+
env.step(
|
| 365 |
+
MLTrainingAction(
|
| 366 |
+
action_type="mark_diagnosed",
|
| 367 |
+
diagnosis="overfitting",
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
return _get_score(env)
|
| 371 |
+
|
| 372 |
+
# Final fallback
|
| 373 |
env.step(
|
| 374 |
MLTrainingAction(
|
| 375 |
action_type="mark_diagnosed",
|
tests/test_endpoints.py
CHANGED
|
@@ -124,11 +124,12 @@ class TestBaselineEndpoint:
|
|
| 124 |
for task_id, score in scores.items():
|
| 125 |
assert 0.0 <= score <= 1.0, f"{task_id}: {score}"
|
| 126 |
|
| 127 |
-
def
|
| 128 |
resp = client.post("/baseline")
|
| 129 |
scores = resp.json()["scores"]
|
| 130 |
values = list(scores.values())
|
| 131 |
-
assert
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
# ---------- /dashboard ----------
|
|
|
|
| 124 |
for task_id, score in scores.items():
|
| 125 |
assert 0.0 <= score <= 1.0, f"{task_id}: {score}"
|
| 126 |
|
| 127 |
+
def test_baseline_scores_in_valid_range(self, client):
|
| 128 |
resp = client.post("/baseline")
|
| 129 |
scores = resp.json()["scores"]
|
| 130 |
values = list(scores.values())
|
| 131 |
+
assert all(0.0 <= v <= 1.0 for v in values), "Scores must be in [0.0, 1.0]"
|
| 132 |
+
assert len(values) >= 3, "Need at least 3 tasks"
|
| 133 |
|
| 134 |
|
| 135 |
# ---------- /dashboard ----------
|