File size: 6,775 Bytes
6425080 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""Hugging Face compatible SMILES tokenizer wrapper."""
import os
import json
import torch
from transformers import PreTrainedTokenizer
from .smiles_tokenizer import SmilesTokenizer, SmilesVocabulary
class HFSmilesTokenizer(PreTrainedTokenizer):
"""
Wrapper class for the SmilesTokenizer to make it compatible with the Hugging Face tokenizer interface.
This allows the tokenizer to be used with any Hugging Face model, especially GPT-2.
"""
# Required for Hugging Face tokenizers
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab=None,
pad_token="<pad>",
eos_token="</s>",
unk_token="<unk>",
bos_token="<go>",
**kwargs
):
# Initialize the base tokenizer
if vocab is None:
self.smiles_tokenizer = SmilesTokenizer()
else:
vocabulary = SmilesVocabulary(
pad=pad_token,
eos=eos_token,
unk=unk_token,
go=bos_token
)
# Add custom vocab symbols if provided
if isinstance(vocab, list):
for token in vocab:
vocabulary.add_symbol(token)
self.smiles_tokenizer = SmilesTokenizer(vocabulary=vocabulary)
# Set up the vocabulary BEFORE calling super().__init__
self._vocab = {
token: idx for idx, token in enumerate(self.smiles_tokenizer.vocabulary.symbols)
}
self._ids_to_tokens = {
idx: token for token, idx in self._vocab.items()
}
# Initialize the PreTrainedTokenizer with our special tokens
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
eos_token=eos_token,
bos_token=bos_token,
**kwargs
)
@property
def vocab_size(self):
"""Return the size of vocabulary."""
return len(self._vocab)
def get_vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
def _tokenize(self, text):
"""
Tokenize a string into a list of tokens.
"""
if isinstance(text, list):
return self.smiles_tokenizer.tokenize(text, enclose=False)[0]
return self.smiles_tokenizer.tokenize([text], enclose=False)[0]
def _convert_token_to_id(self, token):
"""
Convert a token to its ID.
"""
return self.smiles_tokenizer.vocabulary.index(token)
def _convert_id_to_token(self, index):
"""
Convert an ID to its token.
"""
return self.smiles_tokenizer.vocabulary[index]
def convert_tokens_to_string(self, tokens):
"""
Convert a list of tokens to a string.
"""
return "".join(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence by adding special tokens.
Returns:
List[int]: List of input IDs with special tokens added.
"""
bos_token_id = self.bos_token_id
eos_token_id = self.eos_token_id
if token_ids_1 is None:
return [bos_token_id] + token_ids_0 + [eos_token_id]
# For sequence pairs, we follow GPT-2 format: <bos> seq1 <eos> seq2 <eos>
return [bos_token_id] + token_ids_0 + [eos_token_id] + token_ids_1 + [eos_token_id]
def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
"""
Retrieve sequence of special tokens mask.
Returns:
List[int]: A list of integers where 1 indicates a special token and 0 indicates a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
"""
Create token type IDs for sequence pairs.
Returns:
List[int]: List of token type IDs.
"""
if token_ids_1 is None:
return [0] * len(token_ids_0 + 2) # +2 for <bos> and <eos>
# For GPT-2, we use all 0s for token type IDs
return [0] * (len(token_ids_0) + len(token_ids_1) + 3) # +3 for <bos> and two <eos>
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the tokenizer vocabulary to a directory.
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
def encode_smiles(self, smiles, enclose=True, return_tensors=None):
"""
Encode a list of SMILES strings using the original SmilesTokenizer functionality.
Args:
smiles: A list of SMILES strings or a single SMILES string.
enclose: Whether to add special tokens.
return_tensors: The type of tensors to return ('pt' for PyTorch, None for lists).
Returns:
List of token IDs or PyTorch tensors.
"""
ids_list = self.smiles_tokenizer.encode(smiles, enclose=enclose, aslist=True)
if return_tensors == "pt":
return [torch.tensor(ids, dtype=torch.long) for ids in ids_list]
return ids_list
def decode_smiles(self, ids_list):
"""
Decode a list of token IDs back to SMILES strings using the original SmilesTokenizer functionality.
Args:
ids_list: A list of lists or tensors containing token IDs.
Returns:
List of SMILES strings.
"""
return self.smiles_tokenizer.decode(ids_list)
def tokens_to_smiles(self, tokens):
"""
Convert generated tokens to SMILES strings.
Args:
tokens: List of token IDs.
Returns:
List of SMILES strings.
"""
return self.smiles_tokenizer.tokens_to_smiles(tokens)
|