Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| import chess | |
| from contextlib import asynccontextmanager | |
| import logging | |
| import sys | |
| import time | |
| import torch | |
| from pydantic import BaseModel | |
| # Pickle stored the model/tokenizer classes under their original (top-level) | |
| # module paths. Alias src.* as top-level names so torch.load can resolve them. | |
| from src import tokenizer as _tokenizer_module | |
| sys.modules["tokenizer"] = _tokenizer_module | |
| from src.model import ChessPolicyModel, PolicyModelInference | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(name)s - %(message)s", | |
| ) | |
| log = logging.getLogger("transformer4chess") | |
| ml = {} | |
| async def lifespan(app: FastAPI): | |
| log.info("loading tokenizer from ./model/tokenizer.pt") | |
| t0 = time.perf_counter() | |
| tokenizer = torch.load("./model/tokenizer.pt", weights_only=False, map_location="cpu") | |
| log.info("tokenizer loaded (vocab=%d) in %.2fs", tokenizer.language_size, time.perf_counter() - t0) | |
| log.info("loading policy model from ./model/policy_model.pt") | |
| t0 = time.perf_counter() | |
| model = ChessPolicyModel(vocab_size=tokenizer.language_size) | |
| model.load_state_dict( | |
| torch.load("./model/policy_model.pt", weights_only=False, map_location="cpu") | |
| ) | |
| ml["inference"] = PolicyModelInference(model, tokenizer, device="cpu") | |
| log.info("policy model loaded in %.2fs", time.perf_counter() - t0) | |
| yield | |
| log.info("shutting down — clearing model cache") | |
| ml.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| class InferenceRequest(BaseModel): | |
| moves: list[str] | |
| def root(): | |
| return {"status": "ok", "endpoints": ["/inference", "/docs"]} | |
| def model_inference(req: InferenceRequest): | |
| log.info("inference request: %d moves", len(req.moves)) | |
| board = chess.Board() | |
| for move in req.moves: | |
| try: | |
| board.push_uci(move) | |
| except ValueError as e: | |
| log.warning("rejected illegal move %r: %s", move, e) | |
| raise HTTPException(status_code=400, detail=f"Incorrect move {move}: {e}") | |
| try: | |
| t0 = time.perf_counter() | |
| prediction = ml["inference"](board) | |
| log.info("predicted %s in %.3fs", prediction, time.perf_counter() - t0) | |
| return {"move": prediction} | |
| except Exception: | |
| log.exception("model inference failed") | |
| raise HTTPException(status_code=500, detail="Model failed to evaluate") | |