shubhasanket's picture
Chess Challenge submission by shubhasanket
937866b verified
# """
# Improved Chess Tokenizer (Structured) for the Chess Challenge.
# Key idea:
# - Decompose each move into sub-tokens:
# SIDE_W / SIDE_B
# piece (P,N,B,R,Q,K)
# from-square (e2)
# to-square (e4)
# optional flags: CAPTURE, CHECK, MATE, CASTLE_SHORT, CASTLE_LONG
# Important implementation detail:
# - We MUST avoid token-string collisions. In particular, "B" is both:
# - Black side ("B")
# - Bishop piece ("B")
# If we used raw "W"/"B" for side, the vocab dict would overwrite one of them,
# creating non-contiguous IDs and leading to embedding "index out of range".
# """
# from __future__ import annotations
# import json
# import os
# import re
# from typing import Dict, List, Optional
# from transformers import PreTrainedTokenizer
# class ChessTokenizer(PreTrainedTokenizer):
# model_input_names = ["input_ids", "attention_mask"]
# vocab_files_names = {"vocab_file": "vocab.json"}
# # Special tokens
# PAD_TOKEN = "[PAD]"
# BOS_TOKEN = "[BOS]"
# EOS_TOKEN = "[EOS]"
# UNK_TOKEN = "[UNK]"
# # Side tokens (avoid collision with piece "B" for Bishop)
# SIDE_W = "SIDE_W"
# SIDE_B = "SIDE_B"
# def __init__(
# self,
# vocab_file: Optional[str] = None,
# vocab: Optional[Dict[str, int]] = None,
# **kwargs,
# ):
# self._pad_token = self.PAD_TOKEN
# self._bos_token = self.BOS_TOKEN
# self._eos_token = self.EOS_TOKEN
# self._unk_token = self.UNK_TOKEN
# # Avoid duplicate kwargs when HF loads from disk
# kwargs.pop("pad_token", None)
# kwargs.pop("bos_token", None)
# kwargs.pop("eos_token", None)
# kwargs.pop("unk_token", None)
# if vocab is not None:
# self._vocab = {str(k): int(v) for k, v in vocab.items()}
# elif vocab_file is not None and os.path.exists(vocab_file):
# with open(vocab_file, "r", encoding="utf-8") as f:
# loaded = json.load(f)
# self._vocab = {str(k): int(v) for k, v in loaded.items()}
# else:
# self._vocab = self._create_default_vocab()
# # Ensure IDs are contiguous 0..(len-1) (robust to any old saved vocabs)
# self._vocab = self._normalize_vocab_ids(self._vocab)
# self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# super().__init__(
# pad_token=self._pad_token,
# bos_token=self._bos_token,
# eos_token=self._eos_token,
# unk_token=self._unk_token,
# **kwargs,
# )
# @staticmethod
# def _normalize_vocab_ids(vocab: Dict[str, int]) -> Dict[str, int]:
# """
# Re-map token IDs to be contiguous and deterministic.
# Sort by old id then by token string.
# """
# items = sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]))
# return {tok: new_id for new_id, (tok, _) in enumerate(items)}
# # ------------------------------------------------------------------
# # REQUIRED compatibility method (train.py expects this to exist)
# # ------------------------------------------------------------------
# @classmethod
# def build_vocab_from_dataset(
# cls,
# dataset_name: str = "dlouapre/lichess_2025-01_1M",
# split: str = "train",
# column: str = "text",
# min_frequency: int = 1,
# max_samples: Optional[int] = None,
# ) -> "ChessTokenizer":
# """
# Compatibility hook.
# For the structured tokenizer, the vocabulary is fixed and does not
# depend on dataset statistics. We keep this method so src/train.py
# (template code) does not need to change.
# """
# return cls()
# # ------------------------------------------------------------------
# # Vocabulary construction
# # ------------------------------------------------------------------
# def _create_default_vocab(self) -> Dict[str, int]:
# special = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
# sides = [self.SIDE_W, self.SIDE_B] # no collision with piece tokens
# pieces = ["P", "N", "B", "R", "Q", "K"]
# files = list("abcdefgh")
# ranks = list("12345678")
# squares = [f + r for f in files for r in ranks] # 64 tokens
# flags = ["CAPTURE", "CHECK", "MATE", "CASTLE_SHORT", "CASTLE_LONG"]
# tokens = special + sides + pieces + squares + flags
# return {tok: i for i, tok in enumerate(tokens)} # contiguous by construction
# @property
# def vocab_size(self) -> int:
# return len(self._vocab)
# def get_vocab(self) -> Dict[str, int]:
# return dict(self._vocab)
# # ------------------------------------------------------------------
# # Tokenization logic
# # ------------------------------------------------------------------
# MOVE_REGEX = re.compile(
# r"""
# (?P<side>[WB])
# (?P<piece>[PNBRQK])
# (?P<from>[a-h][1-8])
# (?P<to>[a-h][1-8])
# (?P<suffix>.*)?
# """,
# re.VERBOSE,
# )
# def _tokenize(self, text: str) -> List[str]:
# out: List[str] = []
# for move in text.strip().split():
# out.extend(self._decompose_move(move))
# return out
# def _decompose_move(self, move: str) -> List[str]:
# m = self.MOVE_REGEX.match(move)
# if not m:
# return [self.UNK_TOKEN]
# side_raw = m.group("side")
# side_tok = self.SIDE_W if side_raw == "W" else self.SIDE_B
# tokens = [
# side_tok,
# m.group("piece"),
# m.group("from"),
# m.group("to"),
# ]
# suffix = m.group("suffix") or ""
# if "(x)" in suffix:
# tokens.append("CAPTURE")
# if "(+*)" in suffix:
# tokens.append("MATE")
# elif "(+)" in suffix:
# tokens.append("CHECK")
# if "(o)" in suffix:
# tokens.append("CASTLE_SHORT")
# if "(O)" in suffix:
# tokens.append("CASTLE_LONG")
# return tokens
# # ------------------------------------------------------------------
# # ID conversion
# # ------------------------------------------------------------------
# def _convert_token_to_id(self, token: str) -> int:
# return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
# def _convert_id_to_token(self, index: int) -> str:
# return self._ids_to_tokens.get(index, self.UNK_TOKEN)
# def convert_tokens_to_string(self, tokens: List[str]) -> str:
# special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
# return " ".join(t for t in tokens if t not in special)
# # ------------------------------------------------------------------
# # Saving
# # ------------------------------------------------------------------
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
# os.makedirs(save_directory, exist_ok=True)
# vocab_file = os.path.join(
# save_directory,
# (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
# )
# with open(vocab_file, "w", encoding="utf-8") as f:
# json.dump(self._vocab, f, indent=2)
# return (vocab_file,)
# def count_vocab_from_dataset(
# dataset_name: str = "dlouapre/lichess_2025-01_1M",
# split: str = "train",
# column: str = "text",
# max_samples: Optional[int] = 10000,
# ) -> Dict[str, int]:
# """
# Count token frequencies after structured tokenization.
# (Editor warning about 'datasets' can be ignored if terminal run works.)
# """
# from collections import Counter
# from datasets import load_dataset
# dataset = load_dataset(dataset_name, split=split)
# if max_samples is not None:
# dataset = dataset.select(range(min(max_samples, len(dataset))))
# tok = ChessTokenizer()
# counts = Counter()
# for ex in dataset:
# counts.update(tok._tokenize(ex[column]))
# return dict(counts)
"""
Final Structured Chess Tokenizer for the Chess Challenge.
Design goals:
- Strong legality bias
- Fixed, collision-free vocabulary
- HF-compatible (Trainer, save/load, Hub)
- Evaluator-friendly (square extraction still works)
Move decomposition:
PIECE
FROM_<square>
TO_<square>
optional FLAGS
Example:
P FROM_e2 TO_e4
N FROM_g1 TO_f3 CHECK
K FROM_e1 TO_g1 CASTLE_SHORT
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
# Fixed role prefixes
FROM_PREFIX = "FROM_"
TO_PREFIX = "TO_"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Avoid duplicate kwargs when loading
for k in ("pad_token", "bos_token", "eos_token", "unk_token"):
kwargs.pop(k, None)
if vocab is not None:
self._vocab = {str(k): int(v) for k, v in vocab.items()}
elif vocab_file and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = {str(k): int(v) for k, v in json.load(f).items()}
else:
self._vocab = self._create_default_vocab()
# Ensure contiguous IDs
self._vocab = self._normalize_vocab(self._vocab)
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
@staticmethod
def _normalize_vocab(vocab: Dict[str, int]) -> Dict[str, int]:
items = sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]))
return {tok: i for i, (tok, _) in enumerate(items)}
# ------------------------------------------------------------
# Required by train.py (kept for compatibility)
# ------------------------------------------------------------
@classmethod
def build_vocab_from_dataset(
cls,
*args,
**kwargs,
) -> "ChessTokenizer":
return cls()
# ------------------------------------------------------------
# Vocabulary
# ------------------------------------------------------------
def _create_default_vocab(self) -> Dict[str, int]:
special = [
self.PAD_TOKEN,
self.BOS_TOKEN,
self.EOS_TOKEN,
self.UNK_TOKEN,
]
pieces = ["P", "N", "B", "R", "Q", "K"]
files = "abcdefgh"
ranks = "12345678"
squares = [f + r for f in files for r in ranks]
from_tokens = [self.FROM_PREFIX + sq for sq in squares]
to_tokens = [self.TO_PREFIX + sq for sq in squares]
flags = [
"CAPTURE",
"CHECK",
"MATE",
"CASTLE_SHORT",
"CASTLE_LONG",
]
tokens = special + pieces + from_tokens + to_tokens + flags
return {tok: i for i, tok in enumerate(tokens)}
@property
def vocab_size(self) -> int:
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
# ------------------------------------------------------------
# Tokenization
# ------------------------------------------------------------
MOVE_REGEX = re.compile(
r"""
(?P<piece>[PNBRQK])
(?P<from>[a-h][1-8])
(?P<to>[a-h][1-8])
(?P<suffix>.*)?
""",
re.VERBOSE,
)
def _tokenize(self, text: str) -> List[str]:
out: List[str] = []
for move in text.strip().split():
out.extend(self._decompose_move(move))
return out
def _decompose_move(self, move: str) -> List[str]:
m = self.MOVE_REGEX.search(move)
if not m:
return [self.UNK_TOKEN]
tokens = [
m.group("piece"),
self.FROM_PREFIX + m.group("from"),
self.TO_PREFIX + m.group("to"),
]
suffix = m.group("suffix") or ""
if "(x)" in suffix:
tokens.append("CAPTURE")
if "(+*)" in suffix:
tokens.append("MATE")
elif "(+)" in suffix:
tokens.append("CHECK")
if "(o)" in suffix:
tokens.append("CASTLE_SHORT")
if "(O)" in suffix:
tokens.append("CASTLE_LONG")
return tokens
# ------------------------------------------------------------
# ID conversion
# ------------------------------------------------------------
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
# ------------------------------------------------------------
# Saving
# ------------------------------------------------------------
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
os.makedirs(save_directory, exist_ok=True)
path = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(path, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, indent=2)
return (path,)