synplanner_dev / synplan /ml /training /reinforcement.py
Gilmullin Almaz
Refactor code structure for improved readability and maintainability
72a3513
"""Module containing functions for running value network tuning with reinforcement learning
approach."""
import os
import random
from collections import defaultdict
from pathlib import Path
from random import shuffle
from typing import Dict, List
import torch
from CGRtools.containers import MoleculeContainer
from pytorch_lightning import Trainer
from torch.utils.data import random_split
from torch_geometric.data.lightning import LightningDataset
from synplan.chem.precursor import compose_precursors
from synplan.mcts.evaluation import ValueNetworkFunction
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.mcts.tree import Tree
from synplan.ml.networks.value import ValueNetwork
from synplan.ml.training.preprocessing import ValueNetworkDataset
from synplan.utils.config import (
PolicyNetworkConfig,
TuningConfig,
TreeConfig,
ValueNetworkConfig,
)
from synplan.utils.files import MoleculeReader
from synplan.utils.loading import (
load_building_blocks,
load_reaction_rules,
load_value_net,
)
from synplan.utils.logging import DisableLogger, HiddenPrints
def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork:
"""Creates the initial value network.
:param value_config: The value network configuration.
:return: The valueNetwork to be trained/tuned.
"""
weights_path = Path(value_config.weights_path)
value_network = ValueNetwork(
vector_dim=value_config.vector_dim,
batch_size=value_config.batch_size,
dropout=value_config.dropout,
num_conv_layers=value_config.num_conv_layers,
learning_rate=value_config.learning_rate,
)
with DisableLogger(), HiddenPrints():
trainer = Trainer()
trainer.strategy.connect(value_network)
trainer.save_checkpoint(weights_path)
return value_network
def create_targets_batch(
targets: List[MoleculeContainer], batch_size: int
) -> List[List[MoleculeContainer]]:
"""Creates the targets batches for planning simulations and value network tuning.
:param targets: The list of target molecules.
:param batch_size: The size of each target batch.
:return: The list of lists corresponding to each target batch.
"""
num_targets = len(targets)
batch_splits = list(
range(num_targets // batch_size + int(bool(num_targets % batch_size)))
)
if int(num_targets / batch_size) == 0:
print(f"1 batch were created with {num_targets} molecules")
else:
print(
f"{len(batch_splits)} batches were created with {batch_size} molecules each"
)
targets_batch_list = []
for batch_id in batch_splits:
batch_slices = [
i
for i in range(batch_id * batch_size, (batch_id + 1) * batch_size)
if i < len(targets)
]
targets_batch_list.append([targets[i] for i in batch_slices])
return targets_batch_list
def run_tree_search(
target: MoleculeContainer,
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reaction_rules_path: str,
building_blocks_path: str,
) -> Tree:
"""Runs tree search for the given target molecule.
:param target: The target molecule.
:param tree_config: The planning configuration of tree search.
:param policy_config: The policy network configuration.
:param value_config: The value network configuration.
:param reaction_rules_path: The path to the file with reaction rules.
:param building_blocks_path: The path to the file with building blocks.
:return: The built search tree for the given molecule.
"""
# policy and value function loading
policy_function = PolicyNetworkFunction(policy_config=policy_config)
value_function = ValueNetworkFunction(weights_path=value_config.weights_path)
reaction_rules = load_reaction_rules(reaction_rules_path)
building_blocks = load_building_blocks(building_blocks_path, standardize=True)
# initialize tree
tree_config.evaluation_type = "gcn"
tree_config.silent = True
tree = Tree(
target=target,
config=tree_config,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
expansion_function=policy_function,
evaluation_function=value_function,
)
tree._tqdm = False
# remove target from buildings blocs
if str(target) in tree.building_blocks:
tree.building_blocks.remove(str(target))
# run tree search
_ = list(tree)
return tree
def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]:
"""Takes the built tree and extracts the precursor for value network tuning. The
precursor from found retrosynthetic routes are labeled as a positive class and precursor
from not solved routes are labeled as a negative class.
:param tree_list: The list of built search trees.
:return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0).
"""
extracted_precursor = defaultdict(float)
for tree in tree_list:
for idx, node in tree.nodes.items():
# add solved nodes to set
if node.is_solved():
parent = idx
while parent and parent != 1:
composed_smi = str(
compose_precursors(tree.nodes[parent].new_precursors)
)
extracted_precursor[composed_smi] = 1.0
parent = tree.parents[parent]
else:
composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors))
extracted_precursor[composed_smi] = 0.0
# shuffle extracted precursor
processed_keys = list(extracted_precursor.keys())
shuffle(processed_keys)
extracted_precursor = {i: extracted_precursor[i] for i in processed_keys}
return extracted_precursor
def balance_extracted_precursor(extracted_precursor):
extracted_precursor_balanced = {}
neg_list = [i for i, j in extracted_precursor.items() if j == 0]
for k, v in extracted_precursor.items():
if v == 1:
extracted_precursor_balanced[k] = v
if len(extracted_precursor_balanced) < len(neg_list):
neg_list.pop(random.choice(range(len(neg_list))))
return extracted_precursor_balanced
def create_updating_set(
extracted_precursor: Dict[str, float], batch_size: int = 1
) -> LightningDataset:
"""Creates the value network updating dataset from precursor extracted from the planning
simulation.
:param extracted_precursor: The dictionary with the extracted precursor and their
labels.
:param batch_size: The size of the batch in value network updating.
:return: A LightningDataset object, which contains the tuning set for value network
tuning.
"""
extracted_precursor = balance_extracted_precursor(extracted_precursor)
full_dataset = ValueNetworkDataset(extracted_precursor)
train_size = int(0.6 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = random_split(
full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
)
print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")
return LightningDataset(
train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True
)
def tune_value_network(
datamodule: LightningDataset, value_config: ValueNetworkConfig
) -> None:
"""Trains the value network using a given tuning data and saves the trained neural
network.
:param datamodule: The tuning dataset (LightningDataset).
:param value_config: The value network configuration.
:return: None.
"""
current_weights = value_config.weights_path
value_network = load_value_net(ValueNetwork, current_weights)
with DisableLogger(), HiddenPrints():
trainer = Trainer(
accelerator="gpu",
devices=[0],
max_epochs=value_config.num_epoch,
enable_checkpointing=False,
logger=False,
gradient_clip_val=1.0,
enable_progress_bar=False,
)
trainer.fit(value_network, datamodule)
val_score = trainer.validate(value_network, datamodule.val_dataloader())[0]
trainer.save_checkpoint(current_weights)
print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}")
def run_training(
extracted_precursor: Dict[str, float] = None,
value_config: ValueNetworkConfig = None,
) -> None:
"""Runs the training stage in value network tuning.
:param extracted_precursor: The precursor extracted from the planing simulations.
:param value_config: The value network configuration.
:return: None.
"""
# create training set
training_set = create_updating_set(
extracted_precursor=extracted_precursor, batch_size=value_config.batch_size
)
# retrain value network
tune_value_network(datamodule=training_set, value_config=value_config)
def run_planning(
targets_batch: List[MoleculeContainer],
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reaction_rules_path: str,
building_blocks_path: str,
targets_batch_id: int,
):
"""Performs planning stage (tree search) for target molecules and save extracted
from built trees precursor for further tuning the value network in the training stage.
:param targets_batch:
:param tree_config:
:param policy_config:
:param value_config:
:param reaction_rules_path:
:param building_blocks_path:
:param targets_batch_id:
"""
from tqdm import tqdm
print(f"\nProcess batch number {targets_batch_id}")
tree_list = []
tree_config.silent = False
for target in tqdm(targets_batch):
try:
tree = run_tree_search(
target=target,
tree_config=tree_config,
policy_config=policy_config,
value_config=value_config,
reaction_rules_path=reaction_rules_path,
building_blocks_path=building_blocks_path,
)
tree_list.append(tree)
except Exception as e:
print(e)
continue
num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list])
print(f"Planning is finished with {num_solved} solved targets")
return tree_list
def run_updating(
targets_path: str,
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reinforce_config: TuningConfig,
reaction_rules_path: str,
building_blocks_path: str,
results_root: str = None,
) -> None:
"""Performs updating of value network.
:param targets_path: The path to the file with target molecules.
:param tree_config: The search tree configuration.
:param policy_config: The policy network configuration.
:param value_config: The value network configuration.
:param reinforce_config: The value network tuning configuration.
:param reaction_rules_path: The path to the file with reaction rules.
:param building_blocks_path: The path to the file with building blocks.
:param results_root: The path to the directory where trained value network will be
saved.
:return: None.
"""
# create results root folder
results_root = Path(results_root)
if not results_root.exists():
results_root.mkdir()
# load targets list
with MoleculeReader(targets_path) as targets:
targets = list(targets)
# create value neural network
value_config.weights_path = os.path.join(results_root, "value_network.ckpt")
create_value_network(value_config)
# create targets batch
targets_batch_list = create_targets_batch(
targets, batch_size=reinforce_config.batch_size
)
# run value network tuning
for batch_id, targets_batch in enumerate(targets_batch_list, start=1):
# start tree planning simulation for batch of targets
tree_list = run_planning(
targets_batch=targets_batch,
tree_config=tree_config,
policy_config=policy_config,
value_config=value_config,
reaction_rules_path=reaction_rules_path,
building_blocks_path=building_blocks_path,
targets_batch_id=batch_id,
)
# extract pos and neg precursor from the list of built trees
extracted_precursor = extract_tree_precursor(tree_list)
# train value network for extracted precursor
run_training(extracted_precursor=extracted_precursor, value_config=value_config)