flexpert / Flexpert-Design /src /models /structgnn_model.py
Honzus24's picture
initial commit
7968cb0
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.modules.graphtrans_module import Struct2Seq, cat_neighbors_nodes, gather_nodes, ProteinFeatures
class StructGNN_Model(nn.Module):
def __init__(self, args):
super(StructGNN_Model, self).__init__()
self.args = args
self.device = 'cuda:0'
self.smoothing = args.smoothing
self.model = Struct2Seq(
vocab=args.vocab_size,
num_letters=args.vocab_size,
node_features=args.hidden,
edge_features=args.hidden,
hidden_dim=args.hidden,
k_neighbors=args.k_neighbors,
protein_features=args.features,
dropout=args.dropout,
use_mpnn=True)
self.featurizer = ProteinFeatures(
args.hidden, args.hidden, top_k=args.k_neighbors,
features_type=args.features,
dropout=args.dropout
)
def _get_features(self, batch):
X, lengths, mask = batch['X'], batch['lengths'], batch['mask']
V, E, E_idx = self.featurizer(X, lengths, mask)
batch.update({'V':V,
"E":E,
"E_idx":E_idx,
"mask": mask})
return batch
def forward(self, batch):
""" Graph-conditioned sequence model """
S, V, E, E_idx, mask = batch['S'], batch['V'], batch['E'], batch['E_idx'], batch['mask']
# Prepare node and edge embeddings
h_V = self.model.W_v(V)
h_E = self.model.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.model.encoder_layers:
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)
# Concatenate sequence embeddings for autoregressive decoder
h_S = self.model.W_s(S)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
# Build encoder embeddings
h_ES_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_ESV_encoder = cat_neighbors_nodes(h_V, h_ES_encoder, E_idx)
# Decoder uses masked self-attention
mask_attend = self.model._autoregressive_mask(E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
if self.model.forward_attention_decoder:
mask_fw = mask_1D * (1. - mask_attend)
h_ESV_encoder_fw = h_ESV_encoder
else:
h_ESV_encoder_fw = 0
for layer in self.model.decoder_layers:
# Masked positions attend to encoder information, unmasked see.
h_ESV_dec = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV_dec + mask_fw*h_ESV_encoder_fw
h_V = layer(h_V, h_ESV, mask_V=mask)
logits = self.model.W_out(h_V)
log_probs = F.log_softmax(logits, dim=-1)
return {'log_probs': log_probs}
def sample(self, V, E, E_idx, mask, chain_mask=None, temperature=1.0):
""" Autoregressive decoding of a model """
# Prepare node and edge embeddings
h_V = self.model.W_v(V)
h_E = self.model.W_e(E)
# Encoder is unmasked self-attention
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
mask_attend = mask.unsqueeze(-1) * mask_attend
for layer in self.model.encoder_layers:
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
h_V = layer(h_V, h_EV, mask_V=mask, mask_attend=mask_attend)
# Decoder alternates masked self-attention
mask_attend = self.model._autoregressive_mask(E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)
N_batch, N_nodes = V.size(0), V.size(1)
h_S = torch.zeros_like(h_V)
S = torch.zeros((N_batch, N_nodes), dtype=torch.int64, device = self.device)
h_V_stack = [h_V] + [torch.zeros_like(h_V) for _ in range(len(self.model.decoder_layers))]
all_probs = []
for t in range(N_nodes):
# Hidden layers
E_idx_t = E_idx[:,t:t+1,:]
h_E_t = h_E[:,t:t+1,:,:]
# use cache
h_ES_enc_t = cat_neighbors_nodes(torch.zeros_like(h_S), h_E_t, E_idx_t)
h_ESV_encoder_t = mask_fw[:,t:t+1,:,:] * cat_neighbors_nodes(h_V, h_ES_enc_t, E_idx_t)
for l, layer in enumerate(self.model.decoder_layers):
# Updated relational features for future states
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t) # [batch, 1, K, 384]
h_ESV_t = mask_bw[:,t:t+1,:,:] * h_ESV_decoder_t + h_ESV_encoder_t # [batch, 1, K, 384]
h_V_t = h_V_stack[l][:,t:t+1,:] # [batch, 1 128]
h_V_stack[l+1][:,t,:] = layer(
h_V_t, h_ESV_t, mask_V=mask[:,t:t+1]
).squeeze(1) # [1, 128]
# Sampling step
h_V_t = h_V_stack[-1][:,t,:]
logits = self.model.W_out(h_V_t) / temperature
probs = F.softmax(logits, dim=-1)
S_t = torch.multinomial(probs, 1).squeeze(-1)
all_probs.append(probs)
# Update
h_S[:,t,:] = self.model.W_s(S_t)
S[:,t] = S_t
self.probs = torch.cat(all_probs, dim=0)
return S