File size: 5,508 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn

from src.modules.alphadesign_module import ATDecoder, CNNDecoder, CNNDecoder2, StructureEncoder
from src.tools.design_utils import gather_nodes, _dihedrals, _rbf, _orientations_coarse_gl


class AlphaDesign_Model(nn.Module):
    def __init__(self, args, **kwargs):
        """ Graph labeling network """
        super(AlphaDesign_Model, self).__init__()
        self.args = args
        node_features = args.node_features
        edge_features = args.edge_features
        hidden_dim = args.hidden_dim
        dropout = args.dropout
        num_encoder_layers = args.num_encoder_layers
        self.top_k = args.k_neighbors
        self.num_rbf = 16
        self.num_positional_embeddings = 16

        if args.use_new_feat:
            node_in, edge_in = 12, 16+7
        else:
            node_in, edge_in = 6, 16+7
        self.node_embedding = nn.Linear(node_in, node_features, bias=True)
        self.edge_embedding = nn.Linear(edge_in, edge_features, bias=True)
        self.norm_nodes = nn.BatchNorm1d(node_features)
        self.norm_edges = nn.BatchNorm1d(edge_features)

        self.W_v = nn.Sequential(
            nn.Linear(node_features, hidden_dim, bias=True),
            nn.LeakyReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim, bias=True),
            nn.LeakyReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim, bias=True)
        )
        
        
        self.W_e = nn.Linear(edge_features, hidden_dim, bias=True) 
        self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)

        self.encoder = StructureEncoder(hidden_dim, num_encoder_layers, dropout, use_SGT=self.args.use_SGT)

        if args.autoregressive:
            self.decoder = ATDecoder(args, hidden_dim, dropout)
        else:
            self.decoder = CNNDecoder(hidden_dim, hidden_dim)
            self.decoder2 = CNNDecoder2(hidden_dim, hidden_dim)
            
        # self.chain_embed = nn.Embedding(2,16)
        self._init_params()
    
    def forward(self, batch,  AT_test = False, return_logit=False):
        h_V, h_P, P_idx, batch_id = batch['_V'], batch['_E'], batch['E_idx'], batch['batch_id']
        h_V = self.W_v(self.norm_nodes(self.node_embedding(h_V)))
        h_P = self.W_e(self.norm_edges(self.edge_embedding(h_P)))
        
        h_V = self.encoder(h_V, h_P, P_idx, batch_id)
        log_probs0 = None
        if AT_test:
            log_probs = self.decoder.sampling(h_V, h_P, P_idx,  batch_id)
        else:
            log_probs0, logits = self.decoder(h_V, batch_id)
            log_probs, logits = self.decoder2(h_V, logits, batch_id)
        if return_logit:
            return {'log_probs': log_probs, 'logits': logits}
        else:
            return {'log_probs': log_probs, 'log_probs0': log_probs0}
        
    def _init_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _full_dist(self, X, mask, top_k=30, eps=1E-6):
        mask_2D = torch.unsqueeze(mask,1) * torch.unsqueeze(mask,2)
        dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
        D = (1. - mask_2D)*10000 + mask_2D* torch.sqrt(torch.sum(dX**2, 3) + eps)

        D_max, _ = torch.max(D, -1, keepdim=True)
        D_adjust = D + (1. - mask_2D) * (D_max+1)
        D_neighbors, E_idx = torch.topk(D_adjust, min(top_k, D_adjust.shape[-1]), dim=-1, largest=False)
        return D_neighbors, E_idx  

    def _get_features(self, batch):
        S, score, X, mask = batch['S'], batch['score'], batch['X'], batch['mask']
        mask_bool = (mask==1)
        
        B, N, _,_ = X.shape
        X_ca = X[:,:,1,:]
        D_neighbors, E_idx = self._full_dist(X_ca, mask, self.top_k)

        # sequence
        S = torch.masked_select(S, mask_bool)
        if score is not None:
            score = torch.masked_select(score, mask_bool)

        # node feature
        _V = _dihedrals(X) 
        if not self.args.use_new_feat:
            _V = _V[...,:6]
        _V = torch.masked_select(_V, mask_bool.unsqueeze(-1)).reshape(-1,_V.shape[-1])

        # edge feature
        _E = torch.cat((_rbf(D_neighbors, self.num_rbf), _orientations_coarse_gl(X, E_idx)), -1) # [4,387,387,23]
        mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1) # 一阶邻居节点的mask: 1代表节点存在, 0代表节点不存在
        mask_attend = (mask.unsqueeze(-1) * mask_attend) == 1 # 自身的mask*邻居节点的mask
        _E = torch.masked_select(_E, mask_attend.unsqueeze(-1)).reshape(-1,_E.shape[-1])
        
        # edge index
        shift = mask.sum(dim=1).cumsum(dim=0) - mask.sum(dim=1)
        src = shift.view(B,1,1) + E_idx
        src = torch.masked_select(src, mask_attend).view(1,-1)
        dst = shift.view(B,1,1) + torch.arange(0, N, device=src.device).view(1,-1,1).expand_as(mask_attend)
        dst = torch.masked_select(dst, mask_attend).view(1,-1)
        E_idx = torch.cat((dst, src), dim=0).long()
        
        
        # 3D point
        sparse_idx = mask.nonzero()
        X = X[sparse_idx[:,0],sparse_idx[:,1],:,:]
        batch_id = sparse_idx[:,0]

        mask = torch.masked_select(mask, mask_bool)

        batch.update({'X':X,
                'S':S,
                'score':score,
                '_V':_V,
                '_E':_E,
                'E_idx':E_idx,
                'batch_id': batch_id,
                'mask':mask})

        return batch