kaupane commited on
Commit
e8a0fd8
·
verified ·
1 Parent(s): e72f592

Upload 10 files

Browse files
Files changed (10) hide show
  1. Dockerfile +25 -0
  2. README.md +0 -14
  3. app.py +771 -0
  4. model.py +365 -0
  5. requirements.txt +10 -0
  6. utils/__init__.py +17 -0
  7. utils/buffer.py +274 -0
  8. utils/chess_env.py +151 -0
  9. utils/engine.py +759 -0
  10. utils/mapping.py +141 -0
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ # Install system dependencies including Stockfish
4
+ RUN apt-get update && apt-get install -y \
5
+ stockfish \
6
+ && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Set working directory
9
+ WORKDIR /code
10
+
11
+ # Copy requirements and install Python dependencies
12
+ COPY ./requirements.txt /code/requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
14
+
15
+ # Copy application code
16
+ COPY . /code
17
+
18
+ # Make sure Stockfish is executable
19
+ RUN chmod +x /usr/bin/stockfish
20
+
21
+ # Expose port
22
+ EXPOSE 7860
23
+
24
+ # Run the application
25
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,14 +0,0 @@
1
- ---
2
- title: Chessformer Demo
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.32.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Play chess with Chessformer or Stockfish!
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ """
4
+
5
+ import gradio as gr
6
+ import chess
7
+ import chess.svg
8
+ import chess.pgn
9
+ import re
10
+ import torch
11
+ import os
12
+ import io
13
+ import math
14
+ from typing import Optional, Tuple, List
15
+ import traceback
16
+ from datetime import datetime
17
+
18
+ from utils import Engine, ChessformerConfig, StockfishConfig, UCI_MOVE_TO_IDX
19
+ from model import ChessFormerModel
20
+
21
+ from concurrent.futures import ThreadPoolExecutor
22
+
23
+ import spaces
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ # Add to ChessApp.__init__
27
+ def __init__(self):
28
+ # ... existing init code ...
29
+ self.analysis_executor = ThreadPoolExecutor(max_workers=2)
30
+
31
+ def update_evaluations_async(self):
32
+ """Update evaluations asynchronously"""
33
+ def update_current_engine():
34
+ if self.current_engine:
35
+ try:
36
+ self.current_engine_eval = self.current_engine.analyze_position(self.board.copy())
37
+ if self.current_engine_eval is None:
38
+ self.current_engine_eval = 0.0
39
+ except:
40
+ self.current_engine_eval = 0.0
41
+
42
+ def update_stockfish():
43
+ try:
44
+ self.stockfish_eval = self.fast_stockfish_analysis(self.board.copy())
45
+ if self.stockfish_eval is None:
46
+ self.stockfish_eval = 0.0
47
+ except:
48
+ self.stockfish_eval = 0.0
49
+
50
+ # Run both analyses in parallel
51
+ future1 = self.analysis_executor.submit(update_current_engine)
52
+ future2 = self.analysis_executor.submit(update_stockfish)
53
+
54
+ # Wait for both to complete
55
+ future1.result()
56
+ future2.result()
57
+
58
+ class ChessApp:
59
+ def __init__(self, device):
60
+ self.board = chess.Board()
61
+ self.move_history = []
62
+ self.current_engine = None
63
+ self.analysis_engine = None
64
+ self.game_over = False
65
+ self.user_color = chess.WHITE
66
+ self.models = {}
67
+ self.device = device
68
+
69
+ self.current_engine_eval = 0.0
70
+ self.stockfish_eval = 0.0
71
+
72
+ self.load_models()
73
+ self.create_analysis_engine()
74
+
75
+ def load_models(self):
76
+ model_paths = {
77
+ "ChessFormer-SL": "./ckpts/chessformer-sl_01.pth",
78
+ "ChessFormer-RL": "./ckpts/chessformer-rl_final.pth"
79
+ }
80
+
81
+ for name, path in model_paths.items():
82
+ if os.path.exists(path):
83
+ print(f"Loading {name} from {path}...")
84
+ checkpoint = torch.load(path,map_location=self.device)
85
+ config = checkpoint["config"]
86
+ model = ChessFormerModel(**config)
87
+ model.load_state_dict(checkpoint["model_state_dict"])
88
+ model.to(self.device)
89
+ model.eval()
90
+
91
+ self.models[name] = model
92
+ print(f"Successfully loaded {name}.")
93
+ else:
94
+ print(f"Model file not found: {path}")
95
+
96
+ def get_depth_limits(self, engine_type: str) -> Tuple[int,int]:
97
+ if engine_type == "Stockfish":
98
+ return 1,24,6
99
+ else:
100
+ return 0,6,0
101
+
102
+ def create_evaluation_bar(self, eval_score: float, title: str) -> str:
103
+ """Create HTML evaluation bar from user's perspective with page-matching colors"""
104
+ # Convert eval_score from white's perspective to user's perspective
105
+ user_eval = eval_score if self.user_color == chess.WHITE else -eval_score
106
+
107
+ # Clamp evaluation between -1 and 1 for display
108
+ clamped_eval = max(-1.0, min(1.0, user_eval))
109
+
110
+ # Convert to percentage (0 = user losing, 100 = user winning)
111
+ percentage = (clamped_eval + 1.0) / 2.0 * 100
112
+
113
+ # Format evaluation text from user's perspective
114
+ eval_text = f"{user_eval:+.2f}"
115
+ if abs(user_eval) > 5:
116
+ eval_text = "±∞" if user_eval > 0 else "∓∞"
117
+
118
+ # Determine advantage text and colors (matching page theme)
119
+ if user_eval > 0.5:
120
+ advantage_text = "WINNING"
121
+ text_color = "#1e40af" # Blue-700
122
+ indicator_color = "#3b82f6" # Blue-500
123
+ elif user_eval > 0.1:
124
+ advantage_text = "SLIGHT ADVANTAGE"
125
+ text_color = "#1e40af"
126
+ indicator_color = "#60a5fa" # Blue-400
127
+ elif user_eval < -0.5:
128
+ advantage_text = "LOSING"
129
+ text_color = "#7c2d12" # Orange-800 (more muted than red)
130
+ indicator_color = "#ea580c" # Orange-600
131
+ elif user_eval < -0.1:
132
+ advantage_text = "SLIGHT DISADVANTAGE"
133
+ text_color = "#9a3412" # Orange-700
134
+ indicator_color = "#f97316" # Orange-500
135
+ else:
136
+ advantage_text = "EQUAL POSITION"
137
+ text_color = "#4b5563" # Gray-600
138
+ indicator_color = "#6b7280" # Gray-500
139
+
140
+ return f"""
141
+ <div style="margin: 10px 0; font-family: 'Segoe UI', Arial, sans-serif;">
142
+ <h4 style="margin: 5px 0 10px 0; color: #374151; font-size: 14px; font-weight: 600;">{title}</h4>
143
+
144
+ <!-- Evaluation bar with page-matching gradient -->
145
+ <div style="width: 100%; height: 40px; border: 2px solid #d1d5db; border-radius: 8px; position: relative;
146
+ background: linear-gradient(to right,
147
+ #fed7aa 0%, /* Orange-200 - losing */
148
+ #fde68a 20%, /* Yellow-200 */
149
+ #e5e7eb 50%, /* Gray-200 - equal */
150
+ #bfdbfe 80%, /* Blue-200 */
151
+ #93c5fd 100% /* Blue-300 - winning */
152
+ );
153
+ box-shadow: inset 0 1px 3px rgba(0,0,0,0.05);">
154
+
155
+ <!-- Evaluation indicator -->
156
+ <div style="position: absolute; left: {percentage}%; top: 50%; transform: translateX(-50%) translateY(-50%);
157
+ background: {indicator_color}; border: 3px solid white; border-radius: 50%; width: 18px; height: 18px;
158
+ box-shadow: 0 2px 4px rgba(0,0,0,0.15), 0 0 0 1px #d1d5db; z-index: 10;
159
+ transition: all 0.3s ease;"></div>
160
+ </div>
161
+
162
+ <!-- Evaluation text -->
163
+ <div style="text-align: center; margin-top: 8px; padding: 8px; background: #f9fafb;
164
+ border-radius: 6px; border: 1px solid #e5e7eb;">
165
+ <div style="font-weight: 600; color: {text_color}; font-size: 16px; margin-bottom: 2px;">
166
+ {eval_text}
167
+ </div>
168
+ <div style="font-size: 10px; color: {text_color}; text-transform: uppercase; letter-spacing: 0.8px; font-weight: 500; opacity: 0.8;">
169
+ {advantage_text}
170
+ </div>
171
+ </div>
172
+ </div>
173
+ """
174
+
175
+ def create_analysis_engine(self):
176
+ """Create optimized Stockfish depth 27 engine for analysis"""
177
+ try:
178
+ config = StockfishConfig(
179
+ engine_path="/usr/games/stockfish",
180
+ depth=27
181
+ )
182
+ self.analysis_engine = Engine(type="stockfish", stockfish_config=config)
183
+
184
+ # Configure Stockfish for faster analysis
185
+ if self.analysis_engine and hasattr(self.analysis_engine, 'engine_path'):
186
+ # We'll patch the engine creation to use optimized settings
187
+ pass
188
+
189
+ print("Analysis engine (Stockfish depth 27) created successfully")
190
+ except Exception as e:
191
+ print(f"Failed to create analysis engine: {e}")
192
+ self.analysis_engine = None
193
+
194
+ def update_evaluations(self):
195
+ """Update evaluations from both engines with optimized Stockfish analysis"""
196
+ # Get current engine evaluation
197
+ if self.current_engine:
198
+ try:
199
+ self.current_engine_eval = self.current_engine.analyze_position(self.board.copy())
200
+ if self.current_engine_eval is None:
201
+ self.current_engine_eval = 0.0
202
+ except:
203
+ self.current_engine_eval = 0.0
204
+
205
+ # Get optimized Stockfish analysis
206
+ if self.analysis_engine:
207
+ try:
208
+ self.stockfish_eval = self.fast_stockfish_analysis(self.board.copy())
209
+ if self.stockfish_eval is None:
210
+ self.stockfish_eval = 0.0
211
+ except:
212
+ self.stockfish_eval = 0.0
213
+
214
+ def fast_stockfish_analysis(self, board: chess.Board) -> Optional[float]:
215
+ """Fast Stockfish analysis with optimized settings"""
216
+ try:
217
+ import chess.engine
218
+
219
+ # Create engine with optimized settings
220
+ with chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish") as engine:
221
+ # Configure for speed
222
+ engine.configure({
223
+ "Threads": min(8, os.cpu_count() or 4), # Use multiple threads
224
+ "Hash": 256, # 256MB hash table
225
+ "UCI_AnalyseMode": True
226
+ })
227
+
228
+ # Use time limit instead of depth for faster analysis
229
+ info = engine.analyse(
230
+ board,
231
+ chess.engine.Limit(time=1.0), # 1 second analysis
232
+ )
233
+
234
+ score_obj = info.get("score")
235
+ if score_obj is None:
236
+ return None
237
+
238
+ pov_score = score_obj.pov(chess.WHITE)
239
+
240
+ if pov_score.is_mate():
241
+ mate_score = pov_score.mate()
242
+ cp = 10000.0 if mate_score > 0 else -10000.0
243
+ elif pov_score.cp is not None:
244
+ cp = float(pov_score.cp)
245
+ else:
246
+ return None
247
+
248
+ # Normalize score
249
+ normalized_score = 2 / (1 + math.exp(-0.004 * cp)) - 1
250
+ return normalized_score
251
+
252
+ except Exception as e:
253
+ print(f"Fast Stockfish analysis error: {e}")
254
+ return None
255
+
256
+ def create_engine(self, engine_type: str, depth: int, temperature: float=0.5) -> Optional[Engine]:
257
+ if engine_type == "Stockfish":
258
+ config = StockfishConfig(
259
+ engine_path="/usr/games/stockfish",
260
+ depth=depth
261
+ )
262
+ return Engine(type="stockfish",stockfish_config=config)
263
+ elif engine_type in self.models:
264
+ config = ChessformerConfig(
265
+ chessformer=self.models[engine_type],
266
+ device=self.device,
267
+ temperature=temperature,
268
+ depth=depth if depth > 0 else 0,
269
+ top_k=8,
270
+ decay_rate=0.6,
271
+ max_batch_size=800
272
+ )
273
+ return Engine(type="chessformer",chessformer_config=config)
274
+
275
+ return None
276
+
277
+ def parse_move(self, move_str: str) -> Optional[chess.Move]:
278
+ """Parse move input in either UCI format ("e2e4") or algebraic notation ("Ne5")"""
279
+ if not move_str:
280
+ return None
281
+
282
+ move_str = move_str.strip()
283
+
284
+ # Try UCI format first
285
+ uci_pattern = r'^[a-h][1-8][a-h][1-8][qrbn]?$'
286
+ if re.match(uci_pattern,move_str.lower()):
287
+ try:
288
+ return chess.Move.from_uci(move_str.lower())
289
+ except ValueError:
290
+ pass
291
+
292
+ # Try algrebraic notation
293
+ try:
294
+ return self.board.parse_san(move_str)
295
+ except ValueError:
296
+ pass
297
+
298
+ return None
299
+
300
+ def get_board_svg(self) -> str:
301
+ """Generate SVG representation of the chess board"""
302
+ flip = (self.user_color == chess.BLACK)
303
+
304
+ lastmove = None
305
+ if self.move_history:
306
+ lastmove = self.move_history[-1]
307
+
308
+ svg = chess.svg.board(
309
+ board=self.board,
310
+ flipped=flip,
311
+ lastmove=lastmove,
312
+ size=600
313
+ )
314
+ return svg
315
+
316
+ def get_move_history_text(self) -> str:
317
+ """Generate move history in PGN format"""
318
+ try:
319
+ game = chess.pgn.Game()
320
+ game.headers["Event"] = "ChessFormer Demo"
321
+ game.headers["Date"] = datetime.now().strftime("%Y.%m.%d")
322
+ game.headers["White"] = "You" if self.user_color == chess.WHITE else "Engine"
323
+ game.headers["Black"] = "Engine" if self.user_color == chess.WHITE else "You"
324
+
325
+ node = game
326
+ temp_board = chess.Board()
327
+
328
+ for move in self.move_history:
329
+ node = node.add_variation(move)
330
+ temp_board.push(move)
331
+
332
+ if self.game_over:
333
+ outcome = self.board.outcome()
334
+ if outcome:
335
+ if outcome.winner == chess.WHITE:
336
+ game.headers["Result"] = "1-0"
337
+ elif outcome.winner == chess.BLACK:
338
+ game.headers["Result"] = "0-1"
339
+ else:
340
+ game.headers["Result"] = "1/2-1/2"
341
+ else:
342
+ game.headers["Result"] = "*"
343
+ else:
344
+ game.headers["Result"] = "*"
345
+
346
+ return str(game)
347
+ except Exception as e:
348
+ print(f"Error generating move history: {e}")
349
+ return "Move history unavailable"
350
+
351
+ def export_pgn(self) -> str:
352
+ return self.get_move_history_text()
353
+
354
+ def import_fen(self, fen: str) -> Tuple[str,str,str,str,str]:
355
+ try:
356
+ test_board = chess.Board(fen.strip())
357
+ self.board = test_board
358
+ self.move_history = []
359
+ self.game_over = False
360
+ self.update_evaluations()
361
+
362
+ return (
363
+ self.get_board_svg(),
364
+ self.get_move_history_text(),
365
+ f"Position loaded from FEN: {fen}",
366
+ "",
367
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
368
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
369
+ )
370
+ except Exception as e:
371
+ return (
372
+ self.get_board_svg(),
373
+ self.get_move_history_text(),
374
+ f"Invalid FEN: {str(e)}",
375
+ "",
376
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
377
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
378
+ )
379
+
380
+ def import_pgn(self, pgn_text: str) -> Tuple[str,str,str,str,str]:
381
+ try:
382
+ pgn_io = io.StringIO(pgn_text.strip())
383
+ game = chess.pgn.read_game(pgn_io)
384
+
385
+ if game is None:
386
+ raise ValueError("Could not parse PGN")
387
+
388
+ self.board = game.board()
389
+ self.move_history = []
390
+
391
+ for move in game.mainline_moves():
392
+ self.board.push(move)
393
+ self.move_history.append(move)
394
+
395
+ self.game_over = self.board.is_game_over()
396
+ self.update_evaluations()
397
+
398
+ return (
399
+ self.get_board_svg(),
400
+ self.get_move_history_text(),
401
+ f"Game loaded from PGN ({len(self.move_history)} moves)",
402
+ "",
403
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
404
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
405
+ )
406
+ except Exception as e:
407
+ return (
408
+ self.get_board_svg(),
409
+ self.get_move_history_text(),
410
+ f"Invalid PGN: {str(e)}",
411
+ "",
412
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
413
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
414
+ )
415
+
416
+ def make_user_move(self, move_str: str) -> Tuple[str,str,str,str,str,str]:
417
+ if self.game_over:
418
+ return (
419
+ self.get_board_svg(),
420
+ self.get_move_history_text(),
421
+ "Game is over. Click 'New Game' to start a new game.",
422
+ "",
423
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
424
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
425
+ )
426
+
427
+ if self.board.turn != self.user_color:
428
+ return (
429
+ self.get_board_svg(),
430
+ self.get_move_history_text(),
431
+ "It's not your turn now!",
432
+ "",
433
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
434
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
435
+ )
436
+
437
+ move = self.parse_move(move_str)
438
+ if move is None:
439
+ return (
440
+ self.get_board_svg(),
441
+ self.get_move_history_text(),
442
+ f"Invalid move: '{move_str}'. Try formats like 'e2e4' or 'Ne5'",
443
+ "",
444
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
445
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
446
+ )
447
+
448
+ if move not in self.board.legal_moves:
449
+ return (
450
+ self.get_board_svg(),
451
+ self.get_move_history_text(),
452
+ f"Illegal move: '{move_str}'",
453
+ "",
454
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
455
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
456
+ )
457
+
458
+ self.board.push(move)
459
+ self.move_history.append(move)
460
+
461
+ self.update_evaluations()
462
+
463
+ if self.board.is_game_over():
464
+ self.game_over = True
465
+ outcome = self.board.outcome()
466
+ if outcome:
467
+ if outcome.winner == self.user_color:
468
+ status = "Congratulations! You won!"
469
+ elif outcome.winner is None:
470
+ status = "Game drawn."
471
+ else:
472
+ status = "You lost."
473
+ status += f" ({outcome.termination.name})"
474
+ else:
475
+ status = "Game over."
476
+
477
+ return (
478
+ self.get_board_svg(),
479
+ self.get_move_history_text(),
480
+ status,
481
+ "",
482
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
483
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
484
+ )
485
+
486
+ # Get engine move
487
+ try:
488
+ engine_move_uci, engine_value = self.current_engine.move(self.board)
489
+
490
+ if engine_move_uci == "<claim_draw>":
491
+ self.game_over = True
492
+ status = "Engine claimed a draw."
493
+ else:
494
+ engine_move = chess.Move.from_uci(engine_move_uci)
495
+ self.board.push(engine_move)
496
+ self.move_history.append(engine_move)
497
+
498
+ if self.board.is_game_over():
499
+ self.game_over = True
500
+ outcome = self.board.outcome()
501
+ if outcome:
502
+ if outcome.winner == self.user_color:
503
+ status = "🎉🏆 CONGRATULATIONS! YOU WON! 🏆🎉"
504
+ status += f"\n🎯 Victory by {outcome.termination.name}! 🎯"
505
+ elif outcome.winner is None:
506
+ status = "🤝 GAME DRAWN 🤝"
507
+ status += f"\n⚖️ Draw by {outcome.termination.name} ⚖️"
508
+ else:
509
+ status = "😔 YOU LOST 😔"
510
+ status += f"\n💔 Defeated by {outcome.termination.name} 💔"
511
+ else:
512
+ status = "🏁 GAME OVER 🏁"
513
+ else:
514
+ status = f"Engine played: {engine_move.uci()}."
515
+
516
+ except Exception as e:
517
+ status = f"Engine error: {str(e)}"
518
+ print(f"Engine error: {e}")
519
+ traceback.print_exc()
520
+
521
+ return (
522
+ self.get_board_svg(),
523
+ self.get_move_history_text(),
524
+ status,
525
+ "", # clear input
526
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
527
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
528
+ )
529
+
530
+ def new_game(self, engine_type: str, depth: int, color: str, temperature: float) -> Tuple[str,str,str,str,str,str]:
531
+ "Start a new game"
532
+ self.board = chess.Board()
533
+ self.move_history = []
534
+ self.game_over = False
535
+ self.user_color = chess.WHITE if color == "White" else chess.BLACK
536
+
537
+ # Create new engine
538
+ self.current_engine = self.create_engine(engine_type, depth, temperature)
539
+
540
+ self.update_evaluations()
541
+
542
+ if self.current_engine is None:
543
+ status = f"Failed to create {engine_type} engine."
544
+ else:
545
+ status = f"New game started! You are playing {color} against {engine_type} (depth {depth})."
546
+
547
+ # If user is black, make engine move first
548
+ if self.user_color == chess.BLACK:
549
+ try:
550
+ engine_move_uci, engine_value = self.current_engine.move(self.board)
551
+ if engine_move_uci != "<claim_draw>":
552
+ engine_move = chess.Move.from_uci(engine_move_uci)
553
+ self.board.push(engine_move)
554
+ self.move_history.append(engine_move)
555
+ status += f" Engine opened with: {engine_move.uci()}"
556
+ except Exception as e:
557
+ status += f" Engine error on first move: {str(e)}"
558
+
559
+ return (
560
+ self.get_board_svg(),
561
+ self.get_move_history_text(),
562
+ status,
563
+ "",
564
+ self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
565
+ self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
566
+ )
567
+
568
+
569
+ app = ChessApp(torch.device("cpu"))
570
+
571
+ def create_interface():
572
+ """Create the Gradio interface with improved layout"""
573
+
574
+ with gr.Blocks(title="ChessFormer Demo", theme=gr.themes.Soft()) as interface:
575
+ gr.Markdown("# 🏆 ChessFormer Demo")
576
+ gr.Markdown("Play chess against ChessFormer models or Stockfish!")
577
+
578
+ with gr.Row():
579
+ # Left column - Analysis + History
580
+ with gr.Column(scale=1):
581
+ gr.Markdown("### 📊 Position Analysis")
582
+
583
+ # Stockfish Analysis
584
+ stockfish_eval_display = gr.HTML(
585
+ value=app.create_evaluation_bar(0.0, "Stockfish Analysis"),
586
+ label="Stockfish"
587
+ )
588
+
589
+ # Current Engine Analysis
590
+ current_engine_eval_display = gr.HTML(
591
+ value=app.create_evaluation_bar(0.0, "Engine Analysis"),
592
+ label="Engine"
593
+ )
594
+
595
+ # Move history
596
+ gr.Markdown("### 📝 Game History")
597
+ history_display = gr.Textbox(
598
+ value=app.get_move_history_text(),
599
+ label="PGN",
600
+ lines=12,
601
+ max_lines=15,
602
+ interactive=False
603
+ )
604
+
605
+ # Middle column - Game Board + Controls
606
+ with gr.Column(scale=4):
607
+ # Chess board display
608
+ board_display = gr.HTML(
609
+ value=app.get_board_svg(),
610
+ label="Chess Board"
611
+ )
612
+
613
+ # Move input
614
+ with gr.Row():
615
+ move_input = gr.Textbox(
616
+ placeholder="Enter move (e.g., 'e2e4' or 'Ne5')",
617
+ label="Your Move",
618
+ scale=4
619
+ )
620
+ move_button = gr.Button("Make Move", variant="primary", scale=1)
621
+
622
+ # Game status
623
+ status_display = gr.Textbox(
624
+ value="Click 'New Game' to start playing!",
625
+ label="Game Status",
626
+ interactive=False,
627
+ lines=2
628
+ )
629
+
630
+ # Right column - Settings + Import/Export
631
+ with gr.Column(scale=2):
632
+ # Engine settings
633
+ gr.Markdown("### ⚙️ Game Settings")
634
+
635
+ engine_choices = ["Stockfish"] + list(app.models.keys())
636
+ engine_select = gr.Dropdown(
637
+ choices=engine_choices,
638
+ value="ChessFormer-SL" if engine_choices else None,
639
+ label="Opponent Engine"
640
+ )
641
+
642
+ depth_slider = gr.Slider(
643
+ minimum=0,
644
+ maximum=6,
645
+ value=0,
646
+ step=1,
647
+ label="Engine Depth"
648
+ )
649
+
650
+ color_select = gr.Radio(
651
+ choices=["White", "Black"],
652
+ value="White",
653
+ label="Your Color"
654
+ )
655
+
656
+ temperature_slider = gr.Slider(
657
+ minimum=0.1,
658
+ maximum=2.0,
659
+ value=0.5,
660
+ step=0.1,
661
+ label="Temperature (ChessFormer only)"
662
+ )
663
+
664
+ new_game_button = gr.Button("🔄 New Game", variant="secondary", size="lg")
665
+
666
+ # Import/Export section
667
+ gr.Markdown("### 📁 Import/Export")
668
+
669
+ with gr.Tabs():
670
+ with gr.Tab("Import FEN"):
671
+ fen_input = gr.Textbox(
672
+ placeholder="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
673
+ label="FEN String",
674
+ lines=2
675
+ )
676
+ import_fen_button = gr.Button("Import FEN")
677
+
678
+ with gr.Tab("Import PGN"):
679
+ pgn_input = gr.Textbox(
680
+ placeholder="1. e4 e5 2. Nf3 Nc6...",
681
+ label="PGN Text",
682
+ lines=3
683
+ )
684
+ import_pgn_button = gr.Button("Import PGN")
685
+
686
+ with gr.Tab("Export"):
687
+ export_button = gr.Button("📁 Download PGN")
688
+ export_output = gr.File(label="Download")
689
+
690
+ # Available models info
691
+ gr.Markdown("### 🤖 Available Models")
692
+ if app.models:
693
+ model_info = "**Loaded ChessFormer models:**\n" + "\n".join([f"• {name}" for name in app.models.keys()])
694
+ else:
695
+ model_info = "⚠️ No ChessFormer models found. Make sure model checkpoints are in the ./ckpts/ directory."
696
+ gr.Markdown(model_info)
697
+
698
+ # Function to update depth limits based on engine selection
699
+ def update_depth_limits(engine_type):
700
+ min_depth, max_depth, value = app.get_depth_limits(engine_type)
701
+ return gr.Slider(minimum=min_depth, maximum=max_depth, value=value, step=1)
702
+
703
+ # Function to export PGN
704
+ def export_pgn_file():
705
+ pgn_content = app.export_pgn()
706
+ filename = f"chess_game_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pgn"
707
+ with open(filename, 'w') as f:
708
+ f.write(pgn_content)
709
+ return filename
710
+
711
+ # Event handlers (same as before...)
712
+ engine_select.change(
713
+ fn=update_depth_limits,
714
+ inputs=[engine_select],
715
+ outputs=[depth_slider]
716
+ )
717
+
718
+ move_button.click(
719
+ fn=app.make_user_move,
720
+ inputs=[move_input],
721
+ outputs=[board_display, history_display, status_display, move_input,
722
+ stockfish_eval_display, current_engine_eval_display]
723
+ )
724
+
725
+ move_input.submit(
726
+ fn=app.make_user_move,
727
+ inputs=[move_input],
728
+ outputs=[board_display, history_display, status_display, move_input,
729
+ stockfish_eval_display, current_engine_eval_display]
730
+ )
731
+
732
+ new_game_button.click(
733
+ fn=app.new_game,
734
+ inputs=[engine_select, depth_slider, color_select, temperature_slider],
735
+ outputs=[board_display, history_display, status_display, move_input,
736
+ stockfish_eval_display, current_engine_eval_display]
737
+ )
738
+
739
+ import_fen_button.click(
740
+ fn=app.import_fen,
741
+ inputs=[fen_input],
742
+ outputs=[board_display, history_display, status_display, fen_input,
743
+ stockfish_eval_display, current_engine_eval_display]
744
+ )
745
+
746
+ import_pgn_button.click(
747
+ fn=app.import_pgn,
748
+ inputs=[pgn_input],
749
+ outputs=[board_display, history_display, status_display, pgn_input,
750
+ stockfish_eval_display, current_engine_eval_display]
751
+ )
752
+
753
+ export_button.click(
754
+ fn=export_pgn_file,
755
+ outputs=[export_output]
756
+ )
757
+
758
+ # Auto-start a new game when interface loads
759
+ interface.load(
760
+ fn=app.new_game,
761
+ inputs=[gr.State("Stockfish"), gr.State(6), gr.State("White"), gr.State(0.5)],
762
+ outputs=[board_display, history_display, status_display, move_input,
763
+ stockfish_eval_display, current_engine_eval_display]
764
+ )
765
+
766
+ return interface
767
+
768
+ if __name__ == "__main__":
769
+ # Create and launch interface
770
+ interface = create_interface()
771
+ interface.launch()
model.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Dict, Tuple
4
+
5
+ from huggingface_hub import PyTorchModelHubMixin
6
+
7
+ from utils import MAX_HALFMOVES, MAX_FULLMOVES, EMPTY_SQ_IDX, PIECE_TO_IDX, SQUARE_TO_IDX, IDX_TO_UCI_MOVE
8
+
9
+ # --- Tokenizer --- #
10
+ class FENTokenizer(nn.Module):
11
+ """Convert FEN (and repetitions) to a sequence of tokens"""
12
+ def __init__(self, hidden_size,dtype):
13
+ super().__init__()
14
+
15
+ self.side_embed = nn.Embedding(2,hidden_size,dtype=dtype) # black/white embedding
16
+
17
+ self.castling_embed_k = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
18
+ self.castling_embed_q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
19
+ self.castling_embed_K = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
20
+ self.castling_embed_Q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
21
+ self.no_castling_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
22
+
23
+ self.piece_embed = nn.Embedding(13,hidden_size,dtype=dtype) # 6 for white pieces, 6 for black pieces, 1 for empty
24
+
25
+ self.no_en_passant_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) # use positional embed for the target square, or a special one for '-'
26
+
27
+ self.half_move_embed = nn.Embedding(MAX_HALFMOVES,hidden_size,dtype=dtype)
28
+
29
+ self.full_move_embed = nn.Embedding(MAX_FULLMOVES,hidden_size,dtype=dtype)
30
+
31
+ self.repetition_embed = nn.Embedding(3,hidden_size,dtype=dtype)
32
+
33
+ self.pos_embed = nn.Embedding(64,hidden_size,dtype=dtype) # positional embedding
34
+
35
+ def _parse_fen_string(self, fen_str: str) -> Dict:
36
+ parts = fen_str.split()
37
+ if len(parts) != 6:
38
+ raise ValueError(f"Invalid FEN string: {fen_str}. Expected 6 fields")
39
+ return {
40
+ "piece_placement": parts[0],
41
+ "side_to_move": parts[1],
42
+ "castling": parts[2],
43
+ "en_passant": parts[3],
44
+ "halfmove_clock": parts[4],
45
+ "fullmove_number": parts[5],
46
+ }
47
+
48
+ def forward(self, fen_list: List[str], repetitions: torch.Tensor) -> torch.Tensor:
49
+ """
50
+ Args:
51
+ fen: List of fen strings
52
+
53
+ Returns:
54
+ torch tensor of shape (n_fen,73,hidden_size) where 73 tokens consists of:
55
+ 64 piece tokens (fen's first field) +
56
+ 1 which-side-to-move token (fen's second field) +
57
+ 4 casting rights tokens (fen's third field) +
58
+ 1 en-passant target token (fen's fourth field) +
59
+ 1 half move clock token (fen's fifth field) +
60
+ 1 full move number token (fen's fifth field) +
61
+ 1 repetition count token (repetitions input)
62
+ """
63
+ batch_size = len(fen_list)
64
+ assert batch_size == repetitions.shape[0]
65
+ assert len(repetitions.size()) == 1
66
+ batch_tokens = []
67
+ device = self.side_embed.weight.device
68
+
69
+ # Precompute all square indices
70
+ square_indices = torch.arange(64, device=device)
71
+ all_pos_embeds = self.pos_embed(square_indices) # (64,D)
72
+
73
+ for fen_str in fen_list:
74
+ parsed_fen = self._parse_fen_string(fen_str)
75
+ tokens = []
76
+
77
+ # --- 1. Piece Placement (64 tokens) ---
78
+ piece_indices = torch.full((64,), EMPTY_SQ_IDX, dtype=torch.long, device=device)
79
+ current_rank = 7 # Start from rank 8
80
+ current_file = 0 # Start from file 'a'
81
+ for char in parsed_fen["piece_placement"]:
82
+ if char == '/':
83
+ current_rank -= 1
84
+ current_file = 0
85
+ elif char.isdigit():
86
+ current_file += int(char)
87
+ elif char in PIECE_TO_IDX:
88
+ sq_idx = current_rank * 8 + current_file
89
+ if 0 <= sq_idx < 64:
90
+ piece_indices[sq_idx] = PIECE_TO_IDX[char]
91
+ else:
92
+ raise ValueError(f"Invalid FEN piece placement: {parsed_fen['piece_placement']}")
93
+ current_file += 1
94
+ else:
95
+ raise ValueError(f"Invalid character in FEN piece placement: {char}")
96
+
97
+ piece_embeds = self.piece_embed(piece_indices) # (64, D)
98
+ # Add positional embeddings
99
+ board_tokens = piece_embeds + all_pos_embeds # (64, D)
100
+ tokens.append(board_tokens)
101
+
102
+ # --- 2. Side to Move (1 token) ---
103
+ side_idx = 0 if parsed_fen["side_to_move"] == 'w' else 1
104
+ side_token = self.side_embed(torch.tensor(side_idx, device=device)).unsqueeze(0) # (1, D)
105
+ tokens.append(side_token)
106
+
107
+ # --- 3. Castling Rights (4 tokens) ---
108
+ castling_str = parsed_fen["castling"]
109
+ castling_tokens = torch.cat([
110
+ self.castling_embed_K if 'K' in castling_str else self.no_castling_embed.expand(1, 1, -1),
111
+ self.castling_embed_Q if 'Q' in castling_str else self.no_castling_embed.expand(1, 1, -1),
112
+ self.castling_embed_k if 'k' in castling_str else self.no_castling_embed.expand(1, 1, -1),
113
+ self.castling_embed_q if 'q' in castling_str else self.no_castling_embed.expand(1, 1, -1)
114
+ ], dim=1).squeeze(0) # (4, D)
115
+ tokens.append(castling_tokens)
116
+
117
+ # --- 4. En Passant Target (1 token) ---
118
+ en_passant_str = parsed_fen["en_passant"]
119
+ if en_passant_str == '-':
120
+ en_passant_token = self.no_en_passant_embed.squeeze(0) # (1, D)
121
+ else:
122
+ if en_passant_str in SQUARE_TO_IDX:
123
+ sq_idx = SQUARE_TO_IDX[en_passant_str]
124
+ en_passant_token = self.pos_embed(torch.tensor(sq_idx, device=device)).unsqueeze(0) # (1, D)
125
+ else:
126
+ raise ValueError(f"Invalid en passant square: {en_passant_str}")
127
+ tokens.append(en_passant_token)
128
+
129
+ # --- 5. Half Move Clock (1 token) ---
130
+ try:
131
+ half_move_int = int(parsed_fen["halfmove_clock"])
132
+ except ValueError:
133
+ raise ValueError(f"Invalid halfmove clock value: {parsed_fen['halfmove_clock']}")
134
+ # Clamp value before embedding lookup
135
+ half_move_clamped = torch.clamp(torch.tensor(half_move_int, device=device), 0, MAX_HALFMOVES - 1)
136
+ half_move_token = self.half_move_embed(half_move_clamped).unsqueeze(0) # (1, D)
137
+ tokens.append(half_move_token)
138
+
139
+ # --- 6. Full Move Number (1 token) ---
140
+ try:
141
+ full_move_int = int(parsed_fen["fullmove_number"])
142
+ except ValueError:
143
+ raise ValueError(f"Invalid fullmove number value: {parsed_fen['fullmove_number']}")
144
+ # Clamp value (min 1 for full moves) before embedding lookup (adjusting for 0-based index)
145
+ full_move_clamped = torch.clamp(torch.tensor(full_move_int, device=device), 1, MAX_FULLMOVES) - 1
146
+ full_move_token = self.full_move_embed(full_move_clamped).unsqueeze(0) # (1, D)
147
+ tokens.append(full_move_token)
148
+
149
+ # Concatenate all tokens for this FEN string
150
+ # Shapes: (64, D), (1, D), (4, D), (1, D), (1, D), (1, D) -> Total 72 tokens
151
+ fen_embedding = torch.cat(tokens, dim=0) # (72, D)
152
+ batch_tokens.append(fen_embedding)
153
+
154
+ # Stack into a batch
155
+ batch_tokens = torch.stack(batch_tokens, dim=0) # (B,72,D)
156
+
157
+ # ---7. Repetition Count (1 token) ---
158
+ repetitions = repetitions - 1 # from 1~3 to 0~2
159
+ repetitions = torch.clamp(repetitions,0,2) # if repetition count >3 but no player claimed a draw, it will be treated as 3 repetitions
160
+ repetition_tokens = self.repetition_embed(repetitions) # (B,D)
161
+ repetition_tokens = repetition_tokens.unsqueeze(1) # (B,1,D)
162
+
163
+ return torch.cat([batch_tokens,repetition_tokens], dim=1) # (B, 73, D)
164
+
165
+ # --- Helper Modules --- #
166
+ class SwiGLUFFN(nn.Module):
167
+ def __init__(self,
168
+ d_model,
169
+ dim_feedforward,
170
+ dropout: float,
171
+ bias_up: bool=False,
172
+ bias_gate: bool=False,
173
+ bias_down: bool=True,
174
+ dtype=None):
175
+ super().__init__()
176
+ self.up_proj = nn.Linear(d_model,dim_feedforward,bias=bias_up,dtype=dtype)
177
+ self.gate_proj = nn.Linear(d_model,dim_feedforward,bias=bias_gate,dtype=dtype)
178
+ self.down_proj = nn.Linear(dim_feedforward,d_model,bias=bias_down,dtype=dtype)
179
+
180
+ self.dropout = nn.Dropout(dropout)
181
+
182
+ def forward(self, x):
183
+ x = self.up_proj(x) * self.dropout(nn.functional.silu(self.gate_proj(x)))
184
+ return self.down_proj(x)
185
+
186
+ class TransformerEncoderLayer(nn.Module):
187
+ """Custom transformer encoder layer with RMSNorm and SwiGLUFFN"""
188
+ def __init__(self,
189
+ d_model: int,
190
+ nhead: int,
191
+ dim_feedforward: int,
192
+ dropout: float,
193
+ batch_first: bool=True,
194
+ norm_first: bool=False,
195
+ dtype=None):
196
+ super().__init__()
197
+ self.norm_first = norm_first
198
+
199
+ self.norm1 = nn.RMSNorm(d_model,dtype=dtype)
200
+ self.dropout_sa = nn.Dropout(dropout)
201
+ self.self_attn = nn.MultiheadAttention(
202
+ d_model,
203
+ nhead,
204
+ dropout=dropout,
205
+ bias=False,
206
+ batch_first=batch_first,
207
+ dtype=dtype
208
+ )
209
+
210
+ self.norm2 = nn.RMSNorm(d_model,dtype=dtype)
211
+ self.dropout_ff = nn.Dropout(dropout)
212
+ self.mlp = SwiGLUFFN(
213
+ d_model,
214
+ dim_feedforward,
215
+ dropout=dropout,
216
+ bias_up=False,
217
+ bias_gate=False,
218
+ bias_down=True,
219
+ dtype=dtype
220
+ )
221
+
222
+ def forward(self, x, return_attention=False):
223
+ if self.norm_first:
224
+ if return_attention:
225
+ x_norm = self.norm1(x)
226
+ attn_output, attn_weights = self._sa_block(x_norm,return_attention=True)
227
+ x = x + attn_output
228
+ x = x + self._ff_block(self.norm2(x))
229
+ return x, attn_weights
230
+ else:
231
+ x = x + self._sa_block(self.norm1(x))
232
+ x = x + self._ff_block(self.norm2(x))
233
+ return x
234
+ else:
235
+ if return_attention:
236
+ attn_output, attn_weights = self._sa_block(x, return_attention=True)
237
+ x = self.norm1(x + attn_output)
238
+ x = self.norm2(x + self._ff_block(x))
239
+ return x, attn_weights
240
+ else:
241
+ x = self.norm1(x + self._sa_block(x))
242
+ x = self.norm2(x + self._ff_block(x))
243
+ return x
244
+
245
+ def _sa_block(self, x, return_attention=False):
246
+ if return_attention:
247
+ attn_output, attn_weights = self.self_attn(x,x,x,need_weights=True,average_attn_weights=False)
248
+ return self.dropout_sa(attn_output), attn_weights
249
+ else:
250
+ x = self.self_attn(x,x,x)[0]
251
+ return self.dropout_sa(x)
252
+
253
+ def _ff_block(self,x):
254
+ x = self.mlp(x)
255
+ return self.dropout_ff(x)
256
+ nn.TransformerEncoderLayer
257
+
258
+ # --- Model Arch --- #
259
+ class ChessFormerModel(nn.Module, PyTorchModelHubMixin):
260
+ def __init__(self,
261
+ num_blocks,
262
+ hidden_size,
263
+ intermediate_size,
264
+ num_heads,
265
+ dropout: float=0.00,
266
+ possible_moves: int=len(IDX_TO_UCI_MOVE), # 1969 structurally valid moves
267
+ dtype=None):
268
+ super().__init__()
269
+ self.fen_tokenizer = FENTokenizer(hidden_size,dtype=dtype)
270
+
271
+ self.act_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02)
272
+ self.val_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02)
273
+
274
+ self.act_proj = nn.Linear(hidden_size,possible_moves,dtype=dtype)
275
+ self.val_proj = nn.Linear(hidden_size,1,dtype=dtype)
276
+
277
+ self.blocks = nn.ModuleList(
278
+ TransformerEncoderLayer(
279
+ d_model=hidden_size,
280
+ nhead=num_heads,
281
+ dim_feedforward=intermediate_size,
282
+ dropout=dropout,
283
+ batch_first=True,
284
+ norm_first=True,
285
+ dtype=dtype
286
+ ) for _ in range(num_blocks)
287
+ )
288
+ self.dtype=dtype
289
+ self.possible_moves = possible_moves
290
+
291
+ self.final_norm = nn.RMSNorm(hidden_size)
292
+
293
+ self._initialize_weights()
294
+
295
+ def _initialize_weights(self):
296
+ """Initialize weights"""
297
+ for m in self.modules():
298
+ if isinstance(m,nn.Linear):
299
+ nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='relu')
300
+ if m.bias is not None:
301
+ nn.init.constant_(m.bias, 0)
302
+ elif isinstance(m, nn.Embedding):
303
+ nn.init.normal_(m.weight, std=0.02)
304
+ elif isinstance(m, nn.LayerNorm):
305
+ if hasattr(m, 'weight'):
306
+ nn.init.constant_(m.weight, 1.0)
307
+ if hasattr(m, 'bias') and m.bias is not None:
308
+ nn.init.constant_(m.weight, 0.0)
309
+ elif isinstance(m, nn.RMSNorm):
310
+ if hasattr(m, 'weight'):
311
+ nn.init.constant_(m.weight, 1.0)
312
+
313
+ tokenizer_params = dict(self.fen_tokenizer.named_parameters())
314
+
315
+ params_to_init = [
316
+ self.act_token, self.val_token,
317
+ tokenizer_params.get('castling_embed_k'), tokenizer_params.get('castling_embed_q'),
318
+ tokenizer_params.get('castling_embed_K'), tokenizer_params.get('castling_embed_Q'),
319
+ tokenizer_params.get('no_castling_embed'), tokenizer_params.get('no_en_passant_embed')
320
+ ]
321
+
322
+ for param in params_to_init:
323
+ if param is not None and param.requires_grad:
324
+ nn.init.normal_(param, std=0.02)
325
+
326
+
327
+ def forward(self, fen: List[str], repetitions: torch.Tensor, return_attention: bool=False) -> torch.Tensor:
328
+ x = self.fen_tokenizer(fen,repetitions) # (B,73,D), pos embed are added here
329
+ bs = x.shape[0]
330
+ x = torch.cat([x,self.act_token.expand(bs,-1,-1),self.val_token.expand(bs,-1,-1)],dim=1) # (B,75,D)
331
+
332
+ attention_maps = [] if return_attention else None
333
+
334
+ for block in self.blocks:
335
+ if return_attention:
336
+ x, attn = block(x, return_attention=True)
337
+ attention_maps.append(attn)
338
+ else:
339
+ x = block(x)
340
+
341
+ x = self.final_norm(x)
342
+
343
+ act = x[:,-2,:]
344
+ val = x[:,-1,:]
345
+ act_logits = self.act_proj(act) # (B,1969)
346
+ val = self.val_proj(val) # (B,1)
347
+
348
+ if return_attention:
349
+ return act_logits, val.squeeze(1), attention_maps
350
+ else:
351
+ return act_logits, val.squeeze(1)
352
+
353
+ def load_model(ckpt_path):
354
+ checkpoint = torch.load(ckpt_path)
355
+ model_config = checkpoint["model_config"]
356
+ model = ChessFormerModel(**model_config)
357
+ model.load_state_dict(checkpoint["model_state_dict"])
358
+ return model
359
+
360
+ if __name__ == "__main__":
361
+ checkpoint = torch.load("./ckpts/chessformer-sl_01.pth",map_location=torch.device("cpu"))
362
+ model = ChessFormerModel(**checkpoint["config"])
363
+ model.load_state_dict(checkpoint["model_state_dict"])
364
+
365
+ model.push_to_hub("kaupane/ChessFormer-SL")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ python-chess
4
+ chess
5
+ huggingface-hub
6
+ transformers
7
+ numpy
8
+ Pillow
9
+ datasets
10
+
utils/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .buffer import ReplayBuffer, Game
2
+ from .chess_env import BatchChessEnv
3
+ from .engine import Engine, ChessformerConfig, StockfishConfig
4
+ from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE, MAX_HALFMOVES, MAX_FULLMOVES, EMPTY_SQ_IDX, PIECE_TO_IDX, SQUARE_TO_IDX
5
+
6
+ __all__ = ['ReplayBuffer',
7
+ 'BatchChessEnv',
8
+ 'Engine',
9
+ 'Game',
10
+ 'UCI_MOVE_TO_IDX',
11
+ 'IDX_TO_UCI_MOVE',
12
+ 'MAX_HALFMOVES',
13
+ 'MAX_FULLMOVES',
14
+ 'EMPTY_SQ_IDX',
15
+ 'PIECE_TO_IDX',
16
+ 'SQUARE_TO_IDX'
17
+ ]
utils/buffer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from collections import deque
4
+ import numpy as np
5
+ from typing import List, Iterator, Tuple, Optional
6
+ import chess
7
+
8
+ class Game:
9
+ """
10
+ Represents a single chess game trajectory with all relevant data for RL training.
11
+ Acts as a *temporary* buffer inside loop
12
+ Handles:
13
+ - Storing trajectory data (fens, reps, actions, log_probs, values, invalid_masks)
14
+ - Tracking game status (active/complete)
15
+ """
16
+ def __init__(self):
17
+ self.active = True
18
+ self.valid = True
19
+ self.completion_reason = None
20
+ self.game_result = None
21
+
22
+ self.fens = []
23
+ self.repetition_counts = []
24
+ self.actions = []
25
+ self.values = []
26
+ self.log_probs = []
27
+ self.invalid_masks = []
28
+
29
+ def update_trajectory(self, fen, rep, act, val, logp, inv_m):
30
+ self.fens.append(fen)
31
+ self.repetition_counts.append(rep)
32
+ self.actions.append(act)
33
+ self.values.append(val)
34
+ self.log_probs.append(logp)
35
+ self.invalid_masks.append(inv_m)
36
+
37
+ def update_game_status(self, done, reason):
38
+ if done == True:
39
+ self.active = False
40
+ if reason in ["1-0","0-1","1/2-1/2"]:
41
+ self.completion_reason = reason
42
+ self.game_result = reason
43
+ else:
44
+ self.completion_reason = reason
45
+ self.game_result = None
46
+ self.valid = False
47
+
48
+ def get_white_trajectory(self):
49
+ """Extract the trajectory for white"""
50
+ indices = []
51
+ for i in range(len(self.fens) - 1):
52
+ board = chess.Board(self.fens[i])
53
+ if board.turn: # True if white to move
54
+ indices.append(i)
55
+
56
+ return {
57
+ 'fens': [self.fens[i] for i in indices],
58
+ 'repetition_counts': [self.repetition_counts[i] for i in indices],
59
+ 'actions': [self.actions[i] for i in indices],
60
+ 'values': [self.values[i] for i in indices],
61
+ 'log_probs': [self.log_probs[i] for i in indices],
62
+ 'invalid_masks': [self.invalid_masks[i] for i in indices]
63
+ }
64
+
65
+ def get_black_trajectory(self):
66
+ """Extract the trajectory for black pieces."""
67
+ indices = []
68
+ for i in range(len(self.fens) - 1):
69
+ board = chess.Board(self.fens[i])
70
+ if not board.turn: # False if black to move
71
+ indices.append(i)
72
+
73
+ return {
74
+ 'fens': [self.fens[i] for i in indices],
75
+ 'repetition_counts': [self.repetition_counts[i] for i in indices],
76
+ 'actions': [self.actions[i] for i in indices],
77
+ 'values': [self.values[i] for i in indices],
78
+ 'log_probs': [self.log_probs[i] for i in indices],
79
+ 'invalid_masks': [self.invalid_masks[i] for i in indices]
80
+ }
81
+
82
+
83
+
84
+
85
+
86
+ class ReplayBuffer:
87
+ """
88
+ The buffer class for PPO reinforcement learning.
89
+ Handles:
90
+ - store samples including:
91
+ 1. fens
92
+ 2. reps
93
+ 3. actions
94
+ 4. log_probs
95
+ 5. values
96
+ 6. invalid_masks
97
+ - calculate advantage based on reward and value (7. advantage)
98
+ - output samples in batches
99
+ Since PPO is largely on-policy, so the replay buffer will not be so large that deque is not appropriate
100
+ """
101
+ def __init__(self,
102
+ capacity: int,
103
+ batch_size: int,
104
+ gamma: float,
105
+ gae_lambda: float,
106
+ shuffle: bool=True
107
+ ):
108
+ self.gamma = gamma
109
+ self.gae_lambda = gae_lambda
110
+
111
+ self.fens = deque(maxlen=capacity)
112
+ self.repetition_counts = deque(maxlen=capacity)
113
+ self.actions = deque(maxlen=capacity)
114
+ self.log_probs = deque(maxlen=capacity)
115
+ self.values = deque(maxlen=capacity)
116
+ self.invalid_masks = deque(maxlen=capacity)
117
+ self.advantages = deque(maxlen=capacity)
118
+
119
+ self.batch_size = batch_size
120
+ self.shuffle = shuffle
121
+
122
+ def push_game(self, game: Game):
123
+ """
124
+ Process a completed game and add its trajectories to the buffer.
125
+ Handles reward computation for both white and black players.
126
+ """
127
+ if not game.valid:
128
+ return
129
+
130
+ result = game.game_result
131
+ if result not in ["1-0","0-1","1/2-1/2"]:
132
+ raise ValueError(f"Result not recognized: {result}. Either an incompleted game was passed in, or game.update_game_status() method is wrong.")
133
+
134
+ if result == "1-0": result = 1
135
+ elif result == "0-1": result = -1
136
+ elif result == "1/2-1/2": result = 0
137
+
138
+ white_traj = game.get_white_trajectory()
139
+ if white_traj["fens"]:
140
+ self._process_trajectory(
141
+ white_traj["fens"],
142
+ white_traj["repetition_counts"],
143
+ white_traj["actions"],
144
+ white_traj["log_probs"],
145
+ white_traj["values"],
146
+ white_traj["invalid_masks"],
147
+ result
148
+ )
149
+
150
+ black_traj = game.get_black_trajectory()
151
+ if black_traj["fens"]:
152
+ self._process_trajectory(
153
+ black_traj["fens"],
154
+ black_traj["repetition_counts"],
155
+ black_traj["actions"],
156
+ black_traj["log_probs"],
157
+ black_traj["values"],
158
+ black_traj["invalid_masks"],
159
+ -result # flip reward for black's perspective
160
+ )
161
+
162
+ def _process_trajectory(self, fens, reps, actions, log_probs, values, invalid_masks, final_reward):
163
+ """Process a trajectory for one player, compute advantages and add to buffer"""
164
+ values_tensor = torch.tensor(values) if not torch.is_tensor(values) else values
165
+
166
+ advantages = self._compute_advantage(values_tensor, final_reward)
167
+
168
+ for i in range(len(fens)):
169
+ self.fens.append(fens[i])
170
+ self.repetition_counts.append(reps[i])
171
+ self.actions.append(actions[i])
172
+ self.log_probs.append(log_probs[i])
173
+ self.values.append(values[i])
174
+ self.invalid_masks.append(invalid_masks[i])
175
+ self.advantages.append(advantages[i])
176
+
177
+ def _compute_advantage(self, value_traj: torch.Tensor, final_reward: float) -> torch.Tensor:
178
+ """
179
+ Calculate GAE with only a terminal reward: r_t = 0 for t < T-1 and r_{T-1} = final_reward
180
+ Args:
181
+ value_traj: value prediction of the model
182
+ final_reward: game result
183
+
184
+ Returns:
185
+ advantage, torch.Tensor of shape same with value_traj
186
+ """
187
+
188
+ vals = value_traj.detach().cpu().float()
189
+ T = vals.shape[0] if vals.dim() > 0 else 1
190
+
191
+ adv = torch.zeros(T)
192
+ next_value = 0.0
193
+ gae = 0.0
194
+
195
+ for t in reversed(range(T)):
196
+ reward = final_reward if t == T-1 else 0.0
197
+ delta = reward + self.gamma * next_value - vals[t]
198
+ gae = delta + self.gamma * self.gae_lambda * gae
199
+ adv[t] = gae
200
+ next_value = vals[t]
201
+
202
+ return adv
203
+
204
+ def sample(self) -> Iterator[Tuple[List[str], # fen
205
+ torch.Tensor,# rep
206
+ torch.Tensor,# act
207
+ torch.Tensor,# logp
208
+ torch.Tensor,# val
209
+ torch.Tensor,# inv_m
210
+ torch.Tensor]]: # adv
211
+ """Yield minibatches of size batch_size for training"""
212
+ n = len(self.fens)
213
+ if n < self.batch_size:
214
+ return
215
+
216
+ idxs = np.arange(n)
217
+ if self.shuffle:
218
+ np.random.shuffle(idxs)
219
+
220
+ for start in range(0, n, self.batch_size):
221
+ batch_idx = idxs[start:start+self.batch_size]
222
+ if len(batch_idx) < self.batch_size:
223
+ break
224
+
225
+ fens_b = [self.fens[i] for i in batch_idx]
226
+
227
+ reps_b = torch.stack([
228
+ self.repetition_counts[i].detach().clone() if torch.is_tensor(self.repetition_counts[i])
229
+ else torch.tensor(self.repetition_counts[i])
230
+ for i in batch_idx
231
+ ])
232
+
233
+ acts_b = torch.stack([
234
+ self.actions[i].detach().clone() if torch.is_tensor(self.actions[i])
235
+ else torch.tensor(self.actions[i])
236
+ for i in batch_idx
237
+ ])
238
+ logps_b = torch.stack([
239
+ self.log_probs[i].detach().clone() if torch.is_tensor(self.log_probs[i])
240
+ else torch.tensor(self.log_probs[i])
241
+ for i in batch_idx
242
+ ])
243
+
244
+ vals_b = torch.stack([
245
+ self.values[i].detach().clone() if torch.is_tensor(self.values[i])
246
+ else torch.tensor(self.values[i])
247
+ for i in batch_idx
248
+ ])
249
+
250
+ advs_b = torch.stack([
251
+ self.advantages[i].detach().clone() if torch.is_tensor(self.advantages[i])
252
+ else torch.tensor(self.advantages[i])
253
+ for i in batch_idx
254
+ ])
255
+
256
+ invs_b = torch.stack([
257
+ self.invalid_masks[i] if torch.is_tensor(self.invalid_masks[i])
258
+ else torch.tensor(self.invalid_masks[i])
259
+ for i in batch_idx
260
+ ])
261
+
262
+ yield fens_b, reps_b, acts_b, logps_b, vals_b, invs_b, advs_b
263
+
264
+ def __len__(self) -> int:
265
+ return len(self.fens)
266
+
267
+ def clear(self) -> None:
268
+ self.fens.clear()
269
+ self.repetition_counts.clear()
270
+ self.actions.clear()
271
+ self.log_probs.clear()
272
+ self.values.clear()
273
+ self.invalid_masks.clear()
274
+ self.advantages.clear()
utils/chess_env.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Provide a gym-like environment for clarity"""
2
+
3
+ import chess
4
+ import torch
5
+ import time
6
+ from typing import List, Tuple, Dict
7
+ try:
8
+ from .mapping import IDX_TO_UCI_MOVE, UCI_MOVE_TO_IDX
9
+ except:
10
+ from mapping import IDX_TO_UCI_MOVE, UCI_MOVE_TO_IDX
11
+
12
+ class BatchChessEnv:
13
+ """A single chess environment with sparse terminal reward"""
14
+ def __init__(self, batch_size: int, max_moves: int=200):
15
+ self.batch_size = batch_size
16
+ self.max_moves = max_moves
17
+ self.reset()
18
+
19
+ def reset(self) -> Tuple[List[str], torch.Tensor]:
20
+ """
21
+ Starts all games from the initial position
22
+ Returns:
23
+ fens (List[str]), repetition_counts (torch.Tensor of shape [batch_size,])
24
+ """
25
+ self.boards = [chess.Board() for _ in range(self.batch_size)]
26
+ self.move_counts = [0] * self.batch_size
27
+ self.done_flags = [False] * self.batch_size
28
+
29
+ fens = [self.boards[0].fen()] * self.batch_size
30
+ reps = torch.ones(self.batch_size,dtype=torch.long)
31
+ return fens, reps # (bs,)
32
+
33
+ def _compute_rep(self, board: chess.Board) -> int:
34
+ board_copy = board.copy()
35
+ trasposition_key = board_copy._transposition_key()
36
+ count = 0
37
+ while board_copy.move_stack:
38
+ board_copy.pop()
39
+ if board_copy._transposition_key() == trasposition_key:
40
+ count += 1
41
+ return count + 1 # 1 for fresh position
42
+
43
+ def step(self, uci_moves: List[str]) -> Tuple[List[str], # next fens (next state)
44
+ torch.Tensor, # next reps (next state)
45
+ List[bool], # dones
46
+ List[Dict]]: # infos
47
+ """
48
+ Apply one move per game in the batch.
49
+ Args:
50
+ uci_moves: list of UCI strings (plus "<claim_draw>")
51
+ Returns:
52
+ next_fens: new FENs for each game, List[str]
53
+ reps: repetition counts, Tensor[batch_size]
54
+ dones: whether this game is now terminated, List[bool]
55
+ infos: info dict with 'result' key for completed games List[dict]
56
+ """
57
+ next_fens, reps, dones, infos = [], [], [], []
58
+
59
+ for i, move in enumerate(uci_moves):
60
+ board = self.boards[i]
61
+ info = {
62
+ "max_steps_exceeded": False,
63
+ "truncation_due_to_error": False,
64
+ "result": None
65
+ }
66
+ done = self.done_flags[i]
67
+
68
+ if done:
69
+ # Game already done, pass through the existing state
70
+ next_fens.append(board.fen())
71
+ reps.append(1)
72
+ dones.append(True)
73
+ infos.append(info)
74
+ continue
75
+
76
+ if move == "0000":
77
+ # Skip through dummy moves
78
+ next_fens.append(board.fen())
79
+ reps.append(1)
80
+ dones.append(True)
81
+ infos.append(info)
82
+ continue
83
+
84
+ if board.is_game_over():
85
+ # Game already over
86
+ done = True
87
+ info["result"] = board.result()
88
+ next_fens.append(board.fen())
89
+ reps.append(self._compute_rep(board))
90
+ dones.append(done)
91
+ infos.append(info)
92
+ continue
93
+
94
+ try:
95
+ if move == "<claim_draw>":
96
+ if board.can_claim_draw():
97
+ done = True
98
+ info['result'] = "1/2-1/2"
99
+ else:
100
+ raise ValueError(f"Invalid move ('<claim_draw>') passed in.")
101
+ else:
102
+ try:
103
+ m = chess.Move.from_uci(move)
104
+ if m in board.legal_moves:
105
+ board.push(m)
106
+ self.move_counts[i] += 1
107
+
108
+ if board.is_game_over():
109
+ done = True
110
+ info['result'] = board.result()
111
+ else:
112
+ raise ValueError(f"Invalid move ('{m.uci()}') passed in.")
113
+ except Exception as e:
114
+ done = True
115
+ info['truncation_due_to_error'] = True
116
+ print(f"Unexpected error: {e}")
117
+
118
+ if self.move_counts[i] >= self.max_moves:
119
+ done = True
120
+ info['max_steps_exceeded'] = True
121
+ info['result'] = "1/2-1/2"
122
+
123
+ next_fens.append(board.fen())
124
+ reps.append(self._compute_rep(board))
125
+ dones.append(done)
126
+ infos.append(info)
127
+
128
+ except Exception as e:
129
+ print(f"Error processing move {move} for board {i}: {e}")
130
+ done = True
131
+ info["truncation_due_to_error"] = True
132
+ next_fens.append(board.fen())
133
+ reps.append(self._compute_rep(board))
134
+ dones.append(done)
135
+ infos.append(info)
136
+
137
+ self.done_flags[i] = done
138
+
139
+ reps = torch.tensor(reps,dtype=torch.long) # [bs,]
140
+ return next_fens, reps, dones, infos
141
+
142
+ if __name__ == "__main__":
143
+ env = BatchChessEnv(1)
144
+ env.reset()
145
+ board = env.boards[0]
146
+ board.push(chess.Move.from_uci("e2e4"))
147
+ new_board = board.copy()
148
+ rep = env._compute_rep(new_board)
149
+ print(rep)
150
+
151
+
utils/engine.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """An engine class to provide a universal way to interact with both chessformer and stockfish"""
2
+ import torch
3
+ import chess
4
+ import math
5
+ import chess.engine
6
+ import multiprocessing
7
+ from dataclasses import dataclass, field
8
+ from functools import partial
9
+ import time
10
+ import os
11
+
12
+ try:
13
+ from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
14
+ except ImportError:
15
+ from mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
16
+ from torch.distributions import Categorical
17
+ from typing import Optional, Tuple, List, Union
18
+
19
+ @dataclass
20
+ class ChessformerConfig:
21
+ chessformer: torch.nn.Module=None
22
+ device: Optional[torch.device]=None
23
+ temperature: float=0.5
24
+ depth: int=2
25
+ top_k: int=8
26
+ decay_rate: float=0.6
27
+ max_batch_size: int=896
28
+
29
+ @dataclass
30
+ class StockfishConfig:
31
+ engine_path: str="/usr/games/stockfish"
32
+ depth: int=16
33
+
34
+
35
+ def _stockfish_worker(board_fen: str, engine_path: str, depth: int) -> Optional[Tuple[str, float]]:
36
+ """
37
+ Analyzes a single board FEN using a temporary Stockfish engine instance.
38
+ Designed for use with multiprocessing.
39
+ Returns the best move UCI and the normalized score [-1,1].
40
+ Does not handle draw claims explicitly as FEN lacks history.
41
+ Caller should check board.is_game_over() on the main board object.
42
+ """
43
+ engine = None
44
+ try:
45
+ engine = chess.engine.SimpleEngine.popen_uci(engine_path)
46
+ # initialize board from FEN - history is lost here
47
+ board = chess.Board(board_fen)
48
+
49
+ info = engine.analyse(board, chess.engine.Limit(depth=depth))
50
+
51
+ score_obj = info.get("score")
52
+ pv = info.get("pv")
53
+
54
+ if score_obj is None or pv is None or not pv:
55
+ # Analysis failed
56
+ print(f"Warning: Stockfish analysis failed for FEN: {board_fen}")
57
+ return None
58
+
59
+ best_move_uci = pv[0].uci()
60
+ pov_score = score_obj.pov(board.turn)
61
+ cp = None
62
+
63
+ if pov_score.is_mate():
64
+ mate_score = pov_score.mate()
65
+ cp = 10000.0 if mate_score > 0 else -10000.0
66
+ elif pov_score.cp is not None:
67
+ cp = float(pov_score.cp)
68
+ else:
69
+ print(f"Warning: Stockfish score object lacks cp/mate for FEN: {board_fen}")
70
+ return None # analysis is unclear
71
+
72
+ normalized_cp = 2 / (1 + math.exp(-0.004*cp)) - 1
73
+
74
+ return best_move_uci, normalized_cp
75
+
76
+ except (chess.engine.EngineError, chess.engine.EngineTerminatedError, FileNotFoundError, ValueError) as e:
77
+ print(f"Stockfish worker error for FEN {board_fen}: {e}")
78
+ return None
79
+ finally:
80
+ if engine:
81
+ engine.quit()
82
+
83
+ def _compute_repetition_single(board: chess.Board) -> int:
84
+ """Compute repetition count. Used in _chessformer_move and _batch_chessformer_move"""
85
+
86
+ transposition_key = board._transposition_key()
87
+ count = 0
88
+ if board.move_stack:
89
+ if board._transposition_key() == transposition_key:
90
+ count = 1
91
+ else:
92
+ count = 1
93
+ try:
94
+ # Iterate back through history
95
+ while board.move_stack:
96
+ move = board.pop() # note that history is lost here
97
+ if board.is_irreversible(move):
98
+ break
99
+ if board._transposition_key() == transposition_key:
100
+ count += 1
101
+ except Exception as e:
102
+ print(f"Error occurred during repetition count for board {board.fen()}: {e}")
103
+ return 1 # fallback to 1
104
+ return max(1, count)
105
+
106
+ # Engine class, designed to be used in the Evaluator class and app.py
107
+ class Engine:
108
+ def __init__(self,
109
+ type: str,
110
+ chessformer_config: Optional[ChessformerConfig]=None,
111
+ stockfish_config: Optional[StockfishConfig]=None):
112
+ self.type = type
113
+ if type == "chessformer":
114
+ if chessformer_config is None:
115
+ raise ValueError("ChessformerConfig must be provided for chessformer engine.")
116
+
117
+ self.config = chessformer_config
118
+ if self.config.chessformer is None:
119
+ raise ValueError("ChessFormer model must be provided in config.")
120
+
121
+ if self.config.device is None:
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ elif isinstance(self.config.device, str):
124
+ self.device = torch.device(self.config.device)
125
+ else:
126
+ self.device = self.config.device
127
+
128
+ self.model = self.config.chessformer
129
+ self.model.to(self.device)
130
+ self.model.eval()
131
+
132
+ if not (self.config.temperature > 0):
133
+ raise ValueError("Temperature must be greater than 0.")
134
+ if not (self.config.top_k > 0):
135
+ raise ValueError("Top-k must be greater than 0.")
136
+ if not (self.config.depth >= 0):
137
+ raise ValueError("Depth must be greater than or equal to 0.")
138
+ if not (0.0 < self.config.decay_rate <= 1.0):
139
+ raise ValueError("Decay rate must be in range (0.0,1.0].")
140
+ if not (self.config.max_batch_size > 0):
141
+ raise ValueError("Max batch size must be an integer greater than 0.")
142
+
143
+ self.temperature = self.config.temperature
144
+ self.top_k = self.config.top_k
145
+ self.initial_k = self.top_k
146
+ self.depth = self.config.depth
147
+ self.decay_rate = self.config.decay_rate
148
+ self.max_batch_size = self.config.max_batch_size
149
+ elif type == "stockfish":
150
+ if stockfish_config is None:
151
+ raise ValueError("StockfishConfig must be provided for stockfish engine.")
152
+
153
+ self.config = stockfish_config
154
+ self.engine_path = self.config.engine_path
155
+ self.depth = self.config.depth
156
+ if self.config.engine_path is None:
157
+ raise ValueError("Engine path must be provided in config.")
158
+ try:
159
+ with chess.engine.SimpleEngine.popen_uci(self.config.engine_path) as test:
160
+ pass
161
+ except (FileNotFoundError, chess.engine.EngineError) as e:
162
+ raise ValueError(f"Invalid engine path or engine not found: {e}")
163
+ else:
164
+ raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
165
+
166
+ def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
167
+ bs = len(boards)
168
+ possible_moves = len(UCI_MOVE_TO_IDX)
169
+ invalid_mask = torch.full((bs,possible_moves), -torch.inf, dtype=torch.float32, device=self.device)
170
+ for idx,board in enumerate(boards):
171
+ if board.is_game_over(claim_draw=True):
172
+ continue # leave all as -inf
173
+ legal_moves = list(board.legal_moves)
174
+ legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
175
+ if legal_move_ids:
176
+ invalid_mask[idx,legal_move_ids] = 0
177
+ if board.can_claim_draw():
178
+ invalid_mask[idx,0] = 0
179
+
180
+ return invalid_mask
181
+
182
+ def compute_repetition(self, boards: List[chess.Board]) -> torch.Tensor:
183
+ """Use multiprocessing to compute repetition count for a batch of boards."""
184
+ bs = len(boards)
185
+ num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count else 1))
186
+ if bs < num_workers * 2: # avoid overhead for very small batches per worker
187
+ num_workers = max(1, bs//2)
188
+
189
+ try:
190
+ if num_workers > 1 and bs > 1:
191
+ board_copies = [board.copy(stack=True) for board in boards]
192
+ with multiprocessing.Pool(processes=num_workers) as pool:
193
+ counts = pool.map(_compute_repetition_single, board_copies)
194
+ else:
195
+ # Run sequentially if only one worker needed or very small batch
196
+ counts = [_compute_repetition_single(b.copy(stack=True)) for b in boards]
197
+
198
+ counts_tensor = torch.tensor(counts, dtype=torch.long, device=self.device)
199
+ return counts_tensor # (B,)
200
+ except Exception as e:
201
+ print(f"Error during batch repetition computation: {e}")
202
+ # Fall back to single board computation if multiprocessing fails
203
+ return torch.ones((bs,),dtype=torch.long, device=self.device)
204
+
205
+ def _raw_chessformer_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
206
+ """Get the next move from ChessFormer model with optional tactical verification."""
207
+ # Get FEN
208
+ fen = board.fen()
209
+
210
+ # Compute repetition
211
+ count_tensor = self.compute_repetition([board])
212
+
213
+ move_logits, value = self.model([fen],count_tensor)
214
+ move_logits = move_logits.squeeze(0) # remove batch dimension since it will always be 1
215
+ value = value.squeeze(0).item()
216
+
217
+ # Calculate invalid mask
218
+ legal_moves = list(board.legal_moves)
219
+ if not legal_moves and not board.can_claim_draw():
220
+ # No legal moves. Should not happen if this function is called correctly, but it wouldn't hurt to add a check
221
+ return None
222
+ legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
223
+ invalid_mask = torch.ones_like(move_logits) * (-torch.inf)
224
+ invalid_mask[legal_move_ids] = 0
225
+ if board.can_claim_draw():
226
+ invalid_mask[0] = 0
227
+ move_logits = move_logits + invalid_mask
228
+
229
+ if return_perplexity:
230
+ probs = torch.softmax(move_logits, dim=-1)
231
+ perplexity = torch.exp(-torch.sum(probs*torch.log(probs+1e-8))).item()
232
+
233
+ top_k_ids = torch.topk(move_logits, self.top_k, dim=-1).indices
234
+ top_k_mask = torch.ones_like(move_logits) * (-torch.inf)
235
+ top_k_mask[top_k_ids] = 0
236
+ move_logits = move_logits + top_k_mask
237
+ move_logits = move_logits / self.temperature
238
+
239
+ # Sample
240
+ dist = Categorical(logits=move_logits)
241
+ move_id = dist.sample().item()
242
+ move = IDX_TO_UCI_MOVE[move_id]
243
+ if return_perplexity:
244
+ return move, value, perplexity
245
+ else:
246
+ return move, value
247
+
248
+ def _search_enhanced_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
249
+ """Get move from chessformer using tactical search"""
250
+ # Step 1: Build search tree level by level
251
+ current_boards = [board] # aggregate board to a list for batch inference
252
+ board_probs = [1] # the probabilities of getting to this position (estimated)
253
+
254
+ terminal_leaves = [] # (root_move, prob, game_result_value) ^from white's perspective
255
+ search_leaves = [] # (root_move, prob, board) - leaves not terminal but reached max depth therefore needs evaluation from model
256
+
257
+ # Track which root_move each board came from
258
+ board_to_root_move = [None] # root board has no parent move
259
+
260
+ for depth in range(self.depth+1):
261
+ if not current_boards:
262
+ break
263
+ k = max(1,int(self.initial_k*(self.decay_rate**depth)))
264
+
265
+ fens = [b.fen() for b in current_boards]
266
+ reps = self.compute_repetition(current_boards)
267
+
268
+ with torch.no_grad():
269
+ logits, values = self.model(fens,reps)
270
+
271
+ next_boards = []
272
+ next_board_probs = []
273
+ next_board_to_root_move = []
274
+
275
+ # Process each board at current depth
276
+ for board_idx, current_board in enumerate(current_boards):
277
+ board_logits = logits[board_idx]
278
+ board_prob = board_probs[board_idx]
279
+ parent_root_move = board_to_root_move[board_idx]
280
+
281
+ # Check if game is over
282
+ if current_board.is_game_over(claim_draw=True):
283
+ outcome = current_board.outcome(claim_draw=True)
284
+ if outcome.winner == chess.WHITE:
285
+ game_value = 1.0
286
+ elif outcome.winner == chess.BLACK:
287
+ game_value = -1.0
288
+ else:
289
+ game_value = 0.0
290
+ terminal_leaves.append((parent_root_move, board_prob, game_value))
291
+ continue
292
+
293
+ # If we've reached max depth, add to search leaves
294
+ if depth == self.depth:
295
+ search_leaves.append((parent_root_move, board_prob, current_board))
296
+ continue
297
+
298
+ # Otherwise, recursively search deeper
299
+ invalid_mask = self.get_invalid_mask([current_board])[0]
300
+ masked_logits = board_logits + invalid_mask
301
+
302
+ top_k_values, top_k_indices = torch.topk(masked_logits,k=min(k,torch.sum(invalid_mask==0).item()))
303
+ top_k_probs = torch.softmax(top_k_values,dim=0)
304
+ if depth==0:
305
+ initial_masked_logits = masked_logits.squeeze(0)
306
+ initial_invalid_mask = invalid_mask.squeeze(0)
307
+ initial_top_k_indices = top_k_indices
308
+
309
+ # Expand each top k move
310
+ for i,move_idx in enumerate(top_k_indices):
311
+ move_prob = top_k_probs[i].item()
312
+ move_uci = IDX_TO_UCI_MOVE[move_idx.item()]
313
+
314
+ root_move = parent_root_move if parent_root_move is not None else move_uci
315
+
316
+ new_board = current_board.copy()
317
+
318
+ if move_uci == "<claim_draw>":
319
+ if new_board.can_claim_draw():
320
+ terminal_leaves.append((root_move,board_prob*move_prob,0.0))
321
+ continue
322
+ else:
323
+ continue # should not happen, invalid draw claim
324
+ else:
325
+ move = chess.Move.from_uci(move_uci)
326
+ new_board.push(move)
327
+
328
+ next_boards.append(new_board)
329
+ next_board_probs.append(board_prob*move_prob)
330
+ next_board_to_root_move.append(root_move)
331
+
332
+ current_boards = next_boards
333
+ board_probs = next_board_probs
334
+ board_to_root_move = next_board_to_root_move
335
+
336
+ # Step 2: Evaluate all search leaves
337
+ if search_leaves:
338
+ search_boards = [leaf[2] for leaf in search_leaves]
339
+ search_fens = [b.fen() for b in search_boards]
340
+ search_reps = self.compute_repetition(search_boards)
341
+
342
+ with torch.no_grad():
343
+ _, search_values = self.model(search_fens, search_reps)
344
+
345
+ for i, (root_move, prob, leaf_board) in enumerate(search_leaves):
346
+ value = search_values[i].item()
347
+ white_perspective_value = value if leaf_board.turn == chess.WHITE else -value
348
+ terminal_leaves.append((root_move,prob,white_perspective_value))
349
+
350
+ # Step 3: Aggregate all leaves using probability weights
351
+ root_move_weighted_values = {}
352
+ root_move_total_probs = {}
353
+ for root_move, prob, value in terminal_leaves:
354
+ if root_move not in root_move_weighted_values:
355
+ root_move_weighted_values[root_move] = 0.0
356
+ root_move_total_probs[root_move] = 0.0
357
+ root_move_weighted_values[root_move] += prob * value
358
+ root_move_total_probs[root_move] += prob
359
+
360
+ final_value = sum(root_move_weighted_values.values())
361
+ final_value = final_value if board.turn == chess.WHITE else -final_value
362
+
363
+ root_move_values = {}
364
+ for root_move in root_move_total_probs:
365
+ if root_move_total_probs[root_move] > 0:
366
+ root_move_values[root_move] = root_move_weighted_values[root_move] / root_move_total_probs[root_move]
367
+ else:
368
+ root_move_values[root_move] = 0
369
+
370
+ # Step 4: Confidence-based weighting with search results
371
+ initial_probs = torch.softmax(initial_masked_logits,dim=0)
372
+ entropy = -torch.sum(initial_probs*torch.log(initial_probs+1e-8))
373
+ max_entropy = math.log(torch.sum(initial_invalid_mask==0).item())
374
+ confidence = 1.0 - (entropy/max_entropy) if max_entropy > 0 else 1.0
375
+
376
+ if root_move_values:
377
+ search_adjustment_logits = torch.zeros_like(initial_masked_logits)
378
+ for move_uci,search_value in root_move_values.items():
379
+ move_idx = UCI_MOVE_TO_IDX[move_uci]
380
+ search_adjustment_logits[move_idx] += search_value
381
+ # flip value according to perpective
382
+ search_adjustment_logits = search_adjustment_logits if board.turn==chess.WHITE else -search_adjustment_logits
383
+ search_adjustment_logits = search_adjustment_logits - search_adjustment_logits.mean()
384
+
385
+ # Normalize search logits to be in the same norm as the initial logits
386
+
387
+ initial_valid_norm = torch.norm(initial_masked_logits[initial_top_k_indices]) + 1e-8
388
+ search_valid_norm = torch.norm(search_adjustment_logits[initial_top_k_indices]) + 1e-8
389
+
390
+ normalized_search = search_adjustment_logits * initial_valid_norm / search_valid_norm
391
+ normalized_initial = initial_masked_logits
392
+
393
+ adjusted_logits = confidence * normalized_initial + (1 - confidence) * normalized_search
394
+ else:
395
+ adjusted_logits = initial_masked_logits
396
+
397
+ # Apply temperature and top-k filtering
398
+ top_k_mask = torch.full_like(adjusted_logits, -torch.inf)
399
+ top_k_mask[initial_top_k_indices] = 0
400
+ adjusted_logits = adjusted_logits + top_k_mask
401
+ adjusted_logits = adjusted_logits / self.temperature
402
+
403
+ dist = Categorical(logits=adjusted_logits)
404
+ move_idx = dist.sample().item()
405
+ move_uci = IDX_TO_UCI_MOVE[move_idx]
406
+
407
+ if return_perplexity:
408
+ final_probs = torch.softmax(adjusted_logits,dim=0)
409
+ perplexity = torch.exp(-torch.sum(final_probs * torch.log(final_probs + 1e-8))).item()
410
+
411
+ if verbose and self.depth > 0:
412
+ print(f"\n--- Search Enhanced Move Debug Info ({board.fen()}) ---")
413
+ print(f"Confidence: {confidence:.4f}")
414
+
415
+ print("\nMove Analysis (Initial Top-K Candidates):")
416
+ print(f"{'Move':<8} {'Initial Logit':<15} {'Search Adj. Logit':<19} {'Final Adj. Logit':<18} {'Final Prob':<12}")
417
+ print(f"{'-'*8:<8} {'-'*15:<15} {'-'*19:<19} {'-'*18:<18} {'-'*12:<12}")
418
+
419
+ for i, idx in enumerate(initial_top_k_indices):
420
+ move_uci_v = IDX_TO_UCI_MOVE[idx.item()]
421
+ initial_logit = normalized_initial[idx].item()
422
+
423
+ search_adj_logit_val = normalized_search[idx].item() if root_move_values else 0.0
424
+
425
+ final_adj_logit = adjusted_logits[idx].item()
426
+ final_prob_val = final_probs[idx].item()
427
+
428
+ print(f"{move_uci_v:<8} {initial_logit:<15.4f} {search_adj_logit_val:<19.4f} {final_adj_logit:<18.4f} {final_prob_val:<12.4f}")
429
+
430
+ print(f"\nPerplexity: {perplexity:.4f}")
431
+ print(f"Predicted Value (White's POV): {final_value:.4f}")
432
+
433
+ print("\nLeaf Node Values (Root Move, Probability, Value from White's POV):")
434
+ for rm, prob, val in terminal_leaves:
435
+ print(f" Root Move: {rm:<8}, Prob: {prob:<.4f}, Value: {val:<.4f}")
436
+ print("--------------------------------------------------")
437
+
438
+ return move_uci, final_value, perplexity
439
+ else:
440
+ return move_uci, final_value
441
+
442
+ def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
443
+ """Get move from chessformer with optional search enhance"""
444
+ if self.depth == 0:
445
+ return self._raw_chessformer_move(board,return_perplexity)
446
+ else:
447
+ return self._search_enhanced_move(board,return_perplexity,verbose)
448
+
449
+ def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
450
+ """Get best move from stockfish"""
451
+ try:
452
+ engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
453
+ info = engine.analyse(board, chess.engine.Limit(depth=self.depth))
454
+ except (chess.engine.EngineError, chess.engine.EngineTerminatedError) as e:
455
+ print(f"Stockfish error: {e}")
456
+ return None
457
+
458
+ loss_threshold = -0.4
459
+
460
+ score_obj = info.get("score")
461
+ can_claim_draw = board.can_claim_draw()
462
+ if score_obj is None or info.get("pv") is None or not info.get("pv"):
463
+ # Invalid analysis result
464
+ return None
465
+
466
+ pv = info["pv"]
467
+ pov_score = score_obj.pov(chess.WHITE)
468
+ cp = None
469
+ if pov_score.is_mate():
470
+ mate_score = pov_score.mate()
471
+ cp = 10000.0 if mate_score > 0 else -10000.0
472
+ relative_score = score_obj.relative
473
+ if relative_score.is_mate():
474
+ cp = 10000.0 if relative_score.mate() > 0 else -10000.0
475
+ else:
476
+ if relative_score.cp is not None:
477
+ cp = float(relative_score.cp)
478
+ else:
479
+ return None
480
+
481
+ elif pov_score.cp is not None:
482
+ relative_score = score_obj.relative
483
+ if relative_score.cp is not None:
484
+ cp = float(relative_score.cp)
485
+ else:
486
+ return None
487
+
488
+ else:
489
+ return None
490
+
491
+ if cp is not None:
492
+ normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
493
+ else:
494
+ return None
495
+
496
+ if can_claim_draw and normalized_score < loss_threshold:
497
+ best_move_uci = "<claim_draw>"
498
+ else:
499
+ best_move_uci = pv[0].uci()
500
+
501
+ if engine:
502
+ engine.quit()
503
+
504
+ if return_perplexity:
505
+ return best_move_uci, normalized_score, None
506
+ else:
507
+ return best_move_uci, normalized_score
508
+
509
+ def _batch_chessformer_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
510
+ """Get the next moves from Chessformer model using batch inference."""
511
+ bs = len(boards)
512
+ if bs > self.max_batch_size:
513
+ raise ValueError(f"num boards ({bs}) exceeded max batch size ({self.max_batch_size}).")
514
+ fens = [board.fen() for board in boards]
515
+
516
+ count_tensor = self.compute_repetition(boards) # shape (bs, 1)
517
+ count_tensor = count_tensor.to(self.device)
518
+
519
+ with torch.no_grad():
520
+ move_logits, values = self.model(fens, count_tensor)
521
+
522
+ invalid_mask = self.get_invalid_mask(boards)
523
+
524
+ # Apply mask
525
+ move_logits = move_logits + invalid_mask
526
+
527
+ all_masked = torch.all(torch.isinf(move_logits), dim=-1)
528
+
529
+ # Apply top-p filtering
530
+ if 0.0 < self.top_p < 1.0: # Apply only if top_p is strictly between 0 and 1
531
+ sorted_logits, sorted_indices = torch.sort(move_logits, descending=True, dim=-1)
532
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
533
+ sorted_indices_to_remove = cumulative_probs > self.top_p
534
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
535
+ sorted_indices_to_remove[..., 0] = 0
536
+ indices_to_remove = torch.zeros_like(move_logits, dtype=torch.bool).scatter_(
537
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
538
+ )
539
+ move_logits[indices_to_remove] = -torch.inf
540
+
541
+ # Apply temperature
542
+ temp = self.temperature if self.temperature > 0 else 1.0
543
+ move_logits = move_logits / temp
544
+
545
+ # Sample moves
546
+ dist = Categorical(logits=move_logits)
547
+ try:
548
+ sampled_indices = dist.sample()
549
+ except RuntimeError as e:
550
+ print(f"Error sampling moves: {e}. Checking logit values...")
551
+ results = []
552
+ for i in range(bs):
553
+ print(f"Board {i} logits sum: {torch.logsumexp(move_logits[i], dim=-1)}")
554
+ results.append(None) # indicate failure
555
+ return results
556
+
557
+ results = []
558
+ for i in range(bs):
559
+ if all_masked[i]:
560
+ results.append(None) # Game already over
561
+ continue
562
+
563
+ move_id = sampled_indices[i].item()
564
+ move_uci = IDX_TO_UCI_MOVE.get(move_id)
565
+ value = values[i].item()
566
+
567
+ if move_uci is None:
568
+ print(f"Warning: Sampled move ID {move_id} not in IDX_TO_UCI_MOVE map")
569
+ results.append(None)
570
+ continue
571
+
572
+ results.append((move_uci, value))
573
+
574
+
575
+ return results
576
+
577
+ def _batch_stockfish_move(self, boards: List[chess.Board], allow_claim_draw: bool=False) -> List[Tuple[str, float]]:
578
+ """Get the next moves from Stockfish engine using multiprocessing"""
579
+ if allow_claim_draw:
580
+ """Use sequential processing to maintain board history"""
581
+ return [self._stockfish_move(board) for board in boards]
582
+ else:
583
+ """Use multiprocessing to speed up if no need to include claim draw logic"""
584
+ bs = len(boards)
585
+ num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count() else 1))
586
+ if bs < num_workers * 2:
587
+ num_workers = max(1, bs//2)
588
+ if bs == 1: num_workers = 1
589
+
590
+ board_fens = [board.fen() for board in boards]
591
+
592
+ worker_func = partial(_stockfish_worker,
593
+ engine_path=self.engine_path,
594
+ depth=self.depth)
595
+ results: List[Optional[Tuple[str,float]]] = [None] * bs
596
+
597
+ active_indices = [i for i,b in enumerate(boards) if not b.is_game_over(claim_draw=True)]
598
+ active_fens = [board_fens[i] for i in active_indices]
599
+
600
+ if not active_fens:
601
+ # All games are over
602
+ return results # list of None
603
+
604
+ try:
605
+ if num_workers > 1 and len(active_fens) > 1:
606
+ with multiprocessing.Pool(processes=num_workers) as pool:
607
+ worker_results = pool.map(worker_func, active_fens)
608
+ else:
609
+ worker_results = [worker_func(fen) for fen in active_fens]
610
+
611
+ for i, res in enumerate(worker_results):
612
+ original_index = active_indices[i]
613
+ results[original_index] = res
614
+
615
+ except Exception as e:
616
+ print(f"Error during batch Stockfish move: {e}")
617
+
618
+ return results
619
+
620
+ def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
621
+ if self.type == "stockfish":
622
+ return self._stockfish_move(board, return_perplexity)
623
+ elif self.type == "chessformer":
624
+ return self._chessformer_move(board, return_perplexity)
625
+ else:
626
+ raise ValueError(f"Invalid engine type: {self.type}")
627
+
628
+ def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
629
+ if self.type == "stockfish":
630
+ return self._batch_stockfish_move(boards)
631
+ elif self.type == "chessformer":
632
+ return self._batch_chessformer_move(boards)
633
+ else:
634
+ raise ValueError(f"Invalid engine type: {self.type}")
635
+
636
+ def analyze_position(self, board: chess.Board) -> Optional[float]:
637
+ """
638
+ Analyzes the given **single board** position using the engine.
639
+ For Stockfish, returns list of centipawn scores from white's perspective;
640
+ For ChessFormer, returns list of models's value estimates
641
+ Returns None if analysis failed.
642
+ """
643
+ if self.type == "stockfish":
644
+ try:
645
+ engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
646
+ info = engine.analyse(board,chess.engine.Limit(depth=self.depth))
647
+ engine.quit()
648
+ except Exception as e:
649
+ print(f"Stockfish error: {e}")
650
+ return None
651
+
652
+ score_obj = info.get("score")
653
+ pov_score = score_obj.pov(chess.WHITE)
654
+ cp = None
655
+ if pov_score.is_mate():
656
+ mate_score = pov_score.mate()
657
+ cp = 10000.0 if mate_score > 0 else -10000.0
658
+ relative_score = score_obj.relative
659
+ if relative_score.is_mate():
660
+ cp = 10000.0 if relative_score.mate() > 0 else -10000.0
661
+ else:
662
+ if relative_score.cp is not None:
663
+ cp = float(relative_score.cp)
664
+ else:
665
+ return None
666
+ elif pov_score.cp is not None:
667
+ relative_score = score_obj.relative
668
+ if relative_score.cp is not None:
669
+ cp = float(relative_score.cp)
670
+ else:
671
+ return None
672
+ else:
673
+ return None
674
+
675
+ if cp is not None:
676
+ normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
677
+ return normalized_score if board.turn == chess.WHITE else -normalized_score
678
+ else:
679
+ return None
680
+
681
+
682
+ elif self.type == "chessformer":
683
+ fen = board.fen()
684
+ count_tensor = self.compute_repetition([board.copy(stack=True)])
685
+
686
+ with torch.no_grad():
687
+ _, value = self.model([fen],count_tensor)
688
+
689
+ value = value.item()
690
+ return value if board.turn == chess.WHITE else -value
691
+
692
+ else:
693
+ raise ValueError(f"Invalid engine type.")
694
+
695
+
696
+ def test_search_enhanced_move(model_path,device):
697
+ """Test the search-enhanced move functionality"""
698
+ print("\n--- Testing Search-Enhanced ChessFormer ---")
699
+
700
+ import sys
701
+ sys.path.append("./")
702
+ try:
703
+ from model import ChessFormerModel
704
+ except ImportError:
705
+ from model import ChessFormerModel
706
+
707
+ # Load the trained model
708
+ checkpoint = torch.load(model_path,map_location=device)
709
+ config = checkpoint["config"]
710
+ model = ChessFormerModel(**config)
711
+ model.load_state_dict(checkpoint["model_state_dict"])
712
+
713
+ model.to(device)
714
+
715
+ # Test different search configurations
716
+ test_configs = [
717
+ #{"depth": 0, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2}, # No search (baseline)
718
+ #{"depth": 1, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2}, # Shallow search
719
+ {"depth": 8, "top_k": 8, "decay_rate": 0.5, "temperature": 0.5}, # Medium search
720
+ ]
721
+
722
+ # Test positions
723
+ test_positions = [
724
+ #chess.Board(), # Starting position
725
+ #chess.Board("r1bqkb1r/pppp1ppp/2n2n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4"), # Italian Game
726
+ #chess.Board("rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq c6 0 2"), # Sicilian Defense
727
+ #chess.Board("r1bq1rk1/ppp2ppp/2n2n2/2bpp3/2B1P3/3P1N2/PPP2PPP/RNBQ1RK1 w - - 0 6"), # Complex middlegame
728
+ chess.Board("r1b1k2r/1p2bpp1/2p1p1np/2N1P3/1q1P4/5N2/B1Q2PPP/R3R1K1 w kq - 0 19"), # blunder: c2e4
729
+ chess.Board("rn1qk2r/1b2bpp1/1pp1pn1p/p7/3P4/2PB1N2/PP1NQPPP/R1B1R1K1 w kq - 2 12"), # blunder: e2e6
730
+ ]
731
+
732
+ for i, cfg in enumerate(test_configs):
733
+ print(f"\n--- Test Configuration {i+1}: Depth={cfg['depth']}, Top-K={cfg['top_k']}, Decay={cfg['decay_rate']}, Temp={cfg['temperature']} ---")
734
+ chessformer_config = ChessformerConfig(
735
+ chessformer=model,
736
+ device=device,
737
+ temperature=cfg['temperature'],
738
+ depth=cfg['depth'],
739
+ top_k=cfg['top_k'],
740
+ decay_rate=cfg['decay_rate']
741
+ )
742
+ engine = Engine(type="chessformer", chessformer_config=chessformer_config)
743
+
744
+ for j, board in enumerate(test_positions):
745
+ print(f"\n--- Analyzing Position {j+1}: {board.fen()} ---")
746
+ try:
747
+ move, value, perplexity = engine._chessformer_move(board, return_perplexity=True, verbose=True)
748
+ print(f"Selected Move: {move}, Predicted Value (White's POV): {value:.4f}, Perplexity: {perplexity:.4f}")
749
+ except Exception as e:
750
+ print(f"Error analyzing position {board.fen()}: {e}")
751
+ import traceback
752
+ traceback.print_exc()
753
+
754
+ if __name__ == "__main__":
755
+ model_path = "./ckpts/chessformer-sl_01.pth"
756
+ device = torch.device("cpu")
757
+ test_search_enhanced_move(model_path,device)
758
+
759
+
utils/mapping.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple, Set
2
+
3
+ # --- Constants --- #
4
+ MAX_HALFMOVES = 128 # cap for embedding table size
5
+ MAX_FULLMOVES = 256 # cap for embedding table size
6
+
7
+ # --- Helper Mappings --- #
8
+ PIECE_TO_IDX: Dict[str, int] = {
9
+ 'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
10
+ 'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11,
11
+ '.': 12
12
+ }
13
+ IDX_TO_PIECE: Dict[int, str] = {v: k for k, v in PIECE_TO_IDX.items()}
14
+ EMPTY_SQ_IDX = PIECE_TO_IDX['.']
15
+ # Map algebraic square notation (e.g., 'a1', 'h8') to 0-63 index
16
+ # a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
17
+ SQUARE_TO_IDX: Dict[str, int] = {
18
+ f"{file}{rank}": (rank - 1) * 8 + (ord(file) - ord('a'))
19
+ for rank in range(1, 9)
20
+ for file in 'abcdefgh'
21
+ }
22
+ IDX_TO_SQUARE: Dict[int, str] = {v: k for k, v in SQUARE_TO_IDX.items()}
23
+
24
+
25
+
26
+ # --- Coordinate and Notation Helpers ---
27
+
28
+ # Precompute maps for efficiency
29
+ _IDX_TO_COORDS: Dict[int, Tuple[int, int]] = {i: (i // 8, i % 8) for i in range(64)} # (rank, file) 0-7
30
+ _COORDS_TO_IDX: Dict[Tuple[int, int], int] = {v: k for k, v in _IDX_TO_COORDS.items()}
31
+ _IDX_TO_ALG: Dict[int, str] = {
32
+ i: f"{chr(ord('a') + file)}{rank + 1}"
33
+ for i, (rank, file) in _IDX_TO_COORDS.items()
34
+ }
35
+ _ALG_TO_IDX: Dict[str, int] = {v: k for k, v in _IDX_TO_ALG.items()}
36
+
37
+ def _coords_to_alg(r: int, f: int) -> str:
38
+ """Converts 0-indexed (rank, file) to algebraic notation."""
39
+ if 0 <= r < 8 and 0 <= f < 8:
40
+ return f"{chr(ord('a') + f)}{r + 1}"
41
+ # This should not happen with valid indices, but good for safety
42
+ raise ValueError(f"Invalid coordinates: ({r}, {f})")
43
+
44
+ def generate_structurally_valid_move_map() -> Dict[str, int]:
45
+ """
46
+ Generates a dictionary mapping chess moves that are geometrically possible
47
+ by *some* standard piece (K, Q, R, B, N, or P) to unique integer indices.
48
+ It excludes moves that are structurally impossible for any piece to make
49
+ in one turn (e.g., a1->h5 for non-knight).
50
+
51
+ Includes standard UCI promotions (e.g., "e7e8q"), replacing the
52
+ corresponding simple pawn move to the final rank (e.g., "e7e8").
53
+ This is based purely on piece movement geometry, not the current board state.
54
+
55
+ Returns:
56
+ Dict[str, int]: A map from the valid UCI move string to a unique
57
+ integer index (0 to N-1). The size N is expected
58
+ to be around 1800-1900.
59
+ """
60
+ valid_moves: Set[str] = set()
61
+ # Keep track of base moves (like 'e7e8') that are replaced by promotions
62
+ # according to UCI standard.
63
+ promo_base_moves_to_exclude: Set[str] = set()
64
+
65
+ # 1. Generate all geometrically possible non-promotion moves
66
+ for from_idx in range(64):
67
+ from_r, from_f = _IDX_TO_COORDS[from_idx]
68
+ from_alg = _IDX_TO_ALG[from_idx]
69
+
70
+ for to_idx in range(64):
71
+ if from_idx == to_idx:
72
+ continue
73
+
74
+ to_r, to_f = _IDX_TO_COORDS[to_idx]
75
+ to_alg = _IDX_TO_ALG[to_idx]
76
+ dr, df = to_r - from_r, to_f - from_f
77
+ abs_dr, abs_df = abs(dr), abs(df)
78
+
79
+ # Check if the geometry matches any standard piece movement
80
+ # Note: Queen moves are covered by Rook + Bishop checks.
81
+ # Note: Pawn single pushes/captures are covered by King/Rook/Bishop geometry.
82
+ # Note: Pawn double pushes are covered by Rook geometry.
83
+ is_king_move = max(abs_dr, abs_df) == 1
84
+ is_knight_move = (abs_dr == 2 and abs_df == 1) or (abs_dr == 1 and abs_df == 2)
85
+ is_rook_move = dr == 0 or df == 0 # Includes King horiz/vert & pawn double push
86
+ is_bishop_move = abs_dr == abs_df # Includes King diagonal & pawn capture/push
87
+
88
+ if is_king_move or is_knight_move or is_rook_move or is_bishop_move:
89
+ uci_move = f"{from_alg}{to_alg}"
90
+ valid_moves.add(uci_move)
91
+
92
+
93
+ # 2. Generate promotion moves explicitly and mark base moves for exclusion
94
+ promo_pieces = ['q', 'r', 'b', 'n']
95
+ for from_f in range(8):
96
+ # White promotions (from rank 7 (idx 6) to rank 8 (idx 7))
97
+ from_r_w, to_r_w = 6, 7
98
+ if from_r_w != 7: # Ensure we are on the correct rank before promotion
99
+ from_alg_w = _coords_to_alg(from_r_w, from_f)
100
+ # Possible destinations: push (df=0), capture left (df=-1), capture right (df=1)
101
+ for df in [-1, 0, 1]:
102
+ to_f_w = from_f + df
103
+ if 0 <= to_f_w < 8:
104
+ to_alg_w = _coords_to_alg(to_r_w, to_f_w)
105
+ base_move = f"{from_alg_w}{to_alg_w}"
106
+ #promo_base_moves_to_exclude.add(base_move) # Mark e.g. "e7e8" for exclusion
107
+ for p in promo_pieces:
108
+ valid_moves.add(f"{base_move}{p}") # Add e.g. "e7e8q"
109
+
110
+ # Black promotions (from rank 2 (idx 1) to rank 1 (idx 0))
111
+ from_r_b, to_r_b = 1, 0
112
+ if from_r_b != 0: # Ensure we are on the correct rank before promotion
113
+ from_alg_b = _coords_to_alg(from_r_b, from_f)
114
+ # Possible destinations: push (df=0), capture left (df=-1), capture right (df=1)
115
+ for df in [-1, 0, 1]:
116
+ to_f_b = from_f + df
117
+ if 0 <= to_f_b < 8:
118
+ to_alg_b = _coords_to_alg(to_r_b, to_f_b)
119
+ base_move = f"{from_alg_b}{to_alg_b}"
120
+ #promo_base_moves_to_exclude.add(base_move) # Mark e.g. "e2e1" for exclusion
121
+ for p in promo_pieces:
122
+ valid_moves.add(f"{base_move}{p}") # Add e.g. "e2e1q"
123
+
124
+ # 3. Remove the base moves that were replaced by promotions
125
+ final_valid_moves = valid_moves - promo_base_moves_to_exclude
126
+
127
+ # 4. Add draw claim
128
+ final_valid_moves.add("<claim_draw>")
129
+
130
+ # 5. Create the final map with sorted keys for deterministic indices
131
+ sorted_moves = sorted(list(final_valid_moves))
132
+ move_map = {move: i for i, move in enumerate(sorted_moves)}
133
+
134
+ # Optional: Print the number of moves found for verification
135
+ # print(f"Generated {len(move_map)} structurally valid unique UCI moves.")
136
+
137
+ return move_map
138
+
139
+
140
+ UCI_MOVE_TO_IDX = generate_structurally_valid_move_map()
141
+ IDX_TO_UCI_MOVE = {v:k for k,v in UCI_MOVE_TO_IDX.items()}