""" CodeSensei — OpenEnv-Compatible Client. Provides both sync and async clients that connect to the CodeDebug environment server. Supports the standard OpenEnv patterns: - CodeSenseiEnv.from_docker_image(image_name) - await env.reset() -> StepResult - await env.step(action) -> StepResult - await env.close() """ from __future__ import annotations import json import asyncio import subprocess import time import uuid from dataclasses import dataclass from typing import Optional, List try: import websockets import websockets.sync.client as ws_sync except ImportError: websockets = None # type: ignore ws_sync = None # type: ignore try: import aiohttp except ImportError: aiohttp = None # type: ignore from env.models import CodeDebugAction, CodeDebugObservation, TestResult # --------------------------------------------------------------------------- # Result wrapper — matches OpenEnv pattern: result.observation, result.reward, result.done # --------------------------------------------------------------------------- @dataclass class StepResult: """Wrapper returned by reset() and step() to match the OpenEnv pattern.""" observation: CodeDebugObservation reward: float = 0.0 done: bool = False # --------------------------------------------------------------------------- # Async client — matches the sample inference script pattern # --------------------------------------------------------------------------- class CodeSenseiEnv: """Async OpenEnv-compatible client for the CodeDebug environment. Usage (matches sample inference script): env = await CodeSenseiEnv.from_docker_image(IMAGE_NAME) result = await env.reset() result = await env.step(CodeSenseiAction(proposed_fix="...")) await env.close() """ def __init__(self, base_url: str): self.base_url = base_url.rstrip("/") self._session_id: str = "" self._container_id: Optional[str] = None @classmethod async def from_docker_image(cls, image_name: Optional[str] = None) -> "CodeSenseiEnv": """Start the environment from a Docker image (OpenEnv standard). If image_name is provided, starts a Docker container and connects. If image_name is None, connects to localhost:7860 (for local dev). """ if image_name: # Start Docker container port = 7860 container_id = subprocess.check_output( [ "docker", "run", "-d", "--rm", "-p", f"{port}:{port}", image_name, ], text=True, ).strip() # Wait for container to be ready base_url = f"http://localhost:{port}" env = cls(base_url) env._container_id = container_id # Poll health endpoint for _ in range(60): try: if aiohttp: async with aiohttp.ClientSession() as session: async with session.get(f"{base_url}/health", timeout=aiohttp.ClientTimeout(total=2)) as resp: if resp.status == 200: return env else: import urllib.request urllib.request.urlopen(f"{base_url}/health", timeout=2) return env except Exception: await asyncio.sleep(1) raise RuntimeError(f"Docker container {image_name} did not become healthy in 60s") else: # No image — connect to local server return cls("http://localhost:7860") async def reset(self) -> StepResult: """Start a new debugging episode.""" self._session_id = str(uuid.uuid4()) data = await self._post("/reset", {"session_id": self._session_id}) self._session_id = data.get("session_id", self._session_id) obs = self._parse_observation(data) return StepResult(observation=obs, reward=0.0, done=False) async def step(self, action: CodeDebugAction) -> StepResult: """Submit a proposed code fix.""" data = await self._post("/step", { "proposed_fix": action.proposed_fix, "session_id": self._session_id, }) obs = self._parse_observation(data) return StepResult(observation=obs, reward=obs.reward, done=obs.done) async def state(self) -> dict: """Get current episode state.""" return await self._get(f"/state?session_id={self._session_id}") async def close(self) -> None: """Clean up: stop Docker container if we started one.""" if self._container_id: try: subprocess.run( ["docker", "stop", self._container_id], capture_output=True, timeout=30, ) except Exception: pass self._container_id = None # --- HTTP helpers --- async def _post(self, path: str, payload: dict) -> dict: url = f"{self.base_url}{path}" if aiohttp: async with aiohttp.ClientSession() as session: async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp: return await resp.json() else: import urllib.request req = urllib.request.Request( url, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, ) with urllib.request.urlopen(req, timeout=30) as resp: return json.loads(resp.read().decode()) async def _get(self, path: str) -> dict: url = f"{self.base_url}{path}" if aiohttp: async with aiohttp.ClientSession() as session: async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: return await resp.json() else: import urllib.request with urllib.request.urlopen(url, timeout=30) as resp: return json.loads(resp.read().decode()) # --- Parse --- @staticmethod def _parse_observation(data: dict) -> CodeDebugObservation: test_results = [ TestResult( test_name=tr.get("test_name", ""), passed=tr.get("passed", False), error_message=tr.get("error_message", ""), ) for tr in data.get("test_results", []) ] return CodeDebugObservation( buggy_code=data.get("buggy_code", ""), current_code=data.get("current_code", ""), error_output=data.get("error_output", ""), test_results=test_results, tests_passed=data.get("tests_passed", 0), tests_total=data.get("tests_total", 0), reward=data.get("reward", 0.0), done=data.get("done", False), attempt=data.get("attempt", 0), max_attempts=data.get("max_attempts", 6), feedback=data.get("feedback", ""), ) # --------------------------------------------------------------------------- # Sync client (kept for backward compat / training) # --------------------------------------------------------------------------- class CodeDebugEnv: """Synchronous WebSocket client for the CodeDebug OpenEnv. Usage: env = CodeDebugEnv(base_url="wss://your-space.hf.space") obs = env.reset() obs = env.step("def add(a, b):\\n return a + b") """ def __init__(self, base_url: str): self.base_url = self._to_ws_url(base_url) self._ws = None self._session_id: str = "" def connect(self): if ws_sync is None: raise ImportError("websockets is required. pip install websockets") ws_url = f"{self.base_url}/ws" self._ws = ws_sync.connect(ws_url) return self def close(self): if self._ws: self._ws.close() self._ws = None def reset(self, session_id: str = "") -> CodeDebugObservation: if self._ws is None: self.connect() msg = {"type": "reset"} if session_id: msg["session_id"] = session_id self._ws.send(json.dumps(msg)) raw = self._ws.recv() data = json.loads(raw) self._session_id = data.get("session_id", session_id) return CodeSenseiEnv._parse_observation(data) def step(self, proposed_fix: str) -> CodeDebugObservation: if self._ws is None: raise RuntimeError("Not connected. Call connect() or reset() first.") msg = {"type": "step", "proposed_fix": proposed_fix} self._ws.send(json.dumps(msg)) raw = self._ws.recv() data = json.loads(raw) return CodeSenseiEnv._parse_observation(data) def state(self) -> dict: if self._ws is None: raise RuntimeError("Not connected.") self._ws.send(json.dumps({"type": "state"})) raw = self._ws.recv() return json.loads(raw) @property def session_id(self) -> str: return self._session_id def __enter__(self): self.connect() return self def __exit__(self, *args): self.close() @staticmethod def _to_ws_url(url: str) -> str: url = url.rstrip("/") if url.startswith("https://"): return "wss://" + url[8:] elif url.startswith("http://"): return "ws://" + url[7:] elif url.startswith("wss://") or url.startswith("ws://"): return url else: return "ws://" + url