Spaces:
Sleeping
Sleeping
| """Typed clients for the DataOpsEnv environment.""" | |
| from typing import Optional | |
| import requests | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| from models import DataOpsAction, DataOpsObservation, DataOpsState | |
| class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]): | |
| """Native OpenEnv WebSocket client for persistent sessions.""" | |
| def _step_payload(self, action: DataOpsAction) -> dict: | |
| return action.model_dump() | |
| def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]: | |
| observation = DataOpsObservation( | |
| **payload.get("observation", {}), | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward"), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: dict) -> DataOpsState: | |
| return DataOpsState(**payload) | |
| class DataOpsEnvClient: | |
| """Compatibility HTTP client for the validator-facing REST API.""" | |
| def __init__( | |
| self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0 | |
| ) -> None: | |
| self.base_url = base_url.rstrip("/") | |
| self.timeout = timeout | |
| self._session = requests.Session() | |
| def _parse_observation(payload: dict) -> DataOpsObservation: | |
| observation_payload = dict(payload.get("observation", {})) | |
| if "reward" in payload: | |
| observation_payload["reward"] = payload["reward"] | |
| if "done" in payload: | |
| observation_payload["done"] = payload["done"] | |
| return DataOpsObservation(**observation_payload) | |
| def reset( | |
| self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None, | |
| ) -> DataOpsObservation: | |
| resp = self._session.post( | |
| f"{self.base_url}/reset", | |
| params={"task_id": task_id}, | |
| json={"seed": seed}, | |
| timeout=self.timeout, | |
| ) | |
| resp.raise_for_status() | |
| return self._parse_observation(resp.json()) | |
| def step(self, action: DataOpsAction) -> DataOpsObservation: | |
| resp = self._session.post( | |
| f"{self.base_url}/step", | |
| json={"action": action.model_dump()}, | |
| timeout=self.timeout, | |
| ) | |
| resp.raise_for_status() | |
| return self._parse_observation(resp.json()) | |
| def state(self) -> DataOpsState: | |
| resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout) | |
| resp.raise_for_status() | |
| return DataOpsState(**resp.json()) | |
| def grade(self, task_id: Optional[str] = None) -> dict: | |
| url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader" | |
| resp = self._session.get(url, timeout=self.timeout) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def close(self) -> None: | |
| self._session.close() | |
| def __enter__(self) -> "DataOpsEnvClient": | |
| return self | |
| def __exit__(self, *args: object) -> None: | |
| self.close() | |