flexpert / Flexpert-Design /src /models /alphadesign_model.py
Honzus24's picture
initial commit
7968cb0
import torch
import torch.nn as nn
from src.modules.alphadesign_module import ATDecoder, CNNDecoder, CNNDecoder2, StructureEncoder
from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl
class AlphaDesign_Model(nn.Module):
def __init__(self, args, **kwargs):
""" Graph labeling network """
super(AlphaDesign_Model, self).__init__()
self.args = args
node_features = args.node_features
edge_features = args.edge_features
hidden_dim = args.hidden_dim
dropout = args.dropout
num_encoder_layers = args.num_encoder_layers
self.top_k = args.k_neighbors
self.num_rbf = 16
self.num_positional_embeddings = 16
if args.use_new_feat:
node_in, edge_in = 12, 16+7
else:
node_in, edge_in = 6, 16+7
self.node_embedding = nn.Linear(node_in, node_features, bias=True)
self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True)
self.norm_nodes = nn.BatchNorm1d(node_features)
self.norm_edges = nn.BatchNorm1d(edge_features)
self.W_v = nn.Sequential(
nn.Linear(node_features, hidden_dim, bias=True),
nn.LeakyReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.LeakyReLU(),
nn.BatchNorm1d(hidden_dim),
nn.Linear(hidden_dim, hidden_dim, bias=True)
)
self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)
self.encoder = StructureEncoder(hidden_dim, num_encoder_layers, dropout, use_SGT=self.args.use_SGT)
if args.autoregressive:
self.decoder = ATDecoder(args, hidden_dim, dropout)
else:
self.decoder = CNNDecoder(hidden_dim, hidden_dim)
self.decoder2 = CNNDecoder2(hidden_dim, hidden_dim)
# self.chain_embed = nn.Embedding(2,16)
self._init_params()
def forward(self, batch, AT_test = False, return_logit=False):
h_V, h_P, P_idx, batch_id = batch['_V'], batch['_E'], batch['E_idx'], batch['batch_id']
h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
h_V = self.encoder(h_V, h_P, P_idx, batch_id)
log_probs0 = None
if AT_test:
log_probs = self.decoder.sampling(h_V, h_P, P_idx, batch_id)
else:
log_probs0, logits = self.decoder(h_V, batch_id)
log_probs, logits = self.decoder2(h_V, logits, batch_id)
if return_logit:
return {'log_probs': log_probs, 'logits': logits}
else:
return {'log_probs': log_probs, 'log_probs0': log_probs0}
def _init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def _full_dist(self, X, mask, top_k=30, eps=1E-6):
mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)
D_max, _ = torch.max(D, -1, keepdim=True)
D_adjust = D + (1. - mask_2D) * (D_max+1)
D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
return D_neighbors, E_idx
def _get_features(self, batch):
S, score, X, mask = batch['S'], batch['score'], batch['X'], batch['mask']
mask_bool = (mask==1)
B, N, _,_ = X.shape
X_ca = X[:,:,1,:]
D_neighbors, E_idx = self._full_dist(X_ca, mask, self.top_k)
# sequence
S = torch.masked_select(S, mask_bool)
if score is not None:
score = torch.masked_select(score, mask_bool)
# node feature
_V = _dihedrals(X)
if not self.args.use_new_feat:
_V = _V[...,:6]
_V = torch.masked_select(_V, mask_bool.unsqueeze(-1)).reshape(-1,_V.shape[-1])
# edge feature
_E = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, E_idx)), -1) # [4,387,387,23]
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) # 一阶邻居节点的mask: 1代表节点存在, 0代表节点不存在
mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 # 自身的mask*邻居节点的mask
_E = torch.masked_select(_E, mask_attend.unsqueeze(-1)).reshape(-1,_E.shape[-1])
# edge index
shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1)
src = shift.view(B,1,1) + E_idx
src = torch.masked_select(src, mask_attend).view(1,-1)
dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend)
dst = torch.masked_select(dst, mask_attend).view(1,-1)
E_idx = torch.cat((dst, src), dim=0).long()
# 3D point
sparse_idx = mask.nonzero()
X = X[sparse_idx[:,0],sparse_idx[:,1],:,:]
batch_id = sparse_idx[:,0]
mask = torch.masked_select(mask, mask_bool)
batch.update({'X':X,
'S':S,
'score':score,
'_V':_V,
'_E':_E,
'E_idx':E_idx,
'batch_id': batch_id,
'mask':mask})
return batch