"""Module containing main class for policy network.""" from abc import ABC from typing import Dict import torch from pytorch_lightning import LightningModule from torch import Tensor from torch.nn import Linear from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot from torch_geometric.data.batch import Batch from torchmetrics.functional.classification import f1_score, recall, specificity from synplan.ml.networks.modules import MCTSNetwork class PolicyNetwork(MCTSNetwork, LightningModule, ABC): """Policy network.""" def __init__( self, *args, n_rules: int, vector_dim: int, policy_type: str = "ranking", **kwargs ): """Initializes a policy network with the given number of reaction rules (output dimension) and vector graph embedding dimension, and creates linear layers for predicting the regular and priority reaction rules. :param n_rules: The number of reaction rules in the policy network. :param vector_dim: The dimensionality of the input vectors. """ super().__init__(vector_dim, *args, **kwargs) self.save_hyperparameters() self.policy_type = policy_type self.n_rules = n_rules self.y_predictor = Linear(vector_dim, n_rules) if self.policy_type == "filtering": self.priority_predictor = Linear(vector_dim, n_rules) def forward(self, batch: Batch) -> Tensor: """Takes a molecular graph, applies a graph convolution and sigmoid layers to predict regular and priority reaction rules. :param batch: The input batch of molecular graphs. :return: Returns the vector of probabilities (given by sigmoid) of successful application of regular and priority reaction rules. """ x = self.embedder(batch, self.batch_size) y = self.y_predictor(x) if self.policy_type == "ranking": y = torch.softmax(y, dim=-1) return y if self.policy_type == "filtering": y = torch.sigmoid(y) priority = torch.sigmoid(self.priority_predictor(x)) return y, priority def _get_loss(self, batch: Batch) -> Dict[str, Tensor]: """Calculates the loss and various classification metrics for a given batch for reaction rules prediction. :param batch: The batch of molecular graphs. :return: A dictionary with loss value and balanced accuracy of reaction rules prediction. """ true_y = batch.y_rules.long() x = self.embedder(batch, self.batch_size) pred_y = self.y_predictor(x) if self.policy_type == "ranking": true_one_hot = one_hot(true_y, num_classes=self.n_rules) loss = cross_entropy(pred_y, true_one_hot.float()) ba_y = ( recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules) + specificity( pred_y, true_y, task="multiclass", num_classes=self.n_rules ) ) / 2 f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules) metrics = {"loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y} elif self.policy_type == "filtering": loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float()) ba_y = ( recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules) + specificity( pred_y, true_y, task="multilabel", num_labels=self.n_rules ) ) / 2 f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules) true_priority = batch.y_priority.float() pred_priority = self.priority_predictor(x) loss_priority = binary_cross_entropy_with_logits( pred_priority, true_priority ) loss = loss_y + loss_priority true_priority = true_priority.long() ba_priority = ( recall( pred_priority, true_priority, task="multilabel", num_labels=self.n_rules, ) + specificity( pred_priority, true_priority, task="multilabel", num_labels=self.n_rules, ) ) / 2 f1_priority = f1_score( pred_priority, true_priority, task="multilabel", num_labels=self.n_rules ) metrics = { "loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y, "balanced_accuracy_priority": ba_priority, "f1_score_priority": f1_priority, } return metrics