samrat-rm commited on
Commit
a310ad6
·
1 Parent(s): 7cc0ee9

feat: adding [END] log for each episode and error handling for websocket

Browse files
Files changed (1) hide show
  1. inference.py +13 -2
inference.py CHANGED
@@ -16,6 +16,7 @@ STDOUT FORMAT
16
  [STEP] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
17
  [RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
18
  [SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
 
19
  """
20
 
21
  import asyncio
@@ -24,6 +25,8 @@ import os
24
  import textwrap
25
  from typing import List
26
 
 
 
27
  from dotenv import load_dotenv
28
  load_dotenv()
29
 
@@ -119,9 +122,12 @@ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str])
119
  ],
120
  temperature=TEMPERATURE,
121
  max_tokens=MAX_TOKENS,
 
122
  )
123
  text = (completion.choices[0].message.content or "").strip()
124
- return WhyDidItFailAction(**json.loads(text))
 
 
125
  except Exception as exc:
126
  print(f" [DEBUG] parse error: {exc}", flush=True)
127
  if step <= 2:
@@ -144,7 +150,11 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
144
  break
145
 
146
  action = _get_action(client, step, _summarize(obs), history)
147
- result = await env.step(action)
 
 
 
 
148
  obs = result.observation
149
  reward = result.reward or 0.0
150
  done = result.done
@@ -219,6 +229,7 @@ async def main() -> None:
219
  await run_task("easy", EASY_SCENARIOS, env, client)
220
  await run_task("medium", MEDIUM_SCENARIOS, env, client)
221
  await run_task("hard", HARD_SCENARIOS, env, client)
 
222
  finally:
223
  try:
224
  await env.close()
 
16
  [STEP] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
17
  [RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
18
  [SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
19
+ [END] all tasks complete
20
  """
21
 
22
  import asyncio
 
25
  import textwrap
26
  from typing import List
27
 
28
+ from websockets.exceptions import ConnectionClosedError
29
+
30
  from dotenv import load_dotenv
31
  load_dotenv()
32
 
 
122
  ],
123
  temperature=TEMPERATURE,
124
  max_tokens=MAX_TOKENS,
125
+ response_format={"type": "json_object"},
126
  )
127
  text = (completion.choices[0].message.content or "").strip()
128
+ data = json.loads(text)
129
+ filtered = {k: v for k, v in data.items() if k in WhyDidItFailAction.model_fields}
130
+ return WhyDidItFailAction(**filtered)
131
  except Exception as exc:
132
  print(f" [DEBUG] parse error: {exc}", flush=True)
133
  if step <= 2:
 
150
  break
151
 
152
  action = _get_action(client, step, _summarize(obs), history)
153
+ try:
154
+ result = await env.step(action)
155
+ except ConnectionClosedError as e:
156
+ print(f" [WARN] scenario={scenario_key} step={step} WebSocket dropped: {e}", flush=True)
157
+ break
158
  obs = result.observation
159
  reward = result.reward or 0.0
160
  done = result.done
 
229
  await run_task("easy", EASY_SCENARIOS, env, client)
230
  await run_task("medium", MEDIUM_SCENARIOS, env, client)
231
  await run_task("hard", HARD_SCENARIOS, env, client)
232
+ print("[END] all tasks complete", flush=True)
233
  finally:
234
  try:
235
  await env.close()