File size: 3,984 Bytes
c65b2a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
OpenSpielEnv Client.

This module provides the client for connecting to an OpenSpiel Environment server
via WebSocket for persistent sessions.
"""

from __future__ import annotations

from typing import Any, Dict, Optional, TYPE_CHECKING

from openenv.core.client_types import StepResult

from openenv.core.env_client import EnvClient

from .models import OpenSpielAction, OpenSpielObservation, OpenSpielState

if TYPE_CHECKING:
    from openenv.core.containers.runtime import ContainerProvider


class OpenSpielEnv(EnvClient[OpenSpielAction, OpenSpielObservation, OpenSpielState]):
    """
    Client for OpenSpiel Environment.

    This client maintains a persistent WebSocket connection to the environment
    server, enabling efficient multi-step interactions with lower latency.

    Example:
        >>> # Connect to a running server
        >>> with OpenSpielEnv(base_url="http://localhost:8000") as client:
        ...     result = client.reset()
        ...     print(result.observation.info_state)
        ...
        ...     result = client.step(OpenSpielAction(action_id=1, game_name="catch"))
        ...     print(result.observation.reward)

    Example with Docker:
        >>> # Automatically start container and connect
        >>> client = OpenSpielEnv.from_docker_image("openspiel-env:latest")
        >>> try:
        ...     result = client.reset()
        ...     result = client.step(OpenSpielAction(action_id=0))
        ... finally:
        ...     client.close()
    """

    def _step_payload(self, action: OpenSpielAction) -> Dict[str, Any]:
        """
        Convert OpenSpielAction to JSON payload for step request.

        Args:
            action: OpenSpielAction instance.

        Returns:
            Dictionary representation suitable for JSON encoding.
        """
        return {
            "action_id": action.action_id,
            "game_name": action.game_name,
            "game_params": action.game_params,
        }

    def _parse_result(
        self, payload: Dict[str, Any]
    ) -> StepResult[OpenSpielObservation]:
        """
        Parse server response into StepResult[OpenSpielObservation].

        Args:
            payload: JSON response from server.

        Returns:
            StepResult with OpenSpielObservation.
        """
        obs_data = payload.get("observation", {})

        observation = OpenSpielObservation(
            info_state=obs_data.get("info_state", []),
            legal_actions=obs_data.get("legal_actions", []),
            game_phase=obs_data.get("game_phase", "playing"),
            current_player_id=obs_data.get("current_player_id", 0),
            opponent_last_action=obs_data.get("opponent_last_action"),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> OpenSpielState:
        """
        Parse server response into OpenSpielState object.

        Args:
            payload: JSON response from /state endpoint.

        Returns:
            OpenSpielState object with environment state information.
        """
        return OpenSpielState(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
            game_name=payload.get("game_name", "unknown"),
            agent_player=payload.get("agent_player", 0),
            opponent_policy=payload.get("opponent_policy", "random"),
            game_params=payload.get("game_params", {}),
            num_players=payload.get("num_players", 1),
        )