samrat-rm commited on
Commit
61e83f1
·
1 Parent(s): 89b370c

feat: rewards upgrade

Browse files
Files changed (1) hide show
  1. inference.py +8 -2
inference.py CHANGED
@@ -111,6 +111,8 @@ SYSTEM_PROMPT = textwrap.dedent("""
111
  - Never inspect the same source twice.
112
  """).strip()
113
 
 
 
114
 
115
  def _user_prompt(step: int, obs_summary: str, history: List[str]) -> str:
116
  history_block = "\n".join(history[-4:]) if history else "None"
@@ -212,7 +214,7 @@ async def run_episode(
212
  print(f"[STEP] step={step} action={action.action_type} reward=0.10 done=true error={e}", flush=True)
213
  break
214
  obs = result.observation
215
- reward = round(max(0.10, min(0.90, obs.reward)), 2)
216
  done = result.done
217
  if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
218
  source = action.action_type.replace("inspect_", "")
@@ -225,6 +227,7 @@ async def run_episode(
225
  rewards.append(reward)
226
  data_seen = json.dumps(obs.visible_data) if obs.visible_data else "{}"
227
  history.append(f"Step {step}: {action.action_type} → reward={reward:.2f} | {obs.feedback}\n Data: {data_seen}")
 
228
  print(f"[STEP] step={step} action={action.action_type} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
229
 
230
  if done:
@@ -255,6 +258,7 @@ async def run_episode(
255
  finally:
256
  steps_taken = len(rewards)
257
  rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.10"
 
258
  print(f"[END] success={str(success).lower()} steps={steps_taken} rewards={rewards_str}", flush=True)
259
 
260
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
@@ -281,7 +285,8 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
281
 
282
  scores = [r["score"] for r in results]
283
  task_score = round(max(0.10, min(0.90, sum(scores) / len(scores))), 2) if scores else 0.10
284
- print(f"[END] score={task_score}", flush=True)
 
285
  return scores
286
 
287
 
@@ -296,6 +301,7 @@ async def main() -> None:
296
  scores += await run_task("task_hard", HARD_SCENARIOS, env, client)
297
  pass # scoring is handled by the yaml grader, not stdout
298
  finally:
 
299
  try:
300
  await env.close()
301
  except Exception as e:
 
111
  - Never inspect the same source twice.
112
  """).strip()
113
 
114
+ numbers = []
115
+
116
 
117
  def _user_prompt(step: int, obs_summary: str, history: List[str]) -> str:
118
  history_block = "\n".join(history[-4:]) if history else "None"
 
214
  print(f"[STEP] step={step} action={action.action_type} reward=0.10 done=true error={e}", flush=True)
215
  break
216
  obs = result.observation
217
+ reward = round(max(0.10, min(0.90, result.reward or 0.10)), 2)
218
  done = result.done
219
  if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
220
  source = action.action_type.replace("inspect_", "")
 
227
  rewards.append(reward)
228
  data_seen = json.dumps(obs.visible_data) if obs.visible_data else "{}"
229
  history.append(f"Step {step}: {action.action_type} → reward={reward:.2f} | {obs.feedback}\n Data: {data_seen}")
230
+ numbers.append(f"{reward:.2f}")
231
  print(f"[STEP] step={step} action={action.action_type} reward={reward:.2f} done={str(done).lower()} error=null", flush=True)
232
 
233
  if done:
 
258
  finally:
259
  steps_taken = len(rewards)
260
  rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.10"
261
+ numbers.append(f"{rewards_str}")
262
  print(f"[END] success={str(success).lower()} steps={steps_taken} rewards={rewards_str}", flush=True)
263
 
264
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
 
285
 
286
  scores = [r["score"] for r in results]
287
  task_score = round(max(0.10, min(0.90, sum(scores) / len(scores))), 2) if scores else 0.10
288
+ numbers.append(f"{task_score:.2f}")
289
+ print(f"[END] score={task_score:.2f}", flush=True)
290
  return scores
291
 
292
 
 
301
  scores += await run_task("task_hard", HARD_SCENARIOS, env, client)
302
  pass # scoring is handled by the yaml grader, not stdout
303
  finally:
304
+ raise Exception(numbers)
305
  try:
306
  await env.close()
307
  except Exception as e: