"""HuggingFace Transformers-compatible wrapper for JSONTokenizer. Provides JSONPreTrainedTokenizer, a PreTrainedTokenizer subclass that wraps JSONTokenizer for use with the HuggingFace ecosystem: - save_pretrained / from_pretrained - AutoTokenizer.from_pretrained (with trust_remote_code=True) - tokenizer(json_string) -> BatchEncoding - Padding, truncation, batch processing, return_tensors Requires: pip install json-tokenizer[huggingface] """ from __future__ import annotations import json import os from typing import Any, Dict, List, Optional, Tuple, Union try: from transformers import PreTrainedTokenizer except ImportError: raise ImportError( "The HuggingFace transformers library is required for this module. " "Install it with: pip install json-tokenizer[huggingface]" ) from json_tokenizer.tokenizer import JSONTokenizer, StructuralTokens from json_tokenizer.bpe import BPETrainer VOCAB_FILES_NAMES = {"vocab_file": "json_tokenizer_vocab.json"} # Structural token ID -> HF-compatible string name. # Uses format which cannot collide with BPE tokens because # the BPE pre-tokenizer splits <, >, : into separate tokens. _STRUCTURAL_TOKEN_NAMES = { StructuralTokens.PAD: "", StructuralTokens.START: "", StructuralTokens.END: "", StructuralTokens.OBJ_START: "", StructuralTokens.OBJ_END: "", StructuralTokens.ARR_START: "", StructuralTokens.ARR_END: "", StructuralTokens.COLON: "", StructuralTokens.COMMA: "", StructuralTokens.NULL: "", StructuralTokens.TRUE: "", StructuralTokens.FALSE: "", StructuralTokens.STR_DELIM: "", StructuralTokens.NUM_PREFIX: "", StructuralTokens.KEY_PREFIX: "", StructuralTokens.UNK: "", } _STRUCTURAL_NAME_TO_ID = {v: k for k, v in _STRUCTURAL_TOKEN_NAMES.items()} class JSONPreTrainedTokenizer(PreTrainedTokenizer): """HuggingFace-compatible wrapper around JSONTokenizer. Usage: # From a trained JSONTokenizer: tok = JSONTokenizer(bpe_vocab_size=4096) tok.train(data) hf_tok = JSONPreTrainedTokenizer.from_json_tokenizer(tok) # Encode/decode via HF API: output = hf_tok('{"name": "Alice", "age": 30}') print(output["input_ids"]) print(hf_tok.decode(output["input_ids"])) # Save and reload: hf_tok.save_pretrained("./my_tokenizer") loaded = JSONPreTrainedTokenizer.from_pretrained("./my_tokenizer") """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file: Optional[str] = None, unk_token: str = "", bos_token: str = "", eos_token: str = "", pad_token: str = "", **kwargs, ): # Internal state — populated from vocab_file or from_json_tokenizer if not hasattr(self, "_json_tokenizer"): self._json_tokenizer: Optional[JSONTokenizer] = None if not hasattr(self, "_hf_vocab"): self._hf_vocab: Dict[str, int] = {} if not hasattr(self, "_hf_id_to_token"): self._hf_id_to_token: Dict[int, str] = {} if vocab_file is not None and os.path.isfile(vocab_file): self._load_vocab_file(vocab_file) super().__init__( unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, **kwargs, ) # ── Factory ──────────────────────────────────────────────────────── @classmethod def from_json_tokenizer( cls, tokenizer: JSONTokenizer, **kwargs ) -> "JSONPreTrainedTokenizer": """Create from a trained JSONTokenizer instance. Args: tokenizer: A trained JSONTokenizer. **kwargs: Additional arguments passed to __init__. Returns: A new JSONPreTrainedTokenizer wrapping the provided tokenizer. """ if not tokenizer._trained: raise ValueError("JSONTokenizer must be trained before wrapping.") instance = cls.__new__(cls) instance._json_tokenizer = tokenizer instance._hf_vocab = {} instance._hf_id_to_token = {} instance._build_hf_vocab() instance.__init__(vocab_file=None, **kwargs) return instance # ── Vocab building ───────────────────────────────────────────────── def _load_vocab_file(self, vocab_file: str) -> None: """Reconstruct a JSONTokenizer from our saved vocab file.""" with open(vocab_file, "r", encoding="utf-8") as f: data = json.load(f) config = data["config"] tok = JSONTokenizer( bpe_vocab_size=config["bpe_vocab_size"], max_key_vocab=config["max_key_vocab"], min_key_freq=config["min_key_freq"], bpe_min_freq=config["bpe_min_freq"], ) tok._key_to_id = {k: int(v) for k, v in data["key_vocab"].items()} tok._id_to_key = {int(v): k for k, v in data["key_vocab"].items()} tok._key_offset = config["key_offset"] tok._bpe_offset = config["bpe_offset"] bpe_data = data["bpe_model"] bpe = BPETrainer( vocab_size=bpe_data["vocab_size"], min_frequency=bpe_data["min_frequency"], ) bpe.merges = [tuple(m) for m in bpe_data["merges"]] bpe.vocab = bpe_data["vocab"] bpe._id_to_tok = None tok._bpe = bpe tok._build_vocab_lookup() tok._trained = True self._json_tokenizer = tok self._build_hf_vocab() def _build_hf_vocab(self) -> None: """Build the unified {token_string: id} mapping across all tiers.""" tok = self._json_tokenizer self._hf_vocab = {} self._hf_id_to_token = {} # Structural tokens (0-15) for tid, name in _STRUCTURAL_TOKEN_NAMES.items(): self._hf_vocab[name] = tid self._hf_id_to_token[tid] = name # Reserved tokens (16-31) for tid in range(16, StructuralTokens.RESERVED_END): name = f"" self._hf_vocab[name] = tid self._hf_id_to_token[tid] = name # Key vocabulary tokens for key_str, tid in tok._key_to_id.items(): name = f"" self._hf_vocab[name] = tid self._hf_id_to_token[tid] = name # BPE tokens for bpe_token, bpe_local_id in tok._bpe.vocab.items(): full_id = tok._bpe_offset + bpe_local_id # Collision guard (only from BPE could theoretically collide) if bpe_token in self._hf_vocab: bpe_token_name = f"bpe:{bpe_token}" else: bpe_token_name = bpe_token self._hf_vocab[bpe_token_name] = full_id self._hf_id_to_token[full_id] = bpe_token_name # ── Required PreTrainedTokenizer overrides ───────────────────────── @property def vocab_size(self) -> int: if self._json_tokenizer is None: return len(_STRUCTURAL_TOKEN_NAMES) return self._json_tokenizer.vocab_size def get_vocab(self) -> Dict[str, int]: vocab = dict(self._hf_vocab) vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text: str, **kwargs) -> List[str]: """Tokenize a JSON string into HF token strings. The HF pipeline calls: tokenize(text) -> _tokenize -> list[str] then convert_tokens_to_ids maps those to IDs. We parse the JSON, encode via JSONTokenizer (skipping START/END since HF adds special tokens via build_inputs_with_special_tokens), then convert IDs to our HF token string names. """ if self._json_tokenizer is None: return [self.unk_token] try: ids = self._json_tokenizer.encode(text) except (ValueError, json.JSONDecodeError): # Not valid JSON — encode as raw string via BPE ids = [StructuralTokens.START] self._json_tokenizer._encode_string(text, ids) ids.append(StructuralTokens.END) # Strip START/END — HF adds them via build_inputs_with_special_tokens if ids and ids[0] == StructuralTokens.START: ids = ids[1:] if ids and ids[-1] == StructuralTokens.END: ids = ids[:-1] return [self._hf_id_to_token.get(tid, self.unk_token) for tid in ids] def _convert_token_to_id(self, token: str) -> int: return self._hf_vocab.get( token, self._hf_vocab.get(self.unk_token, StructuralTokens.UNK) ) def _convert_id_to_token(self, index: int) -> str: return self._hf_id_to_token.get(index, self.unk_token) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Reconstruct a JSON string from token strings. Converts token strings -> IDs, wraps with START/END, and delegates to JSONTokenizer.decode(). """ if self._json_tokenizer is None: return "" ids = [StructuralTokens.START] for token in tokens: tid = self._convert_token_to_id(token) ids.append(tid) ids.append(StructuralTokens.END) try: return self._json_tokenizer.decode(ids) except Exception: return " ".join(tokens) # ── Special tokens ───────────────────────────────────────────────── def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: """Wrap with START (bos) and END (eos) tokens.""" bos = [self.bos_token_id] eos = [self.eos_token_id] if token_ids_1 is None: return bos + token_ids_0 + eos return bos + token_ids_0 + eos + bos + token_ids_1 + eos def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False, ) -> List[int]: """1 for special tokens (START/END), 0 for content tokens.""" if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True, ) if token_ids_1 is None: return [1] + [0] * len(token_ids_0) + [1] return ( [1] + [0] * len(token_ids_0) + [1] + [1] + [0] * len(token_ids_1) + [1] ) def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, ) -> List[int]: """Segment IDs: 0 for first sequence, 1 for second.""" bos_eos = 2 # one bos + one eos if token_ids_1 is None: return [0] * (len(token_ids_0) + bos_eos) return [0] * (len(token_ids_0) + bos_eos) + [1] * (len(token_ids_1) + bos_eos) # ── Persistence ──────────────────────────────────────────────────── def save_vocabulary( self, save_directory: str, filename_prefix: Optional[str] = None, ) -> Tuple[str]: """Save the vocabulary to a single JSON file. This file contains everything needed to reconstruct the JSONTokenizer: config, key vocab, and BPE model. """ if not os.path.isdir(save_directory): raise ValueError(f"Not a directory: {save_directory}") vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"], ) tok = self._json_tokenizer data = { "version": "json-tokenizer-hf-v1", "config": { "bpe_vocab_size": tok.bpe_vocab_size, "max_key_vocab": tok.max_key_vocab, "min_key_freq": tok.min_key_freq, "bpe_min_freq": tok.bpe_min_freq, "key_offset": tok._key_offset, "bpe_offset": tok._bpe_offset, }, "key_vocab": tok._key_to_id, "bpe_model": { "vocab_size": tok._bpe.vocab_size, "min_frequency": tok._bpe.min_frequency, "merges": [list(m) for m in tok._bpe.merges], "vocab": tok._bpe.vocab, }, } with open(vocab_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) return (vocab_file,)