| |
| """ |
| How to Use this python Script: |
| |
| Usage: python Features_extraction_esm2.py --output_folder_location Features_extraction --ion_symbols PO4 --ion_names Phosphate --distance 5.0 [other options]" |
| |
| If you want to change the location of the PDB files, you can do this: Edit the variable "pdbs_folder_address" |
| |
| When to run this script? |
| Before you run this script, you need to download the PDB files from the RCSB database and perform the filtering steps. |
| |
| """ |
|
|
| import os |
| import sys |
| import re |
| import pickle |
| import argparse |
| from pathlib import Path |
|
|
| import requests |
| import numpy as np |
| import pandas as pd |
| from tqdm.auto import tqdm |
|
|
| import torch |
| import esm |
|
|
| from transformers import AutoTokenizer, AutoModel, pipeline |
|
|
| from Bio import PDB |
| from Bio.PDB import PDBParser, PDBIO |
| from Bio.PDB.Polypeptide import is_aa |
|
|
| sys.path.append("../Utils") |
| from phosbind_utils import * |
|
|
|
|
| def create_parser(): |
| parser = argparse.ArgumentParser( |
| description="Downnload, and filter the PDB files from the RCBS PDB database" |
| ) |
| parser.add_argument( |
| "--ion_symbols", |
| type=str, |
| required=False, |
| help="In case of multi-ion mode, specify the list of ions you want to plot. Default is just K.", |
| ) |
| parser.add_argument( |
| "--input_file", |
| type=str, |
| default=None, |
| required=False, |
| help="Load the arguments from this input file", |
| ) |
| parser.add_argument( |
| "--ion_names", |
| type=str, |
| required=True, |
| help="Specify ion name in a list. Default is just Potassium", |
| ) |
| parser.add_argument( |
| "--output_folder_location", |
| type=str, |
| required=True, |
| help="Specify location where to save the PDBs folder", |
| ) |
| parser.add_argument( |
| "--distance", |
| type=float, |
| required=True, |
| help="Specify the cutoff distance", |
| ) |
| parser.add_argument( |
| "--logfile", |
| type=str, |
| required=False, |
| default="logfile.log", |
| help="Specofy which step to start from", |
| ) |
| return parser |
| |
| def parse_arguments_from_file(file_path): |
| with open(file_path, 'r') as file: |
| lines = file.readlines() |
|
|
| arguments = [line.strip().split() for line in lines] |
| return sum(arguments, []) |
|
|
| if __name__ == "__main__": |
| parser = create_parser() |
| args = parser.parse_args() |
| if all(value is None for value in vars(args).values()): |
| print("Usage: python Features_extraction_esm2.py --output_folder_location Features_extraction --ion_symbols PO4 --ion_names Phosphate --distance 5.0 [other options]") |
| exit() |
| if args.input_file: |
| file_arguments = parse_arguments_from_file(args.input_file) |
| |
| for arg in vars(args): |
| file_value = [file_arguments[i + 1] for i, val in enumerate(file_arguments) if val == f'--{arg}'] |
| if file_value: |
| setattr(args, arg, file_value[0]) |
| required_arguments = ["output_folder_location", "ion_symbols", "ion_names", "distance"] |
| for argument in required_arguments: |
| if argument not in vars(parser.parse_args()) or not vars(parser.parse_args()).get(argument): |
| raise ValueError(f"--{argument} is a required argument. Please provide it in the text file.") |
| |
|
|
| ion_names_in_a_list = [] |
| ion_symbols_in_a_list = [] |
| for ion_name,ion_symbol in zip(args.ion_names.split(","),args.ion_symbols.split(",")): |
| ion_names_in_a_list.append(ion_name) |
| ion_symbols_in_a_list.append(ion_symbol) |
|
|
| |
| esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| batch_converter = alphabet.get_batch_converter() |
| esm_model.eval() |
|
|
| for ion_name,ion_symbol in zip(ion_names_in_a_list,ion_symbols_in_a_list): |
| pdbs_folder_address = f"../Data_new/Ion-Unique-PDBs/{ion_name}-unique-pdbs/" |
| list_of_files= [pdbs_folder_address + item for item in os.listdir(pdbs_folder_address)] |
| distance = float(args.distance) |
| embeddings_output_folder = f"../{args.output_folder_location}/esm2/Distance-{distance}/{ion_name}" |
| check_dir(embeddings_output_folder) |
|
|
| |
| seq_store_global = [] |
| binding_site_embeddings = [] |
| binding_site_index = 0 |
| long_chain_PDBS = [] |
| sites_with_non_standard_residues = [] |
| |
| |
| progress_tracker_file = f"{embeddings_output_folder}/run_{ion_name}-site-extraction-Progress-tracker_esm2" |
| with open(progress_tracker_file, 'w') as file: |
| print(f"Starting my work", file=file) |
| |
| |
| for filenumber,filename in enumerate(list_of_files, start=1): |
| with open(progress_tracker_file, 'a') as file: |
| print(f"working on {filenumber}/{len(list_of_files)}- {filename}", file=file) |
|
|
| |
| parser = PDBParser(QUIET=True) |
| structure = parser.get_structure("protein", filename) |
| first_model = next(structure.get_models()) |
|
|
| |
| seq_store_local = [] |
| list_of_ion_atoms = [] |
| pdb_id = Path(filename).stem |
|
|
| for chain in first_model: |
| chain_id = chain.get_id() |
| seq_label = f"{pdb_id}-chain-{chain_id}" |
| sequence = get_chain_sequence(chain) |
| seq_store_global.append([seq_label, sequence]) |
| seq_store_local.append([seq_label, sequence]) |
|
|
| for residue in chain: |
| if residue.get_resname() == ion_symbol: |
| ion_info = ion_classification.get(ion_symbol, {}) |
| ion_type = ion_info.get("type", "single atom") |
| main_atom = ion_info.get("main_atom", None) |
| |
| for atom in residue: |
| if ion_type == "molecule": |
| print(f"The {ion_symbol} is molecule") |
| if atom.get_id() == main_atom: |
| list_of_ion_atoms.append(atom) |
| print("Found main atom in molecule:", main_atom) |
| else: |
| list_of_ion_atoms.append(atom) |
| print("Found single atom ion:", main_atom) |
| |
| |
| for ref_atom in list_of_ion_atoms: |
| binding_site_index += 1 |
| |
| |
| atoms, closest_residues, indexes_for_averaging = (get_closest_all_atoms_and_residues_and_indices_v3(first_model, ref_atom, distance)) |
| coordinating_number_atom = len(atoms) |
| |
| |
| non_standard = [r for r in closest_residues if not is_aa(r)] |
| if non_standard: |
| print(f"Skipped {pdb_id}-{chain_id} due to non-standard amino acid(s) present at binding site.") |
| sites_with_non_standard_residues.append(f"{pdb_id}-{chain_id}") |
| binding_site_index -= 1 |
| continue |
| |
| |
| if not closest_residues: |
| print(f"No residues within {distance}Å of site-{binding_site_index}") |
|
|
| |
| for seq_label, sequence in seq_store_local: |
| cid = seq_label.split("-chain-")[1] |
| coords = indexes_for_averaging.get(cid, []) |
| coordinating_number_residues = len(coords) |
|
|
| |
| if len(sequence) >= 1024: |
| long_chain_PDBS.append(seq_label) |
| continue |
| |
| |
| tensors = get_set_of_embeddings_at_binding_site_esm(esm_model, alphabet, batch_converter, [seq_label, sequence], coords) |
| if not tensors: |
| continue |
|
|
| site_residues_tensor_embeddings = torch.stack(tensors, dim=0) |
| averaged_tensor_embedding = site_residues_tensor_embeddings.mean(dim=0) |
|
|
| |
| binding_site_embeddings.append([f"site-{binding_site_index}", structure.header.get("name", pdb_id), f"{structure.header.get('idcode', pdb_id)}-{cid}", coordinating_number_atom, coordinating_number_residues, averaged_tensor_embedding, site_residues_tensor_embeddings]) |
|
|
| |
| if binding_site_index % 300 == 0: |
| chkpt = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-Upto{binding_site_index}-files.restart.pkl") |
| with open(chkpt, "wb") as chk_f: |
| pickle.dump(binding_site_embeddings, chk_f) |
| print(f"Checkpoint saved at: {chkpt}") |
| |
| |
| prev_idx = binding_site_index - 300 |
| if prev_idx > 0: |
| prev_chkpt = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-Upto{prev_idx}-files.restart.pkl") |
| if os.path.exists(prev_chkpt): |
| os.remove(prev_chkpt) |
| print(f"Previous checkpoint deleted: {prev_chkpt}") |
| else: |
| print(f"No previous checkpoint found at: {prev_chkpt}") |
|
|
| |
| final_pkl = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-complete.pkl") |
| with open(final_pkl, "wb") as f: |
| pickle.dump(binding_site_embeddings, f) |
|
|
| |
| columns = ["site_id", "protein_name", "pdb_chain_id","coordinating_atoms", "coordinating_residues","avg_embedding", "residue_embeddings"] |
| df = pd.DataFrame(binding_site_embeddings, columns=columns) |
|
|
| |
| df["avg_embedding"] = df["avg_embedding"].apply(lambda x: ",".join(map(str, x.tolist())) if hasattr(x, "tolist") else str(x)) |
| df["residue_embeddings"] = df["residue_embeddings"].apply(str) |
| df["coordinating_atoms"] = df["coordinating_atoms"].apply(str) |
| df["coordinating_residues"] = df["coordinating_residues"].apply(str) |
|
|
| csv_path = f"{embeddings_output_folder}/binding_site_embeddings.csv" |
| df.to_csv(csv_path, index=False) |
|
|
| with open(progress_tracker_file, 'a') as file: |
| |
| print(f"The sequences skipped due to excessive sequence length are: \n", file=file) |
| print(long_chain_PDBS, file=file) |
| print(f"The {ion_name} ion is completed", file=file) |
|
|