Spaces:
Sleeping
Sleeping
File size: 5,731 Bytes
9b47159 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
HTTP clients for the Bug Triage OpenEnv environment.
Provides synchronous (BugTriageEnvClient) and asynchronous
(AsyncBugTriageEnvClient) clients for interacting with the server.
Sync usage:
from bug_triage_env.client import BugTriageEnvClient
from bug_triage_env.models import BugTriageAction
with BugTriageEnvClient("http://localhost:8000") as env:
obs = env.reset(task_id="task_1")
action = BugTriageAction(task_id="task_1", bug_type="crash")
result = env.step(obs["episode_id"], action)
print(result["grader_score"])
Async usage:
from bug_triage_env.client import AsyncBugTriageEnvClient
async with AsyncBugTriageEnvClient("http://localhost:8000") as env:
obs = await env.reset("task_3")
"""
from __future__ import annotations
import os
from typing import Any, Dict, Optional
import requests
from .models import BugTriageAction
DEFAULT_URL = os.getenv("BUG_TRIAGE_ENV_URL", "http://localhost:8000")
class BugTriageEnvClient:
"""
Synchronous HTTP client for the Bug Triage environment.
Follows the OpenEnv HTTPEnvClient interface: reset(), step(), state().
"""
def __init__(self, base_url: str = DEFAULT_URL, timeout: int = 30) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._session = requests.Session()
self._episode_id: Optional[str] = None
def reset(self, task_id: str = "task_1") -> Dict[str, Any]:
"""Start a new episode. Returns observation dict with episode_id."""
resp = self._session.post(
f"{self.base_url}/reset",
json={"task_id": task_id},
timeout=self.timeout,
)
resp.raise_for_status()
data = resp.json()
self._episode_id = data.get("episode_id")
return data
def step(
self,
episode_id: Optional[str] = None,
action: Optional[BugTriageAction] = None,
) -> Dict[str, Any]:
"""Execute one action. Returns observation with grader_score."""
ep_id = episode_id or self._episode_id
if ep_id is None:
raise ValueError("episode_id required. Call reset() first.")
if action is None:
raise ValueError("action required.")
resp = self._session.post(
f"{self.base_url}/step",
json={"episode_id": ep_id, "action": action.model_dump()},
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
def state(self) -> Dict[str, Any]:
"""Return current episode state."""
resp = self._session.get(
f"{self.base_url}/state", timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
def health(self) -> bool:
"""Check if the server is healthy."""
try:
resp = self._session.get(
f"{self.base_url}/health", timeout=self.timeout,
)
return resp.json().get("status") == "healthy"
except Exception:
return False
def tasks(self) -> Dict[str, Any]:
"""Get task list and action schemas."""
resp = self._session.get(
f"{self.base_url}/tasks", timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
def grade(self, episode_id: str, task_id: str) -> Dict[str, Any]:
"""Get grader score for a completed episode."""
resp = self._session.post(
f"{self.base_url}/grader",
json={"episode_id": episode_id, "task_id": task_id},
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
def __enter__(self) -> BugTriageEnvClient:
return self
def __exit__(self, *args: Any) -> None:
self._session.close()
class AsyncBugTriageEnvClient:
"""
Async HTTP client using aiohttp.
Designed for use with GRPO training and vLLM colocate mode.
"""
def __init__(self, base_url: str = DEFAULT_URL, timeout: int = 30) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._episode_id: Optional[str] = None
self._session: Any = None
async def __aenter__(self) -> AsyncBugTriageEnvClient:
import aiohttp
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
return self
async def __aexit__(self, *args: Any) -> None:
if self._session:
await self._session.close()
async def reset(self, task_id: str = "task_1") -> Dict[str, Any]:
"""Start a new episode."""
async with self._session.post(
f"{self.base_url}/reset",
json={"task_id": task_id},
) as resp:
data = await resp.json()
self._episode_id = data.get("episode_id")
return data
async def step(
self,
episode_id: Optional[str] = None,
action: Optional[BugTriageAction] = None,
) -> Dict[str, Any]:
"""Execute one triage action."""
ep_id = episode_id or self._episode_id
if ep_id is None:
raise ValueError("Call reset() first.")
if action is None:
raise ValueError("action required.")
async with self._session.post(
f"{self.base_url}/step",
json={"episode_id": ep_id, "action": action.model_dump()},
) as resp:
return await resp.json()
async def state(self) -> Dict[str, Any]:
"""Return current episode state."""
async with self._session.get(f"{self.base_url}/state") as resp:
return await resp.json()
|