Spaces:
Sleeping
Sleeping
Upload 10 files
Browse files- Dockerfile +25 -0
- README.md +0 -14
- app.py +771 -0
- model.py +365 -0
- requirements.txt +10 -0
- utils/__init__.py +17 -0
- utils/buffer.py +274 -0
- utils/chess_env.py +151 -0
- utils/engine.py +759 -0
- utils/mapping.py +141 -0
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12
|
| 2 |
+
|
| 3 |
+
# Install system dependencies including Stockfish
|
| 4 |
+
RUN apt-get update && apt-get install -y \
|
| 5 |
+
stockfish \
|
| 6 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Set working directory
|
| 9 |
+
WORKDIR /code
|
| 10 |
+
|
| 11 |
+
# Copy requirements and install Python dependencies
|
| 12 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 13 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 14 |
+
|
| 15 |
+
# Copy application code
|
| 16 |
+
COPY . /code
|
| 17 |
+
|
| 18 |
+
# Make sure Stockfish is executable
|
| 19 |
+
RUN chmod +x /usr/bin/stockfish
|
| 20 |
+
|
| 21 |
+
# Expose port
|
| 22 |
+
EXPOSE 7860
|
| 23 |
+
|
| 24 |
+
# Run the application
|
| 25 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Chessformer Demo
|
| 3 |
-
emoji: 🌍
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.32.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
short_description: Play chess with Chessformer or Stockfish!
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,771 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import chess
|
| 7 |
+
import chess.svg
|
| 8 |
+
import chess.pgn
|
| 9 |
+
import re
|
| 10 |
+
import torch
|
| 11 |
+
import os
|
| 12 |
+
import io
|
| 13 |
+
import math
|
| 14 |
+
from typing import Optional, Tuple, List
|
| 15 |
+
import traceback
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
from utils import Engine, ChessformerConfig, StockfishConfig, UCI_MOVE_TO_IDX
|
| 19 |
+
from model import ChessFormerModel
|
| 20 |
+
|
| 21 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 22 |
+
|
| 23 |
+
import spaces
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
# Add to ChessApp.__init__
|
| 27 |
+
def __init__(self):
|
| 28 |
+
# ... existing init code ...
|
| 29 |
+
self.analysis_executor = ThreadPoolExecutor(max_workers=2)
|
| 30 |
+
|
| 31 |
+
def update_evaluations_async(self):
|
| 32 |
+
"""Update evaluations asynchronously"""
|
| 33 |
+
def update_current_engine():
|
| 34 |
+
if self.current_engine:
|
| 35 |
+
try:
|
| 36 |
+
self.current_engine_eval = self.current_engine.analyze_position(self.board.copy())
|
| 37 |
+
if self.current_engine_eval is None:
|
| 38 |
+
self.current_engine_eval = 0.0
|
| 39 |
+
except:
|
| 40 |
+
self.current_engine_eval = 0.0
|
| 41 |
+
|
| 42 |
+
def update_stockfish():
|
| 43 |
+
try:
|
| 44 |
+
self.stockfish_eval = self.fast_stockfish_analysis(self.board.copy())
|
| 45 |
+
if self.stockfish_eval is None:
|
| 46 |
+
self.stockfish_eval = 0.0
|
| 47 |
+
except:
|
| 48 |
+
self.stockfish_eval = 0.0
|
| 49 |
+
|
| 50 |
+
# Run both analyses in parallel
|
| 51 |
+
future1 = self.analysis_executor.submit(update_current_engine)
|
| 52 |
+
future2 = self.analysis_executor.submit(update_stockfish)
|
| 53 |
+
|
| 54 |
+
# Wait for both to complete
|
| 55 |
+
future1.result()
|
| 56 |
+
future2.result()
|
| 57 |
+
|
| 58 |
+
class ChessApp:
|
| 59 |
+
def __init__(self, device):
|
| 60 |
+
self.board = chess.Board()
|
| 61 |
+
self.move_history = []
|
| 62 |
+
self.current_engine = None
|
| 63 |
+
self.analysis_engine = None
|
| 64 |
+
self.game_over = False
|
| 65 |
+
self.user_color = chess.WHITE
|
| 66 |
+
self.models = {}
|
| 67 |
+
self.device = device
|
| 68 |
+
|
| 69 |
+
self.current_engine_eval = 0.0
|
| 70 |
+
self.stockfish_eval = 0.0
|
| 71 |
+
|
| 72 |
+
self.load_models()
|
| 73 |
+
self.create_analysis_engine()
|
| 74 |
+
|
| 75 |
+
def load_models(self):
|
| 76 |
+
model_paths = {
|
| 77 |
+
"ChessFormer-SL": "./ckpts/chessformer-sl_01.pth",
|
| 78 |
+
"ChessFormer-RL": "./ckpts/chessformer-rl_final.pth"
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
for name, path in model_paths.items():
|
| 82 |
+
if os.path.exists(path):
|
| 83 |
+
print(f"Loading {name} from {path}...")
|
| 84 |
+
checkpoint = torch.load(path,map_location=self.device)
|
| 85 |
+
config = checkpoint["config"]
|
| 86 |
+
model = ChessFormerModel(**config)
|
| 87 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 88 |
+
model.to(self.device)
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
self.models[name] = model
|
| 92 |
+
print(f"Successfully loaded {name}.")
|
| 93 |
+
else:
|
| 94 |
+
print(f"Model file not found: {path}")
|
| 95 |
+
|
| 96 |
+
def get_depth_limits(self, engine_type: str) -> Tuple[int,int]:
|
| 97 |
+
if engine_type == "Stockfish":
|
| 98 |
+
return 1,24,6
|
| 99 |
+
else:
|
| 100 |
+
return 0,6,0
|
| 101 |
+
|
| 102 |
+
def create_evaluation_bar(self, eval_score: float, title: str) -> str:
|
| 103 |
+
"""Create HTML evaluation bar from user's perspective with page-matching colors"""
|
| 104 |
+
# Convert eval_score from white's perspective to user's perspective
|
| 105 |
+
user_eval = eval_score if self.user_color == chess.WHITE else -eval_score
|
| 106 |
+
|
| 107 |
+
# Clamp evaluation between -1 and 1 for display
|
| 108 |
+
clamped_eval = max(-1.0, min(1.0, user_eval))
|
| 109 |
+
|
| 110 |
+
# Convert to percentage (0 = user losing, 100 = user winning)
|
| 111 |
+
percentage = (clamped_eval + 1.0) / 2.0 * 100
|
| 112 |
+
|
| 113 |
+
# Format evaluation text from user's perspective
|
| 114 |
+
eval_text = f"{user_eval:+.2f}"
|
| 115 |
+
if abs(user_eval) > 5:
|
| 116 |
+
eval_text = "±∞" if user_eval > 0 else "∓∞"
|
| 117 |
+
|
| 118 |
+
# Determine advantage text and colors (matching page theme)
|
| 119 |
+
if user_eval > 0.5:
|
| 120 |
+
advantage_text = "WINNING"
|
| 121 |
+
text_color = "#1e40af" # Blue-700
|
| 122 |
+
indicator_color = "#3b82f6" # Blue-500
|
| 123 |
+
elif user_eval > 0.1:
|
| 124 |
+
advantage_text = "SLIGHT ADVANTAGE"
|
| 125 |
+
text_color = "#1e40af"
|
| 126 |
+
indicator_color = "#60a5fa" # Blue-400
|
| 127 |
+
elif user_eval < -0.5:
|
| 128 |
+
advantage_text = "LOSING"
|
| 129 |
+
text_color = "#7c2d12" # Orange-800 (more muted than red)
|
| 130 |
+
indicator_color = "#ea580c" # Orange-600
|
| 131 |
+
elif user_eval < -0.1:
|
| 132 |
+
advantage_text = "SLIGHT DISADVANTAGE"
|
| 133 |
+
text_color = "#9a3412" # Orange-700
|
| 134 |
+
indicator_color = "#f97316" # Orange-500
|
| 135 |
+
else:
|
| 136 |
+
advantage_text = "EQUAL POSITION"
|
| 137 |
+
text_color = "#4b5563" # Gray-600
|
| 138 |
+
indicator_color = "#6b7280" # Gray-500
|
| 139 |
+
|
| 140 |
+
return f"""
|
| 141 |
+
<div style="margin: 10px 0; font-family: 'Segoe UI', Arial, sans-serif;">
|
| 142 |
+
<h4 style="margin: 5px 0 10px 0; color: #374151; font-size: 14px; font-weight: 600;">{title}</h4>
|
| 143 |
+
|
| 144 |
+
<!-- Evaluation bar with page-matching gradient -->
|
| 145 |
+
<div style="width: 100%; height: 40px; border: 2px solid #d1d5db; border-radius: 8px; position: relative;
|
| 146 |
+
background: linear-gradient(to right,
|
| 147 |
+
#fed7aa 0%, /* Orange-200 - losing */
|
| 148 |
+
#fde68a 20%, /* Yellow-200 */
|
| 149 |
+
#e5e7eb 50%, /* Gray-200 - equal */
|
| 150 |
+
#bfdbfe 80%, /* Blue-200 */
|
| 151 |
+
#93c5fd 100% /* Blue-300 - winning */
|
| 152 |
+
);
|
| 153 |
+
box-shadow: inset 0 1px 3px rgba(0,0,0,0.05);">
|
| 154 |
+
|
| 155 |
+
<!-- Evaluation indicator -->
|
| 156 |
+
<div style="position: absolute; left: {percentage}%; top: 50%; transform: translateX(-50%) translateY(-50%);
|
| 157 |
+
background: {indicator_color}; border: 3px solid white; border-radius: 50%; width: 18px; height: 18px;
|
| 158 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.15), 0 0 0 1px #d1d5db; z-index: 10;
|
| 159 |
+
transition: all 0.3s ease;"></div>
|
| 160 |
+
</div>
|
| 161 |
+
|
| 162 |
+
<!-- Evaluation text -->
|
| 163 |
+
<div style="text-align: center; margin-top: 8px; padding: 8px; background: #f9fafb;
|
| 164 |
+
border-radius: 6px; border: 1px solid #e5e7eb;">
|
| 165 |
+
<div style="font-weight: 600; color: {text_color}; font-size: 16px; margin-bottom: 2px;">
|
| 166 |
+
{eval_text}
|
| 167 |
+
</div>
|
| 168 |
+
<div style="font-size: 10px; color: {text_color}; text-transform: uppercase; letter-spacing: 0.8px; font-weight: 500; opacity: 0.8;">
|
| 169 |
+
{advantage_text}
|
| 170 |
+
</div>
|
| 171 |
+
</div>
|
| 172 |
+
</div>
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def create_analysis_engine(self):
|
| 176 |
+
"""Create optimized Stockfish depth 27 engine for analysis"""
|
| 177 |
+
try:
|
| 178 |
+
config = StockfishConfig(
|
| 179 |
+
engine_path="/usr/games/stockfish",
|
| 180 |
+
depth=27
|
| 181 |
+
)
|
| 182 |
+
self.analysis_engine = Engine(type="stockfish", stockfish_config=config)
|
| 183 |
+
|
| 184 |
+
# Configure Stockfish for faster analysis
|
| 185 |
+
if self.analysis_engine and hasattr(self.analysis_engine, 'engine_path'):
|
| 186 |
+
# We'll patch the engine creation to use optimized settings
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
print("Analysis engine (Stockfish depth 27) created successfully")
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print(f"Failed to create analysis engine: {e}")
|
| 192 |
+
self.analysis_engine = None
|
| 193 |
+
|
| 194 |
+
def update_evaluations(self):
|
| 195 |
+
"""Update evaluations from both engines with optimized Stockfish analysis"""
|
| 196 |
+
# Get current engine evaluation
|
| 197 |
+
if self.current_engine:
|
| 198 |
+
try:
|
| 199 |
+
self.current_engine_eval = self.current_engine.analyze_position(self.board.copy())
|
| 200 |
+
if self.current_engine_eval is None:
|
| 201 |
+
self.current_engine_eval = 0.0
|
| 202 |
+
except:
|
| 203 |
+
self.current_engine_eval = 0.0
|
| 204 |
+
|
| 205 |
+
# Get optimized Stockfish analysis
|
| 206 |
+
if self.analysis_engine:
|
| 207 |
+
try:
|
| 208 |
+
self.stockfish_eval = self.fast_stockfish_analysis(self.board.copy())
|
| 209 |
+
if self.stockfish_eval is None:
|
| 210 |
+
self.stockfish_eval = 0.0
|
| 211 |
+
except:
|
| 212 |
+
self.stockfish_eval = 0.0
|
| 213 |
+
|
| 214 |
+
def fast_stockfish_analysis(self, board: chess.Board) -> Optional[float]:
|
| 215 |
+
"""Fast Stockfish analysis with optimized settings"""
|
| 216 |
+
try:
|
| 217 |
+
import chess.engine
|
| 218 |
+
|
| 219 |
+
# Create engine with optimized settings
|
| 220 |
+
with chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish") as engine:
|
| 221 |
+
# Configure for speed
|
| 222 |
+
engine.configure({
|
| 223 |
+
"Threads": min(8, os.cpu_count() or 4), # Use multiple threads
|
| 224 |
+
"Hash": 256, # 256MB hash table
|
| 225 |
+
"UCI_AnalyseMode": True
|
| 226 |
+
})
|
| 227 |
+
|
| 228 |
+
# Use time limit instead of depth for faster analysis
|
| 229 |
+
info = engine.analyse(
|
| 230 |
+
board,
|
| 231 |
+
chess.engine.Limit(time=1.0), # 1 second analysis
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
score_obj = info.get("score")
|
| 235 |
+
if score_obj is None:
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
pov_score = score_obj.pov(chess.WHITE)
|
| 239 |
+
|
| 240 |
+
if pov_score.is_mate():
|
| 241 |
+
mate_score = pov_score.mate()
|
| 242 |
+
cp = 10000.0 if mate_score > 0 else -10000.0
|
| 243 |
+
elif pov_score.cp is not None:
|
| 244 |
+
cp = float(pov_score.cp)
|
| 245 |
+
else:
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
# Normalize score
|
| 249 |
+
normalized_score = 2 / (1 + math.exp(-0.004 * cp)) - 1
|
| 250 |
+
return normalized_score
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"Fast Stockfish analysis error: {e}")
|
| 254 |
+
return None
|
| 255 |
+
|
| 256 |
+
def create_engine(self, engine_type: str, depth: int, temperature: float=0.5) -> Optional[Engine]:
|
| 257 |
+
if engine_type == "Stockfish":
|
| 258 |
+
config = StockfishConfig(
|
| 259 |
+
engine_path="/usr/games/stockfish",
|
| 260 |
+
depth=depth
|
| 261 |
+
)
|
| 262 |
+
return Engine(type="stockfish",stockfish_config=config)
|
| 263 |
+
elif engine_type in self.models:
|
| 264 |
+
config = ChessformerConfig(
|
| 265 |
+
chessformer=self.models[engine_type],
|
| 266 |
+
device=self.device,
|
| 267 |
+
temperature=temperature,
|
| 268 |
+
depth=depth if depth > 0 else 0,
|
| 269 |
+
top_k=8,
|
| 270 |
+
decay_rate=0.6,
|
| 271 |
+
max_batch_size=800
|
| 272 |
+
)
|
| 273 |
+
return Engine(type="chessformer",chessformer_config=config)
|
| 274 |
+
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
def parse_move(self, move_str: str) -> Optional[chess.Move]:
|
| 278 |
+
"""Parse move input in either UCI format ("e2e4") or algebraic notation ("Ne5")"""
|
| 279 |
+
if not move_str:
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
move_str = move_str.strip()
|
| 283 |
+
|
| 284 |
+
# Try UCI format first
|
| 285 |
+
uci_pattern = r'^[a-h][1-8][a-h][1-8][qrbn]?$'
|
| 286 |
+
if re.match(uci_pattern,move_str.lower()):
|
| 287 |
+
try:
|
| 288 |
+
return chess.Move.from_uci(move_str.lower())
|
| 289 |
+
except ValueError:
|
| 290 |
+
pass
|
| 291 |
+
|
| 292 |
+
# Try algrebraic notation
|
| 293 |
+
try:
|
| 294 |
+
return self.board.parse_san(move_str)
|
| 295 |
+
except ValueError:
|
| 296 |
+
pass
|
| 297 |
+
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
def get_board_svg(self) -> str:
|
| 301 |
+
"""Generate SVG representation of the chess board"""
|
| 302 |
+
flip = (self.user_color == chess.BLACK)
|
| 303 |
+
|
| 304 |
+
lastmove = None
|
| 305 |
+
if self.move_history:
|
| 306 |
+
lastmove = self.move_history[-1]
|
| 307 |
+
|
| 308 |
+
svg = chess.svg.board(
|
| 309 |
+
board=self.board,
|
| 310 |
+
flipped=flip,
|
| 311 |
+
lastmove=lastmove,
|
| 312 |
+
size=600
|
| 313 |
+
)
|
| 314 |
+
return svg
|
| 315 |
+
|
| 316 |
+
def get_move_history_text(self) -> str:
|
| 317 |
+
"""Generate move history in PGN format"""
|
| 318 |
+
try:
|
| 319 |
+
game = chess.pgn.Game()
|
| 320 |
+
game.headers["Event"] = "ChessFormer Demo"
|
| 321 |
+
game.headers["Date"] = datetime.now().strftime("%Y.%m.%d")
|
| 322 |
+
game.headers["White"] = "You" if self.user_color == chess.WHITE else "Engine"
|
| 323 |
+
game.headers["Black"] = "Engine" if self.user_color == chess.WHITE else "You"
|
| 324 |
+
|
| 325 |
+
node = game
|
| 326 |
+
temp_board = chess.Board()
|
| 327 |
+
|
| 328 |
+
for move in self.move_history:
|
| 329 |
+
node = node.add_variation(move)
|
| 330 |
+
temp_board.push(move)
|
| 331 |
+
|
| 332 |
+
if self.game_over:
|
| 333 |
+
outcome = self.board.outcome()
|
| 334 |
+
if outcome:
|
| 335 |
+
if outcome.winner == chess.WHITE:
|
| 336 |
+
game.headers["Result"] = "1-0"
|
| 337 |
+
elif outcome.winner == chess.BLACK:
|
| 338 |
+
game.headers["Result"] = "0-1"
|
| 339 |
+
else:
|
| 340 |
+
game.headers["Result"] = "1/2-1/2"
|
| 341 |
+
else:
|
| 342 |
+
game.headers["Result"] = "*"
|
| 343 |
+
else:
|
| 344 |
+
game.headers["Result"] = "*"
|
| 345 |
+
|
| 346 |
+
return str(game)
|
| 347 |
+
except Exception as e:
|
| 348 |
+
print(f"Error generating move history: {e}")
|
| 349 |
+
return "Move history unavailable"
|
| 350 |
+
|
| 351 |
+
def export_pgn(self) -> str:
|
| 352 |
+
return self.get_move_history_text()
|
| 353 |
+
|
| 354 |
+
def import_fen(self, fen: str) -> Tuple[str,str,str,str,str]:
|
| 355 |
+
try:
|
| 356 |
+
test_board = chess.Board(fen.strip())
|
| 357 |
+
self.board = test_board
|
| 358 |
+
self.move_history = []
|
| 359 |
+
self.game_over = False
|
| 360 |
+
self.update_evaluations()
|
| 361 |
+
|
| 362 |
+
return (
|
| 363 |
+
self.get_board_svg(),
|
| 364 |
+
self.get_move_history_text(),
|
| 365 |
+
f"Position loaded from FEN: {fen}",
|
| 366 |
+
"",
|
| 367 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 368 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 369 |
+
)
|
| 370 |
+
except Exception as e:
|
| 371 |
+
return (
|
| 372 |
+
self.get_board_svg(),
|
| 373 |
+
self.get_move_history_text(),
|
| 374 |
+
f"Invalid FEN: {str(e)}",
|
| 375 |
+
"",
|
| 376 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 377 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def import_pgn(self, pgn_text: str) -> Tuple[str,str,str,str,str]:
|
| 381 |
+
try:
|
| 382 |
+
pgn_io = io.StringIO(pgn_text.strip())
|
| 383 |
+
game = chess.pgn.read_game(pgn_io)
|
| 384 |
+
|
| 385 |
+
if game is None:
|
| 386 |
+
raise ValueError("Could not parse PGN")
|
| 387 |
+
|
| 388 |
+
self.board = game.board()
|
| 389 |
+
self.move_history = []
|
| 390 |
+
|
| 391 |
+
for move in game.mainline_moves():
|
| 392 |
+
self.board.push(move)
|
| 393 |
+
self.move_history.append(move)
|
| 394 |
+
|
| 395 |
+
self.game_over = self.board.is_game_over()
|
| 396 |
+
self.update_evaluations()
|
| 397 |
+
|
| 398 |
+
return (
|
| 399 |
+
self.get_board_svg(),
|
| 400 |
+
self.get_move_history_text(),
|
| 401 |
+
f"Game loaded from PGN ({len(self.move_history)} moves)",
|
| 402 |
+
"",
|
| 403 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 404 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 405 |
+
)
|
| 406 |
+
except Exception as e:
|
| 407 |
+
return (
|
| 408 |
+
self.get_board_svg(),
|
| 409 |
+
self.get_move_history_text(),
|
| 410 |
+
f"Invalid PGN: {str(e)}",
|
| 411 |
+
"",
|
| 412 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 413 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
def make_user_move(self, move_str: str) -> Tuple[str,str,str,str,str,str]:
|
| 417 |
+
if self.game_over:
|
| 418 |
+
return (
|
| 419 |
+
self.get_board_svg(),
|
| 420 |
+
self.get_move_history_text(),
|
| 421 |
+
"Game is over. Click 'New Game' to start a new game.",
|
| 422 |
+
"",
|
| 423 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 424 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if self.board.turn != self.user_color:
|
| 428 |
+
return (
|
| 429 |
+
self.get_board_svg(),
|
| 430 |
+
self.get_move_history_text(),
|
| 431 |
+
"It's not your turn now!",
|
| 432 |
+
"",
|
| 433 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 434 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
move = self.parse_move(move_str)
|
| 438 |
+
if move is None:
|
| 439 |
+
return (
|
| 440 |
+
self.get_board_svg(),
|
| 441 |
+
self.get_move_history_text(),
|
| 442 |
+
f"Invalid move: '{move_str}'. Try formats like 'e2e4' or 'Ne5'",
|
| 443 |
+
"",
|
| 444 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 445 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
if move not in self.board.legal_moves:
|
| 449 |
+
return (
|
| 450 |
+
self.get_board_svg(),
|
| 451 |
+
self.get_move_history_text(),
|
| 452 |
+
f"Illegal move: '{move_str}'",
|
| 453 |
+
"",
|
| 454 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 455 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
self.board.push(move)
|
| 459 |
+
self.move_history.append(move)
|
| 460 |
+
|
| 461 |
+
self.update_evaluations()
|
| 462 |
+
|
| 463 |
+
if self.board.is_game_over():
|
| 464 |
+
self.game_over = True
|
| 465 |
+
outcome = self.board.outcome()
|
| 466 |
+
if outcome:
|
| 467 |
+
if outcome.winner == self.user_color:
|
| 468 |
+
status = "Congratulations! You won!"
|
| 469 |
+
elif outcome.winner is None:
|
| 470 |
+
status = "Game drawn."
|
| 471 |
+
else:
|
| 472 |
+
status = "You lost."
|
| 473 |
+
status += f" ({outcome.termination.name})"
|
| 474 |
+
else:
|
| 475 |
+
status = "Game over."
|
| 476 |
+
|
| 477 |
+
return (
|
| 478 |
+
self.get_board_svg(),
|
| 479 |
+
self.get_move_history_text(),
|
| 480 |
+
status,
|
| 481 |
+
"",
|
| 482 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 483 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
# Get engine move
|
| 487 |
+
try:
|
| 488 |
+
engine_move_uci, engine_value = self.current_engine.move(self.board)
|
| 489 |
+
|
| 490 |
+
if engine_move_uci == "<claim_draw>":
|
| 491 |
+
self.game_over = True
|
| 492 |
+
status = "Engine claimed a draw."
|
| 493 |
+
else:
|
| 494 |
+
engine_move = chess.Move.from_uci(engine_move_uci)
|
| 495 |
+
self.board.push(engine_move)
|
| 496 |
+
self.move_history.append(engine_move)
|
| 497 |
+
|
| 498 |
+
if self.board.is_game_over():
|
| 499 |
+
self.game_over = True
|
| 500 |
+
outcome = self.board.outcome()
|
| 501 |
+
if outcome:
|
| 502 |
+
if outcome.winner == self.user_color:
|
| 503 |
+
status = "🎉🏆 CONGRATULATIONS! YOU WON! 🏆🎉"
|
| 504 |
+
status += f"\n🎯 Victory by {outcome.termination.name}! 🎯"
|
| 505 |
+
elif outcome.winner is None:
|
| 506 |
+
status = "🤝 GAME DRAWN 🤝"
|
| 507 |
+
status += f"\n⚖️ Draw by {outcome.termination.name} ⚖️"
|
| 508 |
+
else:
|
| 509 |
+
status = "😔 YOU LOST 😔"
|
| 510 |
+
status += f"\n💔 Defeated by {outcome.termination.name} 💔"
|
| 511 |
+
else:
|
| 512 |
+
status = "🏁 GAME OVER 🏁"
|
| 513 |
+
else:
|
| 514 |
+
status = f"Engine played: {engine_move.uci()}."
|
| 515 |
+
|
| 516 |
+
except Exception as e:
|
| 517 |
+
status = f"Engine error: {str(e)}"
|
| 518 |
+
print(f"Engine error: {e}")
|
| 519 |
+
traceback.print_exc()
|
| 520 |
+
|
| 521 |
+
return (
|
| 522 |
+
self.get_board_svg(),
|
| 523 |
+
self.get_move_history_text(),
|
| 524 |
+
status,
|
| 525 |
+
"", # clear input
|
| 526 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 527 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
def new_game(self, engine_type: str, depth: int, color: str, temperature: float) -> Tuple[str,str,str,str,str,str]:
|
| 531 |
+
"Start a new game"
|
| 532 |
+
self.board = chess.Board()
|
| 533 |
+
self.move_history = []
|
| 534 |
+
self.game_over = False
|
| 535 |
+
self.user_color = chess.WHITE if color == "White" else chess.BLACK
|
| 536 |
+
|
| 537 |
+
# Create new engine
|
| 538 |
+
self.current_engine = self.create_engine(engine_type, depth, temperature)
|
| 539 |
+
|
| 540 |
+
self.update_evaluations()
|
| 541 |
+
|
| 542 |
+
if self.current_engine is None:
|
| 543 |
+
status = f"Failed to create {engine_type} engine."
|
| 544 |
+
else:
|
| 545 |
+
status = f"New game started! You are playing {color} against {engine_type} (depth {depth})."
|
| 546 |
+
|
| 547 |
+
# If user is black, make engine move first
|
| 548 |
+
if self.user_color == chess.BLACK:
|
| 549 |
+
try:
|
| 550 |
+
engine_move_uci, engine_value = self.current_engine.move(self.board)
|
| 551 |
+
if engine_move_uci != "<claim_draw>":
|
| 552 |
+
engine_move = chess.Move.from_uci(engine_move_uci)
|
| 553 |
+
self.board.push(engine_move)
|
| 554 |
+
self.move_history.append(engine_move)
|
| 555 |
+
status += f" Engine opened with: {engine_move.uci()}"
|
| 556 |
+
except Exception as e:
|
| 557 |
+
status += f" Engine error on first move: {str(e)}"
|
| 558 |
+
|
| 559 |
+
return (
|
| 560 |
+
self.get_board_svg(),
|
| 561 |
+
self.get_move_history_text(),
|
| 562 |
+
status,
|
| 563 |
+
"",
|
| 564 |
+
self.create_evaluation_bar(self.stockfish_eval, "Stockfish Analysis (from your perspective)"),
|
| 565 |
+
self.create_evaluation_bar(self.current_engine_eval, "Engine Analysis (from your perspective)")
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
app = ChessApp(torch.device("cpu"))
|
| 570 |
+
|
| 571 |
+
def create_interface():
|
| 572 |
+
"""Create the Gradio interface with improved layout"""
|
| 573 |
+
|
| 574 |
+
with gr.Blocks(title="ChessFormer Demo", theme=gr.themes.Soft()) as interface:
|
| 575 |
+
gr.Markdown("# 🏆 ChessFormer Demo")
|
| 576 |
+
gr.Markdown("Play chess against ChessFormer models or Stockfish!")
|
| 577 |
+
|
| 578 |
+
with gr.Row():
|
| 579 |
+
# Left column - Analysis + History
|
| 580 |
+
with gr.Column(scale=1):
|
| 581 |
+
gr.Markdown("### 📊 Position Analysis")
|
| 582 |
+
|
| 583 |
+
# Stockfish Analysis
|
| 584 |
+
stockfish_eval_display = gr.HTML(
|
| 585 |
+
value=app.create_evaluation_bar(0.0, "Stockfish Analysis"),
|
| 586 |
+
label="Stockfish"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Current Engine Analysis
|
| 590 |
+
current_engine_eval_display = gr.HTML(
|
| 591 |
+
value=app.create_evaluation_bar(0.0, "Engine Analysis"),
|
| 592 |
+
label="Engine"
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# Move history
|
| 596 |
+
gr.Markdown("### 📝 Game History")
|
| 597 |
+
history_display = gr.Textbox(
|
| 598 |
+
value=app.get_move_history_text(),
|
| 599 |
+
label="PGN",
|
| 600 |
+
lines=12,
|
| 601 |
+
max_lines=15,
|
| 602 |
+
interactive=False
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Middle column - Game Board + Controls
|
| 606 |
+
with gr.Column(scale=4):
|
| 607 |
+
# Chess board display
|
| 608 |
+
board_display = gr.HTML(
|
| 609 |
+
value=app.get_board_svg(),
|
| 610 |
+
label="Chess Board"
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Move input
|
| 614 |
+
with gr.Row():
|
| 615 |
+
move_input = gr.Textbox(
|
| 616 |
+
placeholder="Enter move (e.g., 'e2e4' or 'Ne5')",
|
| 617 |
+
label="Your Move",
|
| 618 |
+
scale=4
|
| 619 |
+
)
|
| 620 |
+
move_button = gr.Button("Make Move", variant="primary", scale=1)
|
| 621 |
+
|
| 622 |
+
# Game status
|
| 623 |
+
status_display = gr.Textbox(
|
| 624 |
+
value="Click 'New Game' to start playing!",
|
| 625 |
+
label="Game Status",
|
| 626 |
+
interactive=False,
|
| 627 |
+
lines=2
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# Right column - Settings + Import/Export
|
| 631 |
+
with gr.Column(scale=2):
|
| 632 |
+
# Engine settings
|
| 633 |
+
gr.Markdown("### ⚙️ Game Settings")
|
| 634 |
+
|
| 635 |
+
engine_choices = ["Stockfish"] + list(app.models.keys())
|
| 636 |
+
engine_select = gr.Dropdown(
|
| 637 |
+
choices=engine_choices,
|
| 638 |
+
value="ChessFormer-SL" if engine_choices else None,
|
| 639 |
+
label="Opponent Engine"
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
depth_slider = gr.Slider(
|
| 643 |
+
minimum=0,
|
| 644 |
+
maximum=6,
|
| 645 |
+
value=0,
|
| 646 |
+
step=1,
|
| 647 |
+
label="Engine Depth"
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
color_select = gr.Radio(
|
| 651 |
+
choices=["White", "Black"],
|
| 652 |
+
value="White",
|
| 653 |
+
label="Your Color"
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
temperature_slider = gr.Slider(
|
| 657 |
+
minimum=0.1,
|
| 658 |
+
maximum=2.0,
|
| 659 |
+
value=0.5,
|
| 660 |
+
step=0.1,
|
| 661 |
+
label="Temperature (ChessFormer only)"
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
new_game_button = gr.Button("🔄 New Game", variant="secondary", size="lg")
|
| 665 |
+
|
| 666 |
+
# Import/Export section
|
| 667 |
+
gr.Markdown("### 📁 Import/Export")
|
| 668 |
+
|
| 669 |
+
with gr.Tabs():
|
| 670 |
+
with gr.Tab("Import FEN"):
|
| 671 |
+
fen_input = gr.Textbox(
|
| 672 |
+
placeholder="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
| 673 |
+
label="FEN String",
|
| 674 |
+
lines=2
|
| 675 |
+
)
|
| 676 |
+
import_fen_button = gr.Button("Import FEN")
|
| 677 |
+
|
| 678 |
+
with gr.Tab("Import PGN"):
|
| 679 |
+
pgn_input = gr.Textbox(
|
| 680 |
+
placeholder="1. e4 e5 2. Nf3 Nc6...",
|
| 681 |
+
label="PGN Text",
|
| 682 |
+
lines=3
|
| 683 |
+
)
|
| 684 |
+
import_pgn_button = gr.Button("Import PGN")
|
| 685 |
+
|
| 686 |
+
with gr.Tab("Export"):
|
| 687 |
+
export_button = gr.Button("📁 Download PGN")
|
| 688 |
+
export_output = gr.File(label="Download")
|
| 689 |
+
|
| 690 |
+
# Available models info
|
| 691 |
+
gr.Markdown("### 🤖 Available Models")
|
| 692 |
+
if app.models:
|
| 693 |
+
model_info = "**Loaded ChessFormer models:**\n" + "\n".join([f"• {name}" for name in app.models.keys()])
|
| 694 |
+
else:
|
| 695 |
+
model_info = "⚠️ No ChessFormer models found. Make sure model checkpoints are in the ./ckpts/ directory."
|
| 696 |
+
gr.Markdown(model_info)
|
| 697 |
+
|
| 698 |
+
# Function to update depth limits based on engine selection
|
| 699 |
+
def update_depth_limits(engine_type):
|
| 700 |
+
min_depth, max_depth, value = app.get_depth_limits(engine_type)
|
| 701 |
+
return gr.Slider(minimum=min_depth, maximum=max_depth, value=value, step=1)
|
| 702 |
+
|
| 703 |
+
# Function to export PGN
|
| 704 |
+
def export_pgn_file():
|
| 705 |
+
pgn_content = app.export_pgn()
|
| 706 |
+
filename = f"chess_game_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pgn"
|
| 707 |
+
with open(filename, 'w') as f:
|
| 708 |
+
f.write(pgn_content)
|
| 709 |
+
return filename
|
| 710 |
+
|
| 711 |
+
# Event handlers (same as before...)
|
| 712 |
+
engine_select.change(
|
| 713 |
+
fn=update_depth_limits,
|
| 714 |
+
inputs=[engine_select],
|
| 715 |
+
outputs=[depth_slider]
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
move_button.click(
|
| 719 |
+
fn=app.make_user_move,
|
| 720 |
+
inputs=[move_input],
|
| 721 |
+
outputs=[board_display, history_display, status_display, move_input,
|
| 722 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
move_input.submit(
|
| 726 |
+
fn=app.make_user_move,
|
| 727 |
+
inputs=[move_input],
|
| 728 |
+
outputs=[board_display, history_display, status_display, move_input,
|
| 729 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
new_game_button.click(
|
| 733 |
+
fn=app.new_game,
|
| 734 |
+
inputs=[engine_select, depth_slider, color_select, temperature_slider],
|
| 735 |
+
outputs=[board_display, history_display, status_display, move_input,
|
| 736 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
import_fen_button.click(
|
| 740 |
+
fn=app.import_fen,
|
| 741 |
+
inputs=[fen_input],
|
| 742 |
+
outputs=[board_display, history_display, status_display, fen_input,
|
| 743 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
import_pgn_button.click(
|
| 747 |
+
fn=app.import_pgn,
|
| 748 |
+
inputs=[pgn_input],
|
| 749 |
+
outputs=[board_display, history_display, status_display, pgn_input,
|
| 750 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
export_button.click(
|
| 754 |
+
fn=export_pgn_file,
|
| 755 |
+
outputs=[export_output]
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# Auto-start a new game when interface loads
|
| 759 |
+
interface.load(
|
| 760 |
+
fn=app.new_game,
|
| 761 |
+
inputs=[gr.State("Stockfish"), gr.State(6), gr.State("White"), gr.State(0.5)],
|
| 762 |
+
outputs=[board_display, history_display, status_display, move_input,
|
| 763 |
+
stockfish_eval_display, current_engine_eval_display]
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
return interface
|
| 767 |
+
|
| 768 |
+
if __name__ == "__main__":
|
| 769 |
+
# Create and launch interface
|
| 770 |
+
interface = create_interface()
|
| 771 |
+
interface.launch()
|
model.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 6 |
+
|
| 7 |
+
from utils import MAX_HALFMOVES, MAX_FULLMOVES, EMPTY_SQ_IDX, PIECE_TO_IDX, SQUARE_TO_IDX, IDX_TO_UCI_MOVE
|
| 8 |
+
|
| 9 |
+
# --- Tokenizer --- #
|
| 10 |
+
class FENTokenizer(nn.Module):
|
| 11 |
+
"""Convert FEN (and repetitions) to a sequence of tokens"""
|
| 12 |
+
def __init__(self, hidden_size,dtype):
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
self.side_embed = nn.Embedding(2,hidden_size,dtype=dtype) # black/white embedding
|
| 16 |
+
|
| 17 |
+
self.castling_embed_k = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
|
| 18 |
+
self.castling_embed_q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
|
| 19 |
+
self.castling_embed_K = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
|
| 20 |
+
self.castling_embed_Q = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
|
| 21 |
+
self.no_castling_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype))
|
| 22 |
+
|
| 23 |
+
self.piece_embed = nn.Embedding(13,hidden_size,dtype=dtype) # 6 for white pieces, 6 for black pieces, 1 for empty
|
| 24 |
+
|
| 25 |
+
self.no_en_passant_embed = nn.Parameter(torch.randn(1,1,hidden_size,dtype=dtype)) # use positional embed for the target square, or a special one for '-'
|
| 26 |
+
|
| 27 |
+
self.half_move_embed = nn.Embedding(MAX_HALFMOVES,hidden_size,dtype=dtype)
|
| 28 |
+
|
| 29 |
+
self.full_move_embed = nn.Embedding(MAX_FULLMOVES,hidden_size,dtype=dtype)
|
| 30 |
+
|
| 31 |
+
self.repetition_embed = nn.Embedding(3,hidden_size,dtype=dtype)
|
| 32 |
+
|
| 33 |
+
self.pos_embed = nn.Embedding(64,hidden_size,dtype=dtype) # positional embedding
|
| 34 |
+
|
| 35 |
+
def _parse_fen_string(self, fen_str: str) -> Dict:
|
| 36 |
+
parts = fen_str.split()
|
| 37 |
+
if len(parts) != 6:
|
| 38 |
+
raise ValueError(f"Invalid FEN string: {fen_str}. Expected 6 fields")
|
| 39 |
+
return {
|
| 40 |
+
"piece_placement": parts[0],
|
| 41 |
+
"side_to_move": parts[1],
|
| 42 |
+
"castling": parts[2],
|
| 43 |
+
"en_passant": parts[3],
|
| 44 |
+
"halfmove_clock": parts[4],
|
| 45 |
+
"fullmove_number": parts[5],
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def forward(self, fen_list: List[str], repetitions: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
fen: List of fen strings
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
torch tensor of shape (n_fen,73,hidden_size) where 73 tokens consists of:
|
| 55 |
+
64 piece tokens (fen's first field) +
|
| 56 |
+
1 which-side-to-move token (fen's second field) +
|
| 57 |
+
4 casting rights tokens (fen's third field) +
|
| 58 |
+
1 en-passant target token (fen's fourth field) +
|
| 59 |
+
1 half move clock token (fen's fifth field) +
|
| 60 |
+
1 full move number token (fen's fifth field) +
|
| 61 |
+
1 repetition count token (repetitions input)
|
| 62 |
+
"""
|
| 63 |
+
batch_size = len(fen_list)
|
| 64 |
+
assert batch_size == repetitions.shape[0]
|
| 65 |
+
assert len(repetitions.size()) == 1
|
| 66 |
+
batch_tokens = []
|
| 67 |
+
device = self.side_embed.weight.device
|
| 68 |
+
|
| 69 |
+
# Precompute all square indices
|
| 70 |
+
square_indices = torch.arange(64, device=device)
|
| 71 |
+
all_pos_embeds = self.pos_embed(square_indices) # (64,D)
|
| 72 |
+
|
| 73 |
+
for fen_str in fen_list:
|
| 74 |
+
parsed_fen = self._parse_fen_string(fen_str)
|
| 75 |
+
tokens = []
|
| 76 |
+
|
| 77 |
+
# --- 1. Piece Placement (64 tokens) ---
|
| 78 |
+
piece_indices = torch.full((64,), EMPTY_SQ_IDX, dtype=torch.long, device=device)
|
| 79 |
+
current_rank = 7 # Start from rank 8
|
| 80 |
+
current_file = 0 # Start from file 'a'
|
| 81 |
+
for char in parsed_fen["piece_placement"]:
|
| 82 |
+
if char == '/':
|
| 83 |
+
current_rank -= 1
|
| 84 |
+
current_file = 0
|
| 85 |
+
elif char.isdigit():
|
| 86 |
+
current_file += int(char)
|
| 87 |
+
elif char in PIECE_TO_IDX:
|
| 88 |
+
sq_idx = current_rank * 8 + current_file
|
| 89 |
+
if 0 <= sq_idx < 64:
|
| 90 |
+
piece_indices[sq_idx] = PIECE_TO_IDX[char]
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"Invalid FEN piece placement: {parsed_fen['piece_placement']}")
|
| 93 |
+
current_file += 1
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"Invalid character in FEN piece placement: {char}")
|
| 96 |
+
|
| 97 |
+
piece_embeds = self.piece_embed(piece_indices) # (64, D)
|
| 98 |
+
# Add positional embeddings
|
| 99 |
+
board_tokens = piece_embeds + all_pos_embeds # (64, D)
|
| 100 |
+
tokens.append(board_tokens)
|
| 101 |
+
|
| 102 |
+
# --- 2. Side to Move (1 token) ---
|
| 103 |
+
side_idx = 0 if parsed_fen["side_to_move"] == 'w' else 1
|
| 104 |
+
side_token = self.side_embed(torch.tensor(side_idx, device=device)).unsqueeze(0) # (1, D)
|
| 105 |
+
tokens.append(side_token)
|
| 106 |
+
|
| 107 |
+
# --- 3. Castling Rights (4 tokens) ---
|
| 108 |
+
castling_str = parsed_fen["castling"]
|
| 109 |
+
castling_tokens = torch.cat([
|
| 110 |
+
self.castling_embed_K if 'K' in castling_str else self.no_castling_embed.expand(1, 1, -1),
|
| 111 |
+
self.castling_embed_Q if 'Q' in castling_str else self.no_castling_embed.expand(1, 1, -1),
|
| 112 |
+
self.castling_embed_k if 'k' in castling_str else self.no_castling_embed.expand(1, 1, -1),
|
| 113 |
+
self.castling_embed_q if 'q' in castling_str else self.no_castling_embed.expand(1, 1, -1)
|
| 114 |
+
], dim=1).squeeze(0) # (4, D)
|
| 115 |
+
tokens.append(castling_tokens)
|
| 116 |
+
|
| 117 |
+
# --- 4. En Passant Target (1 token) ---
|
| 118 |
+
en_passant_str = parsed_fen["en_passant"]
|
| 119 |
+
if en_passant_str == '-':
|
| 120 |
+
en_passant_token = self.no_en_passant_embed.squeeze(0) # (1, D)
|
| 121 |
+
else:
|
| 122 |
+
if en_passant_str in SQUARE_TO_IDX:
|
| 123 |
+
sq_idx = SQUARE_TO_IDX[en_passant_str]
|
| 124 |
+
en_passant_token = self.pos_embed(torch.tensor(sq_idx, device=device)).unsqueeze(0) # (1, D)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Invalid en passant square: {en_passant_str}")
|
| 127 |
+
tokens.append(en_passant_token)
|
| 128 |
+
|
| 129 |
+
# --- 5. Half Move Clock (1 token) ---
|
| 130 |
+
try:
|
| 131 |
+
half_move_int = int(parsed_fen["halfmove_clock"])
|
| 132 |
+
except ValueError:
|
| 133 |
+
raise ValueError(f"Invalid halfmove clock value: {parsed_fen['halfmove_clock']}")
|
| 134 |
+
# Clamp value before embedding lookup
|
| 135 |
+
half_move_clamped = torch.clamp(torch.tensor(half_move_int, device=device), 0, MAX_HALFMOVES - 1)
|
| 136 |
+
half_move_token = self.half_move_embed(half_move_clamped).unsqueeze(0) # (1, D)
|
| 137 |
+
tokens.append(half_move_token)
|
| 138 |
+
|
| 139 |
+
# --- 6. Full Move Number (1 token) ---
|
| 140 |
+
try:
|
| 141 |
+
full_move_int = int(parsed_fen["fullmove_number"])
|
| 142 |
+
except ValueError:
|
| 143 |
+
raise ValueError(f"Invalid fullmove number value: {parsed_fen['fullmove_number']}")
|
| 144 |
+
# Clamp value (min 1 for full moves) before embedding lookup (adjusting for 0-based index)
|
| 145 |
+
full_move_clamped = torch.clamp(torch.tensor(full_move_int, device=device), 1, MAX_FULLMOVES) - 1
|
| 146 |
+
full_move_token = self.full_move_embed(full_move_clamped).unsqueeze(0) # (1, D)
|
| 147 |
+
tokens.append(full_move_token)
|
| 148 |
+
|
| 149 |
+
# Concatenate all tokens for this FEN string
|
| 150 |
+
# Shapes: (64, D), (1, D), (4, D), (1, D), (1, D), (1, D) -> Total 72 tokens
|
| 151 |
+
fen_embedding = torch.cat(tokens, dim=0) # (72, D)
|
| 152 |
+
batch_tokens.append(fen_embedding)
|
| 153 |
+
|
| 154 |
+
# Stack into a batch
|
| 155 |
+
batch_tokens = torch.stack(batch_tokens, dim=0) # (B,72,D)
|
| 156 |
+
|
| 157 |
+
# ---7. Repetition Count (1 token) ---
|
| 158 |
+
repetitions = repetitions - 1 # from 1~3 to 0~2
|
| 159 |
+
repetitions = torch.clamp(repetitions,0,2) # if repetition count >3 but no player claimed a draw, it will be treated as 3 repetitions
|
| 160 |
+
repetition_tokens = self.repetition_embed(repetitions) # (B,D)
|
| 161 |
+
repetition_tokens = repetition_tokens.unsqueeze(1) # (B,1,D)
|
| 162 |
+
|
| 163 |
+
return torch.cat([batch_tokens,repetition_tokens], dim=1) # (B, 73, D)
|
| 164 |
+
|
| 165 |
+
# --- Helper Modules --- #
|
| 166 |
+
class SwiGLUFFN(nn.Module):
|
| 167 |
+
def __init__(self,
|
| 168 |
+
d_model,
|
| 169 |
+
dim_feedforward,
|
| 170 |
+
dropout: float,
|
| 171 |
+
bias_up: bool=False,
|
| 172 |
+
bias_gate: bool=False,
|
| 173 |
+
bias_down: bool=True,
|
| 174 |
+
dtype=None):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.up_proj = nn.Linear(d_model,dim_feedforward,bias=bias_up,dtype=dtype)
|
| 177 |
+
self.gate_proj = nn.Linear(d_model,dim_feedforward,bias=bias_gate,dtype=dtype)
|
| 178 |
+
self.down_proj = nn.Linear(dim_feedforward,d_model,bias=bias_down,dtype=dtype)
|
| 179 |
+
|
| 180 |
+
self.dropout = nn.Dropout(dropout)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
x = self.up_proj(x) * self.dropout(nn.functional.silu(self.gate_proj(x)))
|
| 184 |
+
return self.down_proj(x)
|
| 185 |
+
|
| 186 |
+
class TransformerEncoderLayer(nn.Module):
|
| 187 |
+
"""Custom transformer encoder layer with RMSNorm and SwiGLUFFN"""
|
| 188 |
+
def __init__(self,
|
| 189 |
+
d_model: int,
|
| 190 |
+
nhead: int,
|
| 191 |
+
dim_feedforward: int,
|
| 192 |
+
dropout: float,
|
| 193 |
+
batch_first: bool=True,
|
| 194 |
+
norm_first: bool=False,
|
| 195 |
+
dtype=None):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.norm_first = norm_first
|
| 198 |
+
|
| 199 |
+
self.norm1 = nn.RMSNorm(d_model,dtype=dtype)
|
| 200 |
+
self.dropout_sa = nn.Dropout(dropout)
|
| 201 |
+
self.self_attn = nn.MultiheadAttention(
|
| 202 |
+
d_model,
|
| 203 |
+
nhead,
|
| 204 |
+
dropout=dropout,
|
| 205 |
+
bias=False,
|
| 206 |
+
batch_first=batch_first,
|
| 207 |
+
dtype=dtype
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
self.norm2 = nn.RMSNorm(d_model,dtype=dtype)
|
| 211 |
+
self.dropout_ff = nn.Dropout(dropout)
|
| 212 |
+
self.mlp = SwiGLUFFN(
|
| 213 |
+
d_model,
|
| 214 |
+
dim_feedforward,
|
| 215 |
+
dropout=dropout,
|
| 216 |
+
bias_up=False,
|
| 217 |
+
bias_gate=False,
|
| 218 |
+
bias_down=True,
|
| 219 |
+
dtype=dtype
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def forward(self, x, return_attention=False):
|
| 223 |
+
if self.norm_first:
|
| 224 |
+
if return_attention:
|
| 225 |
+
x_norm = self.norm1(x)
|
| 226 |
+
attn_output, attn_weights = self._sa_block(x_norm,return_attention=True)
|
| 227 |
+
x = x + attn_output
|
| 228 |
+
x = x + self._ff_block(self.norm2(x))
|
| 229 |
+
return x, attn_weights
|
| 230 |
+
else:
|
| 231 |
+
x = x + self._sa_block(self.norm1(x))
|
| 232 |
+
x = x + self._ff_block(self.norm2(x))
|
| 233 |
+
return x
|
| 234 |
+
else:
|
| 235 |
+
if return_attention:
|
| 236 |
+
attn_output, attn_weights = self._sa_block(x, return_attention=True)
|
| 237 |
+
x = self.norm1(x + attn_output)
|
| 238 |
+
x = self.norm2(x + self._ff_block(x))
|
| 239 |
+
return x, attn_weights
|
| 240 |
+
else:
|
| 241 |
+
x = self.norm1(x + self._sa_block(x))
|
| 242 |
+
x = self.norm2(x + self._ff_block(x))
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def _sa_block(self, x, return_attention=False):
|
| 246 |
+
if return_attention:
|
| 247 |
+
attn_output, attn_weights = self.self_attn(x,x,x,need_weights=True,average_attn_weights=False)
|
| 248 |
+
return self.dropout_sa(attn_output), attn_weights
|
| 249 |
+
else:
|
| 250 |
+
x = self.self_attn(x,x,x)[0]
|
| 251 |
+
return self.dropout_sa(x)
|
| 252 |
+
|
| 253 |
+
def _ff_block(self,x):
|
| 254 |
+
x = self.mlp(x)
|
| 255 |
+
return self.dropout_ff(x)
|
| 256 |
+
nn.TransformerEncoderLayer
|
| 257 |
+
|
| 258 |
+
# --- Model Arch --- #
|
| 259 |
+
class ChessFormerModel(nn.Module, PyTorchModelHubMixin):
|
| 260 |
+
def __init__(self,
|
| 261 |
+
num_blocks,
|
| 262 |
+
hidden_size,
|
| 263 |
+
intermediate_size,
|
| 264 |
+
num_heads,
|
| 265 |
+
dropout: float=0.00,
|
| 266 |
+
possible_moves: int=len(IDX_TO_UCI_MOVE), # 1969 structurally valid moves
|
| 267 |
+
dtype=None):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.fen_tokenizer = FENTokenizer(hidden_size,dtype=dtype)
|
| 270 |
+
|
| 271 |
+
self.act_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02)
|
| 272 |
+
self.val_token = nn.Parameter(torch.randn((1,1,hidden_size),dtype=dtype) * 0.02)
|
| 273 |
+
|
| 274 |
+
self.act_proj = nn.Linear(hidden_size,possible_moves,dtype=dtype)
|
| 275 |
+
self.val_proj = nn.Linear(hidden_size,1,dtype=dtype)
|
| 276 |
+
|
| 277 |
+
self.blocks = nn.ModuleList(
|
| 278 |
+
TransformerEncoderLayer(
|
| 279 |
+
d_model=hidden_size,
|
| 280 |
+
nhead=num_heads,
|
| 281 |
+
dim_feedforward=intermediate_size,
|
| 282 |
+
dropout=dropout,
|
| 283 |
+
batch_first=True,
|
| 284 |
+
norm_first=True,
|
| 285 |
+
dtype=dtype
|
| 286 |
+
) for _ in range(num_blocks)
|
| 287 |
+
)
|
| 288 |
+
self.dtype=dtype
|
| 289 |
+
self.possible_moves = possible_moves
|
| 290 |
+
|
| 291 |
+
self.final_norm = nn.RMSNorm(hidden_size)
|
| 292 |
+
|
| 293 |
+
self._initialize_weights()
|
| 294 |
+
|
| 295 |
+
def _initialize_weights(self):
|
| 296 |
+
"""Initialize weights"""
|
| 297 |
+
for m in self.modules():
|
| 298 |
+
if isinstance(m,nn.Linear):
|
| 299 |
+
nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='relu')
|
| 300 |
+
if m.bias is not None:
|
| 301 |
+
nn.init.constant_(m.bias, 0)
|
| 302 |
+
elif isinstance(m, nn.Embedding):
|
| 303 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 304 |
+
elif isinstance(m, nn.LayerNorm):
|
| 305 |
+
if hasattr(m, 'weight'):
|
| 306 |
+
nn.init.constant_(m.weight, 1.0)
|
| 307 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
| 308 |
+
nn.init.constant_(m.weight, 0.0)
|
| 309 |
+
elif isinstance(m, nn.RMSNorm):
|
| 310 |
+
if hasattr(m, 'weight'):
|
| 311 |
+
nn.init.constant_(m.weight, 1.0)
|
| 312 |
+
|
| 313 |
+
tokenizer_params = dict(self.fen_tokenizer.named_parameters())
|
| 314 |
+
|
| 315 |
+
params_to_init = [
|
| 316 |
+
self.act_token, self.val_token,
|
| 317 |
+
tokenizer_params.get('castling_embed_k'), tokenizer_params.get('castling_embed_q'),
|
| 318 |
+
tokenizer_params.get('castling_embed_K'), tokenizer_params.get('castling_embed_Q'),
|
| 319 |
+
tokenizer_params.get('no_castling_embed'), tokenizer_params.get('no_en_passant_embed')
|
| 320 |
+
]
|
| 321 |
+
|
| 322 |
+
for param in params_to_init:
|
| 323 |
+
if param is not None and param.requires_grad:
|
| 324 |
+
nn.init.normal_(param, std=0.02)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def forward(self, fen: List[str], repetitions: torch.Tensor, return_attention: bool=False) -> torch.Tensor:
|
| 328 |
+
x = self.fen_tokenizer(fen,repetitions) # (B,73,D), pos embed are added here
|
| 329 |
+
bs = x.shape[0]
|
| 330 |
+
x = torch.cat([x,self.act_token.expand(bs,-1,-1),self.val_token.expand(bs,-1,-1)],dim=1) # (B,75,D)
|
| 331 |
+
|
| 332 |
+
attention_maps = [] if return_attention else None
|
| 333 |
+
|
| 334 |
+
for block in self.blocks:
|
| 335 |
+
if return_attention:
|
| 336 |
+
x, attn = block(x, return_attention=True)
|
| 337 |
+
attention_maps.append(attn)
|
| 338 |
+
else:
|
| 339 |
+
x = block(x)
|
| 340 |
+
|
| 341 |
+
x = self.final_norm(x)
|
| 342 |
+
|
| 343 |
+
act = x[:,-2,:]
|
| 344 |
+
val = x[:,-1,:]
|
| 345 |
+
act_logits = self.act_proj(act) # (B,1969)
|
| 346 |
+
val = self.val_proj(val) # (B,1)
|
| 347 |
+
|
| 348 |
+
if return_attention:
|
| 349 |
+
return act_logits, val.squeeze(1), attention_maps
|
| 350 |
+
else:
|
| 351 |
+
return act_logits, val.squeeze(1)
|
| 352 |
+
|
| 353 |
+
def load_model(ckpt_path):
|
| 354 |
+
checkpoint = torch.load(ckpt_path)
|
| 355 |
+
model_config = checkpoint["model_config"]
|
| 356 |
+
model = ChessFormerModel(**model_config)
|
| 357 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 358 |
+
return model
|
| 359 |
+
|
| 360 |
+
if __name__ == "__main__":
|
| 361 |
+
checkpoint = torch.load("./ckpts/chessformer-sl_01.pth",map_location=torch.device("cpu"))
|
| 362 |
+
model = ChessFormerModel(**checkpoint["config"])
|
| 363 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 364 |
+
|
| 365 |
+
model.push_to_hub("kaupane/ChessFormer-SL")
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
python-chess
|
| 4 |
+
chess
|
| 5 |
+
huggingface-hub
|
| 6 |
+
transformers
|
| 7 |
+
numpy
|
| 8 |
+
Pillow
|
| 9 |
+
datasets
|
| 10 |
+
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .buffer import ReplayBuffer, Game
|
| 2 |
+
from .chess_env import BatchChessEnv
|
| 3 |
+
from .engine import Engine, ChessformerConfig, StockfishConfig
|
| 4 |
+
from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE, MAX_HALFMOVES, MAX_FULLMOVES, EMPTY_SQ_IDX, PIECE_TO_IDX, SQUARE_TO_IDX
|
| 5 |
+
|
| 6 |
+
__all__ = ['ReplayBuffer',
|
| 7 |
+
'BatchChessEnv',
|
| 8 |
+
'Engine',
|
| 9 |
+
'Game',
|
| 10 |
+
'UCI_MOVE_TO_IDX',
|
| 11 |
+
'IDX_TO_UCI_MOVE',
|
| 12 |
+
'MAX_HALFMOVES',
|
| 13 |
+
'MAX_FULLMOVES',
|
| 14 |
+
'EMPTY_SQ_IDX',
|
| 15 |
+
'PIECE_TO_IDX',
|
| 16 |
+
'SQUARE_TO_IDX'
|
| 17 |
+
]
|
utils/buffer.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from collections import deque
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import List, Iterator, Tuple, Optional
|
| 6 |
+
import chess
|
| 7 |
+
|
| 8 |
+
class Game:
|
| 9 |
+
"""
|
| 10 |
+
Represents a single chess game trajectory with all relevant data for RL training.
|
| 11 |
+
Acts as a *temporary* buffer inside loop
|
| 12 |
+
Handles:
|
| 13 |
+
- Storing trajectory data (fens, reps, actions, log_probs, values, invalid_masks)
|
| 14 |
+
- Tracking game status (active/complete)
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.active = True
|
| 18 |
+
self.valid = True
|
| 19 |
+
self.completion_reason = None
|
| 20 |
+
self.game_result = None
|
| 21 |
+
|
| 22 |
+
self.fens = []
|
| 23 |
+
self.repetition_counts = []
|
| 24 |
+
self.actions = []
|
| 25 |
+
self.values = []
|
| 26 |
+
self.log_probs = []
|
| 27 |
+
self.invalid_masks = []
|
| 28 |
+
|
| 29 |
+
def update_trajectory(self, fen, rep, act, val, logp, inv_m):
|
| 30 |
+
self.fens.append(fen)
|
| 31 |
+
self.repetition_counts.append(rep)
|
| 32 |
+
self.actions.append(act)
|
| 33 |
+
self.values.append(val)
|
| 34 |
+
self.log_probs.append(logp)
|
| 35 |
+
self.invalid_masks.append(inv_m)
|
| 36 |
+
|
| 37 |
+
def update_game_status(self, done, reason):
|
| 38 |
+
if done == True:
|
| 39 |
+
self.active = False
|
| 40 |
+
if reason in ["1-0","0-1","1/2-1/2"]:
|
| 41 |
+
self.completion_reason = reason
|
| 42 |
+
self.game_result = reason
|
| 43 |
+
else:
|
| 44 |
+
self.completion_reason = reason
|
| 45 |
+
self.game_result = None
|
| 46 |
+
self.valid = False
|
| 47 |
+
|
| 48 |
+
def get_white_trajectory(self):
|
| 49 |
+
"""Extract the trajectory for white"""
|
| 50 |
+
indices = []
|
| 51 |
+
for i in range(len(self.fens) - 1):
|
| 52 |
+
board = chess.Board(self.fens[i])
|
| 53 |
+
if board.turn: # True if white to move
|
| 54 |
+
indices.append(i)
|
| 55 |
+
|
| 56 |
+
return {
|
| 57 |
+
'fens': [self.fens[i] for i in indices],
|
| 58 |
+
'repetition_counts': [self.repetition_counts[i] for i in indices],
|
| 59 |
+
'actions': [self.actions[i] for i in indices],
|
| 60 |
+
'values': [self.values[i] for i in indices],
|
| 61 |
+
'log_probs': [self.log_probs[i] for i in indices],
|
| 62 |
+
'invalid_masks': [self.invalid_masks[i] for i in indices]
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def get_black_trajectory(self):
|
| 66 |
+
"""Extract the trajectory for black pieces."""
|
| 67 |
+
indices = []
|
| 68 |
+
for i in range(len(self.fens) - 1):
|
| 69 |
+
board = chess.Board(self.fens[i])
|
| 70 |
+
if not board.turn: # False if black to move
|
| 71 |
+
indices.append(i)
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
'fens': [self.fens[i] for i in indices],
|
| 75 |
+
'repetition_counts': [self.repetition_counts[i] for i in indices],
|
| 76 |
+
'actions': [self.actions[i] for i in indices],
|
| 77 |
+
'values': [self.values[i] for i in indices],
|
| 78 |
+
'log_probs': [self.log_probs[i] for i in indices],
|
| 79 |
+
'invalid_masks': [self.invalid_masks[i] for i in indices]
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ReplayBuffer:
|
| 87 |
+
"""
|
| 88 |
+
The buffer class for PPO reinforcement learning.
|
| 89 |
+
Handles:
|
| 90 |
+
- store samples including:
|
| 91 |
+
1. fens
|
| 92 |
+
2. reps
|
| 93 |
+
3. actions
|
| 94 |
+
4. log_probs
|
| 95 |
+
5. values
|
| 96 |
+
6. invalid_masks
|
| 97 |
+
- calculate advantage based on reward and value (7. advantage)
|
| 98 |
+
- output samples in batches
|
| 99 |
+
Since PPO is largely on-policy, so the replay buffer will not be so large that deque is not appropriate
|
| 100 |
+
"""
|
| 101 |
+
def __init__(self,
|
| 102 |
+
capacity: int,
|
| 103 |
+
batch_size: int,
|
| 104 |
+
gamma: float,
|
| 105 |
+
gae_lambda: float,
|
| 106 |
+
shuffle: bool=True
|
| 107 |
+
):
|
| 108 |
+
self.gamma = gamma
|
| 109 |
+
self.gae_lambda = gae_lambda
|
| 110 |
+
|
| 111 |
+
self.fens = deque(maxlen=capacity)
|
| 112 |
+
self.repetition_counts = deque(maxlen=capacity)
|
| 113 |
+
self.actions = deque(maxlen=capacity)
|
| 114 |
+
self.log_probs = deque(maxlen=capacity)
|
| 115 |
+
self.values = deque(maxlen=capacity)
|
| 116 |
+
self.invalid_masks = deque(maxlen=capacity)
|
| 117 |
+
self.advantages = deque(maxlen=capacity)
|
| 118 |
+
|
| 119 |
+
self.batch_size = batch_size
|
| 120 |
+
self.shuffle = shuffle
|
| 121 |
+
|
| 122 |
+
def push_game(self, game: Game):
|
| 123 |
+
"""
|
| 124 |
+
Process a completed game and add its trajectories to the buffer.
|
| 125 |
+
Handles reward computation for both white and black players.
|
| 126 |
+
"""
|
| 127 |
+
if not game.valid:
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
result = game.game_result
|
| 131 |
+
if result not in ["1-0","0-1","1/2-1/2"]:
|
| 132 |
+
raise ValueError(f"Result not recognized: {result}. Either an incompleted game was passed in, or game.update_game_status() method is wrong.")
|
| 133 |
+
|
| 134 |
+
if result == "1-0": result = 1
|
| 135 |
+
elif result == "0-1": result = -1
|
| 136 |
+
elif result == "1/2-1/2": result = 0
|
| 137 |
+
|
| 138 |
+
white_traj = game.get_white_trajectory()
|
| 139 |
+
if white_traj["fens"]:
|
| 140 |
+
self._process_trajectory(
|
| 141 |
+
white_traj["fens"],
|
| 142 |
+
white_traj["repetition_counts"],
|
| 143 |
+
white_traj["actions"],
|
| 144 |
+
white_traj["log_probs"],
|
| 145 |
+
white_traj["values"],
|
| 146 |
+
white_traj["invalid_masks"],
|
| 147 |
+
result
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
black_traj = game.get_black_trajectory()
|
| 151 |
+
if black_traj["fens"]:
|
| 152 |
+
self._process_trajectory(
|
| 153 |
+
black_traj["fens"],
|
| 154 |
+
black_traj["repetition_counts"],
|
| 155 |
+
black_traj["actions"],
|
| 156 |
+
black_traj["log_probs"],
|
| 157 |
+
black_traj["values"],
|
| 158 |
+
black_traj["invalid_masks"],
|
| 159 |
+
-result # flip reward for black's perspective
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def _process_trajectory(self, fens, reps, actions, log_probs, values, invalid_masks, final_reward):
|
| 163 |
+
"""Process a trajectory for one player, compute advantages and add to buffer"""
|
| 164 |
+
values_tensor = torch.tensor(values) if not torch.is_tensor(values) else values
|
| 165 |
+
|
| 166 |
+
advantages = self._compute_advantage(values_tensor, final_reward)
|
| 167 |
+
|
| 168 |
+
for i in range(len(fens)):
|
| 169 |
+
self.fens.append(fens[i])
|
| 170 |
+
self.repetition_counts.append(reps[i])
|
| 171 |
+
self.actions.append(actions[i])
|
| 172 |
+
self.log_probs.append(log_probs[i])
|
| 173 |
+
self.values.append(values[i])
|
| 174 |
+
self.invalid_masks.append(invalid_masks[i])
|
| 175 |
+
self.advantages.append(advantages[i])
|
| 176 |
+
|
| 177 |
+
def _compute_advantage(self, value_traj: torch.Tensor, final_reward: float) -> torch.Tensor:
|
| 178 |
+
"""
|
| 179 |
+
Calculate GAE with only a terminal reward: r_t = 0 for t < T-1 and r_{T-1} = final_reward
|
| 180 |
+
Args:
|
| 181 |
+
value_traj: value prediction of the model
|
| 182 |
+
final_reward: game result
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
advantage, torch.Tensor of shape same with value_traj
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
vals = value_traj.detach().cpu().float()
|
| 189 |
+
T = vals.shape[0] if vals.dim() > 0 else 1
|
| 190 |
+
|
| 191 |
+
adv = torch.zeros(T)
|
| 192 |
+
next_value = 0.0
|
| 193 |
+
gae = 0.0
|
| 194 |
+
|
| 195 |
+
for t in reversed(range(T)):
|
| 196 |
+
reward = final_reward if t == T-1 else 0.0
|
| 197 |
+
delta = reward + self.gamma * next_value - vals[t]
|
| 198 |
+
gae = delta + self.gamma * self.gae_lambda * gae
|
| 199 |
+
adv[t] = gae
|
| 200 |
+
next_value = vals[t]
|
| 201 |
+
|
| 202 |
+
return adv
|
| 203 |
+
|
| 204 |
+
def sample(self) -> Iterator[Tuple[List[str], # fen
|
| 205 |
+
torch.Tensor,# rep
|
| 206 |
+
torch.Tensor,# act
|
| 207 |
+
torch.Tensor,# logp
|
| 208 |
+
torch.Tensor,# val
|
| 209 |
+
torch.Tensor,# inv_m
|
| 210 |
+
torch.Tensor]]: # adv
|
| 211 |
+
"""Yield minibatches of size batch_size for training"""
|
| 212 |
+
n = len(self.fens)
|
| 213 |
+
if n < self.batch_size:
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
idxs = np.arange(n)
|
| 217 |
+
if self.shuffle:
|
| 218 |
+
np.random.shuffle(idxs)
|
| 219 |
+
|
| 220 |
+
for start in range(0, n, self.batch_size):
|
| 221 |
+
batch_idx = idxs[start:start+self.batch_size]
|
| 222 |
+
if len(batch_idx) < self.batch_size:
|
| 223 |
+
break
|
| 224 |
+
|
| 225 |
+
fens_b = [self.fens[i] for i in batch_idx]
|
| 226 |
+
|
| 227 |
+
reps_b = torch.stack([
|
| 228 |
+
self.repetition_counts[i].detach().clone() if torch.is_tensor(self.repetition_counts[i])
|
| 229 |
+
else torch.tensor(self.repetition_counts[i])
|
| 230 |
+
for i in batch_idx
|
| 231 |
+
])
|
| 232 |
+
|
| 233 |
+
acts_b = torch.stack([
|
| 234 |
+
self.actions[i].detach().clone() if torch.is_tensor(self.actions[i])
|
| 235 |
+
else torch.tensor(self.actions[i])
|
| 236 |
+
for i in batch_idx
|
| 237 |
+
])
|
| 238 |
+
logps_b = torch.stack([
|
| 239 |
+
self.log_probs[i].detach().clone() if torch.is_tensor(self.log_probs[i])
|
| 240 |
+
else torch.tensor(self.log_probs[i])
|
| 241 |
+
for i in batch_idx
|
| 242 |
+
])
|
| 243 |
+
|
| 244 |
+
vals_b = torch.stack([
|
| 245 |
+
self.values[i].detach().clone() if torch.is_tensor(self.values[i])
|
| 246 |
+
else torch.tensor(self.values[i])
|
| 247 |
+
for i in batch_idx
|
| 248 |
+
])
|
| 249 |
+
|
| 250 |
+
advs_b = torch.stack([
|
| 251 |
+
self.advantages[i].detach().clone() if torch.is_tensor(self.advantages[i])
|
| 252 |
+
else torch.tensor(self.advantages[i])
|
| 253 |
+
for i in batch_idx
|
| 254 |
+
])
|
| 255 |
+
|
| 256 |
+
invs_b = torch.stack([
|
| 257 |
+
self.invalid_masks[i] if torch.is_tensor(self.invalid_masks[i])
|
| 258 |
+
else torch.tensor(self.invalid_masks[i])
|
| 259 |
+
for i in batch_idx
|
| 260 |
+
])
|
| 261 |
+
|
| 262 |
+
yield fens_b, reps_b, acts_b, logps_b, vals_b, invs_b, advs_b
|
| 263 |
+
|
| 264 |
+
def __len__(self) -> int:
|
| 265 |
+
return len(self.fens)
|
| 266 |
+
|
| 267 |
+
def clear(self) -> None:
|
| 268 |
+
self.fens.clear()
|
| 269 |
+
self.repetition_counts.clear()
|
| 270 |
+
self.actions.clear()
|
| 271 |
+
self.log_probs.clear()
|
| 272 |
+
self.values.clear()
|
| 273 |
+
self.invalid_masks.clear()
|
| 274 |
+
self.advantages.clear()
|
utils/chess_env.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Provide a gym-like environment for clarity"""
|
| 2 |
+
|
| 3 |
+
import chess
|
| 4 |
+
import torch
|
| 5 |
+
import time
|
| 6 |
+
from typing import List, Tuple, Dict
|
| 7 |
+
try:
|
| 8 |
+
from .mapping import IDX_TO_UCI_MOVE, UCI_MOVE_TO_IDX
|
| 9 |
+
except:
|
| 10 |
+
from mapping import IDX_TO_UCI_MOVE, UCI_MOVE_TO_IDX
|
| 11 |
+
|
| 12 |
+
class BatchChessEnv:
|
| 13 |
+
"""A single chess environment with sparse terminal reward"""
|
| 14 |
+
def __init__(self, batch_size: int, max_moves: int=200):
|
| 15 |
+
self.batch_size = batch_size
|
| 16 |
+
self.max_moves = max_moves
|
| 17 |
+
self.reset()
|
| 18 |
+
|
| 19 |
+
def reset(self) -> Tuple[List[str], torch.Tensor]:
|
| 20 |
+
"""
|
| 21 |
+
Starts all games from the initial position
|
| 22 |
+
Returns:
|
| 23 |
+
fens (List[str]), repetition_counts (torch.Tensor of shape [batch_size,])
|
| 24 |
+
"""
|
| 25 |
+
self.boards = [chess.Board() for _ in range(self.batch_size)]
|
| 26 |
+
self.move_counts = [0] * self.batch_size
|
| 27 |
+
self.done_flags = [False] * self.batch_size
|
| 28 |
+
|
| 29 |
+
fens = [self.boards[0].fen()] * self.batch_size
|
| 30 |
+
reps = torch.ones(self.batch_size,dtype=torch.long)
|
| 31 |
+
return fens, reps # (bs,)
|
| 32 |
+
|
| 33 |
+
def _compute_rep(self, board: chess.Board) -> int:
|
| 34 |
+
board_copy = board.copy()
|
| 35 |
+
trasposition_key = board_copy._transposition_key()
|
| 36 |
+
count = 0
|
| 37 |
+
while board_copy.move_stack:
|
| 38 |
+
board_copy.pop()
|
| 39 |
+
if board_copy._transposition_key() == trasposition_key:
|
| 40 |
+
count += 1
|
| 41 |
+
return count + 1 # 1 for fresh position
|
| 42 |
+
|
| 43 |
+
def step(self, uci_moves: List[str]) -> Tuple[List[str], # next fens (next state)
|
| 44 |
+
torch.Tensor, # next reps (next state)
|
| 45 |
+
List[bool], # dones
|
| 46 |
+
List[Dict]]: # infos
|
| 47 |
+
"""
|
| 48 |
+
Apply one move per game in the batch.
|
| 49 |
+
Args:
|
| 50 |
+
uci_moves: list of UCI strings (plus "<claim_draw>")
|
| 51 |
+
Returns:
|
| 52 |
+
next_fens: new FENs for each game, List[str]
|
| 53 |
+
reps: repetition counts, Tensor[batch_size]
|
| 54 |
+
dones: whether this game is now terminated, List[bool]
|
| 55 |
+
infos: info dict with 'result' key for completed games List[dict]
|
| 56 |
+
"""
|
| 57 |
+
next_fens, reps, dones, infos = [], [], [], []
|
| 58 |
+
|
| 59 |
+
for i, move in enumerate(uci_moves):
|
| 60 |
+
board = self.boards[i]
|
| 61 |
+
info = {
|
| 62 |
+
"max_steps_exceeded": False,
|
| 63 |
+
"truncation_due_to_error": False,
|
| 64 |
+
"result": None
|
| 65 |
+
}
|
| 66 |
+
done = self.done_flags[i]
|
| 67 |
+
|
| 68 |
+
if done:
|
| 69 |
+
# Game already done, pass through the existing state
|
| 70 |
+
next_fens.append(board.fen())
|
| 71 |
+
reps.append(1)
|
| 72 |
+
dones.append(True)
|
| 73 |
+
infos.append(info)
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
if move == "0000":
|
| 77 |
+
# Skip through dummy moves
|
| 78 |
+
next_fens.append(board.fen())
|
| 79 |
+
reps.append(1)
|
| 80 |
+
dones.append(True)
|
| 81 |
+
infos.append(info)
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
if board.is_game_over():
|
| 85 |
+
# Game already over
|
| 86 |
+
done = True
|
| 87 |
+
info["result"] = board.result()
|
| 88 |
+
next_fens.append(board.fen())
|
| 89 |
+
reps.append(self._compute_rep(board))
|
| 90 |
+
dones.append(done)
|
| 91 |
+
infos.append(info)
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
try:
|
| 95 |
+
if move == "<claim_draw>":
|
| 96 |
+
if board.can_claim_draw():
|
| 97 |
+
done = True
|
| 98 |
+
info['result'] = "1/2-1/2"
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError(f"Invalid move ('<claim_draw>') passed in.")
|
| 101 |
+
else:
|
| 102 |
+
try:
|
| 103 |
+
m = chess.Move.from_uci(move)
|
| 104 |
+
if m in board.legal_moves:
|
| 105 |
+
board.push(m)
|
| 106 |
+
self.move_counts[i] += 1
|
| 107 |
+
|
| 108 |
+
if board.is_game_over():
|
| 109 |
+
done = True
|
| 110 |
+
info['result'] = board.result()
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Invalid move ('{m.uci()}') passed in.")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
done = True
|
| 115 |
+
info['truncation_due_to_error'] = True
|
| 116 |
+
print(f"Unexpected error: {e}")
|
| 117 |
+
|
| 118 |
+
if self.move_counts[i] >= self.max_moves:
|
| 119 |
+
done = True
|
| 120 |
+
info['max_steps_exceeded'] = True
|
| 121 |
+
info['result'] = "1/2-1/2"
|
| 122 |
+
|
| 123 |
+
next_fens.append(board.fen())
|
| 124 |
+
reps.append(self._compute_rep(board))
|
| 125 |
+
dones.append(done)
|
| 126 |
+
infos.append(info)
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Error processing move {move} for board {i}: {e}")
|
| 130 |
+
done = True
|
| 131 |
+
info["truncation_due_to_error"] = True
|
| 132 |
+
next_fens.append(board.fen())
|
| 133 |
+
reps.append(self._compute_rep(board))
|
| 134 |
+
dones.append(done)
|
| 135 |
+
infos.append(info)
|
| 136 |
+
|
| 137 |
+
self.done_flags[i] = done
|
| 138 |
+
|
| 139 |
+
reps = torch.tensor(reps,dtype=torch.long) # [bs,]
|
| 140 |
+
return next_fens, reps, dones, infos
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
env = BatchChessEnv(1)
|
| 144 |
+
env.reset()
|
| 145 |
+
board = env.boards[0]
|
| 146 |
+
board.push(chess.Move.from_uci("e2e4"))
|
| 147 |
+
new_board = board.copy()
|
| 148 |
+
rep = env._compute_rep(new_board)
|
| 149 |
+
print(rep)
|
| 150 |
+
|
| 151 |
+
|
utils/engine.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""An engine class to provide a universal way to interact with both chessformer and stockfish"""
|
| 2 |
+
import torch
|
| 3 |
+
import chess
|
| 4 |
+
import math
|
| 5 |
+
import chess.engine
|
| 6 |
+
import multiprocessing
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from functools import partial
|
| 9 |
+
import time
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
|
| 14 |
+
except ImportError:
|
| 15 |
+
from mapping import UCI_MOVE_TO_IDX, IDX_TO_UCI_MOVE
|
| 16 |
+
from torch.distributions import Categorical
|
| 17 |
+
from typing import Optional, Tuple, List, Union
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ChessformerConfig:
|
| 21 |
+
chessformer: torch.nn.Module=None
|
| 22 |
+
device: Optional[torch.device]=None
|
| 23 |
+
temperature: float=0.5
|
| 24 |
+
depth: int=2
|
| 25 |
+
top_k: int=8
|
| 26 |
+
decay_rate: float=0.6
|
| 27 |
+
max_batch_size: int=896
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class StockfishConfig:
|
| 31 |
+
engine_path: str="/usr/games/stockfish"
|
| 32 |
+
depth: int=16
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _stockfish_worker(board_fen: str, engine_path: str, depth: int) -> Optional[Tuple[str, float]]:
|
| 36 |
+
"""
|
| 37 |
+
Analyzes a single board FEN using a temporary Stockfish engine instance.
|
| 38 |
+
Designed for use with multiprocessing.
|
| 39 |
+
Returns the best move UCI and the normalized score [-1,1].
|
| 40 |
+
Does not handle draw claims explicitly as FEN lacks history.
|
| 41 |
+
Caller should check board.is_game_over() on the main board object.
|
| 42 |
+
"""
|
| 43 |
+
engine = None
|
| 44 |
+
try:
|
| 45 |
+
engine = chess.engine.SimpleEngine.popen_uci(engine_path)
|
| 46 |
+
# initialize board from FEN - history is lost here
|
| 47 |
+
board = chess.Board(board_fen)
|
| 48 |
+
|
| 49 |
+
info = engine.analyse(board, chess.engine.Limit(depth=depth))
|
| 50 |
+
|
| 51 |
+
score_obj = info.get("score")
|
| 52 |
+
pv = info.get("pv")
|
| 53 |
+
|
| 54 |
+
if score_obj is None or pv is None or not pv:
|
| 55 |
+
# Analysis failed
|
| 56 |
+
print(f"Warning: Stockfish analysis failed for FEN: {board_fen}")
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
best_move_uci = pv[0].uci()
|
| 60 |
+
pov_score = score_obj.pov(board.turn)
|
| 61 |
+
cp = None
|
| 62 |
+
|
| 63 |
+
if pov_score.is_mate():
|
| 64 |
+
mate_score = pov_score.mate()
|
| 65 |
+
cp = 10000.0 if mate_score > 0 else -10000.0
|
| 66 |
+
elif pov_score.cp is not None:
|
| 67 |
+
cp = float(pov_score.cp)
|
| 68 |
+
else:
|
| 69 |
+
print(f"Warning: Stockfish score object lacks cp/mate for FEN: {board_fen}")
|
| 70 |
+
return None # analysis is unclear
|
| 71 |
+
|
| 72 |
+
normalized_cp = 2 / (1 + math.exp(-0.004*cp)) - 1
|
| 73 |
+
|
| 74 |
+
return best_move_uci, normalized_cp
|
| 75 |
+
|
| 76 |
+
except (chess.engine.EngineError, chess.engine.EngineTerminatedError, FileNotFoundError, ValueError) as e:
|
| 77 |
+
print(f"Stockfish worker error for FEN {board_fen}: {e}")
|
| 78 |
+
return None
|
| 79 |
+
finally:
|
| 80 |
+
if engine:
|
| 81 |
+
engine.quit()
|
| 82 |
+
|
| 83 |
+
def _compute_repetition_single(board: chess.Board) -> int:
|
| 84 |
+
"""Compute repetition count. Used in _chessformer_move and _batch_chessformer_move"""
|
| 85 |
+
|
| 86 |
+
transposition_key = board._transposition_key()
|
| 87 |
+
count = 0
|
| 88 |
+
if board.move_stack:
|
| 89 |
+
if board._transposition_key() == transposition_key:
|
| 90 |
+
count = 1
|
| 91 |
+
else:
|
| 92 |
+
count = 1
|
| 93 |
+
try:
|
| 94 |
+
# Iterate back through history
|
| 95 |
+
while board.move_stack:
|
| 96 |
+
move = board.pop() # note that history is lost here
|
| 97 |
+
if board.is_irreversible(move):
|
| 98 |
+
break
|
| 99 |
+
if board._transposition_key() == transposition_key:
|
| 100 |
+
count += 1
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Error occurred during repetition count for board {board.fen()}: {e}")
|
| 103 |
+
return 1 # fallback to 1
|
| 104 |
+
return max(1, count)
|
| 105 |
+
|
| 106 |
+
# Engine class, designed to be used in the Evaluator class and app.py
|
| 107 |
+
class Engine:
|
| 108 |
+
def __init__(self,
|
| 109 |
+
type: str,
|
| 110 |
+
chessformer_config: Optional[ChessformerConfig]=None,
|
| 111 |
+
stockfish_config: Optional[StockfishConfig]=None):
|
| 112 |
+
self.type = type
|
| 113 |
+
if type == "chessformer":
|
| 114 |
+
if chessformer_config is None:
|
| 115 |
+
raise ValueError("ChessformerConfig must be provided for chessformer engine.")
|
| 116 |
+
|
| 117 |
+
self.config = chessformer_config
|
| 118 |
+
if self.config.chessformer is None:
|
| 119 |
+
raise ValueError("ChessFormer model must be provided in config.")
|
| 120 |
+
|
| 121 |
+
if self.config.device is None:
|
| 122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 123 |
+
elif isinstance(self.config.device, str):
|
| 124 |
+
self.device = torch.device(self.config.device)
|
| 125 |
+
else:
|
| 126 |
+
self.device = self.config.device
|
| 127 |
+
|
| 128 |
+
self.model = self.config.chessformer
|
| 129 |
+
self.model.to(self.device)
|
| 130 |
+
self.model.eval()
|
| 131 |
+
|
| 132 |
+
if not (self.config.temperature > 0):
|
| 133 |
+
raise ValueError("Temperature must be greater than 0.")
|
| 134 |
+
if not (self.config.top_k > 0):
|
| 135 |
+
raise ValueError("Top-k must be greater than 0.")
|
| 136 |
+
if not (self.config.depth >= 0):
|
| 137 |
+
raise ValueError("Depth must be greater than or equal to 0.")
|
| 138 |
+
if not (0.0 < self.config.decay_rate <= 1.0):
|
| 139 |
+
raise ValueError("Decay rate must be in range (0.0,1.0].")
|
| 140 |
+
if not (self.config.max_batch_size > 0):
|
| 141 |
+
raise ValueError("Max batch size must be an integer greater than 0.")
|
| 142 |
+
|
| 143 |
+
self.temperature = self.config.temperature
|
| 144 |
+
self.top_k = self.config.top_k
|
| 145 |
+
self.initial_k = self.top_k
|
| 146 |
+
self.depth = self.config.depth
|
| 147 |
+
self.decay_rate = self.config.decay_rate
|
| 148 |
+
self.max_batch_size = self.config.max_batch_size
|
| 149 |
+
elif type == "stockfish":
|
| 150 |
+
if stockfish_config is None:
|
| 151 |
+
raise ValueError("StockfishConfig must be provided for stockfish engine.")
|
| 152 |
+
|
| 153 |
+
self.config = stockfish_config
|
| 154 |
+
self.engine_path = self.config.engine_path
|
| 155 |
+
self.depth = self.config.depth
|
| 156 |
+
if self.config.engine_path is None:
|
| 157 |
+
raise ValueError("Engine path must be provided in config.")
|
| 158 |
+
try:
|
| 159 |
+
with chess.engine.SimpleEngine.popen_uci(self.config.engine_path) as test:
|
| 160 |
+
pass
|
| 161 |
+
except (FileNotFoundError, chess.engine.EngineError) as e:
|
| 162 |
+
raise ValueError(f"Invalid engine path or engine not found: {e}")
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError("Invalid engine type. Choose 'chessformer' or 'stockfish'.")
|
| 165 |
+
|
| 166 |
+
def get_invalid_mask(self, boards: List[chess.Board]) -> torch.Tensor:
|
| 167 |
+
bs = len(boards)
|
| 168 |
+
possible_moves = len(UCI_MOVE_TO_IDX)
|
| 169 |
+
invalid_mask = torch.full((bs,possible_moves), -torch.inf, dtype=torch.float32, device=self.device)
|
| 170 |
+
for idx,board in enumerate(boards):
|
| 171 |
+
if board.is_game_over(claim_draw=True):
|
| 172 |
+
continue # leave all as -inf
|
| 173 |
+
legal_moves = list(board.legal_moves)
|
| 174 |
+
legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
|
| 175 |
+
if legal_move_ids:
|
| 176 |
+
invalid_mask[idx,legal_move_ids] = 0
|
| 177 |
+
if board.can_claim_draw():
|
| 178 |
+
invalid_mask[idx,0] = 0
|
| 179 |
+
|
| 180 |
+
return invalid_mask
|
| 181 |
+
|
| 182 |
+
def compute_repetition(self, boards: List[chess.Board]) -> torch.Tensor:
|
| 183 |
+
"""Use multiprocessing to compute repetition count for a batch of boards."""
|
| 184 |
+
bs = len(boards)
|
| 185 |
+
num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count else 1))
|
| 186 |
+
if bs < num_workers * 2: # avoid overhead for very small batches per worker
|
| 187 |
+
num_workers = max(1, bs//2)
|
| 188 |
+
|
| 189 |
+
try:
|
| 190 |
+
if num_workers > 1 and bs > 1:
|
| 191 |
+
board_copies = [board.copy(stack=True) for board in boards]
|
| 192 |
+
with multiprocessing.Pool(processes=num_workers) as pool:
|
| 193 |
+
counts = pool.map(_compute_repetition_single, board_copies)
|
| 194 |
+
else:
|
| 195 |
+
# Run sequentially if only one worker needed or very small batch
|
| 196 |
+
counts = [_compute_repetition_single(b.copy(stack=True)) for b in boards]
|
| 197 |
+
|
| 198 |
+
counts_tensor = torch.tensor(counts, dtype=torch.long, device=self.device)
|
| 199 |
+
return counts_tensor # (B,)
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"Error during batch repetition computation: {e}")
|
| 202 |
+
# Fall back to single board computation if multiprocessing fails
|
| 203 |
+
return torch.ones((bs,),dtype=torch.long, device=self.device)
|
| 204 |
+
|
| 205 |
+
def _raw_chessformer_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
|
| 206 |
+
"""Get the next move from ChessFormer model with optional tactical verification."""
|
| 207 |
+
# Get FEN
|
| 208 |
+
fen = board.fen()
|
| 209 |
+
|
| 210 |
+
# Compute repetition
|
| 211 |
+
count_tensor = self.compute_repetition([board])
|
| 212 |
+
|
| 213 |
+
move_logits, value = self.model([fen],count_tensor)
|
| 214 |
+
move_logits = move_logits.squeeze(0) # remove batch dimension since it will always be 1
|
| 215 |
+
value = value.squeeze(0).item()
|
| 216 |
+
|
| 217 |
+
# Calculate invalid mask
|
| 218 |
+
legal_moves = list(board.legal_moves)
|
| 219 |
+
if not legal_moves and not board.can_claim_draw():
|
| 220 |
+
# No legal moves. Should not happen if this function is called correctly, but it wouldn't hurt to add a check
|
| 221 |
+
return None
|
| 222 |
+
legal_move_ids = [UCI_MOVE_TO_IDX[move.uci()] for move in legal_moves]
|
| 223 |
+
invalid_mask = torch.ones_like(move_logits) * (-torch.inf)
|
| 224 |
+
invalid_mask[legal_move_ids] = 0
|
| 225 |
+
if board.can_claim_draw():
|
| 226 |
+
invalid_mask[0] = 0
|
| 227 |
+
move_logits = move_logits + invalid_mask
|
| 228 |
+
|
| 229 |
+
if return_perplexity:
|
| 230 |
+
probs = torch.softmax(move_logits, dim=-1)
|
| 231 |
+
perplexity = torch.exp(-torch.sum(probs*torch.log(probs+1e-8))).item()
|
| 232 |
+
|
| 233 |
+
top_k_ids = torch.topk(move_logits, self.top_k, dim=-1).indices
|
| 234 |
+
top_k_mask = torch.ones_like(move_logits) * (-torch.inf)
|
| 235 |
+
top_k_mask[top_k_ids] = 0
|
| 236 |
+
move_logits = move_logits + top_k_mask
|
| 237 |
+
move_logits = move_logits / self.temperature
|
| 238 |
+
|
| 239 |
+
# Sample
|
| 240 |
+
dist = Categorical(logits=move_logits)
|
| 241 |
+
move_id = dist.sample().item()
|
| 242 |
+
move = IDX_TO_UCI_MOVE[move_id]
|
| 243 |
+
if return_perplexity:
|
| 244 |
+
return move, value, perplexity
|
| 245 |
+
else:
|
| 246 |
+
return move, value
|
| 247 |
+
|
| 248 |
+
def _search_enhanced_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
|
| 249 |
+
"""Get move from chessformer using tactical search"""
|
| 250 |
+
# Step 1: Build search tree level by level
|
| 251 |
+
current_boards = [board] # aggregate board to a list for batch inference
|
| 252 |
+
board_probs = [1] # the probabilities of getting to this position (estimated)
|
| 253 |
+
|
| 254 |
+
terminal_leaves = [] # (root_move, prob, game_result_value) ^from white's perspective
|
| 255 |
+
search_leaves = [] # (root_move, prob, board) - leaves not terminal but reached max depth therefore needs evaluation from model
|
| 256 |
+
|
| 257 |
+
# Track which root_move each board came from
|
| 258 |
+
board_to_root_move = [None] # root board has no parent move
|
| 259 |
+
|
| 260 |
+
for depth in range(self.depth+1):
|
| 261 |
+
if not current_boards:
|
| 262 |
+
break
|
| 263 |
+
k = max(1,int(self.initial_k*(self.decay_rate**depth)))
|
| 264 |
+
|
| 265 |
+
fens = [b.fen() for b in current_boards]
|
| 266 |
+
reps = self.compute_repetition(current_boards)
|
| 267 |
+
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
logits, values = self.model(fens,reps)
|
| 270 |
+
|
| 271 |
+
next_boards = []
|
| 272 |
+
next_board_probs = []
|
| 273 |
+
next_board_to_root_move = []
|
| 274 |
+
|
| 275 |
+
# Process each board at current depth
|
| 276 |
+
for board_idx, current_board in enumerate(current_boards):
|
| 277 |
+
board_logits = logits[board_idx]
|
| 278 |
+
board_prob = board_probs[board_idx]
|
| 279 |
+
parent_root_move = board_to_root_move[board_idx]
|
| 280 |
+
|
| 281 |
+
# Check if game is over
|
| 282 |
+
if current_board.is_game_over(claim_draw=True):
|
| 283 |
+
outcome = current_board.outcome(claim_draw=True)
|
| 284 |
+
if outcome.winner == chess.WHITE:
|
| 285 |
+
game_value = 1.0
|
| 286 |
+
elif outcome.winner == chess.BLACK:
|
| 287 |
+
game_value = -1.0
|
| 288 |
+
else:
|
| 289 |
+
game_value = 0.0
|
| 290 |
+
terminal_leaves.append((parent_root_move, board_prob, game_value))
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
# If we've reached max depth, add to search leaves
|
| 294 |
+
if depth == self.depth:
|
| 295 |
+
search_leaves.append((parent_root_move, board_prob, current_board))
|
| 296 |
+
continue
|
| 297 |
+
|
| 298 |
+
# Otherwise, recursively search deeper
|
| 299 |
+
invalid_mask = self.get_invalid_mask([current_board])[0]
|
| 300 |
+
masked_logits = board_logits + invalid_mask
|
| 301 |
+
|
| 302 |
+
top_k_values, top_k_indices = torch.topk(masked_logits,k=min(k,torch.sum(invalid_mask==0).item()))
|
| 303 |
+
top_k_probs = torch.softmax(top_k_values,dim=0)
|
| 304 |
+
if depth==0:
|
| 305 |
+
initial_masked_logits = masked_logits.squeeze(0)
|
| 306 |
+
initial_invalid_mask = invalid_mask.squeeze(0)
|
| 307 |
+
initial_top_k_indices = top_k_indices
|
| 308 |
+
|
| 309 |
+
# Expand each top k move
|
| 310 |
+
for i,move_idx in enumerate(top_k_indices):
|
| 311 |
+
move_prob = top_k_probs[i].item()
|
| 312 |
+
move_uci = IDX_TO_UCI_MOVE[move_idx.item()]
|
| 313 |
+
|
| 314 |
+
root_move = parent_root_move if parent_root_move is not None else move_uci
|
| 315 |
+
|
| 316 |
+
new_board = current_board.copy()
|
| 317 |
+
|
| 318 |
+
if move_uci == "<claim_draw>":
|
| 319 |
+
if new_board.can_claim_draw():
|
| 320 |
+
terminal_leaves.append((root_move,board_prob*move_prob,0.0))
|
| 321 |
+
continue
|
| 322 |
+
else:
|
| 323 |
+
continue # should not happen, invalid draw claim
|
| 324 |
+
else:
|
| 325 |
+
move = chess.Move.from_uci(move_uci)
|
| 326 |
+
new_board.push(move)
|
| 327 |
+
|
| 328 |
+
next_boards.append(new_board)
|
| 329 |
+
next_board_probs.append(board_prob*move_prob)
|
| 330 |
+
next_board_to_root_move.append(root_move)
|
| 331 |
+
|
| 332 |
+
current_boards = next_boards
|
| 333 |
+
board_probs = next_board_probs
|
| 334 |
+
board_to_root_move = next_board_to_root_move
|
| 335 |
+
|
| 336 |
+
# Step 2: Evaluate all search leaves
|
| 337 |
+
if search_leaves:
|
| 338 |
+
search_boards = [leaf[2] for leaf in search_leaves]
|
| 339 |
+
search_fens = [b.fen() for b in search_boards]
|
| 340 |
+
search_reps = self.compute_repetition(search_boards)
|
| 341 |
+
|
| 342 |
+
with torch.no_grad():
|
| 343 |
+
_, search_values = self.model(search_fens, search_reps)
|
| 344 |
+
|
| 345 |
+
for i, (root_move, prob, leaf_board) in enumerate(search_leaves):
|
| 346 |
+
value = search_values[i].item()
|
| 347 |
+
white_perspective_value = value if leaf_board.turn == chess.WHITE else -value
|
| 348 |
+
terminal_leaves.append((root_move,prob,white_perspective_value))
|
| 349 |
+
|
| 350 |
+
# Step 3: Aggregate all leaves using probability weights
|
| 351 |
+
root_move_weighted_values = {}
|
| 352 |
+
root_move_total_probs = {}
|
| 353 |
+
for root_move, prob, value in terminal_leaves:
|
| 354 |
+
if root_move not in root_move_weighted_values:
|
| 355 |
+
root_move_weighted_values[root_move] = 0.0
|
| 356 |
+
root_move_total_probs[root_move] = 0.0
|
| 357 |
+
root_move_weighted_values[root_move] += prob * value
|
| 358 |
+
root_move_total_probs[root_move] += prob
|
| 359 |
+
|
| 360 |
+
final_value = sum(root_move_weighted_values.values())
|
| 361 |
+
final_value = final_value if board.turn == chess.WHITE else -final_value
|
| 362 |
+
|
| 363 |
+
root_move_values = {}
|
| 364 |
+
for root_move in root_move_total_probs:
|
| 365 |
+
if root_move_total_probs[root_move] > 0:
|
| 366 |
+
root_move_values[root_move] = root_move_weighted_values[root_move] / root_move_total_probs[root_move]
|
| 367 |
+
else:
|
| 368 |
+
root_move_values[root_move] = 0
|
| 369 |
+
|
| 370 |
+
# Step 4: Confidence-based weighting with search results
|
| 371 |
+
initial_probs = torch.softmax(initial_masked_logits,dim=0)
|
| 372 |
+
entropy = -torch.sum(initial_probs*torch.log(initial_probs+1e-8))
|
| 373 |
+
max_entropy = math.log(torch.sum(initial_invalid_mask==0).item())
|
| 374 |
+
confidence = 1.0 - (entropy/max_entropy) if max_entropy > 0 else 1.0
|
| 375 |
+
|
| 376 |
+
if root_move_values:
|
| 377 |
+
search_adjustment_logits = torch.zeros_like(initial_masked_logits)
|
| 378 |
+
for move_uci,search_value in root_move_values.items():
|
| 379 |
+
move_idx = UCI_MOVE_TO_IDX[move_uci]
|
| 380 |
+
search_adjustment_logits[move_idx] += search_value
|
| 381 |
+
# flip value according to perpective
|
| 382 |
+
search_adjustment_logits = search_adjustment_logits if board.turn==chess.WHITE else -search_adjustment_logits
|
| 383 |
+
search_adjustment_logits = search_adjustment_logits - search_adjustment_logits.mean()
|
| 384 |
+
|
| 385 |
+
# Normalize search logits to be in the same norm as the initial logits
|
| 386 |
+
|
| 387 |
+
initial_valid_norm = torch.norm(initial_masked_logits[initial_top_k_indices]) + 1e-8
|
| 388 |
+
search_valid_norm = torch.norm(search_adjustment_logits[initial_top_k_indices]) + 1e-8
|
| 389 |
+
|
| 390 |
+
normalized_search = search_adjustment_logits * initial_valid_norm / search_valid_norm
|
| 391 |
+
normalized_initial = initial_masked_logits
|
| 392 |
+
|
| 393 |
+
adjusted_logits = confidence * normalized_initial + (1 - confidence) * normalized_search
|
| 394 |
+
else:
|
| 395 |
+
adjusted_logits = initial_masked_logits
|
| 396 |
+
|
| 397 |
+
# Apply temperature and top-k filtering
|
| 398 |
+
top_k_mask = torch.full_like(adjusted_logits, -torch.inf)
|
| 399 |
+
top_k_mask[initial_top_k_indices] = 0
|
| 400 |
+
adjusted_logits = adjusted_logits + top_k_mask
|
| 401 |
+
adjusted_logits = adjusted_logits / self.temperature
|
| 402 |
+
|
| 403 |
+
dist = Categorical(logits=adjusted_logits)
|
| 404 |
+
move_idx = dist.sample().item()
|
| 405 |
+
move_uci = IDX_TO_UCI_MOVE[move_idx]
|
| 406 |
+
|
| 407 |
+
if return_perplexity:
|
| 408 |
+
final_probs = torch.softmax(adjusted_logits,dim=0)
|
| 409 |
+
perplexity = torch.exp(-torch.sum(final_probs * torch.log(final_probs + 1e-8))).item()
|
| 410 |
+
|
| 411 |
+
if verbose and self.depth > 0:
|
| 412 |
+
print(f"\n--- Search Enhanced Move Debug Info ({board.fen()}) ---")
|
| 413 |
+
print(f"Confidence: {confidence:.4f}")
|
| 414 |
+
|
| 415 |
+
print("\nMove Analysis (Initial Top-K Candidates):")
|
| 416 |
+
print(f"{'Move':<8} {'Initial Logit':<15} {'Search Adj. Logit':<19} {'Final Adj. Logit':<18} {'Final Prob':<12}")
|
| 417 |
+
print(f"{'-'*8:<8} {'-'*15:<15} {'-'*19:<19} {'-'*18:<18} {'-'*12:<12}")
|
| 418 |
+
|
| 419 |
+
for i, idx in enumerate(initial_top_k_indices):
|
| 420 |
+
move_uci_v = IDX_TO_UCI_MOVE[idx.item()]
|
| 421 |
+
initial_logit = normalized_initial[idx].item()
|
| 422 |
+
|
| 423 |
+
search_adj_logit_val = normalized_search[idx].item() if root_move_values else 0.0
|
| 424 |
+
|
| 425 |
+
final_adj_logit = adjusted_logits[idx].item()
|
| 426 |
+
final_prob_val = final_probs[idx].item()
|
| 427 |
+
|
| 428 |
+
print(f"{move_uci_v:<8} {initial_logit:<15.4f} {search_adj_logit_val:<19.4f} {final_adj_logit:<18.4f} {final_prob_val:<12.4f}")
|
| 429 |
+
|
| 430 |
+
print(f"\nPerplexity: {perplexity:.4f}")
|
| 431 |
+
print(f"Predicted Value (White's POV): {final_value:.4f}")
|
| 432 |
+
|
| 433 |
+
print("\nLeaf Node Values (Root Move, Probability, Value from White's POV):")
|
| 434 |
+
for rm, prob, val in terminal_leaves:
|
| 435 |
+
print(f" Root Move: {rm:<8}, Prob: {prob:<.4f}, Value: {val:<.4f}")
|
| 436 |
+
print("--------------------------------------------------")
|
| 437 |
+
|
| 438 |
+
return move_uci, final_value, perplexity
|
| 439 |
+
else:
|
| 440 |
+
return move_uci, final_value
|
| 441 |
+
|
| 442 |
+
def _chessformer_move(self, board: chess.Board, return_perplexity: bool=False, verbose: bool=False) -> Tuple[str,float]:
|
| 443 |
+
"""Get move from chessformer with optional search enhance"""
|
| 444 |
+
if self.depth == 0:
|
| 445 |
+
return self._raw_chessformer_move(board,return_perplexity)
|
| 446 |
+
else:
|
| 447 |
+
return self._search_enhanced_move(board,return_perplexity,verbose)
|
| 448 |
+
|
| 449 |
+
def _stockfish_move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str,float]:
|
| 450 |
+
"""Get best move from stockfish"""
|
| 451 |
+
try:
|
| 452 |
+
engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
|
| 453 |
+
info = engine.analyse(board, chess.engine.Limit(depth=self.depth))
|
| 454 |
+
except (chess.engine.EngineError, chess.engine.EngineTerminatedError) as e:
|
| 455 |
+
print(f"Stockfish error: {e}")
|
| 456 |
+
return None
|
| 457 |
+
|
| 458 |
+
loss_threshold = -0.4
|
| 459 |
+
|
| 460 |
+
score_obj = info.get("score")
|
| 461 |
+
can_claim_draw = board.can_claim_draw()
|
| 462 |
+
if score_obj is None or info.get("pv") is None or not info.get("pv"):
|
| 463 |
+
# Invalid analysis result
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
pv = info["pv"]
|
| 467 |
+
pov_score = score_obj.pov(chess.WHITE)
|
| 468 |
+
cp = None
|
| 469 |
+
if pov_score.is_mate():
|
| 470 |
+
mate_score = pov_score.mate()
|
| 471 |
+
cp = 10000.0 if mate_score > 0 else -10000.0
|
| 472 |
+
relative_score = score_obj.relative
|
| 473 |
+
if relative_score.is_mate():
|
| 474 |
+
cp = 10000.0 if relative_score.mate() > 0 else -10000.0
|
| 475 |
+
else:
|
| 476 |
+
if relative_score.cp is not None:
|
| 477 |
+
cp = float(relative_score.cp)
|
| 478 |
+
else:
|
| 479 |
+
return None
|
| 480 |
+
|
| 481 |
+
elif pov_score.cp is not None:
|
| 482 |
+
relative_score = score_obj.relative
|
| 483 |
+
if relative_score.cp is not None:
|
| 484 |
+
cp = float(relative_score.cp)
|
| 485 |
+
else:
|
| 486 |
+
return None
|
| 487 |
+
|
| 488 |
+
else:
|
| 489 |
+
return None
|
| 490 |
+
|
| 491 |
+
if cp is not None:
|
| 492 |
+
normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
|
| 493 |
+
else:
|
| 494 |
+
return None
|
| 495 |
+
|
| 496 |
+
if can_claim_draw and normalized_score < loss_threshold:
|
| 497 |
+
best_move_uci = "<claim_draw>"
|
| 498 |
+
else:
|
| 499 |
+
best_move_uci = pv[0].uci()
|
| 500 |
+
|
| 501 |
+
if engine:
|
| 502 |
+
engine.quit()
|
| 503 |
+
|
| 504 |
+
if return_perplexity:
|
| 505 |
+
return best_move_uci, normalized_score, None
|
| 506 |
+
else:
|
| 507 |
+
return best_move_uci, normalized_score
|
| 508 |
+
|
| 509 |
+
def _batch_chessformer_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 510 |
+
"""Get the next moves from Chessformer model using batch inference."""
|
| 511 |
+
bs = len(boards)
|
| 512 |
+
if bs > self.max_batch_size:
|
| 513 |
+
raise ValueError(f"num boards ({bs}) exceeded max batch size ({self.max_batch_size}).")
|
| 514 |
+
fens = [board.fen() for board in boards]
|
| 515 |
+
|
| 516 |
+
count_tensor = self.compute_repetition(boards) # shape (bs, 1)
|
| 517 |
+
count_tensor = count_tensor.to(self.device)
|
| 518 |
+
|
| 519 |
+
with torch.no_grad():
|
| 520 |
+
move_logits, values = self.model(fens, count_tensor)
|
| 521 |
+
|
| 522 |
+
invalid_mask = self.get_invalid_mask(boards)
|
| 523 |
+
|
| 524 |
+
# Apply mask
|
| 525 |
+
move_logits = move_logits + invalid_mask
|
| 526 |
+
|
| 527 |
+
all_masked = torch.all(torch.isinf(move_logits), dim=-1)
|
| 528 |
+
|
| 529 |
+
# Apply top-p filtering
|
| 530 |
+
if 0.0 < self.top_p < 1.0: # Apply only if top_p is strictly between 0 and 1
|
| 531 |
+
sorted_logits, sorted_indices = torch.sort(move_logits, descending=True, dim=-1)
|
| 532 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 533 |
+
sorted_indices_to_remove = cumulative_probs > self.top_p
|
| 534 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 535 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 536 |
+
indices_to_remove = torch.zeros_like(move_logits, dtype=torch.bool).scatter_(
|
| 537 |
+
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
| 538 |
+
)
|
| 539 |
+
move_logits[indices_to_remove] = -torch.inf
|
| 540 |
+
|
| 541 |
+
# Apply temperature
|
| 542 |
+
temp = self.temperature if self.temperature > 0 else 1.0
|
| 543 |
+
move_logits = move_logits / temp
|
| 544 |
+
|
| 545 |
+
# Sample moves
|
| 546 |
+
dist = Categorical(logits=move_logits)
|
| 547 |
+
try:
|
| 548 |
+
sampled_indices = dist.sample()
|
| 549 |
+
except RuntimeError as e:
|
| 550 |
+
print(f"Error sampling moves: {e}. Checking logit values...")
|
| 551 |
+
results = []
|
| 552 |
+
for i in range(bs):
|
| 553 |
+
print(f"Board {i} logits sum: {torch.logsumexp(move_logits[i], dim=-1)}")
|
| 554 |
+
results.append(None) # indicate failure
|
| 555 |
+
return results
|
| 556 |
+
|
| 557 |
+
results = []
|
| 558 |
+
for i in range(bs):
|
| 559 |
+
if all_masked[i]:
|
| 560 |
+
results.append(None) # Game already over
|
| 561 |
+
continue
|
| 562 |
+
|
| 563 |
+
move_id = sampled_indices[i].item()
|
| 564 |
+
move_uci = IDX_TO_UCI_MOVE.get(move_id)
|
| 565 |
+
value = values[i].item()
|
| 566 |
+
|
| 567 |
+
if move_uci is None:
|
| 568 |
+
print(f"Warning: Sampled move ID {move_id} not in IDX_TO_UCI_MOVE map")
|
| 569 |
+
results.append(None)
|
| 570 |
+
continue
|
| 571 |
+
|
| 572 |
+
results.append((move_uci, value))
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
return results
|
| 576 |
+
|
| 577 |
+
def _batch_stockfish_move(self, boards: List[chess.Board], allow_claim_draw: bool=False) -> List[Tuple[str, float]]:
|
| 578 |
+
"""Get the next moves from Stockfish engine using multiprocessing"""
|
| 579 |
+
if allow_claim_draw:
|
| 580 |
+
"""Use sequential processing to maintain board history"""
|
| 581 |
+
return [self._stockfish_move(board) for board in boards]
|
| 582 |
+
else:
|
| 583 |
+
"""Use multiprocessing to speed up if no need to include claim draw logic"""
|
| 584 |
+
bs = len(boards)
|
| 585 |
+
num_workers = min(bs, max(1, os.cpu_count()//2 if os.cpu_count() else 1))
|
| 586 |
+
if bs < num_workers * 2:
|
| 587 |
+
num_workers = max(1, bs//2)
|
| 588 |
+
if bs == 1: num_workers = 1
|
| 589 |
+
|
| 590 |
+
board_fens = [board.fen() for board in boards]
|
| 591 |
+
|
| 592 |
+
worker_func = partial(_stockfish_worker,
|
| 593 |
+
engine_path=self.engine_path,
|
| 594 |
+
depth=self.depth)
|
| 595 |
+
results: List[Optional[Tuple[str,float]]] = [None] * bs
|
| 596 |
+
|
| 597 |
+
active_indices = [i for i,b in enumerate(boards) if not b.is_game_over(claim_draw=True)]
|
| 598 |
+
active_fens = [board_fens[i] for i in active_indices]
|
| 599 |
+
|
| 600 |
+
if not active_fens:
|
| 601 |
+
# All games are over
|
| 602 |
+
return results # list of None
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
if num_workers > 1 and len(active_fens) > 1:
|
| 606 |
+
with multiprocessing.Pool(processes=num_workers) as pool:
|
| 607 |
+
worker_results = pool.map(worker_func, active_fens)
|
| 608 |
+
else:
|
| 609 |
+
worker_results = [worker_func(fen) for fen in active_fens]
|
| 610 |
+
|
| 611 |
+
for i, res in enumerate(worker_results):
|
| 612 |
+
original_index = active_indices[i]
|
| 613 |
+
results[original_index] = res
|
| 614 |
+
|
| 615 |
+
except Exception as e:
|
| 616 |
+
print(f"Error during batch Stockfish move: {e}")
|
| 617 |
+
|
| 618 |
+
return results
|
| 619 |
+
|
| 620 |
+
def move(self, board: chess.Board, return_perplexity: bool=False) -> Tuple[str, float]:
|
| 621 |
+
if self.type == "stockfish":
|
| 622 |
+
return self._stockfish_move(board, return_perplexity)
|
| 623 |
+
elif self.type == "chessformer":
|
| 624 |
+
return self._chessformer_move(board, return_perplexity)
|
| 625 |
+
else:
|
| 626 |
+
raise ValueError(f"Invalid engine type: {self.type}")
|
| 627 |
+
|
| 628 |
+
def batch_move(self, boards: List[chess.Board]) -> List[Tuple[str, float]]:
|
| 629 |
+
if self.type == "stockfish":
|
| 630 |
+
return self._batch_stockfish_move(boards)
|
| 631 |
+
elif self.type == "chessformer":
|
| 632 |
+
return self._batch_chessformer_move(boards)
|
| 633 |
+
else:
|
| 634 |
+
raise ValueError(f"Invalid engine type: {self.type}")
|
| 635 |
+
|
| 636 |
+
def analyze_position(self, board: chess.Board) -> Optional[float]:
|
| 637 |
+
"""
|
| 638 |
+
Analyzes the given **single board** position using the engine.
|
| 639 |
+
For Stockfish, returns list of centipawn scores from white's perspective;
|
| 640 |
+
For ChessFormer, returns list of models's value estimates
|
| 641 |
+
Returns None if analysis failed.
|
| 642 |
+
"""
|
| 643 |
+
if self.type == "stockfish":
|
| 644 |
+
try:
|
| 645 |
+
engine = chess.engine.SimpleEngine.popen_uci(self.engine_path)
|
| 646 |
+
info = engine.analyse(board,chess.engine.Limit(depth=self.depth))
|
| 647 |
+
engine.quit()
|
| 648 |
+
except Exception as e:
|
| 649 |
+
print(f"Stockfish error: {e}")
|
| 650 |
+
return None
|
| 651 |
+
|
| 652 |
+
score_obj = info.get("score")
|
| 653 |
+
pov_score = score_obj.pov(chess.WHITE)
|
| 654 |
+
cp = None
|
| 655 |
+
if pov_score.is_mate():
|
| 656 |
+
mate_score = pov_score.mate()
|
| 657 |
+
cp = 10000.0 if mate_score > 0 else -10000.0
|
| 658 |
+
relative_score = score_obj.relative
|
| 659 |
+
if relative_score.is_mate():
|
| 660 |
+
cp = 10000.0 if relative_score.mate() > 0 else -10000.0
|
| 661 |
+
else:
|
| 662 |
+
if relative_score.cp is not None:
|
| 663 |
+
cp = float(relative_score.cp)
|
| 664 |
+
else:
|
| 665 |
+
return None
|
| 666 |
+
elif pov_score.cp is not None:
|
| 667 |
+
relative_score = score_obj.relative
|
| 668 |
+
if relative_score.cp is not None:
|
| 669 |
+
cp = float(relative_score.cp)
|
| 670 |
+
else:
|
| 671 |
+
return None
|
| 672 |
+
else:
|
| 673 |
+
return None
|
| 674 |
+
|
| 675 |
+
if cp is not None:
|
| 676 |
+
normalized_score = 2 / (1+math.exp(-0.004*cp)) - 1
|
| 677 |
+
return normalized_score if board.turn == chess.WHITE else -normalized_score
|
| 678 |
+
else:
|
| 679 |
+
return None
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
elif self.type == "chessformer":
|
| 683 |
+
fen = board.fen()
|
| 684 |
+
count_tensor = self.compute_repetition([board.copy(stack=True)])
|
| 685 |
+
|
| 686 |
+
with torch.no_grad():
|
| 687 |
+
_, value = self.model([fen],count_tensor)
|
| 688 |
+
|
| 689 |
+
value = value.item()
|
| 690 |
+
return value if board.turn == chess.WHITE else -value
|
| 691 |
+
|
| 692 |
+
else:
|
| 693 |
+
raise ValueError(f"Invalid engine type.")
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def test_search_enhanced_move(model_path,device):
|
| 697 |
+
"""Test the search-enhanced move functionality"""
|
| 698 |
+
print("\n--- Testing Search-Enhanced ChessFormer ---")
|
| 699 |
+
|
| 700 |
+
import sys
|
| 701 |
+
sys.path.append("./")
|
| 702 |
+
try:
|
| 703 |
+
from model import ChessFormerModel
|
| 704 |
+
except ImportError:
|
| 705 |
+
from model import ChessFormerModel
|
| 706 |
+
|
| 707 |
+
# Load the trained model
|
| 708 |
+
checkpoint = torch.load(model_path,map_location=device)
|
| 709 |
+
config = checkpoint["config"]
|
| 710 |
+
model = ChessFormerModel(**config)
|
| 711 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 712 |
+
|
| 713 |
+
model.to(device)
|
| 714 |
+
|
| 715 |
+
# Test different search configurations
|
| 716 |
+
test_configs = [
|
| 717 |
+
#{"depth": 0, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2}, # No search (baseline)
|
| 718 |
+
#{"depth": 1, "top_k": 8, "decay_rate": 0.6, "temperature": 0.2}, # Shallow search
|
| 719 |
+
{"depth": 8, "top_k": 8, "decay_rate": 0.5, "temperature": 0.5}, # Medium search
|
| 720 |
+
]
|
| 721 |
+
|
| 722 |
+
# Test positions
|
| 723 |
+
test_positions = [
|
| 724 |
+
#chess.Board(), # Starting position
|
| 725 |
+
#chess.Board("r1bqkb1r/pppp1ppp/2n2n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4"), # Italian Game
|
| 726 |
+
#chess.Board("rnbqkbnr/pp1ppppp/8/2p5/4P3/8/PPPP1PPP/RNBQKBNR w KQkq c6 0 2"), # Sicilian Defense
|
| 727 |
+
#chess.Board("r1bq1rk1/ppp2ppp/2n2n2/2bpp3/2B1P3/3P1N2/PPP2PPP/RNBQ1RK1 w - - 0 6"), # Complex middlegame
|
| 728 |
+
chess.Board("r1b1k2r/1p2bpp1/2p1p1np/2N1P3/1q1P4/5N2/B1Q2PPP/R3R1K1 w kq - 0 19"), # blunder: c2e4
|
| 729 |
+
chess.Board("rn1qk2r/1b2bpp1/1pp1pn1p/p7/3P4/2PB1N2/PP1NQPPP/R1B1R1K1 w kq - 2 12"), # blunder: e2e6
|
| 730 |
+
]
|
| 731 |
+
|
| 732 |
+
for i, cfg in enumerate(test_configs):
|
| 733 |
+
print(f"\n--- Test Configuration {i+1}: Depth={cfg['depth']}, Top-K={cfg['top_k']}, Decay={cfg['decay_rate']}, Temp={cfg['temperature']} ---")
|
| 734 |
+
chessformer_config = ChessformerConfig(
|
| 735 |
+
chessformer=model,
|
| 736 |
+
device=device,
|
| 737 |
+
temperature=cfg['temperature'],
|
| 738 |
+
depth=cfg['depth'],
|
| 739 |
+
top_k=cfg['top_k'],
|
| 740 |
+
decay_rate=cfg['decay_rate']
|
| 741 |
+
)
|
| 742 |
+
engine = Engine(type="chessformer", chessformer_config=chessformer_config)
|
| 743 |
+
|
| 744 |
+
for j, board in enumerate(test_positions):
|
| 745 |
+
print(f"\n--- Analyzing Position {j+1}: {board.fen()} ---")
|
| 746 |
+
try:
|
| 747 |
+
move, value, perplexity = engine._chessformer_move(board, return_perplexity=True, verbose=True)
|
| 748 |
+
print(f"Selected Move: {move}, Predicted Value (White's POV): {value:.4f}, Perplexity: {perplexity:.4f}")
|
| 749 |
+
except Exception as e:
|
| 750 |
+
print(f"Error analyzing position {board.fen()}: {e}")
|
| 751 |
+
import traceback
|
| 752 |
+
traceback.print_exc()
|
| 753 |
+
|
| 754 |
+
if __name__ == "__main__":
|
| 755 |
+
model_path = "./ckpts/chessformer-sl_01.pth"
|
| 756 |
+
device = torch.device("cpu")
|
| 757 |
+
test_search_enhanced_move(model_path,device)
|
| 758 |
+
|
| 759 |
+
|
utils/mapping.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Tuple, Set
|
| 2 |
+
|
| 3 |
+
# --- Constants --- #
|
| 4 |
+
MAX_HALFMOVES = 128 # cap for embedding table size
|
| 5 |
+
MAX_FULLMOVES = 256 # cap for embedding table size
|
| 6 |
+
|
| 7 |
+
# --- Helper Mappings --- #
|
| 8 |
+
PIECE_TO_IDX: Dict[str, int] = {
|
| 9 |
+
'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
|
| 10 |
+
'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11,
|
| 11 |
+
'.': 12
|
| 12 |
+
}
|
| 13 |
+
IDX_TO_PIECE: Dict[int, str] = {v: k for k, v in PIECE_TO_IDX.items()}
|
| 14 |
+
EMPTY_SQ_IDX = PIECE_TO_IDX['.']
|
| 15 |
+
# Map algebraic square notation (e.g., 'a1', 'h8') to 0-63 index
|
| 16 |
+
# a1=0, b1=1, ..., h1=7, a2=8, ..., h8=63
|
| 17 |
+
SQUARE_TO_IDX: Dict[str, int] = {
|
| 18 |
+
f"{file}{rank}": (rank - 1) * 8 + (ord(file) - ord('a'))
|
| 19 |
+
for rank in range(1, 9)
|
| 20 |
+
for file in 'abcdefgh'
|
| 21 |
+
}
|
| 22 |
+
IDX_TO_SQUARE: Dict[int, str] = {v: k for k, v in SQUARE_TO_IDX.items()}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# --- Coordinate and Notation Helpers ---
|
| 27 |
+
|
| 28 |
+
# Precompute maps for efficiency
|
| 29 |
+
_IDX_TO_COORDS: Dict[int, Tuple[int, int]] = {i: (i // 8, i % 8) for i in range(64)} # (rank, file) 0-7
|
| 30 |
+
_COORDS_TO_IDX: Dict[Tuple[int, int], int] = {v: k for k, v in _IDX_TO_COORDS.items()}
|
| 31 |
+
_IDX_TO_ALG: Dict[int, str] = {
|
| 32 |
+
i: f"{chr(ord('a') + file)}{rank + 1}"
|
| 33 |
+
for i, (rank, file) in _IDX_TO_COORDS.items()
|
| 34 |
+
}
|
| 35 |
+
_ALG_TO_IDX: Dict[str, int] = {v: k for k, v in _IDX_TO_ALG.items()}
|
| 36 |
+
|
| 37 |
+
def _coords_to_alg(r: int, f: int) -> str:
|
| 38 |
+
"""Converts 0-indexed (rank, file) to algebraic notation."""
|
| 39 |
+
if 0 <= r < 8 and 0 <= f < 8:
|
| 40 |
+
return f"{chr(ord('a') + f)}{r + 1}"
|
| 41 |
+
# This should not happen with valid indices, but good for safety
|
| 42 |
+
raise ValueError(f"Invalid coordinates: ({r}, {f})")
|
| 43 |
+
|
| 44 |
+
def generate_structurally_valid_move_map() -> Dict[str, int]:
|
| 45 |
+
"""
|
| 46 |
+
Generates a dictionary mapping chess moves that are geometrically possible
|
| 47 |
+
by *some* standard piece (K, Q, R, B, N, or P) to unique integer indices.
|
| 48 |
+
It excludes moves that are structurally impossible for any piece to make
|
| 49 |
+
in one turn (e.g., a1->h5 for non-knight).
|
| 50 |
+
|
| 51 |
+
Includes standard UCI promotions (e.g., "e7e8q"), replacing the
|
| 52 |
+
corresponding simple pawn move to the final rank (e.g., "e7e8").
|
| 53 |
+
This is based purely on piece movement geometry, not the current board state.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Dict[str, int]: A map from the valid UCI move string to a unique
|
| 57 |
+
integer index (0 to N-1). The size N is expected
|
| 58 |
+
to be around 1800-1900.
|
| 59 |
+
"""
|
| 60 |
+
valid_moves: Set[str] = set()
|
| 61 |
+
# Keep track of base moves (like 'e7e8') that are replaced by promotions
|
| 62 |
+
# according to UCI standard.
|
| 63 |
+
promo_base_moves_to_exclude: Set[str] = set()
|
| 64 |
+
|
| 65 |
+
# 1. Generate all geometrically possible non-promotion moves
|
| 66 |
+
for from_idx in range(64):
|
| 67 |
+
from_r, from_f = _IDX_TO_COORDS[from_idx]
|
| 68 |
+
from_alg = _IDX_TO_ALG[from_idx]
|
| 69 |
+
|
| 70 |
+
for to_idx in range(64):
|
| 71 |
+
if from_idx == to_idx:
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
to_r, to_f = _IDX_TO_COORDS[to_idx]
|
| 75 |
+
to_alg = _IDX_TO_ALG[to_idx]
|
| 76 |
+
dr, df = to_r - from_r, to_f - from_f
|
| 77 |
+
abs_dr, abs_df = abs(dr), abs(df)
|
| 78 |
+
|
| 79 |
+
# Check if the geometry matches any standard piece movement
|
| 80 |
+
# Note: Queen moves are covered by Rook + Bishop checks.
|
| 81 |
+
# Note: Pawn single pushes/captures are covered by King/Rook/Bishop geometry.
|
| 82 |
+
# Note: Pawn double pushes are covered by Rook geometry.
|
| 83 |
+
is_king_move = max(abs_dr, abs_df) == 1
|
| 84 |
+
is_knight_move = (abs_dr == 2 and abs_df == 1) or (abs_dr == 1 and abs_df == 2)
|
| 85 |
+
is_rook_move = dr == 0 or df == 0 # Includes King horiz/vert & pawn double push
|
| 86 |
+
is_bishop_move = abs_dr == abs_df # Includes King diagonal & pawn capture/push
|
| 87 |
+
|
| 88 |
+
if is_king_move or is_knight_move or is_rook_move or is_bishop_move:
|
| 89 |
+
uci_move = f"{from_alg}{to_alg}"
|
| 90 |
+
valid_moves.add(uci_move)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# 2. Generate promotion moves explicitly and mark base moves for exclusion
|
| 94 |
+
promo_pieces = ['q', 'r', 'b', 'n']
|
| 95 |
+
for from_f in range(8):
|
| 96 |
+
# White promotions (from rank 7 (idx 6) to rank 8 (idx 7))
|
| 97 |
+
from_r_w, to_r_w = 6, 7
|
| 98 |
+
if from_r_w != 7: # Ensure we are on the correct rank before promotion
|
| 99 |
+
from_alg_w = _coords_to_alg(from_r_w, from_f)
|
| 100 |
+
# Possible destinations: push (df=0), capture left (df=-1), capture right (df=1)
|
| 101 |
+
for df in [-1, 0, 1]:
|
| 102 |
+
to_f_w = from_f + df
|
| 103 |
+
if 0 <= to_f_w < 8:
|
| 104 |
+
to_alg_w = _coords_to_alg(to_r_w, to_f_w)
|
| 105 |
+
base_move = f"{from_alg_w}{to_alg_w}"
|
| 106 |
+
#promo_base_moves_to_exclude.add(base_move) # Mark e.g. "e7e8" for exclusion
|
| 107 |
+
for p in promo_pieces:
|
| 108 |
+
valid_moves.add(f"{base_move}{p}") # Add e.g. "e7e8q"
|
| 109 |
+
|
| 110 |
+
# Black promotions (from rank 2 (idx 1) to rank 1 (idx 0))
|
| 111 |
+
from_r_b, to_r_b = 1, 0
|
| 112 |
+
if from_r_b != 0: # Ensure we are on the correct rank before promotion
|
| 113 |
+
from_alg_b = _coords_to_alg(from_r_b, from_f)
|
| 114 |
+
# Possible destinations: push (df=0), capture left (df=-1), capture right (df=1)
|
| 115 |
+
for df in [-1, 0, 1]:
|
| 116 |
+
to_f_b = from_f + df
|
| 117 |
+
if 0 <= to_f_b < 8:
|
| 118 |
+
to_alg_b = _coords_to_alg(to_r_b, to_f_b)
|
| 119 |
+
base_move = f"{from_alg_b}{to_alg_b}"
|
| 120 |
+
#promo_base_moves_to_exclude.add(base_move) # Mark e.g. "e2e1" for exclusion
|
| 121 |
+
for p in promo_pieces:
|
| 122 |
+
valid_moves.add(f"{base_move}{p}") # Add e.g. "e2e1q"
|
| 123 |
+
|
| 124 |
+
# 3. Remove the base moves that were replaced by promotions
|
| 125 |
+
final_valid_moves = valid_moves - promo_base_moves_to_exclude
|
| 126 |
+
|
| 127 |
+
# 4. Add draw claim
|
| 128 |
+
final_valid_moves.add("<claim_draw>")
|
| 129 |
+
|
| 130 |
+
# 5. Create the final map with sorted keys for deterministic indices
|
| 131 |
+
sorted_moves = sorted(list(final_valid_moves))
|
| 132 |
+
move_map = {move: i for i, move in enumerate(sorted_moves)}
|
| 133 |
+
|
| 134 |
+
# Optional: Print the number of moves found for verification
|
| 135 |
+
# print(f"Generated {len(move_map)} structurally valid unique UCI moves.")
|
| 136 |
+
|
| 137 |
+
return move_map
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
UCI_MOVE_TO_IDX = generate_structurally_valid_move_map()
|
| 141 |
+
IDX_TO_UCI_MOVE = {v:k for k,v in UCI_MOVE_TO_IDX.items()}
|