Willy Vo commited on
Commit
d23894d
·
1 Parent(s): a21ba01

Add Tokenizer

Browse files
Files changed (1) hide show
  1. tokenizer.py +340 -0
tokenizer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+
23
+
24
+ class ChessTokenizer(PreTrainedTokenizer):
25
+ """
26
+ A custom tokenizer for chess moves using extended UCI notation.
27
+
28
+ This tokenizer maps each possible chess move to a unique token ID.
29
+ The vocabulary is built from the training dataset to ensure all moves
30
+ encountered during training have a corresponding token.
31
+
32
+ Example:
33
+ >>> tokenizer = ChessTokenizer()
34
+ >>> tokenizer.encode("WPe2e4 BPe7e5")
35
+ [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
+ """
37
+
38
+ model_input_names = ["input_ids", "attention_mask"]
39
+ vocab_files_names = {"vocab_file": "vocab.json"}
40
+
41
+ # Special tokens
42
+ PAD_TOKEN = "[PAD]"
43
+ BOS_TOKEN = "[BOS]"
44
+ EOS_TOKEN = "[EOS]"
45
+ UNK_TOKEN = "[UNK]"
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file: Optional[str] = None,
50
+ vocab: Optional[Dict[str, int]] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initialize the chess tokenizer.
55
+
56
+ Args:
57
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
58
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
+ """
61
+ # Initialize special tokens
62
+ self._pad_token = self.PAD_TOKEN
63
+ self._bos_token = self.BOS_TOKEN
64
+ self._eos_token = self.EOS_TOKEN
65
+ self._unk_token = self.UNK_TOKEN
66
+
67
+ # Remove any duplicate special-token entries passed through kwargs
68
+ # to avoid "multiple values for keyword" errors when loading from disk.
69
+ kwargs.pop("pad_token", None)
70
+ kwargs.pop("bos_token", None)
71
+ kwargs.pop("eos_token", None)
72
+ kwargs.pop("unk_token", None)
73
+
74
+ # Load or create vocabulary
75
+ if vocab is not None:
76
+ self._vocab = vocab
77
+ elif vocab_file is not None and os.path.exists(vocab_file):
78
+ with open(vocab_file, "r", encoding="utf-8") as f:
79
+ self._vocab = json.load(f)
80
+ else:
81
+ # Create a minimal vocabulary with just special tokens
82
+ # The full vocabulary should be built from the dataset
83
+ self._vocab = self._create_default_vocab()
84
+
85
+ # Create reverse mapping
86
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
+
88
+ # Call parent init AFTER setting up vocab
89
+ super().__init__(
90
+ pad_token=self._pad_token,
91
+ bos_token=self._bos_token,
92
+ eos_token=self._eos_token,
93
+ unk_token=self._unk_token,
94
+ **kwargs,
95
+ )
96
+
97
+ def _create_default_vocab(self) -> Dict[str, int]:
98
+ """
99
+ Create a fixed structured vocabulary (no dataset-dependent move tokens).
100
+
101
+ Tokens:
102
+ - Special: [PAD], [BOS], [EOS], [UNK]
103
+ - Color: [W], [B]
104
+ - Pieces: [P], [N], [BISHOP], [R], [Q], [K]
105
+ - Squares: [a1]..[h8]
106
+ - Suffixes: [x], [+], [#]
107
+ - Castling: [O-O], [O-O-O]
108
+ - Promotions: [prom_Q], [prom_R], [prom_B], [prom_N]
109
+ - Move separator: [MOVE_END]
110
+ """
111
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
112
+ colors = ["[W]", "[B]"]
113
+ pieces = ["[P]", "[N]", "[BISHOP]", "[R]", "[Q]", "[K]"]
114
+
115
+ files = "abcdefgh"
116
+ ranks = "12345678"
117
+ squares = [f"[{f}{r}]" for r in ranks for f in files] # a1..h8
118
+
119
+ suffixes = ["[x]", "[+]", "[#]"]
120
+ castling = ["[O-O]", "[O-O-O]"]
121
+ promotions = ["[prom_Q]", "[prom_R]", "[prom_B]", "[prom_N]"]
122
+ move_end = ["[MOVE_END]"]
123
+
124
+ tokens = special + colors + pieces + squares + suffixes + castling + promotions + move_end
125
+ return {tok: i for i, tok in enumerate(tokens)}
126
+
127
+ @classmethod
128
+ def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer":
129
+ # Structured tokenizer uses a fixed vocab; iterator is unused.
130
+ return cls(vocab=cls().get_vocab())
131
+
132
+ # @classmethod
133
+ # def build_vocab_from_dataset(
134
+ # cls,
135
+ # dataset_name: str = "dlouapre/lichess_2025-01_1M",
136
+ # split: str = "train",
137
+ # column: str = "text",
138
+ # min_frequency: int = 500,
139
+ # max_samples: Optional[int] = 100000,
140
+ # ) -> "ChessTokenizer":
141
+ # """
142
+ # Build a tokenizer vocabulary from a Hugging Face dataset.
143
+
144
+ # Args:
145
+ # dataset_name: Name of the dataset on Hugging Face Hub.
146
+ # split: Dataset split to use.
147
+ # column: Column containing the game strings.
148
+ # min_frequency: Minimum frequency for a token to be included (default: 500).
149
+ # max_samples: Maximum number of samples to process (default: 100k).
150
+
151
+ # Returns:
152
+ # A ChessTokenizer with the built vocabulary.
153
+ # """
154
+ # from datasets import load_dataset
155
+
156
+ # dataset = load_dataset(dataset_name, split=split)
157
+
158
+ # if max_samples is not None:
159
+ # dataset = dataset.select(range(min(max_samples, len(dataset))))
160
+
161
+ # def game_iterator():
162
+ # for example in dataset:
163
+ # yield example[column]
164
+
165
+ # return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
166
+
167
+ @classmethod
168
+ def build_vocab_from_dataset(cls,dataset_name: str = "dlouapre/lichess_2025-01_1M",split: str = "train",column: str = "text",min_frequency: int = 500,max_samples: Optional[int] = 100000,) -> "ChessTokenizer":
169
+ # Structured tokenizer uses a fixed vocab; dataset params are unused.
170
+ return cls(vocab=cls().get_vocab())
171
+
172
+ @property
173
+ def vocab_size(self) -> int:
174
+ """Return the size of the vocabulary."""
175
+ return len(self._vocab)
176
+
177
+ def get_vocab(self) -> Dict[str, int]:
178
+ """Return the vocabulary as a dictionary."""
179
+ return dict(self._vocab)
180
+
181
+ def _move_to_tokens(self, move: str) -> List[str]:
182
+ """
183
+ Convert one extended-UCI move string to structured tokens.
184
+
185
+ Examples:
186
+ "WPe2e4" -> ["[W]","[P]","[e2]","[e4]"]
187
+ "WBb5c6(x+)" -> ["[W]","[BISHOP]","[b5]","[c6]","[x]","[+]"]
188
+ "BKe8g8(o)" -> ["[B]","[O-O]"]
189
+ "WPa7a8(Q)" -> ["[W]","[P]","[a7]","[a8]","[prom_Q]"]
190
+ """
191
+ toks: List[str] = []
192
+
193
+ if not move:
194
+ return [self.UNK_TOKEN]
195
+
196
+ # Color
197
+ color = move[0]
198
+ toks.append("[W]" if color == "W" else "[B]")
199
+
200
+ # Basic fields
201
+ # move[1] is piece letter in dataset (P,N,B,R,Q,K)
202
+ piece_char = move[1] if len(move) > 1 else ""
203
+ piece_map = {"P": "[P]", "N": "[N]", "B": "[BISHOP]", "R": "[R]", "Q": "[Q]", "K": "[K]"}
204
+ toks.append(piece_map.get(piece_char, self.UNK_TOKEN))
205
+
206
+ # Source and destination squares assumed at positions 2:4 and 4:6
207
+ # e.g. WPe2e4 -> from=e2 to=e4
208
+ if len(move) >= 6:
209
+ from_sq = move[2:4]
210
+ to_sq = move[4:6]
211
+ toks.append(f"[{from_sq}]")
212
+ toks.append(f"[{to_sq}]")
213
+ else:
214
+ # malformed
215
+ toks.append(self.UNK_TOKEN)
216
+ toks.append(self.UNK_TOKEN)
217
+
218
+ # --- Castling ---
219
+ # Dataset mentions (o)/(O)=castling, sometimes attached to king moves.
220
+ # We'll map based on king destination:
221
+ if "(o)" in move or "(O)" in move:
222
+ # King ends on g-file => O-O ; on c-file => O-O-O
223
+ if len(move) >= 6:
224
+ to_sq = move[4:6]
225
+ if to_sq[0] == "g":
226
+ return [toks[0], "[O-O]"]
227
+ if to_sq[0] == "c":
228
+ return [toks[0], "[O-O-O]"]
229
+
230
+ # --- Promotion ---
231
+ if "(Q)" in move:
232
+ toks.append("[prom_Q]")
233
+ elif "(R)" in move:
234
+ toks.append("[prom_R]")
235
+ elif "(B)" in move:
236
+ toks.append("[prom_B]")
237
+ elif "(N)" in move:
238
+ toks.append("[prom_N]")
239
+
240
+ # --- Capture / check / mate ---
241
+ # Capture patterns: "(x)" "(x+)" "(x+*)" etc.
242
+ if "(x" in move:
243
+ toks.append("[x]")
244
+
245
+ # Checkmate sometimes written (+*) or similar
246
+ if "(+*)" in move:
247
+ toks.append("[#]")
248
+ elif "(+)" in move or "(x+)" in move:
249
+ toks.append("[+]")
250
+
251
+ return toks
252
+
253
+ def _tokenize(self, text: str) -> List[str]:
254
+ """
255
+ Tokenize a game string into structured tokens.
256
+
257
+ Each move becomes:
258
+ [W]/[B], [PIECE], [from], [to], optional flags, then [MOVE_END]
259
+ """
260
+ moves = text.strip().split()
261
+ out: List[str] = []
262
+ for mv in moves:
263
+ out.extend(self._move_to_tokens(mv))
264
+ out.append("[MOVE_END]")
265
+ return out
266
+
267
+ def _convert_token_to_id(self, token: str) -> int:
268
+ """Convert a token to its ID."""
269
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
270
+
271
+ def _convert_id_to_token(self, index: int) -> str:
272
+ """Convert an ID to its token."""
273
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
274
+
275
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
276
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
277
+ return " ".join(t for t in tokens if (t not in special and t != "[MOVE_END]"))
278
+
279
+ def save_vocabulary(
280
+ self,
281
+ save_directory: str,
282
+ filename_prefix: Optional[str] = None,
283
+ ) -> tuple:
284
+ """
285
+ Save the vocabulary to a JSON file.
286
+
287
+ Args:
288
+ save_directory: Directory to save the vocabulary.
289
+ filename_prefix: Optional prefix for the filename.
290
+
291
+ Returns:
292
+ Tuple containing the path to the saved vocabulary file.
293
+ """
294
+ if not os.path.isdir(save_directory):
295
+ os.makedirs(save_directory, exist_ok=True)
296
+
297
+ vocab_file = os.path.join(
298
+ save_directory,
299
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
300
+ )
301
+
302
+ with open(vocab_file, "w", encoding="utf-8") as f:
303
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
304
+
305
+ return (vocab_file,)
306
+
307
+
308
+ def count_vocab_from_dataset(
309
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
310
+ split: str = "train",
311
+ column: str = "text",
312
+ max_samples: Optional[int] = 10000,
313
+ ) -> Dict[str, int]:
314
+ """
315
+ Count token frequencies in a dataset (useful for vocabulary analysis).
316
+
317
+ Args:
318
+ dataset_name: Name of the dataset on Hugging Face Hub.
319
+ split: Dataset split to use.
320
+ column: Column containing the game strings.
321
+ max_samples: Maximum number of samples to process.
322
+
323
+ Returns:
324
+ Dictionary mapping tokens to their frequencies.
325
+ """
326
+ from collections import Counter
327
+ from datasets import load_dataset
328
+
329
+ dataset = load_dataset(dataset_name, split=split)
330
+
331
+ if max_samples is not None:
332
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
333
+
334
+ token_counts = Counter()
335
+
336
+ for example in dataset:
337
+ moves = example[column].strip().split()
338
+ token_counts.update(moves)
339
+
340
+ return dict(token_counts)