Pocket-Gen / models /protein_features.py
Zaixi's picture
1
dcacefd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from torch_geometric.nn import radius_graph, knn_graph
class PositionalEncodings(nn.Module):
def __init__(self, num_embeddings):
super(PositionalEncodings, self).__init__()
self.num_embeddings = num_embeddings
def forward(self, E_idx):
# i-j
frequency = torch.exp(torch.arange(0, self.num_embeddings, 2, dtype=torch.float32) * -(np.log(10000.0) / self.num_embeddings)).to(E_idx.device)
angles = E_idx.unsqueeze(-1) * frequency
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
return E
class ProteinFeatures(nn.Module):
def __init__(self, num_positional_embeddings=16, num_rbf=16, top_k=8, features_type='backbone', direction='forward'):
""" Extract protein features """
super(ProteinFeatures, self).__init__()
self.top_k = top_k
self.num_rbf = num_rbf
self.num_positional_embeddings = num_positional_embeddings
self.direction = direction
# Feature types
self.features_type = features_type
self.feature_dimensions = num_positional_embeddings + num_rbf + 7
# Positional encoding
self.pe = PositionalEncodings(num_positional_embeddings)
def _rbf(self, D):
# Distance radial basis function
D_min, D_max, D_count = 0., 20., self.num_rbf
#D_mu = torch.linspace(D_min, D_max, D_count).cuda()
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu.view([1,1,1,-1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF.squeeze(0).squeeze(0)
def _quaternions(self, R):
""" Convert a batch of 3D rotations [R] to quaternions [Q]
R [...,3,3]
Q [...,4]
"""
# Simple Wikipedia version
# en.wikipedia.org/wiki/Rotation_matrix#Quaternion
# For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
diag = torch.diagonal(R, dim1=-2, dim2=-1)
Rxx, Ryy, Rzz = diag.unbind(-1)
magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
Rxx - Ryy - Rzz,
- Rxx + Ryy - Rzz,
- Rxx - Ryy + Rzz
], -1)))
_R = lambda i,j: R[:,i,j]
signs = torch.sign(torch.stack([
_R(2,1) - _R(1,2),
_R(0,2) - _R(2,0),
_R(1,0) - _R(0,1)
], -1))
xyz = signs * magnitudes
# The relu enforces a non-negative trace
w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.
Q = torch.cat((xyz, w), -1)
Q = F.normalize(Q, dim=-1)
# Axis of rotation
# Replace bad rotation matrices with identity
# I = torch.eye(3).view((1,1,1,3,3))
# I = I.expand(*(list(R.shape[:3]) + [-1,-1]))
# det = (
# R[:,:,:,0,0] * (R[:,:,:,1,1] * R[:,:,:,2,2] - R[:,:,:,1,2] * R[:,:,:,2,1])
# - R[:,:,:,0,1] * (R[:,:,:,1,0] * R[:,:,:,2,2] - R[:,:,:,1,2] * R[:,:,:,2,0])
# + R[:,:,:,0,2] * (R[:,:,:,1,0] * R[:,:,:,2,1] - R[:,:,:,1,1] * R[:,:,:,2,0])
# )
# det_mask = torch.abs(det.unsqueeze(-1).unsqueeze(-1))
# R = det_mask * R + (1 - det_mask) * I
# DEBUG
# https://math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
# Columns of this are in rotation plane
# A = R - I
# v1, v2 = A[:,:,:,:,0], A[:,:,:,:,1]
# axis = F.normalize(torch.cross(v1, v2), dim=-1)
return Q
def _contacts(self, D_neighbors, E_idx, mask_neighbors, cutoff=8):
""" Contacts """
D_neighbors = D_neighbors.unsqueeze(-1)
neighbor_C = mask_neighbors * (D_neighbors < cutoff).type(torch.float32)
return neighbor_C
def _hbonds(self, X, E_idx, mask_neighbors, eps=1E-3):
""" Hydrogen bonds and contact map
"""
X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2)))
# Virtual hydrogens
X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0)
X_atoms['H'] = X_atoms['N'] + F.normalize(
F.normalize(X_atoms['N'] - X_atoms['C_prev'], -1)
+ F.normalize(X_atoms['N'] - X_atoms['CA'], -1)
, -1)
def _distance(X_a, X_b):
return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1)
def _inv_distance(X_a, X_b):
return 1. / (_distance(X_a, X_b) + eps)
# DSSP vacuum electrostatics model
U = (0.084 * 332) * (
_inv_distance(X_atoms['O'], X_atoms['N'])
+ _inv_distance(X_atoms['C'], X_atoms['H'])
- _inv_distance(X_atoms['O'], X_atoms['H'])
- _inv_distance(X_atoms['C'], X_atoms['N'])
)
HB = (U < -0.5).type(torch.float32)
neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx)
# print(HB)
# HB = F.sigmoid(U)
# U_np = U.cpu().data.numpy()
# # plt.matshow(np.mean(U_np < -0.5, axis=0))
# plt.matshow(HB[0,:,:])
# plt.colorbar()
# plt.show()
# D_CA = _distance(X_atoms['CA'], X_atoms['CA'])
# D_CA = D_CA.cpu().data.numpy()
# plt.matshow(D_CA[0,:,:] < contact_D)
# # plt.colorbar()
# plt.show()
# exit(0)
return neighbor_HB
def _AD_features(self, X, eps=1e-6):
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
# Bond angle calculation
cosA = -(u_1 * u_0).sum(-1)
cosA = torch.clamp(cosA, -1+eps, 1-eps)
A = torch.acos(cosA)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
# Backbone features
AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
return F.pad(AD_features, (0,0,1,2), 'constant', 0)
def _orientations_coarse(self, X, edge_index, residue_batch, eps=1e-6):
# Shifted slices of unit vectors
dX = X[1:,:] - X[:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:-1,:]
u_1 = U[1:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
row, col = edge_index # (E,) , (E,)
# Build relative orientations
o_1 = F.normalize(u_2 - u_1, dim=-1)
O = torch.cat([o_1, n_2, torch.cross(o_1, n_2)], dim=-1)
set_zeros_index = torch.cumsum(residue_batch.bincount(), dim=0)[:-1]
#O[set_zeros_index-1] = 0
#O[set_zeros_index-2] = 0
O = F.pad(O, (0,0,1,1), 'constant', 0)
# Re-view as rotation matrices
O = O.view(list(O.shape[:1]) + [3,3])
# Rotate into local reference frames
dX = X[col] - X[row]
dU = torch.matmul(O.reshape(O.shape[0],3,3)[col], dX.unsqueeze(-1)).squeeze(-1)
dU = F.normalize(dU, dim=-1)
R = torch.matmul(O[row], O[col].transpose(-1,-2))
Q = self._quaternions(R)
return torch.cat((dU,Q), dim=-1)
def _dihedrals(self, X, eps=1e-7):
# First 3 coordinates are N, CA, C
X = X[:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
# Shifted slices of unit vectors
dX = X[:,1:,:] - X[:,:-1,:]
U = F.normalize(dX, dim=-1)
u_2 = U[:,:-2,:]
u_1 = U[:,1:-1,:]
u_0 = U[:,2:,:]
# Backbone normals
n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)
# Angle between normals
cosD = (n_2 * n_1).sum(-1)
cosD = torch.clamp(cosD, -1+eps, 1-eps)
D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
D = F.pad(D, (3,0), 'constant', 0)
D = D.view((D.size(0), int(D.size(1)/3), 3))
phi, psi, omega = torch.unbind(D,-1)
D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
return D_features
def forward(self, pos_ligand_coarse, edit_residue, X, S_id, batch):
""" Featurize coordinates as an attributed graph """
X_ca = X[:,1,:]
edge_index = knn_graph(X_ca, k=self.top_k, batch=batch, flow='target_to_source')
edge_length = torch.norm(X_ca[edge_index[0]] - X_ca[edge_index[1]], dim=1)
RBF = self._rbf(edge_length)
E_idx = S_id[edge_index[1]] - S_id[edge_index[0]]
E_positional = self.pe(E_idx)
O_features = self._orientations_coarse(X_ca, edge_index, batch)
E = torch.cat([E_positional, RBF, O_features], -1)
# additional edge index
row = torch.arange(len(edit_residue)).to(X.device)[edit_residue]
col = torch.cat([torch.ones(edit_residue[batch==s].sum(), dtype=torch.long)*s for s in range(batch.max().item()+1)]).to(X.device)
edge_length_new = torch.norm(X_ca[row] - pos_ligand_coarse[col], dim=1)
RBF = self._rbf(edge_length_new)
E_new = torch.cat([torch.zeros(len(row), 16, device=X.device), RBF, torch.zeros(len(row), 7, device=X.device)], -1)
return E, edge_index, edge_length, torch.cat([row.unsqueeze(0), (col+len(X)).unsqueeze(0)], 0), E_new