File size: 12,931 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
from data.scripts.data_utils import parse_PDB
from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name
from models.T5_encoder_per_token import PT5_classification_model
from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict
import argparse
import os
import yaml
import torch
from pathlib import Path
from Bio import SeqIO
import json
import warnings
from datetime import datetime

from data.scripts.data_utils import modify_bfactor_biotite

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True, help="Input file")
    parser.add_argument("--modality", type=str, required=True, help="Indicate 'Seq' or '3D' to use Flexpert-Seq or Flexpert-3D?")
    parser.add_argument("--splits_file", type=str, required=False, help="Path to the file defining the splits, in case that input_file is a dataset which should be subsampled.")
    parser.add_argument("--split", type=str, required=False, help="Specify test/train/val to subselect the respective split. If specified, the splits file needs to be provided as well.")
    parser.add_argument("--output_enm", action='store_true', help="If true, the ENM values will be outputted in separate file(s).")
    parser.add_argument("--output_fasta", action='store_true', help="If true, the sequences used for the prediction will be outputted in a fasta file (can be relevant when working with input list of PDB files).")
    parser.add_argument("--output_name", type=str, required=False, help="Name of the output file.")
    args = parser.parse_args()

    args.modality = args.modality.upper()
    filename, suffix = os.path.splitext(args.input_file)
    
    if args.modality not in ["SEQ", "3D"]:
        raise ValueError("Modality must be either Seq or 3D")
    if args.splits_file is not None and args.split is None:
        raise ValueError("If splits_file is provided, split must be specified.")
    if args.split is not None and args.splits_file is None:
        raise ValueError("If split is specified, splits_file must be provided.")
    if args.split is not None and args.split not in ["test", "train", "val", "validation"]:
        raise ValueError("Split must be either 'test', 'train', 'val' or 'validation'")
    if args.output_enm and (args.modality not in ["3D"]):
        raise ValueError("Output ENM is only supported for 3D modality")
    if not args.output_name:
        default_name = 'untitled_{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S'))
        args.output_name = default_name
        warnings.warn("Output name is not provided, using default name: {}".format(default_name))
        

    if args.splits_file is not None:
        with open(args.splits_file, 'r') as f:
            splits = json.load(f)
        if 'val' in splits.keys() and args.split == 'validation':
            args.split = 'val'
        elif 'validation' in splits.keys() and args.split == 'val':
            args.split = 'validation'
        
        datapoint_for_eval = splits[args.split]
    else:
        datapoint_for_eval = 'all'

    sequences = []
    names = []
    backbones = []
    pdb_files = []
    flucts_list = []

    def process_pdb_file(pdb_file, backbones, sequences, names):
        parsed_name = os.path.splitext(os.path.basename(pdb_file))[0].split('_')
        if len(parsed_name[0]) != 4 or len(parsed_name[1]) != 1 or not parsed_name[1].isalpha():
            raise ValueError("PDB file name is expected to be in the format of 'name_chain.pdb', e.g.: 1BUI_C.pdb")
        _name = parsed_name[0]
        _chain = parsed_name[1]
        parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0]
        backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)]
        if len(sequence) > 1023:
            print("Sequence length is greater than 1023, skipping {}".format(_name + "." + _chain))
        else:
            backbones.append(backbone)
            sequences.append(sequence)
            names.append(_name + "." + _chain)
        return backbones, sequences, names

    if suffix == ".fasta":
        if args.modality == "3D":
            raise ValueError("Flexpert-3D needs the structure, fasta is not enough")

        # Load FASTA file using Biopython
        for record in SeqIO.parse(args.input_file, "fasta"):
            if '_' in record.name:
                dot_separated_name = '.'.join(record.name.split('_'))
            elif '.' in record.name:
                dot_separated_name = record.name
            else:
                raise ValueError("Sequence name must contain either an underscore or a dot to separate the PDB code and the chain code.")
            if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
                names.append(dot_separated_name)
                sequences.append(str(record.seq))
                backbones.append(None)

    elif suffix == ".pdb":
        backbones, sequences, names = process_pdb_file(args.input_file, backbones, sequences, names)
        pdb_files.append(args.input_file)

    elif suffix == ".jsonl":
        for line in open(args.input_file, 'r'):
            _dict = json.loads(line)

            if 'fluctuations' in _dict.keys():
                print("fluctuations are precomputed, using them")
                dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict)
                if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
            
                    names.append(_dict['pdb_name'])
                    backbones.append(None)
                    sequences.append(_dict['sequence'])

                    flucts_list.append(_dict['fluctuations']+[0.0]) #padding for end cls token
                continue
            
            dot_separated_name = get_dot_separated_name(key='name', _dict=_dict)
            
            if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval:
                backbones.append(_dict['coords'])
                sequences.append(_dict['seq'])
                names.append(dot_separated_name)

    elif suffix == ".pdb_list":
        with open(args.input_file, 'r') as f:
            pdb_files = f.read().splitlines()
        for pdb_file in pdb_files:
            backbones, sequences, names = process_pdb_file(pdb_file, backbones, sequences, names)

    else:
        raise ValueError("Input file must be a fasta, pdb, jsonl file or a pdb list file")

    ### Set environment variables
    env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
    # Set folder for huggingface cache
    os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
    # Set gpu device
    os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']

    config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
    class_config=ClassConfig(config)
    class_config.adaptor_architecture = 'no-adaptor' if args.modality == 'SEQ' else 'conv'
    model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)

    model.to(config['inference_args']['device'])
    if args.modality == 'SEQ':
        state_dict = torch.load(config['inference_args']['seq_model_path'], map_location=config['inference_args']['device'])
        model.load_state_dict(state_dict, strict=False)
    elif args.modality == '3D':
        print("Loading 3D model from {}".format(config['inference_args']['3d_model_path']))
        state_dict = torch.load(config['inference_args']['3d_model_path'], map_location=config['inference_args']['device'])
        model.load_state_dict(state_dict, strict=False)
    model.eval()

    data_to_collate = []
    for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)):
        
        if args.modality == '3D':
            if backbone is not None:
                _dict = {'coords': backbone, 'seq': sequence}
                flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type'])
                flucts = flucts.tolist()
                flucts.append(0.0) #To match the special token for the sequence
                flucts = torch.tensor(flucts).to(config['inference_args']['device'])
            else:
                flucts = flucts_list[idx]

        #Ensure that the missing residues in the sequence are not represented as '-' but as 'X'
        sequence = sequence.replace('-', 'X') #due to the tokenizer vocabulary

        tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt')
        tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device'])
        
        if args.modality == '3D':
            data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts})
        elif args.modality == 'SEQ':
            data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]})

    # Use the data collator to process the input
    data_collator = DataCollatorForTokenRegression(tokenizer)

    batch = data_collator(data_to_collate)  # Wrap in list since collator expects batch
    batch.to(model.device)
    for key in batch.keys():
        print("___________-", key, "-___________")
        for b in batch[key]:
            if key == 'attention_mask':
                print(b.sum())
            else:
                print(b.shape)

    # Predict
    with torch.no_grad():
        output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size'])
        predictions = output_logits[:,:,0] #includes the prediction for the added token
        # subselect the predictions using the attention mask
    
    output_filename = Path(config['inference_args']['prediction_output_dir'].format(args.output_name, args.modality, 'all' if not args.split else args.split))
    output_filename.parent.mkdir(parents=True, exist_ok=True)

    #Write the predictions to files
    with open(output_filename.with_suffix('.txt'), 'w') as f:
        print("Saving predictions to {}.".format(output_filename))
        for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences):
            prediction = prediction[mask.bool()]
            if len(prediction) != len(sequence)+1:
                print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1))

            assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)
            if '.' in name:
                name = name.replace('.', '_')
            f.write('>' + name + '\n')
            f.write(', '.join([str(p) for p in prediction.tolist()[:-1]]) + '\n')
    
    if suffix == ".pdb" or suffix == ".pdb_list":
        for name, pdb_file, prediction in zip(names, pdb_files, predictions):
            chain_id = name.split('.')[1]
            _prediction = prediction[:-1].reshape(1,-1)
            _outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_')))
            print("Saving prediction to {}.".format(_outname))
            modify_bfactor_biotite(pdb_file, chain_id, _outname, _prediction) #writing the prediction without the last token

    if args.output_enm:
        _outname = output_filename.with_name(output_filename.stem + '_enm.txt')
        with open(_outname, 'w') as f:
            print("Saving ENM predictions to {}.".format(_outname))
            for enm_prediction, name in zip(batch['enm_vals'], names):
                f.write('>' + name + '\n')
                f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n')
    
        if suffix == ".pdb" or suffix == ".pdb_list":
            for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']):
                _outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_')))
                print("Saving ENM prediction to {}.".format(_outname))
                chain_id = name.split('.')[1]
                _enm_vals = enm_vals_single[:-1].reshape(1,-1)
                modify_bfactor_biotite(pdb_file, chain_id, _outname, _enm_vals) #writing the prediction without the last token
    
    if args.output_fasta:
        _outname = output_filename.with_name(output_filename.stem + '_fasta.fasta')
        with open(_outname, 'w') as f:
            print("Saving fasta to {}.".format(_outname))
            for name, sequence in zip(names, sequences):
                f.write('>' + name + '\n')
                f.write(sequence + '\n')