File size: 4,796 Bytes
61f7235 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
Synapse-Base Inference API
FastAPI server for chess move prediction
Optimized for HF Spaces CPU environment
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import time
import logging
from typing import Optional
from engine import SynapseEngine
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Synapse-Base Inference API",
description="High-performance chess engine powered by 38M parameter neural network",
version="3.0.0"
)
# CORS middleware (allow your frontend domain)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change to your domain in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global engine instance (loaded once at startup)
engine = None
# Request/Response models
class MoveRequest(BaseModel):
fen: str = Field(..., description="Board position in FEN notation")
depth: Optional[int] = Field(3, ge=1, le=5, description="Search depth (1-5)")
time_limit: Optional[int] = Field(5000, ge=1000, le=30000, description="Time limit in ms")
class MoveResponse(BaseModel):
best_move: str
evaluation: float
depth_searched: int
nodes_evaluated: int
time_taken: int
pv: Optional[list] = None # Principal variation
class HealthResponse(BaseModel):
status: str
model_loaded: bool
version: str
# Startup event
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global engine
logger.info("🚀 Starting Synapse-Base Inference API...")
try:
engine = SynapseEngine(
model_path="/app/models/synapse_base.onnx",
num_threads=2 # Match HF Spaces 2 vCPU
)
logger.info("✅ Model loaded successfully")
logger.info(f"📊 Model size: {engine.get_model_size():.2f} MB")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise
# Health check endpoint
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if engine is not None else "unhealthy",
"model_loaded": engine is not None,
"version": "3.0.0"
}
# Main inference endpoint
@app.post("/get-move", response_model=MoveResponse)
async def get_move(request: MoveRequest):
"""
Get best move for given position
Args:
request: MoveRequest with FEN, depth, and time_limit
Returns:
MoveResponse with best_move and evaluation
"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate FEN
if not engine.validate_fen(request.fen):
raise HTTPException(status_code=400, detail="Invalid FEN string")
# Start timing
start_time = time.time()
try:
# Get best move from engine
result = engine.get_best_move(
fen=request.fen,
depth=request.depth,
time_limit=request.time_limit
)
# Calculate time taken
time_taken = int((time.time() - start_time) * 1000)
# Log request
logger.info(
f"Move: {result['best_move']} | "
f"Eval: {result['evaluation']:.3f} | "
f"Depth: {result['depth_searched']} | "
f"Nodes: {result['nodes_evaluated']} | "
f"Time: {time_taken}ms"
)
return MoveResponse(
best_move=result['best_move'],
evaluation=result['evaluation'],
depth_searched=result['depth_searched'],
nodes_evaluated=result['nodes_evaluated'],
time_taken=time_taken,
pv=result.get('pv', None)
)
except Exception as e:
logger.error(f"Error processing move: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API info"""
return {
"name": "Synapse-Base Inference API",
"version": "3.0.0",
"model": "38.1M parameters",
"architecture": "CNN-Transformer Hybrid",
"endpoints": {
"POST /get-move": "Get best move for position",
"GET /health": "Health check",
"GET /docs": "API documentation"
}
}
# Run server
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True
) |