Spaces:
Sleeping
Sleeping
File size: 3,873 Bytes
72a3513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
"""Module containing a class that represents a policy function for node expansion in the
tree search."""
from typing import Iterator, List, Tuple, Union
import torch
import torch_geometric
from CGRtools.reactor.reactor import Reactor
from synplan.chem.precursor import Precursor
from synplan.ml.networks.policy import PolicyNetwork
from synplan.ml.training import mol_to_pyg
from synplan.utils.config import PolicyNetworkConfig
class PolicyNetworkFunction:
"""Policy function implemented as a policy neural network for node expansion in tree
search."""
def __init__(
self, policy_config: PolicyNetworkConfig, compile: bool = False
) -> None:
"""Initializes the expansion function (ranking or filter policy network).
:param policy_config: An expansion policy configuration.
:param compile: Is supposed to speed up the training with model compilation.
"""
self.config = policy_config
policy_net = PolicyNetwork.load_from_checkpoint(
self.config.weights_path,
map_location=torch.device("cpu"),
batch_size=1,
dropout=0,
)
policy_net = policy_net.eval()
if compile:
self.policy_net = torch_geometric.compile(policy_net, dynamic=True)
else:
self.policy_net = policy_net
def predict_reaction_rules(
self, precursor: Precursor, reaction_rules: List[Reactor]
) -> Iterator[Union[Iterator, Iterator[Tuple[float, Reactor, int]]]]:
"""The policy function predicts the list of reaction rules for a given precursor.
:param precursor: The current precursor for which the reaction rules are predicted.
:param reaction_rules: The list of reaction rules from which applicable reaction
rules are predicted and selected.
:return: Yielding the predicted probability for the reaction rule, reaction rule
and reaction rule id.
"""
out_dim = list(self.policy_net.modules())[-1].out_features
if out_dim != len(reaction_rules):
raise Exception(
f"The policy network output dimensionality is {out_dim}, but the number of reaction rules is {len(reaction_rules)}. "
"Probably you use a different version of the policy network. Be sure to retain the policy network "
"with the current set of reaction rules"
)
pyg_graph = mol_to_pyg(precursor.molecule, canonicalize=False)
if pyg_graph:
with torch.no_grad():
if self.policy_net.policy_type == "filtering":
probs, priority = self.policy_net.forward(pyg_graph)
if self.policy_net.policy_type == "ranking":
probs = self.policy_net.forward(pyg_graph)
del pyg_graph
else:
return []
probs = probs[0].double()
if self.policy_net.policy_type == "filtering":
priority = priority[0].double()
priority_coef = self.config.priority_rules_fraction
probs = (1 - priority_coef) * probs + priority_coef * priority
sorted_probs, sorted_rules = torch.sort(probs, descending=True)
sorted_probs, sorted_rules = (
sorted_probs[: self.config.top_rules],
sorted_rules[: self.config.top_rules],
)
if self.policy_net.policy_type == "filtering":
sorted_probs = torch.softmax(sorted_probs, -1)
sorted_probs, sorted_rules = sorted_probs.tolist(), sorted_rules.tolist()
for prob, rule_id in zip(sorted_probs, sorted_rules):
if (
prob > self.config.rule_prob_threshold
): # search may fail if rule_prob_threshold is too low (recommended value is 0.0)
yield prob, reaction_rules[rule_id], rule_id
|