grpo-chess-api / app.py
noamdwc's picture
chore: bake model into Docker image
41a2a9d
import os
import sys
import types
from pathlib import Path
import chess
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from safetensors.torch import load_file
APP_ROOT = Path(__file__).resolve().parent
if str(APP_ROOT) not in sys.path:
sys.path.insert(0, str(APP_ROOT))
# Compatibility shim: training code imports modules as src.*.
# In this Docker layout packages are top-level, so we alias them.
if "src" not in sys.modules:
src_pkg = types.ModuleType("src")
src_pkg.__path__ = [str(APP_ROOT)]
sys.modules["src"] = src_pkg
import grpo_self_play
import searchless_chess_model
sys.modules.setdefault("src.grpo_self_play", grpo_self_play)
sys.modules.setdefault("src.searchless_chess_model", searchless_chess_model)
class MoveRequest(BaseModel):
fen: str
temperature: float = 1.0
greedy: bool = False
class MoveResponse(BaseModel):
uci: str
san: str
fen: str
WEIGHTS_FILE = os.environ.get("WEIGHTS_FILE", "model.safetensors")
CONFIG_FILE = os.environ.get("CONFIG_FILE", "config.json")
LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "/app/model")
_model = None
_config = None
def load_model():
global _model, _config
if _model is not None:
return _model
base = Path(LOCAL_MODEL_DIR)
weights_path = str(base / WEIGHTS_FILE)
config_path = str(base / CONFIG_FILE)
if not Path(weights_path).exists():
raise FileNotFoundError(f"Weights not found at {weights_path}")
if not Path(config_path).exists():
raise FileNotFoundError(f"Config not found at {config_path}")
import json
from grpo_self_play.models import ChessTransformer, ChessTransformerConfig
_config = json.loads(Path(config_path).read_text())
model = ChessTransformer(ChessTransformerConfig(**_config))
state = load_file(weights_path)
model.load_state_dict(state, strict=False)
model.eval()
_model = model
return _model
def choose_move(model, board: chess.Board, temperature: float, greedy: bool) -> chess.Move:
from grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig
cfg = PolicyConfig(temperature=temperature, greedy=greedy)
player = PolicyPlayer(model, cfg=cfg)
move = player.act(board)
return move
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {
"status": "ok",
"model_dir": LOCAL_MODEL_DIR,
}
@app.post("/move", response_model=MoveResponse)
def move(req: MoveRequest):
try:
board = chess.Board(req.fen)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid FEN: {exc}")
model = load_model()
move = choose_move(model, board, req.temperature, req.greedy)
san = board.san(move)
board.push(move)
return MoveResponse(uci=move.uci(), san=san, fen=board.fen())