| |
| |
| """ |
| Created on Fri Jun 16 14:27:44 2023 |
| |
| @author: mheinzinger |
| """ |
|
|
| import argparse |
| import time |
| from pathlib import Path |
|
|
| from urllib import request |
| import shutil |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| from transformers import T5EncoderModel, T5Tokenizer |
|
|
|
|
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| print("Using device: {}".format(device)) |
|
|
|
|
| |
| class CNN(nn.Module): |
| def __init__( self ): |
| super(CNN, self).__init__() |
|
|
| self.classifier = nn.Sequential( |
| nn.Conv2d(1024, 32, kernel_size=(7, 1), padding=(3, 0)), |
| nn.ReLU(), |
| nn.Dropout(0.0), |
| nn.Conv2d(32, 20, kernel_size=(7, 1), padding=(3, 0)) |
| ) |
|
|
| def forward(self, x): |
| """ |
| L = protein length |
| B = batch-size |
| F = number of features (1024 for embeddings) |
| N = number of classes (20 for 3Di) |
| """ |
| x = x.permute(0, 2, 1).unsqueeze(dim=-1) |
| Yhat = self.classifier(x) |
| Yhat = Yhat.squeeze(dim=-1) |
| return Yhat |
|
|
| def one_hot_3di_sequence(sequence, vocab): |
| foldseek_enc = torch.zeros( |
| len(sequence), len(vocab), dtype=torch.float32 |
| ) |
| for i, a in enumerate(sequence): |
| assert a in vocab |
| foldseek_enc[i, vocab[a]] = 1 |
| return foldseek_enc.unsqueeze(0) |
|
|
|
|
| def get_T5_model(model_dir): |
| print("Loading T5 from: {}".format(model_dir)) |
| model = T5EncoderModel.from_pretrained(model_dir).to(device) |
| model = model.eval() |
| vocab = T5Tokenizer.from_pretrained(model_dir, do_lower_case=False ) |
| return model, vocab |
|
|
|
|
| def read_fasta( fasta_path, split_char, id_field ): |
| ''' |
| Reads in fasta file containing multiple sequences. |
| Returns dictionary of holding multiple sequences or only single |
| sequence, depending on input file. |
| ''' |
| |
| sequences = dict() |
| with open( fasta_path, 'r' ) as fasta_f: |
| for line in fasta_f: |
| |
| if line.startswith('>'): |
| uniprot_id = line.replace('>', '').strip().split(split_char)[id_field] |
| |
| uniprot_id = uniprot_id.replace("/","_").replace(".","_") |
| sequences[ uniprot_id ] = '' |
| else: |
| s = ''.join( line.split() ).replace("-","") |
| |
| if s.islower(): |
| print("The input file was in lower-case which indicates 3Di-input." + |
| "This predictor only operates on amino-acid-input (upper-case)." + |
| "Exiting now ..." |
| ) |
| return None |
| else: |
| sequences[ uniprot_id ] += s |
| return sequences |
|
|
| def write_predictions(predictions, out_path): |
| ss_mapping = { |
| 0: "A", |
| 1: "C", |
| 2: "D", |
| 3: "E", |
| 4: "F", |
| 5: "G", |
| 6: "H", |
| 7: "I", |
| 8: "K", |
| 9: "L", |
| 10: "M", |
| 11: "N", |
| 12: "P", |
| 13: "Q", |
| 14: "R", |
| 15: "S", |
| 16: "T", |
| 17: "V", |
| 18: "W", |
| 19: "Y" |
| } |
| |
| with open(out_path, 'w+') as out_f: |
| out_f.write( '\n'.join( |
| [ ">{}\n{}".format( |
| seq_id, "".join(list(map(lambda yhat: ss_mapping[int(yhat)], yhats))) ) |
| for seq_id, yhats in predictions.items() |
| ] |
| ) ) |
| print(f"Finished writing results to {out_path}") |
| return None |
|
|
| def predictions_to_dict(predictions): |
| ss_mapping = { |
| 0: "A", |
| 1: "C", |
| 2: "D", |
| 3: "E", |
| 4: "F", |
| 5: "G", |
| 6: "H", |
| 7: "I", |
| 8: "K", |
| 9: "L", |
| 10: "M", |
| 11: "N", |
| 12: "P", |
| 13: "Q", |
| 14: "R", |
| 15: "S", |
| 16: "T", |
| 17: "V", |
| 18: "W", |
| 19: "Y" |
| } |
| |
| results = {seq_id: "".join(list(map(lambda yhat: ss_mapping[int(yhat)], yhats))) for seq_id, yhats in predictions.items()} |
| return results |
|
|
| def toCPU(tensor): |
| if len(tensor.shape) > 1: |
| return tensor.detach().cpu().squeeze(dim=-1).numpy() |
| else: |
| return tensor.detach().cpu().numpy() |
|
|
|
|
| def download_file(url,local_path): |
| if not local_path.parent.is_dir(): |
| local_path.parent.mkdir() |
| |
| print("Downloading: {}".format(url)) |
| req = request.Request(url, headers={ |
| 'User-Agent' : 'Mozilla/5.0 (Windows NT 6.1; Win64; x64)' |
| }) |
| |
| with request.urlopen(req) as response, open(local_path, 'wb') as outfile: |
| shutil.copyfileobj(response, outfile) |
| return None |
|
|
| |
| def load_predictor( weights_link="https://rostlab.org/~deepppi/prostt5/cnn_chkpnt/model.pt" , device=torch.device("cpu")): |
| model = CNN() |
| checkpoint_p = Path.cwd() / "cnn_chkpnt" / "model.pt" |
| |
| if not checkpoint_p.exists(): |
| download_file(weights_link, checkpoint_p) |
|
|
| state = torch.load(checkpoint_p, map_location=device) |
|
|
| model.load_state_dict(state["state_dict"]) |
|
|
| model = model.eval() |
| model = model.to(device) |
|
|
| return model |
|
|
|
|
| def get_3di_sequences( seq_dict, model_dir, device, |
| max_residues=4000, max_seq_len=1000, max_batch=100,report_fn=print,error_fn=print,half_precision=False): |
| |
| predictions = dict() |
|
|
| prefix = "<AA2fold>" |
| |
| model, vocab = get_T5_model(model_dir) |
| predictor = load_predictor(device=device) |
| |
| if half_precision: |
| model = model.half() |
| predictor = predictor.half() |
| |
| report_fn('Total number of sequences: {}'.format(len(seq_dict))) |
|
|
| avg_length = sum([ len(seq) for _, seq in seq_dict.items()]) / len(seq_dict) |
| n_long = sum([ 1 for _, seq in seq_dict.items() if len(seq)>max_seq_len]) |
| |
| seq_dict = sorted( seq_dict.items(), key=lambda kv: len( seq_dict[kv[0]] ), reverse=True ) |
| |
| report_fn("Average sequence length: {}".format(avg_length)) |
| report_fn("Number of sequences >{}: {}".format(max_seq_len, n_long)) |
| |
| start = time.time() |
| batch = list() |
| for seq_idx, (pdb_id, seq) in enumerate(seq_dict,1): |
| |
| seq = seq.replace('U','X').replace('Z','X').replace('O','X') |
| seq_len = len(seq) |
| seq = prefix + ' ' + ' '.join(list(seq)) |
| batch.append((pdb_id,seq,seq_len)) |
|
|
| |
| |
| n_res_batch = sum([ s_len for _, _, s_len in batch ]) + seq_len |
| if len(batch) >= max_batch or n_res_batch>=max_residues or seq_idx==len(seq_dict) or seq_len>max_seq_len: |
| pdb_ids, seqs, seq_lens = zip(*batch) |
| batch = list() |
|
|
| token_encoding = vocab.batch_encode_plus(seqs, |
| add_special_tokens=True, |
| padding="longest", |
| return_tensors='pt' |
| ).to(device) |
| try: |
| with torch.no_grad(): |
| embedding_repr = model(token_encoding.input_ids, |
| attention_mask=token_encoding.attention_mask |
| ) |
| except RuntimeError: |
| error_fn("RuntimeError during embedding for {} (L={})".format( |
| pdb_id, seq_len) |
| ) |
| continue |
| |
| |
| |
| for idx, s_len in enumerate(seq_lens): |
| token_encoding.attention_mask[idx,s_len+1] = 0 |
|
|
| |
| residue_embedding = embedding_repr.last_hidden_state.detach() |
| |
| residue_embedding = residue_embedding*token_encoding.attention_mask.unsqueeze(dim=-1) |
| |
| residue_embedding = residue_embedding[:,1:] |
| |
| prediction = predictor(residue_embedding) |
| prediction = toCPU(torch.max( prediction, dim=1, keepdim=True )[1] ).astype(np.byte) |
|
|
| |
| |
| for batch_idx, identifier in enumerate(pdb_ids): |
| s_len = seq_lens[batch_idx] |
| |
| predictions[identifier] = prediction[batch_idx,:, 0:s_len].squeeze() |
| assert s_len == len(predictions[identifier]), error_fn(f"Length mismatch for {identifier}: is:{len(predictions[identifier])} vs should:{s_len}") |
| |
| end = time.time() |
| report_fn('Total number of predictions: {}'.format(len(predictions))) |
| report_fn('Total time: {:.2f}[s]; time/prot: {:.4f}[s]; avg. len= {:.2f}'.format( |
| end-start, (end-start)/len(predictions), avg_length)) |
|
|
| return predictions |
|
|
|
|
| def create_arg_parser(): |
| """"Creates and returns the ArgumentParser object.""" |
|
|
| |
| parser = argparse.ArgumentParser(description=( |
| 'embed.py creates ProstT5-Encoder embeddings for a given text '+ |
| ' file containing sequence(s) in FASTA-format.' + |
| 'Example: python predict_3Di.py --input /path/to/some_AA_sequences.fasta --output /path/to/some_3Di_sequences.fasta --half 1' ) ) |
| |
| |
| parser.add_argument( '-i', '--input', required=True, type=str, |
| help='A path to a fasta-formatted text file containing protein sequence(s).') |
|
|
| |
| parser.add_argument( '-o', '--output', required=True, type=str, |
| help='A path for saving the created embeddings as NumPy npz file.') |
|
|
| |
| |
| parser.add_argument('--model', required=False, type=str, |
| default="Rostlab/ProstT5", |
| help='Either a path to a directory holding the checkpoint for a pre-trained model or a huggingface repository link.' ) |
|
|
| |
| parser.add_argument('--split_char', type=str, |
| default='!', |
| help='The character for splitting the FASTA header in order to retrieve ' + |
| "the protein identifier. Should be used in conjunction with --id." + |
| "Default: '!' ") |
| |
| |
| parser.add_argument('--id', type=int, |
| default=0, |
| help='The index for the uniprot identifier field after splitting the ' + |
| "FASTA header after each symbole in ['|', '#', ':', ' ']." + |
| 'Default: 0') |
|
|
| parser.add_argument('--half', type=int, |
| default=1, |
| help="Whether to use half_precision or not. Default: 1 (half-precision)") |
| |
| return parser |
|
|
| def main(): |
| parser = create_arg_parser() |
| args = parser.parse_args() |
| |
| seq_path = Path( args.input ) |
| out_path = Path( args.output) |
| model_dir = args.model |
| |
| if out_path.is_file(): |
| print("Output file is already existing and will be overwritten ...") |
| |
| split_char = args.split_char |
| id_field = args.id |
|
|
| half_precision = False if int(args.half) == 0 else True |
| assert not (half_precision and device=="cpu"), print("Running fp16 on CPU is not supported, yet") |
| |
| seq_dict = read_fasta( seq_path, split_char, id_field ) |
| predictions = get_3di_sequences( |
| seq_dict, |
| model_dir, |
| ) |
| |
| print("Writing results now to disk ...") |
| write_predictions(predictions,out_path) |
|
|
|
|
| if __name__ == '__main__': |
| main() |