luccabb's picture
Upload folder using huggingface_hub
b5e858e verified
"""FastAPI server for the Chess OpenEnv environment."""
from pathlib import Path
from typing import Any, Dict, Optional
import chess
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from moonfish.lib import search_move
from ..models import ChessAction
from .chess_environment import ChessEnvironment
# Pydantic models for API requests/responses
class ResetRequest(BaseModel):
seed: Optional[int] = None
episode_id: Optional[str] = None
fen: Optional[str] = None
class StepRequest(BaseModel):
move: str
class EngineMoveRequest(BaseModel):
fen: str
depth: int = 2
class ObservationResponse(BaseModel):
fen: str
legal_moves: list[str]
is_check: bool = False
done: bool = False
reward: Optional[float] = None
result: Optional[str] = None
metadata: Dict[str, Any] = {}
class StepResponse(BaseModel):
observation: ObservationResponse
reward: float
done: bool
class StateResponse(BaseModel):
episode_id: str
step_count: int
current_player: str
fen: str
move_history: list[str]
# Create FastAPI app
app = FastAPI(
title="Chess OpenEnv",
description="Chess environment for reinforcement learning using moonfish",
version="1.0.0",
)
# Serve static files for the web UI
STATIC_DIR = Path(__file__).parent / "static"
if STATIC_DIR.exists():
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
# Global environment instance (for single-player mode)
# For multi-player, you'd want a session manager
_env: Optional[ChessEnvironment] = None
def get_env() -> ChessEnvironment:
"""Get or create environment instance."""
global _env
if _env is None:
_env = ChessEnvironment()
return _env
@app.get("/")
def root():
"""Serve the chess UI."""
index_path = STATIC_DIR / "index.html"
if index_path.exists():
return FileResponse(str(index_path))
return {"message": "Moonfish Chess API", "docs": "/docs"}
@app.get("/web")
def web():
"""Serve the chess UI (for HF Spaces base_path)."""
index_path = STATIC_DIR / "index.html"
if index_path.exists():
return FileResponse(str(index_path))
return {"message": "Moonfish Chess API", "docs": "/docs"}
@app.get("/health")
def health():
"""Health check endpoint."""
return {"status": "ok"}
@app.post("/engine-move")
def engine_move(request: EngineMoveRequest):
"""Get the best move from moonfish engine for a given position."""
try:
board = chess.Board(request.fen)
except ValueError as e:
raise HTTPException(status_code=400, detail=f"Invalid FEN: {e}")
if board.is_game_over():
raise HTTPException(status_code=400, detail="Game is already over")
depth = max(1, min(request.depth, 4)) # Clamp depth 1-4
move = search_move(board, depth=depth)
return {"move": move.uci(), "fen": request.fen}
@app.get("/metadata")
def metadata():
"""Get environment metadata."""
return get_env().get_metadata()
@app.post("/reset", response_model=ObservationResponse)
def reset(request: ResetRequest):
"""Reset the environment and start a new episode."""
env = get_env()
obs = env.reset(
seed=request.seed,
episode_id=request.episode_id,
fen=request.fen,
)
return ObservationResponse(
fen=obs.fen,
legal_moves=obs.legal_moves,
is_check=obs.is_check,
done=obs.done,
reward=obs.reward,
result=obs.result,
metadata=obs.metadata,
)
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
"""Execute a move and return the result."""
env = get_env()
try:
action = ChessAction(move=request.move)
obs, reward, done = env.step(action)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(
observation=ObservationResponse(
fen=obs.fen,
legal_moves=obs.legal_moves,
is_check=obs.is_check,
done=obs.done,
reward=obs.reward,
result=obs.result,
metadata=obs.metadata,
),
reward=reward,
done=done,
)
@app.get("/state", response_model=StateResponse)
def state():
"""Get current episode state."""
env = get_env()
try:
s = env.state
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
return StateResponse(
episode_id=s.episode_id,
step_count=s.step_count,
current_player=s.current_player,
fen=s.fen,
move_history=s.move_history,
)
def main():
"""Entry point for running the server."""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()