| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class DihedralFeatures(nn.Module): |
| def __init__(self, node_embed_dim): |
| """ Embed dihedral angle features. """ |
| super(DihedralFeatures, self).__init__() |
| |
| node_in = 6 |
| |
| self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True) |
| self.norm_nodes = Normalize(node_embed_dim) |
|
|
| def forward(self, X): |
| """ Featurize coordinates as an attributed graph """ |
| with torch.no_grad(): |
| V = self._dihedrals(X) |
| V = V.squeeze(1) |
| V = self.node_embedding(V) |
| V = self.norm_nodes(V) |
| return V |
|
|
| @staticmethod |
| def _dihedrals(X, eps=1e-7,): |
| |
| if len(X.shape) == 4: |
| X = X[..., :3, :].reshape(X.shape[0], 3*X.shape[1], 3) |
| else: |
| X = X[:, :3, :] |
|
|
| |
| dX = X[:,1:,:] - X[:,:-1,:] |
| U = F.normalize(dX, dim=-1) |
| u_2 = U[:,:-2,:] |
| u_1 = U[:,1:-1,:] |
| u_0 = U[:,2:,:] |
| |
| n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1) |
| n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1) |
|
|
| |
| cosD = (n_2 * n_1).sum(-1) |
| cosD = torch.clamp(cosD, -1+eps, 1-eps) |
| D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD) |
|
|
| |
| D = F.pad(D, (1,2), 'constant', 0) |
| D = D.view((D.size(0), int(D.size(1)/3), 3)) |
|
|
| |
| D_features = torch.cat((torch.cos(D), torch.sin(D)), 2) |
| return D_features |
|
|
|
|
| class Normalize(nn.Module): |
| def __init__(self, features, epsilon=1e-6): |
| super(Normalize, self).__init__() |
| self.gain = nn.Parameter(torch.ones(features)) |
| self.bias = nn.Parameter(torch.zeros(features)) |
| self.epsilon = epsilon |
|
|
| def forward(self, x, dim=-1): |
| mu = x.mean(dim, keepdim=True) |
| sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon) |
| gain = self.gain |
| bias = self.bias |
| |
| if dim != -1: |
| shape = [1] * len(mu.size()) |
| shape[dim] = self.gain.size()[0] |
| gain = gain.view(shape) |
| bias = bias.view(shape) |
| return gain * (x - mu) / (sigma + self.epsilon) + bias |