import torch.nn as nn import torch.nn.functional as F import torch import math import numpy as np from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool import nn_utils as nn_utils num_atom_type = 119 # including the extra mask tokens num_chirality_tag = 4 num_hybrid_type = 8 num_valence_tag = 8 num_degree_tag = 5 num_bond_type = 5 # including aromatic and self-loop edge num_bond_direction = 3 num_bond_configuration = 6 class GINEConv(MessagePassing): def __init__(self, emb_dim): super(GINEConv, self).__init__() self.mlp = nn.Sequential( nn.Linear(emb_dim, 2*emb_dim), nn.ReLU(), nn.Linear(2*emb_dim, emb_dim) ) self.edge_embedding1 = nn.Embedding(num_bond_type, emb_dim) self.edge_embedding2 = nn.Embedding(num_bond_direction, emb_dim) #self.edge_embedding3 = nn.Embedding(num_bond_configuration, emb_dim) nn.init.xavier_uniform_(self.edge_embedding1.weight.data) nn.init.xavier_uniform_(self.edge_embedding2.weight.data) #nn.init.xavier_uniform_(self.edge_embedding3.weight.data) def forward(self, x, edge_index, edge_attr): # add self loops in the edge space edge_index = add_self_loops(edge_index, num_nodes=x.size(0))[0] # add features corresponding to self-loop edges. self_loop_attr = torch.zeros(x.size(0), 2) self_loop_attr[:,0] = 4 #bond type for self-loop edge self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype) edge_attr = torch.cat((edge_attr, self_loop_attr), dim=0) edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1]) return self.propagate(edge_index, x=x, edge_attr=edge_embeddings) def message(self, x_j, edge_attr): return x_j + edge_attr def update(self, aggr_out): return self.mlp(aggr_out) class SmilesModel(nn.Module): """ Args: num_layer (int): the number of GNN layers emb_dim (int): dimensionality of embeddings max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation drop_ratio (float): dropout rate gnn_type: gin, gcn, graphsage, gat Output: node representations """ def __init__(self, num_layer=5, emb_dim=300, feat_dim=256, drop_ratio=0, pool='mean'): super(SmilesModel, self).__init__() self.num_layer = num_layer self.emb_dim = emb_dim self.feat_dim = feat_dim self.drop_ratio = drop_ratio self.x_embedding1 = nn.Embedding(num_atom_type, emb_dim) self.x_embedding2 = nn.Embedding(num_chirality_tag, emb_dim) self.x_embedding3 = nn.Embedding(num_hybrid_type, emb_dim) self.x_embedding4 = nn.Embedding(num_valence_tag, emb_dim) self.x_embedding5 = nn.Embedding(num_degree_tag, emb_dim) nn.init.xavier_uniform_(self.x_embedding1.weight.data) nn.init.xavier_uniform_(self.x_embedding2.weight.data) nn.init.xavier_uniform_(self.x_embedding3.weight.data) nn.init.xavier_uniform_(self.x_embedding4.weight.data) nn.init.xavier_uniform_(self.x_embedding5.weight.data) # List of MLPs self.gnns = nn.ModuleList() for layer in range(num_layer): self.gnns.append(GINEConv(emb_dim)) # List of batchnorms self.batch_norms = nn.ModuleList() for layer in range(num_layer): self.batch_norms.append(nn.BatchNorm1d(emb_dim)) if pool == 'mean': self.pool = global_mean_pool elif pool == 'max': self.pool = global_max_pool elif pool == 'add': self.pool = global_add_pool self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim) self.out_lin = nn.Sequential( nn.Linear(self.feat_dim, self.feat_dim), nn.ReLU(inplace=True), nn.Linear(self.feat_dim, self.feat_dim//2) ) def forward(self, data): x = data.x edge_index = data.edge_index edge_attr = data.edge_attr h = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1]) + self.x_embedding3(x[:,2]) + self.x_embedding4(x[:,3]) + self.x_embedding5(x[:,4]) for layer in range(self.num_layer): h = self.gnns[layer](h, edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layer - 1: h = F.dropout(h, self.drop_ratio, training=self.training) else: h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) '''h = self.pool(h, data.batch) h = self.feat_lin(h) out = self.out_lin(h)''' return h class FourierEmbedder(nn.Module): """Embed a set of mz float values using frequencies""" def __init__(self, spec_embed_dim, logmin=2.5, logmax=3.3): super().__init__() self.d = spec_embed_dim self.logmin = logmin self.logmax = logmax lambda_min = np.power(10, -logmin) lambda_max = np.power(10, logmax) index = torch.arange(np.ceil(self.d / 2)) exp = torch.pow(lambda_max / lambda_min, (2 * index) / (self.d - 2)) freqs = 2 * np.pi * (lambda_min * exp) ** (-1) self.freqs = nn.Parameter(freqs, requires_grad=False) # Turn off requires grad for freqs self.freqs.requires_grad = False def forward(self, mz: torch.FloatTensor): """forward Args: mz: FloatTensor of shape (batch_size, mz values) Returns: FloatTensor of shape (batch_size, peak len, mz ) """ freq_input = torch.einsum("bi,j->bij", mz, self.freqs) embedded = torch.cat([torch.sin(freq_input), torch.cos(freq_input)], -1) embedded = embedded[:, :, : self.d] return embedded class MSModel(nn.Module): def __init__(self, spec_embed_dim,dropout,layers): super(MSModel,self).__init__() self.mz_embedder = FourierEmbedder(spec_embed_dim) self.input_compress = nn.Linear(spec_embed_dim+1, spec_embed_dim) peak_attn_layer = nn_utils.TransformerEncoderLayer( d_model=spec_embed_dim, nhead=8, dim_feedforward=spec_embed_dim * 4, dropout=dropout, additive_attn=False, pairwise_featurization=False) self.peak_attn_layers = nn_utils.get_clones(peak_attn_layer,layers) self.pooling_layer = nn.AdaptiveAvgPool1d(1) self.output_layer = nn.Linear(spec_embed_dim, spec_embed_dim) def forward(self,mzs,intens,num_peaks): embedded_mz = self.mz_embedder(mzs) cat_vec = [embedded_mz, intens[:, :, None]] peak_tensor = torch.cat(cat_vec, -1) peak_tensor = self.input_compress(peak_tensor) peak_dim = peak_tensor.shape[1] peaks_aranged = torch.arange(peak_dim).to(mzs.device) # batch x num peaks attn_mask = ~(peaks_aranged[None, :] < num_peaks[:, None]) # Transpose to peaks x batch x features peak_tensor = peak_tensor.transpose(0, 1) for peak_attn_layer in self.peak_attn_layers: peak_tensor, pairwise_features = peak_attn_layer( peak_tensor, src_key_padding_mask=attn_mask, ) peak_tensor = peak_tensor.transpose(0, 1) # Get only the class token #h0 = peak_tensor[:, 0, :] #output = self.output_layer(h0) '''pooled_embeddings = self.pooling_layer(peak_tensor.permute(0, 2, 1)).squeeze(dim=-1) output = self.output_layer(pooled_embeddings)''' return peak_tensor,attn_mask class ESA_SMILES(nn.Module): def __init__(self, feature_dim, out_dim): super().__init__() self.ln_f = nn.LayerNorm(feature_dim) self.linear = nn.Linear(feature_dim, out_dim) self.linear1 = nn.Linear(out_dim, out_dim) def forward(self, hidden_states,data_batch): B = data_batch.max().item() + 1 # batch_num node_counts = torch.bincount(data_batch) # node_num N = node_counts.max().item() # max_node_num C = hidden_states.shape[1] # feat_dim result = torch.zeros((B, N, C)).to(hidden_states.device) for i in range(B): indices = torch.where(data_batch == i)[0] result[i, :len(indices), :] = hidden_states[indices] attention_mask = (result != 0).any(dim=-1).float() logits = self.ln_f(result) # (B, N, C) cap_embes = self.linear(logits) # Q features_in = self.linear1(cap_embes) # M mask = attention_mask.unsqueeze(-1) # (B, N, 1) features_in = features_in.masked_fill(mask == 0, -1e4) # (B, N, C) features_k_softmax = nn.Softmax(dim=1)(features_in) attn = features_k_softmax.masked_fill(mask == 0, 0) smi_feature = torch.sum(attn * cap_embes, dim=1) # (B, C) return smi_feature class ESA_SPEC(nn.Module): def __init__(self, feature_dim, out_dim): super().__init__() self.ln_f = nn.LayerNorm(feature_dim) self.linear = nn.Linear(feature_dim, out_dim) self.linear1 = nn.Linear(out_dim, out_dim) def forward(self, hidden_states,attention_mask): logits = self.ln_f(hidden_states) # (B, N, C) cap_embes = self.linear(logits) # Q features_in = self.linear1(cap_embes) # M mask = attention_mask.unsqueeze(-1) # (B, N, 1) features_in = features_in.masked_fill(mask == 1, -1e4) # (B, N, C) features_k_softmax = nn.Softmax(dim=1)(features_in) attn = features_k_softmax.masked_fill(mask == 1, 0) spec_feature = torch.sum(attn * cap_embes, dim=1) # (B, C) return spec_feature class ModelCLR(nn.Module): def __init__(self, num_layer, emb_dim, feat_dim, drop_ratio, pool,spec_embed_dim,dropout,layers,embed_dim): super(ModelCLR, self).__init__() self.Smiles_model = SmilesModel(num_layer, emb_dim, feat_dim, drop_ratio, pool) self.MS_model = MSModel(spec_embed_dim,dropout,layers) self.smi_esa = ESA_SMILES(emb_dim, embed_dim) self.spec_esa = ESA_SPEC(spec_embed_dim, embed_dim) self.smi_proj = nn.Linear(embed_dim, embed_dim) self.spec_proj = nn.Linear(embed_dim, embed_dim) def smiles_encoder(self, xis): x = self.Smiles_model(xis) return x def ms_encoder(self, mzs,intens,num_peaks): out_emb = self.MS_model(mzs,intens,num_peaks) return out_emb def forward(self, xis, mzs,intens,num_peaks): zis = self.smiles_encoder(xis) zls,attn_mask = self.ms_encoder(mzs,intens,num_peaks) zis_feat=self.smi_esa(zis,xis.batch) zls_feat=self.spec_esa(zls,attn_mask) zis_feat=self.smi_proj(zis_feat) zls_feat=self.spec_proj(zls_feat) return zis_feat, zls_feat