gbyuvd commited on
Commit
02f786b
·
verified ·
1 Parent(s): 4f7732b

Upload FastChemTokenizerHF.py

Browse files
Files changed (1) hide show
  1. FastChemTokenizerHF.py +41 -0
FastChemTokenizerHF.py CHANGED
@@ -463,6 +463,47 @@ class FastChemTokenizer(PreTrainedTokenizerBase):
463
  token = self.id_to_token.get(tid, self.unk_token)
464
  tid_str = "None" if tid is None else f"{tid:5d}"
465
  print(f" [{i:03d}] ID={tid_str} → '{token}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # ------------------------------
468
  # Save / Load
 
463
  token = self.id_to_token.get(tid, self.unk_token)
464
  tid_str = "None" if tid is None else f"{tid:5d}"
465
  print(f" [{i:03d}] ID={tid_str} → '{token}'")
466
+
467
+ def pad(
468
+ self,
469
+ encoded_inputs,
470
+ padding=True,
471
+ max_length=None,
472
+ pad_to_multiple_of=None,
473
+ return_tensors=None,
474
+ **kwargs,
475
+ ):
476
+ """
477
+ HuggingFace-style pad. Takes a list/dict of encoded inputs and pads them.
478
+ """
479
+ if isinstance(encoded_inputs, dict):
480
+ encoded_inputs = [encoded_inputs]
481
+
482
+ input_ids = [ei["input_ids"] for ei in encoded_inputs]
483
+ attn_masks = [ei.get("attention_mask", [1]*len(ei["input_ids"])) for ei in encoded_inputs]
484
+
485
+ # determine pad length
486
+ max_len = max(len(ids) for ids in input_ids)
487
+ if pad_to_multiple_of:
488
+ max_len = ((max_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
489
+ if max_length is not None:
490
+ max_len = min(max_len, max_length)
491
+
492
+ padded_ids, padded_masks = [], []
493
+ for ids, mask in zip(input_ids, attn_masks):
494
+ pad_len = max_len - len(ids)
495
+ if self.padding_side == "right":
496
+ padded_ids.append(ids + [self.pad_token_id] * pad_len)
497
+ padded_masks.append(mask + [0] * pad_len)
498
+ else:
499
+ padded_ids.append([self.pad_token_id] * pad_len + ids)
500
+ padded_masks.append([0] * pad_len + mask)
501
+
502
+ out = {"input_ids": padded_ids, "attention_mask": padded_masks}
503
+ if return_tensors in ["pt", "torch"]:
504
+ out = {k: torch.tensor(v, dtype=torch.long) for k, v in out.items()}
505
+ return out
506
+
507
 
508
  # ------------------------------
509
  # Save / Load