import time import torch import torch.nn as nn import torch.nn.functional as F from torch_scatter import scatter_sum, scatter_min from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl from src.modules.graphtrans_module import * from src.modules.gca_module import Local_Module, Global_Module class GCA_Model(nn.Module): def __init__(self, args, **kwargs): """ Graph labeling network """ super(GCA_Model, self).__init__() self.node_features = args.hidden self.edge_features = args.hidden self.hidden = args.hidden self.top_k = args.k_neighbors self.num_rbf = 16 self.num_positional_embeddings = 16 vocab = args.vocab_size num_encoder_layers = args.num_encoder_layers num_decoder_layers = args.num_decoder_layers is_attention = args.is_attention dropout = args.dropout # node_in, edge_in = 6, 39 - 16 node_in, edge_in = 12, 39 - 16 self.embeddings = PositionalEncodings(self.num_positional_embeddings) self.node_embedding = nn.Linear(node_in, self.node_features, bias=True) self.edge_embedding = nn.Linear(edge_in, self.edge_features, bias=True) self.norm_nodes = Normalize(self.node_features) self.norm_edges = Normalize(self.edge_features) self.W_v = nn.Linear(self.node_features, self.hidden, bias=True) self.W_e = nn.Linear(self.edge_features, self.hidden, bias=True) self.W_f = nn.Linear(self.edge_features, self.hidden, bias=True) self.W_s = nn.Embedding(vocab, self.hidden) self.encoder_layers = nn.ModuleList([]) for _ in range(num_encoder_layers): self.encoder_layers.append(nn.ModuleList([ Local_Module(self.hidden, self.hidden*2, is_attention=is_attention, dropout=dropout), Global_Module(self.hidden, self.hidden*2, dropout=dropout) ])) self.decoder_layers = nn.ModuleList([]) for _ in range(num_decoder_layers): self.decoder_layers.append( Local_Module(self.hidden, self.hidden*3, is_attention=is_attention, dropout=dropout) ) self.W_out = nn.Linear(self.hidden, vocab, bias=True) self._init_params() self.encode_t = 0 self.decode_t = 0 def _init_params(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def _autoregressive_mask(self, E_idx): N_nodes = E_idx.size(1) ii = torch.arange(N_nodes) ii = ii.view((1, -1, 1)).to(E_idx.device) mask = E_idx - ii < 0 mask = mask.type(torch.float32) return mask def _get_encoder_mask(self, idx, mask): mask_attend = gather_nodes(mask.unsqueeze(-1), idx).squeeze(-1) mask_attend = mask.unsqueeze(-1) * mask_attend return mask_attend def _get_decoder_mask(self, idx, mask): mask_attend = self._autoregressive_mask(idx).unsqueeze(-1) mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1]) mask_bw = mask_1D * mask_attend mask_fw = mask_1D * (1. - mask_attend) return mask_bw, mask_fw def _full_dist(self, X, mask, top_k=30, eps=1E-6): mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps) D_max, _ = torch.max(D, -1, keepdim=True) D_adjust = D + (1. - mask_2D) * D_max D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False) return D_neighbors, E_idx def _encoder_network(self, h_V, h_P, h_F, P_idx, F_idx, mask): ''' h_V: [batch, num_nodes, 128] h_P: [batch, num_nodes, K, 128] h_F: [batch, num_nodes, num_nodes, 128] P_idx: [batch, num_nodes, K] F_idx: [batch, num_nodes, num_nodes] mask: [batch, num_nodes] ''' P_idx_mask_attend = self._get_encoder_mask(P_idx, mask) # part F_idx_mask_attend = self._get_encoder_mask(F_idx, mask) # full for (local_layer, global_layer) in self.encoder_layers: # local_layer h_EV_local = cat_neighbors_nodes(h_V, h_P, P_idx) # [4, 312, 30, 256] h_V = local_layer(h_V, h_EV_local, mask_V=mask, mask_attend=P_idx_mask_attend) # global layer h_EV_global = cat_neighbors_nodes(h_V, h_F, F_idx) h_V = h_V + global_layer(h_V, h_EV_global, mask_V=mask, mask_attend=F_idx_mask_attend) return h_V def _get_sv_encoder(self, S, h_V, h_P, P_idx): h_S = self.W_s(S) h_PS = cat_neighbors_nodes(h_S, h_P, P_idx) h_PS_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_P, P_idx) h_PSV_encoder = cat_neighbors_nodes(h_V, h_PS_encoder, P_idx) return h_PS, h_PSV_encoder def _get_features(self, batch): S, X, mask, chain_mask = batch['S'], batch['X'], batch['mask'], batch['chain_mask'] X_ca = X[:,:,1,:] D_neighbors, F_idx = self._full_dist(X_ca, mask, 500) P_idx = F_idx[:, :, :self.top_k].clone() _V = _dihedrals(X) # node feature _V = self.norm_nodes(self.node_embedding(_V)) _F = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, F_idx)), -1) _F = self.norm_edges(self.edge_embedding(_F)) _P = _F[..., :self.top_k, :] h_V = self.W_v(_V) h_P, h_F = self.W_e(_P), self.W_f(_F) batch.update({'S':S, 'h_V': h_V, 'h_P': h_P, 'h_F': h_F, 'P_idx': P_idx, 'F_idx': F_idx, 'mask': mask}) return batch def sparse_to_dense(self, S, h_V, h_P, edge_idx_P, h_F, edge_idx_F, batch_id): device = h_V.device num_nodes = scatter_sum(torch.ones_like(batch_id), batch_id) batch = num_nodes.shape[0] N = num_nodes.max() S_ = torch.zeros([batch, N], device=device).long() row = batch_id col = torch.cat([torch.arange(0,n) for n in num_nodes]).to(device) S_[row, col] = S S = S_ # node feature dim_V = h_V.shape[-1] h_V_ = torch.zeros([batch, N, dim_V], device=device) row = batch_id col = torch.cat([torch.arange(0,n) for n in num_nodes]).to(device) h_V_[row, col] = h_V h_V = h_V_ mask = torch.zeros([batch, N], device=device) mask[row, col] = 1 # edge feature K = 30 dim_P = h_P.shape[-1] h_P_ = torch.zeros([batch, N, K, dim_P], device=device) row2 = batch_id[edge_idx_P[0]] batch_shift, _ = scatter_min(edge_idx_P[0], batch_id[edge_idx_P[0]]) local_dst_idx = edge_idx_P[0] - batch_shift[batch_id[edge_idx_P[0]]] local_src_idx = edge_idx_P[1] - batch_shift[batch_id[edge_idx_P[1]]] nn_num = scatter_sum(torch.ones_like(edge_idx_P[0]), edge_idx_P[0]) nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device) h_P_[row2, local_dst_idx, nn_idx] = h_P h_P = h_P_ nn_num = scatter_sum(torch.ones_like(edge_idx_P[0]), edge_idx_P[0]) nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device) P_idx = torch.arange(0, K, device=device).reshape(1,1,K).repeat(batch, N, 1) P_idx[row2, local_dst_idx, nn_idx] = local_src_idx # edge feature K = N dim_F = h_F.shape[-1] h_F_ = torch.zeros([batch, N, K, dim_F], device=device) row2 = batch_id[edge_idx_F[0]] batch_shift, _ = scatter_min(edge_idx_F[0], batch_id[edge_idx_F[0]]) local_dst_idx = edge_idx_F[0] - batch_shift[batch_id[edge_idx_F[0]]] local_src_idx = edge_idx_F[1] - batch_shift[batch_id[edge_idx_F[1]]] nn_num = scatter_sum(torch.ones_like(edge_idx_F[0]), edge_idx_F[0]) nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device) h_F_[row2, local_dst_idx, nn_idx] = h_F h_F = h_F_ nn_num = scatter_sum(torch.ones_like(edge_idx_F[0]), edge_idx_F[0]) nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device) F_idx = torch.arange(0, K, device=device).reshape(1,1,K).repeat(batch, N, 1) F_idx[row2, local_dst_idx, nn_idx] = local_src_idx return S, h_V, h_P, h_F, P_idx,F_idx, mask def forward(self, batch): h_V, h_P, h_F, P_idx, F_idx, S, mask = batch['h_V'], batch['h_P'], batch['h_F'], batch['P_idx'], batch['F_idx'], batch['S'], batch['mask'] t1 = time.time() h_V = self._encoder_network(h_V, h_P, h_F, P_idx, F_idx, mask) h_PS, h_PSV_encoder = self._get_sv_encoder(S, h_V, h_P, P_idx) t2 = time.time() # Decoder P_idx_mask_bw, P_idx_mask_fw = self._get_decoder_mask(P_idx, mask) for local_layer in self.decoder_layers: # local_layer h_PSV_local = cat_neighbors_nodes(h_V, h_PS, P_idx) h_PSV_local = P_idx_mask_bw * h_PSV_local + P_idx_mask_fw * h_PSV_encoder h_V = local_layer(h_V, h_PSV_local, mask_V=mask) logits = self.W_out(h_V) log_probs = F.log_softmax(logits, dim=-1) t3 = time.time() self.encode_t += t2-t1 self.decode_t += t3-t2 return {'log_probs':log_probs} def sample(self, h_V, h_P, h_F, P_idx, F_idx, mask=None, temperature=0.1, **kwargs): t1 = time.time() h_V = self._encoder_network(h_V, h_P, h_F, P_idx, F_idx, mask) t2 = time.time() # Decoder P_idx_mask_bw, P_idx_mask_fw = self._get_decoder_mask(P_idx, mask) N_batch, N_nodes = h_V.size(0), h_V.size(1) h_S = torch.zeros_like(h_V) S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=h_V.device) h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.decoder_layers))] all_probs = [] for t in range(N_nodes): # Hidden layers P_idx_t = P_idx[:,t:t+1,:] h_P_t = h_P[:,t:t+1,:,:] h_PS_t = cat_neighbors_nodes(h_S, h_P_t, P_idx_t) h_PSV_encoder_t = P_idx_mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_PS_t, P_idx_t) for l, local_layer in enumerate(self.decoder_layers): # local layer h_PSV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_PS_t, P_idx_t) h_V_t = h_V_stack[l][:,t:t+1,:] h_PSV_t = P_idx_mask_bw[:,t:t+1,:,:] * h_PSV_decoder_t + h_PSV_encoder_t h_V_stack[l+1][:,t,:] = local_layer( h_V_t, h_PSV_t, mask_V=mask[:, t:t+1] ).squeeze(1) # Sampling step h_V_t = h_V_stack[-1][:,t,:] logits = self.W_out(h_V_t) / temperature probs = F.softmax(logits, dim=-1) S_t = torch.multinomial(probs, 1).squeeze(-1) # Update h_S[:,t,:] = self.W_s(S_t) S[:,t] = S_t all_probs.append(probs) self.probs = torch.cat(all_probs, dim=0) t3 = time.time() self.encode_t += t2-t1 self.decode_t += t3-t2 return S