File size: 4,940 Bytes
3e1f9da
 
23fab64
3e1f9da
 
23fab64
3e1f9da
23fab64
 
3e1f9da
 
23fab64
e5572a6
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
 
23fab64
 
 
 
 
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23fab64
 
 
 
 
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
23fab64
 
 
 
 
 
 
 
 
a270816
 
 
 
 
 
 
 
 
3e1f9da
 
 
 
 
 
23fab64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e1f9da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e858e
3e1f9da
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""FastAPI server for the Chess OpenEnv environment."""

from pathlib import Path
from typing import Any, Dict, Optional

import chess
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel

from moonfish.lib import search_move
from ..models import ChessAction
from .chess_environment import ChessEnvironment


# Pydantic models for API requests/responses
class ResetRequest(BaseModel):
    seed: Optional[int] = None
    episode_id: Optional[str] = None
    fen: Optional[str] = None


class StepRequest(BaseModel):
    move: str


class EngineMoveRequest(BaseModel):
    fen: str
    depth: int = 2


class ObservationResponse(BaseModel):
    fen: str
    legal_moves: list[str]
    is_check: bool = False
    done: bool = False
    reward: Optional[float] = None
    result: Optional[str] = None
    metadata: Dict[str, Any] = {}


class StepResponse(BaseModel):
    observation: ObservationResponse
    reward: float
    done: bool


class StateResponse(BaseModel):
    episode_id: str
    step_count: int
    current_player: str
    fen: str
    move_history: list[str]


# Create FastAPI app
app = FastAPI(
    title="Chess OpenEnv",
    description="Chess environment for reinforcement learning using moonfish",
    version="1.0.0",
)

# Serve static files for the web UI
STATIC_DIR = Path(__file__).parent / "static"
if STATIC_DIR.exists():
    app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")

# Global environment instance (for single-player mode)
# For multi-player, you'd want a session manager
_env: Optional[ChessEnvironment] = None


def get_env() -> ChessEnvironment:
    """Get or create environment instance."""
    global _env
    if _env is None:
        _env = ChessEnvironment()
    return _env


@app.get("/")
def root():
    """Serve the chess UI."""
    index_path = STATIC_DIR / "index.html"
    if index_path.exists():
        return FileResponse(str(index_path))
    return {"message": "Moonfish Chess API", "docs": "/docs"}


@app.get("/web")
def web():
    """Serve the chess UI (for HF Spaces base_path)."""
    index_path = STATIC_DIR / "index.html"
    if index_path.exists():
        return FileResponse(str(index_path))
    return {"message": "Moonfish Chess API", "docs": "/docs"}


@app.get("/health")
def health():
    """Health check endpoint."""
    return {"status": "ok"}


@app.post("/engine-move")
def engine_move(request: EngineMoveRequest):
    """Get the best move from moonfish engine for a given position."""
    try:
        board = chess.Board(request.fen)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=f"Invalid FEN: {e}")

    if board.is_game_over():
        raise HTTPException(status_code=400, detail="Game is already over")

    depth = max(1, min(request.depth, 4))  # Clamp depth 1-4
    move = search_move(board, depth=depth)

    return {"move": move.uci(), "fen": request.fen}


@app.get("/metadata")
def metadata():
    """Get environment metadata."""
    return get_env().get_metadata()


@app.post("/reset", response_model=ObservationResponse)
def reset(request: ResetRequest):
    """Reset the environment and start a new episode."""
    env = get_env()
    obs = env.reset(
        seed=request.seed,
        episode_id=request.episode_id,
        fen=request.fen,
    )
    return ObservationResponse(
        fen=obs.fen,
        legal_moves=obs.legal_moves,
        is_check=obs.is_check,
        done=obs.done,
        reward=obs.reward,
        result=obs.result,
        metadata=obs.metadata,
    )


@app.post("/step", response_model=StepResponse)
def step(request: StepRequest):
    """Execute a move and return the result."""
    env = get_env()

    try:
        action = ChessAction(move=request.move)
        obs, reward, done = env.step(action)
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))

    return StepResponse(
        observation=ObservationResponse(
            fen=obs.fen,
            legal_moves=obs.legal_moves,
            is_check=obs.is_check,
            done=obs.done,
            reward=obs.reward,
            result=obs.result,
            metadata=obs.metadata,
        ),
        reward=reward,
        done=done,
    )


@app.get("/state", response_model=StateResponse)
def state():
    """Get current episode state."""
    env = get_env()
    try:
        s = env.state
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))

    return StateResponse(
        episode_id=s.episode_id,
        step_count=s.step_count,
        current_player=s.current_player,
        fen=s.fen,
        move_history=s.move_history,
    )


def main():
    """Entry point for running the server."""
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    main()