nathanael-fijalkow commited on
Commit
8a7719b
·
1 Parent(s): d7e086f

Integrated the template, made evaluation more robust to different tokenizers

Browse files
Files changed (10) hide show
  1. .gitignore +56 -0
  2. TEMPLATE_README.md +152 -0
  3. app.py +7 -1
  4. pyproject.toml +59 -0
  5. src/__init__.py +10 -1
  6. src/data.py +253 -0
  7. src/evaluate.py +270 -132
  8. src/train.py +250 -0
  9. src/utils.py +305 -0
  10. submit.py +144 -0
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ *.egg
11
+
12
+ # Virtual environments
13
+ .venv/
14
+ venv/
15
+ env/
16
+ ENV/
17
+
18
+ # IDE
19
+ .vscode/
20
+ .idea/
21
+ *.swp
22
+ *.swo
23
+ *~
24
+ .DS_Store
25
+
26
+ # Model outputs and training artifacts
27
+ my_model/
28
+ checkpoints/
29
+ runs/
30
+ wandb/
31
+ *.pth
32
+ *.pt
33
+ *.safetensors
34
+ *.bin
35
+
36
+ # Dataset caches
37
+ .cache/
38
+ *.arrow
39
+ *.parquet
40
+
41
+ # Jupyter
42
+ .ipynb_checkpoints/
43
+ *.ipynb
44
+
45
+ # Testing
46
+ .pytest_cache/
47
+ .coverage
48
+ htmlcov/
49
+
50
+ # Logs
51
+ *.log
52
+ logs/
53
+
54
+ # Environment variables
55
+ .env
56
+ .env.local
TEMPLATE_README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chess Challenge
2
+
3
+ Train a 1M parameter LLM to play chess!
4
+
5
+ ## Objective
6
+
7
+ Design and train a transformer-based language model to predict chess moves. Your model must:
8
+
9
+ 1. **Stay under 1M parameters** - This is the hard constraint!
10
+ 2. **Use a custom tokenizer** - Design an efficient move-level tokenizer
11
+ 3. **Play legal chess** - The model should learn to generate valid moves
12
+ 4. **Beat Stockfish** - Your ELO will be measured against Stockfish Level 1
13
+
14
+ ## Dataset
15
+
16
+ We use the Lichess dataset: [`dlouapre/lichess_2025-01_1M`](https://huggingface.co/datasets/dlouapre/lichess_2025-01_1M)
17
+
18
+ The dataset uses an extended UCI notation:
19
+ - `W`/`B` prefix for White/Black
20
+ - Piece letter: `P`=Pawn, `N`=Knight, `B`=Bishop, `R`=Rook, `Q`=Queen, `K`=King
21
+ - Source and destination squares (e.g., `e2e4`)
22
+ - Special suffixes: `(x)`=capture, `(+)`=check, `(+*)`=checkmate, `(o)`/`(O)`=castling
23
+
24
+ Example game:
25
+ ```
26
+ WPe2e4 BPe7e5 WNg1f3 BNb8c6 WBf1b5 BPa7a6 WBb5c6(x) BPd7c6(x) ...
27
+ ```
28
+
29
+ ## Quick Start
30
+
31
+ ### Train a Model
32
+
33
+ ```bash
34
+ # Basic training
35
+ python -m src.train \
36
+ --output_dir ./my_model \
37
+ --num_train_epochs 3 \
38
+ --per_device_train_batch_size 32
39
+ ```
40
+
41
+ ### Evaluate Your Model
42
+
43
+ Evaluation happens in two phases:
44
+
45
+ ```bash
46
+ # Phase 1: Legal Move Evaluation (quick sanity check)
47
+ python -m src.evaluate \
48
+ --model_path ./my_model \
49
+ --mode legal \
50
+ --n_positions 500
51
+
52
+ # Phase 2: Win Rate Evaluation (full games against Stockfish)
53
+ python -m src.evaluate \
54
+ --model_path ./my_model \
55
+ --mode winrate \
56
+ --n_games 100 \
57
+ --stockfish_level 1
58
+
59
+ # Or run both phases:
60
+ python -m src.evaluate \
61
+ --model_path ./my_model \
62
+ --mode both
63
+ ```
64
+
65
+ ## Parameter Budget
66
+
67
+ Use the utility function to check your budget:
68
+
69
+ ```python
70
+ from src import ChessConfig, print_parameter_budget
71
+
72
+ config = ChessConfig(
73
+ vocab_size=1200,
74
+ n_embd=128,
75
+ n_layer=4,
76
+ n_head=4,
77
+ )
78
+ print_parameter_budget(config)
79
+ ```
80
+
81
+ ### Pro Tips
82
+
83
+ 1. **Weight Tying**: The default config ties the embedding and output layer weights, saving ~154k parameters
84
+ 2. **Vocabulary Size**: Keep it small! ~1200 tokens covers all moves
85
+ 3. **Depth vs Width**: With limited parameters, experiment with shallow-but-wide vs deep-but-narrow
86
+
87
+ ## Customization
88
+
89
+ ### Custom Tokenizer
90
+
91
+ The template provides a move-level tokenizer that builds vocabulary from the actual dataset.
92
+ Feel free to try different approaches!
93
+
94
+ ### Custom Architecture
95
+
96
+ Modify the model in `src/model.py`:
97
+
98
+ ```python
99
+ from src import ChessConfig, ChessForCausalLM
100
+
101
+ # Customize configuration
102
+ config = ChessConfig(
103
+ vocab_size=1200,
104
+ n_embd=128, # Try 96, 128, or 192
105
+ n_layer=4, # Try 3, 4, or 6
106
+ n_head=4, # Try 4 or 8
107
+ n_inner=384, # Feed-forward dimension (default: 3*n_embd)
108
+ dropout=0.1,
109
+ tie_weights=True,
110
+ )
111
+
112
+ model = ChessForCausalLM(config)
113
+ ```
114
+
115
+ ## Evaluation Metrics
116
+
117
+ ### Phase 1: Legal Move Evaluation
118
+
119
+ Tests if your model generates valid chess moves:
120
+
121
+ | Metric | Description |
122
+ |--------|-------------|
123
+ | **Legal Rate (1st try)** | % of legal moves on first attempt |
124
+ | **Legal Rate (with retry)** | % of legal moves within 3 attempts |
125
+
126
+ > **Target**: >90% legal rate before proceeding to Phase 2
127
+
128
+ ### Phase 2: Win Rate Evaluation
129
+
130
+ Full games against Stockfish to measure playing strength:
131
+
132
+ | Metric | Description |
133
+ |--------|-------------|
134
+ | **Win Rate** | % of games won against Stockfish |
135
+ | **ELO Rating** | Estimated rating based on game results |
136
+ | **Avg Game Length** | Average number of moves per game |
137
+ | **Illegal Move Rate** | % of illegal moves during games |
138
+
139
+
140
+ ## Submission
141
+
142
+ 1. Train your model
143
+ 2. Log in to Hugging Face: `hf auth login`
144
+ 3. Submit your model using the submission script:
145
+
146
+ ```bash
147
+ python submit.py --model_path ./my_model/final_model --model_name your-model-name
148
+ ```
149
+
150
+ The script will:
151
+ - Upload your model to the LLM-course organization
152
+ - Include your HF username in the model card for tracking
app.py CHANGED
@@ -591,7 +591,13 @@ with gr.Blocks(
591
  The goal is to create a chess-playing language model with **under 1 million parameters**, which is roughly the number of neurons in a honey bee's brain.
592
  At this scale, efficiency and clever architecture choices are key! We are not targetting superhuman performance, but rather exploring how well small models can learn the rules of chess, the goal being (only) to play **legal moves**.
593
 
594
- 1. **Train your model** using the [Chess Challenge Template](https://github.com/nathanael-fijalkow/ChessChallengeTemplate)
 
 
 
 
 
 
595
 
596
  2. **Push to Hugging Face Hub** using the `submit.py` script provided in the template to make sure that your model is registered correctly.
597
 
 
591
  The goal is to create a chess-playing language model with **under 1 million parameters**, which is roughly the number of neurons in a honey bee's brain.
592
  At this scale, efficiency and clever architecture choices are key! We are not targetting superhuman performance, but rather exploring how well small models can learn the rules of chess, the goal being (only) to play **legal moves**.
593
 
594
+ 0. **Clone this repository**:
595
+ ```bash
596
+ git clone https://huggingface.co/spaces/LLM-course/Chess1MChallenge
597
+ ```
598
+ and check the `TEMPLATE_README.md` for detailed instructions.
599
+
600
+ 1. **Train your model**
601
 
602
  2. **Push to Hugging Face Hub** using the `submit.py` script provided in the template to make sure that your model is registered correctly.
603
 
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "chess-challenge"
7
+ version = "0.1.0"
8
+ description = "LLM Chess Challenge - Train a 1M parameter model to play chess"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "Nathanaël Fijalkow", email = "nathanael.fijalkow@gmail.com"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Education",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ ]
24
+ dependencies = [
25
+ "torch>=2.0.0",
26
+ "transformers>=4.40.0",
27
+ "accelerate>=0.26.0",
28
+ "datasets>=2.14.0",
29
+ "python-chess>=1.999",
30
+ "huggingface-hub>=0.20.0",
31
+ "tqdm>=4.65.0",
32
+ "numpy>=1.24.0",
33
+ "wandb>=0.15.0",
34
+ ]
35
+
36
+ [project.optional-dependencies]
37
+ dev = [
38
+ "pytest>=7.0.0",
39
+ "black>=23.0.0",
40
+ "ruff>=0.1.0",
41
+ ]
42
+ eval = [
43
+ "stockfish>=3.28.0",
44
+ ]
45
+
46
+ [project.scripts]
47
+ chess-train = "src.train:main"
48
+ chess-eval = "src.evaluate:main"
49
+
50
+ [tool.setuptools.packages.find]
51
+ where = ["src"]
52
+
53
+ [tool.black]
54
+ line-length = 100
55
+ target-version = ["py310"]
56
+
57
+ [tool.ruff]
58
+ line-length = 100
59
+ select = ["E", "F", "I", "W"]
src/__init__.py CHANGED
@@ -2,7 +2,16 @@
2
 
3
  from .model import ChessConfig, ChessForCausalLM
4
  from .tokenizer import ChessTokenizer
5
- from .evaluate import ChessEvaluator, load_model_from_hub
 
 
 
 
 
 
 
 
 
6
 
7
  __all__ = [
8
  "ChessConfig",
 
2
 
3
  from .model import ChessConfig, ChessForCausalLM
4
  from .tokenizer import ChessTokenizer
5
+
6
+ # Lazy import for evaluate to avoid RuntimeWarning when running as module
7
+ def __getattr__(name):
8
+ if name == "ChessEvaluator":
9
+ from .evaluate import ChessEvaluator
10
+ return ChessEvaluator
11
+ if name == "load_model_from_hub":
12
+ from .evaluate import load_model_from_hub
13
+ return load_model_from_hub
14
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
15
 
16
  __all__ = [
17
  "ChessConfig",
src/data.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading utilities for the Chess Challenge.
3
+
4
+ This module provides functions to load and process chess game data
5
+ from the Lichess dataset on Hugging Face.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict, Iterator, List, Optional
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+
15
+
16
+ class ChessDataset(Dataset):
17
+ """
18
+ PyTorch Dataset for chess games.
19
+
20
+ This dataset loads games from a Hugging Face dataset and prepares
21
+ them for language modeling training.
22
+
23
+ Each game is tokenized and truncated/padded to max_length.
24
+ The labels are shifted by one position for next-token prediction.
25
+
26
+ Example:
27
+ >>> from src.tokenizer import ChessTokenizer
28
+ >>> tokenizer = ChessTokenizer.build_vocab_from_dataset()
29
+ >>> dataset = ChessDataset(tokenizer, max_length=256)
30
+ >>> sample = dataset[0]
31
+ >>> print(sample["input_ids"].shape) # (256,)
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ tokenizer,
37
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
38
+ split: str = "train",
39
+ column: str = "text",
40
+ max_length: int = 256,
41
+ max_samples: Optional[int] = None,
42
+ ):
43
+ """
44
+ Initialize the chess dataset.
45
+
46
+ Args:
47
+ tokenizer: The chess tokenizer to use.
48
+ dataset_name: Name of the dataset on Hugging Face Hub.
49
+ split: Dataset split to use.
50
+ column: Column containing the game strings.
51
+ max_length: Maximum sequence length.
52
+ max_samples: Maximum number of samples to load.
53
+ """
54
+ from datasets import load_dataset
55
+
56
+ self.tokenizer = tokenizer
57
+ self.max_length = max_length
58
+ self.column = column
59
+
60
+ # Load dataset
61
+ dataset = load_dataset(dataset_name, split=split)
62
+
63
+ if max_samples is not None:
64
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
65
+
66
+ self.data = dataset
67
+
68
+ def __len__(self) -> int:
69
+ return len(self.data)
70
+
71
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
72
+ game = self.data[idx][self.column]
73
+
74
+ # Prepend BOS token for proper language modeling
75
+ game_with_bos = self.tokenizer.bos_token + " " + game
76
+
77
+ # Tokenize
78
+ encoding = self.tokenizer(
79
+ game_with_bos,
80
+ truncation=True,
81
+ max_length=self.max_length,
82
+ padding="max_length",
83
+ return_tensors="pt",
84
+ )
85
+
86
+ # Squeeze batch dimension
87
+ input_ids = encoding["input_ids"].squeeze(0)
88
+ attention_mask = encoding["attention_mask"].squeeze(0)
89
+
90
+ # Labels are the same as input_ids (model will shift internally)
91
+ labels = input_ids.clone()
92
+
93
+ # Set padding tokens to -100 to ignore in loss
94
+ labels[attention_mask == 0] = -100
95
+
96
+ return {
97
+ "input_ids": input_ids,
98
+ "attention_mask": attention_mask,
99
+ "labels": labels,
100
+ }
101
+
102
+
103
+ class ChessDataCollator:
104
+ """
105
+ Data collator for chess games.
106
+
107
+ This collator pads sequences to the same length within a batch
108
+ and creates the appropriate attention masks.
109
+ """
110
+
111
+ def __init__(self, tokenizer, max_length: int = 256):
112
+ self.tokenizer = tokenizer
113
+ self.max_length = max_length
114
+
115
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
116
+ # Stack tensors
117
+ input_ids = torch.stack([f["input_ids"] for f in features])
118
+ attention_mask = torch.stack([f["attention_mask"] for f in features])
119
+ labels = torch.stack([f["labels"] for f in features])
120
+
121
+ return {
122
+ "input_ids": input_ids,
123
+ "attention_mask": attention_mask,
124
+ "labels": labels,
125
+ }
126
+
127
+
128
+ def create_train_val_datasets(
129
+ tokenizer,
130
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
131
+ max_length: int = 256,
132
+ train_samples: Optional[int] = None,
133
+ val_samples: int = 5000,
134
+ val_ratio: float = 0.05,
135
+ ):
136
+ """
137
+ Create training and validation datasets.
138
+
139
+ Args:
140
+ tokenizer: The chess tokenizer.
141
+ dataset_name: Name of the dataset.
142
+ max_length: Maximum sequence length.
143
+ train_samples: Maximum training samples (None for all).
144
+ val_samples: Number of validation samples.
145
+ val_ratio: Ratio of validation samples (used if train_samples is None).
146
+
147
+ Returns:
148
+ Tuple of (train_dataset, val_dataset).
149
+ """
150
+ from datasets import load_dataset
151
+
152
+ # Load full dataset
153
+ full_dataset = load_dataset(dataset_name, split="train")
154
+
155
+ # Determine split sizes
156
+ total = len(full_dataset)
157
+
158
+ if train_samples is not None:
159
+ n_train = min(train_samples, total - val_samples)
160
+ else:
161
+ n_train = int(total * (1 - val_ratio))
162
+
163
+ n_val = min(val_samples, total - n_train)
164
+
165
+ # Split dataset
166
+ train_data = full_dataset.select(range(n_train))
167
+ val_data = full_dataset.select(range(n_train, n_train + n_val))
168
+
169
+ # Create dataset objects
170
+ train_dataset = ChessDataset(
171
+ tokenizer=tokenizer,
172
+ dataset_name=dataset_name,
173
+ max_length=max_length,
174
+ )
175
+ train_dataset.data = train_data
176
+
177
+ val_dataset = ChessDataset(
178
+ tokenizer=tokenizer,
179
+ dataset_name=dataset_name,
180
+ max_length=max_length,
181
+ )
182
+ val_dataset.data = val_data
183
+
184
+ return train_dataset, val_dataset
185
+
186
+
187
+ def stream_games(
188
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
189
+ split: str = "train",
190
+ column: str = "text",
191
+ ) -> Iterator[str]:
192
+ """
193
+ Stream games from the dataset for memory-efficient processing.
194
+
195
+ Args:
196
+ dataset_name: Name of the dataset on Hugging Face Hub.
197
+ split: Dataset split to use.
198
+ column: Column containing the game strings.
199
+
200
+ Yields:
201
+ Game strings one at a time.
202
+ """
203
+ from datasets import load_dataset
204
+
205
+ dataset = load_dataset(dataset_name, split=split, streaming=True)
206
+
207
+ for example in dataset:
208
+ yield example[column]
209
+
210
+
211
+ def analyze_dataset_statistics(
212
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
213
+ max_samples: int = 10000,
214
+ ) -> Dict:
215
+ """
216
+ Analyze statistics of the chess dataset.
217
+
218
+ Args:
219
+ dataset_name: Name of the dataset.
220
+ max_samples: Maximum number of samples to analyze.
221
+
222
+ Returns:
223
+ Dictionary containing dataset statistics.
224
+ """
225
+ from collections import Counter
226
+ from datasets import load_dataset
227
+
228
+ dataset = load_dataset(dataset_name, split="train")
229
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
230
+
231
+ game_lengths = []
232
+ move_counts = Counter()
233
+ opening_moves = Counter()
234
+
235
+ for example in dataset:
236
+ moves = example["text"].strip().split()
237
+ game_lengths.append(len(moves))
238
+ move_counts.update(moves)
239
+
240
+ # Track common openings (first 4 moves)
241
+ if len(moves) >= 4:
242
+ opening = " ".join(moves[:4])
243
+ opening_moves[opening] += 1
244
+
245
+ return {
246
+ "total_games": len(dataset),
247
+ "avg_game_length": sum(game_lengths) / len(game_lengths),
248
+ "min_game_length": min(game_lengths),
249
+ "max_game_length": max(game_lengths),
250
+ "unique_moves": len(move_counts),
251
+ "most_common_moves": move_counts.most_common(20),
252
+ "most_common_openings": opening_moves.most_common(10),
253
+ }
src/evaluate.py CHANGED
@@ -9,6 +9,7 @@ from __future__ import annotations
9
 
10
  import argparse
11
  import random
 
12
  from dataclasses import dataclass
13
  from typing import List, Optional, Tuple
14
 
@@ -23,16 +24,23 @@ class GameResult:
23
  model_color: str # "white" or "black"
24
  termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
25
  illegal_move_count: int
26
-
27
-
28
  class ChessEvaluator:
29
  """
30
  Evaluator for chess models.
31
 
32
  This class handles playing games between a trained model and Stockfish,
33
  tracking results, and computing ELO ratings.
 
 
 
 
34
  """
35
 
 
 
 
36
  def __init__(
37
  self,
38
  model,
@@ -88,10 +96,100 @@ class ChessEvaluator:
88
  if hasattr(self, 'engine') and self.engine:
89
  self.engine.quit()
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def _convert_board_to_moves(self, board) -> str:
92
- """Convert board move history to model input format."""
 
 
 
 
 
93
  moves = []
94
  temp_board = self.chess.Board()
 
95
 
96
  for move in board.move_stack:
97
  # Get piece and color
@@ -103,29 +201,44 @@ class ChessEvaluator:
103
  from_sq = self.chess.square_name(move.from_square)
104
  to_sq = self.chess.square_name(move.to_square)
105
 
106
- move_str = f"{color}{piece_letter}{from_sq}{to_sq}"
107
-
108
- # Add promotion
109
  if move.promotion:
110
- move_str += f"={self.chess.piece_symbol(move.promotion).upper()}"
111
-
112
- # Add capture suffix
113
- if temp_board.is_capture(move):
114
- move_str += "(x)"
115
 
116
- # Add check/checkmate suffix
117
- temp_board.push(move)
118
- if temp_board.is_checkmate():
119
- move_str = move_str.replace("(x)", "(x+*)") if "(x)" in move_str else move_str + "(+*)"
120
- elif temp_board.is_check():
121
- move_str = move_str.replace("(x)", "(x+)") if "(x)" in move_str else move_str + "(+)"
122
 
123
- # Handle castling
124
- if piece_letter == "K" and abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
125
- if to_sq[0] == 'g': # Kingside
126
- move_str = move_str.split("(")[0] + "(o)"
127
- else: # Queenside
128
- move_str = move_str.split("(")[0] + "(O)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  moves.append(move_str)
131
 
@@ -160,6 +273,65 @@ class ChessEvaluator:
160
 
161
  return False
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def _generate_move_tokens(
164
  self,
165
  input_ids: torch.Tensor,
@@ -168,11 +340,12 @@ class ChessEvaluator:
168
  max_tokens: int = 20,
169
  ) -> str:
170
  """
171
- Generate tokens until a separator (whitespace/EOS) is encountered.
172
 
173
- This method supports different tokenization strategies:
174
- - For move-level tokenizers: generates one token (the full move)
175
- - For character/subword tokenizers: generates until whitespace
 
176
 
177
  Args:
178
  input_ids: The input token IDs.
@@ -181,10 +354,11 @@ class ChessEvaluator:
181
  max_tokens: Maximum tokens to generate for a single move.
182
 
183
  Returns:
184
- The generated move string (without trailing separator).
185
  """
186
  generated_tokens = []
187
  current_ids = input_ids.clone()
 
188
 
189
  for _ in range(max_tokens):
190
  with torch.no_grad():
@@ -193,31 +367,47 @@ class ChessEvaluator:
193
 
194
  # Apply top-k filtering
195
  if top_k > 0:
196
- top_k_values = torch.topk(logits, min(top_k, logits.size(-1)))[0]
197
- indices_to_remove = logits < top_k_values[..., -1, None]
198
  logits[indices_to_remove] = float("-inf")
199
 
200
  # Sample
201
  probs = torch.softmax(logits, dim=-1)
202
- next_token = torch.multinomial(probs, num_samples=1) # Shape: [1, 1]
203
 
204
  # Decode the token
205
  token_str = self.tokenizer.decode(next_token[0])
206
 
207
  # Check if this is a separator token
208
  if self._is_separator_token(token_str):
209
- break
210
-
211
- generated_tokens.append(next_token[0]) # Store [1] tensor
 
 
 
 
 
 
 
212
 
213
- # Append to input for next iteration (next_token is already [1, 1])
214
  current_ids = torch.cat([current_ids, next_token], dim=-1)
 
215
 
216
- # For move-level tokenizers, a single non-separator token is the full move
217
- # We can detect this by checking if the token looks like a complete move
218
- # (starts with W or B, has enough characters for a move)
219
- if len(token_str) >= 6 and token_str[0] in "WB":
220
- break
 
 
 
 
 
 
 
 
221
 
222
  # Decode all generated tokens together
223
  if generated_tokens:
@@ -236,11 +426,15 @@ class ChessEvaluator:
236
  """
237
  Get the model's next move prediction.
238
 
239
- This method generates tokens until a separator (whitespace/EOS) is produced,
240
- allowing it to work with different tokenization strategies:
241
- - Move-level tokenizers: each move is a single token
242
- - Character-level tokenizers: moves are generated character by character
243
- - BPE/subword tokenizers: moves may be split into subwords
 
 
 
 
244
 
245
  Returns:
246
  Tuple of (UCI move string, number of retries used).
@@ -257,32 +451,26 @@ class ChessEvaluator:
257
  input_text = self.tokenizer.bos_token + " " + moves_str
258
 
259
  # Tokenize
260
- max_len = getattr(self.model.config, 'n_ctx', None) or getattr(self.model.config, 'max_position_embeddings', 256)
261
  inputs = self.tokenizer(
262
  input_text,
263
  return_tensors="pt",
264
  truncation=True,
265
- max_length=max_len - 10, # Leave room for generated tokens
266
  ).to(self.device)
267
 
268
  # Try to generate a legal move
269
  for retry in range(self.max_retries):
270
- # Generate tokens until separator
271
- move_token = self._generate_move_tokens(
272
  inputs["input_ids"],
273
  temperature=temperature,
274
  top_k=top_k,
275
  )
276
 
277
- # Convert to UCI
278
- if len(move_token) >= 6:
279
- uci_move = move_token[2:4] + move_token[4:6]
280
-
281
- # Handle promotion
282
- if "=" in move_token:
283
- promo_idx = move_token.index("=")
284
- uci_move += move_token[promo_idx + 1].lower()
285
-
286
  try:
287
  move = self.chess.Move.from_uci(uci_move)
288
  if move in board.legal_moves:
@@ -390,7 +578,6 @@ class ChessEvaluator:
390
  n_positions: int = 1000,
391
  temperature: float = 0.7,
392
  verbose: bool = True,
393
- seed: int = 42,
394
  ) -> dict:
395
  """
396
  Evaluate the model's ability to generate legal moves.
@@ -402,14 +589,10 @@ class ChessEvaluator:
402
  n_positions: Number of positions to test.
403
  temperature: Sampling temperature.
404
  verbose: Whether to print progress.
405
- seed: Random seed for reproducibility.
406
 
407
  Returns:
408
  Dictionary with legal move statistics.
409
  """
410
- # Set seed for deterministic evaluation
411
- random.seed(seed)
412
-
413
  results = {
414
  "total_positions": 0,
415
  "legal_first_try": 0,
@@ -572,73 +755,24 @@ def load_model_from_hub(model_id: str, device: str = "auto"):
572
  Returns:
573
  Tuple of (model, tokenizer).
574
  """
575
- import json
576
- from huggingface_hub import hf_hub_download
577
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
578
 
579
- # Import custom classes
580
- try:
581
- from src.model import ChessConfig, ChessForCausalLM
582
- from src.tokenizer import ChessTokenizer
583
- except ImportError:
584
- from .model import ChessConfig, ChessForCausalLM
585
- from .tokenizer import ChessTokenizer
586
 
587
- # Register BEFORE any from_pretrained calls
588
  try:
589
- AutoConfig.register("chess_transformer", ChessConfig)
590
- except ValueError:
591
- pass
592
- try:
593
- AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)
594
- except ValueError:
595
- pass
596
-
597
- print(f"Loading model {model_id}...")
598
-
599
- # Download and load config manually to avoid transformers auto-detection issues
600
- config_path = hf_hub_download(repo_id=model_id, filename="config.json")
601
- with open(config_path, "r") as f:
602
- config_dict = json.load(f)
603
-
604
- # Remove fields that are not in ChessConfig to avoid unexpected kwargs
605
- config_dict.pop("model_type", None)
606
- config_dict.pop("architectures", None)
607
- config_dict.pop("transformers_version", None)
608
- config_dict.pop("dtype", None)
609
- config_dict.pop("torch_dtype", None)
610
-
611
- config = ChessConfig(**config_dict)
612
 
613
- # Load model weights with our config
614
- model = ChessForCausalLM.from_pretrained(
615
  model_id,
616
- config=config,
617
  device_map=device,
618
  )
619
 
620
- # Load tokenizer - try to find vocab.json, else build default
621
- try:
622
- tokenizer = ChessTokenizer.from_pretrained(model_id)
623
- except Exception as e:
624
- print(f"ChessTokenizer.from_pretrained failed: {e}")
625
- try:
626
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
627
- except Exception as e2:
628
- print(f"AutoTokenizer also failed: {e2}")
629
- print("Creating default tokenizer with vocab_size from config...")
630
- # Create a minimal tokenizer with just the vocab size
631
- tokenizer = ChessTokenizer()
632
- # Ensure vocab size matches model
633
- if hasattr(config, 'vocab_size'):
634
- # Build a placeholder vocab of the right size
635
- tokenizer._vocab = {f"[MOVE_{i}]": i for i in range(config.vocab_size)}
636
- tokenizer._vocab["[PAD]"] = 0
637
- tokenizer._vocab["[BOS]"] = 1
638
- tokenizer._vocab["[EOS]"] = 2
639
- tokenizer._vocab["[UNK]"] = 3
640
- tokenizer._ids_to_tokens = {v: k for k, v in tokenizer._vocab.items()}
641
-
642
  return model, tokenizer
643
 
644
 
@@ -684,21 +818,25 @@ def main():
684
  # Load model
685
  print(f"\nLoading model from: {args.model_path}")
686
 
687
- if "/" in args.model_path and not args.model_path.startswith("."):
688
- # Assume Hugging Face model ID
689
- model, tokenizer = load_model_from_hub(args.model_path)
690
- else:
691
  # Local path
692
  from transformers import AutoModelForCausalLM
693
- try:
694
- from src.tokenizer import ChessTokenizer
695
- from src.model import ChessConfig, ChessForCausalLM
696
- except ImportError:
697
- from .tokenizer import ChessTokenizer
698
- from .model import ChessConfig, ChessForCausalLM
699
 
700
  tokenizer = ChessTokenizer.from_pretrained(args.model_path)
701
  model = AutoModelForCausalLM.from_pretrained(args.model_path)
 
 
 
 
 
 
 
 
702
 
703
  # Create evaluator
704
  print(f"\nSetting up evaluator...")
 
9
 
10
  import argparse
11
  import random
12
+ import re
13
  from dataclasses import dataclass
14
  from typing import List, Optional, Tuple
15
 
 
24
  model_color: str # "white" or "black"
25
  termination: str # "checkmate", "stalemate", "illegal_move", "max_moves", etc.
26
  illegal_move_count: int
27
+
28
+
29
  class ChessEvaluator:
30
  """
31
  Evaluator for chess models.
32
 
33
  This class handles playing games between a trained model and Stockfish,
34
  tracking results, and computing ELO ratings.
35
+
36
+ Supports any tokenization format as long as the model generates valid
37
+ chess squares (e.g., e2, e4). The evaluator extracts UCI moves by finding
38
+ square patterns in the generated output.
39
  """
40
 
41
+ # Regex pattern to match chess squares
42
+ SQUARE_PATTERN = r'[a-h][1-8]'
43
+
44
  def __init__(
45
  self,
46
  model,
 
96
  if hasattr(self, 'engine') and self.engine:
97
  self.engine.quit()
98
 
99
+ def _detect_tokenizer_format(self) -> str:
100
+ """
101
+ Detect the tokenizer's expected move format by testing tokenization.
102
+
103
+ Tests various formats with a sample move and picks the one that
104
+ produces the fewest unknown tokens. This makes evaluation work
105
+ with any tokenizer format.
106
+
107
+ Supported formats:
108
+ - 'decomposed': "WP e2_f e4_t" (piece, from_suffix, to_suffix)
109
+ - 'standard': "WPe2e4" (combined with optional annotations)
110
+ - 'uci': "e2e4" (pure UCI notation)
111
+ - 'uci_spaced': "e2 e4" (UCI with space separator)
112
+
113
+ Returns:
114
+ The format string that best matches the tokenizer's vocabulary.
115
+ """
116
+ if hasattr(self, '_cached_format'):
117
+ return self._cached_format
118
+
119
+ # Sample move representations to test
120
+ test_formats = {
121
+ 'decomposed': "WP e2_f e4_t",
122
+ 'standard': "WPe2e4",
123
+ 'uci': "e2e4",
124
+ 'uci_spaced': "e2 e4",
125
+ }
126
+
127
+ unk_token_id = getattr(self.tokenizer, 'unk_token_id', None)
128
+ best_format = 'standard'
129
+ min_unk_count = float('inf')
130
+
131
+ for fmt, sample in test_formats.items():
132
+ try:
133
+ tokens = self.tokenizer.encode(sample, add_special_tokens=False)
134
+ # Count unknown tokens
135
+ unk_count = tokens.count(unk_token_id) if unk_token_id is not None else 0
136
+ # Also penalize if the entire thing became one UNK
137
+ if len(tokens) == 1 and unk_count == 1:
138
+ unk_count = 100 # Heavy penalty
139
+
140
+ if unk_count < min_unk_count:
141
+ min_unk_count = unk_count
142
+ best_format = fmt
143
+ except Exception:
144
+ continue
145
+
146
+ self._cached_format = best_format
147
+ return best_format
148
+
149
+ def _format_move(self, color: str, piece: str, from_sq: str, to_sq: str,
150
+ promotion: str = None) -> str:
151
+ """
152
+ Format a single move according to the detected tokenizer format.
153
+
154
+ Args:
155
+ color: 'W' or 'B'
156
+ piece: Piece letter (P, N, B, R, Q, K)
157
+ from_sq: Source square (e.g., 'e2')
158
+ to_sq: Destination square (e.g., 'e4')
159
+ promotion: Promotion piece letter or None
160
+
161
+ Returns:
162
+ Formatted move string.
163
+ """
164
+ fmt = self._detect_tokenizer_format()
165
+
166
+ if fmt == 'decomposed':
167
+ move_str = f"{color}{piece} {from_sq}_f {to_sq}_t"
168
+ elif fmt == 'uci':
169
+ move_str = f"{from_sq}{to_sq}"
170
+ if promotion:
171
+ move_str += promotion.lower()
172
+ elif fmt == 'uci_spaced':
173
+ move_str = f"{from_sq} {to_sq}"
174
+ if promotion:
175
+ move_str += f" {promotion.lower()}"
176
+ else: # standard
177
+ move_str = f"{color}{piece}{from_sq}{to_sq}"
178
+ if promotion:
179
+ move_str += f"={promotion}"
180
+
181
+ return move_str
182
+
183
  def _convert_board_to_moves(self, board) -> str:
184
+ """
185
+ Convert board move history to model input format.
186
+
187
+ Automatically detects the tokenizer's expected format and outputs
188
+ moves accordingly. Supports any tokenization strategy.
189
+ """
190
  moves = []
191
  temp_board = self.chess.Board()
192
+ fmt = self._detect_tokenizer_format()
193
 
194
  for move in board.move_stack:
195
  # Get piece and color
 
201
  from_sq = self.chess.square_name(move.from_square)
202
  to_sq = self.chess.square_name(move.to_square)
203
 
204
+ # Get promotion piece if any
205
+ promo = None
 
206
  if move.promotion:
207
+ promo = self.chess.piece_symbol(move.promotion).upper()
 
 
 
 
208
 
209
+ # Format based on detected tokenizer format
210
+ move_str = self._format_move(color, piece_letter, from_sq, to_sq, promo)
 
 
 
 
211
 
212
+ # For standard format, add annotations (capture, check, castling)
213
+ if fmt == 'standard':
214
+ # Add capture suffix
215
+ if temp_board.is_capture(move):
216
+ move_str += "(x)"
217
+
218
+ # Push move to check for check/checkmate
219
+ temp_board.push(move)
220
+
221
+ if temp_board.is_checkmate():
222
+ if "(x)" in move_str:
223
+ move_str = move_str.replace("(x)", "(x+*)")
224
+ else:
225
+ move_str += "(+*)"
226
+ elif temp_board.is_check():
227
+ if "(x)" in move_str:
228
+ move_str = move_str.replace("(x)", "(x+)")
229
+ else:
230
+ move_str += "(+)"
231
+
232
+ # Handle castling notation
233
+ if piece_letter == "K":
234
+ if abs(ord(from_sq[0]) - ord(to_sq[0])) > 1:
235
+ if to_sq[0] == 'g': # Kingside
236
+ move_str = move_str.split("(")[0] + "(o)"
237
+ else: # Queenside
238
+ move_str = move_str.split("(")[0] + "(O)"
239
+ else:
240
+ # For non-standard formats, just push the move
241
+ temp_board.push(move)
242
 
243
  moves.append(move_str)
244
 
 
273
 
274
  return False
275
 
276
+ def _extract_uci_move(self, text: str) -> Optional[str]:
277
+ """
278
+ Extract a UCI move from generated text using pattern matching.
279
+
280
+ This generic method works with any tokenization format by finding
281
+ chess square patterns ([a-h][1-8]) in the output.
282
+
283
+ Supported formats include:
284
+ - Standard: "WPe2e4" -> "e2e4"
285
+ - Decomposed: "WP e2_f e4_t" -> "e2e4"
286
+ - Pure UCI: "e2e4" -> "e2e4"
287
+ - With separators: "e2-e4", "e2 e4" -> "e2e4"
288
+ - With promotion: "e7e8=Q", "e7e8q" -> "e7e8q"
289
+
290
+ Args:
291
+ text: The generated text containing a move.
292
+
293
+ Returns:
294
+ UCI move string (e.g., "e2e4", "e7e8q") or None if not found.
295
+ """
296
+ if not text:
297
+ return None
298
+
299
+ # Find all squares in the text
300
+ squares = re.findall(self.SQUARE_PATTERN, text)
301
+
302
+ if len(squares) < 2:
303
+ return None
304
+
305
+ # Take the first two squares as from and to
306
+ from_sq, to_sq = squares[0], squares[1]
307
+ uci_move = from_sq + to_sq
308
+
309
+ # Check for promotion (letter after to_square)
310
+ # Look for patterns like "=Q", "=q", or just "q" after the to_square
311
+ to_sq_idx = text.find(to_sq)
312
+ if to_sq_idx != -1:
313
+ remaining = text[to_sq_idx + 2:to_sq_idx + 5] # Check next few chars
314
+ promo_match = re.search(r'[=]?([qrbnQRBN])', remaining)
315
+ if promo_match:
316
+ uci_move += promo_match.group(1).lower()
317
+
318
+ return uci_move
319
+
320
+ def _has_complete_move(self, text: str) -> bool:
321
+ """
322
+ Check if the generated text contains a complete move.
323
+
324
+ A complete move has at least two valid chess squares.
325
+
326
+ Args:
327
+ text: The generated text so far.
328
+
329
+ Returns:
330
+ True if text contains at least two squares.
331
+ """
332
+ squares = re.findall(self.SQUARE_PATTERN, text)
333
+ return len(squares) >= 2
334
+
335
  def _generate_move_tokens(
336
  self,
337
  input_ids: torch.Tensor,
 
340
  max_tokens: int = 20,
341
  ) -> str:
342
  """
343
+ Generate tokens until a complete move is detected or separator is hit.
344
 
345
+ This method is tokenizer-agnostic and stops when:
346
+ - A separator token (whitespace/EOS) is encountered
347
+ - Two chess squares have been generated (complete move)
348
+ - max_tokens limit is reached
349
 
350
  Args:
351
  input_ids: The input token IDs.
 
354
  max_tokens: Maximum tokens to generate for a single move.
355
 
356
  Returns:
357
+ The generated move string.
358
  """
359
  generated_tokens = []
360
  current_ids = input_ids.clone()
361
+ accumulated_text = ""
362
 
363
  for _ in range(max_tokens):
364
  with torch.no_grad():
 
367
 
368
  # Apply top-k filtering
369
  if top_k > 0:
370
+ top_k_vals = torch.topk(logits, min(top_k, logits.size(-1)))
371
+ indices_to_remove = logits < top_k_vals[0][..., -1, None]
372
  logits[indices_to_remove] = float("-inf")
373
 
374
  # Sample
375
  probs = torch.softmax(logits, dim=-1)
376
+ next_token = torch.multinomial(probs, num_samples=1)
377
 
378
  # Decode the token
379
  token_str = self.tokenizer.decode(next_token[0])
380
 
381
  # Check if this is a separator token
382
  if self._is_separator_token(token_str):
383
+ # If we already have a complete move, stop
384
+ if self._has_complete_move(accumulated_text):
385
+ break
386
+ # Otherwise, if it's EOS, we should also stop
387
+ if hasattr(self.tokenizer, 'eos_token'):
388
+ if token_str == self.tokenizer.eos_token:
389
+ break
390
+ # For whitespace separators, only stop if we have content
391
+ if accumulated_text:
392
+ break
393
 
394
+ generated_tokens.append(next_token[0])
395
  current_ids = torch.cat([current_ids, next_token], dim=-1)
396
+ accumulated_text += token_str
397
 
398
+ # Stop if we have a complete move (two squares found)
399
+ if self._has_complete_move(accumulated_text):
400
+ # Check if this might be a promotion - peek for one more token
401
+ # if the move is to rank 1 or 8
402
+ squares = re.findall(self.SQUARE_PATTERN, accumulated_text)
403
+ if len(squares) >= 2:
404
+ to_sq = squares[1]
405
+ if to_sq[1] in '18': # Potential promotion
406
+ # Allow one more iteration to capture promotion piece
407
+ if len(generated_tokens) > 3: # Already have enough
408
+ break
409
+ else:
410
+ break
411
 
412
  # Decode all generated tokens together
413
  if generated_tokens:
 
426
  """
427
  Get the model's next move prediction.
428
 
429
+ This method is tokenizer-agnostic. It generates tokens and extracts
430
+ UCI moves using pattern matching on chess squares.
431
+
432
+ Works with any tokenization format:
433
+ - Move-level: "WPe2e4" -> e2e4
434
+ - Decomposed: "WP e2_f e4_t" -> e2e4
435
+ - Pure UCI: "e2e4" -> e2e4
436
+ - Character-level: "e" "2" "e" "4" -> e2e4
437
+ - BPE/subword: "e2" "e4" -> e2e4
438
 
439
  Returns:
440
  Tuple of (UCI move string, number of retries used).
 
451
  input_text = self.tokenizer.bos_token + " " + moves_str
452
 
453
  # Tokenize
 
454
  inputs = self.tokenizer(
455
  input_text,
456
  return_tensors="pt",
457
  truncation=True,
458
+ max_length=self.model.config.n_ctx - 10,
459
  ).to(self.device)
460
 
461
  # Try to generate a legal move
462
  for retry in range(self.max_retries):
463
+ # Generate tokens until we have a move
464
+ move_text = self._generate_move_tokens(
465
  inputs["input_ids"],
466
  temperature=temperature,
467
  top_k=top_k,
468
  )
469
 
470
+ # Extract UCI move using generic pattern matching
471
+ uci_move = self._extract_uci_move(move_text)
472
+
473
+ if uci_move:
 
 
 
 
 
474
  try:
475
  move = self.chess.Move.from_uci(uci_move)
476
  if move in board.legal_moves:
 
578
  n_positions: int = 1000,
579
  temperature: float = 0.7,
580
  verbose: bool = True,
 
581
  ) -> dict:
582
  """
583
  Evaluate the model's ability to generate legal moves.
 
589
  n_positions: Number of positions to test.
590
  temperature: Sampling temperature.
591
  verbose: Whether to print progress.
 
592
 
593
  Returns:
594
  Dictionary with legal move statistics.
595
  """
 
 
 
596
  results = {
597
  "total_positions": 0,
598
  "legal_first_try": 0,
 
755
  Returns:
756
  Tuple of (model, tokenizer).
757
  """
758
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
759
 
760
+ # Import to register custom classes
761
+ from src.model import ChessConfig, ChessForCausalLM
762
+ from src.tokenizer import ChessTokenizer
 
 
 
 
763
 
764
+ # Try loading with custom tokenizer first, fall back to AutoTokenizer
765
  try:
766
+ tokenizer = ChessTokenizer.from_pretrained(model_id)
767
+ except Exception:
768
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
+ model = AutoModelForCausalLM.from_pretrained(
 
771
  model_id,
772
+ trust_remote_code=True,
773
  device_map=device,
774
  )
775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  return model, tokenizer
777
 
778
 
 
818
  # Load model
819
  print(f"\nLoading model from: {args.model_path}")
820
 
821
+ import os
822
+ is_local_path = os.path.exists(args.model_path)
823
+
824
+ if is_local_path:
825
  # Local path
826
  from transformers import AutoModelForCausalLM
827
+ from src.tokenizer import ChessTokenizer
828
+ from src.model import ChessConfig, ChessForCausalLM
 
 
 
 
829
 
830
  tokenizer = ChessTokenizer.from_pretrained(args.model_path)
831
  model = AutoModelForCausalLM.from_pretrained(args.model_path)
832
+ else:
833
+ # Assume Hugging Face model ID (or invalid path)
834
+ if args.model_path.startswith(".") or args.model_path.startswith("/"):
835
+ raise FileNotFoundError(
836
+ f"Local model path not found: {args.model_path}\n"
837
+ f"Please check that the path exists and contains model files."
838
+ )
839
+ model, tokenizer = load_model_from_hub(args.model_path)
840
 
841
  # Create evaluator
842
  print(f"\nSetting up evaluator...")
src/train.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for the Chess Challenge.
3
+
4
+ This script provides a complete training pipeline using the Hugging Face Trainer.
5
+ Students can modify this script to experiment with different training strategies.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import os
12
+ import warnings
13
+ from pathlib import Path
14
+
15
+ # Suppress warnings from third-party libraries (multiprocess has Python 3.14 compat issues)
16
+ warnings.filterwarnings("ignore", message="'return' in a 'finally' block")
17
+
18
+ import torch
19
+ from transformers import (
20
+ Trainer,
21
+ TrainingArguments,
22
+ set_seed,
23
+ )
24
+
25
+ from src.data import ChessDataCollator, create_train_val_datasets
26
+ from src.model import ChessConfig, ChessForCausalLM
27
+ from src.tokenizer import ChessTokenizer
28
+ from src.utils import count_parameters, print_parameter_budget
29
+
30
+
31
+ def parse_args():
32
+ """Parse command line arguments."""
33
+ parser = argparse.ArgumentParser(
34
+ description="Train a chess-playing language model"
35
+ )
36
+
37
+ # Model arguments
38
+ parser.add_argument(
39
+ "--vocab_size", type=int, default=1200,
40
+ help="Vocabulary size"
41
+ )
42
+ parser.add_argument(
43
+ "--n_embd", type=int, default=128,
44
+ help="Embedding dimension"
45
+ )
46
+ parser.add_argument(
47
+ "--n_layer", type=int, default=4,
48
+ help="Number of transformer layers"
49
+ )
50
+ parser.add_argument(
51
+ "--n_head", type=int, default=4,
52
+ help="Number of attention heads"
53
+ )
54
+ parser.add_argument(
55
+ "--n_ctx", type=int, default=256,
56
+ help="Maximum context length"
57
+ )
58
+ parser.add_argument(
59
+ "--n_inner", type=int, default=None,
60
+ help="Feed-forward inner dimension (default: 4 * n_embd)"
61
+ )
62
+ parser.add_argument(
63
+ "--dropout", type=float, default=0.1,
64
+ help="Dropout probability"
65
+ )
66
+ parser.add_argument(
67
+ "--no_tie_weights", action="store_true",
68
+ help="Disable weight tying between embedding and output layers"
69
+ )
70
+
71
+ # Data arguments
72
+ parser.add_argument(
73
+ "--dataset_name", type=str, default="dlouapre/lichess_2025-01_1M",
74
+ help="Name of the dataset on Hugging Face Hub"
75
+ )
76
+ parser.add_argument(
77
+ "--max_train_samples", type=int, default=None,
78
+ help="Maximum number of training samples"
79
+ )
80
+ parser.add_argument(
81
+ "--val_samples", type=int, default=5000,
82
+ help="Number of validation samples"
83
+ )
84
+
85
+ # Training arguments
86
+ parser.add_argument(
87
+ "--output_dir", type=str, default="./output",
88
+ help="Output directory for model and logs"
89
+ )
90
+ parser.add_argument(
91
+ "--num_train_epochs", type=int, default=3,
92
+ help="Number of training epochs"
93
+ )
94
+ parser.add_argument(
95
+ "--per_device_train_batch_size", type=int, default=32,
96
+ help="Training batch size per device"
97
+ )
98
+ parser.add_argument(
99
+ "--per_device_eval_batch_size", type=int, default=64,
100
+ help="Evaluation batch size per device"
101
+ )
102
+ parser.add_argument(
103
+ "--learning_rate", type=float, default=5e-4,
104
+ help="Learning rate"
105
+ )
106
+ parser.add_argument(
107
+ "--weight_decay", type=float, default=0.01,
108
+ help="Weight decay"
109
+ )
110
+ parser.add_argument(
111
+ "--warmup_ratio", type=float, default=0.1,
112
+ help="Warmup ratio"
113
+ )
114
+ parser.add_argument(
115
+ "--seed", type=int, default=42,
116
+ help="Random seed"
117
+ )
118
+
119
+ # Logging arguments
120
+ parser.add_argument(
121
+ "--logging_steps", type=int, default=100,
122
+ help="Logging frequency"
123
+ )
124
+ parser.add_argument(
125
+ "--eval_steps", type=int, default=500,
126
+ help="Evaluation frequency"
127
+ )
128
+ parser.add_argument(
129
+ "--save_steps", type=int, default=1000,
130
+ help="Checkpoint saving frequency"
131
+ )
132
+
133
+ return parser.parse_args()
134
+
135
+
136
+ def main():
137
+ """Main training function."""
138
+ args = parse_args()
139
+
140
+ # Set seed for reproducibility
141
+ set_seed(args.seed)
142
+
143
+ print("=" * 60)
144
+ print("CHESS CHALLENGE - TRAINING")
145
+ print("=" * 60)
146
+
147
+ # Build tokenizer from dataset
148
+ print("\nBuilding tokenizer from dataset...")
149
+ tokenizer = ChessTokenizer.build_vocab_from_dataset(
150
+ dataset_name=args.dataset_name,
151
+ min_frequency=500, # Only keep moves that appear at least 500 times
152
+ max_samples=100000, # Use 100k games to build vocabulary
153
+ )
154
+ print(f" Vocabulary size: {tokenizer.vocab_size}")
155
+
156
+ # Use the vocab size from tokenizer (override args if provided)
157
+ actual_vocab_size = tokenizer.vocab_size
158
+
159
+ # Create model configuration
160
+ print("\nCreating model configuration...")
161
+ config = ChessConfig(
162
+ vocab_size=actual_vocab_size,
163
+ n_embd=args.n_embd,
164
+ n_layer=args.n_layer,
165
+ n_head=args.n_head,
166
+ n_ctx=args.n_ctx,
167
+ n_inner=args.n_inner,
168
+ dropout=args.dropout,
169
+ tie_weights=not args.no_tie_weights,
170
+ pad_token_id=tokenizer.pad_token_id,
171
+ bos_token_id=tokenizer.bos_token_id,
172
+ eos_token_id=tokenizer.eos_token_id,
173
+ )
174
+
175
+ # Print parameter budget
176
+ print_parameter_budget(config)
177
+
178
+ # Create model
179
+ print("\nCreating model...")
180
+ model = ChessForCausalLM(config)
181
+ n_params = count_parameters(model)
182
+ print(f" Total parameters: {n_params:,}")
183
+
184
+ if n_params > 1_000_000:
185
+ print("WARNING: Model exceeds 1M parameter limit!")
186
+ else:
187
+ print("✓ Model is within 1M parameter limit")
188
+
189
+ # Load datasets
190
+ print("\nLoading datasets...")
191
+ train_dataset, val_dataset = create_train_val_datasets(
192
+ tokenizer=tokenizer,
193
+ dataset_name=args.dataset_name,
194
+ max_length=args.n_ctx,
195
+ train_samples=args.max_train_samples,
196
+ val_samples=args.val_samples,
197
+ )
198
+ print(f" Training samples: {len(train_dataset):,}")
199
+ print(f" Validation samples: {len(val_dataset):,}")
200
+
201
+ # Create data collator
202
+ data_collator = ChessDataCollator(tokenizer, max_length=args.n_ctx)
203
+
204
+ # Training arguments
205
+ training_args = TrainingArguments(
206
+ output_dir=args.output_dir,
207
+ num_train_epochs=args.num_train_epochs,
208
+ per_device_train_batch_size=args.per_device_train_batch_size,
209
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
210
+ learning_rate=args.learning_rate,
211
+ weight_decay=args.weight_decay,
212
+ warmup_ratio=args.warmup_ratio,
213
+ logging_dir=os.path.join(args.output_dir, "logs"),
214
+ logging_steps=args.logging_steps,
215
+ eval_strategy="epoch",
216
+ save_strategy="epoch",
217
+ save_total_limit=3,
218
+ load_best_model_at_end=True,
219
+ metric_for_best_model="eval_loss",
220
+ greater_is_better=False,
221
+ seed=args.seed,
222
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
223
+ report_to=["none"],
224
+ )
225
+
226
+ # Create trainer
227
+ trainer = Trainer(
228
+ model=model,
229
+ args=training_args,
230
+ train_dataset=train_dataset,
231
+ eval_dataset=val_dataset,
232
+ data_collator=data_collator,
233
+ tokenizer=tokenizer,
234
+ )
235
+
236
+ # Train
237
+ print("\nStarting training...")
238
+ trainer.train()
239
+
240
+ # Save final model
241
+ print("\nSaving final model...")
242
+ trainer.save_model(os.path.join(args.output_dir, "final_model"))
243
+ tokenizer.save_pretrained(os.path.join(args.output_dir, "final_model"))
244
+
245
+ print("\nTraining complete!")
246
+ print(f" Model saved to: {args.output_dir}/final_model")
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
src/utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the Chess Challenge.
3
+
4
+ This module provides helper functions for:
5
+ - Parameter counting and budget analysis
6
+ - Model registration with Hugging Face
7
+ - Move validation with python-chess
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Dict, Optional, TYPE_CHECKING
13
+
14
+ import torch.nn as nn
15
+
16
+ if TYPE_CHECKING:
17
+ from src.model import ChessConfig
18
+
19
+
20
+ def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
21
+ """
22
+ Count the number of parameters in a model.
23
+
24
+ Args:
25
+ model: The PyTorch model.
26
+ trainable_only: If True, only count trainable parameters.
27
+
28
+ Returns:
29
+ Total number of parameters.
30
+ """
31
+ if trainable_only:
32
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
33
+ return sum(p.numel() for p in model.parameters())
34
+
35
+
36
+ def count_parameters_by_component(model: nn.Module) -> Dict[str, int]:
37
+ """
38
+ Count parameters broken down by model component.
39
+
40
+ Args:
41
+ model: The PyTorch model.
42
+
43
+ Returns:
44
+ Dictionary mapping component names to parameter counts.
45
+ """
46
+ counts = {}
47
+ for name, module in model.named_modules():
48
+ if len(list(module.children())) == 0: # Leaf module
49
+ param_count = sum(p.numel() for p in module.parameters(recurse=False))
50
+ if param_count > 0:
51
+ counts[name] = param_count
52
+ return counts
53
+
54
+
55
+ def estimate_parameters(config: "ChessConfig") -> Dict[str, int]:
56
+ """
57
+ Estimate the parameter count for a given configuration.
58
+
59
+ This is useful for planning your architecture before building the model.
60
+
61
+ Args:
62
+ config: Model configuration.
63
+
64
+ Returns:
65
+ Dictionary with estimated parameter counts by component.
66
+ """
67
+ V = config.vocab_size
68
+ d = config.n_embd
69
+ L = config.n_layer
70
+ n_ctx = config.n_ctx
71
+ n_inner = config.n_inner
72
+
73
+ estimates = {
74
+ "token_embeddings": V * d,
75
+ "position_embeddings": n_ctx * d,
76
+ "attention_qkv_per_layer": 3 * d * d,
77
+ "attention_proj_per_layer": d * d,
78
+ "ffn_per_layer": 2 * d * n_inner,
79
+ "layernorm_per_layer": 4 * d, # 2 LayerNorms, each with weight and bias
80
+ "final_layernorm": 2 * d,
81
+ }
82
+
83
+ # Calculate totals
84
+ per_layer = (
85
+ estimates["attention_qkv_per_layer"] +
86
+ estimates["attention_proj_per_layer"] +
87
+ estimates["ffn_per_layer"] +
88
+ estimates["layernorm_per_layer"]
89
+ )
90
+
91
+ estimates["total_transformer_layers"] = L * per_layer
92
+
93
+ # LM head (tied with embeddings by default)
94
+ if config.tie_weights:
95
+ estimates["lm_head"] = 0
96
+ estimates["lm_head_note"] = "Tied with token embeddings"
97
+ else:
98
+ estimates["lm_head"] = V * d
99
+
100
+ # Grand total
101
+ estimates["total"] = (
102
+ estimates["token_embeddings"] +
103
+ estimates["position_embeddings"] +
104
+ estimates["total_transformer_layers"] +
105
+ estimates["final_layernorm"] +
106
+ estimates["lm_head"]
107
+ )
108
+
109
+ return estimates
110
+
111
+
112
+ def print_parameter_budget(config: "ChessConfig", limit: int = 1_000_000) -> None:
113
+ """
114
+ Print a formatted parameter budget analysis.
115
+
116
+ Args:
117
+ config: Model configuration.
118
+ limit: Parameter limit to compare against.
119
+ """
120
+ estimates = estimate_parameters(config)
121
+
122
+ print("=" * 60)
123
+ print("PARAMETER BUDGET ANALYSIS")
124
+ print("=" * 60)
125
+ print(f"\nConfiguration:")
126
+ print(f" vocab_size (V) = {config.vocab_size}")
127
+ print(f" n_embd (d) = {config.n_embd}")
128
+ print(f" n_layer (L) = {config.n_layer}")
129
+ print(f" n_head = {config.n_head}")
130
+ print(f" n_ctx = {config.n_ctx}")
131
+ print(f" n_inner = {config.n_inner}")
132
+ print(f" tie_weights = {config.tie_weights}")
133
+
134
+ print(f"\nParameter Breakdown:")
135
+ print(f" Token Embeddings: {estimates['token_embeddings']:>10,}")
136
+ print(f" Position Embeddings: {estimates['position_embeddings']:>10,}")
137
+ print(f" Transformer Layers: {estimates['total_transformer_layers']:>10,}")
138
+ print(f" Final LayerNorm: {estimates['final_layernorm']:>10,}")
139
+
140
+ if config.tie_weights:
141
+ print(f" LM Head: {'(tied)':>10}")
142
+ else:
143
+ print(f" LM Head: {estimates['lm_head']:>10,}")
144
+
145
+ print(f" " + "-" * 30)
146
+ print(f" TOTAL: {estimates['total']:>10,}")
147
+
148
+ print(f"\nBudget Status:")
149
+ print(f" Limit: {limit:>10,}")
150
+ print(f" Used: {estimates['total']:>10,}")
151
+ print(f" Remaining:{limit - estimates['total']:>10,}")
152
+
153
+ if estimates['total'] <= limit:
154
+ print(f"\n Within budget! ({estimates['total'] / limit * 100:.1f}% used)")
155
+ else:
156
+ print(f"\n OVER BUDGET by {estimates['total'] - limit:,} parameters!")
157
+
158
+ print("=" * 60)
159
+
160
+
161
+ def validate_move_with_chess(move: str, board_fen: Optional[str] = None) -> bool:
162
+ """
163
+ Validate a move using python-chess.
164
+
165
+ This function converts the dataset's extended UCI format to standard UCI
166
+ and validates it against the current board state.
167
+
168
+ Args:
169
+ move: Move in extended UCI format (e.g., "WPe2e4", "BNg8f6(x)").
170
+ board_fen: FEN string of the current board state (optional).
171
+
172
+ Returns:
173
+ True if the move is legal, False otherwise.
174
+ """
175
+ try:
176
+ import chess
177
+ except ImportError:
178
+ raise ImportError("python-chess is required for move validation. "
179
+ "Install it with: pip install python-chess")
180
+
181
+ # Parse the extended UCI format
182
+ # Format: [W|B][Piece][from_sq][to_sq][suffix]
183
+ # Example: WPe2e4, BNg8f6(x), WKe1g1(o)
184
+
185
+ if len(move) < 6:
186
+ return False
187
+
188
+ # Extract components
189
+ color = move[0] # W or B
190
+ piece = move[1] # P, N, B, R, Q, K
191
+ from_sq = move[2:4] # e.g., "e2"
192
+ to_sq = move[4:6] # e.g., "e4"
193
+
194
+ # Check for promotion
195
+ promotion = None
196
+ if "=" in move:
197
+ promo_idx = move.index("=")
198
+ promotion = move[promo_idx + 1].lower()
199
+
200
+ # Create board
201
+ board = chess.Board(board_fen) if board_fen else chess.Board()
202
+
203
+ # Build UCI move string
204
+ uci_move = from_sq + to_sq
205
+ if promotion:
206
+ uci_move += promotion
207
+
208
+ try:
209
+ move_obj = chess.Move.from_uci(uci_move)
210
+ return move_obj in board.legal_moves
211
+ except (ValueError, chess.InvalidMoveError):
212
+ return False
213
+
214
+
215
+ def convert_extended_uci_to_uci(move: str) -> str:
216
+ """
217
+ Convert extended UCI format to standard UCI format.
218
+
219
+ Args:
220
+ move: Move in extended UCI format (e.g., "WPe2e4").
221
+
222
+ Returns:
223
+ Move in standard UCI format (e.g., "e2e4").
224
+ """
225
+ if len(move) < 6:
226
+ return move
227
+
228
+ # Extract squares
229
+ from_sq = move[2:4]
230
+ to_sq = move[4:6]
231
+
232
+ # Check for promotion
233
+ promotion = ""
234
+ if "=" in move:
235
+ promo_idx = move.index("=")
236
+ promotion = move[promo_idx + 1].lower()
237
+
238
+ return from_sq + to_sq + promotion
239
+
240
+
241
+ def convert_uci_to_extended(
242
+ uci_move: str,
243
+ board_fen: str,
244
+ ) -> str:
245
+ """
246
+ Convert standard UCI format to extended UCI format.
247
+
248
+ Args:
249
+ uci_move: Move in standard UCI format (e.g., "e2e4").
250
+ board_fen: FEN string of the current board state.
251
+
252
+ Returns:
253
+ Move in extended UCI format (e.g., "WPe2e4").
254
+ """
255
+ try:
256
+ import chess
257
+ except ImportError:
258
+ raise ImportError("python-chess is required for move conversion.")
259
+
260
+ board = chess.Board(board_fen)
261
+ move = chess.Move.from_uci(uci_move)
262
+
263
+ # Get color
264
+ color = "W" if board.turn == chess.WHITE else "B"
265
+
266
+ # Get piece
267
+ piece = board.piece_at(move.from_square)
268
+ piece_letter = piece.symbol().upper() if piece else "P"
269
+
270
+ # Build extended UCI
271
+ from_sq = chess.square_name(move.from_square)
272
+ to_sq = chess.square_name(move.to_square)
273
+
274
+ result = f"{color}{piece_letter}{from_sq}{to_sq}"
275
+
276
+ # Add promotion
277
+ if move.promotion:
278
+ result += f"={chess.piece_symbol(move.promotion).upper()}"
279
+
280
+ # Add suffix for captures
281
+ if board.is_capture(move):
282
+ result += "(x)"
283
+
284
+ # Add suffix for check/checkmate
285
+ board.push(move)
286
+ if board.is_checkmate():
287
+ if "(x)" in result:
288
+ result = result.replace("(x)", "(x+*)")
289
+ else:
290
+ result += "(+*)"
291
+ elif board.is_check():
292
+ if "(x)" in result:
293
+ result = result.replace("(x)", "(x+)")
294
+ else:
295
+ result += "(+)"
296
+ board.pop()
297
+
298
+ # Handle castling notation
299
+ if board.is_castling(move):
300
+ if move.to_square in [chess.G1, chess.G8]: # Kingside
301
+ result = result.replace("(x)", "").replace("(+)", "") + "(o)"
302
+ else: # Queenside
303
+ result = result.replace("(x)", "").replace("(+)", "") + "(O)"
304
+
305
+ return result
submit.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Submission script for the Chess Challenge.
4
+
5
+ This script pushes your trained model to the Hugging Face Hub under the
6
+ LLM-course organization, with metadata tracking who submitted it.
7
+
8
+ Usage:
9
+ python submit.py --model_path ./my_model/final_model --model_name my-chess-model
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import tempfile
15
+ from pathlib import Path
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Submit your chess model to Hugging Face Hub")
20
+ parser.add_argument(
21
+ "--model_path", type=str, default="./my_model/final_model",
22
+ help="Path to your trained model directory"
23
+ )
24
+ parser.add_argument(
25
+ "--model_name", type=str, required=True,
26
+ help="Name for your model on the Hub (e.g., 'my-chess-model')"
27
+ )
28
+ args = parser.parse_args()
29
+
30
+ # Fixed organization
31
+ organization = "LLM-course"
32
+
33
+ # Check model path exists
34
+ if not os.path.exists(args.model_path):
35
+ print(f"Error: Model path '{args.model_path}' does not exist.")
36
+ print("Train a model first with: python -m src.train --output_dir ./my_model")
37
+ return 1
38
+
39
+ # Import here to avoid slow startup
40
+ from huggingface_hub import HfApi, HfFolder, whoami
41
+ from transformers import AutoModelForCausalLM
42
+
43
+ # Ensure user is logged in and get their info
44
+ print("=" * 60)
45
+ print("CHESS CHALLENGE - MODEL SUBMISSION")
46
+ print("=" * 60)
47
+
48
+ try:
49
+ user_info = whoami()
50
+ username = user_info["name"]
51
+ print(f"\nLogged in as: {username}")
52
+ except Exception:
53
+ print("\nYou need to log in to Hugging Face first.")
54
+ print("Run: huggingface-cli login")
55
+ return 1
56
+
57
+ # Import custom classes to register them
58
+ from src.model import ChessConfig, ChessForCausalLM
59
+ from src.tokenizer import ChessTokenizer
60
+
61
+ # Load model and tokenizer
62
+ print(f"\nLoading model from: {args.model_path}")
63
+ model = AutoModelForCausalLM.from_pretrained(args.model_path)
64
+ tokenizer = ChessTokenizer.from_pretrained(args.model_path)
65
+
66
+ # Count parameters
67
+ n_params = sum(p.numel() for p in model.parameters())
68
+ print(f"Model parameters: {n_params:,}")
69
+
70
+ if n_params > 1_000_000:
71
+ print(f"WARNING: Model exceeds 1M parameter limit ({n_params:,} params)")
72
+
73
+ # Prepare repo name
74
+ repo_id = f"{organization}/{args.model_name}"
75
+ print(f"\nSubmitting to: {repo_id}")
76
+
77
+ # Create a temporary directory to prepare submission
78
+ with tempfile.TemporaryDirectory() as tmp_dir:
79
+ tmp_path = Path(tmp_dir)
80
+
81
+ # Save model and tokenizer
82
+ model.save_pretrained(tmp_path)
83
+ tokenizer.save_pretrained(tmp_path)
84
+
85
+ # Create model card with submitter info
86
+ model_card = f"""---
87
+ library_name: transformers
88
+ tags:
89
+ - chess
90
+ - llm-course
91
+ - chess-challenge
92
+ license: mit
93
+ ---
94
+
95
+ # {args.model_name}
96
+
97
+ Chess model submitted to the LLM Course Chess Challenge.
98
+
99
+ ## Submission Info
100
+
101
+ - **Submitted by**: [{username}](https://huggingface.co/{username})
102
+ - **Parameters**: {n_params:,}
103
+ - **Organization**: {organization}
104
+
105
+ ## Model Details
106
+
107
+ - **Architecture**: Chess Transformer (GPT-style)
108
+ - **Vocab size**: {tokenizer.vocab_size}
109
+ - **Embedding dim**: {model.config.n_embd}
110
+ - **Layers**: {model.config.n_layer}
111
+ - **Heads**: {model.config.n_head}
112
+ """
113
+ (tmp_path / "README.md").write_text(model_card)
114
+
115
+ # Push to Hub
116
+ print("\nUploading to Hugging Face Hub...")
117
+ api = HfApi()
118
+
119
+ # Create repo if it doesn't exist
120
+ api.create_repo(
121
+ repo_id=repo_id,
122
+ exist_ok=True,
123
+ )
124
+
125
+ # Upload all files
126
+ api.upload_folder(
127
+ folder_path=tmp_path,
128
+ repo_id=repo_id,
129
+ commit_message=f"Chess Challenge submission by {username}",
130
+ )
131
+
132
+ print("\n" + "=" * 60)
133
+ print("SUBMISSION COMPLETE!")
134
+ print("=" * 60)
135
+ print(f"\nYour model is now available at:")
136
+ print(f" https://huggingface.co/{repo_id}")
137
+ print(f"\nSubmitted by: {username}")
138
+ print(f"Parameters: {n_params:,}")
139
+
140
+ return 0
141
+
142
+
143
+ if __name__ == "__main__":
144
+ exit(main())