Sunxt25 commited on
Commit
2ab9c78
·
verified ·
1 Parent(s): 09b738c

Upload tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +84 -237
tokenizer.py CHANGED
@@ -1,278 +1,125 @@
1
- """
2
- Custom Chess Tokenizer for the Chess Challenge.
3
-
4
- This tokenizer treats each move as a single token using the extended UCI notation
5
- from the Lichess dataset (e.g., WPe2e4, BNg8f6).
6
-
7
- The dataset format uses:
8
- - W/B prefix for White/Black
9
- - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
- - Source and destination squares (e.g., e2e4)
11
- - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
- """
13
-
14
  from __future__ import annotations
15
-
16
  import json
17
  import os
18
- from pathlib import Path
19
  from typing import Dict, List, Optional
20
-
21
  from transformers import PreTrainedTokenizer
22
-
23
 
24
  class ChessTokenizer(PreTrainedTokenizer):
25
  """
26
- A custom tokenizer for chess moves using extended UCI notation.
27
-
28
- This tokenizer maps each possible chess move to a unique token ID.
29
- The vocabulary is built from the training dataset to ensure all moves
30
- encountered during training have a corresponding token.
31
-
32
- Example:
33
- >>> tokenizer = ChessTokenizer()
34
- >>> tokenizer.encode("WPe2e4 BPe7e5")
35
- [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
  """
37
-
38
  model_input_names = ["input_ids", "attention_mask"]
39
  vocab_files_names = {"vocab_file": "vocab.json"}
40
-
41
- # Special tokens
42
  PAD_TOKEN = "[PAD]"
43
  BOS_TOKEN = "[BOS]"
44
  EOS_TOKEN = "[EOS]"
45
  UNK_TOKEN = "[UNK]"
46
-
47
- def __init__(
48
- self,
49
- vocab_file: Optional[str] = None,
50
- vocab: Optional[Dict[str, int]] = None,
51
- **kwargs,
52
- ):
53
- """
54
- Initialize the chess tokenizer.
55
-
56
- Args:
57
- vocab_file: Path to a JSON file containing the vocabulary mapping.
58
- vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
- **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
- """
61
- # Initialize special tokens
62
- self._pad_token = self.PAD_TOKEN
63
- self._bos_token = self.BOS_TOKEN
64
- self._eos_token = self.EOS_TOKEN
65
- self._unk_token = self.UNK_TOKEN
66
 
67
- # Remove any duplicate special-token entries passed through kwargs
68
- # to avoid "multiple values for keyword" errors when loading from disk.
69
- kwargs.pop("pad_token", None)
70
- kwargs.pop("bos_token", None)
71
- kwargs.pop("eos_token", None)
72
- kwargs.pop("unk_token", None)
73
 
74
- # Load or create vocabulary
 
 
 
75
  if vocab is not None:
76
  self._vocab = vocab
77
  elif vocab_file is not None and os.path.exists(vocab_file):
78
  with open(vocab_file, "r", encoding="utf-8") as f:
79
  self._vocab = json.load(f)
80
  else:
81
- # Create a minimal vocabulary with just special tokens
82
- # The full vocabulary should be built from the dataset
83
- self._vocab = self._create_default_vocab()
84
-
85
- # Create reverse mapping
 
 
 
 
 
86
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
-
88
- # Call parent init AFTER setting up vocab
89
  super().__init__(
90
- pad_token=self._pad_token,
91
- bos_token=self._bos_token,
92
- eos_token=self._eos_token,
93
- unk_token=self._unk_token,
94
  **kwargs,
95
  )
96
-
97
- def _create_default_vocab(self) -> Dict[str, int]:
98
- """
99
- Create a minimal default vocabulary with just special tokens.
100
-
101
- For the full vocabulary, use `build_vocab_from_dataset()`.
102
- This minimal vocab is just a placeholder - you should build from data.
103
- """
104
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
- vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
- return vocab
107
-
108
- @classmethod
109
- def build_vocab_from_iterator(
110
- cls,
111
- iterator,
112
- min_frequency: int = 1,
113
- ) -> "ChessTokenizer":
114
- """
115
- Build a tokenizer vocabulary from an iterator of game strings.
116
-
117
- Args:
118
- iterator: An iterator yielding game strings (space-separated moves).
119
- min_frequency: Minimum frequency for a token to be included.
120
-
121
- Returns:
122
- A ChessTokenizer with the built vocabulary.
123
- """
124
- from collections import Counter
125
-
126
- token_counts = Counter()
127
-
128
- for game in iterator:
129
- moves = game.strip().split()
130
- token_counts.update(moves)
131
-
132
- # Filter by frequency
133
- tokens = [
134
- token for token, count in token_counts.items()
135
- if count >= min_frequency
136
- ]
137
-
138
- # Sort for reproducibility
139
- tokens = sorted(tokens)
140
-
141
- # Build vocabulary
142
- special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
143
- vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
144
-
145
- return cls(vocab=vocab)
146
-
147
- @classmethod
148
- def build_vocab_from_dataset(
149
- cls,
150
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
151
- split: str = "train",
152
- column: str = "text",
153
- min_frequency: int = 500,
154
- max_samples: Optional[int] = 100000,
155
- ) -> "ChessTokenizer":
156
- """
157
- Build a tokenizer vocabulary from a Hugging Face dataset.
158
-
159
- Args:
160
- dataset_name: Name of the dataset on Hugging Face Hub.
161
- split: Dataset split to use.
162
- column: Column containing the game strings.
163
- min_frequency: Minimum frequency for a token to be included (default: 500).
164
- max_samples: Maximum number of samples to process (default: 100k).
165
-
166
- Returns:
167
- A ChessTokenizer with the built vocabulary.
168
- """
169
- from datasets import load_dataset
170
-
171
- dataset = load_dataset(dataset_name, split=split)
172
-
173
- if max_samples is not None:
174
- dataset = dataset.select(range(min(max_samples, len(dataset))))
175
-
176
- def game_iterator():
177
- for example in dataset:
178
- yield example[column]
179
-
180
- return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
181
-
182
  @property
183
  def vocab_size(self) -> int:
184
- """Return the size of the vocabulary."""
185
  return len(self._vocab)
186
-
187
  def get_vocab(self) -> Dict[str, int]:
188
- """Return the vocabulary as a dictionary."""
189
  return dict(self._vocab)
190
-
191
  def _tokenize(self, text: str) -> List[str]:
192
- """
193
- Tokenize a string of moves into a list of tokens.
194
-
195
- Args:
196
- text: A string of space-separated moves.
197
-
198
- Returns:
199
- List of move tokens.
200
- """
201
- return text.strip().split()
202
-
203
- def _convert_token_to_id(self, token: str) -> int:
204
- """Convert a token to its ID."""
205
- return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
206
-
207
  def _convert_id_to_token(self, index: int) -> str:
208
- """Convert an ID to its token."""
209
- return self._ids_to_tokens.get(index, self.UNK_TOKEN)
210
-
 
 
 
211
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
212
- """Convert a list of tokens back to a string."""
213
- # Filter out special tokens for cleaner output
214
- special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
215
- return " ".join(t for t in tokens if t not in special)
216
-
217
- def save_vocabulary(
218
- self,
219
- save_directory: str,
220
- filename_prefix: Optional[str] = None,
221
- ) -> tuple:
222
- """
223
- Save the vocabulary to a JSON file.
224
-
225
- Args:
226
- save_directory: Directory to save the vocabulary.
227
- filename_prefix: Optional prefix for the filename.
228
-
229
- Returns:
230
- Tuple containing the path to the saved vocabulary file.
231
- """
 
 
 
 
 
232
  if not os.path.isdir(save_directory):
233
  os.makedirs(save_directory, exist_ok=True)
234
-
235
- vocab_file = os.path.join(
236
- save_directory,
237
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
238
- )
239
-
240
  with open(vocab_file, "w", encoding="utf-8") as f:
241
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
242
-
243
  return (vocab_file,)
244
 
245
-
246
- def count_vocab_from_dataset(
247
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
248
- split: str = "train",
249
- column: str = "text",
250
- max_samples: Optional[int] = 10000,
251
- ) -> Dict[str, int]:
252
- """
253
- Count token frequencies in a dataset (useful for vocabulary analysis).
254
-
255
- Args:
256
- dataset_name: Name of the dataset on Hugging Face Hub.
257
- split: Dataset split to use.
258
- column: Column containing the game strings.
259
- max_samples: Maximum number of samples to process.
260
-
261
- Returns:
262
- Dictionary mapping tokens to their frequencies.
263
- """
264
- from collections import Counter
265
- from datasets import load_dataset
266
-
267
- dataset = load_dataset(dataset_name, split=split)
268
-
269
- if max_samples is not None:
270
- dataset = dataset.select(range(min(max_samples, len(dataset))))
271
-
272
- token_counts = Counter()
273
-
274
- for example in dataset:
275
- moves = example[column].strip().split()
276
- token_counts.update(moves)
277
-
278
- return dict(token_counts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
 
2
  import json
3
  import os
 
4
  from typing import Dict, List, Optional
 
5
  from transformers import PreTrainedTokenizer
6
+ import torch
7
 
8
  class ChessTokenizer(PreTrainedTokenizer):
9
  """
10
+ 符合评估脚本要求的 Chess Tokenizer。
11
+ 词表大小: 149 (4 special + 12 pieces + 64 from_sq + 64 to_sq + 5 suffix)
 
 
 
 
 
 
 
 
12
  """
13
+
14
  model_input_names = ["input_ids", "attention_mask"]
15
  vocab_files_names = {"vocab_file": "vocab.json"}
16
+
 
17
  PAD_TOKEN = "[PAD]"
18
  BOS_TOKEN = "[BOS]"
19
  EOS_TOKEN = "[EOS]"
20
  UNK_TOKEN = "[UNK]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
23
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
 
 
 
 
24
 
25
+ self.colors_pieces = [f'{c}{p}' for c in ['W','B'] for p in ['P','N','B','R','Q','K']]
26
+ self.squares = [f'{f}{r}' for r in '12345678' for f in 'abcdefgh']
27
+ self.suffixes = ["(x)", "(+)", "(+*)", "(o)", "(O)"]
28
+
29
  if vocab is not None:
30
  self._vocab = vocab
31
  elif vocab_file is not None and os.path.exists(vocab_file):
32
  with open(vocab_file, "r", encoding="utf-8") as f:
33
  self._vocab = json.load(f)
34
  else:
35
+ self._vocab = {t: i for i, t in enumerate(special_tokens)}
36
+ for cp in self.colors_pieces:
37
+ self._vocab[cp] = len(self._vocab)
38
+ for sq in self.squares:
39
+ self._vocab[f"{sq}_f"] = len(self._vocab)
40
+ for sq in self.squares:
41
+ self._vocab[f"{sq}_t"] = len(self._vocab)
42
+ for suf in self.suffixes:
43
+ self._vocab[suf] = len(self._vocab)
44
+
45
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
46
+
 
47
  super().__init__(
48
+ pad_token=self.PAD_TOKEN,
49
+ bos_token=self.BOS_TOKEN,
50
+ eos_token=self.EOS_TOKEN,
51
+ unk_token=self.UNK_TOKEN,
52
  **kwargs,
53
  )
54
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @property
56
  def vocab_size(self) -> int:
 
57
  return len(self._vocab)
58
+
59
  def get_vocab(self) -> Dict[str, int]:
 
60
  return dict(self._vocab)
61
+
62
  def _tokenize(self, text: str) -> List[str]:
63
+ """关键:支持识别带后缀的 token,让 eval 识别为 decomposed 模式"""
64
+ tokens = []
65
+ parts = text.strip().split()
66
+ for part in parts:
67
+ if part in self._vocab:
68
+ tokens.append(part)
69
+ elif len(part) >= 6: # 处理 WPe2e4 紧凑格式
70
+ piece, f_sq, t_sq = part[:2], part[2:4] + "_f", part[4:6] + "_t"
71
+ if piece in self._vocab: tokens.append(piece)
72
+ if f_sq in self._vocab: tokens.append(f_sq)
73
+ if t_sq in self._vocab: tokens.append(t_sq)
74
+ if len(part) > 6 and part[6:] in self.suffixes:
75
+ tokens.append(part[6:])
76
+ return tokens
77
+
78
  def _convert_id_to_token(self, index: int) -> str:
79
+ """关键:去掉后缀,让 eval 的正则 [a-h][1-8] 能抓到坐标"""
80
+ token = self._ids_to_tokens.get(index, self.UNK_TOKEN)
81
+ if token in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]:
82
+ return ""
83
+ return token.replace("_f", "").replace("_t", "")
84
+
85
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
86
+ """关键:在 Piece 前加空格,确保历史棋局格式正确"""
87
+ res = []
88
+ for t in tokens:
89
+ if not t: continue
90
+ # 如果是棋子 token,说明是新 move,加空格
91
+ if len(t) == 2 and (t.startswith('W') or t.startswith('B')):
92
+ res.append(" " + t)
93
+ else:
94
+ res.append(t)
95
+ return "".join(res).strip()
96
+ def _convert_token_to_id(self, token: str) -> int:
97
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
98
+
99
+ def decode(self, token_ids, skip_special_tokens=True, **kwargs) -> str:
100
+ if hasattr(token_ids, "tolist"):
101
+ ids = token_ids.tolist()
102
+ elif isinstance(token_ids, (int, torch.LongTensor, torch.IntTensor)):
103
+ ids = [int(token_ids)] if isinstance(token_ids, int) else token_ids.tolist()
104
+ else:
105
+ ids = token_ids
106
+
107
+ tokens = [self._convert_id_to_token(i) for i in ids]
108
+ return self.convert_tokens_to_string(tokens)
109
+
110
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
111
  if not os.path.isdir(save_directory):
112
  os.makedirs(save_directory, exist_ok=True)
113
+ vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
 
 
 
 
 
114
  with open(vocab_file, "w", encoding="utf-8") as f:
115
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
 
116
  return (vocab_file,)
117
 
118
+ @classmethod
119
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "ChessTokenizer":
120
+ vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
121
+ if not os.path.exists(vocab_file):
122
+ return cls()
123
+ with open(vocab_file, "r", encoding="utf-8") as f:
124
+ vocab = json.load(f)
125
+ return cls(vocab=vocab, **kwargs)