Spaces:
Sleeping
Sleeping
Switch Space to Docker + FastAPI
Browse files- .dockerignore +8 -0
- Dockerfile +16 -0
- README.md +1 -3
- app.py +22 -42
- hf_space_repo/README.md +428 -0
- hf_space_repo/__init__.py +30 -0
- hf_space_repo/chess/__init__.py +0 -0
- hf_space_repo/chess/boards_dataset.py +465 -0
- hf_space_repo/chess/chess_logic.py +63 -0
- hf_space_repo/chess/policy_player.py +98 -0
- hf_space_repo/chess/rewards.py +108 -0
- hf_space_repo/chess/searcher.py +90 -0
- hf_space_repo/chess/stockfish.py +288 -0
- hf_space_repo/configs/__init__.py +43 -0
- hf_space_repo/configs/config_loader.py +290 -0
- hf_space_repo/configs/default.yaml +123 -0
- hf_space_repo/configs/pretrain.yaml +49 -0
- hf_space_repo/constants.py +15 -0
- hf_space_repo/eval_utils.py +211 -0
- hf_space_repo/evaluator.py +118 -0
- hf_space_repo/grpo_logic/__init__.py +0 -0
- hf_space_repo/grpo_logic/loss.py +235 -0
- hf_space_repo/grpo_logic/model.py +782 -0
- hf_space_repo/grpo_logic/sampling.py +243 -0
- hf_space_repo/logging_utils.py +32 -0
- hf_space_repo/models.py +234 -0
- hf_space_repo/pretrain/README.md +153 -0
- hf_space_repo/pretrain/__init__.py +15 -0
- hf_space_repo/pretrain/pretrain.py +579 -0
- hf_space_repo/pretrain/pretrain_dataset.py +328 -0
- hf_space_repo/pretrain/pretrain_load_config.py +21 -0
- hf_space_repo/searchless_chess_imports.py +3 -0
- hf_space_repo/searchless_chess_model/.gitattributes +35 -0
- hf_space_repo/searchless_chess_model/README.md +177 -0
- hf_space_repo/searchless_chess_model/config.json +10 -0
- hf_space_repo/searchless_chess_model/model_info.json +13 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/__init__.py +1 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/config.py +90 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/constants.py +119 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/hf_model.py +329 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/tokenizer.py +116 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/transformer.py +284 -0
- hf_space_repo/searchless_chess_model/searchless_chess_code/utils.py +162 -0
- hf_space_repo/train_self_play.py +72 -0
- hf_space_repo/trainer.py +74 -0
- requirements.txt +1 -5
.dockerignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.git/
|
| 7 |
+
.DS_Store
|
| 8 |
+
node_modules/
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ENV PORT=7860
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt /app/requirements.txt
|
| 10 |
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . /app
|
| 13 |
+
|
| 14 |
+
EXPOSE 7860
|
| 15 |
+
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -15,9 +15,7 @@ title: grpo-chess-api
|
|
| 15 |
emoji: ♟️
|
| 16 |
colorFrom: amber
|
| 17 |
colorTo: red
|
| 18 |
-
sdk:
|
| 19 |
-
sdk_version: 4.44.1
|
| 20 |
-
app_file: app.py
|
| 21 |
pinned: false
|
| 22 |
---
|
| 23 |
|
|
|
|
| 15 |
emoji: ♟️
|
| 16 |
colorFrom: amber
|
| 17 |
colorTo: red
|
| 18 |
+
sdk: docker
|
|
|
|
|
|
|
| 19 |
pinned: false
|
| 20 |
---
|
| 21 |
|
app.py
CHANGED
|
@@ -2,8 +2,9 @@ import os
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import chess
|
| 5 |
-
import gradio as gr
|
| 6 |
import torch
|
|
|
|
|
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from safetensors.torch import load_file
|
|
@@ -69,51 +70,30 @@ def choose_move(model, board: chess.Board, temperature: float, greedy: bool) ->
|
|
| 69 |
return move
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def move(req: MoveRequest):
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
model = load_model()
|
| 75 |
move = choose_move(model, board, req.temperature, req.greedy)
|
| 76 |
san = board.san(move)
|
| 77 |
board.push(move)
|
| 78 |
return MoveResponse(uci=move.uci(), san=san, fen=board.fen())
|
| 79 |
|
| 80 |
-
|
| 81 |
-
def gradio_move(fen: str, temperature: float, greedy: bool):
|
| 82 |
-
req = MoveRequest(fen=fen, temperature=temperature, greedy=greedy)
|
| 83 |
-
res = move(req)
|
| 84 |
-
return res.uci, res.san, res.fen
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
with gr.Blocks(title="GRPO Chess API") as demo:
|
| 88 |
-
gr.Markdown(
|
| 89 |
-
"## GRPO Chess Model API\n"
|
| 90 |
-
"Use this panel to test the model. The website calls the Gradio API at "
|
| 91 |
-
"`/run/move`."
|
| 92 |
-
)
|
| 93 |
-
fen = gr.Textbox(
|
| 94 |
-
label="FEN",
|
| 95 |
-
value="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
|
| 96 |
-
)
|
| 97 |
-
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
|
| 98 |
-
greedy = gr.Checkbox(label="Greedy", value=False)
|
| 99 |
-
btn = gr.Button("Get Move")
|
| 100 |
-
uci = gr.Textbox(label="UCI Move")
|
| 101 |
-
san = gr.Textbox(label="SAN Move")
|
| 102 |
-
fen_out = gr.Textbox(label="Next FEN")
|
| 103 |
-
btn.click(
|
| 104 |
-
gradio_move,
|
| 105 |
-
inputs=[fen, temperature, greedy],
|
| 106 |
-
outputs=[uci, san, fen_out],
|
| 107 |
-
api_name="move",
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
app = demo
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
if __name__ == "__main__":
|
| 115 |
-
demo.launch(
|
| 116 |
-
server_name="0.0.0.0",
|
| 117 |
-
server_port=int(os.environ.get("PORT", 7860)),
|
| 118 |
-
show_error=True,
|
| 119 |
-
)
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import chess
|
|
|
|
| 5 |
import torch
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from safetensors.torch import load_file
|
|
|
|
| 70 |
return move
|
| 71 |
|
| 72 |
|
| 73 |
+
app = FastAPI()
|
| 74 |
+
app.add_middleware(
|
| 75 |
+
CORSMiddleware,
|
| 76 |
+
allow_origins=["*"],
|
| 77 |
+
allow_credentials=False,
|
| 78 |
+
allow_methods=["*"],
|
| 79 |
+
allow_headers=["*"],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@app.get("/health")
|
| 84 |
+
def health():
|
| 85 |
+
return {"status": "ok"}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@app.post("/move", response_model=MoveResponse)
|
| 89 |
def move(req: MoveRequest):
|
| 90 |
+
try:
|
| 91 |
+
board = chess.Board(req.fen)
|
| 92 |
+
except Exception as exc:
|
| 93 |
+
raise HTTPException(status_code=400, detail=f"Invalid FEN: {exc}")
|
| 94 |
model = load_model()
|
| 95 |
move = choose_move(model, board, req.temperature, req.greedy)
|
| 96 |
san = board.san(move)
|
| 97 |
board.push(move)
|
| 98 |
return MoveResponse(uci=move.uci(), san=san, fen=board.fen())
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_space_repo/README.md
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO Self-Play Chess Module
|
| 2 |
+
|
| 3 |
+
An experimental, research-grade implementation of **Group Relative Policy Optimization (GRPO)** for training transformer-based chess policies through self-play. This module implements a full reinforcement learning pipeline for chess, but training stability and final strength are still under active investigation.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
This module trains neural network chess policies using GRPO, a variant of Proximal Policy Optimization (PPO) that uses group-based advantage estimation. The system learns to play chess by:
|
| 8 |
+
|
| 9 |
+
1. **Self-Play**: Sampling multiple trajectory groups from diverse starting positions
|
| 10 |
+
2. **Reward Computation**: Using Stockfish evaluations to compute dense rewards
|
| 11 |
+
3. **Policy Optimization**: Applying GRPO with PPO clipping and KL divergence penalties
|
| 12 |
+
4. **Evaluation**: Comprehensive benchmarking against Stockfish at multiple skill levels
|
| 13 |
+
|
| 14 |
+
## Key Features
|
| 15 |
+
|
| 16 |
+
### 🎯 Core Capabilities
|
| 17 |
+
|
| 18 |
+
- **Transformer-Based Policy Network**: Deep neural network architecture that processes FEN-encoded board states
|
| 19 |
+
- **GRPO Training Algorithm**: Group-relative advantage estimation with PPO-style clipping
|
| 20 |
+
- **Self-Play Training Loop**: Infinite dataset of diverse chess positions for robust learning
|
| 21 |
+
- **Stockfish Integration**: Professional-grade evaluation and reward computation
|
| 22 |
+
- **Comprehensive Evaluation**: Multi-level skill ladder evaluation against Stockfish
|
| 23 |
+
- **PyTorch Lightning Integration**: Scalable training with automatic mixed precision, gradient clipping, and checkpointing
|
| 24 |
+
- **Weights & Biases Logging**: Full experiment tracking and visualization
|
| 25 |
+
|
| 26 |
+
### 🏗️ Architecture Highlights
|
| 27 |
+
|
| 28 |
+
- **Modular Design**: Clean separation between model, training logic, chess rules, and evaluation
|
| 29 |
+
- **Efficient Batching**: Parallel trajectory sampling across multiple board positions
|
| 30 |
+
- **Legal Move Masking**: Proper handling of chess rules with action space masking
|
| 31 |
+
- **Trajectory Search**: Optional trajectory search wrapper for improved play strength
|
| 32 |
+
- **Resource Management**: Efficient Stockfish engine pooling and caching
|
| 33 |
+
|
| 34 |
+
## Installation
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
# Install dependencies
|
| 38 |
+
pip install torch pytorch-lightning wandb chess python-chess
|
| 39 |
+
|
| 40 |
+
# Ensure Stockfish is available
|
| 41 |
+
# On Ubuntu/Debian: sudo apt-get install stockfish
|
| 42 |
+
# On macOS: brew install stockfish
|
| 43 |
+
# Or download from: https://stockfishchess.org/download/
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Quick Start
|
| 47 |
+
|
| 48 |
+
### Basic Training
|
| 49 |
+
|
| 50 |
+
The easiest way to start training is using the YAML-based configuration system:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
from src.grpo_self_play.train_self_play import train
|
| 54 |
+
|
| 55 |
+
# Use default configuration (loads from configs/default.yaml)
|
| 56 |
+
train()
|
| 57 |
+
|
| 58 |
+
# Use a custom config file
|
| 59 |
+
train(config_path="my_experiment.yaml")
|
| 60 |
+
|
| 61 |
+
# Override specific hyperparameters programmatically
|
| 62 |
+
train(
|
| 63 |
+
config_path="default.yaml",
|
| 64 |
+
overrides={
|
| 65 |
+
"grpo": {"lr": 1e-4, "num_trajectories": 8},
|
| 66 |
+
"training": {"num_epochs": 100},
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
All hyperparameters (learning rate, model architecture, training settings, etc.) are defined in YAML configuration files. See the [Configuration](#configuration) section below for details.
|
| 72 |
+
|
| 73 |
+
### Running Training in Google Colab
|
| 74 |
+
|
| 75 |
+
**Note for AI agents and contributors**: The primary way this code is run is through the `chess_model_run_git.ipynb` notebook in Google Colab. This notebook is the actual workflow used for training and evaluation.
|
| 76 |
+
|
| 77 |
+
The `chess_model_run_git.ipynb` notebook provides:
|
| 78 |
+
|
| 79 |
+
- **Automated Setup**: Clones the repository, installs dependencies, and downloads the searchless chess model
|
| 80 |
+
- **Complete Configuration**: Pre-configured settings for GRPO training, dataset generation, and evaluation
|
| 81 |
+
- **Phase-Aware Dataset**: Example configuration using `ChessDatasetConfig` with `phase_distribution` for balanced training across opening, middlegame, and endgame positions
|
| 82 |
+
- **Evaluation Pipeline**: Integrated evaluation against Stockfish at multiple skill levels
|
| 83 |
+
|
| 84 |
+
The notebook handles all setup steps including:
|
| 85 |
+
1. Repository cloning and branch checkout
|
| 86 |
+
2. Dependency installation (PyTorch Lightning, WandB, python-chess, etc.)
|
| 87 |
+
3. Downloading the searchless chess model from HuggingFace
|
| 88 |
+
4. Stockfish installation
|
| 89 |
+
5. Training configuration with phase-distributed dataset sampling
|
| 90 |
+
6. Model training and periodic evaluation
|
| 91 |
+
|
| 92 |
+
### Evaluation
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
from src.grpo_self_play import Evaluator, EvalConfig
|
| 96 |
+
from src.grpo_self_play.chess.stockfish import StockfishConfig
|
| 97 |
+
|
| 98 |
+
# Create evaluator
|
| 99 |
+
evaluator = Evaluator(
|
| 100 |
+
eval_cfg=EvalConfig(games=50),
|
| 101 |
+
stockfish_cfg=StockfishConfig(skill_level=10, movetime_ms=100)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Single evaluation
|
| 105 |
+
results, policy = evaluator.single_evaluation(model)
|
| 106 |
+
print(f"Win rate: {results['score']:.2%}")
|
| 107 |
+
print(f"Approx Elo diff: {results['elo_diff_vs_stockfish_approx']:.0f}")
|
| 108 |
+
|
| 109 |
+
# Skill ladder evaluation
|
| 110 |
+
skill_results = evaluator.eval_ladder(model)
|
| 111 |
+
for skill, score in skill_results.items():
|
| 112 |
+
print(f"Skill {skill}: {score:.2%} win rate")
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Architecture
|
| 116 |
+
|
| 117 |
+
### Model Architecture
|
| 118 |
+
|
| 119 |
+
The `ChessTransformer` processes chess positions using:
|
| 120 |
+
|
| 121 |
+
- **Input Encoding**: FEN strings tokenized using DeepMind's chess tokenizer
|
| 122 |
+
- **Transformer Encoder**: Multi-head self-attention with learnable positional encodings
|
| 123 |
+
- **Policy Head**: Dense layers outputting logits over 1968 possible moves
|
| 124 |
+
- **Legal Move Masking**: Automatic filtering of illegal moves during inference
|
| 125 |
+
|
| 126 |
+
### GRPO Algorithm
|
| 127 |
+
|
| 128 |
+
Group Relative Policy Optimization extends PPO by:
|
| 129 |
+
|
| 130 |
+
1. **Group-Based Sampling**: Sample G trajectories per starting position
|
| 131 |
+
2. **Group Rewards**: Compute final reward for each trajectory group
|
| 132 |
+
3. **Relative Advantages**: Normalize advantages within each batch using group statistics
|
| 133 |
+
4. **PPO Clipping**: Prevent large policy updates with clipped importance ratios
|
| 134 |
+
5. **KL Penalty**: Regularize policy updates to prevent divergence
|
| 135 |
+
|
| 136 |
+
The loss function combines:
|
| 137 |
+
- **PPO Surrogate Loss**: `L_clip = E[min(r(θ)A, clip(r(θ), 1-ε, 1+ε)A)]`
|
| 138 |
+
- **KL Divergence Penalty**: `β * KL(π_old || π_new)`
|
| 139 |
+
|
| 140 |
+
### Training Pipeline
|
| 141 |
+
|
| 142 |
+
```
|
| 143 |
+
1. Sample random starting positions (FEN strings)
|
| 144 |
+
2. For each position:
|
| 145 |
+
- Sample G trajectory groups using old policy
|
| 146 |
+
- Compute group rewards using Stockfish evaluation
|
| 147 |
+
3. Compute advantages via group normalization
|
| 148 |
+
4. Update policy using GRPO loss
|
| 149 |
+
5. Sync old policy every epoch
|
| 150 |
+
6. Periodic evaluation against Stockfish
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## Module Structure
|
| 154 |
+
|
| 155 |
+
```
|
| 156 |
+
grpo_self_play/
|
| 157 |
+
├── models.py # ChessTransformer architecture
|
| 158 |
+
├── trainer.py # PyTorch Lightning trainer setup
|
| 159 |
+
├── train_self_play.py # Main training script
|
| 160 |
+
├── evaluator.py # Evaluation framework
|
| 161 |
+
├── eval_utils.py # Evaluation utilities
|
| 162 |
+
├── constants.py # Configuration constants
|
| 163 |
+
├── grpo_logic/
|
| 164 |
+
│ ├── model.py # GRPOChessTransformer (Lightning module)
|
| 165 |
+
│ ├── loss.py # GRPO loss computation
|
| 166 |
+
│ └── sampling.py # Trajectory sampling logic
|
| 167 |
+
└── chess/
|
| 168 |
+
├── chess_logic.py # Board encoding, legal moves
|
| 169 |
+
├── policy_player.py # Policy-based player
|
| 170 |
+
├── searcher.py # Trajectory search wrapper
|
| 171 |
+
├── rewards.py # Stockfish reward computation
|
| 172 |
+
└── stockfish.py # Stockfish engine integration
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
## Key Design Decisions
|
| 176 |
+
|
| 177 |
+
### 1. Group-Based Advantage Estimation
|
| 178 |
+
|
| 179 |
+
Instead of using value functions or Monte Carlo returns, GRPO computes advantages by normalizing rewards within trajectory groups. This approach:
|
| 180 |
+
- Eliminates the need for value function approximation
|
| 181 |
+
- Provides stable learning signals through relative comparisons
|
| 182 |
+
- Reduces variance in advantage estimates
|
| 183 |
+
|
| 184 |
+
### 2. Stockfish-Based Rewards
|
| 185 |
+
|
| 186 |
+
Using Stockfish for reward computation provides:
|
| 187 |
+
- **Dense Rewards**: Evaluation at every position, not just terminal states
|
| 188 |
+
- **High-Quality Signals**: Professional-grade position evaluation
|
| 189 |
+
- **Caching**: LRU cache for efficient reward computation during training
|
| 190 |
+
|
| 191 |
+
### 3. Legal Move Masking
|
| 192 |
+
|
| 193 |
+
The action space (1968 moves) is larger than legal moves in any position. The system:
|
| 194 |
+
- Masks illegal moves with `-inf` in logits
|
| 195 |
+
- Ensures policy only samples legal moves
|
| 196 |
+
- Handles edge cases (no legal moves, promotion moves)
|
| 197 |
+
|
| 198 |
+
### 4. Trajectory Padding and Masking
|
| 199 |
+
|
| 200 |
+
Trajectories have variable lengths due to game terminations. The implementation:
|
| 201 |
+
- Pads trajectories to fixed length for batching
|
| 202 |
+
- Uses attention masks to ignore padding
|
| 203 |
+
- Only considers moves from the starting player's perspective
|
| 204 |
+
|
| 205 |
+
## Configuration
|
| 206 |
+
|
| 207 |
+
This module uses a **YAML-based configuration system** to manage all hyperparameters and experiment settings. All training hyperparameters, model architecture settings, and evaluation configurations are centralized in YAML files located in `configs/`.
|
| 208 |
+
|
| 209 |
+
### Configuration Files
|
| 210 |
+
|
| 211 |
+
The default configuration file is `configs/default.yaml`, which contains all hyperparameters organized into sections:
|
| 212 |
+
|
| 213 |
+
- **`training`**: Training loop settings (epochs, batch size, steps per epoch)
|
| 214 |
+
- **`grpo`**: GRPO algorithm hyperparameters (learning rate, trajectories, clipping, KL penalty, entropy regularization, adaptive KL control)
|
| 215 |
+
- **`transformer`**: Model architecture (embedding dimension, layers, attention heads, vocabulary size, action space)
|
| 216 |
+
- **`eval`**: Evaluation settings (number of games, max plies, opening randomization)
|
| 217 |
+
- **`stockfish`**: Stockfish engine configuration (path, skill level, time limits, resource usage)
|
| 218 |
+
- **`policy`**: Policy player settings (temperature, greedy mode, branching factor, search depth)
|
| 219 |
+
- **`searcher`**: Optional trajectory search configuration
|
| 220 |
+
- **`dataset`**: Dataset generation settings (position phases, quality filters, evaluation bounds)
|
| 221 |
+
|
| 222 |
+
### Using Configurations
|
| 223 |
+
|
| 224 |
+
#### Loading Configurations
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
from src.grpo_self_play.configs.config_loader import load_experiment_config
|
| 228 |
+
|
| 229 |
+
# Load default config
|
| 230 |
+
config = load_experiment_config("default.yaml")
|
| 231 |
+
|
| 232 |
+
# Load with overrides
|
| 233 |
+
config = load_experiment_config("default.yaml", overrides={
|
| 234 |
+
"grpo": {"lr": 1e-4, "entropy_coef": 0.2},
|
| 235 |
+
"training": {"num_epochs": 100},
|
| 236 |
+
})
|
| 237 |
+
|
| 238 |
+
# Access config values
|
| 239 |
+
print(config.grpo.lr)
|
| 240 |
+
print(config.training.batch_size)
|
| 241 |
+
print(config.transformer.embed_dim)
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
#### Training with Configurations
|
| 245 |
+
|
| 246 |
+
```python
|
| 247 |
+
from src.grpo_self_play.train_self_play import train
|
| 248 |
+
|
| 249 |
+
# Use default config
|
| 250 |
+
train()
|
| 251 |
+
|
| 252 |
+
# Use custom config file
|
| 253 |
+
train(config_path="my_experiment.yaml")
|
| 254 |
+
|
| 255 |
+
# Override specific values
|
| 256 |
+
train(
|
| 257 |
+
config_path="default.yaml",
|
| 258 |
+
overrides={
|
| 259 |
+
"grpo": {"lr": 1e-4},
|
| 260 |
+
"training": {"num_epochs": 50},
|
| 261 |
+
},
|
| 262 |
+
dataloader_kwargs={"num_workers": 4} # Override DataLoader args
|
| 263 |
+
)
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
### Creating Custom Configurations
|
| 267 |
+
|
| 268 |
+
1. Copy the default config:
|
| 269 |
+
```bash
|
| 270 |
+
cp configs/default.yaml configs/my_experiment.yaml
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
2. Edit `my_experiment.yaml` to modify hyperparameters
|
| 274 |
+
|
| 275 |
+
3. Use your custom config:
|
| 276 |
+
```python
|
| 277 |
+
train(config_path="my_experiment.yaml")
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
### Configuration Dataclasses
|
| 281 |
+
|
| 282 |
+
The configuration system converts YAML files into typed dataclasses:
|
| 283 |
+
|
| 284 |
+
- **`TrainingConfig`**: Training loop settings
|
| 285 |
+
- **`GRPOConfig`**: GRPO algorithm hyperparameters
|
| 286 |
+
- **`ChessTransformerConfig`**: Model architecture
|
| 287 |
+
- **`EvalConfig`**: Evaluation settings
|
| 288 |
+
- **`StockfishConfig`**: Stockfish engine settings
|
| 289 |
+
- **`PolicyConfig`**: Policy player settings
|
| 290 |
+
- **`SearchConfig`**: Trajectory search settings (optional)
|
| 291 |
+
- **`ChessDatasetConfig`**: Dataset generation settings
|
| 292 |
+
|
| 293 |
+
All configs are combined into an `ExperimentConfig` object that provides type-safe access to all settings.
|
| 294 |
+
|
| 295 |
+
### Key Hyperparameters
|
| 296 |
+
|
| 297 |
+
All hyperparameters are defined in YAML files. Key settings include:
|
| 298 |
+
|
| 299 |
+
**GRPO Algorithm:**
|
| 300 |
+
- `grpo.lr`: Learning rate for policy optimization
|
| 301 |
+
- `grpo.num_trajectories`: Number of trajectory groups per starting position
|
| 302 |
+
- `grpo.trajectory_depth`: Maximum moves per trajectory
|
| 303 |
+
- `grpo.clip_ratio`: PPO clipping epsilon (prevents large policy updates)
|
| 304 |
+
- `grpo.kl_coef`: KL divergence penalty coefficient
|
| 305 |
+
- `grpo.entropy_coef`: Entropy regularization coefficient
|
| 306 |
+
- `grpo.adaptive_kl`: Enable adaptive KL coefficient adjustment
|
| 307 |
+
- `grpo.use_entropy_floor`: Monitor and respond to entropy collapse
|
| 308 |
+
|
| 309 |
+
**Model Architecture:**
|
| 310 |
+
- `transformer.embed_dim`: Transformer embedding dimension
|
| 311 |
+
- `transformer.num_layers`: Number of transformer layers
|
| 312 |
+
- `transformer.num_heads`: Number of attention heads
|
| 313 |
+
- `transformer.vocab_size`: Token vocabulary size
|
| 314 |
+
- `transformer.action_dim`: Action space size (1968 for chess)
|
| 315 |
+
|
| 316 |
+
**Training:**
|
| 317 |
+
- `training.num_epochs`: Total number of training epochs
|
| 318 |
+
- `training.batch_size`: Batch size for training
|
| 319 |
+
- `training.steps_per_epoch`: Number of training steps per epoch
|
| 320 |
+
|
| 321 |
+
See `configs/default.yaml` for the complete list of all hyperparameters and their default values.
|
| 322 |
+
|
| 323 |
+
## Advanced Usage
|
| 324 |
+
|
| 325 |
+
### Custom Reward Function
|
| 326 |
+
|
| 327 |
+
```python
|
| 328 |
+
from src.grpo_self_play.chess.rewards import reward_board
|
| 329 |
+
|
| 330 |
+
def custom_reward(board, start_board):
|
| 331 |
+
# Your custom reward logic
|
| 332 |
+
return reward_board(board, start_board, depth=8, movetime_ms=50)
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
### Trajectory Search
|
| 336 |
+
|
| 337 |
+
```python
|
| 338 |
+
from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
|
| 339 |
+
from src.grpo_self_play.chess.policy_player import PolicyPlayer
|
| 340 |
+
|
| 341 |
+
policy = PolicyPlayer(model)
|
| 342 |
+
searcher = TrajectorySearcher(
|
| 343 |
+
policy,
|
| 344 |
+
cfg=SearchConfig(n_trajectories=10, trajectory_depth=3)
|
| 345 |
+
)
|
| 346 |
+
```
|
| 347 |
+
|
| 348 |
+
### Custom Training Loop
|
| 349 |
+
|
| 350 |
+
```python
|
| 351 |
+
import pytorch_lightning as pl
|
| 352 |
+
from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer
|
| 353 |
+
|
| 354 |
+
model = GRPOChessTransformer(transformer_config, grpo_config)
|
| 355 |
+
trainer = pl.Trainer(
|
| 356 |
+
max_epochs=1000,
|
| 357 |
+
gradient_clip_val=1.0,
|
| 358 |
+
accelerator="gpu",
|
| 359 |
+
devices=1
|
| 360 |
+
)
|
| 361 |
+
trainer.fit(model, dataloader)
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
## Performance Considerations
|
| 365 |
+
|
| 366 |
+
- **Batch Size**: Larger batches improve advantage normalization quality
|
| 367 |
+
- **Trajectory Depth**: Deeper trajectories provide more learning signal but increase compute
|
| 368 |
+
- **Stockfish Depth**: Higher depth = better rewards but slower training
|
| 369 |
+
- **Caching**: Reward caching significantly speeds up training
|
| 370 |
+
- **Gradient Clipping**: Prevents exploding gradients in transformer training
|
| 371 |
+
|
| 372 |
+
## Monitoring and Logging
|
| 373 |
+
|
| 374 |
+
The module logs comprehensive metrics to Weights & Biases:
|
| 375 |
+
|
| 376 |
+
- **Training Metrics**: Loss, KL divergence, policy ratios, reward statistics
|
| 377 |
+
- **Evaluation Metrics**: Win rate, Elo difference, game outcomes
|
| 378 |
+
- **System Metrics**: Trajectory lengths, padding fractions, gradient norms
|
| 379 |
+
|
| 380 |
+
## Research Background
|
| 381 |
+
|
| 382 |
+
GRPO (Group Relative Policy Optimization) is inspired by:
|
| 383 |
+
- **PPO (Proximal Policy Optimization)**: Clipped surrogate objective
|
| 384 |
+
- **REINFORCE**: Policy gradient methods
|
| 385 |
+
- **Self-Play**: Learning through playing against oneself
|
| 386 |
+
- **AlphaZero**: Combining deep learning with game tree search
|
| 387 |
+
|
| 388 |
+
This implementation adapts these ideas specifically for chess, using Stockfish for reward signals and evaluation.
|
| 389 |
+
|
| 390 |
+
## Technical Highlights
|
| 391 |
+
|
| 392 |
+
- ✅ **Practical Infrastructure**: Error handling, resource management, logging
|
| 393 |
+
- ✅ **Scalable Design**: Efficient batching, parallel trajectory sampling
|
| 394 |
+
- ✅ **Extensible**: Modular design allows easy customization
|
| 395 |
+
- ✅ **Documented**: Type hints, docstrings, clear structure
|
| 396 |
+
- ⚠️ **Status**: This is a research system, not a production-ready chess engine
|
| 397 |
+
|
| 398 |
+
## Future Enhancements
|
| 399 |
+
|
| 400 |
+
Potential improvements:
|
| 401 |
+
- Value function approximation for better advantage estimates
|
| 402 |
+
- More robust entropy and KL control for GRPO
|
| 403 |
+
- Multi-GPU training support
|
| 404 |
+
- Distributed self-play
|
| 405 |
+
- Opening book integration
|
| 406 |
+
- Endgame tablebase integration
|
| 407 |
+
|
| 408 |
+
## License
|
| 409 |
+
|
| 410 |
+
[Specify your license here]
|
| 411 |
+
|
| 412 |
+
## Citation
|
| 413 |
+
|
| 414 |
+
If you use this code in your research, please cite:
|
| 415 |
+
|
| 416 |
+
```bibtex
|
| 417 |
+
@software{grpo_chess,
|
| 418 |
+
title = {GRPO Self-Play Chess Module},
|
| 419 |
+
author = {Your Name},
|
| 420 |
+
year = {2024},
|
| 421 |
+
url = {https://github.com/yourusername/grpo_chess}
|
| 422 |
+
}
|
| 423 |
+
```
|
| 424 |
+
|
| 425 |
+
## Contact
|
| 426 |
+
|
| 427 |
+
For questions or contributions, please open an issue or contact [your email].
|
| 428 |
+
|
hf_space_repo/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""GRPO Self-Play Module for Chess.
|
| 2 |
+
|
| 3 |
+
This module implements Group Relative Policy Optimization (GRPO) for training
|
| 4 |
+
chess policies through self-play. It includes:
|
| 5 |
+
- Transformer-based chess policy models
|
| 6 |
+
- GRPO training logic with PPO clipping
|
| 7 |
+
- Trajectory sampling and reward computation
|
| 8 |
+
- Evaluation against Stockfish
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__version__ = "0.1.0"
|
| 12 |
+
|
| 13 |
+
# Main exports
|
| 14 |
+
from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
|
| 15 |
+
from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer, GRPOConfig
|
| 16 |
+
from src.grpo_self_play.grpo_logic.loss import grpo_ppo_loss, GRPOLossInfo
|
| 17 |
+
from src.grpo_self_play.evaluator import Evaluator
|
| 18 |
+
from src.grpo_self_play.eval_utils import EvalConfig
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"ChessTransformer",
|
| 22 |
+
"ChessTransformerConfig",
|
| 23 |
+
"GRPOChessTransformer",
|
| 24 |
+
"GRPOConfig",
|
| 25 |
+
"grpo_ppo_loss",
|
| 26 |
+
"GRPOLossInfo",
|
| 27 |
+
"Evaluator",
|
| 28 |
+
"EvalConfig",
|
| 29 |
+
]
|
| 30 |
+
|
hf_space_repo/chess/__init__.py
ADDED
|
File without changes
|
hf_space_repo/chess/boards_dataset.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset of random chess boards."""
|
| 2 |
+
|
| 3 |
+
import chess
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
from collections import deque
|
| 7 |
+
|
| 8 |
+
from typing import Any, Optional, Dict
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from torch.utils.data import IterableDataset
|
| 11 |
+
from src.grpo_self_play.chess.rewards import evaluate_fen
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_random_board(step_num=30):
|
| 15 |
+
"""Generate a random board position by making random moves from starting position.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
step_num: Maximum number of random moves to make
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Chess board after random moves
|
| 22 |
+
"""
|
| 23 |
+
board = chess.Board()
|
| 24 |
+
random_steps = random.randint(0, step_num)
|
| 25 |
+
for _ in range(random_steps):
|
| 26 |
+
if board.is_game_over(): break
|
| 27 |
+
board.push(random.choice(list(board.legal_moves)))
|
| 28 |
+
return board
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_game_phase(board: chess.Board) -> str:
|
| 32 |
+
"""Determine the game phase (opening, middlegame, or endgame).
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
board: Chess board position
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
"opening", "middlegame", or "endgame"
|
| 39 |
+
"""
|
| 40 |
+
move_count = board.fullmove_number * 2 - (1 if board.turn == chess.BLACK else 0)
|
| 41 |
+
|
| 42 |
+
# Count material (excluding kings)
|
| 43 |
+
material_count = sum(
|
| 44 |
+
len(board.pieces(pt, color))
|
| 45 |
+
for pt in [chess.PAWN, chess.ROOK, chess.KNIGHT, chess.BISHOP, chess.QUEEN]
|
| 46 |
+
for color in [chess.WHITE, chess.BLACK]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Endgame: few pieces remaining (typically < 12-14 pieces)
|
| 50 |
+
if material_count <= 12:
|
| 51 |
+
return "endgame"
|
| 52 |
+
# Opening: early moves (typically first 15 moves)
|
| 53 |
+
elif move_count <= 15:
|
| 54 |
+
return "opening"
|
| 55 |
+
# Middlegame: everything else
|
| 56 |
+
else:
|
| 57 |
+
return "middlegame"
|
| 58 |
+
|
| 59 |
+
def evaluate_position_quality(board: chess.Board, depth: int = 2) -> Optional[float]:
|
| 60 |
+
"""Quick Stockfish evaluation to check position quality.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
board: Chess board position
|
| 64 |
+
depth: Stockfish search depth (shallow for speed)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Centipawn evaluation from White's perspective, or None if evaluation fails
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
fen = board.fen()
|
| 71 |
+
pov_is_white = board.turn == chess.WHITE
|
| 72 |
+
eval_cp = evaluate_fen(fen, pov_is_white, movetime_ms=0, depth=depth)
|
| 73 |
+
return eval_cp
|
| 74 |
+
except Exception:
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
def generate_opening_position(max_moves: int = 15) -> chess.Board:
|
| 78 |
+
"""Generate a realistic opening position using common opening moves.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
max_moves: Maximum number of opening moves to make
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Chess board in opening phase
|
| 85 |
+
"""
|
| 86 |
+
board = chess.Board()
|
| 87 |
+
|
| 88 |
+
# Common first moves for White
|
| 89 |
+
first_moves = [
|
| 90 |
+
chess.Move.from_uci("e2e4"), # King's pawn
|
| 91 |
+
chess.Move.from_uci("d2d4"), # Queen's pawn
|
| 92 |
+
chess.Move.from_uci("g1f3"), # King's knight
|
| 93 |
+
chess.Move.from_uci("c2c4"), # English opening
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
# Make first move
|
| 97 |
+
if first_moves:
|
| 98 |
+
first_move = random.choice(first_moves)
|
| 99 |
+
if first_move in board.legal_moves:
|
| 100 |
+
board.push(first_move)
|
| 101 |
+
|
| 102 |
+
# Continue with semi-random play (preferring development moves)
|
| 103 |
+
moves_made = 1
|
| 104 |
+
while moves_made < max_moves and not board.is_game_over():
|
| 105 |
+
legal_moves = list(board.legal_moves)
|
| 106 |
+
if not legal_moves:
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
# Prefer piece development over pawn moves in opening
|
| 110 |
+
piece_moves = [m for m in legal_moves if board.piece_at(m.from_square) and
|
| 111 |
+
board.piece_at(m.from_square).piece_type != chess.PAWN]
|
| 112 |
+
|
| 113 |
+
if piece_moves and random.random() < 0.6: # 60% chance to prefer piece moves
|
| 114 |
+
move = random.choice(piece_moves)
|
| 115 |
+
else:
|
| 116 |
+
move = random.choice(legal_moves)
|
| 117 |
+
|
| 118 |
+
board.push(move)
|
| 119 |
+
moves_made += 1
|
| 120 |
+
|
| 121 |
+
return board
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_middlegame_position(min_moves: int = 15, max_moves: int = 40) -> chess.Board:
|
| 125 |
+
"""Generate a middlegame position from a reasonable starting point.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
min_moves: Minimum moves to reach middlegame
|
| 129 |
+
max_moves: Maximum moves for middlegame
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Chess board in middlegame phase
|
| 133 |
+
"""
|
| 134 |
+
# Start from an opening position
|
| 135 |
+
board = generate_opening_position(max_moves=min_moves)
|
| 136 |
+
|
| 137 |
+
# Continue with random play to reach middlegame
|
| 138 |
+
target_moves = random.randint(min_moves, max_moves)
|
| 139 |
+
moves_made = len(board.move_stack)
|
| 140 |
+
|
| 141 |
+
while moves_made < target_moves and not board.is_game_over():
|
| 142 |
+
legal_moves = list[Any](board.legal_moves)
|
| 143 |
+
if not legal_moves:
|
| 144 |
+
break
|
| 145 |
+
board.push(random.choice(legal_moves))
|
| 146 |
+
moves_made += 1
|
| 147 |
+
|
| 148 |
+
return board
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def generate_endgame_position() -> chess.Board: # TODO: This is not working as expected, it should be a function that generates a random endgame position.
|
| 153 |
+
"""Generate an endgame position by removing pieces from a middlegame position.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Chess board in endgame phase
|
| 157 |
+
"""
|
| 158 |
+
# Start with a middlegame position
|
| 159 |
+
board = generate_middlegame_position(min_moves=20, max_moves=35)
|
| 160 |
+
|
| 161 |
+
# Remove pieces to create endgame (keep kings, remove other pieces randomly)
|
| 162 |
+
pieces_to_remove = []
|
| 163 |
+
for square in chess.SQUARES:
|
| 164 |
+
piece = board.piece_at(square)
|
| 165 |
+
if piece and piece.piece_type != chess.KING:
|
| 166 |
+
pieces_to_remove.append(square)
|
| 167 |
+
|
| 168 |
+
# Remove random pieces until we have endgame material (<= 12 pieces total)
|
| 169 |
+
target_pieces = random.randint(6, 12) # Endgame typically has 6-12 pieces
|
| 170 |
+
current_pieces = len([p for p in pieces_to_remove if board.piece_at(p)])
|
| 171 |
+
|
| 172 |
+
# We need to remove pieces, but we can't directly remove them from python-chess Board
|
| 173 |
+
# Instead, we'll generate a new position by making moves that trade pieces
|
| 174 |
+
# For simplicity, we'll just continue playing until we naturally reach endgame material
|
| 175 |
+
|
| 176 |
+
# Count material
|
| 177 |
+
def count_material(b: chess.Board) -> int:
|
| 178 |
+
return sum(
|
| 179 |
+
len(b.pieces(pt, color))
|
| 180 |
+
for pt in [chess.PAWN, chess.ROOK, chess.KNIGHT, chess.BISHOP, chess.QUEEN]
|
| 181 |
+
for color in [chess.WHITE, chess.BLACK]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Play random moves until we reach endgame material
|
| 185 |
+
max_attempts = 100
|
| 186 |
+
attempts = 0
|
| 187 |
+
while count_material(board) > 12 and attempts < max_attempts and not board.is_game_over():
|
| 188 |
+
legal_moves = list(board.legal_moves)
|
| 189 |
+
if not legal_moves:
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
# Prefer captures to reduce material
|
| 193 |
+
captures = [m for m in legal_moves if board.is_capture(m)]
|
| 194 |
+
if captures:
|
| 195 |
+
move = random.choice(captures)
|
| 196 |
+
else:
|
| 197 |
+
move = random.choice(legal_moves)
|
| 198 |
+
|
| 199 |
+
board.push(move)
|
| 200 |
+
attempts += 1
|
| 201 |
+
|
| 202 |
+
return board
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def generate_position_by_phase(phase: str) -> chess.Board:
|
| 207 |
+
"""Generate a position for a specific game phase.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
phase: "opening", "middlegame", or "endgame"
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Chess board in the specified phase
|
| 214 |
+
"""
|
| 215 |
+
if phase == "opening":
|
| 216 |
+
return generate_opening_position()
|
| 217 |
+
elif phase == "middlegame":
|
| 218 |
+
return generate_middlegame_position()
|
| 219 |
+
elif phase == "endgame":
|
| 220 |
+
return generate_endgame_position()
|
| 221 |
+
else:
|
| 222 |
+
raise ValueError(f"Unknown phase: {phase}. Must be 'opening', 'middlegame', or 'endgame'")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def generate_quality_filtered_board(
|
| 226 |
+
step_num: int = 30,
|
| 227 |
+
min_eval_cp: int = -200,
|
| 228 |
+
max_eval_cp: int = 200,
|
| 229 |
+
filter_depth: int = 2,
|
| 230 |
+
max_attempts: int = 50,
|
| 231 |
+
phase: Optional[str] = None
|
| 232 |
+
) -> Optional[chess.Board]:
|
| 233 |
+
"""Generate a random board position filtered by Stockfish evaluation quality.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
step_num: Maximum number of random moves (if phase is None)
|
| 237 |
+
min_eval_cp: Minimum centipawn evaluation to accept
|
| 238 |
+
max_eval_cp: Maximum centipawn evaluation to accept
|
| 239 |
+
filter_depth: Stockfish depth for filtering (shallow for speed)
|
| 240 |
+
max_attempts: Maximum attempts to generate a valid position
|
| 241 |
+
phase: Optional phase to generate ("opening", "middlegame", "endgame")
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Chess board within evaluation range, or None if no valid position found
|
| 245 |
+
"""
|
| 246 |
+
for attempt in range(max_attempts):
|
| 247 |
+
# Generate position
|
| 248 |
+
if phase:
|
| 249 |
+
board = generate_position_by_phase(phase)
|
| 250 |
+
else:
|
| 251 |
+
board = generate_random_board(step_num)
|
| 252 |
+
|
| 253 |
+
# Skip if game over or no legal moves
|
| 254 |
+
if board.is_game_over() or not list(board.legal_moves):
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
# Evaluate position quality
|
| 258 |
+
eval_cp = evaluate_position_quality(board, depth=filter_depth)
|
| 259 |
+
if eval_cp is None:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
# Check if evaluation is within acceptable range
|
| 263 |
+
if min_eval_cp <= eval_cp <= max_eval_cp:
|
| 264 |
+
return board
|
| 265 |
+
|
| 266 |
+
# If we couldn't find a good position, return a random one anyway
|
| 267 |
+
if phase:
|
| 268 |
+
return generate_position_by_phase(phase)
|
| 269 |
+
else:
|
| 270 |
+
return generate_random_board(step_num)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@dataclass
|
| 274 |
+
class ChessDatasetConfig:
|
| 275 |
+
"""Configuration for the Chess Start States Dataset.
|
| 276 |
+
|
| 277 |
+
Attributes:
|
| 278 |
+
max_steps: Maximum number of positions to generate per epoch
|
| 279 |
+
random_walk_gen_steps: Maximum random moves (legacy, used if phase_distribution is None)
|
| 280 |
+
phase_distribution: Dict mapping phase names to weights, e.g. {"opening": 0.3, "middlegame": 0.5, "endgame": 0.2}
|
| 281 |
+
min_eval_cp: Minimum centipawn evaluation to accept (-200)
|
| 282 |
+
max_eval_cp: Maximum centipawn evaluation to accept (+200)
|
| 283 |
+
use_opening_book: Whether to use opening book moves for opening positions
|
| 284 |
+
stockfish_filter_depth: Stockfish depth for quality filtering (2-4 for speed)
|
| 285 |
+
cache_positions: Whether to cache and reuse high-quality positions
|
| 286 |
+
cache_size: Maximum number of positions to cache
|
| 287 |
+
quality_filter: Whether to filter positions by Stockfish evaluation
|
| 288 |
+
"""
|
| 289 |
+
max_steps: int = 10000
|
| 290 |
+
random_walk_gen_steps: int = 30
|
| 291 |
+
phase_distribution: Optional[Dict[str, float]] = None
|
| 292 |
+
min_eval_cp: int = -200
|
| 293 |
+
max_eval_cp: int = 200
|
| 294 |
+
use_opening_book: bool = True
|
| 295 |
+
stockfish_filter_depth: int = 2
|
| 296 |
+
cache_positions: bool = False
|
| 297 |
+
cache_size: int = 1000
|
| 298 |
+
quality_filter: bool = True
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class ChessStartStatesDataset(IterableDataset):
|
| 302 |
+
"""
|
| 303 |
+
Infinite dataset that yields high-quality FEN strings from diverse game phases.
|
| 304 |
+
|
| 305 |
+
Supports quality filtering, phase-aware generation, and position caching.
|
| 306 |
+
"""
|
| 307 |
+
def __init__(
|
| 308 |
+
self,
|
| 309 |
+
config: ChessDatasetConfig = ChessDatasetConfig()
|
| 310 |
+
):
|
| 311 |
+
"""
|
| 312 |
+
Initialize dataset with quality filtering and phase diversity options.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
config: ChessDatasetConfig object with all configuration parameters.
|
| 316 |
+
Defaults to ChessDatasetConfig() if no config is provided.
|
| 317 |
+
Parameters in the config are:
|
| 318 |
+
max_steps: Maximum number of positions to generate per epoch
|
| 319 |
+
random_walk_gen_steps: Maximum random moves (legacy, used if phase_distribution is None)
|
| 320 |
+
phase_distribution: Dict mapping phase names to weights, e.g. {"opening": 0.3, "middlegame": 0.5, "endgame": 0.2}
|
| 321 |
+
min_eval_cp: Minimum centipawn evaluation to accept (-200)
|
| 322 |
+
max_eval_cp: Maximum centipawn evaluation to accept (+200)
|
| 323 |
+
use_opening_book: Whether to use opening book moves for opening positions
|
| 324 |
+
stockfish_filter_depth: Stockfish depth for quality filtering (2-4 for speed)
|
| 325 |
+
cache_positions: Whether to cache and reuse high-quality positions
|
| 326 |
+
cache_size: Maximum number of positions to cache
|
| 327 |
+
quality_filter: Whether to filter positions by Stockfish evaluation
|
| 328 |
+
"""
|
| 329 |
+
# Use config if provided, otherwise use individual parameters or defaults
|
| 330 |
+
|
| 331 |
+
self.max_steps = config.max_steps
|
| 332 |
+
self.random_walk_gen_steps = config.random_walk_gen_steps
|
| 333 |
+
self.phase_distribution = config.phase_distribution
|
| 334 |
+
self.min_eval_cp = config.min_eval_cp
|
| 335 |
+
self.max_eval_cp = config.max_eval_cp
|
| 336 |
+
self.use_opening_book = config.use_opening_book
|
| 337 |
+
self.stockfish_filter_depth = config.stockfish_filter_depth
|
| 338 |
+
self.cache_positions = config.cache_positions
|
| 339 |
+
self.cache_size = config.cache_size
|
| 340 |
+
self.quality_filter = config.quality_filter
|
| 341 |
+
|
| 342 |
+
# Normalize phase distribution (only if not None)
|
| 343 |
+
if self.phase_distribution is not None:
|
| 344 |
+
total_weight = sum(self.phase_distribution.values())
|
| 345 |
+
if total_weight > 0:
|
| 346 |
+
self.phase_distribution = {k: v / total_weight for k, v in self.phase_distribution.items()}
|
| 347 |
+
|
| 348 |
+
# Position cache
|
| 349 |
+
self._position_cache: deque = deque[Any](maxlen=self.cache_size)
|
| 350 |
+
self._cache_stats = {"hits": 0, "misses": 0, "generated": 0}
|
| 351 |
+
|
| 352 |
+
# Statistics tracking
|
| 353 |
+
self._stats = {
|
| 354 |
+
"opening": 0,
|
| 355 |
+
"middlegame": 0,
|
| 356 |
+
"endgame": 0,
|
| 357 |
+
"filtered_out": 0,
|
| 358 |
+
"total_generated": 0,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
def _sample_phase(self) -> str:
|
| 362 |
+
"""Sample a game phase according to phase_distribution weights.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Phase name: "opening", "middlegame", or "endgame"
|
| 366 |
+
"""
|
| 367 |
+
rand = random.random()
|
| 368 |
+
cumulative = 0.0
|
| 369 |
+
for phase, weight in self.phase_distribution.items():
|
| 370 |
+
cumulative += weight
|
| 371 |
+
if rand <= cumulative:
|
| 372 |
+
return phase
|
| 373 |
+
# Fallback to middlegame
|
| 374 |
+
return "middlegame"
|
| 375 |
+
|
| 376 |
+
def _generate_position(self) -> Optional[chess.Board]:
|
| 377 |
+
"""Generate a single position according to configuration.
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Chess board or None if generation fails
|
| 381 |
+
"""
|
| 382 |
+
# Check cache first
|
| 383 |
+
if self.cache_positions and self._position_cache:
|
| 384 |
+
if random.random() < 0.3: # 30% chance to use cached position
|
| 385 |
+
cached_pos = random.choice(self._position_cache)
|
| 386 |
+
self._cache_stats["hits"] += 1
|
| 387 |
+
return chess.Board(cached_pos)
|
| 388 |
+
self._cache_stats["misses"] += 1
|
| 389 |
+
|
| 390 |
+
# Determine phase
|
| 391 |
+
if self.phase_distribution:
|
| 392 |
+
phase = self._sample_phase()
|
| 393 |
+
else:
|
| 394 |
+
phase = None
|
| 395 |
+
|
| 396 |
+
# Generate position
|
| 397 |
+
if self.quality_filter:
|
| 398 |
+
board = generate_quality_filtered_board(
|
| 399 |
+
step_num=self.random_walk_gen_steps,
|
| 400 |
+
min_eval_cp=self.min_eval_cp,
|
| 401 |
+
max_eval_cp=self.max_eval_cp,
|
| 402 |
+
filter_depth=self.stockfish_filter_depth,
|
| 403 |
+
phase=phase
|
| 404 |
+
)
|
| 405 |
+
else:
|
| 406 |
+
if phase:
|
| 407 |
+
board = generate_position_by_phase(phase)
|
| 408 |
+
else:
|
| 409 |
+
board = generate_random_board(self.random_walk_gen_steps)
|
| 410 |
+
|
| 411 |
+
if board is None:
|
| 412 |
+
return None
|
| 413 |
+
|
| 414 |
+
# Update statistics
|
| 415 |
+
if not board.is_game_over():
|
| 416 |
+
actual_phase = get_game_phase(board)
|
| 417 |
+
self._stats[actual_phase] = self._stats.get(actual_phase, 0) + 1
|
| 418 |
+
self._stats["total_generated"] += 1
|
| 419 |
+
|
| 420 |
+
# Cache position if enabled
|
| 421 |
+
if self.cache_positions:
|
| 422 |
+
self._position_cache.append(board.fen())
|
| 423 |
+
self._cache_stats["generated"] += 1
|
| 424 |
+
|
| 425 |
+
return board
|
| 426 |
+
|
| 427 |
+
def get_stats(self) -> Dict:
|
| 428 |
+
"""Get statistics about generated positions.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
Dictionary with statistics
|
| 432 |
+
"""
|
| 433 |
+
stats = self._stats.copy()
|
| 434 |
+
if self.cache_positions:
|
| 435 |
+
stats["cache"] = self._cache_stats.copy()
|
| 436 |
+
stats["cache"]["size"] = len(self._position_cache)
|
| 437 |
+
return stats
|
| 438 |
+
|
| 439 |
+
def __iter__(self):
|
| 440 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 441 |
+
|
| 442 |
+
# Determine how many steps this worker should generate
|
| 443 |
+
if worker_info is not None:
|
| 444 |
+
# Split work among workers
|
| 445 |
+
num_workers = worker_info.num_workers
|
| 446 |
+
worker_id = worker_info.id
|
| 447 |
+
per_worker = self.max_steps // num_workers
|
| 448 |
+
# Give remainder to the last worker
|
| 449 |
+
if worker_id == num_workers - 1:
|
| 450 |
+
per_worker += self.max_steps % num_workers
|
| 451 |
+
|
| 452 |
+
# Set deterministic seed per worker for reproducibility and isolation
|
| 453 |
+
worker_seed = 42 + worker_id * 1000
|
| 454 |
+
random.seed(worker_seed)
|
| 455 |
+
torch.manual_seed(worker_seed)
|
| 456 |
+
steps_to_generate = per_worker
|
| 457 |
+
else:
|
| 458 |
+
# Single process mode
|
| 459 |
+
steps_to_generate = self.max_steps
|
| 460 |
+
|
| 461 |
+
# Generate positions for this worker's share
|
| 462 |
+
for step in range(steps_to_generate):
|
| 463 |
+
board = self._generate_position()
|
| 464 |
+
if board is not None and not board.is_game_over():
|
| 465 |
+
yield board.fen()
|
hf_space_repo/chess/chess_logic.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chess
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from src.grpo_self_play.searchless_chess_imports import (MOVE_TO_ACTION,
|
| 6 |
+
ACTION_TO_MOVE,
|
| 7 |
+
tokenize as deepmind_tokenize)
|
| 8 |
+
|
| 9 |
+
MAX_ACTION = max(ACTION_TO_MOVE.keys())
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def board_to_tensor(board, device: str | torch.device ='cpu') -> torch.Tensor:
|
| 13 |
+
fen = board.fen()
|
| 14 |
+
token_ids = list[int](deepmind_tokenize(fen)) # Returns list of ints
|
| 15 |
+
input_tensor = torch.tensor([token_ids], dtype=torch.long, device=device)
|
| 16 |
+
return input_tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_legal_moves_indices(board):
|
| 20 |
+
legal_moves = list(board.legal_moves)
|
| 21 |
+
legal_indices = []
|
| 22 |
+
for move in legal_moves:
|
| 23 |
+
# move.uci() returns "e2e4" or "a7a8q" which matches your dict keys
|
| 24 |
+
uci_str = move.uci()
|
| 25 |
+
if uci_str in MOVE_TO_ACTION:
|
| 26 |
+
legal_indices.append(MOVE_TO_ACTION[uci_str])
|
| 27 |
+
else:
|
| 28 |
+
# Fallback: unlikely if MOVE_TO_ACTION is complete
|
| 29 |
+
raise ValueError(f"Invalid move: {uci_str}")
|
| 30 |
+
return legal_indices
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_legal_moves_mask(board, device: str | torch.device ='cpu') -> torch.Tensor:
|
| 34 |
+
legal_moves = list(board.legal_moves)
|
| 35 |
+
mask = torch.zeros(MAX_ACTION + 1, dtype=torch.bool)
|
| 36 |
+
for move in legal_moves:
|
| 37 |
+
uci_str = move.uci()
|
| 38 |
+
action_idx = MOVE_TO_ACTION.get(uci_str)
|
| 39 |
+
if action_idx is not None:
|
| 40 |
+
mask[action_idx] = True
|
| 41 |
+
return mask.to(device)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def action_to_move(board: chess.Board, action_idx: int):
|
| 45 |
+
uci = ACTION_TO_MOVE.get(action_idx)
|
| 46 |
+
if uci is None:
|
| 47 |
+
return None
|
| 48 |
+
try:
|
| 49 |
+
mv = chess.Move.from_uci(uci)
|
| 50 |
+
except ValueError:
|
| 51 |
+
return None
|
| 52 |
+
return mv if mv in board.legal_moves else None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ChessPlayer:
|
| 56 |
+
"""
|
| 57 |
+
An abstract chess player interface.
|
| 58 |
+
"""
|
| 59 |
+
def act(self, board: chess.Board) -> Optional[chess.Move]:
|
| 60 |
+
"""
|
| 61 |
+
Given a chess.Board, return a chess.Move or None to resign.
|
| 62 |
+
"""
|
| 63 |
+
raise NotImplementedError()
|
hf_space_repo/chess/policy_player.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from src.grpo_self_play.chess.chess_logic import (board_to_tensor,
|
| 6 |
+
get_legal_moves_indices,
|
| 7 |
+
action_to_move,
|
| 8 |
+
ChessPlayer)
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PolicyConfig:
|
| 14 |
+
temperature: float = 1.0
|
| 15 |
+
greedy: bool = False # if True, pick argmax among legal moves
|
| 16 |
+
branching_factor: int = 4 # for search; 0 = no limit
|
| 17 |
+
search_depth: int = 2 # for search; 0 = no search
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 21 |
+
torch.serialization.add_safe_globals([PolicyConfig])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PolicyPlayer(ChessPlayer):
|
| 25 |
+
def __init__(self, model, device=None, cfg=PolicyConfig()):
|
| 26 |
+
self.model = model.eval()
|
| 27 |
+
self.device = device or next(model.parameters()).device
|
| 28 |
+
self.cfg = cfg
|
| 29 |
+
self.stats = {"no_legal_idxs": 0, "mapping_failed": 0, "random_fallback": 0}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def act(self, board):
|
| 34 |
+
legal_moves_indices = get_legal_moves_indices(board)
|
| 35 |
+
if not legal_moves_indices:
|
| 36 |
+
self.stats["no_legal_idxs"] += 1
|
| 37 |
+
self.stats["random_fallback"] += 1
|
| 38 |
+
return random.choice(list(board.legal_moves))
|
| 39 |
+
return self.sample_move(board, legal_moves_indices)
|
| 40 |
+
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def sample_move(self, board, legal_moves_indices=None):
|
| 43 |
+
if legal_moves_indices is None:
|
| 44 |
+
legal_moves_indices = get_legal_moves_indices(board)
|
| 45 |
+
if not legal_moves_indices:
|
| 46 |
+
self.stats["no_legal_idxs"] += 1
|
| 47 |
+
self.stats["random_fallback"] += 1
|
| 48 |
+
return random.choice(list(board.legal_moves))
|
| 49 |
+
board_tensor = board_to_tensor(board, self.device)
|
| 50 |
+
logits = self.model(board_tensor) # [1, A]
|
| 51 |
+
|
| 52 |
+
A = logits.size(-1)
|
| 53 |
+
masked = torch.full(
|
| 54 |
+
(A,),
|
| 55 |
+
-float("inf"),
|
| 56 |
+
device=self.device,
|
| 57 |
+
dtype=logits.dtype,
|
| 58 |
+
)
|
| 59 |
+
li = torch.tensor(legal_moves_indices, device=self.device, dtype=torch.long)
|
| 60 |
+
masked[li] = logits[0, li]
|
| 61 |
+
|
| 62 |
+
if self.cfg.greedy:
|
| 63 |
+
action_idx = int(torch.argmax(masked).item())
|
| 64 |
+
else:
|
| 65 |
+
temp = max(1e-6, self.cfg.temperature)
|
| 66 |
+
probs = F.softmax(masked / temp, dim=-1)
|
| 67 |
+
action_idx = int(torch.multinomial(probs, 1).item())
|
| 68 |
+
move = action_to_move(board, action_idx)
|
| 69 |
+
if move is None:
|
| 70 |
+
self.stats["mapping_failed"] += 1
|
| 71 |
+
self.stats["random_fallback"] += 1
|
| 72 |
+
return random.choice(list(board.legal_moves))
|
| 73 |
+
return move
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def eval_board(self, board, root_color):
|
| 77 |
+
board_tensor = board_to_tensor(board, self.device)
|
| 78 |
+
legal_moves_indices = get_legal_moves_indices(board)
|
| 79 |
+
if not legal_moves_indices:
|
| 80 |
+
# no moves -> treat via game result if available
|
| 81 |
+
outcome = board.outcome()
|
| 82 |
+
if outcome is not None:
|
| 83 |
+
if outcome.winner is None:
|
| 84 |
+
return 0.0
|
| 85 |
+
return 1.0 if outcome.winner == root_color else -1.0
|
| 86 |
+
|
| 87 |
+
logits = self.model(board_tensor) # [1, A]
|
| 88 |
+
A = logits.size(-1)
|
| 89 |
+
masked = torch.full(
|
| 90 |
+
(A,),
|
| 91 |
+
-float("inf"),
|
| 92 |
+
device=self.device,
|
| 93 |
+
dtype=logits.dtype,
|
| 94 |
+
)
|
| 95 |
+
li = torch.tensor(legal_moves_indices, device=self.device, dtype=torch.long)
|
| 96 |
+
masked[li] = logits[-1, li]
|
| 97 |
+
best_logit = float(torch.max(F.tanh(masked)).item())
|
| 98 |
+
return best_logit if board.turn == root_color else -best_logit
|
hf_space_repo/chess/rewards.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import chess
|
| 3 |
+
import chess.engine
|
| 4 |
+
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from src.grpo_self_play.chess.stockfish import stockfish_analyse, DEFAULT_STOCKFISH_TIMEOUT
|
| 7 |
+
|
| 8 |
+
# Engine name for reward evaluation
|
| 9 |
+
REWARD_ENGINE_NAME = f"reward_engine_{os.getpid()}"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_reward_engine_name() -> str:
|
| 13 |
+
"""Get process-specific engine name for reward evaluation."""
|
| 14 |
+
return f"reward_engine_{os.getpid()}"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _raw_white_reward(fen: str, movetime_ms: int, depth: int, timeout: float = DEFAULT_STOCKFISH_TIMEOUT) -> float:
|
| 18 |
+
"""Get raw centipawn evaluation from White's perspective using centralized wrapper."""
|
| 19 |
+
if depth and depth > 0:
|
| 20 |
+
limit = chess.engine.Limit(depth=depth)
|
| 21 |
+
else:
|
| 22 |
+
limit = chess.engine.Limit(time=movetime_ms / 1000.0)
|
| 23 |
+
|
| 24 |
+
info = stockfish_analyse(_get_reward_engine_name(), chess.Board(fen), limit, timeout=timeout)
|
| 25 |
+
|
| 26 |
+
if info is None:
|
| 27 |
+
return 0.0 # Fallback on engine failure
|
| 28 |
+
|
| 29 |
+
score = info["score"].pov(chess.WHITE)
|
| 30 |
+
if score.is_mate():
|
| 31 |
+
return 10000.0 if score.mate() > 0 else -10000.0
|
| 32 |
+
return float(score.score())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@lru_cache(maxsize=50_000)
|
| 36 |
+
def cached_raw_reward_white(fen: str, depth: int) -> float:
|
| 37 |
+
"""
|
| 38 |
+
Cached Stockfish raw eval for a given FEN from White's POV.
|
| 39 |
+
Returns centipawn score (positive = White is better).
|
| 40 |
+
Only caches by depth, not movetime since movetime is not deterministic.
|
| 41 |
+
"""
|
| 42 |
+
return _raw_white_reward(fen, movetime_ms=10, depth=depth)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def normalize_cp(raw_cp: float) -> float:
|
| 46 |
+
"""Normalize raw centipawn score to [-2, 2] using linear clipping."""
|
| 47 |
+
return float(max(-2.0, min(2.0, raw_cp / 1000.0)))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def evaluate_fen(fen: str, pov_is_white: bool, movetime_ms: int, depth: int, normalize: bool = True):
|
| 51 |
+
"""
|
| 52 |
+
Cached Stockfish eval for a given FEN and settings.
|
| 53 |
+
Returns a normalized reward in [-1, 1].
|
| 54 |
+
"""
|
| 55 |
+
if depth and depth > 0:
|
| 56 |
+
raw_score = cached_raw_reward_white(fen, depth)
|
| 57 |
+
else:
|
| 58 |
+
raw_score = _raw_white_reward(fen, movetime_ms, depth)
|
| 59 |
+
|
| 60 |
+
if not pov_is_white: # Flip sign for black POV
|
| 61 |
+
raw_score = -raw_score
|
| 62 |
+
# Normalize raw score using linear clipping instead of tanh
|
| 63 |
+
# Linear clipping preserves gradient signal regardless of position evaluation
|
| 64 |
+
# tanh was compressing differentials at higher absolute values
|
| 65 |
+
if normalize:
|
| 66 |
+
return normalize_cp(raw_score)
|
| 67 |
+
else:
|
| 68 |
+
return raw_score
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def evaluate_board(board: chess.Board, pov_is_white: bool, depth: int = 16, normalize: bool = True) -> float:
|
| 72 |
+
"""
|
| 73 |
+
Evaluate a board position from a given POV.
|
| 74 |
+
Returns normalized reward in [-2, 2] or raw centipawns if normalize=False.
|
| 75 |
+
"""
|
| 76 |
+
if board.is_game_over(claim_draw=True):
|
| 77 |
+
if board.is_checkmate():
|
| 78 |
+
pov_loses = (board.turn == (chess.WHITE if pov_is_white else chess.BLACK))
|
| 79 |
+
raw = -10000.0 if pov_loses else 10000.0
|
| 80 |
+
else:
|
| 81 |
+
raw = 0.0 # Draw
|
| 82 |
+
return normalize_cp(raw) if normalize else raw
|
| 83 |
+
else:
|
| 84 |
+
return evaluate_fen(board.fen(), pov_is_white, movetime_ms=0, depth=depth, normalize=normalize)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def reward_board(env: chess.Board, board_start: chess.Board, movetime_ms: int = 0, depth: int = 16) -> float:
|
| 88 |
+
"""
|
| 89 |
+
Stockfish-based reward from the perspective of board_start.turn,
|
| 90 |
+
matching your original intent.
|
| 91 |
+
|
| 92 |
+
env: current board (python-chess Board)
|
| 93 |
+
board_start: board at trajectory start (used for POV)
|
| 94 |
+
"""
|
| 95 |
+
pov_is_white = (board_start.turn == chess.WHITE)
|
| 96 |
+
if env.is_game_over(claim_draw=True): # Terminal state
|
| 97 |
+
if env.is_checkmate():
|
| 98 |
+
pov_loses = (env.turn == (chess.WHITE if pov_is_white else chess.BLACK))
|
| 99 |
+
r_t = -1.0 if pov_loses else 1.0
|
| 100 |
+
else:
|
| 101 |
+
r_t = 0.0 # Draw
|
| 102 |
+
else:
|
| 103 |
+
fen_t = env.fen()
|
| 104 |
+
r_t = evaluate_fen(fen_t, pov_is_white, movetime_ms, depth)
|
| 105 |
+
|
| 106 |
+
fen_0 = board_start.fen()
|
| 107 |
+
r_0 = evaluate_fen(fen_0, pov_is_white, movetime_ms, depth)
|
| 108 |
+
return r_t - r_0 # Reward is the change in eval
|
hf_space_repo/chess/searcher.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Implement search method to choose moves based on a policy network.
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
import chess
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from src.grpo_self_play.chess.chess_logic import ChessPlayer
|
| 11 |
+
from src.grpo_self_play.chess.policy_player import PolicyPlayer
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class SearchConfig:
|
| 15 |
+
n_trajectories: int = 1 # G: number of sampled trajectories
|
| 16 |
+
trajectory_depth: int = 1 # T: max plies per trajectory
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 20 |
+
torch.serialization.add_safe_globals([SearchConfig])
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TrajectorySearcher(ChessPlayer):
|
| 24 |
+
"""
|
| 25 |
+
Searcher that uses a PolicyPlayer to:
|
| 26 |
+
- sample trajectories using the policy
|
| 27 |
+
- evaluate their final states using the policy
|
| 28 |
+
and picks the first move of the best-scoring trajectory.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, policy: PolicyPlayer, cfg: SearchConfig = SearchConfig()):
|
| 32 |
+
self.policy = policy
|
| 33 |
+
self.cfg = cfg
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@torch.no_grad()
|
| 37 |
+
def act(self, board: chess.Board) -> Optional[chess.Move]:
|
| 38 |
+
"""
|
| 39 |
+
If n_trajectories or trajectory_depth <= 1:
|
| 40 |
+
Just use the policy's one-step act() (no search).
|
| 41 |
+
|
| 42 |
+
Otherwise:
|
| 43 |
+
Sample G trajectories, score each by final state,
|
| 44 |
+
pick first move of best trajectory.
|
| 45 |
+
"""
|
| 46 |
+
if self.cfg.n_trajectories <= 1 or self.cfg.trajectory_depth <= 1:
|
| 47 |
+
return self.policy.act(board)
|
| 48 |
+
|
| 49 |
+
root_color = board.turn
|
| 50 |
+
best_score = -float("inf")
|
| 51 |
+
best_first_move = None
|
| 52 |
+
|
| 53 |
+
for g in range(self.cfg.n_trajectories):
|
| 54 |
+
rollout_board = board.copy()
|
| 55 |
+
|
| 56 |
+
first_move = None
|
| 57 |
+
for step in range(self.cfg.trajectory_depth):
|
| 58 |
+
if rollout_board.is_game_over():
|
| 59 |
+
break
|
| 60 |
+
|
| 61 |
+
mv = self.policy.sample_move(rollout_board)
|
| 62 |
+
if mv is None:
|
| 63 |
+
# no move available -> end trajectory
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
if first_move is None:
|
| 67 |
+
first_move = mv
|
| 68 |
+
|
| 69 |
+
rollout_board.push(mv)
|
| 70 |
+
|
| 71 |
+
if first_move is None:
|
| 72 |
+
# This trajectory failed to get any move (should be rare)
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
score = self.policy.eval_board(rollout_board, root_color)
|
| 76 |
+
|
| 77 |
+
if score > best_score:
|
| 78 |
+
best_score = score
|
| 79 |
+
best_first_move = first_move
|
| 80 |
+
|
| 81 |
+
if best_first_move is None:
|
| 82 |
+
# Fallback to simple 1-step policy
|
| 83 |
+
return self.policy.act(board)
|
| 84 |
+
|
| 85 |
+
return best_first_move
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def stats(self) -> dict:
|
| 90 |
+
return self.policy.stats
|
hf_space_repo/chess/stockfish.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import threading
|
| 3 |
+
import chess
|
| 4 |
+
import chess.engine
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
| 9 |
+
from src.grpo_self_play.chess.chess_logic import ChessPlayer
|
| 10 |
+
from src.grpo_self_play.logging_utils import get_logger
|
| 11 |
+
|
| 12 |
+
logger = get_logger("grpo_chess.stockfish")
|
| 13 |
+
|
| 14 |
+
DEFAULT_STOCKFISH_PATH = "/usr/games/stockfish"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class StockfishConfig:
|
| 19 |
+
path: str = DEFAULT_STOCKFISH_PATH
|
| 20 |
+
skill_level: int = 20
|
| 21 |
+
use_elo_limit: bool = False
|
| 22 |
+
elo: int = 2500
|
| 23 |
+
movetime_ms: int = 50
|
| 24 |
+
threads: int = 1
|
| 25 |
+
hash_mb: int = 128
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 29 |
+
torch.serialization.add_safe_globals([StockfishConfig])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StockfishManager:
|
| 33 |
+
'''
|
| 34 |
+
Manage stockfish engine instances by name for player, eval and reward engines.
|
| 35 |
+
For example, We will use several enignes at diffrenet levels for evaluation,
|
| 36 |
+
or for reward we will limit by time.
|
| 37 |
+
'''
|
| 38 |
+
_pid: int = os.getpid()
|
| 39 |
+
_engines: dict[str, chess.engine.SimpleEngine] = {}
|
| 40 |
+
_cfgs: dict[str, StockfishConfig] = {}
|
| 41 |
+
_locks: dict[str, threading.Lock] = {} # Per-engine locks for thread safety
|
| 42 |
+
_manager_lock: threading.Lock = threading.Lock() # Lock for managing _engines/_locks dicts
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def ensure_pid(cls) -> None:
|
| 47 |
+
pid = os.getpid()
|
| 48 |
+
if pid != cls._pid:
|
| 49 |
+
# We are in a forked/spawned child; discard inherited engine handles.
|
| 50 |
+
# This is a workaround to avoid issues with multiprocessing.
|
| 51 |
+
cls._pid = pid
|
| 52 |
+
cls._engines = {}
|
| 53 |
+
cls._cfgs = {}
|
| 54 |
+
cls._locks = {}
|
| 55 |
+
cls._manager_lock = threading.Lock()
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def _configure_engine(cls, engine: chess.engine.SimpleEngine, cfg: StockfishConfig) -> None:
|
| 59 |
+
try:
|
| 60 |
+
engine.configure({"Threads": cfg.threads})
|
| 61 |
+
except Exception:
|
| 62 |
+
logger.warning("Failed to set Stockfish threads")
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
engine.configure({"Hash": cfg.hash_mb})
|
| 66 |
+
except Exception:
|
| 67 |
+
logger.warning("Failed to set Stockfish hash size")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
engine.configure({"Skill Level": cfg.skill_level})
|
| 71 |
+
except Exception:
|
| 72 |
+
logger.warning("Failed to set Stockfish skill level")
|
| 73 |
+
|
| 74 |
+
if cfg.use_elo_limit:
|
| 75 |
+
try:
|
| 76 |
+
engine.configure({
|
| 77 |
+
"UCI_LimitStrength": True,
|
| 78 |
+
"UCI_Elo": cfg.elo,
|
| 79 |
+
})
|
| 80 |
+
except Exception:
|
| 81 |
+
logger.warning("Failed to set Stockfish ELO limit")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def is_name_registered(cls, name: str) -> bool:
|
| 86 |
+
return name in cls._engines
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def get_lock(cls, name: str) -> threading.Lock:
|
| 90 |
+
"""Get the lock for a named engine (creates if needed)."""
|
| 91 |
+
with cls._manager_lock:
|
| 92 |
+
if name not in cls._locks:
|
| 93 |
+
cls._locks[name] = threading.Lock()
|
| 94 |
+
return cls._locks[name]
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def get_engine(cls, name: str, cfg: StockfishConfig | None = None) -> chess.engine.SimpleEngine:
|
| 98 |
+
"""
|
| 99 |
+
Get (or create) a named engine instance.
|
| 100 |
+
- name: e.g. "reward", "player"
|
| 101 |
+
- cfg: config to use when creating it (ignored later calls).
|
| 102 |
+
"""
|
| 103 |
+
cls.ensure_pid() # Check if we are in a forked/spawned child and discard inherited engine handles.
|
| 104 |
+
with cls._manager_lock:
|
| 105 |
+
if not cls.is_name_registered(name):
|
| 106 |
+
if cfg is None:
|
| 107 |
+
cfg = StockfishConfig()
|
| 108 |
+
engine = chess.engine.SimpleEngine.popen_uci(cfg.path)
|
| 109 |
+
cls._configure_engine(engine, cfg)
|
| 110 |
+
cls._engines[name] = engine
|
| 111 |
+
cls._cfgs[name] = cfg
|
| 112 |
+
cls._locks[name] = threading.Lock()
|
| 113 |
+
return cls._engines[name]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def close(cls, name: str) -> None:
|
| 118 |
+
with cls._manager_lock:
|
| 119 |
+
engine = cls._engines.get(name)
|
| 120 |
+
if engine is not None:
|
| 121 |
+
try:
|
| 122 |
+
engine.quit()
|
| 123 |
+
except Exception:
|
| 124 |
+
logger.warning(f"Failed to close Stockfish engine '{name}'")
|
| 125 |
+
finally:
|
| 126 |
+
cls._engines.pop(name, None)
|
| 127 |
+
cls._cfgs.pop(name, None)
|
| 128 |
+
cls._locks.pop(name, None)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def close_all(cls) -> None:
|
| 133 |
+
for name in list(cls._engines.keys()):
|
| 134 |
+
cls.close(name)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Default timeout for Stockfish operations (seconds)
|
| 139 |
+
DEFAULT_STOCKFISH_TIMEOUT = 10.0
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def run_with_timeout(func, timeout: float, *args, **kwargs):
|
| 143 |
+
"""Run a function with a timeout.
|
| 144 |
+
|
| 145 |
+
Uses a single threading.Thread + join(timeout) instead of ThreadPoolExecutor
|
| 146 |
+
so that this works correctly in forked child processes (ProcessPoolExecutor
|
| 147 |
+
with fork). ThreadPoolExecutor can deadlock in forked workers due to
|
| 148 |
+
inherited lock state.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
func: Function to call
|
| 152 |
+
timeout: Maximum time to wait (seconds)
|
| 153 |
+
*args, **kwargs: Arguments to pass to func
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Result of func
|
| 157 |
+
|
| 158 |
+
Raises:
|
| 159 |
+
FuturesTimeoutError: If the function doesn't complete within timeout
|
| 160 |
+
"""
|
| 161 |
+
result_holder: list = []
|
| 162 |
+
exc_holder: list = []
|
| 163 |
+
|
| 164 |
+
def target() -> None:
|
| 165 |
+
try:
|
| 166 |
+
out = func(*args, **kwargs)
|
| 167 |
+
result_holder.append(out)
|
| 168 |
+
except BaseException as e:
|
| 169 |
+
exc_holder.append(e)
|
| 170 |
+
|
| 171 |
+
t = threading.Thread(target=target, daemon=True)
|
| 172 |
+
t.start()
|
| 173 |
+
t.join(timeout=timeout)
|
| 174 |
+
if t.is_alive():
|
| 175 |
+
raise FuturesTimeoutError()
|
| 176 |
+
if exc_holder:
|
| 177 |
+
raise exc_holder[0]
|
| 178 |
+
return result_holder[0]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def stockfish_analyse(
|
| 182 |
+
engine_name: str,
|
| 183 |
+
board: chess.Board,
|
| 184 |
+
limit: chess.engine.Limit,
|
| 185 |
+
timeout: float = DEFAULT_STOCKFISH_TIMEOUT,
|
| 186 |
+
cfg: StockfishConfig | None = None,
|
| 187 |
+
attempts_n: int = 2
|
| 188 |
+
) -> Optional[chess.engine.InfoDict]:
|
| 189 |
+
"""Analyse a position with Stockfish, with timeout and crash recovery.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
engine_name: Name of the engine instance to use
|
| 193 |
+
board: Chess board position to analyse
|
| 194 |
+
limit: Search limit (depth, time, etc.)
|
| 195 |
+
timeout: Maximum time to wait for response (seconds)
|
| 196 |
+
cfg: Optional config for engine creation
|
| 197 |
+
attempts_n: how many attempts to try
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Analysis info dict, or None if analysis failed
|
| 201 |
+
"""
|
| 202 |
+
for attempt in range(attempts_n):
|
| 203 |
+
try:
|
| 204 |
+
engine = StockfishManager.get_engine(engine_name, cfg)
|
| 205 |
+
lock = StockfishManager.get_lock(engine_name)
|
| 206 |
+
with lock:
|
| 207 |
+
return run_with_timeout(engine.analyse, timeout, board, limit)
|
| 208 |
+
except chess.engine.EngineTerminatedError:
|
| 209 |
+
logger.error(f"Stockfish engine '{engine_name}' terminated unexpectedly, recreating...")
|
| 210 |
+
StockfishManager.close(engine_name)
|
| 211 |
+
if attempt == 1:
|
| 212 |
+
return None
|
| 213 |
+
except FuturesTimeoutError:
|
| 214 |
+
logger.warning(f"Stockfish analyse timed out after {timeout}s for engine '{engine_name}'")
|
| 215 |
+
return None
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Stockfish analyse error: {e}")
|
| 218 |
+
return None
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def stockfish_play(
|
| 223 |
+
engine_name: str,
|
| 224 |
+
board: chess.Board,
|
| 225 |
+
limit: chess.engine.Limit,
|
| 226 |
+
timeout: float = DEFAULT_STOCKFISH_TIMEOUT,
|
| 227 |
+
cfg: StockfishConfig | None = None,
|
| 228 |
+
) -> Optional[chess.Move]:
|
| 229 |
+
"""Get best move from Stockfish, with timeout and crash recovery.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
engine_name: Name of the engine instance to use
|
| 233 |
+
board: Chess board position
|
| 234 |
+
limit: Search limit (depth, time, etc.)
|
| 235 |
+
timeout: Maximum time to wait for response (seconds)
|
| 236 |
+
cfg: Optional config for engine creation
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Best move, or None if engine failed
|
| 240 |
+
"""
|
| 241 |
+
if board.is_game_over():
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
for attempt in range(2):
|
| 245 |
+
try:
|
| 246 |
+
engine = StockfishManager.get_engine(engine_name, cfg)
|
| 247 |
+
lock = StockfishManager.get_lock(engine_name)
|
| 248 |
+
with lock:
|
| 249 |
+
result = run_with_timeout(engine.play, timeout, board, limit)
|
| 250 |
+
return result.move
|
| 251 |
+
except chess.engine.EngineTerminatedError:
|
| 252 |
+
logger.error(f"Stockfish engine '{engine_name}' terminated unexpectedly, recreating...")
|
| 253 |
+
StockfishManager.close(engine_name)
|
| 254 |
+
if attempt == 1:
|
| 255 |
+
return None
|
| 256 |
+
except FuturesTimeoutError:
|
| 257 |
+
logger.warning(f"Stockfish play timed out after {timeout}s for engine '{engine_name}'")
|
| 258 |
+
return None
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"Stockfish play error: {e}")
|
| 261 |
+
return None
|
| 262 |
+
return None
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class StockfishPlayer(ChessPlayer):
|
| 266 |
+
'''
|
| 267 |
+
A chess player that uses Stockfish engine to select moves.
|
| 268 |
+
'''
|
| 269 |
+
|
| 270 |
+
DEFUALT_PLAYER_ENGINE_NAME = "player_engine"
|
| 271 |
+
|
| 272 |
+
def __init__(self, cfg: StockfishConfig, engine_name: Optional[str] = None):
|
| 273 |
+
if engine_name is None:
|
| 274 |
+
engine_name = self.DEFUALT_PLAYER_ENGINE_NAME
|
| 275 |
+
self.engine_name = engine_name
|
| 276 |
+
self.cfg = cfg
|
| 277 |
+
self.engine = StockfishManager.get_engine(self.engine_name, cfg)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def close(self):
|
| 281 |
+
try:
|
| 282 |
+
StockfishManager.close(self.engine_name)
|
| 283 |
+
except Exception:
|
| 284 |
+
logger.warning("Failed to close Stockfish engine in StockfishPlayer")
|
| 285 |
+
|
| 286 |
+
def act(self, board: chess.Board) -> chess.Move | None:
|
| 287 |
+
limit = chess.engine.Limit(time=self.cfg.movetime_ms / 1000.0)
|
| 288 |
+
return stockfish_play(self.engine_name, board, limit, cfg=self.cfg)
|
hf_space_repo/configs/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Config module for GRPO Chess experiments.
|
| 3 |
+
|
| 4 |
+
Provides YAML-based configuration loading with override support.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
from src.grpo_self_play.configs import load_experiment_config
|
| 8 |
+
|
| 9 |
+
# Load default config
|
| 10 |
+
config = load_experiment_config()
|
| 11 |
+
|
| 12 |
+
# Load with overrides
|
| 13 |
+
config = load_experiment_config("default.yaml", overrides={
|
| 14 |
+
"grpo": {"lr": 1e-4},
|
| 15 |
+
"training": {"num_epochs": 100},
|
| 16 |
+
})
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from src.grpo_self_play.configs.config_loader import (
|
| 20 |
+
ExperimentConfig,
|
| 21 |
+
TrainingConfig,
|
| 22 |
+
load_experiment_config,
|
| 23 |
+
load_grpo_config,
|
| 24 |
+
load_transformer_config,
|
| 25 |
+
load_eval_config,
|
| 26 |
+
load_stockfish_config,
|
| 27 |
+
load_dataset_config,
|
| 28 |
+
list_available_configs,
|
| 29 |
+
print_config_summary,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
__all__ = [
|
| 33 |
+
"ExperimentConfig",
|
| 34 |
+
"TrainingConfig",
|
| 35 |
+
"load_experiment_config",
|
| 36 |
+
"load_grpo_config",
|
| 37 |
+
"load_transformer_config",
|
| 38 |
+
"load_eval_config",
|
| 39 |
+
"load_stockfish_config",
|
| 40 |
+
"load_dataset_config",
|
| 41 |
+
"list_available_configs",
|
| 42 |
+
"print_config_summary",
|
| 43 |
+
]
|
hf_space_repo/configs/config_loader.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Config loader for GRPO Chess experiments.
|
| 3 |
+
|
| 4 |
+
This module provides utilities to load experiment configurations from YAML files
|
| 5 |
+
and convert them to the appropriate dataclass objects.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
from src.grpo_self_play.configs.config_loader import load_experiment_config
|
| 9 |
+
|
| 10 |
+
# Load a complete experiment config
|
| 11 |
+
config = load_experiment_config("default.yaml")
|
| 12 |
+
|
| 13 |
+
# Load with overrides
|
| 14 |
+
config = load_experiment_config("default.yaml", overrides={
|
| 15 |
+
"grpo": {"lr": 1e-4, "entropy_coef": 0.2},
|
| 16 |
+
"training": {"num_epochs": 100},
|
| 17 |
+
})
|
| 18 |
+
|
| 19 |
+
# Access configs
|
| 20 |
+
grpo_config = config.grpo
|
| 21 |
+
transformer_config = config.transformer
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from dataclasses import dataclass, fields
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any, Optional, TypeVar, Type
|
| 27 |
+
import yaml
|
| 28 |
+
|
| 29 |
+
# Import all config dataclasses
|
| 30 |
+
from src.grpo_self_play.grpo_logic.model import GRPOConfig
|
| 31 |
+
from src.grpo_self_play.models import ChessTransformerConfig
|
| 32 |
+
from src.grpo_self_play.eval_utils import EvalConfig
|
| 33 |
+
from src.grpo_self_play.chess.stockfish import StockfishConfig
|
| 34 |
+
from src.grpo_self_play.chess.policy_player import PolicyConfig
|
| 35 |
+
from src.grpo_self_play.chess.searcher import SearchConfig
|
| 36 |
+
from src.grpo_self_play.chess.boards_dataset import ChessDatasetConfig
|
| 37 |
+
from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Directory containing config YAML files
|
| 41 |
+
CONFIGS_DIR = Path(__file__).parent
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class TrainingConfig:
|
| 46 |
+
"""Training loop configuration."""
|
| 47 |
+
num_epochs: int = 400
|
| 48 |
+
batch_size: int = 32
|
| 49 |
+
steps_per_epoch: int = 512
|
| 50 |
+
checkpoint_every_n_epochs: int = 5
|
| 51 |
+
keep_n_checkpoints: int = 3
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class ExperimentConfig:
|
| 56 |
+
"""Complete experiment configuration containing all sub-configs."""
|
| 57 |
+
training: TrainingConfig
|
| 58 |
+
grpo: GRPOConfig
|
| 59 |
+
transformer: ChessTransformerConfig
|
| 60 |
+
eval: EvalConfig
|
| 61 |
+
stockfish: StockfishConfig
|
| 62 |
+
policy: PolicyConfig
|
| 63 |
+
searcher: Optional[SearchConfig]
|
| 64 |
+
dataset: ChessDatasetConfig
|
| 65 |
+
pretrain: PretrainLoadConfig
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
T = TypeVar('T')
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _deep_merge(base: dict, overrides: dict) -> dict:
|
| 72 |
+
"""Deep merge two dictionaries, with overrides taking precedence.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
base: Base dictionary
|
| 76 |
+
overrides: Dictionary with values to override
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Merged dictionary
|
| 80 |
+
"""
|
| 81 |
+
result = base.copy()
|
| 82 |
+
for key, value in overrides.items():
|
| 83 |
+
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
| 84 |
+
result[key] = _deep_merge(result[key], value)
|
| 85 |
+
else:
|
| 86 |
+
result[key] = value
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def dict_to_dataclass(cls: Type[T], data: dict[str, Any]) -> T:
|
| 91 |
+
"""Convert a dictionary to a dataclass, ignoring extra keys.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
cls: The dataclass type to instantiate
|
| 95 |
+
data: Dictionary with field values
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Instance of the dataclass with values from data
|
| 99 |
+
"""
|
| 100 |
+
if data is None:
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
# Get valid field names for this dataclass
|
| 104 |
+
valid_fields = {f.name for f in fields(cls)}
|
| 105 |
+
|
| 106 |
+
# Filter to only include valid fields
|
| 107 |
+
filtered_data = {k: v for k, v in data.items() if k in valid_fields}
|
| 108 |
+
|
| 109 |
+
return cls(**filtered_data)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_yaml_file(path: str | Path) -> dict[str, Any]:
|
| 113 |
+
"""Load a YAML config file.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
path: Path to the YAML file (absolute or relative to configs dir)
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Dictionary containing the parsed YAML
|
| 120 |
+
"""
|
| 121 |
+
path = Path(path)
|
| 122 |
+
|
| 123 |
+
# If not absolute, look in configs directory
|
| 124 |
+
if not path.is_absolute():
|
| 125 |
+
path = CONFIGS_DIR / path
|
| 126 |
+
|
| 127 |
+
if not path.exists():
|
| 128 |
+
raise FileNotFoundError(f"Config file not found: {path}")
|
| 129 |
+
|
| 130 |
+
with open(path, 'r') as f:
|
| 131 |
+
return yaml.safe_load(f)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def load_experiment_config(
|
| 135 |
+
path: str | Path = "default.yaml",
|
| 136 |
+
overrides: dict[str, dict[str, Any]] | None = None
|
| 137 |
+
) -> ExperimentConfig:
|
| 138 |
+
"""Load a complete experiment configuration from a YAML file.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
path: Path to the YAML file (absolute or relative to configs dir)
|
| 142 |
+
overrides: Optional dict of overrides per section. Example:
|
| 143 |
+
{
|
| 144 |
+
"grpo": {"lr": 1e-4, "entropy_coef": 0.2},
|
| 145 |
+
"training": {"num_epochs": 100},
|
| 146 |
+
"stockfish": {"skill_level": 5},
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
ExperimentConfig containing all sub-configs
|
| 151 |
+
"""
|
| 152 |
+
data = load_yaml_file(path)
|
| 153 |
+
|
| 154 |
+
# Apply overrides if provided
|
| 155 |
+
if overrides:
|
| 156 |
+
data = _deep_merge(data, overrides)
|
| 157 |
+
|
| 158 |
+
# Convert each section to its dataclass
|
| 159 |
+
training = dict_to_dataclass(TrainingConfig, data.get('training', {}))
|
| 160 |
+
grpo = dict_to_dataclass(GRPOConfig, data.get('grpo', {}))
|
| 161 |
+
transformer = dict_to_dataclass(ChessTransformerConfig, data.get('transformer', {}))
|
| 162 |
+
eval_cfg = dict_to_dataclass(EvalConfig, data.get('eval', {}))
|
| 163 |
+
stockfish = dict_to_dataclass(StockfishConfig, data.get('stockfish', {}))
|
| 164 |
+
policy = dict_to_dataclass(PolicyConfig, data.get('policy', {}))
|
| 165 |
+
dataset = dict_to_dataclass(ChessDatasetConfig, data.get('dataset', {}))
|
| 166 |
+
pretrain = dict_to_dataclass(PretrainLoadConfig, data.get('pretrain', {}))
|
| 167 |
+
|
| 168 |
+
# Searcher is optional (can be null)
|
| 169 |
+
searcher_data = data.get('searcher')
|
| 170 |
+
searcher = dict_to_dataclass(SearchConfig, searcher_data) if searcher_data else None
|
| 171 |
+
|
| 172 |
+
return ExperimentConfig(
|
| 173 |
+
training=training,
|
| 174 |
+
grpo=grpo,
|
| 175 |
+
transformer=transformer,
|
| 176 |
+
eval=eval_cfg,
|
| 177 |
+
stockfish=stockfish,
|
| 178 |
+
policy=policy,
|
| 179 |
+
searcher=searcher,
|
| 180 |
+
dataset=dataset,
|
| 181 |
+
pretrain=pretrain,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_grpo_config(
|
| 186 |
+
path: str | Path = "default.yaml",
|
| 187 |
+
overrides: dict[str, Any] | None = None
|
| 188 |
+
) -> GRPOConfig:
|
| 189 |
+
"""Load just the GRPO config from a YAML file.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
path: Path to the YAML file
|
| 193 |
+
overrides: Optional dict of field overrides. Example: {"lr": 1e-4}
|
| 194 |
+
"""
|
| 195 |
+
data = load_yaml_file(path)
|
| 196 |
+
grpo_data = data.get('grpo', {})
|
| 197 |
+
if overrides:
|
| 198 |
+
grpo_data = _deep_merge(grpo_data, overrides)
|
| 199 |
+
return dict_to_dataclass(GRPOConfig, grpo_data)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def load_transformer_config(
|
| 203 |
+
path: str | Path = "default.yaml",
|
| 204 |
+
overrides: dict[str, Any] | None = None
|
| 205 |
+
) -> ChessTransformerConfig:
|
| 206 |
+
"""Load just the transformer config from a YAML file."""
|
| 207 |
+
data = load_yaml_file(path)
|
| 208 |
+
cfg_data = data.get('transformer', {})
|
| 209 |
+
if overrides:
|
| 210 |
+
cfg_data = _deep_merge(cfg_data, overrides)
|
| 211 |
+
return dict_to_dataclass(ChessTransformerConfig, cfg_data)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def load_eval_config(
|
| 215 |
+
path: str | Path = "default.yaml",
|
| 216 |
+
overrides: dict[str, Any] | None = None
|
| 217 |
+
) -> EvalConfig:
|
| 218 |
+
"""Load just the eval config from a YAML file."""
|
| 219 |
+
data = load_yaml_file(path)
|
| 220 |
+
cfg_data = data.get('eval', {})
|
| 221 |
+
if overrides:
|
| 222 |
+
cfg_data = _deep_merge(cfg_data, overrides)
|
| 223 |
+
return dict_to_dataclass(EvalConfig, cfg_data)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def load_stockfish_config(
|
| 227 |
+
path: str | Path = "default.yaml",
|
| 228 |
+
overrides: dict[str, Any] | None = None
|
| 229 |
+
) -> StockfishConfig:
|
| 230 |
+
"""Load just the stockfish config from a YAML file."""
|
| 231 |
+
data = load_yaml_file(path)
|
| 232 |
+
cfg_data = data.get('stockfish', {})
|
| 233 |
+
if overrides:
|
| 234 |
+
cfg_data = _deep_merge(cfg_data, overrides)
|
| 235 |
+
return dict_to_dataclass(StockfishConfig, cfg_data)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def load_dataset_config(
|
| 239 |
+
path: str | Path = "default.yaml",
|
| 240 |
+
overrides: dict[str, Any] | None = None
|
| 241 |
+
) -> ChessDatasetConfig:
|
| 242 |
+
"""Load just the dataset config from a YAML file."""
|
| 243 |
+
data = load_yaml_file(path)
|
| 244 |
+
cfg_data = data.get('dataset', {})
|
| 245 |
+
if overrides:
|
| 246 |
+
cfg_data = _deep_merge(cfg_data, overrides)
|
| 247 |
+
return dict_to_dataclass(ChessDatasetConfig, cfg_data)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def list_available_configs() -> list[str]:
|
| 251 |
+
"""List all available YAML config files in the configs directory."""
|
| 252 |
+
return [f.name for f in CONFIGS_DIR.glob("*.yaml")]
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def print_config_summary(config: ExperimentConfig) -> None:
|
| 256 |
+
"""Print a summary of the experiment configuration."""
|
| 257 |
+
print("=" * 60)
|
| 258 |
+
print("EXPERIMENT CONFIGURATION")
|
| 259 |
+
print("=" * 60)
|
| 260 |
+
|
| 261 |
+
print("\n[Training]")
|
| 262 |
+
print(f" epochs: {config.training.num_epochs}")
|
| 263 |
+
print(f" batch_size: {config.training.batch_size}")
|
| 264 |
+
print(f" steps_per_epoch: {config.training.steps_per_epoch}")
|
| 265 |
+
|
| 266 |
+
print("\n[GRPO]")
|
| 267 |
+
print(f" lr: {config.grpo.lr}")
|
| 268 |
+
print(f" num_trajectories: {config.grpo.num_trajectories}")
|
| 269 |
+
print(f" trajectory_depth: {config.grpo.trajectory_depth}")
|
| 270 |
+
print(f" entropy_coef: {config.grpo.entropy_coef}")
|
| 271 |
+
print(f" rollout_temperature: {config.grpo.rollout_temperature}")
|
| 272 |
+
print(f" adaptive_kl: {config.grpo.adaptive_kl}")
|
| 273 |
+
print(f" use_entropy_floor: {config.grpo.use_entropy_floor}")
|
| 274 |
+
|
| 275 |
+
print("\n[Transformer]")
|
| 276 |
+
print(f" embed_dim: {config.transformer.embed_dim}")
|
| 277 |
+
print(f" num_layers: {config.transformer.num_layers}")
|
| 278 |
+
print(f" num_heads: {config.transformer.num_heads}")
|
| 279 |
+
|
| 280 |
+
print("\n[Eval]")
|
| 281 |
+
print(f" games: {config.eval.games}")
|
| 282 |
+
print(f" max_plies: {config.eval.max_plies}")
|
| 283 |
+
|
| 284 |
+
print("\n[Stockfish]")
|
| 285 |
+
print(f" skill_level: {config.stockfish.skill_level}")
|
| 286 |
+
|
| 287 |
+
print("\n[Searcher]")
|
| 288 |
+
print(f" enabled: {config.searcher is not None}")
|
| 289 |
+
|
| 290 |
+
print("=" * 60)
|
hf_space_repo/configs/default.yaml
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default experiment configuration
|
| 2 |
+
# This file contains all hyperparameters for a training run.
|
| 3 |
+
# Copy this file and modify for new experiments.
|
| 4 |
+
|
| 5 |
+
# =============================================================================
|
| 6 |
+
# Training Loop Settings
|
| 7 |
+
# =============================================================================
|
| 8 |
+
training:
|
| 9 |
+
num_epochs: 400
|
| 10 |
+
batch_size: 32
|
| 11 |
+
steps_per_epoch: 512
|
| 12 |
+
checkpoint_every_n_epochs: 5 # Save periodic checkpoint every N epochs for crash recovery
|
| 13 |
+
keep_n_checkpoints: 3 # Keep last N periodic checkpoints per run
|
| 14 |
+
|
| 15 |
+
# =============================================================================
|
| 16 |
+
# GRPO (Group Relative Policy Optimization) Config
|
| 17 |
+
# Clean run config (see research_docs/2026-02-06_loss-budget-and-monitor-analysis.md)
|
| 18 |
+
# =============================================================================
|
| 19 |
+
grpo:
|
| 20 |
+
lr: 0.000001 # 1e-6: reduced because PPO signal now dominates gradient
|
| 21 |
+
num_trajectories: 16
|
| 22 |
+
trajectory_depth: 16
|
| 23 |
+
clip_ratio: 0.20
|
| 24 |
+
kl_coef: 0.001 # reduced from 0.01 (was being overridden to 0.1 by adaptive KL)
|
| 25 |
+
entropy_coef: 0.0 # removed: not part of original GRPO loss, was 95% of gradient
|
| 26 |
+
eval_every_n_epochs: 10
|
| 27 |
+
ppo_steps: 1
|
| 28 |
+
rollout_temperature: 1.3
|
| 29 |
+
|
| 30 |
+
# Entropy floor monitoring — disabled (never triggered, see research doc)
|
| 31 |
+
use_entropy_floor: false
|
| 32 |
+
entropy_floor: 1.5
|
| 33 |
+
entropy_floor_steps: 150
|
| 34 |
+
entropy_floor_action: "boost"
|
| 35 |
+
entropy_boost_factor: 1.5
|
| 36 |
+
|
| 37 |
+
# Adaptive KL controller — disabled (saturated at max instantly, see research doc)
|
| 38 |
+
adaptive_kl: false
|
| 39 |
+
target_kl: 0.012
|
| 40 |
+
kl_adapt_rate: 1.2
|
| 41 |
+
kl_coef_min: 0.001
|
| 42 |
+
kl_coef_max: 0.1
|
| 43 |
+
|
| 44 |
+
# Safety checks
|
| 45 |
+
enable_safety_checks: false
|
| 46 |
+
safety_patience_steps: 1000
|
| 47 |
+
max_clip_fraction: 0.95
|
| 48 |
+
min_entropy: 0.5
|
| 49 |
+
max_kl_divergence: 0.08
|
| 50 |
+
|
| 51 |
+
# Teacher forcing: use Stockfish for rival moves during trajectory sampling
|
| 52 |
+
teacher_forcing_prob: 0.1 # 10% of rival moves will be from Stockfish
|
| 53 |
+
teacher_forcing_depth: 4
|
| 54 |
+
|
| 55 |
+
# =============================================================================
|
| 56 |
+
# Transformer Model Config
|
| 57 |
+
# =============================================================================
|
| 58 |
+
transformer:
|
| 59 |
+
vocab_size: 300
|
| 60 |
+
embed_dim: 256
|
| 61 |
+
num_layers: 4
|
| 62 |
+
num_heads: 8
|
| 63 |
+
action_dim: 1968
|
| 64 |
+
|
| 65 |
+
# =============================================================================
|
| 66 |
+
# Evaluation Config (vs Stockfish)
|
| 67 |
+
# =============================================================================
|
| 68 |
+
eval:
|
| 69 |
+
games: 64
|
| 70 |
+
seed: 0
|
| 71 |
+
max_plies: 400
|
| 72 |
+
randomize_opening: true
|
| 73 |
+
opening_plies: 6
|
| 74 |
+
|
| 75 |
+
# =============================================================================
|
| 76 |
+
# Stockfish Config
|
| 77 |
+
# =============================================================================
|
| 78 |
+
stockfish:
|
| 79 |
+
path: "/usr/games/stockfish" # Override in colab/local as needed
|
| 80 |
+
skill_level: 2
|
| 81 |
+
use_elo_limit: false
|
| 82 |
+
elo: 2500
|
| 83 |
+
movetime_ms: 50
|
| 84 |
+
threads: 1
|
| 85 |
+
hash_mb: 128
|
| 86 |
+
|
| 87 |
+
# =============================================================================
|
| 88 |
+
# Policy Player Config (for evaluation)
|
| 89 |
+
# =============================================================================
|
| 90 |
+
policy:
|
| 91 |
+
temperature: 0.8
|
| 92 |
+
greedy: true
|
| 93 |
+
branching_factor: 4
|
| 94 |
+
search_depth: 2
|
| 95 |
+
|
| 96 |
+
# =============================================================================
|
| 97 |
+
# Searcher Config (optional - set to null to disable)
|
| 98 |
+
# =============================================================================
|
| 99 |
+
searcher: null
|
| 100 |
+
# searcher:
|
| 101 |
+
# n_trajectories: 4
|
| 102 |
+
# trajectory_depth: 8
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# Pretraining (optional - load pretrained weights before GRPO)
|
| 106 |
+
# =============================================================================
|
| 107 |
+
pretrain:
|
| 108 |
+
checkpoint_path: null # Path to pretrained checkpoint (e.g., "checkpoints/pretrain/pretrain_final.pt")
|
| 109 |
+
freeze_layers: 2 # Freeze first 2 transformer layers to preserve learned representations
|
| 110 |
+
|
| 111 |
+
# =============================================================================
|
| 112 |
+
# Dataset Config (Chess Start States)
|
| 113 |
+
# =============================================================================
|
| 114 |
+
dataset:
|
| 115 |
+
max_steps: 512 # Should match steps_per_epoch
|
| 116 |
+
phase_distribution:
|
| 117 |
+
opening: 0.33
|
| 118 |
+
middlegame: 0.34
|
| 119 |
+
endgame: 0.33
|
| 120 |
+
min_eval_cp: -200
|
| 121 |
+
max_eval_cp: 200
|
| 122 |
+
quality_filter: true
|
| 123 |
+
stockfish_filter_depth: 4
|
hf_space_repo/configs/pretrain.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretraining configuration for chess model
|
| 2 |
+
# This file contains hyperparameters for supervised pretraining on Lichess games.
|
| 3 |
+
#
|
| 4 |
+
# Usage:
|
| 5 |
+
# python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
|
| 6 |
+
|
| 7 |
+
# =============================================================================
|
| 8 |
+
# Pretraining Settings
|
| 9 |
+
# =============================================================================
|
| 10 |
+
pretrain:
|
| 11 |
+
lr: 0.0001 # Learning rate (higher than GRPO fine-tuning)
|
| 12 |
+
batch_size: 4096 # Batch size for pretraining
|
| 13 |
+
num_epochs: 22 # Number of passes through the dataset
|
| 14 |
+
warmup_steps: 1000 # Linear warmup steps
|
| 15 |
+
weight_decay: 0.01 # AdamW weight decay
|
| 16 |
+
max_grad_norm: 1.0 # Gradient clipping
|
| 17 |
+
checkpoint_dir: "checkpoints/pretrain"
|
| 18 |
+
resume_from: null # Path to resume from (optional)
|
| 19 |
+
use_wandb: true
|
| 20 |
+
wandb_project: "chess-grpo-pretrain"
|
| 21 |
+
label_smoothing: 0.1 # Prevents overconfidence
|
| 22 |
+
num_workers: 4 # DataLoader workers
|
| 23 |
+
val_check_interval: 0.1 # Validate every 10% of epoch
|
| 24 |
+
|
| 25 |
+
# =============================================================================
|
| 26 |
+
# Dataset Settings (Lichess games from HuggingFace)
|
| 27 |
+
# =============================================================================
|
| 28 |
+
dataset:
|
| 29 |
+
min_elo: 1800 # Minimum player rating to include
|
| 30 |
+
max_samples: 5000000 # Max samples per epoch (null = unlimited)
|
| 31 |
+
skip_first_n_moves: 5 # Skip opening moves (book territory)
|
| 32 |
+
skip_last_n_moves: 5 # Skip endgame/resignation moves
|
| 33 |
+
sample_positions_per_game: 3 # Positions to sample from each game
|
| 34 |
+
buffer_size: 10000 # Shuffle buffer size for streaming
|
| 35 |
+
filter_abandoned: true # Skip abandoned games
|
| 36 |
+
dataset_name: "Lichess/standard-chess-games"
|
| 37 |
+
split: "train" # Dataset split to use
|
| 38 |
+
is_eval: false # False for training, True for evaluation
|
| 39 |
+
eval_fraction: 0.05 # 5% of games held out for evaluation
|
| 40 |
+
|
| 41 |
+
# =============================================================================
|
| 42 |
+
# Transformer Model Config (should match GRPO training)
|
| 43 |
+
# =============================================================================
|
| 44 |
+
transformer:
|
| 45 |
+
vocab_size: 300
|
| 46 |
+
embed_dim: 256
|
| 47 |
+
num_layers: 4
|
| 48 |
+
num_heads: 8
|
| 49 |
+
action_dim: 1968
|
hf_space_repo/constants.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constants used across the GRPO self-play module."""
|
| 2 |
+
|
| 3 |
+
# Sequence length for tokenized FEN strings
|
| 4 |
+
SEQUENCE_LENGTH = 77
|
| 5 |
+
|
| 6 |
+
# Default training hyperparameters
|
| 7 |
+
DEFAULT_LEARNING_RATE = 1e-4
|
| 8 |
+
DEFAULT_NUM_TRAJECTORIES = 4
|
| 9 |
+
DEFAULT_TRAJECTORY_DEPTH = 5
|
| 10 |
+
DEFAULT_CLIP_RATIO = 0.2
|
| 11 |
+
DEFAULT_KL_COEF = 0.01
|
| 12 |
+
|
| 13 |
+
# Default evaluation settings
|
| 14 |
+
DEFAULT_EVAL_GAMES = 50
|
| 15 |
+
DEFAULT_EVAL_MAX_PLIES = 400
|
hf_space_repo/eval_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities for evaluating chess policies against Stockfish."""
|
| 2 |
+
import io
|
| 3 |
+
import math
|
| 4 |
+
import chess
|
| 5 |
+
import chess.pgn
|
| 6 |
+
import chess.engine
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
from src.grpo_self_play.chess.chess_logic import MOVE_TO_ACTION
|
| 15 |
+
from src.grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig
|
| 16 |
+
from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
|
| 17 |
+
from src.grpo_self_play.chess.stockfish import StockfishPlayer, StockfishConfig, DEFAULT_STOCKFISH_PATH as STOCKFISH_PATH
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class EvalConfig:
|
| 22 |
+
games: int = 50
|
| 23 |
+
seed: int = 0
|
| 24 |
+
max_plies: int = 400 # safety to avoid extremely long games
|
| 25 |
+
randomize_opening: bool = False
|
| 26 |
+
opening_plies: int = 6 # random legal moves to diversify early positions
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 30 |
+
torch.serialization.add_safe_globals([EvalConfig])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def debug_legal_coverage(board: chess.Board) -> tuple[int, int, list[str]]:
|
| 34 |
+
"""Debug function to check coverage of legal moves in action space.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
board: Chess board position
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tuple of (covered_count, total_legal_moves, list_of_missing_moves)
|
| 41 |
+
"""
|
| 42 |
+
legals = list(board.legal_moves)
|
| 43 |
+
covered = 0
|
| 44 |
+
missing = []
|
| 45 |
+
for mv in legals:
|
| 46 |
+
u = mv.uci()
|
| 47 |
+
if u in MOVE_TO_ACTION:
|
| 48 |
+
covered += 1
|
| 49 |
+
else:
|
| 50 |
+
missing.append(u)
|
| 51 |
+
return covered, len(legals), missing[:10]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def play_one_game(
|
| 58 |
+
policy: PolicyPlayer | TrajectorySearcher,
|
| 59 |
+
stockfish: StockfishPlayer,
|
| 60 |
+
policy_is_white: bool,
|
| 61 |
+
cfg: EvalConfig,
|
| 62 |
+
game_number: int = 0,
|
| 63 |
+
) -> Tuple[str, str, str]:
|
| 64 |
+
"""Play a single game between policy and Stockfish.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
policy: Policy player to evaluate
|
| 68 |
+
stockfish: Stockfish player
|
| 69 |
+
policy_is_white: Whether policy plays as white
|
| 70 |
+
cfg: Evaluation configuration
|
| 71 |
+
game_number: Game number for PGN metadata
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (result_str, termination_reason, pgn_str)
|
| 75 |
+
result_str in {"1-0", "0-1", "1/2-1/2"}
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
board = chess.Board()
|
| 79 |
+
game = chess.pgn.Game()
|
| 80 |
+
game.headers["Event"] = "Policy vs Stockfish Evaluation"
|
| 81 |
+
game.headers["White"] = "Policy" if policy_is_white else "Stockfish"
|
| 82 |
+
game.headers["Black"] = "Stockfish" if policy_is_white else "Policy"
|
| 83 |
+
game.headers["Round"] = str(game_number + 1)
|
| 84 |
+
node = game
|
| 85 |
+
|
| 86 |
+
# Optional random opening to reduce overfitting to a single line
|
| 87 |
+
if cfg.randomize_opening and cfg.opening_plies > 0:
|
| 88 |
+
for _ in range(cfg.opening_plies):
|
| 89 |
+
if board.is_game_over():
|
| 90 |
+
break
|
| 91 |
+
move = random.choice(list(board.legal_moves))
|
| 92 |
+
board.push(move)
|
| 93 |
+
node = node.add_variation(move)
|
| 94 |
+
|
| 95 |
+
for ply in range(cfg.max_plies):
|
| 96 |
+
if board.is_game_over(claim_draw=True):
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
is_white_to_move = board.turn
|
| 100 |
+
policy_turn = (is_white_to_move and policy_is_white) or ((not is_white_to_move) and (not policy_is_white))
|
| 101 |
+
|
| 102 |
+
if policy_turn:
|
| 103 |
+
move = policy.act(board)
|
| 104 |
+
else:
|
| 105 |
+
move = stockfish.act(board)
|
| 106 |
+
if move is None:
|
| 107 |
+
break # no legal moves
|
| 108 |
+
|
| 109 |
+
board.push(move)
|
| 110 |
+
node = node.add_variation(move)
|
| 111 |
+
|
| 112 |
+
# Determine result
|
| 113 |
+
if board.is_game_over(claim_draw=True):
|
| 114 |
+
res = board.result(claim_draw=True)
|
| 115 |
+
reason = "game_over"
|
| 116 |
+
else:
|
| 117 |
+
# Reached max plies: treat as draw
|
| 118 |
+
res = "1/2-1/2"
|
| 119 |
+
reason = "max_plies"
|
| 120 |
+
|
| 121 |
+
game.headers["Result"] = res
|
| 122 |
+
|
| 123 |
+
# Generate PGN string
|
| 124 |
+
pgn_output = io.StringIO()
|
| 125 |
+
exporter = chess.pgn.FileExporter(pgn_output)
|
| 126 |
+
game.accept(exporter)
|
| 127 |
+
pgn_str = pgn_output.getvalue()
|
| 128 |
+
|
| 129 |
+
return res, reason, pgn_str
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def estimate_elo_diff(score: float) -> float:
|
| 133 |
+
"""Estimate Elo difference from match score.
|
| 134 |
+
|
| 135 |
+
Uses logistic model: S = 1/(1+10^(-d/400)) => d = -400*log10(1/S - 1)
|
| 136 |
+
Clamped for numeric stability.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
score: Win rate score in [0, 1]
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Estimated Elo difference
|
| 143 |
+
"""
|
| 144 |
+
eps = 1e-6
|
| 145 |
+
s = min(max(score, eps), 1 - eps)
|
| 146 |
+
return -400.0 * math.log10(1.0 / s - 1.0)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def evaluate_policy_vs_stockfish(
|
| 150 |
+
policy: PolicyPlayer | TrajectorySearcher,
|
| 151 |
+
sf: StockfishPlayer,
|
| 152 |
+
eval_cfg: EvalConfig,
|
| 153 |
+
) -> Tuple[Dict, PolicyPlayer | TrajectorySearcher, List[str]]:
|
| 154 |
+
"""Evaluate a policy by playing multiple games against Stockfish.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
policy: Policy player to evaluate
|
| 158 |
+
sf: Stockfish player
|
| 159 |
+
eval_cfg: Evaluation configuration
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Tuple of (results_dict, policy_player, pgns)
|
| 163 |
+
results_dict contains: games, wins, draws, losses, score, elo_diff, etc.
|
| 164 |
+
pgns is a list of PGN strings for all games played
|
| 165 |
+
"""
|
| 166 |
+
random.seed(eval_cfg.seed)
|
| 167 |
+
torch.manual_seed(eval_cfg.seed)
|
| 168 |
+
|
| 169 |
+
wins = draws = losses = 0
|
| 170 |
+
term_reasons = {}
|
| 171 |
+
pgns: List[str] = []
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
for g in range(eval_cfg.games):
|
| 175 |
+
policy_is_white = (g % 2 == 0)
|
| 176 |
+
res, reason, pgn = play_one_game(policy, sf, policy_is_white, eval_cfg, game_number=g)
|
| 177 |
+
term_reasons[reason] = term_reasons.get(reason, 0) + 1
|
| 178 |
+
pgns.append(pgn)
|
| 179 |
+
|
| 180 |
+
# From policy perspective
|
| 181 |
+
if res == "1-0":
|
| 182 |
+
if policy_is_white:
|
| 183 |
+
wins += 1
|
| 184 |
+
else:
|
| 185 |
+
losses += 1
|
| 186 |
+
elif res == "0-1":
|
| 187 |
+
if policy_is_white:
|
| 188 |
+
losses += 1
|
| 189 |
+
else:
|
| 190 |
+
wins += 1
|
| 191 |
+
else:
|
| 192 |
+
draws += 1
|
| 193 |
+
|
| 194 |
+
finally:
|
| 195 |
+
sf.close()
|
| 196 |
+
|
| 197 |
+
total = wins + draws + losses
|
| 198 |
+
score = (wins + 0.5 * draws) / total if total else 0.0
|
| 199 |
+
elo_diff = estimate_elo_diff(score) if total else 0.0
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
"games": total,
|
| 203 |
+
"wins": wins,
|
| 204 |
+
"draws": draws,
|
| 205 |
+
"losses": losses,
|
| 206 |
+
"score": score,
|
| 207 |
+
"elo_diff_vs_stockfish_approx": elo_diff,
|
| 208 |
+
"termination_reasons": term_reasons,
|
| 209 |
+
"eval_cfg": eval_cfg,
|
| 210 |
+
}, policy, pgns
|
| 211 |
+
|
hf_space_repo/evaluator.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
from chess import engine
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from src.grpo_self_play.chess.policy_player import PolicyPlayer, PolicyConfig
|
| 6 |
+
from src.grpo_self_play.chess.searcher import TrajectorySearcher, SearchConfig
|
| 7 |
+
from src.grpo_self_play.chess.stockfish import StockfishPlayer, StockfishConfig, StockfishManager
|
| 8 |
+
from src.grpo_self_play.eval_utils import EvalConfig, evaluate_policy_vs_stockfish
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Evaluator:
|
| 13 |
+
"""Evaluate a chess model by playing against Stockfish.
|
| 14 |
+
|
| 15 |
+
Handles evaluation of chess policies against Stockfish at various skill levels.
|
| 16 |
+
Supports both single evaluations and skill ladder evaluations.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self,
|
| 19 |
+
eval_cfg: EvalConfig = EvalConfig(),
|
| 20 |
+
policy_cfg: PolicyConfig = PolicyConfig(),
|
| 21 |
+
searcher_cfg: Optional[SearchConfig] = None,
|
| 22 |
+
stockfish_cfg: StockfishConfig = StockfishConfig()):
|
| 23 |
+
"""
|
| 24 |
+
Initialize evaluator.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
eval_cfg: Evaluation configuration (number of games, etc.)
|
| 28 |
+
policy_cfg: Policy player configuration
|
| 29 |
+
searcher_cfg: Optional search configuration for tree search
|
| 30 |
+
stockfish_cfg: Stockfish engine configuration
|
| 31 |
+
"""
|
| 32 |
+
self.eval_cfg = eval_cfg
|
| 33 |
+
self.policy_cfg = policy_cfg
|
| 34 |
+
self.searcher_cfg = searcher_cfg
|
| 35 |
+
self.default_stockfish_cfg = stockfish_cfg
|
| 36 |
+
|
| 37 |
+
def _make_policy(self, model: nn.Module) -> PolicyPlayer | TrajectorySearcher:
|
| 38 |
+
"""Create a policy player (optionally wrapped with search).
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: Neural network model
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Policy player, optionally wrapped with trajectory search
|
| 45 |
+
"""
|
| 46 |
+
policy = PolicyPlayer(model, cfg=self.policy_cfg)
|
| 47 |
+
if self.searcher_cfg is not None:
|
| 48 |
+
policy = TrajectorySearcher(policy, cfg=self.searcher_cfg)
|
| 49 |
+
return policy
|
| 50 |
+
|
| 51 |
+
def _make_stockfish(self) -> StockfishPlayer:
|
| 52 |
+
"""Create a Stockfish player with default configuration.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Stockfish player instance
|
| 56 |
+
"""
|
| 57 |
+
return StockfishPlayer(self.default_stockfish_cfg)
|
| 58 |
+
|
| 59 |
+
def single_evaluation(self, model: nn.Module) -> Tuple[Dict, PolicyPlayer | TrajectorySearcher, List[str]]:
|
| 60 |
+
"""Evaluate the model by playing games against Stockfish.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model: Neural network model to evaluate
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (results_dict, policy_or_searcher, pgns)
|
| 67 |
+
pgns is a list of PGN strings for all games played
|
| 68 |
+
"""
|
| 69 |
+
stockfish_player = self._make_stockfish()
|
| 70 |
+
policy = self._make_policy(model)
|
| 71 |
+
results, policy_or_searcher, pgns = evaluate_policy_vs_stockfish(
|
| 72 |
+
policy,
|
| 73 |
+
stockfish_player,
|
| 74 |
+
self.eval_cfg,
|
| 75 |
+
)
|
| 76 |
+
return results, policy_or_searcher, pgns
|
| 77 |
+
|
| 78 |
+
def eval_ladder(self, model: nn.Module) -> Dict[int, float]:
|
| 79 |
+
"""Evaluate model against Stockfish at multiple skill levels.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
model: Neural network model to evaluate
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Dictionary mapping skill level to win rate score
|
| 86 |
+
"""
|
| 87 |
+
policy = self._make_policy(model)
|
| 88 |
+
results = {}
|
| 89 |
+
skill_levels = [1, 3, 5, 8, 10]
|
| 90 |
+
for skill in skill_levels:
|
| 91 |
+
stockfish_cfg = StockfishConfig(
|
| 92 |
+
path=self.default_stockfish_cfg.path,
|
| 93 |
+
skill_level=skill,
|
| 94 |
+
movetime_ms=self.default_stockfish_cfg.movetime_ms,
|
| 95 |
+
)
|
| 96 |
+
engine_name = f"stockfish_skill_{skill}"
|
| 97 |
+
stockfish_player = StockfishPlayer(stockfish_cfg, engine_name=engine_name)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
r, policy_wrapper, _ = evaluate_policy_vs_stockfish(
|
| 101 |
+
policy,
|
| 102 |
+
stockfish_player,
|
| 103 |
+
self.eval_cfg,
|
| 104 |
+
)
|
| 105 |
+
results[skill] = r["score"]
|
| 106 |
+
print(f"Skill {skill}: {r}")
|
| 107 |
+
if hasattr(policy_wrapper, 'stats'):
|
| 108 |
+
print(f'Policy stats: {policy_wrapper.stats}')
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Error evaluating at skill {skill}: {e}")
|
| 111 |
+
results[skill] = 0.0
|
| 112 |
+
finally:
|
| 113 |
+
StockfishManager.close(engine_name) # Close engine to free resources
|
| 114 |
+
return results
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
hf_space_repo/grpo_logic/__init__.py
ADDED
|
File without changes
|
hf_space_repo/grpo_logic/loss.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class GRPOLossInfo:
|
| 8 |
+
"""Information about GRPO loss components for logging and debugging."""
|
| 9 |
+
kl_div: torch.Tensor
|
| 10 |
+
mean_ratio: torch.Tensor
|
| 11 |
+
mean_clip_fraction: torch.Tensor
|
| 12 |
+
ppo_loss: torch.Tensor
|
| 13 |
+
entropy: torch.Tensor
|
| 14 |
+
loss_without_entropy: torch.Tensor
|
| 15 |
+
|
| 16 |
+
def grpo_chess_loss(
|
| 17 |
+
logprobs_new: torch.Tensor, # [G, T] log πθ(a_{g,k,t} | s_{g,k,t})
|
| 18 |
+
logprobs_old: torch.Tensor, # [G, T] log πold(a_{g,k,t} | s_{g,k,t})
|
| 19 |
+
advantages: torch.Tensor, # [G, T]
|
| 20 |
+
clip_eps: float = 0.2, # ε in the formula
|
| 21 |
+
beta_kl: float = 0.0, # β in the formula (0 = no explicit KL penalty)
|
| 22 |
+
eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 23 |
+
"""
|
| 24 |
+
Compute GRPO chess loss (legacy function, consider using grpo_ppo_loss instead).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
logprobs_new: New policy log probabilities [G, T]
|
| 28 |
+
logprobs_old: Old policy log probabilities [G, T]
|
| 29 |
+
advantages: Advantage values [G, T]
|
| 30 |
+
clip_eps: PPO clipping epsilon
|
| 31 |
+
beta_kl: KL penalty coefficient
|
| 32 |
+
eps: Numerical stability epsilon
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tuple of (loss, approximate_kl_divergence)
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# ------------------------------------------------------------
|
| 39 |
+
# 3. Probability ratio r_{g,k,t}(θ)
|
| 40 |
+
#
|
| 41 |
+
# r_{g,k,t}(θ) = πθ(a_{g,k,t}|s_{g,k,t}) / πold(a_{g,k,t}|s_{g,k,t})
|
| 42 |
+
# = exp( logπθ - logπold )
|
| 43 |
+
# ------------------------------------------------------------
|
| 44 |
+
ratio = (logprobs_new - logprobs_old).exp() # [G, T]
|
| 45 |
+
pg_unclipped = -advantages * ratio # [G, T]
|
| 46 |
+
pg_clipped = -advantages * ratio.clamp(1.0 - clip_eps, 1.0 + clip_eps) # [G, T]
|
| 47 |
+
|
| 48 |
+
# Surrogate policy gradient loss (PPO-clip part)
|
| 49 |
+
# This corresponds to the -E[min(...)] in the formula.
|
| 50 |
+
policy_loss = torch.max(pg_unclipped, pg_clipped).mean()
|
| 51 |
+
approx_kl = (logprobs_old - logprobs_new).mean()
|
| 52 |
+
|
| 53 |
+
# KL penalty: β * E[ KL(...) ]
|
| 54 |
+
kl_loss = beta_kl * approx_kl
|
| 55 |
+
loss = policy_loss + kl_loss
|
| 56 |
+
|
| 57 |
+
return loss, approx_kl
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Utils functions for GRPO
|
| 61 |
+
def group_advantage(group_rewards: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Compute normalized advantages from group rewards using standardization.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
group_rewards: Group rewards tensor [B, G] or [G]
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Normalized advantages with same shape as input
|
| 70 |
+
"""
|
| 71 |
+
mean_reward = group_rewards.mean(dim=-1, keepdim=True)
|
| 72 |
+
std_reward = group_rewards.std(dim=-1, unbiased=False, keepdim=True) + 1e-8
|
| 73 |
+
advantages = (group_rewards - mean_reward) / std_reward
|
| 74 |
+
return advantages
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def step_group_advantage(step_rewards: torch.Tensor, pad_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
Compute per-step normalized advantages from step rewards.
|
| 80 |
+
For each timestep t, normalizes across the G dimension (trajectories).
|
| 81 |
+
|
| 82 |
+
NOTE: No std normalization is applied here, Using DR. GRPO paper.
|
| 83 |
+
Args:
|
| 84 |
+
step_rewards: Per-step rewards tensor [B, G, T]
|
| 85 |
+
pad_mask: Optional mask for valid steps [B, G, T], True=valid
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Normalized advantages [B, G, T] where each timestep is normalized across G
|
| 89 |
+
"""
|
| 90 |
+
# Normalize across G dimension for each (batch, timestep)
|
| 91 |
+
# step_rewards: [B, G, T]
|
| 92 |
+
mean_t = step_rewards.mean(dim=1, keepdim=True) # [B, 1, T]
|
| 93 |
+
advantages = (step_rewards - mean_t) # [B, G, T]
|
| 94 |
+
|
| 95 |
+
if pad_mask is not None:
|
| 96 |
+
advantages = advantages * pad_mask.float()
|
| 97 |
+
|
| 98 |
+
return advantages
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def ppo_chess_loss(
|
| 102 |
+
logprobs_new: torch.Tensor, # [G, T] log πθ(a_{g,k,t} | s_{g,k,t})
|
| 103 |
+
logprobs_old: torch.Tensor, # [G, T] log πold(a_{g,k,t} | s_{g,k,t})
|
| 104 |
+
advantages: torch.Tensor, # [G, T]
|
| 105 |
+
clip_eps: float = 0.2, # ε in the formula
|
| 106 |
+
pad_mask: torch.Tensor | None = None, # [G, T], True = real, False = pad
|
| 107 |
+
return_info: bool = False,
|
| 108 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 109 |
+
"""
|
| 110 |
+
Compute PPO-clip loss for chess policy optimization.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
logprobs_new: New policy log probabilities [B, G, T] or [G, T]
|
| 114 |
+
logprobs_old: Old policy log probabilities [B, G, T] or [G, T]
|
| 115 |
+
advantages: Advantage values [B, G, T] or [G, T]
|
| 116 |
+
clip_eps: PPO clipping epsilon (default: 0.2)
|
| 117 |
+
pad_mask: Mask indicating valid steps, True=valid, False=padding
|
| 118 |
+
return_info: If True, return additional statistics
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
If return_info=False: policy loss tensor [B, G, T] or [G, T]
|
| 122 |
+
If return_info=True: tuple of (policy_loss, mean_ratio, mean_clip_fraction)
|
| 123 |
+
"""
|
| 124 |
+
if pad_mask is None:
|
| 125 |
+
pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
|
| 126 |
+
ratio = (logprobs_new - logprobs_old).exp() # [G, T]
|
| 127 |
+
pg_unclipped = -advantages * ratio # [G, T]
|
| 128 |
+
pg_clipped = -advantages * ratio.clamp(1.0 - clip_eps, 1.0 + clip_eps) # [G, T]
|
| 129 |
+
# Surrogate policy gradient loss (PPO-clip part)
|
| 130 |
+
# This corresponds to the -E[min(...)] in the formula.
|
| 131 |
+
policy_loss = torch.max(pg_unclipped, pg_clipped) * pad_mask.float()
|
| 132 |
+
if return_info:
|
| 133 |
+
valid_steps = pad_mask.sum().clamp_min(1.0)
|
| 134 |
+
mean_padded_ratio = (ratio * pad_mask.float()).sum() / valid_steps
|
| 135 |
+
clip_fraction_mask = (ratio > (1.0 + clip_eps)) | (ratio < (1.0 - clip_eps))
|
| 136 |
+
mean_clip_fraction = (clip_fraction_mask.float() * pad_mask.float()).sum() / valid_steps
|
| 137 |
+
return policy_loss, mean_padded_ratio, mean_clip_fraction # [G, T], scalar, scalar
|
| 138 |
+
return policy_loss # [G, T]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def kl_penalty(logprobs_new: torch.Tensor,
|
| 142 |
+
logprobs_old: torch.Tensor,
|
| 143 |
+
pad_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 144 |
+
"""
|
| 145 |
+
Compute KL divergence penalty between old and new policies.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
logprobs_new: New policy log probabilities
|
| 149 |
+
logprobs_old: Old policy log probabilities
|
| 150 |
+
pad_mask: Optional mask for valid steps
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Mean KL divergence over valid steps
|
| 154 |
+
"""
|
| 155 |
+
if pad_mask is None:
|
| 156 |
+
pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
|
| 157 |
+
return (logprobs_old - logprobs_new)[pad_mask].mean()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def grpo_ppo_loss(
|
| 161 |
+
logprobs_new: torch.Tensor, # [B, G, T] or [G, T]
|
| 162 |
+
logprobs_old: torch.Tensor, # [B, G, T] or [G, T]
|
| 163 |
+
step_rewards: torch.Tensor, # [B, G, T] or [G, T] - per-step rewards
|
| 164 |
+
pad_mask: torch.Tensor | None = None, # [B, G, T] or [G, T]
|
| 165 |
+
clip_ratio: float = 0.2, # PPO clipping ratio (epsilon in paper)
|
| 166 |
+
kl_coef: float = 0.01, # KL penalty coefficient (beta in paper)
|
| 167 |
+
entropy_coef: float = 0.1, # Entropy bonus coefficient (prevents policy collapse)
|
| 168 |
+
return_info: bool = False, # Return extra info for logging
|
| 169 |
+
) -> torch.Tensor | Tuple[torch.Tensor, GRPOLossInfo]:
|
| 170 |
+
"""
|
| 171 |
+
Compute GRPO (Group Relative Policy Optimization) loss with PPO clipping.
|
| 172 |
+
|
| 173 |
+
This combines PPO-clip loss with KL divergence penalty and optional entropy bonus.
|
| 174 |
+
Advantages are computed per-step by normalizing step rewards across trajectories
|
| 175 |
+
(G dimension) for each timestep.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
logprobs_new: New policy log probabilities [B, G, T] or [G, T]
|
| 179 |
+
logprobs_old: Old policy log probabilities [B, G, T] or [G, T]
|
| 180 |
+
step_rewards: Per-step rewards [B, G, T] or [G, T]
|
| 181 |
+
pad_mask: Mask indicating valid steps, True=valid, False=padding
|
| 182 |
+
clip_ratio: PPO clipping ratio (default: 0.2)
|
| 183 |
+
kl_coef: KL divergence penalty coefficient (default: 0.01)
|
| 184 |
+
entropy_coef: Entropy bonus coefficient (default: 0.0, set >0 to encourage exploration)
|
| 185 |
+
return_info: If True, return GRPOLossInfo for logging
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
If return_info=False: scalar loss tensor
|
| 189 |
+
If return_info=True: tuple of (loss, GRPOLossInfo)
|
| 190 |
+
"""
|
| 191 |
+
# Handle 2D input (no batch dimension) by adding batch dimension
|
| 192 |
+
if logprobs_new.ndim == 2:
|
| 193 |
+
logprobs_new = logprobs_new.unsqueeze(0)
|
| 194 |
+
logprobs_old = logprobs_old.unsqueeze(0)
|
| 195 |
+
step_rewards = step_rewards.unsqueeze(0)
|
| 196 |
+
if pad_mask is not None:
|
| 197 |
+
pad_mask = pad_mask.unsqueeze(0)
|
| 198 |
+
|
| 199 |
+
if pad_mask is None:
|
| 200 |
+
pad_mask = torch.ones_like(logprobs_new, dtype=torch.bool)
|
| 201 |
+
|
| 202 |
+
# Compute per-step advantages (normalized across G for each timestep)
|
| 203 |
+
advantages = step_group_advantage(step_rewards, pad_mask).detach() # [B, G, T]
|
| 204 |
+
|
| 205 |
+
ppo_loss, mean_ratio, mean_clip_fraction = ppo_chess_loss(logprobs_new,
|
| 206 |
+
logprobs_old,
|
| 207 |
+
advantages,
|
| 208 |
+
clip_ratio,
|
| 209 |
+
pad_mask,
|
| 210 |
+
return_info=True)
|
| 211 |
+
valid_steps = pad_mask.sum().clamp_min(1)
|
| 212 |
+
ppo_loss = ppo_loss.sum() / valid_steps
|
| 213 |
+
kl_div = kl_penalty(logprobs_new, logprobs_old, pad_mask)
|
| 214 |
+
|
| 215 |
+
# Entropy bonus: H(π) ≈ -E[log π(a|s)] encourages exploration
|
| 216 |
+
# We use the negative log_probs of selected actions as an estimate
|
| 217 |
+
entropy = -logprobs_new[pad_mask].mean()
|
| 218 |
+
|
| 219 |
+
# Loss components:
|
| 220 |
+
# - loss_without_entropy = PPO loss + KL penalty
|
| 221 |
+
# - total loss = loss_without_entropy - entropy bonus
|
| 222 |
+
loss_without_entropy = ppo_loss + kl_coef * kl_div
|
| 223 |
+
loss = loss_without_entropy - entropy_coef * entropy
|
| 224 |
+
|
| 225 |
+
if return_info:
|
| 226 |
+
return loss, GRPOLossInfo(
|
| 227 |
+
kl_div=kl_div.detach(),
|
| 228 |
+
mean_ratio=mean_ratio.detach(),
|
| 229 |
+
mean_clip_fraction=mean_clip_fraction.detach(),
|
| 230 |
+
ppo_loss=ppo_loss.detach(),
|
| 231 |
+
entropy=entropy.detach(),
|
| 232 |
+
loss_without_entropy=loss_without_entropy.detach(),
|
| 233 |
+
)
|
| 234 |
+
return loss
|
| 235 |
+
|
hf_space_repo/grpo_logic/model.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import torch
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
import chess
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
from src.grpo_self_play.evaluator import Evaluator
|
| 9 |
+
from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
|
| 10 |
+
from src.grpo_self_play.grpo_logic.loss import grpo_ppo_loss
|
| 11 |
+
from src.grpo_self_play.grpo_logic.sampling import sample_trajectories_batched
|
| 12 |
+
from src.grpo_self_play.eval_utils import EvalConfig
|
| 13 |
+
from src.grpo_self_play.chess.policy_player import PolicyConfig
|
| 14 |
+
from src.grpo_self_play.chess.searcher import SearchConfig
|
| 15 |
+
from src.grpo_self_play.chess.stockfish import StockfishConfig
|
| 16 |
+
from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EntropyFloorMonitor:
|
| 20 |
+
"""Monitors entropy and takes action when it falls below a floor (Recommendation 1).
|
| 21 |
+
|
| 22 |
+
Tracks consecutive steps where entropy is below a threshold and triggers
|
| 23 |
+
configurable actions (warn, stop, or boost entropy_coef) when the threshold
|
| 24 |
+
is breached for too long.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, floor: float, steps_threshold: int, action: str, boost_factor: float):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
floor: Minimum entropy threshold
|
| 31 |
+
steps_threshold: Consecutive steps below floor before action
|
| 32 |
+
action: Action to take ("warn", "stop", "boost")
|
| 33 |
+
boost_factor: Factor to multiply entropy_coef when boosting
|
| 34 |
+
"""
|
| 35 |
+
self.floor = floor
|
| 36 |
+
self.steps_threshold = steps_threshold
|
| 37 |
+
self.action = action
|
| 38 |
+
self.boost_factor = boost_factor
|
| 39 |
+
self.consecutive_low_steps = 0
|
| 40 |
+
self.triggered = False
|
| 41 |
+
|
| 42 |
+
def check(self, entropy: float, current_entropy_coef: float) -> tuple[float, dict]:
|
| 43 |
+
"""Check entropy and return updated entropy_coef and metrics.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
entropy: Current entropy value
|
| 47 |
+
current_entropy_coef: Current entropy coefficient
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tuple of (new_entropy_coef, metrics_dict)
|
| 51 |
+
"""
|
| 52 |
+
metrics = {}
|
| 53 |
+
new_entropy_coef = current_entropy_coef
|
| 54 |
+
|
| 55 |
+
if entropy < self.floor:
|
| 56 |
+
self.consecutive_low_steps += 1
|
| 57 |
+
|
| 58 |
+
if self.consecutive_low_steps >= self.steps_threshold and not self.triggered:
|
| 59 |
+
self.triggered = True
|
| 60 |
+
if self.action == "warn":
|
| 61 |
+
print(f"WARNING: Entropy collapse detected! Entropy={entropy:.4f} < floor={self.floor} "
|
| 62 |
+
f"for {self.consecutive_low_steps} consecutive steps.")
|
| 63 |
+
elif self.action == "stop":
|
| 64 |
+
raise RuntimeError(
|
| 65 |
+
f"STOPPING: Entropy collapse detected! Entropy={entropy:.4f} < floor={self.floor} "
|
| 66 |
+
f"for {self.consecutive_low_steps} consecutive steps.")
|
| 67 |
+
elif self.action == "boost":
|
| 68 |
+
new_entropy_coef = current_entropy_coef * self.boost_factor
|
| 69 |
+
print(f"BOOSTING entropy_coef: {current_entropy_coef:.4f} -> {new_entropy_coef:.4f} "
|
| 70 |
+
f"(entropy={entropy:.4f} < floor={self.floor})")
|
| 71 |
+
self.consecutive_low_steps = 0
|
| 72 |
+
self.triggered = False
|
| 73 |
+
else:
|
| 74 |
+
self.consecutive_low_steps = 0
|
| 75 |
+
self.triggered = False
|
| 76 |
+
|
| 77 |
+
metrics["entropy_floor/consecutive_low_steps"] = self.consecutive_low_steps
|
| 78 |
+
metrics["entropy_floor/below_floor"] = float(entropy < self.floor)
|
| 79 |
+
metrics["entropy_floor/current_entropy_coef"] = new_entropy_coef
|
| 80 |
+
|
| 81 |
+
return new_entropy_coef, metrics
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_group_collapse_metrics(
|
| 85 |
+
actions: torch.Tensor,
|
| 86 |
+
group_rewards: torch.Tensor,
|
| 87 |
+
step_rewards: torch.Tensor,
|
| 88 |
+
pad_mask: torch.Tensor,
|
| 89 |
+
) -> dict:
|
| 90 |
+
"""Compute within-board group collapse metrics (Recommendation 4).
|
| 91 |
+
|
| 92 |
+
These metrics directly measure whether all G trajectories from the same board
|
| 93 |
+
are converging to the same moves, which is the key failure mode in entropy collapse.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
actions: Action indices [B, G, T]
|
| 97 |
+
group_rewards: Final rewards for each trajectory [B, G]
|
| 98 |
+
step_rewards: Per-step rewards [B, G, T]
|
| 99 |
+
pad_mask: Mask indicating valid steps [B, G, T], True=valid
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Dictionary of metrics for logging
|
| 103 |
+
"""
|
| 104 |
+
B, _, T = actions.shape
|
| 105 |
+
metrics = {}
|
| 106 |
+
|
| 107 |
+
# 1. Action agreement: for each (b, t), what fraction of trajectories chose the most common action?
|
| 108 |
+
# agreement[b,t] = max_count(actions[b,:,t]) / G
|
| 109 |
+
action_agreement = torch.zeros(B, T, device=actions.device)
|
| 110 |
+
for b in range(B):
|
| 111 |
+
for t in range(T):
|
| 112 |
+
if pad_mask[b, :, t].any(): # At least one valid trajectory at this timestep
|
| 113 |
+
valid_actions = actions[b, pad_mask[b, :, t], t]
|
| 114 |
+
if len(valid_actions) > 0:
|
| 115 |
+
# Count occurrences of each action
|
| 116 |
+
_, counts = valid_actions.unique(return_counts=True)
|
| 117 |
+
max_count = counts.max().item()
|
| 118 |
+
num_valid = pad_mask[b, :, t].sum().item()
|
| 119 |
+
action_agreement[b, t] = max_count / num_valid
|
| 120 |
+
|
| 121 |
+
# Mask to only consider valid (b, t) pairs
|
| 122 |
+
valid_bt_mask = pad_mask.any(dim=1) # [B, T] - True if any trajectory valid at (b, t)
|
| 123 |
+
valid_agreements = action_agreement[valid_bt_mask]
|
| 124 |
+
|
| 125 |
+
if len(valid_agreements) > 0:
|
| 126 |
+
metrics["group_collapse/action_agreement_mean"] = valid_agreements.mean().item()
|
| 127 |
+
metrics["group_collapse/action_agreement_p90"] = valid_agreements.quantile(0.9).item()
|
| 128 |
+
metrics["group_collapse/action_agreement_max"] = valid_agreements.max().item()
|
| 129 |
+
else:
|
| 130 |
+
metrics["group_collapse/action_agreement_mean"] = 0.0
|
| 131 |
+
metrics["group_collapse/action_agreement_p90"] = 0.0
|
| 132 |
+
metrics["group_collapse/action_agreement_max"] = 0.0
|
| 133 |
+
|
| 134 |
+
# 2. Within-board reward diversity: std(group_rewards[b,:]) for each board b
|
| 135 |
+
# This measures whether trajectories from the same starting position get similar rewards
|
| 136 |
+
reward_std_within = group_rewards.std(dim=1) # [B]
|
| 137 |
+
metrics["group_collapse/reward_std_within_mean"] = reward_std_within.mean().item()
|
| 138 |
+
metrics["group_collapse/reward_std_within_min"] = reward_std_within.min().item()
|
| 139 |
+
|
| 140 |
+
# 3. Within-board step reward diversity: std(step_rewards[b,:,t]) for each (b, t)
|
| 141 |
+
# Only compute for valid (b, t) pairs
|
| 142 |
+
step_reward_std_within = torch.zeros(B, T, device=step_rewards.device)
|
| 143 |
+
for b in range(B):
|
| 144 |
+
for t in range(T):
|
| 145 |
+
valid_mask_bt = pad_mask[b, :, t]
|
| 146 |
+
if valid_mask_bt.sum() > 1: # Need at least 2 valid trajectories for std
|
| 147 |
+
step_reward_std_within[b, t] = step_rewards[b, valid_mask_bt, t].std().item()
|
| 148 |
+
|
| 149 |
+
valid_step_stds = step_reward_std_within[valid_bt_mask]
|
| 150 |
+
if len(valid_step_stds) > 0:
|
| 151 |
+
metrics["group_collapse/step_reward_std_within_mean"] = valid_step_stds.mean().item()
|
| 152 |
+
metrics["group_collapse/step_reward_std_within_min"] = valid_step_stds.min().item()
|
| 153 |
+
else:
|
| 154 |
+
metrics["group_collapse/step_reward_std_within_mean"] = 0.0
|
| 155 |
+
metrics["group_collapse/step_reward_std_within_min"] = 0.0
|
| 156 |
+
|
| 157 |
+
return metrics
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class AdaptiveKLController:
|
| 161 |
+
"""Adapts KL coefficient to maintain target KL divergence (Recommendation 2).
|
| 162 |
+
|
| 163 |
+
Implements a simple multiplicative controller that increases kl_coef when
|
| 164 |
+
KL divergence exceeds target and decreases it when below target.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, initial_kl_coef: float, target_kl: float, adapt_rate: float,
|
| 168 |
+
kl_coef_min: float, kl_coef_max: float):
|
| 169 |
+
"""
|
| 170 |
+
Args:
|
| 171 |
+
initial_kl_coef: Starting KL coefficient
|
| 172 |
+
target_kl: Target KL divergence value
|
| 173 |
+
adapt_rate: Multiplicative factor for adjustment
|
| 174 |
+
kl_coef_min: Minimum allowed kl_coef
|
| 175 |
+
kl_coef_max: Maximum allowed kl_coef
|
| 176 |
+
"""
|
| 177 |
+
self.current_kl_coef = initial_kl_coef
|
| 178 |
+
self.target_kl = target_kl
|
| 179 |
+
self.adapt_rate = adapt_rate
|
| 180 |
+
self.kl_coef_min = kl_coef_min
|
| 181 |
+
self.kl_coef_max = kl_coef_max
|
| 182 |
+
|
| 183 |
+
def update(self, kl_div: float) -> dict:
|
| 184 |
+
"""Update KL coefficient based on current KL divergence.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
kl_div: Current KL divergence value
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Metrics dict for logging
|
| 191 |
+
"""
|
| 192 |
+
if kl_div > self.target_kl:
|
| 193 |
+
self.current_kl_coef = min(self.current_kl_coef * self.adapt_rate, self.kl_coef_max)
|
| 194 |
+
else:
|
| 195 |
+
self.current_kl_coef = max(self.current_kl_coef / self.adapt_rate, self.kl_coef_min)
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"adaptive_kl/current_kl_coef": self.current_kl_coef,
|
| 199 |
+
"adaptive_kl/target_kl": self.target_kl,
|
| 200 |
+
"adaptive_kl/kl_ratio": kl_div / self.target_kl if self.target_kl > 0 else 0.0,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@dataclass
|
| 205 |
+
class GRPOConfig:
|
| 206 |
+
"""Configuration for GRPO (Group Relative Policy Optimization) training.
|
| 207 |
+
|
| 208 |
+
Attributes:
|
| 209 |
+
lr: Learning rate for optimizer
|
| 210 |
+
num_trajectories: Number of trajectory groups to sample per batch
|
| 211 |
+
trajectory_depth: Maximum depth of each trajectory
|
| 212 |
+
clip_ratio: PPO clipping ratio (epsilon)
|
| 213 |
+
kl_coef: KL divergence penalty coefficient (beta)
|
| 214 |
+
entropy_coef: Entropy bonus coefficient (encourages exploration, prevents policy collapse)
|
| 215 |
+
eval_every_n_epochs: Frequency of evaluation runs (not used in model, but useful for trainer)
|
| 216 |
+
|
| 217 |
+
# Entropy floor monitoring (Recommendation 1)
|
| 218 |
+
use_entropy_floor: Whether to enable entropy floor monitoring
|
| 219 |
+
entropy_floor: Minimum entropy threshold for collapse detection
|
| 220 |
+
entropy_floor_steps: Number of consecutive steps below floor before alert/action
|
| 221 |
+
entropy_floor_action: Action to take when entropy floor is breached ("warn", "stop", "boost")
|
| 222 |
+
entropy_boost_factor: Factor to multiply entropy_coef when boosting (if action="boost")
|
| 223 |
+
|
| 224 |
+
# Adaptive KL controller (Recommendation 2)
|
| 225 |
+
adaptive_kl: Whether to use adaptive KL coefficient
|
| 226 |
+
target_kl: Target KL divergence value
|
| 227 |
+
kl_adapt_rate: Rate at which to adjust kl_coef (higher = faster adaptation)
|
| 228 |
+
kl_coef_min: Minimum allowed kl_coef
|
| 229 |
+
kl_coef_max: Maximum allowed kl_coef
|
| 230 |
+
|
| 231 |
+
# PPO-style multiple updates
|
| 232 |
+
ppo_steps: Number of optimization steps per sampled trajectory batch (reuses samples)
|
| 233 |
+
|
| 234 |
+
# Rollout temperature for exploration
|
| 235 |
+
rollout_temperature: Temperature for action sampling during rollouts (>1 increases exploration)
|
| 236 |
+
|
| 237 |
+
# Safety checks on training dynamics
|
| 238 |
+
enable_safety_checks: Whether to abort training when known-bad patterns persist
|
| 239 |
+
safety_patience_steps: Number of training steps to tolerate violations before aborting
|
| 240 |
+
max_clip_fraction: If mean_clip_fraction > this for too long -> abort
|
| 241 |
+
min_entropy: If entropy < this for too long -> abort
|
| 242 |
+
max_kl_divergence: If KL >> target_kl for too long -> abort
|
| 243 |
+
"""
|
| 244 |
+
# Clean run defaults (see research_docs/2026-02-06_loss-budget-and-monitor-analysis.md)
|
| 245 |
+
lr: float = 1e-6 # Reduced: PPO signal now dominates gradient
|
| 246 |
+
num_trajectories: int = 4
|
| 247 |
+
trajectory_depth: int = 5
|
| 248 |
+
clip_ratio: float = 0.2
|
| 249 |
+
kl_coef: float = 0.001 # Reduced from 0.01 (was overridden to 0.1 by adaptive KL)
|
| 250 |
+
entropy_coef: float = 0.0 # Removed: not in original GRPO loss, was 95% of gradient
|
| 251 |
+
eval_every_n_epochs: int = 10
|
| 252 |
+
|
| 253 |
+
# Entropy floor monitoring — disabled by default (never triggered in practice)
|
| 254 |
+
use_entropy_floor: bool = False
|
| 255 |
+
entropy_floor: float = 1.5
|
| 256 |
+
entropy_floor_steps: int = 200
|
| 257 |
+
entropy_floor_action: str = "boost"
|
| 258 |
+
entropy_boost_factor: float = 2.0
|
| 259 |
+
|
| 260 |
+
# Adaptive KL controller — disabled by default (saturated at max instantly)
|
| 261 |
+
adaptive_kl: bool = False
|
| 262 |
+
target_kl: float = 0.015
|
| 263 |
+
kl_adapt_rate: float = 1.2
|
| 264 |
+
kl_coef_min: float = 0.003
|
| 265 |
+
kl_coef_max: float = 0.05
|
| 266 |
+
|
| 267 |
+
# PPO-style multiple updates per sample
|
| 268 |
+
ppo_steps: int = 1
|
| 269 |
+
|
| 270 |
+
# Rollout temperature for exploration (>1 flattens distribution, increases entropy)
|
| 271 |
+
rollout_temperature: float = 1.0
|
| 272 |
+
|
| 273 |
+
# Safety checks on training dynamics
|
| 274 |
+
enable_safety_checks: bool = False
|
| 275 |
+
safety_patience_steps: int = 1000 # Number of training steps to tolerate violations
|
| 276 |
+
# Thresholds derived from prior research docs
|
| 277 |
+
max_clip_fraction: float = 0.95 # If mean_clip_fraction > this for too long -> abort
|
| 278 |
+
min_entropy: float = 0.5 # If entropy < this for too long -> abort
|
| 279 |
+
max_kl_divergence: float = 0.08 # If KL >> target_kl for too long -> abort
|
| 280 |
+
|
| 281 |
+
# Teacher forcing: use Stockfish for rival moves during trajectory sampling
|
| 282 |
+
teacher_forcing_prob: float = 0.0 # Probability of using Stockfish for rival (opponent) moves
|
| 283 |
+
teacher_forcing_depth: int = 4 # Stockfish search depth for teacher forcing moves
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 287 |
+
torch.serialization.add_safe_globals([GRPOConfig])
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class GRPOChessTransformer(pl.LightningModule):
|
| 291 |
+
"""PyTorch Lightning module for training chess policy with GRPO.
|
| 292 |
+
|
| 293 |
+
This module implements Group Relative Policy Optimization (GRPO) for training
|
| 294 |
+
a chess transformer policy. It maintains both a current policy and an old policy
|
| 295 |
+
for computing importance sampling ratios in the PPO loss.
|
| 296 |
+
|
| 297 |
+
Attributes:
|
| 298 |
+
policy_model: Current policy model being trained
|
| 299 |
+
old_policy_model: Frozen copy of policy for importance sampling
|
| 300 |
+
evaluator: Evaluator for running games against Stockfish
|
| 301 |
+
eval_every_n_epochs: Frequency of evaluation runs
|
| 302 |
+
entropy_monitor: Optional entropy floor monitor (Recommendation 1)
|
| 303 |
+
kl_controller: Optional adaptive KL controller (Recommendation 2)
|
| 304 |
+
current_entropy_coef: Current entropy coefficient (mutable for entropy boosting)
|
| 305 |
+
automatic_optimization: Set to False for manual PPO steps
|
| 306 |
+
"""
|
| 307 |
+
automatic_optimization = False # Manual optimization for ppo_steps
|
| 308 |
+
|
| 309 |
+
def __init__(self,
|
| 310 |
+
transformer_config: ChessTransformerConfig,
|
| 311 |
+
grpo_config: GRPOConfig,
|
| 312 |
+
eval_cfg: EvalConfig | None = None,
|
| 313 |
+
stockfish_cfg: StockfishConfig | None = None,
|
| 314 |
+
policy_cfg: PolicyConfig | None = None,
|
| 315 |
+
searcher_cfg: SearchConfig | None = None,
|
| 316 |
+
pretrain_cfg: PretrainLoadConfig | None = None):
|
| 317 |
+
"""
|
| 318 |
+
Initialize GRPO Chess Transformer.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
transformer_config: Configuration for the chess transformer model
|
| 322 |
+
grpo_config: GRPO training configuration
|
| 323 |
+
eval_cfg: Optional evaluation configuration
|
| 324 |
+
stockfish_cfg: Optional Stockfish configuration for evaluation
|
| 325 |
+
policy_cfg: Optional policy player configuration
|
| 326 |
+
searcher_cfg: Optional search configuration
|
| 327 |
+
pretrain_cfg: Optional pretrain config for loading pretrained weights
|
| 328 |
+
"""
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.save_hyperparameters()
|
| 331 |
+
self.policy_model = ChessTransformer(transformer_config)
|
| 332 |
+
self.old_policy_model = ChessTransformer(transformer_config)
|
| 333 |
+
|
| 334 |
+
# Load pretrained weights if specified
|
| 335 |
+
if pretrain_cfg and pretrain_cfg.checkpoint_path:
|
| 336 |
+
self._load_pretrained_weights(pretrain_cfg)
|
| 337 |
+
|
| 338 |
+
self._sync_old_policy()
|
| 339 |
+
|
| 340 |
+
# Evaluation config
|
| 341 |
+
self.eval_every_n_epochs = grpo_config.eval_every_n_epochs
|
| 342 |
+
self.evaluator = Evaluator(eval_cfg=eval_cfg or EvalConfig(),
|
| 343 |
+
policy_cfg=policy_cfg or PolicyConfig(),
|
| 344 |
+
stockfish_cfg=stockfish_cfg or StockfishConfig(),
|
| 345 |
+
searcher_cfg=searcher_cfg)
|
| 346 |
+
|
| 347 |
+
# Entropy floor monitor (Recommendation 1) - optional
|
| 348 |
+
self.entropy_monitor: EntropyFloorMonitor | None = None
|
| 349 |
+
if grpo_config.use_entropy_floor:
|
| 350 |
+
self.entropy_monitor = EntropyFloorMonitor(
|
| 351 |
+
floor=grpo_config.entropy_floor,
|
| 352 |
+
steps_threshold=grpo_config.entropy_floor_steps,
|
| 353 |
+
action=grpo_config.entropy_floor_action,
|
| 354 |
+
boost_factor=grpo_config.entropy_boost_factor,
|
| 355 |
+
)
|
| 356 |
+
self.current_entropy_coef = grpo_config.entropy_coef
|
| 357 |
+
|
| 358 |
+
# Adaptive KL controller (Recommendation 2) - optional
|
| 359 |
+
self.kl_controller: AdaptiveKLController | None = None
|
| 360 |
+
if grpo_config.adaptive_kl:
|
| 361 |
+
self.kl_controller = AdaptiveKLController(
|
| 362 |
+
initial_kl_coef=grpo_config.kl_coef,
|
| 363 |
+
target_kl=grpo_config.target_kl,
|
| 364 |
+
adapt_rate=grpo_config.kl_adapt_rate,
|
| 365 |
+
kl_coef_min=grpo_config.kl_coef_min,
|
| 366 |
+
kl_coef_max=grpo_config.kl_coef_max,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Safety-check state (for tracking persistent violations)
|
| 370 |
+
self._safety_step_idx: int = 0
|
| 371 |
+
self._high_clip_steps: int = 0
|
| 372 |
+
self._low_entropy_steps: int = 0
|
| 373 |
+
self._high_kl_steps: int = 0
|
| 374 |
+
|
| 375 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 376 |
+
"""Forward pass through the current policy model.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
x: Input tensor [batch, seq_len]
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
Policy logits [batch, action_dim]
|
| 383 |
+
"""
|
| 384 |
+
return self.policy_model(x)
|
| 385 |
+
|
| 386 |
+
def _old_forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 387 |
+
"""Forward pass through the old (frozen) policy model.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
x: Input tensor [batch, seq_len]
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
Policy logits [batch, action_dim]
|
| 394 |
+
"""
|
| 395 |
+
return self.old_policy_model(x)
|
| 396 |
+
|
| 397 |
+
def _sync_old_policy(self) -> None:
|
| 398 |
+
"""Synchronize old policy model with current policy and freeze it."""
|
| 399 |
+
self.old_policy_model.load_state_dict(self.policy_model.state_dict())
|
| 400 |
+
# Freeze old policy parameters
|
| 401 |
+
for param in self.old_policy_model.parameters():
|
| 402 |
+
param.requires_grad = False
|
| 403 |
+
|
| 404 |
+
def _load_pretrained_weights(self, pretrain_cfg: PretrainLoadConfig) -> None:
|
| 405 |
+
"""Load pretrained weights from a checkpoint.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
pretrain_cfg: Pretrain configuration with checkpoint path and freeze settings
|
| 409 |
+
"""
|
| 410 |
+
checkpoint_path = pretrain_cfg.checkpoint_path
|
| 411 |
+
print(f"Loading pretrained weights from: {checkpoint_path}")
|
| 412 |
+
|
| 413 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 414 |
+
|
| 415 |
+
# Handle different checkpoint formats
|
| 416 |
+
if 'model_state_dict' in checkpoint:
|
| 417 |
+
state_dict = checkpoint['model_state_dict']
|
| 418 |
+
elif 'state_dict' in checkpoint:
|
| 419 |
+
# Lightning checkpoint format - extract policy_model weights
|
| 420 |
+
state_dict = {}
|
| 421 |
+
for k, v in checkpoint['state_dict'].items():
|
| 422 |
+
if k.startswith('model.'):
|
| 423 |
+
# From PretrainChessTransformer
|
| 424 |
+
state_dict[k[6:]] = v # Remove 'model.' prefix
|
| 425 |
+
elif k.startswith('policy_model.'):
|
| 426 |
+
# From GRPOChessTransformer
|
| 427 |
+
state_dict[k[13:]] = v # Remove 'policy_model.' prefix
|
| 428 |
+
else:
|
| 429 |
+
# Assume it's a raw state dict
|
| 430 |
+
state_dict = checkpoint
|
| 431 |
+
|
| 432 |
+
# Load into policy model
|
| 433 |
+
missing, unexpected = self.policy_model.load_state_dict(state_dict, strict=False)
|
| 434 |
+
if missing:
|
| 435 |
+
print(f"Warning: Missing keys in pretrained checkpoint: {missing}")
|
| 436 |
+
if unexpected:
|
| 437 |
+
print(f"Warning: Unexpected keys in pretrained checkpoint: {unexpected}")
|
| 438 |
+
|
| 439 |
+
print(f"Successfully loaded pretrained weights")
|
| 440 |
+
|
| 441 |
+
# Optionally freeze transformer layers
|
| 442 |
+
if pretrain_cfg.freeze_layers > 0:
|
| 443 |
+
self._freeze_transformer_layers(pretrain_cfg.freeze_layers)
|
| 444 |
+
|
| 445 |
+
def _freeze_transformer_layers(self, num_layers: int) -> None:
|
| 446 |
+
"""Freeze the first N transformer encoder layers.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
num_layers: Number of layers to freeze (from the bottom)
|
| 450 |
+
"""
|
| 451 |
+
# Freeze embedding and positional encoding
|
| 452 |
+
for param in self.policy_model.embedding.parameters():
|
| 453 |
+
param.requires_grad = False
|
| 454 |
+
self.policy_model.pos_encoding.requires_grad = False
|
| 455 |
+
|
| 456 |
+
# Freeze specified number of transformer layers
|
| 457 |
+
for i, layer in enumerate(self.policy_model.transformer.layers):
|
| 458 |
+
if i < num_layers:
|
| 459 |
+
for param in layer.parameters():
|
| 460 |
+
param.requires_grad = False
|
| 461 |
+
print(f"Froze transformer layer {i}")
|
| 462 |
+
|
| 463 |
+
# Count trainable parameters
|
| 464 |
+
trainable = sum(p.numel() for p in self.policy_model.parameters() if p.requires_grad)
|
| 465 |
+
total = sum(p.numel() for p in self.policy_model.parameters())
|
| 466 |
+
print(f"Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
|
| 467 |
+
|
| 468 |
+
def _log_rewards_metrics(self, batch_group_rewards: torch.Tensor, prefix: str = "train/") -> None:
|
| 469 |
+
"""Log reward statistics for monitoring training progress.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
batch_group_rewards: Group rewards tensor [B, G]
|
| 473 |
+
prefix: Prefix for log keys (default: "train/")
|
| 474 |
+
"""
|
| 475 |
+
mean_r = batch_group_rewards.mean()
|
| 476 |
+
best = batch_group_rewards.max()
|
| 477 |
+
gap = best - mean_r
|
| 478 |
+
|
| 479 |
+
self.log(prefix + "avg_reward", mean_r, prog_bar=True)
|
| 480 |
+
self.log(prefix + "reward_std", batch_group_rewards.std())
|
| 481 |
+
self.log(prefix + "reward_p50", batch_group_rewards.median())
|
| 482 |
+
self.log(prefix + "reward_p90", batch_group_rewards.quantile(0.9))
|
| 483 |
+
self.log(prefix + "reward_best", best)
|
| 484 |
+
self.log(prefix + "reward_gap_best_minus_mean", gap)
|
| 485 |
+
|
| 486 |
+
def on_train_epoch_start(self) -> None:
|
| 487 |
+
"""Called at the start of each training epoch. Syncs old policy."""
|
| 488 |
+
self._sync_old_policy()
|
| 489 |
+
|
| 490 |
+
def _ppo_step(
|
| 491 |
+
self,
|
| 492 |
+
trajectories_states: torch.Tensor,
|
| 493 |
+
trajectories_actions: torch.Tensor,
|
| 494 |
+
trajectories_old_log_probs: torch.Tensor,
|
| 495 |
+
trajectories_legal_masks: torch.Tensor | None,
|
| 496 |
+
step_rewards: torch.Tensor,
|
| 497 |
+
effective_pad_mask: torch.Tensor,
|
| 498 |
+
) -> tuple[torch.Tensor, object]:
|
| 499 |
+
"""Perform a single PPO optimization step.
|
| 500 |
+
|
| 501 |
+
Args:
|
| 502 |
+
trajectories_states: State tensors [B, G, T, SEQ]
|
| 503 |
+
trajectories_actions: Action indices [B, G, T]
|
| 504 |
+
trajectories_old_log_probs: Log probs from old policy [B, G, T]
|
| 505 |
+
trajectories_legal_masks: Legal move masks [B, G, T, A] or None
|
| 506 |
+
step_rewards: Per-step rewards [B, G, T]
|
| 507 |
+
effective_pad_mask: Mask for valid steps [B, G, T]
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
Tuple of (loss, loss_info)
|
| 511 |
+
"""
|
| 512 |
+
# Compute new log probs with current policy
|
| 513 |
+
new_log_probs = self.policy_model.get_group_log_probs(
|
| 514 |
+
trajectories_states, trajectories_actions, trajectories_legal_masks
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Use current (possibly adapted) coefficients
|
| 518 |
+
kl_coef = self.kl_controller.current_kl_coef if self.kl_controller else self.hparams.grpo_config.kl_coef
|
| 519 |
+
|
| 520 |
+
loss, loss_info = grpo_ppo_loss(
|
| 521 |
+
new_log_probs,
|
| 522 |
+
trajectories_old_log_probs,
|
| 523 |
+
step_rewards,
|
| 524 |
+
effective_pad_mask,
|
| 525 |
+
clip_ratio=self.hparams.grpo_config.clip_ratio,
|
| 526 |
+
kl_coef=kl_coef,
|
| 527 |
+
entropy_coef=self.current_entropy_coef,
|
| 528 |
+
return_info=True,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if not torch.isfinite(loss):
|
| 532 |
+
raise ValueError(f"Non-finite loss encountered: {loss.item()}")
|
| 533 |
+
|
| 534 |
+
return loss, loss_info
|
| 535 |
+
|
| 536 |
+
def _run_safety_checks(self, loss_info) -> None:
|
| 537 |
+
"""Run safety checks on training dynamics and abort if they persistently fail."""
|
| 538 |
+
cfg = self.hparams.grpo_config
|
| 539 |
+
if not cfg.enable_safety_checks:
|
| 540 |
+
return
|
| 541 |
+
|
| 542 |
+
self._safety_step_idx += 1
|
| 543 |
+
|
| 544 |
+
# 1) PPO clipping saturation
|
| 545 |
+
if loss_info.mean_clip_fraction.item() > cfg.max_clip_fraction:
|
| 546 |
+
self._high_clip_steps += 1
|
| 547 |
+
else:
|
| 548 |
+
self._high_clip_steps = 0
|
| 549 |
+
|
| 550 |
+
# 2) Entropy collapse
|
| 551 |
+
if loss_info.entropy.item() < cfg.min_entropy:
|
| 552 |
+
self._low_entropy_steps += 1
|
| 553 |
+
else:
|
| 554 |
+
self._low_entropy_steps = 0
|
| 555 |
+
|
| 556 |
+
# 3) Excessive KL divergence
|
| 557 |
+
if loss_info.kl_div.item() > cfg.max_kl_divergence:
|
| 558 |
+
self._high_kl_steps += 1
|
| 559 |
+
else:
|
| 560 |
+
self._high_kl_steps = 0
|
| 561 |
+
|
| 562 |
+
# Log safety counters for debugging
|
| 563 |
+
self.log("safety/high_clip_steps", float(self._high_clip_steps))
|
| 564 |
+
self.log("safety/low_entropy_steps", float(self._low_entropy_steps))
|
| 565 |
+
self.log("safety/high_kl_steps", float(self._high_kl_steps))
|
| 566 |
+
|
| 567 |
+
if (
|
| 568 |
+
self._high_clip_steps >= cfg.safety_patience_steps
|
| 569 |
+
or self._low_entropy_steps >= cfg.safety_patience_steps
|
| 570 |
+
or self._high_kl_steps >= cfg.safety_patience_steps
|
| 571 |
+
):
|
| 572 |
+
raise RuntimeError(
|
| 573 |
+
"Safety checks triggered: training aborted due to persistent "
|
| 574 |
+
f"bad dynamics (clip={loss_info.mean_clip_fraction.item():.3f}, "
|
| 575 |
+
f"entropy={loss_info.entropy.item():.3f}, "
|
| 576 |
+
f"kl={loss_info.kl_div.item():.4f}). "
|
| 577 |
+
"Adjust GRPOConfig or investigate recent research docs."
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
def training_step(self, batch_fens: list[str], batch_idx: int) -> None:
|
| 581 |
+
"""Perform a training step with multiple PPO optimization iterations.
|
| 582 |
+
|
| 583 |
+
Samples trajectories once, then performs ppo_steps optimization iterations
|
| 584 |
+
on the same sampled data to improve compute efficiency.
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
batch_fens: List of FEN strings representing starting positions
|
| 588 |
+
batch_idx: Batch index (unused)
|
| 589 |
+
"""
|
| 590 |
+
opt = self.optimizers()
|
| 591 |
+
|
| 592 |
+
boards = [chess.Board(start_fen) for start_fen in batch_fens]
|
| 593 |
+
boards = [board for board in boards if not board.is_game_over()]
|
| 594 |
+
if not boards:
|
| 595 |
+
return # Skip if game over
|
| 596 |
+
|
| 597 |
+
trajectories_sample = sample_trajectories_batched(
|
| 598 |
+
self.old_policy_model,
|
| 599 |
+
boards,
|
| 600 |
+
self.hparams.grpo_config.num_trajectories,
|
| 601 |
+
self.hparams.grpo_config.trajectory_depth,
|
| 602 |
+
temperature=self.hparams.grpo_config.rollout_temperature,
|
| 603 |
+
teacher_forcing_prob=self.hparams.grpo_config.teacher_forcing_prob,
|
| 604 |
+
teacher_forcing_depth=self.hparams.grpo_config.teacher_forcing_depth,
|
| 605 |
+
)
|
| 606 |
+
if trajectories_sample is None:
|
| 607 |
+
return # Skip if no moves
|
| 608 |
+
|
| 609 |
+
# Extract trajectory components (sampled once, reused for ppo_steps)
|
| 610 |
+
trajectories_old_log_probs = trajectories_sample.trajectories_log_probs # [B, G, T]
|
| 611 |
+
trajectories_actions = trajectories_sample.trajectories_actions # [B, G, T]
|
| 612 |
+
trajectories_states = trajectories_sample.trajectories_states # [B, G, T, SEQ]
|
| 613 |
+
batch_group_rewards = trajectories_sample.group_rewards # [B, G] (for logging)
|
| 614 |
+
step_rewards = trajectories_sample.step_rewards # [B, G, T]
|
| 615 |
+
pad_mask = trajectories_sample.pad_mask # [B, G, T]
|
| 616 |
+
trajectories_legal_masks = trajectories_sample.trajectories_legal_masks # [B, G, T, A] or None
|
| 617 |
+
|
| 618 |
+
# Add starting player mask (only consider moves from the starting player's perspective)
|
| 619 |
+
_, _, T = pad_mask.shape
|
| 620 |
+
t = torch.arange(T, device=pad_mask.device)
|
| 621 |
+
start_player_mask = (t % 2 == 0)[None, None, :] # [1, 1, T]
|
| 622 |
+
effective_pad_mask = pad_mask & start_player_mask # [B, G, T]
|
| 623 |
+
|
| 624 |
+
ppo_steps = self.hparams.grpo_config.ppo_steps
|
| 625 |
+
|
| 626 |
+
# Perform multiple PPO optimization steps on the same sampled trajectories
|
| 627 |
+
for ppo_step_idx in range(ppo_steps):
|
| 628 |
+
loss, loss_info = self._ppo_step(
|
| 629 |
+
trajectories_states,
|
| 630 |
+
trajectories_actions,
|
| 631 |
+
trajectories_old_log_probs,
|
| 632 |
+
trajectories_legal_masks,
|
| 633 |
+
step_rewards,
|
| 634 |
+
effective_pad_mask,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# Manual optimization step
|
| 638 |
+
opt.zero_grad()
|
| 639 |
+
self.manual_backward(loss)
|
| 640 |
+
self.clip_gradients(opt, gradient_clip_val=1.0, gradient_clip_algorithm="norm")
|
| 641 |
+
opt.step()
|
| 642 |
+
|
| 643 |
+
# Entropy floor monitoring (Recommendation 1) - only on last ppo_step
|
| 644 |
+
if ppo_step_idx == ppo_steps - 1 and self.entropy_monitor is not None:
|
| 645 |
+
self.current_entropy_coef, entropy_metrics = self.entropy_monitor.check(
|
| 646 |
+
loss_info.entropy.item(), self.current_entropy_coef
|
| 647 |
+
)
|
| 648 |
+
for key, value in entropy_metrics.items():
|
| 649 |
+
self.log(key, value)
|
| 650 |
+
|
| 651 |
+
# Adaptive KL controller (Recommendation 2) - only on last ppo_step
|
| 652 |
+
if ppo_step_idx == ppo_steps - 1 and self.kl_controller is not None:
|
| 653 |
+
kl_metrics = self.kl_controller.update(loss_info.kl_div.item())
|
| 654 |
+
for key, value in kl_metrics.items():
|
| 655 |
+
self.log(key, value)
|
| 656 |
+
|
| 657 |
+
# Within-board group collapse metrics (Recommendation 4) - log once per training_step
|
| 658 |
+
collapse_metrics = compute_group_collapse_metrics(
|
| 659 |
+
trajectories_actions, batch_group_rewards, step_rewards, pad_mask
|
| 660 |
+
)
|
| 661 |
+
for key, value in collapse_metrics.items():
|
| 662 |
+
self.log(key, value)
|
| 663 |
+
|
| 664 |
+
# Standard logging (log final ppo_step metrics)
|
| 665 |
+
valid_mask = pad_mask.float() # [B, G, T] 1 = real step
|
| 666 |
+
|
| 667 |
+
self.log("train_total_loss", loss, prog_bar=True)
|
| 668 |
+
self.log("pad_fraction", 1.0 - valid_mask.mean())
|
| 669 |
+
self.log("avg_trajectory_length", pad_mask.float().sum(dim=-1).mean())
|
| 670 |
+
|
| 671 |
+
self.log("mean_kl_divergence", loss_info.kl_div)
|
| 672 |
+
self.log("mean_ratio", loss_info.mean_ratio)
|
| 673 |
+
self.log("mean_clip_fraction", loss_info.mean_clip_fraction)
|
| 674 |
+
self.log("ppo_loss", loss_info.ppo_loss)
|
| 675 |
+
self.log("entropy", loss_info.entropy)
|
| 676 |
+
# Loss without the entropy bonus term (PPO + KL only)
|
| 677 |
+
self.log("train/loss_without_entropy", loss_info.loss_without_entropy)
|
| 678 |
+
self.log("ppo_steps", float(ppo_steps))
|
| 679 |
+
self._log_rewards_metrics(batch_group_rewards, prefix="train/")
|
| 680 |
+
|
| 681 |
+
# Log step rewards statistics (only for valid steps)
|
| 682 |
+
valid_step_rewards = step_rewards[pad_mask]
|
| 683 |
+
self.log("train/step_reward_mean", valid_step_rewards.mean())
|
| 684 |
+
self.log("train/step_reward_std", valid_step_rewards.std())
|
| 685 |
+
|
| 686 |
+
# Log raw centipawn step rewards (before normalization) for debugging
|
| 687 |
+
raw_step_cp = trajectories_sample.raw_step_cp
|
| 688 |
+
valid_raw_step_cp = raw_step_cp[pad_mask]
|
| 689 |
+
self.log("train/raw_step_cp_mean", valid_raw_step_cp.mean())
|
| 690 |
+
self.log("train/raw_step_cp_std", valid_raw_step_cp.std())
|
| 691 |
+
self.log("train/raw_step_cp_abs_mean", valid_raw_step_cp.abs().mean())
|
| 692 |
+
|
| 693 |
+
# Run safety checks on the final loss statistics
|
| 694 |
+
self._run_safety_checks(loss_info)
|
| 695 |
+
|
| 696 |
+
def configure_optimizers(self) -> torch.optim.Adam:
|
| 697 |
+
"""Configure optimizer for training.
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
Adam optimizer with learning rate from GRPO config
|
| 701 |
+
"""
|
| 702 |
+
return torch.optim.Adam(self.parameters(), lr=self.hparams.grpo_config.lr)
|
| 703 |
+
|
| 704 |
+
def _evaluate_against_stockfish(self) -> Optional[tuple[dict, list[str]]]:
|
| 705 |
+
"""Run a single game evaluation against Stockfish with current policy model.
|
| 706 |
+
|
| 707 |
+
Returns:
|
| 708 |
+
Tuple of (results_dict, pgns) or None if evaluation failed
|
| 709 |
+
pgns is a list of PGN strings for all games played
|
| 710 |
+
"""
|
| 711 |
+
was_training = self.training
|
| 712 |
+
self.eval()
|
| 713 |
+
try:
|
| 714 |
+
with torch.no_grad():
|
| 715 |
+
results, _, pgns = self.evaluator.single_evaluation(self.policy_model)
|
| 716 |
+
return results, pgns
|
| 717 |
+
except Exception as e:
|
| 718 |
+
self.logger.warning(f"Evaluation against Stockfish failed: {e}") if hasattr(self, 'logger') else print(f"Evaluation against Stockfish failed: {e}")
|
| 719 |
+
return None
|
| 720 |
+
finally:
|
| 721 |
+
if was_training:
|
| 722 |
+
self.train()
|
| 723 |
+
|
| 724 |
+
def _log_stockfish_eval(self, results: dict) -> None:
|
| 725 |
+
"""Log scalar evaluation metrics from the Stockfish evaluation.
|
| 726 |
+
|
| 727 |
+
Args:
|
| 728 |
+
results: Dictionary containing evaluation results with keys:
|
| 729 |
+
- games: Total number of games played
|
| 730 |
+
- wins: Number of wins
|
| 731 |
+
- draws: Number of draws
|
| 732 |
+
- losses: Number of losses
|
| 733 |
+
- score: Win rate (0-1)
|
| 734 |
+
- elo_diff_vs_stockfish_approx: Approximate Elo difference
|
| 735 |
+
- termination_reasons: Dict mapping termination reasons to counts
|
| 736 |
+
"""
|
| 737 |
+
# Scalar stats
|
| 738 |
+
self.log("eval_stockfish/games", results["games"])
|
| 739 |
+
self.log("eval_stockfish/wins", results["wins"])
|
| 740 |
+
self.log("eval_stockfish/draws", results["draws"])
|
| 741 |
+
self.log("eval_stockfish/losses", results["losses"])
|
| 742 |
+
self.log("eval_stockfish/score", results["score"], prog_bar=True)
|
| 743 |
+
self.log("eval_stockfish/elo_diff", results["elo_diff_vs_stockfish_approx"], prog_bar=True)
|
| 744 |
+
|
| 745 |
+
# Termination reasons as fractions
|
| 746 |
+
games = results["games"] or 1
|
| 747 |
+
for reason, cnt in results["termination_reasons"].items():
|
| 748 |
+
frac = cnt / games
|
| 749 |
+
self.log(f"eval_stockfish/term_{reason}", frac)
|
| 750 |
+
|
| 751 |
+
def _log_pgns(self, pgns: list[str]) -> None:
|
| 752 |
+
"""Log PGNs to WandB as a text artifact.
|
| 753 |
+
|
| 754 |
+
Args:
|
| 755 |
+
pgns: List of PGN strings for all games played
|
| 756 |
+
"""
|
| 757 |
+
if not pgns:
|
| 758 |
+
return
|
| 759 |
+
|
| 760 |
+
# Combine all PGNs into a single string
|
| 761 |
+
combined_pgn = "\n\n".join(pgns)
|
| 762 |
+
|
| 763 |
+
# Log to WandB if available
|
| 764 |
+
if self.logger and hasattr(self.logger, 'experiment'):
|
| 765 |
+
try:
|
| 766 |
+
import wandb
|
| 767 |
+
# Log as a text artifact
|
| 768 |
+
self.logger.experiment.log({
|
| 769 |
+
"eval_stockfish/pgns": wandb.Html(f"<pre>{combined_pgn}</pre>"),
|
| 770 |
+
"eval_stockfish/pgn_text": combined_pgn,
|
| 771 |
+
})
|
| 772 |
+
except Exception as e:
|
| 773 |
+
print(f"Failed to log PGNs to WandB: {e}")
|
| 774 |
+
|
| 775 |
+
def on_train_epoch_end(self) -> None:
|
| 776 |
+
"""Called at the end of each training epoch. Runs evaluation if scheduled."""
|
| 777 |
+
if (self.current_epoch + 1) % self.eval_every_n_epochs == 0:
|
| 778 |
+
eval_result = self._evaluate_against_stockfish()
|
| 779 |
+
if eval_result is not None:
|
| 780 |
+
results, pgns = eval_result
|
| 781 |
+
self._log_stockfish_eval(results)
|
| 782 |
+
self._log_pgns(pgns)
|
hf_space_repo/grpo_logic/sampling.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import chess
|
| 5 |
+
import chess.engine
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
from src.grpo_self_play.chess.rewards import reward_board, evaluate_board, normalize_cp
|
| 11 |
+
from src.grpo_self_play.models import ChessTransformer
|
| 12 |
+
from src.grpo_self_play.searchless_chess_imports import ACTION_TO_MOVE, SEQUENCE_LENGTH, MOVE_TO_ACTION
|
| 13 |
+
from src.grpo_self_play.chess.chess_logic import board_to_tensor, get_legal_moves_mask
|
| 14 |
+
from src.grpo_self_play.chess.stockfish import stockfish_play, DEFAULT_STOCKFISH_TIMEOUT
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _get_teacher_engine_name() -> str:
|
| 18 |
+
"""Get process-specific engine name for teacher forcing."""
|
| 19 |
+
return f"teacher_forcing_{os.getpid()}"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_stockfish_move(board: chess.Board, depth: int = 4, timeout: float = DEFAULT_STOCKFISH_TIMEOUT) -> Optional[chess.Move]:
|
| 23 |
+
"""Get the best move from Stockfish for a given board position.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
board: Chess board position
|
| 27 |
+
depth: Stockfish search depth
|
| 28 |
+
timeout: Maximum time to wait for response (seconds)
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Best move from Stockfish, or None if no move available or on error
|
| 32 |
+
"""
|
| 33 |
+
limit = chess.engine.Limit(depth=depth)
|
| 34 |
+
return stockfish_play(_get_teacher_engine_name(), board, limit, timeout=timeout)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Trajectories sampling logic
|
| 38 |
+
@dataclass
|
| 39 |
+
class TrajectoriesSample:
|
| 40 |
+
"""Container for batched trajectory samples.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
trajectories_log_probs: Log probabilities of sampled actions [B, G, T]
|
| 44 |
+
trajectories_actions: Action indices [B, G, T]
|
| 45 |
+
trajectories_states: State tensors [B, G, T, SEQ]
|
| 46 |
+
group_rewards: Final rewards for each trajectory group [B, G] (for logging)
|
| 47 |
+
step_rewards: Per-step rewards [B, G, T] where step_rewards[b,g,t] = eval(s_{t+1}) - eval(s_t)
|
| 48 |
+
pad_mask: Mask indicating valid steps, True=valid, False=padding [B, G, T]
|
| 49 |
+
trajectories_legal_masks: Legal moves masks [B, G, T, A]
|
| 50 |
+
raw_step_cp: Raw centipawn step rewards [B, G, T] (for logging, not normalized)
|
| 51 |
+
"""
|
| 52 |
+
trajectories_log_probs: torch.Tensor # [B, G, T]
|
| 53 |
+
trajectories_actions: torch.Tensor # [B, G, T]
|
| 54 |
+
trajectories_states: torch.Tensor # [B, G, T, SEQ]
|
| 55 |
+
group_rewards: torch.Tensor # [B, G]
|
| 56 |
+
step_rewards: torch.Tensor # [B, G, T]
|
| 57 |
+
pad_mask: torch.Tensor # [B, G, T]
|
| 58 |
+
trajectories_legal_masks: torch.Tensor # [B, G, T, A]
|
| 59 |
+
raw_step_cp: torch.Tensor # [B, G, T] - raw centipawn differences
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def batched_policy_step(model: ChessTransformer, boards: List[chess.Board], temperature: float = 1.0) -> Optional[tuple]:
|
| 63 |
+
"""Sample actions from policy for a batch of boards.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
model: Chess transformer model
|
| 67 |
+
boards: List of chess board positions
|
| 68 |
+
temperature: Temperature for sampling
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (action_indices, log_probs, moves, states_tensor, legal_mask) or None if empty
|
| 72 |
+
"""
|
| 73 |
+
N = len(boards)
|
| 74 |
+
if N == 0:
|
| 75 |
+
return None
|
| 76 |
+
device = next(model.parameters()).device
|
| 77 |
+
states_list = []
|
| 78 |
+
legal_masks = []
|
| 79 |
+
for board in boards:
|
| 80 |
+
state = board_to_tensor(board, device=device)
|
| 81 |
+
states_list.append(state)
|
| 82 |
+
mask = get_legal_moves_mask(board, device=device)
|
| 83 |
+
if mask.ndim == 2:
|
| 84 |
+
mask = mask.squeeze(0)
|
| 85 |
+
assert mask.ndim == 1, f"legal_moves_mask must be 1D [A], got {mask.shape}"
|
| 86 |
+
legal_masks.append(mask)
|
| 87 |
+
|
| 88 |
+
states_tensor = torch.cat(states_list, dim=0) # [N, SEQ]
|
| 89 |
+
legal_mask = torch.stack(legal_masks, dim=0) # [N, A] bool
|
| 90 |
+
assert legal_mask.dtype == torch.bool, "legal_mask must be bool dtype"
|
| 91 |
+
assert legal_mask.shape[0] == N, f"legal_mask batch size mismatch {legal_mask.shape[0]} vs {N}"
|
| 92 |
+
assert legal_mask.shape[1] == model.action_size, f"legal_mask action size mismatch {legal_mask.shape[1]} vs {model.action_size}"
|
| 93 |
+
if not legal_mask.any(dim=1).all():
|
| 94 |
+
bad = (~legal_mask.any(dim=1)).nonzero(as_tuple=False).flatten().tolist()
|
| 95 |
+
raise ValueError(f"Empty legal mask for boards: {bad}")
|
| 96 |
+
probs = model.get_legal_moves_probs(states_tensor, legal_mask, temperature) # [N, O]
|
| 97 |
+
|
| 98 |
+
action_idx = torch.multinomial(probs, 1).squeeze(1) # [N,]
|
| 99 |
+
chosen_probs = probs.gather(1, action_idx.unsqueeze(1)).squeeze(1) # [N,]
|
| 100 |
+
chosen_log_probs = torch.log(chosen_probs + 1e-12) # [N,], avoid log(0)
|
| 101 |
+
|
| 102 |
+
# Convert action indices to moves, ensure legality
|
| 103 |
+
moves = []
|
| 104 |
+
for i, idx in enumerate(action_idx.tolist()):
|
| 105 |
+
uci = ACTION_TO_MOVE[idx]
|
| 106 |
+
move = chess.Move.from_uci(uci)
|
| 107 |
+
if move not in boards[i].legal_moves:
|
| 108 |
+
raise ValueError(f"Sampled illegal move {uci} for board:\n{boards[i]}")
|
| 109 |
+
moves.append(move)
|
| 110 |
+
return action_idx, chosen_log_probs, moves, states_tensor, legal_mask
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def sample_trajectories_batched(model: ChessTransformer,
|
| 114 |
+
boards: List[chess.Board],
|
| 115 |
+
num_trajectories: int,
|
| 116 |
+
trajectory_depth: int,
|
| 117 |
+
reward_depth: int = 4,
|
| 118 |
+
temperature: float = 1.0,
|
| 119 |
+
teacher_forcing_prob: float = 0.0,
|
| 120 |
+
teacher_forcing_depth: int = 4) -> Optional[TrajectoriesSample]:
|
| 121 |
+
"""Sample multiple trajectories from each board position using the policy model.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model: Chess transformer model for action selection
|
| 125 |
+
boards: List of starting board positions [B]
|
| 126 |
+
num_trajectories: Number of trajectory groups per board (G)
|
| 127 |
+
trajectory_depth: Maximum depth of each trajectory (T)
|
| 128 |
+
reward_depth: Stockfish depth for reward computation (default: 4)
|
| 129 |
+
temperature: Temperature for action sampling (default: 1.0, >1 increases exploration)
|
| 130 |
+
teacher_forcing_prob: Probability of using Stockfish for rival moves (default: 0.0)
|
| 131 |
+
teacher_forcing_depth: Stockfish depth for teacher forcing moves (default: 4)
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
TrajectoriesSample containing batched trajectory data, or None if no boards
|
| 135 |
+
"""
|
| 136 |
+
device = next(model.parameters()).device
|
| 137 |
+
B, G, T = len(boards), num_trajectories, trajectory_depth
|
| 138 |
+
if B == 0:
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Create B*G copies of boards for parallel trajectory sampling
|
| 142 |
+
envs = [boards[b].copy() for b in range(B) for _ in range(G)] # Length of B*G
|
| 143 |
+
# Per (b, g) storage as nested lists
|
| 144 |
+
traj_log_probs = [[[] for _ in range(G)] for _ in range(B)]
|
| 145 |
+
traj_actions = [[[] for _ in range(G)] for _ in range(B)]
|
| 146 |
+
traj_states = [[[] for _ in range(G)] for _ in range(B)]
|
| 147 |
+
traj_legal_masks = [[[] for _ in range(G)] for _ in range(B)]
|
| 148 |
+
traj_step_rewards = [[[] for _ in range(G)] for _ in range(B)]
|
| 149 |
+
traj_raw_step_cp = [[[] for _ in range(G)] for _ in range(B)] # Raw centipawn differences for logging
|
| 150 |
+
|
| 151 |
+
# Track POV and previous raw eval for each trajectory (we normalize step rewards later)
|
| 152 |
+
pov_is_white = [(boards[b].turn == chess.WHITE) for b in range(B) for _ in range(G)]
|
| 153 |
+
prev_evals_raw = [evaluate_board(boards[b], pov_is_white[b * G], depth=reward_depth, normalize=False)
|
| 154 |
+
for b in range(B) for _ in range(G)]
|
| 155 |
+
|
| 156 |
+
# Rollout: sample trajectories in batches
|
| 157 |
+
for t in range(T):
|
| 158 |
+
active_env_idx = [i for i, e in enumerate(envs) if not e.is_game_over()]
|
| 159 |
+
if not active_env_idx:
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# Determine if this is the rival's turn (odd timesteps)
|
| 163 |
+
is_rival_turn = (t % 2 == 1)
|
| 164 |
+
use_teacher_forcing = is_rival_turn and teacher_forcing_prob > 0 and random.random() < teacher_forcing_prob
|
| 165 |
+
|
| 166 |
+
active_boards = [envs[i] for i in active_env_idx]
|
| 167 |
+
roll_out_step = batched_policy_step(model, active_boards, temperature=temperature)
|
| 168 |
+
if roll_out_step is None:
|
| 169 |
+
break
|
| 170 |
+
|
| 171 |
+
action_indices, log_probs, moves, states_batch, legal_mask = roll_out_step
|
| 172 |
+
if action_indices is None:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
for j, env_idx_j in enumerate(active_env_idx):
|
| 176 |
+
move_j = moves[j]
|
| 177 |
+
if move_j is None:
|
| 178 |
+
continue # End of game for this env
|
| 179 |
+
b_idx = env_idx_j // G
|
| 180 |
+
g_idx = env_idx_j % G
|
| 181 |
+
state_j = states_batch[j]
|
| 182 |
+
|
| 183 |
+
# Teacher forcing: override rival's move with Stockfish
|
| 184 |
+
if use_teacher_forcing:
|
| 185 |
+
sf_move = get_stockfish_move(envs[env_idx_j], depth=teacher_forcing_depth)
|
| 186 |
+
if sf_move is not None and sf_move in envs[env_idx_j].legal_moves:
|
| 187 |
+
move_j = sf_move
|
| 188 |
+
# Update action index to match the Stockfish move
|
| 189 |
+
action_indices[j] = MOVE_TO_ACTION[move_j.uci()]
|
| 190 |
+
|
| 191 |
+
traj_log_probs[b_idx][g_idx].append(log_probs[j])
|
| 192 |
+
traj_actions[b_idx][g_idx].append(int(action_indices[j].item()))
|
| 193 |
+
traj_states[b_idx][g_idx].append(state_j)
|
| 194 |
+
traj_legal_masks[b_idx][g_idx].append(legal_mask[j])
|
| 195 |
+
envs[env_idx_j].push(move_j)
|
| 196 |
+
|
| 197 |
+
# Compute step reward: eval(new_state) - eval(prev_state)
|
| 198 |
+
# Get raw centipawn value, then normalize for step_rewards
|
| 199 |
+
new_eval_raw = evaluate_board(envs[env_idx_j], pov_is_white[env_idx_j], depth=reward_depth, normalize=False)
|
| 200 |
+
raw_step_cp = new_eval_raw - prev_evals_raw[env_idx_j]
|
| 201 |
+
step_reward = normalize_cp(new_eval_raw) - normalize_cp(prev_evals_raw[env_idx_j])
|
| 202 |
+
traj_step_rewards[b_idx][g_idx].append(step_reward)
|
| 203 |
+
traj_raw_step_cp[b_idx][g_idx].append(raw_step_cp)
|
| 204 |
+
prev_evals_raw[env_idx_j] = new_eval_raw
|
| 205 |
+
|
| 206 |
+
# Compute group_rewards for logging (sum of step rewards = final - initial)
|
| 207 |
+
group_rewards = torch.zeros(B, G, dtype=torch.float32, device=device)
|
| 208 |
+
for env_idx, env in enumerate(envs):
|
| 209 |
+
b_idx = env_idx // G
|
| 210 |
+
g_idx = env_idx % G
|
| 211 |
+
group_rewards[b_idx, g_idx] = reward_board(env, boards[b_idx], depth=reward_depth, movetime_ms=0)
|
| 212 |
+
|
| 213 |
+
# Allocate padded tensors
|
| 214 |
+
trajectories_log_probs = torch.zeros(B, G, T, dtype=torch.float32, device=device)
|
| 215 |
+
trajectories_actions = torch.zeros(B, G, T, dtype=torch.long, device=device)
|
| 216 |
+
trajectories_states = torch.zeros(B, G, T, SEQUENCE_LENGTH, dtype=torch.long, device=device)
|
| 217 |
+
trajectories_legal_masks = torch.zeros(B, G, T, model.action_size, dtype=torch.bool, device=device)
|
| 218 |
+
trajectories_legal_masks[..., 0] = True # Ensure at least one legal move (to avoid empty legal masks -> NaNs in log_softmax)
|
| 219 |
+
step_rewards = torch.zeros(B, G, T, dtype=torch.float32, device=device)
|
| 220 |
+
raw_step_cp = torch.zeros(B, G, T, dtype=torch.float32, device=device)
|
| 221 |
+
pad_mask = torch.zeros(B, G, T, dtype=torch.bool, device=device)
|
| 222 |
+
for b in range(B):
|
| 223 |
+
for g in range(G):
|
| 224 |
+
L = len(traj_log_probs[b][g])
|
| 225 |
+
assert L <= T, f"Trajectory length {L} exceeds pad_length {T}"
|
| 226 |
+
pad_mask[b, g, :L] = True
|
| 227 |
+
trajectories_log_probs[b, g, :L] = torch.stack(traj_log_probs[b][g], dim=0)
|
| 228 |
+
trajectories_actions[b, g, :L] = torch.tensor(traj_actions[b][g], dtype=torch.long, device=device)
|
| 229 |
+
trajectories_states[b, g, :L] = torch.stack(traj_states[b][g], dim=0)
|
| 230 |
+
if L > 0:
|
| 231 |
+
trajectories_legal_masks[b, g, :L] = torch.stack(traj_legal_masks[b][g], dim=0)
|
| 232 |
+
step_rewards[b, g, :L] = torch.tensor(traj_step_rewards[b][g], dtype=torch.float32, device=device)
|
| 233 |
+
raw_step_cp[b, g, :L] = torch.tensor(traj_raw_step_cp[b][g], dtype=torch.float32, device=device)
|
| 234 |
+
|
| 235 |
+
return TrajectoriesSample(trajectories_log_probs,
|
| 236 |
+
trajectories_actions,
|
| 237 |
+
trajectories_states,
|
| 238 |
+
group_rewards,
|
| 239 |
+
step_rewards,
|
| 240 |
+
pad_mask,
|
| 241 |
+
trajectories_legal_masks,
|
| 242 |
+
raw_step_cp)
|
| 243 |
+
|
hf_space_repo/logging_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging utilities for GRPO training.
|
| 2 |
+
|
| 3 |
+
Uses Python's standard logging module which WandB captures automatically
|
| 4 |
+
in the Logs tab of a run.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
_initialized_loggers = set()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_logger(name: str = "grpo_chess") -> logging.Logger:
|
| 13 |
+
"""Get a logger that appears in WandB Logs tab.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
name: Logger name (default: "grpo_chess")
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Configured logger instance
|
| 20 |
+
"""
|
| 21 |
+
logger = logging.getLogger(name)
|
| 22 |
+
|
| 23 |
+
if name not in _initialized_loggers:
|
| 24 |
+
logger.setLevel(logging.INFO)
|
| 25 |
+
handler = logging.StreamHandler()
|
| 26 |
+
handler.setFormatter(logging.Formatter(
|
| 27 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 28 |
+
))
|
| 29 |
+
logger.addHandler(handler)
|
| 30 |
+
_initialized_loggers.add(name)
|
| 31 |
+
|
| 32 |
+
return logger
|
hf_space_repo/models.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import chess
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from src.grpo_self_play.searchless_chess_imports import ACTION_TO_MOVE
|
| 9 |
+
from src.grpo_self_play.chess.chess_logic import board_to_tensor, get_legal_moves_indices
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class ChessTransformerConfig:
|
| 13 |
+
"""Configuration for the Chess Transformer model.
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
vocab_size: Size of the vocabulary (token dictionary)
|
| 17 |
+
embed_dim: Embedding dimension for transformer
|
| 18 |
+
num_layers: Number of transformer encoder layers
|
| 19 |
+
num_heads: Number of attention heads
|
| 20 |
+
action_dim: Dimension of action space (number of possible moves)
|
| 21 |
+
"""
|
| 22 |
+
vocab_size: int = 300
|
| 23 |
+
embed_dim: int = 256
|
| 24 |
+
num_layers: int = 4
|
| 25 |
+
num_heads: int = 8
|
| 26 |
+
action_dim: int = 1968
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 30 |
+
torch.serialization.add_safe_globals([ChessTransformerConfig])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ChessTransformer(nn.Module):
|
| 34 |
+
"""Transformer-based chess policy network.
|
| 35 |
+
|
| 36 |
+
Takes FEN-encoded board states as input and outputs action logits.
|
| 37 |
+
Uses a transformer encoder with learnable positional encodings.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, transformer_config: ChessTransformerConfig):
|
| 40 |
+
"""
|
| 41 |
+
Initialize Chess Transformer.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
transformer_config: Configuration for the transformer model
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
vocab_size = transformer_config.vocab_size
|
| 48 |
+
embed_dim = transformer_config.embed_dim
|
| 49 |
+
num_layers = transformer_config.num_layers
|
| 50 |
+
num_heads = transformer_config.num_heads
|
| 51 |
+
action_dim = transformer_config.action_dim
|
| 52 |
+
|
| 53 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 54 |
+
|
| 55 |
+
# DeepMind uses absolute or relative pos encoding.
|
| 56 |
+
# For simplicity, we use learnable absolute encoding for FEN length (~80 chars)
|
| 57 |
+
self.pos_encoding = nn.Parameter(torch.randn(1, 128, embed_dim))
|
| 58 |
+
|
| 59 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
|
| 60 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 61 |
+
|
| 62 |
+
# Head outputs 1968 logits (one for each possible unique move type)
|
| 63 |
+
self.policy_head = nn.Sequential(
|
| 64 |
+
nn.Linear(embed_dim, embed_dim),
|
| 65 |
+
nn.ReLU(),
|
| 66 |
+
nn.Linear(embed_dim, action_dim)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
"""Forward pass through the transformer.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
x: Input tensor of token IDs [batch, seq_len]
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Action logits [batch, action_dim]
|
| 77 |
+
"""
|
| 78 |
+
batch, seq = x.shape
|
| 79 |
+
|
| 80 |
+
# Create padding mask: True indicates a masked position (padding token 0)
|
| 81 |
+
src_key_padding_mask = (x == 0)
|
| 82 |
+
x = self.embedding(x) + self.pos_encoding[:, :seq, :]
|
| 83 |
+
|
| 84 |
+
# Pass the padding mask to the transformer
|
| 85 |
+
out = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
|
| 86 |
+
|
| 87 |
+
# Pool: Mean of the non-masked tokens
|
| 88 |
+
mask = ~src_key_padding_mask
|
| 89 |
+
mask_expanded = mask.unsqueeze(-1).float() # [B, SEQ, 1]
|
| 90 |
+
pooled = (out * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1).clamp_min(1)
|
| 91 |
+
|
| 92 |
+
return self.policy_head(pooled)
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def device(self) -> torch.device:
|
| 96 |
+
"""Get the device of the model parameters."""
|
| 97 |
+
return next(self.parameters()).device
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def action_size(self) -> int:
|
| 101 |
+
"""Get the size of the action space."""
|
| 102 |
+
return self.policy_head[-1].out_features
|
| 103 |
+
|
| 104 |
+
def get_legal_moves_logits(self, tensor_state: torch.Tensor,
|
| 105 |
+
legal_moves_mask: torch.Tensor,
|
| 106 |
+
temperature: float = 1.0) -> torch.Tensor:
|
| 107 |
+
"""Get logits for legal moves only, masking illegal moves.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
tensor_state: Board state tensor [B, SEQ]
|
| 111 |
+
legal_moves_mask: Boolean mask for legal moves [B, A]
|
| 112 |
+
temperature: Temperature for scaling logits
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Masked logits [B, A] with illegal moves set to -inf
|
| 116 |
+
"""
|
| 117 |
+
assert legal_moves_mask is not None, "legal_moves_mask cannot be None"
|
| 118 |
+
logits = self(tensor_state) / temperature
|
| 119 |
+
return logits.masked_fill(~legal_moves_mask, -float('inf'))
|
| 120 |
+
|
| 121 |
+
def get_legal_moves_probs(self, tensor_state: torch.Tensor,
|
| 122 |
+
legal_moves_mask: torch.Tensor,
|
| 123 |
+
temperature: float = 1.0) -> torch.Tensor:
|
| 124 |
+
"""Get probability distribution over legal moves.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
tensor_state: Board state tensor [B, SEQ]
|
| 128 |
+
legal_moves_mask: Boolean mask for legal moves [B, A]
|
| 129 |
+
temperature: Temperature for scaling logits
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Probability distribution [B, A] over legal moves
|
| 133 |
+
"""
|
| 134 |
+
mask_logits = self.get_legal_moves_logits(tensor_state, legal_moves_mask, temperature)
|
| 135 |
+
return F.softmax(mask_logits, dim=-1)
|
| 136 |
+
|
| 137 |
+
def get_group_log_probs(self,
|
| 138 |
+
trajectories_states: torch.Tensor,
|
| 139 |
+
action_idx: torch.Tensor,
|
| 140 |
+
legal_moves_mask: torch.Tensor,
|
| 141 |
+
temperature: float = 1.0) -> torch.Tensor:
|
| 142 |
+
"""Get log probabilities for actions in batched trajectories.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
trajectories_states: State tensors [B, G, T, SEQ]
|
| 146 |
+
action_idx: Action indices [B, G, T]
|
| 147 |
+
legal_moves_mask: Legal moves mask [B, G, T, A]
|
| 148 |
+
temperature: Temperature for scaling logits
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Log probabilities [B, G, T] for the selected actions
|
| 152 |
+
"""
|
| 153 |
+
assert legal_moves_mask is not None, "legal_moves_mask cannot be None"
|
| 154 |
+
assert legal_moves_mask.dtype == torch.bool, "legal_moves_mask must be bool dtype"
|
| 155 |
+
x = trajectories_states # [B, G, T, SEQ]
|
| 156 |
+
B, G, T, L = x.shape
|
| 157 |
+
x_flat = x.view(B * G * T, L) # [B*G*T, SEQ]
|
| 158 |
+
if legal_moves_mask is not None:
|
| 159 |
+
legal_moves_mask = legal_moves_mask.view(B * G * T, -1) # [B*G*T, O]
|
| 160 |
+
masked_logits = self.get_legal_moves_logits(x_flat, legal_moves_mask, temperature) # [B*G*T, O]
|
| 161 |
+
log_probs_all = F.log_softmax(masked_logits, dim=-1) # [B*G*T, O]
|
| 162 |
+
|
| 163 |
+
action_idx_flat = action_idx.view(B * G * T, 1) # [B*G*T, 1]
|
| 164 |
+
log_probs_flat = log_probs_all.gather(1, action_idx_flat).squeeze(-1) # [B*G*T]
|
| 165 |
+
log_probs = log_probs_flat.view(B, G, T) # [B, G, T]
|
| 166 |
+
return log_probs
|
| 167 |
+
|
| 168 |
+
def _get_action_logits(self, board: chess.Board, temperature: float = 1.0) -> Optional[torch.Tensor]:
|
| 169 |
+
"""Get action logits for a single board position.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
board: Chess board position
|
| 173 |
+
temperature: Temperature for scaling logits
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Logits tensor [1, action_dim] or None if no legal moves
|
| 177 |
+
"""
|
| 178 |
+
legal_moves = list(board.legal_moves)
|
| 179 |
+
legal_indices = get_legal_moves_indices(board)
|
| 180 |
+
|
| 181 |
+
if not legal_moves:
|
| 182 |
+
return None
|
| 183 |
+
|
| 184 |
+
# Run model
|
| 185 |
+
state = board_to_tensor(board, device=self.device)
|
| 186 |
+
logits = self(state) # [1, O]
|
| 187 |
+
|
| 188 |
+
output = torch.full_like(logits, -float('inf'))
|
| 189 |
+
output[0, legal_indices] = logits[0, legal_indices] / temperature
|
| 190 |
+
return output
|
| 191 |
+
|
| 192 |
+
def select_action(self, board: chess.Board, temperature: float = 1.0) -> tuple[Optional[chess.Move], Optional[torch.Tensor], Optional[int]]:
|
| 193 |
+
"""Sample an action from the policy for a given board position.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
board: Chess board position
|
| 197 |
+
temperature: Temperature for sampling (higher = more random)
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Tuple of (move, log_prob, action_idx) or (None, None, None) if no legal moves
|
| 201 |
+
"""
|
| 202 |
+
logits = self._get_action_logits(board, temperature)
|
| 203 |
+
if logits is None:
|
| 204 |
+
return None, None, None
|
| 205 |
+
logits = logits.squeeze(0) # Remove batch dimension
|
| 206 |
+
probs = F.softmax(logits, dim=0)
|
| 207 |
+
|
| 208 |
+
# Sample
|
| 209 |
+
action_idx = int(torch.multinomial(probs, 1).item())
|
| 210 |
+
chosen_move = ACTION_TO_MOVE[action_idx]
|
| 211 |
+
log_prob = torch.log(probs[action_idx] + 1e-12) # Avoid log(0)
|
| 212 |
+
|
| 213 |
+
return chess.Move.from_uci(chosen_move), log_prob, action_idx
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def select_action_greedy(model: ChessTransformer, board: chess.Board, temperature: float = 1.0) -> Optional[chess.Move]:
|
| 217 |
+
"""Select the best action greedily (no sampling).
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
model: Chess transformer model
|
| 221 |
+
board: Chess board position
|
| 222 |
+
temperature: Temperature for scaling logits (unused in greedy selection)
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
Best move or None if no legal moves
|
| 226 |
+
"""
|
| 227 |
+
logits = model._get_action_logits(board, temperature)
|
| 228 |
+
if logits is None:
|
| 229 |
+
return None
|
| 230 |
+
logits = logits.squeeze(0) # Remove batch dimension
|
| 231 |
+
probs = F.softmax(logits, dim=0)
|
| 232 |
+
action_idx = int(torch.argmax(probs).item())
|
| 233 |
+
chosen_move = ACTION_TO_MOVE[action_idx]
|
| 234 |
+
return chess.Move.from_uci(chosen_move)
|
hf_space_repo/pretrain/README.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chess Model Pretraining
|
| 2 |
+
|
| 3 |
+
This module provides supervised pretraining on expert chess moves from Lichess games before GRPO reinforcement learning fine-tuning.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The pretraining pipeline:
|
| 8 |
+
1. Streams chess games from HuggingFace (`Lichess/standard-chess-games`)
|
| 9 |
+
2. Filters by player ELO rating
|
| 10 |
+
3. Extracts positions and moves from games
|
| 11 |
+
4. Trains the ChessTransformer with cross-entropy loss on expert moves
|
| 12 |
+
5. Saves checkpoints compatible with GRPO training
|
| 13 |
+
|
| 14 |
+
## Quick Start
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
# Run pretraining with default config
|
| 18 |
+
python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
|
| 19 |
+
|
| 20 |
+
# With custom parameters
|
| 21 |
+
python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml \
|
| 22 |
+
--lr 1e-4 --batch_size 512 --min_elo 1800
|
| 23 |
+
|
| 24 |
+
# Disable wandb logging
|
| 25 |
+
python -m src.grpo_self_play.pretrain.pretrain --no_wandb
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Configuration
|
| 29 |
+
|
| 30 |
+
Configuration is in `src/grpo_self_play/configs/pretrain.yaml`:
|
| 31 |
+
|
| 32 |
+
```yaml
|
| 33 |
+
pretrain:
|
| 34 |
+
lr: 0.0001 # Learning rate
|
| 35 |
+
batch_size: 256 # Batch size
|
| 36 |
+
num_epochs: 1 # Number of epochs
|
| 37 |
+
warmup_steps: 1000 # Linear warmup steps
|
| 38 |
+
weight_decay: 0.01 # AdamW weight decay
|
| 39 |
+
max_grad_norm: 1.0 # Gradient clipping
|
| 40 |
+
label_smoothing: 0.1 # Prevents overconfidence
|
| 41 |
+
val_check_interval: 0.1 # Validate every 10% of epoch
|
| 42 |
+
|
| 43 |
+
dataset:
|
| 44 |
+
min_elo: 1800 # Minimum player rating
|
| 45 |
+
skip_first_n_moves: 5 # Skip opening moves
|
| 46 |
+
skip_last_n_moves: 5 # Skip endgame moves
|
| 47 |
+
sample_positions_per_game: 3 # Positions per game
|
| 48 |
+
eval_fraction: 0.05 # 5% held out for evaluation
|
| 49 |
+
|
| 50 |
+
transformer:
|
| 51 |
+
embed_dim: 256
|
| 52 |
+
num_layers: 4
|
| 53 |
+
num_heads: 8
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Train/Eval Split
|
| 57 |
+
|
| 58 |
+
The dataset uses a **hash-based deterministic split** to ensure:
|
| 59 |
+
- No data leakage between training and evaluation
|
| 60 |
+
- Consistent splits across runs
|
| 61 |
+
- Process-safe multi-worker data loading
|
| 62 |
+
|
| 63 |
+
Games are assigned to train or eval based on:
|
| 64 |
+
```python
|
| 65 |
+
is_eval = hash(game_site_url) % 10000 < (eval_fraction * 10000)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
This means the same game always goes to the same split, regardless of worker or epoch.
|
| 69 |
+
|
| 70 |
+
## Using Pretrained Weights in GRPO
|
| 71 |
+
|
| 72 |
+
After pretraining, use the checkpoint for GRPO fine-tuning by updating `default.yaml`:
|
| 73 |
+
|
| 74 |
+
```yaml
|
| 75 |
+
pretrain:
|
| 76 |
+
checkpoint_path: "checkpoints/pretrain/pretrain_final.pt"
|
| 77 |
+
freeze_layers: 0 # Optional: freeze first N transformer layers
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Or pass the path when running training:
|
| 81 |
+
```bash
|
| 82 |
+
python -m src.grpo_self_play.train_self_play --config default.yaml
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## Module Structure
|
| 86 |
+
|
| 87 |
+
```
|
| 88 |
+
pretrain/
|
| 89 |
+
├── __init__.py # Package exports
|
| 90 |
+
├── pretrain.py # PyTorch Lightning training module
|
| 91 |
+
├── pretrain_dataset.py # Streaming dataset from HuggingFace
|
| 92 |
+
├── pretrain_load_config.py # Config for loading pretrained weights
|
| 93 |
+
└── README.md # This file
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Key Classes
|
| 97 |
+
|
| 98 |
+
### PretrainChessTransformer
|
| 99 |
+
|
| 100 |
+
PyTorch Lightning module that wraps the ChessTransformer for supervised learning.
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
from src.grpo_self_play.pretrain.pretrain import PretrainChessTransformer, PretrainConfig
|
| 104 |
+
from src.grpo_self_play.models import ChessTransformerConfig
|
| 105 |
+
|
| 106 |
+
model = PretrainChessTransformer(
|
| 107 |
+
transformer_config=ChessTransformerConfig(embed_dim=256, num_layers=4, num_heads=8),
|
| 108 |
+
pretrain_config=PretrainConfig(lr=1e-4, batch_size=256),
|
| 109 |
+
)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### ChessPretrainDataset
|
| 113 |
+
|
| 114 |
+
Streaming dataset that yields (board_tokens, action, legal_mask) tuples.
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
from src.grpo_self_play.pretrain import ChessPretrainDataset, PretrainDatasetConfig
|
| 118 |
+
|
| 119 |
+
dataset = ChessPretrainDataset(PretrainDatasetConfig(
|
| 120 |
+
min_elo=1800,
|
| 121 |
+
is_eval=False, # True for evaluation set
|
| 122 |
+
))
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Metrics
|
| 126 |
+
|
| 127 |
+
The following metrics are logged during training:
|
| 128 |
+
|
| 129 |
+
| Metric | Description |
|
| 130 |
+
|--------|-------------|
|
| 131 |
+
| `train/loss` | Cross-entropy loss with label smoothing |
|
| 132 |
+
| `train/accuracy` | Top-1 move prediction accuracy |
|
| 133 |
+
| `train/top5_accuracy` | Top-5 move prediction accuracy |
|
| 134 |
+
| `train/entropy` | Policy entropy (confidence measure) |
|
| 135 |
+
| `train/perplexity` | Exponential of loss |
|
| 136 |
+
|
| 137 |
+
## Tests
|
| 138 |
+
|
| 139 |
+
Run the test suite:
|
| 140 |
+
```bash
|
| 141 |
+
pytest tests/test_pretrain_pipeline.py -v
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Tests cover:
|
| 145 |
+
- Configuration dataclasses
|
| 146 |
+
- PGN move parsing
|
| 147 |
+
- Position extraction from games
|
| 148 |
+
- UCI to action conversion
|
| 149 |
+
- Collate function
|
| 150 |
+
- Model creation and forward pass
|
| 151 |
+
- Training and validation steps
|
| 152 |
+
- Hash-based train/eval splitting
|
| 153 |
+
- Integration with PyTorch Lightning
|
hf_space_repo/pretrain/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretraining module for chess model."""
|
| 2 |
+
|
| 3 |
+
from src.grpo_self_play.pretrain.pretrain_load_config import PretrainLoadConfig
|
| 4 |
+
from src.grpo_self_play.pretrain.pretrain_dataset import (
|
| 5 |
+
ChessPretrainDataset,
|
| 6 |
+
PretrainDatasetConfig,
|
| 7 |
+
collate_pretrain_batch,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"PretrainLoadConfig",
|
| 12 |
+
"ChessPretrainDataset",
|
| 13 |
+
"PretrainDatasetConfig",
|
| 14 |
+
"collate_pretrain_batch",
|
| 15 |
+
]
|
hf_space_repo/pretrain/pretrain.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretraining script for chess model on Lichess games using PyTorch Lightning.
|
| 2 |
+
|
| 3 |
+
This script trains the ChessTransformer model using supervised learning
|
| 4 |
+
on expert moves from Lichess games before GRPO reinforcement learning.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml
|
| 8 |
+
|
| 9 |
+
# Or with overrides:
|
| 10 |
+
python -m src.grpo_self_play.pretrain.pretrain --config pretrain.yaml \
|
| 11 |
+
--lr 1e-4 --batch_size 512 --min_elo 1800
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import pytorch_lightning as pl
|
| 22 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 23 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 24 |
+
from torch.utils.data import DataLoader
|
| 25 |
+
|
| 26 |
+
from src.grpo_self_play.models import ChessTransformer, ChessTransformerConfig
|
| 27 |
+
from src.grpo_self_play.pretrain.pretrain_dataset import (
|
| 28 |
+
ChessPretrainDataset,
|
| 29 |
+
PretrainDatasetConfig,
|
| 30 |
+
collate_pretrain_batch,
|
| 31 |
+
)
|
| 32 |
+
from src.grpo_self_play.configs.config_loader import (
|
| 33 |
+
load_yaml_file,
|
| 34 |
+
dict_to_dataclass,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class PretrainConfig:
|
| 40 |
+
"""Configuration for pretraining.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
lr: Learning rate
|
| 44 |
+
batch_size: Batch size for training
|
| 45 |
+
num_epochs: Number of epochs to train
|
| 46 |
+
warmup_steps: Number of warmup steps for learning rate
|
| 47 |
+
weight_decay: Weight decay for AdamW
|
| 48 |
+
max_grad_norm: Maximum gradient norm for clipping
|
| 49 |
+
checkpoint_dir: Directory to save checkpoints
|
| 50 |
+
resume_from: Path to checkpoint to resume from
|
| 51 |
+
use_wandb: Whether to use Weights & Biases logging
|
| 52 |
+
wandb_project: WandB project name
|
| 53 |
+
label_smoothing: Label smoothing factor for cross-entropy
|
| 54 |
+
num_workers: Number of DataLoader workers
|
| 55 |
+
val_check_interval: Validation check interval (fraction of epoch or int steps)
|
| 56 |
+
"""
|
| 57 |
+
lr: float = 1e-4
|
| 58 |
+
batch_size: int = 256
|
| 59 |
+
num_epochs: int = 1
|
| 60 |
+
warmup_steps: int = 1000
|
| 61 |
+
weight_decay: float = 0.01
|
| 62 |
+
max_grad_norm: float = 1.0
|
| 63 |
+
checkpoint_dir: str = "checkpoints/pretrain"
|
| 64 |
+
resume_from: Optional[str] = None
|
| 65 |
+
use_wandb: bool = True
|
| 66 |
+
wandb_project: str = "chess-grpo-pretrain"
|
| 67 |
+
label_smoothing: float = 0.1
|
| 68 |
+
num_workers: int = 4
|
| 69 |
+
val_check_interval: float = 0.1
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 73 |
+
torch.serialization.add_safe_globals([PretrainConfig])
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class PretrainChessTransformer(pl.LightningModule):
|
| 77 |
+
"""PyTorch Lightning module for pretraining chess policy with supervised learning.
|
| 78 |
+
|
| 79 |
+
This module implements supervised learning on expert chess moves from Lichess games.
|
| 80 |
+
The pretrained model can then be fine-tuned with GRPO reinforcement learning.
|
| 81 |
+
|
| 82 |
+
Attributes:
|
| 83 |
+
model: The ChessTransformer policy model
|
| 84 |
+
pretrain_config: Pretraining configuration
|
| 85 |
+
transformer_config: Model architecture configuration
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
transformer_config: ChessTransformerConfig,
|
| 91 |
+
pretrain_config: PretrainConfig,
|
| 92 |
+
):
|
| 93 |
+
"""Initialize pretraining module.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
transformer_config: Configuration for the chess transformer model
|
| 97 |
+
pretrain_config: Pretraining configuration
|
| 98 |
+
"""
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.save_hyperparameters()
|
| 101 |
+
|
| 102 |
+
self.model = ChessTransformer(transformer_config)
|
| 103 |
+
self.pretrain_config = pretrain_config
|
| 104 |
+
self.transformer_config = transformer_config
|
| 105 |
+
|
| 106 |
+
# For warmup scheduler
|
| 107 |
+
self._num_training_steps = None
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""Forward pass through the model.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
x: Input tensor [batch, seq_len]
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Policy logits [batch, action_dim]
|
| 117 |
+
"""
|
| 118 |
+
return self.model(x)
|
| 119 |
+
|
| 120 |
+
def _compute_loss(
|
| 121 |
+
self,
|
| 122 |
+
logits: torch.Tensor,
|
| 123 |
+
targets: torch.Tensor,
|
| 124 |
+
legal_masks: torch.Tensor,
|
| 125 |
+
) -> tuple[torch.Tensor, dict]:
|
| 126 |
+
"""Compute cross-entropy loss with legal move masking.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
logits: Model output logits [B, num_actions]
|
| 130 |
+
targets: Target action indices [B]
|
| 131 |
+
legal_masks: Legal moves mask [B, num_actions]
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tuple of (loss, metrics_dict)
|
| 135 |
+
"""
|
| 136 |
+
# Validate shapes match
|
| 137 |
+
B, action_dim = logits.shape
|
| 138 |
+
if legal_masks.shape != (B, action_dim):
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"Shape mismatch: logits {logits.shape} vs legal_masks {legal_masks.shape}. "
|
| 141 |
+
f"Expected legal_masks to be [{B}, {action_dim}]"
|
| 142 |
+
)
|
| 143 |
+
if targets.shape != (B,):
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"Shape mismatch: targets {targets.shape} vs expected [{B}]"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Validate target actions are within bounds
|
| 149 |
+
max_target = targets.max().item()
|
| 150 |
+
min_target = targets.min().item()
|
| 151 |
+
if max_target >= action_dim or min_target < 0:
|
| 152 |
+
raise ValueError(
|
| 153 |
+
f"Target action indices out of bounds: min={min_target}, max={max_target}, "
|
| 154 |
+
f"action_dim={action_dim}. This suggests a mismatch between dataset action "
|
| 155 |
+
f"space and model action_dim."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Validate target actions are legal (should always be true, but check defensively)
|
| 159 |
+
target_legal = legal_masks.gather(1, targets.unsqueeze(1)).squeeze(1)
|
| 160 |
+
if not target_legal.all():
|
| 161 |
+
illegal_count = (~target_legal).sum().item()
|
| 162 |
+
illegal_indices = (~target_legal).nonzero(as_tuple=False).flatten().tolist()
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Found {illegal_count} illegal target actions in batch (out of {B}). "
|
| 165 |
+
f"First few batch indices: {illegal_indices[:10]}. "
|
| 166 |
+
f"This should not happen - dataset should filter these out."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Check for NaN or Inf in raw logits (before masking)
|
| 170 |
+
if not torch.isfinite(logits).all():
|
| 171 |
+
nan_count = (~torch.isfinite(logits)).sum().item()
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Found {nan_count} non-finite values in raw logits before masking. "
|
| 174 |
+
f"This suggests the model is outputting NaN/Inf."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Mask illegal moves to -inf
|
| 178 |
+
masked_logits = logits.masked_fill(~legal_masks, float('-inf'))
|
| 179 |
+
|
| 180 |
+
# Check that each sample has at least one legal move (before checking masked logits)
|
| 181 |
+
legal_per_sample = legal_masks.sum(dim=1)
|
| 182 |
+
if (legal_per_sample == 0).any():
|
| 183 |
+
empty_samples = (legal_per_sample == 0).nonzero(as_tuple=False).flatten().tolist()
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"Found {len(empty_samples)} samples with no legal moves. "
|
| 186 |
+
f"Batch indices: {empty_samples[:10]}. This should not happen."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Check masked logits: each sample must have at least one finite logit (legal move)
|
| 190 |
+
finite_per_sample = torch.isfinite(masked_logits).sum(dim=1)
|
| 191 |
+
if (finite_per_sample == 0).any():
|
| 192 |
+
bad_samples = (finite_per_sample == 0).nonzero(as_tuple=False).flatten().tolist()
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"Found {len(bad_samples)} samples with all -inf logits after masking. "
|
| 195 |
+
f"Batch indices: {bad_samples[:10]}. This means no legal moves have finite logits."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Ensure target actions are not masked (defensive check)
|
| 199 |
+
target_logits = masked_logits.gather(1, targets.unsqueeze(1)).squeeze(1)
|
| 200 |
+
if not torch.isfinite(target_logits).all():
|
| 201 |
+
inf_count = (~torch.isfinite(target_logits)).sum().item()
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Found {inf_count} target actions with -inf logits after masking. "
|
| 204 |
+
f"This means target actions are being masked as illegal, which should not happen."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Compute NLL loss (works correctly with -inf masked logits)
|
| 208 |
+
nll_loss = F.cross_entropy(masked_logits, targets, reduction='mean')
|
| 209 |
+
|
| 210 |
+
# Apply label smoothing only over legal moves to avoid inf from -inf logits
|
| 211 |
+
# Standard F.cross_entropy with label_smoothing averages log_softmax over ALL
|
| 212 |
+
# actions, but -inf logits cause smooth_loss = +inf
|
| 213 |
+
eps = self.pretrain_config.label_smoothing
|
| 214 |
+
if eps > 0:
|
| 215 |
+
# Compute log_softmax (illegal moves will be -inf)
|
| 216 |
+
log_probs = F.log_softmax(masked_logits, dim=-1)
|
| 217 |
+
# Zero out illegal moves so they don't contribute to smoothing term
|
| 218 |
+
log_probs_legal = log_probs.masked_fill(~legal_masks, 0.0)
|
| 219 |
+
# Average only over legal moves
|
| 220 |
+
num_legal = legal_masks.sum(dim=-1).float() # [B]
|
| 221 |
+
smooth_loss = -log_probs_legal.sum(dim=-1) / num_legal # [B]
|
| 222 |
+
loss = (1 - eps) * nll_loss + eps * smooth_loss.mean()
|
| 223 |
+
else:
|
| 224 |
+
loss = nll_loss
|
| 225 |
+
|
| 226 |
+
# Check if loss is infinite or NaN
|
| 227 |
+
if not torch.isfinite(loss):
|
| 228 |
+
# Additional debugging info
|
| 229 |
+
target_logits_debug = masked_logits.gather(1, targets.unsqueeze(1)).squeeze(1)
|
| 230 |
+
print(f"DEBUG: Loss is {loss.item()}")
|
| 231 |
+
print(f"DEBUG: NLL loss: {nll_loss.item()}")
|
| 232 |
+
if eps > 0:
|
| 233 |
+
print(f"DEBUG: Smooth loss mean: {smooth_loss.mean().item()}")
|
| 234 |
+
print(f"DEBUG: Logits shape: {logits.shape}")
|
| 235 |
+
print(f"DEBUG: Legal masks shape: {legal_masks.shape}")
|
| 236 |
+
print(f"DEBUG: Targets range: [{targets.min().item()}, {targets.max().item()}]")
|
| 237 |
+
print(f"DEBUG: Target logits range: [{target_logits_debug.min().item():.2f}, {target_logits_debug.max().item():.2f}]")
|
| 238 |
+
print(f"DEBUG: Legal moves per sample: min={legal_per_sample.min().item()}, max={legal_per_sample.max().item()}")
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Loss is {loss.item()}. This can happen if:\n"
|
| 241 |
+
f"1. Target actions are out of bounds\n"
|
| 242 |
+
f"2. Target actions are masked as illegal\n"
|
| 243 |
+
f"3. Model outputs contain NaN/Inf\n"
|
| 244 |
+
f"4. All logits are -inf (no legal moves)"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Compute metrics
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
# Top-1 accuracy
|
| 250 |
+
predictions = masked_logits.argmax(dim=-1)
|
| 251 |
+
accuracy = (predictions == targets).float().mean()
|
| 252 |
+
|
| 253 |
+
# Top-5 accuracy
|
| 254 |
+
_, top5_preds = masked_logits.topk(5, dim=-1)
|
| 255 |
+
top5_correct = (top5_preds == targets.unsqueeze(-1)).any(dim=-1)
|
| 256 |
+
top5_accuracy = top5_correct.float().mean()
|
| 257 |
+
|
| 258 |
+
# Entropy of the distribution (measure of confidence)
|
| 259 |
+
probs = F.softmax(masked_logits, dim=-1)
|
| 260 |
+
log_probs = F.log_softmax(masked_logits, dim=-1)
|
| 261 |
+
# Handle -inf * 0 = nan by replacing with 0
|
| 262 |
+
entropy_terms = probs * log_probs
|
| 263 |
+
entropy_terms = torch.where(
|
| 264 |
+
torch.isfinite(entropy_terms),
|
| 265 |
+
entropy_terms,
|
| 266 |
+
torch.zeros_like(entropy_terms)
|
| 267 |
+
)
|
| 268 |
+
entropy = -entropy_terms.sum(dim=-1).mean()
|
| 269 |
+
|
| 270 |
+
# Perplexity - clamp to avoid inf
|
| 271 |
+
perplexity = torch.exp(loss.clamp(max=50))
|
| 272 |
+
|
| 273 |
+
metrics = {
|
| 274 |
+
'accuracy': accuracy,
|
| 275 |
+
'top5_accuracy': top5_accuracy,
|
| 276 |
+
'entropy': entropy,
|
| 277 |
+
'perplexity': perplexity,
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
return loss, metrics
|
| 281 |
+
|
| 282 |
+
def training_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
|
| 283 |
+
"""Perform a training step.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
batch: Tuple of (boards, actions, legal_masks)
|
| 287 |
+
batch_idx: Batch index
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
Loss value
|
| 291 |
+
"""
|
| 292 |
+
boards, actions, legal_masks = batch
|
| 293 |
+
|
| 294 |
+
# Forward pass
|
| 295 |
+
logits = self(boards)
|
| 296 |
+
|
| 297 |
+
# Compute loss and metrics
|
| 298 |
+
loss, metrics = self._compute_loss(logits, actions, legal_masks)
|
| 299 |
+
|
| 300 |
+
# Log metrics
|
| 301 |
+
self.log('train/loss', loss, prog_bar=True)
|
| 302 |
+
self.log('train/accuracy', metrics['accuracy'], prog_bar=True)
|
| 303 |
+
self.log('train/top5_accuracy', metrics['top5_accuracy'])
|
| 304 |
+
self.log('train/entropy', metrics['entropy'])
|
| 305 |
+
self.log('train/perplexity', metrics['perplexity'])
|
| 306 |
+
|
| 307 |
+
return loss
|
| 308 |
+
|
| 309 |
+
def validation_step(self, batch: tuple, batch_idx: int) -> torch.Tensor:
|
| 310 |
+
"""Perform a validation step.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
batch: Tuple of (boards, actions, legal_masks)
|
| 314 |
+
batch_idx: Batch index
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Loss value
|
| 318 |
+
"""
|
| 319 |
+
boards, actions, legal_masks = batch
|
| 320 |
+
|
| 321 |
+
# Forward pass
|
| 322 |
+
logits = self(boards)
|
| 323 |
+
|
| 324 |
+
# Compute loss and metrics
|
| 325 |
+
loss, metrics = self._compute_loss(logits, actions, legal_masks)
|
| 326 |
+
|
| 327 |
+
# Log metrics
|
| 328 |
+
self.log('val/loss', loss, prog_bar=True, sync_dist=True)
|
| 329 |
+
self.log('val/accuracy', metrics['accuracy'], prog_bar=True, sync_dist=True)
|
| 330 |
+
self.log('val/top5_accuracy', metrics['top5_accuracy'], sync_dist=True)
|
| 331 |
+
self.log('val/entropy', metrics['entropy'], sync_dist=True)
|
| 332 |
+
self.log('val/perplexity', metrics['perplexity'], sync_dist=True)
|
| 333 |
+
|
| 334 |
+
return loss
|
| 335 |
+
|
| 336 |
+
def configure_optimizers(self):
|
| 337 |
+
"""Configure optimizer and learning rate scheduler.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Dictionary with optimizer and lr_scheduler configuration
|
| 341 |
+
"""
|
| 342 |
+
optimizer = torch.optim.AdamW(
|
| 343 |
+
self.parameters(),
|
| 344 |
+
lr=self.pretrain_config.lr,
|
| 345 |
+
weight_decay=self.pretrain_config.weight_decay,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Linear warmup + cosine decay scheduler
|
| 349 |
+
def lr_lambda(current_step: int) -> float:
|
| 350 |
+
warmup_steps = self.pretrain_config.warmup_steps
|
| 351 |
+
if current_step < warmup_steps:
|
| 352 |
+
return float(current_step) / float(max(1, warmup_steps))
|
| 353 |
+
return 1.0 # After warmup, use constant LR (or add cosine decay)
|
| 354 |
+
|
| 355 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 356 |
+
|
| 357 |
+
return {
|
| 358 |
+
'optimizer': optimizer,
|
| 359 |
+
'lr_scheduler': {
|
| 360 |
+
'scheduler': scheduler,
|
| 361 |
+
'interval': 'step',
|
| 362 |
+
'frequency': 1,
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def get_pretrain_trainer(
|
| 368 |
+
pretrain_config: PretrainConfig,
|
| 369 |
+
run_name: str,
|
| 370 |
+
) -> pl.Trainer:
|
| 371 |
+
"""Create a PyTorch Lightning trainer for pretraining.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
pretrain_config: Pretraining configuration
|
| 375 |
+
run_name: Name for this training run
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
Configured PyTorch Lightning trainer
|
| 379 |
+
"""
|
| 380 |
+
# Create checkpoint directory
|
| 381 |
+
checkpoint_dir = Path(pretrain_config.checkpoint_dir)
|
| 382 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 383 |
+
|
| 384 |
+
callbacks = [
|
| 385 |
+
ModelCheckpoint(
|
| 386 |
+
dirpath=str(checkpoint_dir),
|
| 387 |
+
filename=run_name + "-{epoch:02d}-{train/loss:.4f}",
|
| 388 |
+
save_top_k=3,
|
| 389 |
+
monitor="train/loss",
|
| 390 |
+
mode="min",
|
| 391 |
+
save_last=True,
|
| 392 |
+
),
|
| 393 |
+
LearningRateMonitor(logging_interval='step'),
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
logger = None
|
| 397 |
+
if pretrain_config.use_wandb:
|
| 398 |
+
logger = WandbLogger(
|
| 399 |
+
project=pretrain_config.wandb_project,
|
| 400 |
+
name=run_name,
|
| 401 |
+
log_model=True,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
trainer = pl.Trainer(
|
| 405 |
+
max_epochs=pretrain_config.num_epochs,
|
| 406 |
+
accelerator="auto",
|
| 407 |
+
devices=1,
|
| 408 |
+
logger=logger,
|
| 409 |
+
callbacks=callbacks,
|
| 410 |
+
gradient_clip_val=pretrain_config.max_grad_norm,
|
| 411 |
+
log_every_n_steps=50,
|
| 412 |
+
val_check_interval=pretrain_config.val_check_interval,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
return trainer
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def load_pretrain_config(
|
| 419 |
+
path: str = "pretrain.yaml",
|
| 420 |
+
overrides: dict = None,
|
| 421 |
+
) -> tuple[PretrainConfig, PretrainDatasetConfig, ChessTransformerConfig]:
|
| 422 |
+
"""Load pretraining configuration from YAML file.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
path: Path to config file (relative to configs dir or absolute)
|
| 426 |
+
overrides: Optional dict of overrides
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tuple of (PretrainConfig, PretrainDatasetConfig, ChessTransformerConfig)
|
| 430 |
+
"""
|
| 431 |
+
data = load_yaml_file(path)
|
| 432 |
+
|
| 433 |
+
if overrides:
|
| 434 |
+
for section, section_overrides in overrides.items():
|
| 435 |
+
if section in data:
|
| 436 |
+
data[section].update(section_overrides)
|
| 437 |
+
else:
|
| 438 |
+
data[section] = section_overrides
|
| 439 |
+
|
| 440 |
+
pretrain = dict_to_dataclass(PretrainConfig, data.get('pretrain', {}))
|
| 441 |
+
dataset = dict_to_dataclass(PretrainDatasetConfig, data.get('dataset', {}))
|
| 442 |
+
transformer = dict_to_dataclass(ChessTransformerConfig, data.get('transformer', {}))
|
| 443 |
+
|
| 444 |
+
return pretrain, dataset, transformer
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def train(
|
| 448 |
+
pretrain_config: PretrainConfig,
|
| 449 |
+
dataset_config: PretrainDatasetConfig,
|
| 450 |
+
transformer_config: ChessTransformerConfig,
|
| 451 |
+
) -> str:
|
| 452 |
+
"""Main pretraining function.
|
| 453 |
+
|
| 454 |
+
Args:
|
| 455 |
+
pretrain_config: Pretraining configuration
|
| 456 |
+
dataset_config: Dataset configuration
|
| 457 |
+
transformer_config: Model configuration
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
Path to final checkpoint
|
| 461 |
+
"""
|
| 462 |
+
import time
|
| 463 |
+
import random
|
| 464 |
+
import string
|
| 465 |
+
|
| 466 |
+
# Generate run name
|
| 467 |
+
timestamp = time.strftime("%Y%m%d-%H%M")
|
| 468 |
+
random_suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
| 469 |
+
run_name = f"pretrain-{timestamp}-{random_suffix}"
|
| 470 |
+
print(f"Run name: {run_name}")
|
| 471 |
+
|
| 472 |
+
# Create model
|
| 473 |
+
model = PretrainChessTransformer(transformer_config, pretrain_config)
|
| 474 |
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 475 |
+
|
| 476 |
+
# Create datasets
|
| 477 |
+
train_dataset = ChessPretrainDataset(dataset_config)
|
| 478 |
+
|
| 479 |
+
# Create validation dataset using hash-based split
|
| 480 |
+
val_dataset_config = PretrainDatasetConfig(
|
| 481 |
+
min_elo=dataset_config.min_elo,
|
| 482 |
+
max_samples=10000, # Smaller validation set
|
| 483 |
+
skip_first_n_moves=dataset_config.skip_first_n_moves,
|
| 484 |
+
skip_last_n_moves=dataset_config.skip_last_n_moves,
|
| 485 |
+
sample_positions_per_game=1, # Less samples per game for validation
|
| 486 |
+
is_eval=True, # Use eval portion of hash-based split
|
| 487 |
+
eval_fraction=dataset_config.eval_fraction,
|
| 488 |
+
cache_path=dataset_config.cache_path,
|
| 489 |
+
)
|
| 490 |
+
val_dataset = ChessPretrainDataset(val_dataset_config)
|
| 491 |
+
print(f"Train: {len(train_dataset):,} samples, Eval: {len(val_dataset):,} samples")
|
| 492 |
+
|
| 493 |
+
# Create dataloaders
|
| 494 |
+
train_dataloader = DataLoader(
|
| 495 |
+
train_dataset,
|
| 496 |
+
batch_size=pretrain_config.batch_size,
|
| 497 |
+
shuffle=True, # Shuffle for training
|
| 498 |
+
num_workers=pretrain_config.num_workers,
|
| 499 |
+
collate_fn=collate_pretrain_batch,
|
| 500 |
+
pin_memory=True,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
val_dataloader = DataLoader(
|
| 504 |
+
val_dataset,
|
| 505 |
+
batch_size=pretrain_config.batch_size,
|
| 506 |
+
shuffle=False,
|
| 507 |
+
num_workers=max(1, pretrain_config.num_workers // 2),
|
| 508 |
+
collate_fn=collate_pretrain_batch,
|
| 509 |
+
pin_memory=True,
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Create trainer
|
| 513 |
+
trainer = get_pretrain_trainer(pretrain_config, run_name)
|
| 514 |
+
|
| 515 |
+
# Resume from checkpoint if specified
|
| 516 |
+
ckpt_path = pretrain_config.resume_from
|
| 517 |
+
|
| 518 |
+
# Train
|
| 519 |
+
trainer.fit(model, train_dataloader, val_dataloader, ckpt_path=ckpt_path)
|
| 520 |
+
|
| 521 |
+
# Save final checkpoint in a standard location
|
| 522 |
+
final_path = Path(pretrain_config.checkpoint_dir) / "pretrain_final.pt"
|
| 523 |
+
torch.save({
|
| 524 |
+
'model_state_dict': model.model.state_dict(),
|
| 525 |
+
'transformer_config': transformer_config,
|
| 526 |
+
'pretrain_config': pretrain_config,
|
| 527 |
+
}, final_path)
|
| 528 |
+
|
| 529 |
+
print(f"\nPretraining complete! Final checkpoint saved to {final_path}")
|
| 530 |
+
return str(final_path)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def main():
|
| 534 |
+
"""Main entry point for pretraining script."""
|
| 535 |
+
parser = argparse.ArgumentParser(description="Pretrain chess model on Lichess games")
|
| 536 |
+
parser.add_argument("--config", type=str, default="pretrain.yaml",
|
| 537 |
+
help="Path to config file")
|
| 538 |
+
|
| 539 |
+
# Allow command-line overrides for common parameters
|
| 540 |
+
parser.add_argument("--lr", type=float, help="Learning rate")
|
| 541 |
+
parser.add_argument("--batch_size", type=int, help="Batch size")
|
| 542 |
+
parser.add_argument("--num_epochs", type=int, help="Number of epochs")
|
| 543 |
+
parser.add_argument("--min_elo", type=int, help="Minimum player ELO")
|
| 544 |
+
parser.add_argument("--max_samples", type=int, help="Max samples per epoch")
|
| 545 |
+
parser.add_argument("--resume_from", type=str, help="Resume from checkpoint")
|
| 546 |
+
parser.add_argument("--no_wandb", action="store_true", help="Disable wandb logging")
|
| 547 |
+
|
| 548 |
+
args = parser.parse_args()
|
| 549 |
+
|
| 550 |
+
# Build overrides from command-line arguments
|
| 551 |
+
overrides = {'pretrain': {}, 'dataset': {}}
|
| 552 |
+
|
| 553 |
+
if args.lr:
|
| 554 |
+
overrides['pretrain']['lr'] = args.lr
|
| 555 |
+
if args.batch_size:
|
| 556 |
+
overrides['pretrain']['batch_size'] = args.batch_size
|
| 557 |
+
if args.num_epochs:
|
| 558 |
+
overrides['pretrain']['num_epochs'] = args.num_epochs
|
| 559 |
+
if args.resume_from:
|
| 560 |
+
overrides['pretrain']['resume_from'] = args.resume_from
|
| 561 |
+
if args.no_wandb:
|
| 562 |
+
overrides['pretrain']['use_wandb'] = False
|
| 563 |
+
if args.min_elo:
|
| 564 |
+
overrides['dataset']['min_elo'] = args.min_elo
|
| 565 |
+
if args.max_samples:
|
| 566 |
+
overrides['dataset']['max_samples'] = args.max_samples
|
| 567 |
+
|
| 568 |
+
# Load config
|
| 569 |
+
pretrain_config, dataset_config, transformer_config = load_pretrain_config(
|
| 570 |
+
args.config,
|
| 571 |
+
overrides=overrides if any(v for v in overrides.values()) else None
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# Run training
|
| 575 |
+
train(pretrain_config, dataset_config, transformer_config)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
if __name__ == "__main__":
|
| 579 |
+
main()
|
hf_space_repo/pretrain/pretrain_dataset.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset for pretraining on chess games from HuggingFace.
|
| 2 |
+
|
| 3 |
+
Uses angeluriot/chess_games: 14M high-ELO games (7.3GB download).
|
| 4 |
+
Mean ELO ~2355, moves already in UCI format - no parsing needed.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import chess
|
| 9 |
+
import torch
|
| 10 |
+
import random
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from functools import partial
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from multiprocessing import cpu_count
|
| 15 |
+
from torch.utils.data import Dataset
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from src.grpo_self_play.searchless_chess_imports import MOVE_TO_ACTION, tokenize
|
| 20 |
+
|
| 21 |
+
# Global constant
|
| 22 |
+
_ACTION_SPACE_SIZE = max(MOVE_TO_ACTION.values()) + 1
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class PretrainDatasetConfig:
|
| 27 |
+
"""Configuration for the pretraining dataset.
|
| 28 |
+
|
| 29 |
+
Uses angeluriot/chess_games: 14M high-ELO games (7.3GB download).
|
| 30 |
+
Mean ELO ~2355, moves already in UCI format.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
min_elo: Minimum player ELO to include games
|
| 34 |
+
max_samples: Maximum number of samples per epoch (None for unlimited)
|
| 35 |
+
skip_first_n_moves: Skip the first N moves (avoid memorizing openings)
|
| 36 |
+
skip_last_n_moves: Skip the last N moves (avoid noisy endgame positions)
|
| 37 |
+
sample_positions_per_game: Number of positions to sample from each game
|
| 38 |
+
is_eval: If True, use eval portion of hash-based split.
|
| 39 |
+
eval_fraction: Fraction of data to use for evaluation (default 0.05 = 5%)
|
| 40 |
+
cache_path: Path to save/load filtered dataset (e.g., Google Drive, studio storage).
|
| 41 |
+
If set and exists, loads from cache. Otherwise downloads, filters, and saves.
|
| 42 |
+
"""
|
| 43 |
+
min_elo: int = 2000
|
| 44 |
+
max_samples: Optional[int] = None
|
| 45 |
+
skip_first_n_moves: int = 5
|
| 46 |
+
skip_last_n_moves: int = 5
|
| 47 |
+
sample_positions_per_game: int = 3
|
| 48 |
+
is_eval: bool = False
|
| 49 |
+
eval_fraction: float = 0.05
|
| 50 |
+
cache_path: Optional[str] = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def uci_to_action(uci_move: str) -> Optional[int]:
|
| 54 |
+
"""Convert UCI move string to action index."""
|
| 55 |
+
return MOVE_TO_ACTION.get(uci_move)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_positions_from_game(
|
| 59 |
+
moves: list[str],
|
| 60 |
+
skip_first_n: int = 5,
|
| 61 |
+
skip_last_n: int = 5,
|
| 62 |
+
sample_n: int = 3,
|
| 63 |
+
) -> list[tuple[str, str, int]]:
|
| 64 |
+
"""Extract (FEN, move_played, move_number) tuples from a game.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
moves: List of UCI moves
|
| 68 |
+
skip_first_n: Skip first N moves (opening book territory)
|
| 69 |
+
skip_last_n: Skip last N moves (endgame/resignation noise)
|
| 70 |
+
sample_n: Number of positions to randomly sample
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of (fen, uci_move, move_number) tuples
|
| 74 |
+
"""
|
| 75 |
+
if len(moves) <= skip_first_n + skip_last_n:
|
| 76 |
+
return []
|
| 77 |
+
|
| 78 |
+
board = chess.Board()
|
| 79 |
+
positions = []
|
| 80 |
+
|
| 81 |
+
for i, uci_move in enumerate(moves):
|
| 82 |
+
if i < skip_first_n:
|
| 83 |
+
try:
|
| 84 |
+
board.push_uci(uci_move)
|
| 85 |
+
except (ValueError, chess.InvalidMoveError):
|
| 86 |
+
return positions
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
if i >= len(moves) - skip_last_n:
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
fen = board.fen()
|
| 93 |
+
positions.append((fen, uci_move, i))
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
board.push_uci(uci_move)
|
| 97 |
+
except (ValueError, chess.InvalidMoveError):
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
if len(positions) > sample_n:
|
| 101 |
+
positions = random.sample(positions, sample_n)
|
| 102 |
+
|
| 103 |
+
return positions
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ChessPretrainDataset(Dataset):
|
| 107 |
+
"""Dataset for chess pretraining from angeluriot/chess_games.
|
| 108 |
+
|
| 109 |
+
Downloads the full dataset (7.3GB) and processes games into
|
| 110 |
+
(board_tensor, target_action, legal_moves_mask) tuples.
|
| 111 |
+
|
| 112 |
+
Example:
|
| 113 |
+
>>> config = PretrainDatasetConfig(min_elo=2000)
|
| 114 |
+
>>> dataset = ChessPretrainDataset(config)
|
| 115 |
+
>>> dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, config: PretrainDatasetConfig = PretrainDatasetConfig()):
|
| 119 |
+
"""Initialize the dataset - downloads and processes all games."""
|
| 120 |
+
self.config = config
|
| 121 |
+
self._action_space_size = max(MOVE_TO_ACTION.values()) + 1
|
| 122 |
+
self._samples: list[tuple[torch.Tensor, int, torch.Tensor]] = []
|
| 123 |
+
|
| 124 |
+
self._load_and_process()
|
| 125 |
+
|
| 126 |
+
def _load_and_process(self):
|
| 127 |
+
"""Download dataset and process all games into samples."""
|
| 128 |
+
# Try loading processed samples from cache
|
| 129 |
+
if self.config.cache_path:
|
| 130 |
+
cache_file = self._get_cache_filename()
|
| 131 |
+
if os.path.exists(cache_file):
|
| 132 |
+
print(f"Loading processed samples from {cache_file}...")
|
| 133 |
+
self._samples = torch.load(cache_file)
|
| 134 |
+
print(f"Loaded {len(self._samples):,} samples from cache")
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
# Download, filter, and process
|
| 138 |
+
dataset = self._load_filtered_dataset()
|
| 139 |
+
|
| 140 |
+
# Limit dataset size if max_samples is set
|
| 141 |
+
if self.config.max_samples:
|
| 142 |
+
max_games = self.config.max_samples // self.config.sample_positions_per_game + 1000
|
| 143 |
+
if len(dataset) > max_games:
|
| 144 |
+
dataset = dataset.select(range(max_games))
|
| 145 |
+
print(f"Limited to {len(dataset):,} games")
|
| 146 |
+
|
| 147 |
+
# Process games using HuggingFace's optimized map
|
| 148 |
+
num_workers = min(8, cpu_count() or 4)
|
| 149 |
+
print(f"Processing games into samples with {num_workers} workers...")
|
| 150 |
+
|
| 151 |
+
skip_first = self.config.skip_first_n_moves
|
| 152 |
+
skip_last = self.config.skip_last_n_moves
|
| 153 |
+
sample_n = self.config.sample_positions_per_game
|
| 154 |
+
|
| 155 |
+
def process_batch(batch):
|
| 156 |
+
"""Process a batch of games - returns lists for HF dataset."""
|
| 157 |
+
all_boards, all_actions, all_masks = [], [], []
|
| 158 |
+
|
| 159 |
+
for i in range(len(batch['moves_uci'])):
|
| 160 |
+
moves = batch['moves_uci'][i]
|
| 161 |
+
if not moves:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
positions = get_positions_from_game(moves, skip_first, skip_last, sample_n)
|
| 165 |
+
|
| 166 |
+
for fen, uci_move, _ in positions:
|
| 167 |
+
action_idx = MOVE_TO_ACTION.get(uci_move)
|
| 168 |
+
if action_idx is None:
|
| 169 |
+
continue
|
| 170 |
+
try:
|
| 171 |
+
token_ids = list(tokenize(fen))
|
| 172 |
+
board = chess.Board(fen)
|
| 173 |
+
legal_mask = [False] * _ACTION_SPACE_SIZE
|
| 174 |
+
for move in board.legal_moves:
|
| 175 |
+
move_idx = MOVE_TO_ACTION.get(move.uci())
|
| 176 |
+
if move_idx is not None:
|
| 177 |
+
legal_mask[move_idx] = True
|
| 178 |
+
if not legal_mask[action_idx]:
|
| 179 |
+
continue
|
| 180 |
+
all_boards.append(token_ids)
|
| 181 |
+
all_actions.append(action_idx)
|
| 182 |
+
all_masks.append(legal_mask)
|
| 183 |
+
except Exception:
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
return {'boards': all_boards, 'actions': all_actions, 'masks': all_masks}
|
| 187 |
+
|
| 188 |
+
processed = dataset.map(
|
| 189 |
+
process_batch,
|
| 190 |
+
batched=True,
|
| 191 |
+
batch_size=1000,
|
| 192 |
+
num_proc=num_workers,
|
| 193 |
+
remove_columns=dataset.column_names,
|
| 194 |
+
desc="Processing"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Convert to tensors (HF map flattens the lists)
|
| 198 |
+
print("Converting to tensors...")
|
| 199 |
+
for i in tqdm(range(len(processed)), desc="Tensorizing"):
|
| 200 |
+
board_tensor = torch.tensor(processed[i]['boards'], dtype=torch.long)
|
| 201 |
+
legal_mask = torch.tensor(processed[i]['masks'], dtype=torch.bool)
|
| 202 |
+
self._samples.append((board_tensor, processed[i]['actions'], legal_mask))
|
| 203 |
+
if self.config.max_samples and len(self._samples) >= self.config.max_samples:
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
print(f"Done: {len(self._samples):,} samples")
|
| 207 |
+
|
| 208 |
+
# Save processed samples to cache
|
| 209 |
+
if self.config.cache_path:
|
| 210 |
+
cache_file = self._get_cache_filename()
|
| 211 |
+
print(f"Saving processed samples to {cache_file}...")
|
| 212 |
+
os.makedirs(self.config.cache_path, exist_ok=True)
|
| 213 |
+
torch.save(self._samples, cache_file)
|
| 214 |
+
print("Saved to cache")
|
| 215 |
+
|
| 216 |
+
def _get_cache_filename(self) -> str:
|
| 217 |
+
"""Generate cache filename based on config."""
|
| 218 |
+
split = 'eval' if self.config.is_eval else 'train'
|
| 219 |
+
max_samples = self.config.max_samples or 'all'
|
| 220 |
+
return f"{self.config.cache_path}/processed_elo{self.config.min_elo}_{split}_{max_samples}.pt"
|
| 221 |
+
|
| 222 |
+
def _load_filtered_dataset(self):
|
| 223 |
+
"""Download and filter dataset."""
|
| 224 |
+
# Download (uses cache_path for HuggingFace cache)
|
| 225 |
+
print("Downloading angeluriot/chess_games (7.3GB)...")
|
| 226 |
+
cache_dir = self.config.cache_path if self.config.cache_path else None
|
| 227 |
+
dataset = load_dataset("angeluriot/chess_games", split="train", cache_dir=cache_dir)
|
| 228 |
+
print(f"Loaded {len(dataset):,} games")
|
| 229 |
+
|
| 230 |
+
# Fast batched filtering
|
| 231 |
+
print(f"Filtering games (min_elo={self.config.min_elo})...")
|
| 232 |
+
min_elo = self.config.min_elo
|
| 233 |
+
eval_frac = self.config.eval_fraction
|
| 234 |
+
is_eval = self.config.is_eval
|
| 235 |
+
|
| 236 |
+
def batch_filter(batch):
|
| 237 |
+
"""Filter a batch of games - much faster than per-example."""
|
| 238 |
+
keep = []
|
| 239 |
+
for i in range(len(batch['white_elo'])):
|
| 240 |
+
white_elo = batch['white_elo'][i]
|
| 241 |
+
black_elo = batch['black_elo'][i]
|
| 242 |
+
|
| 243 |
+
# Skip if ELO is missing
|
| 244 |
+
if white_elo is None or black_elo is None:
|
| 245 |
+
keep.append(False)
|
| 246 |
+
continue
|
| 247 |
+
# ELO filter
|
| 248 |
+
if white_elo < min_elo or black_elo < min_elo:
|
| 249 |
+
keep.append(False)
|
| 250 |
+
continue
|
| 251 |
+
# Moves filter
|
| 252 |
+
if len(batch['moves_uci'][i]) < 10:
|
| 253 |
+
keep.append(False)
|
| 254 |
+
continue
|
| 255 |
+
# Hash-based train/eval split
|
| 256 |
+
game_id = f"{batch['date'][i]}-{white_elo}-{black_elo}"
|
| 257 |
+
hash_val = hash(game_id) % 10000
|
| 258 |
+
is_eval_game = hash_val < (eval_frac * 10000)
|
| 259 |
+
if is_eval_game != is_eval:
|
| 260 |
+
keep.append(False)
|
| 261 |
+
continue
|
| 262 |
+
keep.append(True)
|
| 263 |
+
return keep
|
| 264 |
+
|
| 265 |
+
dataset = dataset.filter(batch_filter, batched=True, batch_size=10000, desc="Filtering")
|
| 266 |
+
print(f"After filtering: {len(dataset):,} games")
|
| 267 |
+
|
| 268 |
+
return dataset
|
| 269 |
+
|
| 270 |
+
def _process_game(self, game: dict):
|
| 271 |
+
"""Process a single game and yield training samples."""
|
| 272 |
+
moves = game.get('moves_uci', [])
|
| 273 |
+
|
| 274 |
+
positions = get_positions_from_game(
|
| 275 |
+
moves,
|
| 276 |
+
skip_first_n=self.config.skip_first_n_moves,
|
| 277 |
+
skip_last_n=self.config.skip_last_n_moves,
|
| 278 |
+
sample_n=self.config.sample_positions_per_game,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
for fen, uci_move, _ in positions:
|
| 282 |
+
action_idx = uci_to_action(uci_move)
|
| 283 |
+
if action_idx is None:
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
token_ids = list(tokenize(fen))
|
| 288 |
+
board_tensor = torch.tensor(token_ids, dtype=torch.long)
|
| 289 |
+
except Exception:
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
board = chess.Board(fen)
|
| 294 |
+
legal_mask = torch.zeros(self._action_space_size, dtype=torch.bool)
|
| 295 |
+
for move in board.legal_moves:
|
| 296 |
+
move_idx = MOVE_TO_ACTION.get(move.uci())
|
| 297 |
+
if move_idx is not None:
|
| 298 |
+
legal_mask[move_idx] = True
|
| 299 |
+
except Exception:
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
if not legal_mask[action_idx]:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
yield board_tensor, action_idx, legal_mask
|
| 306 |
+
|
| 307 |
+
def __len__(self) -> int:
|
| 308 |
+
return len(self._samples)
|
| 309 |
+
|
| 310 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, torch.Tensor]:
|
| 311 |
+
return self._samples[idx]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def collate_pretrain_batch(
|
| 315 |
+
batch: list[tuple[torch.Tensor, int, torch.Tensor]]
|
| 316 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 317 |
+
"""Collate function for DataLoader.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Tuple of (boards [B, 77], actions [B], legal_masks [B, num_actions])
|
| 321 |
+
"""
|
| 322 |
+
boards, actions, masks = zip(*batch)
|
| 323 |
+
|
| 324 |
+
boards = torch.stack(boards)
|
| 325 |
+
actions = torch.tensor(actions, dtype=torch.long)
|
| 326 |
+
masks = torch.stack(masks)
|
| 327 |
+
|
| 328 |
+
return boards, actions, masks
|
hf_space_repo/pretrain/pretrain_load_config.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pretrain load configuration - separated to avoid circular imports."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class PretrainLoadConfig:
|
| 10 |
+
"""Configuration for loading pretrained weights.
|
| 11 |
+
|
| 12 |
+
Attributes:
|
| 13 |
+
checkpoint_path: Path to pretrained checkpoint file
|
| 14 |
+
freeze_layers: Number of transformer layers to freeze (0 = train all)
|
| 15 |
+
"""
|
| 16 |
+
checkpoint_path: Optional[str] = None
|
| 17 |
+
freeze_layers: int = 0
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Register as safe for torch.load with weights_only=True (PyTorch 2.6+ compatibility)
|
| 21 |
+
torch.serialization.add_safe_globals([PretrainLoadConfig])
|
hf_space_repo/searchless_chess_imports.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.searchless_chess_model.searchless_chess_code.utils import ACTION_TO_MOVE, MOVE_TO_ACTION
|
| 2 |
+
from src.searchless_chess_model.searchless_chess_code.tokenizer import tokenize, SEQUENCE_LENGTH
|
| 3 |
+
|
hf_space_repo/searchless_chess_model/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
hf_space_repo/searchless_chess_model/README.md
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- chess
|
| 5 |
+
- reinforcement-learning
|
| 6 |
+
- jax
|
| 7 |
+
- transformer
|
| 8 |
+
language:
|
| 9 |
+
- en
|
| 10 |
+
library_name: jax
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Searchless Chess 9M Self-Play
|
| 14 |
+
|
| 15 |
+
A 9-million parameter transformer-based chess engine trained via self-play with Stockfish evaluation. This model learns to play chess without explicit search during inference, relying purely on learned pattern recognition.
|
| 16 |
+
|
| 17 |
+
## Model Description
|
| 18 |
+
|
| 19 |
+
- **Model Size**: 9M parameters (8 layers, 256 embedding dim, 8 attention heads)
|
| 20 |
+
- **Architecture**: Decoder-only Transformer with learned positional encodings
|
| 21 |
+
- **Training Method**: Self-play with Stockfish rewards
|
| 22 |
+
- **Framework**: JAX + Haiku
|
| 23 |
+
- **Q-Value Distribution**: 128 return buckets for action-value prediction
|
| 24 |
+
|
| 25 |
+
This model predicts action-values (Q-values) for chess positions without performing tree search, making it extremely fast for inference while maintaining strong play.
|
| 26 |
+
|
| 27 |
+
## Installation
|
| 28 |
+
|
| 29 |
+
### CPU Installation
|
| 30 |
+
|
| 31 |
+
Install the required dependencies for CPU inference:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
pip install jax jaxlib dm-haiku orbax-checkpoint numpy chess huggingface-hub jaxtyping apache-beam grain
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### GPU Installation (Recommended)
|
| 38 |
+
|
| 39 |
+
For GPU acceleration with CUDA 12:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
| 43 |
+
pip install dm-haiku orbax-checkpoint numpy chess huggingface-hub jaxtyping apache-beam grain
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
For other CUDA versions, see the [JAX installation guide](https://github.com/google/jax#installation).
|
| 47 |
+
|
| 48 |
+
**Note**: This model includes all necessary code and can be used **without cloning the original repository**.
|
| 49 |
+
|
| 50 |
+
## Quick Start
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import sys
|
| 54 |
+
from huggingface_hub import snapshot_download
|
| 55 |
+
|
| 56 |
+
# Download model from HuggingFace Hub
|
| 57 |
+
model_path = snapshot_download(
|
| 58 |
+
repo_id="dbest-isi/searchless-chess-9M-selfplay",
|
| 59 |
+
local_dir="./searchless_chess_model"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Add bundled code to Python path
|
| 63 |
+
sys.path.insert(0, f"{model_path}/searchless_chess_code")
|
| 64 |
+
|
| 65 |
+
# Import model wrapper
|
| 66 |
+
import hf_model
|
| 67 |
+
|
| 68 |
+
# Load the model
|
| 69 |
+
model = hf_model.SearchlessChessModel.from_pretrained(model_path)
|
| 70 |
+
|
| 71 |
+
# Make a prediction
|
| 72 |
+
fen = "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"
|
| 73 |
+
result = model.predict(fen, temperature=1.0)
|
| 74 |
+
|
| 75 |
+
print(f"Best move: {result['best_move']}")
|
| 76 |
+
print(f"Q-value: {result['q_value']:.4f}")
|
| 77 |
+
print(f"Action probabilities shape: {result['action_probs'].shape}")
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Example Output
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
Best move: e7e5
|
| 84 |
+
Q-value: 0.0119
|
| 85 |
+
Action probabilities shape: (1968,)
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Full Example with Multiple Positions
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
import sys
|
| 92 |
+
from huggingface_hub import snapshot_download
|
| 93 |
+
|
| 94 |
+
# Download and setup
|
| 95 |
+
model_path = snapshot_download(
|
| 96 |
+
repo_id="dbest-isi/searchless-chess-9M-selfplay",
|
| 97 |
+
local_dir="./searchless_chess_model"
|
| 98 |
+
)
|
| 99 |
+
sys.path.insert(0, f"{model_path}/searchless_chess_code")
|
| 100 |
+
|
| 101 |
+
import hf_model
|
| 102 |
+
|
| 103 |
+
# Load model
|
| 104 |
+
print("Loading model...")
|
| 105 |
+
model = hf_model.SearchlessChessModel.from_pretrained(model_path)
|
| 106 |
+
print("Model loaded!")
|
| 107 |
+
|
| 108 |
+
# Test on multiple positions
|
| 109 |
+
positions = [
|
| 110 |
+
("Starting position", "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
|
| 111 |
+
("After 1.e4", "rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"),
|
| 112 |
+
("Scandinavian Defense", "rnbqkbnr/ppp1pppp/8/3p4/4P3/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 2"),
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
for name, fen in positions:
|
| 116 |
+
result = model.predict(fen)
|
| 117 |
+
print(f"\n{name}")
|
| 118 |
+
print(f" FEN: {fen}")
|
| 119 |
+
print(f" Best move: {result['best_move']}")
|
| 120 |
+
print(f" Q-value: {result['q_value']:.4f}")
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## Model Architecture
|
| 124 |
+
|
| 125 |
+
```python
|
| 126 |
+
TransformerConfig(
|
| 127 |
+
vocab_size=1968,
|
| 128 |
+
output_size=128,
|
| 129 |
+
embedding_dim=256,
|
| 130 |
+
num_layers=8,
|
| 131 |
+
num_heads=8,
|
| 132 |
+
max_sequence_length=79,
|
| 133 |
+
num_return_buckets=128,
|
| 134 |
+
pos_encodings="LEARNED",
|
| 135 |
+
apply_post_ln=True,
|
| 136 |
+
apply_qk_layernorm=False,
|
| 137 |
+
use_causal_mask=False,
|
| 138 |
+
)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
## Training Details
|
| 142 |
+
|
| 143 |
+
- **Base Model**: Initialized from pretrained 9M checkpoint
|
| 144 |
+
- **Training Method**: Self-play reinforcement learning
|
| 145 |
+
- **Reward Signal**: Stockfish evaluation at depth 20
|
| 146 |
+
- **Iteration**: 22 (EMA parameters)
|
| 147 |
+
- **Action Space**: 1968 possible moves (all legal chess moves)
|
| 148 |
+
- **Value Representation**: Discretized into 128 buckets
|
| 149 |
+
|
| 150 |
+
## Use Cases
|
| 151 |
+
|
| 152 |
+
- Fast chess move prediction without search
|
| 153 |
+
- Chess position evaluation
|
| 154 |
+
- Research on learned planning in board games
|
| 155 |
+
- Integration into chess applications requiring low-latency move suggestions
|
| 156 |
+
|
| 157 |
+
## Limitations
|
| 158 |
+
|
| 159 |
+
- Does not perform explicit search (unlike traditional chess engines)
|
| 160 |
+
- May make suboptimal moves in complex tactical positions
|
| 161 |
+
- Performance depends on training data distribution
|
| 162 |
+
- Best suited for fast move suggestions rather than deep analysis
|
| 163 |
+
|
| 164 |
+
## Background
|
| 165 |
+
|
| 166 |
+
This model is based on the architecture from DeepMind's [Searchless Chess](https://github.com/google-deepmind/searchless_chess) work. The **self-play training implementation and this trained model** are original work by Darrell Best.
|
| 167 |
+
|
| 168 |
+
For the full self-play training implementation and codebase, visit:
|
| 169 |
+
- Repository: https://github.com/DarrellBest/searchless_chess
|
| 170 |
+
|
| 171 |
+
## License
|
| 172 |
+
|
| 173 |
+
Apache 2.0
|
| 174 |
+
|
| 175 |
+
## Model Card Contact
|
| 176 |
+
|
| 177 |
+
For questions or issues, please open an issue on the [GitHub repository](https://github.com/DarrellBest/searchless_chess).
|
hf_space_repo/searchless_chess_model/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab_size": 1968,
|
| 3 |
+
"output_size": 128,
|
| 4 |
+
"embedding_dim": 256,
|
| 5 |
+
"num_layers": 8,
|
| 6 |
+
"num_heads": 8,
|
| 7 |
+
"max_sequence_length": 79,
|
| 8 |
+
"num_return_buckets": 128,
|
| 9 |
+
"model_name": "9M"
|
| 10 |
+
}
|
hf_space_repo/searchless_chess_model/model_info.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "searchless_chess",
|
| 3 |
+
"framework": "jax",
|
| 4 |
+
"library": "dm-haiku",
|
| 5 |
+
"includes_source": true,
|
| 6 |
+
"source_modules": [
|
| 7 |
+
"tokenizer.py",
|
| 8 |
+
"transformer.py",
|
| 9 |
+
"constants.py",
|
| 10 |
+
"utils.py",
|
| 11 |
+
"config.py"
|
| 12 |
+
]
|
| 13 |
+
}
|
hf_space_repo/searchless_chess_model/searchless_chess_code/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Searchless Chess code bundle
|
hf_space_repo/searchless_chess_model/searchless_chess_code/config.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Defines the configuration dataclasses."""
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
from typing import Literal
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PolicyType = Literal['action_value', 'state_value', 'behavioral_cloning']
|
| 23 |
+
POLICY_TYPES = ['action_value', 'state_value', 'behavioral_cloning']
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclasses.dataclass(kw_only=True)
|
| 27 |
+
class DataConfig:
|
| 28 |
+
"""Config for the data generation."""
|
| 29 |
+
|
| 30 |
+
# The batch size for the sequences.
|
| 31 |
+
batch_size: int
|
| 32 |
+
# Whether to shuffle the dataset (shuffling is applied per epoch).
|
| 33 |
+
shuffle: bool = False
|
| 34 |
+
# The seed used for shuffling and transformations of the data.
|
| 35 |
+
seed: int | None = 0
|
| 36 |
+
# Whether to drop partial batches.
|
| 37 |
+
drop_remainder: bool = False
|
| 38 |
+
# The number of child processes launched to parallelize the transformations.
|
| 39 |
+
worker_count: int | None = 0
|
| 40 |
+
# The number of return buckets.
|
| 41 |
+
num_return_buckets: int
|
| 42 |
+
# The dataset split.
|
| 43 |
+
split: Literal['train', 'test']
|
| 44 |
+
# The policy used to create the dataset.
|
| 45 |
+
policy: PolicyType
|
| 46 |
+
# The number of records to read from the dataset (can be useful when, e.g.,
|
| 47 |
+
# the dataset does not fit into memory).
|
| 48 |
+
num_records: int | None = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclasses.dataclass(kw_only=True)
|
| 52 |
+
class TrainConfig:
|
| 53 |
+
"""Config for the training function."""
|
| 54 |
+
|
| 55 |
+
# The data configuration for training.
|
| 56 |
+
data: DataConfig
|
| 57 |
+
# The learning rate for Adam.
|
| 58 |
+
learning_rate: float
|
| 59 |
+
# The gradient clipping value.
|
| 60 |
+
max_grad_norm: float = 1.0
|
| 61 |
+
# The number of gradient steps.
|
| 62 |
+
num_steps: int
|
| 63 |
+
# The frequency (in gradient steps) at which checkpoints should be saved
|
| 64 |
+
# (`None` means there is no checkpointing).
|
| 65 |
+
ckpt_frequency: int | None = None
|
| 66 |
+
# If provided, the maximum number of checkpoints to keep.
|
| 67 |
+
ckpt_max_to_keep: int | None = 1
|
| 68 |
+
# The frequency (in gradient steps) at which checkpoints should be saved
|
| 69 |
+
# permanently (`None` means all checkpoints are temporary).
|
| 70 |
+
save_frequency: int | None = None
|
| 71 |
+
# The frequency of logging in gradient steps (`None` means no logging).
|
| 72 |
+
log_frequency: int | None = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclasses.dataclass(kw_only=True)
|
| 76 |
+
class EvalConfig:
|
| 77 |
+
"""Config for the evaluator."""
|
| 78 |
+
|
| 79 |
+
# The data configuration for evaluation.
|
| 80 |
+
data: DataConfig
|
| 81 |
+
# How many data points to consider for evaluation.
|
| 82 |
+
num_eval_data: int | None = None
|
| 83 |
+
# Enables use of ema-ed params in eval.
|
| 84 |
+
use_ema_params: bool = False
|
| 85 |
+
# The policy used to play moves with the model.
|
| 86 |
+
policy: PolicyType
|
| 87 |
+
# The number of return buckets.
|
| 88 |
+
num_return_buckets: int
|
| 89 |
+
# The batch size for evaluation.
|
| 90 |
+
batch_size: int | None = None
|
hf_space_repo/searchless_chess_model/searchless_chess_code/constants.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Constants, interfaces, and types."""
|
| 17 |
+
|
| 18 |
+
import abc
|
| 19 |
+
from collections.abc import Callable, Mapping
|
| 20 |
+
import dataclasses
|
| 21 |
+
from typing import Any, NamedTuple, Protocol
|
| 22 |
+
|
| 23 |
+
from apache_beam import coders
|
| 24 |
+
from grain import python as pygrain
|
| 25 |
+
import haiku as hk
|
| 26 |
+
import jaxtyping as jtp
|
| 27 |
+
|
| 28 |
+
import config as config_lib
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Integer sequences of token ids.
|
| 32 |
+
Sequences = jtp.UInt32[jtp.Array, 'B T']
|
| 33 |
+
|
| 34 |
+
# The predictions are log-probabilities (natural logarithm) for the passed
|
| 35 |
+
# sequences. It can either be marginal log-probabilities (i.e. log P(s) for all
|
| 36 |
+
# sequences s in the batch), or full conditionals (i.e. log P(token | s_<t) for
|
| 37 |
+
# all sequence s, time t and token in the alphabet).
|
| 38 |
+
Marginals = jtp.Float32[jtp.Array, '*B']
|
| 39 |
+
Conditionals = jtp.Float32[jtp.Array, '*B T F']
|
| 40 |
+
Predictions = Marginals | Conditionals
|
| 41 |
+
|
| 42 |
+
# True means the loss will be masked there, i.e. we ignore it.
|
| 43 |
+
LossMask = jtp.Bool[jtp.Array, 'B T']
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclasses.dataclass
|
| 47 |
+
class Predictor:
|
| 48 |
+
"""Defines the predictor interface."""
|
| 49 |
+
|
| 50 |
+
initial_params: Callable[..., hk.MutableParams]
|
| 51 |
+
predict: Callable[..., Predictions]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DataLoaderBuilder(Protocol):
|
| 55 |
+
|
| 56 |
+
def __call__(self, config: config_lib.DataConfig) -> pygrain.DataLoader:
|
| 57 |
+
"""Returns a PyGrain data loader from the `config`."""
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Evaluator(abc.ABC):
|
| 61 |
+
"""Defines the interface of the evaluator that evaluates a predictor."""
|
| 62 |
+
|
| 63 |
+
@abc.abstractmethod
|
| 64 |
+
def step(self, params: hk.Params, step: int) -> Mapping[str, Any]:
|
| 65 |
+
"""Returns the results of evaluating the predictor with `params`."""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class EvaluatorBuilder(Protocol):
|
| 69 |
+
|
| 70 |
+
def __call__(
|
| 71 |
+
self,
|
| 72 |
+
predictor: Predictor,
|
| 73 |
+
config: config_lib.EvalConfig,
|
| 74 |
+
) -> Evaluator:
|
| 75 |
+
"""Returns an evaluator for the `predictor` and `config`.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
predictor: The predictor to be evaluated. The training loop continuously
|
| 79 |
+
saves the predictor's parameters, which are then loaded in the
|
| 80 |
+
evaluation loop and passed to the evaluator's step method.
|
| 81 |
+
config: The configuration of the evaluator.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
CODERS = {
|
| 86 |
+
'fen': coders.StrUtf8Coder(),
|
| 87 |
+
'move': coders.StrUtf8Coder(),
|
| 88 |
+
'count': coders.BigIntegerCoder(),
|
| 89 |
+
'win_prob': coders.FloatCoder(),
|
| 90 |
+
}
|
| 91 |
+
CODERS['state_value'] = coders.TupleCoder((
|
| 92 |
+
CODERS['fen'],
|
| 93 |
+
CODERS['win_prob'],
|
| 94 |
+
))
|
| 95 |
+
CODERS['action_value'] = coders.TupleCoder((
|
| 96 |
+
CODERS['fen'],
|
| 97 |
+
CODERS['move'],
|
| 98 |
+
CODERS['win_prob'],
|
| 99 |
+
))
|
| 100 |
+
CODERS['behavioral_cloning'] = coders.TupleCoder((
|
| 101 |
+
CODERS['fen'],
|
| 102 |
+
CODERS['move'],
|
| 103 |
+
))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class BehavioralCloningData(NamedTuple):
|
| 107 |
+
fen: str
|
| 108 |
+
move: str
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class StateValueData(NamedTuple):
|
| 112 |
+
fen: str
|
| 113 |
+
win_prob: float
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class ActionValueData(NamedTuple):
|
| 117 |
+
fen: str
|
| 118 |
+
move: str
|
| 119 |
+
win_prob: float
|
hf_space_repo/searchless_chess_model/searchless_chess_code/hf_model.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace model wrapper for searchless chess."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
|
| 7 |
+
import haiku as hk
|
| 8 |
+
import jax
|
| 9 |
+
import jax.numpy as jnp
|
| 10 |
+
import numpy as np
|
| 11 |
+
import orbax.checkpoint as ocp
|
| 12 |
+
|
| 13 |
+
import tokenizer
|
| 14 |
+
import transformer
|
| 15 |
+
import utils
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SearchlessChessConfig:
|
| 19 |
+
"""Configuration for SearchlessChess model."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
vocab_size: int = 1968,
|
| 24 |
+
output_size: int = 128,
|
| 25 |
+
embedding_dim: int = 256,
|
| 26 |
+
num_layers: int = 8,
|
| 27 |
+
num_heads: int = 8,
|
| 28 |
+
max_sequence_length: int = 79,
|
| 29 |
+
num_return_buckets: int = 128,
|
| 30 |
+
model_name: str = "9M",
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
self.vocab_size = vocab_size
|
| 34 |
+
self.output_size = output_size
|
| 35 |
+
self.embedding_dim = embedding_dim
|
| 36 |
+
self.num_layers = num_layers
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
self.max_sequence_length = max_sequence_length
|
| 39 |
+
self.num_return_buckets = num_return_buckets
|
| 40 |
+
self.model_name = model_name
|
| 41 |
+
|
| 42 |
+
# Store any extra kwargs
|
| 43 |
+
for key, value in kwargs.items():
|
| 44 |
+
setattr(self, key, value)
|
| 45 |
+
|
| 46 |
+
def to_dict(self) -> Dict:
|
| 47 |
+
"""Convert config to dictionary."""
|
| 48 |
+
return {
|
| 49 |
+
"vocab_size": self.vocab_size,
|
| 50 |
+
"output_size": self.output_size,
|
| 51 |
+
"embedding_dim": self.embedding_dim,
|
| 52 |
+
"num_layers": self.num_layers,
|
| 53 |
+
"num_heads": self.num_heads,
|
| 54 |
+
"max_sequence_length": self.max_sequence_length,
|
| 55 |
+
"num_return_buckets": self.num_return_buckets,
|
| 56 |
+
"model_name": self.model_name,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def from_dict(cls, config_dict: Dict) -> "SearchlessChessConfig":
|
| 61 |
+
"""Load config from dictionary."""
|
| 62 |
+
return cls(**config_dict)
|
| 63 |
+
|
| 64 |
+
def save_pretrained(self, save_directory: str):
|
| 65 |
+
"""Save config to directory."""
|
| 66 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 67 |
+
config_path = os.path.join(save_directory, "config.json")
|
| 68 |
+
with open(config_path, "w") as f:
|
| 69 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def from_pretrained(cls, model_path: str) -> "SearchlessChessConfig":
|
| 73 |
+
"""Load config from directory."""
|
| 74 |
+
config_path = os.path.join(model_path, "config.json")
|
| 75 |
+
with open(config_path, "r") as f:
|
| 76 |
+
config_dict = json.load(f)
|
| 77 |
+
return cls.from_dict(config_dict)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class SearchlessChessModel:
|
| 81 |
+
"""HuggingFace-compatible wrapper for SearchlessChess JAX/Haiku model."""
|
| 82 |
+
|
| 83 |
+
def __init__(self, config: SearchlessChessConfig):
|
| 84 |
+
self.config = config
|
| 85 |
+
|
| 86 |
+
# Build transformer config
|
| 87 |
+
self.transformer_config = transformer.TransformerConfig(
|
| 88 |
+
vocab_size=config.vocab_size,
|
| 89 |
+
output_size=config.output_size,
|
| 90 |
+
pos_encodings=transformer.PositionalEncodings.LEARNED,
|
| 91 |
+
max_sequence_length=config.max_sequence_length,
|
| 92 |
+
num_heads=config.num_heads,
|
| 93 |
+
num_layers=config.num_layers,
|
| 94 |
+
embedding_dim=config.embedding_dim,
|
| 95 |
+
apply_post_ln=True,
|
| 96 |
+
apply_qk_layernorm=False,
|
| 97 |
+
use_causal_mask=False,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Build predictor
|
| 101 |
+
self.predictor = transformer.build_transformer_predictor(self.transformer_config)
|
| 102 |
+
|
| 103 |
+
# Initialize params
|
| 104 |
+
self.params = None
|
| 105 |
+
self.return_buckets_values = None
|
| 106 |
+
|
| 107 |
+
# Get return bucket values
|
| 108 |
+
_, self.return_buckets_values = utils.get_uniform_buckets_edges_values(
|
| 109 |
+
config.num_return_buckets
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def load_params(self, params_path: str):
|
| 113 |
+
"""Load parameters from Orbax checkpoint."""
|
| 114 |
+
# Convert to absolute path (Orbax requires absolute paths)
|
| 115 |
+
params_path = os.path.abspath(params_path)
|
| 116 |
+
|
| 117 |
+
# Create dummy params for structure
|
| 118 |
+
dummy_params = self.predictor.initial_params(
|
| 119 |
+
rng=jax.random.PRNGKey(0),
|
| 120 |
+
targets=np.ones((1, 1), dtype=np.uint32),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Load checkpoint
|
| 124 |
+
restore_args = ocp.checkpoint_utils.construct_restore_args(dummy_params)
|
| 125 |
+
checkpointer = ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
|
| 126 |
+
self.params = checkpointer.restore(params_path, restore_args=restore_args)
|
| 127 |
+
|
| 128 |
+
def predict(self, fen: str, temperature: float = 1.0) -> Dict:
|
| 129 |
+
"""Predict move from FEN position.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
fen: Chess position in FEN notation
|
| 133 |
+
temperature: Temperature for sampling (1.0 = no modification)
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Dictionary with:
|
| 137 |
+
- q_values: Q-value distribution
|
| 138 |
+
- action_probs: Action probabilities
|
| 139 |
+
- best_action: Best action index
|
| 140 |
+
- best_move: Best move in UCI notation
|
| 141 |
+
"""
|
| 142 |
+
if self.params is None:
|
| 143 |
+
raise ValueError("Model parameters not loaded. Call load_params() first.")
|
| 144 |
+
|
| 145 |
+
# Tokenize input
|
| 146 |
+
tokens = tokenizer.tokenize(fen)
|
| 147 |
+
tokens = tokens[None, :] # Add batch dimension
|
| 148 |
+
|
| 149 |
+
# Get predictions
|
| 150 |
+
bucket_log_probs = self.predictor.predict(
|
| 151 |
+
params=self.params,
|
| 152 |
+
targets=tokens,
|
| 153 |
+
rng=None,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Extract action Q-values (second to last position)
|
| 157 |
+
action_bucket_log_probs = bucket_log_probs[0, -2] # [num_return_buckets]
|
| 158 |
+
action_bucket_probs = jnp.exp(action_bucket_log_probs)
|
| 159 |
+
|
| 160 |
+
# Compute Q-value for each action bucket
|
| 161 |
+
q_value = float(jnp.dot(action_bucket_probs, self.return_buckets_values))
|
| 162 |
+
|
| 163 |
+
# Get action probabilities from Q-values
|
| 164 |
+
# Use softmax over return bucket expectations
|
| 165 |
+
action_values = jnp.dot(
|
| 166 |
+
jnp.exp(bucket_log_probs[0, -2:]),
|
| 167 |
+
self.return_buckets_values,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Apply temperature and softmax
|
| 171 |
+
action_logits = action_values / temperature
|
| 172 |
+
action_probs = jax.nn.softmax(action_logits)
|
| 173 |
+
|
| 174 |
+
# Get best action
|
| 175 |
+
best_action = int(jnp.argmax(action_probs))
|
| 176 |
+
|
| 177 |
+
# Convert action to move
|
| 178 |
+
best_move = utils.ACTION_TO_MOVE.get(best_action, "unknown")
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
"q_value": q_value,
|
| 182 |
+
"action_probs": np.array(action_probs),
|
| 183 |
+
"best_action": best_action,
|
| 184 |
+
"best_move": best_move,
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
def save_pretrained(self, save_directory: str):
|
| 188 |
+
"""Save model to directory in HuggingFace format."""
|
| 189 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 190 |
+
|
| 191 |
+
# Save config
|
| 192 |
+
self.config.save_pretrained(save_directory)
|
| 193 |
+
|
| 194 |
+
# Save parameters as numpy arrays
|
| 195 |
+
if self.params is not None:
|
| 196 |
+
params_cpu = jax.device_get(self.params)
|
| 197 |
+
params_flat, tree_def = jax.tree.flatten(params_cpu)
|
| 198 |
+
|
| 199 |
+
# Save flattened params
|
| 200 |
+
params_path = os.path.join(save_directory, "params.npz")
|
| 201 |
+
np.savez(params_path, *params_flat)
|
| 202 |
+
|
| 203 |
+
# Save tree structure
|
| 204 |
+
import pickle
|
| 205 |
+
tree_path = os.path.join(save_directory, "tree_structure.pkl")
|
| 206 |
+
with open(tree_path, "wb") as f:
|
| 207 |
+
pickle.dump(tree_def, f)
|
| 208 |
+
|
| 209 |
+
# Copy necessary source files for standalone usage
|
| 210 |
+
import shutil
|
| 211 |
+
src_dir = os.path.dirname(__file__)
|
| 212 |
+
code_dir = os.path.join(save_directory, "searchless_chess_code")
|
| 213 |
+
os.makedirs(code_dir, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
# Copy core modules and fix imports for standalone usage
|
| 216 |
+
def fix_imports(content):
|
| 217 |
+
"""Replace absolute imports with relative imports."""
|
| 218 |
+
content = content.replace("import tokenizer", "import tokenizer")
|
| 219 |
+
content = content.replace("import transformer", "import transformer")
|
| 220 |
+
content = content.replace("import utils", "import utils")
|
| 221 |
+
content = content.replace("import constants", "import constants")
|
| 222 |
+
content = content.replace("import config as config_lib", "import config as config_lib")
|
| 223 |
+
content = content.replace("import config", "import config")
|
| 224 |
+
return content
|
| 225 |
+
|
| 226 |
+
for module in ["tokenizer.py", "transformer.py", "constants.py", "utils.py", "config.py"]:
|
| 227 |
+
src_file = os.path.join(src_dir, module)
|
| 228 |
+
dst_file = os.path.join(code_dir, module)
|
| 229 |
+
if os.path.exists(src_file):
|
| 230 |
+
with open(src_file, 'r') as f:
|
| 231 |
+
content = fix_imports(f.read())
|
| 232 |
+
with open(dst_file, 'w') as f:
|
| 233 |
+
f.write(content)
|
| 234 |
+
|
| 235 |
+
# Create standalone hf_model.py
|
| 236 |
+
standalone_hf_model = os.path.join(code_dir, "hf_model.py")
|
| 237 |
+
with open(__file__, 'r') as source:
|
| 238 |
+
content = fix_imports(source.read())
|
| 239 |
+
with open(standalone_hf_model, 'w') as dest:
|
| 240 |
+
dest.write(content)
|
| 241 |
+
|
| 242 |
+
# Create __init__.py
|
| 243 |
+
with open(os.path.join(code_dir, "__init__.py"), "w") as f:
|
| 244 |
+
f.write("# Searchless Chess code bundle\n")
|
| 245 |
+
|
| 246 |
+
# Save model info
|
| 247 |
+
model_info = {
|
| 248 |
+
"model_type": "searchless_chess",
|
| 249 |
+
"framework": "jax",
|
| 250 |
+
"library": "dm-haiku",
|
| 251 |
+
"includes_source": True,
|
| 252 |
+
"source_modules": ["tokenizer.py", "transformer.py", "constants.py", "utils.py", "config.py"],
|
| 253 |
+
}
|
| 254 |
+
with open(os.path.join(save_directory, "model_info.json"), "w") as f:
|
| 255 |
+
json.dump(model_info, f, indent=2)
|
| 256 |
+
|
| 257 |
+
@classmethod
|
| 258 |
+
def from_pretrained(cls, model_path: str) -> "SearchlessChessModel":
|
| 259 |
+
"""Load model from directory."""
|
| 260 |
+
# Load config
|
| 261 |
+
config = SearchlessChessConfig.from_pretrained(model_path)
|
| 262 |
+
|
| 263 |
+
# Create model
|
| 264 |
+
model = cls(config)
|
| 265 |
+
|
| 266 |
+
# Load parameters
|
| 267 |
+
params_path = os.path.join(model_path, "params.npz")
|
| 268 |
+
tree_path = os.path.join(model_path, "tree_structure.pkl")
|
| 269 |
+
|
| 270 |
+
if os.path.exists(params_path) and os.path.exists(tree_path):
|
| 271 |
+
# Load tree structure
|
| 272 |
+
import pickle
|
| 273 |
+
with open(tree_path, "rb") as f:
|
| 274 |
+
tree_def = pickle.load(f)
|
| 275 |
+
|
| 276 |
+
# Load params
|
| 277 |
+
params_data = np.load(params_path)
|
| 278 |
+
params_flat = [params_data[f"arr_{i}"] for i in range(len(params_data.files))]
|
| 279 |
+
|
| 280 |
+
# Reconstruct pytree
|
| 281 |
+
model.params = jax.tree.unflatten(tree_def, params_flat)
|
| 282 |
+
|
| 283 |
+
return model
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def create_model_from_checkpoint(
|
| 287 |
+
checkpoint_path: str,
|
| 288 |
+
model_name: str = "9M",
|
| 289 |
+
use_ema: bool = True,
|
| 290 |
+
) -> SearchlessChessModel:
|
| 291 |
+
"""Create HuggingFace model from existing checkpoint.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
checkpoint_path: Path to checkpoint directory (e.g., checkpoints/9M_selfplay/4)
|
| 295 |
+
model_name: Model size (9M, 136M, 270M)
|
| 296 |
+
use_ema: Whether to load EMA parameters
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
SearchlessChessModel ready to save or use
|
| 300 |
+
"""
|
| 301 |
+
# Determine architecture from model name
|
| 302 |
+
if model_name == "9M":
|
| 303 |
+
num_layers, embedding_dim, num_heads = 8, 256, 8
|
| 304 |
+
elif model_name == "136M":
|
| 305 |
+
num_layers, embedding_dim, num_heads = 8, 1024, 8
|
| 306 |
+
else: # 270M
|
| 307 |
+
num_layers, embedding_dim, num_heads = 16, 1024, 8
|
| 308 |
+
|
| 309 |
+
# Create config
|
| 310 |
+
config = SearchlessChessConfig(
|
| 311 |
+
vocab_size=1968,
|
| 312 |
+
output_size=128,
|
| 313 |
+
embedding_dim=embedding_dim,
|
| 314 |
+
num_layers=num_layers,
|
| 315 |
+
num_heads=num_heads,
|
| 316 |
+
max_sequence_length=79,
|
| 317 |
+
num_return_buckets=128,
|
| 318 |
+
model_name=model_name,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Create model
|
| 322 |
+
model = SearchlessChessModel(config)
|
| 323 |
+
|
| 324 |
+
# Load parameters from Orbax checkpoint
|
| 325 |
+
params_dir = "params_ema" if use_ema else "params"
|
| 326 |
+
params_path = os.path.join(checkpoint_path, params_dir)
|
| 327 |
+
model.load_params(params_path)
|
| 328 |
+
|
| 329 |
+
return model
|
hf_space_repo/searchless_chess_model/searchless_chess_code/tokenizer.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements tokenization of FEN strings."""
|
| 17 |
+
|
| 18 |
+
import jaxtyping as jtp
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# pyfmt: disable
|
| 23 |
+
_CHARACTERS = [
|
| 24 |
+
'0',
|
| 25 |
+
'1',
|
| 26 |
+
'2',
|
| 27 |
+
'3',
|
| 28 |
+
'4',
|
| 29 |
+
'5',
|
| 30 |
+
'6',
|
| 31 |
+
'7',
|
| 32 |
+
'8',
|
| 33 |
+
'9',
|
| 34 |
+
'a',
|
| 35 |
+
'b',
|
| 36 |
+
'c',
|
| 37 |
+
'd',
|
| 38 |
+
'e',
|
| 39 |
+
'f',
|
| 40 |
+
'g',
|
| 41 |
+
'h',
|
| 42 |
+
'p',
|
| 43 |
+
'n',
|
| 44 |
+
'r',
|
| 45 |
+
'k',
|
| 46 |
+
'q',
|
| 47 |
+
'P',
|
| 48 |
+
'B',
|
| 49 |
+
'N',
|
| 50 |
+
'R',
|
| 51 |
+
'Q',
|
| 52 |
+
'K',
|
| 53 |
+
'w',
|
| 54 |
+
'.',
|
| 55 |
+
]
|
| 56 |
+
# pyfmt: enable
|
| 57 |
+
_CHARACTERS_INDEX = {letter: index for index, letter in enumerate(_CHARACTERS)}
|
| 58 |
+
_SPACES_CHARACTERS = frozenset({'1', '2', '3', '4', '5', '6', '7', '8'})
|
| 59 |
+
SEQUENCE_LENGTH = 77
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def tokenize(fen: str) -> jtp.Int32[jtp.Array, 'T']:
|
| 63 |
+
"""Returns an array of tokens from a fen string.
|
| 64 |
+
|
| 65 |
+
We compute a tokenized representation of the board, from the FEN string.
|
| 66 |
+
The final array of tokens is a mapping from this string to numbers, which
|
| 67 |
+
are defined in the dictionary `_CHARACTERS_INDEX`.
|
| 68 |
+
For the 'en passant' information, we convert the '-' (which means there is
|
| 69 |
+
no en passant relevant square) to '..', to always have two characters, and
|
| 70 |
+
a fixed length output.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
fen: The board position in Forsyth-Edwards Notation.
|
| 74 |
+
"""
|
| 75 |
+
# Extracting the relevant information from the FEN.
|
| 76 |
+
board, side, castling, en_passant, halfmoves_last, fullmoves = fen.split(' ')
|
| 77 |
+
board = board.replace('/', '')
|
| 78 |
+
board = side + board
|
| 79 |
+
|
| 80 |
+
indices = list()
|
| 81 |
+
|
| 82 |
+
for char in board:
|
| 83 |
+
if char in _SPACES_CHARACTERS:
|
| 84 |
+
indices.extend(int(char) * [_CHARACTERS_INDEX['.']])
|
| 85 |
+
else:
|
| 86 |
+
indices.append(_CHARACTERS_INDEX[char])
|
| 87 |
+
|
| 88 |
+
if castling == '-':
|
| 89 |
+
indices.extend(4 * [_CHARACTERS_INDEX['.']])
|
| 90 |
+
else:
|
| 91 |
+
for char in castling:
|
| 92 |
+
indices.append(_CHARACTERS_INDEX[char])
|
| 93 |
+
# Padding castling to have exactly 4 characters.
|
| 94 |
+
if len(castling) < 4:
|
| 95 |
+
indices.extend((4 - len(castling)) * [_CHARACTERS_INDEX['.']])
|
| 96 |
+
|
| 97 |
+
if en_passant == '-':
|
| 98 |
+
indices.extend(2 * [_CHARACTERS_INDEX['.']])
|
| 99 |
+
else:
|
| 100 |
+
# En passant is a square like 'e3'.
|
| 101 |
+
for char in en_passant:
|
| 102 |
+
indices.append(_CHARACTERS_INDEX[char])
|
| 103 |
+
|
| 104 |
+
# Three digits for halfmoves (since last capture) is enough since the game
|
| 105 |
+
# ends at 50.
|
| 106 |
+
halfmoves_last += '.' * (3 - len(halfmoves_last))
|
| 107 |
+
indices.extend([_CHARACTERS_INDEX[x] for x in halfmoves_last])
|
| 108 |
+
|
| 109 |
+
# Three digits for full moves is enough (no game lasts longer than 999
|
| 110 |
+
# moves).
|
| 111 |
+
fullmoves += '.' * (3 - len(fullmoves))
|
| 112 |
+
indices.extend([_CHARACTERS_INDEX[x] for x in fullmoves])
|
| 113 |
+
|
| 114 |
+
assert len(indices) == SEQUENCE_LENGTH
|
| 115 |
+
|
| 116 |
+
return np.asarray(indices, dtype=np.uint8)
|
hf_space_repo/searchless_chess_model/searchless_chess_code/transformer.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Transformer model."""
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
import enum
|
| 20 |
+
import functools
|
| 21 |
+
|
| 22 |
+
import haiku as hk
|
| 23 |
+
import jax
|
| 24 |
+
import jax.nn as jnn
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
import constants
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PositionalEncodings(enum.Enum):
|
| 32 |
+
SINUSOID = enum.auto()
|
| 33 |
+
LEARNED = enum.auto()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclasses.dataclass(kw_only=True)
|
| 37 |
+
class TransformerConfig:
|
| 38 |
+
"""Hyperparameters used in the Transformer architectures."""
|
| 39 |
+
|
| 40 |
+
# The random seed for parameter initialization.
|
| 41 |
+
seed: int = 1
|
| 42 |
+
# The input vocabulary size.
|
| 43 |
+
vocab_size: int
|
| 44 |
+
# The output size (by default equal to the vocabulary size).
|
| 45 |
+
output_size: int | None = None
|
| 46 |
+
# The dimension of the first embedding.
|
| 47 |
+
embedding_dim: int = 64
|
| 48 |
+
# The number of multi-head attention layers.
|
| 49 |
+
num_layers: int = 4
|
| 50 |
+
# The number of heads per layer.
|
| 51 |
+
num_heads: int = 8
|
| 52 |
+
# Whether to use a causal mask or not.
|
| 53 |
+
use_causal_mask: bool = True
|
| 54 |
+
# The parameter initialization scale for the embeddings.
|
| 55 |
+
emb_init_scale: float = 0.02
|
| 56 |
+
# Positional encodings to use.
|
| 57 |
+
pos_encodings: PositionalEncodings = PositionalEncodings.SINUSOID
|
| 58 |
+
# Maximum sequence length, useful for the LEARNED positional encodings.
|
| 59 |
+
max_sequence_length: int | None = None
|
| 60 |
+
# How much larger the hidden layer of the feedforward network should be
|
| 61 |
+
# compared to the `embedding_dim`.
|
| 62 |
+
widening_factor: int = 4
|
| 63 |
+
# Whether to apply QK normalization trick in attention layer.
|
| 64 |
+
apply_qk_layernorm: bool = False
|
| 65 |
+
# Whether to apply post LN after attention + MLP blocks
|
| 66 |
+
apply_post_ln: bool = True
|
| 67 |
+
|
| 68 |
+
def __post_init__(self):
|
| 69 |
+
if self.output_size is None:
|
| 70 |
+
self.output_size = self.vocab_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MultiHeadDotProductAttention(hk.Module):
|
| 74 |
+
"""Multi-head dot-product attention (Vaswani et al., 2017)."""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
num_heads: int,
|
| 79 |
+
num_hiddens_per_head: int,
|
| 80 |
+
name: str | None = None,
|
| 81 |
+
apply_qk_layernorm: bool = False,
|
| 82 |
+
) -> None:
|
| 83 |
+
"""Initializes the attention module.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
num_heads: Number of heads to use.
|
| 87 |
+
num_hiddens_per_head: Number of hidden neurons per head.
|
| 88 |
+
name: Name of the module.
|
| 89 |
+
apply_qk_layernorm: Applies layernorm to query and key matrices, this
|
| 90 |
+
helps training stability.
|
| 91 |
+
"""
|
| 92 |
+
super().__init__(name=name)
|
| 93 |
+
self._num_heads = num_heads
|
| 94 |
+
self._num_hiddens_per_head = num_hiddens_per_head
|
| 95 |
+
self._apply_qk_layernorm = apply_qk_layernorm
|
| 96 |
+
|
| 97 |
+
def __call__(
|
| 98 |
+
self,
|
| 99 |
+
inputs_q: jax.Array,
|
| 100 |
+
inputs_kv: jax.Array,
|
| 101 |
+
mask: jax.Array | None = None,
|
| 102 |
+
) -> jax.Array:
|
| 103 |
+
"""Returns the output of the multi-head attention."""
|
| 104 |
+
batch_size, sequence_length, embedding_size = inputs_q.shape
|
| 105 |
+
|
| 106 |
+
num_hiddens = self._num_hiddens_per_head * self._num_heads
|
| 107 |
+
q = hk.Linear(num_hiddens, with_bias=False)(inputs_q)
|
| 108 |
+
k = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
|
| 109 |
+
|
| 110 |
+
if self._apply_qk_layernorm:
|
| 111 |
+
q = layer_norm(q)
|
| 112 |
+
k = layer_norm(k)
|
| 113 |
+
|
| 114 |
+
v = hk.Linear(num_hiddens, with_bias=False)(inputs_kv)
|
| 115 |
+
# The second (sequence) dimension is undefined since it can differ between
|
| 116 |
+
# queries and keys/values when decoding. Also checking that the inputs have
|
| 117 |
+
# the same batch size as the reshape below does not guarantee a failure if
|
| 118 |
+
# they are different.
|
| 119 |
+
new_shape = (batch_size, -1, self._num_heads, self._num_hiddens_per_head)
|
| 120 |
+
q = jnp.reshape(q, new_shape)
|
| 121 |
+
k = jnp.reshape(k, new_shape)
|
| 122 |
+
v = jnp.reshape(v, new_shape)
|
| 123 |
+
|
| 124 |
+
# Let b=batch_size, t=seq_len, h=num_heads, and d=num_hiddens_per_head.
|
| 125 |
+
attention = jnp.einsum('bthd,bThd->bhtT', q, k)
|
| 126 |
+
attention *= 1.0 / jnp.sqrt(self._num_hiddens_per_head)
|
| 127 |
+
|
| 128 |
+
if mask is not None:
|
| 129 |
+
attention = jnp.where(mask, attention, jnp.finfo(jnp.float32).min)
|
| 130 |
+
|
| 131 |
+
normalized_attention = jnn.softmax(attention)
|
| 132 |
+
|
| 133 |
+
output = jnp.einsum('bhtT,bThd->bthd', normalized_attention, v)
|
| 134 |
+
output = jnp.reshape(output, (batch_size, sequence_length, num_hiddens))
|
| 135 |
+
return hk.Linear(embedding_size, with_bias=False)(output)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def sinusoid_position_encoding(
|
| 139 |
+
sequence_length: int,
|
| 140 |
+
hidden_size: int,
|
| 141 |
+
max_timescale: float = 1e4,
|
| 142 |
+
) -> np.ndarray:
|
| 143 |
+
"""Creates sinusoidal encodings from the original transformer paper.
|
| 144 |
+
|
| 145 |
+
The returned values are, for all i < D/2:
|
| 146 |
+
array[pos, i] = sin(pos / (max_timescale^(2*i / D)))
|
| 147 |
+
array[pos, D/2 + i] = cos(pos / (max_timescale^(2*i / D)))
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
sequence_length: Sequence length.
|
| 151 |
+
hidden_size: Dimension of the positional encoding vectors, D. Should be
|
| 152 |
+
even.
|
| 153 |
+
max_timescale: Maximum timescale for the frequency.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
An array of shape [L, D] if `add_negative` or `keep_positive_side` is
|
| 157 |
+
`False`, else [2 * L, D].
|
| 158 |
+
"""
|
| 159 |
+
freqs = np.arange(0, hidden_size + 1, 2)
|
| 160 |
+
inv_freq = max_timescale ** (-freqs / hidden_size)
|
| 161 |
+
|
| 162 |
+
pos_seq = np.arange(start=0, stop=sequence_length)
|
| 163 |
+
|
| 164 |
+
sinusoid_inp = np.einsum('i,j->ij', pos_seq, inv_freq)
|
| 165 |
+
embeddings = np.concatenate(
|
| 166 |
+
[np.sin(sinusoid_inp), np.cos(sinusoid_inp)], axis=-1
|
| 167 |
+
)
|
| 168 |
+
return embeddings[:, :hidden_size]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def embed_sequences(
|
| 172 |
+
sequences: jax.Array,
|
| 173 |
+
config: TransformerConfig,
|
| 174 |
+
) -> jax.Array:
|
| 175 |
+
"""Returns embeddings for sequences of tokens."""
|
| 176 |
+
embs_init = hk.initializers.TruncatedNormal(stddev=config.emb_init_scale)
|
| 177 |
+
embeddings_layer = hk.Embed(
|
| 178 |
+
vocab_size=config.vocab_size,
|
| 179 |
+
embed_dim=config.embedding_dim,
|
| 180 |
+
lookup_style=hk.EmbedLookupStyle.ARRAY_INDEX,
|
| 181 |
+
w_init=embs_init,
|
| 182 |
+
)
|
| 183 |
+
embeddings = embeddings_layer(sequences)
|
| 184 |
+
embeddings *= jnp.sqrt(config.embedding_dim)
|
| 185 |
+
|
| 186 |
+
_, sequence_length, embedding_size = embeddings.shape
|
| 187 |
+
match config.pos_encodings:
|
| 188 |
+
case PositionalEncodings.SINUSOID:
|
| 189 |
+
pos_encodings = sinusoid_position_encoding(
|
| 190 |
+
sequence_length=sequence_length,
|
| 191 |
+
hidden_size=embedding_size,
|
| 192 |
+
)
|
| 193 |
+
case PositionalEncodings.LEARNED:
|
| 194 |
+
assert sequence_length <= config.max_sequence_length
|
| 195 |
+
positions = jnp.arange(sequence_length)
|
| 196 |
+
pos_encodings = hk.Embed(
|
| 197 |
+
vocab_size=config.max_sequence_length,
|
| 198 |
+
embed_dim=embedding_size,
|
| 199 |
+
)(positions)
|
| 200 |
+
return embeddings + pos_encodings
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def layer_norm(x: jax.Array) -> jax.Array:
|
| 204 |
+
"""Helper function for layer norm."""
|
| 205 |
+
return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def shift_right(sequences: jax.Array) -> jax.Array:
|
| 209 |
+
"""Right-shift the one-hot encoded input by padding on the temporal axis."""
|
| 210 |
+
bos_array = jnp.zeros((sequences.shape[0], 1), dtype=jnp.uint8)
|
| 211 |
+
padded_sequences = jnp.concatenate([bos_array, sequences], axis=1)
|
| 212 |
+
return padded_sequences[:, :-1]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _mlp_block(inputs: jax.Array, config: TransformerConfig) -> jax.Array:
|
| 216 |
+
"""Gated MLP block for the Transformer."""
|
| 217 |
+
ffn_dim = config.embedding_dim * config.widening_factor
|
| 218 |
+
split_1 = hk.Linear(ffn_dim, with_bias=False)(inputs)
|
| 219 |
+
split_2 = hk.Linear(ffn_dim, with_bias=False)(inputs)
|
| 220 |
+
gate_output = jnn.silu(split_1) * split_2
|
| 221 |
+
return hk.Linear(config.embedding_dim, with_bias=False)(gate_output)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _attention_block(inputs: jax.Array, config: TransformerConfig) -> jax.Array:
|
| 225 |
+
"""Attention block for the Transformer."""
|
| 226 |
+
batch_size, sequence_length = inputs.shape[:2]
|
| 227 |
+
if config.use_causal_mask:
|
| 228 |
+
causal_mask = np.tril(
|
| 229 |
+
np.ones((batch_size, 1, sequence_length, sequence_length))
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
causal_mask = None
|
| 233 |
+
block = MultiHeadDotProductAttention(
|
| 234 |
+
num_heads=config.num_heads,
|
| 235 |
+
num_hiddens_per_head=config.embedding_dim // config.num_heads,
|
| 236 |
+
apply_qk_layernorm=config.apply_qk_layernorm,
|
| 237 |
+
)
|
| 238 |
+
return block(inputs_q=inputs, inputs_kv=inputs, mask=causal_mask)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def transformer_decoder(
|
| 242 |
+
targets: jax.Array,
|
| 243 |
+
config: TransformerConfig,
|
| 244 |
+
) -> jax.Array:
|
| 245 |
+
"""Returns the transformer decoder output, shape [B, T, V].
|
| 246 |
+
|
| 247 |
+
Follows the LLaMa architecture:
|
| 248 |
+
https://github.com/facebookresearch/llama/blob/main/llama/model.py
|
| 249 |
+
Main changes to the original Transformer decoder:
|
| 250 |
+
- Using gating in the MLP block, with SwiGLU activation function.
|
| 251 |
+
- Using normalization before the attention and MLP blocks.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
targets: The integer target values, shape [B, T].
|
| 255 |
+
config: The config to use for the transformer.
|
| 256 |
+
"""
|
| 257 |
+
# Right shift the targets to get the inputs (the first token is now a 0).
|
| 258 |
+
inputs = shift_right(targets)
|
| 259 |
+
|
| 260 |
+
# Embeds the inputs and adds positional encodings.
|
| 261 |
+
embeddings = embed_sequences(inputs, config)
|
| 262 |
+
|
| 263 |
+
h = embeddings
|
| 264 |
+
for _ in range(config.num_layers):
|
| 265 |
+
attention_input = layer_norm(h)
|
| 266 |
+
attention = _attention_block(attention_input, config)
|
| 267 |
+
h += attention
|
| 268 |
+
|
| 269 |
+
mlp_input = layer_norm(h)
|
| 270 |
+
mlp_output = _mlp_block(mlp_input, config)
|
| 271 |
+
h += mlp_output
|
| 272 |
+
|
| 273 |
+
if config.apply_post_ln:
|
| 274 |
+
h = layer_norm(h)
|
| 275 |
+
logits = hk.Linear(config.output_size)(h)
|
| 276 |
+
return jnn.log_softmax(logits, axis=-1)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def build_transformer_predictor(
|
| 280 |
+
config: TransformerConfig,
|
| 281 |
+
) -> constants.Predictor:
|
| 282 |
+
"""Returns a transformer predictor."""
|
| 283 |
+
model = hk.transform(functools.partial(transformer_decoder, config=config))
|
| 284 |
+
return constants.Predictor(initial_params=model.init, predict=model.apply)
|
hf_space_repo/searchless_chess_model/searchless_chess_code/utils.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements some utility functions."""
|
| 17 |
+
|
| 18 |
+
import math
|
| 19 |
+
|
| 20 |
+
import chess
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# The lists of the strings of the row and columns of a chess board,
|
| 25 |
+
# traditionally named rank and file.
|
| 26 |
+
_CHESS_FILE = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _compute_all_possible_actions() -> tuple[dict[str, int], dict[int, str]]:
|
| 30 |
+
"""Returns two dicts converting moves to actions and actions to moves.
|
| 31 |
+
|
| 32 |
+
These dicts contain all possible chess moves.
|
| 33 |
+
"""
|
| 34 |
+
all_moves = []
|
| 35 |
+
|
| 36 |
+
# First, deal with the normal moves.
|
| 37 |
+
# Note that this includes castling, as it is just a rook or king move from one
|
| 38 |
+
# square to another.
|
| 39 |
+
board = chess.BaseBoard.empty()
|
| 40 |
+
for square in range(64):
|
| 41 |
+
next_squares = []
|
| 42 |
+
|
| 43 |
+
# Place the queen and see where it attacks (we don't need to cover the case
|
| 44 |
+
# for a bishop, rook, or pawn because the queen's moves includes all their
|
| 45 |
+
# squares).
|
| 46 |
+
board.set_piece_at(square, chess.Piece.from_symbol('Q'))
|
| 47 |
+
next_squares += board.attacks(square)
|
| 48 |
+
|
| 49 |
+
# Place knight and see where it attacks
|
| 50 |
+
board.set_piece_at(square, chess.Piece.from_symbol('N'))
|
| 51 |
+
next_squares += board.attacks(square)
|
| 52 |
+
board.remove_piece_at(square)
|
| 53 |
+
|
| 54 |
+
for next_square in next_squares:
|
| 55 |
+
all_moves.append(
|
| 56 |
+
chess.square_name(square) + chess.square_name(next_square)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Then deal with promotions.
|
| 60 |
+
# Only look at the last ranks.
|
| 61 |
+
promotion_moves = []
|
| 62 |
+
for rank, next_rank in [('2', '1'), ('7', '8')]:
|
| 63 |
+
for index_file, file in enumerate(_CHESS_FILE):
|
| 64 |
+
# Normal promotions.
|
| 65 |
+
move = f'{file}{rank}{file}{next_rank}'
|
| 66 |
+
promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
|
| 67 |
+
|
| 68 |
+
# Capture promotions.
|
| 69 |
+
# Left side.
|
| 70 |
+
if file > 'a':
|
| 71 |
+
next_file = _CHESS_FILE[index_file - 1]
|
| 72 |
+
move = f'{file}{rank}{next_file}{next_rank}'
|
| 73 |
+
promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
|
| 74 |
+
# Right side.
|
| 75 |
+
if file < 'h':
|
| 76 |
+
next_file = _CHESS_FILE[index_file + 1]
|
| 77 |
+
move = f'{file}{rank}{next_file}{next_rank}'
|
| 78 |
+
promotion_moves += [(move + piece) for piece in ['q', 'r', 'b', 'n']]
|
| 79 |
+
all_moves += promotion_moves
|
| 80 |
+
|
| 81 |
+
move_to_action, action_to_move = {}, {}
|
| 82 |
+
for action, move in enumerate(all_moves):
|
| 83 |
+
assert move not in move_to_action
|
| 84 |
+
move_to_action[move] = action
|
| 85 |
+
action_to_move[action] = move
|
| 86 |
+
|
| 87 |
+
return move_to_action, action_to_move
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
MOVE_TO_ACTION, ACTION_TO_MOVE = _compute_all_possible_actions()
|
| 91 |
+
NUM_ACTIONS = len(MOVE_TO_ACTION)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def centipawns_to_win_probability(centipawns: int) -> float:
|
| 95 |
+
"""Returns the win probability (in [0, 1]) converted from the centipawn score.
|
| 96 |
+
|
| 97 |
+
Reference: https://lichess.org/page/accuracy
|
| 98 |
+
Well-known transformation, backed by real-world data.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
centipawns: The chess score in centipawns.
|
| 102 |
+
"""
|
| 103 |
+
return 0.5 + 0.5 * (2 / (1 + math.exp(-0.00368208 * centipawns)) - 1)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_uniform_buckets_edges_values(
|
| 107 |
+
num_buckets: int,
|
| 108 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 109 |
+
"""Returns edges and values of uniformly sampled buckets in [0, 1].
|
| 110 |
+
|
| 111 |
+
Example: for num_buckets=4, it returns:
|
| 112 |
+
edges=[0.25, 0.50, 0.75]
|
| 113 |
+
values=[0.125, 0.375, 0.625, 0.875]
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
num_buckets: Number of buckets to create.
|
| 117 |
+
"""
|
| 118 |
+
full_linspace = np.linspace(0.0, 1.0, num_buckets + 1)
|
| 119 |
+
edges = full_linspace[1:-1]
|
| 120 |
+
values = (full_linspace[:-1] + full_linspace[1:]) / 2
|
| 121 |
+
return edges, values
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_return_buckets_from_returns(
|
| 125 |
+
returns: np.ndarray,
|
| 126 |
+
bins_edges: np.ndarray,
|
| 127 |
+
) -> np.ndarray:
|
| 128 |
+
"""Arranges the discounted returns into bins.
|
| 129 |
+
|
| 130 |
+
The returns are put into the bins specified by `bin_edges`. The length of
|
| 131 |
+
`bin_edges` is equal to the number of buckets minus 1. In case of a tie (if
|
| 132 |
+
the return is exactly equal to an edge), we take the bucket right before the
|
| 133 |
+
edge. See example below.
|
| 134 |
+
This function is purely using np.searchsorted, so it's a good reference to
|
| 135 |
+
look at.
|
| 136 |
+
|
| 137 |
+
Examples:
|
| 138 |
+
* bin_edges=[0.5] and returns=[0., 1.] gives the buckets [0, 1].
|
| 139 |
+
* bin_edges=[-30., 30.] and returns=[-200., -30., 0., 1.] gives the buckets
|
| 140 |
+
[0, 0, 1, 1].
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
returns: An array of discounted returns, rank 1.
|
| 144 |
+
bins_edges: The boundary values of the return buckets, rank 1.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
An array of buckets, described as integers, rank 1.
|
| 148 |
+
|
| 149 |
+
Raises:
|
| 150 |
+
ValueError if `returns` or `bins_edges` are not of rank 1.
|
| 151 |
+
"""
|
| 152 |
+
if len(returns.shape) != 1:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
'The passed returns should be of rank 1. Got'
|
| 155 |
+
f' rank={len(returns.shape)}.'
|
| 156 |
+
)
|
| 157 |
+
if len(bins_edges.shape) != 1:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
'The passed bins_edges should be of rank 1. Got'
|
| 160 |
+
f' rank{len(bins_edges.shape)}.'
|
| 161 |
+
)
|
| 162 |
+
return np.searchsorted(bins_edges, returns, side='left')
|
hf_space_repo/train_self_play.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training script for GRPO chess self-play."""
|
| 2 |
+
import argparse
|
| 3 |
+
import warnings
|
| 4 |
+
from typing import Any
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
from src.grpo_self_play.trainer import get_trainer
|
| 8 |
+
from src.grpo_self_play.chess.boards_dataset import ChessStartStatesDataset
|
| 9 |
+
from src.grpo_self_play.grpo_logic.model import GRPOChessTransformer
|
| 10 |
+
from src.grpo_self_play.configs.config_loader import load_experiment_config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train(
|
| 14 |
+
config_path: str = "default.yaml",
|
| 15 |
+
overrides: dict[str, dict[str, Any]] | None = None,
|
| 16 |
+
dataloader_kwargs: dict[str, Any] | None = None
|
| 17 |
+
) -> None:
|
| 18 |
+
"""Main training function for GRPO chess self-play.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
config_path: Path to the YAML config file (relative to configs directory)
|
| 22 |
+
overrides: Optional dict of overrides per section. Example:
|
| 23 |
+
{
|
| 24 |
+
"grpo": {"lr": 1e-4, "entropy_coef": 0.2},
|
| 25 |
+
"training": {"num_epochs": 100},
|
| 26 |
+
"stockfish": {"skill_level": 5},
|
| 27 |
+
}
|
| 28 |
+
dataloader_kwargs: Optional dict of arguments to pass to DataLoader constructor.
|
| 29 |
+
These override config values. Example: {"batch_size": 64, "num_workers": 4}
|
| 30 |
+
"""
|
| 31 |
+
config = load_experiment_config(config_path, overrides=overrides)
|
| 32 |
+
|
| 33 |
+
# Build dataloader kwargs from config, with defaults
|
| 34 |
+
dataloader_config = {
|
| 35 |
+
"batch_size": config.training.batch_size,
|
| 36 |
+
"num_workers": 2,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# Apply dataloader_kwargs overrides and warn if overriding config values
|
| 40 |
+
if dataloader_kwargs:
|
| 41 |
+
for key, value in dataloader_kwargs.items():
|
| 42 |
+
if key in dataloader_config:
|
| 43 |
+
warnings.warn(
|
| 44 |
+
f"Overriding DataLoader '{key}' from config ({dataloader_config[key]}) "
|
| 45 |
+
f"with provided value ({value})",
|
| 46 |
+
UserWarning,
|
| 47 |
+
stacklevel=2
|
| 48 |
+
)
|
| 49 |
+
dataloader_config[key] = value
|
| 50 |
+
|
| 51 |
+
trainer = get_trainer(num_epochs=config.training.num_epochs)
|
| 52 |
+
dataset = ChessStartStatesDataset(config.dataset)
|
| 53 |
+
dataloader = DataLoader(dataset, **dataloader_config)
|
| 54 |
+
model = GRPOChessTransformer(
|
| 55 |
+
transformer_config=config.transformer,
|
| 56 |
+
grpo_config=config.grpo,
|
| 57 |
+
eval_cfg=config.eval,
|
| 58 |
+
stockfish_cfg=config.stockfish,
|
| 59 |
+
policy_cfg=config.policy,
|
| 60 |
+
searcher_cfg=config.searcher,
|
| 61 |
+
pretrain_cfg=config.pretrain,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
print("Starting Training with WandB Tracking...")
|
| 65 |
+
trainer.fit(model, dataloader)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
parser.add_argument("--config", type=str, default="default.yaml")
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
+
train(config_path=args.config)
|
hf_space_repo/trainer.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import random
|
| 3 |
+
import string
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 8 |
+
|
| 9 |
+
def generate_run_name(project: str = "chess-grpo") -> str:
|
| 10 |
+
"""Generate a unique run name with timestamp and random suffix.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
project: Project name prefix
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Unique run name string
|
| 17 |
+
"""
|
| 18 |
+
timestamp = time.strftime("%Y%m%d-%H%M")
|
| 19 |
+
random_suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
| 20 |
+
return f"{project}-{timestamp}-{random_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_trainer(num_epochs: int = 5000,
|
| 24 |
+
checkpoint_dir: str = "/content/drive/MyDrive/data/grpo-chess/checkpoints/",
|
| 25 |
+
checkpoint_every_n_epochs: int = 5,
|
| 26 |
+
keep_n_checkpoints: int = 3) -> pl.Trainer:
|
| 27 |
+
"""Create a PyTorch Lightning trainer with WandB logging and checkpointing.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
num_epochs: Maximum number of training epochs
|
| 31 |
+
checkpoint_dir: Directory to save model checkpoints
|
| 32 |
+
checkpoint_every_n_epochs: Save periodic checkpoint every N epochs
|
| 33 |
+
keep_n_checkpoints: Keep last N periodic checkpoints per run
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Configured PyTorch Lightning trainer
|
| 37 |
+
"""
|
| 38 |
+
run_name = generate_run_name()
|
| 39 |
+
print(f"Generated run name: {run_name}")
|
| 40 |
+
|
| 41 |
+
wandb_logger = WandbLogger(project="Chess-GRPO-Bot", log_model=True, name=run_name)
|
| 42 |
+
|
| 43 |
+
# Best checkpoint - saves top 2 based on loss
|
| 44 |
+
best_checkpoint_cb = ModelCheckpoint(
|
| 45 |
+
dirpath=checkpoint_dir,
|
| 46 |
+
filename=run_name + "-best-{epoch:02d}-{train_total_loss:.4f}",
|
| 47 |
+
save_top_k=2,
|
| 48 |
+
monitor="train_total_loss",
|
| 49 |
+
mode="min"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Periodic checkpoint for crash recovery
|
| 53 |
+
# Fixed filenames (periodic-0, periodic-1, etc.) that rotate within each run
|
| 54 |
+
periodic_checkpoint_cb = ModelCheckpoint(
|
| 55 |
+
dirpath=checkpoint_dir,
|
| 56 |
+
filename=run_name + "-periodic",
|
| 57 |
+
save_top_k=keep_n_checkpoints,
|
| 58 |
+
monitor="train_total_loss",
|
| 59 |
+
mode="min",
|
| 60 |
+
every_n_epochs=checkpoint_every_n_epochs,
|
| 61 |
+
save_last=True, # Always keep the very last checkpoint
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return pl.Trainer(
|
| 65 |
+
max_epochs=num_epochs,
|
| 66 |
+
# Gradient clipping handled manually in GRPOChessTransformer.training_step
|
| 67 |
+
accelerator="auto",
|
| 68 |
+
devices=1,
|
| 69 |
+
logger=wandb_logger,
|
| 70 |
+
callbacks=[best_checkpoint_cb, periodic_checkpoint_cb],
|
| 71 |
+
log_every_n_steps=1 # Log every step for GRPO debug
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
requirements.txt
CHANGED
|
@@ -4,9 +4,5 @@ torch
|
|
| 4 |
safetensors
|
| 5 |
python-chess
|
| 6 |
huggingface_hub
|
| 7 |
-
|
| 8 |
-
mcp>=0.9.0
|
| 9 |
-
wandb>=0.16.0
|
| 10 |
jaxtyping
|
| 11 |
-
datasets
|
| 12 |
-
gradio>=4.44.1
|
|
|
|
| 4 |
safetensors
|
| 5 |
python-chess
|
| 6 |
huggingface_hub
|
| 7 |
+
numpy
|
|
|
|
|
|
|
| 8 |
jaxtyping
|
|
|
|
|
|