khanghoang0902 commited on
Commit
1302f66
·
verified ·
1 Parent(s): af201b1

final-model

Browse files
Files changed (4) hide show
  1. __init__.py +0 -0
  2. tokenizer.py +321 -0
  3. tokenizer_config.json +3 -6
  4. training_args.bin +3 -0
__init__.py ADDED
File without changes
tokenizer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Factorized UCI / verbose-UCI tokenizer for the Chess1MChallenge.
3
+
4
+ Input move examples:
5
+ - WPe2e4
6
+ - BNg8f6
7
+ - WQd1h5(x)+
8
+ - WKe1g1(o) # castling indicated by suffix in this dataset convention
9
+ - WPe7e8=Q # promotion styles vary; we support =Q or trailing q/r/b/n
10
+
11
+ We tokenize each move into a small sequence of tokens:
12
+ [SIDE] [PIECE] [FROM_SQ] [TO_SQ] (optional: [PROMO_*]) (optional: [CAPTURE]) (optional: [CHECK/MATE]) (optional: [CASTLE])
13
+
14
+ This keeps vocab small and compositional.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import os
21
+ import re
22
+ from typing import Dict, List, Optional, Tuple, Any, Union, Sequence
23
+ import torch
24
+ from transformers import PreTrainedTokenizer
25
+
26
+
27
+ class ChessTokenizer(PreTrainedTokenizer):
28
+ model_input_names = ["input_ids", "attention_mask"]
29
+ vocab_files_names = {"vocab_file": "vocab.json"}
30
+
31
+ PAD_TOKEN = "[PAD]"
32
+ BOS_TOKEN = "[BOS]"
33
+ EOS_TOKEN = "[EOS]"
34
+ UNK_TOKEN = "[UNK]"
35
+
36
+ # Token prefixes (purely cosmetic; still single tokens in vocab)
37
+ SIDE_W = "SIDE_W"
38
+ SIDE_B = "SIDE_B"
39
+
40
+ PIECES = ["P", "N", "B", "R", "Q", "K"]
41
+
42
+ # Suffix/flags
43
+ CAPTURE = "CAPTURE" # (x)
44
+ CHECK = "CHECK" # +
45
+ MATE = "MATE" # ++ or (+*)
46
+ CASTLE = "CASTLE" # (o) or (O)
47
+
48
+ PROMO_PREFIX = "PROMO_" # PROMO_Q, PROMO_R, PROMO_B, PROMO_N
49
+
50
+ # Regex for your verbose convention:
51
+ # <Side><Piece><from><to><optional_promo><optional_suffixes>
52
+ # side: W/B
53
+ # piece: P N B R Q K
54
+ # from/to: [a-h][1-8]
55
+ # promo: =Q or =q or trailing Q/q etc (we accept several)
56
+ MOVE_RE = re.compile(
57
+ r"^(?P<side>[WB])"
58
+ r"(?P<piece>[PNBRQK])"
59
+ r"(?P<from>[a-h][1-8])"
60
+ r"(?P<to>[a-h][1-8])"
61
+ r"(?P<rest>.*)$"
62
+ )
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_file: Optional[str] = None,
67
+ vocab: Optional[Dict[str, int]] = None,
68
+ **kwargs,
69
+ ):
70
+ # Avoid duplicate special tokens passed by HF loaders
71
+ kwargs.pop("pad_token", None)
72
+ kwargs.pop("bos_token", None)
73
+ kwargs.pop("eos_token", None)
74
+ kwargs.pop("unk_token", None)
75
+
76
+ self._pad_token = self.PAD_TOKEN
77
+ self._bos_token = self.BOS_TOKEN
78
+ self._eos_token = self.EOS_TOKEN
79
+ self._unk_token = self.UNK_TOKEN
80
+
81
+ if vocab is not None:
82
+ self._vocab = vocab
83
+ elif vocab_file is not None and os.path.exists(vocab_file):
84
+ with open(vocab_file, "r", encoding="utf-8") as f:
85
+ self._vocab = json.load(f)
86
+ else:
87
+ self._vocab = self._build_fixed_vocab()
88
+
89
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
90
+
91
+ super().__init__(
92
+ pad_token=self._pad_token,
93
+ bos_token=self._bos_token,
94
+ eos_token=self._eos_token,
95
+ unk_token=self._unk_token,
96
+ **kwargs,
97
+ )
98
+
99
+ def _build_fixed_vocab(self) -> Dict[str, int]:
100
+ special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
101
+
102
+ sides = [self.SIDE_W, self.SIDE_B]
103
+ pieces = [f"PIECE_{p}" for p in self.PIECES]
104
+
105
+ squares = [f"SQ_{file}{rank}" for file in "abcdefgh" for rank in "12345678"]
106
+
107
+ promos = [f"{self.PROMO_PREFIX}{p}" for p in ["Q", "R", "B", "N"]]
108
+
109
+ flags = [self.CAPTURE, self.CHECK, self.MATE, self.CASTLE]
110
+
111
+ tokens = special + sides + pieces + squares + promos + flags
112
+ return {tok: i for i, tok in enumerate(tokens)}
113
+
114
+ @property
115
+ def vocab_size(self) -> int:
116
+ return len(self._vocab)
117
+
118
+ def get_vocab(self) -> Dict[str, int]:
119
+ return dict(self._vocab)
120
+
121
+ # -------------------------
122
+ # Core tokenization methods
123
+ # -------------------------
124
+
125
+ def _tokenize(self, text: str) -> List[str]:
126
+ """
127
+ Tokenize a game string into factorized tokens.
128
+
129
+ Input is a space-separated move sequence.
130
+ Each move becomes multiple tokens.
131
+ """
132
+ out: List[str] = []
133
+ for move in text.strip().split():
134
+ out.extend(self._tokenize_move(move))
135
+ return out
136
+
137
+ def _tokenize_move(self, move: str) -> List[str]:
138
+ """
139
+ Convert one verbose-UCI move into tokens.
140
+ """
141
+ m = self.MOVE_RE.match(move)
142
+ if not m:
143
+ return [self.UNK_TOKEN]
144
+
145
+ side = m.group("side")
146
+ piece = m.group("piece")
147
+ frm = m.group("from")
148
+ to = m.group("to")
149
+ rest = m.group("rest") or ""
150
+
151
+ tokens: List[str] = []
152
+ tokens.append(self.SIDE_W if side == "W" else self.SIDE_B)
153
+ tokens.append(f"PIECE_{piece}")
154
+ tokens.append(f"SQ_{frm}")
155
+ tokens.append(f"SQ_{to}")
156
+
157
+ promo = self._parse_promotion(rest)
158
+ if promo is not None:
159
+ tokens.append(f"{self.PROMO_PREFIX}{promo}")
160
+
161
+ # Flags (order is fixed for determinism)
162
+ if "(x)" in rest:
163
+ tokens.append(self.CAPTURE)
164
+
165
+ # checkmate conventions vary; support "++" or "(+*)" (and also "#")
166
+ if "++" in rest or "(+*)" in rest or "#" in rest:
167
+ tokens.append(self.MATE)
168
+ elif "+" in rest:
169
+ tokens.append(self.CHECK)
170
+
171
+ # castling flag appears in this dataset as (o) or (O)
172
+ if "(o)" in rest or "(O)" in rest:
173
+ tokens.append(self.CASTLE)
174
+
175
+ return tokens
176
+
177
+ def _parse_promotion(self, rest: str) -> Optional[str]:
178
+ """
179
+ Detect promotion piece if present.
180
+ Accepts patterns like:
181
+ =Q, =q, e7e8Q, e7e8=q, etc.
182
+ Returns one of Q/R/B/N or None.
183
+ """
184
+ # Look for =<piece>
185
+ m = re.search(r"=([QRBNqrbn])", rest)
186
+ if m:
187
+ return m.group(1).upper()
188
+
189
+ # Or trailing promo letter (rare, but some formats do it)
190
+ m2 = re.search(r"([QRBNqrbn])", rest)
191
+ # Only treat as promo if it looks like a promo marker context (avoid grabbing random chars)
192
+ # Heuristic: promotion usually appears near end and not inside parentheses
193
+ if m2 and not "(" in rest:
194
+ # If rest is exactly "Q" or "q", accept
195
+ if rest.strip() in ["Q", "R", "B", "N", "q", "r", "b", "n"]:
196
+ return rest.strip().upper()
197
+
198
+ return None
199
+
200
+ def _convert_token_to_id(self, token: str) -> int:
201
+ return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
202
+
203
+ def _convert_id_to_token(self, index: int) -> str:
204
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
205
+
206
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
207
+ """
208
+ Convert factorized tokens back into a move string sequence.
209
+ This expects tokens to be aligned in move-chunks, so it’s mostly for debugging.
210
+ """
211
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
212
+ toks = [t for t in tokens if t not in special]
213
+
214
+ # Re-chunk by detecting SIDE tokens as move boundaries.
215
+ moves: List[str] = []
216
+ i = 0
217
+ while i < len(toks):
218
+ if toks[i] not in (self.SIDE_W, self.SIDE_B):
219
+ # If misaligned, skip until next SIDE
220
+ i += 1
221
+ continue
222
+ move_str, next_i = self._decode_one_move(toks, i)
223
+ moves.append(move_str)
224
+ i = next_i
225
+
226
+ return " ".join(moves)
227
+
228
+ def _decode_one_move(self, toks: List[str], i: int) -> Tuple[str, int]:
229
+ """
230
+ Decode a single move starting at index i (which should be SIDE_*).
231
+ Returns (move_string, next_index).
232
+ """
233
+ side_tok = toks[i]
234
+ side = "W" if side_tok == self.SIDE_W else "B"
235
+
236
+ # Need at least PIECE + FROM + TO
237
+ if i + 3 >= len(toks):
238
+ return "", i + 1
239
+
240
+ piece_tok = toks[i + 1]
241
+ from_tok = toks[i + 2]
242
+ to_tok = toks[i + 3]
243
+
244
+ if not piece_tok.startswith("PIECE_") or not from_tok.startswith("SQ_") or not to_tok.startswith("SQ_"):
245
+ return "", i + 1
246
+
247
+ piece = piece_tok.replace("PIECE_", "")
248
+ frm = from_tok.replace("SQ_", "")
249
+ to = to_tok.replace("SQ_", "")
250
+
251
+ j = i + 4
252
+ promo = None
253
+ flags: List[str] = []
254
+
255
+ # Read until next SIDE token or end
256
+ while j < len(toks) and toks[j] not in (self.SIDE_W, self.SIDE_B):
257
+ t = toks[j]
258
+ if t.startswith(self.PROMO_PREFIX):
259
+ promo = t.replace(self.PROMO_PREFIX, "")
260
+ elif t in (self.CAPTURE, self.CHECK, self.MATE, self.CASTLE):
261
+ flags.append(t)
262
+ j += 1
263
+
264
+ rest = ""
265
+ if promo is not None:
266
+ rest += f"={promo}"
267
+
268
+ # Match your dataset-style suffixes
269
+ if self.CAPTURE in flags:
270
+ rest += "(x)"
271
+ if self.MATE in flags:
272
+ rest += "++" # or "(+*)" if your dataset prefers that
273
+ elif self.CHECK in flags:
274
+ rest += "+"
275
+ if self.CASTLE in flags:
276
+ rest += "(o)"
277
+
278
+ return f"{side}{piece}{frm}{to}{rest}", j
279
+
280
+ # ---------------
281
+ # Saving/loading
282
+ # ---------------
283
+
284
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
285
+ if not os.path.isdir(save_directory):
286
+ os.makedirs(save_directory, exist_ok=True)
287
+
288
+ vocab_file = os.path.join(
289
+ save_directory,
290
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
291
+ )
292
+ with open(vocab_file, "w", encoding="utf-8") as f:
293
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
294
+
295
+ return (vocab_file,)
296
+
297
+ def decode(
298
+ self,
299
+ token_ids: Union[int, Sequence[int], torch.Tensor],
300
+ skip_special_tokens: bool = False,
301
+ clean_up_tokenization_spaces: bool = False,
302
+ **kwargs: Any,
303
+ ) -> str:
304
+ # Normalize input type
305
+ if isinstance(token_ids, int):
306
+ ids = [token_ids]
307
+ elif "torch" in str(type(token_ids)):
308
+ # torch.Tensor
309
+ ids = token_ids.detach().cpu().flatten().tolist()
310
+ else:
311
+ ids = list(token_ids)
312
+
313
+ # Convert ids -> token strings
314
+ toks = [self._convert_id_to_token(i) for i in ids]
315
+
316
+ if skip_special_tokens:
317
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
318
+ toks = [t for t in toks if t not in special]
319
+
320
+ # IMPORTANT: reconstruct chess moves instead of joining token names
321
+ return self.convert_tokens_to_string(toks)
tokenizer_config.json CHANGED
@@ -33,12 +33,6 @@
33
  "special": true
34
  }
35
  },
36
- "auto_map": {
37
- "AutoTokenizer": [
38
- "tokenizer.py",
39
- "ChessTokenizer"
40
- ]
41
- },
42
  "bos_token": "[BOS]",
43
  "clean_up_tokenization_spaces": false,
44
  "eos_token": "[EOS]",
@@ -46,5 +40,8 @@
46
  "model_max_length": 1000000000000000019884624838656,
47
  "pad_token": "[PAD]",
48
  "tokenizer_class": "ChessTokenizer",
 
 
 
49
  "unk_token": "[UNK]"
50
  }
 
33
  "special": true
34
  }
35
  },
 
 
 
 
 
 
36
  "bos_token": "[BOS]",
37
  "clean_up_tokenization_spaces": false,
38
  "eos_token": "[EOS]",
 
40
  "model_max_length": 1000000000000000019884624838656,
41
  "pad_token": "[PAD]",
42
  "tokenizer_class": "ChessTokenizer",
43
+ "auto_map": {
44
+ "AutoTokenizer": ["tokenizer.py", "ChessTokenizer"]
45
+ },
46
  "unk_token": "[UNK]"
47
  }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae24236124c061824973c9e4f4340e7186cd3564179ff2db75e0a72e99907bbe
3
+ size 5777