Spaces:
Paused
Paused
| import json | |
| import os | |
| import random | |
| import sys | |
| from websocket import create_connection | |
| # ============ GLOBAL CONFIG ============ | |
| SERVER_URL = "ws://localhost:8000/ws" | |
| OUTPUT_FILE = "demo_http_output.txt" | |
| GAME_ID = "cicero" | |
| NUM_EPISODES = 1 | |
| DIFFICULTY = 2 | |
| TURNS_PER_EPISODE = 5 | |
| WS_TIMEOUT = 120 # seconds (reset/step can take time with LLM) | |
| # ======================================= | |
| def send_and_receive(ws, message: dict) -> dict: | |
| """Send a message and receive response.""" | |
| ws.send(json.dumps(message)) | |
| response = ws.recv() | |
| if not response: | |
| raise RuntimeError("Server returned empty response") | |
| try: | |
| result = json.loads(response) | |
| except json.JSONDecodeError as e: | |
| raise RuntimeError(f"Server returned invalid JSON (first 200 chars): {repr(response[:200])}") from e | |
| if result.get("type") == "error": | |
| raise RuntimeError(f"Server error: {result.get('data', {}).get('message', result)}") | |
| return result | |
| with open(OUTPUT_FILE, "w") as f: | |
| f.write(f"Server: {SERVER_URL}\n") | |
| f.write(f"Game: {GAME_ID} | Episodes: {NUM_EPISODES} | Difficulty: {DIFFICULTY}\n\n") | |
| f.flush() | |
| for ep in range(NUM_EPISODES): | |
| try: | |
| ws = create_connection(SERVER_URL, timeout=WS_TIMEOUT) | |
| except Exception as e: | |
| print(f"ERROR: Cannot connect to {SERVER_URL}: {e}", file=sys.stderr) | |
| print("Make sure the server is running: uvicorn watchdog_env.server.app:app --port 8000 --host 0.0.0.0", file=sys.stderr) | |
| sys.exit(1) | |
| try: | |
| # Reset environment - data contains reset params (game_id, level, seed) | |
| result = send_and_receive(ws, { | |
| "type": "reset", | |
| "data": {"seed": ep + 42, "game_id": GAME_ID, "level": DIFFICULTY} | |
| }) | |
| # Response has type="observation" with data containing the actual observation | |
| obs = result.get("data", {}).get("observation", {}) | |
| done = result.get("data", {}).get("done", False) | |
| f.write(f"EPISODE {ep + 1}\n") | |
| for turn in range(TURNS_PER_EPISODE): | |
| if done or obs.get("phase") == "done": | |
| break | |
| f.write(f"\n TURN {turn + 1}\n") | |
| data = result.get("data", {}) | |
| reward = data.get("reward") or obs.get("step_reward") | |
| f.write(f" reward: {reward}\n") | |
| f.write(f" state: {obs.get('current_turn') or '(no turn yet)'}\n") | |
| # Take step - data contains the action | |
| action = random.choice(["pass", "flag", "question"]) | |
| result = send_and_receive(ws, { | |
| "type": "step", | |
| "data": {"action_type": action} | |
| }) | |
| obs = result.get("data", {}).get("observation", {}) | |
| done = result.get("data", {}).get("done", False) | |
| f.write(f" action: {action}\n") | |
| f.write(f"\n{'='*40}\n\n") | |
| print(f"Episode {ep + 1} done") | |
| finally: | |
| ws.close() | |
| print(f"Saved to {OUTPUT_FILE}") | |