Spaces:
Sleeping
Sleeping
File size: 12,817 Bytes
72a3513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 |
"""Module containing functions for running value network tuning with reinforcement learning
approach."""
import os
import random
from collections import defaultdict
from pathlib import Path
from random import shuffle
from typing import Dict, List
import torch
from CGRtools.containers import MoleculeContainer
from pytorch_lightning import Trainer
from torch.utils.data import random_split
from torch_geometric.data.lightning import LightningDataset
from synplan.chem.precursor import compose_precursors
from synplan.mcts.evaluation import ValueNetworkFunction
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.mcts.tree import Tree
from synplan.ml.networks.value import ValueNetwork
from synplan.ml.training.preprocessing import ValueNetworkDataset
from synplan.utils.config import (
PolicyNetworkConfig,
TuningConfig,
TreeConfig,
ValueNetworkConfig,
)
from synplan.utils.files import MoleculeReader
from synplan.utils.loading import (
load_building_blocks,
load_reaction_rules,
load_value_net,
)
from synplan.utils.logging import DisableLogger, HiddenPrints
def create_value_network(value_config: ValueNetworkConfig) -> ValueNetwork:
"""Creates the initial value network.
:param value_config: The value network configuration.
:return: The valueNetwork to be trained/tuned.
"""
weights_path = Path(value_config.weights_path)
value_network = ValueNetwork(
vector_dim=value_config.vector_dim,
batch_size=value_config.batch_size,
dropout=value_config.dropout,
num_conv_layers=value_config.num_conv_layers,
learning_rate=value_config.learning_rate,
)
with DisableLogger(), HiddenPrints():
trainer = Trainer()
trainer.strategy.connect(value_network)
trainer.save_checkpoint(weights_path)
return value_network
def create_targets_batch(
targets: List[MoleculeContainer], batch_size: int
) -> List[List[MoleculeContainer]]:
"""Creates the targets batches for planning simulations and value network tuning.
:param targets: The list of target molecules.
:param batch_size: The size of each target batch.
:return: The list of lists corresponding to each target batch.
"""
num_targets = len(targets)
batch_splits = list(
range(num_targets // batch_size + int(bool(num_targets % batch_size)))
)
if int(num_targets / batch_size) == 0:
print(f"1 batch were created with {num_targets} molecules")
else:
print(
f"{len(batch_splits)} batches were created with {batch_size} molecules each"
)
targets_batch_list = []
for batch_id in batch_splits:
batch_slices = [
i
for i in range(batch_id * batch_size, (batch_id + 1) * batch_size)
if i < len(targets)
]
targets_batch_list.append([targets[i] for i in batch_slices])
return targets_batch_list
def run_tree_search(
target: MoleculeContainer,
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reaction_rules_path: str,
building_blocks_path: str,
) -> Tree:
"""Runs tree search for the given target molecule.
:param target: The target molecule.
:param tree_config: The planning configuration of tree search.
:param policy_config: The policy network configuration.
:param value_config: The value network configuration.
:param reaction_rules_path: The path to the file with reaction rules.
:param building_blocks_path: The path to the file with building blocks.
:return: The built search tree for the given molecule.
"""
# policy and value function loading
policy_function = PolicyNetworkFunction(policy_config=policy_config)
value_function = ValueNetworkFunction(weights_path=value_config.weights_path)
reaction_rules = load_reaction_rules(reaction_rules_path)
building_blocks = load_building_blocks(building_blocks_path, standardize=True)
# initialize tree
tree_config.evaluation_type = "gcn"
tree_config.silent = True
tree = Tree(
target=target,
config=tree_config,
reaction_rules=reaction_rules,
building_blocks=building_blocks,
expansion_function=policy_function,
evaluation_function=value_function,
)
tree._tqdm = False
# remove target from buildings blocs
if str(target) in tree.building_blocks:
tree.building_blocks.remove(str(target))
# run tree search
_ = list(tree)
return tree
def extract_tree_precursor(tree_list: List[Tree]) -> Dict[str, float]:
"""Takes the built tree and extracts the precursor for value network tuning. The
precursor from found retrosynthetic routes are labeled as a positive class and precursor
from not solved routes are labeled as a negative class.
:param tree_list: The list of built search trees.
:return: The dictionary with the precursor SMILES and its class (positive - 1 or negative - 0).
"""
extracted_precursor = defaultdict(float)
for tree in tree_list:
for idx, node in tree.nodes.items():
# add solved nodes to set
if node.is_solved():
parent = idx
while parent and parent != 1:
composed_smi = str(
compose_precursors(tree.nodes[parent].new_precursors)
)
extracted_precursor[composed_smi] = 1.0
parent = tree.parents[parent]
else:
composed_smi = str(compose_precursors(tree.nodes[idx].new_precursors))
extracted_precursor[composed_smi] = 0.0
# shuffle extracted precursor
processed_keys = list(extracted_precursor.keys())
shuffle(processed_keys)
extracted_precursor = {i: extracted_precursor[i] for i in processed_keys}
return extracted_precursor
def balance_extracted_precursor(extracted_precursor):
extracted_precursor_balanced = {}
neg_list = [i for i, j in extracted_precursor.items() if j == 0]
for k, v in extracted_precursor.items():
if v == 1:
extracted_precursor_balanced[k] = v
if len(extracted_precursor_balanced) < len(neg_list):
neg_list.pop(random.choice(range(len(neg_list))))
return extracted_precursor_balanced
def create_updating_set(
extracted_precursor: Dict[str, float], batch_size: int = 1
) -> LightningDataset:
"""Creates the value network updating dataset from precursor extracted from the planning
simulation.
:param extracted_precursor: The dictionary with the extracted precursor and their
labels.
:param batch_size: The size of the batch in value network updating.
:return: A LightningDataset object, which contains the tuning set for value network
tuning.
"""
extracted_precursor = balance_extracted_precursor(extracted_precursor)
full_dataset = ValueNetworkDataset(extracted_precursor)
train_size = int(0.6 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = random_split(
full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
)
print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")
return LightningDataset(
train_set, val_set, batch_size=batch_size, pin_memory=True, drop_last=True
)
def tune_value_network(
datamodule: LightningDataset, value_config: ValueNetworkConfig
) -> None:
"""Trains the value network using a given tuning data and saves the trained neural
network.
:param datamodule: The tuning dataset (LightningDataset).
:param value_config: The value network configuration.
:return: None.
"""
current_weights = value_config.weights_path
value_network = load_value_net(ValueNetwork, current_weights)
with DisableLogger(), HiddenPrints():
trainer = Trainer(
accelerator="gpu",
devices=[0],
max_epochs=value_config.num_epoch,
enable_checkpointing=False,
logger=False,
gradient_clip_val=1.0,
enable_progress_bar=False,
)
trainer.fit(value_network, datamodule)
val_score = trainer.validate(value_network, datamodule.val_dataloader())[0]
trainer.save_checkpoint(current_weights)
print(f"Value network balanced accuracy: {val_score['val_balanced_accuracy']}")
def run_training(
extracted_precursor: Dict[str, float] = None,
value_config: ValueNetworkConfig = None,
) -> None:
"""Runs the training stage in value network tuning.
:param extracted_precursor: The precursor extracted from the planing simulations.
:param value_config: The value network configuration.
:return: None.
"""
# create training set
training_set = create_updating_set(
extracted_precursor=extracted_precursor, batch_size=value_config.batch_size
)
# retrain value network
tune_value_network(datamodule=training_set, value_config=value_config)
def run_planning(
targets_batch: List[MoleculeContainer],
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reaction_rules_path: str,
building_blocks_path: str,
targets_batch_id: int,
):
"""Performs planning stage (tree search) for target molecules and save extracted
from built trees precursor for further tuning the value network in the training stage.
:param targets_batch:
:param tree_config:
:param policy_config:
:param value_config:
:param reaction_rules_path:
:param building_blocks_path:
:param targets_batch_id:
"""
from tqdm import tqdm
print(f"\nProcess batch number {targets_batch_id}")
tree_list = []
tree_config.silent = False
for target in tqdm(targets_batch):
try:
tree = run_tree_search(
target=target,
tree_config=tree_config,
policy_config=policy_config,
value_config=value_config,
reaction_rules_path=reaction_rules_path,
building_blocks_path=building_blocks_path,
)
tree_list.append(tree)
except Exception as e:
print(e)
continue
num_solved = sum([len(i.winning_nodes) > 0 for i in tree_list])
print(f"Planning is finished with {num_solved} solved targets")
return tree_list
def run_updating(
targets_path: str,
tree_config: TreeConfig,
policy_config: PolicyNetworkConfig,
value_config: ValueNetworkConfig,
reinforce_config: TuningConfig,
reaction_rules_path: str,
building_blocks_path: str,
results_root: str = None,
) -> None:
"""Performs updating of value network.
:param targets_path: The path to the file with target molecules.
:param tree_config: The search tree configuration.
:param policy_config: The policy network configuration.
:param value_config: The value network configuration.
:param reinforce_config: The value network tuning configuration.
:param reaction_rules_path: The path to the file with reaction rules.
:param building_blocks_path: The path to the file with building blocks.
:param results_root: The path to the directory where trained value network will be
saved.
:return: None.
"""
# create results root folder
results_root = Path(results_root)
if not results_root.exists():
results_root.mkdir()
# load targets list
with MoleculeReader(targets_path) as targets:
targets = list(targets)
# create value neural network
value_config.weights_path = os.path.join(results_root, "value_network.ckpt")
create_value_network(value_config)
# create targets batch
targets_batch_list = create_targets_batch(
targets, batch_size=reinforce_config.batch_size
)
# run value network tuning
for batch_id, targets_batch in enumerate(targets_batch_list, start=1):
# start tree planning simulation for batch of targets
tree_list = run_planning(
targets_batch=targets_batch,
tree_config=tree_config,
policy_config=policy_config,
value_config=value_config,
reaction_rules_path=reaction_rules_path,
building_blocks_path=building_blocks_path,
targets_batch_id=batch_id,
)
# extract pos and neg precursor from the list of built trees
extracted_precursor = extract_tree_precursor(tree_list)
# train value network for extracted precursor
run_training(extracted_precursor=extracted_precursor, value_config=value_config)
|