Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from collections.abc import Mapping, Sequence | |
| # Thanks for StructTrans | |
| # https://github.com/jingraham/neurips19-graph-protein-design | |
| def nan_to_num(tensor, nan=0.0): | |
| idx = torch.isnan(tensor) | |
| tensor[idx] = nan | |
| return tensor | |
| def _normalize(tensor, dim=-1): | |
| return nan_to_num( | |
| torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) | |
| def cal_dihedral(X, eps=1e-7): | |
| dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N... | |
| U = _normalize(dX, dim=-1) | |
| u_0 = U[:,:-2,:] # CA-N, C-CA, N-C,... | |
| u_1 = U[:,1:-1,:] # C-CA, N-C, CA-N, ... 0, psi_{i}, omega_{i}, phi_{i+1} or 0, tau_{i},... | |
| u_2 = U[:,2:,:] # N-C, CA-N, C-CA, ... | |
| n_0 = _normalize(torch.cross(u_0, u_1), dim=-1) | |
| n_1 = _normalize(torch.cross(u_1, u_2), dim=-1) | |
| cosD = (n_0 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| v = _normalize(torch.cross(n_0, n_1), dim=-1) | |
| D = torch.sign((-v* u_1).sum(-1)) * torch.acos(cosD) # TODO: sign | |
| return D | |
| def _dihedrals(X, dihedral_type=0, eps=1e-7): | |
| B, N, _, _ = X.shape | |
| # psi, omega, phi | |
| X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) # ['N', 'CA', 'C', 'O'] | |
| D = cal_dihedral(X) | |
| D = F.pad(D, (1,2), 'constant', 0) | |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
| Dihedral_Angle_features = torch.cat((torch.cos(D), torch.sin(D)), 2) | |
| # alpha, beta, gamma | |
| dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N... | |
| U = _normalize(dX, dim=-1) | |
| u_0 = U[:,:-2,:] # CA-N, C-CA, N-C,... | |
| u_1 = U[:,1:-1,:] # C-CA, N-C, CA-N, ... | |
| cosD = (u_0*u_1).sum(-1) # alpha_{i}, gamma_{i}, beta_{i+1} | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.acos(cosD) | |
| D = F.pad(D, (1,2), 'constant', 0) | |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
| Angle_features = torch.cat((torch.cos(D), torch.sin(D)), 2) | |
| D_features = torch.cat((Dihedral_Angle_features, Angle_features), 2) | |
| return D_features | |
| def _hbonds(X, E_idx, mask_neighbors, eps=1E-3): | |
| X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2))) | |
| X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0) | |
| X_atoms['H'] = X_atoms['N'] + _normalize( | |
| _normalize(X_atoms['N'] - X_atoms['C_prev'], -1) | |
| + _normalize(X_atoms['N'] - X_atoms['CA'], -1) | |
| , -1) | |
| def _distance(X_a, X_b): | |
| return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1) | |
| def _inv_distance(X_a, X_b): | |
| return 1. / (_distance(X_a, X_b) + eps) | |
| U = (0.084 * 332) * ( | |
| _inv_distance(X_atoms['O'], X_atoms['N']) | |
| + _inv_distance(X_atoms['C'], X_atoms['H']) | |
| - _inv_distance(X_atoms['O'], X_atoms['H']) | |
| - _inv_distance(X_atoms['C'], X_atoms['N']) | |
| ) | |
| HB = (U < -0.5).type(torch.float32) | |
| neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx) | |
| return neighbor_HB | |
| def _rbf(D, num_rbf): | |
| D_min, D_max, D_count = 0., 20., num_rbf | |
| D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) | |
| D_mu = D_mu.view([1,1,1,-1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) | |
| return RBF | |
| def _get_rbf(A, B, E_idx=None, num_rbf=16): | |
| if E_idx is not None: | |
| D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L] | |
| D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K] | |
| RBF_A_B = _rbf(D_A_B_neighbors, num_rbf) | |
| else: | |
| D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,:,None,:])**2,-1) + 1e-6) #[B, L, L] | |
| RBF_A_B = _rbf(D_A_B, num_rbf) | |
| return RBF_A_B | |
| def _get_dist(A, B, E_idx=None, num_rbf=None): | |
| if E_idx is not None: | |
| D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L] | |
| D_A_B = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K] | |
| else: | |
| D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,:,None,:])**2,-1) + 1e-6) #[B, L, L] | |
| return D_A_B | |
| def _orientations_coarse_gl(X, E_idx, eps=1e-6): | |
| X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
| dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N... | |
| U = _normalize(dX, dim=-1) | |
| u_0, u_1 = U[:,:-2,:], U[:,1:-1,:] | |
| n_0 = _normalize(torch.cross(u_0, u_1), dim=-1) | |
| b_1 = _normalize(u_0 - u_1, dim=-1) | |
| n_0 = n_0[:,::3,:] | |
| b_1 = b_1[:,::3,:] | |
| X = X[:,::3,:] | |
| O = torch.stack((b_1, n_0, torch.cross(b_1, n_0)), 2) | |
| O = O.view(list(O.shape[:2]) + [9]) | |
| O = F.pad(O, (0,0,0,1), 'constant', 0) | |
| O_neighbors = gather_nodes(O, E_idx) | |
| X_neighbors = gather_nodes(X, E_idx) | |
| O = O.view(list(O.shape[:2]) + [3,3]).unsqueeze(2) | |
| O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3]) | |
| dX = X_neighbors - X.unsqueeze(-2) | |
| dU = torch.matmul(O, dX.unsqueeze(-1)).squeeze(-1) | |
| R = torch.matmul(O.transpose(-1,-2), O_neighbors) | |
| feat = torch.cat((_normalize(dU, dim=-1), _quaternions(R)), dim=-1) | |
| return feat | |
| def _orientations_coarse_gl_tuple(X, E_idx, eps=1e-6): | |
| V = X.clone() | |
| X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
| dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N... | |
| U = _normalize(dX, dim=-1) | |
| u_0, u_1 = U[:,:-2,:], U[:,1:-1,:] | |
| n_0 = _normalize(torch.cross(u_0, u_1), dim=-1) | |
| b_1 = _normalize(u_0 - u_1, dim=-1) | |
| n_0 = n_0[:,::3,:] | |
| b_1 = b_1[:,::3,:] | |
| X = X[:,::3,:] | |
| Q = torch.stack((b_1, n_0, torch.cross(b_1, n_0)), 2) | |
| Q = Q.view(list(Q.shape[:2]) + [9]) | |
| Q = F.pad(Q, (0,0,0,1), 'constant', 0) | |
| Q_neighbors = gather_nodes(Q, E_idx) | |
| X_neighbors = gather_nodes(V[:,:,1,:], E_idx) | |
| N_neighbors = gather_nodes(V[:,:,0,:], E_idx) | |
| C_neighbors = gather_nodes(V[:,:,2,:], E_idx) | |
| O_neighbors = gather_nodes(V[:,:,3,:], E_idx) | |
| Q = Q.view(list(Q.shape[:2]) + [3,3]).unsqueeze(2) | |
| Q_neighbors = Q_neighbors.view(list(Q_neighbors.shape[:3]) + [3,3]) | |
| dX = torch.stack([X_neighbors,N_neighbors,C_neighbors,O_neighbors], dim=3) - X[:,:,None,None,:] | |
| dU = torch.matmul(Q[:,:,:,None,:,:], dX[...,None]).squeeze(-1) | |
| B, N, K = dU.shape[:3] | |
| E_direct = _normalize(dU, dim=-1) | |
| E_direct = E_direct.reshape(B, N, K,-1) | |
| R = torch.matmul(Q.transpose(-1,-2), Q_neighbors) | |
| q = _quaternions(R) | |
| dX_inner = V[:,:,[0,2,3],:] - X.unsqueeze(-2) | |
| dU_inner = torch.matmul(Q, dX_inner.unsqueeze(-1)).squeeze(-1) | |
| dU_inner = _normalize(dU_inner, dim=-1) | |
| V_direct = dU_inner.reshape(B,N,-1) | |
| return V_direct, E_direct, q | |
| def gather_edges(edges, neighbor_idx): | |
| neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1)) | |
| return torch.gather(edges, 2, neighbors) | |
| def gather_nodes(nodes, neighbor_idx): | |
| neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) # [4, 317, 30]-->[4, 9510] | |
| neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # [4, 9510, dim] | |
| neighbor_features = torch.gather(nodes, 1, neighbors_flat) # [4, 9510, dim] | |
| return neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) # [4, 317, 30, 128] | |
| def _quaternions(R): | |
| diag = torch.diagonal(R, dim1=-2, dim2=-1) | |
| Rxx, Ryy, Rzz = diag.unbind(-1) | |
| # magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([ | |
| # Rxx - Ryy - Rzz, | |
| # - Rxx + Ryy - Rzz, | |
| # - Rxx - Ryy + Rzz | |
| # ], -1))) | |
| magnitudes = torch.abs(1 + torch.stack([ | |
| Rxx - Ryy - Rzz, | |
| -Rxx + Ryy - Rzz, | |
| - Rxx - Ryy + Rzz | |
| ],-1)) | |
| magnitudes[magnitudes == 0.0] = 1e-12 | |
| magnitudes = 0.5 * torch.sqrt(magnitudes) | |
| _R = lambda i,j: R[:,:,:,i,j] | |
| signs = torch.sign(torch.stack([ | |
| _R(2,1) - _R(1,2), | |
| _R(0,2) - _R(2,0), | |
| _R(1,0) - _R(0,1) | |
| ], -1)) | |
| xyz = signs * magnitudes | |
| w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2. | |
| Q = torch.cat((xyz, w), -1) | |
| return _normalize(Q, dim=-1) | |
| def cuda(obj, *args, **kwargs): | |
| """ | |
| Transfer any nested container of tensors to CUDA. | |
| """ | |
| if hasattr(obj, "cuda"): | |
| return obj.cuda(*args, **kwargs) | |
| elif isinstance(obj, Mapping): | |
| return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()}) | |
| elif isinstance(obj, Sequence): | |
| return type(obj)(cuda(x, *args, **kwargs) for x in obj) | |
| elif isinstance(obj, np.ndarray): | |
| return torch.tensor(obj, *args, **kwargs) | |
| raise TypeError("Can't transfer object type `%s`" % type(obj)) | |