Spaces:
Sleeping
Sleeping
| """ | |
| Python client for the Data Cleaning RL Environment. | |
| Provides a lightweight async wrapper for local testing and integration | |
| with RL training frameworks. | |
| Usage (async): | |
| import asyncio | |
| from data_cleaning_env.client import DataCleaningEnvClient | |
| from data_cleaning_env.models import CleaningAction, ActionType, FillStrategy | |
| async def main(): | |
| client = DataCleaningEnvClient(base_url="http://localhost:8000") | |
| result = await client.reset(task="easy") | |
| episode_id = result["state"]["episode_id"] | |
| action = CleaningAction( | |
| action_type=ActionType.fill_missing, | |
| column="sepallength", | |
| strategy=FillStrategy.median, | |
| ) | |
| result = await client.step(episode_id, action) | |
| print(result) | |
| asyncio.run(main()) | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| try: | |
| import httpx | |
| _HAS_HTTPX = True | |
| except ImportError: | |
| _HAS_HTTPX = False | |
| from data_cleaning_env.models import CleaningAction | |
| class DataCleaningEnvClient: | |
| """Async HTTP client for the Data Cleaning OpenEnv server.""" | |
| def __init__(self, base_url: str = "http://localhost:8000") -> None: | |
| self.base_url = base_url.rstrip("/") | |
| async def reset(self, task: str = "easy") -> dict[str, Any]: | |
| """Start a new episode. Returns {observation, state}.""" | |
| return await self._post("/reset", {"task": task}) | |
| async def step(self, episode_id: str, action: CleaningAction) -> dict[str, Any]: | |
| """Apply a cleaning action. Returns {observation, reward, done, info}.""" | |
| return await self._post( | |
| "/step", | |
| { | |
| "episode_id": episode_id, | |
| "action": action.model_dump(), | |
| }, | |
| ) | |
| async def get_state(self, episode_id: str) -> dict[str, Any]: | |
| """Get episode metadata.""" | |
| return await self._get(f"/state?episode_id={episode_id}") | |
| async def grade(self, episode_id: str) -> dict[str, Any]: | |
| """Grade the current episode. Returns {episode_id, task, score}.""" | |
| return await self._post("/grader", {"episode_id": episode_id}) | |
| async def get_tasks(self) -> dict[str, Any]: | |
| """Get available tasks and action schema.""" | |
| return await self._get("/tasks") | |
| async def baseline(self) -> dict[str, Any]: | |
| """Trigger the baseline agent and return scores.""" | |
| return await self._post("/baseline", {}) | |
| async def health(self) -> dict[str, Any]: | |
| """Liveness check.""" | |
| return await self._get("/health") | |
| async def _post(self, path: str, payload: dict) -> dict[str, Any]: | |
| if not _HAS_HTTPX: | |
| raise ImportError( | |
| "httpx is required for async HTTP. Install it: pip install httpx" | |
| ) | |
| async with httpx.AsyncClient(base_url=self.base_url, timeout=60) as client: | |
| resp = await client.post(path, json=payload) | |
| resp.raise_for_status() | |
| return resp.json() | |
| async def _get(self, path: str) -> dict[str, Any]: | |
| if not _HAS_HTTPX: | |
| raise ImportError( | |
| "httpx is required for async HTTP. Install it: pip install httpx" | |
| ) | |
| async with httpx.AsyncClient(base_url=self.base_url, timeout=60) as client: | |
| resp = await client.get(path) | |
| resp.raise_for_status() | |
| return resp.json() | |