"""Hugging Face compatible SMILES tokenizer wrapper.""" import os import json import torch from transformers import PreTrainedTokenizer from .smiles_tokenizer import SmilesTokenizer, SmilesVocabulary class HFSmilesTokenizer(PreTrainedTokenizer): """ Wrapper class for the SmilesTokenizer to make it compatible with the Hugging Face tokenizer interface. This allows the tokenizer to be used with any Hugging Face model, especially GPT-2. """ # Required for Hugging Face tokenizers model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab=None, pad_token="", eos_token="", unk_token="", bos_token="", **kwargs ): # Initialize the base tokenizer if vocab is None: self.smiles_tokenizer = SmilesTokenizer() else: vocabulary = SmilesVocabulary( pad=pad_token, eos=eos_token, unk=unk_token, go=bos_token ) # Add custom vocab symbols if provided if isinstance(vocab, list): for token in vocab: vocabulary.add_symbol(token) self.smiles_tokenizer = SmilesTokenizer(vocabulary=vocabulary) # Set up the vocabulary BEFORE calling super().__init__ self._vocab = { token: idx for idx, token in enumerate(self.smiles_tokenizer.vocabulary.symbols) } self._ids_to_tokens = { idx: token for token, idx in self._vocab.items() } # Initialize the PreTrainedTokenizer with our special tokens super().__init__( unk_token=unk_token, pad_token=pad_token, eos_token=eos_token, bos_token=bos_token, **kwargs ) @property def vocab_size(self): """Return the size of vocabulary.""" return len(self._vocab) def get_vocab(self): """Return the vocabulary dictionary.""" return self._vocab def _tokenize(self, text): """ Tokenize a string into a list of tokens. """ if isinstance(text, list): return self.smiles_tokenizer.tokenize(text, enclose=False)[0] return self.smiles_tokenizer.tokenize([text], enclose=False)[0] def _convert_token_to_id(self, token): """ Convert a token to its ID. """ return self.smiles_tokenizer.vocabulary.index(token) def _convert_id_to_token(self, index): """ Convert an ID to its token. """ return self.smiles_tokenizer.vocabulary[index] def convert_tokens_to_string(self, tokens): """ Convert a list of tokens to a string. """ return "".join(tokens) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): """ Build model inputs from a sequence by adding special tokens. Returns: List[int]: List of input IDs with special tokens added. """ bos_token_id = self.bos_token_id eos_token_id = self.eos_token_id if token_ids_1 is None: return [bos_token_id] + token_ids_0 + [eos_token_id] # For sequence pairs, we follow GPT-2 format: seq1 seq2 return [bos_token_id] + token_ids_0 + [eos_token_id] + token_ids_1 + [eos_token_id] def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): """ Retrieve sequence of special tokens mask. Returns: List[int]: A list of integers where 1 indicates a special token and 0 indicates a sequence token. """ if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) if token_ids_1 is None: return [1] + [0] * len(token_ids_0) + [1] return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): """ Create token type IDs for sequence pairs. Returns: List[int]: List of token type IDs. """ if token_ids_1 is None: return [0] * len(token_ids_0 + 2) # +2 for and # For GPT-2, we use all 0s for token type IDs return [0] * (len(token_ids_0) + len(token_ids_1) + 3) # +3 for and two def save_vocabulary(self, save_directory, filename_prefix=None): """ Save the tokenizer vocabulary to a directory. """ if not os.path.isdir(save_directory): os.makedirs(save_directory, exist_ok=True) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: json.dump(self._vocab, f, ensure_ascii=False, indent=2) return (vocab_file,) def encode_smiles(self, smiles, enclose=True, return_tensors=None): """ Encode a list of SMILES strings using the original SmilesTokenizer functionality. Args: smiles: A list of SMILES strings or a single SMILES string. enclose: Whether to add special tokens. return_tensors: The type of tensors to return ('pt' for PyTorch, None for lists). Returns: List of token IDs or PyTorch tensors. """ ids_list = self.smiles_tokenizer.encode(smiles, enclose=enclose, aslist=True) if return_tensors == "pt": return [torch.tensor(ids, dtype=torch.long) for ids in ids_list] return ids_list def decode_smiles(self, ids_list): """ Decode a list of token IDs back to SMILES strings using the original SmilesTokenizer functionality. Args: ids_list: A list of lists or tensors containing token IDs. Returns: List of SMILES strings. """ return self.smiles_tokenizer.decode(ids_list) def tokens_to_smiles(self, tokens): """ Convert generated tokens to SMILES strings. Args: tokens: List of token IDs. Returns: List of SMILES strings. """ return self.smiles_tokenizer.tokens_to_smiles(tokens)