j-silv commited on
Commit
9467393
·
1 Parent(s): 54ba8ca

Fix bug missing sigmoid, refactor code, and add docstrings

Browse files

I noticed that when I train this for around 60 epochs, we start to get
edge outputs which are not just 0 (empty). I'm wondering if I try with more
hops whether or not the performance improves significantly.

Files changed (5) hide show
  1. LICENSE +21 -0
  2. README.md +2 -2
  3. molegen/data.py +0 -3
  4. molegen/model.py +189 -119
  5. molegen/train.py +18 -30
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Justin Silver
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,5 +1,5 @@
1
  # Molecular generation
2
 
3
- This is a project to learn how generative models work and showcase a project for drug discovery.
4
 
5
- I am replicating the following paper: [A two-step graph convolutional decoder for molecular generation](https://arxiv.org/abs/1906.03412).
 
1
  # Molecular generation
2
 
3
+ I started this project to get hands-on experience with generative models and explore ML for drug discovery.
4
 
5
+ The code is a replication of the following paper: [A Two-Step Graph Convolutional Decoder for Molecule Generation](https://arxiv.org/abs/1906.03412) by Bresson et Laurent (2019).
molegen/data.py CHANGED
@@ -171,9 +171,6 @@ def map_rdkit_bond_types(rdkit_edge_types, verbose=False):
171
  return our_edge_types
172
 
173
 
174
- #########################################################################################################################
175
- #########################################################################################################################
176
-
177
  def main():
178
  tqdm.pandas() # enable progress bars in pandas
179
  df = pd.read_csv("https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv", nrows=1000)
 
171
  return our_edge_types
172
 
173
 
 
 
 
174
  def main():
175
  tqdm.pandas() # enable progress bars in pandas
176
  df = pd.read_csv("https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv", nrows=1000)
molegen/model.py CHANGED
@@ -1,40 +1,28 @@
1
  import torch.nn as nn
2
- from torch_geometric.data import Data, Batch
3
  from torch_geometric.nn import ResGatedGraphConv
4
  from torch_geometric.utils import scatter
5
- from torch_geometric.utils import dense_to_sparse
6
- import torch
7
 
8
- class MLPAtom(nn.Module):
9
- """Multi-layer perceptron which gets soft bag of atom
10
 
11
- Implementation of Eq. 9 in paper
 
 
 
 
 
 
12
  """
13
  def __init__(self, embd=16, vocab_size=8, max_atoms=100):
 
14
  super().__init__()
15
 
16
  self.vocab_size = vocab_size
17
  self.max_atoms = max_atoms
18
 
19
- # note that we have an input vector of size (num_graphs, embd)
20
- # we need an output vector of size (vocab_size, R) where vocab_size
21
- # is 'm' in the paper which means the different atoms we can select
22
- # and R is basically a one-hot vector that goes up to the maximum number
23
- # of total atoms in any one molecule in the training set
24
- #
25
- # (small optimization here is to not do the total number of atoms in training set
26
- # but simply the highest number of any particular atom, i.e. take C02:
27
- # we would only need to have R == 2, since even though we have 3 atoms,
28
- # the max number for any single atom is only 2, thus we don't need R == 3.
29
- # i think they just simplified it in the paper)
30
- #
31
- # I'm not sure how to construct this though. I could do like a concatenation thing though
32
- # and have m linear layers with all different weights and then just concatenate the output
33
- #
34
- # ah okay the intuition here is that we are going to output a R*vocab_size vector,
35
- # and then we will reshape it afterwards.
36
-
37
- # not a lot of info on the structure of a simple MLP
38
  self.mlp = nn.Sequential(
39
  nn.Linear(embd, embd),
40
  nn.ReLU(),
@@ -42,15 +30,39 @@ class MLPAtom(nn.Module):
42
  )
43
 
44
  def forward(self, z):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  boa = self.mlp(z)
46
 
47
  return boa.view(-1, self.vocab_size, self.max_atoms)
48
 
49
 
50
- class MLPBond(nn.Module):
51
  """Multi-layer perceptron which predicts which bond for each edge
52
 
53
- Implementation of Eq. 11 in paper
 
 
 
 
 
 
 
54
  """
55
  def __init__(self, embd=16, num_bonds=4):
56
  super().__init__()
@@ -64,14 +76,30 @@ class MLPBond(nn.Module):
64
  )
65
 
66
  def forward(self, e):
 
 
 
 
 
 
 
 
 
 
 
 
67
  bonds = self.mlp(e)
68
  return bonds
69
 
70
 
71
- class GCNAtomLayer(nn.Module):
72
- """Intermediate class which performs ConvNet, BN, Relu, and Residual for atoms
73
 
74
  Implementation of Eq. 4 in paper
 
 
 
 
75
  """
76
  def __init__(self, embd=16):
77
  super().__init__()
@@ -81,156 +109,198 @@ class GCNAtomLayer(nn.Module):
81
  self.relu = nn.ReLU()
82
 
83
  def forward(self, data):
 
 
 
 
 
 
 
 
 
 
 
 
84
  x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
85
 
86
  h = self.gcn(x, edge_index, edge_attr)
87
  h = self.bn(h)
88
  h = self.relu(h)
89
- h = x + h # resid
90
  return h
91
 
92
- class GCNBondLayer(nn.Module):
93
- """Intermediate class which performs linear matrix multiply for bonds
94
 
95
  Implementation of Eq. 5 in paper
 
 
 
 
96
  """
97
 
98
  def __init__(self, embd=16):
99
  super().__init__()
100
- self.v1 = nn.Linear(embd, embd)
101
- self.v2 = nn.Linear(embd, embd)
102
- self.v3 = nn.Linear(embd, embd)
103
  self.bn = nn.BatchNorm1d(embd)
104
  self.relu = nn.ReLU()
105
 
106
  def forward(self, data):
 
 
 
 
 
 
 
 
 
 
 
 
107
  x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
108
 
109
  h_src = x[edge_index[0]] # oh... apparently we get this for free?
110
  h_dest = x[edge_index[1]] # apparently pytorch geometric already handles large batch
111
 
112
- e = self.v1(edge_attr) + self.v2(h_src) + self.v3(h_dest)
113
 
114
  e = self.bn(e)
115
  e = self.relu(e)
116
 
117
- e = edge_attr + e # resid
118
  return e
119
 
120
  class MoleGen(nn.Module):
121
- """Main model for VAE model generating"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def __init__(self, vocab_size=8, embd=16, num_layers=4, max_atoms=100, num_bonds=4):
 
123
  super().__init__()
124
  self.vocab_size = vocab_size
125
  self.embd = embd
126
  self.num_layers = num_layers
127
 
128
- self.aembs = nn.Embedding(vocab_size, embd) # atom embeddings
129
- self.bembs = nn.Embedding(vocab_size, embd) # bond embeddings
130
-
131
- # GCN atom encoder
132
- self.gcn_aenc = nn.ModuleList([GCNAtomLayer(embd) for i in range(num_layers)])
133
-
134
- # GCN bond encoder
135
- self.gcn_benc = nn.ModuleList([GCNBondLayer(embd) for i in range(num_layers)])
136
-
137
- # GCN atom decoder
138
- self.gcn_adec = nn.ModuleList([GCNAtomLayer(embd) for i in range(num_layers)])
139
-
140
- # GCN bond decoder
141
- self.gcn_bdec = nn.ModuleList([GCNBondLayer(embd) for i in range(num_layers)])
142
-
143
- self.a = nn.Linear(embd, embd)
144
- self.b = nn.Linear(embd, embd)
145
- self.c = nn.Linear(embd, embd)
146
- self.d = nn.Linear(embd, embd)
147
 
148
- self.u = nn.Linear(embd, embd)
149
-
150
- self.sig = nn.Sigmoid()
151
 
152
- self.mlp_atom = MLPAtom(embd, vocab_size, max_atoms)
153
- self.mlp_bond = MLPBond(embd, num_bonds)
154
-
155
  def forward(self, input_data):
156
-
157
- x, edge_index, edge_attr = input_data.x, input_data.edge_index, input_data.edge_attr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
- # we look up the embeddings from our table
161
- # data.x.shape = (num_atoms, 1)
162
- # aemb.shape = (num_atoms, embedding_dim)
163
- # data.edge_attr.shape = (num_bonds)
164
- # bemb.shape = (num_bonds, embedding_dim)
165
- aemb = self.aembs(x.view(-1)) # this is used later for the bond generation
166
- bemb = self.bembs(edge_attr)
167
 
168
- # we run GCN
169
- h = aemb
170
- e = bemb
171
 
172
- # import pdb; pdb.set_trace()
 
 
173
  for i in range(self.num_layers):
174
- data = Data(x=h, edge_index=edge_index, edge_attr=e)
175
-
176
- h = self.gcn_aenc[i](data)
177
- e = self.gcn_benc[i](data)
178
-
179
- h_src = h[edge_index[0]]
180
- h_dest = h[edge_index[1]]
181
 
 
 
 
182
 
183
- z = (self.a(e) + self.b(h_src) + self.c(h_dest))*self.d(e)
 
184
 
185
  # the .batch attribute only maps to nodes
186
- # to get a mapping to the edge -> graph (which is what we need)
187
- # we simply index into the batch to get the indexes of which
188
- # edge corresponds to which graph
189
- batch_edge = input_data.batch[edge_index[0]]
190
-
191
- # now batch_edge will be of shape (num_edges),
192
- # and cruicially it will ressemble .batch but now have 0s for the first
193
- # graph's edges, 1s for the second graph's edges, etc.
194
- # this is because the edge_index is automatically incremented,
195
- # imagine we had a graph with 25 nodes (N) and 30 edges (M)
196
- # then edge_index[0, 0:M] will only contain values between 0 and 24 inclusive.
197
- # thus, we will convert the 0 to 24 inclusive into a 0 to 29 inclusive
198
- # since we have M entries
199
- # then once we have a similar looking .batch for edges, we can do the
200
- # scatter operation to get a per graph output
201
- z = scatter(z, batch_edge, dim=0, reduce='sum') # (num_graphs, embedding_dim)
202
-
203
- boa = self.mlp_atom(z) # (num_graphs, vocab_size, max_atoms)
204
 
205
 
206
  ######################################################################################
207
- # VVVVVVVVVVVVV untested VVVVVVVVVVVVVVVVVVVVVV
208
  ######################################################################################
209
 
210
- fc_bemb = self.u(z) # (num_graphs, embedding_dim) -> this needs to be applied to each edge_attr
 
211
 
212
- # with this, we can create a new embedding attr matrix
213
- # then we will index into the fc_bemb which has e.g. 32 graphs, and we want to apply fc_bemb[0] to
214
- # the total number of edges in the first graph, fc_bemb[1] to total number of edges in second graph,
215
- # etc... so this trick turns the node indexes into edge_indexes
216
  fc_batch_edge = input_data.batch[input_data.fc_edge_index[0]]
217
 
218
- # and then we simply index into our bond embeddings to get the bond embeddings
219
- # for each edge per graph!
220
- fc_edge_attr = fc_bemb[fc_batch_edge] # (num_fc_edges, embedding_dim)
221
 
222
- # (num_atoms, embedding_dim)
223
- fc_aemb = aemb # same as original because the fc graph still has the same nodes in it, just with different bonds
224
-
225
- h = fc_aemb
 
226
  e = fc_edge_attr
227
  for i in range(self.num_layers):
228
  data = Data(x=h, edge_index=input_data.fc_edge_index, edge_attr=e)
229
- h = self.gcn_adec[i](data)
230
- e = self.gcn_bdec[i](data)
231
-
232
- s = self.mlp_bond(e)
233
 
 
 
234
 
235
  return boa, z, s
236
 
 
1
  import torch.nn as nn
2
+ from torch_geometric.data import Data
3
  from torch_geometric.nn import ResGatedGraphConv
4
  from torch_geometric.utils import scatter
 
 
5
 
6
+ class AtomMLP(nn.Module):
7
+ """Multi-layer perceptron (MLP) which predicts BOA for each graph
8
 
9
+ Implementation of Eq. 9 in paper. This MLP takes in a graph level embedding 'z'
10
+ and produces a BOA where we essentially get an integer distribution of the predicted atoms
11
+
12
+ Args:
13
+ embd : embedding dimension for latent space
14
+ vocab_size : number of unique atoms in training set
15
+ max_atoms : max number of total atoms in any molecule
16
  """
17
  def __init__(self, embd=16, vocab_size=8, max_atoms=100):
18
+
19
  super().__init__()
20
 
21
  self.vocab_size = vocab_size
22
  self.max_atoms = max_atoms
23
 
24
+ # note there isn't much information on the paper on this MLP but I am assuming we
25
+ # use a nn.ReLU for the activation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  self.mlp = nn.Sequential(
27
  nn.Linear(embd, embd),
28
  nn.ReLU(),
 
30
  )
31
 
32
  def forward(self, z):
33
+ """Forward pass of Atom MLP
34
+
35
+ We compute an output vector of size (num_graphs, vocab_size, max_atoms) with
36
+ the last dimension essentially representing a one-hot vector that has a dimension size
37
+ up to the maximum number of atoms. We perform the MLP with a flattened last dimension
38
+ (vocab_size*max_atoms) and then view it as a separate dimension for the loss calculation.
39
+
40
+ Note that a small optimization here
41
+ is to not do the total number of atoms in the training set but simply the
42
+ highest number of any particular atom instead.
43
+
44
+ Args:
45
+ z : graph level embedding (num_graphs, embd)
46
+
47
+ Returns:
48
+ boa : graph level BOA (num_graphs, vocab_size, max_atoms)
49
+ """
50
  boa = self.mlp(z)
51
 
52
  return boa.view(-1, self.vocab_size, self.max_atoms)
53
 
54
 
55
+ class BondMLP(nn.Module):
56
  """Multi-layer perceptron which predicts which bond for each edge
57
 
58
+ Implementation of Eq. 11 in paper. This MLP takes in edge embeddings 'e'
59
+ and produces an integer distribution of the predicted bond types
60
+
61
+ Args:
62
+ embd : embedding dimension for latent space
63
+ num_bonds : number of unique atoms in training set
64
+ num_bonds : number of different types of bonds for classification
65
+
66
  """
67
  def __init__(self, embd=16, num_bonds=4):
68
  super().__init__()
 
76
  )
77
 
78
  def forward(self, e):
79
+ """Forward pass of Bond MLP
80
+
81
+ We compute an output vector of size (num_edges, num_bonds) with
82
+ the last dimension representing a one-hot vector that has a dimension size
83
+ up to the different types of bonds.
84
+
85
+ Args:
86
+ e : edge embedding (num_edges, embd)
87
+
88
+ Returns:
89
+ bonds : graph level BOA (num_edges, num_bonds)
90
+ """
91
  bonds = self.mlp(e)
92
  return bonds
93
 
94
 
95
+ class AtomGCNLayer(nn.Module):
96
+ """Updates atom embeddings with Conv->BN->Relu->Residual
97
 
98
  Implementation of Eq. 4 in paper
99
+
100
+ Args:
101
+ embd : embedding dimension for latent space
102
+
103
  """
104
  def __init__(self, embd=16):
105
  super().__init__()
 
109
  self.relu = nn.ReLU()
110
 
111
  def forward(self, data):
112
+ """Perform forward pass for atom GCN
113
+
114
+ Args:
115
+ data : PyG Data() object (possibly batched) with the following attributes:
116
+ x : atom token, e.g. x[0] is the embedding vector for atom/node 0 (num_nodes, embd)
117
+ edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
118
+ edge_attr : bond token, e.g. edge_attr[0] is the embedding vector for bond/edge 0 (num_edges, embd)
119
+
120
+ Returns:
121
+ h : Updated node embedding (num_nodes, embd)
122
+
123
+ """
124
  x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
125
 
126
  h = self.gcn(x, edge_index, edge_attr)
127
  h = self.bn(h)
128
  h = self.relu(h)
129
+ h = x + h # residual pathway
130
  return h
131
 
132
+ class BondGCNLayer(nn.Module):
133
+ """Updates bond embeddings with Linear->BN->Relu->Residual
134
 
135
  Implementation of Eq. 5 in paper
136
+
137
+ Args:
138
+ embd : embedding dimension for latent space
139
+
140
  """
141
 
142
  def __init__(self, embd=16):
143
  super().__init__()
144
+ self.v = nn.ModuleList([nn.Linear(embd,embd) for _ in range(3)])
 
 
145
  self.bn = nn.BatchNorm1d(embd)
146
  self.relu = nn.ReLU()
147
 
148
  def forward(self, data):
149
+ """Perform forward pass for bond GCN
150
+
151
+ Args:
152
+ data : PyG Data() object (possibly batched) with the following attributes:
153
+ x : atom token, e.g. x[0] is the embedding vector for atom/node 0 (num_nodes, embd)
154
+ edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
155
+ edge_attr : bond token, e.g. edge_attr[0] is the embedding vector for bond/edge 0 (num_edges, embd)
156
+
157
+ Returns:
158
+ e : Updated edge embedding (num_edges, embd)
159
+
160
+ """
161
  x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
162
 
163
  h_src = x[edge_index[0]] # oh... apparently we get this for free?
164
  h_dest = x[edge_index[1]] # apparently pytorch geometric already handles large batch
165
 
166
+ e = self.v[0](edge_attr) + self.v[1](h_src) + self.v[2](h_dest)
167
 
168
  e = self.bn(e)
169
  e = self.relu(e)
170
 
171
+ e = edge_attr + e # residual pathway
172
  return e
173
 
174
  class MoleGen(nn.Module):
175
+ """Main model for VAE molecular generation
176
+
177
+ This model is an implementation of the paper:
178
+ A Two-Step Graph Convolutional Decoder for Molecule Generation
179
+ by Bresson et Laurent (2019). A molecule graph is first encoded
180
+ into a latent representation 'z', which is used to produce a
181
+ 'Bag of Atoms' (BOA). This BOA tells us how many of each atom
182
+ we have in the predicted molecule, ignoring connectivity.
183
+ The second stage takes 'z' and the original input formula and
184
+ decodes the edge feature connectivity from a fully connected network.
185
+
186
+ Args:
187
+ vocab_size : number of unique atoms in training set
188
+ embd : embedding dimension for latent space
189
+ num_layers : number of GNN layers (message passing/k-hop distance)
190
+ max_atoms : max number of total atoms in any molecule
191
+ num_bonds : number of different types of bonds for classification
192
+ """
193
  def __init__(self, vocab_size=8, embd=16, num_layers=4, max_atoms=100, num_bonds=4):
194
+
195
  super().__init__()
196
  self.vocab_size = vocab_size
197
  self.embd = embd
198
  self.num_layers = num_layers
199
 
200
+ self.atom_embeddings = nn.Embedding(vocab_size, embd)
201
+ self.bond_embeddings = nn.Embedding(vocab_size, embd)
202
+ self.atom_encoder = nn.ModuleList([AtomGCNLayer(embd) for _ in range(num_layers)])
203
+ self.bond_encoder = nn.ModuleList([BondGCNLayer(embd) for _ in range(num_layers)])
204
+ self.atom_decoder = nn.ModuleList([AtomGCNLayer(embd) for _ in range(num_layers)])
205
+ self.bond_decoder = nn.ModuleList([BondGCNLayer(embd) for _ in range(num_layers)])
206
+
207
+ self.linear = nn.ModuleDict(dict(
208
+ a=nn.Linear(embd, embd),
209
+ b=nn.Linear(embd, embd),
210
+ c=nn.Linear(embd, embd),
211
+ d=nn.Linear(embd, embd),
212
+ u=nn.Linear(embd, embd)
213
+ ))
214
+
215
+ self.sigmoid = nn.Sigmoid()
 
 
 
216
 
217
+ self.atom_mlp = AtomMLP(embd, vocab_size, max_atoms)
218
+ self.bond_mlp = BondMLP(embd, num_bonds)
 
219
 
 
 
 
220
  def forward(self, input_data):
221
+ """Forward pass for MoleGen
222
+
223
+ First we look up the atom and bond embeddings, and then apply GCN layers iteratively.
224
+ Afterwards we reduce the node and edge embeddings into a single graph latent vector 'z'.
225
+ We apply MLP to 'z' to produce a BOA. Then we use the original 'x' tokens and the 'z' vector
226
+ to produce an edge probability matrix, which we again apply an MLP to to get 's' which is
227
+ the predicted bond tokens for each edge of a fully connected network.
228
+
229
+ Args:
230
+ input_data : PyG Data() object (possibly batched) with the following attributes:
231
+ x : atom token, e.g. x[0] is the integer token for atom/node 0 (num_nodes, 1)
232
+ edge_index : bond connectivity, e.g. atom at edge_index[0,0] connects to edge_index[1, 0] (2, num_edges)
233
+ edge_attr : bond token, e.g. edge_attr[0] is the integer token for bond/edge 0 (num_edges, )
234
+
235
+ Returns:
236
+ boa : Bag of Atoms prediction (num_graphs, vocab_size, max_atoms)
237
+ z : Per graph latent representation (num_graphs, embd)
238
+ s : Edge probability matrix (num_fc_edges, num_bonds)
239
+ """
240
 
241
 
242
+ ######################################################################################
243
+ # Encoding step
244
+ ######################################################################################
 
 
 
 
245
 
246
+ # Get embeddings
247
+ atom_embedding = self.atom_embeddings(input_data.x.view(-1)) # indexing into embedding needs a flat vector
248
+ bond_embedding = self.bond_embeddings(input_data.edge_attr)
249
 
250
+ # Apply GCN layers iteratively (encoding)
251
+ h = atom_embedding
252
+ e = bond_embedding
253
  for i in range(self.num_layers):
254
+ data = Data(x=h, edge_index=input_data.edge_index, edge_attr=e)
255
+ h = self.atom_encoder[i](data)
256
+ e = self.bond_encoder[i](data)
 
 
 
 
257
 
258
+ # Extract source and destination atom features
259
+ h_src = h[input_data.edge_index[0]]
260
+ h_dest = h[input_data.edge_index[1]]
261
 
262
+ # Apply linear and activation before reduction step
263
+ z = (self.sigmoid(self.linear['a'](e) + self.linear['b'](h_src) + self.linear['c'](h_dest)))*self.linear['d'](e)
264
 
265
  # the .batch attribute only maps to nodes
266
+ # to get a mapping to the edge -> graph we simply index
267
+ # into the batch to get the indexes of which edge corresponds to which graph
268
+ batch_edge = input_data.batch[input_data.edge_index[0]] # (num_edges, )
269
+
270
+ # now batch_edge will ressemble input_data.batch but for edges instead of nodes
271
+ # apply scatter operation to get a per graph output
272
+ z = scatter(z, batch_edge, dim=0, reduce='sum') # (num_graphs, embd)
273
+
274
+ boa = self.atom_mlp(z) # (num_graphs, vocab_size, max_atoms)
 
 
 
 
 
 
 
 
 
275
 
276
 
277
  ######################################################################################
278
+ # Decoding step
279
  ######################################################################################
280
 
281
+ # in the bond generation step, each edge gets the same initial feature vector
282
+ fc_bond_embedding = self.linear['u'](z) # (num_graphs, embd)
283
 
284
+ # same trick as before where we convert node mappings to edge mappings
 
 
 
285
  fc_batch_edge = input_data.batch[input_data.fc_edge_index[0]]
286
 
287
+ # we index into our bond embeddings to get the bond embeddings
288
+ # for each edge per graph, since fc_bond_embedding is batched
289
+ fc_edge_attr = fc_bond_embedding[fc_batch_edge] # (num_fc_edges, embd)
290
 
291
+ # re-use original atom embedding since the fc graph has the same nodes just with different bonds
292
+ fc_atom_embedding = atom_embedding # (num_atoms, embd)
293
+
294
+ # Apply GCN layers iteratively (decoding)
295
+ h = fc_atom_embedding
296
  e = fc_edge_attr
297
  for i in range(self.num_layers):
298
  data = Data(x=h, edge_index=input_data.fc_edge_index, edge_attr=e)
299
+ h = self.atom_decoder[i](data)
300
+ e = self.bond_decoder[i](data)
 
 
301
 
302
+ # now take each edge and apply MLP to predict bond type for each
303
+ s = self.bond_mlp(e)
304
 
305
  return boa, z, s
306
 
molegen/train.py CHANGED
@@ -5,8 +5,8 @@ import torch
5
  from . import data as prepare_data
6
  from .model import MoleGen
7
 
8
-
9
  class DataFrameDataset(Dataset):
 
10
  def __init__(self, df, colname="data"):
11
  self.data = df[colname].values
12
 
@@ -17,32 +17,32 @@ class DataFrameDataset(Dataset):
17
  return self.data[idx]
18
 
19
  def main():
20
- DEBUG = True
21
-
22
- df, a2t, t2a, e2t, t2e, max_atoms = prepare_data.main()
23
-
24
- torch.manual_seed(42) # reproducibility
25
- dataset = DataFrameDataset(df, "data")
26
- loader = DataLoader(dataset, batch_size=32, shuffle=True)
27
-
28
- data = next(iter(loader))
29
- print(data)
30
 
31
- # import code; code.interact(local=locals())
32
- # import pdb; pdb.set_trace()
33
-
34
  ###################################################
 
 
35
  embd = 16 # embedding size of a vocab indice
36
- vocab_size = len(a2t)
37
  num_layers = 1 # number of GCN layers
38
  lr = 0.001
39
  betas = (0.9, 0.999)
40
  eps = 1e-08
41
  epochs = 100
42
- lambda_boa = 0
43
- lambda_edge = 1
 
44
  ###################################################
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  model = MoleGen(vocab_size=vocab_size, num_layers=num_layers, embd=embd, max_atoms=max_atoms)
47
 
48
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
@@ -57,9 +57,7 @@ def main():
57
  for idx, batch in enumerate(loader):
58
  boa, z, s = model(batch)
59
 
60
-
61
  if DEBUG:
62
- # import code; code.interact(local=locals())
63
  boa_actual = batch.y_boa[0]
64
  boa_pred = torch.argmax(boa[0], dim=-1)
65
 
@@ -85,9 +83,7 @@ def main():
85
 
86
  if actual != 0 and predicted != 0:
87
  print(f"[{t2a[src.item()]}][{t2a[dest.item()]}] ({actual}, {predicted})")
88
-
89
- # import code; code.interact(local=locals())
90
-
91
  optimizer.zero_grad()
92
 
93
  # we have to permute because loss function expects (N, C, d1, d2, dK)
@@ -98,7 +94,6 @@ def main():
98
  # y_fc_edge_attr has shape (B,) with each value between 0 and C
99
  loss_edge = loss_fn(s, batch.y_fc_edge_attr.long())
100
 
101
- # TODO add loss from other functions
102
  loss = lambda_boa*loss_boa + lambda_edge*loss_edge
103
 
104
  loss_avg += loss.item()
@@ -111,12 +106,5 @@ def main():
111
  print(f"Avg loss: {loss_avg/len(loader):.5f} | Epoch {epoch}")
112
  print()
113
 
114
-
115
-
116
-
117
-
118
-
119
-
120
-
121
  if __name__ == "__main__":
122
  main()
 
5
  from . import data as prepare_data
6
  from .model import MoleGen
7
 
 
8
  class DataFrameDataset(Dataset):
9
+ """Wrapper class to use Pytorch DataLoader"""
10
  def __init__(self, df, colname="data"):
11
  self.data = df[colname].values
12
 
 
17
  return self.data[idx]
18
 
19
  def main():
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
21
  ###################################################
22
+ # Parameters
23
+ DEBUG = True
24
  embd = 16 # embedding size of a vocab indice
 
25
  num_layers = 1 # number of GCN layers
26
  lr = 0.001
27
  betas = (0.9, 0.999)
28
  eps = 1e-08
29
  epochs = 100
30
+ lambda_boa = 0.5
31
+ lambda_edge = 0.5
32
+ batch_size = 32
33
  ###################################################
34
 
35
+ torch.manual_seed(42)
36
+
37
+ df, a2t, t2a, e2t, t2e, max_atoms = prepare_data.main()
38
+ vocab_size = len(a2t)
39
+
40
+ dataset = DataFrameDataset(df, "data")
41
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
42
+
43
+ data = next(iter(loader))
44
+ print(data)
45
+
46
  model = MoleGen(vocab_size=vocab_size, num_layers=num_layers, embd=embd, max_atoms=max_atoms)
47
 
48
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps)
 
57
  for idx, batch in enumerate(loader):
58
  boa, z, s = model(batch)
59
 
 
60
  if DEBUG:
 
61
  boa_actual = batch.y_boa[0]
62
  boa_pred = torch.argmax(boa[0], dim=-1)
63
 
 
83
 
84
  if actual != 0 and predicted != 0:
85
  print(f"[{t2a[src.item()]}][{t2a[dest.item()]}] ({actual}, {predicted})")
86
+
 
 
87
  optimizer.zero_grad()
88
 
89
  # we have to permute because loss function expects (N, C, d1, d2, dK)
 
94
  # y_fc_edge_attr has shape (B,) with each value between 0 and C
95
  loss_edge = loss_fn(s, batch.y_fc_edge_attr.long())
96
 
 
97
  loss = lambda_boa*loss_boa + lambda_edge*loss_edge
98
 
99
  loss_avg += loss.item()
 
106
  print(f"Avg loss: {loss_avg/len(loader):.5f} | Epoch {epoch}")
107
  print()
108
 
 
 
 
 
 
 
 
109
  if __name__ == "__main__":
110
  main()