synplanner_dev / synplan /mcts /evaluation.py
Gilmullin Almaz
Refactor code structure for improved readability and maintainability
72a3513
"""Module containing a class that represents a value function for prediction of
synthesisablity of new nodes in the tree search."""
from typing import List
import torch
from synplan.chem.precursor import Precursor, compose_precursors
from synplan.ml.networks.value import ValueNetwork
from synplan.ml.training import mol_to_pyg
class ValueNetworkFunction:
"""Value function implemented as a value neural network for node evaluation
(synthesisability prediction) in tree search."""
def __init__(self, weights_path: str) -> None:
"""The value function predicts the probability to synthesize the target molecule
with available building blocks starting from a given precursor.
:param weights_path: The value network weights file path.
"""
value_net = ValueNetwork.load_from_checkpoint(
weights_path, map_location=torch.device("cpu")
)
self.value_network = value_net.eval()
def predict_value(self, precursors: List[Precursor,]) -> float:
"""Predicts a value based on the given precursors from the node. For prediction,
precursors must be composed into a single molecule (product).
:param precursors: The list of precursors.
:return: The predicted float value ("synthesisability") of the node.
"""
molecule = compose_precursors(precursors=precursors, exclude_small=True)
pyg_graph = mol_to_pyg(molecule)
if pyg_graph:
with torch.no_grad():
value_pred = self.value_network.forward(pyg_graph)[0].item()
else:
value_pred = -1e6
return value_pred