| | """UCI Tokenizer for ChessGPT, compatible with HuggingFace transformers.""" |
| |
|
| | import json |
| | import os |
| | from typing import Dict, List, Optional, Tuple |
| |
|
| | from transformers import PreTrainedTokenizer |
| |
|
| |
|
| | class UCITokenizer(PreTrainedTokenizer): |
| | """Maps UCI move strings to integer token IDs and back. |
| | |
| | Vocab: |
| | - <PAD>=0, <BOS>=1, <EOS>=2 |
| | - All src->dst normal moves (4032) |
| | - Promotions (176) |
| | - Total: 4209 tokens (including 1 unused slot for <UNK> alias) |
| | """ |
| |
|
| | vocab_files_names = {"vocab_file": "vocab.json"} |
| | model_input_names = ["input_ids", "attention_mask"] |
| |
|
| | PAD_ID = 0 |
| | BOS_ID = 1 |
| | EOS_ID = 2 |
| |
|
| | def __init__( |
| | self, |
| | vocab_file: Optional[str] = None, |
| | bos_token: str = "<BOS>", |
| | eos_token: str = "<EOS>", |
| | pad_token: str = "<PAD>", |
| | unk_token: str = "<PAD>", |
| | **kwargs, |
| | ): |
| | |
| | if vocab_file is not None and os.path.isfile(vocab_file): |
| | with open(vocab_file, "r", encoding="utf-8") as f: |
| | self.encoder: Dict[str, int] = json.load(f) |
| | else: |
| | self.encoder = self._build_vocab() |
| |
|
| | self.decoder: Dict[int, str] = {v: k for k, v in self.encoder.items()} |
| |
|
| | super().__init__( |
| | bos_token=bos_token, |
| | eos_token=eos_token, |
| | pad_token=pad_token, |
| | unk_token=unk_token, |
| | **kwargs, |
| | ) |
| |
|
| | @staticmethod |
| | def _build_vocab() -> Dict[str, int]: |
| | """Build the UCI move vocabulary deterministically.""" |
| | vocab: Dict[str, int] = {"<PAD>": 0, "<BOS>": 1, "<EOS>": 2} |
| | idx = 3 |
| |
|
| | squares = [f + r for f in "abcdefgh" for r in "12345678"] |
| |
|
| | |
| | for src in squares: |
| | for dst in squares: |
| | if src != dst: |
| | vocab[src + dst] = idx |
| | idx += 1 |
| |
|
| | |
| | for f_idx, f in enumerate("abcdefgh"): |
| | for df in (-1, 0, 1): |
| | nf_idx = f_idx + df |
| | if 0 <= nf_idx < 8: |
| | nf = "abcdefgh"[nf_idx] |
| | for promo in "qrbn": |
| | |
| | vocab[f + "7" + nf + "8" + promo] = idx |
| | idx += 1 |
| | |
| | vocab[f + "2" + nf + "1" + promo] = idx |
| | idx += 1 |
| |
|
| | return vocab |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self.encoder) |
| |
|
| | def get_vocab(self) -> Dict[str, int]: |
| | return dict(self.encoder) |
| |
|
| | def _tokenize(self, text: str, **kwargs) -> List[str]: |
| | """Split UCI move string on whitespace. Each move is one token.""" |
| | return text.strip().split() |
| |
|
| | def _convert_token_to_id(self, token: str) -> int: |
| | return self.encoder.get(token, self.PAD_ID) |
| |
|
| | def _convert_id_to_token(self, index: int) -> str: |
| | return self.decoder.get(index, "<PAD>") |
| |
|
| | def build_inputs_with_special_tokens( |
| | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
| | ) -> List[int]: |
| | """Add BOS at the start and EOS at the end.""" |
| | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| |
|
| | def save_vocabulary( |
| | self, save_directory: str, filename_prefix: Optional[str] = None |
| | ) -> Tuple[str]: |
| | if not os.path.isdir(save_directory): |
| | raise ValueError(f"save_directory ({save_directory}) is not a directory") |
| | vocab_file = os.path.join( |
| | save_directory, |
| | (filename_prefix + "-" if filename_prefix else "") + "vocab.json", |
| | ) |
| | with open(vocab_file, "w", encoding="utf-8") as f: |
| | json.dump(self.encoder, f, ensure_ascii=False, indent=2) |
| | return (vocab_file,) |
| |
|
| | |
| | |
| | |
| |
|
| | @property |
| | def move_to_id(self) -> Dict[str, int]: |
| | return self.encoder |
| |
|
| | @property |
| | def id_to_move(self) -> Dict[int, str]: |
| | return self.decoder |
| |
|