nuriamimbreropelegri commited on
Commit
a770a4f
·
verified ·
1 Parent(s): b66ca3e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +150 -39
README.md CHANGED
@@ -198,45 +198,156 @@ but this score would enforce a high sequence similarity (thus not *de novo* desi
198
  We recommend generating many sequences and selecting them by plDDT, as well as other metrics.
199
 
200
  ```python
201
- from datasets import load_from_disk
202
- from transformers import AutoTokenizer
203
- from transformers import T5Tokenizer, T5ForConditionalGeneration
204
- import math
205
- import torch
206
- from tqdm import tqdm
207
- import pickle
208
- tokenizer_aa = AutoTokenizer.from_pretrained('/path/to//tokenizer_aa')
209
- tokenizer_smiles = AutoTokenizer.from_pretrained('/path/to//tokenizer_smiles')
210
-
211
- model = T5ForConditionalGeneration.from_pretrained("/path/to/REXzyme").cuda()
212
- print(model.generation_config)
213
- reactions = ["NC1=NC=NC2=C1N=CN2[C@@H]1O[C@H](COP(=O)([O-])OP(=O)([O-])OP(=O)([O-])[O-])[C@@H](O)[C@H]1O.*N[C@@H](CO)C(*)=O>>NC1=NC=NC2=C1N=CN2[C@@H]1O[C@H](COP(=O)([O-])OP(=O)([O-])[O-])[C@@H](O)[C@H]1O.[H+].*N[C@@H](COP(=O)([O-])[O-])C(*)=O"]
214
-
215
- def calculatePerplexity(inputs,model):
216
- '''Function to compute perplexity'''
217
- a=tokenizer_aa.decode(inputs)
218
- b=tokenizer_aa(a, return_tensors="pt").input_ids.to(device='cuda')
219
- b = torch.stack([[b[b!=tokenizer_aa.pad_token_id]] for label in b][0])
220
- with torch.no_grad():
221
- outputs = model(b, labels=b)
222
- loss, logits = outputs[:2]
223
- return math.exp(loss)
224
-
225
-
226
- for idx,i in tqdm(enumerate(reactions)):
227
- input_ids = tokenizer_smiles(f"r2s{i}</s>", return_tensors="pt").input_ids.to(device='cuda')
228
- print(f'Generating for {i}')
229
- ppls_total = []
230
- for _ in range(4):
231
- outputs = model.generate(input_ids,
232
- top_k=15,
233
- top_p = 0.92,
234
- repetition_penalty=1.2,
235
- max_length=1024,
236
- do_sample=True,
237
- num_return_sequences=25)
238
- ppls = [(tokenizer_aa.decode(output,skip_special_tokens=True), calculatePerplexity(output, model),len(tokenizer_aa.decode(output))) for output in tqdm(outputs)]
239
- ppls_total.extend(ppls)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  ```
241
 
242
  ## **A word of caution**
 
198
  We recommend generating many sequences and selecting them by plDDT, as well as other metrics.
199
 
200
  ```python
201
+ """Inference on a SMILES txt. Saved as fastas
202
+ Previously called generate_comparison"""
203
+
204
+ if __name__ == '__main__':
205
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,AutoModelForCausalLM #T5ForConditionalGeneration
206
+ import argparse
207
+ import os
208
+ import torch
209
+ import json
210
+
211
+ parser = argparse.ArgumentParser(description='Mol2Pro inference',
212
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
213
+ parser.add_argument('--input_file', default='../inference/random_smiles2.txt', type=str,
214
+ help='File with the input molecule SMILES')
215
+ parser.add_argument('--model_path', default='./output03/checkpoint-60000', type=str, help='Path to model to load')
216
+ parser.add_argument('--tokenizer_aa',
217
+ default='/home/woody/b114cb/b114cb10/mol2pro/1.training-different-sizes/1.all-data-16M-tokenizernuria/tokenizer_aa', type=str,
218
+ help='Path to amino acid tokenizer')
219
+ parser.add_argument('--tokenizer_mol',
220
+ default='/home/woody/b114cb/b114cb10/mol2pro/1.training-different-sizes/1.all-data-16M-tokenizernuria/nuria_tokenizer_smiles', type=str,
221
+ help='Path to SMILES tokenizer')
222
+ parser.add_argument('--top_k',
223
+ default=15,type=int,
224
+ help='K for top-k sampling')
225
+ parser.add_argument('--output_folder', default='fastas', type=str, help='Folder for saving results')
226
+ args = parser.parse_args()
227
+
228
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
+
230
+ if 'gatgpt' in args.model_path.lower():
231
+ GNN = True
232
+ print('Graph data mode')
233
+ else:
234
+ GNN = False
235
+ print('SMILES/SELFIES data mode')
236
+
237
+ # Load protein tokenizer
238
+ if 'ape' in args.tokenizer_aa:
239
+ from ape_tokenizer import APETokenizer
240
+
241
+ tokenizer_aa = APETokenizer.from_pretrained(args.tokenizer_aa)
242
+ else:
243
+ tokenizer_aa = AutoTokenizer.from_pretrained(args.tokenizer_aa)
244
+
245
+ # Load molecule tokenizer
246
+ if GNN:
247
+ tokenizer_mol = None
248
+ else:
249
+ if 'ape' in args.tokenizer_mol:
250
+ from ape_tokenizer import APETokenizer
251
+
252
+ tokenizer_mol = APETokenizer.from_pretrained(args.tokenizer_mol)
253
+ else:
254
+ tokenizer_mol = AutoTokenizer.from_pretrained(args.tokenizer_mol)
255
+
256
+ # Load model
257
+ dec_only = False
258
+ if GNN:
259
+ from transformers import GPT2Config, Trainer
260
+ from models import GATGPT2Config, GATGPT2
261
+ from torch_geometric.data import Batch, Data
262
+
263
+ config = GATGPT2Config.from_pretrained(args.model_path)
264
+
265
+ # Load model weights
266
+ model = GATGPT2.from_pretrained(args.model_path, config=config)
267
+
268
+ model.eval()
269
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
270
+ else:
271
+ try:
272
+ print('Attempt Seq2Seq model load... ')
273
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path).cuda()
274
+ except:
275
+ print('Attempt CausalLM model load... ')
276
+ model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda()
277
+ model.config.eos_token_id = tokenizer_mol.eos_token_id
278
+ model.config.pad_token_id = tokenizer_mol.pad_token_id
279
+ print(
280
+ f"Set `eos_token_id` to {tokenizer_mol.eos_token_id} and `pad_token_id` to {tokenizer_mol.pad_token_id}.")
281
+ dec_only = True
282
+ print('Model Loaded')
283
+
284
+
285
+ smiles_list = []
286
+ with open(args.input_file, 'r') as input_file:
287
+ for line in input_file:
288
+ smiles_list.append(line.strip())
289
+
290
+ molecule_json = {}
291
+ for index,smiles in enumerate(smiles_list):
292
+ sequences=[]
293
+ if GNN:
294
+ from build_tokenized_dataset import convert_smiles_to_graph
295
+
296
+ node_feats, edge_index, edge_feats = convert_smiles_to_graph(smiles)
297
+ node_feats_tensor = torch.tensor(node_feats, dtype=torch.float, device=device)
298
+ edge_index_tensor = torch.tensor(edge_index, dtype=torch.long, device=device).T.contiguous()
299
+ edge_feats_tensor = torch.tensor(edge_feats, dtype=torch.float, device=device)
300
+
301
+ # Input to decoder is only bos
302
+ start_token = tokenizer_aa.bos_token_id or tokenizer_aa.convert_tokens_to_ids("▁") # fallback to the space which is always appended by our tokenizer
303
+ text_input_ids = torch.tensor([[start_token]], dtype=torch.long, device=device)
304
+
305
+ input_ids = {
306
+ "graph_node_feats": node_feats_tensor, # shape (N, 3)
307
+ "graph_edge_index": edge_index_tensor, # shape (2, E)
308
+ "graph_edge_feats": edge_feats_tensor, # shape (E, 2)
309
+ "batch": torch.full((len(node_feats),), 0, dtype=torch.long, device=device), # shape (N,)
310
+ "input_ids": text_input_ids
311
+ }
312
+
313
+ elif 'ape' in args.tokenizer_mol:
314
+ input_ids = tokenizer_mol(smiles, return_tensors="pt")["input_ids"].to(device='cuda')
315
+ else:
316
+ input_ids = tokenizer_mol(smiles, return_tensors="pt").input_ids.to(device='cuda')
317
+ if not GNN:
318
+ print(f'Generating for {smiles} (input ids: {input_ids})')
319
+ else:
320
+ print(f'Generating for {smiles}')
321
+
322
+ # top_k = Choose at random from the first K tokens (weigthed by softmax score)
323
+ # num_return_sequences = The number of independently computed returned sequences for each element in the batch.
324
+ if dec_only:
325
+ attention_mask = torch.ones_like(input_ids).cuda()
326
+ outputs = model.generate(input_ids, top_k=args.top_k, attention_mask = attention_mask, repetition_penalty=1.2, max_length=1024, do_sample=True, num_return_sequences=25)
327
+ else:
328
+ outputs = model.generate(input_ids, top_k=args.top_k, repetition_penalty=1.2, max_length=1024, do_sample=True, num_return_sequences=25)
329
+
330
+ sequences = [tokenizer_aa.decode(output, skip_special_tokens=True) for output in outputs]
331
+
332
+ if not os.path.exists(args.output_folder):
333
+ os.makedirs(args.output_folder)
334
+
335
+ filename = f'{args.output_folder}/output_topk{args.top_k}_file-{index}.fasta'
336
+ with open(filename, 'w') as fn:
337
+ for idx, seq in enumerate(sequences):
338
+ fn.write(f">{idx}\n{seq}\n")
339
+
340
+ # Store molecule name
341
+ molecule_json[filename] = smiles
342
+
343
+ # Save metadata
344
+ metadata_path = os.path.join(args.output_folder, 'molecule_input_metadata.json')
345
+ try:
346
+ with open(metadata_path, 'w') as json_file:
347
+ json.dump(molecule_json, json_file, indent=4)
348
+ print(f"Metadata successfully written to {metadata_path}")
349
+ except Exception as e:
350
+ print(f"An error occurred while writing to JSON: {e}")
351
  ```
352
 
353
  ## **A word of caution**