Rafs-an09002's picture
Create app.py
61f7235 verified
raw
history blame
4.8 kB
"""
Synapse-Base Inference API
FastAPI server for chess move prediction
Optimized for HF Spaces CPU environment
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import time
import logging
from typing import Optional
from engine import SynapseEngine
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Synapse-Base Inference API",
description="High-performance chess engine powered by 38M parameter neural network",
version="3.0.0"
)
# CORS middleware (allow your frontend domain)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change to your domain in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global engine instance (loaded once at startup)
engine = None
# Request/Response models
class MoveRequest(BaseModel):
fen: str = Field(..., description="Board position in FEN notation")
depth: Optional[int] = Field(3, ge=1, le=5, description="Search depth (1-5)")
time_limit: Optional[int] = Field(5000, ge=1000, le=30000, description="Time limit in ms")
class MoveResponse(BaseModel):
best_move: str
evaluation: float
depth_searched: int
nodes_evaluated: int
time_taken: int
pv: Optional[list] = None # Principal variation
class HealthResponse(BaseModel):
status: str
model_loaded: bool
version: str
# Startup event
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global engine
logger.info("πŸš€ Starting Synapse-Base Inference API...")
try:
engine = SynapseEngine(
model_path="/app/models/synapse_base.onnx",
num_threads=2 # Match HF Spaces 2 vCPU
)
logger.info("βœ… Model loaded successfully")
logger.info(f"πŸ“Š Model size: {engine.get_model_size():.2f} MB")
except Exception as e:
logger.error(f"❌ Failed to load model: {e}")
raise
# Health check endpoint
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if engine is not None else "unhealthy",
"model_loaded": engine is not None,
"version": "3.0.0"
}
# Main inference endpoint
@app.post("/get-move", response_model=MoveResponse)
async def get_move(request: MoveRequest):
"""
Get best move for given position
Args:
request: MoveRequest with FEN, depth, and time_limit
Returns:
MoveResponse with best_move and evaluation
"""
if engine is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate FEN
if not engine.validate_fen(request.fen):
raise HTTPException(status_code=400, detail="Invalid FEN string")
# Start timing
start_time = time.time()
try:
# Get best move from engine
result = engine.get_best_move(
fen=request.fen,
depth=request.depth,
time_limit=request.time_limit
)
# Calculate time taken
time_taken = int((time.time() - start_time) * 1000)
# Log request
logger.info(
f"Move: {result['best_move']} | "
f"Eval: {result['evaluation']:.3f} | "
f"Depth: {result['depth_searched']} | "
f"Nodes: {result['nodes_evaluated']} | "
f"Time: {time_taken}ms"
)
return MoveResponse(
best_move=result['best_move'],
evaluation=result['evaluation'],
depth_searched=result['depth_searched'],
nodes_evaluated=result['nodes_evaluated'],
time_taken=time_taken,
pv=result.get('pv', None)
)
except Exception as e:
logger.error(f"Error processing move: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Root endpoint
@app.get("/")
async def root():
"""Root endpoint with API info"""
return {
"name": "Synapse-Base Inference API",
"version": "3.0.0",
"model": "38.1M parameters",
"architecture": "CNN-Transformer Hybrid",
"endpoints": {
"POST /get-move": "Get best move for position",
"GET /health": "Health check",
"GET /docs": "API documentation"
}
}
# Run server
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True
)