| | """ |
| | JSON-optimized tokenizer. |
| | |
| | Design principles: |
| | 1. Structural tokens: JSON grammar symbols ({, }, [, ], :, comma) each get |
| | a dedicated single token β no wasted subword splits on syntax. |
| | 2. Key vocabulary: Frequently occurring JSON keys get their own tokens |
| | (Key(name), Key(id), etc.), massively reducing token count for |
| | repetitive schemas. |
| | 3. Type-prefixed values: Values are prefixed with a type marker |
| | (STR:, NUM:, BOOL:, NULL) so the tokenizer preserves JSON types |
| | for lossless roundtrip. |
| | 4. BPE for value content: String and number content is tokenized via |
| | a BPE codec trained on JSON value distributions. |
| | 5. Nesting tokens: [OBJ_START]/[OBJ_END] and Array(N) tokens encode |
| | hierarchy without ambiguity. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import re |
| | from collections import Counter |
| | from typing import Any, Optional, Union |
| |
|
| | from json_tokenizer.bpe import BPETrainer |
| |
|
| |
|
| | |
| | class StructuralTokens: |
| | """Reserved token IDs for JSON grammar elements.""" |
| |
|
| | PAD = 0 |
| | START = 1 |
| | END = 2 |
| | OBJ_START = 3 |
| | OBJ_END = 4 |
| | ARR_START = 5 |
| | ARR_END = 6 |
| | COLON = 7 |
| | COMMA = 8 |
| | NULL = 9 |
| | TRUE = 10 |
| | FALSE = 11 |
| | STR_DELIM = 12 |
| | NUM_PREFIX = 13 |
| | KEY_PREFIX = 14 |
| | UNK = 15 |
| |
|
| | |
| | RESERVED_END = 32 |
| |
|
| | @classmethod |
| | def name(cls, token_id: int) -> str: |
| | _names = { |
| | 0: "[PAD]", |
| | 1: "[START]", |
| | 2: "[END]", |
| | 3: "{", |
| | 4: "}", |
| | 5: "[", |
| | 6: "]", |
| | 7: ":", |
| | 8: ",", |
| | 9: "null", |
| | 10: "true", |
| | 11: "false", |
| | 12: "[STR]", |
| | 13: "[NUM]", |
| | 14: "[KEY]", |
| | 15: "[UNK]", |
| | } |
| | return _names.get(token_id, f"[RESERVED_{token_id}]") |
| |
|
| |
|
| | class JSONTokenizer: |
| | """Tokenizer optimized for JSON structures. |
| | |
| | Encodes JSON into a compact token sequence with: |
| | - Single tokens for structural elements |
| | - Dedicated key tokens for common keys |
| | - BPE subword tokens for string/number values |
| | - Full roundtrip fidelity (encode β decode == original) |
| | |
| | Usage: |
| | tokenizer = JSONTokenizer() |
| | tokenizer.train_from_json_files(["data1.json", "data2.json"]) |
| | ids = tokenizer.encode('{"name": "Alice", "age": 30}') |
| | decoded = tokenizer.decode(ids) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | bpe_vocab_size: int = 4096, |
| | max_key_vocab: int = 1024, |
| | min_key_freq: int = 2, |
| | bpe_min_freq: int = 2, |
| | ): |
| | self.bpe_vocab_size = bpe_vocab_size |
| | self.max_key_vocab = max_key_vocab |
| | self.min_key_freq = min_key_freq |
| | self.bpe_min_freq = bpe_min_freq |
| |
|
| | |
| | self._key_to_id: dict[str, int] = {} |
| | self._id_to_key: dict[int, str] = {} |
| | self._key_offset = StructuralTokens.RESERVED_END |
| |
|
| | |
| | self._bpe = BPETrainer(vocab_size=bpe_vocab_size, min_frequency=bpe_min_freq) |
| | self._bpe_offset = 0 |
| |
|
| | |
| | self._id_to_token: dict[int, str] = {} |
| | self._token_to_id: dict[str, int] = {} |
| | self._trained = False |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | """Total vocabulary size.""" |
| | if not self._trained: |
| | return StructuralTokens.RESERVED_END |
| | return self._bpe_offset + len(self._bpe.vocab) |
| |
|
| | |
| |
|
| | def train(self, json_objects: list[Any]) -> None: |
| | """Train the tokenizer from a list of parsed JSON objects. |
| | |
| | Extracts keys for the key vocabulary and values for BPE training. |
| | |
| | Args: |
| | json_objects: List of parsed JSON values (dicts, lists, primitives). |
| | """ |
| | key_counter: Counter[str] = Counter() |
| | value_strings: list[str] = [] |
| |
|
| | for obj in json_objects: |
| | self._extract_keys_and_values(obj, key_counter, value_strings) |
| |
|
| | |
| | top_keys = [ |
| | k |
| | for k, count in key_counter.most_common(self.max_key_vocab) |
| | if count >= self.min_key_freq |
| | ] |
| |
|
| | self._key_to_id = {} |
| | self._id_to_key = {} |
| | for i, key in enumerate(top_keys): |
| | tid = self._key_offset + i |
| | self._key_to_id[key] = tid |
| | self._id_to_key[tid] = key |
| |
|
| | |
| | self._bpe_offset = self._key_offset + len(self._key_to_id) |
| |
|
| | |
| | if value_strings: |
| | self._bpe.train(value_strings) |
| |
|
| | |
| | self._build_vocab_lookup() |
| | self._trained = True |
| |
|
| | def train_from_json_strings(self, json_strings: list[str]) -> None: |
| | """Train from raw JSON strings.""" |
| | objects = [] |
| | for s in json_strings: |
| | try: |
| | objects.append(json.loads(s)) |
| | except json.JSONDecodeError: |
| | continue |
| | self.train(objects) |
| |
|
| | def train_from_json_files(self, file_paths: list[str]) -> None: |
| | """Train from JSON files (one JSON object per file, or JSONL).""" |
| | objects = [] |
| | for path in file_paths: |
| | with open(path) as f: |
| | content = f.read().strip() |
| | |
| | try: |
| | obj = json.loads(content) |
| | if isinstance(obj, list): |
| | objects.extend(obj) |
| | else: |
| | objects.append(obj) |
| | continue |
| | except json.JSONDecodeError: |
| | pass |
| | |
| | for line in content.splitlines(): |
| | line = line.strip() |
| | if line: |
| | try: |
| | objects.append(json.loads(line)) |
| | except json.JSONDecodeError: |
| | continue |
| | self.train(objects) |
| |
|
| | def _extract_keys_and_values( |
| | self, |
| | obj: Any, |
| | key_counter: Counter[str], |
| | value_strings: list[str], |
| | ) -> None: |
| | """Recursively extract keys and value strings from a JSON object.""" |
| | if isinstance(obj, dict): |
| | for key, value in obj.items(): |
| | key_counter[key] += 1 |
| | |
| | value_strings.append(key) |
| | self._extract_keys_and_values(value, key_counter, value_strings) |
| | elif isinstance(obj, list): |
| | for item in obj: |
| | self._extract_keys_and_values(item, key_counter, value_strings) |
| | elif isinstance(obj, str): |
| | value_strings.append(obj) |
| | elif isinstance(obj, (int, float)): |
| | value_strings.append(str(obj)) |
| | |
| |
|
| | def _build_vocab_lookup(self) -> None: |
| | """Build the complete idβtoken mappings.""" |
| | self._id_to_token = {} |
| | self._token_to_id = {} |
| |
|
| | |
| | for i in range(StructuralTokens.RESERVED_END): |
| | name = StructuralTokens.name(i) |
| | self._id_to_token[i] = name |
| | self._token_to_id[name] = i |
| |
|
| | |
| | for key, tid in self._key_to_id.items(): |
| | token_name = f"Key({key})" |
| | self._id_to_token[tid] = token_name |
| | self._token_to_id[token_name] = tid |
| |
|
| | |
| | for bpe_token, bpe_id in self._bpe.vocab.items(): |
| | full_id = self._bpe_offset + bpe_id |
| | self._id_to_token[full_id] = f"BPE({bpe_token})" |
| | self._token_to_id[f"BPE({bpe_token})"] = full_id |
| |
|
| | |
| |
|
| | def encode(self, json_input: Union[str, Any]) -> list[int]: |
| | """Encode a JSON string or parsed object into token IDs. |
| | |
| | Args: |
| | json_input: Either a JSON string or an already-parsed Python object. |
| | |
| | Returns: |
| | List of integer token IDs. |
| | """ |
| | if isinstance(json_input, str): |
| | try: |
| | obj = json.loads(json_input) |
| | except json.JSONDecodeError: |
| | raise ValueError(f"Invalid JSON: {json_input[:100]}...") |
| | else: |
| | obj = json_input |
| |
|
| | tokens = [StructuralTokens.START] |
| | self._encode_value(obj, tokens) |
| | tokens.append(StructuralTokens.END) |
| | return tokens |
| |
|
| | def _encode_value(self, value: Any, tokens: list[int]) -> None: |
| | """Recursively encode a JSON value into tokens.""" |
| | if isinstance(value, dict): |
| | self._encode_object(value, tokens) |
| | elif isinstance(value, list): |
| | self._encode_array(value, tokens) |
| | elif isinstance(value, str): |
| | self._encode_string(value, tokens) |
| | elif isinstance(value, bool): |
| | |
| | tokens.append(StructuralTokens.TRUE if value else StructuralTokens.FALSE) |
| | elif isinstance(value, (int, float)): |
| | self._encode_number(value, tokens) |
| | elif value is None: |
| | tokens.append(StructuralTokens.NULL) |
| | else: |
| | tokens.append(StructuralTokens.UNK) |
| |
|
| | def _encode_object(self, obj: dict, tokens: list[int]) -> None: |
| | """Encode a JSON object.""" |
| | tokens.append(StructuralTokens.OBJ_START) |
| | for i, (key, value) in enumerate(obj.items()): |
| | if i > 0: |
| | tokens.append(StructuralTokens.COMMA) |
| | self._encode_key(key, tokens) |
| | tokens.append(StructuralTokens.COLON) |
| | self._encode_value(value, tokens) |
| | tokens.append(StructuralTokens.OBJ_END) |
| |
|
| | def _encode_array(self, arr: list, tokens: list[int]) -> None: |
| | """Encode a JSON array.""" |
| | tokens.append(StructuralTokens.ARR_START) |
| | for i, item in enumerate(arr): |
| | if i > 0: |
| | tokens.append(StructuralTokens.COMMA) |
| | self._encode_value(item, tokens) |
| | tokens.append(StructuralTokens.ARR_END) |
| |
|
| | def _encode_key(self, key: str, tokens: list[int]) -> None: |
| | """Encode a JSON key β uses key vocab if available, else BPE.""" |
| | if key in self._key_to_id: |
| | tokens.append(self._key_to_id[key]) |
| | else: |
| | tokens.append(StructuralTokens.KEY_PREFIX) |
| | bpe_ids = self._bpe.encode_to_ids(key) |
| | tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
| |
|
| | def _encode_string(self, value: str, tokens: list[int]) -> None: |
| | """Encode a JSON string value.""" |
| | tokens.append(StructuralTokens.STR_DELIM) |
| | if value: |
| | bpe_ids = self._bpe.encode_to_ids(value) |
| | tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
| | tokens.append(StructuralTokens.STR_DELIM) |
| |
|
| | def _encode_number(self, value: Union[int, float], tokens: list[int]) -> None: |
| | """Encode a JSON number value.""" |
| | tokens.append(StructuralTokens.NUM_PREFIX) |
| | |
| | if isinstance(value, float) and value == int(value) and "." in str(value): |
| | text = str(value) |
| | elif isinstance(value, int): |
| | text = str(value) |
| | else: |
| | text = repr(value) |
| | bpe_ids = self._bpe.encode_to_ids(text) |
| | tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
| |
|
| | |
| |
|
| | def decode(self, token_ids: list[int]) -> str: |
| | """Decode token IDs back to a JSON string. |
| | |
| | Args: |
| | token_ids: List of integer token IDs from encode(). |
| | |
| | Returns: |
| | JSON string faithful to the original. |
| | """ |
| | obj = self._decode_to_object(token_ids) |
| | return json.dumps(obj, ensure_ascii=False) |
| |
|
| | def decode_to_object(self, token_ids: list[int]) -> Any: |
| | """Decode token IDs back to a Python object.""" |
| | return self._decode_to_object(token_ids) |
| |
|
| | def _decode_to_object(self, token_ids: list[int]) -> Any: |
| | """Parse token IDs back into a Python object.""" |
| | |
| | ids = list(token_ids) |
| | if ids and ids[0] == StructuralTokens.START: |
| | ids = ids[1:] |
| | if ids and ids[-1] == StructuralTokens.END: |
| | ids = ids[:-1] |
| |
|
| | result, _ = self._parse_value(ids, 0) |
| | return result |
| |
|
| | def _parse_value(self, ids: list[int], pos: int) -> tuple[Any, int]: |
| | """Parse a single value starting at position pos.""" |
| | if pos >= len(ids): |
| | return None, pos |
| |
|
| | tid = ids[pos] |
| |
|
| | if tid == StructuralTokens.OBJ_START: |
| | return self._parse_object(ids, pos) |
| | elif tid == StructuralTokens.ARR_START: |
| | return self._parse_array(ids, pos) |
| | elif tid == StructuralTokens.STR_DELIM: |
| | return self._parse_string(ids, pos) |
| | elif tid == StructuralTokens.NUM_PREFIX: |
| | return self._parse_number(ids, pos) |
| | elif tid == StructuralTokens.NULL: |
| | return None, pos + 1 |
| | elif tid == StructuralTokens.TRUE: |
| | return True, pos + 1 |
| | elif tid == StructuralTokens.FALSE: |
| | return False, pos + 1 |
| | else: |
| | return None, pos + 1 |
| |
|
| | def _parse_object(self, ids: list[int], pos: int) -> tuple[dict, int]: |
| | """Parse a JSON object from token IDs.""" |
| | assert ids[pos] == StructuralTokens.OBJ_START |
| | pos += 1 |
| | result: dict[str, Any] = {} |
| |
|
| | while pos < len(ids) and ids[pos] != StructuralTokens.OBJ_END: |
| | if ids[pos] == StructuralTokens.COMMA: |
| | pos += 1 |
| | continue |
| |
|
| | |
| | key, pos = self._parse_key(ids, pos) |
| |
|
| | |
| | if pos < len(ids) and ids[pos] == StructuralTokens.COLON: |
| | pos += 1 |
| |
|
| | |
| | value, pos = self._parse_value(ids, pos) |
| | result[key] = value |
| |
|
| | if pos < len(ids) and ids[pos] == StructuralTokens.OBJ_END: |
| | pos += 1 |
| |
|
| | return result, pos |
| |
|
| | def _parse_array(self, ids: list[int], pos: int) -> tuple[list, int]: |
| | """Parse a JSON array from token IDs.""" |
| | assert ids[pos] == StructuralTokens.ARR_START |
| | pos += 1 |
| | result: list[Any] = [] |
| |
|
| | while pos < len(ids) and ids[pos] != StructuralTokens.ARR_END: |
| | if ids[pos] == StructuralTokens.COMMA: |
| | pos += 1 |
| | continue |
| |
|
| | value, pos = self._parse_value(ids, pos) |
| | result.append(value) |
| |
|
| | if pos < len(ids) and ids[pos] == StructuralTokens.ARR_END: |
| | pos += 1 |
| |
|
| | return result, pos |
| |
|
| | def _parse_key(self, ids: list[int], pos: int) -> tuple[str, int]: |
| | """Parse a key from token IDs.""" |
| | tid = ids[pos] |
| |
|
| | |
| | if tid in self._id_to_key: |
| | return self._id_to_key[tid], pos + 1 |
| |
|
| | |
| | if tid == StructuralTokens.KEY_PREFIX: |
| | pos += 1 |
| | bpe_tokens: list[str] = [] |
| | while pos < len(ids) and ids[pos] >= self._bpe_offset: |
| | bpe_id = ids[pos] - self._bpe_offset |
| | bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| | pos += 1 |
| | |
| | if pos < len(ids) and ids[pos] == StructuralTokens.COLON: |
| | break |
| | return self._bpe.decode_tokens(bpe_tokens), pos |
| |
|
| | return f"<unknown_key_{tid}>", pos + 1 |
| |
|
| | def _parse_string(self, ids: list[int], pos: int) -> tuple[str, int]: |
| | """Parse a string value from token IDs.""" |
| | assert ids[pos] == StructuralTokens.STR_DELIM |
| | pos += 1 |
| |
|
| | bpe_tokens: list[str] = [] |
| | while pos < len(ids) and ids[pos] != StructuralTokens.STR_DELIM: |
| | bpe_id = ids[pos] - self._bpe_offset |
| | bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| | pos += 1 |
| |
|
| | |
| | if pos < len(ids) and ids[pos] == StructuralTokens.STR_DELIM: |
| | pos += 1 |
| |
|
| | return self._bpe.decode_tokens(bpe_tokens), pos |
| |
|
| | def _parse_number(self, ids: list[int], pos: int) -> tuple[Union[int, float], int]: |
| | """Parse a number value from token IDs.""" |
| | assert ids[pos] == StructuralTokens.NUM_PREFIX |
| | pos += 1 |
| |
|
| | bpe_tokens: list[str] = [] |
| | while pos < len(ids): |
| | tid = ids[pos] |
| | if tid < self._bpe_offset: |
| | break |
| | bpe_id = tid - self._bpe_offset |
| | bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| | pos += 1 |
| |
|
| | text = self._bpe.decode_tokens(bpe_tokens).strip() |
| | try: |
| | if "." in text or "e" in text.lower(): |
| | return float(text), pos |
| | return int(text), pos |
| | except ValueError: |
| | return 0, pos |
| |
|
| | |
| |
|
| | def decode_tokens_readable(self, token_ids: list[int]) -> list[str]: |
| | """Convert token IDs to human-readable token names.""" |
| | result: list[str] = [] |
| | for tid in token_ids: |
| | if tid in self._id_to_token: |
| | result.append(self._id_to_token[tid]) |
| | elif tid in self._id_to_key: |
| | result.append(f"Key({self._id_to_key[tid]})") |
| | else: |
| | bpe_id = tid - self._bpe_offset |
| | token_str = self._bpe.id_to_token(bpe_id) |
| | result.append(f"BPE({repr(token_str)})") |
| | return result |
| |
|
| | def token_count(self, json_input: Union[str, Any]) -> int: |
| | """Count tokens for a JSON input without materializing full list.""" |
| | return len(self.encode(json_input)) |
| |
|
| | |
| |
|
| | def save(self, directory: str) -> None: |
| | """Save the full tokenizer state to a directory.""" |
| | import os |
| |
|
| | os.makedirs(directory, exist_ok=True) |
| |
|
| | |
| | self._bpe.save(os.path.join(directory, "bpe_model.json")) |
| |
|
| | |
| | config = { |
| | "version": "json-tokenizer-v1", |
| | "bpe_vocab_size": self.bpe_vocab_size, |
| | "max_key_vocab": self.max_key_vocab, |
| | "min_key_freq": self.min_key_freq, |
| | "bpe_min_freq": self.bpe_min_freq, |
| | "key_vocab": self._key_to_id, |
| | "key_offset": self._key_offset, |
| | "bpe_offset": self._bpe_offset, |
| | } |
| | with open(os.path.join(directory, "tokenizer_config.json"), "w") as f: |
| | json.dump(config, f, indent=2) |
| |
|
| | @classmethod |
| | def load(cls, directory: str) -> "JSONTokenizer": |
| | """Load a trained tokenizer from a directory.""" |
| | import os |
| |
|
| | with open(os.path.join(directory, "tokenizer_config.json")) as f: |
| | config = json.load(f) |
| |
|
| | tokenizer = cls( |
| | bpe_vocab_size=config["bpe_vocab_size"], |
| | max_key_vocab=config["max_key_vocab"], |
| | min_key_freq=config["min_key_freq"], |
| | bpe_min_freq=config["bpe_min_freq"], |
| | ) |
| |
|
| | |
| | tokenizer._key_to_id = config["key_vocab"] |
| | tokenizer._id_to_key = {int(v): k for k, v in config["key_vocab"].items()} |
| | tokenizer._key_offset = config["key_offset"] |
| | tokenizer._bpe_offset = config["bpe_offset"] |
| |
|
| | |
| | tokenizer._bpe = BPETrainer.load(os.path.join(directory, "bpe_model.json")) |
| |
|
| | tokenizer._build_vocab_lookup() |
| | tokenizer._trained = True |
| | return tokenizer |
| |
|