Spaces:
Sleeping
Sleeping
File size: 5,052 Bytes
3e1f9da e5572a6 3e1f9da b5e858e 3e1f9da e5572a6 3e1f9da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""Client for the Chess OpenEnv environment."""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import httpx
from .models import ChessAction, ChessObservation, ChessState
@dataclass
class StepResult:
"""Result from a step() call."""
observation: ChessObservation
reward: float
done: bool
class ChessEnvClient:
"""
HTTP client for the Chess OpenEnv environment.
Provides a simple interface to interact with a remote chess environment
server for reinforcement learning.
Example usage:
client = ChessEnvClient("http://localhost:8000")
obs = client.reset()
print(f"Legal moves: {obs.legal_moves}")
result = client.step(ChessAction(move="e2e4"))
print(f"Reward: {result.reward}, Done: {result.done}")
state = client.state()
print(f"Move count: {state.step_count}")
client.close()
"""
def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 30.0):
"""
Initialize the chess environment client.
Args:
base_url: URL of the chess environment server
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip("/")
self._client = httpx.Client(timeout=timeout)
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
fen: Optional[str] = None,
) -> ChessObservation:
"""
Reset the environment and start a new episode.
Args:
seed: Random seed (optional)
episode_id: Unique episode identifier (optional)
fen: Starting position in FEN notation (optional)
Returns:
Initial observation of the board state
"""
payload: Dict[str, Any] = {}
if seed is not None:
payload["seed"] = seed
if episode_id is not None:
payload["episode_id"] = episode_id
if fen is not None:
payload["fen"] = fen
response = self._client.post(f"{self.base_url}/reset", json=payload)
response.raise_for_status()
data = response.json()
return self._parse_observation(data)
def step(self, action: ChessAction) -> StepResult:
"""
Execute a move in the environment.
Args:
action: The chess action (move in UCI format)
Returns:
StepResult with observation, reward, and done flag
"""
payload = {"move": action.move}
response = self._client.post(f"{self.base_url}/step", json=payload)
response.raise_for_status()
data = response.json()
return StepResult(
observation=self._parse_observation(data["observation"]),
reward=data["reward"],
done=data["done"],
)
def state(self) -> ChessState:
"""
Get the current episode state.
Returns:
Current episode state with metadata
"""
response = self._client.get(f"{self.base_url}/state")
response.raise_for_status()
data = response.json()
return ChessState(
episode_id=data["episode_id"],
step_count=data["step_count"],
current_player=data["current_player"],
fen=data["fen"],
move_history=data.get("move_history", []),
)
def metadata(self) -> Dict[str, Any]:
"""
Get environment metadata.
Returns:
Dictionary with environment configuration
"""
response = self._client.get(f"{self.base_url}/metadata")
response.raise_for_status()
return response.json()
def health(self) -> bool:
"""
Check if the server is healthy.
Returns:
True if server is responding
"""
try:
response = self._client.get(f"{self.base_url}/health")
return response.status_code == 200
except Exception:
return False
def close(self) -> None:
"""Close the HTTP client."""
self._client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def _parse_observation(self, data: Dict[str, Any]) -> ChessObservation:
"""Parse observation from JSON response."""
return ChessObservation(
fen=data["fen"],
legal_moves=data["legal_moves"],
is_check=data.get("is_check", False),
done=data.get("done", False),
reward=data.get("reward"),
result=data.get("result"),
metadata=data.get("metadata", {}),
)
# Convenience function for quick usage
def make_env(base_url: str = "http://localhost:8000") -> ChessEnvClient:
"""
Create a chess environment client.
Args:
base_url: URL of the chess environment server
Returns:
ChessEnvClient instance
"""
return ChessEnvClient(base_url)
|