Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import shutil | |
| import random | |
| import sys | |
| import time | |
| from typing import List, Tuple, Optional | |
| import Bio.PDB | |
| import Bio.SeqUtils | |
| import pandas as pd | |
| import numpy as np | |
| import requests | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| BASE_FOLDER = "/tmp/" | |
| OUTPUT_FOLDER = f"{BASE_FOLDER}/processed" | |
| # https://storage.googleapis.com/plinder/2024-06/v2/index/annotation_table.parquet | |
| PLINDER_ANNOTATIONS = f'{BASE_FOLDER}/raw_data/2024-06_v2_index_annotation_table.parquet' | |
| # https://storage.googleapis.com/plinder/2024-06/v2/splits/split.parquet | |
| PLINDER_SPLITS = f'{BASE_FOLDER}/raw_data/2024-06_v2_splits_split.parquet' | |
| # https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dapo/links.parquet | |
| PLINDER_LINKED_APO_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=apo_links.parquet" | |
| # https://console.cloud.google.com/storage/browser/_details/plinder/2024-06/v2/links/kind%3Dpred/links.parquet | |
| PLINDER_LINKED_PRED_MAP = f"{BASE_FOLDER}/raw_data/2024-06_v2_links_kind=pred_links.parquet" | |
| # https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/apo.zip | |
| PLINDER_LINKED_APO_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_apo" | |
| # https://storage.googleapis.com/plinder/2024-06/v2/linked_structures/pred.zip | |
| PLINDER_LINKED_PRED_STRUCTURES = f"{BASE_FOLDER}/raw_data/2024-06_v2_linked_structures_pred" | |
| GSUTIL_PATH = f"{BASE_FOLDER}/google-cloud-sdk/bin/gsutil" | |
| def get_cached_systems_to_train(recompute=False): | |
| output_path = os.path.join(OUTPUT_FOLDER, "to_train.pickle") | |
| if os.path.exists(output_path) and not recompute: | |
| return pd.read_pickle(output_path) | |
| """ | |
| full: | |
| loaded 1357906 409726 163816 433865 | |
| loaded 990260 409726 125818 106411 | |
| joined splits 409726 | |
| Has splits 311008 | |
| unique systems 311008 | |
| split | |
| train 309140 | |
| test 1036 | |
| val 832 | |
| Name: count, dtype: int64 | |
| Has affinity 36856 | |
| Has affinity by splits split | |
| train 36598 | |
| test 142 | |
| val 116 | |
| Name: count, dtype: int64 | |
| Total systems before pred 311008 | |
| Total systems after pred 311008 | |
| Has pred 83487 | |
| Has apo 75127 | |
| Has both 51506 | |
| Has either 107108 | |
| columns Index(['system_id', 'entry_pdb_id', 'ligand_binding_affinity', | |
| 'entry_release_date', 'system_pocket_UniProt', | |
| 'system_num_protein_chains', 'system_num_ligand_chains', 'uniqueness', | |
| 'split', 'cluster', 'cluster_for_val_split', | |
| 'system_pass_validation_criteria', 'system_pass_statistics_criteria', | |
| 'system_proper_num_ligand_chains', 'system_proper_pocket_num_residues', | |
| 'system_proper_num_interactions', | |
| 'system_proper_ligand_max_molecular_weight', | |
| 'system_has_binding_affinity', 'system_has_apo_or_pred', '_bucket_id', | |
| 'linked_pred_id', 'linked_apo_id'], | |
| dtype='object') | |
| total systems 311008 | |
| """ | |
| systems = pd.read_parquet(PLINDER_ANNOTATIONS, | |
| columns=['system_id', 'entry_pdb_id', 'ligand_binding_affinity', | |
| 'entry_release_date', 'system_pocket_UniProt', 'entry_resolution', | |
| 'system_num_protein_chains', 'system_num_ligand_chains']) | |
| splits = pd.read_parquet(PLINDER_SPLITS) | |
| linked_pred = pd.read_parquet(PLINDER_LINKED_PRED_MAP) | |
| linked_apo = pd.read_parquet(PLINDER_LINKED_APO_MAP) | |
| print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo)) | |
| # remove duplicated | |
| systems = systems.drop_duplicates(subset=['system_id']) | |
| splits = splits.drop_duplicates(subset=['system_id']) | |
| linked_pred = linked_pred.drop_duplicates(subset=['reference_system_id']) | |
| linked_apo = linked_apo.drop_duplicates(subset=['reference_system_id']) | |
| print("loaded", len(systems), len(splits), len(linked_pred), len(linked_apo)) | |
| # join splits | |
| systems = pd.merge(systems, splits, on='system_id', how='inner') | |
| print("joined splits", len(systems)) | |
| systems['_bucket_id'] = systems['entry_pdb_id'].str[1:3] | |
| # leave only with train/val/test splits | |
| systems = systems[systems['split'].isin(['train', 'val', 'test'])] | |
| print("Has splits", len(systems)) | |
| print("unique systems", systems['system_id'].nunique()) | |
| print(systems["split"].value_counts()) | |
| print("Has affinity", len(systems[systems['ligand_binding_affinity'].notna()])) | |
| # print has affinity by splits | |
| print("Has affinity by splits", systems[systems['ligand_binding_affinity'].notna()]['split'].value_counts()) | |
| print("Total systems before pred", len(systems)) | |
| # join linked structures - allow to not have linked structures | |
| systems = pd.merge(systems, linked_pred[['reference_system_id', 'id']], | |
| left_on='system_id', right_on='reference_system_id', | |
| how='left') | |
| print("Total systems after pred", len(systems)) | |
| # Rename the 'id' column from linked_pred to 'linked_pred_id' | |
| systems.rename(columns={'id': 'linked_pred_id'}, inplace=True) | |
| # Merge the result with linked_apo on the same condition | |
| systems = pd.merge(systems, linked_apo[['reference_system_id', 'id']], | |
| left_on='system_id', right_on='reference_system_id', | |
| how='left') | |
| # Rename the 'id' column from linked_apo to 'linked_apo_id' | |
| systems.rename(columns={'id': 'linked_apo_id'}, inplace=True) | |
| # Drop the reference_system_id columns that were added during the merge | |
| systems.drop(columns=['reference_system_id_x', 'reference_system_id_y'], inplace=True) | |
| cluster_sizes = systems["cluster"].value_counts() | |
| systems["cluster_size"] = systems["cluster"].map(cluster_sizes) | |
| # print(systems[['system_id', 'cluster', 'cluster_size']]) | |
| print("Has pred", systems['linked_pred_id'].notna().sum()) | |
| print("Has apo", systems['linked_apo_id'].notna().sum()) | |
| print("Has both", (systems['linked_pred_id'].notna() & systems['linked_apo_id'].notna()).sum()) | |
| print("Has either", (systems['linked_pred_id'].notna() | systems['linked_apo_id'].notna()).sum()) | |
| print("columns", systems.columns) | |
| systems.to_pickle(output_path) | |
| return systems | |
| def create_conformers(smiles, output_path, num_conformers=100, multiplier_samples=1): | |
| target_mol = Chem.MolFromSmiles(smiles) | |
| target_mol = Chem.AddHs(target_mol) | |
| params = AllChem.ETKDGv3() | |
| params.numThreads = 0 # Use all available threads | |
| params.pruneRmsThresh = 0.1 # Pruning threshold for RMSD | |
| conformer_ids = AllChem.EmbedMultipleConfs(target_mol, numConfs=num_conformers * multiplier_samples, params=params) | |
| # Optional: Optimize each conformer using MMFF94 force field | |
| # for conf_id in conformer_ids: | |
| # AllChem.UFFOptimizeMolecule(target_mol, confId=conf_id) | |
| # remove hydrogen atoms | |
| target_mol = Chem.RemoveHs(target_mol) | |
| # Save aligned conformers to a file (optional) | |
| w = Chem.SDWriter(output_path) | |
| for i, conf_id in enumerate(conformer_ids): | |
| if i >= num_conformers: | |
| break | |
| w.write(target_mol, confId=conf_id) | |
| w.close() | |
| def do_robust_chain_object_renumber(chain: Bio.PDB.Chain.Chain, new_chain_id: str) -> Optional[Bio.PDB.Chain.Chain]: | |
| all_residues = [res for res in chain.get_residues() | |
| if "CA" in res and Bio.SeqUtils.seq1(res.get_resname()) not in ("X", "", " ")] | |
| if not all_residues: | |
| return None | |
| res_and_res_id = [(res, res.get_id()[1]) for res in all_residues] | |
| min_res_id = min([i[1] for i in res_and_res_id]) | |
| if min_res_id < 1: | |
| print("Negative res id", chain, min_res_id) | |
| factor = -1 * min_res_id + 1 | |
| res_and_res_id = [(res, res_id + factor) for res, res_id in res_and_res_id] | |
| res_and_res_id_no_collisions = [] | |
| for res, res_id in res_and_res_id[::-1]: | |
| if res_and_res_id_no_collisions and res_and_res_id_no_collisions[-1][1] == res_id: | |
| # there is a collision, usually an insertion residue | |
| res_and_res_id_no_collisions = [(i, j + 1) for i, j in res_and_res_id_no_collisions] | |
| res_and_res_id_no_collisions.append((res, res_id)) | |
| first_res_id = min([i[1] for i in res_and_res_id_no_collisions]) | |
| factor = 1 - first_res_id # start from 1 | |
| new_chain = Bio.PDB.Chain.Chain(new_chain_id) | |
| res_and_res_id_no_collisions.sort(key=lambda x: x[1]) | |
| for res, res_id in res_and_res_id_no_collisions: | |
| chain.detach_child(res.id) | |
| res.id = (" ", res_id + factor, " ") | |
| new_chain.add(res) | |
| return new_chain | |
| def robust_renumber_protein(pdb_path: str, output_path: str): | |
| if pdb_path.endswith(".pdb"): | |
| pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
| pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path) | |
| elif pdb_path.endswith(".cif"): | |
| pdb_struct = Bio.PDB.MMCIFParser().get_structure("original_pdb", pdb_path) | |
| else: | |
| raise ValueError("Unknown file type", pdb_path) | |
| assert len(list(pdb_struct)) == 1, "can't extract if more than one model" | |
| model = next(iter(pdb_struct)) | |
| chains = list(model.get_chains()) | |
| new_model = Bio.PDB.Model.Model(0) | |
| chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" | |
| for chain, chain_id in zip(chains, chain_ids): | |
| new_chain = do_robust_chain_object_renumber(chain, chain_id) | |
| if new_chain is None: | |
| continue | |
| new_model.add(new_chain) | |
| new_struct = Bio.PDB.Structure.Structure("renumbered_pdb") | |
| new_struct.add(new_model) | |
| io = Bio.PDB.PDBIO() | |
| io.set_structure(new_struct) | |
| io.save(output_path) | |
| def _get_extra(extra_to_save: int, res_before: List[int], res_after: List[int]) -> set: | |
| take_from_before = random.randint(0, extra_to_save) | |
| take_from_after = extra_to_save - take_from_before | |
| if take_from_before > len(res_before): | |
| take_from_after = extra_to_save - len(res_before) | |
| take_from_before = len(res_before) | |
| if take_from_after > len(res_after): | |
| take_from_before = extra_to_save - len(res_after) | |
| take_from_after = len(res_after) | |
| extra_to_add = set() | |
| if take_from_before > 0: | |
| extra_to_add.update(res_before[-take_from_before:]) | |
| extra_to_add.update(res_after[:take_from_after]) | |
| return extra_to_add | |
| def crop_protein_cont(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int, | |
| distance_threshold: float): | |
| protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False) | |
| ligand_size = ligand_pos.shape[0] | |
| pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
| gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path))) | |
| all_res_ids_by_chain = {chain.id: sorted([res.id[1] for res in chain.get_residues() if "CA" in res]) | |
| for chain in gt_model.get_chains()} | |
| protein_conf = protein.GetConformer() | |
| protein_pos = protein_conf.GetPositions() | |
| protein_atoms = list(protein.GetAtoms()) | |
| assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}" | |
| inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :] | |
| inter_dists = np.sqrt((inter_dists ** 2).sum(-1)) | |
| min_inter_dist_per_protein_atom = inter_dists.min(axis=0) | |
| res_to_save_count = max_length - ligand_size | |
| used_protein_idx = np.where(min_inter_dist_per_protein_atom < distance_threshold)[0] | |
| pocket_residues_by_chain = {} | |
| for idx in used_protein_idx: | |
| res = protein_atoms[idx].GetPDBResidueInfo() | |
| if res.GetIsHeteroAtom(): | |
| continue | |
| if res.GetChainId() not in pocket_residues_by_chain: | |
| pocket_residues_by_chain[res.GetChainId()] = set() | |
| # get residue chain | |
| pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber()) | |
| if not pocket_residues_by_chain: | |
| print("No pocket residues found") | |
| return -1 | |
| # print("pocket_residues_by_chain", pocket_residues_by_chain) | |
| complete_pocket = [] | |
| extended_pocket_per_chain = {} | |
| for chain_id, pocket_residues in pocket_residues_by_chain.items(): | |
| max_pocket_res = max(pocket_residues) | |
| min_pocket_res = min(pocket_residues) | |
| extended_pocket_per_chain[chain_id] = {res_id for res_id in all_res_ids_by_chain[chain_id] | |
| if min_pocket_res <= res_id <= max_pocket_res} | |
| for res_id in extended_pocket_per_chain[chain_id]: | |
| complete_pocket.append((chain_id, res_id)) | |
| # print("extended_pocket_per_chain", pocket_residues_by_chain) | |
| if len(complete_pocket) > res_to_save_count: | |
| total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()]) | |
| total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
| print(f"Too many residues all: {total_res_ids} pocket:{total_pocket_res} {len(complete_pocket)} " | |
| f"(ligand size: {ligand_size})") | |
| return -1 | |
| extra_to_save = res_to_save_count - len(complete_pocket) | |
| # divide extra_to_save between chains | |
| for chain_id, pocket_residues in extended_pocket_per_chain.items(): | |
| extra_to_save_per_chain = extra_to_save // len(extended_pocket_per_chain) | |
| res_before = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id < min(pocket_residues)] | |
| res_after = [res_id for res_id in all_res_ids_by_chain[chain_id] if res_id > max(pocket_residues)] | |
| extra_to_add = _get_extra(extra_to_save_per_chain, res_before, res_after) | |
| for res_id in extra_to_add: | |
| complete_pocket.append((chain_id, res_id)) | |
| total_res_ids = sum([len(res_ids) for res_ids in all_res_ids_by_chain.values()]) | |
| total_pocket_res = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
| total_extended_res = sum([len(res_ids) for res_ids in extended_pocket_per_chain.values()]) | |
| print(f"Found valid pocket all: {total_res_ids} pocket:{total_pocket_res} {total_extended_res} " | |
| f"{len(complete_pocket)} (ligand size: {ligand_size}) extra: {extra_to_save}") | |
| # print("all_res_ids_by_chain", all_res_ids_by_chain) | |
| # print("complete_pocket", sorted(complete_pocket)) | |
| res_to_remove = [] | |
| for res in gt_model.get_residues(): | |
| if (res.parent.id, res.id[1]) not in complete_pocket or res.id[0].strip() != "" or res.id[2].strip() != "": | |
| res_to_remove.append(res) | |
| for res in res_to_remove: | |
| gt_model[res.parent.id].detach_child(res.id) | |
| io = Bio.PDB.PDBIO() | |
| io.set_structure(gt_model) | |
| io.save(output_path) | |
| return len(complete_pocket) | |
| def crop_protein_simple(gt_pdb_path: str, ligand_pos: np.ndarray, output_path: str, max_length: int): | |
| protein = Chem.MolFromPDBFile(gt_pdb_path, sanitize=False) | |
| ligand_size = ligand_pos.shape[0] | |
| res_to_save_count = max_length - ligand_size | |
| pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
| gt_model = next(iter(pdb_parser.get_structure("gt_pdb", gt_pdb_path))) | |
| protein_conf = protein.GetConformer() | |
| protein_pos = protein_conf.GetPositions() | |
| protein_atoms = list(protein.GetAtoms()) | |
| assert len(protein_pos) == len(protein_atoms), f"Positions and atoms mismatch in {gt_pdb_path}" | |
| inter_dists = ligand_pos[:, np.newaxis, :] - protein_pos[np.newaxis, :, :] | |
| inter_dists = np.sqrt((inter_dists ** 2).sum(-1)) | |
| min_inter_dist_per_protein_atom = inter_dists.min(axis=0) | |
| protein_idx_by_dist = np.argsort(min_inter_dist_per_protein_atom) | |
| pocket_residues_by_chain = {} | |
| total_found = 0 | |
| for idx in protein_idx_by_dist: | |
| res = protein_atoms[idx].GetPDBResidueInfo() | |
| if res.GetIsHeteroAtom(): | |
| continue | |
| if res.GetChainId() not in pocket_residues_by_chain: | |
| pocket_residues_by_chain[res.GetChainId()] = set() | |
| # get residue chain | |
| pocket_residues_by_chain[res.GetChainId()].add(res.GetResidueNumber()) | |
| total_found = sum([len(res_ids) for res_ids in pocket_residues_by_chain.values()]) | |
| if total_found >= res_to_save_count: | |
| break | |
| print("saved with simple", total_found) | |
| if not pocket_residues_by_chain: | |
| print("No pocket residues found") | |
| return -1 | |
| res_to_remove = [] | |
| for res in gt_model.get_residues(): | |
| if res.id[1] not in pocket_residues_by_chain.get(res.parent.id, set()) \ | |
| or res.id[0].strip() != "" or res.id[2].strip() != "": | |
| res_to_remove.append(res) | |
| for res in res_to_remove: | |
| gt_model[res.parent.id].detach_child(res.id) | |
| io = Bio.PDB.PDBIO() | |
| io.set_structure(gt_model) | |
| io.save(output_path) | |
| return total_found | |
| def cif_to_pdb(cif_path: str, pdb_path: str): | |
| protein = Bio.PDB.MMCIFParser().get_structure("s_cif", cif_path) | |
| io = Bio.PDB.PDBIO() | |
| io.set_structure(protein) | |
| io.save(pdb_path) | |
| def get_chain_object_to_seq(chain: Bio.PDB.Chain.Chain) -> str: | |
| res_id_to_res = {res.get_id()[1]: res for res in chain.get_residues() if "CA" in res} | |
| if len(res_id_to_res) == 0: | |
| print("skipping empty chain", chain.get_id()) | |
| return "" | |
| seq = "" | |
| for i in range(1, max(res_id_to_res) + 1): | |
| if i in res_id_to_res: | |
| seq += Bio.SeqUtils.seq1(res_id_to_res[i].get_resname()) | |
| else: | |
| seq += "X" | |
| return seq | |
| def get_sequence_from_pdb(pdb_path: str) -> Tuple[str, List[int]]: | |
| pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
| pdb_struct = pdb_parser.get_structure("original_pdb", pdb_path) | |
| # chain_to_seq = {chain.id: get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()} | |
| all_chain_seqs = [ get_chain_object_to_seq(chain) for chain in pdb_struct.get_chains()] | |
| chain_lengths = [len(seq) for seq in all_chain_seqs] | |
| return ("X" * 20).join(all_chain_seqs), chain_lengths | |
| from Bio import PDB | |
| from Bio import pairwise2 | |
| def extract_sequence(chain): | |
| seq = '' | |
| residues = [] | |
| for res in chain.get_residues(): | |
| seq_res = Bio.SeqUtils.seq1(res.get_resname()) | |
| if seq_res in ('X', "", " "): | |
| continue | |
| seq += seq_res | |
| residues.append(res) | |
| return seq, residues | |
| def map_residues(alignment, residues_gt, residues_pred): | |
| idx_gt = 0 | |
| idx_pred = 0 | |
| mapping = [] | |
| for i in range(len(alignment.seqA)): | |
| aa_gt = alignment.seqA[i] | |
| aa_pred = alignment.seqB[i] | |
| res_gt = None | |
| res_pred = None | |
| if aa_gt != '-': | |
| res_gt = residues_gt[idx_gt] | |
| idx_gt += 1 | |
| if aa_pred != '-': | |
| res_pred = residues_pred[idx_pred] | |
| idx_pred +=1 | |
| if res_gt and res_pred: | |
| mapping.append((res_gt, res_pred)) | |
| return mapping | |
| class ResidueSelect(PDB.Select): | |
| def __init__(self, residues_to_select): | |
| self.residues_to_select = set(residues_to_select) | |
| def accept_residue(self, residue): | |
| return residue in self.residues_to_select | |
| def align_gt_and_input(gt_pdb_path, input_pdb_path, output_gt_path, output_input_path): | |
| parser = PDB.PDBParser(QUIET=True) | |
| gt_structure = parser.get_structure('gt', gt_pdb_path) | |
| pred_structure = parser.get_structure('pred', input_pdb_path) | |
| matched_residues_gt = [] | |
| matched_residues_pred = [] | |
| used_chain_pred = [] | |
| total_mapping_size = 0 | |
| for chain_gt in gt_structure.get_chains(): | |
| seq_gt, residues_gt = extract_sequence(chain_gt) | |
| best_alignment = None | |
| best_chain_pred = None | |
| best_score = -1 | |
| best_residues_pred = None | |
| # Find the best matching chain in pred | |
| for chain_pred in pred_structure.get_chains(): | |
| print("checking", chain_pred.get_id(), chain_gt.get_id()) | |
| if chain_pred in used_chain_pred: | |
| continue | |
| seq_pred, residues_pred = extract_sequence(chain_pred) | |
| print(seq_gt) | |
| print(seq_pred) | |
| alignments = pairwise2.align.globalxx(seq_gt, seq_pred, one_alignment_only=True) | |
| if not alignments: | |
| continue | |
| print("checking2", chain_pred.get_id(), chain_gt.get_id()) | |
| alignment = alignments[0] | |
| score = alignment.score | |
| if score > best_score: | |
| best_score = score | |
| best_alignment = alignment | |
| best_chain_pred = chain_pred | |
| best_residues_pred = residues_pred | |
| if best_alignment: | |
| mapping = map_residues(best_alignment, residues_gt, best_residues_pred) | |
| total_mapping_size += len(mapping) | |
| used_chain_pred.append(best_chain_pred) | |
| for res_gt, res_pred in mapping: | |
| matched_residues_gt.append(res_gt) | |
| matched_residues_pred.append(res_pred) | |
| else: | |
| print(f"No matching chain found for chain {chain_gt.get_id()}") | |
| print(f"Total mapping size: {total_mapping_size}") | |
| # Write new PDB files with only matched residues | |
| io = PDB.PDBIO() | |
| io.set_structure(gt_structure) | |
| io.save(output_gt_path, ResidueSelect(matched_residues_gt)) | |
| io.set_structure(pred_structure) | |
| io.save(output_input_path, ResidueSelect(matched_residues_pred)) | |
| def validate_matching_input_gt(gt_pdb_path, input_pdb_path): | |
| gt_residues = [res for res in PDB.PDBParser().get_structure('gt', gt_pdb_path).get_residues()] | |
| input_residues = [res for res in PDB.PDBParser().get_structure('input', input_pdb_path).get_residues()] | |
| if len(gt_residues) != len(input_residues): | |
| print(f"Residue count mismatch: {len(gt_residues)} vs {len(input_residues)}") | |
| return -1 | |
| for res_gt, res_input in zip(gt_residues, input_residues): | |
| if res_gt.get_resname() != res_input.get_resname(): | |
| print(f"Residue name mismatch: {res_gt.get_resname()} vs {res_input.get_resname()}") | |
| return -1 | |
| return len(input_residues) | |
| def prepare_system(row, system_folder, output_models_folder, output_jsons_folder, should_overwrite=False): | |
| output_json_path = os.path.join(output_jsons_folder, f"{row['system_id']}.json") | |
| if os.path.exists(output_json_path) and not should_overwrite: | |
| return "Already exists" | |
| plinder_gt_pdb_path = os.path.join(system_folder, f"receptor.pdb") | |
| plinder_gt_ligand_paths = [] | |
| plinder_gt_ligands_folder = os.path.join(system_folder, "ligand_files") | |
| gt_output_path = os.path.join(output_models_folder, f"{row['system_id']}_gt.pdb") | |
| gt_output_relative_path = "plinder_models/" + f"{row['system_id']}_gt.pdb" | |
| tmp_input_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_input.pdb") | |
| protein_input_path = os.path.join(output_models_folder, f"{row['system_id']}_input.pdb") | |
| protein_input_relative_path = "plinder_models/" + f"{row['system_id']}_input.pdb" | |
| print("Copying ground truth files") | |
| if not os.path.exists(plinder_gt_pdb_path): | |
| print("no receptor", plinder_gt_pdb_path) | |
| return "No receptor" | |
| tmp_gt_pdb_path = os.path.join(output_models_folder, f"tmp_{row['system_id']}_gt.pdb") | |
| robust_renumber_protein(plinder_gt_pdb_path, tmp_gt_pdb_path) | |
| ligand_pos_list = [] | |
| for ligand_file in os.listdir(plinder_gt_ligands_folder): | |
| if not ligand_file.endswith(".sdf"): | |
| continue | |
| plinder_gt_ligand_paths.append(os.path.join(plinder_gt_ligands_folder, ligand_file)) | |
| loaded_ligand = Chem.MolFromMolFile(os.path.join(plinder_gt_ligands_folder, ligand_file)) | |
| ligand_pos_list.append(loaded_ligand.GetConformer().GetPositions()) | |
| if loaded_ligand is None: | |
| print("failed to load", plinder_gt_ligand_paths[-1]) | |
| return "Failed to load ligand" | |
| # Crop ground truth protein, also removes insertion codes | |
| ligand_pos = np.concatenate(ligand_pos_list, axis=0) | |
| res_count_in_protein = crop_protein_cont(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350, | |
| distance_threshold=5) | |
| if res_count_in_protein == -1: | |
| print("Failed to crop protein continously, using simple crop") | |
| crop_protein_simple(tmp_gt_pdb_path, ligand_pos, gt_output_path, max_length=350) | |
| os.remove(tmp_gt_pdb_path) | |
| # Generate input protein structure | |
| input_protein_source = None | |
| if pd.notna(row["linked_apo_id"]): | |
| apo_pdb_path = os.path.join(PLINDER_LINKED_APO_STRUCTURES, f"{row['linked_apo_id']}.cif") | |
| try: | |
| robust_renumber_protein(apo_pdb_path, tmp_input_path) | |
| input_protein_source = "apo" | |
| print("Using input apo", row['linked_apo_id']) | |
| except Exception as e: | |
| print("Problem with apo", e, row["linked_apo_id"], apo_pdb_path) | |
| if not os.path.exists(tmp_input_path) and pd.notna(row["linked_pred_id"]): | |
| pred_pdb_path = os.path.join(PLINDER_LINKED_PRED_STRUCTURES, f"{row['linked_pred_id']}.cif") | |
| try: | |
| # cif_to_pdb(pred_pdb_path, tmp_input_path) | |
| robust_renumber_protein(pred_pdb_path, tmp_input_path) | |
| input_protein_source = "pred" | |
| print("Using input pred", row['linked_pred_id']) | |
| except: | |
| print("Problem with pred") | |
| if not os.path.exists(tmp_input_path): | |
| print("No linked structure found, running ESM") | |
| url = "https://api.esmatlas.com/foldSequence/v1/pdb/" | |
| sequence, chain_lengths = get_sequence_from_pdb(gt_output_path) | |
| if len(sequence) <= 400: | |
| try: | |
| response = requests.post(url, data=sequence) | |
| response.raise_for_status() | |
| pdb_text = response.text | |
| with open(tmp_input_path, "w") as f: | |
| f.write(pdb_text) | |
| # divide to chains | |
| if len(chain_lengths) > 1: | |
| pdb_parser = Bio.PDB.PDBParser(QUIET=True) | |
| pdb_struct = pdb_parser.get_structure("original_pdb", tmp_input_path) | |
| pdb_model = next(iter(pdb_struct)) | |
| chain_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[:len(chain_lengths)] | |
| start_ind = 1 | |
| esm_chain = next(pdb_model.get_chains()) | |
| new_model = Bio.PDB.Model.Model(0) | |
| for chain_length, chain_id in zip(chain_lengths, chain_ids): | |
| end_ind = start_ind + chain_length | |
| new_chain = Bio.PDB.Chain.Chain(chain_id) | |
| for res in esm_chain.get_residues(): | |
| if start_ind <= res.id[1] <= end_ind: | |
| new_chain.add(res) | |
| new_model.add(new_chain) | |
| start_ind = end_ind + 20 # 20 is the gap in esm | |
| io = Bio.PDB.PDBIO() | |
| io.set_structure(new_model) | |
| io.save(tmp_input_path) | |
| input_protein_source = "esm" | |
| print("Using input ESM") | |
| except requests.exceptions.RequestException as e: | |
| print(f"An error occurred in ESM: {e}") | |
| # return "No linked structure found" | |
| else: | |
| print("Sequence too long for ESM") | |
| if not os.path.exists(tmp_input_path): | |
| print("Using input GT") | |
| shutil.copyfile(gt_output_path, tmp_input_path) | |
| input_protein_source = "gt" | |
| align_gt_and_input(gt_output_path, tmp_input_path, gt_output_path, protein_input_path) | |
| protein_size = validate_matching_input_gt(gt_output_path, protein_input_path) | |
| assert protein_size > -1, "Failed to validate matching input and gt" | |
| os.remove(tmp_input_path) | |
| rel_gt_lig_paths = [] | |
| rel_ref_lig_paths = [] | |
| input_smiles = [] | |
| for i, ligand_path in enumerate(sorted(plinder_gt_ligand_paths)): | |
| gt_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_gt_{i}.sdf") | |
| # rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ref_ligand_{i}.sdf") | |
| rel_gt_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_gt_{i}.sdf") | |
| shutil.copyfile(ligand_path, gt_ligand_output_path) | |
| loaded_ligand = Chem.MolFromMolFile(gt_ligand_output_path) | |
| input_smiles.append(Chem.MolToSmiles(loaded_ligand)) | |
| ref_ligand_output_path = os.path.join(output_models_folder, f"{row['system_id']}_ligand_ref_{i}.sdf") | |
| rel_ref_lig_paths.append(f"plinder_models/{row['system_id']}_ligand_ref_{i}.sdf") | |
| create_conformers(input_smiles[-1], ref_ligand_output_path, num_conformers=1) | |
| # check if file is empty | |
| if os.path.getsize(ref_ligand_output_path) == 0: | |
| print("Empty ref ligand, copying from gt", ref_ligand_output_path) | |
| shutil.copyfile(gt_ligand_output_path, ref_ligand_output_path) | |
| affinity = row["ligand_binding_affinity"] | |
| if not pd.notna(affinity): | |
| affinity = None | |
| json_data = { | |
| "input_structure": protein_input_relative_path, | |
| "gt_structure": gt_output_relative_path, | |
| "gt_sdf_list": rel_gt_lig_paths, | |
| "input_smiles_list": input_smiles, | |
| "resolution": row.fillna(99)["entry_resolution"], | |
| "release_year": row["entry_release_date"], | |
| "affinity": affinity, | |
| "protein_seq_len": protein_size, | |
| "uniprot": row["system_pocket_UniProt"], | |
| "ligand_num_atoms": ligand_pos.shape[0], | |
| "cluster": row["cluster"], | |
| "cluster_size": row["cluster_size"], | |
| "input_protein_source": input_protein_source, | |
| "ref_sdf_list": rel_ref_lig_paths, | |
| "pdb_id": row["system_id"], | |
| } | |
| open(output_json_path, "w").write(json.dumps(json_data, indent=4)) | |
| return "success" | |
| # use linked structures | |
| # input_structure_to_use = None | |
| # apo_linked_structure = os.path.join(linked_structures_folder, "apo", system_id) | |
| # pred_linked_structure = os.path.join(linked_structures_folder, "pred", system_id) | |
| # if os.path.exists(apo_linked_structure): | |
| # for folder in os.listdir(apo_linked_structure): | |
| # if not os.path.isdir(os.path.join(pred_linked_structure, folder)): | |
| # continue | |
| # for filename in os.listdir(os.path.join(apo_linked_structure, folder)): | |
| # if filename.endswith(".cif"): | |
| # input_structure_to_use = os.path.join(apo_linked_structure, folder, filename) | |
| # break | |
| # if input_structure_to_use: | |
| # break | |
| # print(system_id, "found apo", input_structure_to_use) | |
| # elif os.path.exists(pred_linked_structure): | |
| # for folder in os.listdir(pred_linked_structure): | |
| # if not os.path.isdir(os.path.join(pred_linked_structure, folder)): | |
| # continue | |
| # for filename in os.listdir(os.path.join(pred_linked_structure, folder)): | |
| # if filename.endswith(".cif"): | |
| # input_structure_to_use = os.path.join(pred_linked_structure, folder, filename) | |
| # break | |
| # if input_structure_to_use: | |
| # break | |
| # print(system_id, "found pred", input_structure_to_use) | |
| # else: | |
| # print(system_id, "no linked structure found") | |
| # return "No linked structure found" | |
| def main(prefix_bucket_id: str = "*"): | |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
| systems = get_cached_systems_to_train() | |
| print("total systems", len(systems)) | |
| print("clusters", systems["cluster"].value_counts()) | |
| # systems = systems[systems["system_num_protein_chains"] > 1] | |
| # return | |
| print("splits", systems["split"].value_counts()) | |
| val_or_test = systems[(systems["split"] == "val") | (systems["split"] == "test")] | |
| print("validation or test", len(val_or_test)) | |
| output_models_folder = os.path.join(OUTPUT_FOLDER, "plinder_models") | |
| output_train_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_train") | |
| output_val_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_val") | |
| output_test_jsons_folder = os.path.join(OUTPUT_FOLDER, "plinder_jsons_test") | |
| output_info = os.path.join(OUTPUT_FOLDER, "plinder_generation_info.csv") | |
| if prefix_bucket_id != "*": | |
| output_info = os.path.join(OUTPUT_FOLDER, f"plinder_generation_info_{prefix_bucket_id}.csv") | |
| os.makedirs(output_models_folder, exist_ok=True) | |
| os.makedirs(output_train_jsons_folder, exist_ok=True) | |
| os.makedirs(output_val_jsons_folder, exist_ok=True) | |
| os.makedirs(output_test_jsons_folder, exist_ok=True) | |
| split_to_folder = { | |
| "train": output_train_jsons_folder, | |
| "val": output_val_jsons_folder, | |
| "test": output_test_jsons_folder | |
| } | |
| output_info_file = open(output_info, "a+") | |
| for bucket_id, bucket_systems in systems.groupby('_bucket_id', sort=True): | |
| if prefix_bucket_id != "*" and not str(bucket_id).startswith(prefix_bucket_id): | |
| continue | |
| # if bucket_id != "z2": | |
| # continue | |
| # systems_folder = "{BASE_FOLDER}/processed/tmp_z2/systems" | |
| print("Starting bucket", bucket_id, len(bucket_systems)) | |
| print(len(bucket_systems), bucket_systems["system_num_ligand_chains"].value_counts()) | |
| tmp_output_models_folder = os.path.join(OUTPUT_FOLDER, f"tmp_{bucket_id}") | |
| os.makedirs(tmp_output_models_folder, exist_ok=True) | |
| os.system(f'{GSUTIL_PATH} -m cp -r "gs://plinder/2024-06/v2/systems/{bucket_id}.zip" {tmp_output_models_folder}') | |
| systems_folder = os.path.join(tmp_output_models_folder, "systems") | |
| os.system(f'unzip -o {os.path.join(tmp_output_models_folder, f"{bucket_id}.zip")} -d {systems_folder}') | |
| for i, row in bucket_systems.iterrows(): | |
| # if not str(row['system_id']).startswith("4z22__1__1.A__1.C"): | |
| # continue | |
| print("doing", row['system_id'], row["system_num_protein_chains"], row["system_num_ligand_chains"]) | |
| system_folder = os.path.join(systems_folder, row['system_id']) | |
| try: | |
| success = prepare_system(row, system_folder, output_models_folder, split_to_folder[row["split"]]) | |
| print("done", row['system_id'], success) | |
| output_info_file.write(f"{bucket_id},{row['system_id']},{success}\n") | |
| except Exception as e: | |
| print("Failed", row['system_id'], e) | |
| output_info_file.write(f"{bucket_id},{row['system_id']},Failed\n") | |
| output_info_file.flush() | |
| shutil.rmtree(tmp_output_models_folder) | |
| if __name__ == '__main__': | |
| prefix_bucket_id = "*" | |
| if len(sys.argv) > 1: | |
| prefix_bucket_id = sys.argv[1] | |
| main(prefix_bucket_id) |