Spaces:
Sleeping
Sleeping
File size: 4,940 Bytes
3e1f9da 23fab64 3e1f9da 23fab64 3e1f9da 23fab64 3e1f9da 23fab64 e5572a6 3e1f9da 23fab64 3e1f9da 23fab64 3e1f9da 23fab64 a270816 3e1f9da 23fab64 3e1f9da b5e858e 3e1f9da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""FastAPI server for the Chess OpenEnv environment."""
from pathlib import Path
from typing import Any, Dict, Optional
import chess
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from moonfish.lib import search_move
from ..models import ChessAction
from .chess_environment import ChessEnvironment
# Pydantic models for API requests/responses
class ResetRequest(BaseModel):
seed: Optional[int] = None
episode_id: Optional[str] = None
fen: Optional[str] = None
class StepRequest(BaseModel):
move: str
class EngineMoveRequest(BaseModel):
fen: str
depth: int = 2
class ObservationResponse(BaseModel):
fen: str
legal_moves: list[str]
is_check: bool = False
done: bool = False
reward: Optional[float] = None
result: Optional[str] = None
metadata: Dict[str, Any] = {}
class StepResponse(BaseModel):
observation: ObservationResponse
reward: float
done: bool
class StateResponse(BaseModel):
episode_id: str
step_count: int
current_player: str
fen: str
move_history: list[str]
# Create FastAPI app
app = FastAPI(
title="Chess OpenEnv",
description="Chess environment for reinforcement learning using moonfish",
version="1.0.0",
)
# Serve static files for the web UI
STATIC_DIR = Path(__file__).parent / "static"
if STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
# Global environment instance (for single-player mode)
# For multi-player, you'd want a session manager
_env: Optional[ChessEnvironment] = None
def get_env() -> ChessEnvironment:
"""Get or create environment instance."""
global _env
if _env is None:
_env = ChessEnvironment()
return _env
@app.get("/")
def root():
"""Serve the chess UI."""
index_path = STATIC_DIR / "index.html"
if index_path.exists():
return FileResponse(str(index_path))
return {"message": "Moonfish Chess API", "docs": "/docs"}
@app.get("/web")
def web():
"""Serve the chess UI (for HF Spaces base_path)."""
index_path = STATIC_DIR / "index.html"
if index_path.exists():
return FileResponse(str(index_path))
return {"message": "Moonfish Chess API", "docs": "/docs"}
@app.get("/health")
def health():
"""Health check endpoint."""
return {"status": "ok"}
@app.post("/engine-move")
def engine_move(request: EngineMoveRequest):
"""Get the best move from moonfish engine for a given position."""
try:
board = chess.Board(request.fen)
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid FEN: {e}")
if board.is_game_over():
raise HTTPException(status_code=400, detail="Game is already over")
depth = max(1, min(request.depth, 4)) # Clamp depth 1-4
move = search_move(board, depth=depth)
return {"move": move.uci(), "fen": request.fen}
@app.get("/metadata")
def metadata():
"""Get environment metadata."""
return get_env().get_metadata()
@app.post("/reset", response_model=ObservationResponse)
def reset(request: ResetRequest):
"""Reset the environment and start a new episode."""
env = get_env()
obs = env.reset(
seed=request.seed,
episode_id=request.episode_id,
fen=request.fen,
)
return ObservationResponse(
fen=obs.fen,
legal_moves=obs.legal_moves,
is_check=obs.is_check,
done=obs.done,
reward=obs.reward,
result=obs.result,
metadata=obs.metadata,
)
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
"""Execute a move and return the result."""
env = get_env()
try:
action = ChessAction(move=request.move)
obs, reward, done = env.step(action)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(
observation=ObservationResponse(
fen=obs.fen,
legal_moves=obs.legal_moves,
is_check=obs.is_check,
done=obs.done,
reward=obs.reward,
result=obs.result,
metadata=obs.metadata,
),
reward=reward,
done=done,
)
@app.get("/state", response_model=StateResponse)
def state():
"""Get current episode state."""
env = get_env()
try:
s = env.state
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
return StateResponse(
episode_id=s.episode_id,
step_count=s.step_count,
current_player=s.current_player,
fen=s.fen,
move_history=s.move_history,
)
def main():
"""Entry point for running the server."""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|