Upload tokenizer.py with huggingface_hub
Browse files- tokenizer.py +130 -0
tokenizer.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Chess Tokenizer for the Chess Challenge.
|
| 3 |
+
|
| 4 |
+
This tokenizer treats each move as a single token using the extended UCI notation
|
| 5 |
+
from the Lichess dataset (e.g., WPe2e4, BNg8f6).
|
| 6 |
+
|
| 7 |
+
The dataset format uses:
|
| 8 |
+
- W/B prefix for White/Black
|
| 9 |
+
- Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
|
| 10 |
+
- Source and destination squares (e.g., e2e4)
|
| 11 |
+
- Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, List, Optional
|
| 20 |
+
|
| 21 |
+
from transformers import PreTrainedTokenizer
|
| 22 |
+
"""
|
| 23 |
+
Custom Chess Tokenizer - Normalized Version
|
| 24 |
+
"""
|
| 25 |
+
import re
|
| 26 |
+
|
| 27 |
+
# Regex pour extraire case départ, arrivée et promotion
|
| 28 |
+
MOVE_RE = re.compile(r"([a-h][1-8])([a-h][1-8])")
|
| 29 |
+
PROMO_RE = re.compile(r"=([NBRQ])")
|
| 30 |
+
|
| 31 |
+
def normalize_move(tok: str) -> str:
|
| 32 |
+
"""Transforme 'WPe2e4(x)' en 'WPe2e4' pour réduire le vocabulaire."""
|
| 33 |
+
# 1. Garder les infos de base
|
| 34 |
+
m = MOVE_RE.search(tok)
|
| 35 |
+
if not m:
|
| 36 |
+
return tok # Fallback (sera probablement UNK)
|
| 37 |
+
|
| 38 |
+
fr, to = m.group(1), m.group(2)
|
| 39 |
+
|
| 40 |
+
# 2. Gérer la promotion
|
| 41 |
+
promo = ""
|
| 42 |
+
pm = PROMO_RE.search(tok)
|
| 43 |
+
if pm:
|
| 44 |
+
promo = "=" + pm.group(1)
|
| 45 |
+
|
| 46 |
+
# 3. Reconstruire le token standardisé
|
| 47 |
+
# On garde le préfixe WP/BN (chars 0 et 1) pour garder l'info couleur/pièce
|
| 48 |
+
# mais on supprime les suffixes (x), (+), etc.
|
| 49 |
+
prefix = tok[:2] if len(tok) >= 2 else "WP"
|
| 50 |
+
return f"{prefix}{fr}{to}{promo}"
|
| 51 |
+
|
| 52 |
+
class ChessTokenizer(PreTrainedTokenizer):
|
| 53 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 54 |
+
|
| 55 |
+
PAD_TOKEN = "[PAD]"
|
| 56 |
+
BOS_TOKEN = "[BOS]"
|
| 57 |
+
EOS_TOKEN = "[EOS]"
|
| 58 |
+
UNK_TOKEN = "[UNK]"
|
| 59 |
+
|
| 60 |
+
def __init__(self, vocab_file=None, vocab=None, **kwargs):
|
| 61 |
+
self._pad_token = self.PAD_TOKEN
|
| 62 |
+
self._bos_token = self.BOS_TOKEN
|
| 63 |
+
self._eos_token = self.EOS_TOKEN
|
| 64 |
+
self._unk_token = self.UNK_TOKEN
|
| 65 |
+
|
| 66 |
+
# Nettoyage kwargs
|
| 67 |
+
for t in ["pad_token", "bos_token", "eos_token", "unk_token"]:
|
| 68 |
+
kwargs.pop(t, None)
|
| 69 |
+
|
| 70 |
+
if vocab:
|
| 71 |
+
self._vocab = vocab
|
| 72 |
+
elif vocab_file:
|
| 73 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
| 74 |
+
self._vocab = json.load(f)
|
| 75 |
+
else:
|
| 76 |
+
self._vocab = {t: i for i, t in enumerate([self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])}
|
| 77 |
+
|
| 78 |
+
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
|
| 79 |
+
super().__init__(pad_token=self.PAD_TOKEN, bos_token=self.BOS_TOKEN, eos_token=self.EOS_TOKEN, unk_token=self.UNK_TOKEN, **kwargs)
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def vocab_size(self):
|
| 83 |
+
return len(self._vocab)
|
| 84 |
+
|
| 85 |
+
def get_vocab(self):
|
| 86 |
+
return dict(self._vocab)
|
| 87 |
+
|
| 88 |
+
def _tokenize(self, text):
|
| 89 |
+
# C'est ICI que la magie opère : on normalise à la volée
|
| 90 |
+
return [normalize_move(t) for t in text.strip().split()]
|
| 91 |
+
|
| 92 |
+
def _convert_token_to_id(self, token):
|
| 93 |
+
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
|
| 94 |
+
|
| 95 |
+
def _convert_id_to_token(self, index):
|
| 96 |
+
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
|
| 97 |
+
|
| 98 |
+
def convert_tokens_to_string(self, tokens):
|
| 99 |
+
return " ".join(t for t in tokens if t not in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])
|
| 100 |
+
|
| 101 |
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
| 102 |
+
if not os.path.exists(save_directory):
|
| 103 |
+
os.makedirs(save_directory)
|
| 104 |
+
path = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json")
|
| 105 |
+
with open(path, "w") as f:
|
| 106 |
+
json.dump(self._vocab, f, indent=2)
|
| 107 |
+
return (path,)
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def build_vocab_from_dataset(cls, dataset_name, min_frequency=2, max_vocab_size=1200, **kwargs):
|
| 111 |
+
"""Construit un vocabulaire compact et dense."""
|
| 112 |
+
from datasets import load_dataset
|
| 113 |
+
from collections import Counter
|
| 114 |
+
|
| 115 |
+
# On charge en streaming pour aller vite
|
| 116 |
+
ds = load_dataset(dataset_name, split="train", streaming=True)
|
| 117 |
+
ds = ds.take(50000) # 50k parties suffisent pour voir tous les coups possibles
|
| 118 |
+
|
| 119 |
+
counter = Counter()
|
| 120 |
+
for ex in ds:
|
| 121 |
+
# On normalise avant de compter !
|
| 122 |
+
moves = [normalize_move(t) for t in ex["text"].split()]
|
| 123 |
+
counter.update(moves)
|
| 124 |
+
|
| 125 |
+
# On garde les tokens spéciaux + les N plus fréquents
|
| 126 |
+
special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
|
| 127 |
+
most_common = counter.most_common(max_vocab_size - len(special))
|
| 128 |
+
|
| 129 |
+
vocab = {t: i for i, t in enumerate(special + [t for t, c in most_common])}
|
| 130 |
+
return cls(vocab=vocab)
|