SmileChou's picture
Chess Challenge submission by SmileChou
dc52c9a verified
"""
Chess Tokenizer (Refactored).
Architecture:
- Splits chess moves into atomic component tokens.
- Structure: [Actor] -> [Source_Square] -> [Target_Square] -> [Promotion?]
- Output format: "WP", "e2_f", "e4_t"
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional, Any, Tuple
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
"""
A tokenizer that breaks chess moves into explicit actor and coordinate tokens.
Designed for high-precision state tracking.
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# --- Configuration Constants ---
TOKENS_SPECIAL = ["[PAD]", "[BOS]", "[EOS]", "[UNK]"]
CHARS_PIECE = "PNBRQK"
CHARS_COLOR = "WB"
CHARS_FILE = "abcdefgh"
CHARS_RANK = "12345678"
CHARS_PROMO = {"q", "r", "b", "n"}
# Regex to validate and parse standard Lichess moves (e.g., WPe2e4)
# Group 1: Color, 2: Piece, 3: Source, 4: Target, 5: Suffix
PATTERN_MOVE = re.compile(r"^([WB])([PNBRQK])([a-h][1-8])([a-h][1-8])(.*)$")
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs: Any,
):
# Initialize special tokens for the parent class
self._pad_token = self.TOKENS_SPECIAL[0]
self._bos_token = self.TOKENS_SPECIAL[1]
self._eos_token = self.TOKENS_SPECIAL[2]
self._unk_token = self.TOKENS_SPECIAL[3]
# Clean kwargs to prevent collisions
for token_arg in ["pad_token", "bos_token", "eos_token", "unk_token"]:
kwargs.pop(token_arg, None)
# 1. Load Vocabulary
if vocab:
self._vocab = vocab
elif vocab_file and os.path.isfile(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._generate_vocabulary()
# 2. Build ID-to-Token Map
self._id_to_token = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _generate_vocabulary(self) -> Dict[str, int]:
"""Constructs the fixed dictionary of tokens."""
token_list = list(self.TOKENS_SPECIAL)
# A. Actor Tokens (e.g., WP, BN)
token_list.extend(
f"{c}{p}" for c in self.CHARS_COLOR for p in self.CHARS_PIECE
)
# B. Coordinate Tokens (Source & Target)
squares = [f"{f}{r}" for r in self.CHARS_RANK for f in self.CHARS_FILE]
token_list.extend(f"{sq}_f" for sq in squares) # From
token_list.extend(f"{sq}_t" for sq in squares) # To
# C. Promotion Tokens (Sorted for consistency)
token_list.extend(sorted(self.CHARS_PROMO))
return {token: idx for idx, token in enumerate(token_list)}
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return self._vocab.copy()
def _tokenize(self, text: str) -> List[str]:
"""
Parses a string of moves into atomic tokens.
Input: "WPe2e4 BNg8f6"
Output: ["WP", "e2_f", "e4_t", "BN", "g8_f", "f6_t"]
"""
if not text:
return []
tokens = []
raw_items = text.strip().split()
special_set = set(self.TOKENS_SPECIAL)
for item in raw_items:
# Pass through special tokens immediately
if item in special_set:
tokens.append(item)
continue
# Parse move structure
match = self.PATTERN_MOVE.match(item)
if not match:
tokens.append(self.unk_token)
continue
# Deconstruct parts
color, piece, src, dst, suffix = match.groups()
# 1. Actor (Who)
tokens.append(f"{color}{piece}")
# 2. Origin (Where from)
tokens.append(f"{src}_f")
# 3. Destination (Where to)
tokens.append(f"{dst}_t")
# 4. Promotion (Transformation)
# Check for suffixes like "=Q" or trailing chars
if suffix:
if "=" in suffix:
# Look for the character immediately following '='
eq_idx = suffix.find("=")
if eq_idx + 1 < len(suffix):
promo_char = suffix[eq_idx + 1].lower()
if promo_char in self.CHARS_PROMO:
tokens.append(promo_char)
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Joins tokens into a space-separated string, filtering out specials."""
special_set = set(self.TOKENS_SPECIAL)
return " ".join(t for t in tokens if t not in special_set)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.exists(save_directory):
os.makedirs(save_directory, exist_ok=True)
filename = "vocab.json"
if filename_prefix:
filename = f"{filename_prefix}-{filename}"
full_path = os.path.join(save_directory, filename)
with open(full_path, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, indent=2, ensure_ascii=False)
return (full_path,)
@classmethod
def build_vocab_from_dataset(cls, *args: Any, **kwargs: Any) -> "ChessTokenizer":
"""Compatibility method for training pipelines."""
return cls()