Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import types | |
| from pathlib import Path | |
| import chess | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from safetensors.torch import load_file | |
| APP_ROOT = Path(__file__).resolve().parent | |
| if str(APP_ROOT) not in sys.path: | |
| sys.path.insert(0, str(APP_ROOT)) | |
| # Compatibility shim: training code imports modules as src.*. | |
| # In this Docker layout packages are top-level, so we alias them. | |
| if "src" not in sys.modules: | |
| src_pkg = types.ModuleType("src") | |
| src_pkg.__path__ = [str(APP_ROOT)] | |
| sys.modules["src"] = src_pkg | |
| import grpo_self_play | |
| import searchless_chess_model | |
| sys.modules.setdefault("src.grpo_self_play", grpo_self_play) | |
| sys.modules.setdefault("src.searchless_chess_model", searchless_chess_model) | |
| class MoveRequest(BaseModel): | |
| fen: str | |
| temperature: float = 1.0 | |
| greedy: bool = False | |
| class MoveResponse(BaseModel): | |
| uci: str | |
| san: str | |
| fen: str | |
| WEIGHTS_FILE = os.environ.get("WEIGHTS_FILE", "model.safetensors") | |
| CONFIG_FILE = os.environ.get("CONFIG_FILE", "config.json") | |
| LOCAL_MODEL_DIR = os.environ.get("LOCAL_MODEL_DIR", "/app/model") | |
| _model = None | |
| _config = None | |
| def load_model(): | |
| global _model, _config | |
| if _model is not None: | |
| return _model | |
| base = Path(LOCAL_MODEL_DIR) | |
| weights_path = str(base / WEIGHTS_FILE) | |
| config_path = str(base / CONFIG_FILE) | |
| if not Path(weights_path).exists(): | |
| raise FileNotFoundError(f"Weights not found at {weights_path}") | |
| if not Path(config_path).exists(): | |
| raise FileNotFoundError(f"Config not found at {config_path}") | |
| import json | |
| from grpo_self_play.models import ChessTransformer, ChessTransformerConfig | |
| _config = json.loads(Path(config_path).read_text()) | |
| model = ChessTransformer(ChessTransformerConfig(**_config)) | |
| state = load_file(weights_path) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| _model = model | |
| return _model | |
| def choose_move(model, board: chess.Board, temperature: float, greedy: bool) -> chess.Move: | |
| from grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig | |
| cfg = PolicyConfig(temperature=temperature, greedy=greedy) | |
| player = PolicyPlayer(model, cfg=cfg) | |
| move = player.act(board) | |
| return move | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model_dir": LOCAL_MODEL_DIR, | |
| } | |
| def move(req: MoveRequest): | |
| try: | |
| board = chess.Board(req.fen) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid FEN: {exc}") | |
| model = load_model() | |
| move = choose_move(model, board, req.temperature, req.greedy) | |
| san = board.san(move) | |
| board.push(move) | |
| return MoveResponse(uci=move.uci(), san=san, fen=board.fen()) | |