|
|
"""Hugging Face compatible SMILES tokenizer wrapper.""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
from .smiles_tokenizer import SmilesTokenizer, SmilesVocabulary |
|
|
|
|
|
class HFSmilesTokenizer(PreTrainedTokenizer): |
|
|
""" |
|
|
Wrapper class for the SmilesTokenizer to make it compatible with the Hugging Face tokenizer interface. |
|
|
This allows the tokenizer to be used with any Hugging Face model, especially GPT-2. |
|
|
""" |
|
|
|
|
|
|
|
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab=None, |
|
|
pad_token="<pad>", |
|
|
eos_token="</s>", |
|
|
unk_token="<unk>", |
|
|
bos_token="<go>", |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
if vocab is None: |
|
|
self.smiles_tokenizer = SmilesTokenizer() |
|
|
else: |
|
|
vocabulary = SmilesVocabulary( |
|
|
pad=pad_token, |
|
|
eos=eos_token, |
|
|
unk=unk_token, |
|
|
go=bos_token |
|
|
) |
|
|
|
|
|
if isinstance(vocab, list): |
|
|
for token in vocab: |
|
|
vocabulary.add_symbol(token) |
|
|
self.smiles_tokenizer = SmilesTokenizer(vocabulary=vocabulary) |
|
|
|
|
|
|
|
|
self._vocab = { |
|
|
token: idx for idx, token in enumerate(self.smiles_tokenizer.vocabulary.symbols) |
|
|
} |
|
|
self._ids_to_tokens = { |
|
|
idx: token for token, idx in self._vocab.items() |
|
|
} |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
unk_token=unk_token, |
|
|
pad_token=pad_token, |
|
|
eos_token=eos_token, |
|
|
bos_token=bos_token, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
"""Return the size of vocabulary.""" |
|
|
return len(self._vocab) |
|
|
|
|
|
def get_vocab(self): |
|
|
"""Return the vocabulary dictionary.""" |
|
|
return self._vocab |
|
|
|
|
|
def _tokenize(self, text): |
|
|
""" |
|
|
Tokenize a string into a list of tokens. |
|
|
""" |
|
|
if isinstance(text, list): |
|
|
return self.smiles_tokenizer.tokenize(text, enclose=False)[0] |
|
|
return self.smiles_tokenizer.tokenize([text], enclose=False)[0] |
|
|
|
|
|
def _convert_token_to_id(self, token): |
|
|
""" |
|
|
Convert a token to its ID. |
|
|
""" |
|
|
return self.smiles_tokenizer.vocabulary.index(token) |
|
|
|
|
|
def _convert_id_to_token(self, index): |
|
|
""" |
|
|
Convert an ID to its token. |
|
|
""" |
|
|
return self.smiles_tokenizer.vocabulary[index] |
|
|
|
|
|
def convert_tokens_to_string(self, tokens): |
|
|
""" |
|
|
Convert a list of tokens to a string. |
|
|
""" |
|
|
return "".join(tokens) |
|
|
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
|
|
""" |
|
|
Build model inputs from a sequence by adding special tokens. |
|
|
|
|
|
Returns: |
|
|
List[int]: List of input IDs with special tokens added. |
|
|
""" |
|
|
bos_token_id = self.bos_token_id |
|
|
eos_token_id = self.eos_token_id |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [bos_token_id] + token_ids_0 + [eos_token_id] |
|
|
|
|
|
|
|
|
return [bos_token_id] + token_ids_0 + [eos_token_id] + token_ids_1 + [eos_token_id] |
|
|
|
|
|
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): |
|
|
""" |
|
|
Retrieve sequence of special tokens mask. |
|
|
|
|
|
Returns: |
|
|
List[int]: A list of integers where 1 indicates a special token and 0 indicates a sequence token. |
|
|
""" |
|
|
if already_has_special_tokens: |
|
|
return super().get_special_tokens_mask( |
|
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True |
|
|
) |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [1] + [0] * len(token_ids_0) + [1] |
|
|
|
|
|
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] |
|
|
|
|
|
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): |
|
|
""" |
|
|
Create token type IDs for sequence pairs. |
|
|
|
|
|
Returns: |
|
|
List[int]: List of token type IDs. |
|
|
""" |
|
|
if token_ids_1 is None: |
|
|
return [0] * len(token_ids_0 + 2) |
|
|
|
|
|
|
|
|
return [0] * (len(token_ids_0) + len(token_ids_1) + 3) |
|
|
|
|
|
def save_vocabulary(self, save_directory, filename_prefix=None): |
|
|
""" |
|
|
Save the tokenizer vocabulary to a directory. |
|
|
""" |
|
|
if not os.path.isdir(save_directory): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
vocab_file = os.path.join( |
|
|
save_directory, |
|
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.json" |
|
|
) |
|
|
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
json.dump(self._vocab, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return (vocab_file,) |
|
|
|
|
|
def encode_smiles(self, smiles, enclose=True, return_tensors=None): |
|
|
""" |
|
|
Encode a list of SMILES strings using the original SmilesTokenizer functionality. |
|
|
|
|
|
Args: |
|
|
smiles: A list of SMILES strings or a single SMILES string. |
|
|
enclose: Whether to add special tokens. |
|
|
return_tensors: The type of tensors to return ('pt' for PyTorch, None for lists). |
|
|
|
|
|
Returns: |
|
|
List of token IDs or PyTorch tensors. |
|
|
""" |
|
|
ids_list = self.smiles_tokenizer.encode(smiles, enclose=enclose, aslist=True) |
|
|
|
|
|
if return_tensors == "pt": |
|
|
return [torch.tensor(ids, dtype=torch.long) for ids in ids_list] |
|
|
|
|
|
return ids_list |
|
|
|
|
|
def decode_smiles(self, ids_list): |
|
|
""" |
|
|
Decode a list of token IDs back to SMILES strings using the original SmilesTokenizer functionality. |
|
|
|
|
|
Args: |
|
|
ids_list: A list of lists or tensors containing token IDs. |
|
|
|
|
|
Returns: |
|
|
List of SMILES strings. |
|
|
""" |
|
|
return self.smiles_tokenizer.decode(ids_list) |
|
|
|
|
|
def tokens_to_smiles(self, tokens): |
|
|
""" |
|
|
Convert generated tokens to SMILES strings. |
|
|
|
|
|
Args: |
|
|
tokens: List of token IDs. |
|
|
|
|
|
Returns: |
|
|
List of SMILES strings. |
|
|
""" |
|
|
return self.smiles_tokenizer.tokens_to_smiles(tokens) |
|
|
|