"""Typed clients for the DataOpsEnv environment.""" from typing import Optional import requests from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from models import DataOpsAction, DataOpsObservation, DataOpsState class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]): """Native OpenEnv WebSocket client for persistent sessions.""" def _step_payload(self, action: DataOpsAction) -> dict: return action.model_dump() def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]: observation = DataOpsObservation( **payload.get("observation", {}), reward=payload.get("reward"), done=payload.get("done", False), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: dict) -> DataOpsState: return DataOpsState(**payload) class DataOpsEnvClient: """Compatibility HTTP client for the validator-facing REST API.""" def __init__( self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0 ) -> None: self.base_url = base_url.rstrip("/") self.timeout = timeout self._session = requests.Session() @staticmethod def _parse_observation(payload: dict) -> DataOpsObservation: observation_payload = dict(payload.get("observation", {})) if "reward" in payload: observation_payload["reward"] = payload["reward"] if "done" in payload: observation_payload["done"] = payload["done"] return DataOpsObservation(**observation_payload) def reset( self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None, ) -> DataOpsObservation: resp = self._session.post( f"{self.base_url}/reset", params={"task_id": task_id}, json={"seed": seed}, timeout=self.timeout, ) resp.raise_for_status() return self._parse_observation(resp.json()) def step(self, action: DataOpsAction) -> DataOpsObservation: resp = self._session.post( f"{self.base_url}/step", json={"action": action.model_dump()}, timeout=self.timeout, ) resp.raise_for_status() return self._parse_observation(resp.json()) def state(self) -> DataOpsState: resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout) resp.raise_for_status() return DataOpsState(**resp.json()) def grade(self, task_id: Optional[str] = None) -> dict: url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader" resp = self._session.get(url, timeout=self.timeout) resp.raise_for_status() return resp.json() def close(self) -> None: self._session.close() def __enter__(self) -> "DataOpsEnvClient": return self def __exit__(self, *args: object) -> None: self.close()