Giu0804 commited on
Commit
10fd3a5
·
verified ·
1 Parent(s): 28e2d18

Chess Challenge submission by Giu0804

Browse files
Files changed (1) hide show
  1. tokenizer_v2.py +178 -0
tokenizer_v2.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Coordinate Chess Tokenizer (Vocab Size = 72).
3
+ Compatible with Hugging Face AutoTokenizer and existing Evaluation scripts.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ import re
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+
13
+ from transformers import PreTrainedTokenizer
14
+
15
+ class ChessTokenizer(PreTrainedTokenizer):
16
+ model_input_names = ["input_ids", "attention_mask"]
17
+ vocab_files_names = {"vocab_file": "vocab.json"}
18
+
19
+ # Special tokens
20
+ PAD_TOKEN = "[PAD]"
21
+ BOS_TOKEN = "[BOS]"
22
+ EOS_TOKEN = "[EOS]"
23
+ UNK_TOKEN = "[UNK]"
24
+
25
+ # Regex to capture coordinates and promotions from any format (UCI, SAN, Extended)
26
+ # Captures: "e2", "e4", "q" inside strings like "WPe2e4" or "e2e4q"
27
+ MOVE_REGEX = re.compile(r"([a-h][1-8])([a-h][1-8])([qrbn])?")
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_file: Optional[str] = None,
32
+ **kwargs,
33
+ ):
34
+ # Initialize special tokens
35
+ self._pad_token = self.PAD_TOKEN
36
+ self._bos_token = self.BOS_TOKEN
37
+ self._eos_token = self.EOS_TOKEN
38
+ self._unk_token = self.UNK_TOKEN
39
+
40
+ # Clean kwargs to avoid duplication errors during loading
41
+ kwargs.pop("pad_token", None)
42
+ kwargs.pop("bos_token", None)
43
+ kwargs.pop("eos_token", None)
44
+ kwargs.pop("unk_token", None)
45
+
46
+ # 1. Load or Create Vocabulary
47
+ # If a vocab_file is provided (loading from HF), use it.
48
+ # Otherwise, create the fixed 72-token vocabulary.
49
+ if vocab_file is not None and os.path.exists(vocab_file):
50
+ with open(vocab_file, "r", encoding="utf-8") as f:
51
+ self._vocab = json.load(f)
52
+ else:
53
+ self._vocab = self._create_fixed_vocab()
54
+
55
+ self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
56
+
57
+ super().__init__(
58
+ pad_token=self._pad_token,
59
+ bos_token=self._bos_token,
60
+ eos_token=self._eos_token,
61
+ unk_token=self._unk_token,
62
+ **kwargs,
63
+ )
64
+
65
+ def _create_fixed_vocab(self) -> Dict[str, int]:
66
+ """Creates the deterministic 72-token vocabulary."""
67
+ vocab = {}
68
+
69
+ # 0-3: Special Tokens
70
+ special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
71
+ for idx, token in enumerate(special_tokens):
72
+ vocab[token] = idx
73
+
74
+ # 4-7: Promotions (q, r, b, n)
75
+ promotions = ["q", "r", "b", "n"]
76
+ for idx, token in enumerate(promotions):
77
+ vocab[token] = len(vocab)
78
+
79
+ # 8-71: Squares (a1...h8)
80
+ files = "abcdefgh"
81
+ ranks = "12345678"
82
+ for r in ranks:
83
+ for f in files:
84
+ square = f + r
85
+ vocab[square] = len(vocab)
86
+
87
+ return vocab
88
+
89
+ @property
90
+ def vocab_size(self) -> int:
91
+ return len(self._vocab)
92
+
93
+ def get_vocab(self) -> Dict[str, int]:
94
+ return dict(self._vocab)
95
+
96
+ def _tokenize(self, text: str) -> List[str]:
97
+ """
98
+ Robust tokenization handling both raw coordinates and 'dirty' UCI extended strings.
99
+ """
100
+ tokens = []
101
+ # Split by whitespace first
102
+ raw_chunks = text.strip().split()
103
+
104
+ # Set of exact match tokens to preserve special tokens
105
+ special_set = {self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN}
106
+
107
+ for chunk in raw_chunks:
108
+ # If it's explicitly a special token, keep it
109
+ if chunk in special_set:
110
+ tokens.append(chunk)
111
+ continue
112
+
113
+ # Otherwise, use Regex to extract coordinates
114
+ # This handles "WPe2e4" -> ["e2", "e4"]
115
+ # And "e2e4" -> ["e2", "e4"]
116
+ match = self.MOVE_REGEX.search(chunk)
117
+ if match:
118
+ start_sq, end_sq, promotion = match.groups()
119
+ tokens.append(start_sq)
120
+ tokens.append(end_sq)
121
+ if promotion:
122
+ tokens.append(promotion)
123
+ else:
124
+ # If regex fails but it is in our vocab (e.g. isolated 'a1'), take it
125
+ if chunk in self._vocab:
126
+ tokens.append(chunk)
127
+ else:
128
+ tokens.append(self.UNK_TOKEN)
129
+
130
+ return tokens
131
+
132
+ def _convert_token_to_id(self, token: str) -> int:
133
+ return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
134
+
135
+ def _convert_id_to_token(self, index: int) -> str:
136
+ return self._ids_to_tokens.get(index, self.UNK_TOKEN)
137
+
138
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
139
+ """
140
+ Reconstructs string. Important: adds spaces between coordinates.
141
+ Evaluate.py handles spaces fine via regex.
142
+ """
143
+ special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
144
+ clean_tokens = [t for t in tokens if t not in special]
145
+ return " ".join(clean_tokens)
146
+
147
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
148
+ """
149
+ Vital for Hugging Face: saves the vocab.json to the directory.
150
+ """
151
+ if not os.path.isdir(save_directory):
152
+ os.makedirs(save_directory, exist_ok=True)
153
+
154
+ vocab_file = os.path.join(
155
+ save_directory,
156
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
157
+ )
158
+
159
+ with open(vocab_file, "w", encoding="utf-8") as f:
160
+ json.dump(self._vocab, f, ensure_ascii=False, indent=2)
161
+
162
+ return (vocab_file,)
163
+
164
+ @classmethod
165
+ def build_vocab_from_dataset(
166
+ cls,
167
+ dataset_name: str = "dlouapre/lichess_2025-01_1M",
168
+ split: str = "train",
169
+ column: str = "text",
170
+ min_frequency: int = 500, # Ignored
171
+ max_samples: Optional[int] = 100000, # Ignored
172
+ ) -> "ChessTokenizer":
173
+ """
174
+ Mock implementation to satisfy train.py API.
175
+ Ignores dataset scanning since vocab is fixed.
176
+ """
177
+ print(f"Coordinate Tokenizer: Using fixed vocabulary (size 72). Ignoring dataset scan.")
178
+ return cls()