anthonym21 commited on
Commit
f3a6aa4
Β·
verified Β·
1 Parent(s): cdef25b

Upload json_tokenizer/hf_compat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. json_tokenizer/hf_compat.py +362 -0
json_tokenizer/hf_compat.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace Transformers-compatible wrapper for JSONTokenizer.
2
+
3
+ Provides JSONPreTrainedTokenizer, a PreTrainedTokenizer subclass that
4
+ wraps JSONTokenizer for use with the HuggingFace ecosystem:
5
+ - save_pretrained / from_pretrained
6
+ - AutoTokenizer.from_pretrained (with trust_remote_code=True)
7
+ - tokenizer(json_string) -> BatchEncoding
8
+ - Padding, truncation, batch processing, return_tensors
9
+
10
+ Requires: pip install json-tokenizer[huggingface]
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ try:
20
+ from transformers import PreTrainedTokenizer
21
+ except ImportError:
22
+ raise ImportError(
23
+ "The HuggingFace transformers library is required for this module. "
24
+ "Install it with: pip install json-tokenizer[huggingface]"
25
+ )
26
+
27
+ from json_tokenizer.tokenizer import JSONTokenizer, StructuralTokens
28
+ from json_tokenizer.bpe import BPETrainer
29
+
30
+
31
+ VOCAB_FILES_NAMES = {"vocab_file": "json_tokenizer_vocab.json"}
32
+
33
+ # Structural token ID -> HF-compatible string name.
34
+ # Uses <name> format which cannot collide with BPE tokens because
35
+ # the BPE pre-tokenizer splits <, >, : into separate tokens.
36
+ _STRUCTURAL_TOKEN_NAMES = {
37
+ StructuralTokens.PAD: "<pad>",
38
+ StructuralTokens.START: "<s>",
39
+ StructuralTokens.END: "</s>",
40
+ StructuralTokens.OBJ_START: "<obj_start>",
41
+ StructuralTokens.OBJ_END: "<obj_end>",
42
+ StructuralTokens.ARR_START: "<arr_start>",
43
+ StructuralTokens.ARR_END: "<arr_end>",
44
+ StructuralTokens.COLON: "<colon>",
45
+ StructuralTokens.COMMA: "<comma>",
46
+ StructuralTokens.NULL: "<null>",
47
+ StructuralTokens.TRUE: "<true>",
48
+ StructuralTokens.FALSE: "<false>",
49
+ StructuralTokens.STR_DELIM: "<str_delim>",
50
+ StructuralTokens.NUM_PREFIX: "<num_prefix>",
51
+ StructuralTokens.KEY_PREFIX: "<key_prefix>",
52
+ StructuralTokens.UNK: "<unk>",
53
+ }
54
+
55
+ _STRUCTURAL_NAME_TO_ID = {v: k for k, v in _STRUCTURAL_TOKEN_NAMES.items()}
56
+
57
+
58
+ class JSONPreTrainedTokenizer(PreTrainedTokenizer):
59
+ """HuggingFace-compatible wrapper around JSONTokenizer.
60
+
61
+ Usage:
62
+ # From a trained JSONTokenizer:
63
+ tok = JSONTokenizer(bpe_vocab_size=4096)
64
+ tok.train(data)
65
+ hf_tok = JSONPreTrainedTokenizer.from_json_tokenizer(tok)
66
+
67
+ # Encode/decode via HF API:
68
+ output = hf_tok('{"name": "Alice", "age": 30}')
69
+ print(output["input_ids"])
70
+ print(hf_tok.decode(output["input_ids"]))
71
+
72
+ # Save and reload:
73
+ hf_tok.save_pretrained("./my_tokenizer")
74
+ loaded = JSONPreTrainedTokenizer.from_pretrained("./my_tokenizer")
75
+ """
76
+
77
+ vocab_files_names = VOCAB_FILES_NAMES
78
+ model_input_names = ["input_ids", "attention_mask"]
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_file: Optional[str] = None,
83
+ unk_token: str = "<unk>",
84
+ bos_token: str = "<s>",
85
+ eos_token: str = "</s>",
86
+ pad_token: str = "<pad>",
87
+ **kwargs,
88
+ ):
89
+ # Internal state β€” populated from vocab_file or from_json_tokenizer
90
+ if not hasattr(self, "_json_tokenizer"):
91
+ self._json_tokenizer: Optional[JSONTokenizer] = None
92
+ if not hasattr(self, "_hf_vocab"):
93
+ self._hf_vocab: Dict[str, int] = {}
94
+ if not hasattr(self, "_hf_id_to_token"):
95
+ self._hf_id_to_token: Dict[int, str] = {}
96
+
97
+ if vocab_file is not None and os.path.isfile(vocab_file):
98
+ self._load_vocab_file(vocab_file)
99
+
100
+ super().__init__(
101
+ unk_token=unk_token,
102
+ bos_token=bos_token,
103
+ eos_token=eos_token,
104
+ pad_token=pad_token,
105
+ **kwargs,
106
+ )
107
+
108
+ # ── Factory ────────────────────────────────────────────────────────
109
+
110
+ @classmethod
111
+ def from_json_tokenizer(
112
+ cls, tokenizer: JSONTokenizer, **kwargs
113
+ ) -> "JSONPreTrainedTokenizer":
114
+ """Create from a trained JSONTokenizer instance.
115
+
116
+ Args:
117
+ tokenizer: A trained JSONTokenizer.
118
+ **kwargs: Additional arguments passed to __init__.
119
+
120
+ Returns:
121
+ A new JSONPreTrainedTokenizer wrapping the provided tokenizer.
122
+ """
123
+ if not tokenizer._trained:
124
+ raise ValueError("JSONTokenizer must be trained before wrapping.")
125
+
126
+ instance = cls.__new__(cls)
127
+ instance._json_tokenizer = tokenizer
128
+ instance._hf_vocab = {}
129
+ instance._hf_id_to_token = {}
130
+ instance._build_hf_vocab()
131
+ instance.__init__(vocab_file=None, **kwargs)
132
+ return instance
133
+
134
+ # ── Vocab building ─────────────────────────────────────────────────
135
+
136
+ def _load_vocab_file(self, vocab_file: str) -> None:
137
+ """Reconstruct a JSONTokenizer from our saved vocab file."""
138
+ with open(vocab_file, "r", encoding="utf-8") as f:
139
+ data = json.load(f)
140
+
141
+ config = data["config"]
142
+ tok = JSONTokenizer(
143
+ bpe_vocab_size=config["bpe_vocab_size"],
144
+ max_key_vocab=config["max_key_vocab"],
145
+ min_key_freq=config["min_key_freq"],
146
+ bpe_min_freq=config["bpe_min_freq"],
147
+ )
148
+ tok._key_to_id = {k: int(v) for k, v in data["key_vocab"].items()}
149
+ tok._id_to_key = {int(v): k for k, v in data["key_vocab"].items()}
150
+ tok._key_offset = config["key_offset"]
151
+ tok._bpe_offset = config["bpe_offset"]
152
+
153
+ bpe_data = data["bpe_model"]
154
+ bpe = BPETrainer(
155
+ vocab_size=bpe_data["vocab_size"],
156
+ min_frequency=bpe_data["min_frequency"],
157
+ )
158
+ bpe.merges = [tuple(m) for m in bpe_data["merges"]]
159
+ bpe.vocab = bpe_data["vocab"]
160
+ bpe._id_to_tok = None
161
+ tok._bpe = bpe
162
+
163
+ tok._build_vocab_lookup()
164
+ tok._trained = True
165
+
166
+ self._json_tokenizer = tok
167
+ self._build_hf_vocab()
168
+
169
+ def _build_hf_vocab(self) -> None:
170
+ """Build the unified {token_string: id} mapping across all tiers."""
171
+ tok = self._json_tokenizer
172
+ self._hf_vocab = {}
173
+ self._hf_id_to_token = {}
174
+
175
+ # Structural tokens (0-15)
176
+ for tid, name in _STRUCTURAL_TOKEN_NAMES.items():
177
+ self._hf_vocab[name] = tid
178
+ self._hf_id_to_token[tid] = name
179
+
180
+ # Reserved tokens (16-31)
181
+ for tid in range(16, StructuralTokens.RESERVED_END):
182
+ name = f"<reserved_{tid}>"
183
+ self._hf_vocab[name] = tid
184
+ self._hf_id_to_token[tid] = name
185
+
186
+ # Key vocabulary tokens
187
+ for key_str, tid in tok._key_to_id.items():
188
+ name = f"<key:{key_str}>"
189
+ self._hf_vocab[name] = tid
190
+ self._hf_id_to_token[tid] = name
191
+
192
+ # BPE tokens
193
+ for bpe_token, bpe_local_id in tok._bpe.vocab.items():
194
+ full_id = tok._bpe_offset + bpe_local_id
195
+ # Collision guard (only <UNK> from BPE could theoretically collide)
196
+ if bpe_token in self._hf_vocab:
197
+ bpe_token_name = f"bpe:{bpe_token}"
198
+ else:
199
+ bpe_token_name = bpe_token
200
+ self._hf_vocab[bpe_token_name] = full_id
201
+ self._hf_id_to_token[full_id] = bpe_token_name
202
+
203
+ # ── Required PreTrainedTokenizer overrides ─────────────────────────
204
+
205
+ @property
206
+ def vocab_size(self) -> int:
207
+ if self._json_tokenizer is None:
208
+ return len(_STRUCTURAL_TOKEN_NAMES)
209
+ return self._json_tokenizer.vocab_size
210
+
211
+ def get_vocab(self) -> Dict[str, int]:
212
+ vocab = dict(self._hf_vocab)
213
+ vocab.update(self.added_tokens_encoder)
214
+ return vocab
215
+
216
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
217
+ """Tokenize a JSON string into HF token strings.
218
+
219
+ The HF pipeline calls: tokenize(text) -> _tokenize -> list[str]
220
+ then convert_tokens_to_ids maps those to IDs.
221
+
222
+ We parse the JSON, encode via JSONTokenizer (skipping START/END
223
+ since HF adds special tokens via build_inputs_with_special_tokens),
224
+ then convert IDs to our HF token string names.
225
+ """
226
+ if self._json_tokenizer is None:
227
+ return [self.unk_token]
228
+
229
+ try:
230
+ ids = self._json_tokenizer.encode(text)
231
+ except (ValueError, json.JSONDecodeError):
232
+ # Not valid JSON β€” encode as raw string via BPE
233
+ ids = [StructuralTokens.START]
234
+ self._json_tokenizer._encode_string(text, ids)
235
+ ids.append(StructuralTokens.END)
236
+
237
+ # Strip START/END β€” HF adds them via build_inputs_with_special_tokens
238
+ if ids and ids[0] == StructuralTokens.START:
239
+ ids = ids[1:]
240
+ if ids and ids[-1] == StructuralTokens.END:
241
+ ids = ids[:-1]
242
+
243
+ return [self._hf_id_to_token.get(tid, self.unk_token) for tid in ids]
244
+
245
+ def _convert_token_to_id(self, token: str) -> int:
246
+ return self._hf_vocab.get(
247
+ token, self._hf_vocab.get(self.unk_token, StructuralTokens.UNK)
248
+ )
249
+
250
+ def _convert_id_to_token(self, index: int) -> str:
251
+ return self._hf_id_to_token.get(index, self.unk_token)
252
+
253
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
254
+ """Reconstruct a JSON string from token strings.
255
+
256
+ Converts token strings -> IDs, wraps with START/END,
257
+ and delegates to JSONTokenizer.decode().
258
+ """
259
+ if self._json_tokenizer is None:
260
+ return ""
261
+
262
+ ids = [StructuralTokens.START]
263
+ for token in tokens:
264
+ tid = self._convert_token_to_id(token)
265
+ ids.append(tid)
266
+ ids.append(StructuralTokens.END)
267
+
268
+ try:
269
+ return self._json_tokenizer.decode(ids)
270
+ except Exception:
271
+ return " ".join(tokens)
272
+
273
+ # ── Special tokens ────────────────────────────���────────────────────
274
+
275
+ def build_inputs_with_special_tokens(
276
+ self,
277
+ token_ids_0: List[int],
278
+ token_ids_1: Optional[List[int]] = None,
279
+ ) -> List[int]:
280
+ """Wrap with START (bos) and END (eos) tokens."""
281
+ bos = [self.bos_token_id]
282
+ eos = [self.eos_token_id]
283
+ if token_ids_1 is None:
284
+ return bos + token_ids_0 + eos
285
+ return bos + token_ids_0 + eos + bos + token_ids_1 + eos
286
+
287
+ def get_special_tokens_mask(
288
+ self,
289
+ token_ids_0: List[int],
290
+ token_ids_1: Optional[List[int]] = None,
291
+ already_has_special_tokens: bool = False,
292
+ ) -> List[int]:
293
+ """1 for special tokens (START/END), 0 for content tokens."""
294
+ if already_has_special_tokens:
295
+ return super().get_special_tokens_mask(
296
+ token_ids_0=token_ids_0,
297
+ token_ids_1=token_ids_1,
298
+ already_has_special_tokens=True,
299
+ )
300
+ if token_ids_1 is None:
301
+ return [1] + [0] * len(token_ids_0) + [1]
302
+ return (
303
+ [1] + [0] * len(token_ids_0) + [1]
304
+ + [1] + [0] * len(token_ids_1) + [1]
305
+ )
306
+
307
+ def create_token_type_ids_from_sequences(
308
+ self,
309
+ token_ids_0: List[int],
310
+ token_ids_1: Optional[List[int]] = None,
311
+ ) -> List[int]:
312
+ """Segment IDs: 0 for first sequence, 1 for second."""
313
+ bos_eos = 2 # one bos + one eos
314
+ if token_ids_1 is None:
315
+ return [0] * (len(token_ids_0) + bos_eos)
316
+ return [0] * (len(token_ids_0) + bos_eos) + [1] * (len(token_ids_1) + bos_eos)
317
+
318
+ # ── Persistence ────────────────────────────────────────────────────
319
+
320
+ def save_vocabulary(
321
+ self,
322
+ save_directory: str,
323
+ filename_prefix: Optional[str] = None,
324
+ ) -> Tuple[str]:
325
+ """Save the vocabulary to a single JSON file.
326
+
327
+ This file contains everything needed to reconstruct the
328
+ JSONTokenizer: config, key vocab, and BPE model.
329
+ """
330
+ if not os.path.isdir(save_directory):
331
+ raise ValueError(f"Not a directory: {save_directory}")
332
+
333
+ vocab_file = os.path.join(
334
+ save_directory,
335
+ (filename_prefix + "-" if filename_prefix else "")
336
+ + VOCAB_FILES_NAMES["vocab_file"],
337
+ )
338
+
339
+ tok = self._json_tokenizer
340
+ data = {
341
+ "version": "json-tokenizer-hf-v1",
342
+ "config": {
343
+ "bpe_vocab_size": tok.bpe_vocab_size,
344
+ "max_key_vocab": tok.max_key_vocab,
345
+ "min_key_freq": tok.min_key_freq,
346
+ "bpe_min_freq": tok.bpe_min_freq,
347
+ "key_offset": tok._key_offset,
348
+ "bpe_offset": tok._bpe_offset,
349
+ },
350
+ "key_vocab": tok._key_to_id,
351
+ "bpe_model": {
352
+ "vocab_size": tok._bpe.vocab_size,
353
+ "min_frequency": tok._bpe.min_frequency,
354
+ "merges": [list(m) for m in tok._bpe.merges],
355
+ "vocab": tok._bpe.vocab,
356
+ },
357
+ }
358
+
359
+ with open(vocab_file, "w", encoding="utf-8") as f:
360
+ json.dump(data, f, indent=2, ensure_ascii=False)
361
+
362
+ return (vocab_file,)