anthonym21's picture
Upload json_tokenizer/tokenizer.py with huggingface_hub
7e0e0d5 verified
raw
history blame
21 kB
"""
JSON-optimized tokenizer.
Design principles:
1. Structural tokens: JSON grammar symbols ({, }, [, ], :, comma) each get
a dedicated single token β€” no wasted subword splits on syntax.
2. Key vocabulary: Frequently occurring JSON keys get their own tokens
(Key(name), Key(id), etc.), massively reducing token count for
repetitive schemas.
3. Type-prefixed values: Values are prefixed with a type marker
(STR:, NUM:, BOOL:, NULL) so the tokenizer preserves JSON types
for lossless roundtrip.
4. BPE for value content: String and number content is tokenized via
a BPE codec trained on JSON value distributions.
5. Nesting tokens: [OBJ_START]/[OBJ_END] and Array(N) tokens encode
hierarchy without ambiguity.
"""
from __future__ import annotations
import json
import re
from collections import Counter
from typing import Any, Optional, Union
from json_tokenizer.bpe import BPETrainer
# ── Structural token constants ──────────────────────────────────────────
class StructuralTokens:
"""Reserved token IDs for JSON grammar elements."""
PAD = 0
START = 1 # start of JSON document
END = 2 # end of JSON document
OBJ_START = 3 # {
OBJ_END = 4 # }
ARR_START = 5 # [ (generic, length encoded separately)
ARR_END = 6 # ]
COLON = 7 # :
COMMA = 8 # ,
NULL = 9 # null value
TRUE = 10 # true
FALSE = 11 # false
STR_DELIM = 12 # marks start/end of a string value
NUM_PREFIX = 13 # marks start of a number value
KEY_PREFIX = 14 # marks start of a key (if not in key vocab)
UNK = 15 # unknown token
# IDs 16-31 reserved for future structural tokens
RESERVED_END = 32
@classmethod
def name(cls, token_id: int) -> str:
_names = {
0: "[PAD]",
1: "[START]",
2: "[END]",
3: "{",
4: "}",
5: "[",
6: "]",
7: ":",
8: ",",
9: "null",
10: "true",
11: "false",
12: "[STR]",
13: "[NUM]",
14: "[KEY]",
15: "[UNK]",
}
return _names.get(token_id, f"[RESERVED_{token_id}]")
class JSONTokenizer:
"""Tokenizer optimized for JSON structures.
Encodes JSON into a compact token sequence with:
- Single tokens for structural elements
- Dedicated key tokens for common keys
- BPE subword tokens for string/number values
- Full roundtrip fidelity (encode β†’ decode == original)
Usage:
tokenizer = JSONTokenizer()
tokenizer.train_from_json_files(["data1.json", "data2.json"])
ids = tokenizer.encode('{"name": "Alice", "age": 30}')
decoded = tokenizer.decode(ids)
"""
def __init__(
self,
bpe_vocab_size: int = 4096,
max_key_vocab: int = 1024,
min_key_freq: int = 2,
bpe_min_freq: int = 2,
):
self.bpe_vocab_size = bpe_vocab_size
self.max_key_vocab = max_key_vocab
self.min_key_freq = min_key_freq
self.bpe_min_freq = bpe_min_freq
# Key vocabulary: key_string β†’ token_id
self._key_to_id: dict[str, int] = {}
self._id_to_key: dict[int, str] = {}
self._key_offset = StructuralTokens.RESERVED_END
# BPE for values
self._bpe = BPETrainer(vocab_size=bpe_vocab_size, min_frequency=bpe_min_freq)
self._bpe_offset = 0 # set after key vocab is built
# Full vocab
self._id_to_token: dict[int, str] = {}
self._token_to_id: dict[str, int] = {}
self._trained = False
@property
def vocab_size(self) -> int:
"""Total vocabulary size."""
if not self._trained:
return StructuralTokens.RESERVED_END
return self._bpe_offset + len(self._bpe.vocab)
# ── Training ────────────────────────────────────────────────────────
def train(self, json_objects: list[Any]) -> None:
"""Train the tokenizer from a list of parsed JSON objects.
Extracts keys for the key vocabulary and values for BPE training.
Args:
json_objects: List of parsed JSON values (dicts, lists, primitives).
"""
key_counter: Counter[str] = Counter()
value_strings: list[str] = []
for obj in json_objects:
self._extract_keys_and_values(obj, key_counter, value_strings)
# Build key vocabulary from most common keys
top_keys = [
k
for k, count in key_counter.most_common(self.max_key_vocab)
if count >= self.min_key_freq
]
self._key_to_id = {}
self._id_to_key = {}
for i, key in enumerate(top_keys):
tid = self._key_offset + i
self._key_to_id[key] = tid
self._id_to_key[tid] = key
# BPE offset is after key vocab
self._bpe_offset = self._key_offset + len(self._key_to_id)
# Train BPE on value strings
if value_strings:
self._bpe.train(value_strings)
# Build full vocab lookup
self._build_vocab_lookup()
self._trained = True
def train_from_json_strings(self, json_strings: list[str]) -> None:
"""Train from raw JSON strings."""
objects = []
for s in json_strings:
try:
objects.append(json.loads(s))
except json.JSONDecodeError:
continue
self.train(objects)
def train_from_json_files(self, file_paths: list[str]) -> None:
"""Train from JSON files (one JSON object per file, or JSONL)."""
objects = []
for path in file_paths:
with open(path) as f:
content = f.read().strip()
# Try as single JSON object
try:
obj = json.loads(content)
if isinstance(obj, list):
objects.extend(obj)
else:
objects.append(obj)
continue
except json.JSONDecodeError:
pass
# Try as JSONL
for line in content.splitlines():
line = line.strip()
if line:
try:
objects.append(json.loads(line))
except json.JSONDecodeError:
continue
self.train(objects)
def _extract_keys_and_values(
self,
obj: Any,
key_counter: Counter[str],
value_strings: list[str],
) -> None:
"""Recursively extract keys and value strings from a JSON object."""
if isinstance(obj, dict):
for key, value in obj.items():
key_counter[key] += 1
# Also train BPE on key strings (they appear as values too)
value_strings.append(key)
self._extract_keys_and_values(value, key_counter, value_strings)
elif isinstance(obj, list):
for item in obj:
self._extract_keys_and_values(item, key_counter, value_strings)
elif isinstance(obj, str):
value_strings.append(obj)
elif isinstance(obj, (int, float)):
value_strings.append(str(obj))
# bool and None don't need BPE (they're structural tokens)
def _build_vocab_lookup(self) -> None:
"""Build the complete id↔token mappings."""
self._id_to_token = {}
self._token_to_id = {}
# Structural tokens
for i in range(StructuralTokens.RESERVED_END):
name = StructuralTokens.name(i)
self._id_to_token[i] = name
self._token_to_id[name] = i
# Key tokens
for key, tid in self._key_to_id.items():
token_name = f"Key({key})"
self._id_to_token[tid] = token_name
self._token_to_id[token_name] = tid
# BPE tokens
for bpe_token, bpe_id in self._bpe.vocab.items():
full_id = self._bpe_offset + bpe_id
self._id_to_token[full_id] = f"BPE({bpe_token})"
self._token_to_id[f"BPE({bpe_token})"] = full_id
# ── Encoding ────────────────────────────────────────────────────────
def encode(self, json_input: Union[str, Any]) -> list[int]:
"""Encode a JSON string or parsed object into token IDs.
Args:
json_input: Either a JSON string or an already-parsed Python object.
Returns:
List of integer token IDs.
"""
if isinstance(json_input, str):
try:
obj = json.loads(json_input)
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON: {json_input[:100]}...")
else:
obj = json_input
tokens = [StructuralTokens.START]
self._encode_value(obj, tokens)
tokens.append(StructuralTokens.END)
return tokens
def _encode_value(self, value: Any, tokens: list[int]) -> None:
"""Recursively encode a JSON value into tokens."""
if isinstance(value, dict):
self._encode_object(value, tokens)
elif isinstance(value, list):
self._encode_array(value, tokens)
elif isinstance(value, str):
self._encode_string(value, tokens)
elif isinstance(value, bool):
# Must check bool before int (bool is subclass of int in Python)
tokens.append(StructuralTokens.TRUE if value else StructuralTokens.FALSE)
elif isinstance(value, (int, float)):
self._encode_number(value, tokens)
elif value is None:
tokens.append(StructuralTokens.NULL)
else:
tokens.append(StructuralTokens.UNK)
def _encode_object(self, obj: dict, tokens: list[int]) -> None:
"""Encode a JSON object."""
tokens.append(StructuralTokens.OBJ_START)
for i, (key, value) in enumerate(obj.items()):
if i > 0:
tokens.append(StructuralTokens.COMMA)
self._encode_key(key, tokens)
tokens.append(StructuralTokens.COLON)
self._encode_value(value, tokens)
tokens.append(StructuralTokens.OBJ_END)
def _encode_array(self, arr: list, tokens: list[int]) -> None:
"""Encode a JSON array."""
tokens.append(StructuralTokens.ARR_START)
for i, item in enumerate(arr):
if i > 0:
tokens.append(StructuralTokens.COMMA)
self._encode_value(item, tokens)
tokens.append(StructuralTokens.ARR_END)
def _encode_key(self, key: str, tokens: list[int]) -> None:
"""Encode a JSON key β€” uses key vocab if available, else BPE."""
if key in self._key_to_id:
tokens.append(self._key_to_id[key])
else:
tokens.append(StructuralTokens.KEY_PREFIX)
bpe_ids = self._bpe.encode_to_ids(key)
tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
def _encode_string(self, value: str, tokens: list[int]) -> None:
"""Encode a JSON string value."""
tokens.append(StructuralTokens.STR_DELIM)
if value: # don't BPE-encode empty strings
bpe_ids = self._bpe.encode_to_ids(value)
tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
tokens.append(StructuralTokens.STR_DELIM)
def _encode_number(self, value: Union[int, float], tokens: list[int]) -> None:
"""Encode a JSON number value."""
tokens.append(StructuralTokens.NUM_PREFIX)
# Preserve int vs float distinction
if isinstance(value, float) and value == int(value) and "." in str(value):
text = str(value)
elif isinstance(value, int):
text = str(value)
else:
text = repr(value)
bpe_ids = self._bpe.encode_to_ids(text)
tokens.extend(self._bpe_offset + bid for bid in bpe_ids)
# ── Decoding ────────────────────────────────────────────────────────
def decode(self, token_ids: list[int]) -> str:
"""Decode token IDs back to a JSON string.
Args:
token_ids: List of integer token IDs from encode().
Returns:
JSON string faithful to the original.
"""
obj = self._decode_to_object(token_ids)
return json.dumps(obj, ensure_ascii=False)
def decode_to_object(self, token_ids: list[int]) -> Any:
"""Decode token IDs back to a Python object."""
return self._decode_to_object(token_ids)
def _decode_to_object(self, token_ids: list[int]) -> Any:
"""Parse token IDs back into a Python object."""
# Strip START/END
ids = list(token_ids)
if ids and ids[0] == StructuralTokens.START:
ids = ids[1:]
if ids and ids[-1] == StructuralTokens.END:
ids = ids[:-1]
result, _ = self._parse_value(ids, 0)
return result
def _parse_value(self, ids: list[int], pos: int) -> tuple[Any, int]:
"""Parse a single value starting at position pos."""
if pos >= len(ids):
return None, pos
tid = ids[pos]
if tid == StructuralTokens.OBJ_START:
return self._parse_object(ids, pos)
elif tid == StructuralTokens.ARR_START:
return self._parse_array(ids, pos)
elif tid == StructuralTokens.STR_DELIM:
return self._parse_string(ids, pos)
elif tid == StructuralTokens.NUM_PREFIX:
return self._parse_number(ids, pos)
elif tid == StructuralTokens.NULL:
return None, pos + 1
elif tid == StructuralTokens.TRUE:
return True, pos + 1
elif tid == StructuralTokens.FALSE:
return False, pos + 1
else:
return None, pos + 1
def _parse_object(self, ids: list[int], pos: int) -> tuple[dict, int]:
"""Parse a JSON object from token IDs."""
assert ids[pos] == StructuralTokens.OBJ_START
pos += 1
result: dict[str, Any] = {}
while pos < len(ids) and ids[pos] != StructuralTokens.OBJ_END:
if ids[pos] == StructuralTokens.COMMA:
pos += 1
continue
# Parse key
key, pos = self._parse_key(ids, pos)
# Expect colon
if pos < len(ids) and ids[pos] == StructuralTokens.COLON:
pos += 1
# Parse value
value, pos = self._parse_value(ids, pos)
result[key] = value
if pos < len(ids) and ids[pos] == StructuralTokens.OBJ_END:
pos += 1
return result, pos
def _parse_array(self, ids: list[int], pos: int) -> tuple[list, int]:
"""Parse a JSON array from token IDs."""
assert ids[pos] == StructuralTokens.ARR_START
pos += 1
result: list[Any] = []
while pos < len(ids) and ids[pos] != StructuralTokens.ARR_END:
if ids[pos] == StructuralTokens.COMMA:
pos += 1
continue
value, pos = self._parse_value(ids, pos)
result.append(value)
if pos < len(ids) and ids[pos] == StructuralTokens.ARR_END:
pos += 1
return result, pos
def _parse_key(self, ids: list[int], pos: int) -> tuple[str, int]:
"""Parse a key from token IDs."""
tid = ids[pos]
# Check key vocabulary
if tid in self._id_to_key:
return self._id_to_key[tid], pos + 1
# KEY_PREFIX β†’ BPE-encoded key
if tid == StructuralTokens.KEY_PREFIX:
pos += 1
bpe_tokens: list[str] = []
while pos < len(ids) and ids[pos] >= self._bpe_offset:
bpe_id = ids[pos] - self._bpe_offset
bpe_tokens.append(self._bpe.id_to_token(bpe_id))
pos += 1
# Stop before COLON
if pos < len(ids) and ids[pos] == StructuralTokens.COLON:
break
return self._bpe.decode_tokens(bpe_tokens), pos
return f"<unknown_key_{tid}>", pos + 1
def _parse_string(self, ids: list[int], pos: int) -> tuple[str, int]:
"""Parse a string value from token IDs."""
assert ids[pos] == StructuralTokens.STR_DELIM
pos += 1
bpe_tokens: list[str] = []
while pos < len(ids) and ids[pos] != StructuralTokens.STR_DELIM:
bpe_id = ids[pos] - self._bpe_offset
bpe_tokens.append(self._bpe.id_to_token(bpe_id))
pos += 1
# Skip closing delimiter
if pos < len(ids) and ids[pos] == StructuralTokens.STR_DELIM:
pos += 1
return self._bpe.decode_tokens(bpe_tokens), pos
def _parse_number(self, ids: list[int], pos: int) -> tuple[Union[int, float], int]:
"""Parse a number value from token IDs."""
assert ids[pos] == StructuralTokens.NUM_PREFIX
pos += 1
bpe_tokens: list[str] = []
while pos < len(ids):
tid = ids[pos]
if tid < self._bpe_offset:
break # hit a structural token
bpe_id = tid - self._bpe_offset
bpe_tokens.append(self._bpe.id_to_token(bpe_id))
pos += 1
text = self._bpe.decode_tokens(bpe_tokens).strip()
try:
if "." in text or "e" in text.lower():
return float(text), pos
return int(text), pos
except ValueError:
return 0, pos
# ── Inspection / Debug ──────────────────────────────────────────────
def decode_tokens_readable(self, token_ids: list[int]) -> list[str]:
"""Convert token IDs to human-readable token names."""
result: list[str] = []
for tid in token_ids:
if tid in self._id_to_token:
result.append(self._id_to_token[tid])
elif tid in self._id_to_key:
result.append(f"Key({self._id_to_key[tid]})")
else:
bpe_id = tid - self._bpe_offset
token_str = self._bpe.id_to_token(bpe_id)
result.append(f"BPE({repr(token_str)})")
return result
def token_count(self, json_input: Union[str, Any]) -> int:
"""Count tokens for a JSON input without materializing full list."""
return len(self.encode(json_input))
# ── Persistence ─────────────────────────────────────────────────────
def save(self, directory: str) -> None:
"""Save the full tokenizer state to a directory."""
import os
os.makedirs(directory, exist_ok=True)
# Save BPE model
self._bpe.save(os.path.join(directory, "bpe_model.json"))
# Save key vocabulary and config
config = {
"version": "json-tokenizer-v1",
"bpe_vocab_size": self.bpe_vocab_size,
"max_key_vocab": self.max_key_vocab,
"min_key_freq": self.min_key_freq,
"bpe_min_freq": self.bpe_min_freq,
"key_vocab": self._key_to_id,
"key_offset": self._key_offset,
"bpe_offset": self._bpe_offset,
}
with open(os.path.join(directory, "tokenizer_config.json"), "w") as f:
json.dump(config, f, indent=2)
@classmethod
def load(cls, directory: str) -> "JSONTokenizer":
"""Load a trained tokenizer from a directory."""
import os
with open(os.path.join(directory, "tokenizer_config.json")) as f:
config = json.load(f)
tokenizer = cls(
bpe_vocab_size=config["bpe_vocab_size"],
max_key_vocab=config["max_key_vocab"],
min_key_freq=config["min_key_freq"],
bpe_min_freq=config["bpe_min_freq"],
)
# Restore key vocab
tokenizer._key_to_id = config["key_vocab"]
tokenizer._id_to_key = {int(v): k for k, v in config["key_vocab"].items()}
tokenizer._key_offset = config["key_offset"]
tokenizer._bpe_offset = config["bpe_offset"]
# Load BPE
tokenizer._bpe = BPETrainer.load(os.path.join(directory, "bpe_model.json"))
tokenizer._build_vocab_lookup()
tokenizer._trained = True
return tokenizer