jinysun commited on
Commit
2abe046
·
verified ·
1 Parent(s): 98e9b5a

Upload utils.py

Browse files
Files changed (1) hide show
  1. tool/comget/utils.py +275 -0
tool/comget/utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .moses.utils import get_mol
7
+ from rdkit import Chem
8
+
9
+ import numpy as np
10
+ import threading
11
+
12
+ def set_seed(seed):
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+
18
+ def top_k_logits(logits, k):
19
+ v, ix = torch.topk(logits, k)
20
+ out = logits.clone()
21
+ out[out < v[:, [-1]]] = -float('Inf')
22
+ return out
23
+
24
+ @torch.no_grad()
25
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, prop = None, scaffold = None):
26
+ """
27
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
28
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
29
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
30
+ of block_size, unlike an RNN that has an infinite context window.
31
+ """
32
+ block_size = model.get_block_size()
33
+ model.eval()
34
+
35
+ for k in range(steps):
36
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
37
+ logits, _, _ = model(x_cond, prop = prop, scaffold = scaffold) # for liggpt
38
+ # logits, _, _ = model(x_cond) # for char_rnn
39
+ # pluck the logits at the final step and scale by temperature
40
+ logits = logits[:, -1, :] / temperature
41
+ # optionally crop probabilities to only the top k options
42
+ if top_k is not None:
43
+ logits = top_k_logits(logits, top_k)
44
+ # apply softmax to convert to probabilities
45
+ probs = F.softmax(logits, dim=-1)
46
+ # sample from the distribution or take the most likely
47
+ if sample:
48
+ ix = torch.multinomial(probs, num_samples=1)
49
+ else:
50
+ _, ix = torch.topk(probs, k=1, dim=-1)
51
+ # append to the sequence and continue
52
+ x = torch.cat((x, ix), dim=1)
53
+
54
+ return x
55
+
56
+ def check_novelty(gen_smiles, train_smiles): # gen: say 788, train: 120803
57
+ if len(gen_smiles) == 0:
58
+ novel_ratio = 0.
59
+ else:
60
+ duplicates = [1 for mol in gen_smiles if mol in train_smiles] # [1]*45
61
+ novel = len(gen_smiles) - sum(duplicates) # 788-45=743
62
+ novel_ratio = novel*100./len(gen_smiles) # 743*100/788=94.289
63
+ print("novelty: {:.3f}%".format(novel_ratio))
64
+ return novel_ratio
65
+
66
+ def canonic_smiles(smiles_or_mol):
67
+ mol = get_mol(smiles_or_mol)
68
+ if mol is None:
69
+ return None
70
+ return Chem.MolToSmiles(mol)
71
+
72
+ #Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 1.2.2
73
+
74
+ class Iterator(object):
75
+ """Abstract base class for data iterators.
76
+ # Arguments
77
+ n: Integer, total number of samples in the dataset to loop over.
78
+ batch_size: Integer, size of a batch.
79
+ shuffle: Boolean, whether to shuffle the data between epochs.
80
+ seed: Random seeding for data shuffling.
81
+ """
82
+
83
+ def __init__(self, n, batch_size, shuffle, seed):
84
+ self.n = n
85
+ self.batch_size = batch_size
86
+ self.shuffle = shuffle
87
+ self.batch_index = 0
88
+ self.total_batches_seen = 0
89
+ self.lock = threading.Lock()
90
+ self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
91
+ if n < batch_size:
92
+ raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')
93
+
94
+ def reset(self):
95
+ self.batch_index = 0
96
+
97
+ def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
98
+ # Ensure self.batch_index is 0.
99
+ self.reset()
100
+ while 1:
101
+ if seed is not None:
102
+ np.random.seed(seed + self.total_batches_seen)
103
+ if self.batch_index == 0:
104
+ index_array = np.arange(n)
105
+ if shuffle:
106
+ index_array = np.random.permutation(n)
107
+
108
+ current_index = (self.batch_index * batch_size) % n
109
+ if n > current_index + batch_size:
110
+ current_batch_size = batch_size
111
+ self.batch_index += 1
112
+ else:
113
+ current_batch_size = n - current_index
114
+ self.batch_index = 0
115
+ self.total_batches_seen += 1
116
+ yield (index_array[current_index: current_index + current_batch_size],
117
+ current_index, current_batch_size)
118
+
119
+ def __iter__(self):
120
+ # Needed if we want to do something like:
121
+ # for x, y in data_gen.flow(...):
122
+ return self
123
+
124
+ def __next__(self, *args, **kwargs):
125
+ return self.next(*args, **kwargs)
126
+
127
+
128
+
129
+
130
+ class SmilesIterator(Iterator):
131
+ """Iterator yielding data from a SMILES array.
132
+ # Arguments
133
+ x: Numpy array of SMILES input data.
134
+ y: Numpy array of targets data.
135
+ smiles_data_generator: Instance of `SmilesEnumerator`
136
+ to use for random SMILES generation.
137
+ batch_size: Integer, size of a batch.
138
+ shuffle: Boolean, whether to shuffle the data between epochs.
139
+ seed: Random seed for data shuffling.
140
+ dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
141
+ """
142
+
143
+ def __init__(self, x, y, smiles_data_generator,
144
+ batch_size=32, shuffle=False, seed=None,
145
+ dtype=np.float32
146
+ ):
147
+ if y is not None and len(x) != len(y):
148
+ raise ValueError('X (images tensor) and y (labels) '
149
+ 'should have the same length. '
150
+ 'Found: X.shape = %s, y.shape = %s' %
151
+ (np.asarray(x).shape, np.asarray(y).shape))
152
+
153
+ self.x = np.asarray(x)
154
+
155
+ if y is not None:
156
+ self.y = np.asarray(y)
157
+ else:
158
+ self.y = None
159
+ self.smiles_data_generator = smiles_data_generator
160
+ self.dtype = dtype
161
+ super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
162
+
163
+ def next(self):
164
+ """For python 2.x.
165
+ # Returns
166
+ The next batch.
167
+ """
168
+ # Keeps under lock only the mechanism which advances
169
+ # the indexing of each batch.
170
+ with self.lock:
171
+ index_array, current_index, current_batch_size = next(self.index_generator)
172
+ # The transformation of images is not under thread lock
173
+ # so it can be done in parallel
174
+ batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
175
+ for i, j in enumerate(index_array):
176
+ smiles = self.x[j:j+1]
177
+ x = self.smiles_data_generator.transform(smiles)
178
+ batch_x[i] = x
179
+
180
+ if self.y is None:
181
+ return batch_x
182
+ batch_y = self.y[index_array]
183
+ return batch_x, batch_y
184
+
185
+
186
+ class SmilesEnumerator(object):
187
+ """SMILES Enumerator, vectorizer and devectorizer
188
+
189
+ #Arguments
190
+ charset: string containing the characters for the vectorization
191
+ can also be generated via the .fit() method
192
+ pad: Length of the vectorization
193
+ leftpad: Add spaces to the left of the SMILES
194
+ isomericSmiles: Generate SMILES containing information about stereogenic centers
195
+ enum: Enumerate the SMILES during transform
196
+ canonical: use canonical SMILES during transform (overrides enum)
197
+ """
198
+ def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
199
+ self._charset = None
200
+ self.charset = charset
201
+ self.pad = pad
202
+ self.leftpad = leftpad
203
+ self.isomericSmiles = isomericSmiles
204
+ self.enumerate = enum
205
+ self.canonical = canonical
206
+
207
+ @property
208
+ def charset(self):
209
+ return self._charset
210
+
211
+ @charset.setter
212
+ def charset(self, charset):
213
+ self._charset = charset
214
+ self._charlen = len(charset)
215
+ self._char_to_int = dict((c,i) for i,c in enumerate(charset))
216
+ self._int_to_char = dict((i,c) for i,c in enumerate(charset))
217
+
218
+ def fit(self, smiles, extra_chars=[], extra_pad = 5):
219
+ """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
220
+
221
+ #Arguments
222
+ smiles: Numpy array or Pandas series containing smiles as strings
223
+ extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
224
+ extra_pad: Extra padding to add before or after the SMILES vectorization
225
+ """
226
+ charset = set("".join(list(smiles)))
227
+ self.charset = "".join(charset.union(set(extra_chars)))
228
+ self.pad = max([len(smile) for smile in smiles]) + extra_pad
229
+
230
+ def randomize_smiles(self, smiles):
231
+ """Perform a randomization of a SMILES string
232
+ must be RDKit sanitizable"""
233
+ m = Chem.MolFromSmiles(smiles)
234
+ ans = list(range(m.GetNumAtoms()))
235
+ np.random.shuffle(ans)
236
+ nm = Chem.RenumberAtoms(m,ans)
237
+ return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)
238
+
239
+ def transform(self, smiles):
240
+ """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
241
+ #Arguments
242
+ smiles: Numpy array or Pandas series containing smiles as strings
243
+ """
244
+ one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)
245
+
246
+ if self.leftpad:
247
+ for i,ss in enumerate(smiles):
248
+ if self.enumerate: ss = self.randomize_smiles(ss)
249
+ l = len(ss)
250
+ diff = self.pad - l
251
+ for j,c in enumerate(ss):
252
+ one_hot[i,j+diff,self._char_to_int[c]] = 1
253
+ return one_hot
254
+ else:
255
+ for i,ss in enumerate(smiles):
256
+ if self.enumerate: ss = self.randomize_smiles(ss)
257
+ for j,c in enumerate(ss):
258
+ one_hot[i,j,self._char_to_int[c]] = 1
259
+ return one_hot
260
+
261
+
262
+ def reverse_transform(self, vect):
263
+ """ Performs a conversion of a vectorized SMILES to a smiles strings
264
+ charset must be the same as used for vectorization.
265
+ #Arguments
266
+ vect: Numpy array of vectorized SMILES.
267
+ """
268
+ smiles = []
269
+ for v in vect:
270
+ #mask v
271
+ v=v[v.sum(axis=1)==1]
272
+ #Find one hot encoded index with argmax, translate to char and join to string
273
+ smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
274
+ smiles.append(smile)
275
+ return np.array(smiles)