File size: 6,156 Bytes
10fd3a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Coordinate Chess Tokenizer (Vocab Size = 72).
Compatible with Hugging Face AutoTokenizer and existing Evaluation scripts.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional, Tuple, Union
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# Regex to capture coordinates and promotions from any format (UCI, SAN, Extended)
# Captures: "e2", "e4", "q" inside strings like "WPe2e4" or "e2e4q"
MOVE_REGEX = re.compile(r"([a-h][1-8])([a-h][1-8])([qrbn])?")
def __init__(
self,
vocab_file: Optional[str] = None,
**kwargs,
):
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Clean kwargs to avoid duplication errors during loading
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# 1. Load or Create Vocabulary
# If a vocab_file is provided (loading from HF), use it.
# Otherwise, create the fixed 72-token vocabulary.
if vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_fixed_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_fixed_vocab(self) -> Dict[str, int]:
"""Creates the deterministic 72-token vocabulary."""
vocab = {}
# 0-3: Special Tokens
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
for idx, token in enumerate(special_tokens):
vocab[token] = idx
# 4-7: Promotions (q, r, b, n)
promotions = ["q", "r", "b", "n"]
for idx, token in enumerate(promotions):
vocab[token] = len(vocab)
# 8-71: Squares (a1...h8)
files = "abcdefgh"
ranks = "12345678"
for r in ranks:
for f in files:
square = f + r
vocab[square] = len(vocab)
return vocab
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Robust tokenization handling both raw coordinates and 'dirty' UCI extended strings.
"""
tokens = []
# Split by whitespace first
raw_chunks = text.strip().split()
# Set of exact match tokens to preserve special tokens
special_set = {self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN}
for chunk in raw_chunks:
# If it's explicitly a special token, keep it
if chunk in special_set:
tokens.append(chunk)
continue
# Otherwise, use Regex to extract coordinates
# This handles "WPe2e4" -> ["e2", "e4"]
# And "e2e4" -> ["e2", "e4"]
match = self.MOVE_REGEX.search(chunk)
if match:
start_sq, end_sq, promotion = match.groups()
tokens.append(start_sq)
tokens.append(end_sq)
if promotion:
tokens.append(promotion)
else:
# If regex fails but it is in our vocab (e.g. isolated 'a1'), take it
if chunk in self._vocab:
tokens.append(chunk)
else:
tokens.append(self.UNK_TOKEN)
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Reconstructs string. Important: adds spaces between coordinates.
Evaluate.py handles spaces fine via regex.
"""
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
clean_tokens = [t for t in tokens if t not in special]
return " ".join(clean_tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Vital for Hugging Face: saves the vocab.json to the directory.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500, # Ignored
max_samples: Optional[int] = 100000, # Ignored
) -> "ChessTokenizer":
"""
Mock implementation to satisfy train.py API.
Ignores dataset scanning since vocab is fixed.
"""
print(f"Coordinate Tokenizer: Using fixed vocabulary (size 72). Ignoring dataset scan.")
return cls() |