File size: 1,653 Bytes
72a3513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Module containing a class that represents a value function for prediction of
synthesisablity of new nodes in the tree search."""

from typing import List

import torch

from synplan.chem.precursor import Precursor, compose_precursors
from synplan.ml.networks.value import ValueNetwork
from synplan.ml.training import mol_to_pyg


class ValueNetworkFunction:
    """Value function implemented as a value neural network for node evaluation
    (synthesisability prediction) in tree search."""

    def __init__(self, weights_path: str) -> None:
        """The value function predicts the probability to synthesize the target molecule
        with available building blocks starting from a given precursor.

        :param weights_path: The value network weights file path.
        """

        value_net = ValueNetwork.load_from_checkpoint(
            weights_path, map_location=torch.device("cpu")
        )
        self.value_network = value_net.eval()

    def predict_value(self, precursors: List[Precursor,]) -> float:
        """Predicts a value based on the given precursors from the node. For prediction,
        precursors must be composed into a single molecule (product).

        :param precursors: The list of precursors.
        :return: The predicted float value ("synthesisability") of the node.
        """

        molecule = compose_precursors(precursors=precursors, exclude_small=True)
        pyg_graph = mol_to_pyg(molecule)
        if pyg_graph:
            with torch.no_grad():
                value_pred = self.value_network.forward(pyg_graph)[0].item()
        else:
            value_pred = -1e6

        return value_pred