Skydono commited on
Commit
7d16fe4
·
verified ·
1 Parent(s): def9824

Upload tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +318 -0
tokenizer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional
20
+
21
+ import re
22
+ from transformers import PreTrainedTokenizer
23
+
24
+
25
+ MOVE_RE = re.compile(
26
+ r"^(?P<side>[WB])"
27
+ r"(?P<piece>[PNBRQK])"
28
+ r"(?P<src>[a-h][1-8])"
29
+ r"(?P<dst>[a-h][1-8])"
30
+ r"(?P<suffix>.*)$"
31
+ )
32
+
33
+
34
+ class ChessTokenizer(PreTrainedTokenizer):
35
+ """
36
+ A custom tokenizer for chess moves using extended UCI notation.
37
+
38
+ This tokenizer maps each possible chess move to a unique token ID.
39
+ The vocabulary is built from the training dataset to ensure all moves
40
+ encountered during training have a corresponding token.
41
+
42
+ Example:
43
+ >>> tokenizer = ChessTokenizer()
44
+ >>> tokenizer.encode("WPe2e4 BPe7e5")
45
+ [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
46
+ """
47
+
48
+ model_input_names = ["input_ids", "attention_mask"]
49
+ vocab_files_names = {"vocab_file": "vocab.json"}
50
+
51
+ # Special tokens
52
+ PAD_TOKEN = "[PAD]"
53
+ BOS_TOKEN = "[BOS]"
54
+ EOS_TOKEN = "[EOS]"
55
+ UNK_TOKEN = "[UNK]"
56
+
57
+ def __init__(
58
+ self,
59
+ vocab_file: Optional[str] = None,
60
+ vocab: Optional[Dict[str, int]] = None,
61
+ **kwargs,
62
+ ):
63
+ """
64
+ Initialize the chess tokenizer.
65
+
66
+ Args:
67
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
68
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
69
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
70
+ """
71
+ # Initialize special tokens
72
+ self._pad_token = self.PAD_TOKEN
73
+ self._bos_token = self.BOS_TOKEN
74
+ self._eos_token = self.EOS_TOKEN
75
+ self._unk_token = self.UNK_TOKEN
76
+
77
+ # Remove any duplicate special-token entries passed through kwargs
78
+ # to avoid "multiple values for keyword" errors when loading from disk.
79
+ kwargs.pop("pad_token", None)
80
+ kwargs.pop("bos_token", None)
81
+ kwargs.pop("eos_token", None)
82
+ kwargs.pop("unk_token", None)
83
+
84
+ # Load or create vocabulary
85
+ if vocab is not None:
86
+ self._vocab = vocab
87
+ elif vocab_file is not None and os.path.exists(vocab_file):
88
+ with open(vocab_file, "r", encoding="utf-8") as f:
89
+ self._vocab = json.load(f)
90
+ else:
91
+ # Create a minimal vocabulary with just special tokens
92
+ # The full vocabulary should be built from the dataset
93
+ self._vocab = self._create_default_vocab()
94
+
95
+ # Create reverse mapping
96
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
97
+
98
+ # Call parent init AFTER setting up vocab
99
+ super().__init__(
100
+ pad_token=self._pad_token,
101
+ bos_token=self._bos_token,
102
+ eos_token=self._eos_token,
103
+ unk_token=self._unk_token,
104
+ **kwargs,
105
+ )
106
+
107
+ def _create_default_vocab(self) -> Dict[str, int]:
108
+ """
109
+ Create a minimal default vocabulary with just special tokens.
110
+
111
+ For the full vocabulary, use `build_vocab_from_dataset()`.
112
+ This minimal vocab is just a placeholder - you should build from data.
113
+ """
114
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
115
+
116
+ side_tokens = ["[W]", "[B]"]
117
+
118
+ piece_tokens = ["[P]", "[N]", "[BISHOP]", "[R]", "[Q]", "[K]"]
119
+
120
+ square_tokens = [f"[{file}{rank}]" for rank in "12345678" for file in "abcdefgh"]
121
+
122
+ suffix_tokens = ["[x]", "[+]", "[#]", "[O-O]", "[O-O-O]", "[prom_Q]", "[prom_R]", "[prom_B]", "[prom_N]"]
123
+
124
+ vocab_list = special_tokens + side_tokens + piece_tokens + square_tokens + suffix_tokens
125
+ vocab = {token: idx for idx, token in enumerate(vocab_list)}
126
+ return vocab
127
+
128
+ @classmethod
129
+ def build_vocab_from_iterator(
130
+ cls,
131
+ iterator,
132
+ min_frequency: int = 1,
133
+ ) -> "ChessTokenizer":
134
+ """
135
+ Build a tokenizer vocabulary from an iterator of game strings.
136
+
137
+ Args:
138
+ iterator: An iterator yielding game strings (space-separated moves).
139
+ min_frequency: Minimum frequency for a token to be included.
140
+
141
+ Returns:
142
+ A ChessTokenizer with the built vocabulary.
143
+ """
144
+ return cls()
145
+
146
+ @classmethod
147
+ def build_vocab_from_dataset(
148
+ cls,
149
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
150
+ split: str = "train",
151
+ column: str = "text",
152
+ min_frequency: int = 500,
153
+ max_samples: Optional[int] = 100000,
154
+ ) -> "ChessTokenizer":
155
+ """
156
+ Build a tokenizer vocabulary from a Hugging Face dataset.
157
+
158
+ Args:
159
+ dataset_name: Name of the dataset on Hugging Face Hub.
160
+ split: Dataset split to use.
161
+ column: Column containing the game strings.
162
+ min_frequency: Minimum frequency for a token to be included (default: 500).
163
+ max_samples: Maximum number of samples to process (default: 100k).
164
+
165
+ Returns:
166
+ A ChessTokenizer with the built vocabulary.
167
+ """
168
+ return cls()
169
+
170
+ @property
171
+ def vocab_size(self) -> int:
172
+ """Return the size of the vocabulary."""
173
+ return len(self._vocab)
174
+
175
+ def get_vocab(self) -> Dict[str, int]:
176
+ """Return the vocabulary as a dictionary."""
177
+ return dict(self._vocab)
178
+
179
+ def _tokenize(self, text: str) -> List[str]:
180
+ """
181
+ Tokenize a string of moves into a list of tokens.
182
+
183
+ Args:
184
+ text: A string of space-separated moves.
185
+
186
+ Returns:
187
+ List of move tokens.
188
+ """
189
+ tokens: List[str] = []
190
+ moves = text.strip().split()
191
+
192
+ for move in moves:
193
+ if "O-O-O" in move:
194
+ side = "[W]" if move.startswith("W") else "[B]"
195
+ tokens.append(side)
196
+ tokens.append("[O-O-O]")
197
+ continue
198
+
199
+ if "O-O" in move:
200
+ side = "[W]" if move.startswith("W") else "[B]"
201
+ tokens.append(side)
202
+ tokens.append("[O-O]")
203
+ continue
204
+
205
+ m = MOVE_RE.match(move)
206
+ if not m:
207
+ tokens.append(self.UNK_TOKEN)
208
+ continue
209
+
210
+ side = "[W]" if m.group("side") == "W" else "[B]"
211
+ piece = m.group("piece")
212
+ src = m.group("src")
213
+ dst = m.group("dst")
214
+ suffix = m.group("suffix") or ""
215
+
216
+ tokens.append(side)
217
+
218
+ if piece == "B":
219
+ tokens.append("[BISHOP]")
220
+ else:
221
+ tokens.append(f"[{piece}]")
222
+
223
+ tokens.append(f"[{src}]")
224
+ tokens.append(f"[{dst}]")
225
+
226
+ if "x" in suffix:
227
+ tokens.append("[x]")
228
+
229
+ if "*" in suffix:
230
+ tokens.append("[#]")
231
+ elif "+" in suffix:
232
+ tokens.append("[+]")
233
+
234
+ if "=" in suffix:
235
+ i = suffix.find("=")
236
+ if i != -1 and i + 1 < len(suffix):
237
+ promo = suffix[i + 1].upper()
238
+ if promo in ("Q", "R", "B", "N"):
239
+ tokens.append(f"[prom_{promo}]")
240
+
241
+ return tokens
242
+
243
+ def _convert_token_to_id(self, token: str) -> int:
244
+ """Convert a token to its ID."""
245
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
246
+
247
+ def _convert_id_to_token(self, index: int) -> str:
248
+ """Convert an ID to its token."""
249
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
250
+
251
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
252
+ """Convert a list of tokens back to a string."""
253
+ # Filter out special tokens for cleaner output
254
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
255
+ return " ".join(t for t in tokens if t not in special)
256
+
257
+ def save_vocabulary(
258
+ self,
259
+ save_directory: str,
260
+ filename_prefix: Optional[str] = None,
261
+ ) -> tuple:
262
+ """
263
+ Save the vocabulary to a JSON file.
264
+
265
+ Args:
266
+ save_directory: Directory to save the vocabulary.
267
+ filename_prefix: Optional prefix for the filename.
268
+
269
+ Returns:
270
+ Tuple containing the path to the saved vocabulary file.
271
+ """
272
+ if not os.path.isdir(save_directory):
273
+ os.makedirs(save_directory, exist_ok=True)
274
+
275
+ vocab_file = os.path.join(
276
+ save_directory,
277
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
278
+ )
279
+
280
+ with open(vocab_file, "w", encoding="utf-8") as f:
281
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
282
+
283
+ return (vocab_file,)
284
+
285
+
286
+ def count_vocab_from_dataset(
287
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
288
+ split: str = "train",
289
+ column: str = "text",
290
+ max_samples: Optional[int] = 10000,
291
+ ) -> Dict[str, int]:
292
+ """
293
+ Count token frequencies in a dataset (useful for vocabulary analysis).
294
+
295
+ Args:
296
+ dataset_name: Name of the dataset on Hugging Face Hub.
297
+ split: Dataset split to use.
298
+ column: Column containing the game strings.
299
+ max_samples: Maximum number of samples to process.
300
+
301
+ Returns:
302
+ Dictionary mapping tokens to their frequencies.
303
+ """
304
+ from collections import Counter
305
+ from datasets import load_dataset
306
+
307
+ dataset = load_dataset(dataset_name, split=split)
308
+
309
+ if max_samples is not None:
310
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
311
+
312
+ tokenizer = ChessTokenizer()
313
+ token_counts = Counter()
314
+
315
+ for example in dataset:
316
+ token_counts.update(tokenizer._tokenize(example[column]))
317
+
318
+ return dict(token_counts)