"""Module containing main class for value network.""" from abc import ABC from typing import Any, Dict import torch from pytorch_lightning import LightningModule from torch import Tensor from torch.nn import Linear from torch.nn.functional import binary_cross_entropy_with_logits from torch_geometric.data.batch import Batch from torchmetrics.functional.classification import ( binary_f1_score, binary_recall, binary_specificity, ) from synplan.ml.networks.modules import MCTSNetwork class ValueNetwork(MCTSNetwork, LightningModule, ABC): """Value network.""" def __init__(self, vector_dim: int, *args: Any, **kwargs: Any) -> None: """Initializes a value network, and creates linear layer for predicting the synthesisability of given precursor represented by molecular graph. :param vector_dim: The dimensionality of the output linear layer. """ super().__init__(vector_dim, *args, **kwargs) self.save_hyperparameters() self.predictor = Linear(vector_dim, 1) def forward(self, batch) -> torch.Tensor: """Takes a batch of molecular graphs, applies a graph convolution returns the synthesisability (probability given by sigmoid function) of a given precursor represented by molecular graph precessed by graph convolution. :param batch: The batch of molecular graphs. :return: The predicted synthesisability (between 0 and 1). """ x = self.embedder(batch, self.batch_size) x = torch.sigmoid(self.predictor(x)) return x def _get_loss(self, batch: Batch) -> Dict[str, Tensor]: """Calculates the loss and various classification metrics for a given batch for the precursor synthesysability prediction. :param batch: The batch of molecular graphs. :return: The dictionary with loss value and balanced accuracy of precursor synthesysability prediction. """ true_y = batch.y.float() true_y = torch.unsqueeze(true_y, -1) x = self.embedder(batch, self.batch_size) pred_y = self.predictor(x) # calc loss func loss = binary_cross_entropy_with_logits(pred_y, true_y) true_y = true_y.long() ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2 f1 = binary_f1_score(pred_y, true_y) metrics = {"loss": loss, "balanced_accuracy": ba, "f1_score": f1} return metrics