Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from diffab.modules.common.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm | |
| from diffab.modules.common.layers import mask_zero, LayerNorm | |
| from diffab.utils.protein.constants import BBHeavyAtom | |
| def _alpha_from_logits(logits, mask, inf=1e5): | |
| """ | |
| Args: | |
| logits: Logit matrices, (N, L_i, L_j, num_heads). | |
| mask: Masks, (N, L). | |
| Returns: | |
| alpha: Attention weights. | |
| """ | |
| N, L, _, _ = logits.size() | |
| mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *) | |
| mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *) | |
| logits = torch.where(mask_pair, logits, logits - inf) | |
| alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads) | |
| alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) | |
| return alpha | |
| def _heads(x, n_heads, n_ch): | |
| """ | |
| Args: | |
| x: (..., num_heads * num_channels) | |
| Returns: | |
| (..., num_heads, num_channels) | |
| """ | |
| s = list(x.size())[:-1] + [n_heads, n_ch] | |
| return x.view(*s) | |
| class GABlock(nn.Module): | |
| def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8, | |
| num_value_points=8, num_heads=12, bias=False): | |
| super().__init__() | |
| self.node_feat_dim = node_feat_dim | |
| self.pair_feat_dim = pair_feat_dim | |
| self.value_dim = value_dim | |
| self.query_key_dim = query_key_dim | |
| self.num_query_points = num_query_points | |
| self.num_value_points = num_value_points | |
| self.num_heads = num_heads | |
| # Node | |
| self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) | |
| self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) | |
| self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias) | |
| # Pair | |
| self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias) | |
| # Spatial | |
| self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)), | |
| requires_grad=True) | |
| self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) | |
| self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) | |
| self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias) | |
| # Output | |
| self.out_transform = nn.Linear( | |
| in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + ( | |
| num_heads * num_value_points * (3 + 3 + 1)), | |
| out_features=node_feat_dim, | |
| ) | |
| self.layer_norm_1 = LayerNorm(node_feat_dim) | |
| self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), | |
| nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), | |
| nn.Linear(node_feat_dim, node_feat_dim)) | |
| self.layer_norm_2 = LayerNorm(node_feat_dim) | |
| def _node_logits(self, x): | |
| query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) | |
| key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) | |
| logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) * | |
| (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads) | |
| return logits_node | |
| def _pair_logits(self, z): | |
| logits_pair = self.proj_pair_bias(z) | |
| return logits_pair | |
| def _spatial_logits(self, R, t, x): | |
| N, L, _ = t.size() | |
| # Query | |
| query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points, | |
| 3) # (N, L, n_heads * n_pnts, 3) | |
| query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3) | |
| query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) | |
| # Key | |
| key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points, | |
| 3) # (N, L, 3, n_heads * n_pnts) | |
| key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3) | |
| key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) | |
| # Q-K Product | |
| sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads) | |
| gamma = F.softplus(self.spatial_coef) | |
| logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points))) | |
| / 2) # (N, L, L, n_heads) | |
| return logits_spatial | |
| def _pair_aggregation(self, alpha, z): | |
| N, L = z.shape[:2] | |
| feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C) | |
| feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C) | |
| return feat_p2n.reshape(N, L, -1) | |
| def _node_aggregation(self, alpha, x): | |
| N, L = x.shape[:2] | |
| value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch) | |
| feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch) | |
| feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch) | |
| return feat_node.reshape(N, L, -1) | |
| def _spatial_aggregation(self, alpha, R, t, x): | |
| N, L, _ = t.size() | |
| value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points, | |
| 3) # (N, L, n_heads * n_v_pnts, 3) | |
| value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points, | |
| 3)) # (N, L, n_heads, n_v_pnts, 3) | |
| aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \ | |
| value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3) | |
| aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3) | |
| feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3) | |
| feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts) | |
| feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3) | |
| feat_spatial = torch.cat([ | |
| feat_points.reshape(N, L, -1), | |
| feat_distance.reshape(N, L, -1), | |
| feat_direction.reshape(N, L, -1), | |
| ], dim=-1) | |
| return feat_spatial | |
| def forward(self, R, t, x, z, mask): | |
| """ | |
| Args: | |
| R: Frame basis matrices, (N, L, 3, 3_index). | |
| t: Frame external (absolute) coordinates, (N, L, 3). | |
| x: Node-wise features, (N, L, F). | |
| z: Pair-wise features, (N, L, L, C). | |
| mask: Masks, (N, L). | |
| Returns: | |
| x': Updated node-wise features, (N, L, F). | |
| """ | |
| # Attention logits | |
| logits_node = self._node_logits(x) | |
| logits_pair = self._pair_logits(z) | |
| logits_spatial = self._spatial_logits(R, t, x) | |
| # Summing logits up and apply `softmax`. | |
| logits_sum = logits_node + logits_pair + logits_spatial | |
| alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads) | |
| # Aggregate features | |
| feat_p2n = self._pair_aggregation(alpha, z) | |
| feat_node = self._node_aggregation(alpha, x) | |
| feat_spatial = self._spatial_aggregation(alpha, R, t, x) | |
| # Finally | |
| feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F) | |
| feat_all = mask_zero(mask.unsqueeze(-1), feat_all) | |
| x_updated = self.layer_norm_1(x + feat_all) | |
| x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated)) | |
| return x_updated | |
| class GAEncoder(nn.Module): | |
| def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}): | |
| super(GAEncoder, self).__init__() | |
| self.blocks = nn.ModuleList([ | |
| GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt) | |
| for _ in range(num_layers) | |
| ]) | |
| def forward(self, R, t, res_feat, pair_feat, mask): | |
| for i, block in enumerate(self.blocks): | |
| res_feat = block(R, t, res_feat, pair_feat, mask) | |
| return res_feat | |