File size: 3,879 Bytes
b8663ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
Character-level Chess Tokenizer (Robust Version).
Fixes the [BOS] splitting issue and HuggingFace from_pretrained crashes.
"""
from __future__ import annotations
import json
import os
import re
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file=None, **kwargs):
# --- Définition des tokens spéciaux ---
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# --- Alphabet UCI + annotations ---
chars = "abcdefgh12345678PNBRQKWBx+#=-O()"
# --- Vocabulaire statique ---
self._vocab = {
self.PAD_TOKEN: 0,
self.BOS_TOKEN: 1,
self.EOS_TOKEN: 2,
self.UNK_TOKEN: 3,
" ": 4,
}
for i, char in enumerate(chars):
self._vocab[char] = i + 5
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# --- FIX CRITIQUE HF ---
# from_pretrained passe déjà ces valeurs via kwargs
# donc on ne les écrase PAS si elles existent
kwargs.setdefault("pad_token", self.PAD_TOKEN)
kwargs.setdefault("bos_token", self.BOS_TOKEN)
kwargs.setdefault("eos_token", self.EOS_TOKEN)
kwargs.setdefault("unk_token", self.UNK_TOKEN)
super().__init__(**kwargs)
# ------------------------------------------------------------------
# Propriétés obligatoires HuggingFace
# ------------------------------------------------------------------
@property
def vocab_size(self) -> int:
# Hack volontaire : évite les crashs CUDA si un ID dépasse
return 128
def get_vocab(self) -> Dict[str, int]:
return dict(self._vocab)
# ------------------------------------------------------------------
# Tokenisation
# ------------------------------------------------------------------
def _tokenize(self, text: str) -> List[str]:
"""
Découpe robuste qui ne casse jamais les tokens spéciaux.
"""
if text in [
self.BOS_TOKEN,
self.EOS_TOKEN,
self.PAD_TOKEN,
self.UNK_TOKEN,
]:
return [text]
pattern = r"(\[PAD\]|\[BOS\]|\[EOS\]|\[UNK\]|.)"
tokens = [t for t in re.split(pattern, text) if t]
return tokens
def _convert_token_to_id(self, token: str) -> int:
return self._vocab.get(token, self._vocab[self.UNK_TOKEN])
def _convert_id_to_token(self, index: int) -> str:
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return "".join(
t
for t in tokens
if t
not in [
self.PAD_TOKEN,
self.BOS_TOKEN,
self.EOS_TOKEN,
self.UNK_TOKEN,
]
)
# ------------------------------------------------------------------
# Méthodes utilitaires
# ------------------------------------------------------------------
@classmethod
def build_vocab_from_dataset(cls, **kwargs):
print("Using static character-level vocab (no build needed).")
return cls()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> tuple:
os.makedirs(save_directory, exist_ok=True)
vocab_path = os.path.join(save_directory, "vocab.json")
with open(vocab_path, "w") as f:
json.dump(self._vocab, f)
return (vocab_path,)
|