AlistairXnigo commited on
Commit
d2e0a38
·
verified ·
1 Parent(s): 53e4ac8

Upload mpnn_pom.py

Browse files
Files changed (1) hide show
  1. mpnn_pom.py +601 -0
mpnn_pom.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List, Tuple, Union, Optional, Callable, Dict
5
+
6
+ from deepchem.models.losses import Loss, L2Loss
7
+ from deepchem.models.torch_models.torch_model import TorchModel
8
+ from deepchem.models.optimizers import Optimizer, LearningRateSchedule
9
+
10
+ from openpom.layers.pom_ffn import CustomPositionwiseFeedForward
11
+ from openpom.utils.loss import CustomMultiLabelLoss
12
+ from openpom.utils.optimizer import get_optimizer
13
+
14
+ try:
15
+ import dgl
16
+ from dgl import DGLGraph
17
+ from dgl.nn.pytorch import Set2Set
18
+ from openpom.layers.pom_mpnn_gnn import CustomMPNNGNN
19
+ except (ImportError, ModuleNotFoundError):
20
+ raise ImportError('This module requires dgl and dgllife')
21
+
22
+
23
+ class MPNNPOM(nn.Module):
24
+ """
25
+ MPNN model computes a principal odor map
26
+ using multilabel-classification based on the pre-print:
27
+ "A Principal Odor Map Unifies DiverseTasks in Human
28
+ Olfactory Perception" [1]
29
+
30
+ This model proceeds as follows:
31
+
32
+ * Combine latest node representations and edge features in
33
+ updating node representations, which involves multiple
34
+ rounds of message passing.
35
+ * For each graph, compute its representation by radius 0 combination
36
+ to fold atom and bond embeddings together, followed by
37
+ 'set2set' or 'global_sum_pooling' readout.
38
+ * Perform the final prediction using a feed-forward layer.
39
+
40
+ References
41
+ ----------
42
+ .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling,
43
+ Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres,
44
+ Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin,
45
+ Joel D. Mainland, Alexander B. Wiltschko
46
+ `A Principal Odor Map Unifies Diverse Tasks
47
+ in Human Olfactory Perception preprint
48
+ <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.
49
+
50
+ .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee,
51
+ Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko
52
+ `Machine Learning for Scent:
53
+ Learning Generalizable Perceptual Representations
54
+ of Small Molecules <https://arxiv.org/abs/1910.10685>`_.
55
+
56
+ .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley,
57
+ Oriol Vinyals, George E. Dahl.
58
+ "Neural Message Passing for Quantum Chemistry." ICML 2017.
59
+
60
+ Notes
61
+ -----
62
+ This class requires DGL (https://github.com/dmlc/dgl)
63
+ and DGL-LifeSci (https://github.com/awslabs/dgl-lifesci)
64
+ to be installed.
65
+ """
66
+
67
+ def __init__(self,
68
+ n_tasks: int,
69
+ node_out_feats: int = 64,
70
+ edge_hidden_feats: int = 128,
71
+ edge_out_feats: int = 64,
72
+ num_step_message_passing: int = 3,
73
+ mpnn_residual: bool = True,
74
+ message_aggregator_type: str = 'sum',
75
+ mode: str = 'classification',
76
+ number_atom_features: int = 134,
77
+ number_bond_features: int = 6,
78
+ n_classes: int = 1,
79
+ nfeat_name: str = 'x',
80
+ efeat_name: str = 'edge_attr',
81
+ readout_type: str = 'set2set',
82
+ num_step_set2set: int = 6,
83
+ num_layer_set2set: int = 3,
84
+ ffn_hidden_list: List = [300],
85
+ ffn_embeddings: int = 256,
86
+ ffn_activation: str = 'relu',
87
+ ffn_dropout_p: float = 0.0,
88
+ ffn_dropout_at_input_no_act: bool = True):
89
+ """
90
+ Parameters
91
+ ----------
92
+ n_tasks: int
93
+ Number of tasks.
94
+ node_out_feats: int
95
+ The length of the final node representation vectors
96
+ before readout. Default to 64.
97
+ edge_hidden_feats: int
98
+ The length of the hidden edge representation vectors
99
+ for mpnn edge network. Default to 128.
100
+ edge_out_feats: int
101
+ The length of the final edge representation vectors
102
+ before readout. Default to 64.
103
+ num_step_message_passing: int
104
+ The number of rounds of message passing. Default to 3.
105
+ mpnn_residual: bool
106
+ If true, adds residual layer to mpnn layer. Default to True.
107
+ message_aggregator_type: str
108
+ MPNN message aggregator type, 'sum', 'mean' or 'max'.
109
+ Default to 'sum'.
110
+ mode: str
111
+ The model type, 'classification' or 'regression'.
112
+ Default to 'classification'.
113
+ number_atom_features: int
114
+ The length of the initial atom feature vectors. Default to 134.
115
+ number_bond_features: int
116
+ The length of the initial bond feature vectors. Default to 6.
117
+ n_classes: int
118
+ The number of classes to predict per task
119
+ (only used when ``mode`` is 'classification'). Default to 1.
120
+ nfeat_name: str
121
+ For an input graph ``g``, the model assumes that it stores
122
+ node features in ``g.ndata[nfeat_name]`` and will retrieve
123
+ input node features from that. Default to 'x'.
124
+ efeat_name: str
125
+ For an input graph ``g``, the model assumes that it stores
126
+ edge features in ``g.edata[efeat_name]`` and will retrieve
127
+ input edge features from that. Default to 'edge_attr'.
128
+ readout_type: str
129
+ The Readout type, 'set2set' or 'global_sum_pooling'.
130
+ Default to 'set2set'.
131
+ num_step_set2set: int
132
+ Number of steps in set2set readout.
133
+ Used if, readout_type == 'set2set'.
134
+ Default to 6.
135
+ num_layer_set2set: int
136
+ Number of layers in set2set readout.
137
+ Used if, readout_type == 'set2set'.
138
+ Default to 3.
139
+ ffn_hidden_list: List
140
+ List of sizes of hidden layer in the feed-forward network layer.
141
+ Default to [300].
142
+ ffn_embeddings: int
143
+ Size of penultimate layer in the feed-forward network layer.
144
+ This determines the Principal Odor Map dimension.
145
+ Default to 256.
146
+ ffn_activation: str
147
+ Activation function to be used in feed-forward network layer.
148
+ Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU,
149
+ 'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU,
150
+ and 'elu' for ELU.
151
+ ffn_dropout_p: float
152
+ Dropout probability for the feed-forward network layer.
153
+ Default to 0.0
154
+ ffn_dropout_at_input_no_act: bool
155
+ If true, dropout is applied on the input tensor.
156
+ For single layer, it is not passed to an activation function.
157
+ """
158
+ if mode not in ['classification', 'regression']:
159
+ raise ValueError(
160
+ "mode must be either 'classification' or 'regression'")
161
+
162
+ super(MPNNPOM, self).__init__()
163
+
164
+ self.n_tasks: int = n_tasks
165
+ self.mode: str = mode
166
+ self.n_classes: int = n_classes
167
+ self.nfeat_name: str = nfeat_name
168
+ self.efeat_name: str = efeat_name
169
+ self.readout_type: str = readout_type
170
+ self.ffn_embeddings: int = ffn_embeddings
171
+ self.ffn_activation: str = ffn_activation
172
+ self.ffn_dropout_p: float = ffn_dropout_p
173
+
174
+ if mode == 'classification':
175
+ self.ffn_output: int = n_tasks * n_classes
176
+ else:
177
+ self.ffn_output = n_tasks
178
+
179
+ self.mpnn: nn.Module = CustomMPNNGNN(
180
+ node_in_feats=number_atom_features,
181
+ node_out_feats=node_out_feats,
182
+ edge_in_feats=number_bond_features,
183
+ edge_hidden_feats=edge_hidden_feats,
184
+ num_step_message_passing=num_step_message_passing,
185
+ residual=mpnn_residual,
186
+ message_aggregator_type=message_aggregator_type)
187
+
188
+ self.project_edge_feats: nn.Module = nn.Sequential(
189
+ nn.Linear(number_bond_features, edge_out_feats), nn.ReLU())
190
+
191
+ if self.readout_type == 'set2set':
192
+ self.readout_set2set: nn.Module = Set2Set(
193
+ input_dim=node_out_feats + edge_out_feats,
194
+ n_iters=num_step_set2set,
195
+ n_layers=num_layer_set2set)
196
+ ffn_input: int = 2 * (node_out_feats + edge_out_feats)
197
+ elif self.readout_type == 'global_sum_pooling':
198
+ ffn_input = node_out_feats + edge_out_feats
199
+ else:
200
+ raise Exception("readout_type invalid")
201
+
202
+ if ffn_embeddings is not None:
203
+ d_hidden_list: List = ffn_hidden_list + [ffn_embeddings]
204
+
205
+ self.ffn: nn.Module = CustomPositionwiseFeedForward(
206
+ d_input=ffn_input,
207
+ d_hidden_list=d_hidden_list,
208
+ d_output=self.ffn_output,
209
+ activation=ffn_activation,
210
+ dropout_p=ffn_dropout_p,
211
+ dropout_at_input_no_act=ffn_dropout_at_input_no_act)
212
+
213
+ def _readout(self, g: DGLGraph, node_encodings: torch.Tensor,
214
+ edge_feats: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ Method to execute the readout phase.
217
+ (compute molecules encodings from atom hidden states)
218
+
219
+ Readout phase consists of radius 0 combination to fold atom
220
+ and bond embeddings together,
221
+ followed by:
222
+ - a reduce-sum across atoms
223
+ if `self.readout_type == 'global_sum_pooling'`
224
+ - set2set pooling
225
+ if `self.readout_type == 'set2set'`
226
+
227
+ Parameters
228
+ ----------
229
+ g: DGLGraph
230
+ A DGLGraph for a batch of graphs.
231
+ It stores the node features in
232
+ ``dgl_graph.ndata[self.nfeat_name]`` and edge features in
233
+ ``dgl_graph.edata[self.efeat_name]``.
234
+
235
+ node_encodings: torch.Tensor
236
+ Tensor containing node hidden states.
237
+
238
+ edge_feats: torch.Tensor
239
+ Tensor containing edge features.
240
+
241
+ Returns
242
+ -------
243
+ batch_mol_hidden_states: torch.Tensor
244
+ Tensor containing batchwise molecule encodings.
245
+ """
246
+
247
+ g.ndata['node_emb'] = node_encodings
248
+ g.edata['edge_emb'] = self.project_edge_feats(edge_feats)
249
+
250
+ def message_func(edges) -> Dict:
251
+ """
252
+ The message function to generate messages
253
+ along the edges for DGLGraph.send_and_recv()
254
+ """
255
+ src_msg: torch.Tensor = torch.cat(
256
+ (edges.src['node_emb'], edges.data['edge_emb']), dim=1)
257
+ return {'src_msg': src_msg}
258
+
259
+ def reduce_func(nodes) -> Dict:
260
+ """
261
+ The reduce function to aggregate the messages
262
+ for DGLGraph.send_and_recv()
263
+ """
264
+ src_msg_sum: torch.Tensor = torch.sum(nodes.mailbox['src_msg'],
265
+ dim=1)
266
+ return {'src_msg_sum': src_msg_sum}
267
+
268
+ # radius 0 combination to fold atom and bond embeddings together
269
+ g.send_and_recv(g.edges(),
270
+ message_func=message_func,
271
+ reduce_func=reduce_func)
272
+
273
+ if self.readout_type == 'set2set':
274
+ batch_mol_hidden_states: torch.Tensor = self.readout_set2set(
275
+ g, g.ndata['src_msg_sum'])
276
+ elif self.readout_type == 'global_sum_pooling':
277
+ batch_mol_hidden_states = dgl.sum_nodes(g, 'src_msg_sum')
278
+
279
+ # batch_size x (node_out_feats + edge_out_feats)
280
+ return batch_mol_hidden_states
281
+
282
+ def forward(
283
+ self, g: DGLGraph
284
+ ) -> Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
285
+ """
286
+ Foward pass for MPNNPOM class. It also returns embeddings for POM.
287
+
288
+ Parameters
289
+ ----------
290
+ g: DGLGraph
291
+ A DGLGraph for a batch of graphs. It stores the node features in
292
+ ``dgl_graph.ndata[self.nfeat_name]`` and edge features in
293
+ ``dgl_graph.edata[self.efeat_name]``.
294
+
295
+ Returns
296
+ -------
297
+ Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
298
+ The model output.
299
+
300
+ * When self.mode = 'regression',
301
+ its shape will be ``(dgl_graph.batch_size, self.n_tasks)``.
302
+ * When self.mode = 'classification',
303
+ the output consists of probabilities for classes.
304
+ Its shape will be
305
+ ``(dgl_graph.batch_size, self.n_tasks, self.n_classes)``
306
+ if self.n_tasks > 1;
307
+ its shape will be ``(dgl_graph.batch_size, self.n_classes)``
308
+ if self.n_tasks is 1.
309
+ """
310
+ node_feats: torch.Tensor = g.ndata[self.nfeat_name]
311
+ edge_feats: torch.Tensor = g.edata[self.efeat_name]
312
+
313
+ node_encodings: torch.Tensor = self.mpnn(g, node_feats, edge_feats)
314
+
315
+ molecular_encodings: torch.Tensor = self._readout(
316
+ g, node_encodings, edge_feats)
317
+ if self.readout_type == 'global_sum_pooling':
318
+ molecular_encodings = F.softmax(molecular_encodings, dim=1)
319
+
320
+ embeddings: torch.Tensor
321
+ out: torch.Tensor
322
+ embeddings, out = self.ffn(molecular_encodings)
323
+
324
+ if self.mode == 'classification':
325
+ if self.n_tasks == 1:
326
+ logits: torch.Tensor = out.view(-1, self.n_classes)
327
+ else:
328
+ logits = out.view(-1, self.n_tasks, self.n_classes)
329
+ proba: torch.Tensor = F.sigmoid(
330
+ logits) # (batch, n_tasks, classes)
331
+ if self.n_classes == 1:
332
+ proba = proba.squeeze(-1) # (batch, n_tasks)
333
+ return proba, logits, embeddings
334
+ else:
335
+ return out
336
+
337
+
338
+ class MPNNPOMModel(TorchModel):
339
+ """
340
+ MPNNPOMModel for obtaining a principal odor map
341
+ using multilabel-classification based on the pre-print:
342
+ "A Principal Odor Map Unifies DiverseTasks in Human
343
+ Olfactory Perception" [1]
344
+
345
+ * Combine latest node representations and edge features in
346
+ updating node representations, which involves multiple
347
+ rounds of message passing.
348
+ * For each graph, compute its representation by radius 0 combination
349
+ to fold atom and bond embeddings together, followed by
350
+ 'set2set' or 'global_sum_pooling' readout.
351
+ * Perform the final prediction using a feed-forward layer.
352
+
353
+ References
354
+ ----------
355
+ .. [1] Brian K. Lee, Emily J. Mayhew, Benjamin Sanchez-Lengeling,
356
+ Jennifer N. Wei, Wesley W. Qian, Kelsie Little, Matthew Andres,
357
+ Britney B. Nguyen, Theresa Moloy, Jane K. Parker, Richard C. Gerkin,
358
+ Joel D. Mainland, Alexander B. Wiltschko
359
+ `A Principal Odor Map Unifies Diverse Tasks
360
+ in Human Olfactory Perception preprint
361
+ <https://www.biorxiv.org/content/10.1101/2022.09.01.504602v4>`_.
362
+
363
+ .. [2] Benjamin Sanchez-Lengeling, Jennifer N. Wei, Brian K. Lee,
364
+ Richard C. Gerkin, Alán Aspuru-Guzik, Alexander B. Wiltschko
365
+ `Machine Learning for Scent:
366
+ Learning Generalizable Perceptual Representations
367
+ of Small Molecules <https://arxiv.org/abs/1910.10685>`_.
368
+
369
+ .. [3] Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley,
370
+ Oriol Vinyals, George E. Dahl.
371
+ "Neural Message Passing for Quantum Chemistry." ICML 2017.
372
+
373
+ Notes
374
+ -----
375
+ This class requires DGL (https://github.com/dmlc/dgl) and DGL-LifeSci
376
+ (https://github.com/awslabs/dgl-lifesci) to be installed.
377
+
378
+ The featurizer used with MPNNPOMModel must produce a Deepchem GraphData
379
+ object which should have both 'edge' and 'node' features.
380
+ """
381
+
382
+ def __init__(self,
383
+ n_tasks: int,
384
+ class_imbalance_ratio: Optional[List] = None,
385
+ loss_aggr_type: str = 'sum',
386
+ learning_rate: Union[float, LearningRateSchedule] = 0.001,
387
+ batch_size: int = 100,
388
+ node_out_feats: int = 64,
389
+ edge_hidden_feats: int = 128,
390
+ edge_out_feats: int = 64,
391
+ num_step_message_passing: int = 3,
392
+ mpnn_residual: bool = True,
393
+ message_aggregator_type: str = 'sum',
394
+ mode: str = 'regression',
395
+ number_atom_features: int = 134,
396
+ number_bond_features: int = 6,
397
+ n_classes: int = 1,
398
+ readout_type: str = 'set2set',
399
+ num_step_set2set: int = 6,
400
+ num_layer_set2set: int = 3,
401
+ ffn_hidden_list: List = [300],
402
+ ffn_embeddings: int = 256,
403
+ ffn_activation: str = 'relu',
404
+ ffn_dropout_p: float = 0.0,
405
+ ffn_dropout_at_input_no_act: bool = True,
406
+ weight_decay: float = 1e-5,
407
+ self_loop: bool = False,
408
+ optimizer_name: str = 'adam',
409
+ device_name: Optional[str] = None,
410
+ **kwargs):
411
+ """
412
+ Parameters
413
+ ----------
414
+ n_tasks: int
415
+ Number of tasks.
416
+ class_imbalance_ratio: Optional[List]
417
+ List of imbalance ratios per task.
418
+ loss_aggr_type: str
419
+ loss aggregation type; 'sum' or 'mean'. Default to 'sum'.
420
+ Only applies to CustomMultiLabelLoss for classification
421
+ learning_rate: Union[float, LearningRateSchedule]
422
+ Learning rate value or scheduler object. Default to 0.001.
423
+ batch_size: int
424
+ Batch size for training. Default to 100.
425
+ node_out_feats: int
426
+ The length of the final node representation vectors
427
+ before readout. Default to 64.
428
+ edge_hidden_feats: int
429
+ The length of the hidden edge representation vectors
430
+ for mpnn edge network. Default to 128.
431
+ edge_out_feats: int
432
+ The length of the final edge representation vectors
433
+ before readout. Default to 64.
434
+ num_step_message_passing: int
435
+ The number of rounds of message passing. Default to 3.
436
+ mpnn_residual: bool
437
+ If true, adds residual layer to mpnn layer. Default to True.
438
+ message_aggregator_type: str
439
+ MPNN message aggregator type, 'sum', 'mean' or 'max'.
440
+ Default to 'sum'.
441
+ mode: str
442
+ The model type, 'classification' or 'regression'.
443
+ Default to 'classification'.
444
+ number_atom_features: int
445
+ The length of the initial atom feature vectors. Default to 134.
446
+ number_bond_features: int
447
+ The length of the initial bond feature vectors. Default to 6.
448
+ n_classes: int
449
+ The number of classes to predict per task
450
+ (only used when ``mode`` is 'classification'). Default to 1.
451
+ readout_type: str
452
+ The Readout type, 'set2set' or 'global_sum_pooling'.
453
+ Default to 'set2set'.
454
+ num_step_set2set: int
455
+ Number of steps in set2set readout.
456
+ Used if, readout_type == 'set2set'.
457
+ Default to 6.
458
+ num_layer_set2set: int
459
+ Number of layers in set2set readout.
460
+ Used if, readout_type == 'set2set'.
461
+ Default to 3.
462
+ ffn_hidden_list: List
463
+ List of sizes of hidden layer in the feed-forward network layer.
464
+ Default to [300].
465
+ ffn_embeddings: int
466
+ Size of penultimate layer in the feed-forward network layer.
467
+ This determines the Principal Odor Map dimension.
468
+ Default to 256.
469
+ ffn_activation: str
470
+ Activation function to be used in feed-forward network layer.
471
+ Can choose between 'relu' for ReLU, 'leakyrelu' for LeakyReLU,
472
+ 'prelu' for PReLU, 'tanh' for TanH, 'selu' for SELU,
473
+ and 'elu' for ELU.
474
+ ffn_dropout_p: float
475
+ Dropout probability for the feed-forward network layer.
476
+ Default to 0.0
477
+ ffn_dropout_at_input_no_act: bool
478
+ If true, dropout is applied on the input tensor.
479
+ For single layer, it is not passed to an activation function.
480
+ weight_decay: float
481
+ weight decay value for L1 and L2 regularization. Default to 1e-5.
482
+ self_loop: bool
483
+ Whether to add self loops for the nodes, i.e. edges
484
+ from nodes to themselves. Generally, an MPNNPOMModel
485
+ does not require self loops. Default to False.
486
+ optimizer_name: str
487
+ Name of optimizer to be used from
488
+ [adam, adagrad, adamw, sparseadam, rmsprop, sgd, kfac]
489
+ Default to 'adam'.
490
+ device_name: Optional[str]
491
+ The device on which to run computations. If None, a device is
492
+ chosen automatically.
493
+ kwargs
494
+ This can include any keyword argument of TorchModel.
495
+ """
496
+ model: nn.Module = MPNNPOM(
497
+ n_tasks=n_tasks,
498
+ node_out_feats=node_out_feats,
499
+ edge_hidden_feats=edge_hidden_feats,
500
+ edge_out_feats=edge_out_feats,
501
+ num_step_message_passing=num_step_message_passing,
502
+ mpnn_residual=mpnn_residual,
503
+ message_aggregator_type=message_aggregator_type,
504
+ mode=mode,
505
+ number_atom_features=number_atom_features,
506
+ number_bond_features=number_bond_features,
507
+ n_classes=n_classes,
508
+ readout_type=readout_type,
509
+ num_step_set2set=num_step_set2set,
510
+ num_layer_set2set=num_layer_set2set,
511
+ ffn_hidden_list=ffn_hidden_list,
512
+ ffn_embeddings=ffn_embeddings,
513
+ ffn_activation=ffn_activation,
514
+ ffn_dropout_p=ffn_dropout_p,
515
+ ffn_dropout_at_input_no_act=ffn_dropout_at_input_no_act)
516
+
517
+ if class_imbalance_ratio and (len(class_imbalance_ratio) != n_tasks):
518
+ raise Exception("size of class_imbalance_ratio \
519
+ should be equal to n_tasks")
520
+
521
+ if mode == 'regression':
522
+ loss: Loss = L2Loss()
523
+ output_types: List = ['prediction']
524
+ else:
525
+ loss = CustomMultiLabelLoss(
526
+ class_imbalance_ratio=class_imbalance_ratio,
527
+ loss_aggr_type=loss_aggr_type,
528
+ device=device_name)
529
+ output_types = ['prediction', 'loss', 'embedding']
530
+
531
+ optimizer: Optimizer = get_optimizer(optimizer_name)
532
+ optimizer.learning_rate = learning_rate
533
+ if device_name is not None:
534
+ device: Optional[torch.device] = torch.device(device_name)
535
+ else:
536
+ device = None
537
+ super(MPNNPOMModel, self).__init__(model,
538
+ loss=loss,
539
+ output_types=output_types,
540
+ optimizer=optimizer,
541
+ learning_rate=learning_rate,
542
+ batch_size=batch_size,
543
+ device=device,
544
+ **kwargs)
545
+
546
+ self.weight_decay: float = weight_decay
547
+ self._self_loop: bool = self_loop
548
+ self.regularization_loss: Callable = self._regularization_loss
549
+
550
+ def _regularization_loss(self) -> torch.Tensor:
551
+ """
552
+ L1 and L2-norm losses for regularization
553
+
554
+ Returns
555
+ -------
556
+ torch.Tensor
557
+ sum of l1_norm and l2_norm
558
+ """
559
+ l1_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
560
+ l2_regularization: torch.Tensor = torch.tensor(0., requires_grad=True)
561
+ for name, param in self.model.named_parameters():
562
+ if 'bias' not in name:
563
+ l1_regularization = l1_regularization + torch.norm(param, p=1)
564
+ l2_regularization = l2_regularization + torch.norm(param, p=2)
565
+ l1_norm: torch.Tensor = self.weight_decay * l1_regularization
566
+ l2_norm: torch.Tensor = self.weight_decay * l2_regularization
567
+ return l1_norm + l2_norm
568
+
569
+ def _prepare_batch(
570
+ self, batch: Tuple[List, List, List]
571
+ ) -> Tuple[DGLGraph, List[torch.Tensor], List[torch.Tensor]]:
572
+ """Create batch data for MPNN.
573
+
574
+ Parameters
575
+ ----------
576
+ batch: Tuple[List, List, List]
577
+ The tuple is ``(inputs, labels, weights)``.
578
+
579
+ Returns
580
+ -------
581
+ g: DGLGraph
582
+ DGLGraph for a batch of graphs.
583
+ labels: list of torch.Tensor or None
584
+ The graph labels.
585
+ weights: list of torch.Tensor or None
586
+ The weights for each sample or
587
+ sample/task pair converted to torch.Tensor.
588
+ """
589
+ inputs: List
590
+ labels: List
591
+ weights: List
592
+
593
+ inputs, labels, weights = batch
594
+ dgl_graphs: List[DGLGraph] = [
595
+ graph.to_dgl_graph(self_loop=self._self_loop)
596
+ for graph in inputs[0]
597
+ ]
598
+ g: DGLGraph = dgl.batch(dgl_graphs).to(self.device)
599
+ _, labels, weights = super(MPNNPOMModel, self)._prepare_batch(
600
+ ([], labels, weights))
601
+ return g, labels, weights