Spaces:
Running on Zero
Running on Zero
| from __future__ import print_function | |
| import numpy as np | |
| import torch | |
| import torch.utils | |
| from prody import * | |
| confProDy(verbosity="none") | |
| restype_1to3 = { | |
| "A": "ALA", | |
| "R": "ARG", | |
| "N": "ASN", | |
| "D": "ASP", | |
| "C": "CYS", | |
| "Q": "GLN", | |
| "E": "GLU", | |
| "G": "GLY", | |
| "H": "HIS", | |
| "I": "ILE", | |
| "L": "LEU", | |
| "K": "LYS", | |
| "M": "MET", | |
| "F": "PHE", | |
| "P": "PRO", | |
| "S": "SER", | |
| "T": "THR", | |
| "W": "TRP", | |
| "Y": "TYR", | |
| "V": "VAL", | |
| "X": "UNK", | |
| } | |
| restype_str_to_int = { | |
| "A": 0, | |
| "C": 1, | |
| "D": 2, | |
| "E": 3, | |
| "F": 4, | |
| "G": 5, | |
| "H": 6, | |
| "I": 7, | |
| "K": 8, | |
| "L": 9, | |
| "M": 10, | |
| "N": 11, | |
| "P": 12, | |
| "Q": 13, | |
| "R": 14, | |
| "S": 15, | |
| "T": 16, | |
| "V": 17, | |
| "W": 18, | |
| "Y": 19, | |
| "X": 20, | |
| } | |
| restype_int_to_str = { | |
| 0: "A", | |
| 1: "C", | |
| 2: "D", | |
| 3: "E", | |
| 4: "F", | |
| 5: "G", | |
| 6: "H", | |
| 7: "I", | |
| 8: "K", | |
| 9: "L", | |
| 10: "M", | |
| 11: "N", | |
| 12: "P", | |
| 13: "Q", | |
| 14: "R", | |
| 15: "S", | |
| 16: "T", | |
| 17: "V", | |
| 18: "W", | |
| 19: "Y", | |
| 20: "X", | |
| } | |
| alphabet = list(restype_str_to_int) | |
| element_list = [ | |
| "H", | |
| "He", | |
| "Li", | |
| "Be", | |
| "B", | |
| "C", | |
| "N", | |
| "O", | |
| "F", | |
| "Ne", | |
| "Na", | |
| "Mg", | |
| "Al", | |
| "Si", | |
| "P", | |
| "S", | |
| "Cl", | |
| "Ar", | |
| "K", | |
| "Ca", | |
| "Sc", | |
| "Ti", | |
| "V", | |
| "Cr", | |
| "Mn", | |
| "Fe", | |
| "Co", | |
| "Ni", | |
| "Cu", | |
| "Zn", | |
| "Ga", | |
| "Ge", | |
| "As", | |
| "Se", | |
| "Br", | |
| "Kr", | |
| "Rb", | |
| "Sr", | |
| "Y", | |
| "Zr", | |
| "Nb", | |
| "Mb", | |
| "Tc", | |
| "Ru", | |
| "Rh", | |
| "Pd", | |
| "Ag", | |
| "Cd", | |
| "In", | |
| "Sn", | |
| "Sb", | |
| "Te", | |
| "I", | |
| "Xe", | |
| "Cs", | |
| "Ba", | |
| "La", | |
| "Ce", | |
| "Pr", | |
| "Nd", | |
| "Pm", | |
| "Sm", | |
| "Eu", | |
| "Gd", | |
| "Tb", | |
| "Dy", | |
| "Ho", | |
| "Er", | |
| "Tm", | |
| "Yb", | |
| "Lu", | |
| "Hf", | |
| "Ta", | |
| "W", | |
| "Re", | |
| "Os", | |
| "Ir", | |
| "Pt", | |
| "Au", | |
| "Hg", | |
| "Tl", | |
| "Pb", | |
| "Bi", | |
| "Po", | |
| "At", | |
| "Rn", | |
| "Fr", | |
| "Ra", | |
| "Ac", | |
| "Th", | |
| "Pa", | |
| "U", | |
| "Np", | |
| "Pu", | |
| "Am", | |
| "Cm", | |
| "Bk", | |
| "Cf", | |
| "Es", | |
| "Fm", | |
| "Md", | |
| "No", | |
| "Lr", | |
| "Rf", | |
| "Db", | |
| "Sg", | |
| "Bh", | |
| "Hs", | |
| "Mt", | |
| "Ds", | |
| "Rg", | |
| "Cn", | |
| "Uut", | |
| "Fl", | |
| "Uup", | |
| "Lv", | |
| "Uus", | |
| "Uuo", | |
| ] | |
| element_list = [item.upper() for item in element_list] | |
| # element_dict = dict(zip(element_list, range(1,len(element_list)))) | |
| element_dict_rev = dict(zip(range(1, len(element_list)), element_list)) | |
| def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor): | |
| """ | |
| S : true sequence shape=[batch, length] | |
| S_pred : predicted sequence shape=[batch, length] | |
| mask : mask to compute average over the region shape=[batch, length] | |
| average : averaged sequence recovery shape=[batch] | |
| """ | |
| match = S == S_pred | |
| average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1) | |
| return average | |
| def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor): | |
| """ | |
| S : true sequence shape=[batch, length] | |
| log_probs : predicted sequence shape=[batch, length] | |
| mask : mask to compute average over the region shape=[batch, length] | |
| average_loss : averaged categorical cross entropy (CCE) [batch] | |
| loss_per_resdue : per position CCE [batch, length] | |
| """ | |
| S_one_hot = torch.nn.functional.one_hot(S, 21) | |
| loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L] | |
| average_loss = torch.sum(loss_per_residue * mask, dim=-1) / ( | |
| torch.sum(mask, dim=-1) + 1e-8 | |
| ) | |
| return average_loss, loss_per_residue | |
| def write_full_PDB( | |
| save_path: str, | |
| X: np.ndarray, | |
| X_m: np.ndarray, | |
| b_factors: np.ndarray, | |
| R_idx: np.ndarray, | |
| chain_letters: np.ndarray, | |
| S: np.ndarray, | |
| other_atoms=None, | |
| icodes=None, | |
| force_hetatm=False, | |
| ): | |
| """ | |
| save_path : path where the PDB will be written to | |
| X : protein atom xyz coordinates shape=[length, 14, 3] | |
| X_m : protein atom mask shape=[length, 14] | |
| b_factors: shape=[length, 14] | |
| R_idx: protein residue indices shape=[length] | |
| chain_letters: protein chain letters shape=[length] | |
| S : protein amino acid sequence shape=[length] | |
| other_atoms: other atoms parsed by prody | |
| icodes: a list of insertion codes for the PDB; e.g. antibody loops | |
| """ | |
| restype_1to3 = { | |
| "A": "ALA", | |
| "R": "ARG", | |
| "N": "ASN", | |
| "D": "ASP", | |
| "C": "CYS", | |
| "Q": "GLN", | |
| "E": "GLU", | |
| "G": "GLY", | |
| "H": "HIS", | |
| "I": "ILE", | |
| "L": "LEU", | |
| "K": "LYS", | |
| "M": "MET", | |
| "F": "PHE", | |
| "P": "PRO", | |
| "S": "SER", | |
| "T": "THR", | |
| "W": "TRP", | |
| "Y": "TYR", | |
| "V": "VAL", | |
| "X": "UNK", | |
| } | |
| restype_INTtoSTR = { | |
| 0: "A", | |
| 1: "C", | |
| 2: "D", | |
| 3: "E", | |
| 4: "F", | |
| 5: "G", | |
| 6: "H", | |
| 7: "I", | |
| 8: "K", | |
| 9: "L", | |
| 10: "M", | |
| 11: "N", | |
| 12: "P", | |
| 13: "Q", | |
| 14: "R", | |
| 15: "S", | |
| 16: "T", | |
| 17: "V", | |
| 18: "W", | |
| 19: "Y", | |
| 20: "X", | |
| } | |
| restype_name_to_atom14_names = { | |
| "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], | |
| "ARG": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD", | |
| "NE", | |
| "CZ", | |
| "NH1", | |
| "NH2", | |
| "", | |
| "", | |
| "", | |
| ], | |
| "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""], | |
| "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""], | |
| "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], | |
| "GLN": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD", | |
| "OE1", | |
| "NE2", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| ], | |
| "GLU": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD", | |
| "OE1", | |
| "OE2", | |
| "", | |
| "", | |
| "", | |
| "", | |
| "", | |
| ], | |
| "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], | |
| "HIS": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "ND1", | |
| "CD2", | |
| "CE1", | |
| "NE2", | |
| "", | |
| "", | |
| "", | |
| "", | |
| ], | |
| "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""], | |
| "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""], | |
| "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""], | |
| "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""], | |
| "PHE": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD1", | |
| "CD2", | |
| "CE1", | |
| "CE2", | |
| "CZ", | |
| "", | |
| "", | |
| "", | |
| ], | |
| "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], | |
| "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], | |
| "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""], | |
| "TRP": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD1", | |
| "CD2", | |
| "CE2", | |
| "CE3", | |
| "NE1", | |
| "CZ2", | |
| "CZ3", | |
| "CH2", | |
| ], | |
| "TYR": [ | |
| "N", | |
| "CA", | |
| "C", | |
| "O", | |
| "CB", | |
| "CG", | |
| "CD1", | |
| "CD2", | |
| "CE1", | |
| "CE2", | |
| "CZ", | |
| "OH", | |
| "", | |
| "", | |
| ], | |
| "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""], | |
| "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], | |
| } | |
| S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]] | |
| X_list = [] | |
| b_factor_list = [] | |
| atom_name_list = [] | |
| element_name_list = [] | |
| residue_name_list = [] | |
| residue_number_list = [] | |
| chain_id_list = [] | |
| icodes_list = [] | |
| for i, AA in enumerate(S_str): | |
| sel = X_m[i].astype(np.int32) == 1 | |
| total = np.sum(sel) | |
| tmp = np.array(restype_name_to_atom14_names[AA])[sel] | |
| X_list.append(X[i][sel]) | |
| b_factor_list.append(b_factors[i][sel]) | |
| atom_name_list.append(tmp) | |
| element_name_list += [AA[:1] for AA in list(tmp)] | |
| residue_name_list += total * [AA] | |
| residue_number_list += total * [R_idx[i]] | |
| chain_id_list += total * [chain_letters[i]] | |
| icodes_list += total * [icodes[i]] | |
| X_stack = np.concatenate(X_list, 0) | |
| b_factor_stack = np.concatenate(b_factor_list, 0) | |
| atom_name_stack = np.concatenate(atom_name_list, 0) | |
| protein = prody.AtomGroup() | |
| protein.setCoords(X_stack) | |
| protein.setBetas(b_factor_stack) | |
| protein.setNames(atom_name_stack) | |
| protein.setResnames(residue_name_list) | |
| protein.setElements(element_name_list) | |
| protein.setOccupancies(np.ones([X_stack.shape[0]])) | |
| protein.setResnums(residue_number_list) | |
| protein.setChids(chain_id_list) | |
| protein.setIcodes(icodes_list) | |
| if other_atoms: | |
| other_atoms_g = prody.AtomGroup() | |
| other_atoms_g.setCoords(other_atoms.getCoords()) | |
| other_atoms_g.setNames(other_atoms.getNames()) | |
| other_atoms_g.setResnames(other_atoms.getResnames()) | |
| other_atoms_g.setElements(other_atoms.getElements()) | |
| other_atoms_g.setOccupancies(other_atoms.getOccupancies()) | |
| other_atoms_g.setResnums(other_atoms.getResnums()) | |
| other_atoms_g.setChids(other_atoms.getChids()) | |
| if force_hetatm: | |
| other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm")) | |
| writePDB(save_path, protein + other_atoms_g) | |
| else: | |
| writePDB(save_path, protein) | |
| def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str): | |
| """ | |
| protein_atoms: prody atom group | |
| CA_dict: mapping between chain_residue_idx_icodes and integers | |
| atom_name: atom to be parsed; e.g. CA | |
| """ | |
| atom_atoms = protein_atoms.select(f"name {atom_name}") | |
| if atom_atoms != None: | |
| atom_coords = atom_atoms.getCoords() | |
| atom_resnums = atom_atoms.getResnums() | |
| atom_chain_ids = atom_atoms.getChids() | |
| atom_icodes = atom_atoms.getIcodes() | |
| atom_coords_ = np.zeros([len(CA_dict), 3], np.float32) | |
| atom_coords_m = np.zeros([len(CA_dict)], np.int32) | |
| if atom_atoms != None: | |
| for i in range(len(atom_resnums)): | |
| code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i] | |
| if code in list(CA_dict): | |
| atom_coords_[CA_dict[code], :] = atom_coords[i] | |
| atom_coords_m[CA_dict[code]] = 1 | |
| return atom_coords_, atom_coords_m | |
| def parse_PDB( | |
| input_path: str, | |
| device: str = "cpu", | |
| chains: list = [], | |
| parse_all_atoms: bool = False, | |
| parse_atoms_with_zero_occupancy: bool = False | |
| ): | |
| """ | |
| input_path : path for the input PDB | |
| device: device for the torch.Tensor | |
| chains: a list specifying which chains need to be parsed; e.g. ["A", "B"] | |
| parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms | |
| parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed | |
| """ | |
| element_list = [ | |
| "H", | |
| "He", | |
| "Li", | |
| "Be", | |
| "B", | |
| "C", | |
| "N", | |
| "O", | |
| "F", | |
| "Ne", | |
| "Na", | |
| "Mg", | |
| "Al", | |
| "Si", | |
| "P", | |
| "S", | |
| "Cl", | |
| "Ar", | |
| "K", | |
| "Ca", | |
| "Sc", | |
| "Ti", | |
| "V", | |
| "Cr", | |
| "Mn", | |
| "Fe", | |
| "Co", | |
| "Ni", | |
| "Cu", | |
| "Zn", | |
| "Ga", | |
| "Ge", | |
| "As", | |
| "Se", | |
| "Br", | |
| "Kr", | |
| "Rb", | |
| "Sr", | |
| "Y", | |
| "Zr", | |
| "Nb", | |
| "Mb", | |
| "Tc", | |
| "Ru", | |
| "Rh", | |
| "Pd", | |
| "Ag", | |
| "Cd", | |
| "In", | |
| "Sn", | |
| "Sb", | |
| "Te", | |
| "I", | |
| "Xe", | |
| "Cs", | |
| "Ba", | |
| "La", | |
| "Ce", | |
| "Pr", | |
| "Nd", | |
| "Pm", | |
| "Sm", | |
| "Eu", | |
| "Gd", | |
| "Tb", | |
| "Dy", | |
| "Ho", | |
| "Er", | |
| "Tm", | |
| "Yb", | |
| "Lu", | |
| "Hf", | |
| "Ta", | |
| "W", | |
| "Re", | |
| "Os", | |
| "Ir", | |
| "Pt", | |
| "Au", | |
| "Hg", | |
| "Tl", | |
| "Pb", | |
| "Bi", | |
| "Po", | |
| "At", | |
| "Rn", | |
| "Fr", | |
| "Ra", | |
| "Ac", | |
| "Th", | |
| "Pa", | |
| "U", | |
| "Np", | |
| "Pu", | |
| "Am", | |
| "Cm", | |
| "Bk", | |
| "Cf", | |
| "Es", | |
| "Fm", | |
| "Md", | |
| "No", | |
| "Lr", | |
| "Rf", | |
| "Db", | |
| "Sg", | |
| "Bh", | |
| "Hs", | |
| "Mt", | |
| "Ds", | |
| "Rg", | |
| "Cn", | |
| "Uut", | |
| "Fl", | |
| "Uup", | |
| "Lv", | |
| "Uus", | |
| "Uuo", | |
| ] | |
| element_list = [item.upper() for item in element_list] | |
| element_dict = dict(zip(element_list, range(1, len(element_list)))) | |
| restype_3to1 = { | |
| "ALA": "A", | |
| "ARG": "R", | |
| "ASN": "N", | |
| "ASP": "D", | |
| "CYS": "C", | |
| "GLN": "Q", | |
| "GLU": "E", | |
| "GLY": "G", | |
| "HIS": "H", | |
| "ILE": "I", | |
| "LEU": "L", | |
| "LYS": "K", | |
| "MET": "M", | |
| "PHE": "F", | |
| "PRO": "P", | |
| "SER": "S", | |
| "THR": "T", | |
| "TRP": "W", | |
| "TYR": "Y", | |
| "VAL": "V", | |
| } | |
| restype_STRtoINT = { | |
| "A": 0, | |
| "C": 1, | |
| "D": 2, | |
| "E": 3, | |
| "F": 4, | |
| "G": 5, | |
| "H": 6, | |
| "I": 7, | |
| "K": 8, | |
| "L": 9, | |
| "M": 10, | |
| "N": 11, | |
| "P": 12, | |
| "Q": 13, | |
| "R": 14, | |
| "S": 15, | |
| "T": 16, | |
| "V": 17, | |
| "W": 18, | |
| "Y": 19, | |
| "X": 20, | |
| } | |
| atom_order = { | |
| "N": 0, | |
| "CA": 1, | |
| "C": 2, | |
| "CB": 3, | |
| "O": 4, | |
| "CG": 5, | |
| "CG1": 6, | |
| "CG2": 7, | |
| "OG": 8, | |
| "OG1": 9, | |
| "SG": 10, | |
| "CD": 11, | |
| "CD1": 12, | |
| "CD2": 13, | |
| "ND1": 14, | |
| "ND2": 15, | |
| "OD1": 16, | |
| "OD2": 17, | |
| "SD": 18, | |
| "CE": 19, | |
| "CE1": 20, | |
| "CE2": 21, | |
| "CE3": 22, | |
| "NE": 23, | |
| "NE1": 24, | |
| "NE2": 25, | |
| "OE1": 26, | |
| "OE2": 27, | |
| "CH2": 28, | |
| "NH1": 29, | |
| "NH2": 30, | |
| "OH": 31, | |
| "CZ": 32, | |
| "CZ2": 33, | |
| "CZ3": 34, | |
| "NZ": 35, | |
| "OXT": 36, | |
| } | |
| if not parse_all_atoms: | |
| atom_types = ["N", "CA", "C", "O"] | |
| else: | |
| atom_types = [ | |
| "N", | |
| "CA", | |
| "C", | |
| "CB", | |
| "O", | |
| "CG", | |
| "CG1", | |
| "CG2", | |
| "OG", | |
| "OG1", | |
| "SG", | |
| "CD", | |
| "CD1", | |
| "CD2", | |
| "ND1", | |
| "ND2", | |
| "OD1", | |
| "OD2", | |
| "SD", | |
| "CE", | |
| "CE1", | |
| "CE2", | |
| "CE3", | |
| "NE", | |
| "NE1", | |
| "NE2", | |
| "OE1", | |
| "OE2", | |
| "CH2", | |
| "NH1", | |
| "NH2", | |
| "OH", | |
| "CZ", | |
| "CZ2", | |
| "CZ3", | |
| "NZ", | |
| ] | |
| atoms = parsePDB(input_path) | |
| if not parse_atoms_with_zero_occupancy: | |
| atoms = atoms.select("occupancy > 0") | |
| if chains: | |
| str_out = "" | |
| for item in chains: | |
| str_out += " chain " + item + " or" | |
| atoms = atoms.select(str_out[1:-3]) | |
| protein_atoms = atoms.select("protein") | |
| backbone = protein_atoms.select("backbone") | |
| other_atoms = atoms.select("not protein and not water") | |
| water_atoms = atoms.select("water") | |
| CA_atoms = protein_atoms.select("name CA") | |
| CA_resnums = CA_atoms.getResnums() | |
| CA_chain_ids = CA_atoms.getChids() | |
| CA_icodes = CA_atoms.getIcodes() | |
| CA_dict = {} | |
| for i in range(len(CA_resnums)): | |
| code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i] | |
| CA_dict[code] = i | |
| xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32) | |
| xyz_37_m = np.zeros([len(CA_dict), 37], np.int32) | |
| for atom_name in atom_types: | |
| xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name) | |
| xyz_37[:, atom_order[atom_name], :] = xyz | |
| xyz_37_m[:, atom_order[atom_name]] = xyz_m | |
| N = xyz_37[:, atom_order["N"], :] | |
| CA = xyz_37[:, atom_order["CA"], :] | |
| C = xyz_37[:, atom_order["C"], :] | |
| O = xyz_37[:, atom_order["O"], :] | |
| N_m = xyz_37_m[:, atom_order["N"]] | |
| CA_m = xyz_37_m[:, atom_order["CA"]] | |
| C_m = xyz_37_m[:, atom_order["C"]] | |
| O_m = xyz_37_m[:, atom_order["O"]] | |
| mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist | |
| b = CA - N | |
| c = C - CA | |
| a = np.cross(b, c, axis=-1) | |
| CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA | |
| chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32) | |
| R_idx = np.array(CA_resnums, dtype=np.int32) | |
| S = CA_atoms.getResnames() | |
| S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)] | |
| S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32) | |
| X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1) | |
| try: | |
| Y = np.array(other_atoms.getCoords(), dtype=np.float32) | |
| Y_t = list(other_atoms.getElements()) | |
| Y_t = np.array( | |
| [ | |
| element_dict[y_t.upper()] if y_t.upper() in element_list else 0 | |
| for y_t in Y_t | |
| ], | |
| dtype=np.int32, | |
| ) | |
| Y_m = (Y_t != 1) * (Y_t != 0) | |
| Y = Y[Y_m, :] | |
| Y_t = Y_t[Y_m] | |
| Y_m = Y_m[Y_m] | |
| except: | |
| Y = np.zeros([1, 3], np.float32) | |
| Y_t = np.zeros([1], np.int32) | |
| Y_m = np.zeros([1], np.int32) | |
| output_dict = {} | |
| output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32) | |
| output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32) | |
| output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32) | |
| output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32) | |
| output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32) | |
| output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32) | |
| output_dict["chain_labels"] = torch.tensor( | |
| chain_labels, device=device, dtype=torch.int32 | |
| ) | |
| output_dict["chain_letters"] = CA_chain_ids | |
| mask_c = [] | |
| chain_list = list(set(output_dict["chain_letters"])) | |
| chain_list.sort() | |
| for chain in chain_list: | |
| mask_c.append( | |
| torch.tensor( | |
| [chain == item for item in output_dict["chain_letters"]], | |
| device=device, | |
| dtype=bool, | |
| ) | |
| ) | |
| output_dict["mask_c"] = mask_c | |
| output_dict["chain_list"] = chain_list | |
| output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32) | |
| output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32) | |
| output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32) | |
| return output_dict, backbone, other_atoms, CA_icodes, water_atoms | |
| def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms): | |
| device = CB.device | |
| mask_CBY = mask[:, None] * Y_m[None, :] # [A,B] | |
| L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1) | |
| L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0 | |
| nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms] | |
| L2_AB_nn = torch.gather(L2_AB, 1, nn_idx) | |
| D_AB_closest = torch.sqrt(L2_AB_nn[:, 0]) | |
| Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1) | |
| Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1) | |
| Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1) | |
| Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3)) | |
| Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx) | |
| Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx) | |
| Y = torch.zeros( | |
| [CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device | |
| ) | |
| Y_t = torch.zeros( | |
| [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device | |
| ) | |
| Y_m = torch.zeros( | |
| [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device | |
| ) | |
| num_nn_update = Y_tmp.shape[1] | |
| Y[:, :num_nn_update] = Y_tmp | |
| Y_t[:, :num_nn_update] = Y_t_tmp | |
| Y_m[:, :num_nn_update] = Y_m_tmp | |
| return Y, Y_t, Y_m, D_AB_closest | |
| def featurize( | |
| input_dict, | |
| cutoff_for_score=8.0, | |
| use_atom_context=True, | |
| number_of_ligand_atoms=16, | |
| model_type="protein_mpnn", | |
| ): | |
| output_dict = {} | |
| if model_type == "ligand_mpnn": | |
| mask = input_dict["mask"] | |
| Y = input_dict["Y"] | |
| Y_t = input_dict["Y_t"] | |
| Y_m = input_dict["Y_m"] | |
| N = input_dict["X"][:, 0, :] | |
| CA = input_dict["X"][:, 1, :] | |
| C = input_dict["X"][:, 2, :] | |
| b = CA - N | |
| c = C - CA | |
| a = torch.cross(b, c, axis=-1) | |
| CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA | |
| Y, Y_t, Y_m, D_XY = get_nearest_neighbours( | |
| CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms | |
| ) | |
| mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0] | |
| output_dict["mask_XY"] = mask_XY[None,] | |
| if "side_chain_mask" in list(input_dict): | |
| output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,] | |
| output_dict["Y"] = Y[None,] | |
| output_dict["Y_t"] = Y_t[None,] | |
| output_dict["Y_m"] = Y_m[None,] | |
| if not use_atom_context: | |
| output_dict["Y_m"] = 0.0 * output_dict["Y_m"] | |
| elif ( | |
| model_type == "per_residue_label_membrane_mpnn" | |
| or model_type == "global_label_membrane_mpnn" | |
| ): | |
| output_dict["membrane_per_residue_labels"] = input_dict[ | |
| "membrane_per_residue_labels" | |
| ][None,] | |
| R_idx_list = [] | |
| count = 0 | |
| R_idx_prev = -100000 | |
| for R_idx in list(input_dict["R_idx"]): | |
| if R_idx_prev == R_idx: | |
| count += 1 | |
| R_idx_list.append(R_idx + count) | |
| R_idx_prev = R_idx | |
| R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device) | |
| output_dict["R_idx"] = R_idx_renumbered[None,] | |
| output_dict["R_idx_original"] = input_dict["R_idx"][None,] | |
| output_dict["chain_labels"] = input_dict["chain_labels"][None,] | |
| output_dict["S"] = input_dict["S"][None,] | |
| output_dict["chain_mask"] = input_dict["chain_mask"][None,] | |
| output_dict["mask"] = input_dict["mask"][None,] | |
| output_dict["X"] = input_dict["X"][None,] | |
| if "xyz_37" in list(input_dict): | |
| output_dict["xyz_37"] = input_dict["xyz_37"][None,] | |
| output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,] | |
| return output_dict | |