File size: 5,052 Bytes
3e1f9da
 
 
e5572a6
3e1f9da
 
 
 
 
 
 
 
 
b5e858e
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5572a6
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Client for the Chess OpenEnv environment."""

from dataclasses import dataclass
from typing import Any, Dict, Optional

import httpx

from .models import ChessAction, ChessObservation, ChessState


@dataclass
class StepResult:
    """Result from a step() call."""

    observation: ChessObservation
    reward: float
    done: bool


class ChessEnvClient:
    """
    HTTP client for the Chess OpenEnv environment.

    Provides a simple interface to interact with a remote chess environment
    server for reinforcement learning.

    Example usage:
        client = ChessEnvClient("http://localhost:8000")
        obs = client.reset()
        print(f"Legal moves: {obs.legal_moves}")

        result = client.step(ChessAction(move="e2e4"))
        print(f"Reward: {result.reward}, Done: {result.done}")

        state = client.state()
        print(f"Move count: {state.step_count}")

        client.close()
    """

    def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 30.0):
        """
        Initialize the chess environment client.

        Args:
            base_url: URL of the chess environment server
            timeout: Request timeout in seconds
        """
        self.base_url = base_url.rstrip("/")
        self._client = httpx.Client(timeout=timeout)

    def reset(
        self,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
        fen: Optional[str] = None,
    ) -> ChessObservation:
        """
        Reset the environment and start a new episode.

        Args:
            seed: Random seed (optional)
            episode_id: Unique episode identifier (optional)
            fen: Starting position in FEN notation (optional)

        Returns:
            Initial observation of the board state
        """
        payload: Dict[str, Any] = {}
        if seed is not None:
            payload["seed"] = seed
        if episode_id is not None:
            payload["episode_id"] = episode_id
        if fen is not None:
            payload["fen"] = fen

        response = self._client.post(f"{self.base_url}/reset", json=payload)
        response.raise_for_status()
        data = response.json()

        return self._parse_observation(data)

    def step(self, action: ChessAction) -> StepResult:
        """
        Execute a move in the environment.

        Args:
            action: The chess action (move in UCI format)

        Returns:
            StepResult with observation, reward, and done flag
        """
        payload = {"move": action.move}
        response = self._client.post(f"{self.base_url}/step", json=payload)
        response.raise_for_status()
        data = response.json()

        return StepResult(
            observation=self._parse_observation(data["observation"]),
            reward=data["reward"],
            done=data["done"],
        )

    def state(self) -> ChessState:
        """
        Get the current episode state.

        Returns:
            Current episode state with metadata
        """
        response = self._client.get(f"{self.base_url}/state")
        response.raise_for_status()
        data = response.json()

        return ChessState(
            episode_id=data["episode_id"],
            step_count=data["step_count"],
            current_player=data["current_player"],
            fen=data["fen"],
            move_history=data.get("move_history", []),
        )

    def metadata(self) -> Dict[str, Any]:
        """
        Get environment metadata.

        Returns:
            Dictionary with environment configuration
        """
        response = self._client.get(f"{self.base_url}/metadata")
        response.raise_for_status()
        return response.json()

    def health(self) -> bool:
        """
        Check if the server is healthy.

        Returns:
            True if server is responding
        """
        try:
            response = self._client.get(f"{self.base_url}/health")
            return response.status_code == 200
        except Exception:
            return False

    def close(self) -> None:
        """Close the HTTP client."""
        self._client.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def _parse_observation(self, data: Dict[str, Any]) -> ChessObservation:
        """Parse observation from JSON response."""
        return ChessObservation(
            fen=data["fen"],
            legal_moves=data["legal_moves"],
            is_check=data.get("is_check", False),
            done=data.get("done", False),
            reward=data.get("reward"),
            result=data.get("result"),
            metadata=data.get("metadata", {}),
        )


# Convenience function for quick usage
def make_env(base_url: str = "http://localhost:8000") -> ChessEnvClient:
    """
    Create a chess environment client.

    Args:
        base_url: URL of the chess environment server

    Returns:
        ChessEnvClient instance
    """
    return ChessEnvClient(base_url)