Spaces:
Running
Running
File size: 8,988 Bytes
70a9d5e 49aa3ca 70a9d5e 6d9a8b2 49aa3ca 70a9d5e 4108ae8 6d9a8b2 70a9d5e 49aa3ca 70a9d5e 49aa3ca 1099086 49aa3ca 70a9d5e 6d9a8b2 49aa3ca 6d9a8b2 49aa3ca 6d9a8b2 49aa3ca 70a9d5e 4108ae8 70a9d5e 49aa3ca 70a9d5e 4108ae8 49aa3ca 70a9d5e 49aa3ca 70a9d5e 1099086 70a9d5e 49aa3ca 70a9d5e 1099086 70a9d5e 49aa3ca 70a9d5e 6d9a8b2 49aa3ca 70a9d5e 49aa3ca 70a9d5e 1099086 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 6d9a8b2 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e 49aa3ca 70a9d5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 | import sys
import os
from uuid import uuid4
from typing import Optional
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import DebugAction, DebugObservation, DebugState
from bug_generator import (
get_scenario,
execute_tool,
BugScenario,
ALL_TASKS,
AVAILABLE_TOOLS,
TASK_SHAPE_MISMATCH,
TASK_TRAINING_COLLAPSE,
TASK_DATA_LEAKAGE,
TASK_WRONG_DEVICE,
TASK_GRADIENT_NOT_ZEROED,
TASK_MISSING_EVAL_MODE,
TASK_COMPOUND_SHAPE_DEVICE,
TASK_COMPOUND_LEAKAGE_EVAL,
)
from grader import grade, GradeResult
from adversarial_scheduler import AdversarialScheduler
MAX_STEPS = 5
SUCCESS_THRESHOLD = 0.95
# Module-level session store β shared across all instances
_SESSION_STORE: dict = {}
def _efficiency_multiplier(steps_used: int, total_steps: int) -> float:
"""
Reward agents that fix bugs efficiently.
steps_used = number of steps taken when fix was submitted (1-indexed).
"""
if steps_used <= 2:
return 1.2
elif steps_used <= 3:
return 1.1
else:
return 1.0
class MlDebugEnvEnvironment(Environment):
"""
ML Debugging Environment β 8 tasks, easy β expert.
Partially observable: agent sees only a minimal alert on reset().
Must use tool calls (inspect actions) to gather information before fixing.
Episode structure:
- reset() β minimal alert, available tools, step budget
- step(action_type="inspect", tool_name=X) β tool output (costs 1 step)
- step(action_type="fix", bug_type=X, ...) β grader score (costs 1 step)
- Max 5 steps total across all inspect + fix actions
Efficiency bonus:
- Fix correct in β€2 total steps β score Γ 1.2 (capped at 0.99)
- Fix correct in β€3 total steps β score Γ 1.1
- Fix in 4-5 steps β score Γ 1.0
Single-bug tasks (6):
shape_mismatch, training_collapse, wrong_device,
gradient_not_zeroed, data_leakage, missing_eval_mode
Compound tasks β TWO bugs per script (2):
compound_shape_device, compound_leakage_eval
"""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, task_id: Optional[str] = None):
super().__init__()
self._task_id: Optional[str] = task_id
self._current_scenario: Optional[BugScenario] = None
self._state = DebugState(
episode_id=None,
step_count=0,
task_id="",
max_steps=MAX_STEPS,
current_score=0.0,
attempts=0,
tools_used=[],
fix_submitted=False,
)
self._episode_count = 0
self._scheduler = AdversarialScheduler(ALL_TASKS)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
task_id: Optional[str] = None,
**kwargs,
) -> DebugObservation:
active_task = task_id or self._task_id or self._scheduler.next_task()
effective_seed = seed if seed is not None else self._scheduler.next_seed(active_task)
scenario = get_scenario(active_task, seed=effective_seed)
eid = episode_id or str(uuid4())
state = DebugState(
episode_id=eid,
step_count=0,
task_id=active_task,
max_steps=MAX_STEPS,
current_score=0.0,
attempts=0,
tools_used=[],
fix_submitted=False,
)
self._current_scenario = scenario
self._state = state
self._episode_id = eid
_SESSION_STORE[eid] = {"scenario": scenario, "state": state}
return DebugObservation(
task_id=active_task,
alert=scenario.alert,
available_tools=AVAILABLE_TOOLS,
step_budget=MAX_STEPS,
step_number=0,
num_bugs=scenario.num_bugs,
action_type=None,
tool_name=None,
tool_result=None,
grader_score=None,
grader_feedback=None,
execution_result=None,
done=False,
reward=None,
efficiency_multiplier=None,
)
def step(
self,
action: DebugAction,
timeout_s: Optional[float] = None,
**kwargs,
) -> DebugObservation:
if self._current_scenario is None:
# Try to recover from session store using most recent session
if _SESSION_STORE:
latest = list(_SESSION_STORE.values())[-1]
self._current_scenario = latest["scenario"]
self._state = latest["state"]
else:
raise RuntimeError("Call reset() before step().")
self._state.step_count += 1
steps_remaining = MAX_STEPS - self._state.step_count
if action.action_type == "inspect":
return self._handle_inspect(action, steps_remaining)
elif action.action_type == "fix":
return self._handle_fix(action, steps_remaining)
else:
self._state.step_count -= 1
raise ValueError(f"Unknown action_type: '{action.action_type}'. Must be 'inspect' or 'fix'.")
def _handle_inspect(self, action: DebugAction, steps_remaining: int) -> DebugObservation:
tool_name = action.tool_name or ""
if tool_name not in AVAILABLE_TOOLS:
tool_result = (
f"Unknown tool: '{tool_name}'. "
f"Available tools: {AVAILABLE_TOOLS}"
)
else:
tool_result = execute_tool(tool_name, self._current_scenario)
self._state.tools_used.append(tool_name)
done = self._state.step_count >= MAX_STEPS
return DebugObservation(
task_id=self._state.task_id,
alert=self._current_scenario.alert,
available_tools=AVAILABLE_TOOLS,
step_budget=steps_remaining,
step_number=self._state.step_count,
num_bugs=self._current_scenario.num_bugs,
action_type="inspect",
tool_name=tool_name,
tool_result=tool_result,
grader_score=None,
grader_feedback=None,
execution_result=None,
done=done,
reward=0.0,
efficiency_multiplier=None,
)
def _handle_fix(self, action: DebugAction, steps_remaining: int) -> DebugObservation:
self._state.attempts += 1
self._state.fix_submitted = True
bug_type = action.bug_type or "other"
diagnosis = action.diagnosis or ""
fixed_code = action.fixed_code or ""
result: GradeResult = grade(
action_bug_type=bug_type,
action_diagnosis=diagnosis,
fixed_code=fixed_code,
scenario=self._current_scenario,
)
multiplier = _efficiency_multiplier(self._state.step_count, MAX_STEPS)
final_score = min(result.score * multiplier, 0.99)
if final_score > self._state.current_score:
self._state.current_score = final_score
done = final_score >= SUCCESS_THRESHOLD or self._state.step_count >= MAX_STEPS
if done:
self._scheduler.record(self._state.task_id, final_score)
return DebugObservation(
task_id=self._state.task_id,
alert=self._current_scenario.alert,
available_tools=AVAILABLE_TOOLS,
step_budget=steps_remaining,
step_number=self._state.step_count,
num_bugs=self._current_scenario.num_bugs,
action_type="fix",
tool_name=None,
tool_result=None,
grader_score=final_score,
grader_feedback=result.feedback,
execution_result=result.execution_output,
done=done,
reward=final_score,
efficiency_multiplier=multiplier,
)
@property
def state(self) -> DebugState:
return self._state
def get_metadata(self):
from openenv.core.env_server.types import EnvironmentMetadata
return EnvironmentMetadata(
name="ML Debugging Environment",
description=(
"Partially observable RL environment where agents debug broken PyTorch training scripts. "
"Agent sees only a minimal failure alert on reset β no code, no traceback. "
"Must use tool calls (run_code, get_traceback, inspect_gradients, print_shapes, view_source) "
"to investigate before submitting a fix. "
"5 steps total per episode. Efficiency bonus: fix in β€2 steps β Γ1.2 reward. "
"8 tasks: six single-bug (easyβhard), two compound double-bug tasks (expert). "
"Execution-based grading in subprocess."
),
version="4.0.0",
author="ml-debug-env",
) |