Honzus24's picture
initial commit
7968cb0
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_sum, scatter_min
from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl
from src.modules.graphtrans_module import *
from src.modules.gca_module import Local_Module, Global_Module
class GCA_Model(nn.Module):
def __init__(self, args, **kwargs):
""" Graph labeling network """
super(GCA_Model, self).__init__()
self.node_features = args.hidden
self.edge_features = args.hidden
self.hidden = args.hidden
self.top_k = args.k_neighbors
self.num_rbf = 16
self.num_positional_embeddings = 16
vocab = args.vocab_size
num_encoder_layers = args.num_encoder_layers
num_decoder_layers = args.num_decoder_layers
is_attention = args.is_attention
dropout = args.dropout
# node_in, edge_in = 6, 39 - 16
node_in, edge_in = 12, 39 - 16
self.embeddings = PositionalEncodings(self.num_positional_embeddings)
self.node_embedding = nn.Linear(node_in, self.node_features, bias=True)
self.edge_embedding = nn.Linear(edge_in, self.edge_features, bias=True)
self.norm_nodes = Normalize(self.node_features)
self.norm_edges = Normalize(self.edge_features)
self.W_v = nn.Linear(self.node_features, self.hidden, bias=True)
self.W_e = nn.Linear(self.edge_features, self.hidden, bias=True)
self.W_f = nn.Linear(self.edge_features, self.hidden, bias=True)
self.W_s = nn.Embedding(vocab, self.hidden)
self.encoder_layers = nn.ModuleList([])
for _ in range(num_encoder_layers):
self.encoder_layers.append(nn.ModuleList([
Local_Module(self.hidden, self.hidden*2, is_attention=is_attention, dropout=dropout),
Global_Module(self.hidden, self.hidden*2, dropout=dropout)
]))
self.decoder_layers = nn.ModuleList([])
for _ in range(num_decoder_layers):
self.decoder_layers.append(
Local_Module(self.hidden, self.hidden*3, is_attention=is_attention, dropout=dropout)
)
self.W_out = nn.Linear(self.hidden, vocab, bias=True)
self._init_params()
self.encode_t = 0
self.decode_t = 0
def _init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _autoregressive_mask(self, E_idx):
N_nodes = E_idx.size(1)
ii = torch.arange(N_nodes)
ii = ii.view((1, -1, 1)).to(E_idx.device)
mask = E_idx - ii < 0
mask = mask.type(torch.float32)
return mask
def _get_encoder_mask(self, idx, mask):
mask_attend = gather_nodes(mask.unsqueeze(-1), idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
return mask_attend
def _get_decoder_mask(self, idx, mask):
mask_attend = self._autoregressive_mask(idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
return mask_bw, mask_fw
def _full_dist(self, X, mask, top_k=30, eps=1E-6):
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
D_max, _ = torch.max(D, -1, keepdim=True)
D_adjust = D + (1. - mask_2D) * D_max
D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
return D_neighbors, E_idx
def _encoder_network(self, h_V, h_P, h_F, P_idx, F_idx, mask):
'''
h_V: [batch, num_nodes, 128]
h_P: [batch, num_nodes, K, 128]
h_F: [batch, num_nodes, num_nodes, 128]
P_idx: [batch, num_nodes, K]
F_idx: [batch, num_nodes, num_nodes]
mask: [batch, num_nodes]
'''
P_idx_mask_attend = self._get_encoder_mask(P_idx, mask) # part
F_idx_mask_attend = self._get_encoder_mask(F_idx, mask) # full
for (local_layer, global_layer) in self.encoder_layers:
# local_layer
h_EV_local = cat_neighbors_nodes(h_V, h_P, P_idx) # [4, 312, 30, 256]
h_V = local_layer(h_V, h_EV_local, mask_V=mask, mask_attend=P_idx_mask_attend)
# global layer
h_EV_global = cat_neighbors_nodes(h_V, h_F, F_idx)
h_V = h_V + global_layer(h_V, h_EV_global, mask_V=mask, mask_attend=F_idx_mask_attend)
return h_V
def _get_sv_encoder(self, S, h_V, h_P, P_idx):
h_S = self.W_s(S)
h_PS = cat_neighbors_nodes(h_S, h_P, P_idx)
h_PS_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_P, P_idx)
h_PSV_encoder = cat_neighbors_nodes(h_V, h_PS_encoder, P_idx)
return h_PS, h_PSV_encoder
def _get_features(self, batch):
S, X, mask, chain_mask = batch['S'], batch['X'], batch['mask'], batch['chain_mask']
X_ca = X[:,:,1,:]
D_neighbors, F_idx = self._full_dist(X_ca, mask, 500)
P_idx = F_idx[:, :, :self.top_k].clone()
_V = _dihedrals(X) # node feature
_V = self.norm_nodes(self.node_embedding(_V))
_F = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, F_idx)), -1)
_F = self.norm_edges(self.edge_embedding(_F))
_P = _F[..., :self.top_k, :]
h_V = self.W_v(_V)
h_P, h_F = self.W_e(_P), self.W_f(_F)
batch.update({'S':S,
'h_V': h_V,
'h_P': h_P,
'h_F': h_F,
'P_idx': P_idx,
'F_idx': F_idx,
'mask': mask})
return batch
def sparse_to_dense(self, S, h_V, h_P, edge_idx_P, h_F, edge_idx_F, batch_id):
device = h_V.device
num_nodes = scatter_sum(torch.ones_like(batch_id), batch_id)
batch = num_nodes.shape[0]
N = num_nodes.max()
S_ = torch.zeros([batch, N], device=device).long()
row = batch_id
col = torch.cat([torch.arange(0,n) for n in num_nodes]).to(device)
S_[row, col] = S
S = S_
# node feature
dim_V = h_V.shape[-1]
h_V_ = torch.zeros([batch, N, dim_V], device=device)
row = batch_id
col = torch.cat([torch.arange(0,n) for n in num_nodes]).to(device)
h_V_[row, col] = h_V
h_V = h_V_
mask = torch.zeros([batch, N], device=device)
mask[row, col] = 1
# edge feature
K = 30
dim_P = h_P.shape[-1]
h_P_ = torch.zeros([batch, N, K, dim_P], device=device)
row2 = batch_id[edge_idx_P[0]]
batch_shift, _ = scatter_min(edge_idx_P[0], batch_id[edge_idx_P[0]])
local_dst_idx = edge_idx_P[0] - batch_shift[batch_id[edge_idx_P[0]]]
local_src_idx = edge_idx_P[1] - batch_shift[batch_id[edge_idx_P[1]]]
nn_num = scatter_sum(torch.ones_like(edge_idx_P[0]), edge_idx_P[0])
nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device)
h_P_[row2, local_dst_idx, nn_idx] = h_P
h_P = h_P_
nn_num = scatter_sum(torch.ones_like(edge_idx_P[0]), edge_idx_P[0])
nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device)
P_idx = torch.arange(0, K, device=device).reshape(1,1,K).repeat(batch, N, 1)
P_idx[row2, local_dst_idx, nn_idx] = local_src_idx
# edge feature
K = N
dim_F = h_F.shape[-1]
h_F_ = torch.zeros([batch, N, K, dim_F], device=device)
row2 = batch_id[edge_idx_F[0]]
batch_shift, _ = scatter_min(edge_idx_F[0], batch_id[edge_idx_F[0]])
local_dst_idx = edge_idx_F[0] - batch_shift[batch_id[edge_idx_F[0]]]
local_src_idx = edge_idx_F[1] - batch_shift[batch_id[edge_idx_F[1]]]
nn_num = scatter_sum(torch.ones_like(edge_idx_F[0]), edge_idx_F[0])
nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device)
h_F_[row2, local_dst_idx, nn_idx] = h_F
h_F = h_F_
nn_num = scatter_sum(torch.ones_like(edge_idx_F[0]), edge_idx_F[0])
nn_idx = torch.cat([torch.arange(0,n) for n in nn_num]).to(device)
F_idx = torch.arange(0, K, device=device).reshape(1,1,K).repeat(batch, N, 1)
F_idx[row2, local_dst_idx, nn_idx] = local_src_idx
return S, h_V, h_P, h_F, P_idx,F_idx, mask
def forward(self, batch):
h_V, h_P, h_F, P_idx, F_idx, S, mask = batch['h_V'], batch['h_P'], batch['h_F'], batch['P_idx'], batch['F_idx'], batch['S'], batch['mask']
t1 = time.time()
h_V = self._encoder_network(h_V, h_P, h_F, P_idx, F_idx, mask)
h_PS, h_PSV_encoder = self._get_sv_encoder(S, h_V, h_P, P_idx)
t2 = time.time()
# Decoder
P_idx_mask_bw, P_idx_mask_fw = self._get_decoder_mask(P_idx, mask)
for local_layer in self.decoder_layers:
# local_layer
h_PSV_local = cat_neighbors_nodes(h_V, h_PS, P_idx)
h_PSV_local = P_idx_mask_bw * h_PSV_local + P_idx_mask_fw * h_PSV_encoder
h_V = local_layer(h_V, h_PSV_local, mask_V=mask)
logits = self.W_out(h_V)
log_probs = F.log_softmax(logits, dim=-1)
t3 = time.time()
self.encode_t += t2-t1
self.decode_t += t3-t2
return {'log_probs':log_probs}
def sample(self, h_V, h_P, h_F, P_idx, F_idx, mask=None, temperature=0.1, **kwargs):
t1 = time.time()
h_V = self._encoder_network(h_V, h_P, h_F, P_idx, F_idx, mask)
t2 = time.time()
# Decoder
P_idx_mask_bw, P_idx_mask_fw = self._get_decoder_mask(P_idx, mask)
N_batch, N_nodes = h_V.size(0), h_V.size(1)
h_S = torch.zeros_like(h_V)
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device=h_V.device)
h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.decoder_layers))]
all_probs = []
for t in range(N_nodes):
# Hidden layers
P_idx_t = P_idx[:,t:t+1,:]
h_P_t = h_P[:,t:t+1,:,:]
h_PS_t = cat_neighbors_nodes(h_S, h_P_t, P_idx_t)
h_PSV_encoder_t = P_idx_mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_PS_t, P_idx_t)
for l, local_layer in enumerate(self.decoder_layers):
# local layer
h_PSV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_PS_t, P_idx_t)
h_V_t = h_V_stack[l][:,t:t+1,:]
h_PSV_t = P_idx_mask_bw[:,t:t+1,:,:] * h_PSV_decoder_t + h_PSV_encoder_t
h_V_stack[l+1][:,t,:] = local_layer(
h_V_t, h_PSV_t, mask_V=mask[:, t:t+1]
).squeeze(1)
# Sampling step
h_V_t = h_V_stack[-1][:,t,:]
logits = self.W_out(h_V_t) / temperature
probs = F.softmax(logits, dim=-1)
S_t = torch.multinomial(probs, 1).squeeze(-1)
# Update
h_S[:,t,:] = self.W_s(S_t)
S[:,t] = S_t
all_probs.append(probs)
self.probs = torch.cat(all_probs, dim=0)
t3 = time.time()
self.encode_t += t2-t1
self.decode_t += t3-t2
return S