Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- client.py +1 -0
- examples/openenv_training.py +14 -8
- models.py +4 -0
- server/app.py +1 -0
- server/chess_environment.py +11 -8
client.py
CHANGED
|
@@ -11,6 +11,7 @@ from .models import ChessAction, ChessObservation, ChessState
|
|
| 11 |
@dataclass
|
| 12 |
class StepResult:
|
| 13 |
"""Result from a step() call."""
|
|
|
|
| 14 |
observation: ChessObservation
|
| 15 |
reward: float
|
| 16 |
done: bool
|
|
|
|
| 11 |
@dataclass
|
| 12 |
class StepResult:
|
| 13 |
"""Result from a step() call."""
|
| 14 |
+
|
| 15 |
observation: ChessObservation
|
| 16 |
reward: float
|
| 17 |
done: bool
|
examples/openenv_training.py
CHANGED
|
@@ -75,9 +75,11 @@ def train_with_remote_env():
|
|
| 75 |
print(" (truncated at 200 moves)")
|
| 76 |
break
|
| 77 |
|
| 78 |
-
print(
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
| 82 |
# Cleanup
|
| 83 |
client.close()
|
|
@@ -111,10 +113,12 @@ def train_with_local_env():
|
|
| 111 |
if env.state.step_count > 200:
|
| 112 |
break
|
| 113 |
|
| 114 |
-
print(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
| 118 |
|
| 119 |
env.close()
|
| 120 |
print("\nTraining complete!")
|
|
@@ -130,5 +134,7 @@ if __name__ == "__main__":
|
|
| 130 |
print("=== Local Environment ===\n")
|
| 131 |
train_with_local_env()
|
| 132 |
print("\nTo test with HTTP client, run:")
|
| 133 |
-
print(
|
|
|
|
|
|
|
| 134 |
print(" 2. Run: python examples/openenv_training.py --remote")
|
|
|
|
| 75 |
print(" (truncated at 200 moves)")
|
| 76 |
break
|
| 77 |
|
| 78 |
+
print(
|
| 79 |
+
f" Moves: {client.state().step_count}, "
|
| 80 |
+
f"Result: {obs.result or 'ongoing'}, "
|
| 81 |
+
f"Reward: {episode_reward:.2f}"
|
| 82 |
+
)
|
| 83 |
|
| 84 |
# Cleanup
|
| 85 |
client.close()
|
|
|
|
| 113 |
if env.state.step_count > 200:
|
| 114 |
break
|
| 115 |
|
| 116 |
+
print(
|
| 117 |
+
f"Episode {episode + 1}: "
|
| 118 |
+
f"Moves={env.state.step_count}, "
|
| 119 |
+
f"Result={obs.result or 'ongoing'}, "
|
| 120 |
+
f"Reward={episode_reward:.2f}"
|
| 121 |
+
)
|
| 122 |
|
| 123 |
env.close()
|
| 124 |
print("\nTraining complete!")
|
|
|
|
| 134 |
print("=== Local Environment ===\n")
|
| 135 |
train_with_local_env()
|
| 136 |
print("\nTo test with HTTP client, run:")
|
| 137 |
+
print(
|
| 138 |
+
" 1. Start server: python -m uvicorn moonfish.rl.server.app:app --port 8000"
|
| 139 |
+
)
|
| 140 |
print(" 2. Run: python examples/openenv_training.py --remote")
|
models.py
CHANGED
|
@@ -12,6 +12,7 @@ class ChessAction:
|
|
| 12 |
Attributes:
|
| 13 |
move: UCI format move string (e.g., "e2e4", "e7e8q" for promotion)
|
| 14 |
"""
|
|
|
|
| 15 |
move: str
|
| 16 |
|
| 17 |
|
|
@@ -29,6 +30,7 @@ class ChessObservation:
|
|
| 29 |
result: Game result string if game is over (e.g., "1-0", "0-1", "1/2-1/2")
|
| 30 |
metadata: Additional information about the position
|
| 31 |
"""
|
|
|
|
| 32 |
fen: str
|
| 33 |
legal_moves: List[str]
|
| 34 |
is_check: bool = False
|
|
@@ -50,6 +52,7 @@ class ChessState:
|
|
| 50 |
fen: Current position in FEN notation
|
| 51 |
move_history: List of moves played in UCI format
|
| 52 |
"""
|
|
|
|
| 53 |
episode_id: str
|
| 54 |
step_count: int
|
| 55 |
current_player: str
|
|
@@ -70,6 +73,7 @@ class RewardConfig:
|
|
| 70 |
use_evaluation: Whether to include position evaluation in rewards
|
| 71 |
evaluation_scale: Scale factor for evaluation-based rewards
|
| 72 |
"""
|
|
|
|
| 73 |
win: float = 1.0
|
| 74 |
loss: float = -1.0
|
| 75 |
draw: float = 0.0
|
|
|
|
| 12 |
Attributes:
|
| 13 |
move: UCI format move string (e.g., "e2e4", "e7e8q" for promotion)
|
| 14 |
"""
|
| 15 |
+
|
| 16 |
move: str
|
| 17 |
|
| 18 |
|
|
|
|
| 30 |
result: Game result string if game is over (e.g., "1-0", "0-1", "1/2-1/2")
|
| 31 |
metadata: Additional information about the position
|
| 32 |
"""
|
| 33 |
+
|
| 34 |
fen: str
|
| 35 |
legal_moves: List[str]
|
| 36 |
is_check: bool = False
|
|
|
|
| 52 |
fen: Current position in FEN notation
|
| 53 |
move_history: List of moves played in UCI format
|
| 54 |
"""
|
| 55 |
+
|
| 56 |
episode_id: str
|
| 57 |
step_count: int
|
| 58 |
current_player: str
|
|
|
|
| 73 |
use_evaluation: Whether to include position evaluation in rewards
|
| 74 |
evaluation_scale: Scale factor for evaluation-based rewards
|
| 75 |
"""
|
| 76 |
+
|
| 77 |
win: float = 1.0
|
| 78 |
loss: float = -1.0
|
| 79 |
draw: float = 0.0
|
server/app.py
CHANGED
|
@@ -193,6 +193,7 @@ def state():
|
|
| 193 |
def main():
|
| 194 |
"""Entry point for running the server."""
|
| 195 |
import uvicorn
|
|
|
|
| 196 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 197 |
|
| 198 |
|
|
|
|
| 193 |
def main():
|
| 194 |
"""Entry point for running the server."""
|
| 195 |
import uvicorn
|
| 196 |
+
|
| 197 |
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 198 |
|
| 199 |
|
server/chess_environment.py
CHANGED
|
@@ -24,8 +24,12 @@ class ChessEnvironment:
|
|
| 24 |
self,
|
| 25 |
reward_config: Optional[RewardConfig] = None,
|
| 26 |
max_moves: int = 500,
|
| 27 |
-
agent_color: Optional[
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
opponent_depth: int = 2, # Search depth for moonfish opponent
|
| 30 |
):
|
| 31 |
"""
|
|
@@ -54,7 +58,7 @@ class ChessEnvironment:
|
|
| 54 |
seed: Optional[int] = None,
|
| 55 |
episode_id: Optional[str] = None,
|
| 56 |
fen: Optional[str] = None,
|
| 57 |
-
**kwargs
|
| 58 |
) -> ChessObservation:
|
| 59 |
"""
|
| 60 |
Initialize a new chess game episode.
|
|
@@ -99,10 +103,7 @@ class ChessEnvironment:
|
|
| 99 |
return self._get_observation()
|
| 100 |
|
| 101 |
def step(
|
| 102 |
-
self,
|
| 103 |
-
action: ChessAction,
|
| 104 |
-
timeout_s: Optional[float] = None,
|
| 105 |
-
**kwargs
|
| 106 |
) -> Tuple[ChessObservation, float, bool]:
|
| 107 |
"""
|
| 108 |
Execute a chess move and return the resulting state.
|
|
@@ -263,7 +264,9 @@ class ChessEnvironment:
|
|
| 263 |
|
| 264 |
return reward, False
|
| 265 |
|
| 266 |
-
def _handle_illegal_move(
|
|
|
|
|
|
|
| 267 |
"""Handle an illegal move attempt."""
|
| 268 |
observation = self._get_observation(done=False, error=error_msg)
|
| 269 |
return observation, self.reward_config.illegal_move, False
|
|
|
|
| 24 |
self,
|
| 25 |
reward_config: Optional[RewardConfig] = None,
|
| 26 |
max_moves: int = 500,
|
| 27 |
+
agent_color: Optional[
|
| 28 |
+
bool
|
| 29 |
+
] = None, # None = alternate, True = White, False = Black
|
| 30 |
+
opponent: Optional[
|
| 31 |
+
str
|
| 32 |
+
] = None, # None = self-play, "moonfish" = moonfish engine, "random" = random
|
| 33 |
opponent_depth: int = 2, # Search depth for moonfish opponent
|
| 34 |
):
|
| 35 |
"""
|
|
|
|
| 58 |
seed: Optional[int] = None,
|
| 59 |
episode_id: Optional[str] = None,
|
| 60 |
fen: Optional[str] = None,
|
| 61 |
+
**kwargs,
|
| 62 |
) -> ChessObservation:
|
| 63 |
"""
|
| 64 |
Initialize a new chess game episode.
|
|
|
|
| 103 |
return self._get_observation()
|
| 104 |
|
| 105 |
def step(
|
| 106 |
+
self, action: ChessAction, timeout_s: Optional[float] = None, **kwargs
|
|
|
|
|
|
|
|
|
|
| 107 |
) -> Tuple[ChessObservation, float, bool]:
|
| 108 |
"""
|
| 109 |
Execute a chess move and return the resulting state.
|
|
|
|
| 264 |
|
| 265 |
return reward, False
|
| 266 |
|
| 267 |
+
def _handle_illegal_move(
|
| 268 |
+
self, error_msg: str
|
| 269 |
+
) -> Tuple[ChessObservation, float, bool]:
|
| 270 |
"""Handle an illegal move attempt."""
|
| 271 |
observation = self._get_observation(done=False, error=error_msg)
|
| 272 |
return observation, self.reward_config.illegal_move, False
|