Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| from flare.models.encoders import MLP | |
| from torch_geometric.nn import global_mean_pool | |
| class SpecEncMLP_BIN(nn.Module): | |
| def __init__(self, args, out_dim=None): | |
| super(SpecEncMLP_BIN, self).__init__() | |
| if not out_dim: | |
| out_dim = args.final_embedding_dim | |
| bin_size = int(args.max_mz / args.bin_width) | |
| self.dropout = nn.Dropout(args.fc_dropout) | |
| self.mz_fc1 = nn.Linear(bin_size, out_dim * 2) | |
| self.mz_fc2 = nn.Linear(out_dim* 2, out_dim * 2) | |
| self.mz_fc3 = nn.Linear(out_dim * 2, out_dim) | |
| self.relu = nn.ReLU() | |
| def forward(self, mzi_b, n_peaks=None): | |
| h1 = self.mz_fc1(mzi_b) | |
| h1 = self.relu(h1) | |
| h1 = self.dropout(h1) | |
| h1 = self.mz_fc2(h1) | |
| h1 = self.relu(h1) | |
| h1 = self.dropout(h1) | |
| mz_vec = self.mz_fc3(h1) | |
| mz_vec = self.dropout(mz_vec) | |
| return mz_vec | |
| class SpecMzIntTokenTransformer(nn.Module): | |
| def __init__(self, args): | |
| super(SpecMzIntTokenTransformer, self).__init__() | |
| in_dim = 2 | |
| self.tokenEnc = MLP(in_dim, args.hidden_dims, dropout=args.peak_dropout) | |
| self.returnEmb = False | |
| if args.model in ('crossAttenContrastive', 'filipContrastive'): | |
| self.returnEmb = True | |
| assert(args.use_cls == False) | |
| else: | |
| self.specEncoder = nn.Sequential(nn.Linear(args.hidden_dims[-1], args.final_embedding_dim), nn.Dropout(args.fc_dropout)) | |
| self.use_cls = args.use_cls | |
| if self.use_cls: | |
| self.cls_embed = torch.nn.Embedding(1,args.hidden_dims[-1]) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=args.hidden_dims[-1], nhead=2, batch_first=True) | |
| self.tokenTransformer = nn.TransformerEncoder(encoder_layer, num_layers=2) | |
| def forward(self, spec, n_peaks=None): | |
| h = self.tokenEnc(spec) | |
| pad = (spec == -5) | |
| pad = torch.all(pad, -1) | |
| if self.use_cls: | |
| cls_embed = self.cls_embed(torch.tensor(0).to(spec.device)) | |
| h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1) | |
| pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1) | |
| h = self.tokenTransformer(h, src_key_padding_mask=pad) | |
| h = h[:,0,:] | |
| else: | |
| # mean | |
| h = self.tokenTransformer(h, src_key_padding_mask=pad) | |
| if self.returnEmb: | |
| # repad h | |
| h[pad] = -5 | |
| return h | |
| n_peaks_indices = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(spec.device) | |
| h = h[~pad].reshape(-1, h.shape[-1]) | |
| h = global_mean_pool(h, n_peaks_indices) | |
| h = self.specEncoder(h) | |
| return h | |
| class SpecFormulaEncMLP(nn.Module): | |
| def __init__(self, args, out_dim=None): | |
| super(SpecFormulaEncMLP, self).__init__() | |
| in_dim = len(args.element_list) | |
| if args.add_intensities: | |
| in_dim+=1 | |
| if args.spectra_view == "SpecFormulaMz": #mz | |
| in_dim+=1 | |
| self.formulaEnc = MLP(in_dim, args.formula_dims, dropout=args.formula_dropout) | |
| if not out_dim: | |
| out_dim = args.final_embedding_dim | |
| self.mz_fc1 = nn.Linear(args.formula_dims[-1], out_dim) | |
| self.dropout = nn.Dropout(args.fc_dropout) | |
| def forward(self, spec, n_peaks): | |
| h = self.formulaEnc(spec) | |
| h = torch.sum(h, axis=1) | |
| h = self.mz_fc1(h) | |
| h = self.dropout(h) | |
| return h | |
| class SpecFormulaTransformer(nn.Module): | |
| def __init__(self, args, out_dim=None): | |
| super(SpecFormulaTransformer, self).__init__() | |
| in_dim = len(args.element_list) | |
| if args.add_intensities: # intensity | |
| in_dim+=1 | |
| if args.spectra_view == "SpecFormulaMz": #mz | |
| in_dim+=1 | |
| self.returnEmb = False | |
| if args.model in ('crossAttenContrastive', 'filipContrastive', 'filipGlobalContrastive'): | |
| self.returnEmb = True | |
| assert(args.use_cls == False) | |
| self.formulaEnc = MLP(in_dim=in_dim, hidden_dims=args.formula_dims, dropout=args.formula_dropout) | |
| self.use_cls = args.use_cls | |
| if args.use_cls: | |
| self.cls_embed = torch.nn.Embedding(1,args.formula_dims[-1]) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=args.formula_dims[-1], nhead=args.formula_attn_heads, batch_first=True) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=args.formula_transformer_layers) | |
| if not self.returnEmb: | |
| if not out_dim: | |
| out_dim = args.final_embedding_dim | |
| self.fc = nn.Linear(args.formula_dims[-1], out_dim) | |
| def forward(self, spec, n_peaks=None): | |
| h = self.formulaEnc(spec) | |
| pad = (spec == -5) | |
| pad = torch.all(pad, -1) | |
| if self.use_cls: | |
| cls_embed = self.cls_embed(torch.tensor(0).to(spec.device)) | |
| h = torch.concat((cls_embed.repeat(spec.shape[0], 1).unsqueeze(1), h), dim=1) | |
| pad = torch.concat((torch.tensor(False).repeat(pad.shape[0],1).to(spec.device), pad), dim=1) | |
| h = self.transformer(h, src_key_padding_mask=pad) | |
| h = h[:,0,:] | |
| else: | |
| h = self.transformer(h, src_key_padding_mask=pad) | |
| if self.returnEmb: | |
| # repad h | |
| h[pad] = -5 | |
| return h | |
| h = h[~pad].reshape(-1, h.shape[-1]) | |
| indecies = torch.tensor([i for i, count in enumerate(n_peaks) for _ in range(count)]).to(h.device) | |
| h = global_mean_pool(h, indecies) | |
| h = self.fc(h) | |
| return h | |
| class SpecFormula_mz_Encoder(nn.Module): | |
| ''' | |
| Encodes formula and mz_int | |
| ''' | |
| def __init__(self, args): | |
| super(SpecFormula_mz_Encoder, self).__init__() | |
| self.formula_encoder = SpecFormulaTransformer(args, out_dim=args.final_embedding_dim//4) | |
| self.mz_encoder = SpecEncMLP_BIN(args, out_dim=args.final_embedding_dim//4) | |
| self.fc = nn.Sequential(nn.Linear(args.final_embedding_dim //2, args.final_embedding_dim), nn.ReLU(), | |
| ) | |
| def forward(self, formulas, binned_mzs): | |
| h_formula = self.formula_encoder(formulas) | |
| h_bin = self.mz_encoder(binned_mzs) | |
| h_spec = torch.concat((h_formula, h_bin), axis=1) | |
| h = self.fc(h_spec) | |
| return h | |