Spaces:
Sleeping
Sleeping
| """Module containing a class that represents a policy function for node expansion in the | |
| tree search.""" | |
| from typing import Iterator, List, Tuple, Union | |
| import torch | |
| import torch_geometric | |
| from CGRtools.reactor.reactor import Reactor | |
| from synplan.chem.precursor import Precursor | |
| from synplan.ml.networks.policy import PolicyNetwork | |
| from synplan.ml.training import mol_to_pyg | |
| from synplan.utils.config import PolicyNetworkConfig | |
| class PolicyNetworkFunction: | |
| """Policy function implemented as a policy neural network for node expansion in tree | |
| search.""" | |
| def __init__( | |
| self, policy_config: PolicyNetworkConfig, compile: bool = False | |
| ) -> None: | |
| """Initializes the expansion function (ranking or filter policy network). | |
| :param policy_config: An expansion policy configuration. | |
| :param compile: Is supposed to speed up the training with model compilation. | |
| """ | |
| self.config = policy_config | |
| policy_net = PolicyNetwork.load_from_checkpoint( | |
| self.config.weights_path, | |
| map_location=torch.device("cpu"), | |
| batch_size=1, | |
| dropout=0, | |
| ) | |
| policy_net = policy_net.eval() | |
| if compile: | |
| self.policy_net = torch_geometric.compile(policy_net, dynamic=True) | |
| else: | |
| self.policy_net = policy_net | |
| def predict_reaction_rules( | |
| self, precursor: Precursor, reaction_rules: List[Reactor] | |
| ) -> Iterator[Union[Iterator, Iterator[Tuple[float, Reactor, int]]]]: | |
| """The policy function predicts the list of reaction rules for a given precursor. | |
| :param precursor: The current precursor for which the reaction rules are predicted. | |
| :param reaction_rules: The list of reaction rules from which applicable reaction | |
| rules are predicted and selected. | |
| :return: Yielding the predicted probability for the reaction rule, reaction rule | |
| and reaction rule id. | |
| """ | |
| out_dim = list(self.policy_net.modules())[-1].out_features | |
| if out_dim != len(reaction_rules): | |
| raise Exception( | |
| f"The policy network output dimensionality is {out_dim}, but the number of reaction rules is {len(reaction_rules)}. " | |
| "Probably you use a different version of the policy network. Be sure to retain the policy network " | |
| "with the current set of reaction rules" | |
| ) | |
| pyg_graph = mol_to_pyg(precursor.molecule, canonicalize=False) | |
| if pyg_graph: | |
| with torch.no_grad(): | |
| if self.policy_net.policy_type == "filtering": | |
| probs, priority = self.policy_net.forward(pyg_graph) | |
| if self.policy_net.policy_type == "ranking": | |
| probs = self.policy_net.forward(pyg_graph) | |
| del pyg_graph | |
| else: | |
| return [] | |
| probs = probs[0].double() | |
| if self.policy_net.policy_type == "filtering": | |
| priority = priority[0].double() | |
| priority_coef = self.config.priority_rules_fraction | |
| probs = (1 - priority_coef) * probs + priority_coef * priority | |
| sorted_probs, sorted_rules = torch.sort(probs, descending=True) | |
| sorted_probs, sorted_rules = ( | |
| sorted_probs[: self.config.top_rules], | |
| sorted_rules[: self.config.top_rules], | |
| ) | |
| if self.policy_net.policy_type == "filtering": | |
| sorted_probs = torch.softmax(sorted_probs, -1) | |
| sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist() | |
| for prob, rule_id in zip(sorted_probs, sorted_rules): | |
| if ( | |
| prob > self.config.rule_prob_threshold | |
| ): # search may fail if rule_prob_threshold is too low (recommended value is 0.0) | |
| yield prob, reaction_rules[rule_id], rule_id | |