Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import httpx | |
| from openenv.core.env_client import EnvClient, StepResult | |
| from models import DataAnalysisAction, DataAnalysisObservation, DataAnalysisState | |
| class DataAnalysisEnv(EnvClient): | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self._base_url = base_url.rstrip("/") | |
| if self._base_url.startswith("ws://"): | |
| self._base_url = self._base_url.replace("ws://", "http://") | |
| elif not self._base_url.startswith("http://"): | |
| self._base_url = "http://" + self._base_url | |
| self._client: Optional[httpx.AsyncClient] = None | |
| def _get_client(self) -> httpx.AsyncClient: | |
| if self._client is None: | |
| self._client = httpx.AsyncClient(base_url=self._base_url, timeout=60.0) | |
| return self._client | |
| async def reset(self, task: str = "task_1", **kwargs) -> StepResult: | |
| client = self._get_client() | |
| response = await client.post("/reset", json={"task": task}) | |
| response.raise_for_status() | |
| data = response.json() | |
| return self._parse_result(data) | |
| async def step(self, action: DataAnalysisAction) -> StepResult: | |
| payload = { | |
| "action": { | |
| "tool": action.tool, | |
| "parameters": action.parameters, | |
| } | |
| } | |
| client = self._get_client() | |
| response = await client.post("/step", json=payload) | |
| response.raise_for_status() | |
| data = response.json() | |
| return self._parse_result(data) | |
| async def state(self) -> DataAnalysisState: | |
| client = self._get_client() | |
| response = await client.get("/state") | |
| response.raise_for_status() | |
| data = response.json() | |
| return DataAnalysisState(**data) | |
| async def close(self): | |
| if self._client: | |
| await self._client.aclose() | |
| self._client = None | |
| def _parse_result(payload: dict) -> StepResult: | |
| obs = DataAnalysisObservation(**payload.get("observation", {})) | |
| return StepResult( | |
| observation=obs, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(payload: dict) -> DataAnalysisState: | |
| return DataAnalysisState(**payload) | |