| """Per-task grader functions — returns normalized 0.0-1.0 score at episode end. |
| |
| Separate from reward_engine.py. Evaluates EpisodeState holistically. |
| NOT a sum of step rewards. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
|
|
| from ml_training_debugger.models import EpisodeState |
| from ml_training_debugger.scenarios import ScenarioParams |
|
|
| FIX_ACTIONS = frozenset( |
| { |
| "modify_config", |
| "add_callback", |
| "replace_optimizer", |
| "patch_data_loader", |
| "fix_model_mode", |
| "fix_code", |
| } |
| ) |
|
|
|
|
| def _has_action(state: EpisodeState, action_type: str) -> bool: |
| return action_type in state.actions_taken |
|
|
|
|
| def _correct_diagnosis(state: EpisodeState, scenario: ScenarioParams) -> bool: |
| if not state.diagnosis_submitted: |
| return False |
| |
| |
| for action_str in reversed(state.actions_taken): |
| if action_str.startswith("mark_diagnosed:"): |
| submitted = action_str.split(":", 1)[1] |
| return submitted == scenario.root_cause.value |
| return False |
|
|
|
|
| def _submitted_diagnosis(state: EpisodeState) -> str | None: |
| for action_str in reversed(state.actions_taken): |
| if action_str.startswith("mark_diagnosed:"): |
| return action_str.split(":", 1)[1] |
| return None |
|
|
|
|
| def grade_task_001(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 1 — Exploding Gradients (easy).""" |
| score = 0.0 |
|
|
| |
| if state.gradients_inspected: |
| score += 0.05 |
|
|
| |
| if _has_action(state, "modify_config"): |
| score += 0.20 |
|
|
| |
| if state.restart_after_fix: |
| score += 0.35 |
|
|
| |
| if _correct_diagnosis(state, scenario): |
| score += 0.40 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_002(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 2 — Vanishing Gradients (easy).""" |
| score = 0.0 |
|
|
| if state.gradients_inspected: |
| score += 0.05 |
| if _has_action(state, "modify_config"): |
| score += 0.20 |
| if state.restart_after_fix: |
| score += 0.35 |
| if _correct_diagnosis(state, scenario): |
| score += 0.40 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_003(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 3 — Silent Data Leakage (medium).""" |
| score = 0.0 |
|
|
| |
| if state.data_inspected: |
| score += 0.05 |
|
|
| |
| if _has_action(state, "patch_data_loader"): |
| score += 0.30 |
|
|
| |
| if state.restart_after_fix: |
| score += 0.30 |
|
|
| |
| if _correct_diagnosis(state, scenario): |
| score += 0.35 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_004(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 4 — Overfitting (medium).""" |
| score = 0.0 |
|
|
| if state.data_inspected: |
| score += 0.05 |
| if _has_action(state, "modify_config") or _has_action(state, "add_callback"): |
| score += 0.25 |
| if state.restart_after_fix: |
| score += 0.30 |
| if _correct_diagnosis(state, scenario): |
| score += 0.40 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_005(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 5 — BatchNorm Eval Mode (hard). |
| |
| Hard task requires thorough investigation for full credit. |
| Context-gated penalty: -0.20 if add_callback after gradients_were_normal. |
| Wrong-fix penalty: -0.10 if modify_config used (LR isn't the problem). |
| Full credit requires inspecting gradients, weights, data, AND modes. |
| """ |
| score = 0.0 |
|
|
| |
| if state.gradients_inspected: |
| score += 0.05 |
| if state.model_modes_inspected: |
| score += 0.05 |
| if state.data_inspected: |
| score += 0.05 |
|
|
| |
| if ( |
| _has_action(state, "add_callback") |
| and state.gradients_inspected |
| and state.gradients_were_normal |
| ): |
| score -= 0.20 |
|
|
| |
| if _has_action(state, "modify_config"): |
| score -= 0.10 |
|
|
| |
| |
| if _has_action(state, "fix_model_mode"): |
| if state.model_weights_inspected and state.data_inspected: |
| score += 0.25 |
| elif state.model_weights_inspected or state.data_inspected: |
| score += 0.15 |
| else: |
| score += 0.08 |
|
|
| |
| if state.restart_after_fix: |
| if state.model_weights_inspected: |
| score += 0.20 |
| else: |
| score += 0.10 |
|
|
| |
| if _correct_diagnosis(state, scenario): |
| score += 0.40 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 6 — PyTorch Code Bug (hard). |
| |
| Diagnosis must ALWAYS be 'code_bug' regardless of bug variant. |
| Hard task rewards thorough investigation before fixing. |
| Full credit requires ruling out non-code causes via weight inspection. |
| """ |
| score = 0.0 |
|
|
| |
| if state.code_inspected: |
| score += 0.05 |
| if state.gradients_inspected: |
| score += 0.05 |
| if state.model_modes_inspected: |
| score += 0.05 |
| if state.model_weights_inspected: |
| score += 0.05 |
| if state.data_inspected: |
| score += 0.05 |
|
|
| |
| if _has_action(state, "fix_code") and state.fix_action_taken: |
| if state.model_weights_inspected: |
| score += 0.15 |
| else: |
| score += 0.08 |
|
|
| |
| if state.restart_after_fix: |
| if state.model_weights_inspected: |
| score += 0.15 |
| else: |
| score += 0.08 |
|
|
| |
| if _correct_diagnosis(state, scenario): |
| score += 0.45 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| def grade_task_007(state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade Task 7 — LR Scheduler Misconfigured (hard). |
| |
| Requires thorough investigation: agents must inspect weights to rule out |
| weight-related issues before concluding scheduler is the root cause. |
| Penalizes wrong fixes (e.g. patch_data_loader when data is fine). |
| """ |
| score = 0.0 |
|
|
| |
| if state.gradients_inspected: |
| score += 0.05 |
| if state.data_inspected: |
| score += 0.05 |
| if state.model_weights_inspected: |
| score += 0.05 |
| if state.model_modes_inspected: |
| score += 0.05 |
|
|
| |
| if _has_action(state, "modify_config"): |
| if state.model_weights_inspected: |
| score += 0.20 |
| else: |
| score += 0.12 |
|
|
| |
| if state.restart_after_fix: |
| if state.model_weights_inspected: |
| score += 0.20 |
| else: |
| score += 0.12 |
|
|
| |
| if _correct_diagnosis(state, scenario): |
| score += 0.40 |
|
|
| |
| if _has_action(state, "patch_data_loader"): |
| score -= 0.10 |
|
|
| return min(1.0, max(0.0, score)) |
|
|
|
|
| |
| GRADERS = { |
| "task_001": grade_task_001, |
| "task_002": grade_task_002, |
| "task_003": grade_task_003, |
| "task_004": grade_task_004, |
| "task_005": grade_task_005, |
| "task_006": grade_task_006, |
| "task_007": grade_task_007, |
| } |
|
|
|
|
| def grade_episode(task_id: str, state: EpisodeState, scenario: ScenarioParams) -> float: |
| """Grade a completed episode. Returns score in (0.0, 1.0) exclusive.""" |
| grader = GRADERS.get(task_id) |
| if grader is None: |
| return 0.01 |
| score = grader(state, scenario) |
| |
| return max(0.01, min(0.99, score)) |
|
|