""" JSON-optimized tokenizer. Design principles: 1. Structural tokens: JSON grammar symbols ({, }, [, ], :, comma) each get a dedicated single token — no wasted subword splits on syntax. 2. Key vocabulary: Frequently occurring JSON keys get their own tokens (Key(name), Key(id), etc.), massively reducing token count for repetitive schemas. 3. Type-prefixed values: Values are prefixed with a type marker (STR:, NUM:, BOOL:, NULL) so the tokenizer preserves JSON types for lossless roundtrip. 4. BPE for value content: String and number content is tokenized via a BPE codec trained on JSON value distributions. 5. Nesting tokens: [OBJ_START]/[OBJ_END] and Array(N) tokens encode hierarchy without ambiguity. """ from __future__ import annotations import json import re from collections import Counter from typing import Any, Optional, Union from json_tokenizer.bpe import BPETrainer # ── Structural token constants ────────────────────────────────────────── class StructuralTokens: """Reserved token IDs for JSON grammar elements.""" PAD = 0 START = 1 # start of JSON document END = 2 # end of JSON document OBJ_START = 3 # { OBJ_END = 4 # } ARR_START = 5 # [ (generic, length encoded separately) ARR_END = 6 # ] COLON = 7 # : COMMA = 8 # , NULL = 9 # null value TRUE = 10 # true FALSE = 11 # false STR_DELIM = 12 # marks start/end of a string value NUM_PREFIX = 13 # marks start of a number value KEY_PREFIX = 14 # marks start of a key (if not in key vocab) UNK = 15 # unknown token # IDs 16-31 reserved for future structural tokens RESERVED_END = 32 @classmethod def name(cls, token_id: int) -> str: _names = { 0: "[PAD]", 1: "[START]", 2: "[END]", 3: "{", 4: "}", 5: "[", 6: "]", 7: ":", 8: ",", 9: "null", 10: "true", 11: "false", 12: "[STR]", 13: "[NUM]", 14: "[KEY]", 15: "[UNK]", } return _names.get(token_id, f"[RESERVED_{token_id}]") class JSONTokenizer: """Tokenizer optimized for JSON structures. Encodes JSON into a compact token sequence with: - Single tokens for structural elements - Dedicated key tokens for common keys - BPE subword tokens for string/number values - Full roundtrip fidelity (encode → decode == original) Usage: tokenizer = JSONTokenizer() tokenizer.train_from_json_files(["data1.json", "data2.json"]) ids = tokenizer.encode('{"name": "Alice", "age": 30}') decoded = tokenizer.decode(ids) """ def __init__( self, bpe_vocab_size: int = 4096, max_key_vocab: int = 1024, min_key_freq: int = 2, bpe_min_freq: int = 2, ): self.bpe_vocab_size = bpe_vocab_size self.max_key_vocab = max_key_vocab self.min_key_freq = min_key_freq self.bpe_min_freq = bpe_min_freq # Key vocabulary: key_string → token_id self._key_to_id: dict[str, int] = {} self._id_to_key: dict[int, str] = {} self._key_offset = StructuralTokens.RESERVED_END # BPE for values self._bpe = BPETrainer(vocab_size=bpe_vocab_size, min_frequency=bpe_min_freq) self._bpe_offset = 0 # set after key vocab is built # Full vocab self._id_to_token: dict[int, str] = {} self._token_to_id: dict[str, int] = {} self._trained = False @property def vocab_size(self) -> int: """Total vocabulary size.""" if not self._trained: return StructuralTokens.RESERVED_END return self._bpe_offset + len(self._bpe.vocab) # ── Training ──────────────────────────────────────────────────────── def train(self, json_objects: list[Any]) -> None: """Train the tokenizer from a list of parsed JSON objects. Extracts keys for the key vocabulary and values for BPE training. Args: json_objects: List of parsed JSON values (dicts, lists, primitives). """ key_counter: Counter[str] = Counter() value_strings: list[str] = [] for obj in json_objects: self._extract_keys_and_values(obj, key_counter, value_strings) # Build key vocabulary from most common keys top_keys = [ k for k, count in key_counter.most_common(self.max_key_vocab) if count >= self.min_key_freq ] self._key_to_id = {} self._id_to_key = {} for i, key in enumerate(top_keys): tid = self._key_offset + i self._key_to_id[key] = tid self._id_to_key[tid] = key # BPE offset is after key vocab self._bpe_offset = self._key_offset + len(self._key_to_id) # Train BPE on value strings if value_strings: self._bpe.train(value_strings) # Build full vocab lookup self._build_vocab_lookup() self._trained = True def train_from_json_strings(self, json_strings: list[str]) -> None: """Train from raw JSON strings.""" objects = [] for s in json_strings: try: objects.append(json.loads(s)) except json.JSONDecodeError: continue self.train(objects) def train_from_json_files(self, file_paths: list[str]) -> None: """Train from JSON files (one JSON object per file, or JSONL).""" objects = [] for path in file_paths: with open(path) as f: content = f.read().strip() # Try as single JSON object try: obj = json.loads(content) if isinstance(obj, list): objects.extend(obj) else: objects.append(obj) continue except json.JSONDecodeError: pass # Try as JSONL for line in content.splitlines(): line = line.strip() if line: try: objects.append(json.loads(line)) except json.JSONDecodeError: continue self.train(objects) def _extract_keys_and_values( self, obj: Any, key_counter: Counter[str], value_strings: list[str], ) -> None: """Recursively extract keys and value strings from a JSON object.""" if isinstance(obj, dict): for key, value in obj.items(): key_counter[key] += 1 # Also train BPE on key strings (they appear as values too) value_strings.append(key) self._extract_keys_and_values(value, key_counter, value_strings) elif isinstance(obj, list): for item in obj: self._extract_keys_and_values(item, key_counter, value_strings) elif isinstance(obj, str): value_strings.append(obj) elif isinstance(obj, (int, float)): value_strings.append(str(obj)) # bool and None don't need BPE (they're structural tokens) def _build_vocab_lookup(self) -> None: """Build the complete id↔token mappings.""" self._id_to_token = {} self._token_to_id = {} # Structural tokens for i in range(StructuralTokens.RESERVED_END): name = StructuralTokens.name(i) self._id_to_token[i] = name self._token_to_id[name] = i # Key tokens for key, tid in self._key_to_id.items(): token_name = f"Key({key})" self._id_to_token[tid] = token_name self._token_to_id[token_name] = tid # BPE tokens for bpe_token, bpe_id in self._bpe.vocab.items(): full_id = self._bpe_offset + bpe_id self._id_to_token[full_id] = f"BPE({bpe_token})" self._token_to_id[f"BPE({bpe_token})"] = full_id # ── Encoding ──────────────────────────────────────────────────────── def encode(self, json_input: Union[str, Any]) -> list[int]: """Encode a JSON string or parsed object into token IDs. Args: json_input: Either a JSON string or an already-parsed Python object. Returns: List of integer token IDs. """ if isinstance(json_input, str): try: obj = json.loads(json_input) except json.JSONDecodeError: raise ValueError(f"Invalid JSON: {json_input[:100]}...") else: obj = json_input tokens = [StructuralTokens.START] self._encode_value(obj, tokens) tokens.append(StructuralTokens.END) return tokens def _encode_value(self, value: Any, tokens: list[int]) -> None: """Recursively encode a JSON value into tokens.""" if isinstance(value, dict): self._encode_object(value, tokens) elif isinstance(value, list): self._encode_array(value, tokens) elif isinstance(value, str): self._encode_string(value, tokens) elif isinstance(value, bool): # Must check bool before int (bool is subclass of int in Python) tokens.append(StructuralTokens.TRUE if value else StructuralTokens.FALSE) elif isinstance(value, (int, float)): self._encode_number(value, tokens) elif value is None: tokens.append(StructuralTokens.NULL) else: tokens.append(StructuralTokens.UNK) def _encode_object(self, obj: dict, tokens: list[int]) -> None: """Encode a JSON object.""" tokens.append(StructuralTokens.OBJ_START) for i, (key, value) in enumerate(obj.items()): if i > 0: tokens.append(StructuralTokens.COMMA) self._encode_key(key, tokens) tokens.append(StructuralTokens.COLON) self._encode_value(value, tokens) tokens.append(StructuralTokens.OBJ_END) def _encode_array(self, arr: list, tokens: list[int]) -> None: """Encode a JSON array.""" tokens.append(StructuralTokens.ARR_START) for i, item in enumerate(arr): if i > 0: tokens.append(StructuralTokens.COMMA) self._encode_value(item, tokens) tokens.append(StructuralTokens.ARR_END) def _encode_key(self, key: str, tokens: list[int]) -> None: """Encode a JSON key — uses key vocab if available, else BPE.""" if key in self._key_to_id: tokens.append(self._key_to_id[key]) else: tokens.append(StructuralTokens.KEY_PREFIX) bpe_ids = self._bpe.encode_to_ids(key) tokens.extend(self._bpe_offset + bid for bid in bpe_ids) def _encode_string(self, value: str, tokens: list[int]) -> None: """Encode a JSON string value.""" tokens.append(StructuralTokens.STR_DELIM) if value: # don't BPE-encode empty strings bpe_ids = self._bpe.encode_to_ids(value) tokens.extend(self._bpe_offset + bid for bid in bpe_ids) tokens.append(StructuralTokens.STR_DELIM) def _encode_number(self, value: Union[int, float], tokens: list[int]) -> None: """Encode a JSON number value.""" tokens.append(StructuralTokens.NUM_PREFIX) # Preserve int vs float distinction if isinstance(value, float) and value == int(value) and "." in str(value): text = str(value) elif isinstance(value, int): text = str(value) else: text = repr(value) bpe_ids = self._bpe.encode_to_ids(text) tokens.extend(self._bpe_offset + bid for bid in bpe_ids) # ── Decoding ──────────────────────────────────────────────────────── def decode(self, token_ids: list[int]) -> str: """Decode token IDs back to a JSON string. Args: token_ids: List of integer token IDs from encode(). Returns: JSON string faithful to the original. """ obj = self._decode_to_object(token_ids) return json.dumps(obj, ensure_ascii=False) def decode_to_object(self, token_ids: list[int]) -> Any: """Decode token IDs back to a Python object.""" return self._decode_to_object(token_ids) def _decode_to_object(self, token_ids: list[int]) -> Any: """Parse token IDs back into a Python object.""" # Strip START/END ids = list(token_ids) if ids and ids[0] == StructuralTokens.START: ids = ids[1:] if ids and ids[-1] == StructuralTokens.END: ids = ids[:-1] result, _ = self._parse_value(ids, 0) return result def _parse_value(self, ids: list[int], pos: int) -> tuple[Any, int]: """Parse a single value starting at position pos.""" if pos >= len(ids): return None, pos tid = ids[pos] if tid == StructuralTokens.OBJ_START: return self._parse_object(ids, pos) elif tid == StructuralTokens.ARR_START: return self._parse_array(ids, pos) elif tid == StructuralTokens.STR_DELIM: return self._parse_string(ids, pos) elif tid == StructuralTokens.NUM_PREFIX: return self._parse_number(ids, pos) elif tid == StructuralTokens.NULL: return None, pos + 1 elif tid == StructuralTokens.TRUE: return True, pos + 1 elif tid == StructuralTokens.FALSE: return False, pos + 1 else: return None, pos + 1 def _parse_object(self, ids: list[int], pos: int) -> tuple[dict, int]: """Parse a JSON object from token IDs.""" assert ids[pos] == StructuralTokens.OBJ_START pos += 1 result: dict[str, Any] = {} while pos < len(ids) and ids[pos] != StructuralTokens.OBJ_END: if ids[pos] == StructuralTokens.COMMA: pos += 1 continue # Parse key key, pos = self._parse_key(ids, pos) # Expect colon if pos < len(ids) and ids[pos] == StructuralTokens.COLON: pos += 1 # Parse value value, pos = self._parse_value(ids, pos) result[key] = value if pos < len(ids) and ids[pos] == StructuralTokens.OBJ_END: pos += 1 return result, pos def _parse_array(self, ids: list[int], pos: int) -> tuple[list, int]: """Parse a JSON array from token IDs.""" assert ids[pos] == StructuralTokens.ARR_START pos += 1 result: list[Any] = [] while pos < len(ids) and ids[pos] != StructuralTokens.ARR_END: if ids[pos] == StructuralTokens.COMMA: pos += 1 continue value, pos = self._parse_value(ids, pos) result.append(value) if pos < len(ids) and ids[pos] == StructuralTokens.ARR_END: pos += 1 return result, pos def _parse_key(self, ids: list[int], pos: int) -> tuple[str, int]: """Parse a key from token IDs.""" tid = ids[pos] # Check key vocabulary if tid in self._id_to_key: return self._id_to_key[tid], pos + 1 # KEY_PREFIX → BPE-encoded key if tid == StructuralTokens.KEY_PREFIX: pos += 1 bpe_tokens: list[str] = [] while pos < len(ids) and ids[pos] >= self._bpe_offset: bpe_id = ids[pos] - self._bpe_offset bpe_tokens.append(self._bpe.id_to_token(bpe_id)) pos += 1 # Stop before COLON if pos < len(ids) and ids[pos] == StructuralTokens.COLON: break return self._bpe.decode_tokens(bpe_tokens), pos return f"", pos + 1 def _parse_string(self, ids: list[int], pos: int) -> tuple[str, int]: """Parse a string value from token IDs.""" assert ids[pos] == StructuralTokens.STR_DELIM pos += 1 bpe_tokens: list[str] = [] while pos < len(ids) and ids[pos] != StructuralTokens.STR_DELIM: bpe_id = ids[pos] - self._bpe_offset bpe_tokens.append(self._bpe.id_to_token(bpe_id)) pos += 1 # Skip closing delimiter if pos < len(ids) and ids[pos] == StructuralTokens.STR_DELIM: pos += 1 return self._bpe.decode_tokens(bpe_tokens), pos def _parse_number(self, ids: list[int], pos: int) -> tuple[Union[int, float], int]: """Parse a number value from token IDs.""" assert ids[pos] == StructuralTokens.NUM_PREFIX pos += 1 bpe_tokens: list[str] = [] while pos < len(ids): tid = ids[pos] if tid < self._bpe_offset: break # hit a structural token bpe_id = tid - self._bpe_offset bpe_tokens.append(self._bpe.id_to_token(bpe_id)) pos += 1 text = self._bpe.decode_tokens(bpe_tokens).strip() try: if "." in text or "e" in text.lower(): return float(text), pos return int(text), pos except ValueError: return 0, pos # ── Inspection / Debug ────────────────────────────────────────────── def decode_tokens_readable(self, token_ids: list[int]) -> list[str]: """Convert token IDs to human-readable token names.""" result: list[str] = [] for tid in token_ids: if tid in self._id_to_token: result.append(self._id_to_token[tid]) elif tid in self._id_to_key: result.append(f"Key({self._id_to_key[tid]})") else: bpe_id = tid - self._bpe_offset token_str = self._bpe.id_to_token(bpe_id) result.append(f"BPE({repr(token_str)})") return result def token_count(self, json_input: Union[str, Any]) -> int: """Count tokens for a JSON input without materializing full list.""" return len(self.encode(json_input)) # ── Persistence ───────────────────────────────────────────────────── def save(self, directory: str) -> None: """Save the full tokenizer state to a directory.""" import os os.makedirs(directory, exist_ok=True) # Save BPE model self._bpe.save(os.path.join(directory, "bpe_model.json")) # Save key vocabulary and config config = { "version": "json-tokenizer-v1", "bpe_vocab_size": self.bpe_vocab_size, "max_key_vocab": self.max_key_vocab, "min_key_freq": self.min_key_freq, "bpe_min_freq": self.bpe_min_freq, "key_vocab": self._key_to_id, "key_offset": self._key_offset, "bpe_offset": self._bpe_offset, } with open(os.path.join(directory, "tokenizer_config.json"), "w") as f: json.dump(config, f, indent=2) @classmethod def load(cls, directory: str) -> "JSONTokenizer": """Load a trained tokenizer from a directory.""" import os with open(os.path.join(directory, "tokenizer_config.json")) as f: config = json.load(f) tokenizer = cls( bpe_vocab_size=config["bpe_vocab_size"], max_key_vocab=config["max_key_vocab"], min_key_freq=config["min_key_freq"], bpe_min_freq=config["bpe_min_freq"], ) # Restore key vocab tokenizer._key_to_id = config["key_vocab"] tokenizer._id_to_key = {int(v): k for k, v in config["key_vocab"].items()} tokenizer._key_offset = config["key_offset"] tokenizer._bpe_offset = config["bpe_offset"] # Load BPE tokenizer._bpe = BPETrainer.load(os.path.join(directory, "bpe_model.json")) tokenizer._build_vocab_lookup() tokenizer._trained = True return tokenizer