File size: 4,809 Bytes
406eb8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00c1ca5
 
406eb8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 15 16:22:05 2022

@author: ZNDX002
"""
from model import ModelCLR
import yaml
import os
import torch
import numpy as np
import re
from torch_geometric.data import Data, Batch
from dataloader.dataset_wrapper import MolToGraph
from rdkit import Chem

class ModelInference(object):
    def __init__(self, config_path, pretrain_model_path, device):
        assert config_path is not None, "config_path is None"
        assert pretrain_model_path is not None, "pretrain_model_path is None"

        if device is None:
            self.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
        self.model = ModelCLR(**self.config["model_config"]).to(self.device)
        state_dict = torch.load(pretrain_model_path,map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.eval()


    def smiles_encode(self, smiles_str):
        with torch.no_grad():
            if isinstance(smiles_str, str):
                #single smiles
                v_d = MolToGraph(smiles_str)
                v_d = v_d.to(self.device)
                smiles_tensor = self.model.smiles_encoder(v_d)
                smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch)
                smiles_tensor = self.model.smi_proj(smiles_tensor)
                smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
                return smiles_tensor
            else:
                #smiles list
                graphs=[]
                for smi in smiles_str:
                    v_d = MolToGraph(smi)
                    graphs.append(v_d)
                v_ds = Batch.from_data_list(graphs)
                v_ds = v_ds.to(self.device)
                smiles_tensor = self.model.smiles_encoder(v_ds)
                smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch)
                smiles_tensor = self.model.smi_proj(smiles_tensor)
                smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
                return smiles_tensor

    def ms2_encode(self, ms2_list):
        with torch.no_grad():
            if not isinstance(ms2_list, list):
                #single ms2
                spec_mz = ms2_list.mz
                spec_intens = ms2_list.intensities
                num_peak = len(spec_mz)
                spec_mz = np.around(spec_mz, decimals=4)
                spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
                spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
                spec_mz= torch.tensor(spec_mz).float().unsqueeze(0)
                spec_intens= torch.tensor(spec_intens).float().unsqueeze(0)
                num_peak = torch.LongTensor(num_peak).unsqueeze(0)
                spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak)
                spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
                spec_tensor = self.model.spec_proj(spec_tensor)
                spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
                return spec_tensor
            else:
                # batch ms2
                spec_mzs = [spec.mz for spec in ms2_list]
                spec_intens = [spec.intensities for spec in ms2_list]
                num_peaks = [len(i) for i in spec_mzs]
                spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs]
                num_peaks = torch.LongTensor(num_peaks)
                mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs]
                intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens]
                mzs_tensors = torch.nn.utils.rnn.pad_sequence(
                        mzs, batch_first=True, padding_value=0
                    )
                intens_tensors = torch.nn.utils.rnn.pad_sequence(
                        intens, batch_first=True, padding_value=0
                    )
                mzs_tensors=mzs_tensors.to(self.device)
                intens_tensors=intens_tensors.to(self.device)
                num_peaks=num_peaks.to(self.device)

                spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks)
                spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
                spec_tensor = self.model.spec_proj(spec_tensor)
                spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
                return spec_tensor

    def get_cos_distance(self, input_1, input_2):
        with torch.no_grad():
            return input_1 @ input_2.t()