# -*- coding: utf-8 -*- """ Created on Thu Sep 15 16:22:05 2022 @author: ZNDX002 """ from model import ModelCLR import yaml import os import torch import numpy as np import re from torch_geometric.data import Data, Batch from dataloader.dataset_wrapper import MolToGraph from rdkit import Chem class ModelInference(object): def __init__(self, config_path, pretrain_model_path, device): assert config_path is not None, "config_path is None" assert pretrain_model_path is not None, "pretrain_model_path is None" if device is None: self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) self.model = ModelCLR(**self.config["model_config"]).to(self.device) state_dict = torch.load(pretrain_model_path,map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval() def smiles_encode(self, smiles_str): with torch.no_grad(): if isinstance(smiles_str, str): #single smiles v_d = MolToGraph(smiles_str) v_d = v_d.to(self.device) smiles_tensor = self.model.smiles_encoder(v_d) smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch) smiles_tensor = self.model.smi_proj(smiles_tensor) smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True) return smiles_tensor else: #smiles list graphs=[] for smi in smiles_str: v_d = MolToGraph(smi) graphs.append(v_d) v_ds = Batch.from_data_list(graphs) v_ds = v_ds.to(self.device) smiles_tensor = self.model.smiles_encoder(v_ds) smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch) smiles_tensor = self.model.smi_proj(smiles_tensor) smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True) return smiles_tensor def ms2_encode(self, ms2_list): with torch.no_grad(): if not isinstance(ms2_list, list): #single ms2 spec_mz = ms2_list.mz spec_intens = ms2_list.intensities num_peak = len(spec_mz) spec_mz = np.around(spec_mz, decimals=4) spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0) spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0) spec_mz= torch.tensor(spec_mz).float().unsqueeze(0) spec_intens= torch.tensor(spec_intens).float().unsqueeze(0) num_peak = torch.LongTensor(num_peak).unsqueeze(0) spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak) spec_tensor=self.model.spec_esa(spec_tensor,spec_mask) spec_tensor = self.model.spec_proj(spec_tensor) spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True) return spec_tensor else: # batch ms2 spec_mzs = [spec.mz for spec in ms2_list] spec_intens = [spec.intensities for spec in ms2_list] num_peaks = [len(i) for i in spec_mzs] spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs] num_peaks = torch.LongTensor(num_peaks) mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs] intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens] mzs_tensors = torch.nn.utils.rnn.pad_sequence( mzs, batch_first=True, padding_value=0 ) intens_tensors = torch.nn.utils.rnn.pad_sequence( intens, batch_first=True, padding_value=0 ) mzs_tensors=mzs_tensors.to(self.device) intens_tensors=intens_tensors.to(self.device) num_peaks=num_peaks.to(self.device) spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks) spec_tensor=self.model.spec_esa(spec_tensor,spec_mask) spec_tensor = self.model.spec_proj(spec_tensor) spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True) return spec_tensor def get_cos_distance(self, input_1, input_2): with torch.no_grad(): return input_1 @ input_2.t()