Songyou commited on
Commit
f3b11f9
·
1 Parent(s): 9581086

add files

Browse files
common/__init__.py ADDED
File without changes
common/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class Data_Type(Enum):
4
+ frag = 'frag'
5
+ whole = 'whole'
configuration/__init__.py ADDED
File without changes
configuration/config_default.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ # Data
4
+ DATA_DEFAULT = {
5
+ 'max_sequence_length': 256,
6
+ 'padding_value': 0
7
+ }
8
+
9
+ # Properties
10
+ PROPERTIES = ['pki']
11
+
12
+
13
+ # For Test_Property test
14
+ LOD_MIN = 1.0
15
+ LOD_MAX = 3.4
16
+
17
+
18
+
configuration/opts.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Implementation of all available options """
2
+ from __future__ import print_function
3
+
4
+
5
+ def train_opts(parser):
6
+ # Transformer or Seq2Seq
7
+ parser.add_argument('--model-choice', required=True, help="transformer or seq2seq")
8
+ # Common training options
9
+ group = parser.add_argument_group('Training_options')
10
+ group.add_argument('--batch-size', type=int, default=512,
11
+ help='Batch size for training')
12
+ group.add_argument('--num-epoch', type=int, default=200,
13
+ help='Number of training steps')
14
+ group.add_argument('--starting-epoch', type=int, default=1,
15
+ help="Training from given starting epoch")
16
+ # Input output settings
17
+ group = parser.add_argument_group('Input-Output')
18
+ group.add_argument('--data-path', required=True,
19
+ help="""Input data path""")
20
+ group.add_argument('--save-directory', default='finetune-TLR7',
21
+ help="""Result save directory""")
22
+
23
+ subparsers = parser.add_subparsers()
24
+ transformer_parser = subparsers.add_parser('transformer')
25
+ train_opts_transformer(transformer_parser)
26
+
27
+ seq2seq_parser = subparsers.add_parser('seq2seq')
28
+ train_opts_seq2seq(seq2seq_parser)
29
+
30
+ def train_opts_transformer(parser):
31
+ # Model architecture options
32
+ group = parser.add_argument_group('Model')
33
+ group.add_argument('--vocab-path', required=False, default='',
34
+ help="vocab path for finetuning")
35
+ group.add_argument('--pretrain-path', default='',
36
+ help="pretrain directory")
37
+ group.add_argument('-N', type=int, default=6,
38
+ help="number of encoder and decoder")
39
+ group.add_argument('-H', type=int, default=8,
40
+ help="heads of attention")
41
+ group.add_argument('-d-model', type=int, default=128,
42
+ help="embedding dimension, model dimension")
43
+ group.add_argument('-d-ff', type=int, default=2048,
44
+ help="dimension in feed forward network")
45
+ # Regularization
46
+ group.add_argument('--dropout', type=float, default=0.1,
47
+ help="Dropout probability; applied in LSTM stacks.")
48
+ group.add_argument('--label-smoothing', type=float, default=0.0,
49
+ help="""Label smoothing value epsilon.
50
+ Probabilities of all non-true labels
51
+ will be smoothed by epsilon / (vocab_size - 1).
52
+ Set to zero to turn off label smoothing.
53
+ For more detailed information, see:
54
+ https://arxiv.org/abs/1512.00567""")
55
+ # Optimization options
56
+ group = parser.add_argument_group('Optimization')
57
+ group.add_argument('--factor', type=float, default=1.0,
58
+ help="""Factor multiplied to the learning rate scheduler formula in NoamOpt.
59
+ For more information about the formula,
60
+ see paper Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf""")
61
+ group.add_argument('--warmup-steps', type=int, default=4000,
62
+ help="""Number of warmup steps for custom decay.""")
63
+ group.add_argument('--adam-beta1', type=float, default=0.9,
64
+ help="""The beta1 parameter for Adam optimizer""")
65
+ group.add_argument('--adam-beta2', type=float, default=0.98,
66
+ help="""The beta2 parameter for Adam optimizer""")
67
+ group.add_argument('--adam-eps', type=float, default=1e-9,
68
+ help="""The eps parameter for Adam optimizer""")
69
+
70
+
71
+ def train_opts_seq2seq(parser):
72
+ # Model architecture options
73
+ group = parser.add_argument_group('Model')
74
+ group.add_argument("--num-layers", "-l", help="Number of RNN layers of the model",
75
+ default=5, type=int)
76
+ group.add_argument("--layer-size", "-s", help="Size of each of the RNN layers",
77
+ default=512, type=int)
78
+ group.add_argument("--cell-type", "-c",
79
+ help="Type of cell used in RNN [gru, lstm]",
80
+ default='lstm', type=str)
81
+ group.add_argument("--embedding-layer-size", "-e", help="Size of the embedding layer",
82
+ default=256, type=int)
83
+ group.add_argument("--dropout", "-d", help="Amount of dropout between layers ",
84
+ default=0.3, type=float)
85
+ group.add_argument("--bidirectional", "--bi", help="Encoder bidirectional", action="store_false")
86
+ group.add_argument("--bidirect-model",
87
+ help="Method to use encoder hidden state for initialising decoder['concat', 'addition', 'none']",
88
+ default='addition', type=str)
89
+ group.add_argument("--attn-model", help="Attention model ['dot', 'general', 'concat']",
90
+ default='dot', type=str)
91
+ # Optimization options
92
+ group = parser.add_argument_group('Optimization')
93
+ group.add_argument('--learning-rate', type=float, default=0.0001,
94
+ help="""Starting learning rate""")
95
+ group.add_argument("--clip-gradient-norm", help="Clip gradients to a given norm",
96
+ default=1.0, type=float)
97
+
98
+
99
+ def generate_opts(parser):
100
+ # Transformer or Seq2Seq
101
+ parser.add_argument('--model-choice', required=True, help="transformer or seq2seq")
102
+ """Input output settings"""
103
+ group = parser.add_argument_group('Input-Output')
104
+ group.add_argument('--data-path', required=True,
105
+ help="""Input data path""")
106
+ group.add_argument('--test-file-name', required=True, help="""test file name without .csv,
107
+ [test, test_not_in_train, test_unseen_L-1_S01_C10_range]""")
108
+ group.add_argument('--save-directory', default='evaluation',
109
+ help="""Result save directory""")
110
+ group.add_argument('--vocab-path', required=False, default='',
111
+ help="vocab path for finetuning")
112
+ # Model to be used for generating molecules
113
+ group = parser.add_argument_group('Model')
114
+ group.add_argument('--model-path', help="""Model path""", required=True)
115
+ group.add_argument('--epoch', type=int, help="""Which epoch to use""", required=True)
116
+ # General
117
+ group = parser.add_argument_group('General')
118
+ group.add_argument('--batch-size', type=int, default=64,
119
+ help='Batch size for training')
120
+ group.add_argument('--num-samples', type=int, default=50,
121
+ help='Number of molecules to be generated')
122
+ group.add_argument('--decode-type',type=str, default='multinomial',help='decode strategy')
123
+ group.add_argument('--dev-no',type=int, default=0, help='using device')
124
+ group.add_argument('--overwrite',type=bool, default=False, help='whether overwrite exist file')
125
+
126
+
127
+ def evaluation_opts(parser):
128
+ """Evaluation options (compute properties)"""
129
+ group = parser.add_argument_group('General')
130
+ group.add_argument('--data-path', required=True,
131
+ help="""Input data path for generated molecules""")
132
+ group.add_argument('--num-samples', type=int, default=10,
133
+ help='Number of molecules generated')
134
+ group = parser.add_argument_group('Evaluation')
135
+ group.add_argument('--range-evaluation', default='',
136
+ help='[ , lower, higher]; set lower when evaluating test_unseen_L-1_S01_C10_range')
137
+ group = parser.add_argument_group('MMP')
138
+ group.add_argument('--mmpdb-path', help='mmpdb path; download from https://github.com/rdkit/mmpdb')
139
+ group.add_argument('--train-path', help='Training data path')
140
+ group.add_argument('--only-desirable', help='Only check generated molecules with desirable properties',
141
+ action="store_true")
models/__init__.py ADDED
File without changes
models/dataset.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """
4
+ Implementation of a SMILES dataset.
5
+ """
6
+ import pandas as pd
7
+
8
+ import torch
9
+ import torch.utils.data as tud
10
+ from torch.autograd import Variable
11
+
12
+ import configuration.config_default as cfgd
13
+ from models.transformer.module.subsequent_mask import subsequent_mask
14
+
15
+ from rdkit.Chem.SaltRemover import SaltRemover
16
+ import random
17
+ import rdkit.Chem as rkc
18
+ from common.utils import Data_Type
19
+
20
+ class Dataset(tud.Dataset):
21
+ """Custom PyTorch Dataset that takes a file containing
22
+ Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol,
23
+ Source_Mol_LogD,Target_Mol_LogD,Delta_LogD,
24
+ Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility,
25
+ Source_Mol_Clint,Target_Mol_Clint,Delta_Clint,
26
+ Transformation,Core"""
27
+
28
+ def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value):
29
+ """
30
+
31
+ :param data: dataframe read from training, validation or test file
32
+ :param vocabulary: used to encode source/target tokens
33
+ :param tokenizer: used to tokenize source/target smiles
34
+ :param prediction_mode: if use target smiles or not (training or test)
35
+ """
36
+ self._vocabulary = vocabulary
37
+ self._tokenizer = tokenizer
38
+ self._data = data
39
+ self._prediction_mode = prediction_mode
40
+ self._use_random = use_random
41
+ self._data_type = data_type
42
+
43
+ def smiles_preprocess(self, smiles, random_type="unrestricted"):
44
+ """
45
+ Returns a random SMILES given a SMILES of a molecule.
46
+ :param mol: A Mol object
47
+ :param random_type: The type (unrestricted, restricted) of randomization performed.
48
+ :return : A random SMILES string of the same molecule or None if the molecule is invalid.
49
+ """
50
+ if not self._use_random:
51
+ return smiles
52
+ mol = rkc.MolFromSmiles(smiles)
53
+ if not mol:
54
+ return None
55
+
56
+ remover = SaltRemover() ## default salt remover
57
+ if random_type == "unrestricted":
58
+ stripped = remover.StripMol(mol)
59
+ if stripped == None:
60
+ return smiles
61
+ ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False)
62
+ if not bool(ret):
63
+ return smiles
64
+ return ret
65
+ if random_type == "restricted":
66
+ new_atom_order = list(range(mol.GetNumAtoms()))
67
+ random.shuffle(new_atom_order)
68
+ random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
69
+ ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
70
+ if not bool(ret):
71
+ return smiles
72
+ return ret
73
+ raise ValueError("Type '{}' is not valid".format(random_type))
74
+
75
+
76
+ def __getitem__(self, i):
77
+ """
78
+ Tokenize and encode source smile and/or target smile (if prediction_mode is True)
79
+ :param i:
80
+ :return:
81
+ """
82
+
83
+ row = self._data.iloc[i]
84
+ # tokenize and encode source smiles
85
+ main_cls = row['main_cls']
86
+ minor_cls = row['minor_cls']
87
+ target_name = row['target_name']
88
+ target_name = target_name if isinstance(target_name, str) else ''
89
+ value = row['Delta_Value']
90
+ # value = row['Delta_pki']
91
+ source_tokens = []
92
+
93
+ if self._data_type == Data_Type.frag.value:
94
+ sourceConstant = self.smiles_preprocess(row['constantSMILES'])
95
+ sourceVariable = self.smiles_preprocess(row['fromVarSMILES'])
96
+ # 先variable
97
+ source_tokens.extend(self._tokenizer.tokenize(sourceVariable)) ## add source variable SMILES token
98
+ # 接着constant
99
+ source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token
100
+ elif self._data_type == Data_Type.whole.value:
101
+ sourceSmi = self.smiles_preprocess(row['cpd1SMILES'])
102
+ source_tokens.extend(self._tokenizer.tokenize(sourceSmi))
103
+ # 再 major class eg activity
104
+ source_tokens.append(main_cls)
105
+ # 再 minor class eg Ki
106
+ source_tokens.append(minor_cls)
107
+ # 然后value
108
+ source_tokens.append(value)
109
+ # 然后target name
110
+ source_tokens.extend(list(target_name))
111
+
112
+ source_encoded = self._vocabulary.encode(source_tokens)
113
+
114
+ # print(source_tokens,'\n=====\n', source_encoded)
115
+ # tokenize and encode target smiles if it is for training instead of evaluation
116
+ if not self._prediction_mode:
117
+ target_smi = ''
118
+ if self._data_type == Data_Type.frag.value:
119
+ target_smi = row['toVarSMILES']
120
+ elif self._data_type == Data_Type.whole.value:
121
+ target_smi = row['cpd2SMILES']
122
+ target_tokens = self._tokenizer.tokenize(target_smi)
123
+ target_encoded = self._vocabulary.encode(target_tokens)
124
+
125
+ return torch.tensor(source_encoded, dtype=torch.long), torch.tensor(
126
+ target_encoded, dtype=torch.long), row
127
+ else:
128
+ return torch.tensor(source_encoded, dtype=torch.long), row
129
+
130
+ def __len__(self):
131
+ return len(self._data)
132
+
133
+ @classmethod
134
+ def collate_fn(cls, data_all):
135
+ # sort based on source sequence's length
136
+ data_all.sort(key=lambda x: len(x[0]), reverse=True)
137
+ is_prediction_mode = True if len(data_all[0]) == 2 else False
138
+ if is_prediction_mode:
139
+ source_encoded, data = zip(*data_all)
140
+ data = pd.DataFrame(data)
141
+
142
+ else:
143
+ source_encoded, target_encoded, data = zip(*data_all)
144
+ data = pd.DataFrame(data)
145
+
146
+ # maximum length of source sequences
147
+ max_length_source = max([seq.size(0) for seq in source_encoded])
148
+ # print('=====max len', max_length_source)
149
+ # padded source sequences with zeroes
150
+ collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long)
151
+ for i, seq in enumerate(source_encoded):
152
+ collated_arr_source[i, :seq.size(0)] = seq
153
+ # length of each source sequence
154
+ source_length = [seq.size(0) for seq in source_encoded]
155
+ source_length = torch.tensor(source_length)
156
+ # mask of source seqs
157
+ src_mask = (collated_arr_source !=0).unsqueeze(-2)
158
+
159
+ # target seq
160
+ if not is_prediction_mode:
161
+ max_length_target = max([seq.size(0) for seq in target_encoded])
162
+ collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long)
163
+ for i, seq in enumerate(target_encoded):
164
+ collated_arr_target[i, :seq.size(0)] = seq
165
+
166
+ trg_mask = (collated_arr_target != 0).unsqueeze(-2)
167
+ trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask))
168
+ trg_mask = trg_mask[:, :-1, :-1] # save start token, skip end token
169
+ else:
170
+ trg_mask = None
171
+ max_length_target = None
172
+ collated_arr_target = None
173
+
174
+ return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data
175
+
models/transformer/__init__.py ADDED
File without changes
models/transformer/encode_decode/__init__.py ADDED
File without changes
models/transformer/encode_decode/clones.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ def clones(module, N):
7
+ "Produce N identical layers."
8
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
models/transformer/encode_decode/decoder.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from models.transformer.encode_decode.clones import clones
4
+ from models.transformer.encode_decode.layer_norm import LayerNorm
5
+
6
+
7
+ class Decoder(nn.Module):
8
+ "Generic N layer decoder with masking."
9
+
10
+ def __init__(self, layer, N):
11
+ super(Decoder, self).__init__()
12
+ self.layers = clones(layer, N)
13
+ self.norm = LayerNorm(layer.size)
14
+
15
+ def forward(self, x, memory, src_mask, tgt_mask):
16
+
17
+
18
+ memory = memory
19
+
20
+ for layer in self.layers:
21
+ x = layer(x, memory, src_mask, tgt_mask)
22
+
23
+ return self.norm(x)
models/transformer/encode_decode/decoder_layer.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from models.transformer.encode_decode.clones import clones
4
+ from models.transformer.encode_decode.sublayer_connection import SublayerConnection
5
+
6
+ class DecoderLayer(nn.Module):
7
+ "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
8
+
9
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
10
+ super(DecoderLayer, self).__init__()
11
+ self.size = size
12
+ self.self_attn = self_attn
13
+ self.src_attn = src_attn
14
+ self.feed_forward = feed_forward
15
+ self.sublayer = clones(SublayerConnection(size, dropout), 3)
16
+
17
+ def forward(self, x, memory, src_mask, tgt_mask):
18
+ "Follow Figure 1 (right) for connections."
19
+
20
+ m = memory
21
+
22
+ x = self.sublayer[0](x, lambda x: self.self_attn(
23
+ x, x, x, tgt_mask))
24
+ x = self.sublayer[1](x, lambda x: self.src_attn(
25
+ x, m, m, src_mask))
26
+
27
+ return self.sublayer[2](x, self.feed_forward)
28
+
models/transformer/encode_decode/encoder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from models.transformer.encode_decode.clones import clones
4
+ from models.transformer.encode_decode.layer_norm import LayerNorm
5
+
6
+
7
+ class Encoder(nn.Module):
8
+ "Core encoder is a stack of N layers"
9
+
10
+ def __init__(self, layer, N):
11
+ super(Encoder, self).__init__()
12
+ self.layers = clones(layer, N)
13
+ self.norm = LayerNorm(layer.size)
14
+
15
+ def forward(self, x, mask):
16
+ "Pass the input (and mask) through each layer in turn."
17
+
18
+ for layer in self.layers:
19
+ x = layer(x, mask)
20
+ return self.norm(x)
models/transformer/encode_decode/encoder_layer.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from models.transformer.encode_decode.clones import clones
4
+ from models.transformer.encode_decode.sublayer_connection import SublayerConnection
5
+
6
+
7
+ class EncoderLayer(nn.Module):
8
+ "Encoder is made up of self-attn and feed forward (defined below)"
9
+
10
+ def __init__(self, size, self_attn, feed_forward, dropout):
11
+ super(EncoderLayer, self).__init__()
12
+ self.self_attn = self_attn
13
+ self.feed_forward = feed_forward
14
+ self.sublayer = clones(SublayerConnection(size, dropout), 2)
15
+ self.size = size
16
+
17
+ def forward(self, x, mask):
18
+ "Follow Figure 1 (left) for connections."
19
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
20
+ return self.sublayer[1](x, self.feed_forward)
models/transformer/encode_decode/layer_norm.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class LayerNorm(nn.Module):
6
+ "Construct a layernorm module (See citation for details)."
7
+
8
+ def __init__(self, features, eps=1e-6):
9
+ super(LayerNorm, self).__init__()
10
+ self.a_2 = nn.Parameter(torch.ones(features))
11
+ self.b_2 = nn.Parameter(torch.zeros(features))
12
+ self.eps = eps
13
+
14
+ def forward(self, x):
15
+ mean = x.mean(-1, keepdim=True)
16
+ std = x.std(-1, keepdim=True)
17
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
models/transformer/encode_decode/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import copy
4
+
5
+ from models.transformer.module.positional_encoding import PositionalEncoding
6
+ from models.transformer.module.positionwise_feedforward import PositionwiseFeedForward
7
+ from models.transformer.module.multi_headed_attention import MultiHeadedAttention
8
+ from models.transformer.module.embeddings import Embeddings
9
+ from models.transformer.encode_decode.encoder import Encoder
10
+ from models.transformer.encode_decode.decoder import Decoder
11
+ from models.transformer.encode_decode.encoder_layer import EncoderLayer
12
+ from models.transformer.encode_decode.decoder_layer import DecoderLayer
13
+ from models.transformer.module.generator import Generator
14
+
15
+ class EncoderDecoder(nn.Module):
16
+ """
17
+ A standard Encoder-Decoder architecture.
18
+ """
19
+
20
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
21
+ super(EncoderDecoder, self).__init__()
22
+ self.encoder = encoder
23
+ self.decoder = decoder
24
+ self.src_embed = src_embed
25
+ self.tgt_embed = tgt_embed
26
+ self.generator = generator
27
+
28
+ def forward(self, src, tgt, src_mask, tgt_mask):
29
+ "Take in and process masked src and target sequences."
30
+ return self.decode(self.encode(src, src_mask), src_mask,
31
+ tgt, tgt_mask)
32
+
33
+ def encode(self, src, src_mask):
34
+ return self.encoder(self.src_embed(src), src_mask)
35
+
36
+ def decode(self, memory, src_mask, tgt, tgt_mask):
37
+ return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
38
+
39
+ @classmethod
40
+ def make_model(cls, src_vocab, tgt_vocab, N=6,
41
+ d_model=256, d_ff=2048, h=8, dropout=0.1):
42
+ "Helper: Construct a model from hyperparameters."
43
+ c = copy.deepcopy
44
+ attn = MultiHeadedAttention(h, d_model)
45
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout)
46
+ position = PositionalEncoding(d_model, dropout)
47
+ model = EncoderDecoder(
48
+ Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
49
+ Decoder(DecoderLayer(d_model, c(attn), c(attn),
50
+ c(ff), dropout), N),
51
+ nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
52
+ nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
53
+ Generator(d_model, tgt_vocab))
54
+
55
+ # This was important from their code.
56
+ # Initialize parameters with Glorot / fan_avg.
57
+ for p in model.parameters():
58
+ if p.dim() > 1:
59
+ nn.init.xavier_uniform(p)
60
+
61
+ return model
62
+
63
+ @classmethod
64
+ def load_from_file(cls, file_path):
65
+ # Load model
66
+ checkpoint = torch.load(file_path, map_location='cuda:0')
67
+ para_dict = checkpoint['model_parameters']
68
+ vocab_size = para_dict['vocab_size']
69
+ model = EncoderDecoder.make_model(vocab_size, vocab_size, para_dict['N'],
70
+ para_dict['d_model'], para_dict['d_ff'],
71
+ para_dict['H'], para_dict['dropout'])
72
+ model.load_state_dict(checkpoint['model_state_dict'])
73
+ return model
models/transformer/encode_decode/sublayer_connection.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.transformer.encode_decode.layer_norm import LayerNorm
3
+
4
+
5
+ class SublayerConnection(nn.Module):
6
+ """
7
+ A residual connection followed by a layer norm.
8
+ Note for code simplicity the norm is first as opposed to last.
9
+ """
10
+
11
+ def __init__(self, size, dropout):
12
+ super(SublayerConnection, self).__init__()
13
+ self.norm = LayerNorm(size)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, x, sublayer):
17
+ "Apply residual connection to any sublayer with the same size."
18
+ return x + self.dropout(sublayer(self.norm(x)))
models/transformer/module/__init__.py ADDED
File without changes
models/transformer/module/decode.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+
4
+ from models.transformer.module.subsequent_mask import subsequent_mask
5
+
6
+
7
+ def decode(model, src, src_mask, max_len, type):
8
+ ys = torch.ones(1)
9
+ ys = ys.repeat(src.shape[0], 1).view(src.shape[0], 1).type_as(src.data)
10
+ # ys shape [batch_size, 1]
11
+ encoder_outputs = model.encode(src, src_mask)
12
+ break_condition = torch.zeros(src.shape[0], dtype=torch.bool)
13
+ for i in range(max_len-1):
14
+ with torch.no_grad():
15
+ out = model.decode(encoder_outputs, src_mask, Variable(ys),
16
+ Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
17
+
18
+ log_prob = model.generator(out[:, -1])
19
+ prob = torch.exp(log_prob)
20
+
21
+ if type == 'greedy':
22
+ _, next_word = torch.max(prob, dim = 1)
23
+ ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # [batch_size, i]
24
+ elif type == 'multinomial':
25
+ next_word = torch.multinomial(prob, 1)
26
+ ys = torch.cat([ys, next_word], dim=1) #[batch_size, i]
27
+ next_word = torch.squeeze(next_word)
28
+
29
+ break_condition = (break_condition | (next_word.to('cpu')==2))
30
+ if all(break_condition): # end token
31
+ break
32
+
33
+ return ys
34
+
models/transformer/module/embeddings.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+
4
+
5
+ class Embeddings(nn.Module):
6
+ def __init__(self, d_model, vocab):
7
+ super(Embeddings, self).__init__()
8
+ # weight matrix, each row present one word
9
+ self.lut = nn.Embedding(vocab, d_model)
10
+ self.d_model = d_model
11
+
12
+ def forward(self, x):
13
+ return self.lut(x) * math.sqrt(self.d_model)
models/transformer/module/generator.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class Generator(nn.Module):
6
+ "Define standard linear + softmax generation step."
7
+
8
+ def __init__(self, d_model, vocab):
9
+ super(Generator, self).__init__()
10
+ self.proj = nn.Linear(d_model, vocab)
11
+
12
+ def forward(self, x):
13
+ return F.log_softmax(self.proj(x), dim=-1)
models/transformer/module/label_smoothing.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torch.autograd import Variable
4
+
5
+
6
+ class LabelSmoothing(nn.Module):
7
+ "Implement label smoothing."
8
+
9
+ def __init__(self, size, padding_idx, smoothing=0.00):
10
+ super(LabelSmoothing, self).__init__()
11
+ self.criterion = nn.KLDivLoss(size_average=False)
12
+ self.padding_idx = padding_idx
13
+ self.confidence = 1.0 - smoothing
14
+ self.smoothing = smoothing
15
+ self.size = size
16
+ self.true_dist = None
17
+
18
+ def forward(self, x, target):
19
+ assert x.size(1) == self.size
20
+ true_dist = x.data.clone()
21
+ true_dist.fill_(self.smoothing / (self.size - 2))
22
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
23
+ true_dist[:, self.padding_idx] = 0
24
+ mask = torch.nonzero(target.data == self.padding_idx)
25
+
26
+ if mask.dim() > 0:
27
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
28
+ self.true_dist = true_dist
29
+ return self.criterion(x, Variable(true_dist, requires_grad=False))
models/transformer/module/multi_headed_attention.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.transformer.encode_decode.clones import clones
8
+
9
+
10
+ def attention(query, key, value, mask=None, dropout=None):
11
+ "Compute 'Scaled Dot Product Attention'"
12
+ d_k = query.size(-1)
13
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
14
+ / math.sqrt(d_k)
15
+ if mask is not None:
16
+ scores = scores.masked_fill(mask == 0, -1e9)
17
+ p_attn = F.softmax(scores, dim=-1)
18
+ if dropout is not None:
19
+ p_attn = dropout(p_attn)
20
+ return torch.matmul(p_attn, value), p_attn
21
+
22
+
23
+ class MultiHeadedAttention(nn.Module):
24
+ def __init__(self, h, d_model, dropout=0.1):
25
+ "Take in model size and number of heads."
26
+ super(MultiHeadedAttention, self).__init__()
27
+ assert d_model % h == 0
28
+ # We assume d_v always equals d_k
29
+ self.d_k = d_model // h
30
+ self.h = h
31
+ self.linears = clones(nn.Linear(d_model, d_model), 4)
32
+ self.attn = None
33
+ self.dropout = nn.Dropout(p=dropout)
34
+
35
+ def forward(self, query, key, value, mask=None):
36
+ "Implements Figure 2"
37
+ if mask is not None:
38
+ # Same mask applied to all h heads.
39
+ mask = mask.unsqueeze(1)
40
+ nbatches = query.size(0)
41
+
42
+ # 1) Do all the linear projections in batch from d_model => h x d_k
43
+ query, key, value = \
44
+ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
45
+ for l, x in zip(self.linears, (query, key, value))]
46
+
47
+ # 2) Apply attention on all the projected vectors in batch.
48
+ x, self.attn = attention(query, key, value, mask=mask,
49
+ dropout=self.dropout)
50
+
51
+ # 3) "Concat" using a view and apply a final linear.
52
+ x = x.transpose(1, 2).contiguous() \
53
+ .view(nbatches, -1, self.h * self.d_k)
54
+
55
+ return self.linears[-1](x)
models/transformer/module/noam_opt.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class NoamOpt:
3
+ "Optim wrapper that implements rate."
4
+
5
+ def __init__(self, model_size, factor, warmup, optimizer):
6
+ self.optimizer = optimizer
7
+ self._step = 0
8
+ self.warmup = warmup
9
+ self.factor = factor
10
+ self.model_size = model_size
11
+ self._rate = 0
12
+
13
+ def step(self):
14
+ "Update parameters and rate"
15
+ self._step += 1
16
+ rate = self.rate()
17
+ for p in self.optimizer.param_groups:
18
+ p['lr'] = rate
19
+ self._rate = rate
20
+ self.optimizer.step()
21
+
22
+ def rate(self, step=None):
23
+ "Implement `lrate` above"
24
+ if step is None:
25
+ step = self._step
26
+ return self.factor * \
27
+ (self.model_size ** (-0.5) *
28
+ min(step ** (-0.5), step * self.warmup ** (-1.5)))
29
+
30
+ def save_state_dict(self):
31
+ return {
32
+ 'inner_optimizer_state_dict': self.optimizer.state_dict(),
33
+ 'step': self._step,
34
+ 'warmup': self.warmup,
35
+ 'factor': self.factor,
36
+ 'model_size': self.model_size,
37
+ 'rate': self._rate
38
+ }
39
+
40
+ def load_state_dict(self, state_dict):
41
+ self._rate = state_dict['rate']
42
+ self._step = state_dict['step']
43
+ self.optimizer.load_state_dict(state_dict['inner_optimizer_state_dict'])
44
+
45
+
46
+
47
+
models/transformer/module/positional_encoding.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import math
4
+ from torch.autograd import Variable
5
+
6
+
7
+ class PositionalEncoding(nn.Module):
8
+ "Implement the PE function."
9
+
10
+ def __init__(self, d_model, dropout, max_len=5000):
11
+ super(PositionalEncoding, self).__init__()
12
+ self.dropout = nn.Dropout(p=dropout)
13
+
14
+ # Compute the positional encodings once in log space.
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0., d_model, 2) *
18
+ -(math.log(10000.0) / d_model))
19
+ pe[:, 0::2] = torch.sin(position * div_term)
20
+ pe[:, 1::2] = torch.cos(position * div_term)
21
+ pe = pe.unsqueeze(0)
22
+ self.register_buffer('pe', pe)
23
+
24
+ def forward(self, x):
25
+ x = x + Variable(self.pe[:, :x.size(1)],
26
+ requires_grad=False)
27
+ return self.dropout(x)
models/transformer/module/positionwise_feedforward.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class PositionwiseFeedForward(nn.Module):
6
+ "Implements FFN equation."
7
+
8
+ def __init__(self, d_model, d_ff, dropout=0.1):
9
+ super(PositionwiseFeedForward, self).__init__()
10
+ self.w_1 = nn.Linear(d_model, d_ff)
11
+ self.w_2 = nn.Linear(d_ff, d_model)
12
+ self.dropout = nn.Dropout(dropout)
13
+
14
+ def forward(self, x):
15
+ return self.w_2(self.dropout(F.relu(self.w_1(x))))
models/transformer/module/simpleloss_compute.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class SimpleLossCompute:
3
+ "A simple loss compute and train function."
4
+
5
+ def __init__(self, generator, loss_function, opt):
6
+ self.generator = generator
7
+ self.loss_function = loss_function
8
+ self.opt = opt
9
+
10
+ def __call__(self, x, y, norm):
11
+
12
+ x = self.generator(x)
13
+
14
+ loss = self.loss_function(x.contiguous().view(-1, x.size(-1)),
15
+ y.contiguous().view(-1)) / norm
16
+
17
+ if self.opt is not None:
18
+ loss.backward()
19
+ self.opt.step()
20
+ self.opt.optimizer.zero_grad()
21
+ # print("loss from simplelosscompute:",loss)
22
+ # print("norm from simplelosscompute:",norm)
23
+ return loss.data * norm
models/transformer/module/subsequent_mask.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def subsequent_mask(size):
6
+ "Mask out subsequent positions."
7
+ attn_shape = (1, size, size)
8
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
9
+ return torch.from_numpy(subsequent_mask) == 0
10
+
preprocess/__init__.py ADDED
File without changes
preprocess/data_preparation.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+
4
+ from sklearn.model_selection import train_test_split
5
+
6
+ import utils.file as uf
7
+ import configuration.config_default as cfgd
8
+ import preprocess.property_change_encoder as pce
9
+
10
+ SEED = 42
11
+ # SPLIT_RATIO = 0.8
12
+
13
+
14
+ def get_smiles_list(file_name):
15
+ """
16
+ Get smiles list for building vocabulary
17
+ :param file_name:
18
+ :return:
19
+ """
20
+ pd_data = pd.read_csv(file_name, sep=",")
21
+
22
+ print("Read %s file" % file_name)
23
+ # ravel('K') 是将二维数组展成一维
24
+ smiles_list = pd.unique(pd_data[['constantSMILES', 'fromVarSMILES', 'toVarSMILES']].values.ravel('K'))
25
+ print("Number of SMILES in chemical transformations: %d" % len(smiles_list))
26
+
27
+ return smiles_list
28
+
29
+ def split_data(input_transformations_path,SPLIT_RATIO, LOG=None):
30
+ """
31
+ Split data into training, validation and test set, write to files
32
+ :param input_transformations_path:L
33
+ :return: dataframe
34
+ """
35
+ data = pd.read_csv(input_transformations_path, sep=",")
36
+ if LOG:
37
+ LOG.info("Read %s file" % input_transformations_path)
38
+
39
+ train, test = train_test_split(
40
+ data, test_size=(1-SPLIT_RATIO)/2, random_state=SEED)
41
+ train, validation = train_test_split(train, test_size=(1-SPLIT_RATIO)/2, random_state=SEED)
42
+ if LOG:
43
+ LOG.info("Train, Validation, Test: %d, %d, %d" % (len(train), len(validation), len(test)))
44
+
45
+ parent = uf.get_parent_dir(input_transformations_path)
46
+ train.to_csv(os.path.join(parent, "train.csv"), index=False)
47
+ validation.to_csv(os.path.join(parent, "validation.csv"), index=False)
48
+ test.to_csv(os.path.join(parent, "test.csv"), index=False)
49
+
50
+ return train, validation, test
51
+
52
+ def save_df_property_encoded(file_name, property_change_encoder, LOG=None):
53
+ data = pd.read_csv(file_name, sep=",")
54
+ for property_name in cfgd.PROPERTIES:
55
+ if property_name == 'pki':
56
+ encoder, start_map_interval = property_change_encoder[property_name]
57
+ data['Delta_{}'.format(property_name)] = \
58
+ data['Delta_{}'.format(property_name)].apply(lambda x:
59
+ pce.value_in_interval(x, start_map_interval), encoder)
60
+
61
+ elif property_name == 'qed':
62
+ encoder, start_map_interval = property_change_encoder[property_name]
63
+ data['Delta_{}'.format(property_name)] = \
64
+ data['Delta_{}'.format(property_name)].apply(lambda x:
65
+ pce.value_in_interval(x, start_map_interval), encoder)
66
+
67
+ elif property_name == 'sa':
68
+ encoder, start_map_interval = property_change_encoder[property_name]
69
+ data['Delta_{}'.format(property_name)] = \
70
+ data['Delta_{}'.format(property_name)].apply(lambda x:
71
+ pce.value_in_interval(x, start_map_interval), encoder)
72
+
73
+ output_file = file_name.split('.csv')[0] + '_encoded.csv'
74
+ LOG.info("Saving encoded property change to file: {}".format(output_file))
75
+ data.to_csv(output_file, index=False)
76
+ return output_file
77
+
78
+ def prop_change(source, target, threshold):
79
+ if source <= threshold and target > threshold:
80
+ return "low->high"
81
+ elif source > threshold and target <= threshold:
82
+ return "high->low"
83
+ elif source <= threshold and target <= threshold:
84
+ return "no_change"
85
+ elif source > threshold and target > threshold:
86
+ return "no_change"
preprocess/property_change_encoder.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ import configuration.config_default as cfgd
5
+
6
+ STEP_pki = 1
7
+
8
+
9
+ def encode_property_change(input_data_path, LOG=None):
10
+ property_change_encoder = {}
11
+ for property_name in cfgd.PROPERTIES:
12
+ if property_name == 'pki':
13
+ # intervals ['(3,4]', ...] 形式
14
+ # start_map_interval <number, interval> 用来找区间
15
+ intervals, start_map_interval = build_intervals(input_data_path, step=STEP_pki, LOG=LOG)
16
+
17
+ if property_name == 'pki':
18
+ property_change_encoder[property_name] = intervals, start_map_interval
19
+
20
+ return property_change_encoder
21
+
22
+
23
+ def value_in_interval(value, start_map_interval):
24
+ start_vals = sorted(list(start_map_interval.keys()))
25
+ return start_map_interval[start_vals[np.searchsorted(start_vals, value, side='right') - 1]]
26
+
27
+
28
+ def interval_to_onehot(interval, encoder):
29
+ return encoder.transform([interval]).toarray()[0]
30
+
31
+
32
+ def build_intervals(input_transformations_path, step=STEP_pki, LOG=None):
33
+ df = pd.read_csv(input_transformations_path)
34
+ # df=input_transformations_path
35
+ delta_pki = df['Delta_pki'].tolist()
36
+ min_val, max_val = min(delta_pki), max(delta_pki)
37
+ if LOG:
38
+ LOG.info("pki min and max: {}, {}".format(min_val, max_val))
39
+
40
+ start_map_interval = {}
41
+ interval_str = '({}, {}]'.format(round(-step/2, 2), round(step/2, 2))
42
+ intervals = [interval_str]
43
+ start_map_interval[-step/2] = interval_str
44
+
45
+ smallStep=step
46
+ bigStep=2
47
+ positives = step/2
48
+ while positives < 10:
49
+ if positives>2:
50
+ step=bigStep
51
+ interval_str = '({}, {}]'.format(round(positives, 2), round(positives+step, 2))
52
+ intervals.append(interval_str)
53
+ start_map_interval[positives] = interval_str
54
+ positives += step
55
+
56
+
57
+ interval_str = '({}, inf]'.format(round(positives, 2))
58
+ intervals.append(interval_str)
59
+ start_map_interval[positives] = interval_str
60
+
61
+ step=smallStep
62
+ negatives = -step/2
63
+ while negatives > min_val:
64
+ interval_str = '({}, {}]'.format(round(negatives-step, 2), round(negatives, 2))
65
+ intervals.append(interval_str)
66
+ negatives -= step
67
+ start_map_interval[negatives] = interval_str
68
+ interval_str = '(-inf, {}]'.format(round(negatives, 2))
69
+ intervals.append(interval_str)
70
+ start_map_interval[float('-inf')] = interval_str
71
+
72
+ return intervals, start_map_interval
73
+
preprocess/vocabulary.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ """
4
+ Vocabulary helper class
5
+ """
6
+
7
+ import re
8
+ import numpy as np
9
+
10
+ class Vocabulary:
11
+ """Stores the tokens and their conversion to one-hot vectors."""
12
+
13
+ def __init__(self, tokens=None, starting_id=0):
14
+ self._tokens = {}
15
+ self._current_id = starting_id
16
+
17
+ if tokens:
18
+ for token, idx in tokens.items():
19
+ self._add(token, idx)
20
+ self._current_id = max(self._current_id, idx + 1)
21
+
22
+ def __getitem__(self, token_or_id):
23
+ return self._tokens[token_or_id]
24
+
25
+ def add(self, token):
26
+ """Adds a token."""
27
+ if not isinstance(token, str):
28
+ raise TypeError("Token is not a string")
29
+ if token in self:
30
+ # raise ValueError("Token already present in the vocabulary")
31
+ print(f'=== Token "{token}"already present in the vocabulary')
32
+ return
33
+ self._add(token, self._current_id)
34
+ self._current_id += 1
35
+ return self._current_id - 1
36
+
37
+ def update(self, tokens):
38
+ """Adds many tokens."""
39
+ return [self.add(token) for token in tokens]
40
+
41
+ def __delitem__(self, token_or_id):
42
+ other_val = self._tokens[token_or_id]
43
+ del self._tokens[other_val]
44
+ del self._tokens[token_or_id]
45
+
46
+ def __contains__(self, token_or_id):
47
+ return token_or_id in self._tokens
48
+
49
+ def __eq__(self, other_vocabulary):
50
+ return self._tokens == other_vocabulary._tokens
51
+
52
+ def __len__(self):
53
+ return len(self._tokens) // 2
54
+
55
+ def encode(self, tokens):
56
+ """Encodes a list of tokens, encoding them in 1-hot encoded vectors."""
57
+ ohe_vect = np.zeros(len(tokens), dtype=np.float32)
58
+ for i, token in enumerate(tokens):
59
+ try:
60
+ ohe_vect[i] = self._tokens[token]
61
+ except KeyError:
62
+ ohe_vect[i] = self._tokens["default_key"]
63
+ return ohe_vect
64
+
65
+ def decode(self, ohe_vect):
66
+ """Decodes a one-hot encoded vector matrix to a list of tokens."""
67
+ tokens = []
68
+ for ohv in ohe_vect:
69
+ try:
70
+ tokens.append(self[ohv])
71
+ except KeyError:
72
+ tokens.append("default_key")
73
+ return tokens
74
+
75
+ def _add(self, token, idx):
76
+ if idx not in self._tokens:
77
+ self._tokens[token] = idx
78
+ self._tokens[idx] = token
79
+ else:
80
+ raise ValueError("IDX already present in vocabulary")
81
+
82
+ def tokens(self):
83
+ """Returns the tokens from the vocabulary"""
84
+ return [t for t in self._tokens if isinstance(t, str)]
85
+
86
+ def word2idx(self):
87
+ return {k: self._tokens[k] for k in self._tokens if isinstance(k, str)}
88
+
89
+
90
+ class SMILESTokenizer:
91
+ """Deals with the tokenization and untokenization of SMILES."""
92
+
93
+ REGEXPS = {
94
+ "brackets": re.compile(r"(\[[^\]]*\])"),
95
+ "2_ring_nums": re.compile(r"(%\d{2})"),
96
+ "brcl": re.compile(r"(Br|Cl)")
97
+ }
98
+ REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"]
99
+
100
+ def tokenize(self, data, with_begin_and_end=True):
101
+ """Tokenizes a SMILES string."""
102
+ def split_by(data, regexps):
103
+ if not regexps:
104
+ return list(data)
105
+ regexp = self.REGEXPS[regexps[0]]
106
+ splitted = regexp.split(data)
107
+ tokens = []
108
+ for i, split in enumerate(splitted):
109
+ if i % 2 == 0:
110
+ tokens += split_by(split, regexps[1:])
111
+ else:
112
+ tokens.append(split)
113
+ return tokens
114
+
115
+ tokens = split_by(data, self.REGEXP_ORDER)
116
+ if with_begin_and_end:
117
+ tokens = ["^"] + tokens + ["$"]
118
+ return tokens
119
+
120
+ def untokenize(self, tokens):
121
+ """Untokenizes a SMILES string."""
122
+ smi = ""
123
+ for token in tokens:
124
+ if token == "$":
125
+ break
126
+ if token != "^":
127
+ smi += token
128
+ return smi
129
+
130
+
131
+ def create_vocabulary(smiles_list, tokenizer, property_condition=None):
132
+ """Creates a vocabulary for the SMILES syntax."""
133
+ tokens = set()
134
+ for smi in smiles_list:
135
+ tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False))
136
+
137
+ vocabulary = Vocabulary()
138
+ vocabulary.update(["*", "^", "$"] + sorted(tokens)) # pad=0, start=1, end=2
139
+ if property_condition is not None:
140
+ vocabulary.update(property_condition)
141
+ # for random smiles
142
+ if "8" not in vocabulary.tokens():
143
+ vocabulary.update(["8"])
144
+
145
+ return vocabulary