File size: 2,259 Bytes
2cc94a4
56ddfd4
2cc94a4
56ddfd4
2cc94a4
56ddfd4
 
 
2cc94a4
56ddfd4
 
2cc94a4
56ddfd4
2cc94a4
 
56ddfd4
2cc94a4
 
 
56ddfd4
 
2cc94a4
 
56ddfd4
2cc94a4
 
56ddfd4
 
2cc94a4
 
56ddfd4
2cc94a4
40e4201
 
 
 
 
56ddfd4
 
40e4201
56ddfd4
 
 
 
 
 
 
2cc94a4
 
 
56ddfd4
 
2cc94a4
 
 
 
40e4201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""OpenEnv Data Cleaning Environment Client."""

from typing import Dict, Any, Optional

from openenv.core import EnvClient, SyncEnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from .models import DataCleaningAction, DataCleaningObservation, DataCleaningState


class DataCleaningClient(EnvClient[DataCleaningAction, DataCleaningObservation, DataCleaningState]):
    """
    Client for the Data Cleaning Environment.
    
    Example:
        >>> with DataCleaningClient(base_url="http://localhost:7860") as client:
        ...     result = client.reset(task_id="easy_001")
        ...     print(result.observation.metadata.get("message"))
    """

    def _step_payload(self, action: DataCleaningAction) -> Dict[str, Any]:
        """Convert DataCleaningAction to JSON payload."""
        return {
            "action_type": action.action_type,
            "params": action.params,
        }

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[DataCleaningObservation]:
        """Parse server response into StepResult."""
        obs_data = payload.get("observation", {})
        observation = DataCleaningObservation(
            dataset_info=obs_data.get("dataset_info", {}),
            available_actions=obs_data.get("available_actions", []),
            step_count=obs_data.get("step_count", 0),
            task_id=obs_data.get("task_id"),
            message=obs_data.get("message", ""),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> DataCleaningState:
        """Parse server response into State object."""
        return DataCleaningState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            session_id=payload.get("session_id", ""),
            task_id=payload.get("task_id"),
            action_history=payload.get("action_history", []),
            grade=payload.get("grade"),
        )