import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from torch import Tensor from omegaconf import OmegaConf from src.modules.E3PiFold import GaussianEncoder, TransformerEncoderWithPair from src.tools import gather_nodes, _dihedrals, _get_rbf, _get_dist, _rbf, _orientations_coarse_gl_tuple class E3PiFold(nn.Module): def __init__(self, config) -> None: super().__init__() self.node_embed = nn.Linear(21, config.embed_dim) self.protein_embedder = GaussianEncoder(config.kernel_num, config.embed_dim, config.attention_heads, config.use_dist, config.use_product) self.encoder = TransformerEncoderWithPair( config.encoder_layers, config.embed_dim, config.ffn_embed_dim, config.attention_heads, config.emb_dropout, config.dropout, config.attention_dropout, config.activation_dropout, config.max_seq_len, ) self.predictor = nn.Linear(config.embed_dim, 33) def _full_dist(self, X, mask, top_k=30, eps=1E-6): mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2) dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2) D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps) D_max, _ = torch.max(D, -1, keepdim=True) D_adjust = D + (1. - mask_2D) * (D_max+1) D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False) return D_neighbors, E_idx def _get_features(self, batch): X = batch['X'] X_ca = X[:,:,1,:] D_neighbors, E_idx = self._full_dist(X_ca, batch['mask'], 30) V_angles = _dihedrals(X.float()) V_direct, E_direct, E_angles = _orientations_coarse_gl_tuple(X.float(), E_idx) h_V = torch.cat([V_angles, V_direct], dim=-1).to(X.dtype) batch['h_V'] = h_V return batch def forward(self, batch): ''' X, H, seq_mask ''' X = batch['X'][:,:,1] H = self.node_embed(batch['h_V']) seq_mask = batch['mask'] pair_mask = seq_mask[..., None] * seq_mask[..., None, :] padding_mask = 1 - seq_mask x, graph_attn_bias = self.protein_embedder(X, H, pair_mask) ( encoder_rep, encoder_pair_rep, delta_encoder_pair_rep, x_norm, delta_encoder_pair_rep_norm, ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias, pair_mask=pair_mask) logits = self.predictor(x) log_probs = F.log_softmax(logits, dim=-1) return {'log_probs': log_probs} if __name__ == '__main__': B, N, dim = 16, 512, 768 X = torch.randn(B, N, 3) H = torch.randn(B, N, dim) seq_mask = (torch.ones(B, N)>0.5).float() config = {'encoder_layers': 12, 'kernel_num':16, 'embed_dim': 768, 'ffn_embed_dim': 3072, 'attention_heads': 8, 'emb_dropout': 0.1, 'dropout': 0.1, 'attention_dropout': 0.1, 'activation_dropout': 0.0, 'max_seq_len': 256} config = OmegaConf.create(config) model = E3PiFold(config) feat = model(X, H, seq_mask)