File size: 4,677 Bytes
e7251ec
 
89ea84b
e7251ec
 
 
 
 
 
89ea84b
e7251ec
 
 
 
 
 
 
89ea84b
 
e7251ec
 
 
 
 
89ea84b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7251ec
 
 
 
 
 
89ea84b
 
 
 
 
 
 
 
 
 
 
 
e7251ec
 
89ea84b
 
 
 
 
 
 
 
 
 
e7251ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89ea84b
 
e7251ec
89ea84b
 
 
 
 
e7251ec
 
89ea84b
e7251ec
 
 
 
 
89ea84b
 
e7251ec
 
 
89ea84b
e7251ec
 
 
89ea84b
e7251ec
 
 
 
 
89ea84b
e7251ec
89ea84b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Custom Chess Tokenizer for the Chess Challenge.
Strategy: Semantic Split (Piece, Square, Suffix)
"""

from __future__ import annotations

import json
import os
import re
from typing import Dict, List, Optional

from transformers import PreTrainedTokenizer

class ChessTokenizer(PreTrainedTokenizer):
    model_input_names = ["input_ids", "attention_mask"]
    
    # --- FIXED VOCABULARY ---
    # 1. Special Tokens
    PAD_TOKEN = "[PAD]"
    BOS_TOKEN = "[BOS]"
    EOS_TOKEN = "[EOS]"
    UNK_TOKEN = "[UNK]"
    
    # 2. Pieces (Color + Role)
    PIECES = [
        "WP", "WN", "WB", "WR", "WQ", "WK",  # White
        "BP", "BN", "BB", "BR", "BQ", "BK"   # Black
    ]
    
    # 3. Squares (a1 to h8)
    SQUARES = [f"{c}{r}" for c in "abcdefgh" for r in "12345678"]
    
    # 4. Suffixes (Capture, Check, Mate, Castling, Promotion)
    # Note: We include standard promotion suffixes just in case (q,r,b,n)
    SUFFIXES = [
        "(x)", "(+)", "(+*)", "(o)", "(O)",  # Event suffixes
        "q", "r", "b", "n", "Q", "R", "B", "N" # Promotions
    ]

    def __init__(self, **kwargs):
        # Initialize special tokens
        self._pad_token = self.PAD_TOKEN
        self._bos_token = self.BOS_TOKEN
        self._eos_token = self.EOS_TOKEN
        self._unk_token = self.UNK_TOKEN

        # Clean kwargs
        for token in ["pad_token", "bos_token", "eos_token", "unk_token"]:
            kwargs.pop(token, None)
        
        # Build Fixed Vocabulary
        self.all_tokens = (
            [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] +
            self.PIECES +
            self.SQUARES +
            self.SUFFIXES
        )
        self._vocab = {token: idx for idx, token in enumerate(self.all_tokens)}
        self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
        
        # Compile Regex for Tokenization
        # Logic: Match Piece OR Square OR Suffix
        # We sort suffixes by length (descending) to match longest first (e.g. (+*) before (+))
        escaped_suffixes = [re.escape(s) for s in self.SUFFIXES]
        suffix_pattern = "|".join(sorted(escaped_suffixes, key=len, reverse=True))
        
        self.token_pattern = re.compile(
            r'([WB][PNBRQK])|([a-h][1-8])|(' + suffix_pattern + r')'
        )

        super().__init__(
            pad_token=self._pad_token,
            bos_token=self._bos_token,
            eos_token=self._eos_token,
            unk_token=self._unk_token,
            **kwargs,
        )
    
    @property
    def vocab_size(self) -> int:
        return len(self._vocab)
    
    def get_vocab(self) -> Dict[str, int]:
        return dict(self._vocab)
    
    def _tokenize(self, text: str) -> List[str]:
        """
        Splits a game string using Regex.
        Example: "WPe2e4" -> ["WP", "e2", "e4"]
        """
        # Find all matches. Each match is a tuple like ('WP', '', '') or ('', 'e2', '')
        # We flatten this list and filter out empty strings
        matches = self.token_pattern.findall(text)
        tokens = [token for group in matches for token in group if token]
        return tokens
    
    def _convert_token_to_id(self, token: str) -> int:
        return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
    
    def _convert_id_to_token(self, index: int) -> str:
        return self._ids_to_tokens.get(index, self.UNK_TOKEN)
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        # Simple join, but we might want to group them back into moves for display
        # For raw processing, space separation is fine
        special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
        return " ".join(t for t in tokens if t not in special)
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory, exist_ok=True)
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self._vocab, f, ensure_ascii=False, indent=2)
        return (vocab_file,)

    # --- Static/Class Methods Override ---
    
    @classmethod
    def build_vocab_from_dataset(cls, *args, **kwargs) -> "ChessTokenizer":
        """
        Override: Returns a pre-initialized tokenizer with fixed vocab.
        We don't need to scan the dataset because we know the rules of Chess.
        """
        print("Using fixed vocabulary (Pieces + Squares + Suffixes). No dataset scan needed.")
        return cls()