Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,107 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os, sys, warnings, argparse, math, tqdm, datetime
import pytorch_lightning as pl
import torch
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.callbacks as plc
import pytorch_lightning.loggers as plog
from model_interface import MInterface
from data_interface import DInterface
from src.tools.logger import SetupCallback, BackupCodeCallback
from shutil import ignore_patterns
from transformers import AutoTokenizer
import numpy as np
import yaml
import wandb
warnings.filterwarnings("ignore")
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--infer_path', type=str, help='Path where to read the data to be predicted and where to save the predictions.')
# Set-up parameters
parser.add_argument('--res_dir', default='./train/results', type=str)
parser.add_argument('--ex_name', default='debug', type=str)
parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
parser.add_argument('--stage', default='predict', type=str) #'fit', 'test' or 'predict'
parser.add_argument('--val_check_interval', default=0.5, type=float, help='Validation check interval')
parser.add_argument('--dataset', default='PDBInference') # AF2DB_dataset, CATH_dataset
parser.add_argument('--model_name', default='ProteinMPNN', choices=['StructGNN', 'GraphTrans', 'GVP', 'GCA', 'AlphaDesign', 'ESMIF', 'PiFold', 'ProteinMPNN', 'KWDesign', 'E3PiFold'])
# parser.add_argument('--lr', default=4e-4, type=float, help='Learning rate')
# parser.add_argument('--lr_scheduler', default='onecycle')
# parser.add_argument('--offline', default=1, type=int)
parser.add_argument('--seed', default=111, type=int)
parser.add_argument('--num_workers', default=12, type=int)
parser.add_argument('--pad', default=1024, type=int)
parser.add_argument('--min_length', default=40, type=int)
parser.add_argument('--data_root', default='./data/')
# Training parameters
# parser.add_argument('--epoch', default=10, type=int, help='end epoch')
parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
# parser.add_argument('--gpus', default=1, type=int, help='how many GPUs to train on')
# parser.add_argument('--weight_decay', default=0.0, type=float, help='Weight decay for optimizer')
# # Eval parameters
# parser.add_argument('--eval_sequences_sampled', default=1, type=int, help='How many sequences to sample in evaluation.')
# parser.add_argument('--eval_sequences_temperature', default=0, type=float, help='What temperature to use for the sampling in evaluation.')
# parser.add_argument('--eval_output_dir', default=None, type=str, help='Where to save the evaluation output.')
# Model parameters
parser.add_argument('--use_dist', default=1, type=int)
parser.add_argument('--use_product', default=0, type=int)
parser.add_argument('--use_pmpnn_checkpoint', default=1, type=int, help='By 1 or 0 decide whether to start with pretrained ProteinMPNN.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to the model checkpoint to load weights from')
# Dynamics aware parameters
parser.add_argument('--use_dynamics', default=0, type=int)
# parser.add_argument('--flex_loss_coeff', default=0.5, type=float)
# parser.add_argument('--get_gt_flex_onthefly', default=0, type=int, help='Flag to get ground truth flexibility on-the-fly (with subsequent caching)')
parser.add_argument('--init_flex_features', default=1, type=int, help="Set to 0 if no flexibility information should be passed on input to the node features h_V")
# parser.add_argument('--loss_fn', default='MSE', type=str, help= 'Define what loss to use. Choose MSE, L1 or DPO.')
# parser.add_argument('--grad_normalization', default=1, type=int, help="Set to 0 if the gradients of the seq and flex losses should not be normalized.")
# parser.add_argument('--test_engineering', default=0, type=int, help="In this main.py should be set to 0 to not overwrite the training dataset.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = create_parser()
args.batch_size = 1
print('In the predict stage, defaulting batch size to 1.')
assert args.use_dynamics == 0, "In the inference script this should be set to 0."
if not os.path.exists(args.infer_path):
os.makedirs(args.infer_path)
if (len(args.infer_path) > 0 or args.dataset=='PDBInference') and (len(args.infer_path) == 0 or args.dataset!='PDBInference'):
raise ValueError("You should only use --infer_path with --dataset 'PDBInference' and vice versa.")
# Load model weights from checkpoint if provided
if args.checkpoint_path is not None:
trained_model_path = args.checkpoint_path
print(f"Loading model weights from checkpoint passed by argument: {trained_model_path}")
else:
with open('configs/Flexpert-Design-inference.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
trained_model_path = config['pmpnn_model_path']
print(f"Loading model weights from checkpoint specified in Flexpert-Design-inference.yaml: {trained_model_path}")
if os.path.exists(trained_model_path):
print(f"Rewriting the path to the Flexpert-Design trained ProteinMPNN weights in the model interface.")
args.starting_checkpoint_path = trained_model_path
else:
raise FileNotFoundError(f"Checkpoint file not found at {trained_model_path}")
pl.seed_everything(args.seed)
data_module = DInterface(**vars(args))
data_module.setup(stage='predict')
model = MInterface(**vars(args))
trainer_config = {
'devices': 1,
'max_epochs': 1,
'num_nodes': 1,
"strategy": 'ddp',
"precision": '32',
'accelerator': 'gpu',
'val_check_interval': args.val_check_interval,
'check_val_every_n_epoch': args.check_val_every_n_epoch
}
trainer = Trainer(**trainer_config)
predictions = trainer.predict(model, data_module)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir='./cache_dir/') # mask token: 32
serializable_predictions = []
for pred_idx, pred in enumerate(predictions):
logprobs = pred['log_probs'].cpu().numpy()[0] # [L, 21]
pmpnn_alphabet_tokens_argmax = logprobs.argmax(axis=-1) # [L]
aa_sequence = ''.join(tokenizer.decode(pmpnn_alphabet_tokens_argmax, skip_special_tokens=True).split())
# Get probability of the predicted sequence
seq_probs = np.exp(logprobs.max(axis=-1)) # [L]
avg_prob = float(np.mean(seq_probs))
serializable_predictions.append({
'prediction_id': pred['batch']['title'][0],
'amino_acid_sequence': aa_sequence
})
with open(f'{args.infer_path}/predictions.txt', 'w') as f:
for pred in serializable_predictions:
f.write(f'>{pred["prediction_id"]}\n{pred["amino_acid_sequence"]}\n')
|