codesensei-env / env /client.py
vineetshukla.work@gmail.com
fix: rewrite inference.py & client.py to match OpenEnv sample pattern
6cbd2ef
"""
CodeSensei — OpenEnv-Compatible Client.
Provides both sync and async clients that connect to the CodeDebug
environment server. Supports the standard OpenEnv patterns:
- CodeSenseiEnv.from_docker_image(image_name)
- await env.reset() -> StepResult
- await env.step(action) -> StepResult
- await env.close()
"""
from __future__ import annotations
import json
import asyncio
import subprocess
import time
import uuid
from dataclasses import dataclass
from typing import Optional, List
try:
import websockets
import websockets.sync.client as ws_sync
except ImportError:
websockets = None # type: ignore
ws_sync = None # type: ignore
try:
import aiohttp
except ImportError:
aiohttp = None # type: ignore
from env.models import CodeDebugAction, CodeDebugObservation, TestResult
# ---------------------------------------------------------------------------
# Result wrapper — matches OpenEnv pattern: result.observation, result.reward, result.done
# ---------------------------------------------------------------------------
@dataclass
class StepResult:
"""Wrapper returned by reset() and step() to match the OpenEnv pattern."""
observation: CodeDebugObservation
reward: float = 0.0
done: bool = False
# ---------------------------------------------------------------------------
# Async client — matches the sample inference script pattern
# ---------------------------------------------------------------------------
class CodeSenseiEnv:
"""Async OpenEnv-compatible client for the CodeDebug environment.
Usage (matches sample inference script):
env = await CodeSenseiEnv.from_docker_image(IMAGE_NAME)
result = await env.reset()
result = await env.step(CodeSenseiAction(proposed_fix="..."))
await env.close()
"""
def __init__(self, base_url: str):
self.base_url = base_url.rstrip("/")
self._session_id: str = ""
self._container_id: Optional[str] = None
@classmethod
async def from_docker_image(cls, image_name: Optional[str] = None) -> "CodeSenseiEnv":
"""Start the environment from a Docker image (OpenEnv standard).
If image_name is provided, starts a Docker container and connects.
If image_name is None, connects to localhost:7860 (for local dev).
"""
if image_name:
# Start Docker container
port = 7860
container_id = subprocess.check_output(
[
"docker", "run", "-d", "--rm",
"-p", f"{port}:{port}",
image_name,
],
text=True,
).strip()
# Wait for container to be ready
base_url = f"http://localhost:{port}"
env = cls(base_url)
env._container_id = container_id
# Poll health endpoint
for _ in range(60):
try:
if aiohttp:
async with aiohttp.ClientSession() as session:
async with session.get(f"{base_url}/health", timeout=aiohttp.ClientTimeout(total=2)) as resp:
if resp.status == 200:
return env
else:
import urllib.request
urllib.request.urlopen(f"{base_url}/health", timeout=2)
return env
except Exception:
await asyncio.sleep(1)
raise RuntimeError(f"Docker container {image_name} did not become healthy in 60s")
else:
# No image — connect to local server
return cls("http://localhost:7860")
async def reset(self) -> StepResult:
"""Start a new debugging episode."""
self._session_id = str(uuid.uuid4())
data = await self._post("/reset", {"session_id": self._session_id})
self._session_id = data.get("session_id", self._session_id)
obs = self._parse_observation(data)
return StepResult(observation=obs, reward=0.0, done=False)
async def step(self, action: CodeDebugAction) -> StepResult:
"""Submit a proposed code fix."""
data = await self._post("/step", {
"proposed_fix": action.proposed_fix,
"session_id": self._session_id,
})
obs = self._parse_observation(data)
return StepResult(observation=obs, reward=obs.reward, done=obs.done)
async def state(self) -> dict:
"""Get current episode state."""
return await self._get(f"/state?session_id={self._session_id}")
async def close(self) -> None:
"""Clean up: stop Docker container if we started one."""
if self._container_id:
try:
subprocess.run(
["docker", "stop", self._container_id],
capture_output=True, timeout=30,
)
except Exception:
pass
self._container_id = None
# --- HTTP helpers ---
async def _post(self, path: str, payload: dict) -> dict:
url = f"{self.base_url}{path}"
if aiohttp:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=30)) as resp:
return await resp.json()
else:
import urllib.request
req = urllib.request.Request(
url,
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode())
async def _get(self, path: str) -> dict:
url = f"{self.base_url}{path}"
if aiohttp:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
return await resp.json()
else:
import urllib.request
with urllib.request.urlopen(url, timeout=30) as resp:
return json.loads(resp.read().decode())
# --- Parse ---
@staticmethod
def _parse_observation(data: dict) -> CodeDebugObservation:
test_results = [
TestResult(
test_name=tr.get("test_name", ""),
passed=tr.get("passed", False),
error_message=tr.get("error_message", ""),
)
for tr in data.get("test_results", [])
]
return CodeDebugObservation(
buggy_code=data.get("buggy_code", ""),
current_code=data.get("current_code", ""),
error_output=data.get("error_output", ""),
test_results=test_results,
tests_passed=data.get("tests_passed", 0),
tests_total=data.get("tests_total", 0),
reward=data.get("reward", 0.0),
done=data.get("done", False),
attempt=data.get("attempt", 0),
max_attempts=data.get("max_attempts", 6),
feedback=data.get("feedback", ""),
)
# ---------------------------------------------------------------------------
# Sync client (kept for backward compat / training)
# ---------------------------------------------------------------------------
class CodeDebugEnv:
"""Synchronous WebSocket client for the CodeDebug OpenEnv.
Usage:
env = CodeDebugEnv(base_url="wss://your-space.hf.space")
obs = env.reset()
obs = env.step("def add(a, b):\\n return a + b")
"""
def __init__(self, base_url: str):
self.base_url = self._to_ws_url(base_url)
self._ws = None
self._session_id: str = ""
def connect(self):
if ws_sync is None:
raise ImportError("websockets is required. pip install websockets")
ws_url = f"{self.base_url}/ws"
self._ws = ws_sync.connect(ws_url)
return self
def close(self):
if self._ws:
self._ws.close()
self._ws = None
def reset(self, session_id: str = "") -> CodeDebugObservation:
if self._ws is None:
self.connect()
msg = {"type": "reset"}
if session_id:
msg["session_id"] = session_id
self._ws.send(json.dumps(msg))
raw = self._ws.recv()
data = json.loads(raw)
self._session_id = data.get("session_id", session_id)
return CodeSenseiEnv._parse_observation(data)
def step(self, proposed_fix: str) -> CodeDebugObservation:
if self._ws is None:
raise RuntimeError("Not connected. Call connect() or reset() first.")
msg = {"type": "step", "proposed_fix": proposed_fix}
self._ws.send(json.dumps(msg))
raw = self._ws.recv()
data = json.loads(raw)
return CodeSenseiEnv._parse_observation(data)
def state(self) -> dict:
if self._ws is None:
raise RuntimeError("Not connected.")
self._ws.send(json.dumps({"type": "state"}))
raw = self._ws.recv()
return json.loads(raw)
@property
def session_id(self) -> str:
return self._session_id
def __enter__(self):
self.connect()
return self
def __exit__(self, *args):
self.close()
@staticmethod
def _to_ws_url(url: str) -> str:
url = url.rstrip("/")
if url.startswith("https://"):
return "wss://" + url[8:]
elif url.startswith("http://"):
return "ws://" + url[7:]
elif url.startswith("wss://") or url.startswith("ws://"):
return url
else:
return "ws://" + url