anthonym21 commited on
Commit
8ae05a9
·
verified ·
1 Parent(s): 7e0e0d5

Upload json_tokenizer/bpe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. json_tokenizer/bpe.py +229 -0
json_tokenizer/bpe.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte-Pair Encoding trainer and codec optimized for JSON value strings.
3
+
4
+ Uses incremental pair counting with pair→word index for fast merges.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import re
11
+ from collections import defaultdict
12
+ from typing import Optional
13
+
14
+
15
+ def _bytes_to_unicode() -> dict[int, str]:
16
+ """Map bytes 0-255 to unicode chars, avoiding control/whitespace collisions."""
17
+ bs = (
18
+ list(range(ord("!"), ord("~") + 1))
19
+ + list(range(ord("¡"), ord("¬") + 1))
20
+ + list(range(ord("®"), ord("ÿ") + 1))
21
+ )
22
+ cs = bs[:]
23
+ n = 0
24
+ for b in range(2**8):
25
+ if b not in bs:
26
+ bs.append(b)
27
+ cs.append(2**8 + n)
28
+ n += 1
29
+ return {b: chr(c) for b, c in zip(bs, cs)}
30
+
31
+
32
+ BYTE_ENCODER = _bytes_to_unicode()
33
+ BYTE_DECODER = {v: k for k, v in BYTE_ENCODER.items()}
34
+
35
+ _PRE_TOK_PAT = re.compile(
36
+ r"""'s|'t|'re|'ve|'m|'ll|'d| ?[a-zA-Z_]+| ?[0-9]+| ?[^\s\w]+|\s+|."""
37
+ )
38
+
39
+
40
+ class BPETrainer:
41
+ """Train a BPE vocabulary from a corpus of JSON value strings."""
42
+
43
+ def __init__(self, vocab_size: int = 4096, min_frequency: int = 2):
44
+ self.vocab_size = vocab_size
45
+ self.min_frequency = min_frequency
46
+ self.merges: list[tuple[str, str]] = []
47
+ self.vocab: dict[str, int] = {}
48
+ self._id_to_tok: dict[int, str] | None = None
49
+
50
+ def _pre_tokenize(self, text: str) -> list[str]:
51
+ return _PRE_TOK_PAT.findall(text)
52
+
53
+ def _text_to_bytes(self, text: str) -> tuple[str, ...]:
54
+ return tuple(BYTE_ENCODER[b] for b in text.encode("utf-8"))
55
+
56
+ def train(self, texts: list[str]) -> None:
57
+ """Train BPE with pair→word index for O(affected) merges."""
58
+ # Count word frequencies
59
+ word_freqs: dict[tuple[str, ...], int] = {}
60
+ for text in texts:
61
+ for word in self._pre_tokenize(text):
62
+ bw = self._text_to_bytes(word)
63
+ word_freqs[bw] = word_freqs.get(bw, 0) + 1
64
+
65
+ # Base vocab
66
+ base_vocab: set[str] = set()
67
+ for word in word_freqs:
68
+ base_vocab.update(word)
69
+
70
+ num_merges = self.vocab_size - len(base_vocab) - 1
71
+
72
+ # Word storage: idx → [symbols], freq
73
+ words: list[list[str]] = []
74
+ freqs: list[int] = []
75
+ for w, f in word_freqs.items():
76
+ words.append(list(w))
77
+ freqs.append(f)
78
+
79
+ # Pair counts and pair→word indices
80
+ pair_counts: dict[tuple[str, str], int] = defaultdict(int)
81
+ pair_to_words: dict[tuple[str, str], set[int]] = defaultdict(set)
82
+
83
+ for idx, (w, f) in enumerate(zip(words, freqs)):
84
+ for i in range(len(w) - 1):
85
+ p = (w[i], w[i + 1])
86
+ pair_counts[p] += f
87
+ pair_to_words[p].add(idx)
88
+
89
+ for _ in range(max(0, num_merges)):
90
+ if not pair_counts:
91
+ break
92
+
93
+ # Find best pair
94
+ best_pair = max(pair_counts, key=pair_counts.__getitem__)
95
+ if pair_counts[best_pair] < self.min_frequency:
96
+ break
97
+
98
+ a, b = best_pair
99
+ merged = a + b
100
+ self.merges.append(best_pair)
101
+
102
+ # Only process words that contain this pair
103
+ affected = list(pair_to_words.pop(best_pair, set()))
104
+ del pair_counts[best_pair]
105
+
106
+ for idx in affected:
107
+ w = words[idx]
108
+ f = freqs[idx]
109
+
110
+ # Find positions of the pair
111
+ new_w: list[str] = []
112
+ i = 0
113
+ while i < len(w):
114
+ if i < len(w) - 1 and w[i] == a and w[i + 1] == b:
115
+ # Decrement old adjacent pairs
116
+ if new_w:
117
+ old_left = (new_w[-1], a)
118
+ pair_counts[old_left] -= f
119
+ if pair_counts[old_left] <= 0:
120
+ pair_counts.pop(old_left, None)
121
+ pair_to_words[old_left].discard(idx)
122
+
123
+ if i + 2 < len(w):
124
+ old_right = (b, w[i + 2])
125
+ pair_counts[old_right] -= f
126
+ if pair_counts[old_right] <= 0:
127
+ pair_counts.pop(old_right, None)
128
+ pair_to_words[old_right].discard(idx)
129
+
130
+ new_w.append(merged)
131
+
132
+ # Increment new adjacent pairs
133
+ if len(new_w) >= 2:
134
+ nl = (new_w[-2], merged)
135
+ pair_counts[nl] += f
136
+ pair_to_words[nl].add(idx)
137
+
138
+ if i + 2 < len(w):
139
+ nr = (merged, w[i + 2])
140
+ pair_counts[nr] += f
141
+ pair_to_words[nr].add(idx)
142
+
143
+ i += 2
144
+ else:
145
+ new_w.append(w[i])
146
+ i += 1
147
+
148
+ words[idx] = new_w
149
+
150
+ # Prune dead entries periodically
151
+ if _ % 50 == 0:
152
+ pair_counts = defaultdict(int, {k: v for k, v in pair_counts.items() if v > 0})
153
+
154
+ # Build vocab
155
+ self.vocab = {}
156
+ idx = 0
157
+ for ch in sorted(base_vocab):
158
+ self.vocab[ch] = idx
159
+ idx += 1
160
+ for merge in self.merges:
161
+ m = merge[0] + merge[1]
162
+ if m not in self.vocab:
163
+ self.vocab[m] = idx
164
+ idx += 1
165
+ self.vocab["<UNK>"] = idx
166
+ self._id_to_tok = None
167
+
168
+ def _apply_merge(self, word: tuple[str, ...], pair: tuple[str, str]) -> tuple[str, ...]:
169
+ new: list[str] = []
170
+ i = 0
171
+ while i < len(word):
172
+ if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
173
+ new.append(pair[0] + pair[1])
174
+ i += 2
175
+ else:
176
+ new.append(word[i])
177
+ i += 1
178
+ return tuple(new)
179
+
180
+ def encode_word(self, word: str) -> list[str]:
181
+ bw = self._text_to_bytes(word)
182
+ if len(bw) == 1:
183
+ return [bw[0]]
184
+ for merge in self.merges:
185
+ bw = self._apply_merge(bw, merge)
186
+ return list(bw)
187
+
188
+ def encode(self, text: str) -> list[str]:
189
+ tokens: list[str] = []
190
+ for word in self._pre_tokenize(text):
191
+ tokens.extend(self.encode_word(word))
192
+ return tokens
193
+
194
+ def encode_to_ids(self, text: str) -> list[int]:
195
+ tokens = self.encode(text)
196
+ unk_id = self.vocab.get("<UNK>", 0)
197
+ return [self.vocab.get(t, unk_id) for t in tokens]
198
+
199
+ def id_to_token(self, token_id: int) -> str:
200
+ if self._id_to_tok is None:
201
+ self._id_to_tok = {v: k for k, v in self.vocab.items()}
202
+ return self._id_to_tok.get(token_id, "<UNK>")
203
+
204
+ def decode_ids(self, ids: list[int]) -> str:
205
+ return self.decode_tokens([self.id_to_token(i) for i in ids])
206
+
207
+ def decode_tokens(self, tokens: list[str]) -> str:
208
+ byte_str = "".join(tokens)
209
+ return bytearray(BYTE_DECODER.get(c, ord(c)) for c in byte_str).decode("utf-8", errors="replace")
210
+
211
+ def save(self, path: str) -> None:
212
+ with open(path, "w") as f:
213
+ json.dump({
214
+ "version": "json-tokenizer-bpe-v1",
215
+ "vocab_size": self.vocab_size,
216
+ "min_frequency": self.min_frequency,
217
+ "merges": [list(m) for m in self.merges],
218
+ "vocab": self.vocab,
219
+ }, f, indent=2)
220
+
221
+ @classmethod
222
+ def load(cls, path: str) -> "BPETrainer":
223
+ with open(path) as f:
224
+ data = json.load(f)
225
+ t = cls(vocab_size=data["vocab_size"], min_frequency=data["min_frequency"])
226
+ t.merges = [tuple(m) for m in data["merges"]]
227
+ t.vocab = data["vocab"]
228
+ t._id_to_tok = None
229
+ return t