File size: 6,775 Bytes
6425080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Hugging Face compatible SMILES tokenizer wrapper."""

import os
import json
import torch
from transformers import PreTrainedTokenizer

from .smiles_tokenizer import SmilesTokenizer, SmilesVocabulary

class HFSmilesTokenizer(PreTrainedTokenizer):
    """
    Wrapper class for the SmilesTokenizer to make it compatible with the Hugging Face tokenizer interface.
    This allows the tokenizer to be used with any Hugging Face model, especially GPT-2.
    """
    
    # Required for Hugging Face tokenizers
    model_input_names = ["input_ids", "attention_mask"]
    
    def __init__(
        self,
        vocab=None,
        pad_token="<pad>",
        eos_token="</s>",
        unk_token="<unk>",
        bos_token="<go>",
        **kwargs
    ):
        # Initialize the base tokenizer
        if vocab is None:
            self.smiles_tokenizer = SmilesTokenizer()
        else:
            vocabulary = SmilesVocabulary(
                pad=pad_token, 
                eos=eos_token, 
                unk=unk_token, 
                go=bos_token
            )
            # Add custom vocab symbols if provided
            if isinstance(vocab, list):
                for token in vocab:
                    vocabulary.add_symbol(token)
            self.smiles_tokenizer = SmilesTokenizer(vocabulary=vocabulary)
        
        # Set up the vocabulary BEFORE calling super().__init__
        self._vocab = {
            token: idx for idx, token in enumerate(self.smiles_tokenizer.vocabulary.symbols)
        }
        self._ids_to_tokens = {
            idx: token for token, idx in self._vocab.items()
        }
        
        # Initialize the PreTrainedTokenizer with our special tokens
        super().__init__(
            unk_token=unk_token,
            pad_token=pad_token,
            eos_token=eos_token,
            bos_token=bos_token,
            **kwargs
        )
    
    @property
    def vocab_size(self):
        """Return the size of vocabulary."""
        return len(self._vocab)
    
    def get_vocab(self):
        """Return the vocabulary dictionary."""
        return self._vocab
    
    def _tokenize(self, text):
        """
        Tokenize a string into a list of tokens.
        """
        if isinstance(text, list):
            return self.smiles_tokenizer.tokenize(text, enclose=False)[0]
        return self.smiles_tokenizer.tokenize([text], enclose=False)[0]
    
    def _convert_token_to_id(self, token):
        """
        Convert a token to its ID.
        """
        return self.smiles_tokenizer.vocabulary.index(token)
    
    def _convert_id_to_token(self, index):
        """
        Convert an ID to its token.
        """
        return self.smiles_tokenizer.vocabulary[index]
    
    def convert_tokens_to_string(self, tokens):
        """
        Convert a list of tokens to a string.
        """
        return "".join(tokens)
    
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence by adding special tokens.
        
        Returns:
            List[int]: List of input IDs with special tokens added.
        """
        bos_token_id = self.bos_token_id
        eos_token_id = self.eos_token_id
        
        if token_ids_1 is None:
            return [bos_token_id] + token_ids_0 + [eos_token_id]
        
        # For sequence pairs, we follow GPT-2 format: <bos> seq1 <eos> seq2 <eos>
        return [bos_token_id] + token_ids_0 + [eos_token_id] + token_ids_1 + [eos_token_id]
    
    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Retrieve sequence of special tokens mask.
        
        Returns:
            List[int]: A list of integers where 1 indicates a special token and 0 indicates a sequence token.
        """
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )
        
        if token_ids_1 is None:
            return [1] + [0] * len(token_ids_0) + [1]
        
        return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
    
    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        """
        Create token type IDs for sequence pairs.
        
        Returns:
            List[int]: List of token type IDs.
        """
        if token_ids_1 is None:
            return [0] * len(token_ids_0 + 2)  # +2 for <bos> and <eos>
        
        # For GPT-2, we use all 0s for token type IDs
        return [0] * (len(token_ids_0) + len(token_ids_1) + 3)  # +3 for <bos> and two <eos>
    
    def save_vocabulary(self, save_directory, filename_prefix=None):
        """
        Save the tokenizer vocabulary to a directory.
        """
        if not os.path.isdir(save_directory):
            os.makedirs(save_directory, exist_ok=True)
        
        vocab_file = os.path.join(
            save_directory, 
            (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
        )
        
        with open(vocab_file, "w", encoding="utf-8") as f:
            json.dump(self._vocab, f, ensure_ascii=False, indent=2)
        
        return (vocab_file,)
    
    def encode_smiles(self, smiles, enclose=True, return_tensors=None):
        """
        Encode a list of SMILES strings using the original SmilesTokenizer functionality.
        
        Args:
            smiles: A list of SMILES strings or a single SMILES string.
            enclose: Whether to add special tokens.
            return_tensors: The type of tensors to return ('pt' for PyTorch, None for lists).
            
        Returns:
            List of token IDs or PyTorch tensors.
        """
        ids_list = self.smiles_tokenizer.encode(smiles, enclose=enclose, aslist=True)
        
        if return_tensors == "pt":
            return [torch.tensor(ids, dtype=torch.long) for ids in ids_list]
        
        return ids_list
    
    def decode_smiles(self, ids_list):
        """
        Decode a list of token IDs back to SMILES strings using the original SmilesTokenizer functionality.
        
        Args:
            ids_list: A list of lists or tensors containing token IDs.
            
        Returns:
            List of SMILES strings.
        """
        return self.smiles_tokenizer.decode(ids_list)
    
    def tokens_to_smiles(self, tokens):
        """
        Convert generated tokens to SMILES strings.
        
        Args:
            tokens: List of token IDs.
            
        Returns:
            List of SMILES strings.
        """
        return self.smiles_tokenizer.tokens_to_smiles(tokens)