Spaces:
Running
Running
Gilmullin Almaz
Refactor code structure and remove redundant sections for improved readability and maintainability
914ea41
| """Module for the preparation and training of a policy network used in the expansion of | |
| nodes in tree search. | |
| This module includes functions for creating training datasets and running the training | |
| process for the policy network. | |
| """ | |
| import warnings | |
| from pathlib import Path | |
| from typing import Union, List | |
| import os | |
| import torch | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from torch.utils.data import random_split | |
| from torch_geometric.data.lightning import LightningDataset | |
| from synplan.ml.networks.policy import PolicyNetwork | |
| from synplan.ml.training.preprocessing import ( | |
| FilteringPolicyDataset, | |
| RankingPolicyDataset, | |
| ) | |
| from synplan.utils.config import PolicyNetworkConfig | |
| from synplan.utils.logging import DisableLogger, HiddenPrints | |
| warnings.filterwarnings("ignore") | |
| def create_policy_dataset( | |
| reaction_rules_path: str, | |
| molecules_or_reactions_path: str, | |
| output_path: str, | |
| dataset_type: str = "filtering", | |
| batch_size: int = 100, | |
| num_cpus: int = 1, | |
| training_data_ratio: float = 0.8, | |
| ): | |
| """ | |
| Create a training dataset for a policy network. | |
| :param reaction_rules_path: Path to the reaction rules file. | |
| :param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set. | |
| :param output_path: Path to store the processed dataset. | |
| :param dataset_type: Type of the dataset to be created ('ranking' or 'filtering'). | |
| :param batch_size: The size of batch of molecules/reactions. | |
| :param training_data_ratio: Ratio of training data to total data. | |
| :param num_cpus: Number of CPUs to use for data processing. | |
| :return: A `LightningDataset` object containing training and validation datasets. | |
| """ | |
| with DisableLogger(), HiddenPrints(): | |
| if dataset_type == "filtering": | |
| full_dataset = FilteringPolicyDataset( | |
| reaction_rules_path=reaction_rules_path, | |
| molecules_path=molecules_or_reactions_path, | |
| output_path=output_path, | |
| num_cpus=num_cpus, | |
| ) | |
| elif dataset_type == "ranking": | |
| full_dataset = RankingPolicyDataset( | |
| reaction_rules_path=reaction_rules_path, | |
| reactions_path=molecules_or_reactions_path, | |
| output_path=output_path, | |
| ) | |
| train_size = int(training_data_ratio * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = random_split( | |
| full_dataset, [train_size, val_size], torch.Generator().manual_seed(42) | |
| ) | |
| print( | |
| f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}" | |
| ) | |
| datamodule = LightningDataset( | |
| train_dataset, | |
| val_dataset, | |
| batch_size=batch_size, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| return datamodule | |
| def run_policy_training( | |
| datamodule: LightningDataset, | |
| config: PolicyNetworkConfig, | |
| results_path: str, | |
| weights_file_name: str = "policy_network", | |
| accelerator: str = "gpu", | |
| devices: Union[List[int], str, int] = "auto", | |
| silent: bool = False, | |
| ) -> None: | |
| """ | |
| Trains a policy network using a given datamodule and training configuration. | |
| :param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model. | |
| :param config: The dictionary that contains various configuration settings for the policy training process. | |
| :param results_path: Path to store the training results and logs. | |
| :param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu". | |
| :param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto". | |
| :param silent: Run in the silent mode with no progress bars. Default: True. | |
| :param weights_file_name: The name of weights file to be saved. Default: "policy_network". | |
| :return: None. | |
| """ | |
| results_path = Path(results_path) | |
| results_path.mkdir(exist_ok=True) | |
| network = PolicyNetwork( | |
| vector_dim=config.vector_dim, | |
| n_rules=datamodule.train_dataset.dataset.num_classes, | |
| batch_size=config.batch_size, | |
| dropout=config.dropout, | |
| num_conv_layers=config.num_conv_layers, | |
| learning_rate=config.learning_rate, | |
| policy_type=config.policy_type, | |
| ) | |
| checkpoint = ModelCheckpoint( | |
| dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min" | |
| ) | |
| if silent: | |
| enable_progress_bar = False | |
| else: | |
| enable_progress_bar = True | |
| trainer = Trainer( | |
| accelerator=accelerator, | |
| devices=devices, | |
| max_epochs=config.num_epoch, | |
| callbacks=[checkpoint], | |
| logger=False, | |
| gradient_clip_val=1.0, | |
| enable_progress_bar=enable_progress_bar, | |
| ) | |
| if silent: | |
| with DisableLogger(), HiddenPrints(): | |
| trainer.fit(network, datamodule) | |
| else: | |
| trainer.fit(network, datamodule) | |
| ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3) | |
| print(f"Policy network balanced accuracy: {ba}") | |