chess-aj-split-v3 / tokenizer.py
ali-javani's picture
Chess Challenge submission by ali-javani
f9222c0 verified
"""
4-Step Split Tokenizer
Splits moves into: [Piece] -> [From] -> [To] -> [Suffix]
Minimizes vocabulary to ~150 tokens.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer, AutoTokenizer
class ChessTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
# 1. Pieces
PIECES = ["WP", "WN", "WB", "WR", "WQ", "WK", "BP", "BN", "BB", "BR", "BQ", "BK"]
# 2. Squares
SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
# 3. Suffixes (Crucial: (-) represents "No Suffix/Quiet Move")
SUFFIXES = ["(-)", "(x)", "(+)", "(#)", "(x+)", "(x#)", "(O)", "(o)", "(Q)", "=Q"]
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
# # 1. Build or Load Vocab first
# self._vocab = vocab
# if vocab_file and os.path.exists(vocab_file):
# with open(vocab_file, "r", encoding="utf-8") as f:
# self._vocab = json.load(f)
# if not self._vocab:
# self._vocab = self._build_split_vocab()
# self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# # 2. Call parent init with explicit tokens to prevent auto-add errors
# super().__init__(
# pad_token=self.PAD_TOKEN,
# bos_token=self.BOS_TOKEN,
# eos_token=self.EOS_TOKEN,
# unk_token=self.UNK_TOKEN,
# **kwargs,
# )
def __init__(self, vocab_file: Optional[str] = None, vocab: Optional[Dict[str, int]] = None, **kwargs):
# 1. Build or Load Vocab
self._vocab = vocab
if vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
if not self._vocab:
self._vocab = self._build_split_vocab()
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# 2. Handle Special Tokens Safely
# We "pop" them from kwargs to prevent the "multiple values" error.
# This prioritizes the loaded config (kwargs) if it exists,
# falling back to your class constants if it doesn't.
pad_token = kwargs.pop("pad_token", self.PAD_TOKEN)
bos_token = kwargs.pop("bos_token", self.BOS_TOKEN)
eos_token = kwargs.pop("eos_token", self.EOS_TOKEN)
unk_token = kwargs.pop("unk_token", self.UNK_TOKEN)
# 3. Call parent
super().__init__(
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
**kwargs,
)
def _build_split_vocab(self):
tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
tokens += self.PIECES + self.SQUARES + self.SUFFIXES
# Sort and unique to be safe
unique_tokens = sorted(list(set(tokens)))
return {t: i for i, t in enumerate(unique_tokens)}
def get_vocab(self) -> Dict[str, int]:
"""Required by Hugging Face PreTrainedTokenizer"""
return dict(self._vocab)
@property
def vocab_size(self) -> int:
return len(self._vocab)
def _tokenize(self, text: str) -> List[str]:
moves = text.strip().split()
tokens = []
# Regex: (Piece)(Square)(Square)(Optional Suffix)
pattern = re.compile(r"([WB][PNBRQK])([a-h][1-8])([a-h][1-8])(.*)")
for move in moves:
match = pattern.match(move)
if match:
p, s, t, suf = match.groups()
tokens.extend([p, s, t])
tokens.append(suf if suf else "(-)")
else:
tokens.append(self.UNK_TOKEN)
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
out = []
specials = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
clean = [t for t in tokens if t not in specials]
current_move = ""
for i, t in enumerate(clean):
if t == "(-)":
pass
else:
current_move += t
# Every 4th token completes a move
if (i + 1) % 4 == 0:
out.append(current_move)
current_move = ""
if current_move: out.append(current_move)
return " ".join(out)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
with open(path, "w") as f:
json.dump(self._vocab, f)
return (path,)
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
print("Using static 4-Step Split vocabulary.")
return cls()
# Register
AutoTokenizer.register("ChessTokenizer", ChessTokenizer)