gbyuvd commited on
Commit
ad6a9f0
·
verified ·
1 Parent(s): 07e938b

Upload FastChemTokenizerHF2.py

Browse files
Files changed (1) hide show
  1. FastChemTokenizerHF2.py +813 -0
FastChemTokenizerHF2.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from typing import List, Union, Optional, Tuple, Dict, Any
5
+ from functools import lru_cache
6
+
7
+ # Minimal equivalents for problematic imports
8
+ class BatchEncoding:
9
+ def __init__(self, data, tensor_type=None):
10
+ self.data = data
11
+ for key, value in data.items():
12
+ setattr(self, key, value)
13
+
14
+ def __getitem__(self, key):
15
+ return self.data[key]
16
+
17
+ def __contains__(self, key):
18
+ return key in self.data # This should fix the issue
19
+
20
+ def keys(self):
21
+ return self.data.keys()
22
+
23
+ def values(self):
24
+ return self.data.values()
25
+
26
+ def items(self):
27
+ return self.data.items()
28
+ class PreTrainedTokenizerBase:
29
+ def __init__(self, **kwargs):
30
+ # Set special tokens from kwargs
31
+ for key, value in kwargs.items():
32
+ if key.endswith('_token'):
33
+ setattr(self, f"_{key}", value)
34
+ # Don't set token_id here - let subclass handle it
35
+
36
+ # Set other attributes
37
+ self.model_max_length = kwargs.get('model_max_length', 512)
38
+ self.padding_side = kwargs.get('padding_side', 'right')
39
+ self.truncation_side = kwargs.get('truncation_side', 'right')
40
+ self.chat_template = kwargs.get('chat_template')
41
+
42
+ class TrieNode:
43
+ __slots__ = ['children', 'token_id']
44
+ def __init__(self):
45
+ self.children = {}
46
+ self.token_id = None # If set, this node completes a valid token
47
+
48
+ class FastChemTokenizer(PreTrainedTokenizerBase):
49
+ """
50
+ Fully HuggingFace API compatible tokenizer for chemical representations.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ token_to_id=None,
56
+ vocab_file=None,
57
+ model_max_length=512,
58
+ padding_side="right",
59
+ truncation_side="right",
60
+ chat_template=None,
61
+ **kwargs
62
+ ):
63
+ # Handle vocab loading
64
+ if token_to_id is None and vocab_file is None:
65
+ raise ValueError("Either token_to_id or vocab_file must be provided")
66
+
67
+ if vocab_file is not None:
68
+ with open(vocab_file, "r", encoding="utf-8") as f:
69
+ token_to_id = json.load(f)
70
+ token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
71
+
72
+ self.token_to_id = token_to_id
73
+ self.id_to_token = {v: k for k, v in token_to_id.items()}
74
+
75
+ # Precompute max token length for possible use & clarity
76
+ self.max_token_len = max(len(t) for t in token_to_id.keys()) if token_to_id else 0
77
+
78
+ # Build trie for fast longest-match lookup
79
+ self.trie_root = self._build_trie(token_to_id)
80
+
81
+ # Validate required special tokens
82
+ required_special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
83
+ for tok in required_special_tokens:
84
+ if tok not in token_to_id:
85
+ raise KeyError(f"Required special token '{tok}' not found in vocab.")
86
+
87
+ # ✅ Assign special token IDs explicitly - MUST be done before parent init
88
+ self.bos_token_id = token_to_id["<s>"]
89
+ self.eos_token_id = token_to_id["</s>"]
90
+ self.pad_token_id = token_to_id["<pad>"]
91
+ self.unk_token_id = token_to_id["<unk>"]
92
+ self.mask_token_id = token_to_id["<mask>"]
93
+
94
+ # Initialize parent class
95
+ super().__init__(
96
+ bos_token="<s>",
97
+ eos_token="</s>",
98
+ unk_token="<unk>",
99
+ sep_token=None,
100
+ pad_token="<pad>",
101
+ cls_token=None,
102
+ mask_token="<mask>",
103
+ additional_special_tokens=[],
104
+ model_max_length=model_max_length,
105
+ padding_side=padding_side,
106
+ truncation_side=truncation_side,
107
+ chat_template=chat_template,
108
+ **kwargs,
109
+ )
110
+
111
+ # ✅ Set all token attributes for compatibility
112
+ self._unk_token = "<unk>"
113
+ self._bos_token = "<s>"
114
+ self._eos_token = "</s>"
115
+ self._pad_token = "<pad>"
116
+ self._mask_token = "<mask>"
117
+ self._sep_token = None
118
+ self._cls_token = None
119
+
120
+ self.unk_token = "<unk>"
121
+ self.bos_token = "<s>"
122
+ self.eos_token = "</s>"
123
+ self.pad_token = "<pad>"
124
+ self.mask_token = "<mask>"
125
+ self.sep_token = None
126
+ self.cls_token = None
127
+
128
+ def _build_trie(self, token_to_id):
129
+ root = TrieNode()
130
+ for token, tid in token_to_id.items():
131
+ node = root
132
+ for char in token:
133
+ if char not in node.children:
134
+ node.children[char] = TrieNode()
135
+ node = node.children[char]
136
+ node.token_id = tid
137
+ return root
138
+
139
+ @property
140
+ def vocab_size(self):
141
+ return len(self.token_to_id)
142
+
143
+ def __len__(self):
144
+ return len(self.token_to_id)
145
+
146
+ def get_vocab(self) -> Dict[str, int]:
147
+ return self.token_to_id.copy()
148
+
149
+ @lru_cache(maxsize=10000)
150
+ def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
151
+ return tuple(self._encode_core(s))
152
+
153
+ def _encode_core(self, text: str) -> List[int]:
154
+ """Core encoding logic using Trie — no caching."""
155
+ tokens = text
156
+ result_ids = []
157
+ i = 0
158
+ n = len(tokens)
159
+
160
+ while i < n:
161
+ node = self.trie_root
162
+ j = i
163
+ last_match_id = None
164
+ last_match_end = i
165
+
166
+ while j < n and tokens[j] in node.children:
167
+ node = node.children[tokens[j]]
168
+ j += 1
169
+ if node.token_id is not None:
170
+ last_match_id = node.token_id
171
+ last_match_end = j
172
+
173
+ if last_match_id is not None:
174
+ result_ids.append(last_match_id)
175
+ i = last_match_end
176
+ else:
177
+ tok = tokens[i]
178
+ result_ids.append(self.token_to_id.get(tok, self.unk_token_id))
179
+ i += 1
180
+
181
+ return result_ids
182
+
183
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
184
+ token_ids = self._encode_core(text.strip())
185
+ return [self.id_to_token[tid] for tid in token_ids]
186
+
187
+ def _convert_token_to_id(self, token: str) -> int:
188
+ return self.token_to_id.get(token, self.unk_token_id)
189
+
190
+ def _convert_id_to_token(self, index: int) -> str:
191
+ return self.id_to_token.get(index, self.unk_token)
192
+
193
+ # ✅ Public methods
194
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
195
+ if isinstance(tokens, str):
196
+ return self._convert_token_to_id(tokens)
197
+ return [self._convert_token_to_id(tok) for tok in tokens]
198
+
199
+ def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
200
+ if isinstance(ids, int):
201
+ return self._convert_id_to_token(ids)
202
+ return [self._convert_id_to_token(i) for i in ids]
203
+
204
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
205
+ """SMILES-style decoding: no spaces between tokens."""
206
+ return "".join(tokens)
207
+
208
+ def encode(
209
+ self,
210
+ text: str,
211
+ text_pair: Optional[str] = None,
212
+ add_special_tokens: bool = True,
213
+ padding: bool = False,
214
+ truncation: bool = False,
215
+ max_length: Optional[int] = None,
216
+ return_tensors: Optional[str] = None,
217
+ ) -> List[int]:
218
+ """Simple encode method that returns list of token IDs."""
219
+ if text_pair is not None:
220
+ raise NotImplementedError("text_pair not supported in simple encode method")
221
+
222
+ # Get core token IDs
223
+ token_ids = list(self._cached_encode_str(text.strip()))
224
+
225
+ # Add special tokens if requested
226
+ if add_special_tokens:
227
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
228
+
229
+ # Handle truncation
230
+ if truncation:
231
+ if max_length is None:
232
+ max_length = self.model_max_length
233
+ if len(token_ids) > max_length:
234
+ token_ids = token_ids[:max_length]
235
+
236
+ # Handle padding
237
+ if padding:
238
+ if max_length is None:
239
+ max_length = self.model_max_length
240
+ pad_len = max_length - len(token_ids)
241
+ if pad_len > 0:
242
+ if self.padding_side == "right":
243
+ token_ids = token_ids + [self.pad_token_id] * pad_len
244
+ else:
245
+ token_ids = [self.pad_token_id] * pad_len + token_ids
246
+
247
+ # Return as tensor if requested
248
+ if return_tensors == "pt":
249
+ token_ids = torch.tensor(token_ids, dtype=torch.long)
250
+ if token_ids.dim() == 0: # scalar
251
+ token_ids = token_ids.unsqueeze(0)
252
+
253
+ return token_ids
254
+
255
+ def decode(
256
+ self,
257
+ token_ids: Union[List[int], torch.Tensor],
258
+ skip_special_tokens: bool = False,
259
+ clean_up_tokenization_spaces: bool = None,
260
+ **kwargs
261
+ ) -> str:
262
+ if isinstance(token_ids, torch.Tensor):
263
+ token_ids = token_ids.tolist()
264
+
265
+ if skip_special_tokens:
266
+ special_ids = {
267
+ self.bos_token_id,
268
+ self.eos_token_id,
269
+ self.pad_token_id,
270
+ self.mask_token_id,
271
+ }
272
+ else:
273
+ special_ids = set()
274
+
275
+ tokens = []
276
+ for tid in token_ids:
277
+ if tid in special_ids:
278
+ continue
279
+ token = self.id_to_token.get(tid, self.unk_token)
280
+ tokens.append(token)
281
+
282
+ return "".join(tokens)
283
+
284
+ def batch_decode(
285
+ self,
286
+ sequences: Union[List[List[int]], torch.Tensor],
287
+ skip_special_tokens: bool = False,
288
+ clean_up_tokenization_spaces: bool = None,
289
+ **kwargs
290
+ ) -> List[str]:
291
+ """Batch decode sequences."""
292
+ if isinstance(sequences, torch.Tensor):
293
+ sequences = sequences.tolist()
294
+
295
+ return [
296
+ self.decode(
297
+ seq,
298
+ skip_special_tokens=skip_special_tokens,
299
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
300
+ **kwargs
301
+ )
302
+ for seq in sequences
303
+ ]
304
+
305
+ def decode_with_trace(self, token_ids: List[int]) -> None:
306
+ """Debug method to trace decoding process."""
307
+ print(f"\n🔍 Decoding {len(token_ids)} tokens:")
308
+ for i, tid in enumerate(token_ids):
309
+ token = self.id_to_token.get(tid, self.unk_token)
310
+ print(f" [{i:03d}] ID={tid:5d} → '{token}'")
311
+
312
+ def __call__(
313
+ self,
314
+ text: Union[str, List[str]],
315
+ text_pair: Optional[Union[str, List[str]]] = None,
316
+ add_special_tokens: bool = True,
317
+ padding: Union[bool, str] = False,
318
+ truncation: Union[bool, str] = False,
319
+ max_length: Optional[int] = None,
320
+ stride: int = 0,
321
+ is_split_into_words: bool = False,
322
+ pad_to_multiple_of: Optional[int] = None,
323
+ return_tensors: Optional[Union[str, Any]] = None,
324
+ return_token_type_ids: Optional[bool] = None,
325
+ return_attention_mask: Optional[bool] = None,
326
+ return_overflowing_tokens: bool = False,
327
+ return_special_tokens_mask: bool = False,
328
+ return_offsets_mapping: bool = False,
329
+ return_length: bool = False,
330
+ verbose: bool = True,
331
+ **kwargs
332
+ ) -> BatchEncoding:
333
+ """
334
+ Main callable method that handles both single and batch inputs.
335
+ """
336
+ # Handle defaults
337
+ if return_token_type_ids is None:
338
+ return_token_type_ids = True
339
+ if return_attention_mask is None:
340
+ return_attention_mask = True
341
+
342
+ if isinstance(text, list):
343
+ if text_pair is not None:
344
+ batch = [(t, p) for t, p in zip(text, text_pair)]
345
+ else:
346
+ batch = text
347
+ return self.batch_encode_plus(
348
+ batch,
349
+ add_special_tokens=add_special_tokens,
350
+ padding=padding,
351
+ truncation=truncation,
352
+ max_length=max_length,
353
+ stride=stride,
354
+ is_split_into_words=is_split_into_words,
355
+ pad_to_multiple_of=pad_to_multiple_of,
356
+ return_tensors=return_tensors,
357
+ return_token_type_ids=return_token_type_ids,
358
+ return_attention_mask=return_attention_mask,
359
+ return_overflowing_tokens=return_overflowing_tokens,
360
+ return_special_tokens_mask=return_special_tokens_mask,
361
+ return_offsets_mapping=return_offsets_mapping,
362
+ return_length=return_length,
363
+ verbose=verbose,
364
+ **kwargs
365
+ )
366
+ else:
367
+ return self.encode_plus(
368
+ text=text,
369
+ text_pair=text_pair,
370
+ add_special_tokens=add_special_tokens,
371
+ padding=padding,
372
+ truncation=truncation,
373
+ max_length=max_length,
374
+ stride=stride,
375
+ is_split_into_words=is_split_into_words,
376
+ pad_to_multiple_of=pad_to_multiple_of,
377
+ return_tensors=return_tensors,
378
+ return_token_type_ids=return_token_type_ids,
379
+ return_attention_mask=return_attention_mask,
380
+ return_overflowing_tokens=return_overflowing_tokens,
381
+ return_special_tokens_mask=return_special_tokens_mask,
382
+ return_offsets_mapping=return_offsets_mapping,
383
+ return_length=return_length,
384
+ verbose=verbose,
385
+ **kwargs
386
+ )
387
+
388
+ def encode_plus(
389
+ self,
390
+ text: str,
391
+ text_pair: Optional[str] = None,
392
+ add_special_tokens: bool = True,
393
+ padding: Union[bool, str] = False,
394
+ truncation: Union[bool, str] = False,
395
+ max_length: Optional[int] = None,
396
+ stride: int = 0,
397
+ is_split_into_words: bool = False,
398
+ pad_to_multiple_of: Optional[int] = None,
399
+ return_tensors: Optional[Union[str, Any]] = None,
400
+ return_token_type_ids: Optional[bool] = True,
401
+ return_attention_mask: Optional[bool] = True,
402
+ return_overflowing_tokens: bool = False,
403
+ return_special_tokens_mask: bool = False,
404
+ return_offsets_mapping: bool = False,
405
+ return_length: bool = False,
406
+ verbose: bool = True,
407
+ **kwargs
408
+ ) -> BatchEncoding:
409
+ if max_length is None:
410
+ max_length = self.model_max_length
411
+
412
+ ids_a = list(self._cached_encode_str(text.strip()))
413
+
414
+ if text_pair is not None:
415
+ ids_b = list(self._cached_encode_str(text_pair.strip()))
416
+ else:
417
+ ids_b = None
418
+
419
+ input_ids = []
420
+ token_type_ids = []
421
+
422
+ if add_special_tokens:
423
+ input_ids.append(self.bos_token_id)
424
+ token_type_ids.append(0)
425
+ if ids_b is not None:
426
+ input_ids.extend(ids_a)
427
+ token_type_ids.extend([0] * len(ids_a))
428
+ input_ids.append(self.eos_token_id)
429
+ token_type_ids.append(0)
430
+
431
+ input_ids.extend(ids_b)
432
+ token_type_ids.extend([1] * len(ids_b))
433
+ input_ids.append(self.eos_token_id)
434
+ token_type_ids.append(1)
435
+ else:
436
+ input_ids.extend(ids_a)
437
+ token_type_ids.extend([0] * len(ids_a))
438
+ input_ids.append(self.eos_token_id)
439
+ token_type_ids.append(0)
440
+ else:
441
+ input_ids = ids_a.copy()
442
+ token_type_ids = [0] * len(input_ids)
443
+ if ids_b is not None:
444
+ input_ids.extend(ids_b)
445
+ token_type_ids.extend([1] * len(ids_b))
446
+
447
+ # Handle truncation
448
+ if truncation and len(input_ids) > max_length:
449
+ input_ids = input_ids[:max_length]
450
+ token_type_ids = token_type_ids[:max_length]
451
+
452
+ # Handle padding
453
+ if padding == True or padding == "max_length":
454
+ pad_len = max_length - len(input_ids)
455
+ if pad_len > 0:
456
+ if self.padding_side == "right":
457
+ input_ids.extend([self.pad_token_id] * pad_len)
458
+ token_type_ids.extend([0] * pad_len)
459
+ else:
460
+ input_ids = [self.pad_token_id] * pad_len + input_ids
461
+ token_type_ids = [0] * pad_len + token_type_ids
462
+
463
+ attention_mask = [1 if tid != self.pad_token_id else 0 for tid in input_ids]
464
+
465
+ encoded_dict = {
466
+ "input_ids": input_ids,
467
+ }
468
+
469
+ if return_attention_mask:
470
+ encoded_dict["attention_mask"] = attention_mask
471
+
472
+ if return_token_type_ids:
473
+ encoded_dict["token_type_ids"] = token_type_ids
474
+
475
+ if return_special_tokens_mask:
476
+ special_tokens_mask = [
477
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
478
+ for tid in input_ids
479
+ ]
480
+ encoded_dict["special_tokens_mask"] = special_tokens_mask
481
+
482
+ if return_length:
483
+ encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
484
+
485
+ if return_tensors == "pt":
486
+ output = {}
487
+ for k, v in encoded_dict.items():
488
+ tensor = torch.tensor(v, dtype=torch.long)
489
+ if tensor.ndim == 1:
490
+ tensor = tensor.unsqueeze(0)
491
+ output[k] = tensor
492
+ else:
493
+ output = encoded_dict
494
+
495
+ return BatchEncoding(output, tensor_type=return_tensors)
496
+
497
+ def batch_encode_plus(
498
+ self,
499
+ batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
500
+ add_special_tokens: bool = True,
501
+ padding: Union[bool, str] = False,
502
+ truncation: Union[bool, str] = False,
503
+ max_length: Optional[int] = None,
504
+ stride: int = 0,
505
+ is_split_into_words: bool = False,
506
+ pad_to_multiple_of: Optional[int] = None,
507
+ return_tensors: Optional[Union[str, Any]] = None,
508
+ return_token_type_ids: Optional[bool] = True,
509
+ return_attention_mask: Optional[bool] = True,
510
+ return_overflowing_tokens: bool = False,
511
+ return_special_tokens_mask: bool = False,
512
+ return_offsets_mapping: bool = False,
513
+ return_length: bool = False,
514
+ verbose: bool = True,
515
+ **kwargs
516
+ ) -> BatchEncoding:
517
+ all_input_ids = []
518
+ all_attention_masks = []
519
+ all_token_type_ids = []
520
+ all_special_tokens_masks = []
521
+ all_lengths = []
522
+
523
+ for item in batch_text_or_text_pairs:
524
+ if isinstance(item, tuple):
525
+ text, text_pair = item
526
+ else:
527
+ text, text_pair = item, None
528
+
529
+ encoded = self.encode_plus(
530
+ text=text,
531
+ text_pair=text_pair,
532
+ add_special_tokens=add_special_tokens,
533
+ padding=False, # We'll handle batch padding later
534
+ truncation=truncation,
535
+ max_length=max_length,
536
+ stride=stride,
537
+ is_split_into_words=is_split_into_words,
538
+ pad_to_multiple_of=pad_to_multiple_of,
539
+ return_tensors=None, # Don't convert to tensors yet
540
+ return_token_type_ids=return_token_type_ids,
541
+ return_attention_mask=return_attention_mask,
542
+ return_overflowing_tokens=return_overflowing_tokens,
543
+ return_special_tokens_mask=return_special_tokens_mask,
544
+ return_offsets_mapping=return_offsets_mapping,
545
+ return_length=return_length,
546
+ verbose=verbose,
547
+ **kwargs
548
+ )
549
+
550
+ all_input_ids.append(encoded["input_ids"])
551
+ if "attention_mask" in encoded:
552
+ all_attention_masks.append(encoded["attention_mask"])
553
+ if "token_type_ids" in encoded:
554
+ all_token_type_ids.append(encoded["token_type_ids"])
555
+ if "special_tokens_mask" in encoded:
556
+ all_special_tokens_masks.append(encoded["special_tokens_mask"])
557
+ if "length" in encoded:
558
+ all_lengths.append(encoded["length"])
559
+
560
+ batched = {
561
+ "input_ids": all_input_ids,
562
+ }
563
+
564
+ if all_attention_masks:
565
+ batched["attention_mask"] = all_attention_masks
566
+ if all_token_type_ids:
567
+ batched["token_type_ids"] = all_token_type_ids
568
+ if all_special_tokens_masks:
569
+ batched["special_tokens_mask"] = all_special_tokens_masks
570
+ if all_lengths:
571
+ batched["length"] = all_lengths
572
+
573
+ # Handle batch padding
574
+ if padding == True or padding == "longest":
575
+ max_len = max(len(ids) for ids in all_input_ids)
576
+ for key in batched:
577
+ if key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
578
+ padded_seqs = []
579
+ for seq in batched[key]:
580
+ pad_len = max_len - len(seq)
581
+ if pad_len > 0:
582
+ if key == "input_ids":
583
+ padding_value = self.pad_token_id
584
+ else:
585
+ padding_value = 0
586
+
587
+ if self.padding_side == "right":
588
+ padded_seq = seq + [padding_value] * pad_len
589
+ else:
590
+ padded_seq = [padding_value] * pad_len + seq
591
+ else:
592
+ padded_seq = seq
593
+ padded_seqs.append(padded_seq)
594
+ batched[key] = padded_seqs
595
+
596
+ if return_tensors == "pt":
597
+ def to_tensor_list(lst):
598
+ return [torch.tensor(item, dtype=torch.long) for item in lst]
599
+
600
+ for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
601
+ if key in batched:
602
+ batched[key] = torch.nn.utils.rnn.pad_sequence(
603
+ to_tensor_list(batched[key]),
604
+ batch_first=True,
605
+ padding_value=self.pad_token_id if key == "input_ids" else 0
606
+ )
607
+
608
+ # Handle non-sequence data
609
+ if "length" in batched:
610
+ batched["length"] = torch.tensor(batched["length"], dtype=torch.long)
611
+
612
+ return BatchEncoding(batched, tensor_type=return_tensors)
613
+
614
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
615
+ """Save vocabulary to files."""
616
+ if not os.path.isdir(save_directory):
617
+ os.makedirs(save_directory)
618
+
619
+ vocab_file = os.path.join(
620
+ save_directory,
621
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
622
+ )
623
+
624
+ with open(vocab_file, "w", encoding="utf-8") as f:
625
+ json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)
626
+
627
+ return (vocab_file,)
628
+
629
+ def save_pretrained(
630
+ self,
631
+ save_directory: Union[str, os.PathLike],
632
+ legacy_format: bool = True,
633
+ filename_prefix: Optional[str] = None,
634
+ push_to_hub: bool = False,
635
+ **kwargs
636
+ ):
637
+ """Save tokenizer to directory."""
638
+ if not os.path.exists(save_directory):
639
+ os.makedirs(save_directory)
640
+
641
+ # Save vocabulary
642
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix)
643
+
644
+ # Save tokenizer config
645
+ tokenizer_config = {
646
+ "tokenizer_class": self.__class__.__name__,
647
+ "model_max_length": self.model_max_length,
648
+ "padding_side": self.padding_side,
649
+ "truncation_side": self.truncation_side,
650
+ "special_tokens": {
651
+ "bos_token": self.bos_token,
652
+ "eos_token": self.eos_token,
653
+ "pad_token": self.pad_token,
654
+ "unk_token": self.unk_token,
655
+ "mask_token": self.mask_token,
656
+ }
657
+ }
658
+
659
+ config_file = os.path.join(save_directory, "tokenizer_config.json")
660
+ with open(config_file, "w", encoding="utf-8") as f:
661
+ json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
662
+
663
+ print(f"✅ Tokenizer saved to: {save_directory}")
664
+
665
+ return (save_directory,)
666
+
667
+ @classmethod
668
+ def from_pretrained(
669
+ cls,
670
+ pretrained_model_name_or_path: Union[str, os.PathLike],
671
+ *init_inputs,
672
+ **kwargs
673
+ ):
674
+ """Load tokenizer from pretrained directory or hub."""
675
+ if os.path.isdir(pretrained_model_name_or_path):
676
+ vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
677
+ config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
678
+
679
+ # Load config if available
680
+ config = {}
681
+ if os.path.exists(config_file):
682
+ with open(config_file, "r", encoding="utf-8") as f:
683
+ config = json.load(f)
684
+
685
+ # Merge config with kwargs
686
+ merged_config = {**config, **kwargs}
687
+
688
+ return cls(vocab_file=vocab_file, **merged_config)
689
+ else:
690
+ raise NotImplementedError("Loading from HuggingFace Hub not implemented yet")
691
+
692
+ def get_special_tokens_mask(
693
+ self,
694
+ token_ids_0: List[int],
695
+ token_ids_1: Optional[List[int]] = None,
696
+ already_has_special_tokens: bool = False
697
+ ) -> List[int]:
698
+ """Get special tokens mask."""
699
+ if already_has_special_tokens:
700
+ return [
701
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id}
702
+ else 0 for tid in token_ids_0
703
+ ]
704
+
705
+ mask = [1] # BOS
706
+ mask.extend([0] * len(token_ids_0)) # Token sequence
707
+ mask.append(1) # EOS
708
+
709
+ if token_ids_1 is not None:
710
+ mask.extend([0] * len(token_ids_1)) # Second sequence
711
+ mask.append(1) # EOS
712
+
713
+ return mask
714
+
715
+ def create_token_type_ids_from_sequences(
716
+ self,
717
+ token_ids_0: List[int],
718
+ token_ids_1: Optional[List[int]] = None
719
+ ) -> List[int]:
720
+ """Create token type IDs for sequences."""
721
+ sep = [self.eos_token_id]
722
+ cls = [self.bos_token_id]
723
+
724
+ if token_ids_1 is None:
725
+ return len(cls + token_ids_0 + sep) * [0]
726
+
727
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
728
+
729
+ def build_inputs_with_special_tokens(
730
+ self,
731
+ token_ids_0: List[int],
732
+ token_ids_1: Optional[List[int]] = None
733
+ ) -> List[int]:
734
+ """Build inputs with special tokens."""
735
+ if token_ids_1 is None:
736
+ return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
737
+
738
+ return ([self.bos_token_id] + token_ids_0 + [self.eos_token_id] +
739
+ token_ids_1 + [self.eos_token_id])
740
+
741
+
742
+ class FastChemTokenizerSelfies(FastChemTokenizer):
743
+ """
744
+ SELFIES variant that handles whitespace-separated tokens.
745
+ Uses trie-based longest-match encoding (same as original working version).
746
+ """
747
+
748
+ def _encode_core(self, text: str) -> List[int]:
749
+ """Trie-based encoding for SELFIES with fragment + atom vocab."""
750
+ result_ids = []
751
+ i = 0
752
+ n = len(text)
753
+
754
+ while i < n:
755
+ if text[i].isspace(): # skip literal whitespace
756
+ i += 1
757
+ continue
758
+
759
+ node = self.trie_root
760
+ j = i
761
+ last_match_id = None
762
+ last_match_end = i
763
+
764
+ # Traverse trie character by character (including spaces if part of vocab key)
765
+ while j < n and text[j] in node.children:
766
+ node = node.children[text[j]]
767
+ j += 1
768
+ if node.token_id is not None:
769
+ last_match_id = node.token_id
770
+ last_match_end = j
771
+
772
+ if last_match_id is not None:
773
+ result_ids.append(last_match_id)
774
+ i = last_match_end
775
+ else:
776
+ # Fallback: encode one char as unk or atom
777
+ result_ids.append(self.token_to_id.get(text[i], self.unk_token_id))
778
+ i += 1
779
+
780
+ return result_ids
781
+
782
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
783
+ """SELFIES decoding: join tokens with spaces (preserve original format)."""
784
+ return " ".join(tokens)
785
+
786
+ def decode(
787
+ self,
788
+ token_ids: Union[List[int], torch.Tensor],
789
+ skip_special_tokens: bool = False,
790
+ clean_up_tokenization_spaces: bool = None,
791
+ **kwargs
792
+ ) -> str:
793
+ if isinstance(token_ids, torch.Tensor):
794
+ token_ids = token_ids.tolist()
795
+
796
+ if skip_special_tokens:
797
+ special_ids = {
798
+ self.bos_token_id,
799
+ self.eos_token_id,
800
+ self.pad_token_id,
801
+ self.mask_token_id,
802
+ }
803
+ else:
804
+ special_ids = set()
805
+
806
+ tokens = []
807
+ for tid in token_ids:
808
+ if tid in special_ids:
809
+ continue
810
+ token = self.id_to_token.get(tid, self.unk_token)
811
+ tokens.append(token)
812
+
813
+ return " ".join(tokens) # ✅ preserve spaces return " ".join(tokens) # ✅ preserve spaces