Amogh1221 commited on
Commit
da04874
Β·
verified Β·
1 Parent(s): 578e922

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +161 -46
main.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional, List, Dict
 
5
  import os
6
  import math
7
  import chess
@@ -9,8 +10,6 @@ import chess.engine
9
  import asyncio
10
  import json
11
 
12
- app = FastAPI(title="Deepcastle Engine API")
13
-
14
  # ─── Multiplaying / Challenge Manager ──────────────────────────────────────────
15
  class ConnectionManager:
16
  def __init__(self):
@@ -41,34 +40,6 @@ class ConnectionManager:
41
 
42
  manager = ConnectionManager()
43
 
44
- @app.websocket("/ws/{match_id}")
45
- async def websocket_endpoint(websocket: WebSocket, match_id: str):
46
- await manager.connect(websocket, match_id)
47
- room = manager.active_connections.get(match_id, [])
48
- # Notify others that someone joined
49
- await manager.broadcast(json.dumps({"type": "join"}), match_id, exclude=websocket)
50
- try:
51
- while True:
52
- data = await websocket.receive_text()
53
- # Relay the message (move, etc.) to others in the same room
54
- await manager.broadcast(data, match_id, exclude=websocket)
55
- except WebSocketDisconnect:
56
- manager.disconnect(websocket, match_id)
57
- # Notify remaining players that opponent disconnected β†’ they win
58
- await manager.broadcast(json.dumps({"type": "opponent_disconnected"}), match_id)
59
- except Exception:
60
- manager.disconnect(websocket, match_id)
61
- await manager.broadcast(json.dumps({"type": "opponent_disconnected"}), match_id)
62
-
63
-
64
- # Allow ALL for easy testing (we can restrict this later if needed)
65
- app.add_middleware(
66
- CORSMiddleware,
67
- allow_origins=["*"],
68
- allow_methods=["*"],
69
- allow_headers=["*"],
70
- )
71
-
72
  # Paths relative to the Docker container
73
  DEEPCASTLE_ENGINE_PATH = os.environ.get(
74
  "DEEPCASTLE_ENGINE_PATH",
@@ -114,16 +85,6 @@ class AnalyzeResponse(BaseModel):
114
  moves: List[MoveAnalysis]
115
  counts: Dict[str, int]
116
 
117
- @app.get("/")
118
- def home():
119
- return {"status": "online", "engine": "Deepcastle Hybrid Neural", "platform": "Hugging Face Spaces"}
120
-
121
- @app.get("/health")
122
- def health():
123
- if not os.path.exists(DEEPCASTLE_ENGINE_PATH):
124
- return {"status": "error", "message": "Missing engine binary: deepcastle"}
125
- return {"status": "ok", "engine": "deepcastle"}
126
-
127
  # Global engine instances to save memory and improve performance
128
  _GLOBAL_DEEPCASTLE_ENGINE = None
129
  _ENGINE_LOCK = asyncio.Lock()
@@ -209,6 +170,127 @@ async def get_stockfish_engine():
209
  # Compatibility alias: analysis now also uses DeepCastle.
210
  return await get_deepcastle_engine()
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def get_normalized_score(info) -> tuple[float, Optional[int]]:
213
  """Returns the score from White's perspective in centipawns."""
214
  if "score" not in info:
@@ -246,11 +328,20 @@ async def get_move(request: MoveRequest):
246
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
247
 
248
  # One search: stats must come from this run (not a separate short analyse).
 
249
  async with _ENGINE_IO_LOCK:
250
- result = await engine.play(board, limit, info=chess.engine.INFO_ALL)
 
 
 
 
251
  info = dict(result.info)
252
  if not info:
253
- info = await engine.analyse(board, limit, info=chess.engine.INFO_ALL)
 
 
 
 
254
 
255
  # From White's perspective in CP -> converted to Pawns for UI
256
  score_cp, mate_in = get_normalized_score(info)
@@ -287,6 +378,8 @@ async def get_move(request: MoveRequest):
287
  mate_in=mate_in,
288
  opening=opening_name
289
  )
 
 
290
  except Exception as e:
291
  print(f"Error: {e}")
292
  raise HTTPException(status_code=500, detail=str(e))
@@ -298,11 +391,20 @@ async def get_analysis_move(request: MoveRequest):
298
  board = chess.Board(request.fen)
299
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
300
 
 
301
  async with _ENGINE_IO_LOCK:
302
- result = await engine.play(board, limit, info=chess.engine.INFO_ALL)
 
 
 
 
303
  info = dict(result.info)
304
  if not info:
305
- info = await engine.analyse(board, limit, info=chess.engine.INFO_ALL)
 
 
 
 
306
 
307
  score_cp, mate_in = get_normalized_score(info)
308
 
@@ -336,6 +438,8 @@ async def get_analysis_move(request: MoveRequest):
336
  mate_in=mate_in,
337
  opening=opening_name
338
  )
 
 
339
  except Exception as e:
340
  print(f"Analysis move error: {e}")
341
  raise HTTPException(status_code=500, detail=str(e))
@@ -492,8 +596,13 @@ async def analyze_game(request: AnalyzeRequest):
492
 
493
  analysis_results = []
494
 
 
495
  async with _ENGINE_IO_LOCK:
496
- infos_before = await engine.analyse(board, limit, multipv=2)
 
 
 
 
497
  infos_before = infos_before if isinstance(infos_before, list) else [infos_before]
498
 
499
  counts = {
@@ -542,7 +651,11 @@ async def analyze_game(request: AnalyzeRequest):
542
  fen_history.append(board.fen())
543
 
544
  async with _ENGINE_IO_LOCK:
545
- infos_after_raw = await engine.analyse(board, limit, multipv=2)
 
 
 
 
546
  infos_after: List[dict] = infos_after_raw if isinstance(infos_after_raw, list) else [infos_after_raw]
547
 
548
  info_after_dict: dict = infos_after[0]
@@ -620,6 +733,8 @@ async def analyze_game(request: AnalyzeRequest):
620
  counts=counts
621
  )
622
 
 
 
623
  except Exception as e:
624
  print(f"Analysis Error: {e}")
625
  raise HTTPException(status_code=500, detail=str(e))
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional, List, Dict
5
+ from contextlib import asynccontextmanager
6
  import os
7
  import math
8
  import chess
 
10
  import asyncio
11
  import json
12
 
 
 
13
  # ─── Multiplaying / Challenge Manager ──────────────────────────────────────────
14
  class ConnectionManager:
15
  def __init__(self):
 
40
 
41
  manager = ConnectionManager()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Paths relative to the Docker container
44
  DEEPCASTLE_ENGINE_PATH = os.environ.get(
45
  "DEEPCASTLE_ENGINE_PATH",
 
85
  moves: List[MoveAnalysis]
86
  counts: Dict[str, int]
87
 
 
 
 
 
 
 
 
 
 
 
88
  # Global engine instances to save memory and improve performance
89
  _GLOBAL_DEEPCASTLE_ENGINE = None
90
  _ENGINE_LOCK = asyncio.Lock()
 
170
  # Compatibility alias: analysis now also uses DeepCastle.
171
  return await get_deepcastle_engine()
172
 
173
+
174
+ async def shutdown_engine_async() -> None:
175
+ """Release UCI subprocess on process exit (deploy / SIGTERM)."""
176
+ global _GLOBAL_DEEPCASTLE_ENGINE
177
+ async with _ENGINE_IO_LOCK:
178
+ async with _ENGINE_LOCK:
179
+ eng = _GLOBAL_DEEPCASTLE_ENGINE
180
+ _GLOBAL_DEEPCASTLE_ENGINE = None
181
+ if eng:
182
+ try:
183
+ await asyncio.wait_for(eng.quit(), timeout=5.0)
184
+ except Exception:
185
+ pass
186
+
187
+
188
+ async def _detach_and_quit_engine(engine) -> None:
189
+ """After a hung search, drop the singleton and try to terminate the process."""
190
+ global _GLOBAL_DEEPCASTLE_ENGINE
191
+ async with _ENGINE_LOCK:
192
+ if _GLOBAL_DEEPCASTLE_ENGINE is engine:
193
+ _GLOBAL_DEEPCASTLE_ENGINE = None
194
+ try:
195
+ await asyncio.wait_for(engine.quit(), timeout=5.0)
196
+ except Exception:
197
+ pass
198
+
199
+
200
+ def _search_timeout_sec(request_time: float, depth: Optional[int] = None) -> float:
201
+ """Wall-clock cap for a single play/analyse (env ENGINE_SEARCH_TIMEOUT_SEC, default 120)."""
202
+ try:
203
+ cap = float(os.environ.get("ENGINE_SEARCH_TIMEOUT_SEC", "120"))
204
+ except ValueError:
205
+ cap = 120.0
206
+ cap = max(15.0, min(600.0, cap))
207
+ if request_time and request_time > 0:
208
+ return min(cap, max(request_time * 3.0 + 10.0, 30.0))
209
+ return cap
210
+
211
+
212
+ def _analyze_ply_timeout(time_per_move: float) -> float:
213
+ """Wall-clock cap per analyse() in /analyze-game (multipv=2 needs headroom)."""
214
+ try:
215
+ cap = float(os.environ.get("ENGINE_SEARCH_TIMEOUT_SEC", "120"))
216
+ except ValueError:
217
+ cap = 120.0
218
+ cap = max(15.0, min(600.0, cap))
219
+ if time_per_move and time_per_move > 0:
220
+ return min(cap, max(time_per_move * 80.0 + 15.0, 30.0))
221
+ return cap
222
+
223
+
224
+ async def _engine_call(engine, coro, timeout_sec: float):
225
+ try:
226
+ return await asyncio.wait_for(coro, timeout=timeout_sec)
227
+ except asyncio.TimeoutError:
228
+ await _detach_and_quit_engine(engine)
229
+ raise HTTPException(status_code=504, detail="Engine search timed out")
230
+
231
+
232
+ @asynccontextmanager
233
+ async def lifespan(app: FastAPI):
234
+ yield
235
+ await shutdown_engine_async()
236
+
237
+
238
+ app = FastAPI(title="Deepcastle Engine API", lifespan=lifespan)
239
+
240
+ # Allow ALL for easy testing (we can restrict this later if needed)
241
+ app.add_middleware(
242
+ CORSMiddleware,
243
+ allow_origins=["*"],
244
+ allow_methods=["*"],
245
+ allow_headers=["*"],
246
+ )
247
+
248
+
249
+ @app.websocket("/ws/{match_id}")
250
+ async def websocket_endpoint(websocket: WebSocket, match_id: str):
251
+ await manager.connect(websocket, match_id)
252
+ room = manager.active_connections.get(match_id, [])
253
+ await manager.broadcast(json.dumps({"type": "join"}), match_id, exclude=websocket)
254
+ try:
255
+ while True:
256
+ data = await websocket.receive_text()
257
+ await manager.broadcast(data, match_id, exclude=websocket)
258
+ except WebSocketDisconnect:
259
+ manager.disconnect(websocket, match_id)
260
+ await manager.broadcast(json.dumps({"type": "opponent_disconnected"}), match_id)
261
+ except Exception:
262
+ manager.disconnect(websocket, match_id)
263
+ await manager.broadcast(json.dumps({"type": "opponent_disconnected"}), match_id)
264
+
265
+
266
+ @app.get("/")
267
+ def home():
268
+ return {"status": "online", "engine": "Deepcastle Hybrid Neural", "platform": "Hugging Face Spaces"}
269
+
270
+
271
+ @app.get("/health")
272
+ def health():
273
+ if not os.path.exists(DEEPCASTLE_ENGINE_PATH):
274
+ return {"status": "error", "message": "Missing engine binary: deepcastle"}
275
+ return {"status": "ok", "engine": "deepcastle"}
276
+
277
+
278
+ @app.get("/health/ready")
279
+ async def health_ready():
280
+ """Optional deep check: binary exists and engine answers UCI (for orchestrators)."""
281
+ if not os.path.exists(DEEPCASTLE_ENGINE_PATH):
282
+ raise HTTPException(status_code=503, detail="Missing engine binary")
283
+ try:
284
+ engine = await get_deepcastle_engine()
285
+ async with _ENGINE_IO_LOCK:
286
+ await asyncio.wait_for(engine.ping(), timeout=5.0)
287
+ return {"status": "ok", "engine": "responsive"}
288
+ except HTTPException:
289
+ raise
290
+ except Exception as e:
291
+ raise HTTPException(status_code=503, detail=str(e))
292
+
293
+
294
  def get_normalized_score(info) -> tuple[float, Optional[int]]:
295
  """Returns the score from White's perspective in centipawns."""
296
  if "score" not in info:
 
328
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
329
 
330
  # One search: stats must come from this run (not a separate short analyse).
331
+ tsec = _search_timeout_sec(request.time, request.depth)
332
  async with _ENGINE_IO_LOCK:
333
+ result = await _engine_call(
334
+ engine,
335
+ engine.play(board, limit, info=chess.engine.INFO_ALL),
336
+ tsec,
337
+ )
338
  info = dict(result.info)
339
  if not info:
340
+ info = await _engine_call(
341
+ engine,
342
+ engine.analyse(board, limit, info=chess.engine.INFO_ALL),
343
+ tsec,
344
+ )
345
 
346
  # From White's perspective in CP -> converted to Pawns for UI
347
  score_cp, mate_in = get_normalized_score(info)
 
378
  mate_in=mate_in,
379
  opening=opening_name
380
  )
381
+ except HTTPException:
382
+ raise
383
  except Exception as e:
384
  print(f"Error: {e}")
385
  raise HTTPException(status_code=500, detail=str(e))
 
391
  board = chess.Board(request.fen)
392
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
393
 
394
+ tsec = _search_timeout_sec(request.time, request.depth)
395
  async with _ENGINE_IO_LOCK:
396
+ result = await _engine_call(
397
+ engine,
398
+ engine.play(board, limit, info=chess.engine.INFO_ALL),
399
+ tsec,
400
+ )
401
  info = dict(result.info)
402
  if not info:
403
+ info = await _engine_call(
404
+ engine,
405
+ engine.analyse(board, limit, info=chess.engine.INFO_ALL),
406
+ tsec,
407
+ )
408
 
409
  score_cp, mate_in = get_normalized_score(info)
410
 
 
438
  mate_in=mate_in,
439
  opening=opening_name
440
  )
441
+ except HTTPException:
442
+ raise
443
  except Exception as e:
444
  print(f"Analysis move error: {e}")
445
  raise HTTPException(status_code=500, detail=str(e))
 
596
 
597
  analysis_results = []
598
 
599
+ ply_timeout = _analyze_ply_timeout(request.time_per_move)
600
  async with _ENGINE_IO_LOCK:
601
+ infos_before = await _engine_call(
602
+ engine,
603
+ engine.analyse(board, limit, multipv=2),
604
+ ply_timeout,
605
+ )
606
  infos_before = infos_before if isinstance(infos_before, list) else [infos_before]
607
 
608
  counts = {
 
651
  fen_history.append(board.fen())
652
 
653
  async with _ENGINE_IO_LOCK:
654
+ infos_after_raw = await _engine_call(
655
+ engine,
656
+ engine.analyse(board, limit, multipv=2),
657
+ ply_timeout,
658
+ )
659
  infos_after: List[dict] = infos_after_raw if isinstance(infos_after_raw, list) else [infos_after_raw]
660
 
661
  info_after_dict: dict = infos_after[0]
 
733
  counts=counts
734
  )
735
 
736
+ except HTTPException:
737
+ raise
738
  except Exception as e:
739
  print(f"Analysis Error: {e}")
740
  raise HTTPException(status_code=500, detail=str(e))