anthonym21 commited on
Commit
7e0e0d5
Β·
verified Β·
1 Parent(s): f3a6aa4

Upload json_tokenizer/tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. json_tokenizer/tokenizer.py +572 -0
json_tokenizer/tokenizer.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ JSON-optimized tokenizer.
3
+
4
+ Design principles:
5
+ 1. Structural tokens: JSON grammar symbols ({, }, [, ], :, comma) each get
6
+ a dedicated single token β€” no wasted subword splits on syntax.
7
+ 2. Key vocabulary: Frequently occurring JSON keys get their own tokens
8
+ (Key(name), Key(id), etc.), massively reducing token count for
9
+ repetitive schemas.
10
+ 3. Type-prefixed values: Values are prefixed with a type marker
11
+ (STR:, NUM:, BOOL:, NULL) so the tokenizer preserves JSON types
12
+ for lossless roundtrip.
13
+ 4. BPE for value content: String and number content is tokenized via
14
+ a BPE codec trained on JSON value distributions.
15
+ 5. Nesting tokens: [OBJ_START]/[OBJ_END] and Array(N) tokens encode
16
+ hierarchy without ambiguity.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import re
23
+ from collections import Counter
24
+ from typing import Any, Optional, Union
25
+
26
+ from json_tokenizer.bpe import BPETrainer
27
+
28
+
29
+ # ── Structural token constants ──────────────────────────────────────────
30
+ class StructuralTokens:
31
+ """Reserved token IDs for JSON grammar elements."""
32
+
33
+ PAD = 0
34
+ START = 1 # start of JSON document
35
+ END = 2 # end of JSON document
36
+ OBJ_START = 3 # {
37
+ OBJ_END = 4 # }
38
+ ARR_START = 5 # [ (generic, length encoded separately)
39
+ ARR_END = 6 # ]
40
+ COLON = 7 # :
41
+ COMMA = 8 # ,
42
+ NULL = 9 # null value
43
+ TRUE = 10 # true
44
+ FALSE = 11 # false
45
+ STR_DELIM = 12 # marks start/end of a string value
46
+ NUM_PREFIX = 13 # marks start of a number value
47
+ KEY_PREFIX = 14 # marks start of a key (if not in key vocab)
48
+ UNK = 15 # unknown token
49
+
50
+ # IDs 16-31 reserved for future structural tokens
51
+ RESERVED_END = 32
52
+
53
+ @classmethod
54
+ def name(cls, token_id: int) -> str:
55
+ _names = {
56
+ 0: "[PAD]",
57
+ 1: "[START]",
58
+ 2: "[END]",
59
+ 3: "{",
60
+ 4: "}",
61
+ 5: "[",
62
+ 6: "]",
63
+ 7: ":",
64
+ 8: ",",
65
+ 9: "null",
66
+ 10: "true",
67
+ 11: "false",
68
+ 12: "[STR]",
69
+ 13: "[NUM]",
70
+ 14: "[KEY]",
71
+ 15: "[UNK]",
72
+ }
73
+ return _names.get(token_id, f"[RESERVED_{token_id}]")
74
+
75
+
76
+ class JSONTokenizer:
77
+ """Tokenizer optimized for JSON structures.
78
+
79
+ Encodes JSON into a compact token sequence with:
80
+ - Single tokens for structural elements
81
+ - Dedicated key tokens for common keys
82
+ - BPE subword tokens for string/number values
83
+ - Full roundtrip fidelity (encode β†’ decode == original)
84
+
85
+ Usage:
86
+ tokenizer = JSONTokenizer()
87
+ tokenizer.train_from_json_files(["data1.json", "data2.json"])
88
+ ids = tokenizer.encode('{"name": "Alice", "age": 30}')
89
+ decoded = tokenizer.decode(ids)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ bpe_vocab_size: int = 4096,
95
+ max_key_vocab: int = 1024,
96
+ min_key_freq: int = 2,
97
+ bpe_min_freq: int = 2,
98
+ ):
99
+ self.bpe_vocab_size = bpe_vocab_size
100
+ self.max_key_vocab = max_key_vocab
101
+ self.min_key_freq = min_key_freq
102
+ self.bpe_min_freq = bpe_min_freq
103
+
104
+ # Key vocabulary: key_string β†’ token_id
105
+ self._key_to_id: dict[str, int] = {}
106
+ self._id_to_key: dict[int, str] = {}
107
+ self._key_offset = StructuralTokens.RESERVED_END
108
+
109
+ # BPE for values
110
+ self._bpe = BPETrainer(vocab_size=bpe_vocab_size, min_frequency=bpe_min_freq)
111
+ self._bpe_offset = 0 # set after key vocab is built
112
+
113
+ # Full vocab
114
+ self._id_to_token: dict[int, str] = {}
115
+ self._token_to_id: dict[str, int] = {}
116
+ self._trained = False
117
+
118
+ @property
119
+ def vocab_size(self) -> int:
120
+ """Total vocabulary size."""
121
+ if not self._trained:
122
+ return StructuralTokens.RESERVED_END
123
+ return self._bpe_offset + len(self._bpe.vocab)
124
+
125
+ # ── Training ────────────────────────────────────────────────────────
126
+
127
+ def train(self, json_objects: list[Any]) -> None:
128
+ """Train the tokenizer from a list of parsed JSON objects.
129
+
130
+ Extracts keys for the key vocabulary and values for BPE training.
131
+
132
+ Args:
133
+ json_objects: List of parsed JSON values (dicts, lists, primitives).
134
+ """
135
+ key_counter: Counter[str] = Counter()
136
+ value_strings: list[str] = []
137
+
138
+ for obj in json_objects:
139
+ self._extract_keys_and_values(obj, key_counter, value_strings)
140
+
141
+ # Build key vocabulary from most common keys
142
+ top_keys = [
143
+ k
144
+ for k, count in key_counter.most_common(self.max_key_vocab)
145
+ if count >= self.min_key_freq
146
+ ]
147
+
148
+ self._key_to_id = {}
149
+ self._id_to_key = {}
150
+ for i, key in enumerate(top_keys):
151
+ tid = self._key_offset + i
152
+ self._key_to_id[key] = tid
153
+ self._id_to_key[tid] = key
154
+
155
+ # BPE offset is after key vocab
156
+ self._bpe_offset = self._key_offset + len(self._key_to_id)
157
+
158
+ # Train BPE on value strings
159
+ if value_strings:
160
+ self._bpe.train(value_strings)
161
+
162
+ # Build full vocab lookup
163
+ self._build_vocab_lookup()
164
+ self._trained = True
165
+
166
+ def train_from_json_strings(self, json_strings: list[str]) -> None:
167
+ """Train from raw JSON strings."""
168
+ objects = []
169
+ for s in json_strings:
170
+ try:
171
+ objects.append(json.loads(s))
172
+ except json.JSONDecodeError:
173
+ continue
174
+ self.train(objects)
175
+
176
+ def train_from_json_files(self, file_paths: list[str]) -> None:
177
+ """Train from JSON files (one JSON object per file, or JSONL)."""
178
+ objects = []
179
+ for path in file_paths:
180
+ with open(path) as f:
181
+ content = f.read().strip()
182
+ # Try as single JSON object
183
+ try:
184
+ obj = json.loads(content)
185
+ if isinstance(obj, list):
186
+ objects.extend(obj)
187
+ else:
188
+ objects.append(obj)
189
+ continue
190
+ except json.JSONDecodeError:
191
+ pass
192
+ # Try as JSONL
193
+ for line in content.splitlines():
194
+ line = line.strip()
195
+ if line:
196
+ try:
197
+ objects.append(json.loads(line))
198
+ except json.JSONDecodeError:
199
+ continue
200
+ self.train(objects)
201
+
202
+ def _extract_keys_and_values(
203
+ self,
204
+ obj: Any,
205
+ key_counter: Counter[str],
206
+ value_strings: list[str],
207
+ ) -> None:
208
+ """Recursively extract keys and value strings from a JSON object."""
209
+ if isinstance(obj, dict):
210
+ for key, value in obj.items():
211
+ key_counter[key] += 1
212
+ # Also train BPE on key strings (they appear as values too)
213
+ value_strings.append(key)
214
+ self._extract_keys_and_values(value, key_counter, value_strings)
215
+ elif isinstance(obj, list):
216
+ for item in obj:
217
+ self._extract_keys_and_values(item, key_counter, value_strings)
218
+ elif isinstance(obj, str):
219
+ value_strings.append(obj)
220
+ elif isinstance(obj, (int, float)):
221
+ value_strings.append(str(obj))
222
+ # bool and None don't need BPE (they're structural tokens)
223
+
224
+ def _build_vocab_lookup(self) -> None:
225
+ """Build the complete id↔token mappings."""
226
+ self._id_to_token = {}
227
+ self._token_to_id = {}
228
+
229
+ # Structural tokens
230
+ for i in range(StructuralTokens.RESERVED_END):
231
+ name = StructuralTokens.name(i)
232
+ self._id_to_token[i] = name
233
+ self._token_to_id[name] = i
234
+
235
+ # Key tokens
236
+ for key, tid in self._key_to_id.items():
237
+ token_name = f"Key({key})"
238
+ self._id_to_token[tid] = token_name
239
+ self._token_to_id[token_name] = tid
240
+
241
+ # BPE tokens
242
+ for bpe_token, bpe_id in self._bpe.vocab.items():
243
+ full_id = self._bpe_offset + bpe_id
244
+ self._id_to_token[full_id] = f"BPE({bpe_token})"
245
+ self._token_to_id[f"BPE({bpe_token})"] = full_id
246
+
247
+ # ── Encoding ────────────────────────────────────────────────────────
248
+
249
+ def encode(self, json_input: Union[str, Any]) -> list[int]:
250
+ """Encode a JSON string or parsed object into token IDs.
251
+
252
+ Args:
253
+ json_input: Either a JSON string or an already-parsed Python object.
254
+
255
+ Returns:
256
+ List of integer token IDs.
257
+ """
258
+ if isinstance(json_input, str):
259
+ try:
260
+ obj = json.loads(json_input)
261
+ except json.JSONDecodeError:
262
+ raise ValueError(f"Invalid JSON: {json_input[:100]}...")
263
+ else:
264
+ obj = json_input
265
+
266
+ tokens = [StructuralTokens.START]
267
+ self._encode_value(obj, tokens)
268
+ tokens.append(StructuralTokens.END)
269
+ return tokens
270
+
271
+ def _encode_value(self, value: Any, tokens: list[int]) -> None:
272
+ """Recursively encode a JSON value into tokens."""
273
+ if isinstance(value, dict):
274
+ self._encode_object(value, tokens)
275
+ elif isinstance(value, list):
276
+ self._encode_array(value, tokens)
277
+ elif isinstance(value, str):
278
+ self._encode_string(value, tokens)
279
+ elif isinstance(value, bool):
280
+ # Must check bool before int (bool is subclass of int in Python)
281
+ tokens.append(StructuralTokens.TRUE if value else StructuralTokens.FALSE)
282
+ elif isinstance(value, (int, float)):
283
+ self._encode_number(value, tokens)
284
+ elif value is None:
285
+ tokens.append(StructuralTokens.NULL)
286
+ else:
287
+ tokens.append(StructuralTokens.UNK)
288
+
289
+ def _encode_object(self, obj: dict, tokens: list[int]) -> None:
290
+ """Encode a JSON object."""
291
+ tokens.append(StructuralTokens.OBJ_START)
292
+ for i, (key, value) in enumerate(obj.items()):
293
+ if i > 0:
294
+ tokens.append(StructuralTokens.COMMA)
295
+ self._encode_key(key, tokens)
296
+ tokens.append(StructuralTokens.COLON)
297
+ self._encode_value(value, tokens)
298
+ tokens.append(StructuralTokens.OBJ_END)
299
+
300
+ def _encode_array(self, arr: list, tokens: list[int]) -> None:
301
+ """Encode a JSON array."""
302
+ tokens.append(StructuralTokens.ARR_START)
303
+ for i, item in enumerate(arr):
304
+ if i > 0:
305
+ tokens.append(StructuralTokens.COMMA)
306
+ self._encode_value(item, tokens)
307
+ tokens.append(StructuralTokens.ARR_END)
308
+
309
+ def _encode_key(self, key: str, tokens: list[int]) -> None:
310
+ """Encode a JSON key β€” uses key vocab if available, else BPE."""
311
+ if key in self._key_to_id:
312
+ tokens.append(self._key_to_id[key])
313
+ else:
314
+ tokens.append(StructuralTokens.KEY_PREFIX)
315
+ bpe_ids = self._bpe.encode_to_ids(key)
316
+ tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
317
+
318
+ def _encode_string(self, value: str, tokens: list[int]) -> None:
319
+ """Encode a JSON string value."""
320
+ tokens.append(StructuralTokens.STR_DELIM)
321
+ if value: # don't BPE-encode empty strings
322
+ bpe_ids = self._bpe.encode_to_ids(value)
323
+ tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
324
+ tokens.append(StructuralTokens.STR_DELIM)
325
+
326
+ def _encode_number(self, value: Union[int, float], tokens: list[int]) -> None:
327
+ """Encode a JSON number value."""
328
+ tokens.append(StructuralTokens.NUM_PREFIX)
329
+ # Preserve int vs float distinction
330
+ if isinstance(value, float) and value == int(value) and "." in str(value):
331
+ text = str(value)
332
+ elif isinstance(value, int):
333
+ text = str(value)
334
+ else:
335
+ text = repr(value)
336
+ bpe_ids = self._bpe.encode_to_ids(text)
337
+ tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
338
+
339
+ # ── Decoding ────────────────────────────────────────────────────────
340
+
341
+ def decode(self, token_ids: list[int]) -> str:
342
+ """Decode token IDs back to a JSON string.
343
+
344
+ Args:
345
+ token_ids: List of integer token IDs from encode().
346
+
347
+ Returns:
348
+ JSON string faithful to the original.
349
+ """
350
+ obj = self._decode_to_object(token_ids)
351
+ return json.dumps(obj, ensure_ascii=False)
352
+
353
+ def decode_to_object(self, token_ids: list[int]) -> Any:
354
+ """Decode token IDs back to a Python object."""
355
+ return self._decode_to_object(token_ids)
356
+
357
+ def _decode_to_object(self, token_ids: list[int]) -> Any:
358
+ """Parse token IDs back into a Python object."""
359
+ # Strip START/END
360
+ ids = list(token_ids)
361
+ if ids and ids[0] == StructuralTokens.START:
362
+ ids = ids[1:]
363
+ if ids and ids[-1] == StructuralTokens.END:
364
+ ids = ids[:-1]
365
+
366
+ result, _ = self._parse_value(ids, 0)
367
+ return result
368
+
369
+ def _parse_value(self, ids: list[int], pos: int) -> tuple[Any, int]:
370
+ """Parse a single value starting at position pos."""
371
+ if pos >= len(ids):
372
+ return None, pos
373
+
374
+ tid = ids[pos]
375
+
376
+ if tid == StructuralTokens.OBJ_START:
377
+ return self._parse_object(ids, pos)
378
+ elif tid == StructuralTokens.ARR_START:
379
+ return self._parse_array(ids, pos)
380
+ elif tid == StructuralTokens.STR_DELIM:
381
+ return self._parse_string(ids, pos)
382
+ elif tid == StructuralTokens.NUM_PREFIX:
383
+ return self._parse_number(ids, pos)
384
+ elif tid == StructuralTokens.NULL:
385
+ return None, pos + 1
386
+ elif tid == StructuralTokens.TRUE:
387
+ return True, pos + 1
388
+ elif tid == StructuralTokens.FALSE:
389
+ return False, pos + 1
390
+ else:
391
+ return None, pos + 1
392
+
393
+ def _parse_object(self, ids: list[int], pos: int) -> tuple[dict, int]:
394
+ """Parse a JSON object from token IDs."""
395
+ assert ids[pos] == StructuralTokens.OBJ_START
396
+ pos += 1
397
+ result: dict[str, Any] = {}
398
+
399
+ while pos < len(ids) and ids[pos] != StructuralTokens.OBJ_END:
400
+ if ids[pos] == StructuralTokens.COMMA:
401
+ pos += 1
402
+ continue
403
+
404
+ # Parse key
405
+ key, pos = self._parse_key(ids, pos)
406
+
407
+ # Expect colon
408
+ if pos < len(ids) and ids[pos] == StructuralTokens.COLON:
409
+ pos += 1
410
+
411
+ # Parse value
412
+ value, pos = self._parse_value(ids, pos)
413
+ result[key] = value
414
+
415
+ if pos < len(ids) and ids[pos] == StructuralTokens.OBJ_END:
416
+ pos += 1
417
+
418
+ return result, pos
419
+
420
+ def _parse_array(self, ids: list[int], pos: int) -> tuple[list, int]:
421
+ """Parse a JSON array from token IDs."""
422
+ assert ids[pos] == StructuralTokens.ARR_START
423
+ pos += 1
424
+ result: list[Any] = []
425
+
426
+ while pos < len(ids) and ids[pos] != StructuralTokens.ARR_END:
427
+ if ids[pos] == StructuralTokens.COMMA:
428
+ pos += 1
429
+ continue
430
+
431
+ value, pos = self._parse_value(ids, pos)
432
+ result.append(value)
433
+
434
+ if pos < len(ids) and ids[pos] == StructuralTokens.ARR_END:
435
+ pos += 1
436
+
437
+ return result, pos
438
+
439
+ def _parse_key(self, ids: list[int], pos: int) -> tuple[str, int]:
440
+ """Parse a key from token IDs."""
441
+ tid = ids[pos]
442
+
443
+ # Check key vocabulary
444
+ if tid in self._id_to_key:
445
+ return self._id_to_key[tid], pos + 1
446
+
447
+ # KEY_PREFIX β†’ BPE-encoded key
448
+ if tid == StructuralTokens.KEY_PREFIX:
449
+ pos += 1
450
+ bpe_tokens: list[str] = []
451
+ while pos < len(ids) and ids[pos] >= self._bpe_offset:
452
+ bpe_id = ids[pos] - self._bpe_offset
453
+ bpe_tokens.append(self._bpe.id_to_token(bpe_id))
454
+ pos += 1
455
+ # Stop before COLON
456
+ if pos < len(ids) and ids[pos] == StructuralTokens.COLON:
457
+ break
458
+ return self._bpe.decode_tokens(bpe_tokens), pos
459
+
460
+ return f"<unknown_key_{tid}>", pos + 1
461
+
462
+ def _parse_string(self, ids: list[int], pos: int) -> tuple[str, int]:
463
+ """Parse a string value from token IDs."""
464
+ assert ids[pos] == StructuralTokens.STR_DELIM
465
+ pos += 1
466
+
467
+ bpe_tokens: list[str] = []
468
+ while pos < len(ids) and ids[pos] != StructuralTokens.STR_DELIM:
469
+ bpe_id = ids[pos] - self._bpe_offset
470
+ bpe_tokens.append(self._bpe.id_to_token(bpe_id))
471
+ pos += 1
472
+
473
+ # Skip closing delimiter
474
+ if pos < len(ids) and ids[pos] == StructuralTokens.STR_DELIM:
475
+ pos += 1
476
+
477
+ return self._bpe.decode_tokens(bpe_tokens), pos
478
+
479
+ def _parse_number(self, ids: list[int], pos: int) -> tuple[Union[int, float], int]:
480
+ """Parse a number value from token IDs."""
481
+ assert ids[pos] == StructuralTokens.NUM_PREFIX
482
+ pos += 1
483
+
484
+ bpe_tokens: list[str] = []
485
+ while pos < len(ids):
486
+ tid = ids[pos]
487
+ if tid < self._bpe_offset:
488
+ break # hit a structural token
489
+ bpe_id = tid - self._bpe_offset
490
+ bpe_tokens.append(self._bpe.id_to_token(bpe_id))
491
+ pos += 1
492
+
493
+ text = self._bpe.decode_tokens(bpe_tokens).strip()
494
+ try:
495
+ if "." in text or "e" in text.lower():
496
+ return float(text), pos
497
+ return int(text), pos
498
+ except ValueError:
499
+ return 0, pos
500
+
501
+ # ── Inspection / Debug ──────────────────────────────────────────────
502
+
503
+ def decode_tokens_readable(self, token_ids: list[int]) -> list[str]:
504
+ """Convert token IDs to human-readable token names."""
505
+ result: list[str] = []
506
+ for tid in token_ids:
507
+ if tid in self._id_to_token:
508
+ result.append(self._id_to_token[tid])
509
+ elif tid in self._id_to_key:
510
+ result.append(f"Key({self._id_to_key[tid]})")
511
+ else:
512
+ bpe_id = tid - self._bpe_offset
513
+ token_str = self._bpe.id_to_token(bpe_id)
514
+ result.append(f"BPE({repr(token_str)})")
515
+ return result
516
+
517
+ def token_count(self, json_input: Union[str, Any]) -> int:
518
+ """Count tokens for a JSON input without materializing full list."""
519
+ return len(self.encode(json_input))
520
+
521
+ # ── Persistence ─────────────────────────────────────────────────────
522
+
523
+ def save(self, directory: str) -> None:
524
+ """Save the full tokenizer state to a directory."""
525
+ import os
526
+
527
+ os.makedirs(directory, exist_ok=True)
528
+
529
+ # Save BPE model
530
+ self._bpe.save(os.path.join(directory, "bpe_model.json"))
531
+
532
+ # Save key vocabulary and config
533
+ config = {
534
+ "version": "json-tokenizer-v1",
535
+ "bpe_vocab_size": self.bpe_vocab_size,
536
+ "max_key_vocab": self.max_key_vocab,
537
+ "min_key_freq": self.min_key_freq,
538
+ "bpe_min_freq": self.bpe_min_freq,
539
+ "key_vocab": self._key_to_id,
540
+ "key_offset": self._key_offset,
541
+ "bpe_offset": self._bpe_offset,
542
+ }
543
+ with open(os.path.join(directory, "tokenizer_config.json"), "w") as f:
544
+ json.dump(config, f, indent=2)
545
+
546
+ @classmethod
547
+ def load(cls, directory: str) -> "JSONTokenizer":
548
+ """Load a trained tokenizer from a directory."""
549
+ import os
550
+
551
+ with open(os.path.join(directory, "tokenizer_config.json")) as f:
552
+ config = json.load(f)
553
+
554
+ tokenizer = cls(
555
+ bpe_vocab_size=config["bpe_vocab_size"],
556
+ max_key_vocab=config["max_key_vocab"],
557
+ min_key_freq=config["min_key_freq"],
558
+ bpe_min_freq=config["bpe_min_freq"],
559
+ )
560
+
561
+ # Restore key vocab
562
+ tokenizer._key_to_id = config["key_vocab"]
563
+ tokenizer._id_to_key = {int(v): k for k, v in config["key_vocab"].items()}
564
+ tokenizer._key_offset = config["key_offset"]
565
+ tokenizer._bpe_offset = config["bpe_offset"]
566
+
567
+ # Load BPE
568
+ tokenizer._bpe = BPETrainer.load(os.path.join(directory, "bpe_model.json"))
569
+
570
+ tokenizer._build_vocab_lookup()
571
+ tokenizer._trained = True
572
+ return tokenizer