Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,931 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
from data.scripts.data_utils import parse_PDB
from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name
from models.T5_encoder_per_token import PT5_classification_model
from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict
import argparse
import os
import yaml
import torch
from pathlib import Path
from Bio import SeqIO
import json
import warnings
from datetime import datetime
from data.scripts.data_utils import modify_bfactor_biotite
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, required=True, help="Input file")
parser.add_argument("--modality", type=str, required=True, help="Indicate 'Seq' or '3D' to use Flexpert-Seq or Flexpert-3D?")
parser.add_argument("--splits_file", type=str, required=False, help="Path to the file defining the splits, in case that input_file is a dataset which should be subsampled.")
parser.add_argument("--split", type=str, required=False, help="Specify test/train/val to subselect the respective split. If specified, the splits file needs to be provided as well.")
parser.add_argument("--output_enm", action='store_true', help="If true, the ENM values will be outputted in separate file(s).")
parser.add_argument("--output_fasta", action='store_true', help="If true, the sequences used for the prediction will be outputted in a fasta file (can be relevant when working with input list of PDB files).")
parser.add_argument("--output_name", type=str, required=False, help="Name of the output file.")
args = parser.parse_args()
args.modality = args.modality.upper()
filename, suffix = os.path.splitext(args.input_file)
if args.modality not in ["SEQ", "3D"]:
raise ValueError("Modality must be either Seq or 3D")
if args.splits_file is not None and args.split is None:
raise ValueError("If splits_file is provided, split must be specified.")
if args.split is not None and args.splits_file is None:
raise ValueError("If split is specified, splits_file must be provided.")
if args.split is not None and args.split not in ["test", "train", "val", "validation"]:
raise ValueError("Split must be either 'test', 'train', 'val' or 'validation'")
if args.output_enm and (args.modality not in ["3D"]):
raise ValueError("Output ENM is only supported for 3D modality")
if not args.output_name:
default_name = 'untitled_{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
args.output_name = default_name
warnings.warn("Output name is not provided, using default name: {}".format(default_name))
if args.splits_file is not None:
with open(args.splits_file, 'r') as f:
splits = json.load(f)
if 'val' in splits.keys() and args.split == 'validation':
args.split = 'val'
elif 'validation' in splits.keys() and args.split == 'val':
args.split = 'validation'
datapoint_for_eval = splits[args.split]
else:
datapoint_for_eval = 'all'
sequences = []
names = []
backbones = []
pdb_files = []
flucts_list = []
def process_pdb_file(pdb_file, backbones, sequences, names):
parsed_name = os.path.splitext(os.path.basename(pdb_file))[0].split('_')
if len(parsed_name[0]) != 4 or len(parsed_name[1]) != 1 or not parsed_name[1].isalpha():
raise ValueError("PDB file name is expected to be in the format of 'name_chain.pdb', e.g.: 1BUI_C.pdb")
_name = parsed_name[0]
_chain = parsed_name[1]
parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0]
backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)]
if len(sequence) > 1023:
print("Sequence length is greater than 1023, skipping {}".format(_name + "." + _chain))
else:
backbones.append(backbone)
sequences.append(sequence)
names.append(_name + "." + _chain)
return backbones, sequences, names
if suffix == ".fasta":
if args.modality == "3D":
raise ValueError("Flexpert-3D needs the structure, fasta is not enough")
# Load FASTA file using Biopython
for record in SeqIO.parse(args.input_file, "fasta"):
if '_' in record.name:
dot_separated_name = '.'.join(record.name.split('_'))
elif '.' in record.name:
dot_separated_name = record.name
else:
raise ValueError("Sequence name must contain either an underscore or a dot to separate the PDB code and the chain code.")
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
names.append(dot_separated_name)
sequences.append(str(record.seq))
backbones.append(None)
elif suffix == ".pdb":
backbones, sequences, names = process_pdb_file(args.input_file, backbones, sequences, names)
pdb_files.append(args.input_file)
elif suffix == ".jsonl":
for line in open(args.input_file, 'r'):
_dict = json.loads(line)
if 'fluctuations' in _dict.keys():
print("fluctuations are precomputed, using them")
dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict)
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
names.append(_dict['pdb_name'])
backbones.append(None)
sequences.append(_dict['sequence'])
flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token
continue
dot_separated_name = get_dot_separated_name(key='name', _dict=_dict)
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
backbones.append(_dict['coords'])
sequences.append(_dict['seq'])
names.append(dot_separated_name)
elif suffix == ".pdb_list":
with open(args.input_file, 'r') as f:
pdb_files = f.read().splitlines()
for pdb_file in pdb_files:
backbones, sequences, names = process_pdb_file(pdb_file, backbones, sequences, names)
else:
raise ValueError("Input file must be a fasta, pdb, jsonl file or a pdb list file")
### Set environment variables
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
# Set folder for huggingface cache
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
# Set gpu device
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
class_config=ClassConfig(config)
class_config.adaptor_architecture = 'no-adaptor' if args.modality == 'SEQ' else 'conv'
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
model.to(config['inference_args']['device'])
if args.modality == 'SEQ':
state_dict = torch.load(config['inference_args']['seq_model_path'], map_location=config['inference_args']['device'])
model.load_state_dict(state_dict, strict=False)
elif args.modality == '3D':
print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
state_dict = torch.load(config['inference_args']['3d_model_path'], map_location=config['inference_args']['device'])
model.load_state_dict(state_dict, strict=False)
model.eval()
data_to_collate = []
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
if args.modality == '3D':
if backbone is not None:
_dict = {'coords': backbone, 'seq': sequence}
flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type'])
flucts = flucts.tolist()
flucts.append(0.0) #To match the special token for the sequence
flucts = torch.tensor(flucts).to(config['inference_args']['device'])
else:
flucts = flucts_list[idx]
#Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
if args.modality == '3D':
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts})
elif args.modality == 'SEQ':
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})
# Use the data collator to process the input
data_collator = DataCollatorForTokenRegression(tokenizer)
batch = data_collator(data_to_collate) # Wrap in list since collator expects batch
batch.to(model.device)
for key in batch.keys():
print("___________-", key, "-___________")
for b in batch[key]:
if key == 'attention_mask':
print(b.sum())
else:
print(b.shape)
# Predict
with torch.no_grad():
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
predictions = output_logits[:,:,0] #includes the prediction for the added token
# subselect the predictions using the attention mask
output_filename = Path(config['inference_args']['prediction_output_dir'].format(args.output_name, args.modality, 'all' if not args.split else args.split))
output_filename.parent.mkdir(parents=True, exist_ok=True)
#Write the predictions to files
with open(output_filename.with_suffix('.txt'), 'w') as f:
print("Saving predictions to {}.".format(output_filename))
for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
prediction = prediction[mask.bool()]
if len(prediction) != len(sequence)+1:
print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
if '.' in name:
name = name.replace('.', '_')
f.write('>' + name + '\n')
f.write(', '.join([str(p) for p in prediction.tolist()[:-1]]) + '\n')
if suffix == ".pdb" or suffix == ".pdb_list":
for name, pdb_file, prediction in zip(names, pdb_files, predictions):
chain_id = name.split('.')[1]
_prediction = prediction[:-1].reshape(1,-1)
_outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_')))
print("Saving prediction to {}.".format(_outname))
modify_bfactor_biotite(pdb_file, chain_id, _outname, _prediction) #writing the prediction without the last token
if args.output_enm:
_outname = output_filename.with_name(output_filename.stem + '_enm.txt')
with open(_outname, 'w') as f:
print("Saving ENM predictions to {}.".format(_outname))
for enm_prediction, name in zip(batch['enm_vals'], names):
f.write('>' + name + '\n')
f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n')
if suffix == ".pdb" or suffix == ".pdb_list":
for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']):
_outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_')))
print("Saving ENM prediction to {}.".format(_outname))
chain_id = name.split('.')[1]
_enm_vals = enm_vals_single[:-1].reshape(1,-1)
modify_bfactor_biotite(pdb_file, chain_id, _outname, _enm_vals) #writing the prediction without the last token
if args.output_fasta:
_outname = output_filename.with_name(output_filename.stem + '_fasta.fasta')
with open(_outname, 'w') as f:
print("Saving fasta to {}.".format(_outname))
for name, sequence in zip(names, sequences):
f.write('>' + name + '\n')
f.write(sequence + '\n') |