Spaces:
Sleeping
Sleeping
vineetshukla.work@gmail.com
fix: rewrite inference.py & client.py to match OpenEnv sample pattern
6cbd2ef | """ | |
| CodeSensei — OpenEnv-Compatible Client. | |
| Provides both sync and async clients that connect to the CodeDebug | |
| environment server. Supports the standard OpenEnv patterns: | |
| - CodeSenseiEnv.from_docker_image(image_name) | |
| - await env.reset() -> StepResult | |
| - await env.step(action) -> StepResult | |
| - await env.close() | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import asyncio | |
| import subprocess | |
| import time | |
| import uuid | |
| from dataclasses import dataclass | |
| from typing import Optional, List | |
| try: | |
| import websockets | |
| import websockets.sync.client as ws_sync | |
| except ImportError: | |
| websockets = None # type: ignore | |
| ws_sync = None # type: ignore | |
| try: | |
| import aiohttp | |
| except ImportError: | |
| aiohttp = None # type: ignore | |
| from env.models import CodeDebugAction, CodeDebugObservation, TestResult | |
| # --------------------------------------------------------------------------- | |
| # Result wrapper — matches OpenEnv pattern: result.observation, result.reward, result.done | |
| # --------------------------------------------------------------------------- | |
| class StepResult: | |
| """Wrapper returned by reset() and step() to match the OpenEnv pattern.""" | |
| observation: CodeDebugObservation | |
| reward: float = 0.0 | |
| done: bool = False | |
| # --------------------------------------------------------------------------- | |
| # Async client — matches the sample inference script pattern | |
| # --------------------------------------------------------------------------- | |
| class CodeSenseiEnv: | |
| """Async OpenEnv-compatible client for the CodeDebug environment. | |
| Usage (matches sample inference script): | |
| env = await CodeSenseiEnv.from_docker_image(IMAGE_NAME) | |
| result = await env.reset() | |
| result = await env.step(CodeSenseiAction(proposed_fix="...")) | |
| await env.close() | |
| """ | |
| def __init__(self, base_url: str): | |
| self.base_url = base_url.rstrip("/") | |
| self._session_id: str = "" | |
| self._container_id: Optional[str] = None | |
| async def from_docker_image(cls, image_name: Optional[str] = None) -> "CodeSenseiEnv": | |
| """Start the environment from a Docker image (OpenEnv standard). | |
| If image_name is provided, starts a Docker container and connects. | |
| If image_name is None, connects to localhost:7860 (for local dev). | |
| """ | |
| if image_name: | |
| # Start Docker container | |
| port = 7860 | |
| container_id = subprocess.check_output( | |
| [ | |
| "docker", "run", "-d", "--rm", | |
| "-p", f"{port}:{port}", | |
| image_name, | |
| ], | |
| text=True, | |
| ).strip() | |
| # Wait for container to be ready | |
| base_url = f"http://localhost:{port}" | |
| env = cls(base_url) | |
| env._container_id = container_id | |
| # Poll health endpoint | |
| for _ in range(60): | |
| try: | |
| if aiohttp: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(f"{base_url}/health", timeout=aiohttp.ClientTimeout(total=2)) as resp: | |
| if resp.status == 200: | |
| return env | |
| else: | |
| import urllib.request | |
| urllib.request.urlopen(f"{base_url}/health", timeout=2) | |
| return env | |
| except Exception: | |
| await asyncio.sleep(1) | |
| raise RuntimeError(f"Docker container {image_name} did not become healthy in 60s") | |
| else: | |
| # No image — connect to local server | |
| return cls("http://localhost:7860") | |
| async def reset(self) -> StepResult: | |
| """Start a new debugging episode.""" | |
| self._session_id = str(uuid.uuid4()) | |
| data = await self._post("/reset", {"session_id": self._session_id}) | |
| self._session_id = data.get("session_id", self._session_id) | |
| obs = self._parse_observation(data) | |
| return StepResult(observation=obs, reward=0.0, done=False) | |
| async def step(self, action: CodeDebugAction) -> StepResult: | |
| """Submit a proposed code fix.""" | |
| data = await self._post("/step", { | |
| "proposed_fix": action.proposed_fix, | |
| "session_id": self._session_id, | |
| }) | |
| obs = self._parse_observation(data) | |
| return StepResult(observation=obs, reward=obs.reward, done=obs.done) | |
| async def state(self) -> dict: | |
| """Get current episode state.""" | |
| return await self._get(f"/state?session_id={self._session_id}") | |
| async def close(self) -> None: | |
| """Clean up: stop Docker container if we started one.""" | |
| if self._container_id: | |
| try: | |
| subprocess.run( | |
| ["docker", "stop", self._container_id], | |
| capture_output=True, timeout=30, | |
| ) | |
| except Exception: | |
| pass | |
| self._container_id = None | |
| # --- HTTP helpers --- | |
| async def _post(self, path: str, payload: dict) -> dict: | |
| url = f"{self.base_url}{path}" | |
| if aiohttp: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp: | |
| return await resp.json() | |
| else: | |
| import urllib.request | |
| req = urllib.request.Request( | |
| url, | |
| data=json.dumps(payload).encode(), | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| with urllib.request.urlopen(req, timeout=30) as resp: | |
| return json.loads(resp.read().decode()) | |
| async def _get(self, path: str) -> dict: | |
| url = f"{self.base_url}{path}" | |
| if aiohttp: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: | |
| return await resp.json() | |
| else: | |
| import urllib.request | |
| with urllib.request.urlopen(url, timeout=30) as resp: | |
| return json.loads(resp.read().decode()) | |
| # --- Parse --- | |
| def _parse_observation(data: dict) -> CodeDebugObservation: | |
| test_results = [ | |
| TestResult( | |
| test_name=tr.get("test_name", ""), | |
| passed=tr.get("passed", False), | |
| error_message=tr.get("error_message", ""), | |
| ) | |
| for tr in data.get("test_results", []) | |
| ] | |
| return CodeDebugObservation( | |
| buggy_code=data.get("buggy_code", ""), | |
| current_code=data.get("current_code", ""), | |
| error_output=data.get("error_output", ""), | |
| test_results=test_results, | |
| tests_passed=data.get("tests_passed", 0), | |
| tests_total=data.get("tests_total", 0), | |
| reward=data.get("reward", 0.0), | |
| done=data.get("done", False), | |
| attempt=data.get("attempt", 0), | |
| max_attempts=data.get("max_attempts", 6), | |
| feedback=data.get("feedback", ""), | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Sync client (kept for backward compat / training) | |
| # --------------------------------------------------------------------------- | |
| class CodeDebugEnv: | |
| """Synchronous WebSocket client for the CodeDebug OpenEnv. | |
| Usage: | |
| env = CodeDebugEnv(base_url="wss://your-space.hf.space") | |
| obs = env.reset() | |
| obs = env.step("def add(a, b):\\n return a + b") | |
| """ | |
| def __init__(self, base_url: str): | |
| self.base_url = self._to_ws_url(base_url) | |
| self._ws = None | |
| self._session_id: str = "" | |
| def connect(self): | |
| if ws_sync is None: | |
| raise ImportError("websockets is required. pip install websockets") | |
| ws_url = f"{self.base_url}/ws" | |
| self._ws = ws_sync.connect(ws_url) | |
| return self | |
| def close(self): | |
| if self._ws: | |
| self._ws.close() | |
| self._ws = None | |
| def reset(self, session_id: str = "") -> CodeDebugObservation: | |
| if self._ws is None: | |
| self.connect() | |
| msg = {"type": "reset"} | |
| if session_id: | |
| msg["session_id"] = session_id | |
| self._ws.send(json.dumps(msg)) | |
| raw = self._ws.recv() | |
| data = json.loads(raw) | |
| self._session_id = data.get("session_id", session_id) | |
| return CodeSenseiEnv._parse_observation(data) | |
| def step(self, proposed_fix: str) -> CodeDebugObservation: | |
| if self._ws is None: | |
| raise RuntimeError("Not connected. Call connect() or reset() first.") | |
| msg = {"type": "step", "proposed_fix": proposed_fix} | |
| self._ws.send(json.dumps(msg)) | |
| raw = self._ws.recv() | |
| data = json.loads(raw) | |
| return CodeSenseiEnv._parse_observation(data) | |
| def state(self) -> dict: | |
| if self._ws is None: | |
| raise RuntimeError("Not connected.") | |
| self._ws.send(json.dumps({"type": "state"})) | |
| raw = self._ws.recv() | |
| return json.loads(raw) | |
| def session_id(self) -> str: | |
| return self._session_id | |
| def __enter__(self): | |
| self.connect() | |
| return self | |
| def __exit__(self, *args): | |
| self.close() | |
| def _to_ws_url(url: str) -> str: | |
| url = url.rstrip("/") | |
| if url.startswith("https://"): | |
| return "wss://" + url[8:] | |
| elif url.startswith("http://"): | |
| return "ws://" + url[7:] | |
| elif url.startswith("wss://") or url.startswith("ws://"): | |
| return url | |
| else: | |
| return "ws://" + url | |