chess-gpt-char-level-v7 / tokenizer.py
Vadim38's picture
Update tokenizer.py
93009e7 verified
from __future__ import annotations
import torch
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from transformers import PreTrainedTokenizer
"""
Custom Chess Tokenizer (Character Level)
Compatible with HF Trainer & Evaluator (BatchEncoding support)
"""
class BatchEncoding(dict):
"""
Sert à envelopper le dictionnaire de sortie pour qu'il accepte la méthode .to(device).
"""
def to(self, device):
new_obj = BatchEncoding()
for k, v in self.items():
if hasattr(v, "to"):
new_obj[k] = v.to(device)
else:
new_obj[k] = v
return new_obj
class ChessTokenizer:
def __init__(self):
# Vocabulaire statique
self.chars = list("abcdefgh12345678PRNBQKxoO-=") + ["<pad>", "<s>", "</s>"]
self.vocab = {ch: i for i, ch in enumerate(self.chars)}
self.id_to_char = {i: ch for i, ch in enumerate(self.chars)}
# Attributs spéciaux
self.pad_token = "<pad>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.unk_token = "<pad>"
self.pad_token_id = self.vocab["<pad>"]
self.bos_token_id = self.vocab["<s>"]
self.eos_token_id = self.vocab["</s>"]
self.vocab_size = len(self.vocab)
self.model_max_length = 1024
self.padding_side = "right"
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
return cls()
def encode(self, text):
return [self.vocab.get(c, self.pad_token_id) for c in text]
# --- CORRECTION ICI : Ajout de **kwargs pour accepter 'skip_special_tokens' ---
def decode(self, token_ids, skip_special_tokens=False, **kwargs):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
if isinstance(token_ids, int):
token_ids = [token_ids]
tokens = [self.id_to_char.get(i, "") for i in token_ids]
# On nettoie toujours les tokens spéciaux, peu importe l'argument
return "".join(tokens).replace("<pad>", "").replace("<s>", "").replace("</s>", "")
def __call__(self, text, max_length=None, padding=False, truncation=False, return_tensors=None, **kwargs):
# 1. Encodage
ids = self.encode(text)
# 2. Truncation
if truncation and max_length is not None:
ids = ids[:max_length]
# 3. Padding + Mask
attention_mask = [1] * len(ids)
if padding == "max_length" and max_length is not None:
if len(ids) < max_length:
pad_len = max_length - len(ids)
ids = ids + [self.pad_token_id] * pad_len
attention_mask = attention_mask + [0] * pad_len
# 4. Retour intelligent (BatchEncoding)
if return_tensors == "pt":
return BatchEncoding({
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention_mask], dtype=torch.long)
})
return {
"input_ids": ids,
"attention_mask": attention_mask
}
def save_pretrained(self, save_directory):
# On sauvegarde le vocabulaire en JSON pour aider le chargement auto
vocab_path = os.path.join(save_directory, "vocab.json")
try:
with open(vocab_path, "w") as f:
json.dump(self.vocab, f)
except Exception:
pass
@classmethod
def from_pretrained(cls, save_directory, **kwargs):
return cls()
@classmethod
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
"""
Méthode vide requise par le script d'évaluation du serveur.
Sans elle, AutoTokenizer plante et le vocabulaire ne charge pas.
"""
pass
'''"""
Custom Chess Tokenizer (Character Level) - Fully Compatible with HF Trainer
"""
class ChessTokenizer:
def __init__(self):
# Vocabulaire statique
self.chars = list("abcdefgh12345678PRNBQKxoO-=") + ["<pad>", "<s>", "</s>"]
self.vocab = {ch: i for i, ch in enumerate(self.chars)}
self.id_to_char = {i: ch for i, ch in enumerate(self.chars)}
# Attributs spéciaux (version texte)
self.pad_token = "<pad>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.unk_token = "<pad>"
# Attributs spéciaux (version ID)
self.pad_token_id = self.vocab["<pad>"]
self.bos_token_id = self.vocab["<s>"]
self.eos_token_id = self.vocab["</s>"]
self.vocab_size = len(self.vocab)
# Config par défaut
self.model_max_length = 1024
self.padding_side = "right"
@classmethod
def build_vocab_from_dataset(cls, *args, **kwargs):
print("⚡ Utilisation du Tokenizer 'Char-Level' (Vocabulaire statique) ⚡")
return cls()
def encode(self, text):
return [self.vocab.get(c, self.pad_token_id) for c in text]
def decode(self, token_ids):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
if isinstance(token_ids, int):
token_ids = [token_ids]
tokens = [self.id_to_char.get(i, "") for i in token_ids]
return "".join(tokens).replace("<pad>", "").replace("<s>", "").replace("</s>", "")
def __call__(self, text, max_length=None, padding=False, truncation=False, return_tensors=None, **kwargs):
"""
Cette méthode est le coeur du problème. Elle imite le comportement
d'un tokenizer Hugging Face standard (Padding, Truncation, Tensors).
"""
# 1. Encodage brut
ids = self.encode(text)
# 2. Truncation (Couper si trop long)
if truncation and max_length is not None:
ids = ids[:max_length]
# 3. Padding (Remplir si trop court)
# On calcule le masque d'attention en même temps (1 pour les vrais tokens, 0 pour le padding)
attention_mask = [1] * len(ids)
if padding == "max_length" and max_length is not None:
if len(ids) < max_length:
pad_len = max_length - len(ids)
ids = ids + [self.pad_token_id] * pad_len
attention_mask = attention_mask + [0] * pad_len
# 4. Conversion en Tenseurs PyTorch
if return_tensors == "pt":
# data.py s'attend à une dimension de batch [1, seq_len] pour pouvoir faire .squeeze(0)
return {
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention_mask], dtype=torch.long)
}
# Fallback (liste simple)
return {
"input_ids": ids,
"attention_mask": attention_mask
}
def save_pretrained(self, save_directory):
pass
@classmethod
def from_pretrained(cls, save_directory):
return cls()
'''
"""
Custom Chess Tokenizer for the Chess Challenge.
This tokenizer treats each move as a single token using the extended UCI notation
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
The dataset format uses:
- W/B prefix for White/Black
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
- Source and destination squares (e.g., e2e4)
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
"""
'''class ChessTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer for chess moves using extended UCI notation.
This tokenizer maps each possible chess move to a unique token ID.
The vocabulary is built from the training dataset to ensure all moves
encountered during training have a corresponding token.
Example:
>>> tokenizer = ChessTokenizer()
>>> tokenizer.encode("WPe2e4 BPe7e5")
[1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
"""
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
# Special tokens
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(
self,
vocab_file: Optional[str] = None,
vocab: Optional[Dict[str, int]] = None,
**kwargs,
):
"""
Initialize the chess tokenizer.
Args:
vocab_file: Path to a JSON file containing the vocabulary mapping.
vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
**kwargs: Additional arguments passed to PreTrainedTokenizer.
"""
# Initialize special tokens
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
# Remove any duplicate special-token entries passed through kwargs
# to avoid "multiple values for keyword" errors when loading from disk.
kwargs.pop("pad_token", None)
kwargs.pop("bos_token", None)
kwargs.pop("eos_token", None)
kwargs.pop("unk_token", None)
# Load or create vocabulary
if vocab is not None:
self._vocab = vocab
elif vocab_file is not None and os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
# Create a minimal vocabulary with just special tokens
# The full vocabulary should be built from the dataset
self._vocab = self._create_default_vocab()
# Create reverse mapping
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
# Call parent init AFTER setting up vocab
super().__init__(
pad_token=self._pad_token,
bos_token=self._bos_token,
eos_token=self._eos_token,
unk_token=self._unk_token,
**kwargs,
)
def _create_default_vocab(self) -> Dict[str, int]:
"""
Create a minimal default vocabulary with just special tokens.
For the full vocabulary, use `build_vocab_from_dataset()`.
This minimal vocab is just a placeholder - you should build from data.
"""
special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens)}
return vocab
@classmethod
def build_vocab_from_iterator(
cls,
iterator,
min_frequency: int = 1,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from an iterator of game strings.
Args:
iterator: An iterator yielding game strings (space-separated moves).
min_frequency: Minimum frequency for a token to be included.
Returns:
A ChessTokenizer with the built vocabulary.
"""
from collections import Counter
token_counts = Counter()
for game in iterator:
moves = game.strip().split()
token_counts.update(moves)
# Filter by frequency
tokens = [
token for token, count in token_counts.items()
if count >= min_frequency
]
# Sort for reproducibility
tokens = sorted(tokens)
# Build vocabulary
special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
return cls(vocab=vocab)
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
min_frequency: int = 500,
max_samples: Optional[int] = 100000,
) -> "ChessTokenizer":
"""
Build a tokenizer vocabulary from a Hugging Face dataset.
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
min_frequency: Minimum frequency for a token to be included (default: 500).
max_samples: Maximum number of samples to process (default: 100k).
Returns:
A ChessTokenizer with the built vocabulary.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
def game_iterator():
for example in dataset:
yield example[column]
return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
@property
def vocab_size(self) -> int:
"""Return the size of the vocabulary."""
return len(self._vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return dict(self._vocab)
def _tokenize(self, text: str) -> List[str]:
"""
Tokenize a string of moves into a list of tokens.
Args:
text: A string of space-separated moves.
Returns:
List of move tokens.
"""
return text.strip().split()
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens back to a string."""
# Filter out special tokens for cleaner output
special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
return " ".join(t for t in tokens if t not in special)
def save_vocabulary(
self,
save_directory: str,
filename_prefix: Optional[str] = None,
) -> tuple:
"""
Save the vocabulary to a JSON file.
Args:
save_directory: Directory to save the vocabulary.
filename_prefix: Optional prefix for the filename.
Returns:
Tuple containing the path to the saved vocabulary file.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)'''
def count_vocab_from_dataset(
dataset_name: str = "dlouapre/lichess_2025-01_1M",
split: str = "train",
column: str = "text",
max_samples: Optional[int] = 10000,
) -> Dict[str, int]:
"""
Count token frequencies in a dataset (useful for vocabulary analysis).
Args:
dataset_name: Name of the dataset on Hugging Face Hub.
split: Dataset split to use.
column: Column containing the game strings.
max_samples: Maximum number of samples to process.
Returns:
Dictionary mapping tokens to their frequencies.
"""
from collections import Counter
from datasets import load_dataset
dataset = load_dataset(dataset_name, split=split)
if max_samples is not None:
dataset = dataset.select(range(min(max_samples, len(dataset))))
token_counts = Counter()
for example in dataset:
moves = example[column].strip().split()
token_counts.update(moves)
return dict(token_counts)