Spaces:
Sleeping
Sleeping
File size: 1,653 Bytes
72a3513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
"""Module containing a class that represents a value function for prediction of
synthesisablity of new nodes in the tree search."""
from typing import List
import torch
from synplan.chem.precursor import Precursor, compose_precursors
from synplan.ml.networks.value import ValueNetwork
from synplan.ml.training import mol_to_pyg
class ValueNetworkFunction:
"""Value function implemented as a value neural network for node evaluation
(synthesisability prediction) in tree search."""
def __init__(self, weights_path: str) -> None:
"""The value function predicts the probability to synthesize the target molecule
with available building blocks starting from a given precursor.
:param weights_path: The value network weights file path.
"""
value_net = ValueNetwork.load_from_checkpoint(
weights_path, map_location=torch.device("cpu")
)
self.value_network = value_net.eval()
def predict_value(self, precursors: List[Precursor,]) -> float:
"""Predicts a value based on the given precursors from the node. For prediction,
precursors must be composed into a single molecule (product).
:param precursors: The list of precursors.
:return: The predicted float value ("synthesisability") of the node.
"""
molecule = compose_precursors(precursors=precursors, exclude_small=True)
pyg_graph = mol_to_pyg(molecule)
if pyg_graph:
with torch.no_grad():
value_pred = self.value_network.forward(pyg_graph)[0].item()
else:
value_pred = -1e6
return value_pred
|