Tingxie commited on
Commit
406eb8d
·
1 Parent(s): aa1762f

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +106 -105
infer.py CHANGED
@@ -1,105 +1,106 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Thu Sep 15 16:22:05 2022
4
-
5
- @author: ZNDX002
6
- """
7
- from model import ModelCLR
8
- import yaml
9
- import os
10
- import torch
11
- import numpy as np
12
- import re
13
- from torch_geometric.data import Data, Batch
14
- from dataloader.dataset_wrapper import MolToGraph
15
- from rdkit import Chem
16
-
17
- class ModelInference(object):
18
- def __init__(self, config_path, pretrain_model_path, device):
19
- assert (config_path is not None, "config_path is None")
20
- assert (pretrain_model_path is not None, "pretrain_model_path is None")
21
-
22
- if device is None:
23
- self.device = torch.device(
24
- "cuda" if torch.cuda.is_available() else "cpu")
25
- else:
26
- self.device = torch.device(device)
27
-
28
- self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
29
- self.model = ModelCLR(**self.config["model_config"]).to(self.device)
30
- state_dict = torch.load(pretrain_model_path)
31
- self.model.load_state_dict(state_dict)
32
- self.model.eval()
33
-
34
-
35
- def smiles_encode(self, smiles_str):
36
- with torch.no_grad():
37
- if isinstance(smiles_str, str):
38
- #single smiles
39
- v_d = MolToGraph(smiles_str)
40
- v_d = v_d.to(self.device)
41
- smiles_tensor = self.model.smiles_encoder(v_d)
42
- smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch)
43
- smiles_tensor = self.model.smi_proj(smiles_tensor)
44
- smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
45
- return smiles_tensor
46
- else:
47
- #smiles list
48
- graphs=[]
49
- for smi in smiles_str:
50
- v_d = MolToGraph(smi)
51
- graphs.append(v_d)
52
- v_ds = Batch.from_data_list(graphs)
53
- v_ds = v_ds.to(self.device)
54
- smiles_tensor = self.model.smiles_encoder(v_ds)
55
- smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch)
56
- smiles_tensor = self.model.smi_proj(smiles_tensor)
57
- smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
58
- return smiles_tensor
59
-
60
- def ms2_encode(self, ms2_list):
61
- with torch.no_grad():
62
- if not isinstance(ms2_list, list):
63
- #single ms2
64
- spec_mz = ms2_list.mz
65
- spec_intens = ms2_list.intensities
66
- num_peak = len(spec_mz)
67
- spec_mz = np.around(spec_mz, decimals=4)
68
- spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
69
- spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
70
- spec_mz= torch.tensor(spec_mz).float().unsqueeze(0)
71
- spec_intens= torch.tensor(spec_intens).float().unsqueeze(0)
72
- num_peak = torch.LongTensor(num_peak).unsqueeze(0)
73
- spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak)
74
- spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
75
- spec_tensor = self.model.spec_proj(spec_tensor)
76
- spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
77
- return spec_tensor
78
- else:
79
- # batch ms2
80
- spec_mzs = [spec.mz for spec in ms2_list]
81
- spec_intens = [spec.intensities for spec in ms2_list]
82
- num_peaks = [len(i) for i in spec_mzs]
83
- spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs]
84
- num_peaks = torch.LongTensor(num_peaks)
85
- mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs]
86
- intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens]
87
- mzs_tensors = torch.nn.utils.rnn.pad_sequence(
88
- mzs, batch_first=True, padding_value=0
89
- )
90
- intens_tensors = torch.nn.utils.rnn.pad_sequence(
91
- intens, batch_first=True, padding_value=0
92
- )
93
- mzs_tensors=mzs_tensors.to(self.device)
94
- intens_tensors=intens_tensors.to(self.device)
95
- num_peaks=num_peaks.to(self.device)
96
-
97
- spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks)
98
- spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
99
- spec_tensor = self.model.spec_proj(spec_tensor)
100
- spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
101
- return spec_tensor
102
-
103
- def get_cos_distance(self, input_1, input_2):
104
- with torch.no_grad():
105
- return input_1 @ input_2.t()
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Sep 15 16:22:05 2022
4
+
5
+ @author: ZNDX002
6
+ """
7
+ from model import ModelCLR
8
+ import yaml
9
+ import os
10
+ import torch
11
+ import numpy as np
12
+ import re
13
+ from torch_geometric.data import Data, Batch
14
+ from dataloader.dataset_wrapper import MolToGraph
15
+ from rdkit import Chem
16
+
17
+ class ModelInference(object):
18
+ def __init__(self, config_path, pretrain_model_path, device):
19
+ assert (config_path is not None, "config_path is None")
20
+ assert (pretrain_model_path is not None, "pretrain_model_path is None")
21
+
22
+ if device is None:
23
+ self.device = torch.device(
24
+ "cuda" if torch.cuda.is_available() else "cpu")
25
+ else:
26
+ self.device = torch.device(device)
27
+
28
+ self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
29
+ self.model = ModelCLR(**self.config["model_config"]).to(self.device)
30
+ state_dict = torch.load(pretrain_model_path,map_location=self.device)
31
+ self.model.load_state_dict(state_dict)
32
+ self.model.to(device)
33
+ self.model.eval()
34
+
35
+
36
+ def smiles_encode(self, smiles_str):
37
+ with torch.no_grad():
38
+ if isinstance(smiles_str, str):
39
+ #single smiles
40
+ v_d = MolToGraph(smiles_str)
41
+ v_d = v_d.to(self.device)
42
+ smiles_tensor = self.model.smiles_encoder(v_d)
43
+ smiles_tensor=self.model.smi_esa(smiles_tensor,v_d.batch)
44
+ smiles_tensor = self.model.smi_proj(smiles_tensor)
45
+ smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
46
+ return smiles_tensor
47
+ else:
48
+ #smiles list
49
+ graphs=[]
50
+ for smi in smiles_str:
51
+ v_d = MolToGraph(smi)
52
+ graphs.append(v_d)
53
+ v_ds = Batch.from_data_list(graphs)
54
+ v_ds = v_ds.to(self.device)
55
+ smiles_tensor = self.model.smiles_encoder(v_ds)
56
+ smiles_tensor=self.model.smi_esa(smiles_tensor,v_ds.batch)
57
+ smiles_tensor = self.model.smi_proj(smiles_tensor)
58
+ smiles_tensor = smiles_tensor/smiles_tensor.norm(dim=-1, keepdim=True)
59
+ return smiles_tensor
60
+
61
+ def ms2_encode(self, ms2_list):
62
+ with torch.no_grad():
63
+ if not isinstance(ms2_list, list):
64
+ #single ms2
65
+ spec_mz = ms2_list.mz
66
+ spec_intens = ms2_list.intensities
67
+ num_peak = len(spec_mz)
68
+ spec_mz = np.around(spec_mz, decimals=4)
69
+ spec_mz = np.pad(spec_mz, (0, 300 - len(spec_mz)), mode='constant', constant_values=0)
70
+ spec_intens = np.pad(spec_intens, (0, 300 - len(spec_intens)), mode='constant', constant_values=0)
71
+ spec_mz= torch.tensor(spec_mz).float().unsqueeze(0)
72
+ spec_intens= torch.tensor(spec_intens).float().unsqueeze(0)
73
+ num_peak = torch.LongTensor(num_peak).unsqueeze(0)
74
+ spec_tensor,spec_mask = self.model.ms_encoder(spec_mz,spec_intens,num_peak)
75
+ spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
76
+ spec_tensor = self.model.spec_proj(spec_tensor)
77
+ spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
78
+ return spec_tensor
79
+ else:
80
+ # batch ms2
81
+ spec_mzs = [spec.mz for spec in ms2_list]
82
+ spec_intens = [spec.intensities for spec in ms2_list]
83
+ num_peaks = [len(i) for i in spec_mzs]
84
+ spec_mzs = [np.around(spec_mz, decimals=4) for spec_mz in spec_mzs]
85
+ num_peaks = torch.LongTensor(num_peaks)
86
+ mzs = [torch.from_numpy(spec_mz).float() for spec_mz in spec_mzs]
87
+ intens = [torch.from_numpy(spec_intens).float() for spec_intens in spec_intens]
88
+ mzs_tensors = torch.nn.utils.rnn.pad_sequence(
89
+ mzs, batch_first=True, padding_value=0
90
+ )
91
+ intens_tensors = torch.nn.utils.rnn.pad_sequence(
92
+ intens, batch_first=True, padding_value=0
93
+ )
94
+ mzs_tensors=mzs_tensors.to(self.device)
95
+ intens_tensors=intens_tensors.to(self.device)
96
+ num_peaks=num_peaks.to(self.device)
97
+
98
+ spec_tensor,spec_mask = self.model.ms_encoder(mzs_tensors,intens_tensors,num_peaks)
99
+ spec_tensor=self.model.spec_esa(spec_tensor,spec_mask)
100
+ spec_tensor = self.model.spec_proj(spec_tensor)
101
+ spec_tensor = spec_tensor/spec_tensor.norm(dim=-1, keepdim=True)
102
+ return spec_tensor
103
+
104
+ def get_cos_distance(self, input_1, input_2):
105
+ with torch.no_grad():
106
+ return input_1 @ input_2.t()