""" Character-level Chess Tokenizer (Robust Version). Fixes the [BOS] splitting issue and HuggingFace from_pretrained crashes. """ from __future__ import annotations import json import os import re from typing import Dict, List, Optional from transformers import PreTrainedTokenizer class ChessTokenizer(PreTrainedTokenizer): model_input_names = ["input_ids", "attention_mask"] PAD_TOKEN = "[PAD]" BOS_TOKEN = "[BOS]" EOS_TOKEN = "[EOS]" UNK_TOKEN = "[UNK]" def __init__(self, vocab_file=None, **kwargs): # --- Définition des tokens spéciaux --- self._pad_token = self.PAD_TOKEN self._bos_token = self.BOS_TOKEN self._eos_token = self.EOS_TOKEN self._unk_token = self.UNK_TOKEN # --- Alphabet UCI + annotations --- chars = "abcdefgh12345678PNBRQKWBx+#=-O()" # --- Vocabulaire statique --- self._vocab = { self.PAD_TOKEN: 0, self.BOS_TOKEN: 1, self.EOS_TOKEN: 2, self.UNK_TOKEN: 3, " ": 4, } for i, char in enumerate(chars): self._vocab[char] = i + 5 self._ids_to_tokens = {v: k for k, v in self._vocab.items()} # --- FIX CRITIQUE HF --- # from_pretrained passe déjà ces valeurs via kwargs # donc on ne les écrase PAS si elles existent kwargs.setdefault("pad_token", self.PAD_TOKEN) kwargs.setdefault("bos_token", self.BOS_TOKEN) kwargs.setdefault("eos_token", self.EOS_TOKEN) kwargs.setdefault("unk_token", self.UNK_TOKEN) super().__init__(**kwargs) # ------------------------------------------------------------------ # Propriétés obligatoires HuggingFace # ------------------------------------------------------------------ @property def vocab_size(self) -> int: # Hack volontaire : évite les crashs CUDA si un ID dépasse return 128 def get_vocab(self) -> Dict[str, int]: return dict(self._vocab) # ------------------------------------------------------------------ # Tokenisation # ------------------------------------------------------------------ def _tokenize(self, text: str) -> List[str]: """ Découpe robuste qui ne casse jamais les tokens spéciaux. """ if text in [ self.BOS_TOKEN, self.EOS_TOKEN, self.PAD_TOKEN, self.UNK_TOKEN, ]: return [text] pattern = r"(\[PAD\]|\[BOS\]|\[EOS\]|\[UNK\]|.)" tokens = [t for t in re.split(pattern, text) if t] return tokens def _convert_token_to_id(self, token: str) -> int: return self._vocab.get(token, self._vocab[self.UNK_TOKEN]) def _convert_id_to_token(self, index: int) -> str: return self._ids_to_tokens.get(index, self.UNK_TOKEN) def convert_tokens_to_string(self, tokens: List[str]) -> str: return "".join( t for t in tokens if t not in [ self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN, ] ) # ------------------------------------------------------------------ # Méthodes utilitaires # ------------------------------------------------------------------ @classmethod def build_vocab_from_dataset(cls, **kwargs): print("Using static character-level vocab (no build needed).") return cls() def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None ) -> tuple: os.makedirs(save_directory, exist_ok=True) vocab_path = os.path.join(save_directory, "vocab.json") with open(vocab_path, "w") as f: json.dump(self._vocab, f) return (vocab_path,)