File size: 8,468 Bytes
72a3513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""Module containing basic pytorch architectures of policy and value neural networks."""

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union

import torch
from adabelief_pytorch import AdaBelief
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList
from torch.nn.functional import relu
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data.batch import Batch
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.pool import global_add_pool


class GraphEmbedding(Module):
    """Needed to convert molecule atom vectors to the single vector using graph
    convolution."""

    def __init__(
        self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5
    ):
        """Initializes a graph convolutional module. Needed to convert molecule atom
        vectors to the single vector using graph convolution.

        :param vector_dim: The dimensionality of the hidden layers and output layer of
            graph convolution module.
        :param dropout: Dropout is a regularization technique used in neural networks to
            prevent overfitting. It randomly sets a fraction of input units to 0 at each
            update during training time.
        :param num_conv_layers: The number of convolutional layers in a graph
            convolutional module.
        """

        super().__init__()
        self.expansion = Linear(11, vector_dim)
        self.dropout = Dropout(dropout)
        self.gcn_convs = ModuleList(
            [
                GCNConv(
                    vector_dim,
                    vector_dim,
                    improved=True,
                )
                for _ in range(num_conv_layers)
            ]
        )

    def forward(self, graph: Batch, batch_size: int) -> Tensor:
        """Takes a graph as input and performs graph convolution on it.

        :param graph: The batch of molecular graphs, where each atom is represented by
            the atom/bond vector.
        :param batch_size: The size of the batch.
        :return: Graph embedding.
        """
        atoms, connections = graph.x.float(), graph.edge_index.long()
        atoms = torch.log(atoms + 1)
        atoms = self.expansion(atoms)
        for gcn_conv in self.gcn_convs:
            atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections)))

        return global_add_pool(atoms, graph.batch, size=batch_size)


class GraphEmbeddingConcat(GraphEmbedding, Module):
    """Needed to concat."""  # TODO for what ?

    def __init__(
        self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8
    ):
        super().__init__()

        gcn_dim = vector_dim // num_conv_layers

        self.expansion = Linear(11, gcn_dim)
        self.dropout = Dropout(dropout)
        self.gcn_convs = ModuleList(
            [
                ModuleDict(
                    {
                        "gcn": GCNConv(gcn_dim, gcn_dim, improved=True),
                        "activation": GELU(),
                    }
                )
                for _ in range(num_conv_layers)
            ]
        )

    def forward(self, graph: Batch, batch_size: int) -> Tensor:
        """Takes a graph as input and performs graph convolution on it.

        :param graph: The batch of molecular graphs, where each atom is represented by
            the atom/bond vector.
        :param batch_size: The size of the batch.
        :return: Graph embedding.
        """

        atoms, connections = graph.x.float(), graph.edge_index.long()
        atoms = torch.log(atoms + 1)
        atoms = self.expansion(atoms)

        collected_atoms = []
        for gcn_convs in self.gcn_convs:
            atoms = gcn_convs["gcn"](atoms, connections)
            atoms = gcn_convs["activation"](atoms)
            atoms = self.dropout(atoms)
            collected_atoms.append(atoms)

        atoms = torch.cat(collected_atoms, dim=-1)

        return global_add_pool(atoms, graph.batch, size=batch_size)


class MCTSNetwork(LightningModule, ABC):
    """Basic class for policy and value networks."""

    def __init__(
        self,
        vector_dim: int,
        batch_size: int,
        dropout: float = 0.4,
        num_conv_layers: int = 5,
        learning_rate: float = 0.001,
        gcn_concat: bool = False,
    ):
        """The basic class for MCTS graph convolutional neural networks (policy and
        value network).

        :param vector_dim: The dimensionality of the hidden layers and output layer of
            graph convolution module.
        :param dropout: Dropout is a regularization technique used in neural networks to
            prevent overfitting.
        :param num_conv_layers: The number of convolutional layers in a graph
            convolutional module.
        :param learning_rate: The learning rate determines how quickly the model learns
            from the training data.
        :param gcn_concat: ???. #TODO explain
        """
        super().__init__()
        if gcn_concat:
            self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers)
        else:
            self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers)
        self.batch_size = batch_size
        self.lr = learning_rate

    @abstractmethod
    def forward(self, batch: Batch) -> Tensor:
        """The forward function takes a batch of input data and performs forward
        propagation through the neural network.

        :param batch: The batch of molecular graphs processed together in a single
            forward pass through the neural network.
        """

    @abstractmethod
    def _get_loss(self, batch: Batch) -> Tensor:
        """Calculate the loss for a given batch of data.

        :param batch: The batch of input data that is used to compute the loss.
        """

    def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
        """Calculates the loss for a given training batch and logs the loss value.

        :param batch: The batch of data that is used for training.
        :param batch_idx: The index of the batch.
        :return: The value of the training loss.
        """
        metrics = self._get_loss(batch)
        for name, value in metrics.items():
            self.log(
                "train_" + name,
                value,
                prog_bar=True,
                on_step=True,
                on_epoch=True,
                batch_size=self.batch_size,
            )
        return metrics["loss"]

    def validation_step(self, batch: Batch, batch_idx: int) -> None:
        """Calculates the loss for a given validation batch and logs the loss value.

        :param batch: The batch of data that is used for validation.
        :param batch_idx: The index of the batch.
        """
        metrics = self._get_loss(batch)
        for name, value in metrics.items():
            self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size)

    def test_step(self, batch: Batch, batch_idx: int) -> None:
        """Calculates the loss for a given test batch and logs the loss value.

        :param batch: The batch of data that is used for testing.
        :param batch_idx: The index of the batch.
        """
        metrics = self._get_loss(batch)
        for name, value in metrics.items():
            self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size)

    def configure_optimizers(
        self,
    ) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]:
        """Returns an optimizer and a learning rate scheduler for training a model using
        the AdaBelief optimizer and ReduceLROnPlateau scheduler.

        :return: The optimizer and a scheduler.
        """

        optimizer = AdaBelief(
            self.parameters(),
            lr=self.lr,
            eps=1e-16,
            betas=(0.9, 0.999),
            weight_decouple=True,
            rectify=True,
            weight_decay=0.01,
            print_change_log=False,
        )

        lr_scheduler = ReduceLROnPlateau(
            optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True
        )
        scheduler = {
            "scheduler": lr_scheduler,
            "reduce_on_plateau": True,
            "monitor": "val_loss",
        }

        return [optimizer], [scheduler]