"""Typed WebSocket client for the DataClean-Env OpenEnv environment. Connects to a running DataClean environment server via WebSocket and provides a type-safe interface for stepping through data-cleaning episodes. Usage:: from dataclean_env.client import DataCleanEnv with DataCleanEnv(base_url="http://localhost:8000").sync() as env: result = env.reset() obs = result.observation done = False while not done: action = DataCleanAction( action_type="fix_value", params={"row_id": 0, "column": "name", "new_value": "Alice"}, ) result = env.step(action) obs = result.observation done = result.done """ from __future__ import annotations from typing import Any, Dict, List, Optional from openenv.core.env_client import EnvClient from openenv.core.client_types import StepResult from dataclean_env.models import ( ActionResult, DataCleanAction, DataCleanObservation, DataCleanState, DataSummary, IssueGroup, QualityIssue, ) class DataCleanEnv(EnvClient[DataCleanAction, DataCleanObservation, DataCleanState]): """WebSocket client for the DataClean environment. Subclasses :class:`EnvClient` and implements the three required serialisation hooks so the generic base class can handle the WebSocket transport transparently. Parameters ---------- base_url: Root URL of the DataClean environment server (e.g. ``"http://localhost:8000"``). """ # ------------------------------------------------------------------ # Abstract-method implementations # ------------------------------------------------------------------ def _step_payload(self, action: DataCleanAction) -> Dict[str, Any]: """Serialise a :class:`DataCleanAction` to a JSON-safe dict. The dict is sent over the WebSocket to the server's ``step`` endpoint. """ return { "action_type": action.action_type, "params": action.params, } def _parse_result( self, payload: Dict[str, Any] ) -> StepResult[DataCleanObservation]: """Parse the server's step response into a typed :class:`StepResult`. The *payload* dict is the JSON body returned by the server after a ``step`` or ``reset`` call. """ obs_data: Dict[str, Any] = payload.get("observation", {}) observation = _build_observation(obs_data) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> DataCleanState: """Parse the server's state response into :class:`DataCleanState`.""" return DataCleanState(**payload) # ------------------------------------------------------------------ # Private helpers — kept outside the class to stay under 50 lines/fn # ------------------------------------------------------------------ def _parse_quality_issues( raw_issues: List[Dict[str, Any]], ) -> List[QualityIssue]: """Convert a list of raw dicts into :class:`QualityIssue` models.""" return [QualityIssue(**item) for item in raw_issues] def _parse_issue_groups( raw_groups: List[Dict[str, Any]], ) -> List[IssueGroup]: """Convert a list of raw dicts into :class:`IssueGroup` models.""" return [IssueGroup(**item) for item in raw_groups] def _parse_action_result( raw: Optional[Dict[str, Any]], ) -> Optional[ActionResult]: """Convert a raw dict into an :class:`ActionResult`, or ``None``.""" if raw is None: return None return ActionResult(**raw) def _parse_recent_actions( raw_actions: List[Dict[str, Any]], ) -> List[ActionResult]: """Convert a list of raw dicts into :class:`ActionResult` models.""" return [ActionResult(**item) for item in raw_actions] def _build_observation(obs_data: Dict[str, Any]) -> DataCleanObservation: """Construct a fully-typed :class:`DataCleanObservation` from raw JSON. Nested models (:class:`DataSummary`, :class:`QualityIssue`, etc.) are parsed explicitly so that callers always receive validated Pydantic objects rather than raw dicts. """ data_summary_raw: Dict[str, Any] = obs_data.get("data_summary", {}) data_summary = DataSummary(**data_summary_raw) if data_summary_raw else DataSummary() return DataCleanObservation( # Issue-first fields data_summary=data_summary, quality_issues=_parse_quality_issues(obs_data.get("quality_issues", [])), issue_groups=_parse_issue_groups(obs_data.get("issue_groups", [])), issues_remaining=obs_data.get("issues_remaining", 0), # Data fields columns=obs_data.get("columns", []), rows=obs_data.get("rows", []), row_count=obs_data.get("row_count", 0), # Schema info schema_info=obs_data.get("schema_info", {}), # Step context step_number=obs_data.get("step_number", 0), max_steps=obs_data.get("max_steps", 30), steps_remaining=obs_data.get("steps_remaining", 30), # History last_action_result=_parse_action_result(obs_data.get("last_action_result")), recent_actions=_parse_recent_actions(obs_data.get("recent_actions", [])), # Task info task_id=obs_data.get("task_id", ""), task_name=obs_data.get("task_name", ""), difficulty=obs_data.get("difficulty", ""), # Budget fields budget_spent=obs_data.get("budget_spent", 0.0), budget_remaining=obs_data.get("budget_remaining", 100.0), action_costs=obs_data.get("action_costs", {}), # Inherited Observation fields done=obs_data.get("done", False), reward=obs_data.get("reward"), metadata=obs_data.get("metadata", {}), )