Spaces:
Sleeping
Sleeping
feat: adding [END] log for each episode and error handling for websocket
Browse files- inference.py +13 -2
inference.py
CHANGED
|
@@ -16,6 +16,7 @@ STDOUT FORMAT
|
|
| 16 |
[STEP] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
|
| 17 |
[RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
|
| 18 |
[SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
import asyncio
|
|
@@ -24,6 +25,8 @@ import os
|
|
| 24 |
import textwrap
|
| 25 |
from typing import List
|
| 26 |
|
|
|
|
|
|
|
| 27 |
from dotenv import load_dotenv
|
| 28 |
load_dotenv()
|
| 29 |
|
|
@@ -119,9 +122,12 @@ def _get_action(client: OpenAI, step: int, obs_summary: str, history: List[str])
|
|
| 119 |
],
|
| 120 |
temperature=TEMPERATURE,
|
| 121 |
max_tokens=MAX_TOKENS,
|
|
|
|
| 122 |
)
|
| 123 |
text = (completion.choices[0].message.content or "").strip()
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
except Exception as exc:
|
| 126 |
print(f" [DEBUG] parse error: {exc}", flush=True)
|
| 127 |
if step <= 2:
|
|
@@ -144,7 +150,11 @@ async def run_episode(env: WhyDidItFailEnv, client: OpenAI, scenario_key: str) -
|
|
| 144 |
break
|
| 145 |
|
| 146 |
action = _get_action(client, step, _summarize(obs), history)
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
obs = result.observation
|
| 149 |
reward = result.reward or 0.0
|
| 150 |
done = result.done
|
|
@@ -219,6 +229,7 @@ async def main() -> None:
|
|
| 219 |
await run_task("easy", EASY_SCENARIOS, env, client)
|
| 220 |
await run_task("medium", MEDIUM_SCENARIOS, env, client)
|
| 221 |
await run_task("hard", HARD_SCENARIOS, env, client)
|
|
|
|
| 222 |
finally:
|
| 223 |
try:
|
| 224 |
await env.close()
|
|
|
|
| 16 |
[STEP] scenario=<key> step=<n> action=<json> reward=<0.00> done=<bool>
|
| 17 |
[RESULT] scenario=<key> score=<0.000> steps=<n> success=<bool>
|
| 18 |
[SUMMARY] task=<task_name> avg_score=<0.000> pass_rate=<0.00>
|
| 19 |
+
[END] all tasks complete
|
| 20 |
"""
|
| 21 |
|
| 22 |
import asyncio
|
|
|
|
| 25 |
import textwrap
|
| 26 |
from typing import List
|
| 27 |
|
| 28 |
+
from websockets.exceptions import ConnectionClosedError
|
| 29 |
+
|
| 30 |
from dotenv import load_dotenv
|
| 31 |
load_dotenv()
|
| 32 |
|
|
|
|
| 122 |
],
|
| 123 |
temperature=TEMPERATURE,
|
| 124 |
max_tokens=MAX_TOKENS,
|
| 125 |
+
response_format={"type": "json_object"},
|
| 126 |
)
|
| 127 |
text = (completion.choices[0].message.content or "").strip()
|
| 128 |
+
data = json.loads(text)
|
| 129 |
+
filtered = {k: v for k, v in data.items() if k in WhyDidItFailAction.model_fields}
|
| 130 |
+
return WhyDidItFailAction(**filtered)
|
| 131 |
except Exception as exc:
|
| 132 |
print(f" [DEBUG] parse error: {exc}", flush=True)
|
| 133 |
if step <= 2:
|
|
|
|
| 150 |
break
|
| 151 |
|
| 152 |
action = _get_action(client, step, _summarize(obs), history)
|
| 153 |
+
try:
|
| 154 |
+
result = await env.step(action)
|
| 155 |
+
except ConnectionClosedError as e:
|
| 156 |
+
print(f" [WARN] scenario={scenario_key} step={step} WebSocket dropped: {e}", flush=True)
|
| 157 |
+
break
|
| 158 |
obs = result.observation
|
| 159 |
reward = result.reward or 0.0
|
| 160 |
done = result.done
|
|
|
|
| 229 |
await run_task("easy", EASY_SCENARIOS, env, client)
|
| 230 |
await run_task("medium", MEDIUM_SCENARIOS, env, client)
|
| 231 |
await run_task("hard", HARD_SCENARIOS, env, client)
|
| 232 |
+
print("[END] all tasks complete", flush=True)
|
| 233 |
finally:
|
| 234 |
try:
|
| 235 |
await env.close()
|