vineetshukla.work@gmail.com commited on
Commit
b64950c
Β·
1 Parent(s): 5fcb94c

fix: task routing by name, remove out-of-range rewards, add grader field to tasks

Browse files
Files changed (3) hide show
  1. env/server/app.py +5 -2
  2. env/server/environment.py +13 -5
  3. openenv.yaml +8 -0
env/server/app.py CHANGED
@@ -25,6 +25,7 @@ from env.models import CodeDebugAction
25
 
26
  class ResetRequest(BaseModel):
27
  session_id: str = ""
 
28
 
29
 
30
  class StepRequest(BaseModel):
@@ -75,7 +76,8 @@ app.add_middleware(
75
  async def reset(request: Optional[ResetRequest] = None):
76
  """Start a new debugging episode."""
77
  session_id = request.session_id if request else str(uuid.uuid4())
78
- obs = env.reset(session_id=session_id)
 
79
  return _obs_to_dict(obs)
80
 
81
 
@@ -142,7 +144,8 @@ async def websocket_endpoint(websocket: WebSocket):
142
 
143
  if msg_type == "reset":
144
  session_id = msg.get("session_id", str(uuid.uuid4()))
145
- obs = env.reset(session_id=session_id)
 
146
  response = _obs_to_dict(obs)
147
  response["session_id"] = session_id
148
  response["type"] = "reset_response"
 
25
 
26
  class ResetRequest(BaseModel):
27
  session_id: str = ""
28
+ task: Optional[str] = None # task name from openenv.yaml e.g. "debug-add_numbers"
29
 
30
 
31
  class StepRequest(BaseModel):
 
76
  async def reset(request: Optional[ResetRequest] = None):
77
  """Start a new debugging episode."""
78
  session_id = request.session_id if request else str(uuid.uuid4())
79
+ task = request.task if request else None
80
+ obs = env.reset(session_id=session_id, task=task)
81
  return _obs_to_dict(obs)
82
 
83
 
 
144
 
145
  if msg_type == "reset":
146
  session_id = msg.get("session_id", str(uuid.uuid4()))
147
+ task = msg.get("task", None)
148
+ obs = env.reset(session_id=session_id, task=task)
149
  response = _obs_to_dict(obs)
150
  response["session_id"] = session_id
151
  response["type"] = "reset_response"
env/server/environment.py CHANGED
@@ -172,11 +172,13 @@ class CodeDebugEnvironment:
172
  _load_dataset()
173
  self._sessions: Dict[str, CodeDebugState] = {}
174
 
175
- def reset(self, session_id: str = "") -> CodeDebugObservation:
176
  """Start a new episode: sample a buggy function.
177
 
178
  Args:
179
  session_id: WebSocket session ID. Auto-generated if empty.
 
 
180
 
181
  Returns:
182
  Initial observation with the buggy code and test info.
@@ -184,8 +186,14 @@ class CodeDebugEnvironment:
184
  if not session_id:
185
  session_id = str(uuid.uuid4())
186
 
187
- # Sample a random bug
188
- bug = random.choice(_BUG_DATASET)
 
 
 
 
 
 
189
 
190
  # Create state
191
  state = CodeDebugState(
@@ -219,7 +227,7 @@ class CodeDebugEnvironment:
219
  test_results=test_results,
220
  tests_passed=passed,
221
  tests_total=total,
222
- reward=0.0,
223
  done=False,
224
  attempt=0,
225
  max_attempts=6,
@@ -279,7 +287,7 @@ class CodeDebugEnvironment:
279
  test_results=[],
280
  tests_passed=0,
281
  tests_total=len(bug["tests"]),
282
- reward=-1.0,
283
  done=done,
284
  attempt=state.attempt,
285
  max_attempts=state.max_attempts,
 
172
  _load_dataset()
173
  self._sessions: Dict[str, CodeDebugState] = {}
174
 
175
+ def reset(self, session_id: str = "", task: Optional[str] = None) -> CodeDebugObservation:
176
  """Start a new episode: sample a buggy function.
177
 
178
  Args:
179
  session_id: WebSocket session ID. Auto-generated if empty.
180
+ task: Optional task name from openenv.yaml (e.g. "debug-add_numbers").
181
+ If provided, selects the matching bug. Otherwise picks randomly.
182
 
183
  Returns:
184
  Initial observation with the buggy code and test info.
 
186
  if not session_id:
187
  session_id = str(uuid.uuid4())
188
 
189
+ # Select bug by task name or randomly
190
+ bug = None
191
+ if task:
192
+ # Strip "debug-" prefix to get function_name (e.g. "debug-add_numbers" -> "add_numbers")
193
+ fn_name = task.replace("debug-", "", 1)
194
+ bug = next((b for b in _BUG_DATASET if b["function_name"] == fn_name), None)
195
+ if bug is None:
196
+ bug = random.choice(_BUG_DATASET)
197
 
198
  # Create state
199
  state = CodeDebugState(
 
227
  test_results=test_results,
228
  tests_passed=passed,
229
  tests_total=total,
230
+ reward=0.01, # Non-zero initial reward (0.0 is forbidden by Phase 2)
231
  done=False,
232
  attempt=0,
233
  max_attempts=6,
 
287
  test_results=[],
288
  tests_passed=0,
289
  tests_total=len(bug["tests"]),
290
+ reward=0.01, # Clamped: syntax error gives minimum reward, not -1.0
291
  done=done,
292
  attempt=state.attempt,
293
  max_attempts=state.max_attempts,
openenv.yaml CHANGED
@@ -89,31 +89,39 @@ tasks:
89
  description: "Fix subtraction β†’ addition bug"
90
  max_steps: 6
91
  reward_range: [0.01, 0.99]
 
92
  - name: debug-find_max
93
  description: "Fix < β†’ > comparison bug"
94
  max_steps: 6
95
  reward_range: [0.01, 0.99]
 
96
  - name: debug-reverse_string
97
  description: "Fix slice β†’ reverse bug"
98
  max_steps: 6
99
  reward_range: [0.01, 0.99]
 
100
  - name: debug-fibonacci
101
  description: "Fix n-3 β†’ n-2 recursion bug"
102
  max_steps: 6
103
  reward_range: [0.01, 0.99]
 
104
  - name: debug-count_vowels
105
  description: "Fix missing case-insensitive bug"
106
  max_steps: 6
107
  reward_range: [0.01, 0.99]
 
108
  - name: debug-flatten_list
109
  description: "Fix append β†’ extend (recursive flatten) bug"
110
  max_steps: 6
111
  reward_range: [0.01, 0.99]
 
112
  - name: debug-merge_sorted
113
  description: "Fix missing remaining elements bug"
114
  max_steps: 6
115
  reward_range: [0.01, 0.99]
 
116
  - name: debug-remove_duplicates
117
  description: "Fix inverted condition bug"
118
  max_steps: 6
119
  reward_range: [0.01, 0.99]
 
 
89
  description: "Fix subtraction β†’ addition bug"
90
  max_steps: 6
91
  reward_range: [0.01, 0.99]
92
+ grader: environment
93
  - name: debug-find_max
94
  description: "Fix < β†’ > comparison bug"
95
  max_steps: 6
96
  reward_range: [0.01, 0.99]
97
+ grader: environment
98
  - name: debug-reverse_string
99
  description: "Fix slice β†’ reverse bug"
100
  max_steps: 6
101
  reward_range: [0.01, 0.99]
102
+ grader: environment
103
  - name: debug-fibonacci
104
  description: "Fix n-3 β†’ n-2 recursion bug"
105
  max_steps: 6
106
  reward_range: [0.01, 0.99]
107
+ grader: environment
108
  - name: debug-count_vowels
109
  description: "Fix missing case-insensitive bug"
110
  max_steps: 6
111
  reward_range: [0.01, 0.99]
112
+ grader: environment
113
  - name: debug-flatten_list
114
  description: "Fix append β†’ extend (recursive flatten) bug"
115
  max_steps: 6
116
  reward_range: [0.01, 0.99]
117
+ grader: environment
118
  - name: debug-merge_sorted
119
  description: "Fix missing remaining elements bug"
120
  max_steps: 6
121
  reward_range: [0.01, 0.99]
122
+ grader: environment
123
  - name: debug-remove_duplicates
124
  description: "Fix inverted condition bug"
125
  max_steps: 6
126
  reward_range: [0.01, 0.99]
127
+ grader: environment