data_analysis_env / client.py
HimanshuSardana2's picture
Upload folder using huggingface_hub
abb357f verified
from typing import Optional
import httpx
from openenv.core.env_client import EnvClient, StepResult
from models import DataAnalysisAction, DataAnalysisObservation, DataAnalysisState
class DataAnalysisEnv(EnvClient):
def __init__(self, base_url: str = "http://localhost:8000"):
self._base_url = base_url.rstrip("/")
if self._base_url.startswith("ws://"):
self._base_url = self._base_url.replace("ws://", "http://")
elif not self._base_url.startswith("http://"):
self._base_url = "http://" + self._base_url
self._client: Optional[httpx.AsyncClient] = None
def _get_client(self) -> httpx.AsyncClient:
if self._client is None:
self._client = httpx.AsyncClient(base_url=self._base_url, timeout=60.0)
return self._client
async def reset(self, task: str = "task_1", **kwargs) -> StepResult:
client = self._get_client()
response = await client.post("/reset", json={"task": task})
response.raise_for_status()
data = response.json()
return self._parse_result(data)
async def step(self, action: DataAnalysisAction) -> StepResult:
payload = {
"action": {
"tool": action.tool,
"parameters": action.parameters,
}
}
client = self._get_client()
response = await client.post("/step", json=payload)
response.raise_for_status()
data = response.json()
return self._parse_result(data)
async def state(self) -> DataAnalysisState:
client = self._get_client()
response = await client.get("/state")
response.raise_for_status()
data = response.json()
return DataAnalysisState(**data)
async def close(self):
if self._client:
await self._client.aclose()
self._client = None
@staticmethod
def _parse_result(payload: dict) -> StepResult:
obs = DataAnalysisObservation(**payload.get("observation", {}))
return StepResult(
observation=obs,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
@staticmethod
def _parse_state(payload: dict) -> DataAnalysisState:
return DataAnalysisState(**payload)