omkarrr88 commited on
Commit
02e58fe
·
1 Parent(s): a3e1032

task 6 and 7 made hard

Browse files
README.md CHANGED
@@ -107,6 +107,7 @@ Fields like `gradient_stats`, `data_batch_stats`, `model_mode_info`, and `code_s
107
 
108
  **Terminal** — end the episode:
109
  - `restart_run` — restart training (only available after a fix)
 
110
  - `mark_diagnosed` — submit diagnosis from 7 possible root causes
111
 
112
  Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
@@ -156,13 +157,14 @@ An agent that chases the gradient spike red herring loses 0.20 points. An agent
156
  | `task_003` | Medium | **1.00** | 0.40 |
157
  | `task_004` | Medium | **1.00** | 0.60 |
158
  | `task_005` | Hard | **0.80** | 0.38-0.55 |
159
- | `task_006` | Hard | **1.00** | 0.60-1.00 |
160
- | `task_007` | Hard | **1.00** | 0.60 |
161
- | **Average** | | **0.97** | 0.52 |
162
 
163
  **What this tells you:**
164
- - **Hard tasks are genuinely hard:** Task 5 requires thorough investigation (weight AND data inspection) for full credit. The heuristic scores 0.80 because it skips weight inspection. An LLM that falls for the gradient red herring scores 0.48 or lower.
165
  - **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
 
166
  - **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
167
  - **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
168
 
@@ -247,7 +249,7 @@ pip install pytest pytest-cov pytest-asyncio httpx websockets
247
  # Start server
248
  uvicorn server.app:app --host 0.0.0.0 --port 7860
249
 
250
- # Run tests (255 tests, 97% coverage)
251
  pytest tests/ -v --cov=ml_training_debugger
252
 
253
  # Run heuristic baseline
@@ -284,7 +286,7 @@ ml_training_debugger/
284
  models.py — Pydantic data models (Action, Observation, EpisodeState)
285
  scenarios.py — Task parameter sampling (7 tasks, deterministic per seed)
286
  pytorch_engine.py — Real PyTorch models, fault injection, gradient/weight extraction
287
- simulation.py — 20-epoch real training with parametric fallback
288
  reward_engine.py — 7-component per-step reward with context gating
289
  graders.py — Per-task holistic 0.0-1.0 scoring
290
  code_templates.py — Task 6 bug variants + 4-strategy fix validation
@@ -295,7 +297,7 @@ server/
295
  app.py — FastAPI + custom endpoints
296
  dashboard.html — Live Plotly.js diagnostic dashboard
297
 
298
- tests/ — 255 tests, 97% coverage
299
  baseline_heuristic.py — Rule-based agent (deterministic, no API key)
300
  baseline_inference.py — LLM agent (Groq/Cerebras/Gemini/OpenAI)
301
  ```
 
107
 
108
  **Terminal** — end the episode:
109
  - `restart_run` — restart training (only available after a fix)
110
+ - `rollback_checkpoint` — rollback to pre-fix state (only available after restart)
111
  - `mark_diagnosed` — submit diagnosis from 7 possible root causes
112
 
113
  Actions are dynamically available based on episode state: `fix_code` requires prior code inspection, `restart_run` requires a fix, `mark_diagnosed` disappears after submission.
 
157
  | `task_003` | Medium | **1.00** | 0.40 |
158
  | `task_004` | Medium | **1.00** | 0.60 |
159
  | `task_005` | Hard | **0.80** | 0.38-0.55 |
160
+ | `task_006` | Hard | **0.81** | 0.60-1.00 |
161
+ | `task_007` | Hard | **0.79** | 0.60 |
162
+ | **Average** | | **0.91** | 0.52 |
163
 
164
  **What this tells you:**
165
+ - **Hard tasks are genuinely hard:** All three hard tasks (5, 6, 7) require thorough investigation including weight inspection for full credit. The heuristic scores 0.79-0.81 on hard tasks because it skips weight inspection. An LLM that falls for red herrings or skips investigation scores even lower.
166
  - **Red herring traps work:** Task 5 penalizes agents that call `add_callback` after seeing normal gradients (-0.20) or `modify_config` when LR isn't the issue (-0.10). LLMs routinely fall for both traps.
167
+ - **Investigation thoroughness matters:** Tasks 6 and 7 scale fix/restart credit based on how thoroughly the agent investigated before acting. Quick fixes without ruling out alternatives score ~60-65% of full credit.
168
  - **8B struggles on multi-step tasks:** Task 2 score of 0.05 shows small models can't maintain investigation strategy across many steps.
169
  - **The heuristic baseline is strong** because it was designed with knowledge of the task structure. An agent that doesn't know the structure has to figure it out from observations alone.
170
 
 
249
  # Start server
250
  uvicorn server.app:app --host 0.0.0.0 --port 7860
251
 
252
+ # Run tests (246 tests, 96% coverage)
253
  pytest tests/ -v --cov=ml_training_debugger
254
 
255
  # Run heuristic baseline
 
286
  models.py — Pydantic data models (Action, Observation, EpisodeState)
287
  scenarios.py — Task parameter sampling (7 tasks, deterministic per seed)
288
  pytorch_engine.py — Real PyTorch models, fault injection, gradient/weight extraction
289
+ simulation.py — 20-epoch real training with fault injection
290
  reward_engine.py — 7-component per-step reward with context gating
291
  graders.py — Per-task holistic 0.0-1.0 scoring
292
  code_templates.py — Task 6 bug variants + 4-strategy fix validation
 
297
  app.py — FastAPI + custom endpoints
298
  dashboard.html — Live Plotly.js diagnostic dashboard
299
 
300
+ tests/ — 246 tests, 96% coverage
301
  baseline_heuristic.py — Rule-based agent (deterministic, no API key)
302
  baseline_inference.py — LLM agent (Groq/Cerebras/Gemini/OpenAI)
303
  ```
ml_training_debugger/graders.py CHANGED
@@ -183,26 +183,35 @@ def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
183
 
184
  Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
185
  Hard task rewards thorough investigation before fixing.
 
186
  """
187
  score = 0.0
188
 
189
- # +0.05 for inspect_code
190
  if state.code_inspected:
191
  score += 0.05
192
-
193
- # Thoroughness bonus: inspecting other systems first rules out non-code causes
194
  if state.gradients_inspected:
195
  score += 0.05
196
  if state.model_modes_inspected:
197
  score += 0.05
 
 
 
 
198
 
199
- # Code fix credit (reduced base, bonus for thorough investigation)
200
  if _has_action(state, "fix_code") and state.fix_action_taken:
201
- score += 0.20
 
 
 
202
 
203
- # Restart credit
204
  if state.restart_after_fix:
205
- score += 0.20
 
 
 
206
 
207
  # +0.45 for correct diagnosis (must be code_bug)
208
  if _correct_diagnosis(state, scenario):
@@ -212,20 +221,46 @@ def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
212
 
213
 
214
  def grade_task_007(state: EpisodeState, scenario: ScenarioParams) -> float:
215
- """Grade Task 7 — LR Scheduler Misconfigured (medium-hard). Spec extension."""
 
 
 
 
 
216
  score = 0.0
217
 
 
218
  if state.gradients_inspected:
219
  score += 0.05
220
  if state.data_inspected:
221
  score += 0.05
 
 
 
 
 
 
222
  if _has_action(state, "modify_config"):
223
- score += 0.25
 
 
 
 
 
224
  if state.restart_after_fix:
225
- score += 0.25
 
 
 
 
 
226
  if _correct_diagnosis(state, scenario):
227
  score += 0.40
228
 
 
 
 
 
229
  return min(1.0, max(0.0, score))
230
 
231
 
 
183
 
184
  Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
185
  Hard task rewards thorough investigation before fixing.
186
+ Full credit requires ruling out non-code causes via weight inspection.
187
  """
188
  score = 0.0
189
 
190
+ # Investigation credits (+0.05 each, up to +0.25 for all 5 types)
191
  if state.code_inspected:
192
  score += 0.05
 
 
193
  if state.gradients_inspected:
194
  score += 0.05
195
  if state.model_modes_inspected:
196
  score += 0.05
197
+ if state.model_weights_inspected:
198
+ score += 0.05
199
+ if state.data_inspected:
200
+ score += 0.05
201
 
202
+ # Code fix credit scaled by investigation thoroughness
203
  if _has_action(state, "fix_code") and state.fix_action_taken:
204
+ if state.model_weights_inspected:
205
+ score += 0.15 # Thorough: ruled out weight-related causes
206
+ else:
207
+ score += 0.08 # Quick fix without full investigation
208
 
209
+ # Restart credit scaled by thoroughness
210
  if state.restart_after_fix:
211
+ if state.model_weights_inspected:
212
+ score += 0.15 # Full restart credit
213
+ else:
214
+ score += 0.08 # Partial credit
215
 
216
  # +0.45 for correct diagnosis (must be code_bug)
217
  if _correct_diagnosis(state, scenario):
 
221
 
222
 
223
  def grade_task_007(state: EpisodeState, scenario: ScenarioParams) -> float:
224
+ """Grade Task 7 — LR Scheduler Misconfigured (hard). Spec extension.
225
+
226
+ Requires thorough investigation: agents must inspect weights to rule out
227
+ weight-related issues before concluding scheduler is the root cause.
228
+ Penalizes wrong fixes (e.g. patch_data_loader when data is fine).
229
+ """
230
  score = 0.0
231
 
232
+ # Investigation credits (+0.05 each, up to +0.20 for all 4 types)
233
  if state.gradients_inspected:
234
  score += 0.05
235
  if state.data_inspected:
236
  score += 0.05
237
+ if state.model_weights_inspected:
238
+ score += 0.05
239
+ if state.model_modes_inspected:
240
+ score += 0.05
241
+
242
+ # Fix credit scaled by investigation thoroughness
243
  if _has_action(state, "modify_config"):
244
+ if state.model_weights_inspected:
245
+ score += 0.20 # Thorough: ruled out weight issues
246
+ else:
247
+ score += 0.12 # Partial: didn't check weights
248
+
249
+ # Restart credit scaled by thoroughness
250
  if state.restart_after_fix:
251
+ if state.model_weights_inspected:
252
+ score += 0.20 # Full restart credit
253
+ else:
254
+ score += 0.12 # Partial credit
255
+
256
+ # Diagnosis
257
  if _correct_diagnosis(state, scenario):
258
  score += 0.40
259
 
260
+ # Wrong-fix penalty: patch_data_loader when data is clean
261
+ if _has_action(state, "patch_data_loader"):
262
+ score -= 0.10
263
+
264
  return min(1.0, max(0.0, score))
265
 
266
 
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "pytorch-training-debugger"
3
- version = "1.0.0"
4
  description = "OpenEnv RL environment for PyTorch training failure debugging"
5
  requires-python = ">=3.12"
6
  dependencies = [
 
1
  [project]
2
  name = "pytorch-training-debugger"
3
+ version = "1.1.0"
4
  description = "OpenEnv RL environment for PyTorch training failure debugging"
5
  requires-python = ">=3.12"
6
  dependencies = [
server/app.py CHANGED
@@ -12,7 +12,7 @@ import sys
12
  from typing import Optional
13
 
14
  from fastapi import FastAPI
15
- from fastapi.responses import HTMLResponse, JSONResponse
16
  from openenv.core.env_server.http_server import create_app
17
 
18
  from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
@@ -77,9 +77,8 @@ _baseline_lock = asyncio.Lock()
77
 
78
 
79
  @app.get("/")
80
- def root():
81
  """Redirect root to dashboard."""
82
- from fastapi.responses import RedirectResponse
83
  return RedirectResponse(url="/dashboard")
84
 
85
 
@@ -174,7 +173,7 @@ def post_grader(session_id: Optional[str] = None) -> dict:
174
 
175
 
176
  @app.post("/baseline", response_model=None)
177
- async def post_baseline():
178
  """Trigger baseline run, return scores for all tasks.
179
 
180
  Returns 409 if already running. Uses asyncio.Lock for thread safety.
 
12
  from typing import Optional
13
 
14
  from fastapi import FastAPI
15
+ from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
16
  from openenv.core.env_server.http_server import create_app
17
 
18
  from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
 
77
 
78
 
79
  @app.get("/")
80
+ def root() -> RedirectResponse:
81
  """Redirect root to dashboard."""
 
82
  return RedirectResponse(url="/dashboard")
83
 
84
 
 
173
 
174
 
175
  @app.post("/baseline", response_model=None)
176
+ async def post_baseline() -> JSONResponse | dict:
177
  """Trigger baseline run, return scores for all tasks.
178
 
179
  Returns 409 if already running. Uses asyncio.Lock for thread safety.
tests/test_graders.py CHANGED
@@ -10,6 +10,7 @@ from ml_training_debugger.graders import (
10
  grade_task_001,
11
  grade_task_003,
12
  grade_task_005,
 
13
  grade_task_007,
14
  )
15
  from ml_training_debugger.models import EpisodeState
@@ -241,25 +242,168 @@ class TestGradeEpisode:
241
  assert score == 0.0
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  class TestGradeTask007:
245
- def test_perfect_score(self):
 
246
  scenario = sample_scenario("task_007", seed=42)
247
  state = EpisodeState(
248
  gradients_inspected=True,
249
  data_inspected=True,
 
 
250
  fix_action_taken=True,
251
  restart_after_fix=True,
252
  diagnosis_submitted=True,
253
  actions_taken=[
254
  "inspect_gradients",
255
  "inspect_data_batch",
 
 
256
  "modify_config",
257
  "restart_run",
258
  "mark_diagnosed:scheduler_misconfigured",
259
  ],
260
  )
261
  score = grade_task_007(state, scenario)
262
- assert score == 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def test_wrong_diagnosis(self):
265
  scenario = sample_scenario("task_007", seed=42)
 
10
  grade_task_001,
11
  grade_task_003,
12
  grade_task_005,
13
+ grade_task_006,
14
  grade_task_007,
15
  )
16
  from ml_training_debugger.models import EpisodeState
 
242
  assert score == 0.0
243
 
244
 
245
+ class TestGradeTask006:
246
+ @pytest.fixture
247
+ def scenario_006(self):
248
+ return sample_scenario("task_006", seed=42)
249
+
250
+ def test_perfect_score_thorough(self, scenario_006):
251
+ """Thorough agent inspects ALL systems including weights — gets perfect score."""
252
+ state = EpisodeState(
253
+ code_inspected=True,
254
+ gradients_inspected=True,
255
+ model_modes_inspected=True,
256
+ model_weights_inspected=True,
257
+ data_inspected=True,
258
+ fix_action_taken=True,
259
+ restart_after_fix=True,
260
+ diagnosis_submitted=True,
261
+ actions_taken=[
262
+ "inspect_gradients",
263
+ "inspect_data_batch",
264
+ "inspect_model_weights",
265
+ "inspect_model_modes",
266
+ "inspect_code",
267
+ "fix_code",
268
+ "restart_run",
269
+ "mark_diagnosed:code_bug",
270
+ ],
271
+ )
272
+ score = grade_task_006(state, scenario_006)
273
+ assert score == pytest.approx(1.0)
274
+
275
+ def test_no_weights_inspection_partial(self, scenario_006):
276
+ """Agent that skips weight inspection gets reduced fix/restart credit."""
277
+ state = EpisodeState(
278
+ code_inspected=True,
279
+ gradients_inspected=True,
280
+ model_modes_inspected=True,
281
+ data_inspected=True,
282
+ fix_action_taken=True,
283
+ restart_after_fix=True,
284
+ diagnosis_submitted=True,
285
+ actions_taken=[
286
+ "inspect_gradients",
287
+ "inspect_data_batch",
288
+ "inspect_model_modes",
289
+ "inspect_code",
290
+ "fix_code",
291
+ "restart_run",
292
+ "mark_diagnosed:code_bug",
293
+ ],
294
+ )
295
+ score = grade_task_006(state, scenario_006)
296
+ # 0.05*4 + 0.08 + 0.08 + 0.45 = 0.81
297
+ assert score == pytest.approx(0.81)
298
+ assert score < 1.0 # Must not be perfect without weights
299
+
300
+ def test_minimal_investigation(self, scenario_006):
301
+ """Agent that only inspects code, fixes, and diagnoses."""
302
+ state = EpisodeState(
303
+ code_inspected=True,
304
+ fix_action_taken=True,
305
+ restart_after_fix=True,
306
+ diagnosis_submitted=True,
307
+ actions_taken=[
308
+ "inspect_code",
309
+ "fix_code",
310
+ "restart_run",
311
+ "mark_diagnosed:code_bug",
312
+ ],
313
+ )
314
+ score = grade_task_006(state, scenario_006)
315
+ # 0.05 + 0.08 + 0.08 + 0.45 = 0.66
316
+ assert score == pytest.approx(0.66)
317
+
318
+ def test_wrong_diagnosis(self, scenario_006):
319
+ """Submitting batchnorm_eval_mode on a code_bug task fails."""
320
+ state = EpisodeState(
321
+ code_inspected=True,
322
+ diagnosis_submitted=True,
323
+ actions_taken=[
324
+ "inspect_code",
325
+ "mark_diagnosed:batchnorm_eval_mode",
326
+ ],
327
+ )
328
+ score = grade_task_006(state, scenario_006)
329
+ assert score < 0.2 # Only gets code_inspected bonus
330
+
331
+ def test_score_in_range(self, scenario_006):
332
+ state = EpisodeState()
333
+ score = grade_task_006(state, scenario_006)
334
+ assert 0.0 <= score <= 1.0
335
+
336
+
337
  class TestGradeTask007:
338
+ def test_perfect_score_thorough(self):
339
+ """Thorough agent inspects weights — gets perfect score."""
340
  scenario = sample_scenario("task_007", seed=42)
341
  state = EpisodeState(
342
  gradients_inspected=True,
343
  data_inspected=True,
344
+ model_weights_inspected=True,
345
+ model_modes_inspected=True,
346
  fix_action_taken=True,
347
  restart_after_fix=True,
348
  diagnosis_submitted=True,
349
  actions_taken=[
350
  "inspect_gradients",
351
  "inspect_data_batch",
352
+ "inspect_model_weights",
353
+ "inspect_model_modes",
354
  "modify_config",
355
  "restart_run",
356
  "mark_diagnosed:scheduler_misconfigured",
357
  ],
358
  )
359
  score = grade_task_007(state, scenario)
360
+ assert score == pytest.approx(1.0)
361
+
362
+ def test_no_weights_partial(self):
363
+ """Agent without weight inspection gets reduced fix/restart credit."""
364
+ scenario = sample_scenario("task_007", seed=42)
365
+ state = EpisodeState(
366
+ gradients_inspected=True,
367
+ data_inspected=True,
368
+ model_modes_inspected=True,
369
+ fix_action_taken=True,
370
+ restart_after_fix=True,
371
+ diagnosis_submitted=True,
372
+ actions_taken=[
373
+ "inspect_gradients",
374
+ "inspect_data_batch",
375
+ "inspect_model_modes",
376
+ "modify_config",
377
+ "restart_run",
378
+ "mark_diagnosed:scheduler_misconfigured",
379
+ ],
380
+ )
381
+ score = grade_task_007(state, scenario)
382
+ # 0.05*3 + 0.12 + 0.12 + 0.40 = 0.79
383
+ assert score == pytest.approx(0.79)
384
+ assert score < 1.0
385
+
386
+ def test_wrong_fix_penalty(self):
387
+ """Agent that patches data loader (wrong fix) gets penalized."""
388
+ scenario = sample_scenario("task_007", seed=42)
389
+ state = EpisodeState(
390
+ gradients_inspected=True,
391
+ data_inspected=True,
392
+ fix_action_taken=True,
393
+ restart_after_fix=True,
394
+ diagnosis_submitted=True,
395
+ actions_taken=[
396
+ "inspect_gradients",
397
+ "inspect_data_batch",
398
+ "patch_data_loader",
399
+ "modify_config",
400
+ "restart_run",
401
+ "mark_diagnosed:scheduler_misconfigured",
402
+ ],
403
+ )
404
+ score = grade_task_007(state, scenario)
405
+ # Normal partial score minus 0.10 penalty
406
+ assert score < 0.75
407
 
408
  def test_wrong_diagnosis(self):
409
  scenario = sample_scenario("task_007", seed=42)