dataclean-env / client.py
Anuj424614's picture
fix: client.py docstring uses correct done-tracking pattern
ba9bdc7 verified
"""Typed WebSocket client for the DataClean-Env OpenEnv environment.
Connects to a running DataClean environment server via WebSocket and
provides a type-safe interface for stepping through data-cleaning episodes.
Usage::
from dataclean_env.client import DataCleanEnv
with DataCleanEnv(base_url="http://localhost:8000").sync() as env:
result = env.reset()
obs = result.observation
done = False
while not done:
action = DataCleanAction(
action_type="fix_value",
params={"row_id": 0, "column": "name", "new_value": "Alice"},
)
result = env.step(action)
obs = result.observation
done = result.done
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult
from dataclean_env.models import (
ActionResult,
DataCleanAction,
DataCleanObservation,
DataCleanState,
DataSummary,
IssueGroup,
QualityIssue,
)
class DataCleanEnv(EnvClient[DataCleanAction, DataCleanObservation, DataCleanState]):
"""WebSocket client for the DataClean environment.
Subclasses :class:`EnvClient` and implements the three required
serialisation hooks so the generic base class can handle the
WebSocket transport transparently.
Parameters
----------
base_url:
Root URL of the DataClean environment server
(e.g. ``"http://localhost:8000"``).
"""
# ------------------------------------------------------------------
# Abstract-method implementations
# ------------------------------------------------------------------
def _step_payload(self, action: DataCleanAction) -> Dict[str, Any]:
"""Serialise a :class:`DataCleanAction` to a JSON-safe dict.
The dict is sent over the WebSocket to the server's ``step``
endpoint.
"""
return {
"action_type": action.action_type,
"params": action.params,
}
def _parse_result(
self, payload: Dict[str, Any]
) -> StepResult[DataCleanObservation]:
"""Parse the server's step response into a typed :class:`StepResult`.
The *payload* dict is the JSON body returned by the server after
a ``step`` or ``reset`` call.
"""
obs_data: Dict[str, Any] = payload.get("observation", {})
observation = _build_observation(obs_data)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> DataCleanState:
"""Parse the server's state response into :class:`DataCleanState`."""
return DataCleanState(**payload)
# ------------------------------------------------------------------
# Private helpers — kept outside the class to stay under 50 lines/fn
# ------------------------------------------------------------------
def _parse_quality_issues(
raw_issues: List[Dict[str, Any]],
) -> List[QualityIssue]:
"""Convert a list of raw dicts into :class:`QualityIssue` models."""
return [QualityIssue(**item) for item in raw_issues]
def _parse_issue_groups(
raw_groups: List[Dict[str, Any]],
) -> List[IssueGroup]:
"""Convert a list of raw dicts into :class:`IssueGroup` models."""
return [IssueGroup(**item) for item in raw_groups]
def _parse_action_result(
raw: Optional[Dict[str, Any]],
) -> Optional[ActionResult]:
"""Convert a raw dict into an :class:`ActionResult`, or ``None``."""
if raw is None:
return None
return ActionResult(**raw)
def _parse_recent_actions(
raw_actions: List[Dict[str, Any]],
) -> List[ActionResult]:
"""Convert a list of raw dicts into :class:`ActionResult` models."""
return [ActionResult(**item) for item in raw_actions]
def _build_observation(obs_data: Dict[str, Any]) -> DataCleanObservation:
"""Construct a fully-typed :class:`DataCleanObservation` from raw JSON.
Nested models (:class:`DataSummary`, :class:`QualityIssue`, etc.) are
parsed explicitly so that callers always receive validated Pydantic
objects rather than raw dicts.
"""
data_summary_raw: Dict[str, Any] = obs_data.get("data_summary", {})
data_summary = DataSummary(**data_summary_raw) if data_summary_raw else DataSummary()
return DataCleanObservation(
# Issue-first fields
data_summary=data_summary,
quality_issues=_parse_quality_issues(obs_data.get("quality_issues", [])),
issue_groups=_parse_issue_groups(obs_data.get("issue_groups", [])),
issues_remaining=obs_data.get("issues_remaining", 0),
# Data fields
columns=obs_data.get("columns", []),
rows=obs_data.get("rows", []),
row_count=obs_data.get("row_count", 0),
# Schema info
schema_info=obs_data.get("schema_info", {}),
# Step context
step_number=obs_data.get("step_number", 0),
max_steps=obs_data.get("max_steps", 30),
steps_remaining=obs_data.get("steps_remaining", 30),
# History
last_action_result=_parse_action_result(obs_data.get("last_action_result")),
recent_actions=_parse_recent_actions(obs_data.get("recent_actions", [])),
# Task info
task_id=obs_data.get("task_id", ""),
task_name=obs_data.get("task_name", ""),
difficulty=obs_data.get("difficulty", ""),
# Budget fields
budget_spent=obs_data.get("budget_spent", 0.0),
budget_remaining=obs_data.get("budget_remaining", 100.0),
action_costs=obs_data.get("action_costs", {}),
# Inherited Observation fields
done=obs_data.get("done", False),
reward=obs_data.get("reward"),
metadata=obs_data.get("metadata", {}),
)