gbl1357 commited on
Commit
cf87091
·
verified ·
1 Parent(s): 4397aa6

Update tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +215 -0
tokenizer.py CHANGED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/tokenizer.py
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ # --- Fixed vocab pieces ---
11
+ _SQUARES = [f"{file}{rank}" for rank in "12345678" for file in "abcdefgh"]
12
+ _PROMOS = ["=Q", "=R", "=B", "=N"]
13
+
14
+
15
+ class SquaresOnlyChessTokenizer(PreTrainedTokenizer):
16
+ """
17
+ Tokenizer designed to MINIMIZE illegal-move formatting issues under the provided evaluate.py,
18
+ WITHOUT modifying evaluate.py.
19
+
20
+ Key idea:
21
+ - evaluate.py extracts UCI using move_token[2:4] + move_token[4:6]
22
+ - so decoded move strings must look like: "W" + <any char> + from_sq + to_sq [+ "=Q/R/B/N"]
23
+ e.g. "WPe2e4", "WNg8f6", "WPe7e8=Q"
24
+ - evaluate.py stops generation on whitespace; we therefore include a SPACE token as a move separator.
25
+
26
+ Encoding (per move):
27
+ from_sq, to_sq, promo? , " " (space is a separator token)
28
+
29
+ Decoding (per move):
30
+ "WP" + from_sq + to_sq + promo? (constant prefix)
31
+
32
+ We strip all suffixes like (x), (+), (+*), (o)/(O) since evaluator doesn't use them.
33
+ """
34
+
35
+ vocab_files_names = {"vocab_file": "vocab.json"}
36
+ model_input_names = ["input_ids", "attention_mask"]
37
+
38
+ PAD_TOKEN = "[PAD]"
39
+ BOS_TOKEN = "[BOS]"
40
+ EOS_TOKEN = "[EOS]"
41
+ UNK_TOKEN = "[UNK]"
42
+
43
+ MOVE_SEP = " " # IMPORTANT: whitespace => evaluator stops on separator
44
+
45
+ def __init__(
46
+ self,
47
+ vocab: Optional[Dict[str, int]] = None,
48
+ vocab_file: Optional[str] = None,
49
+ **kwargs,
50
+ ):
51
+ # Avoid duplicates when loading/saving
52
+ kwargs.pop("pad_token", None)
53
+ kwargs.pop("bos_token", None)
54
+ kwargs.pop("eos_token", None)
55
+ kwargs.pop("unk_token", None)
56
+
57
+ self._pad_token = self.PAD_TOKEN
58
+ self._bos_token = self.BOS_TOKEN
59
+ self._eos_token = self.EOS_TOKEN
60
+ self._unk_token = self.UNK_TOKEN
61
+
62
+ if vocab is not None:
63
+ self._vocab = vocab
64
+ elif vocab_file is not None and os.path.exists(vocab_file):
65
+ with open(vocab_file, "r", encoding="utf-8") as f:
66
+ self._vocab = json.load(f)
67
+ else:
68
+ self._vocab = self._build_fixed_vocab()
69
+
70
+ self._ids_to_tokens = {i: t for t, i in self._vocab.items()}
71
+
72
+ super().__init__(
73
+ pad_token=self._pad_token,
74
+ bos_token=self._bos_token,
75
+ eos_token=self._eos_token,
76
+ unk_token=self._unk_token,
77
+ **kwargs,
78
+ )
79
+
80
+ # -------------------------
81
+ # Vocab
82
+ # -------------------------
83
+ @classmethod
84
+ def _build_fixed_vocab(cls) -> Dict[str, int]:
85
+ toks = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
86
+ toks += [cls.MOVE_SEP]
87
+ toks += _SQUARES
88
+ toks += _PROMOS
89
+ return {t: i for i, t in enumerate(toks)}
90
+
91
+ @property
92
+ def vocab_size(self) -> int:
93
+ return len(self._vocab)
94
+
95
+ def get_vocab(self) -> Dict[str, int]:
96
+ return dict(self._vocab)
97
+
98
+ # -------------------------
99
+ # Helpers: parse / normalize
100
+ # -------------------------
101
+ @staticmethod
102
+ def _strip_suffixes(token: str) -> str:
103
+ # Remove "(x)" "(+)" "(+*)" "(o)" "(O)" etc.
104
+ return token.split("(", 1)[0]
105
+
106
+ @staticmethod
107
+ def _extract_squares_and_promo(base: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
108
+ """
109
+ base expected like:
110
+ WPe2e4
111
+ BNg8f6
112
+ WPe7e8=Q
113
+ Return: (from_sq, to_sq, promo_token like '=Q' or None)
114
+ """
115
+ if len(base) < 6:
116
+ return None, None, None
117
+ from_sq = base[2:4].lower()
118
+ to_sq = base[4:6].lower()
119
+ if from_sq not in _SQUARES or to_sq not in _SQUARES:
120
+ return None, None, None
121
+
122
+ promo = None
123
+ if "=" in base:
124
+ promo = base[base.index("="):].upper() # "=Q"
125
+ if promo not in _PROMOS:
126
+ promo = None
127
+ return from_sq, to_sq, promo
128
+
129
+ # -------------------------
130
+ # Tokenization API
131
+ # -------------------------
132
+ def _tokenize(self, text: str) -> List[str]:
133
+ """
134
+ Tokenize a string of moves (space-separated).
135
+ Special tokens are preserved if present.
136
+ Each move becomes: from, to, promo?, " "
137
+ """
138
+ raw = text.strip().split()
139
+ out: List[str] = []
140
+
141
+ for tok in raw:
142
+ if tok in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN):
143
+ out.append(tok)
144
+ continue
145
+
146
+ base = self._strip_suffixes(tok)
147
+ from_sq, to_sq, promo = self._extract_squares_and_promo(base)
148
+
149
+ if from_sq is None or to_sq is None:
150
+ out.append(self.UNK_TOKEN)
151
+ out.append(self.MOVE_SEP)
152
+ continue
153
+
154
+ out.append(from_sq)
155
+ out.append(to_sq)
156
+ if promo is not None:
157
+ out.append(promo)
158
+ out.append(self.MOVE_SEP)
159
+
160
+ return out
161
+
162
+ def _convert_token_to_id(self, token: str) -> int:
163
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
164
+
165
+ def _convert_id_to_token(self, index: int) -> str:
166
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
167
+
168
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
169
+ """
170
+ Reconstruct a text compatible with evaluate.py.
171
+ Each move is rendered as: "WP" + from + to + promo?
172
+ Moves are separated by actual spaces (MOVE_SEP token).
173
+ """
174
+ s: List[str] = []
175
+ at_move_start = True
176
+
177
+ for tok in tokens:
178
+ if tok in (self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN):
179
+ continue
180
+
181
+ if tok == self.MOVE_SEP:
182
+ s.append(" ")
183
+ at_move_start = True
184
+ continue
185
+
186
+ if tok in _PROMOS:
187
+ s.append(tok)
188
+ continue
189
+
190
+ if tok in _SQUARES:
191
+ if at_move_start:
192
+ s.append("WP") # constant prefix, starts with 'W'
193
+ at_move_start = False
194
+ s.append(tok)
195
+ continue
196
+
197
+ # Fallback (should be rare)
198
+ if at_move_start:
199
+ s.append("WP")
200
+ at_move_start = False
201
+ s.append(tok)
202
+
203
+ return "".join(s)
204
+
205
+ # -------------------------
206
+ # Saving / loading
207
+ # -------------------------
208
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
209
+ os.makedirs(save_directory, exist_ok=True)
210
+ path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
211
+ with open(path, "w", encoding="utf-8") as f:
212
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
213
+ return (path,)
214
+
215
+ ChessTokenizer = SquaresOnlyChessTokenizer