File size: 1,960 Bytes
8ab6a5f
b3fce19
 
8ab6a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from models import DataAction, DataObservation, DataState
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient


class DataAnalysisClient(EnvClient[DataAction, DataObservation, DataState]):
    """Client for interacting with the Data Analysis environment server.

    Supports both async and sync usage patterns:
        - Async: ``async with DataAnalysisClient(base_url=...) as client:``
        - Sync: ``with DataAnalysisClient(base_url=...).sync() as client:``
    """

    def _step_payload(self, action: DataAction) -> dict:
        """Convert a DataAction into a JSON-serializable payload.

        Args:
            action: The action to send to the server.

        Returns:
            A dictionary representation of the action.
        """
        payload = {"action_type": action.action_type}
        if action.code is not None:
            payload["code"] = action.code
        if action.answer is not None:
            payload["answer"] = action.answer
        return payload

    def _parse_result(self, payload: dict) -> StepResult[DataObservation]:
        """Parse the server's JSON response into a StepResult.

        Args:
            payload: The raw JSON response from the server.

        Returns:
            A StepResult containing the parsed observation, reward, and done flag.
        """
        obs_data = payload.get("observation", payload)
        obs = DataObservation(**obs_data)
        return StepResult(
            observation=obs,
            reward=payload.get("reward", obs.reward),
            done=payload.get("done", obs.done),
        )

    def _parse_state(self, payload: dict) -> DataState:
        """Parse the server's state response into a DataState.

        Args:
            payload: The raw JSON state response from the server.

        Returns:
            A DataState object reflecting the current episode state.
        """
        return DataState(**payload)