File size: 2,259 Bytes
2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 56ddfd4 2cc94a4 40e4201 56ddfd4 40e4201 56ddfd4 2cc94a4 56ddfd4 2cc94a4 40e4201 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | """OpenEnv Data Cleaning Environment Client."""
from typing import Dict, Any, Optional
from openenv.core import EnvClient, SyncEnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import DataCleaningAction, DataCleaningObservation, DataCleaningState
class DataCleaningClient(EnvClient[DataCleaningAction, DataCleaningObservation, DataCleaningState]):
"""
Client for the Data Cleaning Environment.
Example:
>>> with DataCleaningClient(base_url="http://localhost:7860") as client:
... result = client.reset(task_id="easy_001")
... print(result.observation.metadata.get("message"))
"""
def _step_payload(self, action: DataCleaningAction) -> Dict[str, Any]:
"""Convert DataCleaningAction to JSON payload."""
return {
"action_type": action.action_type,
"params": action.params,
}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[DataCleaningObservation]:
"""Parse server response into StepResult."""
obs_data = payload.get("observation", {})
observation = DataCleaningObservation(
dataset_info=obs_data.get("dataset_info", {}),
available_actions=obs_data.get("available_actions", []),
step_count=obs_data.get("step_count", 0),
task_id=obs_data.get("task_id"),
message=obs_data.get("message", ""),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> DataCleaningState:
"""Parse server response into State object."""
return DataCleaningState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
session_id=payload.get("session_id", ""),
task_id=payload.get("task_id"),
action_history=payload.get("action_history", []),
grade=payload.get("grade"),
)
|