dataops-env / client.py
visheshrathi's picture
Upload folder using huggingface_hub
f89b1ac verified
"""Typed clients for the DataOpsEnv environment."""
from typing import Optional
import requests
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from models import DataOpsAction, DataOpsObservation, DataOpsState
class DataOpsEnv(EnvClient[DataOpsAction, DataOpsObservation, DataOpsState]):
"""Native OpenEnv WebSocket client for persistent sessions."""
def _step_payload(self, action: DataOpsAction) -> dict:
return action.model_dump()
def _parse_result(self, payload: dict) -> StepResult[DataOpsObservation]:
observation = DataOpsObservation(
**payload.get("observation", {}),
reward=payload.get("reward"),
done=payload.get("done", False),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: dict) -> DataOpsState:
return DataOpsState(**payload)
class DataOpsEnvClient:
"""Compatibility HTTP client for the validator-facing REST API."""
def __init__(
self, base_url: str = "http://127.0.0.1:7860", timeout: float = 30.0
) -> None:
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._session = requests.Session()
@staticmethod
def _parse_observation(payload: dict) -> DataOpsObservation:
observation_payload = dict(payload.get("observation", {}))
if "reward" in payload:
observation_payload["reward"] = payload["reward"]
if "done" in payload:
observation_payload["done"] = payload["done"]
return DataOpsObservation(**observation_payload)
def reset(
self, task_id: str = "task_1_easy_anomaly", seed: Optional[int] = None,
) -> DataOpsObservation:
resp = self._session.post(
f"{self.base_url}/reset",
params={"task_id": task_id},
json={"seed": seed},
timeout=self.timeout,
)
resp.raise_for_status()
return self._parse_observation(resp.json())
def step(self, action: DataOpsAction) -> DataOpsObservation:
resp = self._session.post(
f"{self.base_url}/step",
json={"action": action.model_dump()},
timeout=self.timeout,
)
resp.raise_for_status()
return self._parse_observation(resp.json())
def state(self) -> DataOpsState:
resp = self._session.get(f"{self.base_url}/state", timeout=self.timeout)
resp.raise_for_status()
return DataOpsState(**resp.json())
def grade(self, task_id: Optional[str] = None) -> dict:
url = f"{self.base_url}/grader/{task_id}" if task_id else f"{self.base_url}/grader"
resp = self._session.get(url, timeout=self.timeout)
resp.raise_for_status()
return resp.json()
def close(self) -> None:
self._session.close()
def __enter__(self) -> "DataOpsEnvClient":
return self
def __exit__(self, *args: object) -> None:
self.close()