yashmarathe's picture
fix: packaging, grading baseline, action registry, observations, tests, deploy config
7a14011
"""
Python client for the Data Cleaning RL Environment.
Provides a lightweight async wrapper for local testing and integration
with RL training frameworks.
Usage (async):
import asyncio
from data_cleaning_env.client import DataCleaningEnvClient
from data_cleaning_env.models import CleaningAction, ActionType, FillStrategy
async def main():
client = DataCleaningEnvClient(base_url="http://localhost:8000")
result = await client.reset(task="easy")
episode_id = result["state"]["episode_id"]
action = CleaningAction(
action_type=ActionType.fill_missing,
column="sepallength",
strategy=FillStrategy.median,
)
result = await client.step(episode_id, action)
print(result)
asyncio.run(main())
"""
from __future__ import annotations
from typing import Any
try:
import httpx
_HAS_HTTPX = True
except ImportError:
_HAS_HTTPX = False
from data_cleaning_env.models import CleaningAction
class DataCleaningEnvClient:
"""Async HTTP client for the Data Cleaning OpenEnv server."""
def __init__(self, base_url: str = "http://localhost:8000") -> None:
self.base_url = base_url.rstrip("/")
async def reset(self, task: str = "easy") -> dict[str, Any]:
"""Start a new episode. Returns {observation, state}."""
return await self._post("/reset", {"task": task})
async def step(self, episode_id: str, action: CleaningAction) -> dict[str, Any]:
"""Apply a cleaning action. Returns {observation, reward, done, info}."""
return await self._post(
"/step",
{
"episode_id": episode_id,
"action": action.model_dump(),
},
)
async def get_state(self, episode_id: str) -> dict[str, Any]:
"""Get episode metadata."""
return await self._get(f"/state?episode_id={episode_id}")
async def grade(self, episode_id: str) -> dict[str, Any]:
"""Grade the current episode. Returns {episode_id, task, score}."""
return await self._post("/grader", {"episode_id": episode_id})
async def get_tasks(self) -> dict[str, Any]:
"""Get available tasks and action schema."""
return await self._get("/tasks")
async def baseline(self) -> dict[str, Any]:
"""Trigger the baseline agent and return scores."""
return await self._post("/baseline", {})
async def health(self) -> dict[str, Any]:
"""Liveness check."""
return await self._get("/health")
async def _post(self, path: str, payload: dict) -> dict[str, Any]:
if not _HAS_HTTPX:
raise ImportError(
"httpx is required for async HTTP. Install it: pip install httpx"
)
async with httpx.AsyncClient(base_url=self.base_url, timeout=60) as client:
resp = await client.post(path, json=payload)
resp.raise_for_status()
return resp.json()
async def _get(self, path: str) -> dict[str, Any]:
if not _HAS_HTTPX:
raise ImportError(
"httpx is required for async HTTP. Install it: pip install httpx"
)
async with httpx.AsyncClient(base_url=self.base_url, timeout=60) as client:
resp = await client.get(path)
resp.raise_for_status()
return resp.json()