cafierom commited on
Commit
f9819b3
·
verified ·
1 Parent(s): 93f819c

Upload finetune_gpt.py

Browse files
Files changed (1) hide show
  1. finetune_gpt.py +442 -0
finetune_gpt.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepchem as dc
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import random
5
+ import pandas as pd
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Draw
8
+ import os
9
+
10
+ def finetune_gpt(df, chembl_id):
11
+ '''
12
+ accepts a dataframe with SMILES and uses deepchem to tokenize the dataset,
13
+ then uses tensorflow and a pre-trained model to fine tune the model on the dataset.
14
+ The pretrained model was trained on 305K molecules from the ZN15 dataset, including at least
15
+ 50K that are bioactive.
16
+
17
+ Returns:
18
+ out_text: the generated molecules
19
+ img: the image of the generated molecules
20
+
21
+ requires files:
22
+ vocab.txt
23
+ vocab_305K.txt
24
+ GPT_ZN305_50epochs.weights.h5
25
+ layer_store_GPT_ZN305_50epochs.txt
26
+ ZN305K_smiles.csv
27
+
28
+ '''
29
+ # chemck to see if f"gen_smiles_{chembl_id}.csv" exists
30
+ if os.path.exists(f"gen_smiles_{chembl_id}.csv"):
31
+ df = pd.read_csv(f"gen_smiles_{chembl_id}.csv")
32
+ final_smiles = df["SMILES"].to_list()
33
+ final_mols = [Chem.MolFromSmiles(smile) for smile in final_smiles]
34
+ else:
35
+
36
+ # Prepare dataset from chembl ==========================================
37
+
38
+ if len(df) > 2000:
39
+ df = df.sample(n=2000, random_state=42)
40
+
41
+ smiles_list = df["SMILES"].to_list()
42
+
43
+ Xa = []
44
+ for smiles in smiles_list:
45
+ smiles = smiles.replace("[Na+].","").replace("[Cl-].","").replace(".[Cl-]","").replace(".[Na+]","")
46
+ smiles = smiles.replace("[K+].","").replace("[Br-].","").replace(".[K+]","").replace(".[Br-]","")
47
+ smiles = smiles.replace("[I-].","").replace(".[I-]","").replace("[Ca2+].","").replace(".[Ca2+]","")
48
+ Xa.append(smiles)
49
+
50
+ tokenizer=dc.feat.SmilesTokenizer(vocab_file="vocab.txt")
51
+ featname="SMILES Tokenizer"
52
+
53
+ fl = list(map(lambda x: tokenizer.encode(x),Xa))
54
+
55
+ biggest = 1
56
+ smallest = 200
57
+ for i in range(len(fl)):
58
+ temp = len(fl[i])
59
+ if temp > biggest:
60
+ biggest = temp
61
+ if temp < smallest:
62
+ smallest = temp
63
+
64
+ print(biggest, smallest)
65
+
66
+ string_length = smallest - 1
67
+ max_length = biggest
68
+
69
+ fl2 = list(map(lambda x: tokenizer.add_padding_tokens(x,max_length),fl))
70
+
71
+ fl2set=set()
72
+ for sublist in fl2:
73
+ fl2set.update(sublist)
74
+ new_vocab_size = len(fl2set)
75
+ print("New vocabulary size: ",new_vocab_size)
76
+
77
+ f = open("vocab_305K.txt", "r")
78
+ raw_lines = f.readlines()
79
+ f.close()
80
+ VOCAB_SIZE = len(raw_lines)
81
+ print("Vocabulary size for standard dataset: ",VOCAB_SIZE)
82
+
83
+ lines = []
84
+ for line in raw_lines:
85
+ lines.append(line.replace("\n",""))
86
+
87
+ novel_items = []
88
+ for item in fl2set:
89
+ item = tokenizer.decode([item])
90
+ item = tokenizer.convert_tokens_to_string(item)
91
+ item = item.replace(" ","")
92
+
93
+ if item not in lines:
94
+ print(f"{item} not in standard vocabulary")
95
+ novel_items.append(item)
96
+
97
+ if(len(novel_items) > 0):
98
+ print("This dataset is not compatible with the Foundation model vocabulary")
99
+ else:
100
+ print("This dataset is compatible with the Foundation model vocabulary")
101
+
102
+ if max_length > 166:
103
+ print("This dataset's context window is not compatible with the Foundation model.")
104
+ else:
105
+ print("This dataset's context window is compatible with the Foundation model")
106
+
107
+ smiles_removed_tokens = []
108
+ for i,smiles in enumerate(Xa):
109
+ bad_list = [True if (token in smiles) else False for token in novel_items]
110
+ if not any(bad_list):
111
+ smiles_removed_tokens.append(smiles)
112
+
113
+ smiles_no_long = []
114
+ for i,smiles in enumerate(smiles_removed_tokens):
115
+ if len(smiles) <= 166:
116
+ smiles_no_long.append(smiles)
117
+
118
+ print(f"Removed {len(Xa) - len(smiles_no_long)} entries from the list!")
119
+
120
+ new_dict = {"SMILES": smiles_no_long}
121
+ new_df = pd.DataFrame(new_dict)
122
+
123
+ Xa = []
124
+ for smiles in new_df['SMILES']:
125
+ Xa.append(smiles)
126
+
127
+ tokenizer=dc.feat.SmilesTokenizer(vocab_file="vocab_305K.txt")
128
+ featname="SMILES Tokenizer"
129
+
130
+ fl = list(map(lambda x: tokenizer.encode(x),Xa))
131
+
132
+ biggest = 1
133
+ smallest = 200
134
+ for i in range(len(fl)):
135
+ temp = len(fl[i])
136
+ if temp > biggest:
137
+ biggest = temp
138
+ if temp < smallest:
139
+ smallest = temp
140
+
141
+ print(biggest, smallest)
142
+
143
+ string_length = smallest - 1
144
+ max_length = biggest
145
+
146
+ fl2 = list(map(lambda x: tokenizer.add_padding_tokens(x,max_length),fl))
147
+
148
+ f = open("vocab_305K.txt", "r")
149
+ lines = f.readlines()
150
+ f.close()
151
+ VOCAB_SIZE = len(lines)
152
+ print("Vocabulary size for this dataset: ",VOCAB_SIZE)
153
+
154
+ x = []
155
+ y = []
156
+ i=0
157
+ for string in fl2:
158
+ x.append(string[0:max_length-1]) #string_length
159
+ y.append(string[1:max_length]) #string_length+1
160
+
161
+ fx = np.array(x)
162
+ fy = np.array(y)
163
+ print("Number of features and datapoints, targets: ",fx.shape,fy.shape)
164
+
165
+ # Load foundation model ==================================================
166
+
167
+ VOCAB_SIZE = 100
168
+ max_length = 166
169
+ num_new_blocks = 2
170
+ EMBEDDING_DIM = 256
171
+ N_HEADS = 4
172
+ KEY_DIM = 256
173
+ FEED_FORWARD_DIM = 256
174
+
175
+ inputs = tf.keras.layers.Input(shape=(None,),dtype=tf.int32)
176
+ x = TokenAndPositionEmbedding(max_length,VOCAB_SIZE,EMBEDDING_DIM)(inputs)
177
+ for i in range(num_new_blocks+2):
178
+ x, attentions_scores = TransformerBlock(N_HEADS,KEY_DIM,EMBEDDING_DIM,FEED_FORWARD_DIM)(x)
179
+ outputs = tf.keras.layers.Dense(VOCAB_SIZE,activation="softmax")(x)
180
+
181
+ gpt_ft = tf.keras.models.Model(inputs = inputs, outputs =[outputs, attentions_scores])
182
+
183
+ f = open("layer_store_GPT_ZN305_50epochs.txt", "r")
184
+ layer_name_store_raw = f.readlines()
185
+ f.close()
186
+
187
+ print("Reading in layers:")
188
+ layer_name_store = []
189
+ for line in layer_name_store_raw:
190
+ line = line.replace("\n","")
191
+ layer_name_store.append(line)
192
+ print(line)
193
+ print("===========================================")
194
+
195
+ new_layers = num_new_blocks + 1
196
+ for i,layer in enumerate(gpt_ft.layers[:-new_layers]):
197
+ layer.name = layer_name_store[i]
198
+ print(f"{layer.name} has been named!")
199
+
200
+ for i,layer in enumerate(gpt_ft.layers[-new_layers:-1]):
201
+ layer.name = f"transformer_block_X_{i+1}"
202
+ print(f"{layer.name} has been named!")
203
+
204
+ gpt_ft.layers[-1].name = "dense_X"
205
+
206
+ gpt_ft.load_weights("GPT_ZN305_50epochs.weights.h5", skip_mismatch=True)
207
+
208
+ for layer in gpt_ft.layers[0:-new_layers]: #make old layers freeze and only train new layers
209
+ layer.trainable=False
210
+ print(f"setting layer {layer.name} untrainable.")
211
+
212
+ for layer in gpt_ft.layers[-new_layers:]:
213
+ layer.trainable=True
214
+ print(f"setting layer {layer.name} trainable.")
215
+
216
+ # train new layers =======================================================
217
+
218
+ batch_size = 512
219
+ gpt_ft.compile("adam",loss=[tf.keras.losses.SparseCategoricalCrossentropy(),None])
220
+ gpt_ft.fit(fx,fy,epochs = 50, batch_size = batch_size)
221
+
222
+ # train all together =====================================================
223
+ for layer in gpt_ft.layers:
224
+ layer.trainable=True
225
+ print(f"setting layer {layer.name} trainable.")
226
+
227
+ gpt_ft.compile("adam",loss=[tf.keras.losses.SparseCategoricalCrossentropy(),None])
228
+ gpt_ft.fit(fx,fy,epochs = 25, batch_size = batch_size)
229
+
230
+ # make prompts ============================================================
231
+
232
+ df_prompts = pd.read_csv("ZN305K_smiles.csv")
233
+
234
+ Xap = []
235
+ for smiles in df_prompts["SMILES"]:
236
+ smiles = smiles.replace("[Na+].","").replace("[Cl-].","").replace(".[Cl-]","").replace(".[Na+]","")
237
+ smiles = smiles.replace("[K+].","").replace("[Br-].","").replace(".[K+]","").replace(".[Br-]","")
238
+ smiles = smiles.replace("[I-].","").replace(".[I-]","").replace("[Ca2+].","").replace(".[Ca2+]","")
239
+ Xap.append(smiles)
240
+
241
+ raw_prompts = random.choices(Xap,k=50)
242
+
243
+ test_string = []
244
+ for smile in raw_prompts:
245
+ test_string.append(smile[:2])
246
+
247
+ # inference ================================================================
248
+
249
+ tf.random.set_seed(42)
250
+
251
+ batch_length = len(test_string)
252
+ prompt_length = len(test_string[0])
253
+ test_xlist = np.empty([batch_length,prompt_length], dtype=int)
254
+
255
+ test_tokenized = list(map(lambda x: tokenizer.encode(x),test_string))
256
+ for i in range(batch_length):
257
+ test_xlist[i][:] = test_tokenized[i][:prompt_length]
258
+ test_array = np.array(test_xlist)
259
+
260
+ proba = np.empty([batch_length,VOCAB_SIZE])
261
+ rescaled_logits = np.empty([batch_length,VOCAB_SIZE])
262
+ preds = np.empty([batch_length])
263
+ gen_molecules = np.empty([batch_length])
264
+
265
+ c_final = 60 - prompt_length
266
+ sig_start = 0.10
267
+ TEMP = 1.5
268
+
269
+ for c in range(0,c_final,1):
270
+
271
+ c_o = int(c_final*sig_start)
272
+
273
+ T_int = TEMP*(1/(1+np.exp(-(c-c_o))))
274
+
275
+ results, _ = gpt_ft.predict(test_array)
276
+
277
+ if T_int < 0.015:
278
+ print(f"using zero temp generation with {T_int}.")
279
+ for j in range(batch_length):
280
+ preds[j] = tf.argmax(results[j][-1])
281
+ preds = list(map(lambda x: int(x),preds))
282
+ else:
283
+ print(f"using variable temp generation with {T_int}.")
284
+ for j in range(batch_length):
285
+ proba[j] = (results[j][-1:]) ** (1/T_int)
286
+ rescaled_logits[j] = ( proba[j][:] ) / np.sum(proba[j][:])
287
+ preds[j] = np.random.choice(len(rescaled_logits[j][:]),
288
+ p=rescaled_logits[j][:])
289
+ preds = list(map(lambda x: int(x),preds))
290
+ test_array = np.c_[test_array,preds]
291
+ print(test_array.shape)
292
+
293
+ gen_molecules = list(map(lambda x: tokenizer.decode(x),test_array))
294
+ gen_molecules = list(map(lambda x: tokenizer.convert_tokens_to_string(x),
295
+ gen_molecules))
296
+ gen_molecules = list(map(lambda x: strip_smiles(x),gen_molecules))
297
+
298
+ mols, smiles = mols_from_smiles(gen_molecules)
299
+
300
+ final_smiles = []
301
+ final_mols = []
302
+ for smile, mol in zip(smiles,mols):
303
+ if smile not in final_smiles:
304
+ final_smiles.append(smile)
305
+ final_mols.append(mol)
306
+
307
+ final_dict = {"SMILES": final_smiles}
308
+ final_df = pd.DataFrame.from_dict(final_dict)
309
+ final_df.to_csv(f"gen_smiles_{chembl_id}.csv", index = False)
310
+
311
+ print(f"Generated {len(final_smiles)} unique molecules.")
312
+
313
+ img = Draw.MolsToGridImage(final_mols,molsPerRow=3,legends=final_smiles)
314
+ #img.save("Substitution_image.png")
315
+
316
+ out_text = f'The generated molecules are: \n'
317
+ for smile in final_smiles:
318
+ out_text += f'{smile}\n'
319
+
320
+ return out_text, img
321
+
322
+ def casual_attention_mask(batch_size,n_dest,n_src,dtype):
323
+ '''
324
+ Make a causal attention mask
325
+ '''
326
+ i = tf.range(n_dest)[:,None]
327
+ j = tf.range(n_src)
328
+ m = i >= j - n_src + n_dest
329
+ mask = tf.cast(m,dtype)
330
+ mask = tf.reshape(mask,[1,n_dest,n_src])
331
+ mult = tf.concat([tf.expand_dims(batch_size,-1),tf.constant([1,1],dtype=tf.int32)],0)
332
+ return tf.tile(mask,mult)
333
+
334
+ class TransformerBlock(tf.keras.layers.Layer):
335
+ '''
336
+ Transformer block with multi-head attention.
337
+ '''
338
+ def __init__(self,num_heads,key_dim,embed_dim,ff_dim,dropout_rate=0.1):
339
+ super(TransformerBlock,self).__init__()
340
+ self.num_heads = num_heads
341
+ self.key_dim = key_dim
342
+ self.embed_dim = embed_dim
343
+ self.ff_dim = ff_dim
344
+ self.dropout_rate = dropout_rate
345
+ self.attn = tf.keras.layers.MultiHeadAttention(self.num_heads,self.key_dim,
346
+ output_shape=self.embed_dim)
347
+ self.dropout_1 = tf.keras.layers.Dropout(self.dropout_rate)
348
+ self.ln_1 = tf.keras.layers.LayerNormalization(epsilon=0.000001)
349
+ self.ffn_1 = tf.keras.layers.Dense(self.ff_dim,activation="relu")
350
+ self.ffn_2 = tf.keras.layers.Dense(self.embed_dim)
351
+ self.dropout_2 = tf.keras.layers.Dropout(self.dropout_rate)
352
+ self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=0.000001)
353
+
354
+ def call(self,inputs):
355
+ input_shape = tf.shape(inputs)
356
+ batch_size2 = input_shape[0]
357
+ seq_len = input_shape[1]
358
+ casual_mask = casual_attention_mask(batch_size2,seq_len,seq_len,tf.bool)
359
+ attention_output, attention_scores = self.attn(inputs,inputs,
360
+ attention_mask=casual_mask,
361
+ return_attention_scores=True)
362
+ attention_output = self.dropout_1(attention_output)
363
+ out1 = self.ln_1(inputs + attention_output)
364
+ ffn_1 = self.ffn_1(out1)
365
+ ffn_2 = self.ffn_2(ffn_1)
366
+ ffn_output = self.dropout_2(ffn_2)
367
+ return (self.ln_2(out1+ffn_output),attention_scores)
368
+
369
+ def get_config(self):
370
+ config = super().get_config()
371
+ config.update({"key_dim": self.key_dim, "embed_dim": self.embed_dim,
372
+ "num_heads": self.num_heads,"ff_dim": self.ff_dim,
373
+ "dropout_rate": self.dropout_rate})
374
+ return config
375
+
376
+ class TokenAndPositionEmbedding(tf.keras.layers.Layer):
377
+ '''
378
+ Embeds tokens and positions.
379
+ '''
380
+ def __init__(self,max_len,vocab_size,embed_dim):
381
+ super(TokenAndPositionEmbedding,self).__init__()
382
+ self.max_len = max_len
383
+ self.vocab_size = vocab_size
384
+ self.embed_dim = embed_dim
385
+ self.token_emb = tf.keras.layers.Embedding(input_dim=vocab_size,
386
+ output_dim = embed_dim)
387
+ self.pos_emb = tf.keras.layers.Embedding(input_dim=max_len,output_dim=embed_dim)
388
+
389
+ def call(self,x):
390
+ maxlen = tf.shape(x)[-1]
391
+ positions = tf.range(start=0,limit=maxlen,delta=1)
392
+ positions = self.pos_emb(positions)
393
+ x = self.token_emb(x)
394
+ return x + positions
395
+
396
+ def get_config(self):
397
+ config = super().get_config()
398
+ config.update({"max_len": self.max_len, "vocab_size": self.vocab_size,
399
+ "embed_dim": self.embed_dim})
400
+ return config
401
+
402
+ def strip_smiles(input_string):
403
+ '''
404
+ Cleans un-needed tokens from the SMILES string.
405
+
406
+ Args:
407
+ input_string: SMILES string
408
+ Returns:
409
+ output_string: cleaned SMILES string
410
+ '''
411
+ output_string = input_string.replace(" ","").replace("[CLS]","").replace("[SEP]","").replace("[PAD]","")
412
+ output_string = output_string.replace("[Na+].","").replace(".[Na+]","")
413
+ return output_string
414
+
415
+ def mols_from_smiles(input_smiles_list):
416
+ '''
417
+ Converts a list of SMILES strings to a list of RDKit molecules.
418
+
419
+ Args:
420
+ input_smiles_list: list of SMILES strings
421
+ Returns:
422
+ valid_mols: list of RDKit molecules
423
+ valid_smiles: list of SMILES strings
424
+ '''
425
+ valid_mols = []
426
+ valid_smiles = []
427
+
428
+ good_count = 0
429
+ for ti, smile in enumerate(input_smiles_list):
430
+ temp_mol = Chem.MolFromSmiles(smile)
431
+ if temp_mol != None:
432
+ valid_mols.append(temp_mol)
433
+ valid_smiles.append(smile)
434
+ good_count += 1
435
+ else:
436
+ print(f"SMILES {ti} was not valid!")
437
+
438
+ if len(valid_mols) == len(valid_smiles) == good_count:
439
+ print(f"Generated a total of {good_count} mol objects")
440
+ else:
441
+ print("mismatch!")
442
+ return valid_mols, valid_smiles