Spaces:
Sleeping
Sleeping
| """FastAPI server for the Chess OpenEnv environment.""" | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import chess | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from moonfish.lib import search_move | |
| from ..models import ChessAction | |
| from .chess_environment import ChessEnvironment | |
| # Pydantic models for API requests/responses | |
| class ResetRequest(BaseModel): | |
| seed: Optional[int] = None | |
| episode_id: Optional[str] = None | |
| fen: Optional[str] = None | |
| class StepRequest(BaseModel): | |
| move: str | |
| class EngineMoveRequest(BaseModel): | |
| fen: str | |
| depth: int = 2 | |
| class ObservationResponse(BaseModel): | |
| fen: str | |
| legal_moves: list[str] | |
| is_check: bool = False | |
| done: bool = False | |
| reward: Optional[float] = None | |
| result: Optional[str] = None | |
| metadata: Dict[str, Any] = {} | |
| class StepResponse(BaseModel): | |
| observation: ObservationResponse | |
| reward: float | |
| done: bool | |
| class StateResponse(BaseModel): | |
| episode_id: str | |
| step_count: int | |
| current_player: str | |
| fen: str | |
| move_history: list[str] | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Chess OpenEnv", | |
| description="Chess environment for reinforcement learning using moonfish", | |
| version="1.0.0", | |
| ) | |
| # Serve static files for the web UI | |
| STATIC_DIR = Path(__file__).parent / "static" | |
| if STATIC_DIR.exists(): | |
| app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") | |
| # Global environment instance (for single-player mode) | |
| # For multi-player, you'd want a session manager | |
| _env: Optional[ChessEnvironment] = None | |
| def get_env() -> ChessEnvironment: | |
| """Get or create environment instance.""" | |
| global _env | |
| if _env is None: | |
| _env = ChessEnvironment() | |
| return _env | |
| def root(): | |
| """Serve the chess UI.""" | |
| index_path = STATIC_DIR / "index.html" | |
| if index_path.exists(): | |
| return FileResponse(str(index_path)) | |
| return {"message": "Moonfish Chess API", "docs": "/docs"} | |
| def web(): | |
| """Serve the chess UI (for HF Spaces base_path).""" | |
| index_path = STATIC_DIR / "index.html" | |
| if index_path.exists(): | |
| return FileResponse(str(index_path)) | |
| return {"message": "Moonfish Chess API", "docs": "/docs"} | |
| def health(): | |
| """Health check endpoint.""" | |
| return {"status": "ok"} | |
| def engine_move(request: EngineMoveRequest): | |
| """Get the best move from moonfish engine for a given position.""" | |
| try: | |
| board = chess.Board(request.fen) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid FEN: {e}") | |
| if board.is_game_over(): | |
| raise HTTPException(status_code=400, detail="Game is already over") | |
| depth = max(1, min(request.depth, 4)) # Clamp depth 1-4 | |
| move = search_move(board, depth=depth) | |
| return {"move": move.uci(), "fen": request.fen} | |
| def metadata(): | |
| """Get environment metadata.""" | |
| return get_env().get_metadata() | |
| def reset(request: ResetRequest): | |
| """Reset the environment and start a new episode.""" | |
| env = get_env() | |
| obs = env.reset( | |
| seed=request.seed, | |
| episode_id=request.episode_id, | |
| fen=request.fen, | |
| ) | |
| return ObservationResponse( | |
| fen=obs.fen, | |
| legal_moves=obs.legal_moves, | |
| is_check=obs.is_check, | |
| done=obs.done, | |
| reward=obs.reward, | |
| result=obs.result, | |
| metadata=obs.metadata, | |
| ) | |
| def step(request: StepRequest): | |
| """Execute a move and return the result.""" | |
| env = get_env() | |
| try: | |
| action = ChessAction(move=request.move) | |
| obs, reward, done = env.step(action) | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return StepResponse( | |
| observation=ObservationResponse( | |
| fen=obs.fen, | |
| legal_moves=obs.legal_moves, | |
| is_check=obs.is_check, | |
| done=obs.done, | |
| reward=obs.reward, | |
| result=obs.result, | |
| metadata=obs.metadata, | |
| ), | |
| reward=reward, | |
| done=done, | |
| ) | |
| def state(): | |
| """Get current episode state.""" | |
| env = get_env() | |
| try: | |
| s = env.state | |
| except RuntimeError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return StateResponse( | |
| episode_id=s.episode_id, | |
| step_count=s.step_count, | |
| current_player=s.current_player, | |
| fen=s.fen, | |
| move_history=s.move_history, | |
| ) | |
| def main(): | |
| """Entry point for running the server.""" | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |
| if __name__ == "__main__": | |
| main() | |