Honzus24's picture
initial commit
7968cb0
import torch
import torch.nn.functional as F
import numpy as np
from collections.abc import Mapping, Sequence
# Thanks for StructTrans
# https://github.com/jingraham/neurips19-graph-protein-design
def nan_to_num(tensor, nan=0.0):
idx = torch.isnan(tensor)
tensor[idx] = nan
return tensor
def _normalize(tensor, dim=-1):
return nan_to_num(
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
def cal_dihedral(X, eps=1e-7):
dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N...
U = _normalize(dX, dim=-1)
u_0 = U[:,:-2,:] # CA-N, C-CA, N-C,...
u_1 = U[:,1:-1,:] # C-CA, N-C, CA-N, ... 0, psi_{i}, omega_{i}, phi_{i+1} or 0, tau_{i},...
u_2 = U[:,2:,:] # N-C, CA-N, C-CA, ...
n_0 = _normalize(torch.cross(u_0, u_1), dim=-1)
n_1 = _normalize(torch.cross(u_1, u_2), dim=-1)
cosD = (n_0 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
v = _normalize(torch.cross(n_0, n_1), dim=-1)
D = torch.sign((-v* u_1).sum(-1)) * torch.acos(cosD) # TODO: sign
return D
def _dihedrals(X, dihedral_type=0, eps=1e-7):
B, N, _, _ = X.shape
# psi, omega, phi
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) # ['N', 'CA', 'C', 'O']
D = cal_dihedral(X)
D = F.pad(D, (1,2), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
Dihedral_Angle_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
# alpha, beta, gamma
dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N...
U = _normalize(dX, dim=-1)
u_0 = U[:,:-2,:] # CA-N, C-CA, N-C,...
u_1 = U[:,1:-1,:] # C-CA, N-C, CA-N, ...
cosD = (u_0*u_1).sum(-1) # alpha_{i}, gamma_{i}, beta_{i+1}
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.acos(cosD)
D = F.pad(D, (1,2), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
Angle_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
D_features = torch.cat((Dihedral_Angle_features, Angle_features), 2)
return D_features
def _hbonds(X, E_idx, mask_neighbors, eps=1E-3):
X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2)))
X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0)
X_atoms['H'] = X_atoms['N'] + _normalize(
_normalize(X_atoms['N'] - X_atoms['C_prev'], -1)
+ _normalize(X_atoms['N'] - X_atoms['CA'], -1)
, -1)
def _distance(X_a, X_b):
return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1)
def _inv_distance(X_a, X_b):
return 1. / (_distance(X_a, X_b) + eps)
U = (0.084 * 332) * (
_inv_distance(X_atoms['O'], X_atoms['N'])
+ _inv_distance(X_atoms['C'], X_atoms['H'])
- _inv_distance(X_atoms['O'], X_atoms['H'])
- _inv_distance(X_atoms['C'], X_atoms['N'])
)
HB = (U < -0.5).type(torch.float32)
neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx)
return neighbor_HB
def _rbf(D, num_rbf):
D_min, D_max, D_count = 0., 20., num_rbf
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu.view([1,1,1,-1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
def _get_rbf(A, B, E_idx=None, num_rbf=16):
if E_idx is not None:
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
D_A_B_neighbors = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
RBF_A_B = _rbf(D_A_B_neighbors, num_rbf)
else:
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,:,None,:])**2,-1) + 1e-6) #[B, L, L]
RBF_A_B = _rbf(D_A_B, num_rbf)
return RBF_A_B
def _get_dist(A, B, E_idx=None, num_rbf=None):
if E_idx is not None:
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,None,:,:])**2,-1) + 1e-6) #[B, L, L]
D_A_B = gather_edges(D_A_B[:,:,:,None], E_idx)[:,:,:,0] #[B,L,K]
else:
D_A_B = torch.sqrt(torch.sum((A[:,:,None,:] - B[:,:,None,:])**2,-1) + 1e-6) #[B, L, L]
return D_A_B
def _orientations_coarse_gl(X, E_idx, eps=1e-6):
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N...
U = _normalize(dX, dim=-1)
u_0, u_1 = U[:,:-2,:], U[:,1:-1,:]
n_0 = _normalize(torch.cross(u_0, u_1), dim=-1)
b_1 = _normalize(u_0 - u_1, dim=-1)
n_0 = n_0[:,::3,:]
b_1 = b_1[:,::3,:]
X = X[:,::3,:]
O = torch.stack((b_1, n_0, torch.cross(b_1, n_0)), 2)
O = O.view(list(O.shape[:2]) + [9])
O = F.pad(O, (0,0,0,1), 'constant', 0)
O_neighbors = gather_nodes(O, E_idx)
X_neighbors = gather_nodes(X, E_idx)
O = O.view(list(O.shape[:2]) + [3,3]).unsqueeze(2)
O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3,3])
dX = X_neighbors - X.unsqueeze(-2)
dU = torch.matmul(O, dX.unsqueeze(-1)).squeeze(-1)
R = torch.matmul(O.transpose(-1,-2), O_neighbors)
feat = torch.cat((_normalize(dU, dim=-1), _quaternions(R)), dim=-1)
return feat
def _orientations_coarse_gl_tuple(X, E_idx, eps=1e-6):
V = X.clone()
X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
dX = X[:,1:,:] - X[:,:-1,:] # CA-N, C-CA, N-C, CA-N...
U = _normalize(dX, dim=-1)
u_0, u_1 = U[:,:-2,:], U[:,1:-1,:]
n_0 = _normalize(torch.cross(u_0, u_1), dim=-1)
b_1 = _normalize(u_0 - u_1, dim=-1)
n_0 = n_0[:,::3,:]
b_1 = b_1[:,::3,:]
X = X[:,::3,:]
Q = torch.stack((b_1, n_0, torch.cross(b_1, n_0)), 2)
Q = Q.view(list(Q.shape[:2]) + [9])
Q = F.pad(Q, (0,0,0,1), 'constant', 0)
Q_neighbors = gather_nodes(Q, E_idx)
X_neighbors = gather_nodes(V[:,:,1,:], E_idx)
N_neighbors = gather_nodes(V[:,:,0,:], E_idx)
C_neighbors = gather_nodes(V[:,:,2,:], E_idx)
O_neighbors = gather_nodes(V[:,:,3,:], E_idx)
Q = Q.view(list(Q.shape[:2]) + [3,3]).unsqueeze(2)
Q_neighbors = Q_neighbors.view(list(Q_neighbors.shape[:3]) + [3,3])
dX = torch.stack([X_neighbors,N_neighbors,C_neighbors,O_neighbors], dim=3) - X[:,:,None,None,:]
dU = torch.matmul(Q[:,:,:,None,:,:], dX[...,None]).squeeze(-1)
B, N, K = dU.shape[:3]
E_direct = _normalize(dU, dim=-1)
E_direct = E_direct.reshape(B, N, K,-1)
R = torch.matmul(Q.transpose(-1,-2), Q_neighbors)
q = _quaternions(R)
dX_inner = V[:,:,[0,2,3],:] - X.unsqueeze(-2)
dU_inner = torch.matmul(Q, dX_inner.unsqueeze(-1)).squeeze(-1)
dU_inner = _normalize(dU_inner, dim=-1)
V_direct = dU_inner.reshape(B,N,-1)
return V_direct, E_direct, q
def gather_edges(edges, neighbor_idx):
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
return torch.gather(edges, 2, neighbors)
def gather_nodes(nodes, neighbor_idx):
neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1)) # [4, 317, 30]-->[4, 9510]
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2)) # [4, 9510, dim]
neighbor_features = torch.gather(nodes, 1, neighbors_flat) # [4, 9510, dim]
return neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1]) # [4, 317, 30, 128]
def _quaternions(R):
diag = torch.diagonal(R, dim1=-2, dim2=-1)
Rxx, Ryy, Rzz = diag.unbind(-1)
# magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
# Rxx - Ryy - Rzz,
# - Rxx + Ryy - Rzz,
# - Rxx - Ryy + Rzz
# ], -1)))
magnitudes = torch.abs(1 + torch.stack([
Rxx - Ryy - Rzz,
-Rxx + Ryy - Rzz,
- Rxx - Ryy + Rzz
],-1))
magnitudes[magnitudes == 0.0] = 1e-12
magnitudes = 0.5 * torch.sqrt(magnitudes)
_R = lambda i,j: R[:,:,:,i,j]
signs = torch.sign(torch.stack([
_R(2,1) - _R(1,2),
_R(0,2) - _R(2,0),
_R(1,0) - _R(0,1)
], -1))
xyz = signs * magnitudes
w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
Q = torch.cat((xyz, w), -1)
return _normalize(Q, dim=-1)
def cuda(obj, *args, **kwargs):
"""
Transfer any nested container of tensors to CUDA.
"""
if hasattr(obj, "cuda"):
return obj.cuda(*args, **kwargs)
elif isinstance(obj, Mapping):
return type(obj)({k: cuda(v, *args, **kwargs) for k, v in obj.items()})
elif isinstance(obj, Sequence):
return type(obj)(cuda(x, *args, **kwargs) for x in obj)
elif isinstance(obj, np.ndarray):
return torch.tensor(obj, *args, **kwargs)
raise TypeError("Can't transfer object type `%s`" % type(obj))