suku9 commited on
Commit
4b0a05b
·
verified ·
1 Parent(s): 3f0172f

Upload SMILES tokenizer package

Browse files
README.md CHANGED
@@ -1,10 +1,43 @@
1
- ---
2
- title: Smiles Tokenizer Package
3
- emoji:
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: static
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SMILES Tokenizer
2
+
3
+ This is a custom tokenizer for SMILES (Simplified Molecular Input Line Entry System) strings.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install git+https://huggingface.co/suku9/smiles-tokenizer-package
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ # Basic usage
15
+ from smiles_tokenizer import SmilesTokenizer
16
+
17
+ tokenizer = SmilesTokenizer()
18
+ smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
19
+
20
+ # Tokenize
21
+ tokens = tokenizer.tokenize([smiles])[0]
22
+ print(tokens)
23
+
24
+ # Encode
25
+ encoded = tokenizer.encode([smiles])[0]
26
+ print(encoded)
27
+
28
+ # Use with GPT-2
29
+ from smiles_tokenizer.utils import prepare_for_gpt2
30
+
31
+ model, tokenizer_wrapper = prepare_for_gpt2(tokenizer)
32
+
33
+ # Now you can use it like a regular Hugging Face tokenizer
34
+ inputs = tokenizer_wrapper(smiles, return_tensors="pt")
35
+ outputs = model(**inputs)
36
+ ```
37
+
38
+ ## Features
39
+
40
+ - Specialized for SMILES strings
41
+ - Compatible with Hugging Face's transformers library
42
+ - Designed to work with GPT-2 models
43
+ - Preserves all functionality of the original SMILES tokenizer
setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Setup script for SMILES tokenizer package."""
2
+
3
+ from setuptools import setup, find_packages
4
+
5
+ setup(
6
+ name="smiles_tokenizer",
7
+ version="0.1.0",
8
+ description="SMILES tokenizer from suku9/smiles-tokenizer-package",
9
+ packages=find_packages(),
10
+ install_requires=[
11
+ "torch>=1.0.0",
12
+ "transformers>=4.0.0",
13
+ ],
14
+ )
smiles_tokenizer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """SMILES Tokenizer package."""
2
+
3
+ from .tokenizer import SmilesTokenizer
4
+ from .vocabulary import SmilesVocabulary, Vocabulary
5
+
6
+ __all__ = ["SmilesTokenizer", "SmilesVocabulary", "Vocabulary"]
smiles_tokenizer/tokenizer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SMILES tokenizer implementation."""
2
+
3
+ import re
4
+ import json
5
+ import warnings
6
+ from re import Pattern
7
+ from typing import Dict, List, Optional, Union, Any
8
+ import torch
9
+
10
+ from .vocabulary import Vocabulary, SmilesVocabulary
11
+
12
+ Tokens = List[str]
13
+
14
+ class SmilesTokenizer:
15
+ """
16
+ Smiles Tokenizer
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ vocabulary: Vocabulary = None,
22
+ ) -> None:
23
+ if vocabulary is None:
24
+ self.vocabulary = SmilesVocabulary()
25
+ else:
26
+ self.vocabulary = vocabulary
27
+ self._re: Optional[Pattern] = None
28
+
29
+ @property
30
+ def re(self) -> Pattern:
31
+ """Tokens Regex Object.
32
+
33
+ :return: Tokens Regex Object
34
+ """
35
+ if not self._re:
36
+ self._re = self._get_compiled_regex(self.vocabulary.symbols)
37
+ return self._re
38
+
39
+ def tokenize(self, smiles: List[str], enclose: bool = True) -> List[List[str]]:
40
+ """
41
+ convert list of smiles strings to list of lists containing tokens for each
42
+ """
43
+ if isinstance(smiles, str):
44
+ # Convert string to a list with one string
45
+ smiles = [smiles]
46
+
47
+ tokenized_data = []
48
+
49
+ for smi in smiles:
50
+ tokens = self.re.findall(smi)
51
+ if enclose:
52
+ tokenized_data.append(
53
+ [self.vocabulary.go_word] + tokens + [self.vocabulary.eos_word]
54
+ )
55
+ else:
56
+ tokenized_data.append(tokens)
57
+
58
+ return tokenized_data
59
+
60
+ def encode(self, smiles: List[str], enclose: bool = True, aslist=False):
61
+ """
62
+ convert a list of smiles strings to list of tensors containing token indices
63
+ """
64
+ if isinstance(smiles, str):
65
+ # Convert string to a list with one string
66
+ smiles = [smiles]
67
+
68
+ tokenized_smiles = self.tokenize(smiles, enclose=enclose)
69
+ tokens_lengths = list(map(len, tokenized_smiles))
70
+ ids_list = []
71
+
72
+ for tokens, length in zip(tokenized_smiles, tokens_lengths):
73
+ ids_tensor = [] # torch.zeros(length, dtype=torch.long)
74
+ for tdx, token in enumerate(tokens):
75
+ ids_tensor.append(self.vocabulary.index(token))
76
+ if not aslist:
77
+ ids_tensor = torch.tensor(ids_tensor, dtype=torch.long)
78
+ ids_list.append(ids_tensor)
79
+
80
+ return ids_list
81
+
82
+ def detokenize(
83
+ self,
84
+ token_data: List[List[str]],
85
+ include_control_tokens: bool = False,
86
+ include_end_of_line_token: bool = False,
87
+ truncate_at_end_token: bool = False,
88
+ ) -> List[str]:
89
+ """
90
+ Detokenizes lists of tokens into SMILES by concatenating the token strings.
91
+ """
92
+
93
+ character_lists = [tokens.copy() for tokens in token_data]
94
+
95
+ character_lists = [
96
+ self._strip_list(
97
+ tokens,
98
+ strip_control_tokens=not include_control_tokens,
99
+ truncate_at_end_token=truncate_at_end_token,
100
+ )
101
+ for tokens in character_lists
102
+ ]
103
+
104
+ if include_end_of_line_token:
105
+ for s in character_lists:
106
+ s.append("\n")
107
+
108
+ strings = ["".join(s) for s in character_lists]
109
+
110
+ return strings
111
+
112
+ def decode(self, ids_list: List[torch.Tensor]):
113
+ """
114
+ decodes lists of encodings (ids as tensors) back into smiles strings
115
+ """
116
+
117
+ tokenized_smiles = []
118
+ for ids in ids_list:
119
+ if not isinstance(ids, list):
120
+ ids = ids.tolist()
121
+
122
+ tokens = [self.vocabulary[i] for i in ids]
123
+ tokenized_smiles.append(tokens)
124
+ smiles = self.detokenize(tokenized_smiles, truncate_at_end_token=True)
125
+ return smiles
126
+
127
+ def tokens_to_smiles(self, tokens):
128
+ """
129
+ Convert generated tokens to smiles.
130
+
131
+ Arguments:
132
+ tokens: list of tokens
133
+
134
+ Returns:
135
+ list of smiles strings
136
+ """
137
+ # convert tokens to smiles
138
+ smiles = self.decode(tokens)
139
+ smiles = [smi.replace("<unk>", "") for smi in smiles]
140
+ return smiles
141
+
142
+ def _strip_list(
143
+ self,
144
+ tokens: List[str],
145
+ strip_control_tokens: bool = False,
146
+ truncate_at_end_token: bool = False,
147
+ ) -> List[str]:
148
+ """Cleanup tokens list from control tokens.
149
+
150
+ :param tokens: List of tokens
151
+ :param strip_control_tokens: Flag to remove control tokens, defaults to False
152
+ :param truncate_at_end_token: If True truncate tokens after end-token
153
+ """
154
+ if truncate_at_end_token and self.vocabulary.eos_word in tokens:
155
+ end_token_idx = tokens.index(self.vocabulary.eos_word)
156
+ tokens = tokens[: end_token_idx + 1]
157
+
158
+ strip_characters: List[str] = [self.vocabulary.pad_word]
159
+ if strip_control_tokens:
160
+ strip_characters.extend([self.vocabulary.go_word, self.vocabulary.eos_word])
161
+ while len(tokens) > 0 and tokens[0] in strip_characters:
162
+ tokens.pop(0)
163
+
164
+ while len(tokens) > 0 and tokens[-1] in strip_characters:
165
+ tokens.pop()
166
+
167
+ return tokens
168
+
169
+ def _get_compiled_regex(self, tokens: List[str]) -> Pattern:
170
+ """Create a Regular Expression Object from a list of tokens and regular expression tokens.
171
+
172
+ :param tokens: List of tokens
173
+ :return: Regular Expression Object
174
+ """
175
+ regex_string = r"(" # r"("
176
+ for ix, token in enumerate(tokens):
177
+ processed_token = token
178
+ for special_character in "()[]+*":
179
+ processed_token = processed_token.replace(
180
+ special_character, f"\{special_character}"
181
+ )
182
+ if ix < len(tokens) - 1:
183
+ regex_string += processed_token + r"|"
184
+ else:
185
+ regex_string += processed_token
186
+
187
+ regex_string += r")"
188
+ pattern = re.compile(regex_string)
189
+ return pattern
smiles_tokenizer/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for using SMILES tokenizer with transformers."""
2
+
3
+ import torch
4
+ from transformers import PreTrainedTokenizer, GPT2LMHeadModel
5
+
6
+ from .tokenizer import SmilesTokenizer
7
+
8
+ def get_tokenizer():
9
+ """Get a new instance of the SMILES tokenizer."""
10
+ return SmilesTokenizer()
11
+
12
+ def prepare_for_gpt2(tokenizer, model_name="gpt2"):
13
+ """Prepare a GPT-2 model to work with the SMILES tokenizer.
14
+
15
+ Args:
16
+ tokenizer: A SmilesTokenizer instance
17
+ model_name: Name of the GPT-2 model to load from Hugging Face
18
+
19
+ Returns:
20
+ tuple: (model, tokenizer_wrapper)
21
+ """
22
+ # Create a wrapper class for the tokenizer
23
+ class SmilesTokenizerWrapper(PreTrainedTokenizer):
24
+ def __init__(self, smiles_tokenizer):
25
+ self.smiles_tokenizer = smiles_tokenizer
26
+ self.vocab = {token: idx for idx, token in enumerate(smiles_tokenizer.vocabulary.symbols)}
27
+ super().__init__()
28
+
29
+ @property
30
+ def vocab_size(self):
31
+ return len(self.vocab)
32
+
33
+ def get_vocab(self):
34
+ return self.vocab
35
+
36
+ def _tokenize(self, text):
37
+ if isinstance(text, list):
38
+ return self.smiles_tokenizer.tokenize(text, enclose=False)[0]
39
+ return self.smiles_tokenizer.tokenize([text], enclose=False)[0]
40
+
41
+ def _convert_token_to_id(self, token):
42
+ return self.smiles_tokenizer.vocabulary.index(token)
43
+
44
+ def _convert_id_to_token(self, index):
45
+ return self.smiles_tokenizer.vocabulary[index]
46
+
47
+ def convert_tokens_to_string(self, tokens):
48
+ return "".join(tokens)
49
+
50
+ def __call__(self, text, return_tensors=None, **kwargs):
51
+ if isinstance(text, str):
52
+ text = [text]
53
+ encoded = self.smiles_tokenizer.encode(text, enclose=True)
54
+ if return_tensors == "pt":
55
+ # Convert to PyTorch tensors if needed
56
+ if not isinstance(encoded[0], torch.Tensor):
57
+ encoded = [torch.tensor(ids) for ids in encoded]
58
+ # Create attention mask
59
+ attention_mask = [torch.ones_like(ids) for ids in encoded]
60
+ # Pad sequences if there are multiple
61
+ if len(encoded) > 1:
62
+ max_len = max(len(ids) for ids in encoded)
63
+ padded_ids = []
64
+ padded_masks = []
65
+ for ids, mask in zip(encoded, attention_mask):
66
+ if len(ids) < max_len:
67
+ padding = torch.full((max_len - len(ids),), self.smiles_tokenizer.vocabulary.pad_index, dtype=torch.long)
68
+ padded_ids.append(torch.cat([ids, padding]))
69
+ padded_masks.append(torch.cat([mask, torch.zeros_like(padding)]))
70
+ else:
71
+ padded_ids.append(ids)
72
+ padded_masks.append(mask)
73
+ return {"input_ids": torch.stack(padded_ids), "attention_mask": torch.stack(padded_masks)}
74
+ else:
75
+ return {"input_ids": encoded[0].unsqueeze(0), "attention_mask": attention_mask[0].unsqueeze(0)}
76
+ return {"input_ids": encoded}
77
+
78
+ # Load the GPT-2 model
79
+ model = GPT2LMHeadModel.from_pretrained(model_name)
80
+ # Create the tokenizer wrapper
81
+ tokenizer_wrapper = SmilesTokenizerWrapper(tokenizer)
82
+ # Resize the model embeddings to match our vocabulary size
83
+ model.resize_token_embeddings(len(tokenizer_wrapper))
84
+ return model, tokenizer_wrapper
smiles_tokenizer/vocabulary.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Vocabulary classes for SMILES tokenization."""
2
+
3
+ import os
4
+ import torch
5
+ from collections import Counter
6
+
7
+ class Vocabulary(object):
8
+ """A mapping from symbols to consecutive integers"""
9
+
10
+ def __init__(self, pad="<pad>", eos="</s>", unk="<unk>"):
11
+ self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
12
+ self.symbols = []
13
+ self.count = []
14
+ self.indices = {}
15
+
16
+ self.pad_index = self.add_symbol(pad)
17
+ self.eos_index = self.add_symbol(eos)
18
+ self.unk_index = self.add_symbol(unk)
19
+ self.nspecial = len(self.symbols)
20
+
21
+ def __eq__(self, other):
22
+ return self.indices == other.indices
23
+
24
+ def __getitem__(self, idx):
25
+ if idx < len(self.symbols):
26
+ return self.symbols[idx]
27
+ return self.unk_word
28
+
29
+ def __len__(self):
30
+ """Returns the number of symbols in the dictionary"""
31
+ return len(self.symbols)
32
+
33
+ def index(self, sym):
34
+ """Returns the index of the specified symbol"""
35
+ if sym in self.indices:
36
+ return self.indices[sym]
37
+ return self.unk_index
38
+
39
+ def string(self, tensor, bpe_symbol=None, escape_unk=False):
40
+ """Helper for converting a tensor of token indices to a string.
41
+
42
+ Can optionally remove BPE symbols or escape <unk> words.
43
+ """
44
+ if torch.is_tensor(tensor) and tensor.dim() == 2:
45
+ return "\n".join(self.string(t) for t in tensor)
46
+
47
+ def token_string(i):
48
+ if i == self.unk():
49
+ return self.unk_string(escape_unk)
50
+ else:
51
+ return self[i]
52
+
53
+ sent = " ".join(token_string(i) for i in tensor if i != self.eos())
54
+ if bpe_symbol is not None:
55
+ sent = (sent + " ").replace(bpe_symbol, "").rstrip()
56
+ return sent
57
+
58
+ def unk_string(self, escape=False):
59
+ """Return unknown string, optionally escaped as: <<unk>>"""
60
+ if escape:
61
+ return "<{}>".format(self.unk_word)
62
+ else:
63
+ return self.unk_word
64
+
65
+ def add_symbol(self, word, n=1):
66
+ """Adds a word to the dictionary"""
67
+ if word in self.indices:
68
+ idx = self.indices[word]
69
+ self.count[idx] = self.count[idx] + n
70
+ return idx
71
+ else:
72
+ idx = len(self.symbols)
73
+ self.indices[word] = idx
74
+ self.symbols.append(word)
75
+ self.count.append(n)
76
+ return idx
77
+
78
+ def update(self, new_dict):
79
+ """Updates counts from new dictionary."""
80
+ for word in new_dict.symbols:
81
+ idx2 = new_dict.indices[word]
82
+ if word in self.indices:
83
+ idx = self.indices[word]
84
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
85
+ else:
86
+ idx = len(self.symbols)
87
+ self.indices[word] = idx
88
+ self.symbols.append(word)
89
+ self.count.append(new_dict.count[idx2])
90
+
91
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
92
+ """Sort symbols by frequency in descending order, ignoring special ones.
93
+
94
+ Args:
95
+ - threshold defines the minimum word count
96
+ - nwords defines the total number of words in the final dictionary,
97
+ including special symbols
98
+ - padding_factor can be used to pad the dictionary size to be a
99
+ multiple of 8, which is important on some hardware (e.g., Nvidia
100
+ Tensor Cores).
101
+ """
102
+ if nwords <= 0:
103
+ nwords = len(self)
104
+
105
+ new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
106
+ new_symbols = self.symbols[: self.nspecial]
107
+ new_count = self.count[: self.nspecial]
108
+
109
+ c = Counter(
110
+ dict(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
111
+ )
112
+ for symbol, count in c.most_common(nwords - self.nspecial):
113
+ if count >= threshold:
114
+ new_indices[symbol] = len(new_symbols)
115
+ new_symbols.append(symbol)
116
+ new_count.append(count)
117
+ else:
118
+ break
119
+
120
+ threshold_nwords = len(new_symbols)
121
+ if padding_factor > 1:
122
+ i = 0
123
+ while threshold_nwords % padding_factor != 0:
124
+ symbol = "madeupword{:04d}".format(i)
125
+ new_indices[symbol] = len(new_symbols)
126
+ new_symbols.append(symbol)
127
+ new_count.append(0)
128
+ i += 1
129
+ threshold_nwords += 1
130
+
131
+ assert len(new_symbols) % padding_factor == 0
132
+ assert len(new_symbols) == len(new_indices)
133
+
134
+ self.count = list(new_count)
135
+ self.symbols = list(new_symbols)
136
+ self.indices = new_indices
137
+
138
+ def pad(self):
139
+ """Helper to get index of pad symbol"""
140
+ return self.pad_index
141
+
142
+ def eos(self):
143
+ """Helper to get index of end-of-sentence symbol"""
144
+ return self.eos_index
145
+
146
+ def unk(self):
147
+ """Helper to get index of unk symbol"""
148
+ return self.unk_index
149
+
150
+ @classmethod
151
+ def load(cls, f, ignore_utf_errors=False):
152
+ """Loads the dictionary from a text file with the format:
153
+
154
+ ```
155
+ <symbol0> <count0>
156
+ <symbol1> <count1>
157
+ ...
158
+ ```
159
+ """
160
+ if isinstance(f, str):
161
+ try:
162
+ if not ignore_utf_errors:
163
+ with open(f, "r", encoding="utf-8") as fd:
164
+ return cls.load(fd)
165
+ else:
166
+ with open(f, "r", encoding="utf-8", errors="ignore") as fd:
167
+ return cls.load(fd)
168
+ except FileNotFoundError as fnfe:
169
+ raise fnfe
170
+ except Exception:
171
+ raise Exception(
172
+ "Incorrect encoding detected in {}, please "
173
+ "rebuild the dataset".format(f)
174
+ )
175
+
176
+ d = cls()
177
+ for line in f.readlines():
178
+ idx = line.rfind(" ")
179
+ word = line[:idx]
180
+ count = int(line[idx + 1 :])
181
+ d.indices[word] = len(d.symbols)
182
+ d.symbols.append(word)
183
+ d.count.append(count)
184
+ return d
185
+
186
+ def save(self, f):
187
+ """Stores dictionary into a text file"""
188
+ if isinstance(f, str):
189
+ os.makedirs(os.path.dirname(f), exist_ok=True)
190
+ with open(f, "w", encoding="utf-8") as fd:
191
+ return self.save(fd)
192
+ for symbol, count in zip(
193
+ self.symbols[self.nspecial :], self.count[self.nspecial :]
194
+ ):
195
+ print("{} {}".format(symbol, count), file=f)
196
+
197
+ def dummy_sentence(self, length):
198
+ t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
199
+ t[-1] = self.eos()
200
+ return t
201
+
202
+
203
+ class SmilesVocabulary(Vocabulary):
204
+ def __init__(self, pad="<pad>", eos="</s>", unk="<unk>", go="<go>"):
205
+ self.unk_word, self.pad_word, self.eos_word, self.go_word = (
206
+ unk,
207
+ pad,
208
+ eos,
209
+ go,
210
+ )
211
+ self.symbols = []
212
+ self.count = []
213
+ self.indices = {}
214
+
215
+ self.pad_index = self.add_symbol(pad)
216
+ self.eos_index = self.add_symbol(eos)
217
+ self.unk_index = self.add_symbol(unk)
218
+ self.go_index = self.add_symbol(go)
219
+ self.nspecial = len(self.symbols)
220
+ for token in self.__get_smile_tokens():
221
+ self.add_symbol(token)
222
+
223
+ def __get_smile_tokens(self):
224
+ SMILE_TOKENS = [
225
+ "S",
226
+ "O",
227
+ "2",
228
+ "n",
229
+ "l",
230
+ "F",
231
+ "H",
232
+ "C",
233
+ "o",
234
+ "5",
235
+ "r",
236
+ "s",
237
+ "=",
238
+ "6",
239
+ "[",
240
+ "N",
241
+ "4",
242
+ "c",
243
+ "-",
244
+ "3",
245
+ ")",
246
+ "#",
247
+ "]",
248
+ "B",
249
+ "(",
250
+ "1",
251
+ ]
252
+ return SMILE_TOKENS
253
+
254
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=1):
255
+ super(SmilesVocabulary, self).finalize(
256
+ threshold=threshold, nwords=nwords, padding_factor=padding_factor
257
+ )
258
+
259
+ def go(self):
260
+ """GO index."""
261
+ return self.go_index
262
+
263
+ @classmethod
264
+ def load(cls, f=None, ignore_utf_errors=False):
265
+ """Load function for SMILE data.
266
+
267
+ Ignore the file and just initialize the vocab.
268
+ """
269
+ return cls()
test_tokenizer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test script for SMILES tokenizer."""
2
+
3
+ from smiles_tokenizer import SmilesTokenizer
4
+ from smiles_tokenizer.utils import prepare_for_gpt2
5
+
6
+ def main():
7
+ tokenizer = SmilesTokenizer()
8
+ smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin
9
+
10
+ print(f"Tokenizing SMILES: {smiles}")
11
+ tokens = tokenizer.tokenize([smiles])[0]
12
+ print(f"Tokens: {tokens}")
13
+
14
+ encoded = tokenizer.encode([smiles])[0]
15
+ print(f"Encoded: {encoded}")
16
+
17
+ print("Testing with GPT-2...")
18
+ model, tokenizer_wrapper = prepare_for_gpt2(tokenizer)
19
+ inputs = tokenizer_wrapper(smiles, return_tensors="pt")
20
+ print(f"Model inputs: {inputs}")
21
+ outputs = model(**inputs)
22
+ print(f"Model output shape: {outputs.logits.shape}")
23
+ print("Test completed successfully!")
24
+
25
+ if __name__ == "__main__":
26
+ main()