File size: 7,860 Bytes
8ae05a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""
Byte-Pair Encoding trainer and codec optimized for JSON value strings.

Uses incremental pair counting with pair→word index for fast merges.
"""

from __future__ import annotations

import json
import re
from collections import defaultdict
from typing import Optional


def _bytes_to_unicode() -> dict[int, str]:
    """Map bytes 0-255 to unicode chars, avoiding control/whitespace collisions."""
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    return {b: chr(c) for b, c in zip(bs, cs)}


BYTE_ENCODER = _bytes_to_unicode()
BYTE_DECODER = {v: k for k, v in BYTE_ENCODER.items()}

_PRE_TOK_PAT = re.compile(
    r"""'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z_]+| ?[0-9]+| ?[^\s\w]+|\s+|."""
)


class BPETrainer:
    """Train a BPE vocabulary from a corpus of JSON value strings."""

    def __init__(self, vocab_size: int = 4096, min_frequency: int = 2):
        self.vocab_size = vocab_size
        self.min_frequency = min_frequency
        self.merges: list[tuple[str, str]] = []
        self.vocab: dict[str, int] = {}
        self._id_to_tok: dict[int, str] | None = None

    def _pre_tokenize(self, text: str) -> list[str]:
        return _PRE_TOK_PAT.findall(text)

    def _text_to_bytes(self, text: str) -> tuple[str, ...]:
        return tuple(BYTE_ENCODER[b] for b in text.encode("utf-8"))

    def train(self, texts: list[str]) -> None:
        """Train BPE with pair→word index for O(affected) merges."""
        # Count word frequencies
        word_freqs: dict[tuple[str, ...], int] = {}
        for text in texts:
            for word in self._pre_tokenize(text):
                bw = self._text_to_bytes(word)
                word_freqs[bw] = word_freqs.get(bw, 0) + 1

        # Base vocab
        base_vocab: set[str] = set()
        for word in word_freqs:
            base_vocab.update(word)

        num_merges = self.vocab_size - len(base_vocab) - 1

        # Word storage: idx → [symbols], freq
        words: list[list[str]] = []
        freqs: list[int] = []
        for w, f in word_freqs.items():
            words.append(list(w))
            freqs.append(f)

        # Pair counts and pair→word indices
        pair_counts: dict[tuple[str, str], int] = defaultdict(int)
        pair_to_words: dict[tuple[str, str], set[int]] = defaultdict(set)

        for idx, (w, f) in enumerate(zip(words, freqs)):
            for i in range(len(w) - 1):
                p = (w[i], w[i + 1])
                pair_counts[p] += f
                pair_to_words[p].add(idx)

        for _ in range(max(0, num_merges)):
            if not pair_counts:
                break

            # Find best pair
            best_pair = max(pair_counts, key=pair_counts.__getitem__)
            if pair_counts[best_pair] < self.min_frequency:
                break

            a, b = best_pair
            merged = a + b
            self.merges.append(best_pair)

            # Only process words that contain this pair
            affected = list(pair_to_words.pop(best_pair, set()))
            del pair_counts[best_pair]

            for idx in affected:
                w = words[idx]
                f = freqs[idx]

                # Find positions of the pair
                new_w: list[str] = []
                i = 0
                while i < len(w):
                    if i < len(w) - 1 and w[i] == a and w[i + 1] == b:
                        # Decrement old adjacent pairs
                        if new_w:
                            old_left = (new_w[-1], a)
                            pair_counts[old_left] -= f
                            if pair_counts[old_left] <= 0:
                                pair_counts.pop(old_left, None)
                            pair_to_words[old_left].discard(idx)

                        if i + 2 < len(w):
                            old_right = (b, w[i + 2])
                            pair_counts[old_right] -= f
                            if pair_counts[old_right] <= 0:
                                pair_counts.pop(old_right, None)
                            pair_to_words[old_right].discard(idx)

                        new_w.append(merged)

                        # Increment new adjacent pairs
                        if len(new_w) >= 2:
                            nl = (new_w[-2], merged)
                            pair_counts[nl] += f
                            pair_to_words[nl].add(idx)

                        if i + 2 < len(w):
                            nr = (merged, w[i + 2])
                            pair_counts[nr] += f
                            pair_to_words[nr].add(idx)

                        i += 2
                    else:
                        new_w.append(w[i])
                        i += 1

                words[idx] = new_w

            # Prune dead entries periodically
            if _ % 50 == 0:
                pair_counts = defaultdict(int, {k: v for k, v in pair_counts.items() if v > 0})

        # Build vocab
        self.vocab = {}
        idx = 0
        for ch in sorted(base_vocab):
            self.vocab[ch] = idx
            idx += 1
        for merge in self.merges:
            m = merge[0] + merge[1]
            if m not in self.vocab:
                self.vocab[m] = idx
                idx += 1
        self.vocab["<UNK>"] = idx
        self._id_to_tok = None

    def _apply_merge(self, word: tuple[str, ...], pair: tuple[str, str]) -> tuple[str, ...]:
        new: list[str] = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
                new.append(pair[0] + pair[1])
                i += 2
            else:
                new.append(word[i])
                i += 1
        return tuple(new)

    def encode_word(self, word: str) -> list[str]:
        bw = self._text_to_bytes(word)
        if len(bw) == 1:
            return [bw[0]]
        for merge in self.merges:
            bw = self._apply_merge(bw, merge)
        return list(bw)

    def encode(self, text: str) -> list[str]:
        tokens: list[str] = []
        for word in self._pre_tokenize(text):
            tokens.extend(self.encode_word(word))
        return tokens

    def encode_to_ids(self, text: str) -> list[int]:
        tokens = self.encode(text)
        unk_id = self.vocab.get("<UNK>", 0)
        return [self.vocab.get(t, unk_id) for t in tokens]

    def id_to_token(self, token_id: int) -> str:
        if self._id_to_tok is None:
            self._id_to_tok = {v: k for k, v in self.vocab.items()}
        return self._id_to_tok.get(token_id, "<UNK>")

    def decode_ids(self, ids: list[int]) -> str:
        return self.decode_tokens([self.id_to_token(i) for i in ids])

    def decode_tokens(self, tokens: list[str]) -> str:
        byte_str = "".join(tokens)
        return bytearray(BYTE_DECODER.get(c, ord(c)) for c in byte_str).decode("utf-8", errors="replace")

    def save(self, path: str) -> None:
        with open(path, "w") as f:
            json.dump({
                "version": "json-tokenizer-bpe-v1",
                "vocab_size": self.vocab_size,
                "min_frequency": self.min_frequency,
                "merges": [list(m) for m in self.merges],
                "vocab": self.vocab,
            }, f, indent=2)

    @classmethod
    def load(cls, path: str) -> "BPETrainer":
        with open(path) as f:
            data = json.load(f)
        t = cls(vocab_size=data["vocab_size"], min_frequency=data["min_frequency"])
        t.merges = [tuple(m) for m in data["merges"]]
        t.vocab = data["vocab"]
        t._id_to_tok = None
        return t