from __future__ import print_function import numpy as np import torch import torch.utils from prody import * confProDy(verbosity="none") restype_1to3 = { "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL", "X": "UNK", } restype_str_to_int = { "A": 0, "C": 1, "D": 2, "E": 3, "F": 4, "G": 5, "H": 6, "I": 7, "K": 8, "L": 9, "M": 10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15, "T": 16, "V": 17, "W": 18, "Y": 19, "X": 20, } restype_int_to_str = { 0: "A", 1: "C", 2: "D", 3: "E", 4: "F", 5: "G", 6: "H", 7: "I", 8: "K", 9: "L", 10: "M", 11: "N", 12: "P", 13: "Q", 14: "R", 15: "S", 16: "T", 17: "V", 18: "W", 19: "Y", 20: "X", } alphabet = list(restype_str_to_int) element_list = [ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mb", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Uut", "Fl", "Uup", "Lv", "Uus", "Uuo", ] element_list = [item.upper() for item in element_list] # element_dict = dict(zip(element_list, range(1,len(element_list)))) element_dict_rev = dict(zip(range(1, len(element_list)), element_list)) def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor): """ S : true sequence shape=[batch, length] S_pred : predicted sequence shape=[batch, length] mask : mask to compute average over the region shape=[batch, length] average : averaged sequence recovery shape=[batch] """ match = S == S_pred average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1) return average def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor): """ S : true sequence shape=[batch, length] log_probs : predicted sequence shape=[batch, length] mask : mask to compute average over the region shape=[batch, length] average_loss : averaged categorical cross entropy (CCE) [batch] loss_per_resdue : per position CCE [batch, length] """ S_one_hot = torch.nn.functional.one_hot(S, 21) loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L] average_loss = torch.sum(loss_per_residue * mask, dim=-1) / ( torch.sum(mask, dim=-1) + 1e-8 ) return average_loss, loss_per_residue def write_full_PDB( save_path: str, X: np.ndarray, X_m: np.ndarray, b_factors: np.ndarray, R_idx: np.ndarray, chain_letters: np.ndarray, S: np.ndarray, other_atoms=None, icodes=None, force_hetatm=False, ): """ save_path : path where the PDB will be written to X : protein atom xyz coordinates shape=[length, 14, 3] X_m : protein atom mask shape=[length, 14] b_factors: shape=[length, 14] R_idx: protein residue indices shape=[length] chain_letters: protein chain letters shape=[length] S : protein amino acid sequence shape=[length] other_atoms: other atoms parsed by prody icodes: a list of insertion codes for the PDB; e.g. antibody loops """ restype_1to3 = { "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS", "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE", "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO", "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL", "X": "UNK", } restype_INTtoSTR = { 0: "A", 1: "C", 2: "D", 3: "E", 4: "F", 5: "G", 6: "H", 7: "I", 8: "K", 9: "L", 10: "M", 11: "N", 12: "P", 13: "Q", 14: "R", 15: "S", 16: "T", 17: "V", 18: "W", 19: "Y", 20: "X", } restype_name_to_atom14_names = { "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], "ARG": [ "N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", "", ], "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""], "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""], "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], "GLN": [ "N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", "", ], "GLU": [ "N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", "", ], "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], "HIS": [ "N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", "", ], "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""], "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""], "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""], "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""], "PHE": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", "", ], "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""], "TRP": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "NE1", "CZ2", "CZ3", "CH2", ], "TYR": [ "N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", "", ], "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""], "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], } S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]] X_list = [] b_factor_list = [] atom_name_list = [] element_name_list = [] residue_name_list = [] residue_number_list = [] chain_id_list = [] icodes_list = [] for i, AA in enumerate(S_str): sel = X_m[i].astype(np.int32) == 1 total = np.sum(sel) tmp = np.array(restype_name_to_atom14_names[AA])[sel] X_list.append(X[i][sel]) b_factor_list.append(b_factors[i][sel]) atom_name_list.append(tmp) element_name_list += [AA[:1] for AA in list(tmp)] residue_name_list += total * [AA] residue_number_list += total * [R_idx[i]] chain_id_list += total * [chain_letters[i]] icodes_list += total * [icodes[i]] X_stack = np.concatenate(X_list, 0) b_factor_stack = np.concatenate(b_factor_list, 0) atom_name_stack = np.concatenate(atom_name_list, 0) protein = prody.AtomGroup() protein.setCoords(X_stack) protein.setBetas(b_factor_stack) protein.setNames(atom_name_stack) protein.setResnames(residue_name_list) protein.setElements(element_name_list) protein.setOccupancies(np.ones([X_stack.shape[0]])) protein.setResnums(residue_number_list) protein.setChids(chain_id_list) protein.setIcodes(icodes_list) if other_atoms: other_atoms_g = prody.AtomGroup() other_atoms_g.setCoords(other_atoms.getCoords()) other_atoms_g.setNames(other_atoms.getNames()) other_atoms_g.setResnames(other_atoms.getResnames()) other_atoms_g.setElements(other_atoms.getElements()) other_atoms_g.setOccupancies(other_atoms.getOccupancies()) other_atoms_g.setResnums(other_atoms.getResnums()) other_atoms_g.setChids(other_atoms.getChids()) if force_hetatm: other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm")) writePDB(save_path, protein + other_atoms_g) else: writePDB(save_path, protein) def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str): """ protein_atoms: prody atom group CA_dict: mapping between chain_residue_idx_icodes and integers atom_name: atom to be parsed; e.g. CA """ atom_atoms = protein_atoms.select(f"name {atom_name}") if atom_atoms != None: atom_coords = atom_atoms.getCoords() atom_resnums = atom_atoms.getResnums() atom_chain_ids = atom_atoms.getChids() atom_icodes = atom_atoms.getIcodes() atom_coords_ = np.zeros([len(CA_dict), 3], np.float32) atom_coords_m = np.zeros([len(CA_dict)], np.int32) if atom_atoms != None: for i in range(len(atom_resnums)): code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i] if code in list(CA_dict): atom_coords_[CA_dict[code], :] = atom_coords[i] atom_coords_m[CA_dict[code]] = 1 return atom_coords_, atom_coords_m def parse_PDB( input_path: str, device: str = "cpu", chains: list = [], parse_all_atoms: bool = False, parse_atoms_with_zero_occupancy: bool = False ): """ input_path : path for the input PDB device: device for the torch.Tensor chains: a list specifying which chains need to be parsed; e.g. ["A", "B"] parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed """ element_list = [ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mb", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Uut", "Fl", "Uup", "Lv", "Uus", "Uuo", ] element_list = [item.upper() for item in element_list] element_dict = dict(zip(element_list, range(1, len(element_list)))) restype_3to1 = { "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C", "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P", "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V", } restype_STRtoINT = { "A": 0, "C": 1, "D": 2, "E": 3, "F": 4, "G": 5, "H": 6, "I": 7, "K": 8, "L": 9, "M": 10, "N": 11, "P": 12, "Q": 13, "R": 14, "S": 15, "T": 16, "V": 17, "W": 18, "Y": 19, "X": 20, } atom_order = { "N": 0, "CA": 1, "C": 2, "CB": 3, "O": 4, "CG": 5, "CG1": 6, "CG2": 7, "OG": 8, "OG1": 9, "SG": 10, "CD": 11, "CD1": 12, "CD2": 13, "ND1": 14, "ND2": 15, "OD1": 16, "OD2": 17, "SD": 18, "CE": 19, "CE1": 20, "CE2": 21, "CE3": 22, "NE": 23, "NE1": 24, "NE2": 25, "OE1": 26, "OE2": 27, "CH2": 28, "NH1": 29, "NH2": 30, "OH": 31, "CZ": 32, "CZ2": 33, "CZ3": 34, "NZ": 35, "OXT": 36, } if not parse_all_atoms: atom_types = ["N", "CA", "C", "O"] else: atom_types = [ "N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD", "CD1", "CD2", "ND1", "ND2", "OD1", "OD2", "SD", "CE", "CE1", "CE2", "CE3", "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH", "CZ", "CZ2", "CZ3", "NZ", ] atoms = parsePDB(input_path) if not parse_atoms_with_zero_occupancy: atoms = atoms.select("occupancy > 0") if chains: str_out = "" for item in chains: str_out += " chain " + item + " or" atoms = atoms.select(str_out[1:-3]) protein_atoms = atoms.select("protein") backbone = protein_atoms.select("backbone") other_atoms = atoms.select("not protein and not water") water_atoms = atoms.select("water") CA_atoms = protein_atoms.select("name CA") CA_resnums = CA_atoms.getResnums() CA_chain_ids = CA_atoms.getChids() CA_icodes = CA_atoms.getIcodes() CA_dict = {} for i in range(len(CA_resnums)): code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i] CA_dict[code] = i xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32) xyz_37_m = np.zeros([len(CA_dict), 37], np.int32) for atom_name in atom_types: xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name) xyz_37[:, atom_order[atom_name], :] = xyz xyz_37_m[:, atom_order[atom_name]] = xyz_m N = xyz_37[:, atom_order["N"], :] CA = xyz_37[:, atom_order["CA"], :] C = xyz_37[:, atom_order["C"], :] O = xyz_37[:, atom_order["O"], :] N_m = xyz_37_m[:, atom_order["N"]] CA_m = xyz_37_m[:, atom_order["CA"]] C_m = xyz_37_m[:, atom_order["C"]] O_m = xyz_37_m[:, atom_order["O"]] mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist b = CA - N c = C - CA a = np.cross(b, c, axis=-1) CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32) R_idx = np.array(CA_resnums, dtype=np.int32) S = CA_atoms.getResnames() S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)] S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32) X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1) try: Y = np.array(other_atoms.getCoords(), dtype=np.float32) Y_t = list(other_atoms.getElements()) Y_t = np.array( [ element_dict[y_t.upper()] if y_t.upper() in element_list else 0 for y_t in Y_t ], dtype=np.int32, ) Y_m = (Y_t != 1) * (Y_t != 0) Y = Y[Y_m, :] Y_t = Y_t[Y_m] Y_m = Y_m[Y_m] except: Y = np.zeros([1, 3], np.float32) Y_t = np.zeros([1], np.int32) Y_m = np.zeros([1], np.int32) output_dict = {} output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32) output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32) output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32) output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32) output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32) output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32) output_dict["chain_labels"] = torch.tensor( chain_labels, device=device, dtype=torch.int32 ) output_dict["chain_letters"] = CA_chain_ids mask_c = [] chain_list = list(set(output_dict["chain_letters"])) chain_list.sort() for chain in chain_list: mask_c.append( torch.tensor( [chain == item for item in output_dict["chain_letters"]], device=device, dtype=bool, ) ) output_dict["mask_c"] = mask_c output_dict["chain_list"] = chain_list output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32) output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32) output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32) return output_dict, backbone, other_atoms, CA_icodes, water_atoms def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms): device = CB.device mask_CBY = mask[:, None] * Y_m[None, :] # [A,B] L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1) L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0 nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms] L2_AB_nn = torch.gather(L2_AB, 1, nn_idx) D_AB_closest = torch.sqrt(L2_AB_nn[:, 0]) Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1) Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1) Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1) Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3)) Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx) Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx) Y = torch.zeros( [CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device ) Y_t = torch.zeros( [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device ) Y_m = torch.zeros( [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device ) num_nn_update = Y_tmp.shape[1] Y[:, :num_nn_update] = Y_tmp Y_t[:, :num_nn_update] = Y_t_tmp Y_m[:, :num_nn_update] = Y_m_tmp return Y, Y_t, Y_m, D_AB_closest def featurize( input_dict, cutoff_for_score=8.0, use_atom_context=True, number_of_ligand_atoms=16, model_type="protein_mpnn", ): output_dict = {} if model_type == "ligand_mpnn": mask = input_dict["mask"] Y = input_dict["Y"] Y_t = input_dict["Y_t"] Y_m = input_dict["Y_m"] N = input_dict["X"][:, 0, :] CA = input_dict["X"][:, 1, :] C = input_dict["X"][:, 2, :] b = CA - N c = C - CA a = torch.cross(b, c, axis=-1) CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA Y, Y_t, Y_m, D_XY = get_nearest_neighbours( CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms ) mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0] output_dict["mask_XY"] = mask_XY[None,] if "side_chain_mask" in list(input_dict): output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,] output_dict["Y"] = Y[None,] output_dict["Y_t"] = Y_t[None,] output_dict["Y_m"] = Y_m[None,] if not use_atom_context: output_dict["Y_m"] = 0.0 * output_dict["Y_m"] elif ( model_type == "per_residue_label_membrane_mpnn" or model_type == "global_label_membrane_mpnn" ): output_dict["membrane_per_residue_labels"] = input_dict[ "membrane_per_residue_labels" ][None,] R_idx_list = [] count = 0 R_idx_prev = -100000 for R_idx in list(input_dict["R_idx"]): if R_idx_prev == R_idx: count += 1 R_idx_list.append(R_idx + count) R_idx_prev = R_idx R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device) output_dict["R_idx"] = R_idx_renumbered[None,] output_dict["R_idx_original"] = input_dict["R_idx"][None,] output_dict["chain_labels"] = input_dict["chain_labels"][None,] output_dict["S"] = input_dict["S"][None,] output_dict["chain_mask"] = input_dict["chain_mask"][None,] output_dict["mask"] = input_dict["mask"][None,] output_dict["X"] = input_dict["X"][None,] if "xyz_37" in list(input_dict): output_dict["xyz_37"] = input_dict["xyz_37"][None,] output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,] return output_dict