File size: 5,695 Bytes
d5233a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5946936
 
 
 
 
 
d5233a9
 
 
5946936
d5233a9
 
 
 
 
 
 
5946936
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
from torch import nn
import torch.nn.functional as F
from config import CFG
import utils
import math
import numpy as np
from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
from GNN import layers as gly

class MolGNNEncoder(nn.Module):
    def __init__(self,

                 outdim,

                 n_feats=74, #330, # 74+256 morgan 256

                 n_filters_list=[256, 256, 256],

                 n_head=4,

                 mols=1,

                 adj_chans=6,

                 readout_layers=2,

                 bias=True):

        super().__init__()

        n_filters_list = [i for i in n_filters_list if i is not None]
        lys = []

        for i, nf in enumerate(n_filters_list):
            if i == 0:
                nf1 = n_feats
            else:
                nf1 = prevnf

            prevnf = nf

            ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
            lys.append(ly)

        self.block_layers = nn.ModuleList(lys)
        self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
        self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
        self.gelu = QuickGELU()

    def forward(self, batch):
        V        = batch['V']
        A        = batch['A']
        mol_size = batch['mol_size']

        for ly in self.block_layers:
            V = ly(V, A)

        X = self.attention_layer(V, mol_size)

        for ly in self.readout_layers:
            X = self.gelu(ly(X))

        return X

class ProjectionHead(nn.Module):
    def __init__(self,

                 embedding_dim,

                 projection_dim,

                 cfg,

                 transformer=True,

                 lstm=False):

        super().__init__()

        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU() #QuickGELU()
        self.transformer = None
        if transformer:
            self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
        self.lstm = None
        if lstm:
            self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
        self.dropout = nn.Dropout(cfg.dropout)
    
    def forward(self, x):
        projected = self.projection(x)
        if self.transformer is None:
            x = self.gelu(projected)
        else:
            x = self.transformer(projected)
        if not self.lstm is None:
            x, (_, _) = self.lstm(x)
        x = self.dropout(x)

        return x

# New name in paper is CMSSPModel
class FragSimiModel(nn.Module):
    def __init__(

        self,

        cfg

    ):
        super().__init__()

        self.cfg = cfg
        self.mol_gnn_encoder = None
        mol_embedding_dim = cfg.mol_embedding_dim

        if 'gnn' in self.cfg.mol_encoder:
            self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
                                                 n_filters_list=cfg.molgnn_n_filters_list,
                                                 n_head=cfg.molgnn_nhead,
                                                 readout_layers=cfg.molgnn_readout_layers)
            if 'fp' in self.cfg.mol_encoder:
                mol_embedding_dim = 2*cfg.mol_embedding_dim

        if 'fm' in self.cfg.mol_encoder:
            mol_embedding_dim += 10
            
        self.ms_projection  = ProjectionHead(cfg.ms_embedding_dim,
                                             cfg.projection_dim,
                                             cfg,
                                             cfg.tsfm_in_ms,
                                             cfg.lstm_in_ms)
        
        self.mol_projection = ProjectionHead(mol_embedding_dim,
                                             cfg.projection_dim,
                                             cfg,
                                             cfg.tsfm_in_mol,
                                             cfg.lstm_in_mol)

    def forward(self, batch):
        ms_features = batch["ms_bins"]
        mol_feat_list = []
        if 'gnn' in self.cfg.mol_encoder:
            mol_feat_list.append(self.mol_gnn_encoder(batch))
        if 'fp' in self.cfg.mol_encoder:
            mol_feat_list.append(batch["mol_fps"])
        if 'fm' in self.cfg.mol_encoder:
            mol_feat_list.append(batch["mol_fmvec"])
        
        if len(mol_feat_list) > 1:
            mol_features = torch.cat(mol_feat_list, dim=1)
        else:
            mol_features = mol_feat_list[0]

        # Getting ms and mol Embeddings (with same dimension)
        ms_embeddings = self.ms_projection(ms_features)
        mol_embeddings = self.mol_projection(mol_features)

        # Normalize the projected embeddings
        mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1)
        ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1)

        return mol_embeddings, ms_embeddings

        # Calculating the Loss
        #logits = (mol_embeddings @ ms_embeddings.t())
        #logit_scale = self.logit_scale.exp()
        '''logits = mol_embeddings @ ms_embeddings.t()

        

        ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)



        ms_loss = loss_func(logits, ground_truth)

        mol_loss = loss_func(logits.t(), ground_truth)

        loss =  (ms_loss + mol_loss) / 2.0 # shape: (batch_size)



        return loss.mean()'''