Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Portions of this file were adapted from the open source code | |
| # https://github.com/RosettaCommons/trRosetta2/tree/main | |
| # Modifications were made to loader() and get_coords6d() functions: | |
| # 1. addition of `set_diagonal` flag. | |
| # 2. addition of `allow_missing_residue_coords` flag. | |
| # Original License information below. | |
| # MIT License | |
| # Copyright (c) 2021 Ivan Anishchenko, Minkyung Baek, Naozumi Hiranuma, Ian Humphrey | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import numpy as np | |
| import scipy | |
| import scipy.spatial | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| PDB_LOADER_PARAMS_DEFAULT = { # trRosetta-v2 | |
| "DMIN" : 2.0, | |
| "DMAX" : 20.0, | |
| "DBINS" : 36, | |
| "ABINS" : 36, | |
| "WMIN" : 0.8, | |
| "LMIN" : 150, | |
| "LMAX" : 400, | |
| "EPOCHS" : 10, | |
| "NCPU" : 8, | |
| "SLICE" : "CONT", | |
| "contact_bin_cutoff": (0, 11) # both inclusive | |
| } | |
| # calculate dihedral angles defined by 4 sets of points | |
| def get_dihedrals(a, b, c, d): | |
| b0 = -1.0*(b - a) | |
| b1 = c - b | |
| b2 = d - c | |
| b1 /= np.linalg.norm(b1, axis=-1)[:,None] | |
| v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1 | |
| w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1 | |
| x = np.sum(v*w, axis=-1) | |
| y = np.sum(np.cross(b1, v)*w, axis=-1) | |
| return np.arctan2(y, x) | |
| # calculate planar angles defined by 3 sets of points | |
| def get_angles(a, b, c): | |
| v = a - b | |
| v /= np.linalg.norm(v, axis=-1)[:,None] | |
| w = c - b | |
| w /= np.linalg.norm(w, axis=-1)[:,None] | |
| x = np.sum(v*w, axis=1) | |
| return np.arccos(x) | |
| def nonans(arr): | |
| return not (arr != arr).any() | |
| # get 6d coordinates from x,y,z coords of N,Ca,C atoms | |
| def get_coords6d(xyz, dmax, allow_missing_residue_coords=False): | |
| nres = xyz.shape[1] | |
| # three anchor atoms | |
| N = xyz[0] | |
| Ca = xyz[1] | |
| C = xyz[2] | |
| if not nonans(xyz): | |
| assert allow_missing_residue_coords, "Missing residue coordinates" | |
| logger.warning("PDB xyz contains NaNs!! Loading anyway...") | |
| # recreate Cb given N,Ca,C | |
| b = Ca - N | |
| c = C - Ca | |
| a = np.cross(b, c) | |
| Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca | |
| # fast neighbors search to collect all | |
| # Cb-Cb pairs within dmax | |
| kdCb = scipy.spatial.cKDTree(Cb) | |
| indices = kdCb.query_ball_tree(kdCb, dmax) | |
| # indices of contacting residues | |
| idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T | |
| idx0 = idx[0] | |
| idx1 = idx[1] | |
| # Cb-Cb distance matrix | |
| dist6d = np.full((nres, nres), 999.9) | |
| dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1) | |
| # matrix of Ca-Cb-Cb-Ca dihedrals | |
| omega6d = np.zeros((nres, nres)) | |
| omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) | |
| # matrix of polar coord theta | |
| theta6d = np.zeros((nres, nres)) | |
| theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) | |
| # matrix of polar coord phi | |
| phi6d = np.zeros((nres, nres)) | |
| phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1]) | |
| if nonans(xyz): | |
| # Set dist values to nan where there's no pdb 3d coordinates data | |
| nan_residues = ((Cb != Cb).sum(-1) > 0) | |
| nan_residues_mask = (nan_residues[None] | nan_residues[None].T) | |
| dist6d[nan_residues_mask] = np.nan | |
| omega6d[nan_residues_mask] = np.nan | |
| theta6d[nan_residues_mask] = np.nan | |
| phi6d[nan_residues_mask] = np.nan | |
| return dist6d, omega6d, theta6d, phi6d | |
| def parse_PDB(pdb_path, atoms=['N','CA','C'], chain=None, return_aligned_seq=False): | |
| ''' | |
| input: pdb_path = PDB filename | |
| atoms = atoms to extract (optional) | |
| output: (length, atoms, coords=(x,y,z)), sequence | |
| ''' | |
| xyz,seq,doubles,min_resn,max_resn = {},{},{},np.inf,-np.inf | |
| with open(pdb_path,"rb") as f: | |
| for line in f: | |
| line = line.decode("utf-8","ignore").rstrip() | |
| if line[:6] == "HETATM" and line[17:17+3] == "MSE": | |
| line = line.replace("HETATM","ATOM ") | |
| line = line.replace("MSE","MET") | |
| if line[:4] == "ATOM": | |
| ch = line[21:22] | |
| if ch == chain or chain is None: | |
| atom = line[12:12+4].strip() | |
| resi = line[17:17+3] | |
| resi_extended = line[16:17 + 3].strip() | |
| resn = line[22:22+5].strip() | |
| x,y,z = [float(line[i:(i+8)]) for i in [30,38,46]] | |
| if resn[-1].isalpha(): resa,resn = resn[-1],int(resn[:-1])-1 | |
| else: resa,resn = "",int(resn)-1 | |
| if resn < min_resn: min_resn = resn | |
| if resn > max_resn: max_resn = resn | |
| if resn not in xyz: xyz[resn] = {} | |
| if resa not in xyz[resn]: xyz[resn][resa] = {} | |
| if resn not in seq: seq[resn] = {} | |
| if resa not in seq[resn]: | |
| seq[resn][resa] = resi | |
| elif seq[resn][resa] != resi_extended: | |
| # doubles mark locations in the pdb file where multi residue entries are | |
| # present. There's a known bug in TmAlign binary that doesn't read / skip | |
| # these entries, so we mark them to create a sequence that is aligned with | |
| # gap tokens in such locations. | |
| doubles[resn] = True | |
| if atom not in xyz[resn][resa]: | |
| xyz[resn][resa][atom] = np.array([x,y,z]) | |
| # convert to numpy arrays, fill in missing values | |
| seq_,xyz_,aligned_seq_ = [],[],[] | |
| for resn in range(min_resn,max_resn+1): | |
| if resn in seq: | |
| for k in sorted(seq[resn]): | |
| seq_.append(aa_3_N.get(seq[resn][k],20)) | |
| aligned_seq_.append(seq_[-1]) if resn not in doubles else aligned_seq_.append(20) | |
| else: | |
| seq_.append(20) | |
| if resn in xyz: | |
| for k in sorted(xyz[resn]): | |
| for atom in atoms: | |
| if atom in xyz[resn][k]: xyz_.append(xyz[resn][k][atom]) | |
| else: xyz_.append(np.full(3,np.nan)) | |
| else: | |
| for atom in atoms: xyz_.append(np.full(3,np.nan)) | |
| res = [np.array(xyz_).reshape(-1,len(atoms),3), N_to_AA(np.array(seq_))] | |
| if return_aligned_seq: | |
| res += [N_to_AA(np.array(aligned_seq_))] | |
| return res | |
| def N_to_AA(x): | |
| # [[0,1,2,3]] -> ["ARND"] | |
| x = np.array(x); | |
| if x.ndim == 1: x = x[None] | |
| return ["".join([aa_N_1.get(a,"-") for a in y]) for y in x] | |
| alpha_1 = list("ARNDCQEGHILKMFPSTWYV-") | |
| alpha_3 = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE', | |
| 'LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL','GAP'] | |
| aa_3_N = {a:n for n,a in enumerate(alpha_3)} | |
| aa_N_1 = {n:a for n,a in enumerate(alpha_1)} | |
| def loader(pdb_path, params, chain=None, set_diagonal=False, allow_missing_residue_coords=False): | |
| """ | |
| Args: | |
| pdb_path: path to pdb file | |
| params: dict with hyperparams | |
| chain: <unused> | |
| set_diagonal: sets diagonal to specific default values. | |
| """ | |
| orig_xyz, seq = parse_PDB(pdb_path, atoms=['N','CA','C'], chain=None) | |
| xyz = np.transpose(orig_xyz, (1,0,2)) | |
| idx = np.arange(xyz.shape[1]) | |
| # get 6D coords | |
| d,o,t,p = get_coords6d(xyz, params['DMAX'], allow_missing_residue_coords) | |
| no_angles = any([np.isnan(x).any() for x in [o,t,p]]) | |
| # bin 6D coords | |
| if params['DMIN'] == 2.5: | |
| dbins = np.linspace(params['DMIN'], params['DMAX'], params['DBINS']) | |
| else: | |
| dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] | |
| dbins = np.linspace(params['DMIN'] + dstep, params['DMAX'], params['DBINS']) | |
| angle_step = 2.0 * np.pi / params['ABINS'] | |
| ab360 = np.linspace(-np.pi + angle_step, np.pi, params['ABINS']) | |
| phi_last_bin = int(params['ABINS']/2) | |
| if "PHI_BINS" in params: | |
| phi_last_bin = params['PHI_BINS'] | |
| ab180 = np.linspace(angle_step, np.pi, phi_last_bin) | |
| db = np.digitize(d, dbins).astype(np.uint8) # distance | |
| ob = np.digitize(o, ab360).astype(np.uint8) # omega | |
| tb = np.digitize(t, ab360).astype(np.uint8) # theta | |
| pb = np.digitize(p, ab180).astype(np.uint8) # phi | |
| # synchronize 'no contact' bins | |
| ob[db == params['DBINS']] = params['ABINS'] | |
| tb[db == params['DBINS']] = params['ABINS'] | |
| pb[db == params['DBINS']] = phi_last_bin | |
| if set_diagonal: | |
| db[np.eye(db.shape[0]).astype(bool)] = 0 | |
| ob[np.eye(ob.shape[0]).astype(bool)] = int(params['ABINS']/2) | |
| tb[np.eye(tb.shape[0]).astype(bool)] = int(params['ABINS']/2) | |
| pb[np.eye(pb.shape[0]).astype(bool)] = 0 | |
| # (11/20/2021) | |
| # Added masking of this diagonal as well. | |
| # This matches the way contact predictions were trained. | |
| d[np.eye(db.shape[0]).astype(bool)] = 0 | |
| # stack all coords together | |
| if no_angles: | |
| c6d = db[None, :] # only distogram, unsqueezed(0) | |
| else: | |
| c6d = np.stack([db,ob,tb,pb], axis=0) | |
| # slice long chains | |
| L = idx.shape[0] | |
| start,stop,nres = 0,L,L | |
| sel = np.arange(0, L) | |
| if L > params['LMAX']: | |
| if params['SLICE'] == 'CONT': | |
| # slice continuously | |
| nres = np.random.randint(params['LMIN'], params['LMAX']) | |
| logger.warning("Slicing long chain %s into %s residues", pdb_path, nres) | |
| # nres = params['LMAX'] | |
| start = np.random.randint(L-nres+1) | |
| # start = 0 | |
| stop = start + nres | |
| sel = np.arange(start,stop) | |
| if (idx[sel][1:]-idx[sel][:-1]==1).all(): #check if idx has no gaps | |
| feed_dict = { | |
| "idx" : idx[sel], | |
| "idx_raw" : idx, | |
| "coords6d" : c6d[:,sel,:][:,:,sel], | |
| "sel" : sel, | |
| "dist" : d[sel,:][:,sel], | |
| "fullseq" : [seq[0]], | |
| 'no_angles': no_angles, | |
| 'xyz' : orig_xyz | |
| } | |
| else: | |
| feed_dict = { | |
| "idx" : idx[sel], | |
| "idx_raw" : idx, | |
| "coords6d" : np.zeros(1), | |
| "sel" : sel, | |
| "dist" : d[sel,:][:,sel], | |
| "fullseq" : [seq[0]], | |
| 'no_angles': no_angles, | |
| 'xyz' : orig_xyz | |
| } | |
| return feed_dict | |