File size: 7,523 Bytes
6a4d75f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """
Custom Chess Tokenizer for the Chess1MChallenge.
Goal: maximize legal-move rate in the evaluation.
Key idea:
- The evaluator only needs to recover the UCI move (e.g. "e2e4") from the model output.
It extracts squares like [a-h][1-8] and builds a move from the first 2 squares.
- So we normalize dataset tokens like "WPe2e4(x+)" to plain UCI "e2e4" (plus promotion suffix "q/r/b/n").
- We use a FIXED UCI vocabulary so there is (almost) no OOV -> far fewer [UNK] -> higher legal-move rate.
Vocabulary:
- All from-to square pairs: "a1a2", ..., excluding from==to.
- All promotion moves: e7e8[qrbn], a2a1[qrbn], including capture-promotions (still covered by from-to).
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
_SQUARE_RE = re.compile(r"[a-h][1-8]")
_PROMO_RE = re.compile(r"=([QRBN])") # dataset often uses "=Q"
class ChessTokenizer(PreTrainedTokenizer):
"""
Tokenizer that maps each chess move to a single token.
It is compatible with Hugging Face `AutoTokenizer(..., trust_remote_code=True)`.
Notes:
- Input text may contain "extended UCI" tokens from the Lichess dataset
(e.g. "WPe2e4", "BKe8g8(O)", "WPe7e8=Q(+)" ...).
- We normalize those tokens to plain UCI: "e2e4", "e8g8", "e7e8q", ...
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Avoid duplicate special-token args when loading from disk
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
return {tok: i for i, tok in enumerate(special)}
@staticmethod
def _normalize_one_token(tok: str) -> str:
"""
Convert an extended token to plain UCI.
Examples:
"WPe2e4" -> "e2e4"
"BKe8g8(O)" -> "e8g8"
"WPe7e8=Q(+)" -> "e7e8q"
"WPe5d6(x)" -> "e5d6"
"""
squares = _SQUARE_RE.findall(tok)
if len(squares) < 2:
return ChessTokenizer.UNK_TOKEN
uci = squares[0] + squares[1]
m = _PROMO_RE.search(tok)
if m:
uci += m.group(1).lower() # Q->q etc.
return uci
@classmethod
def build_fixed_uci_vocab(cls) -> "ChessTokenizer":
"""
Build a FIXED vocabulary of (almost) all possible UCI moves.
This dramatically reduces OOV compared to building vocab from the dataset
with a high min_frequency.
"""
files = "abcdefgh"
ranks = "12345678"
tokens: List[str] = []
# All from-to square pairs (excluding from==to)
for f1 in files:
for r1 in ranks:
for f2 in files:
for r2 in ranks:
if f1 == f2 and r1 == r2:
continue
tokens.append(f"{f1}{r1}{f2}{r2}")
# Promotions: white (7->8) and black (2->1), with q/r/b/n
promos = "qrbn"
# White promotions
for f in files:
fr = f + "7"
for df in (-1, 0, 1):
j = files.index(f) + df
if 0 <= j < 8:
to = files[j] + "8"
base = fr + to
for p in promos:
tokens.append(base + p)
# Black promotions
for f in files:
fr = f + "2"
for df in (-1, 0, 1):
j = files.index(f) + df
if 0 <= j < 8:
to = files[j] + "1"
base = fr + to
for p in promos:
tokens.append(base + p)
tokens = sorted(set(tokens))
special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {tok: i for i, tok in enumerate(special + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_iterator(cls, iterator, min_frequency: int = 1) -> "ChessTokenizer":
"""
Optional: build vocabulary from an iterator of strings.
We normalize tokens to UCI before counting.
"""
from collections import Counter
counts = Counter()
for game in iterator:
raw = game.strip().split()
norm = [cls._normalize_one_token(t) for t in raw]
counts.update(norm)
tokens = [t for t, c in counts.items() if c >= min_frequency]
tokens = sorted(set(tokens))
special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {tok: i for i, tok in enumerate(special + tokens)}
return cls(vocab=vocab)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
raw = text.strip().split()
return [self._normalize_one_token(t) for t in raw]
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
|