File size: 10,302 Bytes
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ede16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9afeeeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
"""
Helper utilities for UncheatableEval visualization.

Contains TokenizerBytesConverter for mapping tokens to bytes.
"""

import json
import re
from typing import Dict, List, Optional


def bytes_to_unicode() -> Dict[int, str]:
    """
    GPT-2 style byte-to-unicode mapping.
    Maps byte values 0-255 to printable Unicode characters.
    """
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


class TokenizerBytesConverter:
    """
    Universal Token-to-Bytes Converter for HuggingFace tokenizers.

    Supports two encoding schemes:
    1. ByteLevel BPE (Llama 3.x, Qwen, GPT-2 style)
    2. SentencePiece with ByteFallback (Mistral, early LLaMA)

    Usage:
        converter = TokenizerBytesConverter("meta-llama/Llama-3.2-1B")
        nested_bytes = converter.encode_to_bytes("Hello world")
        # Returns: [[72, 101, 108, 108, 111], [32, 119, 111, 114, 108, 100]]
    """

    # Class-level mapping table cache
    _BYTE_TO_UNICODE = bytes_to_unicode()
    _UNICODE_TO_BYTE = {v: k for k, v in _BYTE_TO_UNICODE.items()}

    def __init__(
        self,
        model_name_or_path: str = None,
        cache_dir: Optional[str] = None,
        trust_remote_code: bool = True,
        tokenizer=None,
    ):
        """
        Initialize the converter.

        Args:
            model_name_or_path: HuggingFace model name or local path
            cache_dir: Directory to cache the downloaded tokenizer files
            trust_remote_code: Whether to trust remote code for custom tokenizers
            tokenizer: Optional pre-loaded tokenizer instance for encoding.
                       If provided, this tokenizer will be used for encode() calls,
                       while AutoTokenizer is still used to extract vocab/decoder config.
        """
        from transformers import AutoTokenizer

        # Always load AutoTokenizer for vocab extraction
        auto_tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            cache_dir=cache_dir,
            trust_remote_code=trust_remote_code,
        )

        # Use provided tokenizer for encoding, or fall back to auto_tokenizer
        self._tokenizer = tokenizer if tokenizer is not None else auto_tokenizer

        # Extract tokenizer.json from the AutoTokenizer's backend
        if hasattr(auto_tokenizer, "backend_tokenizer") and hasattr(auto_tokenizer.backend_tokenizer, "to_str"):
            tokenizer_json = json.loads(auto_tokenizer.backend_tokenizer.to_str())
        else:
            raise ValueError("Tokenizer object is not supported. " "The tokenizer must have a backend_tokenizer with to_str() method.")

        self._tokenizer_json = tokenizer_json
        self._vocab = tokenizer_json["model"]["vocab"]
        self._id_to_token: Dict[int, str] = {v: k for k, v in self._vocab.items()}

        # Detect encoding type
        self._decoder_type = self._detect_decoder_type()

        # Load added_tokens
        self._load_added_tokens()

    def _detect_decoder_type(self) -> str:
        """Detect the decoder type from tokenizer.json."""
        decoder = self._tokenizer_json.get("decoder", {})
        decoder_type = decoder.get("type", "")

        if decoder_type == "ByteLevel":
            return "bytelevel"
        elif decoder_type == "Sequence":
            decoders = decoder.get("decoders", [])
            for d in decoders:
                if d.get("type") == "ByteFallback":
                    return "sentencepiece"
            for d in decoders:
                if d.get("type") == "ByteLevel":
                    return "bytelevel"

        # Fallback: check model configuration
        model = self._tokenizer_json.get("model", {})
        if model.get("byte_fallback", False):
            return "sentencepiece"

        # Default to bytelevel
        return "bytelevel"

    def _load_added_tokens(self):
        """Load added_tokens into the vocabulary."""
        self._special_token_ids = set()
        added_tokens = self._tokenizer_json.get("added_tokens", [])
        for token_info in added_tokens:
            token_id = token_info["id"]
            content = token_info["content"]
            self._id_to_token[token_id] = content
            if token_info.get("special", False):
                self._special_token_ids.add(token_id)

    @property
    def decoder_type(self) -> str:
        """Return the detected decoder type."""
        return self._decoder_type

    @property
    def vocab_size(self) -> int:
        """Return the vocabulary size."""
        return len(self._id_to_token)

    @property
    def tokenizer(self):
        """Return the underlying HuggingFace tokenizer."""
        return self._tokenizer

    def get_token_string(self, token_id: int) -> Optional[str]:
        """Get the raw string for a token_id."""
        return self._id_to_token.get(token_id)

    def token_to_bytes(self, token_id: int) -> Optional[List[int]]:
        """
        Map a single token_id to its byte sequence.

        Args:
            token_id: The token ID

        Returns:
            List of byte values (0-255) as integers, or None if token_id doesn't exist
        """
        token_str = self._id_to_token.get(token_id)
        if token_str is None:
            return None

        if self._decoder_type == "bytelevel":
            return self._decode_bytelevel(token_str)
        else:
            return self._decode_sentencepiece(token_str)

    def _decode_bytelevel(self, token_str: str) -> List[int]:
        """
        ByteLevel decoding: map each Unicode character back to a byte.
        """
        result = []
        for char in token_str:
            if char in self._UNICODE_TO_BYTE:
                result.append(self._UNICODE_TO_BYTE[char])
            else:
                # Characters not in the mapping table are encoded as UTF-8
                result.extend(char.encode("utf-8"))
        return result

    def _decode_sentencepiece(self, token_str: str) -> List[int]:
        """
        SentencePiece decoding: handle ▁ and <0xXX> format.
        """
        result = []
        i = 0
        while i < len(token_str):
            # Check for <0xXX> format
            match = re.match(r"<0x([0-9A-Fa-f]{2})>", token_str[i:])
            if match:
                byte_val = int(match.group(1), 16)
                result.append(byte_val)
                i += 6
            elif token_str[i] == "▁":
                # Replace ▁ with space
                result.append(0x20)
                i += 1
            else:
                result.extend(token_str[i].encode("utf-8"))
                i += 1
        return result

    def encode_to_bytes(
        self,
        text: str,
        add_special_tokens: bool = False,
        strip_leading_space: bool = True,
    ) -> List[List[int]]:
        """
        Encode text to a nested list of bytes.

        Each sub-list contains the byte values (as integers) for one token.

        Args:
            text: Input text to encode
            add_special_tokens: Whether to add special tokens (BOS, EOS, etc.)
            strip_leading_space: For SentencePiece, whether to strip the leading space
                                from the first token

        Returns:
            Nested list where each inner list contains byte values for one token.
            Example: [[72, 101, 108, 108, 111], [32, 119, 111, 114, 108, 100]]
        """
        token_ids = self._tokenizer.encode(text, add_special_tokens=add_special_tokens)

        result = []
        for idx, token_id in enumerate(token_ids):
            token_bytes = self.token_to_bytes(token_id)
            if token_bytes is not None:
                # Handle SentencePiece leading space
                if idx == 0 and self._decoder_type == "sentencepiece" and strip_leading_space and token_bytes and token_bytes[0] == 0x20:
                    token_bytes = token_bytes[1:]

                result.append(token_bytes)

        return result

    def encode_to_ids_and_bytes(
        self,
        text: str,
        add_special_tokens: bool = False,
        strip_leading_space: bool = True,
    ) -> List[tuple]:
        """
        Encode text to (token_id, token_bytes) pairs.

        This is useful when the caller needs both the vocab token id and the exact
        byte sequence used by the tokenizer for alignment/visualization.
        """
        token_ids = self._tokenizer.encode(text, add_special_tokens=add_special_tokens)

        result = []
        for idx, token_id in enumerate(token_ids):
            token_bytes = self.token_to_bytes(token_id)
            if token_bytes is None:
                continue

            # Match encode_to_bytes() behavior for SentencePiece ByteFallback tokenizers.
            if idx == 0 and self._decoder_type == "sentencepiece" and strip_leading_space and token_bytes and token_bytes[0] == 0x20:
                token_bytes = token_bytes[1:]

            result.append((token_id, token_bytes))

        return result

    def encode_to_flat_bytes(
        self,
        text: str,
        add_special_tokens: bool = False,
        strip_leading_space: bool = True,
    ) -> bytes:
        """
        Encode text to a flat byte sequence.

        Args:
            text: Input text to encode
            add_special_tokens: Whether to add special tokens
            strip_leading_space: For SentencePiece, whether to strip the leading space

        Returns:
            Concatenated bytes from all tokens
        """
        nested = self.encode_to_bytes(text, add_special_tokens, strip_leading_space)
        result = []
        for token_bytes in nested:
            result.extend(token_bytes)
        return bytes(result)

    def get_all_token_bytes(self) -> Dict[int, List[int]]:
        """
        Get byte mapping for all tokens in the vocabulary.

        Returns:
            Dictionary mapping token_id to list of byte values
        """
        return {token_id: self.token_to_bytes(token_id) for token_id in self._id_to_token}