Spaces:
Sleeping
Sleeping
| """Module containing a class that represents a value function for prediction of | |
| synthesisablity of new nodes in the tree search.""" | |
| from typing import List | |
| import torch | |
| from synplan.chem.precursor import Precursor, compose_precursors | |
| from synplan.ml.networks.value import ValueNetwork | |
| from synplan.ml.training import mol_to_pyg | |
| class ValueNetworkFunction: | |
| """Value function implemented as a value neural network for node evaluation | |
| (synthesisability prediction) in tree search.""" | |
| def __init__(self, weights_path: str) -> None: | |
| """The value function predicts the probability to synthesize the target molecule | |
| with available building blocks starting from a given precursor. | |
| :param weights_path: The value network weights file path. | |
| """ | |
| value_net = ValueNetwork.load_from_checkpoint( | |
| weights_path, map_location=torch.device("cpu") | |
| ) | |
| self.value_network = value_net.eval() | |
| def predict_value(self, precursors: List[Precursor,]) -> float: | |
| """Predicts a value based on the given precursors from the node. For prediction, | |
| precursors must be composed into a single molecule (product). | |
| :param precursors: The list of precursors. | |
| :return: The predicted float value ("synthesisability") of the node. | |
| """ | |
| molecule = compose_precursors(precursors=precursors, exclude_small=True) | |
| pyg_graph = mol_to_pyg(molecule) | |
| if pyg_graph: | |
| with torch.no_grad(): | |
| value_pred = self.value_network.forward(pyg_graph)[0].item() | |
| else: | |
| value_pred = -1e6 | |
| return value_pred | |