File size: 2,481 Bytes
72a3513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Module containing main class for value network."""

from abc import ABC
from typing import Any, Dict

import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Linear
from torch.nn.functional import binary_cross_entropy_with_logits
from torch_geometric.data.batch import Batch
from torchmetrics.functional.classification import (
    binary_f1_score,
    binary_recall,
    binary_specificity,
)

from synplan.ml.networks.modules import MCTSNetwork


class ValueNetwork(MCTSNetwork, LightningModule, ABC):
    """Value network."""

    def __init__(self, vector_dim: int, *args: Any, **kwargs: Any) -> None:
        """Initializes a value network, and creates linear layer for predicting the
        synthesisability of given precursor represented by molecular graph.

        :param vector_dim: The dimensionality of the output linear layer.
        """
        super().__init__(vector_dim, *args, **kwargs)
        self.save_hyperparameters()
        self.predictor = Linear(vector_dim, 1)

    def forward(self, batch) -> torch.Tensor:
        """Takes a batch of molecular graphs, applies a graph convolution returns the
        synthesisability (probability given by sigmoid function) of a given precursor
        represented by molecular graph precessed by graph convolution.

        :param batch: The batch of molecular graphs.
        :return: The predicted synthesisability (between 0 and 1).
        """

        x = self.embedder(batch, self.batch_size)
        x = torch.sigmoid(self.predictor(x))
        return x

    def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
        """Calculates the loss and various classification metrics for a given batch for
        the precursor synthesysability prediction.

        :param batch: The batch of molecular graphs.
        :return: The dictionary with loss value and balanced accuracy of precursor
            synthesysability prediction.
        """

        true_y = batch.y.float()
        true_y = torch.unsqueeze(true_y, -1)
        x = self.embedder(batch, self.batch_size)
        pred_y = self.predictor(x)
        # calc loss func
        loss = binary_cross_entropy_with_logits(pred_y, true_y)

        true_y = true_y.long()
        ba = (binary_recall(pred_y, true_y) + binary_specificity(pred_y, true_y)) / 2
        f1 = binary_f1_score(pred_y, true_y)
        metrics = {"loss": loss, "balanced_accuracy": ba, "f1_score": f1}
        return metrics