Kevin Hamon commited on
Commit
876e9df
·
1 Parent(s): cd81776

fix tokenizer

Browse files
Files changed (1) hide show
  1. tokenizer.py +218 -69
tokenizer.py CHANGED
@@ -1,14 +1,13 @@
1
  """
2
  Custom Chess Tokenizer for the Chess Challenge.
3
 
4
- This tokenizer treats each move as a single token using the extended UCI notation
5
- from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
-
7
- The dataset format uses:
8
  - W/B prefix for White/Black
9
  - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
- - Source and destination squares (e.g., e2e4)
 
11
  - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
 
12
  """
13
 
14
  from __future__ import annotations
@@ -16,6 +15,8 @@ from __future__ import annotations
16
  import json
17
  import os
18
  from pathlib import Path
 
 
19
  from typing import Dict, List, Optional
20
 
21
  from transformers import PreTrainedTokenizer
@@ -23,16 +24,12 @@ from transformers import PreTrainedTokenizer
23
 
24
  class ChessTokenizer(PreTrainedTokenizer):
25
  """
26
- A custom tokenizer for chess moves using extended UCI notation.
27
-
28
- This tokenizer maps each possible chess move to a unique token ID.
29
- The vocabulary is built from the training dataset to ensure all moves
30
- encountered during training have a corresponding token.
31
 
32
  Example:
33
  >>> tokenizer = ChessTokenizer()
34
  >>> tokenizer.encode("WPe2e4 BPe7e5")
35
- [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
  """
37
 
38
  model_input_names = ["input_ids", "attention_mask"]
@@ -43,6 +40,7 @@ class ChessTokenizer(PreTrainedTokenizer):
43
  BOS_TOKEN = "[BOS]"
44
  EOS_TOKEN = "[EOS]"
45
  UNK_TOKEN = "[UNK]"
 
46
 
47
  def __init__(
48
  self,
@@ -63,6 +61,7 @@ class ChessTokenizer(PreTrainedTokenizer):
63
  self._bos_token = self.BOS_TOKEN
64
  self._eos_token = self.EOS_TOKEN
65
  self._unk_token = self.UNK_TOKEN
 
66
 
67
  # Remove any duplicate special-token entries passed through kwargs
68
  # to avoid "multiple values for keyword" errors when loading from disk.
@@ -70,6 +69,7 @@ class ChessTokenizer(PreTrainedTokenizer):
70
  kwargs.pop("bos_token", None)
71
  kwargs.pop("eos_token", None)
72
  kwargs.pop("unk_token", None)
 
73
 
74
  # Load or create vocabulary
75
  if vocab is not None:
@@ -91,6 +91,7 @@ class ChessTokenizer(PreTrainedTokenizer):
91
  bos_token=self._bos_token,
92
  eos_token=self._eos_token,
93
  unk_token=self._unk_token,
 
94
  **kwargs,
95
  )
96
 
@@ -101,48 +102,10 @@ class ChessTokenizer(PreTrainedTokenizer):
101
  For the full vocabulary, use `build_vocab_from_dataset()`.
102
  This minimal vocab is just a placeholder - you should build from data.
103
  """
104
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
  vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
  return vocab
107
 
108
- @classmethod
109
- def build_vocab_from_iterator(
110
- cls,
111
- iterator,
112
- min_frequency: int = 1,
113
- ) -> "ChessTokenizer":
114
- """
115
- Build a tokenizer vocabulary from an iterator of game strings.
116
-
117
- Args:
118
- iterator: An iterator yielding game strings (space-separated moves).
119
- min_frequency: Minimum frequency for a token to be included.
120
-
121
- Returns:
122
- A ChessTokenizer with the built vocabulary.
123
- """
124
- from collections import Counter
125
-
126
- token_counts = Counter()
127
-
128
- for game in iterator:
129
- moves = game.strip().split()
130
- token_counts.update(moves)
131
-
132
- # Filter by frequency
133
- tokens = [
134
- token for token, count in token_counts.items()
135
- if count >= min_frequency
136
- ]
137
-
138
- # Sort for reproducibility
139
- tokens = sorted(tokens)
140
-
141
- # Build vocabulary
142
- special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
143
- vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
144
-
145
- return cls(vocab=vocab)
146
 
147
  @classmethod
148
  def build_vocab_from_dataset(
@@ -150,8 +113,7 @@ class ChessTokenizer(PreTrainedTokenizer):
150
  dataset_name: str = "dlouapre/lichess_2025-01_1M",
151
  split: str = "train",
152
  column: str = "text",
153
- min_frequency: int = 500,
154
- max_samples: Optional[int] = 100000,
155
  ) -> "ChessTokenizer":
156
  """
157
  Build a tokenizer vocabulary from a Hugging Face dataset.
@@ -160,24 +122,101 @@ class ChessTokenizer(PreTrainedTokenizer):
160
  dataset_name: Name of the dataset on Hugging Face Hub.
161
  split: Dataset split to use.
162
  column: Column containing the game strings.
163
- min_frequency: Minimum frequency for a token to be included (default: 500).
164
- max_samples: Maximum number of samples to process (default: 100k).
165
 
166
  Returns:
167
  A ChessTokenizer with the built vocabulary.
 
 
 
 
168
  """
169
  from datasets import load_dataset
170
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  dataset = load_dataset(dataset_name, split=split)
172
-
173
- if max_samples is not None:
174
- dataset = dataset.select(range(min(max_samples, len(dataset))))
175
-
176
- def game_iterator():
177
- for example in dataset:
178
- yield example[column]
179
-
180
- return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  @property
183
  def vocab_size(self) -> int:
@@ -198,7 +237,34 @@ class ChessTokenizer(PreTrainedTokenizer):
198
  Returns:
199
  List of move tokens.
200
  """
201
- return text.strip().split()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  def _convert_token_to_id(self, token: str) -> int:
204
  """Convert a token to its ID."""
@@ -213,7 +279,16 @@ class ChessTokenizer(PreTrainedTokenizer):
213
  # Filter out special tokens for cleaner output
214
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
215
  return " ".join(t for t in tokens if t not in special)
216
-
 
 
 
 
 
 
 
 
 
217
  def save_vocabulary(
218
  self,
219
  save_directory: str,
@@ -242,6 +317,80 @@ class ChessTokenizer(PreTrainedTokenizer):
242
 
243
  return (vocab_file,)
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def count_vocab_from_dataset(
247
  dataset_name: str = "dlouapre/lichess_2025-01_1M",
@@ -269,10 +418,10 @@ def count_vocab_from_dataset(
269
  if max_samples is not None:
270
  dataset = dataset.select(range(min(max_samples, len(dataset))))
271
 
 
272
  token_counts = Counter()
273
 
274
  for example in dataset:
275
- moves = example[column].strip().split()
276
- token_counts.update(moves)
277
 
278
- return dict(token_counts)
 
1
  """
2
  Custom Chess Tokenizer for the Chess Challenge.
3
 
4
+ We build a vocabulary with:
 
 
 
5
  - W/B prefix for White/Black
6
  - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
7
+ - Source and rank and file: e.g e 2
8
+ - Destination and rank and file: e.g e 4
9
  - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
10
+
11
  """
12
 
13
  from __future__ import annotations
 
15
  import json
16
  import os
17
  from pathlib import Path
18
+ import shutil
19
+ import inspect
20
  from typing import Dict, List, Optional
21
 
22
  from transformers import PreTrainedTokenizer
 
24
 
25
  class ChessTokenizer(PreTrainedTokenizer):
26
  """
27
+ A custom tokenizer for chess moves.
 
 
 
 
28
 
29
  Example:
30
  >>> tokenizer = ChessTokenizer()
31
  >>> tokenizer.encode("WPe2e4 BPe7e5")
32
+ # [BOS, W, P, e, 2, e, 4, B, P, e, 7, e, 5, EOS]
33
  """
34
 
35
  model_input_names = ["input_ids", "attention_mask"]
 
40
  BOS_TOKEN = "[BOS]"
41
  EOS_TOKEN = "[EOS]"
42
  UNK_TOKEN = "[UNK]"
43
+ SEP_TOKEN = "[SEP]"
44
 
45
  def __init__(
46
  self,
 
61
  self._bos_token = self.BOS_TOKEN
62
  self._eos_token = self.EOS_TOKEN
63
  self._unk_token = self.UNK_TOKEN
64
+ self._sep_token = self.SEP_TOKEN
65
 
66
  # Remove any duplicate special-token entries passed through kwargs
67
  # to avoid "multiple values for keyword" errors when loading from disk.
 
69
  kwargs.pop("bos_token", None)
70
  kwargs.pop("eos_token", None)
71
  kwargs.pop("unk_token", None)
72
+ kwargs.pop("sep_token", None)
73
 
74
  # Load or create vocabulary
75
  if vocab is not None:
 
91
  bos_token=self._bos_token,
92
  eos_token=self._eos_token,
93
  unk_token=self._unk_token,
94
+ sep_token=self._sep_token,
95
  **kwargs,
96
  )
97
 
 
102
  For the full vocabulary, use `build_vocab_from_dataset()`.
103
  This minimal vocab is just a placeholder - you should build from data.
104
  """
105
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, self.SEP_TOKEN]
106
  vocab = {token: idx for idx, token in enumerate(special_tokens)}
107
  return vocab
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  @classmethod
111
  def build_vocab_from_dataset(
 
113
  dataset_name: str = "dlouapre/lichess_2025-01_1M",
114
  split: str = "train",
115
  column: str = "text",
116
+ save_path: Optional[str] = None,
 
117
  ) -> "ChessTokenizer":
118
  """
119
  Build a tokenizer vocabulary from a Hugging Face dataset.
 
122
  dataset_name: Name of the dataset on Hugging Face Hub.
123
  split: Dataset split to use.
124
  column: Column containing the game strings.
 
 
125
 
126
  Returns:
127
  A ChessTokenizer with the built vocabulary.
128
+
129
+ Args:
130
+ save_path: Optional path to write the generated vocab JSON. If not
131
+ provided, the vocab will be saved to ``./chess_tokenizer_vocab.json``.
132
  """
133
  from datasets import load_dataset
134
+
135
+ # If a saved vocab exists at `save_path`, load it and return a tokenizer
136
+ if save_path is None:
137
+ cwd = os.getcwd()
138
+ save_path = os.path.join(cwd, "chess_tokenizer_vocab.json")
139
+
140
+ if os.path.exists(save_path):
141
+ try:
142
+ with open(save_path, "r", encoding="utf-8") as f:
143
+ print("Loading existing tokenizer vocab from", save_path)
144
+ vocab = json.load(f)
145
+ return cls(vocab=vocab)
146
+ except Exception:
147
+ # If loading fails, fall through to rebuild the vocab.
148
+ pass
149
+
150
  dataset = load_dataset(dataset_name, split=split)
151
+
152
+ # Iterator over games (respect max_samples if provided)
153
+ samples = dataset[column]
154
+
155
+ tokens = set()
156
+
157
+ for game in samples:
158
+ if not isinstance(game, str):
159
+ continue
160
+ moves = game.strip().split()
161
+ for move in moves:
162
+ # Basic parsing of move token components
163
+ if len(move) < 2:
164
+ continue
165
+ color = move[0]
166
+ piece = move[1]
167
+ from_square = move[2:4] if len(move) >= 4 else ''
168
+ to_square = move[4:6] if len(move) >= 6 else ''
169
+ suffix = move[6:] if len(move) > 6 else ''
170
+
171
+ tokens.add(color)
172
+ tokens.add(piece)
173
+ tokens.add(from_square)
174
+ tokens.add(to_square)
175
+ if suffix:
176
+ tokens.add(suffix)
177
+
178
+ # Sort tokens
179
+ tokens = sorted(tokens)
180
+
181
+ # Ensure special tokens are present at fixed ids
182
+ special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN, cls.SEP_TOKEN]
183
+
184
+ # Build vocab mapping: special tokens first, then tokens
185
+ vocab: Dict[str, int] = {}
186
+ idx = 0
187
+ for st in special_tokens:
188
+ vocab[st] = idx
189
+ idx += 1
190
+
191
+ for t in tokens:
192
+ if t in vocab:
193
+ continue
194
+ vocab[t] = idx
195
+ idx += 1
196
+
197
+ # Create tokenizer instance with this vocab
198
+ tokenizer = cls(vocab=vocab)
199
+
200
+ # Save vocab to disk. Use provided `save_path` or default file name.
201
+ try:
202
+ if save_path is None:
203
+ cwd = os.getcwd()
204
+ save_path = os.path.join(cwd, "chess_tokenizer_vocab.json")
205
+
206
+ # Write to a temporary file first and atomically replace final file.
207
+ tmp_path = save_path + ".tmp"
208
+ with open(tmp_path, "w", encoding="utf-8") as f:
209
+ json.dump(vocab, f, ensure_ascii=False, indent=2)
210
+ os.replace(tmp_path, save_path)
211
+ except Exception:
212
+ # Non-fatal: ignore save errors but don't leave temp files behind.
213
+ try:
214
+ if 'tmp_path' in locals() and os.path.exists(tmp_path):
215
+ os.remove(tmp_path)
216
+ except Exception:
217
+ pass
218
+
219
+ return tokenizer
220
 
221
  @property
222
  def vocab_size(self) -> int:
 
237
  Returns:
238
  List of move tokens.
239
  """
240
+ tokens: List[str] = []
241
+ for move in text.strip().split():
242
+ if len(move) < 2:
243
+ continue
244
+ color, piece, from_square, to_square, suffix = self._decompose_move(move)
245
+ tokens.append(color)
246
+ tokens.append(piece)
247
+ tokens.append(from_square)
248
+ tokens.append(to_square)
249
+ if suffix:
250
+ tokens.append(suffix)
251
+
252
+ tokens.append(self._sep_token)
253
+
254
+ return tokens[:-1] # Remove last SEP token
255
+
256
+ @staticmethod
257
+ def _decompose_move(move: str):
258
+ """Decompose a move string into components: color, piece, from_square, to_square, suffix.
259
+
260
+ Returns a 5-tuple of strings (empty strings for missing parts).
261
+ """
262
+ color = move[0]
263
+ piece = move[1] if len(move) >= 2 else ''
264
+ from_square = move[2:4] if len(move) >= 4 else ''
265
+ to_square = move[4:6] if len(move) >= 6 else ''
266
+ suffix = move[6:] if len(move) > 6 else ''
267
+ return color, piece, from_square, to_square, suffix
268
 
269
  def _convert_token_to_id(self, token: str) -> int:
270
  """Convert a token to its ID."""
 
279
  # Filter out special tokens for cleaner output
280
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
281
  return " ".join(t for t in tokens if t not in special)
282
+
283
+ def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
284
+ """Decode a list of token IDs back to a string."""
285
+ tokens = [self._convert_id_to_token(int(tid)) for tid in token_ids]
286
+ if skip_special_tokens:
287
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
288
+ # SEP token should be replace by space
289
+ tokens = [t if t != self.SEP_TOKEN else " " for t in tokens if t not in special]
290
+ return "".join(tokens)
291
+
292
  def save_vocabulary(
293
  self,
294
  save_directory: str,
 
317
 
318
  return (vocab_file,)
319
 
320
+ def save_pretrained(
321
+ self,
322
+ save_directory: str,
323
+ filename_prefix: Optional[str] = None,
324
+ save_tokenizer_code: bool = True,
325
+ ) -> None:
326
+ """Save tokenizer files to a directory in a HF-compatible layout.
327
+
328
+ This writes the vocab JSON (via `save_vocabulary`), a small
329
+ `tokenizer_config.json` describing special tokens and the vocab
330
+ filename, and optionally copies the tokenizer module source file
331
+ into the directory so others can import the implementation.
332
+ """
333
+ if not os.path.isdir(save_directory):
334
+ os.makedirs(save_directory, exist_ok=True)
335
+
336
+ # Save the vocabulary file
337
+ vocab_file_tuple = self.save_vocabulary(save_directory, filename_prefix)
338
+ vocab_file = vocab_file_tuple[0]
339
+
340
+ # Write a minimal tokenizer config
341
+ config = {
342
+ "tokenizer_class": self.__class__.__name__,
343
+ "vocab_file": os.path.basename(vocab_file),
344
+ "pad_token": self.PAD_TOKEN,
345
+ "bos_token": self.BOS_TOKEN,
346
+ "eos_token": self.EOS_TOKEN,
347
+ "unk_token": self.UNK_TOKEN,
348
+ }
349
+ config_path = os.path.join(save_directory, "tokenizer_config.json")
350
+ with open(config_path, "w", encoding="utf-8") as f:
351
+ json.dump(config, f, ensure_ascii=False, indent=2)
352
+
353
+ # Optionally copy this module file so the tokenizer class implementation
354
+ # is available alongside the saved vocab/config. This helps when
355
+ # transferring the saved tokenizer to another environment.
356
+ if save_tokenizer_code:
357
+ try:
358
+ src_file = Path(inspect.getsourcefile(self.__class__))
359
+ dst_file = Path(save_directory) / src_file.name
360
+ shutil.copy2(src_file, dst_file)
361
+ except Exception:
362
+ # Non-fatal; we still saved vocab and config
363
+ pass
364
+
365
+ @classmethod
366
+ def from_pretrained(cls, load_directory: str) -> "ChessTokenizer":
367
+ """Load tokenizer from a directory previously written with `save_pretrained`.
368
+
369
+ This primarily reads the vocab file and constructs the tokenizer.
370
+ If a `tokenizer_config.json` exists it will be consulted for the
371
+ vocab filename and special tokens (but we still instantiate using
372
+ the provided class).
373
+ """
374
+ config_path = os.path.join(load_directory, "tokenizer_config.json")
375
+ vocab_file = None
376
+ if os.path.exists(config_path):
377
+ try:
378
+ with open(config_path, "r", encoding="utf-8") as f:
379
+ cfg = json.load(f)
380
+ vocab_file = os.path.join(load_directory, cfg.get("vocab_file", "vocab.json"))
381
+ except Exception:
382
+ pass
383
+
384
+ if vocab_file is None:
385
+ # Fallback: look for a vocab file in the directory
386
+ candidates = [p for p in os.listdir(load_directory) if p.endswith("vocab.json")]
387
+ if candidates:
388
+ vocab_file = os.path.join(load_directory, candidates[0])
389
+
390
+ if vocab_file is None or not os.path.exists(vocab_file):
391
+ raise FileNotFoundError(f"No vocab file found in {load_directory}")
392
+
393
+ return cls(vocab_file=vocab_file)
394
 
395
  def count_vocab_from_dataset(
396
  dataset_name: str = "dlouapre/lichess_2025-01_1M",
 
418
  if max_samples is not None:
419
  dataset = dataset.select(range(min(max_samples, len(dataset))))
420
 
421
+ tokenizer = ChessTokenizer()
422
  token_counts = Counter()
423
 
424
  for example in dataset:
425
+ token_counts.update(tokenizer._tokenize(example[column]))
 
426
 
427
+ return dict(token_counts)