"""Module containing basic pytorch architectures of policy and value neural networks.""" from abc import ABC, abstractmethod from typing import Dict, List, Tuple, Union import torch from adabelief_pytorch import AdaBelief from pytorch_lightning import LightningModule from torch import Tensor from torch.nn import GELU, Dropout, Linear, Module, ModuleDict, ModuleList from torch.nn.functional import relu from torch.optim.lr_scheduler import ReduceLROnPlateau from torch_geometric.data.batch import Batch from torch_geometric.nn.conv import GCNConv from torch_geometric.nn.pool import global_add_pool class GraphEmbedding(Module): """Needed to convert molecule atom vectors to the single vector using graph convolution.""" def __init__( self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 5 ): """Initializes a graph convolutional module. Needed to convert molecule atom vectors to the single vector using graph convolution. :param vector_dim: The dimensionality of the hidden layers and output layer of graph convolution module. :param dropout: Dropout is a regularization technique used in neural networks to prevent overfitting. It randomly sets a fraction of input units to 0 at each update during training time. :param num_conv_layers: The number of convolutional layers in a graph convolutional module. """ super().__init__() self.expansion = Linear(11, vector_dim) self.dropout = Dropout(dropout) self.gcn_convs = ModuleList( [ GCNConv( vector_dim, vector_dim, improved=True, ) for _ in range(num_conv_layers) ] ) def forward(self, graph: Batch, batch_size: int) -> Tensor: """Takes a graph as input and performs graph convolution on it. :param graph: The batch of molecular graphs, where each atom is represented by the atom/bond vector. :param batch_size: The size of the batch. :return: Graph embedding. """ atoms, connections = graph.x.float(), graph.edge_index.long() atoms = torch.log(atoms + 1) atoms = self.expansion(atoms) for gcn_conv in self.gcn_convs: atoms = atoms + self.dropout(relu(gcn_conv(atoms, connections))) return global_add_pool(atoms, graph.batch, size=batch_size) class GraphEmbeddingConcat(GraphEmbedding, Module): """Needed to concat.""" # TODO for what ? def __init__( self, vector_dim: int = 512, dropout: float = 0.4, num_conv_layers: int = 8 ): super().__init__() gcn_dim = vector_dim // num_conv_layers self.expansion = Linear(11, gcn_dim) self.dropout = Dropout(dropout) self.gcn_convs = ModuleList( [ ModuleDict( { "gcn": GCNConv(gcn_dim, gcn_dim, improved=True), "activation": GELU(), } ) for _ in range(num_conv_layers) ] ) def forward(self, graph: Batch, batch_size: int) -> Tensor: """Takes a graph as input and performs graph convolution on it. :param graph: The batch of molecular graphs, where each atom is represented by the atom/bond vector. :param batch_size: The size of the batch. :return: Graph embedding. """ atoms, connections = graph.x.float(), graph.edge_index.long() atoms = torch.log(atoms + 1) atoms = self.expansion(atoms) collected_atoms = [] for gcn_convs in self.gcn_convs: atoms = gcn_convs["gcn"](atoms, connections) atoms = gcn_convs["activation"](atoms) atoms = self.dropout(atoms) collected_atoms.append(atoms) atoms = torch.cat(collected_atoms, dim=-1) return global_add_pool(atoms, graph.batch, size=batch_size) class MCTSNetwork(LightningModule, ABC): """Basic class for policy and value networks.""" def __init__( self, vector_dim: int, batch_size: int, dropout: float = 0.4, num_conv_layers: int = 5, learning_rate: float = 0.001, gcn_concat: bool = False, ): """The basic class for MCTS graph convolutional neural networks (policy and value network). :param vector_dim: The dimensionality of the hidden layers and output layer of graph convolution module. :param dropout: Dropout is a regularization technique used in neural networks to prevent overfitting. :param num_conv_layers: The number of convolutional layers in a graph convolutional module. :param learning_rate: The learning rate determines how quickly the model learns from the training data. :param gcn_concat: ???. #TODO explain """ super().__init__() if gcn_concat: self.embedder = GraphEmbeddingConcat(vector_dim, dropout, num_conv_layers) else: self.embedder = GraphEmbedding(vector_dim, dropout, num_conv_layers) self.batch_size = batch_size self.lr = learning_rate @abstractmethod def forward(self, batch: Batch) -> Tensor: """The forward function takes a batch of input data and performs forward propagation through the neural network. :param batch: The batch of molecular graphs processed together in a single forward pass through the neural network. """ @abstractmethod def _get_loss(self, batch: Batch) -> Tensor: """Calculate the loss for a given batch of data. :param batch: The batch of input data that is used to compute the loss. """ def training_step(self, batch: Batch, batch_idx: int) -> Tensor: """Calculates the loss for a given training batch and logs the loss value. :param batch: The batch of data that is used for training. :param batch_idx: The index of the batch. :return: The value of the training loss. """ metrics = self._get_loss(batch) for name, value in metrics.items(): self.log( "train_" + name, value, prog_bar=True, on_step=True, on_epoch=True, batch_size=self.batch_size, ) return metrics["loss"] def validation_step(self, batch: Batch, batch_idx: int) -> None: """Calculates the loss for a given validation batch and logs the loss value. :param batch: The batch of data that is used for validation. :param batch_idx: The index of the batch. """ metrics = self._get_loss(batch) for name, value in metrics.items(): self.log("val_" + name, value, on_epoch=True, batch_size=self.batch_size) def test_step(self, batch: Batch, batch_idx: int) -> None: """Calculates the loss for a given test batch and logs the loss value. :param batch: The batch of data that is used for testing. :param batch_idx: The index of the batch. """ metrics = self._get_loss(batch) for name, value in metrics.items(): self.log("test_" + name, value, on_epoch=True, batch_size=self.batch_size) def configure_optimizers( self, ) -> Tuple[List[AdaBelief], List[Dict[str, Union[bool, str, ReduceLROnPlateau]]]]: """Returns an optimizer and a learning rate scheduler for training a model using the AdaBelief optimizer and ReduceLROnPlateau scheduler. :return: The optimizer and a scheduler. """ optimizer = AdaBelief( self.parameters(), lr=self.lr, eps=1e-16, betas=(0.9, 0.999), weight_decouple=True, rectify=True, weight_decay=0.01, print_change_log=False, ) lr_scheduler = ReduceLROnPlateau( optimizer, patience=3, factor=0.8, min_lr=5e-5, verbose=True ) scheduler = { "scheduler": lr_scheduler, "reduce_on_plateau": True, "monitor": "val_loss", } return [optimizer], [scheduler]