Spaces:
Sleeping
Sleeping
File size: 8,468 Bytes
72a3513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
"""Module containing basic pytorch architectures of policy and value neural networks."""
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
import torch
from adabelief_pytorch import AdaBelief
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList
from torch.nn.functional import relu
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data.batch import Batch
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.pool import global_add_pool
class GraphEmbedding(Module):
"""Needed to convert molecule atom vectors to the single vector using graph
convolution."""
def __init__(
self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5
):
"""Initializes a graph convolutional module. Needed to convert molecule atom
vectors to the single vector using graph convolution.
:param vector_dim: The dimensionality of the hidden layers and output layer of
graph convolution module.
:param dropout: Dropout is a regularization technique used in neural networks to
prevent overfitting. It randomly sets a fraction of input units to 0 at each
update during training time.
:param num_conv_layers: The number of convolutional layers in a graph
convolutional module.
"""
super().__init__()
self.expansion = Linear(11, vector_dim)
self.dropout = Dropout(dropout)
self.gcn_convs = ModuleList(
[
GCNConv(
vector_dim,
vector_dim,
improved=True,
)
for _ in range(num_conv_layers)
]
)
def forward(self, graph: Batch, batch_size: int) -> Tensor:
"""Takes a graph as input and performs graph convolution on it.
:param graph: The batch of molecular graphs, where each atom is represented by
the atom/bond vector.
:param batch_size: The size of the batch.
:return: Graph embedding.
"""
atoms, connections = graph.x.float(), graph.edge_index.long()
atoms = torch.log(atoms + 1)
atoms = self.expansion(atoms)
for gcn_conv in self.gcn_convs:
atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections)))
return global_add_pool(atoms, graph.batch, size=batch_size)
class GraphEmbeddingConcat(GraphEmbedding, Module):
"""Needed to concat.""" # TODO for what ?
def __init__(
self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8
):
super().__init__()
gcn_dim = vector_dim // num_conv_layers
self.expansion = Linear(11, gcn_dim)
self.dropout = Dropout(dropout)
self.gcn_convs = ModuleList(
[
ModuleDict(
{
"gcn": GCNConv(gcn_dim, gcn_dim, improved=True),
"activation": GELU(),
}
)
for _ in range(num_conv_layers)
]
)
def forward(self, graph: Batch, batch_size: int) -> Tensor:
"""Takes a graph as input and performs graph convolution on it.
:param graph: The batch of molecular graphs, where each atom is represented by
the atom/bond vector.
:param batch_size: The size of the batch.
:return: Graph embedding.
"""
atoms, connections = graph.x.float(), graph.edge_index.long()
atoms = torch.log(atoms + 1)
atoms = self.expansion(atoms)
collected_atoms = []
for gcn_convs in self.gcn_convs:
atoms = gcn_convs["gcn"](atoms, connections)
atoms = gcn_convs["activation"](atoms)
atoms = self.dropout(atoms)
collected_atoms.append(atoms)
atoms = torch.cat(collected_atoms, dim=-1)
return global_add_pool(atoms, graph.batch, size=batch_size)
class MCTSNetwork(LightningModule, ABC):
"""Basic class for policy and value networks."""
def __init__(
self,
vector_dim: int,
batch_size: int,
dropout: float = 0.4,
num_conv_layers: int = 5,
learning_rate: float = 0.001,
gcn_concat: bool = False,
):
"""The basic class for MCTS graph convolutional neural networks (policy and
value network).
:param vector_dim: The dimensionality of the hidden layers and output layer of
graph convolution module.
:param dropout: Dropout is a regularization technique used in neural networks to
prevent overfitting.
:param num_conv_layers: The number of convolutional layers in a graph
convolutional module.
:param learning_rate: The learning rate determines how quickly the model learns
from the training data.
:param gcn_concat: ???. #TODO explain
"""
super().__init__()
if gcn_concat:
self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers)
else:
self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers)
self.batch_size = batch_size
self.lr = learning_rate
@abstractmethod
def forward(self, batch: Batch) -> Tensor:
"""The forward function takes a batch of input data and performs forward
propagation through the neural network.
:param batch: The batch of molecular graphs processed together in a single
forward pass through the neural network.
"""
@abstractmethod
def _get_loss(self, batch: Batch) -> Tensor:
"""Calculate the loss for a given batch of data.
:param batch: The batch of input data that is used to compute the loss.
"""
def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
"""Calculates the loss for a given training batch and logs the loss value.
:param batch: The batch of data that is used for training.
:param batch_idx: The index of the batch.
:return: The value of the training loss.
"""
metrics = self._get_loss(batch)
for name, value in metrics.items():
self.log(
"train_" + name,
value,
prog_bar=True,
on_step=True,
on_epoch=True,
batch_size=self.batch_size,
)
return metrics["loss"]
def validation_step(self, batch: Batch, batch_idx: int) -> None:
"""Calculates the loss for a given validation batch and logs the loss value.
:param batch: The batch of data that is used for validation.
:param batch_idx: The index of the batch.
"""
metrics = self._get_loss(batch)
for name, value in metrics.items():
self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size)
def test_step(self, batch: Batch, batch_idx: int) -> None:
"""Calculates the loss for a given test batch and logs the loss value.
:param batch: The batch of data that is used for testing.
:param batch_idx: The index of the batch.
"""
metrics = self._get_loss(batch)
for name, value in metrics.items():
self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size)
def configure_optimizers(
self,
) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]:
"""Returns an optimizer and a learning rate scheduler for training a model using
the AdaBelief optimizer and ReduceLROnPlateau scheduler.
:return: The optimizer and a scheduler.
"""
optimizer = AdaBelief(
self.parameters(),
lr=self.lr,
eps=1e-16,
betas=(0.9, 0.999),
weight_decouple=True,
rectify=True,
weight_decay=0.01,
print_change_log=False,
)
lr_scheduler = ReduceLROnPlateau(
optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True
)
scheduler = {
"scheduler": lr_scheduler,
"reduce_on_plateau": True,
"monitor": "val_loss",
}
return [optimizer], [scheduler]
|