FLARE / flare /models /spec_encoder.py
yzhouchen001's picture
update
19a4dfc
import torch.nn as nn
import torch
from flare.models.encoders import MLP
from torch_geometric.nn import global_mean_pool
class SpecEncMLP_BIN(nn.Module):
def __init__(self, args, out_dim=None):
super(SpecEncMLP_BIN, self).__init__()
if not out_dim:
out_dim = args.final_embedding_dim
bin_size = int(args.max_mz / args.bin_width)
self.dropout = nn.Dropout(args.fc_dropout)
self.mz_fc1 = nn.Linear(bin_size, out_dim * 2)
self.mz_fc2 = nn.Linear(out_dim* 2, out_dim * 2)
self.mz_fc3 = nn.Linear(out_dim * 2, out_dim)
self.relu = nn.ReLU()
def forward(self, mzi_b, n_peaks=None):
h1 = self.mz_fc1(mzi_b)
h1 = self.relu(h1)
h1 = self.dropout(h1)
h1 = self.mz_fc2(h1)
h1 = self.relu(h1)
h1 = self.dropout(h1)
mz_vec = self.mz_fc3(h1)
mz_vec = self.dropout(mz_vec)
return mz_vec
class SpecMzIntTokenTransformer(nn.Module):
def __init__(self, args):
super(SpecMzIntTokenTransformer, self).__init__()
in_dim = 2
self.tokenEnc = MLP(in_dim, args.hidden_dims, dropout=args.peak_dropout)
self.returnEmb = False
if args.model in ('crossAttenContrastive', 'filipContrastive'):
self.returnEmb = True
assert(args.use_cls == False)
else:
self.specEncoder = nn.Sequential(nn.Linear(args.hidden_dims[-1], args.final_embedding_dim), nn.Dropout(args.fc_dropout))
self.use_cls = args.use_cls
if self.use_cls:
self.cls_embed = torch.nn.Embedding(1,args.hidden_dims[-1])
encoder_layer = nn.TransformerEncoderLayer(d_model=args.hidden_dims[-1], nhead=2, batch_first=True)
self.tokenTransformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
def forward(self, spec, n_peaks=None):
h = self.tokenEnc(spec)
pad = (spec == -5)
pad = torch.all(pad, -1)
if self.use_cls:
cls_embed = self.cls_embed(torch.tensor(0).to(spec.device))
h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1)
pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
h = self.tokenTransformer(h, src_key_padding_mask=pad)
h = h[:,0,:]
else:
# mean
h = self.tokenTransformer(h, src_key_padding_mask=pad)
if self.returnEmb:
# repad h
h[pad] = -5
return h
n_peaks_indices = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(spec.device)
h = h[~pad].reshape(-1, h.shape[-1])
h = global_mean_pool(h, n_peaks_indices)
h = self.specEncoder(h)
return h
class SpecFormulaEncMLP(nn.Module):
def __init__(self, args, out_dim=None):
super(SpecFormulaEncMLP, self).__init__()
in_dim = len(args.element_list)
if args.add_intensities:
in_dim+=1
if args.spectra_view == "SpecFormulaMz": #mz
in_dim+=1
self.formulaEnc = MLP(in_dim, args.formula_dims, dropout=args.formula_dropout)
if not out_dim:
out_dim = args.final_embedding_dim
self.mz_fc1 = nn.Linear(args.formula_dims[-1], out_dim)
self.dropout = nn.Dropout(args.fc_dropout)
def forward(self, spec, n_peaks):
h = self.formulaEnc(spec)
h = torch.sum(h, axis=1)
h = self.mz_fc1(h)
h = self.dropout(h)
return h
class SpecFormulaTransformer(nn.Module):
def __init__(self, args, out_dim=None):
super(SpecFormulaTransformer, self).__init__()
in_dim = len(args.element_list)
if args.add_intensities: # intensity
in_dim+=1
if args.spectra_view == "SpecFormulaMz": #mz
in_dim+=1
self.returnEmb = False
if args.model in ('crossAttenContrastive', 'filipContrastive', 'filipGlobalContrastive'):
self.returnEmb = True
assert(args.use_cls == False)
self.formulaEnc = MLP(in_dim=in_dim, hidden_dims=args.formula_dims, dropout=args.formula_dropout)
self.use_cls = args.use_cls
if args.use_cls:
self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1])
encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=args.formula_attn_heads, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.formula_transformer_layers)
if not self.returnEmb:
if not out_dim:
out_dim = args.final_embedding_dim
self.fc = nn.Linear(args.formula_dims[-1], out_dim)
def forward(self, spec, n_peaks=None):
h = self.formulaEnc(spec)
pad = (spec == -5)
pad = torch.all(pad, -1)
if self.use_cls:
cls_embed = self.cls_embed(torch.tensor(0).to(spec.device))
h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1)
pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1)
h = self.transformer(h, src_key_padding_mask=pad)
h = h[:,0,:]
else:
h = self.transformer(h, src_key_padding_mask=pad)
if self.returnEmb:
# repad h
h[pad] = -5
return h
h = h[~pad].reshape(-1, h.shape[-1])
indecies = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(h.device)
h = global_mean_pool(h, indecies)
h = self.fc(h)
return h
class SpecFormula_mz_Encoder(nn.Module):
'''
Encodes formula and mz_int
'''
def __init__(self, args):
super(SpecFormula_mz_Encoder, self).__init__()
self.formula_encoder = SpecFormulaTransformer(args, out_dim=args.final_embedding_dim//4)
self.mz_encoder = SpecEncMLP_BIN(args, out_dim=args.final_embedding_dim//4)
self.fc = nn.Sequential(nn.Linear(args.final_embedding_dim //2, args.final_embedding_dim), nn.ReLU(),
)
def forward(self, formulas, binned_mzs):
h_formula = self.formula_encoder(formulas)
h_bin = self.mz_encoder(binned_mzs)
h_spec = torch.concat((h_formula, h_bin), axis=1)
h = self.fc(h_spec)
return h