Amogh1221 commited on
Commit
d1e63fb
Β·
verified Β·
1 Parent(s): 1c02c1d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +185 -23
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional, List, Dict
@@ -7,9 +7,55 @@ import math
7
  import chess
8
  import chess.engine
9
  import asyncio
 
10
 
11
  app = FastAPI(title="Deepcastle Engine API")
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -17,12 +63,13 @@ app.add_middleware(
17
  allow_headers=["*"],
18
  )
19
 
 
20
  ENGINE_PATH = os.environ.get("ENGINE_PATH", "/app/engine/deepcastle")
21
  NNUE_PATH = os.environ.get("NNUE_PATH", "/app/engine/output.nnue")
22
 
23
  class MoveRequest(BaseModel):
24
  fen: str
25
- time: float = 1.0
26
  depth: Optional[int] = None
27
 
28
  class MoveResponse(BaseModel):
@@ -33,16 +80,17 @@ class MoveResponse(BaseModel):
33
  nps: int
34
  pv: str
35
 
 
36
  class AnalyzeRequest(BaseModel):
37
- moves: List[str]
38
- time_per_move: float = 0.1
39
  player_color: str = "white"
40
 
41
  class MoveAnalysis(BaseModel):
42
  move_num: int
43
  san: str
44
  fen: str
45
- classification: str
46
  cpl: float
47
  score_before: float
48
  score_after: float
@@ -76,6 +124,7 @@ async def get_engine():
76
  return engine
77
 
78
  def get_normalized_score(info, turn_color=chess.WHITE):
 
79
  if "score" not in info:
80
  return 0.0
81
  raw = info["score"].white()
@@ -83,6 +132,7 @@ def get_normalized_score(info, turn_color=chess.WHITE):
83
  return 10000.0 if (raw.mate() or 0) > 0 else -10000.0
84
  return raw.score() or 0.0
85
 
 
86
  @app.post("/move", response_model=MoveResponse)
87
  async def get_move(request: MoveRequest):
88
  engine = None
@@ -90,14 +140,53 @@ async def get_move(request: MoveRequest):
90
  engine = await get_engine()
91
  board = chess.Board(request.fen)
92
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
 
93
  result = await engine.play(board, limit)
94
  info = await engine.analyse(board, limit)
 
 
95
  score_cp = get_normalized_score(info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  score_pawns = score_cp / 100.0 if abs(score_cp) < 9900 else (100.0 if score_cp > 0 else -100.0)
97
- return MoveResponse(bestmove=result.move.uci(), score=score_pawns, depth=info.get("depth", 0), nodes=info.get("nodes", 0), nps=info.get("nps", 0), pv="")
 
 
 
 
 
 
 
 
 
 
 
98
  finally:
99
- if engine: await engine.quit()
 
 
 
 
 
100
 
 
101
  @app.post("/analyze-game", response_model=AnalyzeResponse)
102
  async def analyze_game(request: AnalyzeRequest):
103
  engine = None
@@ -105,46 +194,119 @@ async def analyze_game(request: AnalyzeRequest):
105
  engine = await get_engine()
106
  board = chess.Board()
107
  limit = chess.engine.Limit(time=request.time_per_move)
108
- analysis_results, total_cpl, player_moves_count = [], 0, 0
109
- counts = {"Brilliant": 0, "Great": 0, "Best": 0, "Excellent": 0, "Good": 0, "Inaccuracy": 0, "Mistake": 0, "Blunder": 0}
110
 
 
 
 
111
  info_before = await engine.analyse(board, limit)
112
  current_score = get_normalized_score(info_before)
 
 
 
 
 
 
 
 
 
 
 
113
  player_is_white = (request.player_color.lower() == "white")
114
 
115
  for i, san_move in enumerate(request.moves):
116
- is_player_turn = board.turn == (chess.WHITE if player_is_white else chess.BLACK)
 
 
117
  score_before = current_score
 
 
118
  try:
119
- board.push_san(san_move)
120
- except Exception: break
 
 
121
 
 
122
  info_after = await engine.analyse(board, limit)
123
- current_score = get_normalized_score(info_after)
124
 
 
 
 
 
125
  if is_player_turn:
126
- cpl = max(0, score_before - current_score if player_is_white else current_score - score_before)
 
 
 
 
 
 
 
 
127
  cpl = min(cpl, 1000.0)
 
128
  total_cpl += cpl
129
  player_moves_count += 1
130
 
131
- if cpl <= 15: cls = "Best"
132
- elif cpl <= 35: cls = "Excellent"
133
- elif cpl <= 75: cls = "Good"
134
- elif cpl <= 150: cls = "Inaccuracy"
135
- elif cpl <= 300: cls = "Mistake"
136
- else: cls = "Blunder"
 
 
 
 
 
 
 
137
 
138
  counts[cls] += 1
139
- analysis_results.append(MoveAnalysis(move_num=i+1, san=san_move, fen=board.fen(), classification=cls, cpl=cpl, score_before=score_before/100, score_after=current_score/100))
 
 
 
 
 
 
 
 
 
140
 
 
 
141
  avg_cpl = total_cpl / max(1, player_moves_count)
 
 
 
 
142
  accuracy = max(10.0, min(100.0, 100.0 * math.exp(-0.005 * avg_cpl)))
 
 
 
143
  estimated_elo = int(max(400, min(3600, 3600 - (avg_cpl * 20))))
144
- return AnalyzeResponse(accuracy=round(accuracy, 1), estimated_elo=estimated_elo, moves=analysis_results, counts=counts)
 
 
 
 
 
 
 
 
 
 
145
  finally:
146
- if engine: await engine.quit()
 
 
 
 
 
147
 
148
  if __name__ == "__main__":
149
  import uvicorn
 
150
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional, List, Dict
 
7
  import chess
8
  import chess.engine
9
  import asyncio
10
+ import json
11
 
12
  app = FastAPI(title="Deepcastle Engine API")
13
 
14
+ # ─── Multiplaying / Challenge Manager ──────────────────────────────────────────
15
+ class ConnectionManager:
16
+ def __init__(self):
17
+ # match_id -> list of websockets
18
+ self.active_connections: Dict[str, List[WebSocket]] = {}
19
+
20
+ async def connect(self, websocket: WebSocket, match_id: str):
21
+ await websocket.accept()
22
+ if match_id not in self.active_connections:
23
+ self.active_connections[match_id] = []
24
+ self.active_connections[match_id].append(websocket)
25
+
26
+ def disconnect(self, websocket: WebSocket, match_id: str):
27
+ if match_id in self.active_connections:
28
+ if websocket in self.active_connections[match_id]:
29
+ self.active_connections[match_id].remove(websocket)
30
+ if not self.active_connections[match_id]:
31
+ del self.active_connections[match_id]
32
+
33
+ async def broadcast(self, message: str, match_id: str, exclude: WebSocket = None):
34
+ if match_id in self.active_connections:
35
+ for connection in self.active_connections[match_id]:
36
+ if connection != exclude:
37
+ try:
38
+ await connection.send_text(message)
39
+ except Exception:
40
+ pass
41
+
42
+ manager = ConnectionManager()
43
+
44
+ @app.websocket("/ws/{match_id}")
45
+ async def websocket_endpoint(websocket: WebSocket, match_id: str):
46
+ await manager.connect(websocket, match_id)
47
+ try:
48
+ while True:
49
+ data = await websocket.receive_text()
50
+ # Relay the message (move, chat, etc) to others in the same room
51
+ await manager.broadcast(data, match_id, exclude=websocket)
52
+ except WebSocketDisconnect:
53
+ manager.disconnect(websocket, match_id)
54
+ except Exception:
55
+ manager.disconnect(websocket, match_id)
56
+
57
+
58
+ # Allow ALL for easy testing (we can restrict this later if needed)
59
  app.add_middleware(
60
  CORSMiddleware,
61
  allow_origins=["*"],
 
63
  allow_headers=["*"],
64
  )
65
 
66
+ # Paths relative to the Docker container
67
  ENGINE_PATH = os.environ.get("ENGINE_PATH", "/app/engine/deepcastle")
68
  NNUE_PATH = os.environ.get("NNUE_PATH", "/app/engine/output.nnue")
69
 
70
  class MoveRequest(BaseModel):
71
  fen: str
72
+ time: float = 1.0 # seconds
73
  depth: Optional[int] = None
74
 
75
  class MoveResponse(BaseModel):
 
80
  nps: int
81
  pv: str
82
 
83
+ # ─── New Analysis Types ────────────────────────────────────────────────────────
84
  class AnalyzeRequest(BaseModel):
85
+ moves: List[str] # e.g., ["e4", "e5", "Nf3", "Nc6", ...]
86
+ time_per_move: float = 0.1 # quick eval per move
87
  player_color: str = "white"
88
 
89
  class MoveAnalysis(BaseModel):
90
  move_num: int
91
  san: str
92
  fen: str
93
+ classification: str # Best, Excellent, Good, Inaccuracy, Mistake, Blunder, Brilliant
94
  cpl: float
95
  score_before: float
96
  score_after: float
 
124
  return engine
125
 
126
  def get_normalized_score(info, turn_color=chess.WHITE):
127
+ """Returns the score from White's perspective in centipawns."""
128
  if "score" not in info:
129
  return 0.0
130
  raw = info["score"].white()
 
132
  return 10000.0 if (raw.mate() or 0) > 0 else -10000.0
133
  return raw.score() or 0.0
134
 
135
+ # ─── Engine Inference Route ────────────────────────────────────────────────────
136
  @app.post("/move", response_model=MoveResponse)
137
  async def get_move(request: MoveRequest):
138
  engine = None
 
140
  engine = await get_engine()
141
  board = chess.Board(request.fen)
142
  limit = chess.engine.Limit(time=request.time, depth=request.depth)
143
+
144
  result = await engine.play(board, limit)
145
  info = await engine.analyse(board, limit)
146
+
147
+ # From White's perspective in CP -> converted to Pawns for UI
148
  score_cp = get_normalized_score(info)
149
+
150
+ depth = info.get("depth", 0)
151
+ nodes = info.get("nodes", 0)
152
+ nps = info.get("nps", 0)
153
+
154
+ pv_board = board.copy()
155
+ pv_parts = []
156
+ for m in info.get("pv", [])[:5]:
157
+ if m in pv_board.legal_moves:
158
+ try:
159
+ pv_parts.append(pv_board.san(m))
160
+ pv_board.push(m)
161
+ except Exception:
162
+ break
163
+ else:
164
+ break
165
+ pv = " ".join(pv_parts)
166
+
167
+ # Map mate score to pawns representation to not break old UI
168
  score_pawns = score_cp / 100.0 if abs(score_cp) < 9900 else (100.0 if score_cp > 0 else -100.0)
169
+
170
+ return MoveResponse(
171
+ bestmove=result.move.uci(),
172
+ score=score_pawns,
173
+ depth=depth,
174
+ nodes=nodes,
175
+ nps=nps,
176
+ pv=pv
177
+ )
178
+ except Exception as e:
179
+ print(f"Error: {e}")
180
+ raise HTTPException(status_code=500, detail=str(e))
181
  finally:
182
+ if engine:
183
+ try:
184
+ await engine.quit()
185
+ except Exception:
186
+ pass
187
+
188
 
189
+ # ─── Game Review Route ─────────────────────────────────────────────────────────
190
  @app.post("/analyze-game", response_model=AnalyzeResponse)
191
  async def analyze_game(request: AnalyzeRequest):
192
  engine = None
 
194
  engine = await get_engine()
195
  board = chess.Board()
196
  limit = chess.engine.Limit(time=request.time_per_move)
 
 
197
 
198
+ analysis_results = []
199
+
200
+ # We need the pre-move evaluation of the very first position
201
  info_before = await engine.analyse(board, limit)
202
  current_score = get_normalized_score(info_before)
203
+
204
+ # To track accuracy
205
+ total_cpl = 0
206
+ player_moves_count = 0
207
+
208
+ counts = {
209
+ "Brilliant": 0, "Great": 0, "Best": 0,
210
+ "Excellent": 0, "Good": 0, "Inaccuracy": 0,
211
+ "Mistake": 0, "Blunder": 0
212
+ }
213
+
214
  player_is_white = (request.player_color.lower() == "white")
215
 
216
  for i, san_move in enumerate(request.moves):
217
+ is_player_turn = board.turn == chess.WHITE if player_is_white else board.turn == chess.BLACK
218
+
219
+ # The current_score is the score BEFORE this move
220
  score_before = current_score
221
+
222
+ # Push move
223
  try:
224
+ move = board.parse_san(san_move)
225
+ board.push(move)
226
+ except Exception:
227
+ break # Invalid move, stop analysis here
228
 
229
+ # Get eval AFTER move
230
  info_after = await engine.analyse(board, limit)
231
+ score_after = get_normalized_score(info_after)
232
 
233
+ # Update current score for next iteration
234
+ current_score = score_after
235
+
236
+ # Only analyze the player's moves
237
  if is_player_turn:
238
+ # Calculate Centipawn Loss (diff between score before and score after)
239
+ # If player is White, positive score is good. If White drops from +100 to +50 -> CPL = 50.
240
+ # If player is Black, negative score is good. If Black rises from -100 to -50 -> CPL = 50.
241
+ if player_is_white:
242
+ cpl = max(0, score_before - score_after)
243
+ else:
244
+ cpl = max(0, score_after - score_before)
245
+
246
+ # Cap CPL to 1000 so one massive blunder doesn't utterly ruin the stats
247
  cpl = min(cpl, 1000.0)
248
+
249
  total_cpl += cpl
250
  player_moves_count += 1
251
 
252
+ # Classification mapping
253
+ if cpl <= 15:
254
+ cls = "Best"
255
+ elif cpl <= 35:
256
+ cls = "Excellent"
257
+ elif cpl <= 75:
258
+ cls = "Good"
259
+ elif cpl <= 150:
260
+ cls = "Inaccuracy"
261
+ elif cpl <= 300:
262
+ cls = "Mistake"
263
+ else:
264
+ cls = "Blunder"
265
 
266
  counts[cls] += 1
267
+
268
+ analysis_results.append(MoveAnalysis(
269
+ move_num=i+1,
270
+ san=san_move,
271
+ fen=board.fen(),
272
+ classification=cls,
273
+ cpl=cpl,
274
+ score_before=score_before / 100.0,
275
+ score_after=score_after / 100.0
276
+ ))
277
 
278
+ # Win probability matching accuracy formula
279
+ # Accuracy = 100 * exp(-0.02 * avg_cpl) smoothed
280
  avg_cpl = total_cpl / max(1, player_moves_count)
281
+
282
+ # Simple heuristic mapping for Accuracy & Elo
283
+ # 0 avg loss -> 100%
284
+ # ~100 avg loss -> ~60%
285
  accuracy = max(10.0, min(100.0, 100.0 * math.exp(-0.005 * avg_cpl)))
286
+
287
+ # Estimate Elo based slightly on accuracy
288
+ # This is a fun heuristic metric
289
  estimated_elo = int(max(400, min(3600, 3600 - (avg_cpl * 20))))
290
+
291
+ return AnalyzeResponse(
292
+ accuracy=round(accuracy, 1),
293
+ estimated_elo=estimated_elo,
294
+ moves=analysis_results,
295
+ counts=counts
296
+ )
297
+
298
+ except Exception as e:
299
+ print(f"Analysis Error: {e}")
300
+ raise HTTPException(status_code=500, detail=str(e))
301
  finally:
302
+ if engine:
303
+ try:
304
+ await engine.quit()
305
+ except Exception:
306
+ pass
307
+
308
 
309
  if __name__ == "__main__":
310
  import uvicorn
311
+ # Hugging Face Spaces port is 7860
312
  uvicorn.run(app, host="0.0.0.0", port=7860)