Rafs-an09002 commited on
Commit
61f7235
·
verified ·
1 Parent(s): 9304068

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synapse-Base Inference API
3
+ FastAPI server for chess move prediction
4
+ Optimized for HF Spaces CPU environment
5
+ """
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel, Field
10
+ import time
11
+ import logging
12
+ from typing import Optional
13
+
14
+ from engine import SynapseEngine
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI(
25
+ title="Synapse-Base Inference API",
26
+ description="High-performance chess engine powered by 38M parameter neural network",
27
+ version="3.0.0"
28
+ )
29
+
30
+ # CORS middleware (allow your frontend domain)
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"], # Change to your domain in production
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # Global engine instance (loaded once at startup)
40
+ engine = None
41
+
42
+
43
+ # Request/Response models
44
+ class MoveRequest(BaseModel):
45
+ fen: str = Field(..., description="Board position in FEN notation")
46
+ depth: Optional[int] = Field(3, ge=1, le=5, description="Search depth (1-5)")
47
+ time_limit: Optional[int] = Field(5000, ge=1000, le=30000, description="Time limit in ms")
48
+
49
+
50
+ class MoveResponse(BaseModel):
51
+ best_move: str
52
+ evaluation: float
53
+ depth_searched: int
54
+ nodes_evaluated: int
55
+ time_taken: int
56
+ pv: Optional[list] = None # Principal variation
57
+
58
+
59
+ class HealthResponse(BaseModel):
60
+ status: str
61
+ model_loaded: bool
62
+ version: str
63
+
64
+
65
+ # Startup event
66
+ @app.on_event("startup")
67
+ async def startup_event():
68
+ """Load model on startup"""
69
+ global engine
70
+
71
+ logger.info("🚀 Starting Synapse-Base Inference API...")
72
+
73
+ try:
74
+ engine = SynapseEngine(
75
+ model_path="/app/models/synapse_base.onnx",
76
+ num_threads=2 # Match HF Spaces 2 vCPU
77
+ )
78
+ logger.info("✅ Model loaded successfully")
79
+ logger.info(f"📊 Model size: {engine.get_model_size():.2f} MB")
80
+
81
+ except Exception as e:
82
+ logger.error(f"❌ Failed to load model: {e}")
83
+ raise
84
+
85
+
86
+ # Health check endpoint
87
+ @app.get("/health", response_model=HealthResponse)
88
+ async def health_check():
89
+ """Health check endpoint"""
90
+ return {
91
+ "status": "healthy" if engine is not None else "unhealthy",
92
+ "model_loaded": engine is not None,
93
+ "version": "3.0.0"
94
+ }
95
+
96
+
97
+ # Main inference endpoint
98
+ @app.post("/get-move", response_model=MoveResponse)
99
+ async def get_move(request: MoveRequest):
100
+ """
101
+ Get best move for given position
102
+
103
+ Args:
104
+ request: MoveRequest with FEN, depth, and time_limit
105
+
106
+ Returns:
107
+ MoveResponse with best_move and evaluation
108
+ """
109
+
110
+ if engine is None:
111
+ raise HTTPException(status_code=503, detail="Model not loaded")
112
+
113
+ # Validate FEN
114
+ if not engine.validate_fen(request.fen):
115
+ raise HTTPException(status_code=400, detail="Invalid FEN string")
116
+
117
+ # Start timing
118
+ start_time = time.time()
119
+
120
+ try:
121
+ # Get best move from engine
122
+ result = engine.get_best_move(
123
+ fen=request.fen,
124
+ depth=request.depth,
125
+ time_limit=request.time_limit
126
+ )
127
+
128
+ # Calculate time taken
129
+ time_taken = int((time.time() - start_time) * 1000)
130
+
131
+ # Log request
132
+ logger.info(
133
+ f"Move: {result['best_move']} | "
134
+ f"Eval: {result['evaluation']:.3f} | "
135
+ f"Depth: {result['depth_searched']} | "
136
+ f"Nodes: {result['nodes_evaluated']} | "
137
+ f"Time: {time_taken}ms"
138
+ )
139
+
140
+ return MoveResponse(
141
+ best_move=result['best_move'],
142
+ evaluation=result['evaluation'],
143
+ depth_searched=result['depth_searched'],
144
+ nodes_evaluated=result['nodes_evaluated'],
145
+ time_taken=time_taken,
146
+ pv=result.get('pv', None)
147
+ )
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error processing move: {e}")
151
+ raise HTTPException(status_code=500, detail=str(e))
152
+
153
+
154
+ # Root endpoint
155
+ @app.get("/")
156
+ async def root():
157
+ """Root endpoint with API info"""
158
+ return {
159
+ "name": "Synapse-Base Inference API",
160
+ "version": "3.0.0",
161
+ "model": "38.1M parameters",
162
+ "architecture": "CNN-Transformer Hybrid",
163
+ "endpoints": {
164
+ "POST /get-move": "Get best move for position",
165
+ "GET /health": "Health check",
166
+ "GET /docs": "API documentation"
167
+ }
168
+ }
169
+
170
+
171
+ # Run server
172
+ if __name__ == "__main__":
173
+ import uvicorn
174
+
175
+ uvicorn.run(
176
+ app,
177
+ host="0.0.0.0",
178
+ port=7860,
179
+ log_level="info",
180
+ access_log=True
181
+ )