Spaces:
Running
Running
| import sys | |
| import os | |
| from uuid import uuid4 | |
| from typing import Optional | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import DebugAction, DebugObservation, DebugState | |
| from bug_generator import ( | |
| get_scenario, | |
| execute_tool, | |
| BugScenario, | |
| ALL_TASKS, | |
| AVAILABLE_TOOLS, | |
| TASK_SHAPE_MISMATCH, | |
| TASK_TRAINING_COLLAPSE, | |
| TASK_DATA_LEAKAGE, | |
| TASK_WRONG_DEVICE, | |
| TASK_GRADIENT_NOT_ZEROED, | |
| TASK_MISSING_EVAL_MODE, | |
| TASK_COMPOUND_SHAPE_DEVICE, | |
| TASK_COMPOUND_LEAKAGE_EVAL, | |
| ) | |
| from grader import grade, GradeResult | |
| from adversarial_scheduler import AdversarialScheduler | |
| MAX_STEPS = 5 | |
| SUCCESS_THRESHOLD = 0.95 | |
| # Module-level session store — shared across all instances | |
| _SESSION_STORE: dict = {} | |
| def _efficiency_multiplier(steps_used: int, total_steps: int) -> float: | |
| """ | |
| Reward agents that fix bugs efficiently. | |
| steps_used = number of steps taken when fix was submitted (1-indexed). | |
| """ | |
| if steps_used <= 2: | |
| return 1.2 | |
| elif steps_used <= 3: | |
| return 1.1 | |
| else: | |
| return 1.0 | |
| class MlDebugEnvEnvironment(Environment): | |
| """ | |
| ML Debugging Environment — 8 tasks, easy → expert. | |
| Partially observable: agent sees only a minimal alert on reset(). | |
| Must use tool calls (inspect actions) to gather information before fixing. | |
| Episode structure: | |
| - reset() → minimal alert, available tools, step budget | |
| - step(action_type="inspect", tool_name=X) → tool output (costs 1 step) | |
| - step(action_type="fix", bug_type=X, ...) → grader score (costs 1 step) | |
| - Max 5 steps total across all inspect + fix actions | |
| Efficiency bonus: | |
| - Fix correct in ≤2 total steps → score × 1.2 (capped at 0.99) | |
| - Fix correct in ≤3 total steps → score × 1.1 | |
| - Fix in 4-5 steps → score × 1.0 | |
| Single-bug tasks (6): | |
| shape_mismatch, training_collapse, wrong_device, | |
| gradient_not_zeroed, data_leakage, missing_eval_mode | |
| Compound tasks — TWO bugs per script (2): | |
| compound_shape_device, compound_leakage_eval | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self, task_id: Optional[str] = None): | |
| super().__init__() | |
| self._task_id: Optional[str] = task_id | |
| self._current_scenario: Optional[BugScenario] = None | |
| self._state = DebugState( | |
| episode_id=None, | |
| step_count=0, | |
| task_id="", | |
| max_steps=MAX_STEPS, | |
| current_score=0.0, | |
| attempts=0, | |
| tools_used=[], | |
| fix_submitted=False, | |
| ) | |
| self._episode_count = 0 | |
| self._scheduler = AdversarialScheduler(ALL_TASKS) | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| **kwargs, | |
| ) -> DebugObservation: | |
| active_task = task_id or self._task_id or self._scheduler.next_task() | |
| effective_seed = seed if seed is not None else self._scheduler.next_seed(active_task) | |
| scenario = get_scenario(active_task, seed=effective_seed) | |
| eid = episode_id or str(uuid4()) | |
| state = DebugState( | |
| episode_id=eid, | |
| step_count=0, | |
| task_id=active_task, | |
| max_steps=MAX_STEPS, | |
| current_score=0.0, | |
| attempts=0, | |
| tools_used=[], | |
| fix_submitted=False, | |
| ) | |
| self._current_scenario = scenario | |
| self._state = state | |
| self._episode_id = eid | |
| _SESSION_STORE[eid] = {"scenario": scenario, "state": state} | |
| return DebugObservation( | |
| task_id=active_task, | |
| alert=scenario.alert, | |
| available_tools=AVAILABLE_TOOLS, | |
| step_budget=MAX_STEPS, | |
| step_number=0, | |
| num_bugs=scenario.num_bugs, | |
| action_type=None, | |
| tool_name=None, | |
| tool_result=None, | |
| grader_score=None, | |
| grader_feedback=None, | |
| execution_result=None, | |
| done=False, | |
| reward=None, | |
| efficiency_multiplier=None, | |
| ) | |
| def step( | |
| self, | |
| action: DebugAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs, | |
| ) -> DebugObservation: | |
| if self._current_scenario is None: | |
| # Try to recover from session store using most recent session | |
| if _SESSION_STORE: | |
| latest = list(_SESSION_STORE.values())[-1] | |
| self._current_scenario = latest["scenario"] | |
| self._state = latest["state"] | |
| else: | |
| raise RuntimeError("Call reset() before step().") | |
| self._state.step_count += 1 | |
| steps_remaining = MAX_STEPS - self._state.step_count | |
| if action.action_type == "inspect": | |
| return self._handle_inspect(action, steps_remaining) | |
| elif action.action_type == "fix": | |
| return self._handle_fix(action, steps_remaining) | |
| else: | |
| self._state.step_count -= 1 | |
| raise ValueError(f"Unknown action_type: '{action.action_type}'. Must be 'inspect' or 'fix'.") | |
| def _handle_inspect(self, action: DebugAction, steps_remaining: int) -> DebugObservation: | |
| tool_name = action.tool_name or "" | |
| if tool_name not in AVAILABLE_TOOLS: | |
| tool_result = ( | |
| f"Unknown tool: '{tool_name}'. " | |
| f"Available tools: {AVAILABLE_TOOLS}" | |
| ) | |
| else: | |
| tool_result = execute_tool(tool_name, self._current_scenario) | |
| self._state.tools_used.append(tool_name) | |
| done = self._state.step_count >= MAX_STEPS | |
| return DebugObservation( | |
| task_id=self._state.task_id, | |
| alert=self._current_scenario.alert, | |
| available_tools=AVAILABLE_TOOLS, | |
| step_budget=steps_remaining, | |
| step_number=self._state.step_count, | |
| num_bugs=self._current_scenario.num_bugs, | |
| action_type="inspect", | |
| tool_name=tool_name, | |
| tool_result=tool_result, | |
| grader_score=None, | |
| grader_feedback=None, | |
| execution_result=None, | |
| done=done, | |
| reward=0.0, | |
| efficiency_multiplier=None, | |
| ) | |
| def _handle_fix(self, action: DebugAction, steps_remaining: int) -> DebugObservation: | |
| self._state.attempts += 1 | |
| self._state.fix_submitted = True | |
| bug_type = action.bug_type or "other" | |
| diagnosis = action.diagnosis or "" | |
| fixed_code = action.fixed_code or "" | |
| result: GradeResult = grade( | |
| action_bug_type=bug_type, | |
| action_diagnosis=diagnosis, | |
| fixed_code=fixed_code, | |
| scenario=self._current_scenario, | |
| ) | |
| multiplier = _efficiency_multiplier(self._state.step_count, MAX_STEPS) | |
| final_score = min(result.score * multiplier, 0.99) | |
| if final_score > self._state.current_score: | |
| self._state.current_score = final_score | |
| done = final_score >= SUCCESS_THRESHOLD or self._state.step_count >= MAX_STEPS | |
| if done: | |
| self._scheduler.record(self._state.task_id, final_score) | |
| return DebugObservation( | |
| task_id=self._state.task_id, | |
| alert=self._current_scenario.alert, | |
| available_tools=AVAILABLE_TOOLS, | |
| step_budget=steps_remaining, | |
| step_number=self._state.step_count, | |
| num_bugs=self._current_scenario.num_bugs, | |
| action_type="fix", | |
| tool_name=None, | |
| tool_result=None, | |
| grader_score=final_score, | |
| grader_feedback=result.feedback, | |
| execution_result=result.execution_output, | |
| done=done, | |
| reward=final_score, | |
| efficiency_multiplier=multiplier, | |
| ) | |
| def state(self) -> DebugState: | |
| return self._state | |
| def get_metadata(self): | |
| from openenv.core.env_server.types import EnvironmentMetadata | |
| return EnvironmentMetadata( | |
| name="ML Debugging Environment", | |
| description=( | |
| "Partially observable RL environment where agents debug broken PyTorch training scripts. " | |
| "Agent sees only a minimal failure alert on reset — no code, no traceback. " | |
| "Must use tool calls (run_code, get_traceback, inspect_gradients, print_shapes, view_source) " | |
| "to investigate before submitting a fix. " | |
| "5 steps total per episode. Efficiency bonus: fix in ≤2 steps → ×1.2 reward. " | |
| "8 tasks: six single-bug (easy→hard), two compound double-bug tasks (expert). " | |
| "Execution-based grading in subprocess." | |
| ), | |
| version="4.0.0", | |
| author="ml-debug-env", | |
| ) |