Spaces:
Sleeping
Sleeping
add files
Browse files- common/__init__.py +0 -0
- common/utils.py +5 -0
- configuration/__init__.py +0 -0
- configuration/config_default.py +18 -0
- configuration/opts.py +141 -0
- models/__init__.py +0 -0
- models/dataset.py +175 -0
- models/transformer/__init__.py +0 -0
- models/transformer/encode_decode/__init__.py +0 -0
- models/transformer/encode_decode/clones.py +8 -0
- models/transformer/encode_decode/decoder.py +23 -0
- models/transformer/encode_decode/decoder_layer.py +28 -0
- models/transformer/encode_decode/encoder.py +20 -0
- models/transformer/encode_decode/encoder_layer.py +20 -0
- models/transformer/encode_decode/layer_norm.py +17 -0
- models/transformer/encode_decode/model.py +73 -0
- models/transformer/encode_decode/sublayer_connection.py +18 -0
- models/transformer/module/__init__.py +0 -0
- models/transformer/module/decode.py +34 -0
- models/transformer/module/embeddings.py +13 -0
- models/transformer/module/generator.py +13 -0
- models/transformer/module/label_smoothing.py +29 -0
- models/transformer/module/multi_headed_attention.py +55 -0
- models/transformer/module/noam_opt.py +47 -0
- models/transformer/module/positional_encoding.py +27 -0
- models/transformer/module/positionwise_feedforward.py +15 -0
- models/transformer/module/simpleloss_compute.py +23 -0
- models/transformer/module/subsequent_mask.py +10 -0
- preprocess/__init__.py +0 -0
- preprocess/data_preparation.py +86 -0
- preprocess/property_change_encoder.py +73 -0
- preprocess/vocabulary.py +145 -0
common/__init__.py
ADDED
|
File without changes
|
common/utils.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
class Data_Type(Enum):
|
| 4 |
+
frag = 'frag'
|
| 5 |
+
whole = 'whole'
|
configuration/__init__.py
ADDED
|
File without changes
|
configuration/config_default.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
# Data
|
| 4 |
+
DATA_DEFAULT = {
|
| 5 |
+
'max_sequence_length': 256,
|
| 6 |
+
'padding_value': 0
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
# Properties
|
| 10 |
+
PROPERTIES = ['pki']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# For Test_Property test
|
| 14 |
+
LOD_MIN = 1.0
|
| 15 |
+
LOD_MAX = 3.4
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
configuration/opts.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Implementation of all available options """
|
| 2 |
+
from __future__ import print_function
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def train_opts(parser):
|
| 6 |
+
# Transformer or Seq2Seq
|
| 7 |
+
parser.add_argument('--model-choice', required=True, help="transformer or seq2seq")
|
| 8 |
+
# Common training options
|
| 9 |
+
group = parser.add_argument_group('Training_options')
|
| 10 |
+
group.add_argument('--batch-size', type=int, default=512,
|
| 11 |
+
help='Batch size for training')
|
| 12 |
+
group.add_argument('--num-epoch', type=int, default=200,
|
| 13 |
+
help='Number of training steps')
|
| 14 |
+
group.add_argument('--starting-epoch', type=int, default=1,
|
| 15 |
+
help="Training from given starting epoch")
|
| 16 |
+
# Input output settings
|
| 17 |
+
group = parser.add_argument_group('Input-Output')
|
| 18 |
+
group.add_argument('--data-path', required=True,
|
| 19 |
+
help="""Input data path""")
|
| 20 |
+
group.add_argument('--save-directory', default='finetune-TLR7',
|
| 21 |
+
help="""Result save directory""")
|
| 22 |
+
|
| 23 |
+
subparsers = parser.add_subparsers()
|
| 24 |
+
transformer_parser = subparsers.add_parser('transformer')
|
| 25 |
+
train_opts_transformer(transformer_parser)
|
| 26 |
+
|
| 27 |
+
seq2seq_parser = subparsers.add_parser('seq2seq')
|
| 28 |
+
train_opts_seq2seq(seq2seq_parser)
|
| 29 |
+
|
| 30 |
+
def train_opts_transformer(parser):
|
| 31 |
+
# Model architecture options
|
| 32 |
+
group = parser.add_argument_group('Model')
|
| 33 |
+
group.add_argument('--vocab-path', required=False, default='',
|
| 34 |
+
help="vocab path for finetuning")
|
| 35 |
+
group.add_argument('--pretrain-path', default='',
|
| 36 |
+
help="pretrain directory")
|
| 37 |
+
group.add_argument('-N', type=int, default=6,
|
| 38 |
+
help="number of encoder and decoder")
|
| 39 |
+
group.add_argument('-H', type=int, default=8,
|
| 40 |
+
help="heads of attention")
|
| 41 |
+
group.add_argument('-d-model', type=int, default=128,
|
| 42 |
+
help="embedding dimension, model dimension")
|
| 43 |
+
group.add_argument('-d-ff', type=int, default=2048,
|
| 44 |
+
help="dimension in feed forward network")
|
| 45 |
+
# Regularization
|
| 46 |
+
group.add_argument('--dropout', type=float, default=0.1,
|
| 47 |
+
help="Dropout probability; applied in LSTM stacks.")
|
| 48 |
+
group.add_argument('--label-smoothing', type=float, default=0.0,
|
| 49 |
+
help="""Label smoothing value epsilon.
|
| 50 |
+
Probabilities of all non-true labels
|
| 51 |
+
will be smoothed by epsilon / (vocab_size - 1).
|
| 52 |
+
Set to zero to turn off label smoothing.
|
| 53 |
+
For more detailed information, see:
|
| 54 |
+
https://arxiv.org/abs/1512.00567""")
|
| 55 |
+
# Optimization options
|
| 56 |
+
group = parser.add_argument_group('Optimization')
|
| 57 |
+
group.add_argument('--factor', type=float, default=1.0,
|
| 58 |
+
help="""Factor multiplied to the learning rate scheduler formula in NoamOpt.
|
| 59 |
+
For more information about the formula,
|
| 60 |
+
see paper Attention Is All You Need https://arxiv.org/pdf/1706.03762.pdf""")
|
| 61 |
+
group.add_argument('--warmup-steps', type=int, default=4000,
|
| 62 |
+
help="""Number of warmup steps for custom decay.""")
|
| 63 |
+
group.add_argument('--adam-beta1', type=float, default=0.9,
|
| 64 |
+
help="""The beta1 parameter for Adam optimizer""")
|
| 65 |
+
group.add_argument('--adam-beta2', type=float, default=0.98,
|
| 66 |
+
help="""The beta2 parameter for Adam optimizer""")
|
| 67 |
+
group.add_argument('--adam-eps', type=float, default=1e-9,
|
| 68 |
+
help="""The eps parameter for Adam optimizer""")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def train_opts_seq2seq(parser):
|
| 72 |
+
# Model architecture options
|
| 73 |
+
group = parser.add_argument_group('Model')
|
| 74 |
+
group.add_argument("--num-layers", "-l", help="Number of RNN layers of the model",
|
| 75 |
+
default=5, type=int)
|
| 76 |
+
group.add_argument("--layer-size", "-s", help="Size of each of the RNN layers",
|
| 77 |
+
default=512, type=int)
|
| 78 |
+
group.add_argument("--cell-type", "-c",
|
| 79 |
+
help="Type of cell used in RNN [gru, lstm]",
|
| 80 |
+
default='lstm', type=str)
|
| 81 |
+
group.add_argument("--embedding-layer-size", "-e", help="Size of the embedding layer",
|
| 82 |
+
default=256, type=int)
|
| 83 |
+
group.add_argument("--dropout", "-d", help="Amount of dropout between layers ",
|
| 84 |
+
default=0.3, type=float)
|
| 85 |
+
group.add_argument("--bidirectional", "--bi", help="Encoder bidirectional", action="store_false")
|
| 86 |
+
group.add_argument("--bidirect-model",
|
| 87 |
+
help="Method to use encoder hidden state for initialising decoder['concat', 'addition', 'none']",
|
| 88 |
+
default='addition', type=str)
|
| 89 |
+
group.add_argument("--attn-model", help="Attention model ['dot', 'general', 'concat']",
|
| 90 |
+
default='dot', type=str)
|
| 91 |
+
# Optimization options
|
| 92 |
+
group = parser.add_argument_group('Optimization')
|
| 93 |
+
group.add_argument('--learning-rate', type=float, default=0.0001,
|
| 94 |
+
help="""Starting learning rate""")
|
| 95 |
+
group.add_argument("--clip-gradient-norm", help="Clip gradients to a given norm",
|
| 96 |
+
default=1.0, type=float)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def generate_opts(parser):
|
| 100 |
+
# Transformer or Seq2Seq
|
| 101 |
+
parser.add_argument('--model-choice', required=True, help="transformer or seq2seq")
|
| 102 |
+
"""Input output settings"""
|
| 103 |
+
group = parser.add_argument_group('Input-Output')
|
| 104 |
+
group.add_argument('--data-path', required=True,
|
| 105 |
+
help="""Input data path""")
|
| 106 |
+
group.add_argument('--test-file-name', required=True, help="""test file name without .csv,
|
| 107 |
+
[test, test_not_in_train, test_unseen_L-1_S01_C10_range]""")
|
| 108 |
+
group.add_argument('--save-directory', default='evaluation',
|
| 109 |
+
help="""Result save directory""")
|
| 110 |
+
group.add_argument('--vocab-path', required=False, default='',
|
| 111 |
+
help="vocab path for finetuning")
|
| 112 |
+
# Model to be used for generating molecules
|
| 113 |
+
group = parser.add_argument_group('Model')
|
| 114 |
+
group.add_argument('--model-path', help="""Model path""", required=True)
|
| 115 |
+
group.add_argument('--epoch', type=int, help="""Which epoch to use""", required=True)
|
| 116 |
+
# General
|
| 117 |
+
group = parser.add_argument_group('General')
|
| 118 |
+
group.add_argument('--batch-size', type=int, default=64,
|
| 119 |
+
help='Batch size for training')
|
| 120 |
+
group.add_argument('--num-samples', type=int, default=50,
|
| 121 |
+
help='Number of molecules to be generated')
|
| 122 |
+
group.add_argument('--decode-type',type=str, default='multinomial',help='decode strategy')
|
| 123 |
+
group.add_argument('--dev-no',type=int, default=0, help='using device')
|
| 124 |
+
group.add_argument('--overwrite',type=bool, default=False, help='whether overwrite exist file')
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def evaluation_opts(parser):
|
| 128 |
+
"""Evaluation options (compute properties)"""
|
| 129 |
+
group = parser.add_argument_group('General')
|
| 130 |
+
group.add_argument('--data-path', required=True,
|
| 131 |
+
help="""Input data path for generated molecules""")
|
| 132 |
+
group.add_argument('--num-samples', type=int, default=10,
|
| 133 |
+
help='Number of molecules generated')
|
| 134 |
+
group = parser.add_argument_group('Evaluation')
|
| 135 |
+
group.add_argument('--range-evaluation', default='',
|
| 136 |
+
help='[ , lower, higher]; set lower when evaluating test_unseen_L-1_S01_C10_range')
|
| 137 |
+
group = parser.add_argument_group('MMP')
|
| 138 |
+
group.add_argument('--mmpdb-path', help='mmpdb path; download from https://github.com/rdkit/mmpdb')
|
| 139 |
+
group.add_argument('--train-path', help='Training data path')
|
| 140 |
+
group.add_argument('--only-desirable', help='Only check generated molecules with desirable properties',
|
| 141 |
+
action="store_true")
|
models/__init__.py
ADDED
|
File without changes
|
models/dataset.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Implementation of a SMILES dataset.
|
| 5 |
+
"""
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data as tud
|
| 10 |
+
from torch.autograd import Variable
|
| 11 |
+
|
| 12 |
+
import configuration.config_default as cfgd
|
| 13 |
+
from models.transformer.module.subsequent_mask import subsequent_mask
|
| 14 |
+
|
| 15 |
+
from rdkit.Chem.SaltRemover import SaltRemover
|
| 16 |
+
import random
|
| 17 |
+
import rdkit.Chem as rkc
|
| 18 |
+
from common.utils import Data_Type
|
| 19 |
+
|
| 20 |
+
class Dataset(tud.Dataset):
|
| 21 |
+
"""Custom PyTorch Dataset that takes a file containing
|
| 22 |
+
Source_Mol_ID,Target_Mol_ID,Source_Mol,Target_Mol,
|
| 23 |
+
Source_Mol_LogD,Target_Mol_LogD,Delta_LogD,
|
| 24 |
+
Source_Mol_Solubility,Target_Mol_Solubility,Delta_Solubility,
|
| 25 |
+
Source_Mol_Clint,Target_Mol_Clint,Delta_Clint,
|
| 26 |
+
Transformation,Core"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, data, vocabulary, tokenizer, prediction_mode=False, use_random=False, data_type=Data_Type.frag.value):
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
:param data: dataframe read from training, validation or test file
|
| 32 |
+
:param vocabulary: used to encode source/target tokens
|
| 33 |
+
:param tokenizer: used to tokenize source/target smiles
|
| 34 |
+
:param prediction_mode: if use target smiles or not (training or test)
|
| 35 |
+
"""
|
| 36 |
+
self._vocabulary = vocabulary
|
| 37 |
+
self._tokenizer = tokenizer
|
| 38 |
+
self._data = data
|
| 39 |
+
self._prediction_mode = prediction_mode
|
| 40 |
+
self._use_random = use_random
|
| 41 |
+
self._data_type = data_type
|
| 42 |
+
|
| 43 |
+
def smiles_preprocess(self, smiles, random_type="unrestricted"):
|
| 44 |
+
"""
|
| 45 |
+
Returns a random SMILES given a SMILES of a molecule.
|
| 46 |
+
:param mol: A Mol object
|
| 47 |
+
:param random_type: The type (unrestricted, restricted) of randomization performed.
|
| 48 |
+
:return : A random SMILES string of the same molecule or None if the molecule is invalid.
|
| 49 |
+
"""
|
| 50 |
+
if not self._use_random:
|
| 51 |
+
return smiles
|
| 52 |
+
mol = rkc.MolFromSmiles(smiles)
|
| 53 |
+
if not mol:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
remover = SaltRemover() ## default salt remover
|
| 57 |
+
if random_type == "unrestricted":
|
| 58 |
+
stripped = remover.StripMol(mol)
|
| 59 |
+
if stripped == None:
|
| 60 |
+
return smiles
|
| 61 |
+
ret = rkc.MolToSmiles(stripped, canonical=False, doRandom=True, isomericSmiles=False)
|
| 62 |
+
if not bool(ret):
|
| 63 |
+
return smiles
|
| 64 |
+
return ret
|
| 65 |
+
if random_type == "restricted":
|
| 66 |
+
new_atom_order = list(range(mol.GetNumAtoms()))
|
| 67 |
+
random.shuffle(new_atom_order)
|
| 68 |
+
random_mol = rkc.RenumberAtoms(mol, newOrder=new_atom_order)
|
| 69 |
+
ret = rkc.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
|
| 70 |
+
if not bool(ret):
|
| 71 |
+
return smiles
|
| 72 |
+
return ret
|
| 73 |
+
raise ValueError("Type '{}' is not valid".format(random_type))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def __getitem__(self, i):
|
| 77 |
+
"""
|
| 78 |
+
Tokenize and encode source smile and/or target smile (if prediction_mode is True)
|
| 79 |
+
:param i:
|
| 80 |
+
:return:
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
row = self._data.iloc[i]
|
| 84 |
+
# tokenize and encode source smiles
|
| 85 |
+
main_cls = row['main_cls']
|
| 86 |
+
minor_cls = row['minor_cls']
|
| 87 |
+
target_name = row['target_name']
|
| 88 |
+
target_name = target_name if isinstance(target_name, str) else ''
|
| 89 |
+
value = row['Delta_Value']
|
| 90 |
+
# value = row['Delta_pki']
|
| 91 |
+
source_tokens = []
|
| 92 |
+
|
| 93 |
+
if self._data_type == Data_Type.frag.value:
|
| 94 |
+
sourceConstant = self.smiles_preprocess(row['constantSMILES'])
|
| 95 |
+
sourceVariable = self.smiles_preprocess(row['fromVarSMILES'])
|
| 96 |
+
# 先variable
|
| 97 |
+
source_tokens.extend(self._tokenizer.tokenize(sourceVariable)) ## add source variable SMILES token
|
| 98 |
+
# 接着constant
|
| 99 |
+
source_tokens.extend(self._tokenizer.tokenize(sourceConstant)) ## add source constant SMILES token
|
| 100 |
+
elif self._data_type == Data_Type.whole.value:
|
| 101 |
+
sourceSmi = self.smiles_preprocess(row['cpd1SMILES'])
|
| 102 |
+
source_tokens.extend(self._tokenizer.tokenize(sourceSmi))
|
| 103 |
+
# 再 major class eg activity
|
| 104 |
+
source_tokens.append(main_cls)
|
| 105 |
+
# 再 minor class eg Ki
|
| 106 |
+
source_tokens.append(minor_cls)
|
| 107 |
+
# 然后value
|
| 108 |
+
source_tokens.append(value)
|
| 109 |
+
# 然后target name
|
| 110 |
+
source_tokens.extend(list(target_name))
|
| 111 |
+
|
| 112 |
+
source_encoded = self._vocabulary.encode(source_tokens)
|
| 113 |
+
|
| 114 |
+
# print(source_tokens,'\n=====\n', source_encoded)
|
| 115 |
+
# tokenize and encode target smiles if it is for training instead of evaluation
|
| 116 |
+
if not self._prediction_mode:
|
| 117 |
+
target_smi = ''
|
| 118 |
+
if self._data_type == Data_Type.frag.value:
|
| 119 |
+
target_smi = row['toVarSMILES']
|
| 120 |
+
elif self._data_type == Data_Type.whole.value:
|
| 121 |
+
target_smi = row['cpd2SMILES']
|
| 122 |
+
target_tokens = self._tokenizer.tokenize(target_smi)
|
| 123 |
+
target_encoded = self._vocabulary.encode(target_tokens)
|
| 124 |
+
|
| 125 |
+
return torch.tensor(source_encoded, dtype=torch.long), torch.tensor(
|
| 126 |
+
target_encoded, dtype=torch.long), row
|
| 127 |
+
else:
|
| 128 |
+
return torch.tensor(source_encoded, dtype=torch.long), row
|
| 129 |
+
|
| 130 |
+
def __len__(self):
|
| 131 |
+
return len(self._data)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def collate_fn(cls, data_all):
|
| 135 |
+
# sort based on source sequence's length
|
| 136 |
+
data_all.sort(key=lambda x: len(x[0]), reverse=True)
|
| 137 |
+
is_prediction_mode = True if len(data_all[0]) == 2 else False
|
| 138 |
+
if is_prediction_mode:
|
| 139 |
+
source_encoded, data = zip(*data_all)
|
| 140 |
+
data = pd.DataFrame(data)
|
| 141 |
+
|
| 142 |
+
else:
|
| 143 |
+
source_encoded, target_encoded, data = zip(*data_all)
|
| 144 |
+
data = pd.DataFrame(data)
|
| 145 |
+
|
| 146 |
+
# maximum length of source sequences
|
| 147 |
+
max_length_source = max([seq.size(0) for seq in source_encoded])
|
| 148 |
+
# print('=====max len', max_length_source)
|
| 149 |
+
# padded source sequences with zeroes
|
| 150 |
+
collated_arr_source = torch.zeros(len(source_encoded), max_length_source, dtype=torch.long)
|
| 151 |
+
for i, seq in enumerate(source_encoded):
|
| 152 |
+
collated_arr_source[i, :seq.size(0)] = seq
|
| 153 |
+
# length of each source sequence
|
| 154 |
+
source_length = [seq.size(0) for seq in source_encoded]
|
| 155 |
+
source_length = torch.tensor(source_length)
|
| 156 |
+
# mask of source seqs
|
| 157 |
+
src_mask = (collated_arr_source !=0).unsqueeze(-2)
|
| 158 |
+
|
| 159 |
+
# target seq
|
| 160 |
+
if not is_prediction_mode:
|
| 161 |
+
max_length_target = max([seq.size(0) for seq in target_encoded])
|
| 162 |
+
collated_arr_target = torch.zeros(len(target_encoded), max_length_target, dtype=torch.long)
|
| 163 |
+
for i, seq in enumerate(target_encoded):
|
| 164 |
+
collated_arr_target[i, :seq.size(0)] = seq
|
| 165 |
+
|
| 166 |
+
trg_mask = (collated_arr_target != 0).unsqueeze(-2)
|
| 167 |
+
trg_mask = trg_mask & Variable(subsequent_mask(collated_arr_target.size(-1)).type_as(trg_mask))
|
| 168 |
+
trg_mask = trg_mask[:, :-1, :-1] # save start token, skip end token
|
| 169 |
+
else:
|
| 170 |
+
trg_mask = None
|
| 171 |
+
max_length_target = None
|
| 172 |
+
collated_arr_target = None
|
| 173 |
+
|
| 174 |
+
return collated_arr_source, source_length, collated_arr_target, src_mask, trg_mask, max_length_target, data
|
| 175 |
+
|
models/transformer/__init__.py
ADDED
|
File without changes
|
models/transformer/encode_decode/__init__.py
ADDED
|
File without changes
|
models/transformer/encode_decode/clones.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def clones(module, N):
|
| 7 |
+
"Produce N identical layers."
|
| 8 |
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
models/transformer/encode_decode/decoder.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from models.transformer.encode_decode.clones import clones
|
| 4 |
+
from models.transformer.encode_decode.layer_norm import LayerNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Decoder(nn.Module):
|
| 8 |
+
"Generic N layer decoder with masking."
|
| 9 |
+
|
| 10 |
+
def __init__(self, layer, N):
|
| 11 |
+
super(Decoder, self).__init__()
|
| 12 |
+
self.layers = clones(layer, N)
|
| 13 |
+
self.norm = LayerNorm(layer.size)
|
| 14 |
+
|
| 15 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
memory = memory
|
| 19 |
+
|
| 20 |
+
for layer in self.layers:
|
| 21 |
+
x = layer(x, memory, src_mask, tgt_mask)
|
| 22 |
+
|
| 23 |
+
return self.norm(x)
|
models/transformer/encode_decode/decoder_layer.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from models.transformer.encode_decode.clones import clones
|
| 4 |
+
from models.transformer.encode_decode.sublayer_connection import SublayerConnection
|
| 5 |
+
|
| 6 |
+
class DecoderLayer(nn.Module):
|
| 7 |
+
"Decoder is made of self-attn, src-attn, and feed forward (defined below)"
|
| 8 |
+
|
| 9 |
+
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
|
| 10 |
+
super(DecoderLayer, self).__init__()
|
| 11 |
+
self.size = size
|
| 12 |
+
self.self_attn = self_attn
|
| 13 |
+
self.src_attn = src_attn
|
| 14 |
+
self.feed_forward = feed_forward
|
| 15 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 3)
|
| 16 |
+
|
| 17 |
+
def forward(self, x, memory, src_mask, tgt_mask):
|
| 18 |
+
"Follow Figure 1 (right) for connections."
|
| 19 |
+
|
| 20 |
+
m = memory
|
| 21 |
+
|
| 22 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(
|
| 23 |
+
x, x, x, tgt_mask))
|
| 24 |
+
x = self.sublayer[1](x, lambda x: self.src_attn(
|
| 25 |
+
x, m, m, src_mask))
|
| 26 |
+
|
| 27 |
+
return self.sublayer[2](x, self.feed_forward)
|
| 28 |
+
|
models/transformer/encode_decode/encoder.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from models.transformer.encode_decode.clones import clones
|
| 4 |
+
from models.transformer.encode_decode.layer_norm import LayerNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Encoder(nn.Module):
|
| 8 |
+
"Core encoder is a stack of N layers"
|
| 9 |
+
|
| 10 |
+
def __init__(self, layer, N):
|
| 11 |
+
super(Encoder, self).__init__()
|
| 12 |
+
self.layers = clones(layer, N)
|
| 13 |
+
self.norm = LayerNorm(layer.size)
|
| 14 |
+
|
| 15 |
+
def forward(self, x, mask):
|
| 16 |
+
"Pass the input (and mask) through each layer in turn."
|
| 17 |
+
|
| 18 |
+
for layer in self.layers:
|
| 19 |
+
x = layer(x, mask)
|
| 20 |
+
return self.norm(x)
|
models/transformer/encode_decode/encoder_layer.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from models.transformer.encode_decode.clones import clones
|
| 4 |
+
from models.transformer.encode_decode.sublayer_connection import SublayerConnection
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EncoderLayer(nn.Module):
|
| 8 |
+
"Encoder is made up of self-attn and feed forward (defined below)"
|
| 9 |
+
|
| 10 |
+
def __init__(self, size, self_attn, feed_forward, dropout):
|
| 11 |
+
super(EncoderLayer, self).__init__()
|
| 12 |
+
self.self_attn = self_attn
|
| 13 |
+
self.feed_forward = feed_forward
|
| 14 |
+
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
| 15 |
+
self.size = size
|
| 16 |
+
|
| 17 |
+
def forward(self, x, mask):
|
| 18 |
+
"Follow Figure 1 (left) for connections."
|
| 19 |
+
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
| 20 |
+
return self.sublayer[1](x, self.feed_forward)
|
models/transformer/encode_decode/layer_norm.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LayerNorm(nn.Module):
|
| 6 |
+
"Construct a layernorm module (See citation for details)."
|
| 7 |
+
|
| 8 |
+
def __init__(self, features, eps=1e-6):
|
| 9 |
+
super(LayerNorm, self).__init__()
|
| 10 |
+
self.a_2 = nn.Parameter(torch.ones(features))
|
| 11 |
+
self.b_2 = nn.Parameter(torch.zeros(features))
|
| 12 |
+
self.eps = eps
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
mean = x.mean(-1, keepdim=True)
|
| 16 |
+
std = x.std(-1, keepdim=True)
|
| 17 |
+
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
models/transformer/encode_decode/model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
from models.transformer.module.positional_encoding import PositionalEncoding
|
| 6 |
+
from models.transformer.module.positionwise_feedforward import PositionwiseFeedForward
|
| 7 |
+
from models.transformer.module.multi_headed_attention import MultiHeadedAttention
|
| 8 |
+
from models.transformer.module.embeddings import Embeddings
|
| 9 |
+
from models.transformer.encode_decode.encoder import Encoder
|
| 10 |
+
from models.transformer.encode_decode.decoder import Decoder
|
| 11 |
+
from models.transformer.encode_decode.encoder_layer import EncoderLayer
|
| 12 |
+
from models.transformer.encode_decode.decoder_layer import DecoderLayer
|
| 13 |
+
from models.transformer.module.generator import Generator
|
| 14 |
+
|
| 15 |
+
class EncoderDecoder(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
A standard Encoder-Decoder architecture.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
|
| 21 |
+
super(EncoderDecoder, self).__init__()
|
| 22 |
+
self.encoder = encoder
|
| 23 |
+
self.decoder = decoder
|
| 24 |
+
self.src_embed = src_embed
|
| 25 |
+
self.tgt_embed = tgt_embed
|
| 26 |
+
self.generator = generator
|
| 27 |
+
|
| 28 |
+
def forward(self, src, tgt, src_mask, tgt_mask):
|
| 29 |
+
"Take in and process masked src and target sequences."
|
| 30 |
+
return self.decode(self.encode(src, src_mask), src_mask,
|
| 31 |
+
tgt, tgt_mask)
|
| 32 |
+
|
| 33 |
+
def encode(self, src, src_mask):
|
| 34 |
+
return self.encoder(self.src_embed(src), src_mask)
|
| 35 |
+
|
| 36 |
+
def decode(self, memory, src_mask, tgt, tgt_mask):
|
| 37 |
+
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def make_model(cls, src_vocab, tgt_vocab, N=6,
|
| 41 |
+
d_model=256, d_ff=2048, h=8, dropout=0.1):
|
| 42 |
+
"Helper: Construct a model from hyperparameters."
|
| 43 |
+
c = copy.deepcopy
|
| 44 |
+
attn = MultiHeadedAttention(h, d_model)
|
| 45 |
+
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
| 46 |
+
position = PositionalEncoding(d_model, dropout)
|
| 47 |
+
model = EncoderDecoder(
|
| 48 |
+
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
|
| 49 |
+
Decoder(DecoderLayer(d_model, c(attn), c(attn),
|
| 50 |
+
c(ff), dropout), N),
|
| 51 |
+
nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
|
| 52 |
+
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
|
| 53 |
+
Generator(d_model, tgt_vocab))
|
| 54 |
+
|
| 55 |
+
# This was important from their code.
|
| 56 |
+
# Initialize parameters with Glorot / fan_avg.
|
| 57 |
+
for p in model.parameters():
|
| 58 |
+
if p.dim() > 1:
|
| 59 |
+
nn.init.xavier_uniform(p)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def load_from_file(cls, file_path):
|
| 65 |
+
# Load model
|
| 66 |
+
checkpoint = torch.load(file_path, map_location='cuda:0')
|
| 67 |
+
para_dict = checkpoint['model_parameters']
|
| 68 |
+
vocab_size = para_dict['vocab_size']
|
| 69 |
+
model = EncoderDecoder.make_model(vocab_size, vocab_size, para_dict['N'],
|
| 70 |
+
para_dict['d_model'], para_dict['d_ff'],
|
| 71 |
+
para_dict['H'], para_dict['dropout'])
|
| 72 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 73 |
+
return model
|
models/transformer/encode_decode/sublayer_connection.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from models.transformer.encode_decode.layer_norm import LayerNorm
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SublayerConnection(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
A residual connection followed by a layer norm.
|
| 8 |
+
Note for code simplicity the norm is first as opposed to last.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, size, dropout):
|
| 12 |
+
super(SublayerConnection, self).__init__()
|
| 13 |
+
self.norm = LayerNorm(size)
|
| 14 |
+
self.dropout = nn.Dropout(dropout)
|
| 15 |
+
|
| 16 |
+
def forward(self, x, sublayer):
|
| 17 |
+
"Apply residual connection to any sublayer with the same size."
|
| 18 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
models/transformer/module/__init__.py
ADDED
|
File without changes
|
models/transformer/module/decode.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Variable
|
| 3 |
+
|
| 4 |
+
from models.transformer.module.subsequent_mask import subsequent_mask
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def decode(model, src, src_mask, max_len, type):
|
| 8 |
+
ys = torch.ones(1)
|
| 9 |
+
ys = ys.repeat(src.shape[0], 1).view(src.shape[0], 1).type_as(src.data)
|
| 10 |
+
# ys shape [batch_size, 1]
|
| 11 |
+
encoder_outputs = model.encode(src, src_mask)
|
| 12 |
+
break_condition = torch.zeros(src.shape[0], dtype=torch.bool)
|
| 13 |
+
for i in range(max_len-1):
|
| 14 |
+
with torch.no_grad():
|
| 15 |
+
out = model.decode(encoder_outputs, src_mask, Variable(ys),
|
| 16 |
+
Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
|
| 17 |
+
|
| 18 |
+
log_prob = model.generator(out[:, -1])
|
| 19 |
+
prob = torch.exp(log_prob)
|
| 20 |
+
|
| 21 |
+
if type == 'greedy':
|
| 22 |
+
_, next_word = torch.max(prob, dim = 1)
|
| 23 |
+
ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) # [batch_size, i]
|
| 24 |
+
elif type == 'multinomial':
|
| 25 |
+
next_word = torch.multinomial(prob, 1)
|
| 26 |
+
ys = torch.cat([ys, next_word], dim=1) #[batch_size, i]
|
| 27 |
+
next_word = torch.squeeze(next_word)
|
| 28 |
+
|
| 29 |
+
break_condition = (break_condition | (next_word.to('cpu')==2))
|
| 30 |
+
if all(break_condition): # end token
|
| 31 |
+
break
|
| 32 |
+
|
| 33 |
+
return ys
|
| 34 |
+
|
models/transformer/module/embeddings.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Embeddings(nn.Module):
|
| 6 |
+
def __init__(self, d_model, vocab):
|
| 7 |
+
super(Embeddings, self).__init__()
|
| 8 |
+
# weight matrix, each row present one word
|
| 9 |
+
self.lut = nn.Embedding(vocab, d_model)
|
| 10 |
+
self.d_model = d_model
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return self.lut(x) * math.sqrt(self.d_model)
|
models/transformer/module/generator.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Generator(nn.Module):
|
| 6 |
+
"Define standard linear + softmax generation step."
|
| 7 |
+
|
| 8 |
+
def __init__(self, d_model, vocab):
|
| 9 |
+
super(Generator, self).__init__()
|
| 10 |
+
self.proj = nn.Linear(d_model, vocab)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return F.log_softmax(self.proj(x), dim=-1)
|
models/transformer/module/label_smoothing.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LabelSmoothing(nn.Module):
|
| 7 |
+
"Implement label smoothing."
|
| 8 |
+
|
| 9 |
+
def __init__(self, size, padding_idx, smoothing=0.00):
|
| 10 |
+
super(LabelSmoothing, self).__init__()
|
| 11 |
+
self.criterion = nn.KLDivLoss(size_average=False)
|
| 12 |
+
self.padding_idx = padding_idx
|
| 13 |
+
self.confidence = 1.0 - smoothing
|
| 14 |
+
self.smoothing = smoothing
|
| 15 |
+
self.size = size
|
| 16 |
+
self.true_dist = None
|
| 17 |
+
|
| 18 |
+
def forward(self, x, target):
|
| 19 |
+
assert x.size(1) == self.size
|
| 20 |
+
true_dist = x.data.clone()
|
| 21 |
+
true_dist.fill_(self.smoothing / (self.size - 2))
|
| 22 |
+
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
| 23 |
+
true_dist[:, self.padding_idx] = 0
|
| 24 |
+
mask = torch.nonzero(target.data == self.padding_idx)
|
| 25 |
+
|
| 26 |
+
if mask.dim() > 0:
|
| 27 |
+
true_dist.index_fill_(0, mask.squeeze(), 0.0)
|
| 28 |
+
self.true_dist = true_dist
|
| 29 |
+
return self.criterion(x, Variable(true_dist, requires_grad=False))
|
models/transformer/module/multi_headed_attention.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from models.transformer.encode_decode.clones import clones
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def attention(query, key, value, mask=None, dropout=None):
|
| 11 |
+
"Compute 'Scaled Dot Product Attention'"
|
| 12 |
+
d_k = query.size(-1)
|
| 13 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
| 14 |
+
/ math.sqrt(d_k)
|
| 15 |
+
if mask is not None:
|
| 16 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 17 |
+
p_attn = F.softmax(scores, dim=-1)
|
| 18 |
+
if dropout is not None:
|
| 19 |
+
p_attn = dropout(p_attn)
|
| 20 |
+
return torch.matmul(p_attn, value), p_attn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MultiHeadedAttention(nn.Module):
|
| 24 |
+
def __init__(self, h, d_model, dropout=0.1):
|
| 25 |
+
"Take in model size and number of heads."
|
| 26 |
+
super(MultiHeadedAttention, self).__init__()
|
| 27 |
+
assert d_model % h == 0
|
| 28 |
+
# We assume d_v always equals d_k
|
| 29 |
+
self.d_k = d_model // h
|
| 30 |
+
self.h = h
|
| 31 |
+
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
| 32 |
+
self.attn = None
|
| 33 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 34 |
+
|
| 35 |
+
def forward(self, query, key, value, mask=None):
|
| 36 |
+
"Implements Figure 2"
|
| 37 |
+
if mask is not None:
|
| 38 |
+
# Same mask applied to all h heads.
|
| 39 |
+
mask = mask.unsqueeze(1)
|
| 40 |
+
nbatches = query.size(0)
|
| 41 |
+
|
| 42 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
| 43 |
+
query, key, value = \
|
| 44 |
+
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
| 45 |
+
for l, x in zip(self.linears, (query, key, value))]
|
| 46 |
+
|
| 47 |
+
# 2) Apply attention on all the projected vectors in batch.
|
| 48 |
+
x, self.attn = attention(query, key, value, mask=mask,
|
| 49 |
+
dropout=self.dropout)
|
| 50 |
+
|
| 51 |
+
# 3) "Concat" using a view and apply a final linear.
|
| 52 |
+
x = x.transpose(1, 2).contiguous() \
|
| 53 |
+
.view(nbatches, -1, self.h * self.d_k)
|
| 54 |
+
|
| 55 |
+
return self.linears[-1](x)
|
models/transformer/module/noam_opt.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class NoamOpt:
|
| 3 |
+
"Optim wrapper that implements rate."
|
| 4 |
+
|
| 5 |
+
def __init__(self, model_size, factor, warmup, optimizer):
|
| 6 |
+
self.optimizer = optimizer
|
| 7 |
+
self._step = 0
|
| 8 |
+
self.warmup = warmup
|
| 9 |
+
self.factor = factor
|
| 10 |
+
self.model_size = model_size
|
| 11 |
+
self._rate = 0
|
| 12 |
+
|
| 13 |
+
def step(self):
|
| 14 |
+
"Update parameters and rate"
|
| 15 |
+
self._step += 1
|
| 16 |
+
rate = self.rate()
|
| 17 |
+
for p in self.optimizer.param_groups:
|
| 18 |
+
p['lr'] = rate
|
| 19 |
+
self._rate = rate
|
| 20 |
+
self.optimizer.step()
|
| 21 |
+
|
| 22 |
+
def rate(self, step=None):
|
| 23 |
+
"Implement `lrate` above"
|
| 24 |
+
if step is None:
|
| 25 |
+
step = self._step
|
| 26 |
+
return self.factor * \
|
| 27 |
+
(self.model_size ** (-0.5) *
|
| 28 |
+
min(step ** (-0.5), step * self.warmup ** (-1.5)))
|
| 29 |
+
|
| 30 |
+
def save_state_dict(self):
|
| 31 |
+
return {
|
| 32 |
+
'inner_optimizer_state_dict': self.optimizer.state_dict(),
|
| 33 |
+
'step': self._step,
|
| 34 |
+
'warmup': self.warmup,
|
| 35 |
+
'factor': self.factor,
|
| 36 |
+
'model_size': self.model_size,
|
| 37 |
+
'rate': self._rate
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def load_state_dict(self, state_dict):
|
| 41 |
+
self._rate = state_dict['rate']
|
| 42 |
+
self._step = state_dict['step']
|
| 43 |
+
self.optimizer.load_state_dict(state_dict['inner_optimizer_state_dict'])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
models/transformer/module/positional_encoding.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import math
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PositionalEncoding(nn.Module):
|
| 8 |
+
"Implement the PE function."
|
| 9 |
+
|
| 10 |
+
def __init__(self, d_model, dropout, max_len=5000):
|
| 11 |
+
super(PositionalEncoding, self).__init__()
|
| 12 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 13 |
+
|
| 14 |
+
# Compute the positional encodings once in log space.
|
| 15 |
+
pe = torch.zeros(max_len, d_model)
|
| 16 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
| 17 |
+
div_term = torch.exp(torch.arange(0., d_model, 2) *
|
| 18 |
+
-(math.log(10000.0) / d_model))
|
| 19 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 20 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 21 |
+
pe = pe.unsqueeze(0)
|
| 22 |
+
self.register_buffer('pe', pe)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = x + Variable(self.pe[:, :x.size(1)],
|
| 26 |
+
requires_grad=False)
|
| 27 |
+
return self.dropout(x)
|
models/transformer/module/positionwise_feedforward.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PositionwiseFeedForward(nn.Module):
|
| 6 |
+
"Implements FFN equation."
|
| 7 |
+
|
| 8 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 9 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 10 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
| 11 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
| 12 |
+
self.dropout = nn.Dropout(dropout)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
models/transformer/module/simpleloss_compute.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class SimpleLossCompute:
|
| 3 |
+
"A simple loss compute and train function."
|
| 4 |
+
|
| 5 |
+
def __init__(self, generator, loss_function, opt):
|
| 6 |
+
self.generator = generator
|
| 7 |
+
self.loss_function = loss_function
|
| 8 |
+
self.opt = opt
|
| 9 |
+
|
| 10 |
+
def __call__(self, x, y, norm):
|
| 11 |
+
|
| 12 |
+
x = self.generator(x)
|
| 13 |
+
|
| 14 |
+
loss = self.loss_function(x.contiguous().view(-1, x.size(-1)),
|
| 15 |
+
y.contiguous().view(-1)) / norm
|
| 16 |
+
|
| 17 |
+
if self.opt is not None:
|
| 18 |
+
loss.backward()
|
| 19 |
+
self.opt.step()
|
| 20 |
+
self.opt.optimizer.zero_grad()
|
| 21 |
+
# print("loss from simplelosscompute:",loss)
|
| 22 |
+
# print("norm from simplelosscompute:",norm)
|
| 23 |
+
return loss.data * norm
|
models/transformer/module/subsequent_mask.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def subsequent_mask(size):
|
| 6 |
+
"Mask out subsequent positions."
|
| 7 |
+
attn_shape = (1, size, size)
|
| 8 |
+
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
| 9 |
+
return torch.from_numpy(subsequent_mask) == 0
|
| 10 |
+
|
preprocess/__init__.py
ADDED
|
File without changes
|
preprocess/data_preparation.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from sklearn.model_selection import train_test_split
|
| 5 |
+
|
| 6 |
+
import utils.file as uf
|
| 7 |
+
import configuration.config_default as cfgd
|
| 8 |
+
import preprocess.property_change_encoder as pce
|
| 9 |
+
|
| 10 |
+
SEED = 42
|
| 11 |
+
# SPLIT_RATIO = 0.8
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_smiles_list(file_name):
|
| 15 |
+
"""
|
| 16 |
+
Get smiles list for building vocabulary
|
| 17 |
+
:param file_name:
|
| 18 |
+
:return:
|
| 19 |
+
"""
|
| 20 |
+
pd_data = pd.read_csv(file_name, sep=",")
|
| 21 |
+
|
| 22 |
+
print("Read %s file" % file_name)
|
| 23 |
+
# ravel('K') 是将二维数组展成一维
|
| 24 |
+
smiles_list = pd.unique(pd_data[['constantSMILES', 'fromVarSMILES', 'toVarSMILES']].values.ravel('K'))
|
| 25 |
+
print("Number of SMILES in chemical transformations: %d" % len(smiles_list))
|
| 26 |
+
|
| 27 |
+
return smiles_list
|
| 28 |
+
|
| 29 |
+
def split_data(input_transformations_path,SPLIT_RATIO, LOG=None):
|
| 30 |
+
"""
|
| 31 |
+
Split data into training, validation and test set, write to files
|
| 32 |
+
:param input_transformations_path:L
|
| 33 |
+
:return: dataframe
|
| 34 |
+
"""
|
| 35 |
+
data = pd.read_csv(input_transformations_path, sep=",")
|
| 36 |
+
if LOG:
|
| 37 |
+
LOG.info("Read %s file" % input_transformations_path)
|
| 38 |
+
|
| 39 |
+
train, test = train_test_split(
|
| 40 |
+
data, test_size=(1-SPLIT_RATIO)/2, random_state=SEED)
|
| 41 |
+
train, validation = train_test_split(train, test_size=(1-SPLIT_RATIO)/2, random_state=SEED)
|
| 42 |
+
if LOG:
|
| 43 |
+
LOG.info("Train, Validation, Test: %d, %d, %d" % (len(train), len(validation), len(test)))
|
| 44 |
+
|
| 45 |
+
parent = uf.get_parent_dir(input_transformations_path)
|
| 46 |
+
train.to_csv(os.path.join(parent, "train.csv"), index=False)
|
| 47 |
+
validation.to_csv(os.path.join(parent, "validation.csv"), index=False)
|
| 48 |
+
test.to_csv(os.path.join(parent, "test.csv"), index=False)
|
| 49 |
+
|
| 50 |
+
return train, validation, test
|
| 51 |
+
|
| 52 |
+
def save_df_property_encoded(file_name, property_change_encoder, LOG=None):
|
| 53 |
+
data = pd.read_csv(file_name, sep=",")
|
| 54 |
+
for property_name in cfgd.PROPERTIES:
|
| 55 |
+
if property_name == 'pki':
|
| 56 |
+
encoder, start_map_interval = property_change_encoder[property_name]
|
| 57 |
+
data['Delta_{}'.format(property_name)] = \
|
| 58 |
+
data['Delta_{}'.format(property_name)].apply(lambda x:
|
| 59 |
+
pce.value_in_interval(x, start_map_interval), encoder)
|
| 60 |
+
|
| 61 |
+
elif property_name == 'qed':
|
| 62 |
+
encoder, start_map_interval = property_change_encoder[property_name]
|
| 63 |
+
data['Delta_{}'.format(property_name)] = \
|
| 64 |
+
data['Delta_{}'.format(property_name)].apply(lambda x:
|
| 65 |
+
pce.value_in_interval(x, start_map_interval), encoder)
|
| 66 |
+
|
| 67 |
+
elif property_name == 'sa':
|
| 68 |
+
encoder, start_map_interval = property_change_encoder[property_name]
|
| 69 |
+
data['Delta_{}'.format(property_name)] = \
|
| 70 |
+
data['Delta_{}'.format(property_name)].apply(lambda x:
|
| 71 |
+
pce.value_in_interval(x, start_map_interval), encoder)
|
| 72 |
+
|
| 73 |
+
output_file = file_name.split('.csv')[0] + '_encoded.csv'
|
| 74 |
+
LOG.info("Saving encoded property change to file: {}".format(output_file))
|
| 75 |
+
data.to_csv(output_file, index=False)
|
| 76 |
+
return output_file
|
| 77 |
+
|
| 78 |
+
def prop_change(source, target, threshold):
|
| 79 |
+
if source <= threshold and target > threshold:
|
| 80 |
+
return "low->high"
|
| 81 |
+
elif source > threshold and target <= threshold:
|
| 82 |
+
return "high->low"
|
| 83 |
+
elif source <= threshold and target <= threshold:
|
| 84 |
+
return "no_change"
|
| 85 |
+
elif source > threshold and target > threshold:
|
| 86 |
+
return "no_change"
|
preprocess/property_change_encoder.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
import configuration.config_default as cfgd
|
| 5 |
+
|
| 6 |
+
STEP_pki = 1
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def encode_property_change(input_data_path, LOG=None):
|
| 10 |
+
property_change_encoder = {}
|
| 11 |
+
for property_name in cfgd.PROPERTIES:
|
| 12 |
+
if property_name == 'pki':
|
| 13 |
+
# intervals ['(3,4]', ...] 形式
|
| 14 |
+
# start_map_interval <number, interval> 用来找区间
|
| 15 |
+
intervals, start_map_interval = build_intervals(input_data_path, step=STEP_pki, LOG=LOG)
|
| 16 |
+
|
| 17 |
+
if property_name == 'pki':
|
| 18 |
+
property_change_encoder[property_name] = intervals, start_map_interval
|
| 19 |
+
|
| 20 |
+
return property_change_encoder
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def value_in_interval(value, start_map_interval):
|
| 24 |
+
start_vals = sorted(list(start_map_interval.keys()))
|
| 25 |
+
return start_map_interval[start_vals[np.searchsorted(start_vals, value, side='right') - 1]]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def interval_to_onehot(interval, encoder):
|
| 29 |
+
return encoder.transform([interval]).toarray()[0]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_intervals(input_transformations_path, step=STEP_pki, LOG=None):
|
| 33 |
+
df = pd.read_csv(input_transformations_path)
|
| 34 |
+
# df=input_transformations_path
|
| 35 |
+
delta_pki = df['Delta_pki'].tolist()
|
| 36 |
+
min_val, max_val = min(delta_pki), max(delta_pki)
|
| 37 |
+
if LOG:
|
| 38 |
+
LOG.info("pki min and max: {}, {}".format(min_val, max_val))
|
| 39 |
+
|
| 40 |
+
start_map_interval = {}
|
| 41 |
+
interval_str = '({}, {}]'.format(round(-step/2, 2), round(step/2, 2))
|
| 42 |
+
intervals = [interval_str]
|
| 43 |
+
start_map_interval[-step/2] = interval_str
|
| 44 |
+
|
| 45 |
+
smallStep=step
|
| 46 |
+
bigStep=2
|
| 47 |
+
positives = step/2
|
| 48 |
+
while positives < 10:
|
| 49 |
+
if positives>2:
|
| 50 |
+
step=bigStep
|
| 51 |
+
interval_str = '({}, {}]'.format(round(positives, 2), round(positives+step, 2))
|
| 52 |
+
intervals.append(interval_str)
|
| 53 |
+
start_map_interval[positives] = interval_str
|
| 54 |
+
positives += step
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
interval_str = '({}, inf]'.format(round(positives, 2))
|
| 58 |
+
intervals.append(interval_str)
|
| 59 |
+
start_map_interval[positives] = interval_str
|
| 60 |
+
|
| 61 |
+
step=smallStep
|
| 62 |
+
negatives = -step/2
|
| 63 |
+
while negatives > min_val:
|
| 64 |
+
interval_str = '({}, {}]'.format(round(negatives-step, 2), round(negatives, 2))
|
| 65 |
+
intervals.append(interval_str)
|
| 66 |
+
negatives -= step
|
| 67 |
+
start_map_interval[negatives] = interval_str
|
| 68 |
+
interval_str = '(-inf, {}]'.format(round(negatives, 2))
|
| 69 |
+
intervals.append(interval_str)
|
| 70 |
+
start_map_interval[float('-inf')] = interval_str
|
| 71 |
+
|
| 72 |
+
return intervals, start_map_interval
|
| 73 |
+
|
preprocess/vocabulary.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Vocabulary helper class
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import re
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
class Vocabulary:
|
| 11 |
+
"""Stores the tokens and their conversion to one-hot vectors."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, tokens=None, starting_id=0):
|
| 14 |
+
self._tokens = {}
|
| 15 |
+
self._current_id = starting_id
|
| 16 |
+
|
| 17 |
+
if tokens:
|
| 18 |
+
for token, idx in tokens.items():
|
| 19 |
+
self._add(token, idx)
|
| 20 |
+
self._current_id = max(self._current_id, idx + 1)
|
| 21 |
+
|
| 22 |
+
def __getitem__(self, token_or_id):
|
| 23 |
+
return self._tokens[token_or_id]
|
| 24 |
+
|
| 25 |
+
def add(self, token):
|
| 26 |
+
"""Adds a token."""
|
| 27 |
+
if not isinstance(token, str):
|
| 28 |
+
raise TypeError("Token is not a string")
|
| 29 |
+
if token in self:
|
| 30 |
+
# raise ValueError("Token already present in the vocabulary")
|
| 31 |
+
print(f'=== Token "{token}"already present in the vocabulary')
|
| 32 |
+
return
|
| 33 |
+
self._add(token, self._current_id)
|
| 34 |
+
self._current_id += 1
|
| 35 |
+
return self._current_id - 1
|
| 36 |
+
|
| 37 |
+
def update(self, tokens):
|
| 38 |
+
"""Adds many tokens."""
|
| 39 |
+
return [self.add(token) for token in tokens]
|
| 40 |
+
|
| 41 |
+
def __delitem__(self, token_or_id):
|
| 42 |
+
other_val = self._tokens[token_or_id]
|
| 43 |
+
del self._tokens[other_val]
|
| 44 |
+
del self._tokens[token_or_id]
|
| 45 |
+
|
| 46 |
+
def __contains__(self, token_or_id):
|
| 47 |
+
return token_or_id in self._tokens
|
| 48 |
+
|
| 49 |
+
def __eq__(self, other_vocabulary):
|
| 50 |
+
return self._tokens == other_vocabulary._tokens
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return len(self._tokens) // 2
|
| 54 |
+
|
| 55 |
+
def encode(self, tokens):
|
| 56 |
+
"""Encodes a list of tokens, encoding them in 1-hot encoded vectors."""
|
| 57 |
+
ohe_vect = np.zeros(len(tokens), dtype=np.float32)
|
| 58 |
+
for i, token in enumerate(tokens):
|
| 59 |
+
try:
|
| 60 |
+
ohe_vect[i] = self._tokens[token]
|
| 61 |
+
except KeyError:
|
| 62 |
+
ohe_vect[i] = self._tokens["default_key"]
|
| 63 |
+
return ohe_vect
|
| 64 |
+
|
| 65 |
+
def decode(self, ohe_vect):
|
| 66 |
+
"""Decodes a one-hot encoded vector matrix to a list of tokens."""
|
| 67 |
+
tokens = []
|
| 68 |
+
for ohv in ohe_vect:
|
| 69 |
+
try:
|
| 70 |
+
tokens.append(self[ohv])
|
| 71 |
+
except KeyError:
|
| 72 |
+
tokens.append("default_key")
|
| 73 |
+
return tokens
|
| 74 |
+
|
| 75 |
+
def _add(self, token, idx):
|
| 76 |
+
if idx not in self._tokens:
|
| 77 |
+
self._tokens[token] = idx
|
| 78 |
+
self._tokens[idx] = token
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError("IDX already present in vocabulary")
|
| 81 |
+
|
| 82 |
+
def tokens(self):
|
| 83 |
+
"""Returns the tokens from the vocabulary"""
|
| 84 |
+
return [t for t in self._tokens if isinstance(t, str)]
|
| 85 |
+
|
| 86 |
+
def word2idx(self):
|
| 87 |
+
return {k: self._tokens[k] for k in self._tokens if isinstance(k, str)}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SMILESTokenizer:
|
| 91 |
+
"""Deals with the tokenization and untokenization of SMILES."""
|
| 92 |
+
|
| 93 |
+
REGEXPS = {
|
| 94 |
+
"brackets": re.compile(r"(\[[^\]]*\])"),
|
| 95 |
+
"2_ring_nums": re.compile(r"(%\d{2})"),
|
| 96 |
+
"brcl": re.compile(r"(Br|Cl)")
|
| 97 |
+
}
|
| 98 |
+
REGEXP_ORDER = ["brackets", "2_ring_nums", "brcl"]
|
| 99 |
+
|
| 100 |
+
def tokenize(self, data, with_begin_and_end=True):
|
| 101 |
+
"""Tokenizes a SMILES string."""
|
| 102 |
+
def split_by(data, regexps):
|
| 103 |
+
if not regexps:
|
| 104 |
+
return list(data)
|
| 105 |
+
regexp = self.REGEXPS[regexps[0]]
|
| 106 |
+
splitted = regexp.split(data)
|
| 107 |
+
tokens = []
|
| 108 |
+
for i, split in enumerate(splitted):
|
| 109 |
+
if i % 2 == 0:
|
| 110 |
+
tokens += split_by(split, regexps[1:])
|
| 111 |
+
else:
|
| 112 |
+
tokens.append(split)
|
| 113 |
+
return tokens
|
| 114 |
+
|
| 115 |
+
tokens = split_by(data, self.REGEXP_ORDER)
|
| 116 |
+
if with_begin_and_end:
|
| 117 |
+
tokens = ["^"] + tokens + ["$"]
|
| 118 |
+
return tokens
|
| 119 |
+
|
| 120 |
+
def untokenize(self, tokens):
|
| 121 |
+
"""Untokenizes a SMILES string."""
|
| 122 |
+
smi = ""
|
| 123 |
+
for token in tokens:
|
| 124 |
+
if token == "$":
|
| 125 |
+
break
|
| 126 |
+
if token != "^":
|
| 127 |
+
smi += token
|
| 128 |
+
return smi
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def create_vocabulary(smiles_list, tokenizer, property_condition=None):
|
| 132 |
+
"""Creates a vocabulary for the SMILES syntax."""
|
| 133 |
+
tokens = set()
|
| 134 |
+
for smi in smiles_list:
|
| 135 |
+
tokens.update(tokenizer.tokenize(smi, with_begin_and_end=False))
|
| 136 |
+
|
| 137 |
+
vocabulary = Vocabulary()
|
| 138 |
+
vocabulary.update(["*", "^", "$"] + sorted(tokens)) # pad=0, start=1, end=2
|
| 139 |
+
if property_condition is not None:
|
| 140 |
+
vocabulary.update(property_condition)
|
| 141 |
+
# for random smiles
|
| 142 |
+
if "8" not in vocabulary.tokens():
|
| 143 |
+
vocabulary.update(["8"])
|
| 144 |
+
|
| 145 |
+
return vocabulary
|