Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import copy | |
| from torch_geometric.nn import radius_graph, knn_graph | |
| class PositionalEncodings(nn.Module): | |
| def __init__(self, num_embeddings): | |
| super(PositionalEncodings, self).__init__() | |
| self.num_embeddings = num_embeddings | |
| def forward(self, E_idx): | |
| # i-j | |
| frequency = torch.exp(torch.arange(0, self.num_embeddings, 2, dtype=torch.float32) * -(np.log(10000.0) / self.num_embeddings)).to(E_idx.device) | |
| angles = E_idx.unsqueeze(-1) * frequency | |
| E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
| return E | |
| class ProteinFeatures(nn.Module): | |
| def __init__(self, num_positional_embeddings=16, num_rbf=16, top_k=8, features_type='backbone', direction='forward'): | |
| """ Extract protein features """ | |
| super(ProteinFeatures, self).__init__() | |
| self.top_k = top_k | |
| self.num_rbf = num_rbf | |
| self.num_positional_embeddings = num_positional_embeddings | |
| self.direction = direction | |
| # Feature types | |
| self.features_type = features_type | |
| self.feature_dimensions = num_positional_embeddings + num_rbf + 7 | |
| # Positional encoding | |
| self.pe = PositionalEncodings(num_positional_embeddings) | |
| def _rbf(self, D): | |
| # Distance radial basis function | |
| D_min, D_max, D_count = 0., 20., self.num_rbf | |
| #D_mu = torch.linspace(D_min, D_max, D_count).cuda() | |
| D_mu = torch.linspace(D_min, D_max, D_count).to(D.device) | |
| D_mu = D_mu.view([1,1,1,-1]) | |
| D_sigma = (D_max - D_min) / D_count | |
| D_expand = torch.unsqueeze(D, -1) | |
| RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) | |
| return RBF.squeeze(0).squeeze(0) | |
| def _quaternions(self, R): | |
| """ Convert a batch of 3D rotations [R] to quaternions [Q] | |
| R [...,3,3] | |
| Q [...,4] | |
| """ | |
| # Simple Wikipedia version | |
| # en.wikipedia.org/wiki/Rotation_matrix#Quaternion | |
| # For other options see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix | |
| diag = torch.diagonal(R, dim1=-2, dim2=-1) | |
| Rxx, Ryy, Rzz = diag.unbind(-1) | |
| magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([ | |
| Rxx - Ryy - Rzz, | |
| - Rxx + Ryy - Rzz, | |
| - Rxx - Ryy + Rzz | |
| ], -1))) | |
| _R = lambda i,j: R[:,i,j] | |
| signs = torch.sign(torch.stack([ | |
| _R(2,1) - _R(1,2), | |
| _R(0,2) - _R(2,0), | |
| _R(1,0) - _R(0,1) | |
| ], -1)) | |
| xyz = signs * magnitudes | |
| # The relu enforces a non-negative trace | |
| w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2. | |
| Q = torch.cat((xyz, w), -1) | |
| Q = F.normalize(Q, dim=-1) | |
| # Axis of rotation | |
| # Replace bad rotation matrices with identity | |
| # I = torch.eye(3).view((1,1,1,3,3)) | |
| # I = I.expand(*(list(R.shape[:3]) + [-1,-1])) | |
| # det = ( | |
| # R[:,:,:,0,0] * (R[:,:,:,1,1] * R[:,:,:,2,2] - R[:,:,:,1,2] * R[:,:,:,2,1]) | |
| # - R[:,:,:,0,1] * (R[:,:,:,1,0] * R[:,:,:,2,2] - R[:,:,:,1,2] * R[:,:,:,2,0]) | |
| # + R[:,:,:,0,2] * (R[:,:,:,1,0] * R[:,:,:,2,1] - R[:,:,:,1,1] * R[:,:,:,2,0]) | |
| # ) | |
| # det_mask = torch.abs(det.unsqueeze(-1).unsqueeze(-1)) | |
| # R = det_mask * R + (1 - det_mask) * I | |
| # DEBUG | |
| # https://math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix | |
| # Columns of this are in rotation plane | |
| # A = R - I | |
| # v1, v2 = A[:,:,:,:,0], A[:,:,:,:,1] | |
| # axis = F.normalize(torch.cross(v1, v2), dim=-1) | |
| return Q | |
| def _contacts(self, D_neighbors, E_idx, mask_neighbors, cutoff=8): | |
| """ Contacts """ | |
| D_neighbors = D_neighbors.unsqueeze(-1) | |
| neighbor_C = mask_neighbors * (D_neighbors < cutoff).type(torch.float32) | |
| return neighbor_C | |
| def _hbonds(self, X, E_idx, mask_neighbors, eps=1E-3): | |
| """ Hydrogen bonds and contact map | |
| """ | |
| X_atoms = dict(zip(['N', 'CA', 'C', 'O'], torch.unbind(X, 2))) | |
| # Virtual hydrogens | |
| X_atoms['C_prev'] = F.pad(X_atoms['C'][:,1:,:], (0,0,0,1), 'constant', 0) | |
| X_atoms['H'] = X_atoms['N'] + F.normalize( | |
| F.normalize(X_atoms['N'] - X_atoms['C_prev'], -1) | |
| + F.normalize(X_atoms['N'] - X_atoms['CA'], -1) | |
| , -1) | |
| def _distance(X_a, X_b): | |
| return torch.norm(X_a[:,None,:,:] - X_b[:,:,None,:], dim=-1) | |
| def _inv_distance(X_a, X_b): | |
| return 1. / (_distance(X_a, X_b) + eps) | |
| # DSSP vacuum electrostatics model | |
| U = (0.084 * 332) * ( | |
| _inv_distance(X_atoms['O'], X_atoms['N']) | |
| + _inv_distance(X_atoms['C'], X_atoms['H']) | |
| - _inv_distance(X_atoms['O'], X_atoms['H']) | |
| - _inv_distance(X_atoms['C'], X_atoms['N']) | |
| ) | |
| HB = (U < -0.5).type(torch.float32) | |
| neighbor_HB = mask_neighbors * gather_edges(HB.unsqueeze(-1), E_idx) | |
| # print(HB) | |
| # HB = F.sigmoid(U) | |
| # U_np = U.cpu().data.numpy() | |
| # # plt.matshow(np.mean(U_np < -0.5, axis=0)) | |
| # plt.matshow(HB[0,:,:]) | |
| # plt.colorbar() | |
| # plt.show() | |
| # D_CA = _distance(X_atoms['CA'], X_atoms['CA']) | |
| # D_CA = D_CA.cpu().data.numpy() | |
| # plt.matshow(D_CA[0,:,:] < contact_D) | |
| # # plt.colorbar() | |
| # plt.show() | |
| # exit(0) | |
| return neighbor_HB | |
| def _AD_features(self, X, eps=1e-6): | |
| # Shifted slices of unit vectors | |
| dX = X[:,1:,:] - X[:,:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:,:-2,:] | |
| u_1 = U[:,1:-1,:] | |
| u_0 = U[:,2:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) | |
| n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) | |
| # Bond angle calculation | |
| cosA = -(u_1 * u_0).sum(-1) | |
| cosA = torch.clamp(cosA, -1+eps, 1-eps) | |
| A = torch.acos(cosA) | |
| # Angle between normals | |
| cosD = (n_2 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
| # Backbone features | |
| AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2) | |
| return F.pad(AD_features, (0,0,1,2), 'constant', 0) | |
| def _orientations_coarse(self, X, edge_index, residue_batch, eps=1e-6): | |
| # Shifted slices of unit vectors | |
| dX = X[1:,:] - X[:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:-1,:] | |
| u_1 = U[1:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) | |
| row, col = edge_index # (E,) , (E,) | |
| # Build relative orientations | |
| o_1 = F.normalize(u_2 - u_1, dim=-1) | |
| O = torch.cat([o_1, n_2, torch.cross(o_1, n_2)], dim=-1) | |
| set_zeros_index = torch.cumsum(residue_batch.bincount(), dim=0)[:-1] | |
| #O[set_zeros_index-1] = 0 | |
| #O[set_zeros_index-2] = 0 | |
| O = F.pad(O, (0,0,1,1), 'constant', 0) | |
| # Re-view as rotation matrices | |
| O = O.view(list(O.shape[:1]) + [3,3]) | |
| # Rotate into local reference frames | |
| dX = X[col] - X[row] | |
| dU = torch.matmul(O.reshape(O.shape[0],3,3)[col], dX.unsqueeze(-1)).squeeze(-1) | |
| dU = F.normalize(dU, dim=-1) | |
| R = torch.matmul(O[row], O[col].transpose(-1,-2)) | |
| Q = self._quaternions(R) | |
| return torch.cat((dU,Q), dim=-1) | |
| def _dihedrals(self, X, eps=1e-7): | |
| # First 3 coordinates are N, CA, C | |
| X = X[:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3) | |
| # Shifted slices of unit vectors | |
| dX = X[:,1:,:] - X[:,:-1,:] | |
| U = F.normalize(dX, dim=-1) | |
| u_2 = U[:,:-2,:] | |
| u_1 = U[:,1:-1,:] | |
| u_0 = U[:,2:,:] | |
| # Backbone normals | |
| n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1) | |
| n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1) | |
| # Angle between normals | |
| cosD = (n_2 * n_1).sum(-1) | |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) | |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) | |
| D = F.pad(D, (3,0), 'constant', 0) | |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) | |
| phi, psi, omega = torch.unbind(D,-1) | |
| D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) | |
| return D_features | |
| def forward(self, pos_ligand_coarse, edit_residue, X, S_id, batch): | |
| """ Featurize coordinates as an attributed graph """ | |
| X_ca = X[:,1,:] | |
| edge_index = knn_graph(X_ca, k=self.top_k, batch=batch, flow='target_to_source') | |
| edge_length = torch.norm(X_ca[edge_index[0]] - X_ca[edge_index[1]], dim=1) | |
| RBF = self._rbf(edge_length) | |
| E_idx = S_id[edge_index[1]] - S_id[edge_index[0]] | |
| E_positional = self.pe(E_idx) | |
| O_features = self._orientations_coarse(X_ca, edge_index, batch) | |
| E = torch.cat([E_positional, RBF, O_features], -1) | |
| # additional edge index | |
| row = torch.arange(len(edit_residue)).to(X.device)[edit_residue] | |
| col = torch.cat([torch.ones(edit_residue[batch==s].sum(), dtype=torch.long)*s for s in range(batch.max().item()+1)]).to(X.device) | |
| edge_length_new = torch.norm(X_ca[row] - pos_ligand_coarse[col], dim=1) | |
| RBF = self._rbf(edge_length_new) | |
| E_new = torch.cat([torch.zeros(len(row), 16, device=X.device), RBF, torch.zeros(len(row), 7, device=X.device)], -1) | |
| return E, edge_index, edge_length, torch.cat([row.unsqueeze(0), (col+len(X)).unsqueeze(0)], 0), E_new | |