Fix bug missing sigmoid, refactor code, and add docstrings
Browse filesI noticed that when I train this for around 60 epochs, we start to get
edge outputs which are not just 0 (empty). I'm wondering if I try with more
hops whether or not the performance improves significantly.
- LICENSE +21 -0
- README.md +2 -2
- molegen/data.py +0 -3
- molegen/model.py +189 -119
- molegen/train.py +18 -30
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Justin Silver
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# Molecular generation
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
|
|
|
| 1 |
# Molecular generation
|
| 2 |
|
| 3 |
+
I started this project to get hands-on experience with generative models and explore ML for drug discovery.
|
| 4 |
|
| 5 |
+
The code is a replication of the following paper: [A Two-Step Graph Convolutional Decoder for Molecule Generation](https://arxiv.org/abs/1906.03412) by Bresson et Laurent (2019).
|
molegen/data.py
CHANGED
|
@@ -171,9 +171,6 @@ def map_rdkit_bond_types(rdkit_edge_types, verbose=False):
|
|
| 171 |
return our_edge_types
|
| 172 |
|
| 173 |
|
| 174 |
-
#########################################################################################################################
|
| 175 |
-
#########################################################################################################################
|
| 176 |
-
|
| 177 |
def main():
|
| 178 |
tqdm.pandas() # enable progress bars in pandas
|
| 179 |
df = pd.read_csv("https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv", nrows=1000)
|
|
|
|
| 171 |
return our_edge_types
|
| 172 |
|
| 173 |
|
|
|
|
|
|
|
|
|
|
| 174 |
def main():
|
| 175 |
tqdm.pandas() # enable progress bars in pandas
|
| 176 |
df = pd.read_csv("https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv", nrows=1000)
|
molegen/model.py
CHANGED
|
@@ -1,40 +1,28 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
-
from torch_geometric.data import Data
|
| 3 |
from torch_geometric.nn import ResGatedGraphConv
|
| 4 |
from torch_geometric.utils import scatter
|
| 5 |
-
from torch_geometric.utils import dense_to_sparse
|
| 6 |
-
import torch
|
| 7 |
|
| 8 |
-
class
|
| 9 |
-
"""Multi-layer perceptron which
|
| 10 |
|
| 11 |
-
Implementation of Eq. 9 in paper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
def __init__(self, embd=16, vocab_size=8, max_atoms=100):
|
|
|
|
| 14 |
super().__init__()
|
| 15 |
|
| 16 |
self.vocab_size = vocab_size
|
| 17 |
self.max_atoms = max_atoms
|
| 18 |
|
| 19 |
-
# note
|
| 20 |
-
#
|
| 21 |
-
# is 'm' in the paper which means the different atoms we can select
|
| 22 |
-
# and R is basically a one-hot vector that goes up to the maximum number
|
| 23 |
-
# of total atoms in any one molecule in the training set
|
| 24 |
-
#
|
| 25 |
-
# (small optimization here is to not do the total number of atoms in training set
|
| 26 |
-
# but simply the highest number of any particular atom, i.e. take C02:
|
| 27 |
-
# we would only need to have R == 2, since even though we have 3 atoms,
|
| 28 |
-
# the max number for any single atom is only 2, thus we don't need R == 3.
|
| 29 |
-
# i think they just simplified it in the paper)
|
| 30 |
-
#
|
| 31 |
-
# I'm not sure how to construct this though. I could do like a concatenation thing though
|
| 32 |
-
# and have m linear layers with all different weights and then just concatenate the output
|
| 33 |
-
#
|
| 34 |
-
# ah okay the intuition here is that we are going to output a R*vocab_size vector,
|
| 35 |
-
# and then we will reshape it afterwards.
|
| 36 |
-
|
| 37 |
-
# not a lot of info on the structure of a simple MLP
|
| 38 |
self.mlp = nn.Sequential(
|
| 39 |
nn.Linear(embd, embd),
|
| 40 |
nn.ReLU(),
|
|
@@ -42,15 +30,39 @@ class MLPAtom(nn.Module):
|
|
| 42 |
)
|
| 43 |
|
| 44 |
def forward(self, z):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
boa = self.mlp(z)
|
| 46 |
|
| 47 |
return boa.view(-1, self.vocab_size, self.max_atoms)
|
| 48 |
|
| 49 |
|
| 50 |
-
class
|
| 51 |
"""Multi-layer perceptron which predicts which bond for each edge
|
| 52 |
|
| 53 |
-
Implementation of Eq. 11 in paper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
def __init__(self, embd=16, num_bonds=4):
|
| 56 |
super().__init__()
|
|
@@ -64,14 +76,30 @@ class MLPBond(nn.Module):
|
|
| 64 |
)
|
| 65 |
|
| 66 |
def forward(self, e):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
bonds = self.mlp(e)
|
| 68 |
return bonds
|
| 69 |
|
| 70 |
|
| 71 |
-
class
|
| 72 |
-
"""
|
| 73 |
|
| 74 |
Implementation of Eq. 4 in paper
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
"""
|
| 76 |
def __init__(self, embd=16):
|
| 77 |
super().__init__()
|
|
@@ -81,156 +109,198 @@ class GCNAtomLayer(nn.Module):
|
|
| 81 |
self.relu = nn.ReLU()
|
| 82 |
|
| 83 |
def forward(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
|
| 85 |
|
| 86 |
h = self.gcn(x, edge_index, edge_attr)
|
| 87 |
h = self.bn(h)
|
| 88 |
h = self.relu(h)
|
| 89 |
-
h = x + h #
|
| 90 |
return h
|
| 91 |
|
| 92 |
-
class
|
| 93 |
-
"""
|
| 94 |
|
| 95 |
Implementation of Eq. 5 in paper
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
|
| 98 |
def __init__(self, embd=16):
|
| 99 |
super().__init__()
|
| 100 |
-
self.
|
| 101 |
-
self.v2 = nn.Linear(embd, embd)
|
| 102 |
-
self.v3 = nn.Linear(embd, embd)
|
| 103 |
self.bn = nn.BatchNorm1d(embd)
|
| 104 |
self.relu = nn.ReLU()
|
| 105 |
|
| 106 |
def forward(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
|
| 108 |
|
| 109 |
h_src = x[edge_index[0]] # oh... apparently we get this for free?
|
| 110 |
h_dest = x[edge_index[1]] # apparently pytorch geometric already handles large batch
|
| 111 |
|
| 112 |
-
e = self.
|
| 113 |
|
| 114 |
e = self.bn(e)
|
| 115 |
e = self.relu(e)
|
| 116 |
|
| 117 |
-
e = edge_attr + e #
|
| 118 |
return e
|
| 119 |
|
| 120 |
class MoleGen(nn.Module):
|
| 121 |
-
"""Main model for VAE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def __init__(self, vocab_size=8, embd=16, num_layers=4, max_atoms=100, num_bonds=4):
|
|
|
|
| 123 |
super().__init__()
|
| 124 |
self.vocab_size = vocab_size
|
| 125 |
self.embd = embd
|
| 126 |
self.num_layers = num_layers
|
| 127 |
|
| 128 |
-
self.
|
| 129 |
-
self.
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
self.
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
self.
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
self.
|
| 144 |
-
self.b = nn.Linear(embd, embd)
|
| 145 |
-
self.c = nn.Linear(embd, embd)
|
| 146 |
-
self.d = nn.Linear(embd, embd)
|
| 147 |
|
| 148 |
-
self.
|
| 149 |
-
|
| 150 |
-
self.sig = nn.Sigmoid()
|
| 151 |
|
| 152 |
-
self.mlp_atom = MLPAtom(embd, vocab_size, max_atoms)
|
| 153 |
-
self.mlp_bond = MLPBond(embd, num_bonds)
|
| 154 |
-
|
| 155 |
def forward(self, input_data):
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
|
| 160 |
-
#
|
| 161 |
-
#
|
| 162 |
-
#
|
| 163 |
-
# data.edge_attr.shape = (num_bonds)
|
| 164 |
-
# bemb.shape = (num_bonds, embedding_dim)
|
| 165 |
-
aemb = self.aembs(x.view(-1)) # this is used later for the bond generation
|
| 166 |
-
bemb = self.bembs(edge_attr)
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
-
#
|
|
|
|
|
|
|
| 173 |
for i in range(self.num_layers):
|
| 174 |
-
data = Data(x=h, edge_index=edge_index, edge_attr=e)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
e = self.gcn_benc[i](data)
|
| 178 |
-
|
| 179 |
-
h_src = h[edge_index[0]]
|
| 180 |
-
h_dest = h[edge_index[1]]
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
|
|
|
| 184 |
|
| 185 |
# the .batch attribute only maps to nodes
|
| 186 |
-
# to get a mapping to the edge -> graph
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
#
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# imagine we had a graph with 25 nodes (N) and 30 edges (M)
|
| 196 |
-
# then edge_index[0, 0:M] will only contain values between 0 and 24 inclusive.
|
| 197 |
-
# thus, we will convert the 0 to 24 inclusive into a 0 to 29 inclusive
|
| 198 |
-
# since we have M entries
|
| 199 |
-
# then once we have a similar looking .batch for edges, we can do the
|
| 200 |
-
# scatter operation to get a per graph output
|
| 201 |
-
z = scatter(z, batch_edge, dim=0, reduce='sum') # (num_graphs, embedding_dim)
|
| 202 |
-
|
| 203 |
-
boa = self.mlp_atom(z) # (num_graphs, vocab_size, max_atoms)
|
| 204 |
|
| 205 |
|
| 206 |
######################################################################################
|
| 207 |
-
#
|
| 208 |
######################################################################################
|
| 209 |
|
| 210 |
-
|
|
|
|
| 211 |
|
| 212 |
-
#
|
| 213 |
-
# then we will index into the fc_bemb which has e.g. 32 graphs, and we want to apply fc_bemb[0] to
|
| 214 |
-
# the total number of edges in the first graph, fc_bemb[1] to total number of edges in second graph,
|
| 215 |
-
# etc... so this trick turns the node indexes into edge_indexes
|
| 216 |
fc_batch_edge = input_data.batch[input_data.fc_edge_index[0]]
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
# for each edge per graph
|
| 220 |
-
fc_edge_attr =
|
| 221 |
|
| 222 |
-
#
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
e = fc_edge_attr
|
| 227 |
for i in range(self.num_layers):
|
| 228 |
data = Data(x=h, edge_index=input_data.fc_edge_index, edge_attr=e)
|
| 229 |
-
h = self.
|
| 230 |
-
e = self.
|
| 231 |
-
|
| 232 |
-
s = self.mlp_bond(e)
|
| 233 |
|
|
|
|
|
|
|
| 234 |
|
| 235 |
return boa, z, s
|
| 236 |
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
from torch_geometric.data import Data
|
| 3 |
from torch_geometric.nn import ResGatedGraphConv
|
| 4 |
from torch_geometric.utils import scatter
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
class AtomMLP(nn.Module):
|
| 7 |
+
"""Multi-layer perceptron (MLP) which predicts BOA for each graph
|
| 8 |
|
| 9 |
+
Implementation of Eq. 9 in paper. This MLP takes in a graph level embedding 'z'
|
| 10 |
+
and produces a BOA where we essentially get an integer distribution of the predicted atoms
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
embd : embedding dimension for latent space
|
| 14 |
+
vocab_size : number of unique atoms in training set
|
| 15 |
+
max_atoms : max number of total atoms in any molecule
|
| 16 |
"""
|
| 17 |
def __init__(self, embd=16, vocab_size=8, max_atoms=100):
|
| 18 |
+
|
| 19 |
super().__init__()
|
| 20 |
|
| 21 |
self.vocab_size = vocab_size
|
| 22 |
self.max_atoms = max_atoms
|
| 23 |
|
| 24 |
+
# note there isn't much information on the paper on this MLP but I am assuming we
|
| 25 |
+
# use a nn.ReLU for the activation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
self.mlp = nn.Sequential(
|
| 27 |
nn.Linear(embd, embd),
|
| 28 |
nn.ReLU(),
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
def forward(self, z):
|
| 33 |
+
"""Forward pass of Atom MLP
|
| 34 |
+
|
| 35 |
+
We compute an output vector of size (num_graphs, vocab_size, max_atoms) with
|
| 36 |
+
the last dimension essentially representing a one-hot vector that has a dimension size
|
| 37 |
+
up to the maximum number of atoms. We perform the MLP with a flattened last dimension
|
| 38 |
+
(vocab_size*max_atoms) and then view it as a separate dimension for the loss calculation.
|
| 39 |
+
|
| 40 |
+
Note that a small optimization here
|
| 41 |
+
is to not do the total number of atoms in the training set but simply the
|
| 42 |
+
highest number of any particular atom instead.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
z : graph level embedding (num_graphs, embd)
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
boa : graph level BOA (num_graphs, vocab_size, max_atoms)
|
| 49 |
+
"""
|
| 50 |
boa = self.mlp(z)
|
| 51 |
|
| 52 |
return boa.view(-1, self.vocab_size, self.max_atoms)
|
| 53 |
|
| 54 |
|
| 55 |
+
class BondMLP(nn.Module):
|
| 56 |
"""Multi-layer perceptron which predicts which bond for each edge
|
| 57 |
|
| 58 |
+
Implementation of Eq. 11 in paper. This MLP takes in edge embeddings 'e'
|
| 59 |
+
and produces an integer distribution of the predicted bond types
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
embd : embedding dimension for latent space
|
| 63 |
+
num_bonds : number of unique atoms in training set
|
| 64 |
+
num_bonds : number of different types of bonds for classification
|
| 65 |
+
|
| 66 |
"""
|
| 67 |
def __init__(self, embd=16, num_bonds=4):
|
| 68 |
super().__init__()
|
|
|
|
| 76 |
)
|
| 77 |
|
| 78 |
def forward(self, e):
|
| 79 |
+
"""Forward pass of Bond MLP
|
| 80 |
+
|
| 81 |
+
We compute an output vector of size (num_edges, num_bonds) with
|
| 82 |
+
the last dimension representing a one-hot vector that has a dimension size
|
| 83 |
+
up to the different types of bonds.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
e : edge embedding (num_edges, embd)
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
bonds : graph level BOA (num_edges, num_bonds)
|
| 90 |
+
"""
|
| 91 |
bonds = self.mlp(e)
|
| 92 |
return bonds
|
| 93 |
|
| 94 |
|
| 95 |
+
class AtomGCNLayer(nn.Module):
|
| 96 |
+
"""Updates atom embeddings with Conv->BN->Relu->Residual
|
| 97 |
|
| 98 |
Implementation of Eq. 4 in paper
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
embd : embedding dimension for latent space
|
| 102 |
+
|
| 103 |
"""
|
| 104 |
def __init__(self, embd=16):
|
| 105 |
super().__init__()
|
|
|
|
| 109 |
self.relu = nn.ReLU()
|
| 110 |
|
| 111 |
def forward(self, data):
|
| 112 |
+
"""Perform forward pass for atom GCN
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
data : PyG Data() object (possibly batched) with the following attributes:
|
| 116 |
+
x : atom token, e.g. x[0] is the embedding vector for atom/node 0 (num_nodes, embd)
|
| 117 |
+
edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
|
| 118 |
+
edge_attr : bond token, e.g. edge_attr[0] is the embedding vector for bond/edge 0 (num_edges, embd)
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
h : Updated node embedding (num_nodes, embd)
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
|
| 125 |
|
| 126 |
h = self.gcn(x, edge_index, edge_attr)
|
| 127 |
h = self.bn(h)
|
| 128 |
h = self.relu(h)
|
| 129 |
+
h = x + h # residual pathway
|
| 130 |
return h
|
| 131 |
|
| 132 |
+
class BondGCNLayer(nn.Module):
|
| 133 |
+
"""Updates bond embeddings with Linear->BN->Relu->Residual
|
| 134 |
|
| 135 |
Implementation of Eq. 5 in paper
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
embd : embedding dimension for latent space
|
| 139 |
+
|
| 140 |
"""
|
| 141 |
|
| 142 |
def __init__(self, embd=16):
|
| 143 |
super().__init__()
|
| 144 |
+
self.v = nn.ModuleList([nn.Linear(embd,embd) for _ in range(3)])
|
|
|
|
|
|
|
| 145 |
self.bn = nn.BatchNorm1d(embd)
|
| 146 |
self.relu = nn.ReLU()
|
| 147 |
|
| 148 |
def forward(self, data):
|
| 149 |
+
"""Perform forward pass for bond GCN
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
data : PyG Data() object (possibly batched) with the following attributes:
|
| 153 |
+
x : atom token, e.g. x[0] is the embedding vector for atom/node 0 (num_nodes, embd)
|
| 154 |
+
edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
|
| 155 |
+
edge_attr : bond token, e.g. edge_attr[0] is the embedding vector for bond/edge 0 (num_edges, embd)
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
e : Updated edge embedding (num_edges, embd)
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
|
| 162 |
|
| 163 |
h_src = x[edge_index[0]] # oh... apparently we get this for free?
|
| 164 |
h_dest = x[edge_index[1]] # apparently pytorch geometric already handles large batch
|
| 165 |
|
| 166 |
+
e = self.v[0](edge_attr) + self.v[1](h_src) + self.v[2](h_dest)
|
| 167 |
|
| 168 |
e = self.bn(e)
|
| 169 |
e = self.relu(e)
|
| 170 |
|
| 171 |
+
e = edge_attr + e # residual pathway
|
| 172 |
return e
|
| 173 |
|
| 174 |
class MoleGen(nn.Module):
|
| 175 |
+
"""Main model for VAE molecular generation
|
| 176 |
+
|
| 177 |
+
This model is an implementation of the paper:
|
| 178 |
+
A Two-Step Graph Convolutional Decoder for Molecule Generation
|
| 179 |
+
by Bresson et Laurent (2019). A molecule graph is first encoded
|
| 180 |
+
into a latent representation 'z', which is used to produce a
|
| 181 |
+
'Bag of Atoms' (BOA). This BOA tells us how many of each atom
|
| 182 |
+
we have in the predicted molecule, ignoring connectivity.
|
| 183 |
+
The second stage takes 'z' and the original input formula and
|
| 184 |
+
decodes the edge feature connectivity from a fully connected network.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
vocab_size : number of unique atoms in training set
|
| 188 |
+
embd : embedding dimension for latent space
|
| 189 |
+
num_layers : number of GNN layers (message passing/k-hop distance)
|
| 190 |
+
max_atoms : max number of total atoms in any molecule
|
| 191 |
+
num_bonds : number of different types of bonds for classification
|
| 192 |
+
"""
|
| 193 |
def __init__(self, vocab_size=8, embd=16, num_layers=4, max_atoms=100, num_bonds=4):
|
| 194 |
+
|
| 195 |
super().__init__()
|
| 196 |
self.vocab_size = vocab_size
|
| 197 |
self.embd = embd
|
| 198 |
self.num_layers = num_layers
|
| 199 |
|
| 200 |
+
self.atom_embeddings = nn.Embedding(vocab_size, embd)
|
| 201 |
+
self.bond_embeddings = nn.Embedding(vocab_size, embd)
|
| 202 |
+
self.atom_encoder = nn.ModuleList([AtomGCNLayer(embd) for _ in range(num_layers)])
|
| 203 |
+
self.bond_encoder = nn.ModuleList([BondGCNLayer(embd) for _ in range(num_layers)])
|
| 204 |
+
self.atom_decoder = nn.ModuleList([AtomGCNLayer(embd) for _ in range(num_layers)])
|
| 205 |
+
self.bond_decoder = nn.ModuleList([BondGCNLayer(embd) for _ in range(num_layers)])
|
| 206 |
+
|
| 207 |
+
self.linear = nn.ModuleDict(dict(
|
| 208 |
+
a=nn.Linear(embd, embd),
|
| 209 |
+
b=nn.Linear(embd, embd),
|
| 210 |
+
c=nn.Linear(embd, embd),
|
| 211 |
+
d=nn.Linear(embd, embd),
|
| 212 |
+
u=nn.Linear(embd, embd)
|
| 213 |
+
))
|
| 214 |
+
|
| 215 |
+
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
+
self.atom_mlp = AtomMLP(embd, vocab_size, max_atoms)
|
| 218 |
+
self.bond_mlp = BondMLP(embd, num_bonds)
|
|
|
|
| 219 |
|
|
|
|
|
|
|
|
|
|
| 220 |
def forward(self, input_data):
|
| 221 |
+
"""Forward pass for MoleGen
|
| 222 |
+
|
| 223 |
+
First we look up the atom and bond embeddings, and then apply GCN layers iteratively.
|
| 224 |
+
Afterwards we reduce the node and edge embeddings into a single graph latent vector 'z'.
|
| 225 |
+
We apply MLP to 'z' to produce a BOA. Then we use the original 'x' tokens and the 'z' vector
|
| 226 |
+
to produce an edge probability matrix, which we again apply an MLP to to get 's' which is
|
| 227 |
+
the predicted bond tokens for each edge of a fully connected network.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
input_data : PyG Data() object (possibly batched) with the following attributes:
|
| 231 |
+
x : atom token, e.g. x[0] is the integer token for atom/node 0 (num_nodes, 1)
|
| 232 |
+
edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
|
| 233 |
+
edge_attr : bond token, e.g. edge_attr[0] is the integer token for bond/edge 0 (num_edges, )
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
boa : Bag of Atoms prediction (num_graphs, vocab_size, max_atoms)
|
| 237 |
+
z : Per graph latent representation (num_graphs, embd)
|
| 238 |
+
s : Edge probability matrix (num_fc_edges, num_bonds)
|
| 239 |
+
"""
|
| 240 |
|
| 241 |
|
| 242 |
+
######################################################################################
|
| 243 |
+
# Encoding step
|
| 244 |
+
######################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
# Get embeddings
|
| 247 |
+
atom_embedding = self.atom_embeddings(input_data.x.view(-1)) # indexing into embedding needs a flat vector
|
| 248 |
+
bond_embedding = self.bond_embeddings(input_data.edge_attr)
|
| 249 |
|
| 250 |
+
# Apply GCN layers iteratively (encoding)
|
| 251 |
+
h = atom_embedding
|
| 252 |
+
e = bond_embedding
|
| 253 |
for i in range(self.num_layers):
|
| 254 |
+
data = Data(x=h, edge_index=input_data.edge_index, edge_attr=e)
|
| 255 |
+
h = self.atom_encoder[i](data)
|
| 256 |
+
e = self.bond_encoder[i](data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
# Extract source and destination atom features
|
| 259 |
+
h_src = h[input_data.edge_index[0]]
|
| 260 |
+
h_dest = h[input_data.edge_index[1]]
|
| 261 |
|
| 262 |
+
# Apply linear and activation before reduction step
|
| 263 |
+
z = (self.sigmoid(self.linear['a'](e) + self.linear['b'](h_src) + self.linear['c'](h_dest)))*self.linear['d'](e)
|
| 264 |
|
| 265 |
# the .batch attribute only maps to nodes
|
| 266 |
+
# to get a mapping to the edge -> graph we simply index
|
| 267 |
+
# into the batch to get the indexes of which edge corresponds to which graph
|
| 268 |
+
batch_edge = input_data.batch[input_data.edge_index[0]] # (num_edges, )
|
| 269 |
+
|
| 270 |
+
# now batch_edge will ressemble input_data.batch but for edges instead of nodes
|
| 271 |
+
# apply scatter operation to get a per graph output
|
| 272 |
+
z = scatter(z, batch_edge, dim=0, reduce='sum') # (num_graphs, embd)
|
| 273 |
+
|
| 274 |
+
boa = self.atom_mlp(z) # (num_graphs, vocab_size, max_atoms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
######################################################################################
|
| 278 |
+
# Decoding step
|
| 279 |
######################################################################################
|
| 280 |
|
| 281 |
+
# in the bond generation step, each edge gets the same initial feature vector
|
| 282 |
+
fc_bond_embedding = self.linear['u'](z) # (num_graphs, embd)
|
| 283 |
|
| 284 |
+
# same trick as before where we convert node mappings to edge mappings
|
|
|
|
|
|
|
|
|
|
| 285 |
fc_batch_edge = input_data.batch[input_data.fc_edge_index[0]]
|
| 286 |
|
| 287 |
+
# we index into our bond embeddings to get the bond embeddings
|
| 288 |
+
# for each edge per graph, since fc_bond_embedding is batched
|
| 289 |
+
fc_edge_attr = fc_bond_embedding[fc_batch_edge] # (num_fc_edges, embd)
|
| 290 |
|
| 291 |
+
# re-use original atom embedding since the fc graph has the same nodes just with different bonds
|
| 292 |
+
fc_atom_embedding = atom_embedding # (num_atoms, embd)
|
| 293 |
+
|
| 294 |
+
# Apply GCN layers iteratively (decoding)
|
| 295 |
+
h = fc_atom_embedding
|
| 296 |
e = fc_edge_attr
|
| 297 |
for i in range(self.num_layers):
|
| 298 |
data = Data(x=h, edge_index=input_data.fc_edge_index, edge_attr=e)
|
| 299 |
+
h = self.atom_decoder[i](data)
|
| 300 |
+
e = self.bond_decoder[i](data)
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
# now take each edge and apply MLP to predict bond type for each
|
| 303 |
+
s = self.bond_mlp(e)
|
| 304 |
|
| 305 |
return boa, z, s
|
| 306 |
|
molegen/train.py
CHANGED
|
@@ -5,8 +5,8 @@ import torch
|
|
| 5 |
from . import data as prepare_data
|
| 6 |
from .model import MoleGen
|
| 7 |
|
| 8 |
-
|
| 9 |
class DataFrameDataset(Dataset):
|
|
|
|
| 10 |
def __init__(self, df, colname="data"):
|
| 11 |
self.data = df[colname].values
|
| 12 |
|
|
@@ -17,32 +17,32 @@ class DataFrameDataset(Dataset):
|
|
| 17 |
return self.data[idx]
|
| 18 |
|
| 19 |
def main():
|
| 20 |
-
DEBUG = True
|
| 21 |
-
|
| 22 |
-
df, a2t, t2a, e2t, t2e, max_atoms = prepare_data.main()
|
| 23 |
-
|
| 24 |
-
torch.manual_seed(42) # reproducibility
|
| 25 |
-
dataset = DataFrameDataset(df, "data")
|
| 26 |
-
loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 27 |
-
|
| 28 |
-
data = next(iter(loader))
|
| 29 |
-
print(data)
|
| 30 |
|
| 31 |
-
# import code; code.interact(local=locals())
|
| 32 |
-
# import pdb; pdb.set_trace()
|
| 33 |
-
|
| 34 |
###################################################
|
|
|
|
|
|
|
| 35 |
embd = 16 # embedding size of a vocab indice
|
| 36 |
-
vocab_size = len(a2t)
|
| 37 |
num_layers = 1 # number of GCN layers
|
| 38 |
lr = 0.001
|
| 39 |
betas = (0.9, 0.999)
|
| 40 |
eps = 1e-08
|
| 41 |
epochs = 100
|
| 42 |
-
lambda_boa = 0
|
| 43 |
-
lambda_edge =
|
|
|
|
| 44 |
###################################################
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
model = MoleGen(vocab_size=vocab_size, num_layers=num_layers, embd=embd, max_atoms=max_atoms)
|
| 47 |
|
| 48 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
|
|
@@ -57,9 +57,7 @@ def main():
|
|
| 57 |
for idx, batch in enumerate(loader):
|
| 58 |
boa, z, s = model(batch)
|
| 59 |
|
| 60 |
-
|
| 61 |
if DEBUG:
|
| 62 |
-
# import code; code.interact(local=locals())
|
| 63 |
boa_actual = batch.y_boa[0]
|
| 64 |
boa_pred = torch.argmax(boa[0], dim=-1)
|
| 65 |
|
|
@@ -85,9 +83,7 @@ def main():
|
|
| 85 |
|
| 86 |
if actual != 0 and predicted != 0:
|
| 87 |
print(f"[{t2a[src.item()]}][{t2a[dest.item()]}] ({actual}, {predicted})")
|
| 88 |
-
|
| 89 |
-
# import code; code.interact(local=locals())
|
| 90 |
-
|
| 91 |
optimizer.zero_grad()
|
| 92 |
|
| 93 |
# we have to permute because loss function expects (N, C, d1, d2, dK)
|
|
@@ -98,7 +94,6 @@ def main():
|
|
| 98 |
# y_fc_edge_attr has shape (B,) with each value between 0 and C
|
| 99 |
loss_edge = loss_fn(s, batch.y_fc_edge_attr.long())
|
| 100 |
|
| 101 |
-
# TODO add loss from other functions
|
| 102 |
loss = lambda_boa*loss_boa + lambda_edge*loss_edge
|
| 103 |
|
| 104 |
loss_avg += loss.item()
|
|
@@ -111,12 +106,5 @@ def main():
|
|
| 111 |
print(f"Avg loss: {loss_avg/len(loader):.5f} | Epoch {epoch}")
|
| 112 |
print()
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
if __name__ == "__main__":
|
| 122 |
main()
|
|
|
|
| 5 |
from . import data as prepare_data
|
| 6 |
from .model import MoleGen
|
| 7 |
|
|
|
|
| 8 |
class DataFrameDataset(Dataset):
|
| 9 |
+
"""Wrapper class to use Pytorch DataLoader"""
|
| 10 |
def __init__(self, df, colname="data"):
|
| 11 |
self.data = df[colname].values
|
| 12 |
|
|
|
|
| 17 |
return self.data[idx]
|
| 18 |
|
| 19 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
###################################################
|
| 22 |
+
# Parameters
|
| 23 |
+
DEBUG = True
|
| 24 |
embd = 16 # embedding size of a vocab indice
|
|
|
|
| 25 |
num_layers = 1 # number of GCN layers
|
| 26 |
lr = 0.001
|
| 27 |
betas = (0.9, 0.999)
|
| 28 |
eps = 1e-08
|
| 29 |
epochs = 100
|
| 30 |
+
lambda_boa = 0.5
|
| 31 |
+
lambda_edge = 0.5
|
| 32 |
+
batch_size = 32
|
| 33 |
###################################################
|
| 34 |
|
| 35 |
+
torch.manual_seed(42)
|
| 36 |
+
|
| 37 |
+
df, a2t, t2a, e2t, t2e, max_atoms = prepare_data.main()
|
| 38 |
+
vocab_size = len(a2t)
|
| 39 |
+
|
| 40 |
+
dataset = DataFrameDataset(df, "data")
|
| 41 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 42 |
+
|
| 43 |
+
data = next(iter(loader))
|
| 44 |
+
print(data)
|
| 45 |
+
|
| 46 |
model = MoleGen(vocab_size=vocab_size, num_layers=num_layers, embd=embd, max_atoms=max_atoms)
|
| 47 |
|
| 48 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
|
|
|
|
| 57 |
for idx, batch in enumerate(loader):
|
| 58 |
boa, z, s = model(batch)
|
| 59 |
|
|
|
|
| 60 |
if DEBUG:
|
|
|
|
| 61 |
boa_actual = batch.y_boa[0]
|
| 62 |
boa_pred = torch.argmax(boa[0], dim=-1)
|
| 63 |
|
|
|
|
| 83 |
|
| 84 |
if actual != 0 and predicted != 0:
|
| 85 |
print(f"[{t2a[src.item()]}][{t2a[dest.item()]}] ({actual}, {predicted})")
|
| 86 |
+
|
|
|
|
|
|
|
| 87 |
optimizer.zero_grad()
|
| 88 |
|
| 89 |
# we have to permute because loss function expects (N, C, d1, d2, dK)
|
|
|
|
| 94 |
# y_fc_edge_attr has shape (B,) with each value between 0 and C
|
| 95 |
loss_edge = loss_fn(s, batch.y_fc_edge_attr.long())
|
| 96 |
|
|
|
|
| 97 |
loss = lambda_boa*loss_boa + lambda_edge*loss_edge
|
| 98 |
|
| 99 |
loss_avg += loss.item()
|
|
|
|
| 106 |
print(f"Avg loss: {loss_avg/len(loader):.5f} | Epoch {epoch}")
|
| 107 |
print()
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if __name__ == "__main__":
|
| 110 |
main()
|