File size: 3,330 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from omegaconf import OmegaConf
from src.modules.E3PiFold import GaussianEncoder, TransformerEncoderWithPair
from src.tools import gather_nodes, _dihedrals, _get_rbf, _get_dist, _rbf, _orientations_coarse_gl_tuple

class E3PiFold(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.node_embed = nn.Linear(21, config.embed_dim)
        self.protein_embedder = GaussianEncoder(config.kernel_num, config.embed_dim, config.attention_heads, config.use_dist, config.use_product)

        self.encoder = TransformerEncoderWithPair(
            config.encoder_layers,
            config.embed_dim,
            config.ffn_embed_dim,
            config.attention_heads,
            config.emb_dropout,
            config.dropout,
            config.attention_dropout,
            config.activation_dropout,
            config.max_seq_len,
        )
        self.predictor = nn.Linear(config.embed_dim, 33)
    
    def _full_dist(self, X, mask, top_k=30, eps=1E-6):
        mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
        dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
        D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)

        D_max, _ = torch.max(D, -1, keepdim=True)
        D_adjust = D + (1. - mask_2D) * (D_max+1)
        D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
        return D_neighbors, E_idx 
    
    def _get_features(self, batch):
        X = batch['X']
        X_ca = X[:,:,1,:]
        D_neighbors, E_idx = self._full_dist(X_ca, batch['mask'], 30)
        V_angles = _dihedrals(X.float())
        V_direct, E_direct, E_angles = _orientations_coarse_gl_tuple(X.float(), E_idx)
        h_V = torch.cat([V_angles, V_direct], dim=-1).to(X.dtype)
        batch['h_V'] = h_V 
        return batch
    
    def forward(self, batch):
        '''
        X, H, seq_mask
        '''
        X = batch['X'][:,:,1]
        H = self.node_embed(batch['h_V'])
        seq_mask = batch['mask']
        pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
        padding_mask = 1 - seq_mask
        x, graph_attn_bias = self.protein_embedder(X, H, pair_mask)
        (
            encoder_rep, 
            encoder_pair_rep,
            delta_encoder_pair_rep,
            x_norm,
            delta_encoder_pair_rep_norm,
        ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias, pair_mask=pair_mask)
        logits = self.predictor(x)
        log_probs = F.log_softmax(logits, dim=-1)

        return {'log_probs': log_probs}


if __name__ == '__main__':
    B, N, dim = 16, 512, 768
    X = torch.randn(B, N, 3)
    H = torch.randn(B, N, dim)
    seq_mask = (torch.ones(B, N)>0.5).float()

    config = {'encoder_layers': 12,
              'kernel_num':16,
              'embed_dim': 768,
              'ffn_embed_dim': 3072,
              'attention_heads': 8,
              'emb_dropout': 0.1,
              'dropout': 0.1,
              'attention_dropout': 0.1,
              'activation_dropout': 0.0,
              'max_seq_len': 256}
    config = OmegaConf.create(config)
    model = E3PiFold(config)
    feat = model(X, H, seq_mask)