noamdwc commited on
Commit
60fd122
·
1 Parent(s): f488ed9

Switch Space to Docker + FastAPI

Browse files
Files changed (46) hide show
  1. .dockerignore +8 -0
  2. Dockerfile +16 -0
  3. README.md +1 -3
  4. app.py +22 -42
  5. hf_space_repo/README.md +428 -0
  6. hf_space_repo/__init__.py +30 -0
  7. hf_space_repo/chess/__init__.py +0 -0
  8. hf_space_repo/chess/boards_dataset.py +465 -0
  9. hf_space_repo/chess/chess_logic.py +63 -0
  10. hf_space_repo/chess/policy_player.py +98 -0
  11. hf_space_repo/chess/rewards.py +108 -0
  12. hf_space_repo/chess/searcher.py +90 -0
  13. hf_space_repo/chess/stockfish.py +288 -0
  14. hf_space_repo/configs/__init__.py +43 -0
  15. hf_space_repo/configs/config_loader.py +290 -0
  16. hf_space_repo/configs/default.yaml +123 -0
  17. hf_space_repo/configs/pretrain.yaml +49 -0
  18. hf_space_repo/constants.py +15 -0
  19. hf_space_repo/eval_utils.py +211 -0
  20. hf_space_repo/evaluator.py +118 -0
  21. hf_space_repo/grpo_logic/__init__.py +0 -0
  22. hf_space_repo/grpo_logic/loss.py +235 -0
  23. hf_space_repo/grpo_logic/model.py +782 -0
  24. hf_space_repo/grpo_logic/sampling.py +243 -0
  25. hf_space_repo/logging_utils.py +32 -0
  26. hf_space_repo/models.py +234 -0
  27. hf_space_repo/pretrain/README.md +153 -0
  28. hf_space_repo/pretrain/__init__.py +15 -0
  29. hf_space_repo/pretrain/pretrain.py +579 -0
  30. hf_space_repo/pretrain/pretrain_dataset.py +328 -0
  31. hf_space_repo/pretrain/pretrain_load_config.py +21 -0
  32. hf_space_repo/searchless_chess_imports.py +3 -0
  33. hf_space_repo/searchless_chess_model/.gitattributes +35 -0
  34. hf_space_repo/searchless_chess_model/README.md +177 -0
  35. hf_space_repo/searchless_chess_model/config.json +10 -0
  36. hf_space_repo/searchless_chess_model/model_info.json +13 -0
  37. hf_space_repo/searchless_chess_model/searchless_chess_code/__init__.py +1 -0
  38. hf_space_repo/searchless_chess_model/searchless_chess_code/config.py +90 -0
  39. hf_space_repo/searchless_chess_model/searchless_chess_code/constants.py +119 -0
  40. hf_space_repo/searchless_chess_model/searchless_chess_code/hf_model.py +329 -0
  41. hf_space_repo/searchless_chess_model/searchless_chess_code/tokenizer.py +116 -0
  42. hf_space_repo/searchless_chess_model/searchless_chess_code/transformer.py +284 -0
  43. hf_space_repo/searchless_chess_model/searchless_chess_code/utils.py +162 -0
  44. hf_space_repo/train_self_play.py +72 -0
  45. hf_space_repo/trainer.py +74 -0
  46. requirements.txt +1 -5
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .pytest_cache/
6
+ .git/
7
+ .DS_Store
8
+ node_modules/
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
5
+ ENV PORT=7860
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt /app/requirements.txt
10
+ RUN pip install --no-cache-dir -r /app/requirements.txt
11
+
12
+ COPY . /app
13
+
14
+ EXPOSE 7860
15
+
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -15,9 +15,7 @@ title: grpo-chess-api
15
  emoji: ♟️
16
  colorFrom: amber
17
  colorTo: red
18
- sdk: gradio
19
- sdk_version: 4.44.1
20
- app_file: app.py
21
  pinned: false
22
  ---
23
 
 
15
  emoji: ♟️
16
  colorFrom: amber
17
  colorTo: red
18
+ sdk: docker
 
 
19
  pinned: false
20
  ---
21
 
app.py CHANGED
@@ -2,8 +2,9 @@ import os
2
  from pathlib import Path
3
 
4
  import chess
5
- import gradio as gr
6
  import torch
 
 
7
  from huggingface_hub import hf_hub_download
8
  from pydantic import BaseModel
9
  from safetensors.torch import load_file
@@ -69,51 +70,30 @@ def choose_move(model, board: chess.Board, temperature: float, greedy: bool) ->
69
  return move
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def move(req: MoveRequest):
73
- board = chess.Board(req.fen)
 
 
 
74
  model = load_model()
75
  move = choose_move(model, board, req.temperature, req.greedy)
76
  san = board.san(move)
77
  board.push(move)
78
  return MoveResponse(uci=move.uci(), san=san, fen=board.fen())
79
 
80
-
81
- def gradio_move(fen: str, temperature: float, greedy: bool):
82
- req = MoveRequest(fen=fen, temperature=temperature, greedy=greedy)
83
- res = move(req)
84
- return res.uci, res.san, res.fen
85
-
86
-
87
- with gr.Blocks(title="GRPO Chess API") as demo:
88
- gr.Markdown(
89
- "## GRPO Chess Model API\n"
90
- "Use this panel to test the model. The website calls the Gradio API at "
91
- "`/run/move`."
92
- )
93
- fen = gr.Textbox(
94
- label="FEN",
95
- value="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
96
- )
97
- temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
98
- greedy = gr.Checkbox(label="Greedy", value=False)
99
- btn = gr.Button("Get Move")
100
- uci = gr.Textbox(label="UCI Move")
101
- san = gr.Textbox(label="SAN Move")
102
- fen_out = gr.Textbox(label="Next FEN")
103
- btn.click(
104
- gradio_move,
105
- inputs=[fen, temperature, greedy],
106
- outputs=[uci, san, fen_out],
107
- api_name="move",
108
- )
109
-
110
-
111
- app = demo
112
-
113
-
114
- if __name__ == "__main__":
115
- demo.launch(
116
- server_name="0.0.0.0",
117
- server_port=int(os.environ.get("PORT", 7860)),
118
- show_error=True,
119
- )
 
2
  from pathlib import Path
3
 
4
  import chess
 
5
  import torch
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from huggingface_hub import hf_hub_download
9
  from pydantic import BaseModel
10
  from safetensors.torch import load_file
 
70
  return move
71
 
72
 
73
+ app = FastAPI()
74
+ app.add_middleware(
75
+ CORSMiddleware,
76
+ allow_origins=["*"],
77
+ allow_credentials=False,
78
+ allow_methods=["*"],
79
+ allow_headers=["*"],
80
+ )
81
+
82
+
83
+ @app.get("/health")
84
+ def health():
85
+ return {"status": "ok"}
86
+
87
+
88
+ @app.post("/move", response_model=MoveResponse)
89
  def move(req: MoveRequest):
90
+ try:
91
+ board = chess.Board(req.fen)
92
+ except Exception as exc:
93
+ raise HTTPException(status_code=400, detail=f"Invalid FEN: {exc}")
94
  model = load_model()
95
  move = choose_move(model, board, req.temperature, req.greedy)
96
  san = board.san(move)
97
  board.push(move)
98
  return MoveResponse(uci=move.uci(), san=san, fen=board.fen())
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_space_repo/README.md ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GRPO Self-Play Chess Module
2
+
3
+ An experimental, research-grade implementation of **Group Relative Policy Optimization (GRPO)** for training transformer-based chess policies through self-play. This module implements a full reinforcement learning pipeline for chess, but training stability and final strength are still under active investigation.
4
+
5
+ ## Overview
6
+
7
+ This module trains neural network chess policies using GRPO, a variant of Proximal Policy Optimization (PPO) that uses group-based advantage estimation. The system learns to play chess by:
8
+
9
+ 1. **Self-Play**: Sampling multiple trajectory groups from diverse starting positions
10
+ 2. **Reward Computation**: Using Stockfish evaluations to compute dense rewards
11
+ 3. **Policy Optimization**: Applying GRPO with PPO clipping and KL divergence penalties
12
+ 4. **Evaluation**: Comprehensive benchmarking against Stockfish at multiple skill levels
13
+
14
+ ## Key Features
15
+
16
+ ### 🎯 Core Capabilities
17
+
18
+ - **Transformer-Based Policy Network**: Deep neural network architecture that processes FEN-encoded board states
19
+ - **GRPO Training Algorithm**: Group-relative advantage estimation with PPO-style clipping
20
+ - **Self-Play Training Loop**: Infinite dataset of diverse chess positions for robust learning
21
+ - **Stockfish Integration**: Professional-grade evaluation and reward computation
22
+ - **Comprehensive Evaluation**: Multi-level skill ladder evaluation against Stockfish
23
+ - **PyTorch Lightning Integration**: Scalable training with automatic mixed precision, gradient clipping, and checkpointing
24
+ - **Weights & Biases Logging**: Full experiment tracking and visualization
25
+
26
+ ### 🏗️ Architecture Highlights
27
+
28
+ - **Modular Design**: Clean separation between model, training logic, chess rules, and evaluation
29
+ - **Efficient Batching**: Parallel trajectory sampling across multiple board positions
30
+ - **Legal Move Masking**: Proper handling of chess rules with action space masking
31
+ - **Trajectory Search**: Optional trajectory search wrapper for improved play strength
32
+ - **Resource Management**: Efficient Stockfish engine pooling and caching
33
+
34
+ ## Installation
35
+
36
+ ```bash
37
+ # Install dependencies
38
+ pip install torch pytorch-lightning wandb chess python-chess
39
+
40
+ # Ensure Stockfish is available
41
+ # On Ubuntu/Debian: sudo apt-get install stockfish
42
+ # On macOS: brew install stockfish
43
+ # Or download from: https://stockfishchess.org/download/
44
+ ```
45
+
46
+ ## Quick Start
47
+
48
+ ### Basic Training
49
+
50
+ The easiest way to start training is using the YAML-based configuration system:
51
+
52
+ ```python
53
+ from src.grpo_self_play.train_self_play import train
54
+
55
+ # Use default configuration (loads from configs/default.yaml)
56
+ train()
57
+
58
+ # Use a custom config file
59
+ train(config_path="my_experiment.yaml")
60
+
61
+ # Override specific hyperparameters programmatically
62
+ train(
63
+ config_path="default.yaml",
64
+ overrides={
65
+ "grpo": {"lr": 1e-4, "num_trajectories": 8},
66
+ "training": {"num_epochs": 100},
67
+ }
68
+ )
69
+ ```
70
+
71
+ All hyperparameters (learning rate, model architecture, training settings, etc.) are defined in YAML configuration files. See the [Configuration](#configuration) section below for details.
72
+
73
+ ### Running Training in Google Colab
74
+
75
+ **Note for AI agents and contributors**: The primary way this code is run is through the `chess_model_run_git.ipynb` notebook in Google Colab. This notebook is the actual workflow used for training and evaluation.
76
+
77
+ The `chess_model_run_git.ipynb` notebook provides:
78
+
79
+ - **Automated Setup**: Clones the repository, installs dependencies, and downloads the searchless chess model
80
+ - **Complete Configuration**: Pre-configured settings for GRPO training, dataset generation, and evaluation
81
+ - **Phase-Aware Dataset**: Example configuration using `ChessDatasetConfig` with `phase_distribution` for balanced training across opening, middlegame, and endgame positions
82
+ - **Evaluation Pipeline**: Integrated evaluation against Stockfish at multiple skill levels
83
+
84
+ The notebook handles all setup steps including:
85
+ 1. Repository cloning and branch checkout
86
+ 2. Dependency installation (PyTorch Lightning, WandB, python-chess, etc.)
87
+ 3. Downloading the searchless chess model from HuggingFace
88
+ 4. Stockfish installation
89
+ 5. Training configuration with phase-distributed dataset sampling
90
+ 6. Model training and periodic evaluation
91
+
92
+ ### Evaluation
93
+
94
+ ```python
95
+ from src.grpo_self_play import Evaluator, EvalConfig
96
+ from src.grpo_self_play.chess.stockfish import StockfishConfig
97
+
98
+ # Create evaluator
99
+ evaluator = Evaluator(
100
+ eval_cfg=EvalConfig(games=50),
101
+ stockfish_cfg=StockfishConfig(skill_level=10, movetime_ms=100)
102
+ )
103
+
104
+ # Single evaluation
105
+ results, policy = evaluator.single_evaluation(model)
106
+ print(f"Win rate: {results['score']:.2%}")
107
+ print(f"Approx Elo diff: {results['elo_diff_vs_stockfish_approx']:.0f}")
108
+
109
+ # Skill ladder evaluation
110
+ skill_results = evaluator.eval_ladder(model)
111
+ for skill, score in skill_results.items():
112
+ print(f"Skill {skill}: {score:.2%} win rate")
113
+ ```
114
+
115
+ ## Architecture
116
+
117
+ ### Model Architecture
118
+
119
+ The `ChessTransformer` processes chess positions using:
120
+
121
+ - **Input Encoding**: FEN strings tokenized using DeepMind's chess tokenizer
122
+ - **Transformer Encoder**: Multi-head self-attention with learnable positional encodings
123
+ - **Policy Head**: Dense layers outputting logits over 1968 possible moves
124
+ - **Legal Move Masking**: Automatic filtering of illegal moves during inference
125
+
126
+ ### GRPO Algorithm
127
+
128
+ Group Relative Policy Optimization extends PPO by:
129
+
130
+ 1. **Group-Based Sampling**: Sample G trajectories per starting position
131
+ 2. **Group Rewards**: Compute final reward for each trajectory group
132
+ 3. **Relative Advantages**: Normalize advantages within each batch using group statistics
133
+ 4. **PPO Clipping**: Prevent large policy updates with clipped importance ratios
134
+ 5. **KL Penalty**: Regularize policy updates to prevent divergence
135
+
136
+ The loss function combines:
137
+ - **PPO Surrogate Loss**: `L_clip = E[min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)]`
138
+ - **KL Divergence Penalty**: `β * KL(π_old || π_new)`
139
+
140
+ ### Training Pipeline
141
+
142
+ ```
143
+ 1. Sample random starting positions (FEN strings)
144
+ 2. For each position:
145
+ - Sample G trajectory groups using old policy
146
+ - Compute group rewards using Stockfish evaluation
147
+ 3. Compute advantages via group normalization
148
+ 4. Update policy using GRPO loss
149
+ 5. Sync old policy every epoch
150
+ 6. Periodic evaluation against Stockfish
151
+ ```
152
+
153
+ ## Module Structure
154
+
155
+ ```
156
+ grpo_self_play/
157
+ ├── models.py # ChessTransformer architecture
158
+ ├── trainer.py # PyTorch Lightning trainer setup
159
+ ├── train_self_play.py # Main training script
160
+ ├── evaluator.py # Evaluation framework
161
+ ├── eval_utils.py # Evaluation utilities
162
+ ├── constants.py # Configuration constants
163
+ ├── grpo_logic/
164
+ │ ├── model.py # GRPOChessTransformer (Lightning module)
165
+ │ ├── loss.py # GRPO loss computation
166
+ │ └── sampling.py # Trajectory sampling logic
167
+ └── chess/
168
+ ├── chess_logic.py # Board encoding, legal moves
169
+ ├── policy_player.py # Policy-based player
170
+ ├── searcher.py # Trajectory search wrapper
171
+ ├── rewards.py # Stockfish reward computation
172
+ └── stockfish.py # Stockfish engine integration
173
+ ```
174
+
175
+ ## Key Design Decisions
176
+
177
+ ### 1. Group-Based Advantage Estimation
178
+
179
+ Instead of using value functions or Monte Carlo returns, GRPO computes advantages by normalizing rewards within trajectory groups. This approach:
180
+ - Eliminates the need for value function approximation
181
+ - Provides stable learning signals through relative comparisons
182
+ - Reduces variance in advantage estimates
183
+
184
+ ### 2. Stockfish-Based Rewards
185
+
186
+ Using Stockfish for reward computation provides:
187
+ - **Dense Rewards**: Evaluation at every position, not just terminal states
188
+ - **High-Quality Signals**: Professional-grade position evaluation
189
+ - **Caching**: LRU cache for efficient reward computation during training
190
+
191
+ ### 3. Legal Move Masking
192
+
193
+ The action space (1968 moves) is larger than legal moves in any position. The system:
194
+ - Masks illegal moves with `-inf` in logits
195
+ - Ensures policy only samples legal moves
196
+ - Handles edge cases (no legal moves, promotion moves)
197
+
198
+ ### 4. Trajectory Padding and Masking
199
+
200
+ Trajectories have variable lengths due to game terminations. The implementation:
201
+ - Pads trajectories to fixed length for batching
202
+ - Uses attention masks to ignore padding
203
+ - Only considers moves from the starting player's perspective
204
+
205
+ ## Configuration
206
+
207
+ This module uses a **YAML-based configuration system** to manage all hyperparameters and experiment settings. All training hyperparameters, model architecture settings, and evaluation configurations are centralized in YAML files located in `configs/`.
208
+
209
+ ### Configuration Files
210
+
211
+ The default configuration file is `configs/default.yaml`, which contains all hyperparameters organized into sections:
212
+
213
+ - **`training`**: Training loop settings (epochs, batch size, steps per epoch)
214
+ - **`grpo`**: GRPO algorithm hyperparameters (learning rate, trajectories, clipping, KL penalty, entropy regularization, adaptive KL control)
215
+ - **`transformer`**: Model architecture (embedding dimension, layers, attention heads, vocabulary size, action space)
216
+ - **`eval`**: Evaluation settings (number of games, max plies, opening randomization)
217
+ - **`stockfish`**: Stockfish engine configuration (path, skill level, time limits, resource usage)
218
+ - **`policy`**: Policy player settings (temperature, greedy mode, branching factor, search depth)
219
+ - **`searcher`**: Optional trajectory search configuration
220
+ - **`dataset`**: Dataset generation settings (position phases, quality filters, evaluation bounds)
221
+
222
+ ### Using Configurations
223
+
224
+ #### Loading Configurations
225
+
226
+ ```python
227
+ from src.grpo_self_play.configs.config_loader import load_experiment_config
228
+
229
+ # Load default config
230
+ config = load_experiment_config("default.yaml")
231
+
232
+ # Load with overrides
233
+ config = load_experiment_config("default.yaml", overrides={
234
+ "grpo": {"lr": 1e-4, "entropy_coef": 0.2},
235
+ "training": {"num_epochs": 100},
236
+ })
237
+
238
+ # Access config values
239
+ print(config.grpo.lr)
240
+ print(config.training.batch_size)
241
+ print(config.transformer.embed_dim)
242
+ ```
243
+
244
+ #### Training with Configurations
245
+
246
+ ```python
247
+ from src.grpo_self_play.train_self_play import train
248
+
249
+ # Use default config
250
+ train()
251
+
252
+ # Use custom config file
253
+ train(config_path="my_experiment.yaml")
254
+
255
+ # Override specific values
256
+ train(
257
+ config_path="default.yaml",
258
+ overrides={
259
+ "grpo": {"lr": 1e-4},
260
+ "training": {"num_epochs": 50},
261
+ },
262
+ dataloader_kwargs={"num_workers": 4} # Override DataLoader args
263
+ )
264
+ ```
265
+
266
+ ### Creating Custom Configurations
267
+
268
+ 1. Copy the default config:
269
+ ```bash
270
+ cp configs/default.yaml configs/my_experiment.yaml
271
+ ```
272
+
273
+ 2. Edit `my_experiment.yaml` to modify hyperparameters
274
+
275
+ 3. Use your custom config:
276
+ ```python
277
+ train(config_path="my_experiment.yaml")
278
+ ```
279
+
280
+ ### Configuration Dataclasses
281
+
282
+ The configuration system converts YAML files into typed dataclasses:
283
+
284
+ - **`TrainingConfig`**: Training loop settings
285
+ - **`GRPOConfig`**: GRPO algorithm hyperparameters
286
+ - **`ChessTransformerConfig`**: Model architecture
287
+ - **`EvalConfig`**: Evaluation settings
288
+ - **`StockfishConfig`**: Stockfish engine settings
289
+ - **`PolicyConfig`**: Policy player settings
290
+ - **`SearchConfig`**: Trajectory search settings (optional)
291
+ - **`ChessDatasetConfig`**: Dataset generation settings
292
+
293
+ All configs are combined into an `ExperimentConfig` object that provides type-safe access to all settings.
294
+
295
+ ### Key Hyperparameters
296
+
297
+ All hyperparameters are defined in YAML files. Key settings include:
298
+
299
+ **GRPO Algorithm:**
300
+ - `grpo.lr`: Learning rate for policy optimization
301
+ - `grpo.num_trajectories`: Number of trajectory groups per starting position
302
+ - `grpo.trajectory_depth`: Maximum moves per trajectory
303
+ - `grpo.clip_ratio`: PPO clipping epsilon (prevents large policy updates)
304
+ - `grpo.kl_coef`: KL divergence penalty coefficient
305
+ - `grpo.entropy_coef`: Entropy regularization coefficient
306
+ - `grpo.adaptive_kl`: Enable adaptive KL coefficient adjustment
307
+ - `grpo.use_entropy_floor`: Monitor and respond to entropy collapse
308
+
309
+ **Model Architecture:**
310
+ - `transformer.embed_dim`: Transformer embedding dimension
311
+ - `transformer.num_layers`: Number of transformer layers
312
+ - `transformer.num_heads`: Number of attention heads
313
+ - `transformer.vocab_size`: Token vocabulary size
314
+ - `transformer.action_dim`: Action space size (1968 for chess)
315
+
316
+ **Training:**
317
+ - `training.num_epochs`: Total number of training epochs
318
+ - `training.batch_size`: Batch size for training
319
+ - `training.steps_per_epoch`: Number of training steps per epoch
320
+
321
+ See `configs/default.yaml` for the complete list of all hyperparameters and their default values.
322
+
323
+ ## Advanced Usage
324
+
325
+ ### Custom Reward Function
326
+
327
+ ```python
328
+ from src.grpo_self_play.chess.rewards import reward_board
329
+
330
+ def custom_reward(board, start_board):
331
+ # Your custom reward logic
332
+ return reward_board(board, start_board, depth=8, movetime_ms=50)
333
+ ```
334
+
335
+ ### Trajectory Search
336
+
337
+ ```python
338
+ from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
339
+ from src.grpo_self_play.chess.policy_player import PolicyPlayer
340
+
341
+ policy = PolicyPlayer(model)
342
+ searcher = TrajectorySearcher(
343
+ policy,
344
+ cfg=SearchConfig(n_trajectories=10, trajectory_depth=3)
345
+ )
346
+ ```
347
+
348
+ ### Custom Training Loop
349
+
350
+ ```python
351
+ import pytorch_lightning as pl
352
+ from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer
353
+
354
+ model = GRPOChessTransformer(transformer_config, grpo_config)
355
+ trainer = pl.Trainer(
356
+ max_epochs=1000,
357
+ gradient_clip_val=1.0,
358
+ accelerator="gpu",
359
+ devices=1
360
+ )
361
+ trainer.fit(model, dataloader)
362
+ ```
363
+
364
+ ## Performance Considerations
365
+
366
+ - **Batch Size**: Larger batches improve advantage normalization quality
367
+ - **Trajectory Depth**: Deeper trajectories provide more learning signal but increase compute
368
+ - **Stockfish Depth**: Higher depth = better rewards but slower training
369
+ - **Caching**: Reward caching significantly speeds up training
370
+ - **Gradient Clipping**: Prevents exploding gradients in transformer training
371
+
372
+ ## Monitoring and Logging
373
+
374
+ The module logs comprehensive metrics to Weights & Biases:
375
+
376
+ - **Training Metrics**: Loss, KL divergence, policy ratios, reward statistics
377
+ - **Evaluation Metrics**: Win rate, Elo difference, game outcomes
378
+ - **System Metrics**: Trajectory lengths, padding fractions, gradient norms
379
+
380
+ ## Research Background
381
+
382
+ GRPO (Group Relative Policy Optimization) is inspired by:
383
+ - **PPO (Proximal Policy Optimization)**: Clipped surrogate objective
384
+ - **REINFORCE**: Policy gradient methods
385
+ - **Self-Play**: Learning through playing against oneself
386
+ - **AlphaZero**: Combining deep learning with game tree search
387
+
388
+ This implementation adapts these ideas specifically for chess, using Stockfish for reward signals and evaluation.
389
+
390
+ ## Technical Highlights
391
+
392
+ - ✅ **Practical Infrastructure**: Error handling, resource management, logging
393
+ - ✅ **Scalable Design**: Efficient batching, parallel trajectory sampling
394
+ - ✅ **Extensible**: Modular design allows easy customization
395
+ - ✅ **Documented**: Type hints, docstrings, clear structure
396
+ - ⚠️ **Status**: This is a research system, not a production-ready chess engine
397
+
398
+ ## Future Enhancements
399
+
400
+ Potential improvements:
401
+ - Value function approximation for better advantage estimates
402
+ - More robust entropy and KL control for GRPO
403
+ - Multi-GPU training support
404
+ - Distributed self-play
405
+ - Opening book integration
406
+ - Endgame tablebase integration
407
+
408
+ ## License
409
+
410
+ [Specify your license here]
411
+
412
+ ## Citation
413
+
414
+ If you use this code in your research, please cite:
415
+
416
+ ```bibtex
417
+ @software{grpo_chess,
418
+ title = {GRPO Self-Play Chess Module},
419
+ author = {Your Name},
420
+ year = {2024},
421
+ url = {https://github.com/yourusername/grpo_chess}
422
+ }
423
+ ```
424
+
425
+ ## Contact
426
+
427
+ For questions or contributions, please open an issue or contact [your email].
428
+
hf_space_repo/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GRPO Self-Play Module for Chess.
2
+
3
+ This module implements Group Relative Policy Optimization (GRPO) for training
4
+ chess policies through self-play. It includes:
5
+ - Transformer-based chess policy models
6
+ - GRPO training logic with PPO clipping
7
+ - Trajectory sampling and reward computation
8
+ - Evaluation against Stockfish
9
+ """
10
+
11
+ __version__ = "0.1.0"
12
+
13
+ # Main exports
14
+ from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
15
+ from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer, GRPOConfig
16
+ from src.grpo_self_play.grpo_logic.loss import grpo_ppo_loss, GRPOLossInfo
17
+ from src.grpo_self_play.evaluator import Evaluator
18
+ from src.grpo_self_play.eval_utils import EvalConfig
19
+
20
+ __all__ = [
21
+ "ChessTransformer",
22
+ "ChessTransformerConfig",
23
+ "GRPOChessTransformer",
24
+ "GRPOConfig",
25
+ "grpo_ppo_loss",
26
+ "GRPOLossInfo",
27
+ "Evaluator",
28
+ "EvalConfig",
29
+ ]
30
+
hf_space_repo/chess/__init__.py ADDED
File without changes
hf_space_repo/chess/boards_dataset.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset of random chess boards."""
2
+
3
+ import chess
4
+ import random
5
+ import torch
6
+ from collections import deque
7
+
8
+ from typing import Any, Optional, Dict
9
+ from dataclasses import dataclass
10
+ from torch.utils.data import IterableDataset
11
+ from src.grpo_self_play.chess.rewards import evaluate_fen
12
+
13
+
14
+ def generate_random_board(step_num=30):
15
+ """Generate a random board position by making random moves from starting position.
16
+
17
+ Args:
18
+ step_num: Maximum number of random moves to make
19
+
20
+ Returns:
21
+ Chess board after random moves
22
+ """
23
+ board = chess.Board()
24
+ random_steps = random.randint(0, step_num)
25
+ for _ in range(random_steps):
26
+ if board.is_game_over(): break
27
+ board.push(random.choice(list(board.legal_moves)))
28
+ return board
29
+
30
+
31
+ def get_game_phase(board: chess.Board) -> str:
32
+ """Determine the game phase (opening, middlegame, or endgame).
33
+
34
+ Args:
35
+ board: Chess board position
36
+
37
+ Returns:
38
+ "opening", "middlegame", or "endgame"
39
+ """
40
+ move_count = board.fullmove_number * 2 - (1 if board.turn == chess.BLACK else 0)
41
+
42
+ # Count material (excluding kings)
43
+ material_count = sum(
44
+ len(board.pieces(pt, color))
45
+ for pt in [chess.PAWN, chess.ROOK, chess.KNIGHT, chess.BISHOP, chess.QUEEN]
46
+ for color in [chess.WHITE, chess.BLACK]
47
+ )
48
+
49
+ # Endgame: few pieces remaining (typically < 12-14 pieces)
50
+ if material_count <= 12:
51
+ return "endgame"
52
+ # Opening: early moves (typically first 15 moves)
53
+ elif move_count <= 15:
54
+ return "opening"
55
+ # Middlegame: everything else
56
+ else:
57
+ return "middlegame"
58
+
59
+ def evaluate_position_quality(board: chess.Board, depth: int = 2) -> Optional[float]:
60
+ """Quick Stockfish evaluation to check position quality.
61
+
62
+ Args:
63
+ board: Chess board position
64
+ depth: Stockfish search depth (shallow for speed)
65
+
66
+ Returns:
67
+ Centipawn evaluation from White's perspective, or None if evaluation fails
68
+ """
69
+ try:
70
+ fen = board.fen()
71
+ pov_is_white = board.turn == chess.WHITE
72
+ eval_cp = evaluate_fen(fen, pov_is_white, movetime_ms=0, depth=depth)
73
+ return eval_cp
74
+ except Exception:
75
+ return None
76
+
77
+ def generate_opening_position(max_moves: int = 15) -> chess.Board:
78
+ """Generate a realistic opening position using common opening moves.
79
+
80
+ Args:
81
+ max_moves: Maximum number of opening moves to make
82
+
83
+ Returns:
84
+ Chess board in opening phase
85
+ """
86
+ board = chess.Board()
87
+
88
+ # Common first moves for White
89
+ first_moves = [
90
+ chess.Move.from_uci("e2e4"), # King's pawn
91
+ chess.Move.from_uci("d2d4"), # Queen's pawn
92
+ chess.Move.from_uci("g1f3"), # King's knight
93
+ chess.Move.from_uci("c2c4"), # English opening
94
+ ]
95
+
96
+ # Make first move
97
+ if first_moves:
98
+ first_move = random.choice(first_moves)
99
+ if first_move in board.legal_moves:
100
+ board.push(first_move)
101
+
102
+ # Continue with semi-random play (preferring development moves)
103
+ moves_made = 1
104
+ while moves_made < max_moves and not board.is_game_over():
105
+ legal_moves = list(board.legal_moves)
106
+ if not legal_moves:
107
+ break
108
+
109
+ # Prefer piece development over pawn moves in opening
110
+ piece_moves = [m for m in legal_moves if board.piece_at(m.from_square) and
111
+ board.piece_at(m.from_square).piece_type != chess.PAWN]
112
+
113
+ if piece_moves and random.random() < 0.6: # 60% chance to prefer piece moves
114
+ move = random.choice(piece_moves)
115
+ else:
116
+ move = random.choice(legal_moves)
117
+
118
+ board.push(move)
119
+ moves_made += 1
120
+
121
+ return board
122
+
123
+
124
+ def generate_middlegame_position(min_moves: int = 15, max_moves: int = 40) -> chess.Board:
125
+ """Generate a middlegame position from a reasonable starting point.
126
+
127
+ Args:
128
+ min_moves: Minimum moves to reach middlegame
129
+ max_moves: Maximum moves for middlegame
130
+
131
+ Returns:
132
+ Chess board in middlegame phase
133
+ """
134
+ # Start from an opening position
135
+ board = generate_opening_position(max_moves=min_moves)
136
+
137
+ # Continue with random play to reach middlegame
138
+ target_moves = random.randint(min_moves, max_moves)
139
+ moves_made = len(board.move_stack)
140
+
141
+ while moves_made < target_moves and not board.is_game_over():
142
+ legal_moves = list[Any](board.legal_moves)
143
+ if not legal_moves:
144
+ break
145
+ board.push(random.choice(legal_moves))
146
+ moves_made += 1
147
+
148
+ return board
149
+
150
+
151
+
152
+ def generate_endgame_position() -> chess.Board: # TODO: This is not working as expected, it should be a function that generates a random endgame position.
153
+ """Generate an endgame position by removing pieces from a middlegame position.
154
+
155
+ Returns:
156
+ Chess board in endgame phase
157
+ """
158
+ # Start with a middlegame position
159
+ board = generate_middlegame_position(min_moves=20, max_moves=35)
160
+
161
+ # Remove pieces to create endgame (keep kings, remove other pieces randomly)
162
+ pieces_to_remove = []
163
+ for square in chess.SQUARES:
164
+ piece = board.piece_at(square)
165
+ if piece and piece.piece_type != chess.KING:
166
+ pieces_to_remove.append(square)
167
+
168
+ # Remove random pieces until we have endgame material (<= 12 pieces total)
169
+ target_pieces = random.randint(6, 12) # Endgame typically has 6-12 pieces
170
+ current_pieces = len([p for p in pieces_to_remove if board.piece_at(p)])
171
+
172
+ # We need to remove pieces, but we can't directly remove them from python-chess Board
173
+ # Instead, we'll generate a new position by making moves that trade pieces
174
+ # For simplicity, we'll just continue playing until we naturally reach endgame material
175
+
176
+ # Count material
177
+ def count_material(b: chess.Board) -> int:
178
+ return sum(
179
+ len(b.pieces(pt, color))
180
+ for pt in [chess.PAWN, chess.ROOK, chess.KNIGHT, chess.BISHOP, chess.QUEEN]
181
+ for color in [chess.WHITE, chess.BLACK]
182
+ )
183
+
184
+ # Play random moves until we reach endgame material
185
+ max_attempts = 100
186
+ attempts = 0
187
+ while count_material(board) > 12 and attempts < max_attempts and not board.is_game_over():
188
+ legal_moves = list(board.legal_moves)
189
+ if not legal_moves:
190
+ break
191
+
192
+ # Prefer captures to reduce material
193
+ captures = [m for m in legal_moves if board.is_capture(m)]
194
+ if captures:
195
+ move = random.choice(captures)
196
+ else:
197
+ move = random.choice(legal_moves)
198
+
199
+ board.push(move)
200
+ attempts += 1
201
+
202
+ return board
203
+
204
+
205
+
206
+ def generate_position_by_phase(phase: str) -> chess.Board:
207
+ """Generate a position for a specific game phase.
208
+
209
+ Args:
210
+ phase: "opening", "middlegame", or "endgame"
211
+
212
+ Returns:
213
+ Chess board in the specified phase
214
+ """
215
+ if phase == "opening":
216
+ return generate_opening_position()
217
+ elif phase == "middlegame":
218
+ return generate_middlegame_position()
219
+ elif phase == "endgame":
220
+ return generate_endgame_position()
221
+ else:
222
+ raise ValueError(f"Unknown phase: {phase}. Must be 'opening', 'middlegame', or 'endgame'")
223
+
224
+
225
+ def generate_quality_filtered_board(
226
+ step_num: int = 30,
227
+ min_eval_cp: int = -200,
228
+ max_eval_cp: int = 200,
229
+ filter_depth: int = 2,
230
+ max_attempts: int = 50,
231
+ phase: Optional[str] = None
232
+ ) -> Optional[chess.Board]:
233
+ """Generate a random board position filtered by Stockfish evaluation quality.
234
+
235
+ Args:
236
+ step_num: Maximum number of random moves (if phase is None)
237
+ min_eval_cp: Minimum centipawn evaluation to accept
238
+ max_eval_cp: Maximum centipawn evaluation to accept
239
+ filter_depth: Stockfish depth for filtering (shallow for speed)
240
+ max_attempts: Maximum attempts to generate a valid position
241
+ phase: Optional phase to generate ("opening", "middlegame", "endgame")
242
+
243
+ Returns:
244
+ Chess board within evaluation range, or None if no valid position found
245
+ """
246
+ for attempt in range(max_attempts):
247
+ # Generate position
248
+ if phase:
249
+ board = generate_position_by_phase(phase)
250
+ else:
251
+ board = generate_random_board(step_num)
252
+
253
+ # Skip if game over or no legal moves
254
+ if board.is_game_over() or not list(board.legal_moves):
255
+ continue
256
+
257
+ # Evaluate position quality
258
+ eval_cp = evaluate_position_quality(board, depth=filter_depth)
259
+ if eval_cp is None:
260
+ continue
261
+
262
+ # Check if evaluation is within acceptable range
263
+ if min_eval_cp <= eval_cp <= max_eval_cp:
264
+ return board
265
+
266
+ # If we couldn't find a good position, return a random one anyway
267
+ if phase:
268
+ return generate_position_by_phase(phase)
269
+ else:
270
+ return generate_random_board(step_num)
271
+
272
+
273
+ @dataclass
274
+ class ChessDatasetConfig:
275
+ """Configuration for the Chess Start States Dataset.
276
+
277
+ Attributes:
278
+ max_steps: Maximum number of positions to generate per epoch
279
+ random_walk_gen_steps: Maximum random moves (legacy, used if phase_distribution is None)
280
+ phase_distribution: Dict mapping phase names to weights, e.g. {"opening": 0.3, "middlegame": 0.5, "endgame": 0.2}
281
+ min_eval_cp: Minimum centipawn evaluation to accept (-200)
282
+ max_eval_cp: Maximum centipawn evaluation to accept (+200)
283
+ use_opening_book: Whether to use opening book moves for opening positions
284
+ stockfish_filter_depth: Stockfish depth for quality filtering (2-4 for speed)
285
+ cache_positions: Whether to cache and reuse high-quality positions
286
+ cache_size: Maximum number of positions to cache
287
+ quality_filter: Whether to filter positions by Stockfish evaluation
288
+ """
289
+ max_steps: int = 10000
290
+ random_walk_gen_steps: int = 30
291
+ phase_distribution: Optional[Dict[str, float]] = None
292
+ min_eval_cp: int = -200
293
+ max_eval_cp: int = 200
294
+ use_opening_book: bool = True
295
+ stockfish_filter_depth: int = 2
296
+ cache_positions: bool = False
297
+ cache_size: int = 1000
298
+ quality_filter: bool = True
299
+
300
+
301
+ class ChessStartStatesDataset(IterableDataset):
302
+ """
303
+ Infinite dataset that yields high-quality FEN strings from diverse game phases.
304
+
305
+ Supports quality filtering, phase-aware generation, and position caching.
306
+ """
307
+ def __init__(
308
+ self,
309
+ config: ChessDatasetConfig = ChessDatasetConfig()
310
+ ):
311
+ """
312
+ Initialize dataset with quality filtering and phase diversity options.
313
+
314
+ Args:
315
+ config: ChessDatasetConfig object with all configuration parameters.
316
+ Defaults to ChessDatasetConfig() if no config is provided.
317
+ Parameters in the config are:
318
+ max_steps: Maximum number of positions to generate per epoch
319
+ random_walk_gen_steps: Maximum random moves (legacy, used if phase_distribution is None)
320
+ phase_distribution: Dict mapping phase names to weights, e.g. {"opening": 0.3, "middlegame": 0.5, "endgame": 0.2}
321
+ min_eval_cp: Minimum centipawn evaluation to accept (-200)
322
+ max_eval_cp: Maximum centipawn evaluation to accept (+200)
323
+ use_opening_book: Whether to use opening book moves for opening positions
324
+ stockfish_filter_depth: Stockfish depth for quality filtering (2-4 for speed)
325
+ cache_positions: Whether to cache and reuse high-quality positions
326
+ cache_size: Maximum number of positions to cache
327
+ quality_filter: Whether to filter positions by Stockfish evaluation
328
+ """
329
+ # Use config if provided, otherwise use individual parameters or defaults
330
+
331
+ self.max_steps = config.max_steps
332
+ self.random_walk_gen_steps = config.random_walk_gen_steps
333
+ self.phase_distribution = config.phase_distribution
334
+ self.min_eval_cp = config.min_eval_cp
335
+ self.max_eval_cp = config.max_eval_cp
336
+ self.use_opening_book = config.use_opening_book
337
+ self.stockfish_filter_depth = config.stockfish_filter_depth
338
+ self.cache_positions = config.cache_positions
339
+ self.cache_size = config.cache_size
340
+ self.quality_filter = config.quality_filter
341
+
342
+ # Normalize phase distribution (only if not None)
343
+ if self.phase_distribution is not None:
344
+ total_weight = sum(self.phase_distribution.values())
345
+ if total_weight > 0:
346
+ self.phase_distribution = {k: v / total_weight for k, v in self.phase_distribution.items()}
347
+
348
+ # Position cache
349
+ self._position_cache: deque = deque[Any](maxlen=self.cache_size)
350
+ self._cache_stats = {"hits": 0, "misses": 0, "generated": 0}
351
+
352
+ # Statistics tracking
353
+ self._stats = {
354
+ "opening": 0,
355
+ "middlegame": 0,
356
+ "endgame": 0,
357
+ "filtered_out": 0,
358
+ "total_generated": 0,
359
+ }
360
+
361
+ def _sample_phase(self) -> str:
362
+ """Sample a game phase according to phase_distribution weights.
363
+
364
+ Returns:
365
+ Phase name: "opening", "middlegame", or "endgame"
366
+ """
367
+ rand = random.random()
368
+ cumulative = 0.0
369
+ for phase, weight in self.phase_distribution.items():
370
+ cumulative += weight
371
+ if rand <= cumulative:
372
+ return phase
373
+ # Fallback to middlegame
374
+ return "middlegame"
375
+
376
+ def _generate_position(self) -> Optional[chess.Board]:
377
+ """Generate a single position according to configuration.
378
+
379
+ Returns:
380
+ Chess board or None if generation fails
381
+ """
382
+ # Check cache first
383
+ if self.cache_positions and self._position_cache:
384
+ if random.random() < 0.3: # 30% chance to use cached position
385
+ cached_pos = random.choice(self._position_cache)
386
+ self._cache_stats["hits"] += 1
387
+ return chess.Board(cached_pos)
388
+ self._cache_stats["misses"] += 1
389
+
390
+ # Determine phase
391
+ if self.phase_distribution:
392
+ phase = self._sample_phase()
393
+ else:
394
+ phase = None
395
+
396
+ # Generate position
397
+ if self.quality_filter:
398
+ board = generate_quality_filtered_board(
399
+ step_num=self.random_walk_gen_steps,
400
+ min_eval_cp=self.min_eval_cp,
401
+ max_eval_cp=self.max_eval_cp,
402
+ filter_depth=self.stockfish_filter_depth,
403
+ phase=phase
404
+ )
405
+ else:
406
+ if phase:
407
+ board = generate_position_by_phase(phase)
408
+ else:
409
+ board = generate_random_board(self.random_walk_gen_steps)
410
+
411
+ if board is None:
412
+ return None
413
+
414
+ # Update statistics
415
+ if not board.is_game_over():
416
+ actual_phase = get_game_phase(board)
417
+ self._stats[actual_phase] = self._stats.get(actual_phase, 0) + 1
418
+ self._stats["total_generated"] += 1
419
+
420
+ # Cache position if enabled
421
+ if self.cache_positions:
422
+ self._position_cache.append(board.fen())
423
+ self._cache_stats["generated"] += 1
424
+
425
+ return board
426
+
427
+ def get_stats(self) -> Dict:
428
+ """Get statistics about generated positions.
429
+
430
+ Returns:
431
+ Dictionary with statistics
432
+ """
433
+ stats = self._stats.copy()
434
+ if self.cache_positions:
435
+ stats["cache"] = self._cache_stats.copy()
436
+ stats["cache"]["size"] = len(self._position_cache)
437
+ return stats
438
+
439
+ def __iter__(self):
440
+ worker_info = torch.utils.data.get_worker_info()
441
+
442
+ # Determine how many steps this worker should generate
443
+ if worker_info is not None:
444
+ # Split work among workers
445
+ num_workers = worker_info.num_workers
446
+ worker_id = worker_info.id
447
+ per_worker = self.max_steps // num_workers
448
+ # Give remainder to the last worker
449
+ if worker_id == num_workers - 1:
450
+ per_worker += self.max_steps % num_workers
451
+
452
+ # Set deterministic seed per worker for reproducibility and isolation
453
+ worker_seed = 42 + worker_id * 1000
454
+ random.seed(worker_seed)
455
+ torch.manual_seed(worker_seed)
456
+ steps_to_generate = per_worker
457
+ else:
458
+ # Single process mode
459
+ steps_to_generate = self.max_steps
460
+
461
+ # Generate positions for this worker's share
462
+ for step in range(steps_to_generate):
463
+ board = self._generate_position()
464
+ if board is not None and not board.is_game_over():
465
+ yield board.fen()
hf_space_repo/chess/chess_logic.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chess
2
+ import torch
3
+
4
+ from typing import Optional
5
+ from src.grpo_self_play.searchless_chess_imports import (MOVE_TO_ACTION,
6
+ ACTION_TO_MOVE,
7
+ tokenize as deepmind_tokenize)
8
+
9
+ MAX_ACTION = max(ACTION_TO_MOVE.keys())
10
+
11
+
12
+ def board_to_tensor(board, device: str | torch.device ='cpu') -> torch.Tensor:
13
+ fen = board.fen()
14
+ token_ids = list[int](deepmind_tokenize(fen)) # Returns list of ints
15
+ input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device)
16
+ return input_tensor
17
+
18
+
19
+ def get_legal_moves_indices(board):
20
+ legal_moves = list(board.legal_moves)
21
+ legal_indices = []
22
+ for move in legal_moves:
23
+ # move.uci() returns "e2e4" or "a7a8q" which matches your dict keys
24
+ uci_str = move.uci()
25
+ if uci_str in MOVE_TO_ACTION:
26
+ legal_indices.append(MOVE_TO_ACTION[uci_str])
27
+ else:
28
+ # Fallback: unlikely if MOVE_TO_ACTION is complete
29
+ raise ValueError(f"Invalid move: {uci_str}")
30
+ return legal_indices
31
+
32
+
33
+ def get_legal_moves_mask(board, device: str | torch.device ='cpu') -> torch.Tensor:
34
+ legal_moves = list(board.legal_moves)
35
+ mask = torch.zeros(MAX_ACTION + 1, dtype=torch.bool)
36
+ for move in legal_moves:
37
+ uci_str = move.uci()
38
+ action_idx = MOVE_TO_ACTION.get(uci_str)
39
+ if action_idx is not None:
40
+ mask[action_idx] = True
41
+ return mask.to(device)
42
+
43
+
44
+ def action_to_move(board: chess.Board, action_idx: int):
45
+ uci = ACTION_TO_MOVE.get(action_idx)
46
+ if uci is None:
47
+ return None
48
+ try:
49
+ mv = chess.Move.from_uci(uci)
50
+ except ValueError:
51
+ return None
52
+ return mv if mv in board.legal_moves else None
53
+
54
+
55
+ class ChessPlayer:
56
+ """
57
+ An abstract chess player interface.
58
+ """
59
+ def act(self, board: chess.Board) -> Optional[chess.Move]:
60
+ """
61
+ Given a chess.Board, return a chess.Move or None to resign.
62
+ """
63
+ raise NotImplementedError()
hf_space_repo/chess/policy_player.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import random
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from src.grpo_self_play.chess.chess_logic import (board_to_tensor,
6
+ get_legal_moves_indices,
7
+ action_to_move,
8
+ ChessPlayer)
9
+
10
+ from dataclasses import dataclass
11
+
12
+ @dataclass
13
+ class PolicyConfig:
14
+ temperature: float = 1.0
15
+ greedy: bool = False # if True, pick argmax among legal moves
16
+ branching_factor: int = 4 # for search; 0 = no limit
17
+ search_depth: int = 2 # for search; 0 = no search
18
+
19
+
20
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
21
+ torch.serialization.add_safe_globals([PolicyConfig])
22
+
23
+
24
+ class PolicyPlayer(ChessPlayer):
25
+ def __init__(self, model, device=None, cfg=PolicyConfig()):
26
+ self.model = model.eval()
27
+ self.device = device or next(model.parameters()).device
28
+ self.cfg = cfg
29
+ self.stats = {"no_legal_idxs": 0, "mapping_failed": 0, "random_fallback": 0}
30
+
31
+
32
+ @torch.no_grad()
33
+ def act(self, board):
34
+ legal_moves_indices = get_legal_moves_indices(board)
35
+ if not legal_moves_indices:
36
+ self.stats["no_legal_idxs"] += 1
37
+ self.stats["random_fallback"] += 1
38
+ return random.choice(list(board.legal_moves))
39
+ return self.sample_move(board, legal_moves_indices)
40
+
41
+ @torch.no_grad()
42
+ def sample_move(self, board, legal_moves_indices=None):
43
+ if legal_moves_indices is None:
44
+ legal_moves_indices = get_legal_moves_indices(board)
45
+ if not legal_moves_indices:
46
+ self.stats["no_legal_idxs"] += 1
47
+ self.stats["random_fallback"] += 1
48
+ return random.choice(list(board.legal_moves))
49
+ board_tensor = board_to_tensor(board, self.device)
50
+ logits = self.model(board_tensor) # [1, A]
51
+
52
+ A = logits.size(-1)
53
+ masked = torch.full(
54
+ (A,),
55
+ -float("inf"),
56
+ device=self.device,
57
+ dtype=logits.dtype,
58
+ )
59
+ li = torch.tensor(legal_moves_indices, device=self.device, dtype=torch.long)
60
+ masked[li] = logits[0, li]
61
+
62
+ if self.cfg.greedy:
63
+ action_idx = int(torch.argmax(masked).item())
64
+ else:
65
+ temp = max(1e-6, self.cfg.temperature)
66
+ probs = F.softmax(masked / temp, dim=-1)
67
+ action_idx = int(torch.multinomial(probs, 1).item())
68
+ move = action_to_move(board, action_idx)
69
+ if move is None:
70
+ self.stats["mapping_failed"] += 1
71
+ self.stats["random_fallback"] += 1
72
+ return random.choice(list(board.legal_moves))
73
+ return move
74
+
75
+ @torch.no_grad()
76
+ def eval_board(self, board, root_color):
77
+ board_tensor = board_to_tensor(board, self.device)
78
+ legal_moves_indices = get_legal_moves_indices(board)
79
+ if not legal_moves_indices:
80
+ # no moves -> treat via game result if available
81
+ outcome = board.outcome()
82
+ if outcome is not None:
83
+ if outcome.winner is None:
84
+ return 0.0
85
+ return 1.0 if outcome.winner == root_color else -1.0
86
+
87
+ logits = self.model(board_tensor) # [1, A]
88
+ A = logits.size(-1)
89
+ masked = torch.full(
90
+ (A,),
91
+ -float("inf"),
92
+ device=self.device,
93
+ dtype=logits.dtype,
94
+ )
95
+ li = torch.tensor(legal_moves_indices, device=self.device, dtype=torch.long)
96
+ masked[li] = logits[-1, li]
97
+ best_logit = float(torch.max(F.tanh(masked)).item())
98
+ return best_logit if board.turn == root_color else -best_logit
hf_space_repo/chess/rewards.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import chess
3
+ import chess.engine
4
+
5
+ from functools import lru_cache
6
+ from src.grpo_self_play.chess.stockfish import stockfish_analyse, DEFAULT_STOCKFISH_TIMEOUT
7
+
8
+ # Engine name for reward evaluation
9
+ REWARD_ENGINE_NAME = f"reward_engine_{os.getpid()}"
10
+
11
+
12
+ def _get_reward_engine_name() -> str:
13
+ """Get process-specific engine name for reward evaluation."""
14
+ return f"reward_engine_{os.getpid()}"
15
+
16
+
17
+ def _raw_white_reward(fen: str, movetime_ms: int, depth: int, timeout: float = DEFAULT_STOCKFISH_TIMEOUT) -> float:
18
+ """Get raw centipawn evaluation from White's perspective using centralized wrapper."""
19
+ if depth and depth > 0:
20
+ limit = chess.engine.Limit(depth=depth)
21
+ else:
22
+ limit = chess.engine.Limit(time=movetime_ms / 1000.0)
23
+
24
+ info = stockfish_analyse(_get_reward_engine_name(), chess.Board(fen), limit, timeout=timeout)
25
+
26
+ if info is None:
27
+ return 0.0 # Fallback on engine failure
28
+
29
+ score = info["score"].pov(chess.WHITE)
30
+ if score.is_mate():
31
+ return 10000.0 if score.mate() > 0 else -10000.0
32
+ return float(score.score())
33
+
34
+
35
+ @lru_cache(maxsize=50_000)
36
+ def cached_raw_reward_white(fen: str, depth: int) -> float:
37
+ """
38
+ Cached Stockfish raw eval for a given FEN from White's POV.
39
+ Returns centipawn score (positive = White is better).
40
+ Only caches by depth, not movetime since movetime is not deterministic.
41
+ """
42
+ return _raw_white_reward(fen, movetime_ms=10, depth=depth)
43
+
44
+
45
+ def normalize_cp(raw_cp: float) -> float:
46
+ """Normalize raw centipawn score to [-2, 2] using linear clipping."""
47
+ return float(max(-2.0, min(2.0, raw_cp / 1000.0)))
48
+
49
+
50
+ def evaluate_fen(fen: str, pov_is_white: bool, movetime_ms: int, depth: int, normalize: bool = True):
51
+ """
52
+ Cached Stockfish eval for a given FEN and settings.
53
+ Returns a normalized reward in [-1, 1].
54
+ """
55
+ if depth and depth > 0:
56
+ raw_score = cached_raw_reward_white(fen, depth)
57
+ else:
58
+ raw_score = _raw_white_reward(fen, movetime_ms, depth)
59
+
60
+ if not pov_is_white: # Flip sign for black POV
61
+ raw_score = -raw_score
62
+ # Normalize raw score using linear clipping instead of tanh
63
+ # Linear clipping preserves gradient signal regardless of position evaluation
64
+ # tanh was compressing differentials at higher absolute values
65
+ if normalize:
66
+ return normalize_cp(raw_score)
67
+ else:
68
+ return raw_score
69
+
70
+
71
+ def evaluate_board(board: chess.Board, pov_is_white: bool, depth: int = 16, normalize: bool = True) -> float:
72
+ """
73
+ Evaluate a board position from a given POV.
74
+ Returns normalized reward in [-2, 2] or raw centipawns if normalize=False.
75
+ """
76
+ if board.is_game_over(claim_draw=True):
77
+ if board.is_checkmate():
78
+ pov_loses = (board.turn == (chess.WHITE if pov_is_white else chess.BLACK))
79
+ raw = -10000.0 if pov_loses else 10000.0
80
+ else:
81
+ raw = 0.0 # Draw
82
+ return normalize_cp(raw) if normalize else raw
83
+ else:
84
+ return evaluate_fen(board.fen(), pov_is_white, movetime_ms=0, depth=depth, normalize=normalize)
85
+
86
+
87
+ def reward_board(env: chess.Board, board_start: chess.Board, movetime_ms: int = 0, depth: int = 16) -> float:
88
+ """
89
+ Stockfish-based reward from the perspective of board_start.turn,
90
+ matching your original intent.
91
+
92
+ env: current board (python-chess Board)
93
+ board_start: board at trajectory start (used for POV)
94
+ """
95
+ pov_is_white = (board_start.turn == chess.WHITE)
96
+ if env.is_game_over(claim_draw=True): # Terminal state
97
+ if env.is_checkmate():
98
+ pov_loses = (env.turn == (chess.WHITE if pov_is_white else chess.BLACK))
99
+ r_t = -1.0 if pov_loses else 1.0
100
+ else:
101
+ r_t = 0.0 # Draw
102
+ else:
103
+ fen_t = env.fen()
104
+ r_t = evaluate_fen(fen_t, pov_is_white, movetime_ms, depth)
105
+
106
+ fen_0 = board_start.fen()
107
+ r_0 = evaluate_fen(fen_0, pov_is_white, movetime_ms, depth)
108
+ return r_t - r_0 # Reward is the change in eval
hf_space_repo/chess/searcher.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Implement search method to choose moves based on a policy network.
3
+ '''
4
+
5
+ import chess
6
+ import torch
7
+
8
+ from typing import Optional
9
+ from dataclasses import dataclass
10
+ from src.grpo_self_play.chess.chess_logic import ChessPlayer
11
+ from src.grpo_self_play.chess.policy_player import PolicyPlayer
12
+
13
+ @dataclass
14
+ class SearchConfig:
15
+ n_trajectories: int = 1 # G: number of sampled trajectories
16
+ trajectory_depth: int = 1 # T: max plies per trajectory
17
+
18
+
19
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
20
+ torch.serialization.add_safe_globals([SearchConfig])
21
+
22
+
23
+ class TrajectorySearcher(ChessPlayer):
24
+ """
25
+ Searcher that uses a PolicyPlayer to:
26
+ - sample trajectories using the policy
27
+ - evaluate their final states using the policy
28
+ and picks the first move of the best-scoring trajectory.
29
+ """
30
+
31
+ def __init__(self, policy: PolicyPlayer, cfg: SearchConfig = SearchConfig()):
32
+ self.policy = policy
33
+ self.cfg = cfg
34
+
35
+
36
+ @torch.no_grad()
37
+ def act(self, board: chess.Board) -> Optional[chess.Move]:
38
+ """
39
+ If n_trajectories or trajectory_depth <= 1:
40
+ Just use the policy's one-step act() (no search).
41
+
42
+ Otherwise:
43
+ Sample G trajectories, score each by final state,
44
+ pick first move of best trajectory.
45
+ """
46
+ if self.cfg.n_trajectories <= 1 or self.cfg.trajectory_depth <= 1:
47
+ return self.policy.act(board)
48
+
49
+ root_color = board.turn
50
+ best_score = -float("inf")
51
+ best_first_move = None
52
+
53
+ for g in range(self.cfg.n_trajectories):
54
+ rollout_board = board.copy()
55
+
56
+ first_move = None
57
+ for step in range(self.cfg.trajectory_depth):
58
+ if rollout_board.is_game_over():
59
+ break
60
+
61
+ mv = self.policy.sample_move(rollout_board)
62
+ if mv is None:
63
+ # no move available -> end trajectory
64
+ break
65
+
66
+ if first_move is None:
67
+ first_move = mv
68
+
69
+ rollout_board.push(mv)
70
+
71
+ if first_move is None:
72
+ # This trajectory failed to get any move (should be rare)
73
+ continue
74
+
75
+ score = self.policy.eval_board(rollout_board, root_color)
76
+
77
+ if score > best_score:
78
+ best_score = score
79
+ best_first_move = first_move
80
+
81
+ if best_first_move is None:
82
+ # Fallback to simple 1-step policy
83
+ return self.policy.act(board)
84
+
85
+ return best_first_move
86
+
87
+
88
+ @property
89
+ def stats(self) -> dict:
90
+ return self.policy.stats
hf_space_repo/chess/stockfish.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import chess
4
+ import chess.engine
5
+ import torch
6
+ from typing import Optional
7
+ from dataclasses import dataclass
8
+ from concurrent.futures import TimeoutError as FuturesTimeoutError
9
+ from src.grpo_self_play.chess.chess_logic import ChessPlayer
10
+ from src.grpo_self_play.logging_utils import get_logger
11
+
12
+ logger = get_logger("grpo_chess.stockfish")
13
+
14
+ DEFAULT_STOCKFISH_PATH = "/usr/games/stockfish"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class StockfishConfig:
19
+ path: str = DEFAULT_STOCKFISH_PATH
20
+ skill_level: int = 20
21
+ use_elo_limit: bool = False
22
+ elo: int = 2500
23
+ movetime_ms: int = 50
24
+ threads: int = 1
25
+ hash_mb: int = 128
26
+
27
+
28
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
29
+ torch.serialization.add_safe_globals([StockfishConfig])
30
+
31
+
32
+ class StockfishManager:
33
+ '''
34
+ Manage stockfish engine instances by name for player, eval and reward engines.
35
+ For example, We will use several enignes at diffrenet levels for evaluation,
36
+ or for reward we will limit by time.
37
+ '''
38
+ _pid: int = os.getpid()
39
+ _engines: dict[str, chess.engine.SimpleEngine] = {}
40
+ _cfgs: dict[str, StockfishConfig] = {}
41
+ _locks: dict[str, threading.Lock] = {} # Per-engine locks for thread safety
42
+ _manager_lock: threading.Lock = threading.Lock() # Lock for managing _engines/_locks dicts
43
+
44
+
45
+ @classmethod
46
+ def ensure_pid(cls) -> None:
47
+ pid = os.getpid()
48
+ if pid != cls._pid:
49
+ # We are in a forked/spawned child; discard inherited engine handles.
50
+ # This is a workaround to avoid issues with multiprocessing.
51
+ cls._pid = pid
52
+ cls._engines = {}
53
+ cls._cfgs = {}
54
+ cls._locks = {}
55
+ cls._manager_lock = threading.Lock()
56
+
57
+ @classmethod
58
+ def _configure_engine(cls, engine: chess.engine.SimpleEngine, cfg: StockfishConfig) -> None:
59
+ try:
60
+ engine.configure({"Threads": cfg.threads})
61
+ except Exception:
62
+ logger.warning("Failed to set Stockfish threads")
63
+
64
+ try:
65
+ engine.configure({"Hash": cfg.hash_mb})
66
+ except Exception:
67
+ logger.warning("Failed to set Stockfish hash size")
68
+
69
+ try:
70
+ engine.configure({"Skill Level": cfg.skill_level})
71
+ except Exception:
72
+ logger.warning("Failed to set Stockfish skill level")
73
+
74
+ if cfg.use_elo_limit:
75
+ try:
76
+ engine.configure({
77
+ "UCI_LimitStrength": True,
78
+ "UCI_Elo": cfg.elo,
79
+ })
80
+ except Exception:
81
+ logger.warning("Failed to set Stockfish ELO limit")
82
+
83
+
84
+ @classmethod
85
+ def is_name_registered(cls, name: str) -> bool:
86
+ return name in cls._engines
87
+
88
+ @classmethod
89
+ def get_lock(cls, name: str) -> threading.Lock:
90
+ """Get the lock for a named engine (creates if needed)."""
91
+ with cls._manager_lock:
92
+ if name not in cls._locks:
93
+ cls._locks[name] = threading.Lock()
94
+ return cls._locks[name]
95
+
96
+ @classmethod
97
+ def get_engine(cls, name: str, cfg: StockfishConfig | None = None) -> chess.engine.SimpleEngine:
98
+ """
99
+ Get (or create) a named engine instance.
100
+ - name: e.g. "reward", "player"
101
+ - cfg: config to use when creating it (ignored later calls).
102
+ """
103
+ cls.ensure_pid() # Check if we are in a forked/spawned child and discard inherited engine handles.
104
+ with cls._manager_lock:
105
+ if not cls.is_name_registered(name):
106
+ if cfg is None:
107
+ cfg = StockfishConfig()
108
+ engine = chess.engine.SimpleEngine.popen_uci(cfg.path)
109
+ cls._configure_engine(engine, cfg)
110
+ cls._engines[name] = engine
111
+ cls._cfgs[name] = cfg
112
+ cls._locks[name] = threading.Lock()
113
+ return cls._engines[name]
114
+
115
+
116
+ @classmethod
117
+ def close(cls, name: str) -> None:
118
+ with cls._manager_lock:
119
+ engine = cls._engines.get(name)
120
+ if engine is not None:
121
+ try:
122
+ engine.quit()
123
+ except Exception:
124
+ logger.warning(f"Failed to close Stockfish engine '{name}'")
125
+ finally:
126
+ cls._engines.pop(name, None)
127
+ cls._cfgs.pop(name, None)
128
+ cls._locks.pop(name, None)
129
+
130
+
131
+ @classmethod
132
+ def close_all(cls) -> None:
133
+ for name in list(cls._engines.keys()):
134
+ cls.close(name)
135
+
136
+
137
+
138
+ # Default timeout for Stockfish operations (seconds)
139
+ DEFAULT_STOCKFISH_TIMEOUT = 10.0
140
+
141
+
142
+ def run_with_timeout(func, timeout: float, *args, **kwargs):
143
+ """Run a function with a timeout.
144
+
145
+ Uses a single threading.Thread + join(timeout) instead of ThreadPoolExecutor
146
+ so that this works correctly in forked child processes (ProcessPoolExecutor
147
+ with fork). ThreadPoolExecutor can deadlock in forked workers due to
148
+ inherited lock state.
149
+
150
+ Args:
151
+ func: Function to call
152
+ timeout: Maximum time to wait (seconds)
153
+ *args, **kwargs: Arguments to pass to func
154
+
155
+ Returns:
156
+ Result of func
157
+
158
+ Raises:
159
+ FuturesTimeoutError: If the function doesn't complete within timeout
160
+ """
161
+ result_holder: list = []
162
+ exc_holder: list = []
163
+
164
+ def target() -> None:
165
+ try:
166
+ out = func(*args, **kwargs)
167
+ result_holder.append(out)
168
+ except BaseException as e:
169
+ exc_holder.append(e)
170
+
171
+ t = threading.Thread(target=target, daemon=True)
172
+ t.start()
173
+ t.join(timeout=timeout)
174
+ if t.is_alive():
175
+ raise FuturesTimeoutError()
176
+ if exc_holder:
177
+ raise exc_holder[0]
178
+ return result_holder[0]
179
+
180
+
181
+ def stockfish_analyse(
182
+ engine_name: str,
183
+ board: chess.Board,
184
+ limit: chess.engine.Limit,
185
+ timeout: float = DEFAULT_STOCKFISH_TIMEOUT,
186
+ cfg: StockfishConfig | None = None,
187
+ attempts_n: int = 2
188
+ ) -> Optional[chess.engine.InfoDict]:
189
+ """Analyse a position with Stockfish, with timeout and crash recovery.
190
+
191
+ Args:
192
+ engine_name: Name of the engine instance to use
193
+ board: Chess board position to analyse
194
+ limit: Search limit (depth, time, etc.)
195
+ timeout: Maximum time to wait for response (seconds)
196
+ cfg: Optional config for engine creation
197
+ attempts_n: how many attempts to try
198
+
199
+ Returns:
200
+ Analysis info dict, or None if analysis failed
201
+ """
202
+ for attempt in range(attempts_n):
203
+ try:
204
+ engine = StockfishManager.get_engine(engine_name, cfg)
205
+ lock = StockfishManager.get_lock(engine_name)
206
+ with lock:
207
+ return run_with_timeout(engine.analyse, timeout, board, limit)
208
+ except chess.engine.EngineTerminatedError:
209
+ logger.error(f"Stockfish engine '{engine_name}' terminated unexpectedly, recreating...")
210
+ StockfishManager.close(engine_name)
211
+ if attempt == 1:
212
+ return None
213
+ except FuturesTimeoutError:
214
+ logger.warning(f"Stockfish analyse timed out after {timeout}s for engine '{engine_name}'")
215
+ return None
216
+ except Exception as e:
217
+ logger.error(f"Stockfish analyse error: {e}")
218
+ return None
219
+ return None
220
+
221
+
222
+ def stockfish_play(
223
+ engine_name: str,
224
+ board: chess.Board,
225
+ limit: chess.engine.Limit,
226
+ timeout: float = DEFAULT_STOCKFISH_TIMEOUT,
227
+ cfg: StockfishConfig | None = None,
228
+ ) -> Optional[chess.Move]:
229
+ """Get best move from Stockfish, with timeout and crash recovery.
230
+
231
+ Args:
232
+ engine_name: Name of the engine instance to use
233
+ board: Chess board position
234
+ limit: Search limit (depth, time, etc.)
235
+ timeout: Maximum time to wait for response (seconds)
236
+ cfg: Optional config for engine creation
237
+
238
+ Returns:
239
+ Best move, or None if engine failed
240
+ """
241
+ if board.is_game_over():
242
+ return None
243
+
244
+ for attempt in range(2):
245
+ try:
246
+ engine = StockfishManager.get_engine(engine_name, cfg)
247
+ lock = StockfishManager.get_lock(engine_name)
248
+ with lock:
249
+ result = run_with_timeout(engine.play, timeout, board, limit)
250
+ return result.move
251
+ except chess.engine.EngineTerminatedError:
252
+ logger.error(f"Stockfish engine '{engine_name}' terminated unexpectedly, recreating...")
253
+ StockfishManager.close(engine_name)
254
+ if attempt == 1:
255
+ return None
256
+ except FuturesTimeoutError:
257
+ logger.warning(f"Stockfish play timed out after {timeout}s for engine '{engine_name}'")
258
+ return None
259
+ except Exception as e:
260
+ logger.error(f"Stockfish play error: {e}")
261
+ return None
262
+ return None
263
+
264
+
265
+ class StockfishPlayer(ChessPlayer):
266
+ '''
267
+ A chess player that uses Stockfish engine to select moves.
268
+ '''
269
+
270
+ DEFUALT_PLAYER_ENGINE_NAME = "player_engine"
271
+
272
+ def __init__(self, cfg: StockfishConfig, engine_name: Optional[str] = None):
273
+ if engine_name is None:
274
+ engine_name = self.DEFUALT_PLAYER_ENGINE_NAME
275
+ self.engine_name = engine_name
276
+ self.cfg = cfg
277
+ self.engine = StockfishManager.get_engine(self.engine_name, cfg)
278
+
279
+
280
+ def close(self):
281
+ try:
282
+ StockfishManager.close(self.engine_name)
283
+ except Exception:
284
+ logger.warning("Failed to close Stockfish engine in StockfishPlayer")
285
+
286
+ def act(self, board: chess.Board) -> chess.Move | None:
287
+ limit = chess.engine.Limit(time=self.cfg.movetime_ms / 1000.0)
288
+ return stockfish_play(self.engine_name, board, limit, cfg=self.cfg)
hf_space_repo/configs/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Config module for GRPO Chess experiments.
3
+
4
+ Provides YAML-based configuration loading with override support.
5
+
6
+ Usage:
7
+ from src.grpo_self_play.configs import load_experiment_config
8
+
9
+ # Load default config
10
+ config = load_experiment_config()
11
+
12
+ # Load with overrides
13
+ config = load_experiment_config("default.yaml", overrides={
14
+ "grpo": {"lr": 1e-4},
15
+ "training": {"num_epochs": 100},
16
+ })
17
+ """
18
+
19
+ from src.grpo_self_play.configs.config_loader import (
20
+ ExperimentConfig,
21
+ TrainingConfig,
22
+ load_experiment_config,
23
+ load_grpo_config,
24
+ load_transformer_config,
25
+ load_eval_config,
26
+ load_stockfish_config,
27
+ load_dataset_config,
28
+ list_available_configs,
29
+ print_config_summary,
30
+ )
31
+
32
+ __all__ = [
33
+ "ExperimentConfig",
34
+ "TrainingConfig",
35
+ "load_experiment_config",
36
+ "load_grpo_config",
37
+ "load_transformer_config",
38
+ "load_eval_config",
39
+ "load_stockfish_config",
40
+ "load_dataset_config",
41
+ "list_available_configs",
42
+ "print_config_summary",
43
+ ]
hf_space_repo/configs/config_loader.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Config loader for GRPO Chess experiments.
3
+
4
+ This module provides utilities to load experiment configurations from YAML files
5
+ and convert them to the appropriate dataclass objects.
6
+
7
+ Usage:
8
+ from src.grpo_self_play.configs.config_loader import load_experiment_config
9
+
10
+ # Load a complete experiment config
11
+ config = load_experiment_config("default.yaml")
12
+
13
+ # Load with overrides
14
+ config = load_experiment_config("default.yaml", overrides={
15
+ "grpo": {"lr": 1e-4, "entropy_coef": 0.2},
16
+ "training": {"num_epochs": 100},
17
+ })
18
+
19
+ # Access configs
20
+ grpo_config = config.grpo
21
+ transformer_config = config.transformer
22
+ """
23
+
24
+ from dataclasses import dataclass, fields
25
+ from pathlib import Path
26
+ from typing import Any, Optional, TypeVar, Type
27
+ import yaml
28
+
29
+ # Import all config dataclasses
30
+ from src.grpo_self_play.grpo_logic.model import GRPOConfig
31
+ from src.grpo_self_play.models import ChessTransformerConfig
32
+ from src.grpo_self_play.eval_utils import EvalConfig
33
+ from src.grpo_self_play.chess.stockfish import StockfishConfig
34
+ from src.grpo_self_play.chess.policy_player import PolicyConfig
35
+ from src.grpo_self_play.chess.searcher import SearchConfig
36
+ from src.grpo_self_play.chess.boards_dataset import ChessDatasetConfig
37
+ from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
38
+
39
+
40
+ # Directory containing config YAML files
41
+ CONFIGS_DIR = Path(__file__).parent
42
+
43
+
44
+ @dataclass
45
+ class TrainingConfig:
46
+ """Training loop configuration."""
47
+ num_epochs: int = 400
48
+ batch_size: int = 32
49
+ steps_per_epoch: int = 512
50
+ checkpoint_every_n_epochs: int = 5
51
+ keep_n_checkpoints: int = 3
52
+
53
+
54
+ @dataclass
55
+ class ExperimentConfig:
56
+ """Complete experiment configuration containing all sub-configs."""
57
+ training: TrainingConfig
58
+ grpo: GRPOConfig
59
+ transformer: ChessTransformerConfig
60
+ eval: EvalConfig
61
+ stockfish: StockfishConfig
62
+ policy: PolicyConfig
63
+ searcher: Optional[SearchConfig]
64
+ dataset: ChessDatasetConfig
65
+ pretrain: PretrainLoadConfig
66
+
67
+
68
+ T = TypeVar('T')
69
+
70
+
71
+ def _deep_merge(base: dict, overrides: dict) -> dict:
72
+ """Deep merge two dictionaries, with overrides taking precedence.
73
+
74
+ Args:
75
+ base: Base dictionary
76
+ overrides: Dictionary with values to override
77
+
78
+ Returns:
79
+ Merged dictionary
80
+ """
81
+ result = base.copy()
82
+ for key, value in overrides.items():
83
+ if key in result and isinstance(result[key], dict) and isinstance(value, dict):
84
+ result[key] = _deep_merge(result[key], value)
85
+ else:
86
+ result[key] = value
87
+ return result
88
+
89
+
90
+ def dict_to_dataclass(cls: Type[T], data: dict[str, Any]) -> T:
91
+ """Convert a dictionary to a dataclass, ignoring extra keys.
92
+
93
+ Args:
94
+ cls: The dataclass type to instantiate
95
+ data: Dictionary with field values
96
+
97
+ Returns:
98
+ Instance of the dataclass with values from data
99
+ """
100
+ if data is None:
101
+ return None
102
+
103
+ # Get valid field names for this dataclass
104
+ valid_fields = {f.name for f in fields(cls)}
105
+
106
+ # Filter to only include valid fields
107
+ filtered_data = {k: v for k, v in data.items() if k in valid_fields}
108
+
109
+ return cls(**filtered_data)
110
+
111
+
112
+ def load_yaml_file(path: str | Path) -> dict[str, Any]:
113
+ """Load a YAML config file.
114
+
115
+ Args:
116
+ path: Path to the YAML file (absolute or relative to configs dir)
117
+
118
+ Returns:
119
+ Dictionary containing the parsed YAML
120
+ """
121
+ path = Path(path)
122
+
123
+ # If not absolute, look in configs directory
124
+ if not path.is_absolute():
125
+ path = CONFIGS_DIR / path
126
+
127
+ if not path.exists():
128
+ raise FileNotFoundError(f"Config file not found: {path}")
129
+
130
+ with open(path, 'r') as f:
131
+ return yaml.safe_load(f)
132
+
133
+
134
+ def load_experiment_config(
135
+ path: str | Path = "default.yaml",
136
+ overrides: dict[str, dict[str, Any]] | None = None
137
+ ) -> ExperimentConfig:
138
+ """Load a complete experiment configuration from a YAML file.
139
+
140
+ Args:
141
+ path: Path to the YAML file (absolute or relative to configs dir)
142
+ overrides: Optional dict of overrides per section. Example:
143
+ {
144
+ "grpo": {"lr": 1e-4, "entropy_coef": 0.2},
145
+ "training": {"num_epochs": 100},
146
+ "stockfish": {"skill_level": 5},
147
+ }
148
+
149
+ Returns:
150
+ ExperimentConfig containing all sub-configs
151
+ """
152
+ data = load_yaml_file(path)
153
+
154
+ # Apply overrides if provided
155
+ if overrides:
156
+ data = _deep_merge(data, overrides)
157
+
158
+ # Convert each section to its dataclass
159
+ training = dict_to_dataclass(TrainingConfig, data.get('training', {}))
160
+ grpo = dict_to_dataclass(GRPOConfig, data.get('grpo', {}))
161
+ transformer = dict_to_dataclass(ChessTransformerConfig, data.get('transformer', {}))
162
+ eval_cfg = dict_to_dataclass(EvalConfig, data.get('eval', {}))
163
+ stockfish = dict_to_dataclass(StockfishConfig, data.get('stockfish', {}))
164
+ policy = dict_to_dataclass(PolicyConfig, data.get('policy', {}))
165
+ dataset = dict_to_dataclass(ChessDatasetConfig, data.get('dataset', {}))
166
+ pretrain = dict_to_dataclass(PretrainLoadConfig, data.get('pretrain', {}))
167
+
168
+ # Searcher is optional (can be null)
169
+ searcher_data = data.get('searcher')
170
+ searcher = dict_to_dataclass(SearchConfig, searcher_data) if searcher_data else None
171
+
172
+ return ExperimentConfig(
173
+ training=training,
174
+ grpo=grpo,
175
+ transformer=transformer,
176
+ eval=eval_cfg,
177
+ stockfish=stockfish,
178
+ policy=policy,
179
+ searcher=searcher,
180
+ dataset=dataset,
181
+ pretrain=pretrain,
182
+ )
183
+
184
+
185
+ def load_grpo_config(
186
+ path: str | Path = "default.yaml",
187
+ overrides: dict[str, Any] | None = None
188
+ ) -> GRPOConfig:
189
+ """Load just the GRPO config from a YAML file.
190
+
191
+ Args:
192
+ path: Path to the YAML file
193
+ overrides: Optional dict of field overrides. Example: {"lr": 1e-4}
194
+ """
195
+ data = load_yaml_file(path)
196
+ grpo_data = data.get('grpo', {})
197
+ if overrides:
198
+ grpo_data = _deep_merge(grpo_data, overrides)
199
+ return dict_to_dataclass(GRPOConfig, grpo_data)
200
+
201
+
202
+ def load_transformer_config(
203
+ path: str | Path = "default.yaml",
204
+ overrides: dict[str, Any] | None = None
205
+ ) -> ChessTransformerConfig:
206
+ """Load just the transformer config from a YAML file."""
207
+ data = load_yaml_file(path)
208
+ cfg_data = data.get('transformer', {})
209
+ if overrides:
210
+ cfg_data = _deep_merge(cfg_data, overrides)
211
+ return dict_to_dataclass(ChessTransformerConfig, cfg_data)
212
+
213
+
214
+ def load_eval_config(
215
+ path: str | Path = "default.yaml",
216
+ overrides: dict[str, Any] | None = None
217
+ ) -> EvalConfig:
218
+ """Load just the eval config from a YAML file."""
219
+ data = load_yaml_file(path)
220
+ cfg_data = data.get('eval', {})
221
+ if overrides:
222
+ cfg_data = _deep_merge(cfg_data, overrides)
223
+ return dict_to_dataclass(EvalConfig, cfg_data)
224
+
225
+
226
+ def load_stockfish_config(
227
+ path: str | Path = "default.yaml",
228
+ overrides: dict[str, Any] | None = None
229
+ ) -> StockfishConfig:
230
+ """Load just the stockfish config from a YAML file."""
231
+ data = load_yaml_file(path)
232
+ cfg_data = data.get('stockfish', {})
233
+ if overrides:
234
+ cfg_data = _deep_merge(cfg_data, overrides)
235
+ return dict_to_dataclass(StockfishConfig, cfg_data)
236
+
237
+
238
+ def load_dataset_config(
239
+ path: str | Path = "default.yaml",
240
+ overrides: dict[str, Any] | None = None
241
+ ) -> ChessDatasetConfig:
242
+ """Load just the dataset config from a YAML file."""
243
+ data = load_yaml_file(path)
244
+ cfg_data = data.get('dataset', {})
245
+ if overrides:
246
+ cfg_data = _deep_merge(cfg_data, overrides)
247
+ return dict_to_dataclass(ChessDatasetConfig, cfg_data)
248
+
249
+
250
+ def list_available_configs() -> list[str]:
251
+ """List all available YAML config files in the configs directory."""
252
+ return [f.name for f in CONFIGS_DIR.glob("*.yaml")]
253
+
254
+
255
+ def print_config_summary(config: ExperimentConfig) -> None:
256
+ """Print a summary of the experiment configuration."""
257
+ print("=" * 60)
258
+ print("EXPERIMENT CONFIGURATION")
259
+ print("=" * 60)
260
+
261
+ print("\n[Training]")
262
+ print(f" epochs: {config.training.num_epochs}")
263
+ print(f" batch_size: {config.training.batch_size}")
264
+ print(f" steps_per_epoch: {config.training.steps_per_epoch}")
265
+
266
+ print("\n[GRPO]")
267
+ print(f" lr: {config.grpo.lr}")
268
+ print(f" num_trajectories: {config.grpo.num_trajectories}")
269
+ print(f" trajectory_depth: {config.grpo.trajectory_depth}")
270
+ print(f" entropy_coef: {config.grpo.entropy_coef}")
271
+ print(f" rollout_temperature: {config.grpo.rollout_temperature}")
272
+ print(f" adaptive_kl: {config.grpo.adaptive_kl}")
273
+ print(f" use_entropy_floor: {config.grpo.use_entropy_floor}")
274
+
275
+ print("\n[Transformer]")
276
+ print(f" embed_dim: {config.transformer.embed_dim}")
277
+ print(f" num_layers: {config.transformer.num_layers}")
278
+ print(f" num_heads: {config.transformer.num_heads}")
279
+
280
+ print("\n[Eval]")
281
+ print(f" games: {config.eval.games}")
282
+ print(f" max_plies: {config.eval.max_plies}")
283
+
284
+ print("\n[Stockfish]")
285
+ print(f" skill_level: {config.stockfish.skill_level}")
286
+
287
+ print("\n[Searcher]")
288
+ print(f" enabled: {config.searcher is not None}")
289
+
290
+ print("=" * 60)
hf_space_repo/configs/default.yaml ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Default experiment configuration
2
+ # This file contains all hyperparameters for a training run.
3
+ # Copy this file and modify for new experiments.
4
+
5
+ # =============================================================================
6
+ # Training Loop Settings
7
+ # =============================================================================
8
+ training:
9
+ num_epochs: 400
10
+ batch_size: 32
11
+ steps_per_epoch: 512
12
+ checkpoint_every_n_epochs: 5 # Save periodic checkpoint every N epochs for crash recovery
13
+ keep_n_checkpoints: 3 # Keep last N periodic checkpoints per run
14
+
15
+ # =============================================================================
16
+ # GRPO (Group Relative Policy Optimization) Config
17
+ # Clean run config (see research_docs/2026-02-06_loss-budget-and-monitor-analysis.md)
18
+ # =============================================================================
19
+ grpo:
20
+ lr: 0.000001 # 1e-6: reduced because PPO signal now dominates gradient
21
+ num_trajectories: 16
22
+ trajectory_depth: 16
23
+ clip_ratio: 0.20
24
+ kl_coef: 0.001 # reduced from 0.01 (was being overridden to 0.1 by adaptive KL)
25
+ entropy_coef: 0.0 # removed: not part of original GRPO loss, was 95% of gradient
26
+ eval_every_n_epochs: 10
27
+ ppo_steps: 1
28
+ rollout_temperature: 1.3
29
+
30
+ # Entropy floor monitoring — disabled (never triggered, see research doc)
31
+ use_entropy_floor: false
32
+ entropy_floor: 1.5
33
+ entropy_floor_steps: 150
34
+ entropy_floor_action: "boost"
35
+ entropy_boost_factor: 1.5
36
+
37
+ # Adaptive KL controller — disabled (saturated at max instantly, see research doc)
38
+ adaptive_kl: false
39
+ target_kl: 0.012
40
+ kl_adapt_rate: 1.2
41
+ kl_coef_min: 0.001
42
+ kl_coef_max: 0.1
43
+
44
+ # Safety checks
45
+ enable_safety_checks: false
46
+ safety_patience_steps: 1000
47
+ max_clip_fraction: 0.95
48
+ min_entropy: 0.5
49
+ max_kl_divergence: 0.08
50
+
51
+ # Teacher forcing: use Stockfish for rival moves during trajectory sampling
52
+ teacher_forcing_prob: 0.1 # 10% of rival moves will be from Stockfish
53
+ teacher_forcing_depth: 4
54
+
55
+ # =============================================================================
56
+ # Transformer Model Config
57
+ # =============================================================================
58
+ transformer:
59
+ vocab_size: 300
60
+ embed_dim: 256
61
+ num_layers: 4
62
+ num_heads: 8
63
+ action_dim: 1968
64
+
65
+ # =============================================================================
66
+ # Evaluation Config (vs Stockfish)
67
+ # =============================================================================
68
+ eval:
69
+ games: 64
70
+ seed: 0
71
+ max_plies: 400
72
+ randomize_opening: true
73
+ opening_plies: 6
74
+
75
+ # =============================================================================
76
+ # Stockfish Config
77
+ # =============================================================================
78
+ stockfish:
79
+ path: "/usr/games/stockfish" # Override in colab/local as needed
80
+ skill_level: 2
81
+ use_elo_limit: false
82
+ elo: 2500
83
+ movetime_ms: 50
84
+ threads: 1
85
+ hash_mb: 128
86
+
87
+ # =============================================================================
88
+ # Policy Player Config (for evaluation)
89
+ # =============================================================================
90
+ policy:
91
+ temperature: 0.8
92
+ greedy: true
93
+ branching_factor: 4
94
+ search_depth: 2
95
+
96
+ # =============================================================================
97
+ # Searcher Config (optional - set to null to disable)
98
+ # =============================================================================
99
+ searcher: null
100
+ # searcher:
101
+ # n_trajectories: 4
102
+ # trajectory_depth: 8
103
+
104
+ # =============================================================================
105
+ # Pretraining (optional - load pretrained weights before GRPO)
106
+ # =============================================================================
107
+ pretrain:
108
+ checkpoint_path: null # Path to pretrained checkpoint (e.g., "checkpoints/pretrain/pretrain_final.pt")
109
+ freeze_layers: 2 # Freeze first 2 transformer layers to preserve learned representations
110
+
111
+ # =============================================================================
112
+ # Dataset Config (Chess Start States)
113
+ # =============================================================================
114
+ dataset:
115
+ max_steps: 512 # Should match steps_per_epoch
116
+ phase_distribution:
117
+ opening: 0.33
118
+ middlegame: 0.34
119
+ endgame: 0.33
120
+ min_eval_cp: -200
121
+ max_eval_cp: 200
122
+ quality_filter: true
123
+ stockfish_filter_depth: 4
hf_space_repo/configs/pretrain.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretraining configuration for chess model
2
+ # This file contains hyperparameters for supervised pretraining on Lichess games.
3
+ #
4
+ # Usage:
5
+ # python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
6
+
7
+ # =============================================================================
8
+ # Pretraining Settings
9
+ # =============================================================================
10
+ pretrain:
11
+ lr: 0.0001 # Learning rate (higher than GRPO fine-tuning)
12
+ batch_size: 4096 # Batch size for pretraining
13
+ num_epochs: 22 # Number of passes through the dataset
14
+ warmup_steps: 1000 # Linear warmup steps
15
+ weight_decay: 0.01 # AdamW weight decay
16
+ max_grad_norm: 1.0 # Gradient clipping
17
+ checkpoint_dir: "checkpoints/pretrain"
18
+ resume_from: null # Path to resume from (optional)
19
+ use_wandb: true
20
+ wandb_project: "chess-grpo-pretrain"
21
+ label_smoothing: 0.1 # Prevents overconfidence
22
+ num_workers: 4 # DataLoader workers
23
+ val_check_interval: 0.1 # Validate every 10% of epoch
24
+
25
+ # =============================================================================
26
+ # Dataset Settings (Lichess games from HuggingFace)
27
+ # =============================================================================
28
+ dataset:
29
+ min_elo: 1800 # Minimum player rating to include
30
+ max_samples: 5000000 # Max samples per epoch (null = unlimited)
31
+ skip_first_n_moves: 5 # Skip opening moves (book territory)
32
+ skip_last_n_moves: 5 # Skip endgame/resignation moves
33
+ sample_positions_per_game: 3 # Positions to sample from each game
34
+ buffer_size: 10000 # Shuffle buffer size for streaming
35
+ filter_abandoned: true # Skip abandoned games
36
+ dataset_name: "Lichess/standard-chess-games"
37
+ split: "train" # Dataset split to use
38
+ is_eval: false # False for training, True for evaluation
39
+ eval_fraction: 0.05 # 5% of games held out for evaluation
40
+
41
+ # =============================================================================
42
+ # Transformer Model Config (should match GRPO training)
43
+ # =============================================================================
44
+ transformer:
45
+ vocab_size: 300
46
+ embed_dim: 256
47
+ num_layers: 4
48
+ num_heads: 8
49
+ action_dim: 1968
hf_space_repo/constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants used across the GRPO self-play module."""
2
+
3
+ # Sequence length for tokenized FEN strings
4
+ SEQUENCE_LENGTH = 77
5
+
6
+ # Default training hyperparameters
7
+ DEFAULT_LEARNING_RATE = 1e-4
8
+ DEFAULT_NUM_TRAJECTORIES = 4
9
+ DEFAULT_TRAJECTORY_DEPTH = 5
10
+ DEFAULT_CLIP_RATIO = 0.2
11
+ DEFAULT_KL_COEF = 0.01
12
+
13
+ # Default evaluation settings
14
+ DEFAULT_EVAL_GAMES = 50
15
+ DEFAULT_EVAL_MAX_PLIES = 400
hf_space_repo/eval_utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for evaluating chess policies against Stockfish."""
2
+ import io
3
+ import math
4
+ import chess
5
+ import chess.pgn
6
+ import chess.engine
7
+ import random
8
+
9
+ import torch
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Tuple
13
+
14
+ from src.grpo_self_play.chess.chess_logic import MOVE_TO_ACTION
15
+ from src.grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig
16
+ from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
17
+ from src.grpo_self_play.chess.stockfish import StockfishPlayer, StockfishConfig, DEFAULT_STOCKFISH_PATH as STOCKFISH_PATH
18
+
19
+
20
+ @dataclass
21
+ class EvalConfig:
22
+ games: int = 50
23
+ seed: int = 0
24
+ max_plies: int = 400 # safety to avoid extremely long games
25
+ randomize_opening: bool = False
26
+ opening_plies: int = 6 # random legal moves to diversify early positions
27
+
28
+
29
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
30
+ torch.serialization.add_safe_globals([EvalConfig])
31
+
32
+
33
+ def debug_legal_coverage(board: chess.Board) -> tuple[int, int, list[str]]:
34
+ """Debug function to check coverage of legal moves in action space.
35
+
36
+ Args:
37
+ board: Chess board position
38
+
39
+ Returns:
40
+ Tuple of (covered_count, total_legal_moves, list_of_missing_moves)
41
+ """
42
+ legals = list(board.legal_moves)
43
+ covered = 0
44
+ missing = []
45
+ for mv in legals:
46
+ u = mv.uci()
47
+ if u in MOVE_TO_ACTION:
48
+ covered += 1
49
+ else:
50
+ missing.append(u)
51
+ return covered, len(legals), missing[:10]
52
+
53
+
54
+
55
+
56
+
57
+ def play_one_game(
58
+ policy: PolicyPlayer | TrajectorySearcher,
59
+ stockfish: StockfishPlayer,
60
+ policy_is_white: bool,
61
+ cfg: EvalConfig,
62
+ game_number: int = 0,
63
+ ) -> Tuple[str, str, str]:
64
+ """Play a single game between policy and Stockfish.
65
+
66
+ Args:
67
+ policy: Policy player to evaluate
68
+ stockfish: Stockfish player
69
+ policy_is_white: Whether policy plays as white
70
+ cfg: Evaluation configuration
71
+ game_number: Game number for PGN metadata
72
+
73
+ Returns:
74
+ Tuple of (result_str, termination_reason, pgn_str)
75
+ result_str in {"1-0", "0-1", "1/2-1/2"}
76
+ """
77
+
78
+ board = chess.Board()
79
+ game = chess.pgn.Game()
80
+ game.headers["Event"] = "Policy vs Stockfish Evaluation"
81
+ game.headers["White"] = "Policy" if policy_is_white else "Stockfish"
82
+ game.headers["Black"] = "Stockfish" if policy_is_white else "Policy"
83
+ game.headers["Round"] = str(game_number + 1)
84
+ node = game
85
+
86
+ # Optional random opening to reduce overfitting to a single line
87
+ if cfg.randomize_opening and cfg.opening_plies > 0:
88
+ for _ in range(cfg.opening_plies):
89
+ if board.is_game_over():
90
+ break
91
+ move = random.choice(list(board.legal_moves))
92
+ board.push(move)
93
+ node = node.add_variation(move)
94
+
95
+ for ply in range(cfg.max_plies):
96
+ if board.is_game_over(claim_draw=True):
97
+ break
98
+
99
+ is_white_to_move = board.turn
100
+ policy_turn = (is_white_to_move and policy_is_white) or ((not is_white_to_move) and (not policy_is_white))
101
+
102
+ if policy_turn:
103
+ move = policy.act(board)
104
+ else:
105
+ move = stockfish.act(board)
106
+ if move is None:
107
+ break # no legal moves
108
+
109
+ board.push(move)
110
+ node = node.add_variation(move)
111
+
112
+ # Determine result
113
+ if board.is_game_over(claim_draw=True):
114
+ res = board.result(claim_draw=True)
115
+ reason = "game_over"
116
+ else:
117
+ # Reached max plies: treat as draw
118
+ res = "1/2-1/2"
119
+ reason = "max_plies"
120
+
121
+ game.headers["Result"] = res
122
+
123
+ # Generate PGN string
124
+ pgn_output = io.StringIO()
125
+ exporter = chess.pgn.FileExporter(pgn_output)
126
+ game.accept(exporter)
127
+ pgn_str = pgn_output.getvalue()
128
+
129
+ return res, reason, pgn_str
130
+
131
+
132
+ def estimate_elo_diff(score: float) -> float:
133
+ """Estimate Elo difference from match score.
134
+
135
+ Uses logistic model: S = 1/(1+10^(-d/400)) => d = -400*log10(1/S - 1)
136
+ Clamped for numeric stability.
137
+
138
+ Args:
139
+ score: Win rate score in [0, 1]
140
+
141
+ Returns:
142
+ Estimated Elo difference
143
+ """
144
+ eps = 1e-6
145
+ s = min(max(score, eps), 1 - eps)
146
+ return -400.0 * math.log10(1.0 / s - 1.0)
147
+
148
+
149
+ def evaluate_policy_vs_stockfish(
150
+ policy: PolicyPlayer | TrajectorySearcher,
151
+ sf: StockfishPlayer,
152
+ eval_cfg: EvalConfig,
153
+ ) -> Tuple[Dict, PolicyPlayer | TrajectorySearcher, List[str]]:
154
+ """Evaluate a policy by playing multiple games against Stockfish.
155
+
156
+ Args:
157
+ policy: Policy player to evaluate
158
+ sf: Stockfish player
159
+ eval_cfg: Evaluation configuration
160
+
161
+ Returns:
162
+ Tuple of (results_dict, policy_player, pgns)
163
+ results_dict contains: games, wins, draws, losses, score, elo_diff, etc.
164
+ pgns is a list of PGN strings for all games played
165
+ """
166
+ random.seed(eval_cfg.seed)
167
+ torch.manual_seed(eval_cfg.seed)
168
+
169
+ wins = draws = losses = 0
170
+ term_reasons = {}
171
+ pgns: List[str] = []
172
+
173
+ try:
174
+ for g in range(eval_cfg.games):
175
+ policy_is_white = (g % 2 == 0)
176
+ res, reason, pgn = play_one_game(policy, sf, policy_is_white, eval_cfg, game_number=g)
177
+ term_reasons[reason] = term_reasons.get(reason, 0) + 1
178
+ pgns.append(pgn)
179
+
180
+ # From policy perspective
181
+ if res == "1-0":
182
+ if policy_is_white:
183
+ wins += 1
184
+ else:
185
+ losses += 1
186
+ elif res == "0-1":
187
+ if policy_is_white:
188
+ losses += 1
189
+ else:
190
+ wins += 1
191
+ else:
192
+ draws += 1
193
+
194
+ finally:
195
+ sf.close()
196
+
197
+ total = wins + draws + losses
198
+ score = (wins + 0.5 * draws) / total if total else 0.0
199
+ elo_diff = estimate_elo_diff(score) if total else 0.0
200
+
201
+ return {
202
+ "games": total,
203
+ "wins": wins,
204
+ "draws": draws,
205
+ "losses": losses,
206
+ "score": score,
207
+ "elo_diff_vs_stockfish_approx": elo_diff,
208
+ "termination_reasons": term_reasons,
209
+ "eval_cfg": eval_cfg,
210
+ }, policy, pgns
211
+
hf_space_repo/evaluator.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+ from chess import engine
3
+ import torch.nn as nn
4
+
5
+ from src.grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig
6
+ from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
7
+ from src.grpo_self_play.chess.stockfish import StockfishPlayer, StockfishConfig, StockfishManager
8
+ from src.grpo_self_play.eval_utils import EvalConfig, evaluate_policy_vs_stockfish
9
+
10
+
11
+
12
+ class Evaluator:
13
+ """Evaluate a chess model by playing against Stockfish.
14
+
15
+ Handles evaluation of chess policies against Stockfish at various skill levels.
16
+ Supports both single evaluations and skill ladder evaluations.
17
+ """
18
+ def __init__(self,
19
+ eval_cfg: EvalConfig = EvalConfig(),
20
+ policy_cfg: PolicyConfig = PolicyConfig(),
21
+ searcher_cfg: Optional[SearchConfig] = None,
22
+ stockfish_cfg: StockfishConfig = StockfishConfig()):
23
+ """
24
+ Initialize evaluator.
25
+
26
+ Args:
27
+ eval_cfg: Evaluation configuration (number of games, etc.)
28
+ policy_cfg: Policy player configuration
29
+ searcher_cfg: Optional search configuration for tree search
30
+ stockfish_cfg: Stockfish engine configuration
31
+ """
32
+ self.eval_cfg = eval_cfg
33
+ self.policy_cfg = policy_cfg
34
+ self.searcher_cfg = searcher_cfg
35
+ self.default_stockfish_cfg = stockfish_cfg
36
+
37
+ def _make_policy(self, model: nn.Module) -> PolicyPlayer | TrajectorySearcher:
38
+ """Create a policy player (optionally wrapped with search).
39
+
40
+ Args:
41
+ model: Neural network model
42
+
43
+ Returns:
44
+ Policy player, optionally wrapped with trajectory search
45
+ """
46
+ policy = PolicyPlayer(model, cfg=self.policy_cfg)
47
+ if self.searcher_cfg is not None:
48
+ policy = TrajectorySearcher(policy, cfg=self.searcher_cfg)
49
+ return policy
50
+
51
+ def _make_stockfish(self) -> StockfishPlayer:
52
+ """Create a Stockfish player with default configuration.
53
+
54
+ Returns:
55
+ Stockfish player instance
56
+ """
57
+ return StockfishPlayer(self.default_stockfish_cfg)
58
+
59
+ def single_evaluation(self, model: nn.Module) -> Tuple[Dict, PolicyPlayer | TrajectorySearcher, List[str]]:
60
+ """Evaluate the model by playing games against Stockfish.
61
+
62
+ Args:
63
+ model: Neural network model to evaluate
64
+
65
+ Returns:
66
+ Tuple of (results_dict, policy_or_searcher, pgns)
67
+ pgns is a list of PGN strings for all games played
68
+ """
69
+ stockfish_player = self._make_stockfish()
70
+ policy = self._make_policy(model)
71
+ results, policy_or_searcher, pgns = evaluate_policy_vs_stockfish(
72
+ policy,
73
+ stockfish_player,
74
+ self.eval_cfg,
75
+ )
76
+ return results, policy_or_searcher, pgns
77
+
78
+ def eval_ladder(self, model: nn.Module) -> Dict[int, float]:
79
+ """Evaluate model against Stockfish at multiple skill levels.
80
+
81
+ Args:
82
+ model: Neural network model to evaluate
83
+
84
+ Returns:
85
+ Dictionary mapping skill level to win rate score
86
+ """
87
+ policy = self._make_policy(model)
88
+ results = {}
89
+ skill_levels = [1, 3, 5, 8, 10]
90
+ for skill in skill_levels:
91
+ stockfish_cfg = StockfishConfig(
92
+ path=self.default_stockfish_cfg.path,
93
+ skill_level=skill,
94
+ movetime_ms=self.default_stockfish_cfg.movetime_ms,
95
+ )
96
+ engine_name = f"stockfish_skill_{skill}"
97
+ stockfish_player = StockfishPlayer(stockfish_cfg, engine_name=engine_name)
98
+
99
+ try:
100
+ r, policy_wrapper, _ = evaluate_policy_vs_stockfish(
101
+ policy,
102
+ stockfish_player,
103
+ self.eval_cfg,
104
+ )
105
+ results[skill] = r["score"]
106
+ print(f"Skill {skill}: {r}")
107
+ if hasattr(policy_wrapper, 'stats'):
108
+ print(f'Policy stats: {policy_wrapper.stats}')
109
+ except Exception as e:
110
+ print(f"Error evaluating at skill {skill}: {e}")
111
+ results[skill] = 0.0
112
+ finally:
113
+ StockfishManager.close(engine_name) # Close engine to free resources
114
+ return results
115
+
116
+
117
+
118
+
hf_space_repo/grpo_logic/__init__.py ADDED
File without changes
hf_space_repo/grpo_logic/loss.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class GRPOLossInfo:
8
+ """Information about GRPO loss components for logging and debugging."""
9
+ kl_div: torch.Tensor
10
+ mean_ratio: torch.Tensor
11
+ mean_clip_fraction: torch.Tensor
12
+ ppo_loss: torch.Tensor
13
+ entropy: torch.Tensor
14
+ loss_without_entropy: torch.Tensor
15
+
16
+ def grpo_chess_loss(
17
+ logprobs_new: torch.Tensor, # [G, T] log πθ(a_{g,k,t} | s_{g,k,t})
18
+ logprobs_old: torch.Tensor, # [G, T] log πold(a_{g,k,t} | s_{g,k,t})
19
+ advantages: torch.Tensor, # [G, T]
20
+ clip_eps: float = 0.2, # ε in the formula
21
+ beta_kl: float = 0.0, # β in the formula (0 = no explicit KL penalty)
22
+ eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ """
24
+ Compute GRPO chess loss (legacy function, consider using grpo_ppo_loss instead).
25
+
26
+ Args:
27
+ logprobs_new: New policy log probabilities [G, T]
28
+ logprobs_old: Old policy log probabilities [G, T]
29
+ advantages: Advantage values [G, T]
30
+ clip_eps: PPO clipping epsilon
31
+ beta_kl: KL penalty coefficient
32
+ eps: Numerical stability epsilon
33
+
34
+ Returns:
35
+ Tuple of (loss, approximate_kl_divergence)
36
+ """
37
+
38
+ # ------------------------------------------------------------
39
+ # 3. Probability ratio r_{g,k,t}(θ)
40
+ #
41
+ # r_{g,k,t}(θ) = πθ(a_{g,k,t}|s_{g,k,t}) / πold(a_{g,k,t}|s_{g,k,t})
42
+ # = exp( logπθ - logπold )
43
+ # ------------------------------------------------------------
44
+ ratio = (logprobs_new - logprobs_old).exp() # [G, T]
45
+ pg_unclipped = -advantages * ratio # [G, T]
46
+ pg_clipped = -advantages * ratio.clamp(1.0 - clip_eps, 1.0 + clip_eps) # [G, T]
47
+
48
+ # Surrogate policy gradient loss (PPO-clip part)
49
+ # This corresponds to the -E[min(...)] in the formula.
50
+ policy_loss = torch.max(pg_unclipped, pg_clipped).mean()
51
+ approx_kl = (logprobs_old - logprobs_new).mean()
52
+
53
+ # KL penalty: β * E[ KL(...) ]
54
+ kl_loss = beta_kl * approx_kl
55
+ loss = policy_loss + kl_loss
56
+
57
+ return loss, approx_kl
58
+
59
+
60
+ # Utils functions for GRPO
61
+ def group_advantage(group_rewards: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Compute normalized advantages from group rewards using standardization.
64
+
65
+ Args:
66
+ group_rewards: Group rewards tensor [B, G] or [G]
67
+
68
+ Returns:
69
+ Normalized advantages with same shape as input
70
+ """
71
+ mean_reward = group_rewards.mean(dim=-1, keepdim=True)
72
+ std_reward = group_rewards.std(dim=-1, unbiased=False, keepdim=True) + 1e-8
73
+ advantages = (group_rewards - mean_reward) / std_reward
74
+ return advantages
75
+
76
+
77
+ def step_group_advantage(step_rewards: torch.Tensor, pad_mask: torch.Tensor | None = None) -> torch.Tensor:
78
+ """
79
+ Compute per-step normalized advantages from step rewards.
80
+ For each timestep t, normalizes across the G dimension (trajectories).
81
+
82
+ NOTE: No std normalization is applied here, Using DR. GRPO paper.
83
+ Args:
84
+ step_rewards: Per-step rewards tensor [B, G, T]
85
+ pad_mask: Optional mask for valid steps [B, G, T], True=valid
86
+
87
+ Returns:
88
+ Normalized advantages [B, G, T] where each timestep is normalized across G
89
+ """
90
+ # Normalize across G dimension for each (batch, timestep)
91
+ # step_rewards: [B, G, T]
92
+ mean_t = step_rewards.mean(dim=1, keepdim=True) # [B, 1, T]
93
+ advantages = (step_rewards - mean_t) # [B, G, T]
94
+
95
+ if pad_mask is not None:
96
+ advantages = advantages * pad_mask.float()
97
+
98
+ return advantages
99
+
100
+
101
+ def ppo_chess_loss(
102
+ logprobs_new: torch.Tensor, # [G, T] log πθ(a_{g,k,t} | s_{g,k,t})
103
+ logprobs_old: torch.Tensor, # [G, T] log πold(a_{g,k,t} | s_{g,k,t})
104
+ advantages: torch.Tensor, # [G, T]
105
+ clip_eps: float = 0.2, # ε in the formula
106
+ pad_mask: torch.Tensor | None = None, # [G, T], True = real, False = pad
107
+ return_info: bool = False,
108
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
+ """
110
+ Compute PPO-clip loss for chess policy optimization.
111
+
112
+ Args:
113
+ logprobs_new: New policy log probabilities [B, G, T] or [G, T]
114
+ logprobs_old: Old policy log probabilities [B, G, T] or [G, T]
115
+ advantages: Advantage values [B, G, T] or [G, T]
116
+ clip_eps: PPO clipping epsilon (default: 0.2)
117
+ pad_mask: Mask indicating valid steps, True=valid, False=padding
118
+ return_info: If True, return additional statistics
119
+
120
+ Returns:
121
+ If return_info=False: policy loss tensor [B, G, T] or [G, T]
122
+ If return_info=True: tuple of (policy_loss, mean_ratio, mean_clip_fraction)
123
+ """
124
+ if pad_mask is None:
125
+ pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
126
+ ratio = (logprobs_new - logprobs_old).exp() # [G, T]
127
+ pg_unclipped = -advantages * ratio # [G, T]
128
+ pg_clipped = -advantages * ratio.clamp(1.0 - clip_eps, 1.0 + clip_eps) # [G, T]
129
+ # Surrogate policy gradient loss (PPO-clip part)
130
+ # This corresponds to the -E[min(...)] in the formula.
131
+ policy_loss = torch.max(pg_unclipped, pg_clipped) * pad_mask.float()
132
+ if return_info:
133
+ valid_steps = pad_mask.sum().clamp_min(1.0)
134
+ mean_padded_ratio = (ratio * pad_mask.float()).sum() / valid_steps
135
+ clip_fraction_mask = (ratio > (1.0 + clip_eps)) | (ratio < (1.0 - clip_eps))
136
+ mean_clip_fraction = (clip_fraction_mask.float() * pad_mask.float()).sum() / valid_steps
137
+ return policy_loss, mean_padded_ratio, mean_clip_fraction # [G, T], scalar, scalar
138
+ return policy_loss # [G, T]
139
+
140
+
141
+ def kl_penalty(logprobs_new: torch.Tensor,
142
+ logprobs_old: torch.Tensor,
143
+ pad_mask: torch.Tensor | None = None) -> torch.Tensor:
144
+ """
145
+ Compute KL divergence penalty between old and new policies.
146
+
147
+ Args:
148
+ logprobs_new: New policy log probabilities
149
+ logprobs_old: Old policy log probabilities
150
+ pad_mask: Optional mask for valid steps
151
+
152
+ Returns:
153
+ Mean KL divergence over valid steps
154
+ """
155
+ if pad_mask is None:
156
+ pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
157
+ return (logprobs_old - logprobs_new)[pad_mask].mean()
158
+
159
+
160
+ def grpo_ppo_loss(
161
+ logprobs_new: torch.Tensor, # [B, G, T] or [G, T]
162
+ logprobs_old: torch.Tensor, # [B, G, T] or [G, T]
163
+ step_rewards: torch.Tensor, # [B, G, T] or [G, T] - per-step rewards
164
+ pad_mask: torch.Tensor | None = None, # [B, G, T] or [G, T]
165
+ clip_ratio: float = 0.2, # PPO clipping ratio (epsilon in paper)
166
+ kl_coef: float = 0.01, # KL penalty coefficient (beta in paper)
167
+ entropy_coef: float = 0.1, # Entropy bonus coefficient (prevents policy collapse)
168
+ return_info: bool = False, # Return extra info for logging
169
+ ) -> torch.Tensor | Tuple[torch.Tensor, GRPOLossInfo]:
170
+ """
171
+ Compute GRPO (Group Relative Policy Optimization) loss with PPO clipping.
172
+
173
+ This combines PPO-clip loss with KL divergence penalty and optional entropy bonus.
174
+ Advantages are computed per-step by normalizing step rewards across trajectories
175
+ (G dimension) for each timestep.
176
+
177
+ Args:
178
+ logprobs_new: New policy log probabilities [B, G, T] or [G, T]
179
+ logprobs_old: Old policy log probabilities [B, G, T] or [G, T]
180
+ step_rewards: Per-step rewards [B, G, T] or [G, T]
181
+ pad_mask: Mask indicating valid steps, True=valid, False=padding
182
+ clip_ratio: PPO clipping ratio (default: 0.2)
183
+ kl_coef: KL divergence penalty coefficient (default: 0.01)
184
+ entropy_coef: Entropy bonus coefficient (default: 0.0, set >0 to encourage exploration)
185
+ return_info: If True, return GRPOLossInfo for logging
186
+
187
+ Returns:
188
+ If return_info=False: scalar loss tensor
189
+ If return_info=True: tuple of (loss, GRPOLossInfo)
190
+ """
191
+ # Handle 2D input (no batch dimension) by adding batch dimension
192
+ if logprobs_new.ndim == 2:
193
+ logprobs_new = logprobs_new.unsqueeze(0)
194
+ logprobs_old = logprobs_old.unsqueeze(0)
195
+ step_rewards = step_rewards.unsqueeze(0)
196
+ if pad_mask is not None:
197
+ pad_mask = pad_mask.unsqueeze(0)
198
+
199
+ if pad_mask is None:
200
+ pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
201
+
202
+ # Compute per-step advantages (normalized across G for each timestep)
203
+ advantages = step_group_advantage(step_rewards, pad_mask).detach() # [B, G, T]
204
+
205
+ ppo_loss, mean_ratio, mean_clip_fraction = ppo_chess_loss(logprobs_new,
206
+ logprobs_old,
207
+ advantages,
208
+ clip_ratio,
209
+ pad_mask,
210
+ return_info=True)
211
+ valid_steps = pad_mask.sum().clamp_min(1)
212
+ ppo_loss = ppo_loss.sum() / valid_steps
213
+ kl_div = kl_penalty(logprobs_new, logprobs_old, pad_mask)
214
+
215
+ # Entropy bonus: H(π) ≈ -E[log π(a|s)] encourages exploration
216
+ # We use the negative log_probs of selected actions as an estimate
217
+ entropy = -logprobs_new[pad_mask].mean()
218
+
219
+ # Loss components:
220
+ # - loss_without_entropy = PPO loss + KL penalty
221
+ # - total loss = loss_without_entropy - entropy bonus
222
+ loss_without_entropy = ppo_loss + kl_coef * kl_div
223
+ loss = loss_without_entropy - entropy_coef * entropy
224
+
225
+ if return_info:
226
+ return loss, GRPOLossInfo(
227
+ kl_div=kl_div.detach(),
228
+ mean_ratio=mean_ratio.detach(),
229
+ mean_clip_fraction=mean_clip_fraction.detach(),
230
+ ppo_loss=ppo_loss.detach(),
231
+ entropy=entropy.detach(),
232
+ loss_without_entropy=loss_without_entropy.detach(),
233
+ )
234
+ return loss
235
+
hf_space_repo/grpo_logic/model.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import chess
5
+
6
+ from dataclasses import dataclass
7
+
8
+ from src.grpo_self_play.evaluator import Evaluator
9
+ from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
10
+ from src.grpo_self_play.grpo_logic.loss import grpo_ppo_loss
11
+ from src.grpo_self_play.grpo_logic.sampling import sample_trajectories_batched
12
+ from src.grpo_self_play.eval_utils import EvalConfig
13
+ from src.grpo_self_play.chess.policy_player import PolicyConfig
14
+ from src.grpo_self_play.chess.searcher import SearchConfig
15
+ from src.grpo_self_play.chess.stockfish import StockfishConfig
16
+ from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
17
+
18
+
19
+ class EntropyFloorMonitor:
20
+ """Monitors entropy and takes action when it falls below a floor (Recommendation 1).
21
+
22
+ Tracks consecutive steps where entropy is below a threshold and triggers
23
+ configurable actions (warn, stop, or boost entropy_coef) when the threshold
24
+ is breached for too long.
25
+ """
26
+
27
+ def __init__(self, floor: float, steps_threshold: int, action: str, boost_factor: float):
28
+ """
29
+ Args:
30
+ floor: Minimum entropy threshold
31
+ steps_threshold: Consecutive steps below floor before action
32
+ action: Action to take ("warn", "stop", "boost")
33
+ boost_factor: Factor to multiply entropy_coef when boosting
34
+ """
35
+ self.floor = floor
36
+ self.steps_threshold = steps_threshold
37
+ self.action = action
38
+ self.boost_factor = boost_factor
39
+ self.consecutive_low_steps = 0
40
+ self.triggered = False
41
+
42
+ def check(self, entropy: float, current_entropy_coef: float) -> tuple[float, dict]:
43
+ """Check entropy and return updated entropy_coef and metrics.
44
+
45
+ Args:
46
+ entropy: Current entropy value
47
+ current_entropy_coef: Current entropy coefficient
48
+
49
+ Returns:
50
+ Tuple of (new_entropy_coef, metrics_dict)
51
+ """
52
+ metrics = {}
53
+ new_entropy_coef = current_entropy_coef
54
+
55
+ if entropy < self.floor:
56
+ self.consecutive_low_steps += 1
57
+
58
+ if self.consecutive_low_steps >= self.steps_threshold and not self.triggered:
59
+ self.triggered = True
60
+ if self.action == "warn":
61
+ print(f"WARNING: Entropy collapse detected! Entropy={entropy:.4f} < floor={self.floor} "
62
+ f"for {self.consecutive_low_steps} consecutive steps.")
63
+ elif self.action == "stop":
64
+ raise RuntimeError(
65
+ f"STOPPING: Entropy collapse detected! Entropy={entropy:.4f} < floor={self.floor} "
66
+ f"for {self.consecutive_low_steps} consecutive steps.")
67
+ elif self.action == "boost":
68
+ new_entropy_coef = current_entropy_coef * self.boost_factor
69
+ print(f"BOOSTING entropy_coef: {current_entropy_coef:.4f} -> {new_entropy_coef:.4f} "
70
+ f"(entropy={entropy:.4f} < floor={self.floor})")
71
+ self.consecutive_low_steps = 0
72
+ self.triggered = False
73
+ else:
74
+ self.consecutive_low_steps = 0
75
+ self.triggered = False
76
+
77
+ metrics["entropy_floor/consecutive_low_steps"] = self.consecutive_low_steps
78
+ metrics["entropy_floor/below_floor"] = float(entropy < self.floor)
79
+ metrics["entropy_floor/current_entropy_coef"] = new_entropy_coef
80
+
81
+ return new_entropy_coef, metrics
82
+
83
+
84
+ def compute_group_collapse_metrics(
85
+ actions: torch.Tensor,
86
+ group_rewards: torch.Tensor,
87
+ step_rewards: torch.Tensor,
88
+ pad_mask: torch.Tensor,
89
+ ) -> dict:
90
+ """Compute within-board group collapse metrics (Recommendation 4).
91
+
92
+ These metrics directly measure whether all G trajectories from the same board
93
+ are converging to the same moves, which is the key failure mode in entropy collapse.
94
+
95
+ Args:
96
+ actions: Action indices [B, G, T]
97
+ group_rewards: Final rewards for each trajectory [B, G]
98
+ step_rewards: Per-step rewards [B, G, T]
99
+ pad_mask: Mask indicating valid steps [B, G, T], True=valid
100
+
101
+ Returns:
102
+ Dictionary of metrics for logging
103
+ """
104
+ B, _, T = actions.shape
105
+ metrics = {}
106
+
107
+ # 1. Action agreement: for each (b, t), what fraction of trajectories chose the most common action?
108
+ # agreement[b,t] = max_count(actions[b,:,t]) / G
109
+ action_agreement = torch.zeros(B, T, device=actions.device)
110
+ for b in range(B):
111
+ for t in range(T):
112
+ if pad_mask[b, :, t].any(): # At least one valid trajectory at this timestep
113
+ valid_actions = actions[b, pad_mask[b, :, t], t]
114
+ if len(valid_actions) > 0:
115
+ # Count occurrences of each action
116
+ _, counts = valid_actions.unique(return_counts=True)
117
+ max_count = counts.max().item()
118
+ num_valid = pad_mask[b, :, t].sum().item()
119
+ action_agreement[b, t] = max_count / num_valid
120
+
121
+ # Mask to only consider valid (b, t) pairs
122
+ valid_bt_mask = pad_mask.any(dim=1) # [B, T] - True if any trajectory valid at (b, t)
123
+ valid_agreements = action_agreement[valid_bt_mask]
124
+
125
+ if len(valid_agreements) > 0:
126
+ metrics["group_collapse/action_agreement_mean"] = valid_agreements.mean().item()
127
+ metrics["group_collapse/action_agreement_p90"] = valid_agreements.quantile(0.9).item()
128
+ metrics["group_collapse/action_agreement_max"] = valid_agreements.max().item()
129
+ else:
130
+ metrics["group_collapse/action_agreement_mean"] = 0.0
131
+ metrics["group_collapse/action_agreement_p90"] = 0.0
132
+ metrics["group_collapse/action_agreement_max"] = 0.0
133
+
134
+ # 2. Within-board reward diversity: std(group_rewards[b,:]) for each board b
135
+ # This measures whether trajectories from the same starting position get similar rewards
136
+ reward_std_within = group_rewards.std(dim=1) # [B]
137
+ metrics["group_collapse/reward_std_within_mean"] = reward_std_within.mean().item()
138
+ metrics["group_collapse/reward_std_within_min"] = reward_std_within.min().item()
139
+
140
+ # 3. Within-board step reward diversity: std(step_rewards[b,:,t]) for each (b, t)
141
+ # Only compute for valid (b, t) pairs
142
+ step_reward_std_within = torch.zeros(B, T, device=step_rewards.device)
143
+ for b in range(B):
144
+ for t in range(T):
145
+ valid_mask_bt = pad_mask[b, :, t]
146
+ if valid_mask_bt.sum() > 1: # Need at least 2 valid trajectories for std
147
+ step_reward_std_within[b, t] = step_rewards[b, valid_mask_bt, t].std().item()
148
+
149
+ valid_step_stds = step_reward_std_within[valid_bt_mask]
150
+ if len(valid_step_stds) > 0:
151
+ metrics["group_collapse/step_reward_std_within_mean"] = valid_step_stds.mean().item()
152
+ metrics["group_collapse/step_reward_std_within_min"] = valid_step_stds.min().item()
153
+ else:
154
+ metrics["group_collapse/step_reward_std_within_mean"] = 0.0
155
+ metrics["group_collapse/step_reward_std_within_min"] = 0.0
156
+
157
+ return metrics
158
+
159
+
160
+ class AdaptiveKLController:
161
+ """Adapts KL coefficient to maintain target KL divergence (Recommendation 2).
162
+
163
+ Implements a simple multiplicative controller that increases kl_coef when
164
+ KL divergence exceeds target and decreases it when below target.
165
+ """
166
+
167
+ def __init__(self, initial_kl_coef: float, target_kl: float, adapt_rate: float,
168
+ kl_coef_min: float, kl_coef_max: float):
169
+ """
170
+ Args:
171
+ initial_kl_coef: Starting KL coefficient
172
+ target_kl: Target KL divergence value
173
+ adapt_rate: Multiplicative factor for adjustment
174
+ kl_coef_min: Minimum allowed kl_coef
175
+ kl_coef_max: Maximum allowed kl_coef
176
+ """
177
+ self.current_kl_coef = initial_kl_coef
178
+ self.target_kl = target_kl
179
+ self.adapt_rate = adapt_rate
180
+ self.kl_coef_min = kl_coef_min
181
+ self.kl_coef_max = kl_coef_max
182
+
183
+ def update(self, kl_div: float) -> dict:
184
+ """Update KL coefficient based on current KL divergence.
185
+
186
+ Args:
187
+ kl_div: Current KL divergence value
188
+
189
+ Returns:
190
+ Metrics dict for logging
191
+ """
192
+ if kl_div > self.target_kl:
193
+ self.current_kl_coef = min(self.current_kl_coef * self.adapt_rate, self.kl_coef_max)
194
+ else:
195
+ self.current_kl_coef = max(self.current_kl_coef / self.adapt_rate, self.kl_coef_min)
196
+
197
+ return {
198
+ "adaptive_kl/current_kl_coef": self.current_kl_coef,
199
+ "adaptive_kl/target_kl": self.target_kl,
200
+ "adaptive_kl/kl_ratio": kl_div / self.target_kl if self.target_kl > 0 else 0.0,
201
+ }
202
+
203
+
204
+ @dataclass
205
+ class GRPOConfig:
206
+ """Configuration for GRPO (Group Relative Policy Optimization) training.
207
+
208
+ Attributes:
209
+ lr: Learning rate for optimizer
210
+ num_trajectories: Number of trajectory groups to sample per batch
211
+ trajectory_depth: Maximum depth of each trajectory
212
+ clip_ratio: PPO clipping ratio (epsilon)
213
+ kl_coef: KL divergence penalty coefficient (beta)
214
+ entropy_coef: Entropy bonus coefficient (encourages exploration, prevents policy collapse)
215
+ eval_every_n_epochs: Frequency of evaluation runs (not used in model, but useful for trainer)
216
+
217
+ # Entropy floor monitoring (Recommendation 1)
218
+ use_entropy_floor: Whether to enable entropy floor monitoring
219
+ entropy_floor: Minimum entropy threshold for collapse detection
220
+ entropy_floor_steps: Number of consecutive steps below floor before alert/action
221
+ entropy_floor_action: Action to take when entropy floor is breached ("warn", "stop", "boost")
222
+ entropy_boost_factor: Factor to multiply entropy_coef when boosting (if action="boost")
223
+
224
+ # Adaptive KL controller (Recommendation 2)
225
+ adaptive_kl: Whether to use adaptive KL coefficient
226
+ target_kl: Target KL divergence value
227
+ kl_adapt_rate: Rate at which to adjust kl_coef (higher = faster adaptation)
228
+ kl_coef_min: Minimum allowed kl_coef
229
+ kl_coef_max: Maximum allowed kl_coef
230
+
231
+ # PPO-style multiple updates
232
+ ppo_steps: Number of optimization steps per sampled trajectory batch (reuses samples)
233
+
234
+ # Rollout temperature for exploration
235
+ rollout_temperature: Temperature for action sampling during rollouts (>1 increases exploration)
236
+
237
+ # Safety checks on training dynamics
238
+ enable_safety_checks: Whether to abort training when known-bad patterns persist
239
+ safety_patience_steps: Number of training steps to tolerate violations before aborting
240
+ max_clip_fraction: If mean_clip_fraction > this for too long -> abort
241
+ min_entropy: If entropy < this for too long -> abort
242
+ max_kl_divergence: If KL >> target_kl for too long -> abort
243
+ """
244
+ # Clean run defaults (see research_docs/2026-02-06_loss-budget-and-monitor-analysis.md)
245
+ lr: float = 1e-6 # Reduced: PPO signal now dominates gradient
246
+ num_trajectories: int = 4
247
+ trajectory_depth: int = 5
248
+ clip_ratio: float = 0.2
249
+ kl_coef: float = 0.001 # Reduced from 0.01 (was overridden to 0.1 by adaptive KL)
250
+ entropy_coef: float = 0.0 # Removed: not in original GRPO loss, was 95% of gradient
251
+ eval_every_n_epochs: int = 10
252
+
253
+ # Entropy floor monitoring — disabled by default (never triggered in practice)
254
+ use_entropy_floor: bool = False
255
+ entropy_floor: float = 1.5
256
+ entropy_floor_steps: int = 200
257
+ entropy_floor_action: str = "boost"
258
+ entropy_boost_factor: float = 2.0
259
+
260
+ # Adaptive KL controller — disabled by default (saturated at max instantly)
261
+ adaptive_kl: bool = False
262
+ target_kl: float = 0.015
263
+ kl_adapt_rate: float = 1.2
264
+ kl_coef_min: float = 0.003
265
+ kl_coef_max: float = 0.05
266
+
267
+ # PPO-style multiple updates per sample
268
+ ppo_steps: int = 1
269
+
270
+ # Rollout temperature for exploration (>1 flattens distribution, increases entropy)
271
+ rollout_temperature: float = 1.0
272
+
273
+ # Safety checks on training dynamics
274
+ enable_safety_checks: bool = False
275
+ safety_patience_steps: int = 1000 # Number of training steps to tolerate violations
276
+ # Thresholds derived from prior research docs
277
+ max_clip_fraction: float = 0.95 # If mean_clip_fraction > this for too long -> abort
278
+ min_entropy: float = 0.5 # If entropy < this for too long -> abort
279
+ max_kl_divergence: float = 0.08 # If KL >> target_kl for too long -> abort
280
+
281
+ # Teacher forcing: use Stockfish for rival moves during trajectory sampling
282
+ teacher_forcing_prob: float = 0.0 # Probability of using Stockfish for rival (opponent) moves
283
+ teacher_forcing_depth: int = 4 # Stockfish search depth for teacher forcing moves
284
+
285
+
286
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
287
+ torch.serialization.add_safe_globals([GRPOConfig])
288
+
289
+
290
+ class GRPOChessTransformer(pl.LightningModule):
291
+ """PyTorch Lightning module for training chess policy with GRPO.
292
+
293
+ This module implements Group Relative Policy Optimization (GRPO) for training
294
+ a chess transformer policy. It maintains both a current policy and an old policy
295
+ for computing importance sampling ratios in the PPO loss.
296
+
297
+ Attributes:
298
+ policy_model: Current policy model being trained
299
+ old_policy_model: Frozen copy of policy for importance sampling
300
+ evaluator: Evaluator for running games against Stockfish
301
+ eval_every_n_epochs: Frequency of evaluation runs
302
+ entropy_monitor: Optional entropy floor monitor (Recommendation 1)
303
+ kl_controller: Optional adaptive KL controller (Recommendation 2)
304
+ current_entropy_coef: Current entropy coefficient (mutable for entropy boosting)
305
+ automatic_optimization: Set to False for manual PPO steps
306
+ """
307
+ automatic_optimization = False # Manual optimization for ppo_steps
308
+
309
+ def __init__(self,
310
+ transformer_config: ChessTransformerConfig,
311
+ grpo_config: GRPOConfig,
312
+ eval_cfg: EvalConfig | None = None,
313
+ stockfish_cfg: StockfishConfig | None = None,
314
+ policy_cfg: PolicyConfig | None = None,
315
+ searcher_cfg: SearchConfig | None = None,
316
+ pretrain_cfg: PretrainLoadConfig | None = None):
317
+ """
318
+ Initialize GRPO Chess Transformer.
319
+
320
+ Args:
321
+ transformer_config: Configuration for the chess transformer model
322
+ grpo_config: GRPO training configuration
323
+ eval_cfg: Optional evaluation configuration
324
+ stockfish_cfg: Optional Stockfish configuration for evaluation
325
+ policy_cfg: Optional policy player configuration
326
+ searcher_cfg: Optional search configuration
327
+ pretrain_cfg: Optional pretrain config for loading pretrained weights
328
+ """
329
+ super().__init__()
330
+ self.save_hyperparameters()
331
+ self.policy_model = ChessTransformer(transformer_config)
332
+ self.old_policy_model = ChessTransformer(transformer_config)
333
+
334
+ # Load pretrained weights if specified
335
+ if pretrain_cfg and pretrain_cfg.checkpoint_path:
336
+ self._load_pretrained_weights(pretrain_cfg)
337
+
338
+ self._sync_old_policy()
339
+
340
+ # Evaluation config
341
+ self.eval_every_n_epochs = grpo_config.eval_every_n_epochs
342
+ self.evaluator = Evaluator(eval_cfg=eval_cfg or EvalConfig(),
343
+ policy_cfg=policy_cfg or PolicyConfig(),
344
+ stockfish_cfg=stockfish_cfg or StockfishConfig(),
345
+ searcher_cfg=searcher_cfg)
346
+
347
+ # Entropy floor monitor (Recommendation 1) - optional
348
+ self.entropy_monitor: EntropyFloorMonitor | None = None
349
+ if grpo_config.use_entropy_floor:
350
+ self.entropy_monitor = EntropyFloorMonitor(
351
+ floor=grpo_config.entropy_floor,
352
+ steps_threshold=grpo_config.entropy_floor_steps,
353
+ action=grpo_config.entropy_floor_action,
354
+ boost_factor=grpo_config.entropy_boost_factor,
355
+ )
356
+ self.current_entropy_coef = grpo_config.entropy_coef
357
+
358
+ # Adaptive KL controller (Recommendation 2) - optional
359
+ self.kl_controller: AdaptiveKLController | None = None
360
+ if grpo_config.adaptive_kl:
361
+ self.kl_controller = AdaptiveKLController(
362
+ initial_kl_coef=grpo_config.kl_coef,
363
+ target_kl=grpo_config.target_kl,
364
+ adapt_rate=grpo_config.kl_adapt_rate,
365
+ kl_coef_min=grpo_config.kl_coef_min,
366
+ kl_coef_max=grpo_config.kl_coef_max,
367
+ )
368
+
369
+ # Safety-check state (for tracking persistent violations)
370
+ self._safety_step_idx: int = 0
371
+ self._high_clip_steps: int = 0
372
+ self._low_entropy_steps: int = 0
373
+ self._high_kl_steps: int = 0
374
+
375
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
376
+ """Forward pass through the current policy model.
377
+
378
+ Args:
379
+ x: Input tensor [batch, seq_len]
380
+
381
+ Returns:
382
+ Policy logits [batch, action_dim]
383
+ """
384
+ return self.policy_model(x)
385
+
386
+ def _old_forward(self, x: torch.Tensor) -> torch.Tensor:
387
+ """Forward pass through the old (frozen) policy model.
388
+
389
+ Args:
390
+ x: Input tensor [batch, seq_len]
391
+
392
+ Returns:
393
+ Policy logits [batch, action_dim]
394
+ """
395
+ return self.old_policy_model(x)
396
+
397
+ def _sync_old_policy(self) -> None:
398
+ """Synchronize old policy model with current policy and freeze it."""
399
+ self.old_policy_model.load_state_dict(self.policy_model.state_dict())
400
+ # Freeze old policy parameters
401
+ for param in self.old_policy_model.parameters():
402
+ param.requires_grad = False
403
+
404
+ def _load_pretrained_weights(self, pretrain_cfg: PretrainLoadConfig) -> None:
405
+ """Load pretrained weights from a checkpoint.
406
+
407
+ Args:
408
+ pretrain_cfg: Pretrain configuration with checkpoint path and freeze settings
409
+ """
410
+ checkpoint_path = pretrain_cfg.checkpoint_path
411
+ print(f"Loading pretrained weights from: {checkpoint_path}")
412
+
413
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
414
+
415
+ # Handle different checkpoint formats
416
+ if 'model_state_dict' in checkpoint:
417
+ state_dict = checkpoint['model_state_dict']
418
+ elif 'state_dict' in checkpoint:
419
+ # Lightning checkpoint format - extract policy_model weights
420
+ state_dict = {}
421
+ for k, v in checkpoint['state_dict'].items():
422
+ if k.startswith('model.'):
423
+ # From PretrainChessTransformer
424
+ state_dict[k[6:]] = v # Remove 'model.' prefix
425
+ elif k.startswith('policy_model.'):
426
+ # From GRPOChessTransformer
427
+ state_dict[k[13:]] = v # Remove 'policy_model.' prefix
428
+ else:
429
+ # Assume it's a raw state dict
430
+ state_dict = checkpoint
431
+
432
+ # Load into policy model
433
+ missing, unexpected = self.policy_model.load_state_dict(state_dict, strict=False)
434
+ if missing:
435
+ print(f"Warning: Missing keys in pretrained checkpoint: {missing}")
436
+ if unexpected:
437
+ print(f"Warning: Unexpected keys in pretrained checkpoint: {unexpected}")
438
+
439
+ print(f"Successfully loaded pretrained weights")
440
+
441
+ # Optionally freeze transformer layers
442
+ if pretrain_cfg.freeze_layers > 0:
443
+ self._freeze_transformer_layers(pretrain_cfg.freeze_layers)
444
+
445
+ def _freeze_transformer_layers(self, num_layers: int) -> None:
446
+ """Freeze the first N transformer encoder layers.
447
+
448
+ Args:
449
+ num_layers: Number of layers to freeze (from the bottom)
450
+ """
451
+ # Freeze embedding and positional encoding
452
+ for param in self.policy_model.embedding.parameters():
453
+ param.requires_grad = False
454
+ self.policy_model.pos_encoding.requires_grad = False
455
+
456
+ # Freeze specified number of transformer layers
457
+ for i, layer in enumerate(self.policy_model.transformer.layers):
458
+ if i < num_layers:
459
+ for param in layer.parameters():
460
+ param.requires_grad = False
461
+ print(f"Froze transformer layer {i}")
462
+
463
+ # Count trainable parameters
464
+ trainable = sum(p.numel() for p in self.policy_model.parameters() if p.requires_grad)
465
+ total = sum(p.numel() for p in self.policy_model.parameters())
466
+ print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
467
+
468
+ def _log_rewards_metrics(self, batch_group_rewards: torch.Tensor, prefix: str = "train/") -> None:
469
+ """Log reward statistics for monitoring training progress.
470
+
471
+ Args:
472
+ batch_group_rewards: Group rewards tensor [B, G]
473
+ prefix: Prefix for log keys (default: "train/")
474
+ """
475
+ mean_r = batch_group_rewards.mean()
476
+ best = batch_group_rewards.max()
477
+ gap = best - mean_r
478
+
479
+ self.log(prefix + "avg_reward", mean_r, prog_bar=True)
480
+ self.log(prefix + "reward_std", batch_group_rewards.std())
481
+ self.log(prefix + "reward_p50", batch_group_rewards.median())
482
+ self.log(prefix + "reward_p90", batch_group_rewards.quantile(0.9))
483
+ self.log(prefix + "reward_best", best)
484
+ self.log(prefix + "reward_gap_best_minus_mean", gap)
485
+
486
+ def on_train_epoch_start(self) -> None:
487
+ """Called at the start of each training epoch. Syncs old policy."""
488
+ self._sync_old_policy()
489
+
490
+ def _ppo_step(
491
+ self,
492
+ trajectories_states: torch.Tensor,
493
+ trajectories_actions: torch.Tensor,
494
+ trajectories_old_log_probs: torch.Tensor,
495
+ trajectories_legal_masks: torch.Tensor | None,
496
+ step_rewards: torch.Tensor,
497
+ effective_pad_mask: torch.Tensor,
498
+ ) -> tuple[torch.Tensor, object]:
499
+ """Perform a single PPO optimization step.
500
+
501
+ Args:
502
+ trajectories_states: State tensors [B, G, T, SEQ]
503
+ trajectories_actions: Action indices [B, G, T]
504
+ trajectories_old_log_probs: Log probs from old policy [B, G, T]
505
+ trajectories_legal_masks: Legal move masks [B, G, T, A] or None
506
+ step_rewards: Per-step rewards [B, G, T]
507
+ effective_pad_mask: Mask for valid steps [B, G, T]
508
+
509
+ Returns:
510
+ Tuple of (loss, loss_info)
511
+ """
512
+ # Compute new log probs with current policy
513
+ new_log_probs = self.policy_model.get_group_log_probs(
514
+ trajectories_states, trajectories_actions, trajectories_legal_masks
515
+ )
516
+
517
+ # Use current (possibly adapted) coefficients
518
+ kl_coef = self.kl_controller.current_kl_coef if self.kl_controller else self.hparams.grpo_config.kl_coef
519
+
520
+ loss, loss_info = grpo_ppo_loss(
521
+ new_log_probs,
522
+ trajectories_old_log_probs,
523
+ step_rewards,
524
+ effective_pad_mask,
525
+ clip_ratio=self.hparams.grpo_config.clip_ratio,
526
+ kl_coef=kl_coef,
527
+ entropy_coef=self.current_entropy_coef,
528
+ return_info=True,
529
+ )
530
+
531
+ if not torch.isfinite(loss):
532
+ raise ValueError(f"Non-finite loss encountered: {loss.item()}")
533
+
534
+ return loss, loss_info
535
+
536
+ def _run_safety_checks(self, loss_info) -> None:
537
+ """Run safety checks on training dynamics and abort if they persistently fail."""
538
+ cfg = self.hparams.grpo_config
539
+ if not cfg.enable_safety_checks:
540
+ return
541
+
542
+ self._safety_step_idx += 1
543
+
544
+ # 1) PPO clipping saturation
545
+ if loss_info.mean_clip_fraction.item() > cfg.max_clip_fraction:
546
+ self._high_clip_steps += 1
547
+ else:
548
+ self._high_clip_steps = 0
549
+
550
+ # 2) Entropy collapse
551
+ if loss_info.entropy.item() < cfg.min_entropy:
552
+ self._low_entropy_steps += 1
553
+ else:
554
+ self._low_entropy_steps = 0
555
+
556
+ # 3) Excessive KL divergence
557
+ if loss_info.kl_div.item() > cfg.max_kl_divergence:
558
+ self._high_kl_steps += 1
559
+ else:
560
+ self._high_kl_steps = 0
561
+
562
+ # Log safety counters for debugging
563
+ self.log("safety/high_clip_steps", float(self._high_clip_steps))
564
+ self.log("safety/low_entropy_steps", float(self._low_entropy_steps))
565
+ self.log("safety/high_kl_steps", float(self._high_kl_steps))
566
+
567
+ if (
568
+ self._high_clip_steps >= cfg.safety_patience_steps
569
+ or self._low_entropy_steps >= cfg.safety_patience_steps
570
+ or self._high_kl_steps >= cfg.safety_patience_steps
571
+ ):
572
+ raise RuntimeError(
573
+ "Safety checks triggered: training aborted due to persistent "
574
+ f"bad dynamics (clip={loss_info.mean_clip_fraction.item():.3f}, "
575
+ f"entropy={loss_info.entropy.item():.3f}, "
576
+ f"kl={loss_info.kl_div.item():.4f}). "
577
+ "Adjust GRPOConfig or investigate recent research docs."
578
+ )
579
+
580
+ def training_step(self, batch_fens: list[str], batch_idx: int) -> None:
581
+ """Perform a training step with multiple PPO optimization iterations.
582
+
583
+ Samples trajectories once, then performs ppo_steps optimization iterations
584
+ on the same sampled data to improve compute efficiency.
585
+
586
+ Args:
587
+ batch_fens: List of FEN strings representing starting positions
588
+ batch_idx: Batch index (unused)
589
+ """
590
+ opt = self.optimizers()
591
+
592
+ boards = [chess.Board(start_fen) for start_fen in batch_fens]
593
+ boards = [board for board in boards if not board.is_game_over()]
594
+ if not boards:
595
+ return # Skip if game over
596
+
597
+ trajectories_sample = sample_trajectories_batched(
598
+ self.old_policy_model,
599
+ boards,
600
+ self.hparams.grpo_config.num_trajectories,
601
+ self.hparams.grpo_config.trajectory_depth,
602
+ temperature=self.hparams.grpo_config.rollout_temperature,
603
+ teacher_forcing_prob=self.hparams.grpo_config.teacher_forcing_prob,
604
+ teacher_forcing_depth=self.hparams.grpo_config.teacher_forcing_depth,
605
+ )
606
+ if trajectories_sample is None:
607
+ return # Skip if no moves
608
+
609
+ # Extract trajectory components (sampled once, reused for ppo_steps)
610
+ trajectories_old_log_probs = trajectories_sample.trajectories_log_probs # [B, G, T]
611
+ trajectories_actions = trajectories_sample.trajectories_actions # [B, G, T]
612
+ trajectories_states = trajectories_sample.trajectories_states # [B, G, T, SEQ]
613
+ batch_group_rewards = trajectories_sample.group_rewards # [B, G] (for logging)
614
+ step_rewards = trajectories_sample.step_rewards # [B, G, T]
615
+ pad_mask = trajectories_sample.pad_mask # [B, G, T]
616
+ trajectories_legal_masks = trajectories_sample.trajectories_legal_masks # [B, G, T, A] or None
617
+
618
+ # Add starting player mask (only consider moves from the starting player's perspective)
619
+ _, _, T = pad_mask.shape
620
+ t = torch.arange(T, device=pad_mask.device)
621
+ start_player_mask = (t % 2 == 0)[None, None, :] # [1, 1, T]
622
+ effective_pad_mask = pad_mask & start_player_mask # [B, G, T]
623
+
624
+ ppo_steps = self.hparams.grpo_config.ppo_steps
625
+
626
+ # Perform multiple PPO optimization steps on the same sampled trajectories
627
+ for ppo_step_idx in range(ppo_steps):
628
+ loss, loss_info = self._ppo_step(
629
+ trajectories_states,
630
+ trajectories_actions,
631
+ trajectories_old_log_probs,
632
+ trajectories_legal_masks,
633
+ step_rewards,
634
+ effective_pad_mask,
635
+ )
636
+
637
+ # Manual optimization step
638
+ opt.zero_grad()
639
+ self.manual_backward(loss)
640
+ self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm")
641
+ opt.step()
642
+
643
+ # Entropy floor monitoring (Recommendation 1) - only on last ppo_step
644
+ if ppo_step_idx == ppo_steps - 1 and self.entropy_monitor is not None:
645
+ self.current_entropy_coef, entropy_metrics = self.entropy_monitor.check(
646
+ loss_info.entropy.item(), self.current_entropy_coef
647
+ )
648
+ for key, value in entropy_metrics.items():
649
+ self.log(key, value)
650
+
651
+ # Adaptive KL controller (Recommendation 2) - only on last ppo_step
652
+ if ppo_step_idx == ppo_steps - 1 and self.kl_controller is not None:
653
+ kl_metrics = self.kl_controller.update(loss_info.kl_div.item())
654
+ for key, value in kl_metrics.items():
655
+ self.log(key, value)
656
+
657
+ # Within-board group collapse metrics (Recommendation 4) - log once per training_step
658
+ collapse_metrics = compute_group_collapse_metrics(
659
+ trajectories_actions, batch_group_rewards, step_rewards, pad_mask
660
+ )
661
+ for key, value in collapse_metrics.items():
662
+ self.log(key, value)
663
+
664
+ # Standard logging (log final ppo_step metrics)
665
+ valid_mask = pad_mask.float() # [B, G, T] 1 = real step
666
+
667
+ self.log("train_total_loss", loss, prog_bar=True)
668
+ self.log("pad_fraction", 1.0 - valid_mask.mean())
669
+ self.log("avg_trajectory_length", pad_mask.float().sum(dim=-1).mean())
670
+
671
+ self.log("mean_kl_divergence", loss_info.kl_div)
672
+ self.log("mean_ratio", loss_info.mean_ratio)
673
+ self.log("mean_clip_fraction", loss_info.mean_clip_fraction)
674
+ self.log("ppo_loss", loss_info.ppo_loss)
675
+ self.log("entropy", loss_info.entropy)
676
+ # Loss without the entropy bonus term (PPO + KL only)
677
+ self.log("train/loss_without_entropy", loss_info.loss_without_entropy)
678
+ self.log("ppo_steps", float(ppo_steps))
679
+ self._log_rewards_metrics(batch_group_rewards, prefix="train/")
680
+
681
+ # Log step rewards statistics (only for valid steps)
682
+ valid_step_rewards = step_rewards[pad_mask]
683
+ self.log("train/step_reward_mean", valid_step_rewards.mean())
684
+ self.log("train/step_reward_std", valid_step_rewards.std())
685
+
686
+ # Log raw centipawn step rewards (before normalization) for debugging
687
+ raw_step_cp = trajectories_sample.raw_step_cp
688
+ valid_raw_step_cp = raw_step_cp[pad_mask]
689
+ self.log("train/raw_step_cp_mean", valid_raw_step_cp.mean())
690
+ self.log("train/raw_step_cp_std", valid_raw_step_cp.std())
691
+ self.log("train/raw_step_cp_abs_mean", valid_raw_step_cp.abs().mean())
692
+
693
+ # Run safety checks on the final loss statistics
694
+ self._run_safety_checks(loss_info)
695
+
696
+ def configure_optimizers(self) -> torch.optim.Adam:
697
+ """Configure optimizer for training.
698
+
699
+ Returns:
700
+ Adam optimizer with learning rate from GRPO config
701
+ """
702
+ return torch.optim.Adam(self.parameters(), lr=self.hparams.grpo_config.lr)
703
+
704
+ def _evaluate_against_stockfish(self) -> Optional[tuple[dict, list[str]]]:
705
+ """Run a single game evaluation against Stockfish with current policy model.
706
+
707
+ Returns:
708
+ Tuple of (results_dict, pgns) or None if evaluation failed
709
+ pgns is a list of PGN strings for all games played
710
+ """
711
+ was_training = self.training
712
+ self.eval()
713
+ try:
714
+ with torch.no_grad():
715
+ results, _, pgns = self.evaluator.single_evaluation(self.policy_model)
716
+ return results, pgns
717
+ except Exception as e:
718
+ self.logger.warning(f"Evaluation against Stockfish failed: {e}") if hasattr(self, 'logger') else print(f"Evaluation against Stockfish failed: {e}")
719
+ return None
720
+ finally:
721
+ if was_training:
722
+ self.train()
723
+
724
+ def _log_stockfish_eval(self, results: dict) -> None:
725
+ """Log scalar evaluation metrics from the Stockfish evaluation.
726
+
727
+ Args:
728
+ results: Dictionary containing evaluation results with keys:
729
+ - games: Total number of games played
730
+ - wins: Number of wins
731
+ - draws: Number of draws
732
+ - losses: Number of losses
733
+ - score: Win rate (0-1)
734
+ - elo_diff_vs_stockfish_approx: Approximate Elo difference
735
+ - termination_reasons: Dict mapping termination reasons to counts
736
+ """
737
+ # Scalar stats
738
+ self.log("eval_stockfish/games", results["games"])
739
+ self.log("eval_stockfish/wins", results["wins"])
740
+ self.log("eval_stockfish/draws", results["draws"])
741
+ self.log("eval_stockfish/losses", results["losses"])
742
+ self.log("eval_stockfish/score", results["score"], prog_bar=True)
743
+ self.log("eval_stockfish/elo_diff", results["elo_diff_vs_stockfish_approx"], prog_bar=True)
744
+
745
+ # Termination reasons as fractions
746
+ games = results["games"] or 1
747
+ for reason, cnt in results["termination_reasons"].items():
748
+ frac = cnt / games
749
+ self.log(f"eval_stockfish/term_{reason}", frac)
750
+
751
+ def _log_pgns(self, pgns: list[str]) -> None:
752
+ """Log PGNs to WandB as a text artifact.
753
+
754
+ Args:
755
+ pgns: List of PGN strings for all games played
756
+ """
757
+ if not pgns:
758
+ return
759
+
760
+ # Combine all PGNs into a single string
761
+ combined_pgn = "\n\n".join(pgns)
762
+
763
+ # Log to WandB if available
764
+ if self.logger and hasattr(self.logger, 'experiment'):
765
+ try:
766
+ import wandb
767
+ # Log as a text artifact
768
+ self.logger.experiment.log({
769
+ "eval_stockfish/pgns": wandb.Html(f"<pre>{combined_pgn}</pre>"),
770
+ "eval_stockfish/pgn_text": combined_pgn,
771
+ })
772
+ except Exception as e:
773
+ print(f"Failed to log PGNs to WandB: {e}")
774
+
775
+ def on_train_epoch_end(self) -> None:
776
+ """Called at the end of each training epoch. Runs evaluation if scheduled."""
777
+ if (self.current_epoch + 1) % self.eval_every_n_epochs == 0:
778
+ eval_result = self._evaluate_against_stockfish()
779
+ if eval_result is not None:
780
+ results, pgns = eval_result
781
+ self._log_stockfish_eval(results)
782
+ self._log_pgns(pgns)
hf_space_repo/grpo_logic/sampling.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import List, Optional
4
+ import chess
5
+ import chess.engine
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from dataclasses import dataclass
9
+
10
+ from src.grpo_self_play.chess.rewards import reward_board, evaluate_board, normalize_cp
11
+ from src.grpo_self_play.models import ChessTransformer
12
+ from src.grpo_self_play.searchless_chess_imports import ACTION_TO_MOVE, SEQUENCE_LENGTH, MOVE_TO_ACTION
13
+ from src.grpo_self_play.chess.chess_logic import board_to_tensor, get_legal_moves_mask
14
+ from src.grpo_self_play.chess.stockfish import stockfish_play, DEFAULT_STOCKFISH_TIMEOUT
15
+
16
+
17
+ def _get_teacher_engine_name() -> str:
18
+ """Get process-specific engine name for teacher forcing."""
19
+ return f"teacher_forcing_{os.getpid()}"
20
+
21
+
22
+ def get_stockfish_move(board: chess.Board, depth: int = 4, timeout: float = DEFAULT_STOCKFISH_TIMEOUT) -> Optional[chess.Move]:
23
+ """Get the best move from Stockfish for a given board position.
24
+
25
+ Args:
26
+ board: Chess board position
27
+ depth: Stockfish search depth
28
+ timeout: Maximum time to wait for response (seconds)
29
+
30
+ Returns:
31
+ Best move from Stockfish, or None if no move available or on error
32
+ """
33
+ limit = chess.engine.Limit(depth=depth)
34
+ return stockfish_play(_get_teacher_engine_name(), board, limit, timeout=timeout)
35
+
36
+
37
+ # Trajectories sampling logic
38
+ @dataclass
39
+ class TrajectoriesSample:
40
+ """Container for batched trajectory samples.
41
+
42
+ Attributes:
43
+ trajectories_log_probs: Log probabilities of sampled actions [B, G, T]
44
+ trajectories_actions: Action indices [B, G, T]
45
+ trajectories_states: State tensors [B, G, T, SEQ]
46
+ group_rewards: Final rewards for each trajectory group [B, G] (for logging)
47
+ step_rewards: Per-step rewards [B, G, T] where step_rewards[b,g,t] = eval(s_{t+1}) - eval(s_t)
48
+ pad_mask: Mask indicating valid steps, True=valid, False=padding [B, G, T]
49
+ trajectories_legal_masks: Legal moves masks [B, G, T, A]
50
+ raw_step_cp: Raw centipawn step rewards [B, G, T] (for logging, not normalized)
51
+ """
52
+ trajectories_log_probs: torch.Tensor # [B, G, T]
53
+ trajectories_actions: torch.Tensor # [B, G, T]
54
+ trajectories_states: torch.Tensor # [B, G, T, SEQ]
55
+ group_rewards: torch.Tensor # [B, G]
56
+ step_rewards: torch.Tensor # [B, G, T]
57
+ pad_mask: torch.Tensor # [B, G, T]
58
+ trajectories_legal_masks: torch.Tensor # [B, G, T, A]
59
+ raw_step_cp: torch.Tensor # [B, G, T] - raw centipawn differences
60
+
61
+
62
+ def batched_policy_step(model: ChessTransformer, boards: List[chess.Board], temperature: float = 1.0) -> Optional[tuple]:
63
+ """Sample actions from policy for a batch of boards.
64
+
65
+ Args:
66
+ model: Chess transformer model
67
+ boards: List of chess board positions
68
+ temperature: Temperature for sampling
69
+
70
+ Returns:
71
+ Tuple of (action_indices, log_probs, moves, states_tensor, legal_mask) or None if empty
72
+ """
73
+ N = len(boards)
74
+ if N == 0:
75
+ return None
76
+ device = next(model.parameters()).device
77
+ states_list = []
78
+ legal_masks = []
79
+ for board in boards:
80
+ state = board_to_tensor(board, device=device)
81
+ states_list.append(state)
82
+ mask = get_legal_moves_mask(board, device=device)
83
+ if mask.ndim == 2:
84
+ mask = mask.squeeze(0)
85
+ assert mask.ndim == 1, f"legal_moves_mask must be 1D [A], got {mask.shape}"
86
+ legal_masks.append(mask)
87
+
88
+ states_tensor = torch.cat(states_list, dim=0) # [N, SEQ]
89
+ legal_mask = torch.stack(legal_masks, dim=0) # [N, A] bool
90
+ assert legal_mask.dtype == torch.bool, "legal_mask must be bool dtype"
91
+ assert legal_mask.shape[0] == N, f"legal_mask batch size mismatch {legal_mask.shape[0]} vs {N}"
92
+ assert legal_mask.shape[1] == model.action_size, f"legal_mask action size mismatch {legal_mask.shape[1]} vs {model.action_size}"
93
+ if not legal_mask.any(dim=1).all():
94
+ bad = (~legal_mask.any(dim=1)).nonzero(as_tuple=False).flatten().tolist()
95
+ raise ValueError(f"Empty legal mask for boards: {bad}")
96
+ probs = model.get_legal_moves_probs(states_tensor, legal_mask, temperature) # [N, O]
97
+
98
+ action_idx = torch.multinomial(probs, 1).squeeze(1) # [N,]
99
+ chosen_probs = probs.gather(1, action_idx.unsqueeze(1)).squeeze(1) # [N,]
100
+ chosen_log_probs = torch.log(chosen_probs + 1e-12) # [N,], avoid log(0)
101
+
102
+ # Convert action indices to moves, ensure legality
103
+ moves = []
104
+ for i, idx in enumerate(action_idx.tolist()):
105
+ uci = ACTION_TO_MOVE[idx]
106
+ move = chess.Move.from_uci(uci)
107
+ if move not in boards[i].legal_moves:
108
+ raise ValueError(f"Sampled illegal move {uci} for board:\n{boards[i]}")
109
+ moves.append(move)
110
+ return action_idx, chosen_log_probs, moves, states_tensor, legal_mask
111
+
112
+
113
+ def sample_trajectories_batched(model: ChessTransformer,
114
+ boards: List[chess.Board],
115
+ num_trajectories: int,
116
+ trajectory_depth: int,
117
+ reward_depth: int = 4,
118
+ temperature: float = 1.0,
119
+ teacher_forcing_prob: float = 0.0,
120
+ teacher_forcing_depth: int = 4) -> Optional[TrajectoriesSample]:
121
+ """Sample multiple trajectories from each board position using the policy model.
122
+
123
+ Args:
124
+ model: Chess transformer model for action selection
125
+ boards: List of starting board positions [B]
126
+ num_trajectories: Number of trajectory groups per board (G)
127
+ trajectory_depth: Maximum depth of each trajectory (T)
128
+ reward_depth: Stockfish depth for reward computation (default: 4)
129
+ temperature: Temperature for action sampling (default: 1.0, >1 increases exploration)
130
+ teacher_forcing_prob: Probability of using Stockfish for rival moves (default: 0.0)
131
+ teacher_forcing_depth: Stockfish depth for teacher forcing moves (default: 4)
132
+
133
+ Returns:
134
+ TrajectoriesSample containing batched trajectory data, or None if no boards
135
+ """
136
+ device = next(model.parameters()).device
137
+ B, G, T = len(boards), num_trajectories, trajectory_depth
138
+ if B == 0:
139
+ return None
140
+
141
+ # Create B*G copies of boards for parallel trajectory sampling
142
+ envs = [boards[b].copy() for b in range(B) for _ in range(G)] # Length of B*G
143
+ # Per (b, g) storage as nested lists
144
+ traj_log_probs = [[[] for _ in range(G)] for _ in range(B)]
145
+ traj_actions = [[[] for _ in range(G)] for _ in range(B)]
146
+ traj_states = [[[] for _ in range(G)] for _ in range(B)]
147
+ traj_legal_masks = [[[] for _ in range(G)] for _ in range(B)]
148
+ traj_step_rewards = [[[] for _ in range(G)] for _ in range(B)]
149
+ traj_raw_step_cp = [[[] for _ in range(G)] for _ in range(B)] # Raw centipawn differences for logging
150
+
151
+ # Track POV and previous raw eval for each trajectory (we normalize step rewards later)
152
+ pov_is_white = [(boards[b].turn == chess.WHITE) for b in range(B) for _ in range(G)]
153
+ prev_evals_raw = [evaluate_board(boards[b], pov_is_white[b * G], depth=reward_depth, normalize=False)
154
+ for b in range(B) for _ in range(G)]
155
+
156
+ # Rollout: sample trajectories in batches
157
+ for t in range(T):
158
+ active_env_idx = [i for i, e in enumerate(envs) if not e.is_game_over()]
159
+ if not active_env_idx:
160
+ break
161
+
162
+ # Determine if this is the rival's turn (odd timesteps)
163
+ is_rival_turn = (t % 2 == 1)
164
+ use_teacher_forcing = is_rival_turn and teacher_forcing_prob > 0 and random.random() < teacher_forcing_prob
165
+
166
+ active_boards = [envs[i] for i in active_env_idx]
167
+ roll_out_step = batched_policy_step(model, active_boards, temperature=temperature)
168
+ if roll_out_step is None:
169
+ break
170
+
171
+ action_indices, log_probs, moves, states_batch, legal_mask = roll_out_step
172
+ if action_indices is None:
173
+ break
174
+
175
+ for j, env_idx_j in enumerate(active_env_idx):
176
+ move_j = moves[j]
177
+ if move_j is None:
178
+ continue # End of game for this env
179
+ b_idx = env_idx_j // G
180
+ g_idx = env_idx_j % G
181
+ state_j = states_batch[j]
182
+
183
+ # Teacher forcing: override rival's move with Stockfish
184
+ if use_teacher_forcing:
185
+ sf_move = get_stockfish_move(envs[env_idx_j], depth=teacher_forcing_depth)
186
+ if sf_move is not None and sf_move in envs[env_idx_j].legal_moves:
187
+ move_j = sf_move
188
+ # Update action index to match the Stockfish move
189
+ action_indices[j] = MOVE_TO_ACTION[move_j.uci()]
190
+
191
+ traj_log_probs[b_idx][g_idx].append(log_probs[j])
192
+ traj_actions[b_idx][g_idx].append(int(action_indices[j].item()))
193
+ traj_states[b_idx][g_idx].append(state_j)
194
+ traj_legal_masks[b_idx][g_idx].append(legal_mask[j])
195
+ envs[env_idx_j].push(move_j)
196
+
197
+ # Compute step reward: eval(new_state) - eval(prev_state)
198
+ # Get raw centipawn value, then normalize for step_rewards
199
+ new_eval_raw = evaluate_board(envs[env_idx_j], pov_is_white[env_idx_j], depth=reward_depth, normalize=False)
200
+ raw_step_cp = new_eval_raw - prev_evals_raw[env_idx_j]
201
+ step_reward = normalize_cp(new_eval_raw) - normalize_cp(prev_evals_raw[env_idx_j])
202
+ traj_step_rewards[b_idx][g_idx].append(step_reward)
203
+ traj_raw_step_cp[b_idx][g_idx].append(raw_step_cp)
204
+ prev_evals_raw[env_idx_j] = new_eval_raw
205
+
206
+ # Compute group_rewards for logging (sum of step rewards = final - initial)
207
+ group_rewards = torch.zeros(B, G, dtype=torch.float32, device=device)
208
+ for env_idx, env in enumerate(envs):
209
+ b_idx = env_idx // G
210
+ g_idx = env_idx % G
211
+ group_rewards[b_idx, g_idx] = reward_board(env, boards[b_idx], depth=reward_depth, movetime_ms=0)
212
+
213
+ # Allocate padded tensors
214
+ trajectories_log_probs = torch.zeros(B, G, T, dtype=torch.float32, device=device)
215
+ trajectories_actions = torch.zeros(B, G, T, dtype=torch.long, device=device)
216
+ trajectories_states = torch.zeros(B, G, T, SEQUENCE_LENGTH, dtype=torch.long, device=device)
217
+ trajectories_legal_masks = torch.zeros(B, G, T, model.action_size, dtype=torch.bool, device=device)
218
+ trajectories_legal_masks[..., 0] = True # Ensure at least one legal move (to avoid empty legal masks -> NaNs in log_softmax)
219
+ step_rewards = torch.zeros(B, G, T, dtype=torch.float32, device=device)
220
+ raw_step_cp = torch.zeros(B, G, T, dtype=torch.float32, device=device)
221
+ pad_mask = torch.zeros(B, G, T, dtype=torch.bool, device=device)
222
+ for b in range(B):
223
+ for g in range(G):
224
+ L = len(traj_log_probs[b][g])
225
+ assert L <= T, f"Trajectory length {L} exceeds pad_length {T}"
226
+ pad_mask[b, g, :L] = True
227
+ trajectories_log_probs[b, g, :L] = torch.stack(traj_log_probs[b][g], dim=0)
228
+ trajectories_actions[b, g, :L] = torch.tensor(traj_actions[b][g], dtype=torch.long, device=device)
229
+ trajectories_states[b, g, :L] = torch.stack(traj_states[b][g], dim=0)
230
+ if L > 0:
231
+ trajectories_legal_masks[b, g, :L] = torch.stack(traj_legal_masks[b][g], dim=0)
232
+ step_rewards[b, g, :L] = torch.tensor(traj_step_rewards[b][g], dtype=torch.float32, device=device)
233
+ raw_step_cp[b, g, :L] = torch.tensor(traj_raw_step_cp[b][g], dtype=torch.float32, device=device)
234
+
235
+ return TrajectoriesSample(trajectories_log_probs,
236
+ trajectories_actions,
237
+ trajectories_states,
238
+ group_rewards,
239
+ step_rewards,
240
+ pad_mask,
241
+ trajectories_legal_masks,
242
+ raw_step_cp)
243
+
hf_space_repo/logging_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging utilities for GRPO training.
2
+
3
+ Uses Python's standard logging module which WandB captures automatically
4
+ in the Logs tab of a run.
5
+ """
6
+
7
+ import logging
8
+
9
+ _initialized_loggers = set()
10
+
11
+
12
+ def get_logger(name: str = "grpo_chess") -> logging.Logger:
13
+ """Get a logger that appears in WandB Logs tab.
14
+
15
+ Args:
16
+ name: Logger name (default: "grpo_chess")
17
+
18
+ Returns:
19
+ Configured logger instance
20
+ """
21
+ logger = logging.getLogger(name)
22
+
23
+ if name not in _initialized_loggers:
24
+ logger.setLevel(logging.INFO)
25
+ handler = logging.StreamHandler()
26
+ handler.setFormatter(logging.Formatter(
27
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
28
+ ))
29
+ logger.addHandler(handler)
30
+ _initialized_loggers.add(name)
31
+
32
+ return logger
hf_space_repo/models.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import chess
5
+ from typing import Optional
6
+
7
+ from dataclasses import dataclass
8
+ from src.grpo_self_play.searchless_chess_imports import ACTION_TO_MOVE
9
+ from src.grpo_self_play.chess.chess_logic import board_to_tensor, get_legal_moves_indices
10
+
11
+ @dataclass
12
+ class ChessTransformerConfig:
13
+ """Configuration for the Chess Transformer model.
14
+
15
+ Attributes:
16
+ vocab_size: Size of the vocabulary (token dictionary)
17
+ embed_dim: Embedding dimension for transformer
18
+ num_layers: Number of transformer encoder layers
19
+ num_heads: Number of attention heads
20
+ action_dim: Dimension of action space (number of possible moves)
21
+ """
22
+ vocab_size: int = 300
23
+ embed_dim: int = 256
24
+ num_layers: int = 4
25
+ num_heads: int = 8
26
+ action_dim: int = 1968
27
+
28
+
29
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
30
+ torch.serialization.add_safe_globals([ChessTransformerConfig])
31
+
32
+
33
+ class ChessTransformer(nn.Module):
34
+ """Transformer-based chess policy network.
35
+
36
+ Takes FEN-encoded board states as input and outputs action logits.
37
+ Uses a transformer encoder with learnable positional encodings.
38
+ """
39
+ def __init__(self, transformer_config: ChessTransformerConfig):
40
+ """
41
+ Initialize Chess Transformer.
42
+
43
+ Args:
44
+ transformer_config: Configuration for the transformer model
45
+ """
46
+ super().__init__()
47
+ vocab_size = transformer_config.vocab_size
48
+ embed_dim = transformer_config.embed_dim
49
+ num_layers = transformer_config.num_layers
50
+ num_heads = transformer_config.num_heads
51
+ action_dim = transformer_config.action_dim
52
+
53
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
54
+
55
+ # DeepMind uses absolute or relative pos encoding.
56
+ # For simplicity, we use learnable absolute encoding for FEN length (~80 chars)
57
+ self.pos_encoding = nn.Parameter(torch.randn(1, 128, embed_dim))
58
+
59
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
60
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
61
+
62
+ # Head outputs 1968 logits (one for each possible unique move type)
63
+ self.policy_head = nn.Sequential(
64
+ nn.Linear(embed_dim, embed_dim),
65
+ nn.ReLU(),
66
+ nn.Linear(embed_dim, action_dim)
67
+ )
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ """Forward pass through the transformer.
71
+
72
+ Args:
73
+ x: Input tensor of token IDs [batch, seq_len]
74
+
75
+ Returns:
76
+ Action logits [batch, action_dim]
77
+ """
78
+ batch, seq = x.shape
79
+
80
+ # Create padding mask: True indicates a masked position (padding token 0)
81
+ src_key_padding_mask = (x == 0)
82
+ x = self.embedding(x) + self.pos_encoding[:, :seq, :]
83
+
84
+ # Pass the padding mask to the transformer
85
+ out = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
86
+
87
+ # Pool: Mean of the non-masked tokens
88
+ mask = ~src_key_padding_mask
89
+ mask_expanded = mask.unsqueeze(-1).float() # [B, SEQ, 1]
90
+ pooled = (out * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp_min(1)
91
+
92
+ return self.policy_head(pooled)
93
+
94
+ @property
95
+ def device(self) -> torch.device:
96
+ """Get the device of the model parameters."""
97
+ return next(self.parameters()).device
98
+
99
+ @property
100
+ def action_size(self) -> int:
101
+ """Get the size of the action space."""
102
+ return self.policy_head[-1].out_features
103
+
104
+ def get_legal_moves_logits(self, tensor_state: torch.Tensor,
105
+ legal_moves_mask: torch.Tensor,
106
+ temperature: float = 1.0) -> torch.Tensor:
107
+ """Get logits for legal moves only, masking illegal moves.
108
+
109
+ Args:
110
+ tensor_state: Board state tensor [B, SEQ]
111
+ legal_moves_mask: Boolean mask for legal moves [B, A]
112
+ temperature: Temperature for scaling logits
113
+
114
+ Returns:
115
+ Masked logits [B, A] with illegal moves set to -inf
116
+ """
117
+ assert legal_moves_mask is not None, "legal_moves_mask cannot be None"
118
+ logits = self(tensor_state) / temperature
119
+ return logits.masked_fill(~legal_moves_mask, -float('inf'))
120
+
121
+ def get_legal_moves_probs(self, tensor_state: torch.Tensor,
122
+ legal_moves_mask: torch.Tensor,
123
+ temperature: float = 1.0) -> torch.Tensor:
124
+ """Get probability distribution over legal moves.
125
+
126
+ Args:
127
+ tensor_state: Board state tensor [B, SEQ]
128
+ legal_moves_mask: Boolean mask for legal moves [B, A]
129
+ temperature: Temperature for scaling logits
130
+
131
+ Returns:
132
+ Probability distribution [B, A] over legal moves
133
+ """
134
+ mask_logits = self.get_legal_moves_logits(tensor_state, legal_moves_mask, temperature)
135
+ return F.softmax(mask_logits, dim=-1)
136
+
137
+ def get_group_log_probs(self,
138
+ trajectories_states: torch.Tensor,
139
+ action_idx: torch.Tensor,
140
+ legal_moves_mask: torch.Tensor,
141
+ temperature: float = 1.0) -> torch.Tensor:
142
+ """Get log probabilities for actions in batched trajectories.
143
+
144
+ Args:
145
+ trajectories_states: State tensors [B, G, T, SEQ]
146
+ action_idx: Action indices [B, G, T]
147
+ legal_moves_mask: Legal moves mask [B, G, T, A]
148
+ temperature: Temperature for scaling logits
149
+
150
+ Returns:
151
+ Log probabilities [B, G, T] for the selected actions
152
+ """
153
+ assert legal_moves_mask is not None, "legal_moves_mask cannot be None"
154
+ assert legal_moves_mask.dtype == torch.bool, "legal_moves_mask must be bool dtype"
155
+ x = trajectories_states # [B, G, T, SEQ]
156
+ B, G, T, L = x.shape
157
+ x_flat = x.view(B * G * T, L) # [B*G*T, SEQ]
158
+ if legal_moves_mask is not None:
159
+ legal_moves_mask = legal_moves_mask.view(B * G * T, -1) # [B*G*T, O]
160
+ masked_logits = self.get_legal_moves_logits(x_flat, legal_moves_mask, temperature) # [B*G*T, O]
161
+ log_probs_all = F.log_softmax(masked_logits, dim=-1) # [B*G*T, O]
162
+
163
+ action_idx_flat = action_idx.view(B * G * T, 1) # [B*G*T, 1]
164
+ log_probs_flat = log_probs_all.gather(1, action_idx_flat).squeeze(-1) # [B*G*T]
165
+ log_probs = log_probs_flat.view(B, G, T) # [B, G, T]
166
+ return log_probs
167
+
168
+ def _get_action_logits(self, board: chess.Board, temperature: float = 1.0) -> Optional[torch.Tensor]:
169
+ """Get action logits for a single board position.
170
+
171
+ Args:
172
+ board: Chess board position
173
+ temperature: Temperature for scaling logits
174
+
175
+ Returns:
176
+ Logits tensor [1, action_dim] or None if no legal moves
177
+ """
178
+ legal_moves = list(board.legal_moves)
179
+ legal_indices = get_legal_moves_indices(board)
180
+
181
+ if not legal_moves:
182
+ return None
183
+
184
+ # Run model
185
+ state = board_to_tensor(board, device=self.device)
186
+ logits = self(state) # [1, O]
187
+
188
+ output = torch.full_like(logits, -float('inf'))
189
+ output[0, legal_indices] = logits[0, legal_indices] / temperature
190
+ return output
191
+
192
+ def select_action(self, board: chess.Board, temperature: float = 1.0) -> tuple[Optional[chess.Move], Optional[torch.Tensor], Optional[int]]:
193
+ """Sample an action from the policy for a given board position.
194
+
195
+ Args:
196
+ board: Chess board position
197
+ temperature: Temperature for sampling (higher = more random)
198
+
199
+ Returns:
200
+ Tuple of (move, log_prob, action_idx) or (None, None, None) if no legal moves
201
+ """
202
+ logits = self._get_action_logits(board, temperature)
203
+ if logits is None:
204
+ return None, None, None
205
+ logits = logits.squeeze(0) # Remove batch dimension
206
+ probs = F.softmax(logits, dim=0)
207
+
208
+ # Sample
209
+ action_idx = int(torch.multinomial(probs, 1).item())
210
+ chosen_move = ACTION_TO_MOVE[action_idx]
211
+ log_prob = torch.log(probs[action_idx] + 1e-12) # Avoid log(0)
212
+
213
+ return chess.Move.from_uci(chosen_move), log_prob, action_idx
214
+
215
+
216
+ def select_action_greedy(model: ChessTransformer, board: chess.Board, temperature: float = 1.0) -> Optional[chess.Move]:
217
+ """Select the best action greedily (no sampling).
218
+
219
+ Args:
220
+ model: Chess transformer model
221
+ board: Chess board position
222
+ temperature: Temperature for scaling logits (unused in greedy selection)
223
+
224
+ Returns:
225
+ Best move or None if no legal moves
226
+ """
227
+ logits = model._get_action_logits(board, temperature)
228
+ if logits is None:
229
+ return None
230
+ logits = logits.squeeze(0) # Remove batch dimension
231
+ probs = F.softmax(logits, dim=0)
232
+ action_idx = int(torch.argmax(probs).item())
233
+ chosen_move = ACTION_TO_MOVE[action_idx]
234
+ return chess.Move.from_uci(chosen_move)
hf_space_repo/pretrain/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chess Model Pretraining
2
+
3
+ This module provides supervised pretraining on expert chess moves from Lichess games before GRPO reinforcement learning fine-tuning.
4
+
5
+ ## Overview
6
+
7
+ The pretraining pipeline:
8
+ 1. Streams chess games from HuggingFace (`Lichess/standard-chess-games`)
9
+ 2. Filters by player ELO rating
10
+ 3. Extracts positions and moves from games
11
+ 4. Trains the ChessTransformer with cross-entropy loss on expert moves
12
+ 5. Saves checkpoints compatible with GRPO training
13
+
14
+ ## Quick Start
15
+
16
+ ```bash
17
+ # Run pretraining with default config
18
+ python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
19
+
20
+ # With custom parameters
21
+ python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml \
22
+ --lr 1e-4 --batch_size 512 --min_elo 1800
23
+
24
+ # Disable wandb logging
25
+ python -m src.grpo_self_play.pretrain.pretrain --no_wandb
26
+ ```
27
+
28
+ ## Configuration
29
+
30
+ Configuration is in `src/grpo_self_play/configs/pretrain.yaml`:
31
+
32
+ ```yaml
33
+ pretrain:
34
+ lr: 0.0001 # Learning rate
35
+ batch_size: 256 # Batch size
36
+ num_epochs: 1 # Number of epochs
37
+ warmup_steps: 1000 # Linear warmup steps
38
+ weight_decay: 0.01 # AdamW weight decay
39
+ max_grad_norm: 1.0 # Gradient clipping
40
+ label_smoothing: 0.1 # Prevents overconfidence
41
+ val_check_interval: 0.1 # Validate every 10% of epoch
42
+
43
+ dataset:
44
+ min_elo: 1800 # Minimum player rating
45
+ skip_first_n_moves: 5 # Skip opening moves
46
+ skip_last_n_moves: 5 # Skip endgame moves
47
+ sample_positions_per_game: 3 # Positions per game
48
+ eval_fraction: 0.05 # 5% held out for evaluation
49
+
50
+ transformer:
51
+ embed_dim: 256
52
+ num_layers: 4
53
+ num_heads: 8
54
+ ```
55
+
56
+ ## Train/Eval Split
57
+
58
+ The dataset uses a **hash-based deterministic split** to ensure:
59
+ - No data leakage between training and evaluation
60
+ - Consistent splits across runs
61
+ - Process-safe multi-worker data loading
62
+
63
+ Games are assigned to train or eval based on:
64
+ ```python
65
+ is_eval = hash(game_site_url) % 10000 < (eval_fraction * 10000)
66
+ ```
67
+
68
+ This means the same game always goes to the same split, regardless of worker or epoch.
69
+
70
+ ## Using Pretrained Weights in GRPO
71
+
72
+ After pretraining, use the checkpoint for GRPO fine-tuning by updating `default.yaml`:
73
+
74
+ ```yaml
75
+ pretrain:
76
+ checkpoint_path: "checkpoints/pretrain/pretrain_final.pt"
77
+ freeze_layers: 0 # Optional: freeze first N transformer layers
78
+ ```
79
+
80
+ Or pass the path when running training:
81
+ ```bash
82
+ python -m src.grpo_self_play.train_self_play --config default.yaml
83
+ ```
84
+
85
+ ## Module Structure
86
+
87
+ ```
88
+ pretrain/
89
+ ├── __init__.py # Package exports
90
+ ├── pretrain.py # PyTorch Lightning training module
91
+ ├── pretrain_dataset.py # Streaming dataset from HuggingFace
92
+ ├── pretrain_load_config.py # Config for loading pretrained weights
93
+ └── README.md # This file
94
+ ```
95
+
96
+ ## Key Classes
97
+
98
+ ### PretrainChessTransformer
99
+
100
+ PyTorch Lightning module that wraps the ChessTransformer for supervised learning.
101
+
102
+ ```python
103
+ from src.grpo_self_play.pretrain.pretrain import PretrainChessTransformer, PretrainConfig
104
+ from src.grpo_self_play.models import ChessTransformerConfig
105
+
106
+ model = PretrainChessTransformer(
107
+ transformer_config=ChessTransformerConfig(embed_dim=256, num_layers=4, num_heads=8),
108
+ pretrain_config=PretrainConfig(lr=1e-4, batch_size=256),
109
+ )
110
+ ```
111
+
112
+ ### ChessPretrainDataset
113
+
114
+ Streaming dataset that yields (board_tokens, action, legal_mask) tuples.
115
+
116
+ ```python
117
+ from src.grpo_self_play.pretrain import ChessPretrainDataset, PretrainDatasetConfig
118
+
119
+ dataset = ChessPretrainDataset(PretrainDatasetConfig(
120
+ min_elo=1800,
121
+ is_eval=False, # True for evaluation set
122
+ ))
123
+ ```
124
+
125
+ ## Metrics
126
+
127
+ The following metrics are logged during training:
128
+
129
+ | Metric | Description |
130
+ |--------|-------------|
131
+ | `train/loss` | Cross-entropy loss with label smoothing |
132
+ | `train/accuracy` | Top-1 move prediction accuracy |
133
+ | `train/top5_accuracy` | Top-5 move prediction accuracy |
134
+ | `train/entropy` | Policy entropy (confidence measure) |
135
+ | `train/perplexity` | Exponential of loss |
136
+
137
+ ## Tests
138
+
139
+ Run the test suite:
140
+ ```bash
141
+ pytest tests/test_pretrain_pipeline.py -v
142
+ ```
143
+
144
+ Tests cover:
145
+ - Configuration dataclasses
146
+ - PGN move parsing
147
+ - Position extraction from games
148
+ - UCI to action conversion
149
+ - Collate function
150
+ - Model creation and forward pass
151
+ - Training and validation steps
152
+ - Hash-based train/eval splitting
153
+ - Integration with PyTorch Lightning
hf_space_repo/pretrain/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretraining module for chess model."""
2
+
3
+ from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
4
+ from src.grpo_self_play.pretrain.pretrain_dataset import (
5
+ ChessPretrainDataset,
6
+ PretrainDatasetConfig,
7
+ collate_pretrain_batch,
8
+ )
9
+
10
+ __all__ = [
11
+ "PretrainLoadConfig",
12
+ "ChessPretrainDataset",
13
+ "PretrainDatasetConfig",
14
+ "collate_pretrain_batch",
15
+ ]
hf_space_repo/pretrain/pretrain.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretraining script for chess model on Lichess games using PyTorch Lightning.
2
+
3
+ This script trains the ChessTransformer model using supervised learning
4
+ on expert moves from Lichess games before GRPO reinforcement learning.
5
+
6
+ Usage:
7
+ python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
8
+
9
+ # Or with overrides:
10
+ python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml \
11
+ --lr 1e-4 --batch_size 512 --min_elo 1800
12
+ """
13
+
14
+ import argparse
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import pytorch_lightning as pl
22
+ from pytorch_lightning.loggers import WandbLogger
23
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
24
+ from torch.utils.data import DataLoader
25
+
26
+ from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
27
+ from src.grpo_self_play.pretrain.pretrain_dataset import (
28
+ ChessPretrainDataset,
29
+ PretrainDatasetConfig,
30
+ collate_pretrain_batch,
31
+ )
32
+ from src.grpo_self_play.configs.config_loader import (
33
+ load_yaml_file,
34
+ dict_to_dataclass,
35
+ )
36
+
37
+
38
+ @dataclass
39
+ class PretrainConfig:
40
+ """Configuration for pretraining.
41
+
42
+ Attributes:
43
+ lr: Learning rate
44
+ batch_size: Batch size for training
45
+ num_epochs: Number of epochs to train
46
+ warmup_steps: Number of warmup steps for learning rate
47
+ weight_decay: Weight decay for AdamW
48
+ max_grad_norm: Maximum gradient norm for clipping
49
+ checkpoint_dir: Directory to save checkpoints
50
+ resume_from: Path to checkpoint to resume from
51
+ use_wandb: Whether to use Weights & Biases logging
52
+ wandb_project: WandB project name
53
+ label_smoothing: Label smoothing factor for cross-entropy
54
+ num_workers: Number of DataLoader workers
55
+ val_check_interval: Validation check interval (fraction of epoch or int steps)
56
+ """
57
+ lr: float = 1e-4
58
+ batch_size: int = 256
59
+ num_epochs: int = 1
60
+ warmup_steps: int = 1000
61
+ weight_decay: float = 0.01
62
+ max_grad_norm: float = 1.0
63
+ checkpoint_dir: str = "checkpoints/pretrain"
64
+ resume_from: Optional[str] = None
65
+ use_wandb: bool = True
66
+ wandb_project: str = "chess-grpo-pretrain"
67
+ label_smoothing: float = 0.1
68
+ num_workers: int = 4
69
+ val_check_interval: float = 0.1
70
+
71
+
72
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
73
+ torch.serialization.add_safe_globals([PretrainConfig])
74
+
75
+
76
+ class PretrainChessTransformer(pl.LightningModule):
77
+ """PyTorch Lightning module for pretraining chess policy with supervised learning.
78
+
79
+ This module implements supervised learning on expert chess moves from Lichess games.
80
+ The pretrained model can then be fine-tuned with GRPO reinforcement learning.
81
+
82
+ Attributes:
83
+ model: The ChessTransformer policy model
84
+ pretrain_config: Pretraining configuration
85
+ transformer_config: Model architecture configuration
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ transformer_config: ChessTransformerConfig,
91
+ pretrain_config: PretrainConfig,
92
+ ):
93
+ """Initialize pretraining module.
94
+
95
+ Args:
96
+ transformer_config: Configuration for the chess transformer model
97
+ pretrain_config: Pretraining configuration
98
+ """
99
+ super().__init__()
100
+ self.save_hyperparameters()
101
+
102
+ self.model = ChessTransformer(transformer_config)
103
+ self.pretrain_config = pretrain_config
104
+ self.transformer_config = transformer_config
105
+
106
+ # For warmup scheduler
107
+ self._num_training_steps = None
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ """Forward pass through the model.
111
+
112
+ Args:
113
+ x: Input tensor [batch, seq_len]
114
+
115
+ Returns:
116
+ Policy logits [batch, action_dim]
117
+ """
118
+ return self.model(x)
119
+
120
+ def _compute_loss(
121
+ self,
122
+ logits: torch.Tensor,
123
+ targets: torch.Tensor,
124
+ legal_masks: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, dict]:
126
+ """Compute cross-entropy loss with legal move masking.
127
+
128
+ Args:
129
+ logits: Model output logits [B, num_actions]
130
+ targets: Target action indices [B]
131
+ legal_masks: Legal moves mask [B, num_actions]
132
+
133
+ Returns:
134
+ Tuple of (loss, metrics_dict)
135
+ """
136
+ # Validate shapes match
137
+ B, action_dim = logits.shape
138
+ if legal_masks.shape != (B, action_dim):
139
+ raise ValueError(
140
+ f"Shape mismatch: logits {logits.shape} vs legal_masks {legal_masks.shape}. "
141
+ f"Expected legal_masks to be [{B}, {action_dim}]"
142
+ )
143
+ if targets.shape != (B,):
144
+ raise ValueError(
145
+ f"Shape mismatch: targets {targets.shape} vs expected [{B}]"
146
+ )
147
+
148
+ # Validate target actions are within bounds
149
+ max_target = targets.max().item()
150
+ min_target = targets.min().item()
151
+ if max_target >= action_dim or min_target < 0:
152
+ raise ValueError(
153
+ f"Target action indices out of bounds: min={min_target}, max={max_target}, "
154
+ f"action_dim={action_dim}. This suggests a mismatch between dataset action "
155
+ f"space and model action_dim."
156
+ )
157
+
158
+ # Validate target actions are legal (should always be true, but check defensively)
159
+ target_legal = legal_masks.gather(1, targets.unsqueeze(1)).squeeze(1)
160
+ if not target_legal.all():
161
+ illegal_count = (~target_legal).sum().item()
162
+ illegal_indices = (~target_legal).nonzero(as_tuple=False).flatten().tolist()
163
+ raise ValueError(
164
+ f"Found {illegal_count} illegal target actions in batch (out of {B}). "
165
+ f"First few batch indices: {illegal_indices[:10]}. "
166
+ f"This should not happen - dataset should filter these out."
167
+ )
168
+
169
+ # Check for NaN or Inf in raw logits (before masking)
170
+ if not torch.isfinite(logits).all():
171
+ nan_count = (~torch.isfinite(logits)).sum().item()
172
+ raise ValueError(
173
+ f"Found {nan_count} non-finite values in raw logits before masking. "
174
+ f"This suggests the model is outputting NaN/Inf."
175
+ )
176
+
177
+ # Mask illegal moves to -inf
178
+ masked_logits = logits.masked_fill(~legal_masks, float('-inf'))
179
+
180
+ # Check that each sample has at least one legal move (before checking masked logits)
181
+ legal_per_sample = legal_masks.sum(dim=1)
182
+ if (legal_per_sample == 0).any():
183
+ empty_samples = (legal_per_sample == 0).nonzero(as_tuple=False).flatten().tolist()
184
+ raise ValueError(
185
+ f"Found {len(empty_samples)} samples with no legal moves. "
186
+ f"Batch indices: {empty_samples[:10]}. This should not happen."
187
+ )
188
+
189
+ # Check masked logits: each sample must have at least one finite logit (legal move)
190
+ finite_per_sample = torch.isfinite(masked_logits).sum(dim=1)
191
+ if (finite_per_sample == 0).any():
192
+ bad_samples = (finite_per_sample == 0).nonzero(as_tuple=False).flatten().tolist()
193
+ raise ValueError(
194
+ f"Found {len(bad_samples)} samples with all -inf logits after masking. "
195
+ f"Batch indices: {bad_samples[:10]}. This means no legal moves have finite logits."
196
+ )
197
+
198
+ # Ensure target actions are not masked (defensive check)
199
+ target_logits = masked_logits.gather(1, targets.unsqueeze(1)).squeeze(1)
200
+ if not torch.isfinite(target_logits).all():
201
+ inf_count = (~torch.isfinite(target_logits)).sum().item()
202
+ raise ValueError(
203
+ f"Found {inf_count} target actions with -inf logits after masking. "
204
+ f"This means target actions are being masked as illegal, which should not happen."
205
+ )
206
+
207
+ # Compute NLL loss (works correctly with -inf masked logits)
208
+ nll_loss = F.cross_entropy(masked_logits, targets, reduction='mean')
209
+
210
+ # Apply label smoothing only over legal moves to avoid inf from -inf logits
211
+ # Standard F.cross_entropy with label_smoothing averages log_softmax over ALL
212
+ # actions, but -inf logits cause smooth_loss = +inf
213
+ eps = self.pretrain_config.label_smoothing
214
+ if eps > 0:
215
+ # Compute log_softmax (illegal moves will be -inf)
216
+ log_probs = F.log_softmax(masked_logits, dim=-1)
217
+ # Zero out illegal moves so they don't contribute to smoothing term
218
+ log_probs_legal = log_probs.masked_fill(~legal_masks, 0.0)
219
+ # Average only over legal moves
220
+ num_legal = legal_masks.sum(dim=-1).float() # [B]
221
+ smooth_loss = -log_probs_legal.sum(dim=-1) / num_legal # [B]
222
+ loss = (1 - eps) * nll_loss + eps * smooth_loss.mean()
223
+ else:
224
+ loss = nll_loss
225
+
226
+ # Check if loss is infinite or NaN
227
+ if not torch.isfinite(loss):
228
+ # Additional debugging info
229
+ target_logits_debug = masked_logits.gather(1, targets.unsqueeze(1)).squeeze(1)
230
+ print(f"DEBUG: Loss is {loss.item()}")
231
+ print(f"DEBUG: NLL loss: {nll_loss.item()}")
232
+ if eps > 0:
233
+ print(f"DEBUG: Smooth loss mean: {smooth_loss.mean().item()}")
234
+ print(f"DEBUG: Logits shape: {logits.shape}")
235
+ print(f"DEBUG: Legal masks shape: {legal_masks.shape}")
236
+ print(f"DEBUG: Targets range: [{targets.min().item()}, {targets.max().item()}]")
237
+ print(f"DEBUG: Target logits range: [{target_logits_debug.min().item():.2f}, {target_logits_debug.max().item():.2f}]")
238
+ print(f"DEBUG: Legal moves per sample: min={legal_per_sample.min().item()}, max={legal_per_sample.max().item()}")
239
+ raise ValueError(
240
+ f"Loss is {loss.item()}. This can happen if:\n"
241
+ f"1. Target actions are out of bounds\n"
242
+ f"2. Target actions are masked as illegal\n"
243
+ f"3. Model outputs contain NaN/Inf\n"
244
+ f"4. All logits are -inf (no legal moves)"
245
+ )
246
+
247
+ # Compute metrics
248
+ with torch.no_grad():
249
+ # Top-1 accuracy
250
+ predictions = masked_logits.argmax(dim=-1)
251
+ accuracy = (predictions == targets).float().mean()
252
+
253
+ # Top-5 accuracy
254
+ _, top5_preds = masked_logits.topk(5, dim=-1)
255
+ top5_correct = (top5_preds == targets.unsqueeze(-1)).any(dim=-1)
256
+ top5_accuracy = top5_correct.float().mean()
257
+
258
+ # Entropy of the distribution (measure of confidence)
259
+ probs = F.softmax(masked_logits, dim=-1)
260
+ log_probs = F.log_softmax(masked_logits, dim=-1)
261
+ # Handle -inf * 0 = nan by replacing with 0
262
+ entropy_terms = probs * log_probs
263
+ entropy_terms = torch.where(
264
+ torch.isfinite(entropy_terms),
265
+ entropy_terms,
266
+ torch.zeros_like(entropy_terms)
267
+ )
268
+ entropy = -entropy_terms.sum(dim=-1).mean()
269
+
270
+ # Perplexity - clamp to avoid inf
271
+ perplexity = torch.exp(loss.clamp(max=50))
272
+
273
+ metrics = {
274
+ 'accuracy': accuracy,
275
+ 'top5_accuracy': top5_accuracy,
276
+ 'entropy': entropy,
277
+ 'perplexity': perplexity,
278
+ }
279
+
280
+ return loss, metrics
281
+
282
+ def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
283
+ """Perform a training step.
284
+
285
+ Args:
286
+ batch: Tuple of (boards, actions, legal_masks)
287
+ batch_idx: Batch index
288
+
289
+ Returns:
290
+ Loss value
291
+ """
292
+ boards, actions, legal_masks = batch
293
+
294
+ # Forward pass
295
+ logits = self(boards)
296
+
297
+ # Compute loss and metrics
298
+ loss, metrics = self._compute_loss(logits, actions, legal_masks)
299
+
300
+ # Log metrics
301
+ self.log('train/loss', loss, prog_bar=True)
302
+ self.log('train/accuracy', metrics['accuracy'], prog_bar=True)
303
+ self.log('train/top5_accuracy', metrics['top5_accuracy'])
304
+ self.log('train/entropy', metrics['entropy'])
305
+ self.log('train/perplexity', metrics['perplexity'])
306
+
307
+ return loss
308
+
309
+ def validation_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
310
+ """Perform a validation step.
311
+
312
+ Args:
313
+ batch: Tuple of (boards, actions, legal_masks)
314
+ batch_idx: Batch index
315
+
316
+ Returns:
317
+ Loss value
318
+ """
319
+ boards, actions, legal_masks = batch
320
+
321
+ # Forward pass
322
+ logits = self(boards)
323
+
324
+ # Compute loss and metrics
325
+ loss, metrics = self._compute_loss(logits, actions, legal_masks)
326
+
327
+ # Log metrics
328
+ self.log('val/loss', loss, prog_bar=True, sync_dist=True)
329
+ self.log('val/accuracy', metrics['accuracy'], prog_bar=True, sync_dist=True)
330
+ self.log('val/top5_accuracy', metrics['top5_accuracy'], sync_dist=True)
331
+ self.log('val/entropy', metrics['entropy'], sync_dist=True)
332
+ self.log('val/perplexity', metrics['perplexity'], sync_dist=True)
333
+
334
+ return loss
335
+
336
+ def configure_optimizers(self):
337
+ """Configure optimizer and learning rate scheduler.
338
+
339
+ Returns:
340
+ Dictionary with optimizer and lr_scheduler configuration
341
+ """
342
+ optimizer = torch.optim.AdamW(
343
+ self.parameters(),
344
+ lr=self.pretrain_config.lr,
345
+ weight_decay=self.pretrain_config.weight_decay,
346
+ )
347
+
348
+ # Linear warmup + cosine decay scheduler
349
+ def lr_lambda(current_step: int) -> float:
350
+ warmup_steps = self.pretrain_config.warmup_steps
351
+ if current_step < warmup_steps:
352
+ return float(current_step) / float(max(1, warmup_steps))
353
+ return 1.0 # After warmup, use constant LR (or add cosine decay)
354
+
355
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
356
+
357
+ return {
358
+ 'optimizer': optimizer,
359
+ 'lr_scheduler': {
360
+ 'scheduler': scheduler,
361
+ 'interval': 'step',
362
+ 'frequency': 1,
363
+ }
364
+ }
365
+
366
+
367
+ def get_pretrain_trainer(
368
+ pretrain_config: PretrainConfig,
369
+ run_name: str,
370
+ ) -> pl.Trainer:
371
+ """Create a PyTorch Lightning trainer for pretraining.
372
+
373
+ Args:
374
+ pretrain_config: Pretraining configuration
375
+ run_name: Name for this training run
376
+
377
+ Returns:
378
+ Configured PyTorch Lightning trainer
379
+ """
380
+ # Create checkpoint directory
381
+ checkpoint_dir = Path(pretrain_config.checkpoint_dir)
382
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
383
+
384
+ callbacks = [
385
+ ModelCheckpoint(
386
+ dirpath=str(checkpoint_dir),
387
+ filename=run_name + "-{epoch:02d}-{train/loss:.4f}",
388
+ save_top_k=3,
389
+ monitor="train/loss",
390
+ mode="min",
391
+ save_last=True,
392
+ ),
393
+ LearningRateMonitor(logging_interval='step'),
394
+ ]
395
+
396
+ logger = None
397
+ if pretrain_config.use_wandb:
398
+ logger = WandbLogger(
399
+ project=pretrain_config.wandb_project,
400
+ name=run_name,
401
+ log_model=True,
402
+ )
403
+
404
+ trainer = pl.Trainer(
405
+ max_epochs=pretrain_config.num_epochs,
406
+ accelerator="auto",
407
+ devices=1,
408
+ logger=logger,
409
+ callbacks=callbacks,
410
+ gradient_clip_val=pretrain_config.max_grad_norm,
411
+ log_every_n_steps=50,
412
+ val_check_interval=pretrain_config.val_check_interval,
413
+ )
414
+
415
+ return trainer
416
+
417
+
418
+ def load_pretrain_config(
419
+ path: str = "pretrain.yaml",
420
+ overrides: dict = None,
421
+ ) -> tuple[PretrainConfig, PretrainDatasetConfig, ChessTransformerConfig]:
422
+ """Load pretraining configuration from YAML file.
423
+
424
+ Args:
425
+ path: Path to config file (relative to configs dir or absolute)
426
+ overrides: Optional dict of overrides
427
+
428
+ Returns:
429
+ Tuple of (PretrainConfig, PretrainDatasetConfig, ChessTransformerConfig)
430
+ """
431
+ data = load_yaml_file(path)
432
+
433
+ if overrides:
434
+ for section, section_overrides in overrides.items():
435
+ if section in data:
436
+ data[section].update(section_overrides)
437
+ else:
438
+ data[section] = section_overrides
439
+
440
+ pretrain = dict_to_dataclass(PretrainConfig, data.get('pretrain', {}))
441
+ dataset = dict_to_dataclass(PretrainDatasetConfig, data.get('dataset', {}))
442
+ transformer = dict_to_dataclass(ChessTransformerConfig, data.get('transformer', {}))
443
+
444
+ return pretrain, dataset, transformer
445
+
446
+
447
+ def train(
448
+ pretrain_config: PretrainConfig,
449
+ dataset_config: PretrainDatasetConfig,
450
+ transformer_config: ChessTransformerConfig,
451
+ ) -> str:
452
+ """Main pretraining function.
453
+
454
+ Args:
455
+ pretrain_config: Pretraining configuration
456
+ dataset_config: Dataset configuration
457
+ transformer_config: Model configuration
458
+
459
+ Returns:
460
+ Path to final checkpoint
461
+ """
462
+ import time
463
+ import random
464
+ import string
465
+
466
+ # Generate run name
467
+ timestamp = time.strftime("%Y%m%d-%H%M")
468
+ random_suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
469
+ run_name = f"pretrain-{timestamp}-{random_suffix}"
470
+ print(f"Run name: {run_name}")
471
+
472
+ # Create model
473
+ model = PretrainChessTransformer(transformer_config, pretrain_config)
474
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
475
+
476
+ # Create datasets
477
+ train_dataset = ChessPretrainDataset(dataset_config)
478
+
479
+ # Create validation dataset using hash-based split
480
+ val_dataset_config = PretrainDatasetConfig(
481
+ min_elo=dataset_config.min_elo,
482
+ max_samples=10000, # Smaller validation set
483
+ skip_first_n_moves=dataset_config.skip_first_n_moves,
484
+ skip_last_n_moves=dataset_config.skip_last_n_moves,
485
+ sample_positions_per_game=1, # Less samples per game for validation
486
+ is_eval=True, # Use eval portion of hash-based split
487
+ eval_fraction=dataset_config.eval_fraction,
488
+ cache_path=dataset_config.cache_path,
489
+ )
490
+ val_dataset = ChessPretrainDataset(val_dataset_config)
491
+ print(f"Train: {len(train_dataset):,} samples, Eval: {len(val_dataset):,} samples")
492
+
493
+ # Create dataloaders
494
+ train_dataloader = DataLoader(
495
+ train_dataset,
496
+ batch_size=pretrain_config.batch_size,
497
+ shuffle=True, # Shuffle for training
498
+ num_workers=pretrain_config.num_workers,
499
+ collate_fn=collate_pretrain_batch,
500
+ pin_memory=True,
501
+ )
502
+
503
+ val_dataloader = DataLoader(
504
+ val_dataset,
505
+ batch_size=pretrain_config.batch_size,
506
+ shuffle=False,
507
+ num_workers=max(1, pretrain_config.num_workers // 2),
508
+ collate_fn=collate_pretrain_batch,
509
+ pin_memory=True,
510
+ )
511
+
512
+ # Create trainer
513
+ trainer = get_pretrain_trainer(pretrain_config, run_name)
514
+
515
+ # Resume from checkpoint if specified
516
+ ckpt_path = pretrain_config.resume_from
517
+
518
+ # Train
519
+ trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)
520
+
521
+ # Save final checkpoint in a standard location
522
+ final_path = Path(pretrain_config.checkpoint_dir) / "pretrain_final.pt"
523
+ torch.save({
524
+ 'model_state_dict': model.model.state_dict(),
525
+ 'transformer_config': transformer_config,
526
+ 'pretrain_config': pretrain_config,
527
+ }, final_path)
528
+
529
+ print(f"\nPretraining complete! Final checkpoint saved to {final_path}")
530
+ return str(final_path)
531
+
532
+
533
+ def main():
534
+ """Main entry point for pretraining script."""
535
+ parser = argparse.ArgumentParser(description="Pretrain chess model on Lichess games")
536
+ parser.add_argument("--config", type=str, default="pretrain.yaml",
537
+ help="Path to config file")
538
+
539
+ # Allow command-line overrides for common parameters
540
+ parser.add_argument("--lr", type=float, help="Learning rate")
541
+ parser.add_argument("--batch_size", type=int, help="Batch size")
542
+ parser.add_argument("--num_epochs", type=int, help="Number of epochs")
543
+ parser.add_argument("--min_elo", type=int, help="Minimum player ELO")
544
+ parser.add_argument("--max_samples", type=int, help="Max samples per epoch")
545
+ parser.add_argument("--resume_from", type=str, help="Resume from checkpoint")
546
+ parser.add_argument("--no_wandb", action="store_true", help="Disable wandb logging")
547
+
548
+ args = parser.parse_args()
549
+
550
+ # Build overrides from command-line arguments
551
+ overrides = {'pretrain': {}, 'dataset': {}}
552
+
553
+ if args.lr:
554
+ overrides['pretrain']['lr'] = args.lr
555
+ if args.batch_size:
556
+ overrides['pretrain']['batch_size'] = args.batch_size
557
+ if args.num_epochs:
558
+ overrides['pretrain']['num_epochs'] = args.num_epochs
559
+ if args.resume_from:
560
+ overrides['pretrain']['resume_from'] = args.resume_from
561
+ if args.no_wandb:
562
+ overrides['pretrain']['use_wandb'] = False
563
+ if args.min_elo:
564
+ overrides['dataset']['min_elo'] = args.min_elo
565
+ if args.max_samples:
566
+ overrides['dataset']['max_samples'] = args.max_samples
567
+
568
+ # Load config
569
+ pretrain_config, dataset_config, transformer_config = load_pretrain_config(
570
+ args.config,
571
+ overrides=overrides if any(v for v in overrides.values()) else None
572
+ )
573
+
574
+ # Run training
575
+ train(pretrain_config, dataset_config, transformer_config)
576
+
577
+
578
+ if __name__ == "__main__":
579
+ main()
hf_space_repo/pretrain/pretrain_dataset.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset for pretraining on chess games from HuggingFace.
2
+
3
+ Uses angeluriot/chess_games: 14M high-ELO games (7.3GB download).
4
+ Mean ELO ~2355, moves already in UCI format - no parsing needed.
5
+ """
6
+
7
+ import os
8
+ import chess
9
+ import torch
10
+ import random
11
+ from typing import Optional
12
+ from functools import partial
13
+ from dataclasses import dataclass
14
+ from multiprocessing import cpu_count
15
+ from torch.utils.data import Dataset
16
+ from datasets import load_dataset
17
+ from tqdm import tqdm
18
+
19
+ from src.grpo_self_play.searchless_chess_imports import MOVE_TO_ACTION, tokenize
20
+
21
+ # Global constant
22
+ _ACTION_SPACE_SIZE = max(MOVE_TO_ACTION.values()) + 1
23
+
24
+
25
+ @dataclass
26
+ class PretrainDatasetConfig:
27
+ """Configuration for the pretraining dataset.
28
+
29
+ Uses angeluriot/chess_games: 14M high-ELO games (7.3GB download).
30
+ Mean ELO ~2355, moves already in UCI format.
31
+
32
+ Attributes:
33
+ min_elo: Minimum player ELO to include games
34
+ max_samples: Maximum number of samples per epoch (None for unlimited)
35
+ skip_first_n_moves: Skip the first N moves (avoid memorizing openings)
36
+ skip_last_n_moves: Skip the last N moves (avoid noisy endgame positions)
37
+ sample_positions_per_game: Number of positions to sample from each game
38
+ is_eval: If True, use eval portion of hash-based split.
39
+ eval_fraction: Fraction of data to use for evaluation (default 0.05 = 5%)
40
+ cache_path: Path to save/load filtered dataset (e.g., Google Drive, studio storage).
41
+ If set and exists, loads from cache. Otherwise downloads, filters, and saves.
42
+ """
43
+ min_elo: int = 2000
44
+ max_samples: Optional[int] = None
45
+ skip_first_n_moves: int = 5
46
+ skip_last_n_moves: int = 5
47
+ sample_positions_per_game: int = 3
48
+ is_eval: bool = False
49
+ eval_fraction: float = 0.05
50
+ cache_path: Optional[str] = None
51
+
52
+
53
+ def uci_to_action(uci_move: str) -> Optional[int]:
54
+ """Convert UCI move string to action index."""
55
+ return MOVE_TO_ACTION.get(uci_move)
56
+
57
+
58
+ def get_positions_from_game(
59
+ moves: list[str],
60
+ skip_first_n: int = 5,
61
+ skip_last_n: int = 5,
62
+ sample_n: int = 3,
63
+ ) -> list[tuple[str, str, int]]:
64
+ """Extract (FEN, move_played, move_number) tuples from a game.
65
+
66
+ Args:
67
+ moves: List of UCI moves
68
+ skip_first_n: Skip first N moves (opening book territory)
69
+ skip_last_n: Skip last N moves (endgame/resignation noise)
70
+ sample_n: Number of positions to randomly sample
71
+
72
+ Returns:
73
+ List of (fen, uci_move, move_number) tuples
74
+ """
75
+ if len(moves) <= skip_first_n + skip_last_n:
76
+ return []
77
+
78
+ board = chess.Board()
79
+ positions = []
80
+
81
+ for i, uci_move in enumerate(moves):
82
+ if i < skip_first_n:
83
+ try:
84
+ board.push_uci(uci_move)
85
+ except (ValueError, chess.InvalidMoveError):
86
+ return positions
87
+ continue
88
+
89
+ if i >= len(moves) - skip_last_n:
90
+ break
91
+
92
+ fen = board.fen()
93
+ positions.append((fen, uci_move, i))
94
+
95
+ try:
96
+ board.push_uci(uci_move)
97
+ except (ValueError, chess.InvalidMoveError):
98
+ break
99
+
100
+ if len(positions) > sample_n:
101
+ positions = random.sample(positions, sample_n)
102
+
103
+ return positions
104
+
105
+
106
+ class ChessPretrainDataset(Dataset):
107
+ """Dataset for chess pretraining from angeluriot/chess_games.
108
+
109
+ Downloads the full dataset (7.3GB) and processes games into
110
+ (board_tensor, target_action, legal_moves_mask) tuples.
111
+
112
+ Example:
113
+ >>> config = PretrainDatasetConfig(min_elo=2000)
114
+ >>> dataset = ChessPretrainDataset(config)
115
+ >>> dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
116
+ """
117
+
118
+ def __init__(self, config: PretrainDatasetConfig = PretrainDatasetConfig()):
119
+ """Initialize the dataset - downloads and processes all games."""
120
+ self.config = config
121
+ self._action_space_size = max(MOVE_TO_ACTION.values()) + 1
122
+ self._samples: list[tuple[torch.Tensor, int, torch.Tensor]] = []
123
+
124
+ self._load_and_process()
125
+
126
+ def _load_and_process(self):
127
+ """Download dataset and process all games into samples."""
128
+ # Try loading processed samples from cache
129
+ if self.config.cache_path:
130
+ cache_file = self._get_cache_filename()
131
+ if os.path.exists(cache_file):
132
+ print(f"Loading processed samples from {cache_file}...")
133
+ self._samples = torch.load(cache_file)
134
+ print(f"Loaded {len(self._samples):,} samples from cache")
135
+ return
136
+
137
+ # Download, filter, and process
138
+ dataset = self._load_filtered_dataset()
139
+
140
+ # Limit dataset size if max_samples is set
141
+ if self.config.max_samples:
142
+ max_games = self.config.max_samples // self.config.sample_positions_per_game + 1000
143
+ if len(dataset) > max_games:
144
+ dataset = dataset.select(range(max_games))
145
+ print(f"Limited to {len(dataset):,} games")
146
+
147
+ # Process games using HuggingFace's optimized map
148
+ num_workers = min(8, cpu_count() or 4)
149
+ print(f"Processing games into samples with {num_workers} workers...")
150
+
151
+ skip_first = self.config.skip_first_n_moves
152
+ skip_last = self.config.skip_last_n_moves
153
+ sample_n = self.config.sample_positions_per_game
154
+
155
+ def process_batch(batch):
156
+ """Process a batch of games - returns lists for HF dataset."""
157
+ all_boards, all_actions, all_masks = [], [], []
158
+
159
+ for i in range(len(batch['moves_uci'])):
160
+ moves = batch['moves_uci'][i]
161
+ if not moves:
162
+ continue
163
+
164
+ positions = get_positions_from_game(moves, skip_first, skip_last, sample_n)
165
+
166
+ for fen, uci_move, _ in positions:
167
+ action_idx = MOVE_TO_ACTION.get(uci_move)
168
+ if action_idx is None:
169
+ continue
170
+ try:
171
+ token_ids = list(tokenize(fen))
172
+ board = chess.Board(fen)
173
+ legal_mask = [False] * _ACTION_SPACE_SIZE
174
+ for move in board.legal_moves:
175
+ move_idx = MOVE_TO_ACTION.get(move.uci())
176
+ if move_idx is not None:
177
+ legal_mask[move_idx] = True
178
+ if not legal_mask[action_idx]:
179
+ continue
180
+ all_boards.append(token_ids)
181
+ all_actions.append(action_idx)
182
+ all_masks.append(legal_mask)
183
+ except Exception:
184
+ continue
185
+
186
+ return {'boards': all_boards, 'actions': all_actions, 'masks': all_masks}
187
+
188
+ processed = dataset.map(
189
+ process_batch,
190
+ batched=True,
191
+ batch_size=1000,
192
+ num_proc=num_workers,
193
+ remove_columns=dataset.column_names,
194
+ desc="Processing"
195
+ )
196
+
197
+ # Convert to tensors (HF map flattens the lists)
198
+ print("Converting to tensors...")
199
+ for i in tqdm(range(len(processed)), desc="Tensorizing"):
200
+ board_tensor = torch.tensor(processed[i]['boards'], dtype=torch.long)
201
+ legal_mask = torch.tensor(processed[i]['masks'], dtype=torch.bool)
202
+ self._samples.append((board_tensor, processed[i]['actions'], legal_mask))
203
+ if self.config.max_samples and len(self._samples) >= self.config.max_samples:
204
+ break
205
+
206
+ print(f"Done: {len(self._samples):,} samples")
207
+
208
+ # Save processed samples to cache
209
+ if self.config.cache_path:
210
+ cache_file = self._get_cache_filename()
211
+ print(f"Saving processed samples to {cache_file}...")
212
+ os.makedirs(self.config.cache_path, exist_ok=True)
213
+ torch.save(self._samples, cache_file)
214
+ print("Saved to cache")
215
+
216
+ def _get_cache_filename(self) -> str:
217
+ """Generate cache filename based on config."""
218
+ split = 'eval' if self.config.is_eval else 'train'
219
+ max_samples = self.config.max_samples or 'all'
220
+ return f"{self.config.cache_path}/processed_elo{self.config.min_elo}_{split}_{max_samples}.pt"
221
+
222
+ def _load_filtered_dataset(self):
223
+ """Download and filter dataset."""
224
+ # Download (uses cache_path for HuggingFace cache)
225
+ print("Downloading angeluriot/chess_games (7.3GB)...")
226
+ cache_dir = self.config.cache_path if self.config.cache_path else None
227
+ dataset = load_dataset("angeluriot/chess_games", split="train", cache_dir=cache_dir)
228
+ print(f"Loaded {len(dataset):,} games")
229
+
230
+ # Fast batched filtering
231
+ print(f"Filtering games (min_elo={self.config.min_elo})...")
232
+ min_elo = self.config.min_elo
233
+ eval_frac = self.config.eval_fraction
234
+ is_eval = self.config.is_eval
235
+
236
+ def batch_filter(batch):
237
+ """Filter a batch of games - much faster than per-example."""
238
+ keep = []
239
+ for i in range(len(batch['white_elo'])):
240
+ white_elo = batch['white_elo'][i]
241
+ black_elo = batch['black_elo'][i]
242
+
243
+ # Skip if ELO is missing
244
+ if white_elo is None or black_elo is None:
245
+ keep.append(False)
246
+ continue
247
+ # ELO filter
248
+ if white_elo < min_elo or black_elo < min_elo:
249
+ keep.append(False)
250
+ continue
251
+ # Moves filter
252
+ if len(batch['moves_uci'][i]) < 10:
253
+ keep.append(False)
254
+ continue
255
+ # Hash-based train/eval split
256
+ game_id = f"{batch['date'][i]}-{white_elo}-{black_elo}"
257
+ hash_val = hash(game_id) % 10000
258
+ is_eval_game = hash_val < (eval_frac * 10000)
259
+ if is_eval_game != is_eval:
260
+ keep.append(False)
261
+ continue
262
+ keep.append(True)
263
+ return keep
264
+
265
+ dataset = dataset.filter(batch_filter, batched=True, batch_size=10000, desc="Filtering")
266
+ print(f"After filtering: {len(dataset):,} games")
267
+
268
+ return dataset
269
+
270
+ def _process_game(self, game: dict):
271
+ """Process a single game and yield training samples."""
272
+ moves = game.get('moves_uci', [])
273
+
274
+ positions = get_positions_from_game(
275
+ moves,
276
+ skip_first_n=self.config.skip_first_n_moves,
277
+ skip_last_n=self.config.skip_last_n_moves,
278
+ sample_n=self.config.sample_positions_per_game,
279
+ )
280
+
281
+ for fen, uci_move, _ in positions:
282
+ action_idx = uci_to_action(uci_move)
283
+ if action_idx is None:
284
+ continue
285
+
286
+ try:
287
+ token_ids = list(tokenize(fen))
288
+ board_tensor = torch.tensor(token_ids, dtype=torch.long)
289
+ except Exception:
290
+ continue
291
+
292
+ try:
293
+ board = chess.Board(fen)
294
+ legal_mask = torch.zeros(self._action_space_size, dtype=torch.bool)
295
+ for move in board.legal_moves:
296
+ move_idx = MOVE_TO_ACTION.get(move.uci())
297
+ if move_idx is not None:
298
+ legal_mask[move_idx] = True
299
+ except Exception:
300
+ continue
301
+
302
+ if not legal_mask[action_idx]:
303
+ continue
304
+
305
+ yield board_tensor, action_idx, legal_mask
306
+
307
+ def __len__(self) -> int:
308
+ return len(self._samples)
309
+
310
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, torch.Tensor]:
311
+ return self._samples[idx]
312
+
313
+
314
+ def collate_pretrain_batch(
315
+ batch: list[tuple[torch.Tensor, int, torch.Tensor]]
316
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Collate function for DataLoader.
318
+
319
+ Returns:
320
+ Tuple of (boards [B, 77], actions [B], legal_masks [B, num_actions])
321
+ """
322
+ boards, actions, masks = zip(*batch)
323
+
324
+ boards = torch.stack(boards)
325
+ actions = torch.tensor(actions, dtype=torch.long)
326
+ masks = torch.stack(masks)
327
+
328
+ return boards, actions, masks
hf_space_repo/pretrain/pretrain_load_config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pretrain load configuration - separated to avoid circular imports."""
2
+
3
+ import torch
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+
8
+ @dataclass
9
+ class PretrainLoadConfig:
10
+ """Configuration for loading pretrained weights.
11
+
12
+ Attributes:
13
+ checkpoint_path: Path to pretrained checkpoint file
14
+ freeze_layers: Number of transformer layers to freeze (0 = train all)
15
+ """
16
+ checkpoint_path: Optional[str] = None
17
+ freeze_layers: int = 0
18
+
19
+
20
+ # Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
21
+ torch.serialization.add_safe_globals([PretrainLoadConfig])
hf_space_repo/searchless_chess_imports.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.searchless_chess_model.searchless_chess_code.utils import ACTION_TO_MOVE, MOVE_TO_ACTION
2
+ from src.searchless_chess_model.searchless_chess_code.tokenizer import tokenize, SEQUENCE_LENGTH
3
+
hf_space_repo/searchless_chess_model/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
hf_space_repo/searchless_chess_model/README.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - chess
5
+ - reinforcement-learning
6
+ - jax
7
+ - transformer
8
+ language:
9
+ - en
10
+ library_name: jax
11
+ ---
12
+
13
+ # Searchless Chess 9M Self-Play
14
+
15
+ A 9-million parameter transformer-based chess engine trained via self-play with Stockfish evaluation. This model learns to play chess without explicit search during inference, relying purely on learned pattern recognition.
16
+
17
+ ## Model Description
18
+
19
+ - **Model Size**: 9M parameters (8 layers, 256 embedding dim, 8 attention heads)
20
+ - **Architecture**: Decoder-only Transformer with learned positional encodings
21
+ - **Training Method**: Self-play with Stockfish rewards
22
+ - **Framework**: JAX + Haiku
23
+ - **Q-Value Distribution**: 128 return buckets for action-value prediction
24
+
25
+ This model predicts action-values (Q-values) for chess positions without performing tree search, making it extremely fast for inference while maintaining strong play.
26
+
27
+ ## Installation
28
+
29
+ ### CPU Installation
30
+
31
+ Install the required dependencies for CPU inference:
32
+
33
+ ```bash
34
+ pip install jax jaxlib dm-haiku orbax-checkpoint numpy chess huggingface-hub jaxtyping apache-beam grain
35
+ ```
36
+
37
+ ### GPU Installation (Recommended)
38
+
39
+ For GPU acceleration with CUDA 12:
40
+
41
+ ```bash
42
+ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
43
+ pip install dm-haiku orbax-checkpoint numpy chess huggingface-hub jaxtyping apache-beam grain
44
+ ```
45
+
46
+ For other CUDA versions, see the [JAX installation guide](https://github.com/google/jax#installation).
47
+
48
+ **Note**: This model includes all necessary code and can be used **without cloning the original repository**.
49
+
50
+ ## Quick Start
51
+
52
+ ```python
53
+ import sys
54
+ from huggingface_hub import snapshot_download
55
+
56
+ # Download model from HuggingFace Hub
57
+ model_path = snapshot_download(
58
+ repo_id="dbest-isi/searchless-chess-9M-selfplay",
59
+ local_dir="./searchless_chess_model"
60
+ )
61
+
62
+ # Add bundled code to Python path
63
+ sys.path.insert(0, f"{model_path}/searchless_chess_code")
64
+
65
+ # Import model wrapper
66
+ import hf_model
67
+
68
+ # Load the model
69
+ model = hf_model.SearchlessChessModel.from_pretrained(model_path)
70
+
71
+ # Make a prediction
72
+ fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"
73
+ result = model.predict(fen, temperature=1.0)
74
+
75
+ print(f"Best move: {result['best_move']}")
76
+ print(f"Q-value: {result['q_value']:.4f}")
77
+ print(f"Action probabilities shape: {result['action_probs'].shape}")
78
+ ```
79
+
80
+ ## Example Output
81
+
82
+ ```python
83
+ Best move: e7e5
84
+ Q-value: 0.0119
85
+ Action probabilities shape: (1968,)
86
+ ```
87
+
88
+ ## Full Example with Multiple Positions
89
+
90
+ ```python
91
+ import sys
92
+ from huggingface_hub import snapshot_download
93
+
94
+ # Download and setup
95
+ model_path = snapshot_download(
96
+ repo_id="dbest-isi/searchless-chess-9M-selfplay",
97
+ local_dir="./searchless_chess_model"
98
+ )
99
+ sys.path.insert(0, f"{model_path}/searchless_chess_code")
100
+
101
+ import hf_model
102
+
103
+ # Load model
104
+ print("Loading model...")
105
+ model = hf_model.SearchlessChessModel.from_pretrained(model_path)
106
+ print("Model loaded!")
107
+
108
+ # Test on multiple positions
109
+ positions = [
110
+ ("Starting position", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
111
+ ("After 1.e4", "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"),
112
+ ("Scandinavian Defense", "rnbqkbnr/ppp1pppp/8/3p4/4P3/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 2"),
113
+ ]
114
+
115
+ for name, fen in positions:
116
+ result = model.predict(fen)
117
+ print(f"\n{name}")
118
+ print(f" FEN: {fen}")
119
+ print(f" Best move: {result['best_move']}")
120
+ print(f" Q-value: {result['q_value']:.4f}")
121
+ ```
122
+
123
+ ## Model Architecture
124
+
125
+ ```python
126
+ TransformerConfig(
127
+ vocab_size=1968,
128
+ output_size=128,
129
+ embedding_dim=256,
130
+ num_layers=8,
131
+ num_heads=8,
132
+ max_sequence_length=79,
133
+ num_return_buckets=128,
134
+ pos_encodings="LEARNED",
135
+ apply_post_ln=True,
136
+ apply_qk_layernorm=False,
137
+ use_causal_mask=False,
138
+ )
139
+ ```
140
+
141
+ ## Training Details
142
+
143
+ - **Base Model**: Initialized from pretrained 9M checkpoint
144
+ - **Training Method**: Self-play reinforcement learning
145
+ - **Reward Signal**: Stockfish evaluation at depth 20
146
+ - **Iteration**: 22 (EMA parameters)
147
+ - **Action Space**: 1968 possible moves (all legal chess moves)
148
+ - **Value Representation**: Discretized into 128 buckets
149
+
150
+ ## Use Cases
151
+
152
+ - Fast chess move prediction without search
153
+ - Chess position evaluation
154
+ - Research on learned planning in board games
155
+ - Integration into chess applications requiring low-latency move suggestions
156
+
157
+ ## Limitations
158
+
159
+ - Does not perform explicit search (unlike traditional chess engines)
160
+ - May make suboptimal moves in complex tactical positions
161
+ - Performance depends on training data distribution
162
+ - Best suited for fast move suggestions rather than deep analysis
163
+
164
+ ## Background
165
+
166
+ This model is based on the architecture from DeepMind's [Searchless Chess](https://github.com/google-deepmind/searchless_chess) work. The **self-play training implementation and this trained model** are original work by Darrell Best.
167
+
168
+ For the full self-play training implementation and codebase, visit:
169
+ - Repository: https://github.com/DarrellBest/searchless_chess
170
+
171
+ ## License
172
+
173
+ Apache 2.0
174
+
175
+ ## Model Card Contact
176
+
177
+ For questions or issues, please open an issue on the [GitHub repository](https://github.com/DarrellBest/searchless_chess).
hf_space_repo/searchless_chess_model/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 1968,
3
+ "output_size": 128,
4
+ "embedding_dim": 256,
5
+ "num_layers": 8,
6
+ "num_heads": 8,
7
+ "max_sequence_length": 79,
8
+ "num_return_buckets": 128,
9
+ "model_name": "9M"
10
+ }
hf_space_repo/searchless_chess_model/model_info.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "searchless_chess",
3
+ "framework": "jax",
4
+ "library": "dm-haiku",
5
+ "includes_source": true,
6
+ "source_modules": [
7
+ "tokenizer.py",
8
+ "transformer.py",
9
+ "constants.py",
10
+ "utils.py",
11
+ "config.py"
12
+ ]
13
+ }
hf_space_repo/searchless_chess_model/searchless_chess_code/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Searchless Chess code bundle
hf_space_repo/searchless_chess_model/searchless_chess_code/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Defines the configuration dataclasses."""
17
+
18
+ import dataclasses
19
+ from typing import Literal
20
+
21
+
22
+ PolicyType = Literal['action_value', 'state_value', 'behavioral_cloning']
23
+ POLICY_TYPES = ['action_value', 'state_value', 'behavioral_cloning']
24
+
25
+
26
+ @dataclasses.dataclass(kw_only=True)
27
+ class DataConfig:
28
+ """Config for the data generation."""
29
+
30
+ # The batch size for the sequences.
31
+ batch_size: int
32
+ # Whether to shuffle the dataset (shuffling is applied per epoch).
33
+ shuffle: bool = False
34
+ # The seed used for shuffling and transformations of the data.
35
+ seed: int | None = 0
36
+ # Whether to drop partial batches.
37
+ drop_remainder: bool = False
38
+ # The number of child processes launched to parallelize the transformations.
39
+ worker_count: int | None = 0
40
+ # The number of return buckets.
41
+ num_return_buckets: int
42
+ # The dataset split.
43
+ split: Literal['train', 'test']
44
+ # The policy used to create the dataset.
45
+ policy: PolicyType
46
+ # The number of records to read from the dataset (can be useful when, e.g.,
47
+ # the dataset does not fit into memory).
48
+ num_records: int | None = None
49
+
50
+
51
+ @dataclasses.dataclass(kw_only=True)
52
+ class TrainConfig:
53
+ """Config for the training function."""
54
+
55
+ # The data configuration for training.
56
+ data: DataConfig
57
+ # The learning rate for Adam.
58
+ learning_rate: float
59
+ # The gradient clipping value.
60
+ max_grad_norm: float = 1.0
61
+ # The number of gradient steps.
62
+ num_steps: int
63
+ # The frequency (in gradient steps) at which checkpoints should be saved
64
+ # (`None` means there is no checkpointing).
65
+ ckpt_frequency: int | None = None
66
+ # If provided, the maximum number of checkpoints to keep.
67
+ ckpt_max_to_keep: int | None = 1
68
+ # The frequency (in gradient steps) at which checkpoints should be saved
69
+ # permanently (`None` means all checkpoints are temporary).
70
+ save_frequency: int | None = None
71
+ # The frequency of logging in gradient steps (`None` means no logging).
72
+ log_frequency: int | None = None
73
+
74
+
75
+ @dataclasses.dataclass(kw_only=True)
76
+ class EvalConfig:
77
+ """Config for the evaluator."""
78
+
79
+ # The data configuration for evaluation.
80
+ data: DataConfig
81
+ # How many data points to consider for evaluation.
82
+ num_eval_data: int | None = None
83
+ # Enables use of ema-ed params in eval.
84
+ use_ema_params: bool = False
85
+ # The policy used to play moves with the model.
86
+ policy: PolicyType
87
+ # The number of return buckets.
88
+ num_return_buckets: int
89
+ # The batch size for evaluation.
90
+ batch_size: int | None = None
hf_space_repo/searchless_chess_model/searchless_chess_code/constants.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Constants, interfaces, and types."""
17
+
18
+ import abc
19
+ from collections.abc import Callable, Mapping
20
+ import dataclasses
21
+ from typing import Any, NamedTuple, Protocol
22
+
23
+ from apache_beam import coders
24
+ from grain import python as pygrain
25
+ import haiku as hk
26
+ import jaxtyping as jtp
27
+
28
+ import config as config_lib
29
+
30
+
31
+ # Integer sequences of token ids.
32
+ Sequences = jtp.UInt32[jtp.Array, 'B T']
33
+
34
+ # The predictions are log-probabilities (natural logarithm) for the passed
35
+ # sequences. It can either be marginal log-probabilities (i.e. log P(s) for all
36
+ # sequences s in the batch), or full conditionals (i.e. log P(token | s_<t) for
37
+ # all sequence s, time t and token in the alphabet).
38
+ Marginals = jtp.Float32[jtp.Array, '*B']
39
+ Conditionals = jtp.Float32[jtp.Array, '*B T F']
40
+ Predictions = Marginals | Conditionals
41
+
42
+ # True means the loss will be masked there, i.e. we ignore it.
43
+ LossMask = jtp.Bool[jtp.Array, 'B T']
44
+
45
+
46
+ @dataclasses.dataclass
47
+ class Predictor:
48
+ """Defines the predictor interface."""
49
+
50
+ initial_params: Callable[..., hk.MutableParams]
51
+ predict: Callable[..., Predictions]
52
+
53
+
54
+ class DataLoaderBuilder(Protocol):
55
+
56
+ def __call__(self, config: config_lib.DataConfig) -> pygrain.DataLoader:
57
+ """Returns a PyGrain data loader from the `config`."""
58
+
59
+
60
+ class Evaluator(abc.ABC):
61
+ """Defines the interface of the evaluator that evaluates a predictor."""
62
+
63
+ @abc.abstractmethod
64
+ def step(self, params: hk.Params, step: int) -> Mapping[str, Any]:
65
+ """Returns the results of evaluating the predictor with `params`."""
66
+
67
+
68
+ class EvaluatorBuilder(Protocol):
69
+
70
+ def __call__(
71
+ self,
72
+ predictor: Predictor,
73
+ config: config_lib.EvalConfig,
74
+ ) -> Evaluator:
75
+ """Returns an evaluator for the `predictor` and `config`.
76
+
77
+ Args:
78
+ predictor: The predictor to be evaluated. The training loop continuously
79
+ saves the predictor's parameters, which are then loaded in the
80
+ evaluation loop and passed to the evaluator's step method.
81
+ config: The configuration of the evaluator.
82
+ """
83
+
84
+
85
+ CODERS = {
86
+ 'fen': coders.StrUtf8Coder(),
87
+ 'move': coders.StrUtf8Coder(),
88
+ 'count': coders.BigIntegerCoder(),
89
+ 'win_prob': coders.FloatCoder(),
90
+ }
91
+ CODERS['state_value'] = coders.TupleCoder((
92
+ CODERS['fen'],
93
+ CODERS['win_prob'],
94
+ ))
95
+ CODERS['action_value'] = coders.TupleCoder((
96
+ CODERS['fen'],
97
+ CODERS['move'],
98
+ CODERS['win_prob'],
99
+ ))
100
+ CODERS['behavioral_cloning'] = coders.TupleCoder((
101
+ CODERS['fen'],
102
+ CODERS['move'],
103
+ ))
104
+
105
+
106
+ class BehavioralCloningData(NamedTuple):
107
+ fen: str
108
+ move: str
109
+
110
+
111
+ class StateValueData(NamedTuple):
112
+ fen: str
113
+ win_prob: float
114
+
115
+
116
+ class ActionValueData(NamedTuple):
117
+ fen: str
118
+ move: str
119
+ win_prob: float
hf_space_repo/searchless_chess_model/searchless_chess_code/hf_model.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace model wrapper for searchless chess."""
2
+
3
+ import json
4
+ import os
5
+ from typing import Dict, Optional
6
+
7
+ import haiku as hk
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import numpy as np
11
+ import orbax.checkpoint as ocp
12
+
13
+ import tokenizer
14
+ import transformer
15
+ import utils
16
+
17
+
18
+ class SearchlessChessConfig:
19
+ """Configuration for SearchlessChess model."""
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size: int = 1968,
24
+ output_size: int = 128,
25
+ embedding_dim: int = 256,
26
+ num_layers: int = 8,
27
+ num_heads: int = 8,
28
+ max_sequence_length: int = 79,
29
+ num_return_buckets: int = 128,
30
+ model_name: str = "9M",
31
+ **kwargs,
32
+ ):
33
+ self.vocab_size = vocab_size
34
+ self.output_size = output_size
35
+ self.embedding_dim = embedding_dim
36
+ self.num_layers = num_layers
37
+ self.num_heads = num_heads
38
+ self.max_sequence_length = max_sequence_length
39
+ self.num_return_buckets = num_return_buckets
40
+ self.model_name = model_name
41
+
42
+ # Store any extra kwargs
43
+ for key, value in kwargs.items():
44
+ setattr(self, key, value)
45
+
46
+ def to_dict(self) -> Dict:
47
+ """Convert config to dictionary."""
48
+ return {
49
+ "vocab_size": self.vocab_size,
50
+ "output_size": self.output_size,
51
+ "embedding_dim": self.embedding_dim,
52
+ "num_layers": self.num_layers,
53
+ "num_heads": self.num_heads,
54
+ "max_sequence_length": self.max_sequence_length,
55
+ "num_return_buckets": self.num_return_buckets,
56
+ "model_name": self.model_name,
57
+ }
58
+
59
+ @classmethod
60
+ def from_dict(cls, config_dict: Dict) -> "SearchlessChessConfig":
61
+ """Load config from dictionary."""
62
+ return cls(**config_dict)
63
+
64
+ def save_pretrained(self, save_directory: str):
65
+ """Save config to directory."""
66
+ os.makedirs(save_directory, exist_ok=True)
67
+ config_path = os.path.join(save_directory, "config.json")
68
+ with open(config_path, "w") as f:
69
+ json.dump(self.to_dict(), f, indent=2)
70
+
71
+ @classmethod
72
+ def from_pretrained(cls, model_path: str) -> "SearchlessChessConfig":
73
+ """Load config from directory."""
74
+ config_path = os.path.join(model_path, "config.json")
75
+ with open(config_path, "r") as f:
76
+ config_dict = json.load(f)
77
+ return cls.from_dict(config_dict)
78
+
79
+
80
+ class SearchlessChessModel:
81
+ """HuggingFace-compatible wrapper for SearchlessChess JAX/Haiku model."""
82
+
83
+ def __init__(self, config: SearchlessChessConfig):
84
+ self.config = config
85
+
86
+ # Build transformer config
87
+ self.transformer_config = transformer.TransformerConfig(
88
+ vocab_size=config.vocab_size,
89
+ output_size=config.output_size,
90
+ pos_encodings=transformer.PositionalEncodings.LEARNED,
91
+ max_sequence_length=config.max_sequence_length,
92
+ num_heads=config.num_heads,
93
+ num_layers=config.num_layers,
94
+ embedding_dim=config.embedding_dim,
95
+ apply_post_ln=True,
96
+ apply_qk_layernorm=False,
97
+ use_causal_mask=False,
98
+ )
99
+
100
+ # Build predictor
101
+ self.predictor = transformer.build_transformer_predictor(self.transformer_config)
102
+
103
+ # Initialize params
104
+ self.params = None
105
+ self.return_buckets_values = None
106
+
107
+ # Get return bucket values
108
+ _, self.return_buckets_values = utils.get_uniform_buckets_edges_values(
109
+ config.num_return_buckets
110
+ )
111
+
112
+ def load_params(self, params_path: str):
113
+ """Load parameters from Orbax checkpoint."""
114
+ # Convert to absolute path (Orbax requires absolute paths)
115
+ params_path = os.path.abspath(params_path)
116
+
117
+ # Create dummy params for structure
118
+ dummy_params = self.predictor.initial_params(
119
+ rng=jax.random.PRNGKey(0),
120
+ targets=np.ones((1, 1), dtype=np.uint32),
121
+ )
122
+
123
+ # Load checkpoint
124
+ restore_args = ocp.checkpoint_utils.construct_restore_args(dummy_params)
125
+ checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
126
+ self.params = checkpointer.restore(params_path, restore_args=restore_args)
127
+
128
+ def predict(self, fen: str, temperature: float = 1.0) -> Dict:
129
+ """Predict move from FEN position.
130
+
131
+ Args:
132
+ fen: Chess position in FEN notation
133
+ temperature: Temperature for sampling (1.0 = no modification)
134
+
135
+ Returns:
136
+ Dictionary with:
137
+ - q_values: Q-value distribution
138
+ - action_probs: Action probabilities
139
+ - best_action: Best action index
140
+ - best_move: Best move in UCI notation
141
+ """
142
+ if self.params is None:
143
+ raise ValueError("Model parameters not loaded. Call load_params() first.")
144
+
145
+ # Tokenize input
146
+ tokens = tokenizer.tokenize(fen)
147
+ tokens = tokens[None, :] # Add batch dimension
148
+
149
+ # Get predictions
150
+ bucket_log_probs = self.predictor.predict(
151
+ params=self.params,
152
+ targets=tokens,
153
+ rng=None,
154
+ )
155
+
156
+ # Extract action Q-values (second to last position)
157
+ action_bucket_log_probs = bucket_log_probs[0, -2] # [num_return_buckets]
158
+ action_bucket_probs = jnp.exp(action_bucket_log_probs)
159
+
160
+ # Compute Q-value for each action bucket
161
+ q_value = float(jnp.dot(action_bucket_probs, self.return_buckets_values))
162
+
163
+ # Get action probabilities from Q-values
164
+ # Use softmax over return bucket expectations
165
+ action_values = jnp.dot(
166
+ jnp.exp(bucket_log_probs[0, -2:]),
167
+ self.return_buckets_values,
168
+ )
169
+
170
+ # Apply temperature and softmax
171
+ action_logits = action_values / temperature
172
+ action_probs = jax.nn.softmax(action_logits)
173
+
174
+ # Get best action
175
+ best_action = int(jnp.argmax(action_probs))
176
+
177
+ # Convert action to move
178
+ best_move = utils.ACTION_TO_MOVE.get(best_action, "unknown")
179
+
180
+ return {
181
+ "q_value": q_value,
182
+ "action_probs": np.array(action_probs),
183
+ "best_action": best_action,
184
+ "best_move": best_move,
185
+ }
186
+
187
+ def save_pretrained(self, save_directory: str):
188
+ """Save model to directory in HuggingFace format."""
189
+ os.makedirs(save_directory, exist_ok=True)
190
+
191
+ # Save config
192
+ self.config.save_pretrained(save_directory)
193
+
194
+ # Save parameters as numpy arrays
195
+ if self.params is not None:
196
+ params_cpu = jax.device_get(self.params)
197
+ params_flat, tree_def = jax.tree.flatten(params_cpu)
198
+
199
+ # Save flattened params
200
+ params_path = os.path.join(save_directory, "params.npz")
201
+ np.savez(params_path, *params_flat)
202
+
203
+ # Save tree structure
204
+ import pickle
205
+ tree_path = os.path.join(save_directory, "tree_structure.pkl")
206
+ with open(tree_path, "wb") as f:
207
+ pickle.dump(tree_def, f)
208
+
209
+ # Copy necessary source files for standalone usage
210
+ import shutil
211
+ src_dir = os.path.dirname(__file__)
212
+ code_dir = os.path.join(save_directory, "searchless_chess_code")
213
+ os.makedirs(code_dir, exist_ok=True)
214
+
215
+ # Copy core modules and fix imports for standalone usage
216
+ def fix_imports(content):
217
+ """Replace absolute imports with relative imports."""
218
+ content = content.replace("import tokenizer", "import tokenizer")
219
+ content = content.replace("import transformer", "import transformer")
220
+ content = content.replace("import utils", "import utils")
221
+ content = content.replace("import constants", "import constants")
222
+ content = content.replace("import config as config_lib", "import config as config_lib")
223
+ content = content.replace("import config", "import config")
224
+ return content
225
+
226
+ for module in ["tokenizer.py", "transformer.py", "constants.py", "utils.py", "config.py"]:
227
+ src_file = os.path.join(src_dir, module)
228
+ dst_file = os.path.join(code_dir, module)
229
+ if os.path.exists(src_file):
230
+ with open(src_file, 'r') as f:
231
+ content = fix_imports(f.read())
232
+ with open(dst_file, 'w') as f:
233
+ f.write(content)
234
+
235
+ # Create standalone hf_model.py
236
+ standalone_hf_model = os.path.join(code_dir, "hf_model.py")
237
+ with open(__file__, 'r') as source:
238
+ content = fix_imports(source.read())
239
+ with open(standalone_hf_model, 'w') as dest:
240
+ dest.write(content)
241
+
242
+ # Create __init__.py
243
+ with open(os.path.join(code_dir, "__init__.py"), "w") as f:
244
+ f.write("# Searchless Chess code bundle\n")
245
+
246
+ # Save model info
247
+ model_info = {
248
+ "model_type": "searchless_chess",
249
+ "framework": "jax",
250
+ "library": "dm-haiku",
251
+ "includes_source": True,
252
+ "source_modules": ["tokenizer.py", "transformer.py", "constants.py", "utils.py", "config.py"],
253
+ }
254
+ with open(os.path.join(save_directory, "model_info.json"), "w") as f:
255
+ json.dump(model_info, f, indent=2)
256
+
257
+ @classmethod
258
+ def from_pretrained(cls, model_path: str) -> "SearchlessChessModel":
259
+ """Load model from directory."""
260
+ # Load config
261
+ config = SearchlessChessConfig.from_pretrained(model_path)
262
+
263
+ # Create model
264
+ model = cls(config)
265
+
266
+ # Load parameters
267
+ params_path = os.path.join(model_path, "params.npz")
268
+ tree_path = os.path.join(model_path, "tree_structure.pkl")
269
+
270
+ if os.path.exists(params_path) and os.path.exists(tree_path):
271
+ # Load tree structure
272
+ import pickle
273
+ with open(tree_path, "rb") as f:
274
+ tree_def = pickle.load(f)
275
+
276
+ # Load params
277
+ params_data = np.load(params_path)
278
+ params_flat = [params_data[f"arr_{i}"] for i in range(len(params_data.files))]
279
+
280
+ # Reconstruct pytree
281
+ model.params = jax.tree.unflatten(tree_def, params_flat)
282
+
283
+ return model
284
+
285
+
286
+ def create_model_from_checkpoint(
287
+ checkpoint_path: str,
288
+ model_name: str = "9M",
289
+ use_ema: bool = True,
290
+ ) -> SearchlessChessModel:
291
+ """Create HuggingFace model from existing checkpoint.
292
+
293
+ Args:
294
+ checkpoint_path: Path to checkpoint directory (e.g., checkpoints/9M_selfplay/4)
295
+ model_name: Model size (9M, 136M, 270M)
296
+ use_ema: Whether to load EMA parameters
297
+
298
+ Returns:
299
+ SearchlessChessModel ready to save or use
300
+ """
301
+ # Determine architecture from model name
302
+ if model_name == "9M":
303
+ num_layers, embedding_dim, num_heads = 8, 256, 8
304
+ elif model_name == "136M":
305
+ num_layers, embedding_dim, num_heads = 8, 1024, 8
306
+ else: # 270M
307
+ num_layers, embedding_dim, num_heads = 16, 1024, 8
308
+
309
+ # Create config
310
+ config = SearchlessChessConfig(
311
+ vocab_size=1968,
312
+ output_size=128,
313
+ embedding_dim=embedding_dim,
314
+ num_layers=num_layers,
315
+ num_heads=num_heads,
316
+ max_sequence_length=79,
317
+ num_return_buckets=128,
318
+ model_name=model_name,
319
+ )
320
+
321
+ # Create model
322
+ model = SearchlessChessModel(config)
323
+
324
+ # Load parameters from Orbax checkpoint
325
+ params_dir = "params_ema" if use_ema else "params"
326
+ params_path = os.path.join(checkpoint_path, params_dir)
327
+ model.load_params(params_path)
328
+
329
+ return model
hf_space_repo/searchless_chess_model/searchless_chess_code/tokenizer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements tokenization of FEN strings."""
17
+
18
+ import jaxtyping as jtp
19
+ import numpy as np
20
+
21
+
22
+ # pyfmt: disable
23
+ _CHARACTERS = [
24
+ '0',
25
+ '1',
26
+ '2',
27
+ '3',
28
+ '4',
29
+ '5',
30
+ '6',
31
+ '7',
32
+ '8',
33
+ '9',
34
+ 'a',
35
+ 'b',
36
+ 'c',
37
+ 'd',
38
+ 'e',
39
+ 'f',
40
+ 'g',
41
+ 'h',
42
+ 'p',
43
+ 'n',
44
+ 'r',
45
+ 'k',
46
+ 'q',
47
+ 'P',
48
+ 'B',
49
+ 'N',
50
+ 'R',
51
+ 'Q',
52
+ 'K',
53
+ 'w',
54
+ '.',
55
+ ]
56
+ # pyfmt: enable
57
+ _CHARACTERS_INDEX = {letter: index for index, letter in enumerate(_CHARACTERS)}
58
+ _SPACES_CHARACTERS = frozenset({'1', '2', '3', '4', '5', '6', '7', '8'})
59
+ SEQUENCE_LENGTH = 77
60
+
61
+
62
+ def tokenize(fen: str) -> jtp.Int32[jtp.Array, 'T']:
63
+ """Returns an array of tokens from a fen string.
64
+
65
+ We compute a tokenized representation of the board, from the FEN string.
66
+ The final array of tokens is a mapping from this string to numbers, which
67
+ are defined in the dictionary `_CHARACTERS_INDEX`.
68
+ For the 'en passant' information, we convert the '-' (which means there is
69
+ no en passant relevant square) to '..', to always have two characters, and
70
+ a fixed length output.
71
+
72
+ Args:
73
+ fen: The board position in Forsyth-Edwards Notation.
74
+ """
75
+ # Extracting the relevant information from the FEN.
76
+ board, side, castling, en_passant, halfmoves_last, fullmoves = fen.split(' ')
77
+ board = board.replace('/', '')
78
+ board = side + board
79
+
80
+ indices = list()
81
+
82
+ for char in board:
83
+ if char in _SPACES_CHARACTERS:
84
+ indices.extend(int(char) * [_CHARACTERS_INDEX['.']])
85
+ else:
86
+ indices.append(_CHARACTERS_INDEX[char])
87
+
88
+ if castling == '-':
89
+ indices.extend(4 * [_CHARACTERS_INDEX['.']])
90
+ else:
91
+ for char in castling:
92
+ indices.append(_CHARACTERS_INDEX[char])
93
+ # Padding castling to have exactly 4 characters.
94
+ if len(castling) < 4:
95
+ indices.extend((4 - len(castling)) * [_CHARACTERS_INDEX['.']])
96
+
97
+ if en_passant == '-':
98
+ indices.extend(2 * [_CHARACTERS_INDEX['.']])
99
+ else:
100
+ # En passant is a square like 'e3'.
101
+ for char in en_passant:
102
+ indices.append(_CHARACTERS_INDEX[char])
103
+
104
+ # Three digits for halfmoves (since last capture) is enough since the game
105
+ # ends at 50.
106
+ halfmoves_last += '.' * (3 - len(halfmoves_last))
107
+ indices.extend([_CHARACTERS_INDEX[x] for x in halfmoves_last])
108
+
109
+ # Three digits for full moves is enough (no game lasts longer than 999
110
+ # moves).
111
+ fullmoves += '.' * (3 - len(fullmoves))
112
+ indices.extend([_CHARACTERS_INDEX[x] for x in fullmoves])
113
+
114
+ assert len(indices) == SEQUENCE_LENGTH
115
+
116
+ return np.asarray(indices, dtype=np.uint8)
hf_space_repo/searchless_chess_model/searchless_chess_code/transformer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Transformer model."""
17
+
18
+ import dataclasses
19
+ import enum
20
+ import functools
21
+
22
+ import haiku as hk
23
+ import jax
24
+ import jax.nn as jnn
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ import constants
29
+
30
+
31
+ class PositionalEncodings(enum.Enum):
32
+ SINUSOID = enum.auto()
33
+ LEARNED = enum.auto()
34
+
35
+
36
+ @dataclasses.dataclass(kw_only=True)
37
+ class TransformerConfig:
38
+ """Hyperparameters used in the Transformer architectures."""
39
+
40
+ # The random seed for parameter initialization.
41
+ seed: int = 1
42
+ # The input vocabulary size.
43
+ vocab_size: int
44
+ # The output size (by default equal to the vocabulary size).
45
+ output_size: int | None = None
46
+ # The dimension of the first embedding.
47
+ embedding_dim: int = 64
48
+ # The number of multi-head attention layers.
49
+ num_layers: int = 4
50
+ # The number of heads per layer.
51
+ num_heads: int = 8
52
+ # Whether to use a causal mask or not.
53
+ use_causal_mask: bool = True
54
+ # The parameter initialization scale for the embeddings.
55
+ emb_init_scale: float = 0.02
56
+ # Positional encodings to use.
57
+ pos_encodings: PositionalEncodings = PositionalEncodings.SINUSOID
58
+ # Maximum sequence length, useful for the LEARNED positional encodings.
59
+ max_sequence_length: int | None = None
60
+ # How much larger the hidden layer of the feedforward network should be
61
+ # compared to the `embedding_dim`.
62
+ widening_factor: int = 4
63
+ # Whether to apply QK normalization trick in attention layer.
64
+ apply_qk_layernorm: bool = False
65
+ # Whether to apply post LN after attention + MLP blocks
66
+ apply_post_ln: bool = True
67
+
68
+ def __post_init__(self):
69
+ if self.output_size is None:
70
+ self.output_size = self.vocab_size
71
+
72
+
73
+ class MultiHeadDotProductAttention(hk.Module):
74
+ """Multi-head dot-product attention (Vaswani et al., 2017)."""
75
+
76
+ def __init__(
77
+ self,
78
+ num_heads: int,
79
+ num_hiddens_per_head: int,
80
+ name: str | None = None,
81
+ apply_qk_layernorm: bool = False,
82
+ ) -> None:
83
+ """Initializes the attention module.
84
+
85
+ Args:
86
+ num_heads: Number of heads to use.
87
+ num_hiddens_per_head: Number of hidden neurons per head.
88
+ name: Name of the module.
89
+ apply_qk_layernorm: Applies layernorm to query and key matrices, this
90
+ helps training stability.
91
+ """
92
+ super().__init__(name=name)
93
+ self._num_heads = num_heads
94
+ self._num_hiddens_per_head = num_hiddens_per_head
95
+ self._apply_qk_layernorm = apply_qk_layernorm
96
+
97
+ def __call__(
98
+ self,
99
+ inputs_q: jax.Array,
100
+ inputs_kv: jax.Array,
101
+ mask: jax.Array | None = None,
102
+ ) -> jax.Array:
103
+ """Returns the output of the multi-head attention."""
104
+ batch_size, sequence_length, embedding_size = inputs_q.shape
105
+
106
+ num_hiddens = self._num_hiddens_per_head * self._num_heads
107
+ q = hk.Linear(num_hiddens, with_bias=False)(inputs_q)
108
+ k = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
109
+
110
+ if self._apply_qk_layernorm:
111
+ q = layer_norm(q)
112
+ k = layer_norm(k)
113
+
114
+ v = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
115
+ # The second (sequence) dimension is undefined since it can differ between
116
+ # queries and keys/values when decoding. Also checking that the inputs have
117
+ # the same batch size as the reshape below does not guarantee a failure if
118
+ # they are different.
119
+ new_shape = (batch_size, -1, self._num_heads, self._num_hiddens_per_head)
120
+ q = jnp.reshape(q, new_shape)
121
+ k = jnp.reshape(k, new_shape)
122
+ v = jnp.reshape(v, new_shape)
123
+
124
+ # Let b=batch_size, t=seq_len, h=num_heads, and d=num_hiddens_per_head.
125
+ attention = jnp.einsum('bthd,bThd->bhtT', q, k)
126
+ attention *= 1.0 / jnp.sqrt(self._num_hiddens_per_head)
127
+
128
+ if mask is not None:
129
+ attention = jnp.where(mask, attention, jnp.finfo(jnp.float32).min)
130
+
131
+ normalized_attention = jnn.softmax(attention)
132
+
133
+ output = jnp.einsum('bhtT,bThd->bthd', normalized_attention, v)
134
+ output = jnp.reshape(output, (batch_size, sequence_length, num_hiddens))
135
+ return hk.Linear(embedding_size, with_bias=False)(output)
136
+
137
+
138
+ def sinusoid_position_encoding(
139
+ sequence_length: int,
140
+ hidden_size: int,
141
+ max_timescale: float = 1e4,
142
+ ) -> np.ndarray:
143
+ """Creates sinusoidal encodings from the original transformer paper.
144
+
145
+ The returned values are, for all i < D/2:
146
+ array[pos, i] = sin(pos / (max_timescale^(2*i / D)))
147
+ array[pos, D/2 + i] = cos(pos / (max_timescale^(2*i / D)))
148
+
149
+ Args:
150
+ sequence_length: Sequence length.
151
+ hidden_size: Dimension of the positional encoding vectors, D. Should be
152
+ even.
153
+ max_timescale: Maximum timescale for the frequency.
154
+
155
+ Returns:
156
+ An array of shape [L, D] if `add_negative` or `keep_positive_side` is
157
+ `False`, else [2 * L, D].
158
+ """
159
+ freqs = np.arange(0, hidden_size + 1, 2)
160
+ inv_freq = max_timescale ** (-freqs / hidden_size)
161
+
162
+ pos_seq = np.arange(start=0, stop=sequence_length)
163
+
164
+ sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq)
165
+ embeddings = np.concatenate(
166
+ [np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1
167
+ )
168
+ return embeddings[:, :hidden_size]
169
+
170
+
171
+ def embed_sequences(
172
+ sequences: jax.Array,
173
+ config: TransformerConfig,
174
+ ) -> jax.Array:
175
+ """Returns embeddings for sequences of tokens."""
176
+ embs_init = hk.initializers.TruncatedNormal(stddev=config.emb_init_scale)
177
+ embeddings_layer = hk.Embed(
178
+ vocab_size=config.vocab_size,
179
+ embed_dim=config.embedding_dim,
180
+ lookup_style=hk.EmbedLookupStyle.ARRAY_INDEX,
181
+ w_init=embs_init,
182
+ )
183
+ embeddings = embeddings_layer(sequences)
184
+ embeddings *= jnp.sqrt(config.embedding_dim)
185
+
186
+ _, sequence_length, embedding_size = embeddings.shape
187
+ match config.pos_encodings:
188
+ case PositionalEncodings.SINUSOID:
189
+ pos_encodings = sinusoid_position_encoding(
190
+ sequence_length=sequence_length,
191
+ hidden_size=embedding_size,
192
+ )
193
+ case PositionalEncodings.LEARNED:
194
+ assert sequence_length <= config.max_sequence_length
195
+ positions = jnp.arange(sequence_length)
196
+ pos_encodings = hk.Embed(
197
+ vocab_size=config.max_sequence_length,
198
+ embed_dim=embedding_size,
199
+ )(positions)
200
+ return embeddings + pos_encodings
201
+
202
+
203
+ def layer_norm(x: jax.Array) -> jax.Array:
204
+ """Helper function for layer norm."""
205
+ return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
206
+
207
+
208
+ def shift_right(sequences: jax.Array) -> jax.Array:
209
+ """Right-shift the one-hot encoded input by padding on the temporal axis."""
210
+ bos_array = jnp.zeros((sequences.shape[0], 1), dtype=jnp.uint8)
211
+ padded_sequences = jnp.concatenate([bos_array, sequences], axis=1)
212
+ return padded_sequences[:, :-1]
213
+
214
+
215
+ def _mlp_block(inputs: jax.Array, config: TransformerConfig) -> jax.Array:
216
+ """Gated MLP block for the Transformer."""
217
+ ffn_dim = config.embedding_dim * config.widening_factor
218
+ split_1 = hk.Linear(ffn_dim, with_bias=False)(inputs)
219
+ split_2 = hk.Linear(ffn_dim, with_bias=False)(inputs)
220
+ gate_output = jnn.silu(split_1) * split_2
221
+ return hk.Linear(config.embedding_dim, with_bias=False)(gate_output)
222
+
223
+
224
+ def _attention_block(inputs: jax.Array, config: TransformerConfig) -> jax.Array:
225
+ """Attention block for the Transformer."""
226
+ batch_size, sequence_length = inputs.shape[:2]
227
+ if config.use_causal_mask:
228
+ causal_mask = np.tril(
229
+ np.ones((batch_size, 1, sequence_length, sequence_length))
230
+ )
231
+ else:
232
+ causal_mask = None
233
+ block = MultiHeadDotProductAttention(
234
+ num_heads=config.num_heads,
235
+ num_hiddens_per_head=config.embedding_dim // config.num_heads,
236
+ apply_qk_layernorm=config.apply_qk_layernorm,
237
+ )
238
+ return block(inputs_q=inputs, inputs_kv=inputs, mask=causal_mask)
239
+
240
+
241
+ def transformer_decoder(
242
+ targets: jax.Array,
243
+ config: TransformerConfig,
244
+ ) -> jax.Array:
245
+ """Returns the transformer decoder output, shape [B, T, V].
246
+
247
+ Follows the LLaMa architecture:
248
+ https://github.com/facebookresearch/llama/blob/main/llama/model.py
249
+ Main changes to the original Transformer decoder:
250
+ - Using gating in the MLP block, with SwiGLU activation function.
251
+ - Using normalization before the attention and MLP blocks.
252
+
253
+ Args:
254
+ targets: The integer target values, shape [B, T].
255
+ config: The config to use for the transformer.
256
+ """
257
+ # Right shift the targets to get the inputs (the first token is now a 0).
258
+ inputs = shift_right(targets)
259
+
260
+ # Embeds the inputs and adds positional encodings.
261
+ embeddings = embed_sequences(inputs, config)
262
+
263
+ h = embeddings
264
+ for _ in range(config.num_layers):
265
+ attention_input = layer_norm(h)
266
+ attention = _attention_block(attention_input, config)
267
+ h += attention
268
+
269
+ mlp_input = layer_norm(h)
270
+ mlp_output = _mlp_block(mlp_input, config)
271
+ h += mlp_output
272
+
273
+ if config.apply_post_ln:
274
+ h = layer_norm(h)
275
+ logits = hk.Linear(config.output_size)(h)
276
+ return jnn.log_softmax(logits, axis=-1)
277
+
278
+
279
+ def build_transformer_predictor(
280
+ config: TransformerConfig,
281
+ ) -> constants.Predictor:
282
+ """Returns a transformer predictor."""
283
+ model = hk.transform(functools.partial(transformer_decoder, config=config))
284
+ return constants.Predictor(initial_params=model.init, predict=model.apply)
hf_space_repo/searchless_chess_model/searchless_chess_code/utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements some utility functions."""
17
+
18
+ import math
19
+
20
+ import chess
21
+ import numpy as np
22
+
23
+
24
+ # The lists of the strings of the row and columns of a chess board,
25
+ # traditionally named rank and file.
26
+ _CHESS_FILE = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
27
+
28
+
29
+ def _compute_all_possible_actions() -> tuple[dict[str, int], dict[int, str]]:
30
+ """Returns two dicts converting moves to actions and actions to moves.
31
+
32
+ These dicts contain all possible chess moves.
33
+ """
34
+ all_moves = []
35
+
36
+ # First, deal with the normal moves.
37
+ # Note that this includes castling, as it is just a rook or king move from one
38
+ # square to another.
39
+ board = chess.BaseBoard.empty()
40
+ for square in range(64):
41
+ next_squares = []
42
+
43
+ # Place the queen and see where it attacks (we don't need to cover the case
44
+ # for a bishop, rook, or pawn because the queen's moves includes all their
45
+ # squares).
46
+ board.set_piece_at(square, chess.Piece.from_symbol('Q'))
47
+ next_squares += board.attacks(square)
48
+
49
+ # Place knight and see where it attacks
50
+ board.set_piece_at(square, chess.Piece.from_symbol('N'))
51
+ next_squares += board.attacks(square)
52
+ board.remove_piece_at(square)
53
+
54
+ for next_square in next_squares:
55
+ all_moves.append(
56
+ chess.square_name(square) + chess.square_name(next_square)
57
+ )
58
+
59
+ # Then deal with promotions.
60
+ # Only look at the last ranks.
61
+ promotion_moves = []
62
+ for rank, next_rank in [('2', '1'), ('7', '8')]:
63
+ for index_file, file in enumerate(_CHESS_FILE):
64
+ # Normal promotions.
65
+ move = f'{file}{rank}{file}{next_rank}'
66
+ promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
67
+
68
+ # Capture promotions.
69
+ # Left side.
70
+ if file > 'a':
71
+ next_file = _CHESS_FILE[index_file - 1]
72
+ move = f'{file}{rank}{next_file}{next_rank}'
73
+ promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
74
+ # Right side.
75
+ if file < 'h':
76
+ next_file = _CHESS_FILE[index_file + 1]
77
+ move = f'{file}{rank}{next_file}{next_rank}'
78
+ promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
79
+ all_moves += promotion_moves
80
+
81
+ move_to_action, action_to_move = {}, {}
82
+ for action, move in enumerate(all_moves):
83
+ assert move not in move_to_action
84
+ move_to_action[move] = action
85
+ action_to_move[action] = move
86
+
87
+ return move_to_action, action_to_move
88
+
89
+
90
+ MOVE_TO_ACTION, ACTION_TO_MOVE = _compute_all_possible_actions()
91
+ NUM_ACTIONS = len(MOVE_TO_ACTION)
92
+
93
+
94
+ def centipawns_to_win_probability(centipawns: int) -> float:
95
+ """Returns the win probability (in [0, 1]) converted from the centipawn score.
96
+
97
+ Reference: https://lichess.org/page/accuracy
98
+ Well-known transformation, backed by real-world data.
99
+
100
+ Args:
101
+ centipawns: The chess score in centipawns.
102
+ """
103
+ return 0.5 + 0.5 * (2 / (1 + math.exp(-0.00368208 * centipawns)) - 1)
104
+
105
+
106
+ def get_uniform_buckets_edges_values(
107
+ num_buckets: int,
108
+ ) -> tuple[np.ndarray, np.ndarray]:
109
+ """Returns edges and values of uniformly sampled buckets in [0, 1].
110
+
111
+ Example: for num_buckets=4, it returns:
112
+ edges=[0.25, 0.50, 0.75]
113
+ values=[0.125, 0.375, 0.625, 0.875]
114
+
115
+ Args:
116
+ num_buckets: Number of buckets to create.
117
+ """
118
+ full_linspace = np.linspace(0.0, 1.0, num_buckets + 1)
119
+ edges = full_linspace[1:-1]
120
+ values = (full_linspace[:-1] + full_linspace[1:]) / 2
121
+ return edges, values
122
+
123
+
124
+ def compute_return_buckets_from_returns(
125
+ returns: np.ndarray,
126
+ bins_edges: np.ndarray,
127
+ ) -> np.ndarray:
128
+ """Arranges the discounted returns into bins.
129
+
130
+ The returns are put into the bins specified by `bin_edges`. The length of
131
+ `bin_edges` is equal to the number of buckets minus 1. In case of a tie (if
132
+ the return is exactly equal to an edge), we take the bucket right before the
133
+ edge. See example below.
134
+ This function is purely using np.searchsorted, so it's a good reference to
135
+ look at.
136
+
137
+ Examples:
138
+ * bin_edges=[0.5] and returns=[0., 1.] gives the buckets [0, 1].
139
+ * bin_edges=[-30., 30.] and returns=[-200., -30., 0., 1.] gives the buckets
140
+ [0, 0, 1, 1].
141
+
142
+ Args:
143
+ returns: An array of discounted returns, rank 1.
144
+ bins_edges: The boundary values of the return buckets, rank 1.
145
+
146
+ Returns:
147
+ An array of buckets, described as integers, rank 1.
148
+
149
+ Raises:
150
+ ValueError if `returns` or `bins_edges` are not of rank 1.
151
+ """
152
+ if len(returns.shape) != 1:
153
+ raise ValueError(
154
+ 'The passed returns should be of rank 1. Got'
155
+ f' rank={len(returns.shape)}.'
156
+ )
157
+ if len(bins_edges.shape) != 1:
158
+ raise ValueError(
159
+ 'The passed bins_edges should be of rank 1. Got'
160
+ f' rank{len(bins_edges.shape)}.'
161
+ )
162
+ return np.searchsorted(bins_edges, returns, side='left')
hf_space_repo/train_self_play.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training script for GRPO chess self-play."""
2
+ import argparse
3
+ import warnings
4
+ from typing import Any
5
+ from torch.utils.data import DataLoader
6
+
7
+ from src.grpo_self_play.trainer import get_trainer
8
+ from src.grpo_self_play.chess.boards_dataset import ChessStartStatesDataset
9
+ from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer
10
+ from src.grpo_self_play.configs.config_loader import load_experiment_config
11
+
12
+
13
+ def train(
14
+ config_path: str = "default.yaml",
15
+ overrides: dict[str, dict[str, Any]] | None = None,
16
+ dataloader_kwargs: dict[str, Any] | None = None
17
+ ) -> None:
18
+ """Main training function for GRPO chess self-play.
19
+
20
+ Args:
21
+ config_path: Path to the YAML config file (relative to configs directory)
22
+ overrides: Optional dict of overrides per section. Example:
23
+ {
24
+ "grpo": {"lr": 1e-4, "entropy_coef": 0.2},
25
+ "training": {"num_epochs": 100},
26
+ "stockfish": {"skill_level": 5},
27
+ }
28
+ dataloader_kwargs: Optional dict of arguments to pass to DataLoader constructor.
29
+ These override config values. Example: {"batch_size": 64, "num_workers": 4}
30
+ """
31
+ config = load_experiment_config(config_path, overrides=overrides)
32
+
33
+ # Build dataloader kwargs from config, with defaults
34
+ dataloader_config = {
35
+ "batch_size": config.training.batch_size,
36
+ "num_workers": 2,
37
+ }
38
+
39
+ # Apply dataloader_kwargs overrides and warn if overriding config values
40
+ if dataloader_kwargs:
41
+ for key, value in dataloader_kwargs.items():
42
+ if key in dataloader_config:
43
+ warnings.warn(
44
+ f"Overriding DataLoader '{key}' from config ({dataloader_config[key]}) "
45
+ f"with provided value ({value})",
46
+ UserWarning,
47
+ stacklevel=2
48
+ )
49
+ dataloader_config[key] = value
50
+
51
+ trainer = get_trainer(num_epochs=config.training.num_epochs)
52
+ dataset = ChessStartStatesDataset(config.dataset)
53
+ dataloader = DataLoader(dataset, **dataloader_config)
54
+ model = GRPOChessTransformer(
55
+ transformer_config=config.transformer,
56
+ grpo_config=config.grpo,
57
+ eval_cfg=config.eval,
58
+ stockfish_cfg=config.stockfish,
59
+ policy_cfg=config.policy,
60
+ searcher_cfg=config.searcher,
61
+ pretrain_cfg=config.pretrain,
62
+ )
63
+
64
+ print("Starting Training with WandB Tracking...")
65
+ trainer.fit(model, dataloader)
66
+
67
+
68
+ if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--config", type=str, default="default.yaml")
71
+ args = parser.parse_args()
72
+ train(config_path=args.config)
hf_space_repo/trainer.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random
3
+ import string
4
+ import pytorch_lightning as pl
5
+
6
+ from pytorch_lightning.loggers import WandbLogger
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+
9
+ def generate_run_name(project: str = "chess-grpo") -> str:
10
+ """Generate a unique run name with timestamp and random suffix.
11
+
12
+ Args:
13
+ project: Project name prefix
14
+
15
+ Returns:
16
+ Unique run name string
17
+ """
18
+ timestamp = time.strftime("%Y%m%d-%H%M")
19
+ random_suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
20
+ return f"{project}-{timestamp}-{random_suffix}"
21
+
22
+
23
+ def get_trainer(num_epochs: int = 5000,
24
+ checkpoint_dir: str = "/content/drive/MyDrive/data/grpo-chess/checkpoints/",
25
+ checkpoint_every_n_epochs: int = 5,
26
+ keep_n_checkpoints: int = 3) -> pl.Trainer:
27
+ """Create a PyTorch Lightning trainer with WandB logging and checkpointing.
28
+
29
+ Args:
30
+ num_epochs: Maximum number of training epochs
31
+ checkpoint_dir: Directory to save model checkpoints
32
+ checkpoint_every_n_epochs: Save periodic checkpoint every N epochs
33
+ keep_n_checkpoints: Keep last N periodic checkpoints per run
34
+
35
+ Returns:
36
+ Configured PyTorch Lightning trainer
37
+ """
38
+ run_name = generate_run_name()
39
+ print(f"Generated run name: {run_name}")
40
+
41
+ wandb_logger = WandbLogger(project="Chess-GRPO-Bot", log_model=True, name=run_name)
42
+
43
+ # Best checkpoint - saves top 2 based on loss
44
+ best_checkpoint_cb = ModelCheckpoint(
45
+ dirpath=checkpoint_dir,
46
+ filename=run_name + "-best-{epoch:02d}-{train_total_loss:.4f}",
47
+ save_top_k=2,
48
+ monitor="train_total_loss",
49
+ mode="min"
50
+ )
51
+
52
+ # Periodic checkpoint for crash recovery
53
+ # Fixed filenames (periodic-0, periodic-1, etc.) that rotate within each run
54
+ periodic_checkpoint_cb = ModelCheckpoint(
55
+ dirpath=checkpoint_dir,
56
+ filename=run_name + "-periodic",
57
+ save_top_k=keep_n_checkpoints,
58
+ monitor="train_total_loss",
59
+ mode="min",
60
+ every_n_epochs=checkpoint_every_n_epochs,
61
+ save_last=True, # Always keep the very last checkpoint
62
+ )
63
+
64
+ return pl.Trainer(
65
+ max_epochs=num_epochs,
66
+ # Gradient clipping handled manually in GRPOChessTransformer.training_step
67
+ accelerator="auto",
68
+ devices=1,
69
+ logger=wandb_logger,
70
+ callbacks=[best_checkpoint_cb, periodic_checkpoint_cb],
71
+ log_every_n_steps=1 # Log every step for GRPO debug
72
+ )
73
+
74
+
requirements.txt CHANGED
@@ -4,9 +4,5 @@ torch
4
  safetensors
5
  python-chess
6
  huggingface_hub
7
- pytorch_lightning
8
- mcp>=0.9.0
9
- wandb>=0.16.0
10
  jaxtyping
11
- datasets
12
- gradio>=4.44.1
 
4
  safetensors
5
  python-chess
6
  huggingface_hub
7
+ numpy
 
 
8
  jaxtyping