# Site Extraction """ How to Use this python Script: Usage: python Features_extraction_esm2.py --output_folder_location Features_extraction --ion_symbols PO4 --ion_names Phosphate --distance 5.0 [other options]" If you want to change the location of the PDB files, you can do this: Edit the variable "pdbs_folder_address" When to run this script? Before you run this script, you need to download the PDB files from the RCSB database and perform the filtering steps. """ import os import sys import re import pickle import argparse from pathlib import Path import requests import numpy as np import pandas as pd from tqdm.auto import tqdm import torch import esm from transformers import AutoTokenizer, AutoModel, pipeline from Bio import PDB from Bio.PDB import PDBParser, PDBIO from Bio.PDB.Polypeptide import is_aa sys.path.append("../Utils") from phosbind_utils import * def create_parser(): parser = argparse.ArgumentParser( description="Downnload, and filter the PDB files from the RCBS PDB database" ) parser.add_argument( "--ion_symbols", type=str, required=False, help="In case of multi-ion mode, specify the list of ions you want to plot. Default is just K.", ) parser.add_argument( "--input_file", type=str, default=None, required=False, help="Load the arguments from this input file", ) parser.add_argument( "--ion_names", type=str, required=True, help="Specify ion name in a list. Default is just Potassium", ) parser.add_argument( "--output_folder_location", type=str, required=True, help="Specify location where to save the PDBs folder", ) parser.add_argument( "--distance", type=float, required=True, help="Specify the cutoff distance", ) parser.add_argument( "--logfile", type=str, required=False, default="logfile.log", help="Specofy which step to start from", ) return parser def parse_arguments_from_file(file_path): with open(file_path, 'r') as file: lines = file.readlines() arguments = [line.strip().split() for line in lines] return sum(arguments, []) # flatten the list if __name__ == "__main__": parser = create_parser() args = parser.parse_args() if all(value is None for value in vars(args).values()): print("Usage: python Features_extraction_esm2.py --output_folder_location Features_extraction --ion_symbols PO4 --ion_names Phosphate --distance 5.0 [other options]") exit() if args.input_file: file_arguments = parse_arguments_from_file(args.input_file) # Override default values with arguments from the file for arg in vars(args): file_value = [file_arguments[i + 1] for i, val in enumerate(file_arguments) if val == f'--{arg}'] if file_value: setattr(args, arg, file_value[0]) required_arguments = ["output_folder_location", "ion_symbols", "ion_names", "distance"] for argument in required_arguments: if argument not in vars(parser.parse_args()) or not vars(parser.parse_args()).get(argument): raise ValueError(f"--{argument} is a required argument. Please provide it in the text file.") #X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X-X- ion_names_in_a_list = [] ion_symbols_in_a_list = [] for ion_name,ion_symbol in zip(args.ion_names.split(","),args.ion_symbols.split(",")): ion_names_in_a_list.append(ion_name) ion_symbols_in_a_list.append(ion_symbol) # Load ESM-2 model esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() esm_model.eval() # disables dropout for deterministic results for ion_name,ion_symbol in zip(ion_names_in_a_list,ion_symbols_in_a_list): pdbs_folder_address = f"../Data_new/Ion-Unique-PDBs/{ion_name}-unique-pdbs/" list_of_files= [pdbs_folder_address + item for item in os.listdir(pdbs_folder_address)] distance = float(args.distance) embeddings_output_folder = f"../{args.output_folder_location}/esm2/Distance-{distance}/{ion_name}" check_dir(embeddings_output_folder) # Initialize storage seq_store_global = [] binding_site_embeddings = [] binding_site_index = 0 long_chain_PDBS = [] sites_with_non_standard_residues = [] # Setup progress log progress_tracker_file = f"{embeddings_output_folder}/run_{ion_name}-site-extraction-Progress-tracker_esm2" with open(progress_tracker_file, 'w') as file: print(f"Starting my work", file=file) # Process each PDB for filenumber,filename in enumerate(list_of_files, start=1): with open(progress_tracker_file, 'a') as file: print(f"working on {filenumber}/{len(list_of_files)}- {filename}", file=file) # Parse structure parser = PDBParser(QUIET=True) structure = parser.get_structure("protein", filename) first_model = next(structure.get_models()) # Collect sequences and ion atoms seq_store_local = [] list_of_ion_atoms = [] pdb_id = Path(filename).stem for chain in first_model: chain_id = chain.get_id() seq_label = f"{pdb_id}-chain-{chain_id}" sequence = get_chain_sequence(chain) seq_store_global.append([seq_label, sequence]) seq_store_local.append([seq_label, sequence]) for residue in chain: if residue.get_resname() == ion_symbol: ion_info = ion_classification.get(ion_symbol, {}) ion_type = ion_info.get("type", "single atom") main_atom = ion_info.get("main_atom", None) for atom in residue: if ion_type == "molecule": print(f"The {ion_symbol} is molecule") if atom.get_id() == main_atom: list_of_ion_atoms.append(atom) print("Found main atom in molecule:", main_atom) else: list_of_ion_atoms.append(atom) print("Found single atom ion:", main_atom) # For each ion atom, find binding-site residues and extract embeddings for ref_atom in list_of_ion_atoms: binding_site_index += 1 # Get nearby atoms/residues atoms, closest_residues, indexes_for_averaging = (get_closest_all_atoms_and_residues_and_indices_v3(first_model, ref_atom, distance)) coordinating_number_atom = len(atoms) # Skip if any non-standard residues non_standard = [r for r in closest_residues if not is_aa(r)] if non_standard: print(f"Skipped {pdb_id}-{chain_id} due to non-standard amino acid(s) present at binding site.") sites_with_non_standard_residues.append(f"{pdb_id}-{chain_id}") binding_site_index -= 1 continue # Warn if no residues found if not closest_residues: print(f"No residues within {distance}Å of site-{binding_site_index}") # Iterate over each chain's sequence for embedding for seq_label, sequence in seq_store_local: cid = seq_label.split("-chain-")[1] coords = indexes_for_averaging.get(cid, []) coordinating_number_residues = len(coords) # Skip very long chains if len(sequence) >= 1024: long_chain_PDBS.append(seq_label) continue # Extract per-residue embeddings tensors = get_set_of_embeddings_at_binding_site_esm(esm_model, alphabet, batch_converter, [seq_label, sequence], coords) if not tensors: continue site_residues_tensor_embeddings = torch.stack(tensors, dim=0) averaged_tensor_embedding = site_residues_tensor_embeddings.mean(dim=0) # Record the site binding_site_embeddings.append([f"site-{binding_site_index}", structure.header.get("name", pdb_id), f"{structure.header.get('idcode', pdb_id)}-{cid}", coordinating_number_atom, coordinating_number_residues, averaged_tensor_embedding, site_residues_tensor_embeddings]) # Checkpoint every 300 sites if binding_site_index % 300 == 0: chkpt = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-Upto{binding_site_index}-files.restart.pkl") with open(chkpt, "wb") as chk_f: pickle.dump(binding_site_embeddings, chk_f) print(f"Checkpoint saved at: {chkpt}") # Remove previous chunk if exists prev_idx = binding_site_index - 300 if prev_idx > 0: prev_chkpt = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-Upto{prev_idx}-files.restart.pkl") if os.path.exists(prev_chkpt): os.remove(prev_chkpt) print(f"Previous checkpoint deleted: {prev_chkpt}") else: print(f"No previous checkpoint found at: {prev_chkpt}") # Final dump to pickle final_pkl = (f"{embeddings_output_folder}/{ion_name}BindingSiteEmbeddings-Distance{distance}angstroms-complete.pkl") with open(final_pkl, "wb") as f: pickle.dump(binding_site_embeddings, f) # Build DataFrame and CSV columns = ["site_id", "protein_name", "pdb_chain_id","coordinating_atoms", "coordinating_residues","avg_embedding", "residue_embeddings"] df = pd.DataFrame(binding_site_embeddings, columns=columns) # Serialize tensor fields to strings df["avg_embedding"] = df["avg_embedding"].apply(lambda x: ",".join(map(str, x.tolist())) if hasattr(x, "tolist") else str(x)) df["residue_embeddings"] = df["residue_embeddings"].apply(str) df["coordinating_atoms"] = df["coordinating_atoms"].apply(str) df["coordinating_residues"] = df["coordinating_residues"].apply(str) csv_path = f"{embeddings_output_folder}/binding_site_embeddings.csv" df.to_csv(csv_path, index=False) with open(progress_tracker_file, 'a') as file: # Print a line to the file print(f"The sequences skipped due to excessive sequence length are: \n", file=file) print(long_chain_PDBS, file=file) print(f"The {ion_name} ion is completed", file=file)