Spaces:
Running
Running
File size: 4,474 Bytes
ff9fcbd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
HTTP client for the Code Review Environment.
Usage:
from client import CodeReviewEnv, ReviewAction
with CodeReviewEnv(base_url="http://localhost:7860").sync() as env:
result = env.reset(task_id="bug-detection")
obs = result.observation
print(obs.task_description)
print(obs.code_files)
# Flag an issue
result = env.step(ReviewAction(
action_type="flag_issue",
line_number=6,
filename="utils.py",
issue_type="bug",
severity="high",
description="Off-by-one in range()"
))
print(result.observation.feedback)
# Submit
result = env.step(ReviewAction(action_type="submit_review"))
print(f"Score: {result.reward:.3f}")
"""
from __future__ import annotations
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from typing import Optional, Generic, TypeVar
from models import ReviewAction, ReviewObservation, ReviewState, Issue
ObsT = TypeVar("ObsT")
class StepResult(Generic[ObsT]):
def __init__(
self,
observation: ObsT,
reward: Optional[float] = None,
done: bool = False,
):
self.observation = observation
self.reward = reward
self.done = done
def __repr__(self) -> str:
return (
f"StepResult(done={self.done}, reward={self.reward}, "
f"score={getattr(self.observation, 'current_score', None)})"
)
try:
import httpx
_HAS_HTTPX = True
except ImportError:
_HAS_HTTPX = False
try:
from openenv.core.http_env_client import HTTPEnvClient as _OfficialClient
_HAS_OPENENV_CLIENT = True
except ImportError:
_HAS_OPENENV_CLIENT = False
class SyncCodeReviewEnv:
def __init__(self, base_url: str = "http://localhost:7860"):
self.base_url = base_url.rstrip("/")
if not _HAS_HTTPX:
raise ImportError("httpx is required: pip install httpx")
import httpx
self._client = httpx.Client(timeout=30.0)
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def close(self):
self._client.close()
def reset(
self,
task_id: Optional[str] = None,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
) -> StepResult[ReviewObservation]:
body = {}
if task_id:
body["task_id"] = task_id
if seed is not None:
body["seed"] = seed
if episode_id:
body["episode_id"] = episode_id
resp = self._client.post(f"{self.base_url}/reset", json=body)
resp.raise_for_status()
obs = ReviewObservation.from_dict(resp.json())
return StepResult(observation=obs, reward=obs.reward, done=obs.done)
def step(self, action: ReviewAction) -> StepResult[ReviewObservation]:
body = action.to_dict()
resp = self._client.post(f"{self.base_url}/step", json=body)
resp.raise_for_status()
obs = ReviewObservation.from_dict(resp.json())
return StepResult(observation=obs, reward=obs.reward, done=obs.done)
def state(self) -> ReviewState:
resp = self._client.get(f"{self.base_url}/state")
resp.raise_for_status()
data = resp.json()
return ReviewState(
task_id=data.get("task_id", ""),
difficulty=data.get("difficulty", ""),
episode_id=data.get("episode_id"),
step_count=data.get("step_count", 0),
flagged_issues=[Issue.from_dict(i) for i in data.get("flagged_issues", [])],
current_score=data.get("current_score", 0.0),
submitted=data.get("submitted", False),
)
def health(self) -> dict:
resp = self._client.get(f"{self.base_url}/health")
resp.raise_for_status()
return resp.json()
def list_tasks(self) -> dict:
resp = self._client.get(f"{self.base_url}/tasks")
resp.raise_for_status()
return resp.json()
class CodeReviewEnv:
def __init__(self, base_url: str = "http://localhost:7860"):
self.base_url = base_url
def sync(self) -> SyncCodeReviewEnv:
return SyncCodeReviewEnv(self.base_url)
def __enter__(self):
self._sync = self.sync()
return self._sync
def __exit__(self, *args):
if hasattr(self, "_sync"):
self._sync.close()
|