Spaces:
Running
Running
Commit
·
8a7719b
1
Parent(s):
d7e086f
Integrated the template, made evaluation more robust to different tokenizers
Browse files- .gitignore +56 -0
- TEMPLATE_README.md +152 -0
- app.py +7 -1
- pyproject.toml +59 -0
- src/__init__.py +10 -1
- src/data.py +253 -0
- src/evaluate.py +270 -132
- src/train.py +250 -0
- src/utils.py +305 -0
- submit.py +144 -0
.gitignore
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg
|
| 11 |
+
|
| 12 |
+
# Virtual environments
|
| 13 |
+
.venv/
|
| 14 |
+
venv/
|
| 15 |
+
env/
|
| 16 |
+
ENV/
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
*.swp
|
| 22 |
+
*.swo
|
| 23 |
+
*~
|
| 24 |
+
.DS_Store
|
| 25 |
+
|
| 26 |
+
# Model outputs and training artifacts
|
| 27 |
+
my_model/
|
| 28 |
+
checkpoints/
|
| 29 |
+
runs/
|
| 30 |
+
wandb/
|
| 31 |
+
*.pth
|
| 32 |
+
*.pt
|
| 33 |
+
*.safetensors
|
| 34 |
+
*.bin
|
| 35 |
+
|
| 36 |
+
# Dataset caches
|
| 37 |
+
.cache/
|
| 38 |
+
*.arrow
|
| 39 |
+
*.parquet
|
| 40 |
+
|
| 41 |
+
# Jupyter
|
| 42 |
+
.ipynb_checkpoints/
|
| 43 |
+
*.ipynb
|
| 44 |
+
|
| 45 |
+
# Testing
|
| 46 |
+
.pytest_cache/
|
| 47 |
+
.coverage
|
| 48 |
+
htmlcov/
|
| 49 |
+
|
| 50 |
+
# Logs
|
| 51 |
+
*.log
|
| 52 |
+
logs/
|
| 53 |
+
|
| 54 |
+
# Environment variables
|
| 55 |
+
.env
|
| 56 |
+
.env.local
|
TEMPLATE_README.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chess Challenge
|
| 2 |
+
|
| 3 |
+
Train a 1M parameter LLM to play chess!
|
| 4 |
+
|
| 5 |
+
## Objective
|
| 6 |
+
|
| 7 |
+
Design and train a transformer-based language model to predict chess moves. Your model must:
|
| 8 |
+
|
| 9 |
+
1. **Stay under 1M parameters** - This is the hard constraint!
|
| 10 |
+
2. **Use a custom tokenizer** - Design an efficient move-level tokenizer
|
| 11 |
+
3. **Play legal chess** - The model should learn to generate valid moves
|
| 12 |
+
4. **Beat Stockfish** - Your ELO will be measured against Stockfish Level 1
|
| 13 |
+
|
| 14 |
+
## Dataset
|
| 15 |
+
|
| 16 |
+
We use the Lichess dataset: [`dlouapre/lichess_2025-01_1M`](https://huggingface.co/datasets/dlouapre/lichess_2025-01_1M)
|
| 17 |
+
|
| 18 |
+
The dataset uses an extended UCI notation:
|
| 19 |
+
- `W`/`B` prefix for White/Black
|
| 20 |
+
- Piece letter: `P`=Pawn, `N`=Knight, `B`=Bishop, `R`=Rook, `Q`=Queen, `K`=King
|
| 21 |
+
- Source and destination squares (e.g., `e2e4`)
|
| 22 |
+
- Special suffixes: `(x)`=capture, `(+)`=check, `(+*)`=checkmate, `(o)`/`(O)`=castling
|
| 23 |
+
|
| 24 |
+
Example game:
|
| 25 |
+
```
|
| 26 |
+
WPe2e4 BPe7e5 WNg1f3 BNb8c6 WBf1b5 BPa7a6 WBb5c6(x) BPd7c6(x) ...
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Quick Start
|
| 30 |
+
|
| 31 |
+
### Train a Model
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Basic training
|
| 35 |
+
python -m src.train \
|
| 36 |
+
--output_dir ./my_model \
|
| 37 |
+
--num_train_epochs 3 \
|
| 38 |
+
--per_device_train_batch_size 32
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Evaluate Your Model
|
| 42 |
+
|
| 43 |
+
Evaluation happens in two phases:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Phase 1: Legal Move Evaluation (quick sanity check)
|
| 47 |
+
python -m src.evaluate \
|
| 48 |
+
--model_path ./my_model \
|
| 49 |
+
--mode legal \
|
| 50 |
+
--n_positions 500
|
| 51 |
+
|
| 52 |
+
# Phase 2: Win Rate Evaluation (full games against Stockfish)
|
| 53 |
+
python -m src.evaluate \
|
| 54 |
+
--model_path ./my_model \
|
| 55 |
+
--mode winrate \
|
| 56 |
+
--n_games 100 \
|
| 57 |
+
--stockfish_level 1
|
| 58 |
+
|
| 59 |
+
# Or run both phases:
|
| 60 |
+
python -m src.evaluate \
|
| 61 |
+
--model_path ./my_model \
|
| 62 |
+
--mode both
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Parameter Budget
|
| 66 |
+
|
| 67 |
+
Use the utility function to check your budget:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from src import ChessConfig, print_parameter_budget
|
| 71 |
+
|
| 72 |
+
config = ChessConfig(
|
| 73 |
+
vocab_size=1200,
|
| 74 |
+
n_embd=128,
|
| 75 |
+
n_layer=4,
|
| 76 |
+
n_head=4,
|
| 77 |
+
)
|
| 78 |
+
print_parameter_budget(config)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Pro Tips
|
| 82 |
+
|
| 83 |
+
1. **Weight Tying**: The default config ties the embedding and output layer weights, saving ~154k parameters
|
| 84 |
+
2. **Vocabulary Size**: Keep it small! ~1200 tokens covers all moves
|
| 85 |
+
3. **Depth vs Width**: With limited parameters, experiment with shallow-but-wide vs deep-but-narrow
|
| 86 |
+
|
| 87 |
+
## Customization
|
| 88 |
+
|
| 89 |
+
### Custom Tokenizer
|
| 90 |
+
|
| 91 |
+
The template provides a move-level tokenizer that builds vocabulary from the actual dataset.
|
| 92 |
+
Feel free to try different approaches!
|
| 93 |
+
|
| 94 |
+
### Custom Architecture
|
| 95 |
+
|
| 96 |
+
Modify the model in `src/model.py`:
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
from src import ChessConfig, ChessForCausalLM
|
| 100 |
+
|
| 101 |
+
# Customize configuration
|
| 102 |
+
config = ChessConfig(
|
| 103 |
+
vocab_size=1200,
|
| 104 |
+
n_embd=128, # Try 96, 128, or 192
|
| 105 |
+
n_layer=4, # Try 3, 4, or 6
|
| 106 |
+
n_head=4, # Try 4 or 8
|
| 107 |
+
n_inner=384, # Feed-forward dimension (default: 3*n_embd)
|
| 108 |
+
dropout=0.1,
|
| 109 |
+
tie_weights=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
model = ChessForCausalLM(config)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## Evaluation Metrics
|
| 116 |
+
|
| 117 |
+
### Phase 1: Legal Move Evaluation
|
| 118 |
+
|
| 119 |
+
Tests if your model generates valid chess moves:
|
| 120 |
+
|
| 121 |
+
| Metric | Description |
|
| 122 |
+
|--------|-------------|
|
| 123 |
+
| **Legal Rate (1st try)** | % of legal moves on first attempt |
|
| 124 |
+
| **Legal Rate (with retry)** | % of legal moves within 3 attempts |
|
| 125 |
+
|
| 126 |
+
> **Target**: >90% legal rate before proceeding to Phase 2
|
| 127 |
+
|
| 128 |
+
### Phase 2: Win Rate Evaluation
|
| 129 |
+
|
| 130 |
+
Full games against Stockfish to measure playing strength:
|
| 131 |
+
|
| 132 |
+
| Metric | Description |
|
| 133 |
+
|--------|-------------|
|
| 134 |
+
| **Win Rate** | % of games won against Stockfish |
|
| 135 |
+
| **ELO Rating** | Estimated rating based on game results |
|
| 136 |
+
| **Avg Game Length** | Average number of moves per game |
|
| 137 |
+
| **Illegal Move Rate** | % of illegal moves during games |
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## Submission
|
| 141 |
+
|
| 142 |
+
1. Train your model
|
| 143 |
+
2. Log in to Hugging Face: `hf auth login`
|
| 144 |
+
3. Submit your model using the submission script:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
python submit.py --model_path ./my_model/final_model --model_name your-model-name
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
The script will:
|
| 151 |
+
- Upload your model to the LLM-course organization
|
| 152 |
+
- Include your HF username in the model card for tracking
|
app.py
CHANGED
|
@@ -591,7 +591,13 @@ with gr.Blocks(
|
|
| 591 |
The goal is to create a chess-playing language model with **under 1 million parameters**, which is roughly the number of neurons in a honey bee's brain.
|
| 592 |
At this scale, efficiency and clever architecture choices are key! We are not targetting superhuman performance, but rather exploring how well small models can learn the rules of chess, the goal being (only) to play **legal moves**.
|
| 593 |
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
2. **Push to Hugging Face Hub** using the `submit.py` script provided in the template to make sure that your model is registered correctly.
|
| 597 |
|
|
|
|
| 591 |
The goal is to create a chess-playing language model with **under 1 million parameters**, which is roughly the number of neurons in a honey bee's brain.
|
| 592 |
At this scale, efficiency and clever architecture choices are key! We are not targetting superhuman performance, but rather exploring how well small models can learn the rules of chess, the goal being (only) to play **legal moves**.
|
| 593 |
|
| 594 |
+
0. **Clone this repository**:
|
| 595 |
+
```bash
|
| 596 |
+
git clone https://huggingface.co/spaces/LLM-course/Chess1MChallenge
|
| 597 |
+
```
|
| 598 |
+
and check the `TEMPLATE_README.md` for detailed instructions.
|
| 599 |
+
|
| 600 |
+
1. **Train your model**
|
| 601 |
|
| 602 |
2. **Push to Hugging Face Hub** using the `submit.py` script provided in the template to make sure that your model is registered correctly.
|
| 603 |
|
pyproject.toml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "chess-challenge"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "LLM Chess Challenge - Train a 1M parameter model to play chess"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Nathanaël Fijalkow", email = "nathanael.fijalkow@gmail.com"}
|
| 14 |
+
]
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Development Status :: 3 - Alpha",
|
| 17 |
+
"Intended Audience :: Education",
|
| 18 |
+
"License :: OSI Approved :: MIT License",
|
| 19 |
+
"Programming Language :: Python :: 3",
|
| 20 |
+
"Programming Language :: Python :: 3.10",
|
| 21 |
+
"Programming Language :: Python :: 3.11",
|
| 22 |
+
"Programming Language :: Python :: 3.12",
|
| 23 |
+
]
|
| 24 |
+
dependencies = [
|
| 25 |
+
"torch>=2.0.0",
|
| 26 |
+
"transformers>=4.40.0",
|
| 27 |
+
"accelerate>=0.26.0",
|
| 28 |
+
"datasets>=2.14.0",
|
| 29 |
+
"python-chess>=1.999",
|
| 30 |
+
"huggingface-hub>=0.20.0",
|
| 31 |
+
"tqdm>=4.65.0",
|
| 32 |
+
"numpy>=1.24.0",
|
| 33 |
+
"wandb>=0.15.0",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
[project.optional-dependencies]
|
| 37 |
+
dev = [
|
| 38 |
+
"pytest>=7.0.0",
|
| 39 |
+
"black>=23.0.0",
|
| 40 |
+
"ruff>=0.1.0",
|
| 41 |
+
]
|
| 42 |
+
eval = [
|
| 43 |
+
"stockfish>=3.28.0",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
[project.scripts]
|
| 47 |
+
chess-train = "src.train:main"
|
| 48 |
+
chess-eval = "src.evaluate:main"
|
| 49 |
+
|
| 50 |
+
[tool.setuptools.packages.find]
|
| 51 |
+
where = ["src"]
|
| 52 |
+
|
| 53 |
+
[tool.black]
|
| 54 |
+
line-length = 100
|
| 55 |
+
target-version = ["py310"]
|
| 56 |
+
|
| 57 |
+
[tool.ruff]
|
| 58 |
+
line-length = 100
|
| 59 |
+
select = ["E", "F", "I", "W"]
|
src/__init__.py
CHANGED
|
@@ -2,7 +2,16 @@
|
|
| 2 |
|
| 3 |
from .model import ChessConfig, ChessForCausalLM
|
| 4 |
from .tokenizer import ChessTokenizer
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
"ChessConfig",
|
|
|
|
| 2 |
|
| 3 |
from .model import ChessConfig, ChessForCausalLM
|
| 4 |
from .tokenizer import ChessTokenizer
|
| 5 |
+
|
| 6 |
+
# Lazy import for evaluate to avoid RuntimeWarning when running as module
|
| 7 |
+
def __getattr__(name):
|
| 8 |
+
if name == "ChessEvaluator":
|
| 9 |
+
from .evaluate import ChessEvaluator
|
| 10 |
+
return ChessEvaluator
|
| 11 |
+
if name == "load_model_from_hub":
|
| 12 |
+
from .evaluate import load_model_from_hub
|
| 13 |
+
return load_model_from_hub
|
| 14 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 15 |
|
| 16 |
__all__ = [
|
| 17 |
"ChessConfig",
|
src/data.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading utilities for the Chess Challenge.
|
| 3 |
+
|
| 4 |
+
This module provides functions to load and process chess game data
|
| 5 |
+
from the Lichess dataset on Hugging Face.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Dict, Iterator, List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ChessDataset(Dataset):
|
| 17 |
+
"""
|
| 18 |
+
PyTorch Dataset for chess games.
|
| 19 |
+
|
| 20 |
+
This dataset loads games from a Hugging Face dataset and prepares
|
| 21 |
+
them for language modeling training.
|
| 22 |
+
|
| 23 |
+
Each game is tokenized and truncated/padded to max_length.
|
| 24 |
+
The labels are shifted by one position for next-token prediction.
|
| 25 |
+
|
| 26 |
+
Example:
|
| 27 |
+
>>> from src.tokenizer import ChessTokenizer
|
| 28 |
+
>>> tokenizer = ChessTokenizer.build_vocab_from_dataset()
|
| 29 |
+
>>> dataset = ChessDataset(tokenizer, max_length=256)
|
| 30 |
+
>>> sample = dataset[0]
|
| 31 |
+
>>> print(sample["input_ids"].shape) # (256,)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
tokenizer,
|
| 37 |
+
dataset_name: str = "dlouapre/lichess_2025-01_1M",
|
| 38 |
+
split: str = "train",
|
| 39 |
+
column: str = "text",
|
| 40 |
+
max_length: int = 256,
|
| 41 |
+
max_samples: Optional[int] = None,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize the chess dataset.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
tokenizer: The chess tokenizer to use.
|
| 48 |
+
dataset_name: Name of the dataset on Hugging Face Hub.
|
| 49 |
+
split: Dataset split to use.
|
| 50 |
+
column: Column containing the game strings.
|
| 51 |
+
max_length: Maximum sequence length.
|
| 52 |
+
max_samples: Maximum number of samples to load.
|
| 53 |
+
"""
|
| 54 |
+
from datasets import load_dataset
|
| 55 |
+
|
| 56 |
+
self.tokenizer = tokenizer
|
| 57 |
+
self.max_length = max_length
|
| 58 |
+
self.column = column
|
| 59 |
+
|
| 60 |
+
# Load dataset
|
| 61 |
+
dataset = load_dataset(dataset_name, split=split)
|
| 62 |
+
|
| 63 |
+
if max_samples is not None:
|
| 64 |
+
dataset = dataset.select(range(min(max_samples, len(dataset))))
|
| 65 |
+
|
| 66 |
+
self.data = dataset
|
| 67 |
+
|
| 68 |
+
def __len__(self) -> int:
|
| 69 |
+
return len(self.data)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 72 |
+
game = self.data[idx][self.column]
|
| 73 |
+
|
| 74 |
+
# Prepend BOS token for proper language modeling
|
| 75 |
+
game_with_bos = self.tokenizer.bos_token + " " + game
|
| 76 |
+
|
| 77 |
+
# Tokenize
|
| 78 |
+
encoding = self.tokenizer(
|
| 79 |
+
game_with_bos,
|
| 80 |
+
truncation=True,
|
| 81 |
+
max_length=self.max_length,
|
| 82 |
+
padding="max_length",
|
| 83 |
+
return_tensors="pt",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Squeeze batch dimension
|
| 87 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 88 |
+
attention_mask = encoding["attention_mask"].squeeze(0)
|
| 89 |
+
|
| 90 |
+
# Labels are the same as input_ids (model will shift internally)
|
| 91 |
+
labels = input_ids.clone()
|
| 92 |
+
|
| 93 |
+
# Set padding tokens to -100 to ignore in loss
|
| 94 |
+
labels[attention_mask == 0] = -100
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"input_ids": input_ids,
|
| 98 |
+
"attention_mask": attention_mask,
|
| 99 |
+
"labels": labels,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ChessDataCollator:
|
| 104 |
+
"""
|
| 105 |
+
Data collator for chess games.
|
| 106 |
+
|
| 107 |
+
This collator pads sequences to the same length within a batch
|
| 108 |
+
and creates the appropriate attention masks.
|
| 109 |
+
"""
|
| 110 |
+
|
| 111 |
+
def __init__(self, tokenizer, max_length: int = 256):
|
| 112 |
+
self.tokenizer = tokenizer
|
| 113 |
+
self.max_length = max_length
|
| 114 |
+
|
| 115 |
+
def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 116 |
+
# Stack tensors
|
| 117 |
+
input_ids = torch.stack([f["input_ids"] for f in features])
|
| 118 |
+
attention_mask = torch.stack([f["attention_mask"] for f in features])
|
| 119 |
+
labels = torch.stack([f["labels"] for f in features])
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"input_ids": input_ids,
|
| 123 |
+
"attention_mask": attention_mask,
|
| 124 |
+
"labels": labels,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def create_train_val_datasets(
|
| 129 |
+
tokenizer,
|
| 130 |
+
dataset_name: str = "dlouapre/lichess_2025-01_1M",
|
| 131 |
+
max_length: int = 256,
|
| 132 |
+
train_samples: Optional[int] = None,
|
| 133 |
+
val_samples: int = 5000,
|
| 134 |
+
val_ratio: float = 0.05,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Create training and validation datasets.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
tokenizer: The chess tokenizer.
|
| 141 |
+
dataset_name: Name of the dataset.
|
| 142 |
+
max_length: Maximum sequence length.
|
| 143 |
+
train_samples: Maximum training samples (None for all).
|
| 144 |
+
val_samples: Number of validation samples.
|
| 145 |
+
val_ratio: Ratio of validation samples (used if train_samples is None).
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tuple of (train_dataset, val_dataset).
|
| 149 |
+
"""
|
| 150 |
+
from datasets import load_dataset
|
| 151 |
+
|
| 152 |
+
# Load full dataset
|
| 153 |
+
full_dataset = load_dataset(dataset_name, split="train")
|
| 154 |
+
|
| 155 |
+
# Determine split sizes
|
| 156 |
+
total = len(full_dataset)
|
| 157 |
+
|
| 158 |
+
if train_samples is not None:
|
| 159 |
+
n_train = min(train_samples, total - val_samples)
|
| 160 |
+
else:
|
| 161 |
+
n_train = int(total * (1 - val_ratio))
|
| 162 |
+
|
| 163 |
+
n_val = min(val_samples, total - n_train)
|
| 164 |
+
|
| 165 |
+
# Split dataset
|
| 166 |
+
train_data = full_dataset.select(range(n_train))
|
| 167 |
+
val_data = full_dataset.select(range(n_train, n_train + n_val))
|
| 168 |
+
|
| 169 |
+
# Create dataset objects
|
| 170 |
+
train_dataset = ChessDataset(
|
| 171 |
+
tokenizer=tokenizer,
|
| 172 |
+
dataset_name=dataset_name,
|
| 173 |
+
max_length=max_length,
|
| 174 |
+
)
|
| 175 |
+
train_dataset.data = train_data
|
| 176 |
+
|
| 177 |
+
val_dataset = ChessDataset(
|
| 178 |
+
tokenizer=tokenizer,
|
| 179 |
+
dataset_name=dataset_name,
|
| 180 |
+
max_length=max_length,
|
| 181 |
+
)
|
| 182 |
+
val_dataset.data = val_data
|
| 183 |
+
|
| 184 |
+
return train_dataset, val_dataset
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def stream_games(
|
| 188 |
+
dataset_name: str = "dlouapre/lichess_2025-01_1M",
|
| 189 |
+
split: str = "train",
|
| 190 |
+
column: str = "text",
|
| 191 |
+
) -> Iterator[str]:
|
| 192 |
+
"""
|
| 193 |
+
Stream games from the dataset for memory-efficient processing.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
dataset_name: Name of the dataset on Hugging Face Hub.
|
| 197 |
+
split: Dataset split to use.
|
| 198 |
+
column: Column containing the game strings.
|
| 199 |
+
|
| 200 |
+
Yields:
|
| 201 |
+
Game strings one at a time.
|
| 202 |
+
"""
|
| 203 |
+
from datasets import load_dataset
|
| 204 |
+
|
| 205 |
+
dataset = load_dataset(dataset_name, split=split, streaming=True)
|
| 206 |
+
|
| 207 |
+
for example in dataset:
|
| 208 |
+
yield example[column]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def analyze_dataset_statistics(
|
| 212 |
+
dataset_name: str = "dlouapre/lichess_2025-01_1M",
|
| 213 |
+
max_samples: int = 10000,
|
| 214 |
+
) -> Dict:
|
| 215 |
+
"""
|
| 216 |
+
Analyze statistics of the chess dataset.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
dataset_name: Name of the dataset.
|
| 220 |
+
max_samples: Maximum number of samples to analyze.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Dictionary containing dataset statistics.
|
| 224 |
+
"""
|
| 225 |
+
from collections import Counter
|
| 226 |
+
from datasets import load_dataset
|
| 227 |
+
|
| 228 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 229 |
+
dataset = dataset.select(range(min(max_samples, len(dataset))))
|
| 230 |
+
|
| 231 |
+
game_lengths = []
|
| 232 |
+
move_counts = Counter()
|
| 233 |
+
opening_moves = Counter()
|
| 234 |
+
|
| 235 |
+
for example in dataset:
|
| 236 |
+
moves = example["text"].strip().split()
|
| 237 |
+
game_lengths.append(len(moves))
|
| 238 |
+
move_counts.update(moves)
|
| 239 |
+
|
| 240 |
+
# Track common openings (first 4 moves)
|
| 241 |
+
if len(moves) >= 4:
|
| 242 |
+
opening = " ".join(moves[:4])
|
| 243 |
+
opening_moves[opening] += 1
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
"total_games": len(dataset),
|
| 247 |
+
"avg_game_length": sum(game_lengths) / len(game_lengths),
|
| 248 |
+
"min_game_length": min(game_lengths),
|
| 249 |
+
"max_game_length": max(game_lengths),
|
| 250 |
+
"unique_moves": len(move_counts),
|
| 251 |
+
"most_common_moves": move_counts.most_common(20),
|
| 252 |
+
"most_common_openings": opening_moves.most_common(10),
|
| 253 |
+
}
|
src/evaluate.py
CHANGED
|
@@ -9,6 +9,7 @@ from __future__ import annotations
|
|
| 9 |
|
| 10 |
import argparse
|
| 11 |
import random
|
|
|
|
| 12 |
from dataclasses import dataclass
|
| 13 |
from typing import List, Optional, Tuple
|
| 14 |
|
|
@@ -23,16 +24,23 @@ class GameResult:
|
|
| 23 |
model_color: str # "white" or "black"
|
| 24 |
termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
|
| 25 |
illegal_move_count: int
|
| 26 |
-
|
| 27 |
-
|
| 28 |
class ChessEvaluator:
|
| 29 |
"""
|
| 30 |
Evaluator for chess models.
|
| 31 |
|
| 32 |
This class handles playing games between a trained model and Stockfish,
|
| 33 |
tracking results, and computing ELO ratings.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
def __init__(
|
| 37 |
self,
|
| 38 |
model,
|
|
@@ -88,10 +96,100 @@ class ChessEvaluator:
|
|
| 88 |
if hasattr(self, 'engine') and self.engine:
|
| 89 |
self.engine.quit()
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def _convert_board_to_moves(self, board) -> str:
|
| 92 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
moves = []
|
| 94 |
temp_board = self.chess.Board()
|
|
|
|
| 95 |
|
| 96 |
for move in board.move_stack:
|
| 97 |
# Get piece and color
|
|
@@ -103,29 +201,44 @@ class ChessEvaluator:
|
|
| 103 |
from_sq = self.chess.square_name(move.from_square)
|
| 104 |
to_sq = self.chess.square_name(move.to_square)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# Add promotion
|
| 109 |
if move.promotion:
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# Add capture suffix
|
| 113 |
-
if temp_board.is_capture(move):
|
| 114 |
-
move_str += "(x)"
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
if temp_board.is_checkmate():
|
| 119 |
-
move_str = move_str.replace("(x)", "(x+*)") if "(x)" in move_str else move_str + "(+*)"
|
| 120 |
-
elif temp_board.is_check():
|
| 121 |
-
move_str = move_str.replace("(x)", "(x+)") if "(x)" in move_str else move_str + "(+)"
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
if
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
moves.append(move_str)
|
| 131 |
|
|
@@ -160,6 +273,65 @@ class ChessEvaluator:
|
|
| 160 |
|
| 161 |
return False
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def _generate_move_tokens(
|
| 164 |
self,
|
| 165 |
input_ids: torch.Tensor,
|
|
@@ -168,11 +340,12 @@ class ChessEvaluator:
|
|
| 168 |
max_tokens: int = 20,
|
| 169 |
) -> str:
|
| 170 |
"""
|
| 171 |
-
Generate tokens until a
|
| 172 |
|
| 173 |
-
This method
|
| 174 |
-
-
|
| 175 |
-
-
|
|
|
|
| 176 |
|
| 177 |
Args:
|
| 178 |
input_ids: The input token IDs.
|
|
@@ -181,10 +354,11 @@ class ChessEvaluator:
|
|
| 181 |
max_tokens: Maximum tokens to generate for a single move.
|
| 182 |
|
| 183 |
Returns:
|
| 184 |
-
The generated move string
|
| 185 |
"""
|
| 186 |
generated_tokens = []
|
| 187 |
current_ids = input_ids.clone()
|
|
|
|
| 188 |
|
| 189 |
for _ in range(max_tokens):
|
| 190 |
with torch.no_grad():
|
|
@@ -193,31 +367,47 @@ class ChessEvaluator:
|
|
| 193 |
|
| 194 |
# Apply top-k filtering
|
| 195 |
if top_k > 0:
|
| 196 |
-
|
| 197 |
-
indices_to_remove = logits <
|
| 198 |
logits[indices_to_remove] = float("-inf")
|
| 199 |
|
| 200 |
# Sample
|
| 201 |
probs = torch.softmax(logits, dim=-1)
|
| 202 |
-
next_token = torch.multinomial(probs, num_samples=1)
|
| 203 |
|
| 204 |
# Decode the token
|
| 205 |
token_str = self.tokenizer.decode(next_token[0])
|
| 206 |
|
| 207 |
# Check if this is a separator token
|
| 208 |
if self._is_separator_token(token_str):
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
current_ids = torch.cat([current_ids, next_token], dim=-1)
|
|
|
|
| 215 |
|
| 216 |
-
#
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
# Decode all generated tokens together
|
| 223 |
if generated_tokens:
|
|
@@ -236,11 +426,15 @@ class ChessEvaluator:
|
|
| 236 |
"""
|
| 237 |
Get the model's next move prediction.
|
| 238 |
|
| 239 |
-
This method
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
Returns:
|
| 246 |
Tuple of (UCI move string, number of retries used).
|
|
@@ -257,32 +451,26 @@ class ChessEvaluator:
|
|
| 257 |
input_text = self.tokenizer.bos_token + " " + moves_str
|
| 258 |
|
| 259 |
# Tokenize
|
| 260 |
-
max_len = getattr(self.model.config, 'n_ctx', None) or getattr(self.model.config, 'max_position_embeddings', 256)
|
| 261 |
inputs = self.tokenizer(
|
| 262 |
input_text,
|
| 263 |
return_tensors="pt",
|
| 264 |
truncation=True,
|
| 265 |
-
max_length=
|
| 266 |
).to(self.device)
|
| 267 |
|
| 268 |
# Try to generate a legal move
|
| 269 |
for retry in range(self.max_retries):
|
| 270 |
-
# Generate tokens until
|
| 271 |
-
|
| 272 |
inputs["input_ids"],
|
| 273 |
temperature=temperature,
|
| 274 |
top_k=top_k,
|
| 275 |
)
|
| 276 |
|
| 277 |
-
#
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
# Handle promotion
|
| 282 |
-
if "=" in move_token:
|
| 283 |
-
promo_idx = move_token.index("=")
|
| 284 |
-
uci_move += move_token[promo_idx + 1].lower()
|
| 285 |
-
|
| 286 |
try:
|
| 287 |
move = self.chess.Move.from_uci(uci_move)
|
| 288 |
if move in board.legal_moves:
|
|
@@ -390,7 +578,6 @@ class ChessEvaluator:
|
|
| 390 |
n_positions: int = 1000,
|
| 391 |
temperature: float = 0.7,
|
| 392 |
verbose: bool = True,
|
| 393 |
-
seed: int = 42,
|
| 394 |
) -> dict:
|
| 395 |
"""
|
| 396 |
Evaluate the model's ability to generate legal moves.
|
|
@@ -402,14 +589,10 @@ class ChessEvaluator:
|
|
| 402 |
n_positions: Number of positions to test.
|
| 403 |
temperature: Sampling temperature.
|
| 404 |
verbose: Whether to print progress.
|
| 405 |
-
seed: Random seed for reproducibility.
|
| 406 |
|
| 407 |
Returns:
|
| 408 |
Dictionary with legal move statistics.
|
| 409 |
"""
|
| 410 |
-
# Set seed for deterministic evaluation
|
| 411 |
-
random.seed(seed)
|
| 412 |
-
|
| 413 |
results = {
|
| 414 |
"total_positions": 0,
|
| 415 |
"legal_first_try": 0,
|
|
@@ -572,73 +755,24 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
|
|
| 572 |
Returns:
|
| 573 |
Tuple of (model, tokenizer).
|
| 574 |
"""
|
| 575 |
-
import
|
| 576 |
-
from huggingface_hub import hf_hub_download
|
| 577 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
| 578 |
|
| 579 |
-
# Import custom classes
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
from src.tokenizer import ChessTokenizer
|
| 583 |
-
except ImportError:
|
| 584 |
-
from .model import ChessConfig, ChessForCausalLM
|
| 585 |
-
from .tokenizer import ChessTokenizer
|
| 586 |
|
| 587 |
-
#
|
| 588 |
try:
|
| 589 |
-
|
| 590 |
-
except
|
| 591 |
-
|
| 592 |
-
try:
|
| 593 |
-
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
|
| 594 |
-
except ValueError:
|
| 595 |
-
pass
|
| 596 |
-
|
| 597 |
-
print(f"Loading model {model_id}...")
|
| 598 |
-
|
| 599 |
-
# Download and load config manually to avoid transformers auto-detection issues
|
| 600 |
-
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
|
| 601 |
-
with open(config_path, "r") as f:
|
| 602 |
-
config_dict = json.load(f)
|
| 603 |
-
|
| 604 |
-
# Remove fields that are not in ChessConfig to avoid unexpected kwargs
|
| 605 |
-
config_dict.pop("model_type", None)
|
| 606 |
-
config_dict.pop("architectures", None)
|
| 607 |
-
config_dict.pop("transformers_version", None)
|
| 608 |
-
config_dict.pop("dtype", None)
|
| 609 |
-
config_dict.pop("torch_dtype", None)
|
| 610 |
-
|
| 611 |
-
config = ChessConfig(**config_dict)
|
| 612 |
|
| 613 |
-
|
| 614 |
-
model = ChessForCausalLM.from_pretrained(
|
| 615 |
model_id,
|
| 616 |
-
|
| 617 |
device_map=device,
|
| 618 |
)
|
| 619 |
|
| 620 |
-
# Load tokenizer - try to find vocab.json, else build default
|
| 621 |
-
try:
|
| 622 |
-
tokenizer = ChessTokenizer.from_pretrained(model_id)
|
| 623 |
-
except Exception as e:
|
| 624 |
-
print(f"ChessTokenizer.from_pretrained failed: {e}")
|
| 625 |
-
try:
|
| 626 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 627 |
-
except Exception as e2:
|
| 628 |
-
print(f"AutoTokenizer also failed: {e2}")
|
| 629 |
-
print("Creating default tokenizer with vocab_size from config...")
|
| 630 |
-
# Create a minimal tokenizer with just the vocab size
|
| 631 |
-
tokenizer = ChessTokenizer()
|
| 632 |
-
# Ensure vocab size matches model
|
| 633 |
-
if hasattr(config, 'vocab_size'):
|
| 634 |
-
# Build a placeholder vocab of the right size
|
| 635 |
-
tokenizer._vocab = {f"[MOVE_{i}]": i for i in range(config.vocab_size)}
|
| 636 |
-
tokenizer._vocab["[PAD]"] = 0
|
| 637 |
-
tokenizer._vocab["[BOS]"] = 1
|
| 638 |
-
tokenizer._vocab["[EOS]"] = 2
|
| 639 |
-
tokenizer._vocab["[UNK]"] = 3
|
| 640 |
-
tokenizer._ids_to_tokens = {v: k for k, v in tokenizer._vocab.items()}
|
| 641 |
-
|
| 642 |
return model, tokenizer
|
| 643 |
|
| 644 |
|
|
@@ -684,21 +818,25 @@ def main():
|
|
| 684 |
# Load model
|
| 685 |
print(f"\nLoading model from: {args.model_path}")
|
| 686 |
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
# Local path
|
| 692 |
from transformers import AutoModelForCausalLM
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
from src.model import ChessConfig, ChessForCausalLM
|
| 696 |
-
except ImportError:
|
| 697 |
-
from .tokenizer import ChessTokenizer
|
| 698 |
-
from .model import ChessConfig, ChessForCausalLM
|
| 699 |
|
| 700 |
tokenizer = ChessTokenizer.from_pretrained(args.model_path)
|
| 701 |
model = AutoModelForCausalLM.from_pretrained(args.model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
|
| 703 |
# Create evaluator
|
| 704 |
print(f"\nSetting up evaluator...")
|
|
|
|
| 9 |
|
| 10 |
import argparse
|
| 11 |
import random
|
| 12 |
+
import re
|
| 13 |
from dataclasses import dataclass
|
| 14 |
from typing import List, Optional, Tuple
|
| 15 |
|
|
|
|
| 24 |
model_color: str # "white" or "black"
|
| 25 |
termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
|
| 26 |
illegal_move_count: int
|
| 27 |
+
|
| 28 |
+
|
| 29 |
class ChessEvaluator:
|
| 30 |
"""
|
| 31 |
Evaluator for chess models.
|
| 32 |
|
| 33 |
This class handles playing games between a trained model and Stockfish,
|
| 34 |
tracking results, and computing ELO ratings.
|
| 35 |
+
|
| 36 |
+
Supports any tokenization format as long as the model generates valid
|
| 37 |
+
chess squares (e.g., e2, e4). The evaluator extracts UCI moves by finding
|
| 38 |
+
square patterns in the generated output.
|
| 39 |
"""
|
| 40 |
|
| 41 |
+
# Regex pattern to match chess squares
|
| 42 |
+
SQUARE_PATTERN = r'[a-h][1-8]'
|
| 43 |
+
|
| 44 |
def __init__(
|
| 45 |
self,
|
| 46 |
model,
|
|
|
|
| 96 |
if hasattr(self, 'engine') and self.engine:
|
| 97 |
self.engine.quit()
|
| 98 |
|
| 99 |
+
def _detect_tokenizer_format(self) -> str:
|
| 100 |
+
"""
|
| 101 |
+
Detect the tokenizer's expected move format by testing tokenization.
|
| 102 |
+
|
| 103 |
+
Tests various formats with a sample move and picks the one that
|
| 104 |
+
produces the fewest unknown tokens. This makes evaluation work
|
| 105 |
+
with any tokenizer format.
|
| 106 |
+
|
| 107 |
+
Supported formats:
|
| 108 |
+
- 'decomposed': "WP e2_f e4_t" (piece, from_suffix, to_suffix)
|
| 109 |
+
- 'standard': "WPe2e4" (combined with optional annotations)
|
| 110 |
+
- 'uci': "e2e4" (pure UCI notation)
|
| 111 |
+
- 'uci_spaced': "e2 e4" (UCI with space separator)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
The format string that best matches the tokenizer's vocabulary.
|
| 115 |
+
"""
|
| 116 |
+
if hasattr(self, '_cached_format'):
|
| 117 |
+
return self._cached_format
|
| 118 |
+
|
| 119 |
+
# Sample move representations to test
|
| 120 |
+
test_formats = {
|
| 121 |
+
'decomposed': "WP e2_f e4_t",
|
| 122 |
+
'standard': "WPe2e4",
|
| 123 |
+
'uci': "e2e4",
|
| 124 |
+
'uci_spaced': "e2 e4",
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
unk_token_id = getattr(self.tokenizer, 'unk_token_id', None)
|
| 128 |
+
best_format = 'standard'
|
| 129 |
+
min_unk_count = float('inf')
|
| 130 |
+
|
| 131 |
+
for fmt, sample in test_formats.items():
|
| 132 |
+
try:
|
| 133 |
+
tokens = self.tokenizer.encode(sample, add_special_tokens=False)
|
| 134 |
+
# Count unknown tokens
|
| 135 |
+
unk_count = tokens.count(unk_token_id) if unk_token_id is not None else 0
|
| 136 |
+
# Also penalize if the entire thing became one UNK
|
| 137 |
+
if len(tokens) == 1 and unk_count == 1:
|
| 138 |
+
unk_count = 100 # Heavy penalty
|
| 139 |
+
|
| 140 |
+
if unk_count < min_unk_count:
|
| 141 |
+
min_unk_count = unk_count
|
| 142 |
+
best_format = fmt
|
| 143 |
+
except Exception:
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
self._cached_format = best_format
|
| 147 |
+
return best_format
|
| 148 |
+
|
| 149 |
+
def _format_move(self, color: str, piece: str, from_sq: str, to_sq: str,
|
| 150 |
+
promotion: str = None) -> str:
|
| 151 |
+
"""
|
| 152 |
+
Format a single move according to the detected tokenizer format.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
color: 'W' or 'B'
|
| 156 |
+
piece: Piece letter (P, N, B, R, Q, K)
|
| 157 |
+
from_sq: Source square (e.g., 'e2')
|
| 158 |
+
to_sq: Destination square (e.g., 'e4')
|
| 159 |
+
promotion: Promotion piece letter or None
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Formatted move string.
|
| 163 |
+
"""
|
| 164 |
+
fmt = self._detect_tokenizer_format()
|
| 165 |
+
|
| 166 |
+
if fmt == 'decomposed':
|
| 167 |
+
move_str = f"{color}{piece} {from_sq}_f {to_sq}_t"
|
| 168 |
+
elif fmt == 'uci':
|
| 169 |
+
move_str = f"{from_sq}{to_sq}"
|
| 170 |
+
if promotion:
|
| 171 |
+
move_str += promotion.lower()
|
| 172 |
+
elif fmt == 'uci_spaced':
|
| 173 |
+
move_str = f"{from_sq} {to_sq}"
|
| 174 |
+
if promotion:
|
| 175 |
+
move_str += f" {promotion.lower()}"
|
| 176 |
+
else: # standard
|
| 177 |
+
move_str = f"{color}{piece}{from_sq}{to_sq}"
|
| 178 |
+
if promotion:
|
| 179 |
+
move_str += f"={promotion}"
|
| 180 |
+
|
| 181 |
+
return move_str
|
| 182 |
+
|
| 183 |
def _convert_board_to_moves(self, board) -> str:
|
| 184 |
+
"""
|
| 185 |
+
Convert board move history to model input format.
|
| 186 |
+
|
| 187 |
+
Automatically detects the tokenizer's expected format and outputs
|
| 188 |
+
moves accordingly. Supports any tokenization strategy.
|
| 189 |
+
"""
|
| 190 |
moves = []
|
| 191 |
temp_board = self.chess.Board()
|
| 192 |
+
fmt = self._detect_tokenizer_format()
|
| 193 |
|
| 194 |
for move in board.move_stack:
|
| 195 |
# Get piece and color
|
|
|
|
| 201 |
from_sq = self.chess.square_name(move.from_square)
|
| 202 |
to_sq = self.chess.square_name(move.to_square)
|
| 203 |
|
| 204 |
+
# Get promotion piece if any
|
| 205 |
+
promo = None
|
|
|
|
| 206 |
if move.promotion:
|
| 207 |
+
promo = self.chess.piece_symbol(move.promotion).upper()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
+
# Format based on detected tokenizer format
|
| 210 |
+
move_str = self._format_move(color, piece_letter, from_sq, to_sq, promo)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
# For standard format, add annotations (capture, check, castling)
|
| 213 |
+
if fmt == 'standard':
|
| 214 |
+
# Add capture suffix
|
| 215 |
+
if temp_board.is_capture(move):
|
| 216 |
+
move_str += "(x)"
|
| 217 |
+
|
| 218 |
+
# Push move to check for check/checkmate
|
| 219 |
+
temp_board.push(move)
|
| 220 |
+
|
| 221 |
+
if temp_board.is_checkmate():
|
| 222 |
+
if "(x)" in move_str:
|
| 223 |
+
move_str = move_str.replace("(x)", "(x+*)")
|
| 224 |
+
else:
|
| 225 |
+
move_str += "(+*)"
|
| 226 |
+
elif temp_board.is_check():
|
| 227 |
+
if "(x)" in move_str:
|
| 228 |
+
move_str = move_str.replace("(x)", "(x+)")
|
| 229 |
+
else:
|
| 230 |
+
move_str += "(+)"
|
| 231 |
+
|
| 232 |
+
# Handle castling notation
|
| 233 |
+
if piece_letter == "K":
|
| 234 |
+
if abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
|
| 235 |
+
if to_sq[0] == 'g': # Kingside
|
| 236 |
+
move_str = move_str.split("(")[0] + "(o)"
|
| 237 |
+
else: # Queenside
|
| 238 |
+
move_str = move_str.split("(")[0] + "(O)"
|
| 239 |
+
else:
|
| 240 |
+
# For non-standard formats, just push the move
|
| 241 |
+
temp_board.push(move)
|
| 242 |
|
| 243 |
moves.append(move_str)
|
| 244 |
|
|
|
|
| 273 |
|
| 274 |
return False
|
| 275 |
|
| 276 |
+
def _extract_uci_move(self, text: str) -> Optional[str]:
|
| 277 |
+
"""
|
| 278 |
+
Extract a UCI move from generated text using pattern matching.
|
| 279 |
+
|
| 280 |
+
This generic method works with any tokenization format by finding
|
| 281 |
+
chess square patterns ([a-h][1-8]) in the output.
|
| 282 |
+
|
| 283 |
+
Supported formats include:
|
| 284 |
+
- Standard: "WPe2e4" -> "e2e4"
|
| 285 |
+
- Decomposed: "WP e2_f e4_t" -> "e2e4"
|
| 286 |
+
- Pure UCI: "e2e4" -> "e2e4"
|
| 287 |
+
- With separators: "e2-e4", "e2 e4" -> "e2e4"
|
| 288 |
+
- With promotion: "e7e8=Q", "e7e8q" -> "e7e8q"
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
text: The generated text containing a move.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
UCI move string (e.g., "e2e4", "e7e8q") or None if not found.
|
| 295 |
+
"""
|
| 296 |
+
if not text:
|
| 297 |
+
return None
|
| 298 |
+
|
| 299 |
+
# Find all squares in the text
|
| 300 |
+
squares = re.findall(self.SQUARE_PATTERN, text)
|
| 301 |
+
|
| 302 |
+
if len(squares) < 2:
|
| 303 |
+
return None
|
| 304 |
+
|
| 305 |
+
# Take the first two squares as from and to
|
| 306 |
+
from_sq, to_sq = squares[0], squares[1]
|
| 307 |
+
uci_move = from_sq + to_sq
|
| 308 |
+
|
| 309 |
+
# Check for promotion (letter after to_square)
|
| 310 |
+
# Look for patterns like "=Q", "=q", or just "q" after the to_square
|
| 311 |
+
to_sq_idx = text.find(to_sq)
|
| 312 |
+
if to_sq_idx != -1:
|
| 313 |
+
remaining = text[to_sq_idx + 2:to_sq_idx + 5] # Check next few chars
|
| 314 |
+
promo_match = re.search(r'[=]?([qrbnQRBN])', remaining)
|
| 315 |
+
if promo_match:
|
| 316 |
+
uci_move += promo_match.group(1).lower()
|
| 317 |
+
|
| 318 |
+
return uci_move
|
| 319 |
+
|
| 320 |
+
def _has_complete_move(self, text: str) -> bool:
|
| 321 |
+
"""
|
| 322 |
+
Check if the generated text contains a complete move.
|
| 323 |
+
|
| 324 |
+
A complete move has at least two valid chess squares.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
text: The generated text so far.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
True if text contains at least two squares.
|
| 331 |
+
"""
|
| 332 |
+
squares = re.findall(self.SQUARE_PATTERN, text)
|
| 333 |
+
return len(squares) >= 2
|
| 334 |
+
|
| 335 |
def _generate_move_tokens(
|
| 336 |
self,
|
| 337 |
input_ids: torch.Tensor,
|
|
|
|
| 340 |
max_tokens: int = 20,
|
| 341 |
) -> str:
|
| 342 |
"""
|
| 343 |
+
Generate tokens until a complete move is detected or separator is hit.
|
| 344 |
|
| 345 |
+
This method is tokenizer-agnostic and stops when:
|
| 346 |
+
- A separator token (whitespace/EOS) is encountered
|
| 347 |
+
- Two chess squares have been generated (complete move)
|
| 348 |
+
- max_tokens limit is reached
|
| 349 |
|
| 350 |
Args:
|
| 351 |
input_ids: The input token IDs.
|
|
|
|
| 354 |
max_tokens: Maximum tokens to generate for a single move.
|
| 355 |
|
| 356 |
Returns:
|
| 357 |
+
The generated move string.
|
| 358 |
"""
|
| 359 |
generated_tokens = []
|
| 360 |
current_ids = input_ids.clone()
|
| 361 |
+
accumulated_text = ""
|
| 362 |
|
| 363 |
for _ in range(max_tokens):
|
| 364 |
with torch.no_grad():
|
|
|
|
| 367 |
|
| 368 |
# Apply top-k filtering
|
| 369 |
if top_k > 0:
|
| 370 |
+
top_k_vals = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 371 |
+
indices_to_remove = logits < top_k_vals[0][..., -1, None]
|
| 372 |
logits[indices_to_remove] = float("-inf")
|
| 373 |
|
| 374 |
# Sample
|
| 375 |
probs = torch.softmax(logits, dim=-1)
|
| 376 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 377 |
|
| 378 |
# Decode the token
|
| 379 |
token_str = self.tokenizer.decode(next_token[0])
|
| 380 |
|
| 381 |
# Check if this is a separator token
|
| 382 |
if self._is_separator_token(token_str):
|
| 383 |
+
# If we already have a complete move, stop
|
| 384 |
+
if self._has_complete_move(accumulated_text):
|
| 385 |
+
break
|
| 386 |
+
# Otherwise, if it's EOS, we should also stop
|
| 387 |
+
if hasattr(self.tokenizer, 'eos_token'):
|
| 388 |
+
if token_str == self.tokenizer.eos_token:
|
| 389 |
+
break
|
| 390 |
+
# For whitespace separators, only stop if we have content
|
| 391 |
+
if accumulated_text:
|
| 392 |
+
break
|
| 393 |
|
| 394 |
+
generated_tokens.append(next_token[0])
|
| 395 |
current_ids = torch.cat([current_ids, next_token], dim=-1)
|
| 396 |
+
accumulated_text += token_str
|
| 397 |
|
| 398 |
+
# Stop if we have a complete move (two squares found)
|
| 399 |
+
if self._has_complete_move(accumulated_text):
|
| 400 |
+
# Check if this might be a promotion - peek for one more token
|
| 401 |
+
# if the move is to rank 1 or 8
|
| 402 |
+
squares = re.findall(self.SQUARE_PATTERN, accumulated_text)
|
| 403 |
+
if len(squares) >= 2:
|
| 404 |
+
to_sq = squares[1]
|
| 405 |
+
if to_sq[1] in '18': # Potential promotion
|
| 406 |
+
# Allow one more iteration to capture promotion piece
|
| 407 |
+
if len(generated_tokens) > 3: # Already have enough
|
| 408 |
+
break
|
| 409 |
+
else:
|
| 410 |
+
break
|
| 411 |
|
| 412 |
# Decode all generated tokens together
|
| 413 |
if generated_tokens:
|
|
|
|
| 426 |
"""
|
| 427 |
Get the model's next move prediction.
|
| 428 |
|
| 429 |
+
This method is tokenizer-agnostic. It generates tokens and extracts
|
| 430 |
+
UCI moves using pattern matching on chess squares.
|
| 431 |
+
|
| 432 |
+
Works with any tokenization format:
|
| 433 |
+
- Move-level: "WPe2e4" -> e2e4
|
| 434 |
+
- Decomposed: "WP e2_f e4_t" -> e2e4
|
| 435 |
+
- Pure UCI: "e2e4" -> e2e4
|
| 436 |
+
- Character-level: "e" "2" "e" "4" -> e2e4
|
| 437 |
+
- BPE/subword: "e2" "e4" -> e2e4
|
| 438 |
|
| 439 |
Returns:
|
| 440 |
Tuple of (UCI move string, number of retries used).
|
|
|
|
| 451 |
input_text = self.tokenizer.bos_token + " " + moves_str
|
| 452 |
|
| 453 |
# Tokenize
|
|
|
|
| 454 |
inputs = self.tokenizer(
|
| 455 |
input_text,
|
| 456 |
return_tensors="pt",
|
| 457 |
truncation=True,
|
| 458 |
+
max_length=self.model.config.n_ctx - 10,
|
| 459 |
).to(self.device)
|
| 460 |
|
| 461 |
# Try to generate a legal move
|
| 462 |
for retry in range(self.max_retries):
|
| 463 |
+
# Generate tokens until we have a move
|
| 464 |
+
move_text = self._generate_move_tokens(
|
| 465 |
inputs["input_ids"],
|
| 466 |
temperature=temperature,
|
| 467 |
top_k=top_k,
|
| 468 |
)
|
| 469 |
|
| 470 |
+
# Extract UCI move using generic pattern matching
|
| 471 |
+
uci_move = self._extract_uci_move(move_text)
|
| 472 |
+
|
| 473 |
+
if uci_move:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
try:
|
| 475 |
move = self.chess.Move.from_uci(uci_move)
|
| 476 |
if move in board.legal_moves:
|
|
|
|
| 578 |
n_positions: int = 1000,
|
| 579 |
temperature: float = 0.7,
|
| 580 |
verbose: bool = True,
|
|
|
|
| 581 |
) -> dict:
|
| 582 |
"""
|
| 583 |
Evaluate the model's ability to generate legal moves.
|
|
|
|
| 589 |
n_positions: Number of positions to test.
|
| 590 |
temperature: Sampling temperature.
|
| 591 |
verbose: Whether to print progress.
|
|
|
|
| 592 |
|
| 593 |
Returns:
|
| 594 |
Dictionary with legal move statistics.
|
| 595 |
"""
|
|
|
|
|
|
|
|
|
|
| 596 |
results = {
|
| 597 |
"total_positions": 0,
|
| 598 |
"legal_first_try": 0,
|
|
|
|
| 755 |
Returns:
|
| 756 |
Tuple of (model, tokenizer).
|
| 757 |
"""
|
| 758 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
| 759 |
|
| 760 |
+
# Import to register custom classes
|
| 761 |
+
from src.model import ChessConfig, ChessForCausalLM
|
| 762 |
+
from src.tokenizer import ChessTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
|
| 764 |
+
# Try loading with custom tokenizer first, fall back to AutoTokenizer
|
| 765 |
try:
|
| 766 |
+
tokenizer = ChessTokenizer.from_pretrained(model_id)
|
| 767 |
+
except Exception:
|
| 768 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
+
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 771 |
model_id,
|
| 772 |
+
trust_remote_code=True,
|
| 773 |
device_map=device,
|
| 774 |
)
|
| 775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 776 |
return model, tokenizer
|
| 777 |
|
| 778 |
|
|
|
|
| 818 |
# Load model
|
| 819 |
print(f"\nLoading model from: {args.model_path}")
|
| 820 |
|
| 821 |
+
import os
|
| 822 |
+
is_local_path = os.path.exists(args.model_path)
|
| 823 |
+
|
| 824 |
+
if is_local_path:
|
| 825 |
# Local path
|
| 826 |
from transformers import AutoModelForCausalLM
|
| 827 |
+
from src.tokenizer import ChessTokenizer
|
| 828 |
+
from src.model import ChessConfig, ChessForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
tokenizer = ChessTokenizer.from_pretrained(args.model_path)
|
| 831 |
model = AutoModelForCausalLM.from_pretrained(args.model_path)
|
| 832 |
+
else:
|
| 833 |
+
# Assume Hugging Face model ID (or invalid path)
|
| 834 |
+
if args.model_path.startswith(".") or args.model_path.startswith("/"):
|
| 835 |
+
raise FileNotFoundError(
|
| 836 |
+
f"Local model path not found: {args.model_path}\n"
|
| 837 |
+
f"Please check that the path exists and contains model files."
|
| 838 |
+
)
|
| 839 |
+
model, tokenizer = load_model_from_hub(args.model_path)
|
| 840 |
|
| 841 |
# Create evaluator
|
| 842 |
print(f"\nSetting up evaluator...")
|
src/train.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for the Chess Challenge.
|
| 3 |
+
|
| 4 |
+
This script provides a complete training pipeline using the Hugging Face Trainer.
|
| 5 |
+
Students can modify this script to experiment with different training strategies.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Suppress warnings from third-party libraries (multiprocess has Python 3.14 compat issues)
|
| 16 |
+
warnings.filterwarnings("ignore", message="'return' in a 'finally' block")
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from transformers import (
|
| 20 |
+
Trainer,
|
| 21 |
+
TrainingArguments,
|
| 22 |
+
set_seed,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from src.data import ChessDataCollator, create_train_val_datasets
|
| 26 |
+
from src.model import ChessConfig, ChessForCausalLM
|
| 27 |
+
from src.tokenizer import ChessTokenizer
|
| 28 |
+
from src.utils import count_parameters, print_parameter_budget
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def parse_args():
|
| 32 |
+
"""Parse command line arguments."""
|
| 33 |
+
parser = argparse.ArgumentParser(
|
| 34 |
+
description="Train a chess-playing language model"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Model arguments
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--vocab_size", type=int, default=1200,
|
| 40 |
+
help="Vocabulary size"
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--n_embd", type=int, default=128,
|
| 44 |
+
help="Embedding dimension"
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--n_layer", type=int, default=4,
|
| 48 |
+
help="Number of transformer layers"
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--n_head", type=int, default=4,
|
| 52 |
+
help="Number of attention heads"
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--n_ctx", type=int, default=256,
|
| 56 |
+
help="Maximum context length"
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--n_inner", type=int, default=None,
|
| 60 |
+
help="Feed-forward inner dimension (default: 4 * n_embd)"
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--dropout", type=float, default=0.1,
|
| 64 |
+
help="Dropout probability"
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--no_tie_weights", action="store_true",
|
| 68 |
+
help="Disable weight tying between embedding and output layers"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Data arguments
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--dataset_name", type=str, default="dlouapre/lichess_2025-01_1M",
|
| 74 |
+
help="Name of the dataset on Hugging Face Hub"
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--max_train_samples", type=int, default=None,
|
| 78 |
+
help="Maximum number of training samples"
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--val_samples", type=int, default=5000,
|
| 82 |
+
help="Number of validation samples"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Training arguments
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--output_dir", type=str, default="./output",
|
| 88 |
+
help="Output directory for model and logs"
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--num_train_epochs", type=int, default=3,
|
| 92 |
+
help="Number of training epochs"
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--per_device_train_batch_size", type=int, default=32,
|
| 96 |
+
help="Training batch size per device"
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--per_device_eval_batch_size", type=int, default=64,
|
| 100 |
+
help="Evaluation batch size per device"
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--learning_rate", type=float, default=5e-4,
|
| 104 |
+
help="Learning rate"
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--weight_decay", type=float, default=0.01,
|
| 108 |
+
help="Weight decay"
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--warmup_ratio", type=float, default=0.1,
|
| 112 |
+
help="Warmup ratio"
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--seed", type=int, default=42,
|
| 116 |
+
help="Random seed"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Logging arguments
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--logging_steps", type=int, default=100,
|
| 122 |
+
help="Logging frequency"
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--eval_steps", type=int, default=500,
|
| 126 |
+
help="Evaluation frequency"
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--save_steps", type=int, default=1000,
|
| 130 |
+
help="Checkpoint saving frequency"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return parser.parse_args()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def main():
|
| 137 |
+
"""Main training function."""
|
| 138 |
+
args = parse_args()
|
| 139 |
+
|
| 140 |
+
# Set seed for reproducibility
|
| 141 |
+
set_seed(args.seed)
|
| 142 |
+
|
| 143 |
+
print("=" * 60)
|
| 144 |
+
print("CHESS CHALLENGE - TRAINING")
|
| 145 |
+
print("=" * 60)
|
| 146 |
+
|
| 147 |
+
# Build tokenizer from dataset
|
| 148 |
+
print("\nBuilding tokenizer from dataset...")
|
| 149 |
+
tokenizer = ChessTokenizer.build_vocab_from_dataset(
|
| 150 |
+
dataset_name=args.dataset_name,
|
| 151 |
+
min_frequency=500, # Only keep moves that appear at least 500 times
|
| 152 |
+
max_samples=100000, # Use 100k games to build vocabulary
|
| 153 |
+
)
|
| 154 |
+
print(f" Vocabulary size: {tokenizer.vocab_size}")
|
| 155 |
+
|
| 156 |
+
# Use the vocab size from tokenizer (override args if provided)
|
| 157 |
+
actual_vocab_size = tokenizer.vocab_size
|
| 158 |
+
|
| 159 |
+
# Create model configuration
|
| 160 |
+
print("\nCreating model configuration...")
|
| 161 |
+
config = ChessConfig(
|
| 162 |
+
vocab_size=actual_vocab_size,
|
| 163 |
+
n_embd=args.n_embd,
|
| 164 |
+
n_layer=args.n_layer,
|
| 165 |
+
n_head=args.n_head,
|
| 166 |
+
n_ctx=args.n_ctx,
|
| 167 |
+
n_inner=args.n_inner,
|
| 168 |
+
dropout=args.dropout,
|
| 169 |
+
tie_weights=not args.no_tie_weights,
|
| 170 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 171 |
+
bos_token_id=tokenizer.bos_token_id,
|
| 172 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Print parameter budget
|
| 176 |
+
print_parameter_budget(config)
|
| 177 |
+
|
| 178 |
+
# Create model
|
| 179 |
+
print("\nCreating model...")
|
| 180 |
+
model = ChessForCausalLM(config)
|
| 181 |
+
n_params = count_parameters(model)
|
| 182 |
+
print(f" Total parameters: {n_params:,}")
|
| 183 |
+
|
| 184 |
+
if n_params > 1_000_000:
|
| 185 |
+
print("WARNING: Model exceeds 1M parameter limit!")
|
| 186 |
+
else:
|
| 187 |
+
print("✓ Model is within 1M parameter limit")
|
| 188 |
+
|
| 189 |
+
# Load datasets
|
| 190 |
+
print("\nLoading datasets...")
|
| 191 |
+
train_dataset, val_dataset = create_train_val_datasets(
|
| 192 |
+
tokenizer=tokenizer,
|
| 193 |
+
dataset_name=args.dataset_name,
|
| 194 |
+
max_length=args.n_ctx,
|
| 195 |
+
train_samples=args.max_train_samples,
|
| 196 |
+
val_samples=args.val_samples,
|
| 197 |
+
)
|
| 198 |
+
print(f" Training samples: {len(train_dataset):,}")
|
| 199 |
+
print(f" Validation samples: {len(val_dataset):,}")
|
| 200 |
+
|
| 201 |
+
# Create data collator
|
| 202 |
+
data_collator = ChessDataCollator(tokenizer, max_length=args.n_ctx)
|
| 203 |
+
|
| 204 |
+
# Training arguments
|
| 205 |
+
training_args = TrainingArguments(
|
| 206 |
+
output_dir=args.output_dir,
|
| 207 |
+
num_train_epochs=args.num_train_epochs,
|
| 208 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 209 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
| 210 |
+
learning_rate=args.learning_rate,
|
| 211 |
+
weight_decay=args.weight_decay,
|
| 212 |
+
warmup_ratio=args.warmup_ratio,
|
| 213 |
+
logging_dir=os.path.join(args.output_dir, "logs"),
|
| 214 |
+
logging_steps=args.logging_steps,
|
| 215 |
+
eval_strategy="epoch",
|
| 216 |
+
save_strategy="epoch",
|
| 217 |
+
save_total_limit=3,
|
| 218 |
+
load_best_model_at_end=True,
|
| 219 |
+
metric_for_best_model="eval_loss",
|
| 220 |
+
greater_is_better=False,
|
| 221 |
+
seed=args.seed,
|
| 222 |
+
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
|
| 223 |
+
report_to=["none"],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Create trainer
|
| 227 |
+
trainer = Trainer(
|
| 228 |
+
model=model,
|
| 229 |
+
args=training_args,
|
| 230 |
+
train_dataset=train_dataset,
|
| 231 |
+
eval_dataset=val_dataset,
|
| 232 |
+
data_collator=data_collator,
|
| 233 |
+
tokenizer=tokenizer,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Train
|
| 237 |
+
print("\nStarting training...")
|
| 238 |
+
trainer.train()
|
| 239 |
+
|
| 240 |
+
# Save final model
|
| 241 |
+
print("\nSaving final model...")
|
| 242 |
+
trainer.save_model(os.path.join(args.output_dir, "final_model"))
|
| 243 |
+
tokenizer.save_pretrained(os.path.join(args.output_dir, "final_model"))
|
| 244 |
+
|
| 245 |
+
print("\nTraining complete!")
|
| 246 |
+
print(f" Model saved to: {args.output_dir}/final_model")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
src/utils.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the Chess Challenge.
|
| 3 |
+
|
| 4 |
+
This module provides helper functions for:
|
| 5 |
+
- Parameter counting and budget analysis
|
| 6 |
+
- Model registration with Hugging Face
|
| 7 |
+
- Move validation with python-chess
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Dict, Optional, TYPE_CHECKING
|
| 13 |
+
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from src.model import ChessConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
|
| 21 |
+
"""
|
| 22 |
+
Count the number of parameters in a model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
model: The PyTorch model.
|
| 26 |
+
trainable_only: If True, only count trainable parameters.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Total number of parameters.
|
| 30 |
+
"""
|
| 31 |
+
if trainable_only:
|
| 32 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 33 |
+
return sum(p.numel() for p in model.parameters())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
|
| 37 |
+
"""
|
| 38 |
+
Count parameters broken down by model component.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model: The PyTorch model.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Dictionary mapping component names to parameter counts.
|
| 45 |
+
"""
|
| 46 |
+
counts = {}
|
| 47 |
+
for name, module in model.named_modules():
|
| 48 |
+
if len(list(module.children())) == 0: # Leaf module
|
| 49 |
+
param_count = sum(p.numel() for p in module.parameters(recurse=False))
|
| 50 |
+
if param_count > 0:
|
| 51 |
+
counts[name] = param_count
|
| 52 |
+
return counts
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
|
| 56 |
+
"""
|
| 57 |
+
Estimate the parameter count for a given configuration.
|
| 58 |
+
|
| 59 |
+
This is useful for planning your architecture before building the model.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
config: Model configuration.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Dictionary with estimated parameter counts by component.
|
| 66 |
+
"""
|
| 67 |
+
V = config.vocab_size
|
| 68 |
+
d = config.n_embd
|
| 69 |
+
L = config.n_layer
|
| 70 |
+
n_ctx = config.n_ctx
|
| 71 |
+
n_inner = config.n_inner
|
| 72 |
+
|
| 73 |
+
estimates = {
|
| 74 |
+
"token_embeddings": V * d,
|
| 75 |
+
"position_embeddings": n_ctx * d,
|
| 76 |
+
"attention_qkv_per_layer": 3 * d * d,
|
| 77 |
+
"attention_proj_per_layer": d * d,
|
| 78 |
+
"ffn_per_layer": 2 * d * n_inner,
|
| 79 |
+
"layernorm_per_layer": 4 * d, # 2 LayerNorms, each with weight and bias
|
| 80 |
+
"final_layernorm": 2 * d,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Calculate totals
|
| 84 |
+
per_layer = (
|
| 85 |
+
estimates["attention_qkv_per_layer"] +
|
| 86 |
+
estimates["attention_proj_per_layer"] +
|
| 87 |
+
estimates["ffn_per_layer"] +
|
| 88 |
+
estimates["layernorm_per_layer"]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
estimates["total_transformer_layers"] = L * per_layer
|
| 92 |
+
|
| 93 |
+
# LM head (tied with embeddings by default)
|
| 94 |
+
if config.tie_weights:
|
| 95 |
+
estimates["lm_head"] = 0
|
| 96 |
+
estimates["lm_head_note"] = "Tied with token embeddings"
|
| 97 |
+
else:
|
| 98 |
+
estimates["lm_head"] = V * d
|
| 99 |
+
|
| 100 |
+
# Grand total
|
| 101 |
+
estimates["total"] = (
|
| 102 |
+
estimates["token_embeddings"] +
|
| 103 |
+
estimates["position_embeddings"] +
|
| 104 |
+
estimates["total_transformer_layers"] +
|
| 105 |
+
estimates["final_layernorm"] +
|
| 106 |
+
estimates["lm_head"]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return estimates
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
|
| 113 |
+
"""
|
| 114 |
+
Print a formatted parameter budget analysis.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
config: Model configuration.
|
| 118 |
+
limit: Parameter limit to compare against.
|
| 119 |
+
"""
|
| 120 |
+
estimates = estimate_parameters(config)
|
| 121 |
+
|
| 122 |
+
print("=" * 60)
|
| 123 |
+
print("PARAMETER BUDGET ANALYSIS")
|
| 124 |
+
print("=" * 60)
|
| 125 |
+
print(f"\nConfiguration:")
|
| 126 |
+
print(f" vocab_size (V) = {config.vocab_size}")
|
| 127 |
+
print(f" n_embd (d) = {config.n_embd}")
|
| 128 |
+
print(f" n_layer (L) = {config.n_layer}")
|
| 129 |
+
print(f" n_head = {config.n_head}")
|
| 130 |
+
print(f" n_ctx = {config.n_ctx}")
|
| 131 |
+
print(f" n_inner = {config.n_inner}")
|
| 132 |
+
print(f" tie_weights = {config.tie_weights}")
|
| 133 |
+
|
| 134 |
+
print(f"\nParameter Breakdown:")
|
| 135 |
+
print(f" Token Embeddings: {estimates['token_embeddings']:>10,}")
|
| 136 |
+
print(f" Position Embeddings: {estimates['position_embeddings']:>10,}")
|
| 137 |
+
print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}")
|
| 138 |
+
print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}")
|
| 139 |
+
|
| 140 |
+
if config.tie_weights:
|
| 141 |
+
print(f" LM Head: {'(tied)':>10}")
|
| 142 |
+
else:
|
| 143 |
+
print(f" LM Head: {estimates['lm_head']:>10,}")
|
| 144 |
+
|
| 145 |
+
print(f" " + "-" * 30)
|
| 146 |
+
print(f" TOTAL: {estimates['total']:>10,}")
|
| 147 |
+
|
| 148 |
+
print(f"\nBudget Status:")
|
| 149 |
+
print(f" Limit: {limit:>10,}")
|
| 150 |
+
print(f" Used: {estimates['total']:>10,}")
|
| 151 |
+
print(f" Remaining:{limit - estimates['total']:>10,}")
|
| 152 |
+
|
| 153 |
+
if estimates['total'] <= limit:
|
| 154 |
+
print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)")
|
| 155 |
+
else:
|
| 156 |
+
print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!")
|
| 157 |
+
|
| 158 |
+
print("=" * 60)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
|
| 162 |
+
"""
|
| 163 |
+
Validate a move using python-chess.
|
| 164 |
+
|
| 165 |
+
This function converts the dataset's extended UCI format to standard UCI
|
| 166 |
+
and validates it against the current board state.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)").
|
| 170 |
+
board_fen: FEN string of the current board state (optional).
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
True if the move is legal, False otherwise.
|
| 174 |
+
"""
|
| 175 |
+
try:
|
| 176 |
+
import chess
|
| 177 |
+
except ImportError:
|
| 178 |
+
raise ImportError("python-chess is required for move validation. "
|
| 179 |
+
"Install it with: pip install python-chess")
|
| 180 |
+
|
| 181 |
+
# Parse the extended UCI format
|
| 182 |
+
# Format: [W|B][Piece][from_sq][to_sq][suffix]
|
| 183 |
+
# Example: WPe2e4, BNg8f6(x), WKe1g1(o)
|
| 184 |
+
|
| 185 |
+
if len(move) < 6:
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
# Extract components
|
| 189 |
+
color = move[0] # W or B
|
| 190 |
+
piece = move[1] # P, N, B, R, Q, K
|
| 191 |
+
from_sq = move[2:4] # e.g., "e2"
|
| 192 |
+
to_sq = move[4:6] # e.g., "e4"
|
| 193 |
+
|
| 194 |
+
# Check for promotion
|
| 195 |
+
promotion = None
|
| 196 |
+
if "=" in move:
|
| 197 |
+
promo_idx = move.index("=")
|
| 198 |
+
promotion = move[promo_idx + 1].lower()
|
| 199 |
+
|
| 200 |
+
# Create board
|
| 201 |
+
board = chess.Board(board_fen) if board_fen else chess.Board()
|
| 202 |
+
|
| 203 |
+
# Build UCI move string
|
| 204 |
+
uci_move = from_sq + to_sq
|
| 205 |
+
if promotion:
|
| 206 |
+
uci_move += promotion
|
| 207 |
+
|
| 208 |
+
try:
|
| 209 |
+
move_obj = chess.Move.from_uci(uci_move)
|
| 210 |
+
return move_obj in board.legal_moves
|
| 211 |
+
except (ValueError, chess.InvalidMoveError):
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def convert_extended_uci_to_uci(move: str) -> str:
|
| 216 |
+
"""
|
| 217 |
+
Convert extended UCI format to standard UCI format.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
move: Move in extended UCI format (e.g., "WPe2e4").
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Move in standard UCI format (e.g., "e2e4").
|
| 224 |
+
"""
|
| 225 |
+
if len(move) < 6:
|
| 226 |
+
return move
|
| 227 |
+
|
| 228 |
+
# Extract squares
|
| 229 |
+
from_sq = move[2:4]
|
| 230 |
+
to_sq = move[4:6]
|
| 231 |
+
|
| 232 |
+
# Check for promotion
|
| 233 |
+
promotion = ""
|
| 234 |
+
if "=" in move:
|
| 235 |
+
promo_idx = move.index("=")
|
| 236 |
+
promotion = move[promo_idx + 1].lower()
|
| 237 |
+
|
| 238 |
+
return from_sq + to_sq + promotion
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def convert_uci_to_extended(
|
| 242 |
+
uci_move: str,
|
| 243 |
+
board_fen: str,
|
| 244 |
+
) -> str:
|
| 245 |
+
"""
|
| 246 |
+
Convert standard UCI format to extended UCI format.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
uci_move: Move in standard UCI format (e.g., "e2e4").
|
| 250 |
+
board_fen: FEN string of the current board state.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
Move in extended UCI format (e.g., "WPe2e4").
|
| 254 |
+
"""
|
| 255 |
+
try:
|
| 256 |
+
import chess
|
| 257 |
+
except ImportError:
|
| 258 |
+
raise ImportError("python-chess is required for move conversion.")
|
| 259 |
+
|
| 260 |
+
board = chess.Board(board_fen)
|
| 261 |
+
move = chess.Move.from_uci(uci_move)
|
| 262 |
+
|
| 263 |
+
# Get color
|
| 264 |
+
color = "W" if board.turn == chess.WHITE else "B"
|
| 265 |
+
|
| 266 |
+
# Get piece
|
| 267 |
+
piece = board.piece_at(move.from_square)
|
| 268 |
+
piece_letter = piece.symbol().upper() if piece else "P"
|
| 269 |
+
|
| 270 |
+
# Build extended UCI
|
| 271 |
+
from_sq = chess.square_name(move.from_square)
|
| 272 |
+
to_sq = chess.square_name(move.to_square)
|
| 273 |
+
|
| 274 |
+
result = f"{color}{piece_letter}{from_sq}{to_sq}"
|
| 275 |
+
|
| 276 |
+
# Add promotion
|
| 277 |
+
if move.promotion:
|
| 278 |
+
result += f"={chess.piece_symbol(move.promotion).upper()}"
|
| 279 |
+
|
| 280 |
+
# Add suffix for captures
|
| 281 |
+
if board.is_capture(move):
|
| 282 |
+
result += "(x)"
|
| 283 |
+
|
| 284 |
+
# Add suffix for check/checkmate
|
| 285 |
+
board.push(move)
|
| 286 |
+
if board.is_checkmate():
|
| 287 |
+
if "(x)" in result:
|
| 288 |
+
result = result.replace("(x)", "(x+*)")
|
| 289 |
+
else:
|
| 290 |
+
result += "(+*)"
|
| 291 |
+
elif board.is_check():
|
| 292 |
+
if "(x)" in result:
|
| 293 |
+
result = result.replace("(x)", "(x+)")
|
| 294 |
+
else:
|
| 295 |
+
result += "(+)"
|
| 296 |
+
board.pop()
|
| 297 |
+
|
| 298 |
+
# Handle castling notation
|
| 299 |
+
if board.is_castling(move):
|
| 300 |
+
if move.to_square in [chess.G1, chess.G8]: # Kingside
|
| 301 |
+
result = result.replace("(x)", "").replace("(+)", "") + "(o)"
|
| 302 |
+
else: # Queenside
|
| 303 |
+
result = result.replace("(x)", "").replace("(+)", "") + "(O)"
|
| 304 |
+
|
| 305 |
+
return result
|
submit.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Submission script for the Chess Challenge.
|
| 4 |
+
|
| 5 |
+
This script pushes your trained model to the Hugging Face Hub under the
|
| 6 |
+
LLM-course organization, with metadata tracking who submitted it.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python submit.py --model_path ./my_model/final_model --model_name my-chess-model
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import tempfile
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser(description="Submit your chess model to Hugging Face Hub")
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--model_path", type=str, default="./my_model/final_model",
|
| 22 |
+
help="Path to your trained model directory"
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--model_name", type=str, required=True,
|
| 26 |
+
help="Name for your model on the Hub (e.g., 'my-chess-model')"
|
| 27 |
+
)
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
# Fixed organization
|
| 31 |
+
organization = "LLM-course"
|
| 32 |
+
|
| 33 |
+
# Check model path exists
|
| 34 |
+
if not os.path.exists(args.model_path):
|
| 35 |
+
print(f"Error: Model path '{args.model_path}' does not exist.")
|
| 36 |
+
print("Train a model first with: python -m src.train --output_dir ./my_model")
|
| 37 |
+
return 1
|
| 38 |
+
|
| 39 |
+
# Import here to avoid slow startup
|
| 40 |
+
from huggingface_hub import HfApi, HfFolder, whoami
|
| 41 |
+
from transformers import AutoModelForCausalLM
|
| 42 |
+
|
| 43 |
+
# Ensure user is logged in and get their info
|
| 44 |
+
print("=" * 60)
|
| 45 |
+
print("CHESS CHALLENGE - MODEL SUBMISSION")
|
| 46 |
+
print("=" * 60)
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
user_info = whoami()
|
| 50 |
+
username = user_info["name"]
|
| 51 |
+
print(f"\nLogged in as: {username}")
|
| 52 |
+
except Exception:
|
| 53 |
+
print("\nYou need to log in to Hugging Face first.")
|
| 54 |
+
print("Run: huggingface-cli login")
|
| 55 |
+
return 1
|
| 56 |
+
|
| 57 |
+
# Import custom classes to register them
|
| 58 |
+
from src.model import ChessConfig, ChessForCausalLM
|
| 59 |
+
from src.tokenizer import ChessTokenizer
|
| 60 |
+
|
| 61 |
+
# Load model and tokenizer
|
| 62 |
+
print(f"\nLoading model from: {args.model_path}")
|
| 63 |
+
model = AutoModelForCausalLM.from_pretrained(args.model_path)
|
| 64 |
+
tokenizer = ChessTokenizer.from_pretrained(args.model_path)
|
| 65 |
+
|
| 66 |
+
# Count parameters
|
| 67 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 68 |
+
print(f"Model parameters: {n_params:,}")
|
| 69 |
+
|
| 70 |
+
if n_params > 1_000_000:
|
| 71 |
+
print(f"WARNING: Model exceeds 1M parameter limit ({n_params:,} params)")
|
| 72 |
+
|
| 73 |
+
# Prepare repo name
|
| 74 |
+
repo_id = f"{organization}/{args.model_name}"
|
| 75 |
+
print(f"\nSubmitting to: {repo_id}")
|
| 76 |
+
|
| 77 |
+
# Create a temporary directory to prepare submission
|
| 78 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 79 |
+
tmp_path = Path(tmp_dir)
|
| 80 |
+
|
| 81 |
+
# Save model and tokenizer
|
| 82 |
+
model.save_pretrained(tmp_path)
|
| 83 |
+
tokenizer.save_pretrained(tmp_path)
|
| 84 |
+
|
| 85 |
+
# Create model card with submitter info
|
| 86 |
+
model_card = f"""---
|
| 87 |
+
library_name: transformers
|
| 88 |
+
tags:
|
| 89 |
+
- chess
|
| 90 |
+
- llm-course
|
| 91 |
+
- chess-challenge
|
| 92 |
+
license: mit
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
# {args.model_name}
|
| 96 |
+
|
| 97 |
+
Chess model submitted to the LLM Course Chess Challenge.
|
| 98 |
+
|
| 99 |
+
## Submission Info
|
| 100 |
+
|
| 101 |
+
- **Submitted by**: [{username}](https://huggingface.co/{username})
|
| 102 |
+
- **Parameters**: {n_params:,}
|
| 103 |
+
- **Organization**: {organization}
|
| 104 |
+
|
| 105 |
+
## Model Details
|
| 106 |
+
|
| 107 |
+
- **Architecture**: Chess Transformer (GPT-style)
|
| 108 |
+
- **Vocab size**: {tokenizer.vocab_size}
|
| 109 |
+
- **Embedding dim**: {model.config.n_embd}
|
| 110 |
+
- **Layers**: {model.config.n_layer}
|
| 111 |
+
- **Heads**: {model.config.n_head}
|
| 112 |
+
"""
|
| 113 |
+
(tmp_path / "README.md").write_text(model_card)
|
| 114 |
+
|
| 115 |
+
# Push to Hub
|
| 116 |
+
print("\nUploading to Hugging Face Hub...")
|
| 117 |
+
api = HfApi()
|
| 118 |
+
|
| 119 |
+
# Create repo if it doesn't exist
|
| 120 |
+
api.create_repo(
|
| 121 |
+
repo_id=repo_id,
|
| 122 |
+
exist_ok=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Upload all files
|
| 126 |
+
api.upload_folder(
|
| 127 |
+
folder_path=tmp_path,
|
| 128 |
+
repo_id=repo_id,
|
| 129 |
+
commit_message=f"Chess Challenge submission by {username}",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print("\n" + "=" * 60)
|
| 133 |
+
print("SUBMISSION COMPLETE!")
|
| 134 |
+
print("=" * 60)
|
| 135 |
+
print(f"\nYour model is now available at:")
|
| 136 |
+
print(f" https://huggingface.co/{repo_id}")
|
| 137 |
+
print(f"\nSubmitted by: {username}")
|
| 138 |
+
print(f"Parameters: {n_params:,}")
|
| 139 |
+
|
| 140 |
+
return 0
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
exit(main())
|