CSU-MS2-T2 / infer.py
Tingxie's picture
Update infer.py
00c1ca5 verified
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 15 16:22:05 2022
@author: ZNDX002
"""
from model import ModelCLR
import yaml
import os
import torch
import numpy as np
import re
from torch_geometric.data import Data, Batch
from dataloader.dataset_wrapper import MolToGraph
from rdkit import Chem
class ModelInference(object):
def __init__(self, config_path, pretrain_model_path, device):
assert config_path is not None, "config_path is None"
assert pretrain_model_path is not None, "pretrain_model_path is None"
if device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
self.model = ModelCLR(**self.config["model_config"]).to(self.device)
state_dict = torch.load(pretrain_model_path,map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
def smiles_encode(self, smiles_str):
with torch.no_grad():
if isinstance(smiles_str, str):
#single smiles
v_d = MolToGraph(smiles_str)
v_d = v_d.to(self.device)
smiles_tensor = self.model.smiles_encoder(v_d)
smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch)
smiles_tensor = self.model.smi_proj(smiles_tensor)
smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
return smiles_tensor
else:
#smiles list
graphs=[]
for smi in smiles_str:
v_d = MolToGraph(smi)
graphs.append(v_d)
v_ds = Batch.from_data_list(graphs)
v_ds = v_ds.to(self.device)
smiles_tensor = self.model.smiles_encoder(v_ds)
smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch)
smiles_tensor = self.model.smi_proj(smiles_tensor)
smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
return smiles_tensor
def ms2_encode(self, ms2_list):
with torch.no_grad():
if not isinstance(ms2_list, list):
#single ms2
spec_mz = ms2_list.mz
spec_intens = ms2_list.intensities
num_peak = len(spec_mz)
spec_mz = np.around(spec_mz, decimals=4)
spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
spec_mz= torch.tensor(spec_mz).float().unsqueeze(0)
spec_intens= torch.tensor(spec_intens).float().unsqueeze(0)
num_peak = torch.LongTensor(num_peak).unsqueeze(0)
spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak)
spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
spec_tensor = self.model.spec_proj(spec_tensor)
spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
return spec_tensor
else:
# batch ms2
spec_mzs = [spec.mz for spec in ms2_list]
spec_intens = [spec.intensities for spec in ms2_list]
num_peaks = [len(i) for i in spec_mzs]
spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs]
num_peaks = torch.LongTensor(num_peaks)
mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs]
intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens]
mzs_tensors = torch.nn.utils.rnn.pad_sequence(
mzs, batch_first=True, padding_value=0
)
intens_tensors = torch.nn.utils.rnn.pad_sequence(
intens, batch_first=True, padding_value=0
)
mzs_tensors=mzs_tensors.to(self.device)
intens_tensors=intens_tensors.to(self.device)
num_peaks=num_peaks.to(self.device)
spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks)
spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
spec_tensor = self.model.spec_proj(spec_tensor)
spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
return spec_tensor
def get_cos_distance(self, input_1, input_2):
with torch.no_grad():
return input_1 @ input_2.t()