Sunxt25 commited on
Commit
21ed1fa
·
verified ·
1 Parent(s): b7dc9f4

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +154 -25
tokenizer.py CHANGED
@@ -1,55 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
 
2
  import json
3
  import os
4
- import re
5
  from typing import Dict, List, Optional
 
6
  from transformers import PreTrainedTokenizer
7
- import torch
8
 
9
  class ChessTokenizer(PreTrainedTokenizer):
10
  """
11
- vocab size: 4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix + 4 promotions (qrbn)
 
 
 
 
 
 
 
 
 
12
  """
13
-
14
  model_input_names = ["input_ids", "attention_mask"]
15
  vocab_files_names = {"vocab_file": "vocab.json"}
16
-
 
17
  PAD_TOKEN = "[PAD]"
18
  BOS_TOKEN = "[BOS]"
19
  EOS_TOKEN = "[EOS]"
20
  UNK_TOKEN = "[UNK]"
21
-
22
- def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
23
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
 
 
 
 
 
 
24
 
25
- self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
26
- self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
27
- self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
28
- self.promotions = ["q", "r", "b", "n"] # promotion Token
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
30
  if vocab is not None:
31
  self._vocab = vocab
32
  elif vocab_file is not None and os.path.exists(vocab_file):
33
  with open(vocab_file, "r", encoding="utf-8") as f:
34
  self._vocab = json.load(f)
35
  else:
36
- self._vocab = {t: i for i, t in enumerate(special_tokens)}
37
- for cp in self.colors_pieces: self._vocab[cp] = len(self._vocab)
38
- for sq in self.squares: self._vocab[f"{sq}_f"] = len(self._vocab)
39
- for sq in self.squares: self._vocab[f"{sq}_t"] = len(self._vocab)
40
- for suf in self.suffixes: self._vocab[suf] = len(self._vocab)
41
- for promo in self.promotions: self._vocab[promo] = len(self._vocab)
42
-
43
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
44
-
 
45
  super().__init__(
46
- pad_token=self.PAD_TOKEN,
47
- bos_token=self.BOS_TOKEN,
48
- eos_token=self.EOS_TOKEN,
49
- unk_token=self.UNK_TOKEN,
50
  **kwargs,
51
  )
52
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @property
54
  def vocab_size(self) -> int:
55
  """Return the size of the vocabulary."""
 
1
+ """
2
+ Custom Chess Tokenizer for the Chess Challenge.
3
+
4
+ This tokenizer treats each move as a single token using the extended UCI notation
5
+ from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
+
7
+ The dataset format uses:
8
+ - W/B prefix for White/Black
9
+ - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
+ - Source and destination squares (e.g., e2e4)
11
+ - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
+ """
13
+
14
  from __future__ import annotations
15
+
16
  import json
17
  import os
18
+ from pathlib import Path
19
  from typing import Dict, List, Optional
20
+
21
  from transformers import PreTrainedTokenizer
22
+
23
 
24
  class ChessTokenizer(PreTrainedTokenizer):
25
  """
26
+ A custom tokenizer for chess moves using extended UCI notation.
27
+
28
+ This tokenizer maps each possible chess move to a unique token ID.
29
+ The vocabulary is built from the training dataset to ensure all moves
30
+ encountered during training have a corresponding token.
31
+
32
+ Example:
33
+ >>> tokenizer = ChessTokenizer()
34
+ >>> tokenizer.encode("WPe2e4 BPe7e5")
35
+ [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
  """
37
+
38
  model_input_names = ["input_ids", "attention_mask"]
39
  vocab_files_names = {"vocab_file": "vocab.json"}
40
+
41
+ # Special tokens
42
  PAD_TOKEN = "[PAD]"
43
  BOS_TOKEN = "[BOS]"
44
  EOS_TOKEN = "[EOS]"
45
  UNK_TOKEN = "[UNK]"
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file: Optional[str] = None,
50
+ vocab: Optional[Dict[str, int]] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Initialize the chess tokenizer.
55
 
56
+ Args:
57
+ vocab_file: Path to a JSON file containing the vocabulary mapping.
58
+ vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
+ **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
+ """
61
+ # Initialize special tokens
62
+ self._pad_token = self.PAD_TOKEN
63
+ self._bos_token = self.BOS_TOKEN
64
+ self._eos_token = self.EOS_TOKEN
65
+ self._unk_token = self.UNK_TOKEN
66
 
67
+ # Remove any duplicate special-token entries passed through kwargs
68
+ # to avoid "multiple values for keyword" errors when loading from disk.
69
+ kwargs.pop("pad_token", None)
70
+ kwargs.pop("bos_token", None)
71
+ kwargs.pop("eos_token", None)
72
+ kwargs.pop("unk_token", None)
73
+
74
+ # Load or create vocabulary
75
  if vocab is not None:
76
  self._vocab = vocab
77
  elif vocab_file is not None and os.path.exists(vocab_file):
78
  with open(vocab_file, "r", encoding="utf-8") as f:
79
  self._vocab = json.load(f)
80
  else:
81
+ # Create a minimal vocabulary with just special tokens
82
+ # The full vocabulary should be built from the dataset
83
+ self._vocab = self._create_default_vocab()
84
+
85
+ # Create reverse mapping
 
 
86
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
+
88
+ # Call parent init AFTER setting up vocab
89
  super().__init__(
90
+ pad_token=self._pad_token,
91
+ bos_token=self._bos_token,
92
+ eos_token=self._eos_token,
93
+ unk_token=self._unk_token,
94
  **kwargs,
95
  )
96
+
97
+ def _create_default_vocab(self) -> Dict[str, int]:
98
+ """
99
+ Create a minimal default vocabulary with just special tokens.
100
+
101
+ For the full vocabulary, use `build_vocab_from_dataset()`.
102
+ This minimal vocab is just a placeholder - you should build from data.
103
+ """
104
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
+ vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
+ return vocab
107
+
108
+ @classmethod
109
+ def build_vocab_from_iterator(
110
+ cls,
111
+ iterator,
112
+ min_frequency: int = 1,
113
+ ) -> "ChessTokenizer":
114
+ """
115
+ Build a tokenizer vocabulary from an iterator of game strings.
116
+
117
+ Args:
118
+ iterator: An iterator yielding game strings (space-separated moves).
119
+ min_frequency: Minimum frequency for a token to be included.
120
+
121
+ Returns:
122
+ A ChessTokenizer with the built vocabulary.
123
+ """
124
+ from collections import Counter
125
+
126
+ token_counts = Counter()
127
+
128
+ for game in iterator:
129
+ moves = game.strip().split()
130
+ token_counts.update(moves)
131
+
132
+ # Filter by frequency
133
+ tokens = [
134
+ token for token, count in token_counts.items()
135
+ if count >= min_frequency
136
+ ]
137
+
138
+ # Sort for reproducibility
139
+ tokens = sorted(tokens)
140
+
141
+ # Build vocabulary
142
+ special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
143
+ vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
144
+
145
+ return cls(vocab=vocab)
146
+
147
+ @classmethod
148
+ def build_vocab_from_dataset(
149
+ cls,
150
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
151
+ split: str = "train",
152
+ column: str = "text",
153
+ min_frequency: int = 500,
154
+ max_samples: Optional[int] = 100000,
155
+ ) -> "ChessTokenizer":
156
+ """
157
+ Build a tokenizer vocabulary from a Hugging Face dataset.
158
+
159
+ Args:
160
+ dataset_name: Name of the dataset on Hugging Face Hub.
161
+ split: Dataset split to use.
162
+ column: Column containing the game strings.
163
+ min_frequency: Minimum frequency for a token to be included (default: 500).
164
+ max_samples: Maximum number of samples to process (default: 100k).
165
+
166
+ Returns:
167
+ A ChessTokenizer with the built vocabulary.
168
+ """
169
+ from datasets import load_dataset
170
+
171
+ dataset = load_dataset(dataset_name, split=split)
172
+
173
+ if max_samples is not None:
174
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
175
+
176
+ def game_iterator():
177
+ for example in dataset:
178
+ yield example[column]
179
+
180
+ return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
181
+
182
  @property
183
  def vocab_size(self) -> int:
184
  """Return the size of the vocabulary."""