Spaces:
Sleeping
Sleeping
| """Module containing functions for running value network tuning with reinforcement learning | |
| approach.""" | |
| import os | |
| import random | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from random import shuffle | |
| from typing import Dict, List | |
| import torch | |
| from CGRtools.containers import MoleculeContainer | |
| from pytorch_lightning import Trainer | |
| from torch.utils.data import random_split | |
| from torch_geometric.data.lightning import LightningDataset | |
| from synplan.chem.precursor import compose_precursors | |
| from synplan.mcts.evaluation import ValueNetworkFunction | |
| from synplan.mcts.expansion import PolicyNetworkFunction | |
| from synplan.mcts.tree import Tree | |
| from synplan.ml.networks.value import ValueNetwork | |
| from synplan.ml.training.preprocessing import ValueNetworkDataset | |
| from synplan.utils.config import ( | |
| PolicyNetworkConfig, | |
| TuningConfig, | |
| TreeConfig, | |
| ValueNetworkConfig, | |
| ) | |
| from synplan.utils.files import MoleculeReader | |
| from synplan.utils.loading import ( | |
| load_building_blocks, | |
| load_reaction_rules, | |
| load_value_net, | |
| ) | |
| from synplan.utils.logging import DisableLogger, HiddenPrints | |
| def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork: | |
| """Creates the initial value network. | |
| :param value_config: The value network configuration. | |
| :return: The valueNetwork to be trained/tuned. | |
| """ | |
| weights_path = Path(value_config.weights_path) | |
| value_network = ValueNetwork( | |
| vector_dim=value_config.vector_dim, | |
| batch_size=value_config.batch_size, | |
| dropout=value_config.dropout, | |
| num_conv_layers=value_config.num_conv_layers, | |
| learning_rate=value_config.learning_rate, | |
| ) | |
| with DisableLogger(), HiddenPrints(): | |
| trainer = Trainer() | |
| trainer.strategy.connect(value_network) | |
| trainer.save_checkpoint(weights_path) | |
| return value_network | |
| def create_targets_batch( | |
| targets: List[MoleculeContainer], batch_size: int | |
| ) -> List[List[MoleculeContainer]]: | |
| """Creates the targets batches for planning simulations and value network tuning. | |
| :param targets: The list of target molecules. | |
| :param batch_size: The size of each target batch. | |
| :return: The list of lists corresponding to each target batch. | |
| """ | |
| num_targets = len(targets) | |
| batch_splits = list( | |
| range(num_targets // batch_size + int(bool(num_targets % batch_size))) | |
| ) | |
| if int(num_targets / batch_size) == 0: | |
| print(f"1 batch were created with {num_targets} molecules") | |
| else: | |
| print( | |
| f"{len(batch_splits)} batches were created with {batch_size} molecules each" | |
| ) | |
| targets_batch_list = [] | |
| for batch_id in batch_splits: | |
| batch_slices = [ | |
| i | |
| for i in range(batch_id * batch_size, (batch_id + 1) * batch_size) | |
| if i < len(targets) | |
| ] | |
| targets_batch_list.append([targets[i] for i in batch_slices]) | |
| return targets_batch_list | |
| def run_tree_search( | |
| target: MoleculeContainer, | |
| tree_config: TreeConfig, | |
| policy_config: PolicyNetworkConfig, | |
| value_config: ValueNetworkConfig, | |
| reaction_rules_path: str, | |
| building_blocks_path: str, | |
| ) -> Tree: | |
| """Runs tree search for the given target molecule. | |
| :param target: The target molecule. | |
| :param tree_config: The planning configuration of tree search. | |
| :param policy_config: The policy network configuration. | |
| :param value_config: The value network configuration. | |
| :param reaction_rules_path: The path to the file with reaction rules. | |
| :param building_blocks_path: The path to the file with building blocks. | |
| :return: The built search tree for the given molecule. | |
| """ | |
| # policy and value function loading | |
| policy_function = PolicyNetworkFunction(policy_config=policy_config) | |
| value_function = ValueNetworkFunction(weights_path=value_config.weights_path) | |
| reaction_rules = load_reaction_rules(reaction_rules_path) | |
| building_blocks = load_building_blocks(building_blocks_path, standardize=True) | |
| # initialize tree | |
| tree_config.evaluation_type = "gcn" | |
| tree_config.silent = True | |
| tree = Tree( | |
| target=target, | |
| config=tree_config, | |
| reaction_rules=reaction_rules, | |
| building_blocks=building_blocks, | |
| expansion_function=policy_function, | |
| evaluation_function=value_function, | |
| ) | |
| tree._tqdm = False | |
| # remove target from buildings blocs | |
| if str(target) in tree.building_blocks: | |
| tree.building_blocks.remove(str(target)) | |
| # run tree search | |
| _ = list(tree) | |
| return tree | |
| def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]: | |
| """Takes the built tree and extracts the precursor for value network tuning. The | |
| precursor from found retrosynthetic routes are labeled as a positive class and precursor | |
| from not solved routes are labeled as a negative class. | |
| :param tree_list: The list of built search trees. | |
| :return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0). | |
| """ | |
| extracted_precursor = defaultdict(float) | |
| for tree in tree_list: | |
| for idx, node in tree.nodes.items(): | |
| # add solved nodes to set | |
| if node.is_solved(): | |
| parent = idx | |
| while parent and parent != 1: | |
| composed_smi = str( | |
| compose_precursors(tree.nodes[parent].new_precursors) | |
| ) | |
| extracted_precursor[composed_smi] = 1.0 | |
| parent = tree.parents[parent] | |
| else: | |
| composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors)) | |
| extracted_precursor[composed_smi] = 0.0 | |
| # shuffle extracted precursor | |
| processed_keys = list(extracted_precursor.keys()) | |
| shuffle(processed_keys) | |
| extracted_precursor = {i: extracted_precursor[i] for i in processed_keys} | |
| return extracted_precursor | |
| def balance_extracted_precursor(extracted_precursor): | |
| extracted_precursor_balanced = {} | |
| neg_list = [i for i, j in extracted_precursor.items() if j == 0] | |
| for k, v in extracted_precursor.items(): | |
| if v == 1: | |
| extracted_precursor_balanced[k] = v | |
| if len(extracted_precursor_balanced) < len(neg_list): | |
| neg_list.pop(random.choice(range(len(neg_list)))) | |
| return extracted_precursor_balanced | |
| def create_updating_set( | |
| extracted_precursor: Dict[str, float], batch_size: int = 1 | |
| ) -> LightningDataset: | |
| """Creates the value network updating dataset from precursor extracted from the planning | |
| simulation. | |
| :param extracted_precursor: The dictionary with the extracted precursor and their | |
| labels. | |
| :param batch_size: The size of the batch in value network updating. | |
| :return: A LightningDataset object, which contains the tuning set for value network | |
| tuning. | |
| """ | |
| extracted_precursor = balance_extracted_precursor(extracted_precursor) | |
| full_dataset = ValueNetworkDataset(extracted_precursor) | |
| train_size = int(0.6 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_set, val_set = random_split( | |
| full_dataset, [train_size, val_size], torch.Generator().manual_seed(42) | |
| ) | |
| print(f"Training set size: {len(train_set)}") | |
| print(f"Validation set size: {len(val_set)}") | |
| return LightningDataset( | |
| train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True | |
| ) | |
| def tune_value_network( | |
| datamodule: LightningDataset, value_config: ValueNetworkConfig | |
| ) -> None: | |
| """Trains the value network using a given tuning data and saves the trained neural | |
| network. | |
| :param datamodule: The tuning dataset (LightningDataset). | |
| :param value_config: The value network configuration. | |
| :return: None. | |
| """ | |
| current_weights = value_config.weights_path | |
| value_network = load_value_net(ValueNetwork, current_weights) | |
| with DisableLogger(), HiddenPrints(): | |
| trainer = Trainer( | |
| accelerator="gpu", | |
| devices=[0], | |
| max_epochs=value_config.num_epoch, | |
| enable_checkpointing=False, | |
| logger=False, | |
| gradient_clip_val=1.0, | |
| enable_progress_bar=False, | |
| ) | |
| trainer.fit(value_network, datamodule) | |
| val_score = trainer.validate(value_network, datamodule.val_dataloader())[0] | |
| trainer.save_checkpoint(current_weights) | |
| print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}") | |
| def run_training( | |
| extracted_precursor: Dict[str, float] = None, | |
| value_config: ValueNetworkConfig = None, | |
| ) -> None: | |
| """Runs the training stage in value network tuning. | |
| :param extracted_precursor: The precursor extracted from the planing simulations. | |
| :param value_config: The value network configuration. | |
| :return: None. | |
| """ | |
| # create training set | |
| training_set = create_updating_set( | |
| extracted_precursor=extracted_precursor, batch_size=value_config.batch_size | |
| ) | |
| # retrain value network | |
| tune_value_network(datamodule=training_set, value_config=value_config) | |
| def run_planning( | |
| targets_batch: List[MoleculeContainer], | |
| tree_config: TreeConfig, | |
| policy_config: PolicyNetworkConfig, | |
| value_config: ValueNetworkConfig, | |
| reaction_rules_path: str, | |
| building_blocks_path: str, | |
| targets_batch_id: int, | |
| ): | |
| """Performs planning stage (tree search) for target molecules and save extracted | |
| from built trees precursor for further tuning the value network in the training stage. | |
| :param targets_batch: | |
| :param tree_config: | |
| :param policy_config: | |
| :param value_config: | |
| :param reaction_rules_path: | |
| :param building_blocks_path: | |
| :param targets_batch_id: | |
| """ | |
| from tqdm import tqdm | |
| print(f"\nProcess batch number {targets_batch_id}") | |
| tree_list = [] | |
| tree_config.silent = False | |
| for target in tqdm(targets_batch): | |
| try: | |
| tree = run_tree_search( | |
| target=target, | |
| tree_config=tree_config, | |
| policy_config=policy_config, | |
| value_config=value_config, | |
| reaction_rules_path=reaction_rules_path, | |
| building_blocks_path=building_blocks_path, | |
| ) | |
| tree_list.append(tree) | |
| except Exception as e: | |
| print(e) | |
| continue | |
| num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list]) | |
| print(f"Planning is finished with {num_solved} solved targets") | |
| return tree_list | |
| def run_updating( | |
| targets_path: str, | |
| tree_config: TreeConfig, | |
| policy_config: PolicyNetworkConfig, | |
| value_config: ValueNetworkConfig, | |
| reinforce_config: TuningConfig, | |
| reaction_rules_path: str, | |
| building_blocks_path: str, | |
| results_root: str = None, | |
| ) -> None: | |
| """Performs updating of value network. | |
| :param targets_path: The path to the file with target molecules. | |
| :param tree_config: The search tree configuration. | |
| :param policy_config: The policy network configuration. | |
| :param value_config: The value network configuration. | |
| :param reinforce_config: The value network tuning configuration. | |
| :param reaction_rules_path: The path to the file with reaction rules. | |
| :param building_blocks_path: The path to the file with building blocks. | |
| :param results_root: The path to the directory where trained value network will be | |
| saved. | |
| :return: None. | |
| """ | |
| # create results root folder | |
| results_root = Path(results_root) | |
| if not results_root.exists(): | |
| results_root.mkdir() | |
| # load targets list | |
| with MoleculeReader(targets_path) as targets: | |
| targets = list(targets) | |
| # create value neural network | |
| value_config.weights_path = os.path.join(results_root, "value_network.ckpt") | |
| create_value_network(value_config) | |
| # create targets batch | |
| targets_batch_list = create_targets_batch( | |
| targets, batch_size=reinforce_config.batch_size | |
| ) | |
| # run value network tuning | |
| for batch_id, targets_batch in enumerate(targets_batch_list, start=1): | |
| # start tree planning simulation for batch of targets | |
| tree_list = run_planning( | |
| targets_batch=targets_batch, | |
| tree_config=tree_config, | |
| policy_config=policy_config, | |
| value_config=value_config, | |
| reaction_rules_path=reaction_rules_path, | |
| building_blocks_path=building_blocks_path, | |
| targets_batch_id=batch_id, | |
| ) | |
| # extract pos and neg precursor from the list of built trees | |
| extracted_precursor = extract_tree_precursor(tree_list) | |
| # train value network for extracted precursor | |
| run_training(extracted_precursor=extracted_precursor, value_config=value_config) | |