gbyuvd commited on
Commit
8cc4c16
·
verified ·
1 Parent(s): 42c2638

Update FastChemTokenizerHF2.py

Browse files
Files changed (1) hide show
  1. FastChemTokenizerHF2.py +580 -813
FastChemTokenizerHF2.py CHANGED
@@ -1,813 +1,580 @@
1
- import torch
2
- import json
3
- import os
4
- from typing import List, Union, Optional, Tuple, Dict, Any
5
- from functools import lru_cache
6
-
7
- # Minimal equivalents for problematic imports
8
- class BatchEncoding:
9
- def __init__(self, data, tensor_type=None):
10
- self.data = data
11
- for key, value in data.items():
12
- setattr(self, key, value)
13
-
14
- def __getitem__(self, key):
15
- return self.data[key]
16
-
17
- def __contains__(self, key):
18
- return key in self.data # This should fix the issue
19
-
20
- def keys(self):
21
- return self.data.keys()
22
-
23
- def values(self):
24
- return self.data.values()
25
-
26
- def items(self):
27
- return self.data.items()
28
- class PreTrainedTokenizerBase:
29
- def __init__(self, **kwargs):
30
- # Set special tokens from kwargs
31
- for key, value in kwargs.items():
32
- if key.endswith('_token'):
33
- setattr(self, f"_{key}", value)
34
- # Don't set token_id here - let subclass handle it
35
-
36
- # Set other attributes
37
- self.model_max_length = kwargs.get('model_max_length', 512)
38
- self.padding_side = kwargs.get('padding_side', 'right')
39
- self.truncation_side = kwargs.get('truncation_side', 'right')
40
- self.chat_template = kwargs.get('chat_template')
41
-
42
- class TrieNode:
43
- __slots__ = ['children', 'token_id']
44
- def __init__(self):
45
- self.children = {}
46
- self.token_id = None # If set, this node completes a valid token
47
-
48
- class FastChemTokenizer(PreTrainedTokenizerBase):
49
- """
50
- Fully HuggingFace API compatible tokenizer for chemical representations.
51
- """
52
-
53
- def __init__(
54
- self,
55
- token_to_id=None,
56
- vocab_file=None,
57
- model_max_length=512,
58
- padding_side="right",
59
- truncation_side="right",
60
- chat_template=None,
61
- **kwargs
62
- ):
63
- # Handle vocab loading
64
- if token_to_id is None and vocab_file is None:
65
- raise ValueError("Either token_to_id or vocab_file must be provided")
66
-
67
- if vocab_file is not None:
68
- with open(vocab_file, "r", encoding="utf-8") as f:
69
- token_to_id = json.load(f)
70
- token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
71
-
72
- self.token_to_id = token_to_id
73
- self.id_to_token = {v: k for k, v in token_to_id.items()}
74
-
75
- # Precompute max token length for possible use & clarity
76
- self.max_token_len = max(len(t) for t in token_to_id.keys()) if token_to_id else 0
77
-
78
- # Build trie for fast longest-match lookup
79
- self.trie_root = self._build_trie(token_to_id)
80
-
81
- # Validate required special tokens
82
- required_special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
83
- for tok in required_special_tokens:
84
- if tok not in token_to_id:
85
- raise KeyError(f"Required special token '{tok}' not found in vocab.")
86
-
87
- # Assign special token IDs explicitly - MUST be done before parent init
88
- self.bos_token_id = token_to_id["<s>"]
89
- self.eos_token_id = token_to_id["</s>"]
90
- self.pad_token_id = token_to_id["<pad>"]
91
- self.unk_token_id = token_to_id["<unk>"]
92
- self.mask_token_id = token_to_id["<mask>"]
93
-
94
- # Initialize parent class
95
- super().__init__(
96
- bos_token="<s>",
97
- eos_token="</s>",
98
- unk_token="<unk>",
99
- sep_token=None,
100
- pad_token="<pad>",
101
- cls_token=None,
102
- mask_token="<mask>",
103
- additional_special_tokens=[],
104
- model_max_length=model_max_length,
105
- padding_side=padding_side,
106
- truncation_side=truncation_side,
107
- chat_template=chat_template,
108
- **kwargs,
109
- )
110
-
111
- # Set all token attributes for compatibility
112
- self._unk_token = "<unk>"
113
- self._bos_token = "<s>"
114
- self._eos_token = "</s>"
115
- self._pad_token = "<pad>"
116
- self._mask_token = "<mask>"
117
- self._sep_token = None
118
- self._cls_token = None
119
-
120
- self.unk_token = "<unk>"
121
- self.bos_token = "<s>"
122
- self.eos_token = "</s>"
123
- self.pad_token = "<pad>"
124
- self.mask_token = "<mask>"
125
- self.sep_token = None
126
- self.cls_token = None
127
-
128
- def _build_trie(self, token_to_id):
129
- root = TrieNode()
130
- for token, tid in token_to_id.items():
131
- node = root
132
- for char in token:
133
- if char not in node.children:
134
- node.children[char] = TrieNode()
135
- node = node.children[char]
136
- node.token_id = tid
137
- return root
138
-
139
- @property
140
- def vocab_size(self):
141
- return len(self.token_to_id)
142
-
143
- def __len__(self):
144
- return len(self.token_to_id)
145
-
146
- def get_vocab(self) -> Dict[str, int]:
147
- return self.token_to_id.copy()
148
-
149
- @lru_cache(maxsize=10000)
150
- def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
151
- return tuple(self._encode_core(s))
152
-
153
- def _encode_core(self, text: str) -> List[int]:
154
- """Core encoding logic using Trie — no caching."""
155
- tokens = text
156
- result_ids = []
157
- i = 0
158
- n = len(tokens)
159
-
160
- while i < n:
161
- node = self.trie_root
162
- j = i
163
- last_match_id = None
164
- last_match_end = i
165
-
166
- while j < n and tokens[j] in node.children:
167
- node = node.children[tokens[j]]
168
- j += 1
169
- if node.token_id is not None:
170
- last_match_id = node.token_id
171
- last_match_end = j
172
-
173
- if last_match_id is not None:
174
- result_ids.append(last_match_id)
175
- i = last_match_end
176
- else:
177
- tok = tokens[i]
178
- result_ids.append(self.token_to_id.get(tok, self.unk_token_id))
179
- i += 1
180
-
181
- return result_ids
182
-
183
- def _tokenize(self, text: str, **kwargs) -> List[str]:
184
- token_ids = self._encode_core(text.strip())
185
- return [self.id_to_token[tid] for tid in token_ids]
186
-
187
- def _convert_token_to_id(self, token: str) -> int:
188
- return self.token_to_id.get(token, self.unk_token_id)
189
-
190
- def _convert_id_to_token(self, index: int) -> str:
191
- return self.id_to_token.get(index, self.unk_token)
192
-
193
- # ✅ Public methods
194
- def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
195
- if isinstance(tokens, str):
196
- return self._convert_token_to_id(tokens)
197
- return [self._convert_token_to_id(tok) for tok in tokens]
198
-
199
- def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
200
- if isinstance(ids, int):
201
- return self._convert_id_to_token(ids)
202
- return [self._convert_id_to_token(i) for i in ids]
203
-
204
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
205
- """SMILES-style decoding: no spaces between tokens."""
206
- return "".join(tokens)
207
-
208
- def encode(
209
- self,
210
- text: str,
211
- text_pair: Optional[str] = None,
212
- add_special_tokens: bool = True,
213
- padding: bool = False,
214
- truncation: bool = False,
215
- max_length: Optional[int] = None,
216
- return_tensors: Optional[str] = None,
217
- ) -> List[int]:
218
- """Simple encode method that returns list of token IDs."""
219
- if text_pair is not None:
220
- raise NotImplementedError("text_pair not supported in simple encode method")
221
-
222
- # Get core token IDs
223
- token_ids = list(self._cached_encode_str(text.strip()))
224
-
225
- # Add special tokens if requested
226
- if add_special_tokens:
227
- token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
228
-
229
- # Handle truncation
230
- if truncation:
231
- if max_length is None:
232
- max_length = self.model_max_length
233
- if len(token_ids) > max_length:
234
- token_ids = token_ids[:max_length]
235
-
236
- # Handle padding
237
- if padding:
238
- if max_length is None:
239
- max_length = self.model_max_length
240
- pad_len = max_length - len(token_ids)
241
- if pad_len > 0:
242
- if self.padding_side == "right":
243
- token_ids = token_ids + [self.pad_token_id] * pad_len
244
- else:
245
- token_ids = [self.pad_token_id] * pad_len + token_ids
246
-
247
- # Return as tensor if requested
248
- if return_tensors == "pt":
249
- token_ids = torch.tensor(token_ids, dtype=torch.long)
250
- if token_ids.dim() == 0: # scalar
251
- token_ids = token_ids.unsqueeze(0)
252
-
253
- return token_ids
254
-
255
- def decode(
256
- self,
257
- token_ids: Union[List[int], torch.Tensor],
258
- skip_special_tokens: bool = False,
259
- clean_up_tokenization_spaces: bool = None,
260
- **kwargs
261
- ) -> str:
262
- if isinstance(token_ids, torch.Tensor):
263
- token_ids = token_ids.tolist()
264
-
265
- if skip_special_tokens:
266
- special_ids = {
267
- self.bos_token_id,
268
- self.eos_token_id,
269
- self.pad_token_id,
270
- self.mask_token_id,
271
- }
272
- else:
273
- special_ids = set()
274
-
275
- tokens = []
276
- for tid in token_ids:
277
- if tid in special_ids:
278
- continue
279
- token = self.id_to_token.get(tid, self.unk_token)
280
- tokens.append(token)
281
-
282
- return "".join(tokens)
283
-
284
- def batch_decode(
285
- self,
286
- sequences: Union[List[List[int]], torch.Tensor],
287
- skip_special_tokens: bool = False,
288
- clean_up_tokenization_spaces: bool = None,
289
- **kwargs
290
- ) -> List[str]:
291
- """Batch decode sequences."""
292
- if isinstance(sequences, torch.Tensor):
293
- sequences = sequences.tolist()
294
-
295
- return [
296
- self.decode(
297
- seq,
298
- skip_special_tokens=skip_special_tokens,
299
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
300
- **kwargs
301
- )
302
- for seq in sequences
303
- ]
304
-
305
- def decode_with_trace(self, token_ids: List[int]) -> None:
306
- """Debug method to trace decoding process."""
307
- print(f"\n🔍 Decoding {len(token_ids)} tokens:")
308
- for i, tid in enumerate(token_ids):
309
- token = self.id_to_token.get(tid, self.unk_token)
310
- print(f" [{i:03d}] ID={tid:5d} → '{token}'")
311
-
312
- def __call__(
313
- self,
314
- text: Union[str, List[str]],
315
- text_pair: Optional[Union[str, List[str]]] = None,
316
- add_special_tokens: bool = True,
317
- padding: Union[bool, str] = False,
318
- truncation: Union[bool, str] = False,
319
- max_length: Optional[int] = None,
320
- stride: int = 0,
321
- is_split_into_words: bool = False,
322
- pad_to_multiple_of: Optional[int] = None,
323
- return_tensors: Optional[Union[str, Any]] = None,
324
- return_token_type_ids: Optional[bool] = None,
325
- return_attention_mask: Optional[bool] = None,
326
- return_overflowing_tokens: bool = False,
327
- return_special_tokens_mask: bool = False,
328
- return_offsets_mapping: bool = False,
329
- return_length: bool = False,
330
- verbose: bool = True,
331
- **kwargs
332
- ) -> BatchEncoding:
333
- """
334
- Main callable method that handles both single and batch inputs.
335
- """
336
- # Handle defaults
337
- if return_token_type_ids is None:
338
- return_token_type_ids = True
339
- if return_attention_mask is None:
340
- return_attention_mask = True
341
-
342
- if isinstance(text, list):
343
- if text_pair is not None:
344
- batch = [(t, p) for t, p in zip(text, text_pair)]
345
- else:
346
- batch = text
347
- return self.batch_encode_plus(
348
- batch,
349
- add_special_tokens=add_special_tokens,
350
- padding=padding,
351
- truncation=truncation,
352
- max_length=max_length,
353
- stride=stride,
354
- is_split_into_words=is_split_into_words,
355
- pad_to_multiple_of=pad_to_multiple_of,
356
- return_tensors=return_tensors,
357
- return_token_type_ids=return_token_type_ids,
358
- return_attention_mask=return_attention_mask,
359
- return_overflowing_tokens=return_overflowing_tokens,
360
- return_special_tokens_mask=return_special_tokens_mask,
361
- return_offsets_mapping=return_offsets_mapping,
362
- return_length=return_length,
363
- verbose=verbose,
364
- **kwargs
365
- )
366
- else:
367
- return self.encode_plus(
368
- text=text,
369
- text_pair=text_pair,
370
- add_special_tokens=add_special_tokens,
371
- padding=padding,
372
- truncation=truncation,
373
- max_length=max_length,
374
- stride=stride,
375
- is_split_into_words=is_split_into_words,
376
- pad_to_multiple_of=pad_to_multiple_of,
377
- return_tensors=return_tensors,
378
- return_token_type_ids=return_token_type_ids,
379
- return_attention_mask=return_attention_mask,
380
- return_overflowing_tokens=return_overflowing_tokens,
381
- return_special_tokens_mask=return_special_tokens_mask,
382
- return_offsets_mapping=return_offsets_mapping,
383
- return_length=return_length,
384
- verbose=verbose,
385
- **kwargs
386
- )
387
-
388
- def encode_plus(
389
- self,
390
- text: str,
391
- text_pair: Optional[str] = None,
392
- add_special_tokens: bool = True,
393
- padding: Union[bool, str] = False,
394
- truncation: Union[bool, str] = False,
395
- max_length: Optional[int] = None,
396
- stride: int = 0,
397
- is_split_into_words: bool = False,
398
- pad_to_multiple_of: Optional[int] = None,
399
- return_tensors: Optional[Union[str, Any]] = None,
400
- return_token_type_ids: Optional[bool] = True,
401
- return_attention_mask: Optional[bool] = True,
402
- return_overflowing_tokens: bool = False,
403
- return_special_tokens_mask: bool = False,
404
- return_offsets_mapping: bool = False,
405
- return_length: bool = False,
406
- verbose: bool = True,
407
- **kwargs
408
- ) -> BatchEncoding:
409
- if max_length is None:
410
- max_length = self.model_max_length
411
-
412
- ids_a = list(self._cached_encode_str(text.strip()))
413
-
414
- if text_pair is not None:
415
- ids_b = list(self._cached_encode_str(text_pair.strip()))
416
- else:
417
- ids_b = None
418
-
419
- input_ids = []
420
- token_type_ids = []
421
-
422
- if add_special_tokens:
423
- input_ids.append(self.bos_token_id)
424
- token_type_ids.append(0)
425
- if ids_b is not None:
426
- input_ids.extend(ids_a)
427
- token_type_ids.extend([0] * len(ids_a))
428
- input_ids.append(self.eos_token_id)
429
- token_type_ids.append(0)
430
-
431
- input_ids.extend(ids_b)
432
- token_type_ids.extend([1] * len(ids_b))
433
- input_ids.append(self.eos_token_id)
434
- token_type_ids.append(1)
435
- else:
436
- input_ids.extend(ids_a)
437
- token_type_ids.extend([0] * len(ids_a))
438
- input_ids.append(self.eos_token_id)
439
- token_type_ids.append(0)
440
- else:
441
- input_ids = ids_a.copy()
442
- token_type_ids = [0] * len(input_ids)
443
- if ids_b is not None:
444
- input_ids.extend(ids_b)
445
- token_type_ids.extend([1] * len(ids_b))
446
-
447
- # Handle truncation
448
- if truncation and len(input_ids) > max_length:
449
- input_ids = input_ids[:max_length]
450
- token_type_ids = token_type_ids[:max_length]
451
-
452
- # Handle padding
453
- if padding == True or padding == "max_length":
454
- pad_len = max_length - len(input_ids)
455
- if pad_len > 0:
456
- if self.padding_side == "right":
457
- input_ids.extend([self.pad_token_id] * pad_len)
458
- token_type_ids.extend([0] * pad_len)
459
- else:
460
- input_ids = [self.pad_token_id] * pad_len + input_ids
461
- token_type_ids = [0] * pad_len + token_type_ids
462
-
463
- attention_mask = [1 if tid != self.pad_token_id else 0 for tid in input_ids]
464
-
465
- encoded_dict = {
466
- "input_ids": input_ids,
467
- }
468
-
469
- if return_attention_mask:
470
- encoded_dict["attention_mask"] = attention_mask
471
-
472
- if return_token_type_ids:
473
- encoded_dict["token_type_ids"] = token_type_ids
474
-
475
- if return_special_tokens_mask:
476
- special_tokens_mask = [
477
- 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
478
- for tid in input_ids
479
- ]
480
- encoded_dict["special_tokens_mask"] = special_tokens_mask
481
-
482
- if return_length:
483
- encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
484
-
485
- if return_tensors == "pt":
486
- output = {}
487
- for k, v in encoded_dict.items():
488
- tensor = torch.tensor(v, dtype=torch.long)
489
- if tensor.ndim == 1:
490
- tensor = tensor.unsqueeze(0)
491
- output[k] = tensor
492
- else:
493
- output = encoded_dict
494
-
495
- return BatchEncoding(output, tensor_type=return_tensors)
496
-
497
- def batch_encode_plus(
498
- self,
499
- batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
500
- add_special_tokens: bool = True,
501
- padding: Union[bool, str] = False,
502
- truncation: Union[bool, str] = False,
503
- max_length: Optional[int] = None,
504
- stride: int = 0,
505
- is_split_into_words: bool = False,
506
- pad_to_multiple_of: Optional[int] = None,
507
- return_tensors: Optional[Union[str, Any]] = None,
508
- return_token_type_ids: Optional[bool] = True,
509
- return_attention_mask: Optional[bool] = True,
510
- return_overflowing_tokens: bool = False,
511
- return_special_tokens_mask: bool = False,
512
- return_offsets_mapping: bool = False,
513
- return_length: bool = False,
514
- verbose: bool = True,
515
- **kwargs
516
- ) -> BatchEncoding:
517
- all_input_ids = []
518
- all_attention_masks = []
519
- all_token_type_ids = []
520
- all_special_tokens_masks = []
521
- all_lengths = []
522
-
523
- for item in batch_text_or_text_pairs:
524
- if isinstance(item, tuple):
525
- text, text_pair = item
526
- else:
527
- text, text_pair = item, None
528
-
529
- encoded = self.encode_plus(
530
- text=text,
531
- text_pair=text_pair,
532
- add_special_tokens=add_special_tokens,
533
- padding=False, # We'll handle batch padding later
534
- truncation=truncation,
535
- max_length=max_length,
536
- stride=stride,
537
- is_split_into_words=is_split_into_words,
538
- pad_to_multiple_of=pad_to_multiple_of,
539
- return_tensors=None, # Don't convert to tensors yet
540
- return_token_type_ids=return_token_type_ids,
541
- return_attention_mask=return_attention_mask,
542
- return_overflowing_tokens=return_overflowing_tokens,
543
- return_special_tokens_mask=return_special_tokens_mask,
544
- return_offsets_mapping=return_offsets_mapping,
545
- return_length=return_length,
546
- verbose=verbose,
547
- **kwargs
548
- )
549
-
550
- all_input_ids.append(encoded["input_ids"])
551
- if "attention_mask" in encoded:
552
- all_attention_masks.append(encoded["attention_mask"])
553
- if "token_type_ids" in encoded:
554
- all_token_type_ids.append(encoded["token_type_ids"])
555
- if "special_tokens_mask" in encoded:
556
- all_special_tokens_masks.append(encoded["special_tokens_mask"])
557
- if "length" in encoded:
558
- all_lengths.append(encoded["length"])
559
-
560
- batched = {
561
- "input_ids": all_input_ids,
562
- }
563
-
564
- if all_attention_masks:
565
- batched["attention_mask"] = all_attention_masks
566
- if all_token_type_ids:
567
- batched["token_type_ids"] = all_token_type_ids
568
- if all_special_tokens_masks:
569
- batched["special_tokens_mask"] = all_special_tokens_masks
570
- if all_lengths:
571
- batched["length"] = all_lengths
572
-
573
- # Handle batch padding
574
- if padding == True or padding == "longest":
575
- max_len = max(len(ids) for ids in all_input_ids)
576
- for key in batched:
577
- if key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
578
- padded_seqs = []
579
- for seq in batched[key]:
580
- pad_len = max_len - len(seq)
581
- if pad_len > 0:
582
- if key == "input_ids":
583
- padding_value = self.pad_token_id
584
- else:
585
- padding_value = 0
586
-
587
- if self.padding_side == "right":
588
- padded_seq = seq + [padding_value] * pad_len
589
- else:
590
- padded_seq = [padding_value] * pad_len + seq
591
- else:
592
- padded_seq = seq
593
- padded_seqs.append(padded_seq)
594
- batched[key] = padded_seqs
595
-
596
- if return_tensors == "pt":
597
- def to_tensor_list(lst):
598
- return [torch.tensor(item, dtype=torch.long) for item in lst]
599
-
600
- for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
601
- if key in batched:
602
- batched[key] = torch.nn.utils.rnn.pad_sequence(
603
- to_tensor_list(batched[key]),
604
- batch_first=True,
605
- padding_value=self.pad_token_id if key == "input_ids" else 0
606
- )
607
-
608
- # Handle non-sequence data
609
- if "length" in batched:
610
- batched["length"] = torch.tensor(batched["length"], dtype=torch.long)
611
-
612
- return BatchEncoding(batched, tensor_type=return_tensors)
613
-
614
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
615
- """Save vocabulary to files."""
616
- if not os.path.isdir(save_directory):
617
- os.makedirs(save_directory)
618
-
619
- vocab_file = os.path.join(
620
- save_directory,
621
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
622
- )
623
-
624
- with open(vocab_file, "w", encoding="utf-8") as f:
625
- json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)
626
-
627
- return (vocab_file,)
628
-
629
- def save_pretrained(
630
- self,
631
- save_directory: Union[str, os.PathLike],
632
- legacy_format: bool = True,
633
- filename_prefix: Optional[str] = None,
634
- push_to_hub: bool = False,
635
- **kwargs
636
- ):
637
- """Save tokenizer to directory."""
638
- if not os.path.exists(save_directory):
639
- os.makedirs(save_directory)
640
-
641
- # Save vocabulary
642
- vocab_files = self.save_vocabulary(save_directory, filename_prefix)
643
-
644
- # Save tokenizer config
645
- tokenizer_config = {
646
- "tokenizer_class": self.__class__.__name__,
647
- "model_max_length": self.model_max_length,
648
- "padding_side": self.padding_side,
649
- "truncation_side": self.truncation_side,
650
- "special_tokens": {
651
- "bos_token": self.bos_token,
652
- "eos_token": self.eos_token,
653
- "pad_token": self.pad_token,
654
- "unk_token": self.unk_token,
655
- "mask_token": self.mask_token,
656
- }
657
- }
658
-
659
- config_file = os.path.join(save_directory, "tokenizer_config.json")
660
- with open(config_file, "w", encoding="utf-8") as f:
661
- json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
662
-
663
- print(f"✅ Tokenizer saved to: {save_directory}")
664
-
665
- return (save_directory,)
666
-
667
- @classmethod
668
- def from_pretrained(
669
- cls,
670
- pretrained_model_name_or_path: Union[str, os.PathLike],
671
- *init_inputs,
672
- **kwargs
673
- ):
674
- """Load tokenizer from pretrained directory or hub."""
675
- if os.path.isdir(pretrained_model_name_or_path):
676
- vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
677
- config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
678
-
679
- # Load config if available
680
- config = {}
681
- if os.path.exists(config_file):
682
- with open(config_file, "r", encoding="utf-8") as f:
683
- config = json.load(f)
684
-
685
- # Merge config with kwargs
686
- merged_config = {**config, **kwargs}
687
-
688
- return cls(vocab_file=vocab_file, **merged_config)
689
- else:
690
- raise NotImplementedError("Loading from HuggingFace Hub not implemented yet")
691
-
692
- def get_special_tokens_mask(
693
- self,
694
- token_ids_0: List[int],
695
- token_ids_1: Optional[List[int]] = None,
696
- already_has_special_tokens: bool = False
697
- ) -> List[int]:
698
- """Get special tokens mask."""
699
- if already_has_special_tokens:
700
- return [
701
- 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id}
702
- else 0 for tid in token_ids_0
703
- ]
704
-
705
- mask = [1] # BOS
706
- mask.extend([0] * len(token_ids_0)) # Token sequence
707
- mask.append(1) # EOS
708
-
709
- if token_ids_1 is not None:
710
- mask.extend([0] * len(token_ids_1)) # Second sequence
711
- mask.append(1) # EOS
712
-
713
- return mask
714
-
715
- def create_token_type_ids_from_sequences(
716
- self,
717
- token_ids_0: List[int],
718
- token_ids_1: Optional[List[int]] = None
719
- ) -> List[int]:
720
- """Create token type IDs for sequences."""
721
- sep = [self.eos_token_id]
722
- cls = [self.bos_token_id]
723
-
724
- if token_ids_1 is None:
725
- return len(cls + token_ids_0 + sep) * [0]
726
-
727
- return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
728
-
729
- def build_inputs_with_special_tokens(
730
- self,
731
- token_ids_0: List[int],
732
- token_ids_1: Optional[List[int]] = None
733
- ) -> List[int]:
734
- """Build inputs with special tokens."""
735
- if token_ids_1 is None:
736
- return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
737
-
738
- return ([self.bos_token_id] + token_ids_0 + [self.eos_token_id] +
739
- token_ids_1 + [self.eos_token_id])
740
-
741
-
742
- class FastChemTokenizerSelfies(FastChemTokenizer):
743
- """
744
- SELFIES variant that handles whitespace-separated tokens.
745
- Uses trie-based longest-match encoding (same as original working version).
746
- """
747
-
748
- def _encode_core(self, text: str) -> List[int]:
749
- """Trie-based encoding for SELFIES with fragment + atom vocab."""
750
- result_ids = []
751
- i = 0
752
- n = len(text)
753
-
754
- while i < n:
755
- if text[i].isspace(): # skip literal whitespace
756
- i += 1
757
- continue
758
-
759
- node = self.trie_root
760
- j = i
761
- last_match_id = None
762
- last_match_end = i
763
-
764
- # Traverse trie character by character (including spaces if part of vocab key)
765
- while j < n and text[j] in node.children:
766
- node = node.children[text[j]]
767
- j += 1
768
- if node.token_id is not None:
769
- last_match_id = node.token_id
770
- last_match_end = j
771
-
772
- if last_match_id is not None:
773
- result_ids.append(last_match_id)
774
- i = last_match_end
775
- else:
776
- # Fallback: encode one char as unk or atom
777
- result_ids.append(self.token_to_id.get(text[i], self.unk_token_id))
778
- i += 1
779
-
780
- return result_ids
781
-
782
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
783
- """SELFIES decoding: join tokens with spaces (preserve original format)."""
784
- return " ".join(tokens)
785
-
786
- def decode(
787
- self,
788
- token_ids: Union[List[int], torch.Tensor],
789
- skip_special_tokens: bool = False,
790
- clean_up_tokenization_spaces: bool = None,
791
- **kwargs
792
- ) -> str:
793
- if isinstance(token_ids, torch.Tensor):
794
- token_ids = token_ids.tolist()
795
-
796
- if skip_special_tokens:
797
- special_ids = {
798
- self.bos_token_id,
799
- self.eos_token_id,
800
- self.pad_token_id,
801
- self.mask_token_id,
802
- }
803
- else:
804
- special_ids = set()
805
-
806
- tokens = []
807
- for tid in token_ids:
808
- if tid in special_ids:
809
- continue
810
- token = self.id_to_token.get(tid, self.unk_token)
811
- tokens.append(token)
812
-
813
- return " ".join(tokens) # ✅ preserve spaces return " ".join(tokens) # ✅ preserve spaces
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from typing import List, Union, Optional, Tuple, Dict, Any
5
+ from functools import lru_cache
6
+ from collections.abc import Mapping
7
+
8
+
9
+ # ------------------------------
10
+ # BatchEncoding
11
+ # ------------------------------
12
+ class BatchEncoding(dict, Mapping):
13
+ """Minimal BatchEncoding compatible wrapper."""
14
+
15
+ def __init__(self, data: dict, tensor_type: Optional[str] = None):
16
+ data = {} if data is None else {k: v for k, v in data.items()}
17
+ super().__init__(data)
18
+ self.data = data
19
+ self.tensor_type = tensor_type
20
+ for k, v in data.items():
21
+ setattr(self, k, v)
22
+
23
+ def __getitem__(self, key): return self.data[key]
24
+ def __iter__(self): return iter(self.data)
25
+ def __len__(self): return len(self.data)
26
+ def keys(self): return self.data.keys()
27
+ def values(self): return self.data.values()
28
+ def items(self): return self.data.items()
29
+ def get(self, key, default=None): return self.data.get(key, default)
30
+
31
+ def to(self, device):
32
+ if self.tensor_type in ("pt", "torch"):
33
+ for k, v in list(self.data.items()):
34
+ if torch.is_tensor(v):
35
+ self.data[k] = v.to(device)
36
+ setattr(self, k, self.data[k])
37
+ return self
38
+
39
+ def cpu(self): return self.to("cpu")
40
+ def cuda(self): return self.to("cuda")
41
+ def detach(self):
42
+ if self.tensor_type in ("pt", "torch"):
43
+ for k, v in list(self.data.items()):
44
+ if torch.is_tensor(v):
45
+ self.data[k] = v.detach()
46
+ setattr(self, k, self.data[k])
47
+ return self
48
+
49
+ def __repr__(self):
50
+ keys = ", ".join(list(self.data.keys())[:10])
51
+ return f"BatchEncoding(keys=[{keys}], tensor_type={self.tensor_type})"
52
+
53
+
54
+ # ------------------------------
55
+ # Base class
56
+ # ------------------------------
57
+ class PreTrainedTokenizerBase:
58
+ def __init__(self, **kwargs):
59
+ for key, value in kwargs.items():
60
+ if key.endswith('_token'):
61
+ setattr(self, f"_{key}", value)
62
+ setattr(self, f"{key}_id", None)
63
+ self.model_max_length = kwargs.get('model_max_length', 512)
64
+ self.padding_side = kwargs.get('padding_side', 'right')
65
+ self.truncation_side = kwargs.get('truncation_side', 'right')
66
+ self.chat_template = kwargs.get('chat_template')
67
+
68
+
69
+ # ------------------------------
70
+ # Trie node
71
+ # ------------------------------
72
+ class TrieNode:
73
+ __slots__ = ['children', 'token_id']
74
+ def __init__(self):
75
+ self.children = {}
76
+ self.token_id = None
77
+
78
+
79
+ # ------------------------------
80
+ # FastChemTokenizer
81
+ # ------------------------------
82
+
83
+ class FastChemTokenizer(PreTrainedTokenizerBase):
84
+ def __init__(self, token_to_id=None, vocab_file=None, **kwargs):
85
+ if vocab_file is not None:
86
+ with open(vocab_file, "r", encoding="utf-8") as f:
87
+ token_to_id = json.load(f)
88
+ token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
89
+
90
+ self.token_to_id = token_to_id
91
+ self.id_to_token = {v: k for k, v in token_to_id.items()}
92
+
93
+ # Build trie
94
+ self.trie_root = self._build_trie(self.token_to_id)
95
+
96
+ # ✅ Call parent (sets token *strings*, may reset *_id to None)
97
+ super().__init__(
98
+ bos_token="<s>",
99
+ eos_token="</s>",
100
+ unk_token="<unk>",
101
+ pad_token="<pad>",
102
+ mask_token="<mask>",
103
+ model_max_length=kwargs.get("model_max_length", 512),
104
+ padding_side=kwargs.get("padding_side", "right"),
105
+ truncation_side=kwargs.get("truncation_side", "right"),
106
+ **kwargs,
107
+ )
108
+
109
+ # ✅ Re-map token strings → IDs from vocab
110
+ self.bos_token_id = self.token_to_id.get("<s>", 0)
111
+ self.eos_token_id = self.token_to_id.get("</s>", 1)
112
+ self.pad_token_id = self.token_to_id.get("<pad>", 2)
113
+ self.unk_token_id = self.token_to_id.get("<unk>", 3)
114
+ self.mask_token_id = self.token_to_id.get("<mask>", 4)
115
+
116
+ # Ensure reverse mapping always valid
117
+ self.id_to_token[self.bos_token_id] = "<s>"
118
+ self.id_to_token[self.eos_token_id] = "</s>"
119
+ self.id_to_token[self.pad_token_id] = "<pad>"
120
+ self.id_to_token[self.unk_token_id] = "<unk>"
121
+ self.id_to_token[self.mask_token_id] = "<mask>"
122
+
123
+ # Debug
124
+ print("✅ Special tokens bound:",
125
+ self.bos_token_id, self.eos_token_id, self.pad_token_id,
126
+ self.unk_token_id, self.mask_token_id)
127
+
128
+ # ✅ Ensure token *strings* also exist (for decode fallback)
129
+ self.bos_token = "<s>"
130
+ self.eos_token = "</s>"
131
+ self.pad_token = "<pad>"
132
+ self.unk_token = "<unk>"
133
+ self.mask_token = "<mask>"
134
+
135
+
136
+ def _build_trie(self, token_to_id):
137
+ root = TrieNode()
138
+ for token, tid in token_to_id.items():
139
+ node = root
140
+ for char in token:
141
+ if char not in node.children:
142
+ node.children[char] = TrieNode()
143
+ node = node.children[char]
144
+ node.token_id = tid
145
+ return root
146
+
147
+ @property
148
+ def vocab_size(self): return len(self.token_to_id)
149
+ def __len__(self): return len(self.token_to_id)
150
+ def get_vocab(self) -> Dict[str, int]: return self.token_to_id.copy()
151
+
152
+ @lru_cache(maxsize=10000)
153
+ def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
154
+ return tuple(self._encode_core(s))
155
+
156
+ def _encode_core(self, text: str) -> List[int]:
157
+ tokens, result_ids = text, []
158
+ i, n = 0, len(tokens)
159
+ while i < n:
160
+ node, j = self.trie_root, i
161
+ last_match_id, last_match_end = None, i
162
+ while j < n and tokens[j] in node.children:
163
+ node = node.children[tokens[j]]
164
+ j += 1
165
+ if node.token_id is not None:
166
+ last_match_id, last_match_end = node.token_id, j
167
+ if last_match_id is not None:
168
+ result_ids.append(last_match_id)
169
+ i = last_match_end
170
+ else:
171
+ tid = self.token_to_id.get(tokens[i], self.unk_token_id)
172
+ result_ids.append(tid)
173
+ i += 1
174
+ return result_ids
175
+
176
+ # ------------------------------
177
+ # Converters
178
+ # ------------------------------
179
+ def _convert_token_to_id(self, token: str) -> int:
180
+ return self.token_to_id.get(token, self.unk_token_id)
181
+ def _convert_id_to_token(self, index: int) -> str:
182
+ return self.id_to_token.get(index, self.unk_token)
183
+
184
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]):
185
+ if isinstance(tokens, str): return self._convert_token_to_id(tokens)
186
+ return [self._convert_token_to_id(tok) for tok in tokens]
187
+
188
+ def convert_ids_to_tokens(self, ids: Union[int, List[int]]):
189
+ if isinstance(ids, int): return self._convert_id_to_token(ids)
190
+ return [self._convert_id_to_token(i) for i in ids]
191
+
192
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: return "".join(tokens)
193
+
194
+ # ------------------------------
195
+ # Encoding / Decoding
196
+ # ------------------------------
197
+ # ------------------------------
198
+ # Convenience wrappers
199
+ # ------------------------------
200
+ def encode(
201
+ self,
202
+ text: str,
203
+ text_pair: Optional[str] = None,
204
+ add_special_tokens: bool = True,
205
+ padding: bool = False,
206
+ truncation: bool = False,
207
+ max_length: Optional[int] = None,
208
+ return_tensors: Optional[str] = None,
209
+ ) -> List[int]:
210
+ encoded = self.encode_plus(
211
+ text=text,
212
+ text_pair=text_pair,
213
+ add_special_tokens=add_special_tokens,
214
+ padding=padding,
215
+ truncation=truncation,
216
+ max_length=max_length,
217
+ return_tensors=return_tensors,
218
+ )
219
+ input_ids = encoded["input_ids"]
220
+ if isinstance(input_ids, torch.Tensor):
221
+ if input_ids.dim() > 1:
222
+ input_ids = input_ids.squeeze(0)
223
+ input_ids = input_ids.tolist()
224
+ return input_ids
225
+
226
+ def __call__(
227
+ self,
228
+ text: Union[str, List[str]],
229
+ text_pair: Optional[Union[str, List[str]]] = None,
230
+ add_special_tokens: bool = True,
231
+ padding: Union[bool, str] = False,
232
+ truncation: Union[bool, str] = False,
233
+ max_length: Optional[int] = None,
234
+ stride: int = 0,
235
+ is_split_into_words: bool = False,
236
+ pad_to_multiple_of: Optional[int] = None,
237
+ return_tensors: Optional[Union[str, Any]] = None,
238
+ return_token_type_ids: Optional[bool] = None,
239
+ return_attention_mask: Optional[bool] = None,
240
+ return_overflowing_tokens: bool = False,
241
+ return_special_tokens_mask: bool = False,
242
+ return_offsets_mapping: bool = False,
243
+ return_length: bool = False,
244
+ verbose: bool = True,
245
+ **kwargs
246
+ ) -> BatchEncoding:
247
+ """HuggingFace-compatible: one string encode_plus, list → batch_encode_plus"""
248
+ if return_token_type_ids is None:
249
+ return_token_type_ids = True
250
+ if return_attention_mask is None:
251
+ return_attention_mask = True
252
+
253
+ if isinstance(text, list):
254
+ if text_pair is not None:
255
+ batch = [(t, p) for t, p in zip(text, text_pair)]
256
+ else:
257
+ batch = text
258
+ return self.batch_encode_plus(
259
+ batch,
260
+ add_special_tokens=add_special_tokens,
261
+ padding=padding,
262
+ truncation=truncation,
263
+ max_length=max_length,
264
+ stride=stride,
265
+ is_split_into_words=is_split_into_words,
266
+ pad_to_multiple_of=pad_to_multiple_of,
267
+ return_tensors=return_tensors,
268
+ return_token_type_ids=return_token_type_ids,
269
+ return_attention_mask=return_attention_mask,
270
+ return_overflowing_tokens=return_overflowing_tokens,
271
+ return_special_tokens_mask=return_special_tokens_mask,
272
+ return_offsets_mapping=return_offsets_mapping,
273
+ return_length=return_length,
274
+ verbose=verbose,
275
+ **kwargs
276
+ )
277
+ else:
278
+ return self.encode_plus(
279
+ text=text,
280
+ text_pair=text_pair,
281
+ add_special_tokens=add_special_tokens,
282
+ padding=padding,
283
+ truncation=truncation,
284
+ max_length=max_length,
285
+ stride=stride,
286
+ is_split_into_words=is_split_into_words,
287
+ pad_to_multiple_of=pad_to_multiple_of,
288
+ return_tensors=return_tensors,
289
+ return_token_type_ids=return_token_type_ids,
290
+ return_attention_mask=return_attention_mask,
291
+ return_overflowing_tokens=return_overflowing_tokens,
292
+ return_special_tokens_mask=return_special_tokens_mask,
293
+ return_offsets_mapping=return_offsets_mapping,
294
+ return_length=return_length,
295
+ verbose=verbose,
296
+ **kwargs
297
+ )
298
+
299
+ def encode_plus(
300
+ self,
301
+ text: str,
302
+ text_pair: Optional[str] = None,
303
+ add_special_tokens: bool = True,
304
+ padding: Union[bool, str] = False,
305
+ truncation: Union[bool, str] = False,
306
+ max_length: Optional[int] = None,
307
+ stride: int = 0,
308
+ is_split_into_words: bool = False,
309
+ pad_to_multiple_of: Optional[int] = None,
310
+ return_tensors: Optional[Union[str, Any]] = None,
311
+ return_token_type_ids: Optional[bool] = True,
312
+ return_attention_mask: Optional[bool] = True,
313
+ return_overflowing_tokens: bool = False,
314
+ return_special_tokens_mask: bool = False,
315
+ return_offsets_mapping: bool = False,
316
+ return_length: bool = False,
317
+ verbose: bool = True,
318
+ **kwargs
319
+ ) -> BatchEncoding:
320
+ if max_length is None: max_length = self.model_max_length
321
+ ids_a = list(self._cached_encode_str(text.strip()))
322
+ ids_b = list(self._cached_encode_str(text_pair.strip())) if text_pair else None
323
+
324
+ input_ids, token_type_ids = [], []
325
+ if add_special_tokens:
326
+ input_ids.append(self.bos_token_id); token_type_ids.append(0)
327
+ input_ids.extend(ids_a); token_type_ids.extend([0] * len(ids_a))
328
+ input_ids.append(self.eos_token_id); token_type_ids.append(0)
329
+ if ids_b is not None:
330
+ input_ids.extend(ids_b); token_type_ids.extend([1] * len(ids_b))
331
+ input_ids.append(self.eos_token_id); token_type_ids.append(1)
332
+ else:
333
+ input_ids = ids_a.copy(); token_type_ids = [0] * len(input_ids)
334
+ if ids_b is not None:
335
+ input_ids.extend(ids_b); token_type_ids.extend([1] * len(ids_b))
336
+
337
+ if truncation and len(input_ids) > max_length:
338
+ input_ids, token_type_ids = input_ids[:max_length], token_type_ids[:max_length]
339
+
340
+ encoded_dict = {"input_ids": input_ids}
341
+ if return_attention_mask:
342
+ if padding == True or padding == "max_length":
343
+ pad_len = max_length - len(input_ids)
344
+ if pad_len > 0:
345
+ if self.padding_side == "right":
346
+ input_ids.extend([self.pad_token_id] * pad_len)
347
+ token_type_ids.extend([0] * pad_len)
348
+ else:
349
+ input_ids = [self.pad_token_id] * pad_len + input_ids
350
+ token_type_ids = [0] * pad_len + token_type_ids
351
+ attention_mask = [0 if tid == self.pad_token_id else 1 for tid in input_ids]
352
+ encoded_dict["attention_mask"] = attention_mask
353
+ if return_token_type_ids: encoded_dict["token_type_ids"] = token_type_ids
354
+ if return_special_tokens_mask:
355
+ encoded_dict["special_tokens_mask"] = [
356
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
357
+ for tid in input_ids
358
+ ]
359
+ if return_length:
360
+ encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
361
+
362
+ if return_tensors in ["pt", "torch"]:
363
+ out = {}
364
+ for k, v in encoded_dict.items():
365
+ if isinstance(v, list):
366
+ tensor = torch.tensor(
367
+ [self.unk_token_id if x is None else int(x) for x in v], dtype=torch.long
368
+ ).unsqueeze(0)
369
+ out[k] = tensor
370
+ else:
371
+ out[k] = v
372
+ return BatchEncoding(out, tensor_type=return_tensors)
373
+ return BatchEncoding(encoded_dict, tensor_type=None)
374
+
375
+ def batch_encode_plus(
376
+ self,
377
+ batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
378
+ add_special_tokens: bool = True,
379
+ padding: Union[bool, str] = False,
380
+ truncation: Union[bool, str] = False,
381
+ max_length: Optional[int] = None,
382
+ stride: int = 0,
383
+ is_split_into_words: bool = False,
384
+ pad_to_multiple_of: Optional[int] = None,
385
+ return_tensors: Optional[Union[str, Any]] = None,
386
+ return_token_type_ids: Optional[bool] = True,
387
+ return_attention_mask: Optional[bool] = True,
388
+ return_overflowing_tokens: bool = False,
389
+ return_special_tokens_mask: bool = False,
390
+ return_offsets_mapping: bool = False,
391
+ return_length: bool = False,
392
+ verbose: bool = True,
393
+ **kwargs
394
+ ) -> BatchEncoding:
395
+ if padding is True: padding = "longest"
396
+ if padding == "max_length" and max_length is None: max_length = self.model_max_length
397
+
398
+ all_input_ids, all_token_type_ids, all_attention_masks = [], [], []
399
+ all_special_masks, all_lengths = [], []
400
+ for item in batch_text_or_text_pairs:
401
+ t, tp = item if isinstance(item, tuple) else (item, None)
402
+ enc = self.encode_plus(
403
+ text=t, text_pair=tp, add_special_tokens=add_special_tokens,
404
+ padding=False, truncation=truncation, max_length=max_length,
405
+ return_tensors=None, return_token_type_ids=return_token_type_ids,
406
+ return_attention_mask=return_attention_mask,
407
+ return_special_tokens_mask=return_special_tokens_mask,
408
+ return_length=return_length, **kwargs
409
+ )
410
+ ids, tt, am = enc["input_ids"], enc.get("token_type_ids", [0]*len(enc["input_ids"])), enc.get("attention_mask",[1]*len(enc["input_ids"]))
411
+ sm, ln = enc.get("special_tokens_mask",[0]*len(ids)), enc.get("length", len([x for x in ids if x != self.pad_token_id]))
412
+ all_input_ids.append(ids); all_token_type_ids.append(tt); all_attention_masks.append(am)
413
+ all_special_masks.append(sm); all_lengths.append(ln)
414
+
415
+ pad_to = max(len(x) for x in all_input_ids) if padding == "longest" else (max_length if padding == "max_length" else None)
416
+ batched = {
417
+ "input_ids": all_input_ids,
418
+ "token_type_ids": all_token_type_ids if return_token_type_ids else None,
419
+ "attention_mask": all_attention_masks if return_attention_mask else None,
420
+ "special_tokens_mask": all_special_masks if return_special_tokens_mask else None,
421
+ "length": all_lengths if return_length else None,
422
+ }
423
+ if pad_to is not None:
424
+ for key in ["input_ids","token_type_ids","attention_mask","special_tokens_mask"]:
425
+ if batched.get(key) is None: continue
426
+ padded = []
427
+ for seq in batched[key]:
428
+ pad_len = pad_to - len(seq)
429
+ pad_val = self.pad_token_id if key=="input_ids" else 0
430
+ if pad_len > 0:
431
+ seq = seq+[pad_val]*pad_len if self.padding_side=="right" else [pad_val]*pad_len+seq
432
+ padded.append(seq)
433
+ batched[key] = padded
434
+
435
+ if return_tensors in ["pt", "torch"]:
436
+ def to_tensor(lst, pad_val=0):
437
+ return torch.tensor([[self.unk_token_id if x is None else int(x) for x in row] for row in lst], dtype=torch.long)
438
+ out = {}
439
+ if batched.get("input_ids") is not None: out["input_ids"] = to_tensor(batched["input_ids"], self.pad_token_id)
440
+ if batched.get("attention_mask") is not None: out["attention_mask"] = to_tensor(batched["attention_mask"],0)
441
+ if batched.get("token_type_ids") is not None: out["token_type_ids"] = to_tensor(batched["token_type_ids"],0)
442
+ if batched.get("special_tokens_mask") is not None: out["special_tokens_mask"] = to_tensor(batched["special_tokens_mask"],0)
443
+ if return_length and batched.get("length") is not None: out["length"] = torch.tensor([int(x) for x in batched["length"]], dtype=torch.long)
444
+ return BatchEncoding(out, tensor_type=return_tensors)
445
+ return BatchEncoding({k:v for k,v in batched.items() if v is not None}, tensor_type=None)
446
+
447
+ # ------------------------------
448
+ # Decoding
449
+ # ------------------------------
450
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
451
+ if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist()
452
+ special_ids = {self.bos_token_id,self.eos_token_id,self.pad_token_id,self.mask_token_id} if skip_special_tokens else set()
453
+ tokens = [self.id_to_token.get(tid,self.unk_token) for tid in token_ids if tid not in special_ids]
454
+ return "".join(tokens)
455
+
456
+ def batch_decode(self, sequences, skip_special_tokens=False, **kwargs):
457
+ if isinstance(sequences, torch.Tensor): sequences = sequences.tolist()
458
+ return [self.decode(seq, skip_special_tokens=skip_special_tokens, **kwargs) for seq in sequences]
459
+
460
+ def decode_with_trace(self, token_ids: List[int]):
461
+ print(f"\n🔍 Decoding {len(token_ids)} tokens:")
462
+ for i, tid in enumerate(token_ids):
463
+ token = self.id_to_token.get(tid, self.unk_token)
464
+ tid_str = "None" if tid is None else f"{tid:5d}"
465
+ print(f" [{i:03d}] ID={tid_str} → '{token}'")
466
+
467
+ def pad(
468
+ self,
469
+ encoded_inputs,
470
+ padding=True,
471
+ max_length=None,
472
+ pad_to_multiple_of=None,
473
+ return_tensors=None,
474
+ **kwargs,
475
+ ):
476
+ """
477
+ HuggingFace-style pad. Takes a list/dict of encoded inputs and pads them.
478
+ """
479
+ if isinstance(encoded_inputs, dict):
480
+ encoded_inputs = [encoded_inputs]
481
+
482
+ input_ids = [ei["input_ids"] for ei in encoded_inputs]
483
+ attn_masks = [ei.get("attention_mask", [1]*len(ei["input_ids"])) for ei in encoded_inputs]
484
+
485
+ # determine pad length
486
+ max_len = max(len(ids) for ids in input_ids)
487
+ if pad_to_multiple_of:
488
+ max_len = ((max_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
489
+ if max_length is not None:
490
+ max_len = min(max_len, max_length)
491
+
492
+ padded_ids, padded_masks = [], []
493
+ for ids, mask in zip(input_ids, attn_masks):
494
+ pad_len = max_len - len(ids)
495
+ if self.padding_side == "right":
496
+ padded_ids.append(ids + [self.pad_token_id] * pad_len)
497
+ padded_masks.append(mask + [0] * pad_len)
498
+ else:
499
+ padded_ids.append([self.pad_token_id] * pad_len + ids)
500
+ padded_masks.append([0] * pad_len + mask)
501
+
502
+ out = {"input_ids": padded_ids, "attention_mask": padded_masks}
503
+ if return_tensors in ["pt", "torch"]:
504
+ out = {k: torch.tensor(v, dtype=torch.long) for k, v in out.items()}
505
+ return out
506
+
507
+
508
+ # ------------------------------
509
+ # Save / Load
510
+ # ------------------------------
511
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
512
+ if not os.path.isdir(save_directory): os.makedirs(save_directory)
513
+ vocab_file = os.path.join(save_directory,(filename_prefix+"-" if filename_prefix else "")+"vocab.json")
514
+ with open(vocab_file,"w",encoding="utf-8") as f: json.dump(self.token_to_id,f,ensure_ascii=False,indent=2)
515
+ return (vocab_file,)
516
+
517
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], filename_prefix: Optional[str]=None, **kwargs):
518
+ if not os.path.exists(save_directory): os.makedirs(save_directory)
519
+ self.save_vocabulary(save_directory, filename_prefix)
520
+ config_file = os.path.join(save_directory,"tokenizer_config.json")
521
+ with open(config_file,"w",encoding="utf-8") as f:
522
+ json.dump({
523
+ "tokenizer_class": self.__class__.__name__,
524
+ "model_max_length": self.model_max_length,
525
+ "padding_side": self.padding_side,
526
+ "truncation_side": self.truncation_side,
527
+ "special_tokens": {
528
+ "bos_token": self.bos_token,
529
+ "eos_token": self.eos_token,
530
+ "pad_token": self.pad_token,
531
+ "unk_token": self.unk_token,
532
+ "mask_token": self.mask_token,
533
+ }
534
+ },f,ensure_ascii=False,indent=2)
535
+ return (save_directory,)
536
+
537
+ @classmethod
538
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
539
+ if os.path.isdir(pretrained_model_name_or_path):
540
+ vocab_file = os.path.join(pretrained_model_name_or_path,"vocab.json")
541
+ config_file = os.path.join(pretrained_model_name_or_path,"tokenizer_config.json")
542
+ config = {}
543
+ if os.path.exists(config_file):
544
+ with open(config_file,"r",encoding="utf-8") as f: config=json.load(f)
545
+ return cls(vocab_file=vocab_file, **{**config,**kwargs})
546
+ else:
547
+ raise NotImplementedError("Loading from Hub not implemented yet")
548
+
549
+
550
+ # ------------------------------
551
+ # SELFIES variant
552
+ # ------------------------------
553
+ class FastChemTokenizerSelfies(FastChemTokenizer):
554
+ def __init__(self, *args, **kwargs):
555
+ super().__init__(*args, **kwargs) # ensures BOS/EOS etc. are set
556
+
557
+ """SELFIES variant that handles whitespace-separated tokens."""
558
+
559
+ def _encode_core(self, text: str) -> List[int]:
560
+ result_ids, i, n = [], 0, len(text)
561
+ while i < n:
562
+ if text[i].isspace(): i += 1; continue
563
+ node, j = self.trie_root, i
564
+ last_match_id, last_match_end = None, i
565
+ while j < n and text[j] in node.children:
566
+ node = node.children[text[j]]; j += 1
567
+ if node.token_id is not None:
568
+ last_match_id, last_match_end = node.token_id, j
569
+ if last_match_id is not None:
570
+ result_ids.append(last_match_id); i = last_match_end
571
+ else:
572
+ result_ids.append(self.token_to_id.get(text[i], self.unk_token_id)); i += 1
573
+ return result_ids
574
+
575
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: return " ".join(tokens)
576
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
577
+ if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist()
578
+ special_ids = {self.bos_token_id,self.eos_token_id,self.pad_token_id,self.mask_token_id} if skip_special_tokens else set()
579
+ tokens = [self.id_to_token.get(tid,self.unk_token) for tid in token_ids if tid not in special_ids]
580
+ return " ".join(tokens)