File size: 5,493 Bytes
914ea41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Module for the preparation and training of a policy network used in the expansion of
nodes in tree search.

This module includes functions for creating training datasets and running the training
process for the policy network.
"""

import warnings
from pathlib import Path
from typing import Union, List

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import random_split
from torch_geometric.data.lightning import LightningDataset

from synplan.ml.networks.policy import PolicyNetwork
from synplan.ml.training.preprocessing import (
    FilteringPolicyDataset,
    RankingPolicyDataset,
)
from synplan.utils.config import PolicyNetworkConfig
from synplan.utils.logging import DisableLogger, HiddenPrints

warnings.filterwarnings("ignore")


def create_policy_dataset(
    reaction_rules_path: str,
    molecules_or_reactions_path: str,
    output_path: str,
    dataset_type: str = "filtering",
    batch_size: int = 100,
    num_cpus: int = 1,
    training_data_ratio: float = 0.8,
):
    """
    Create a training dataset for a policy network.

    :param reaction_rules_path: Path to the reaction rules file.
    :param molecules_or_reactions_path: Path to the molecules or reactions file used to create the training set.
    :param output_path: Path to store the processed dataset.
    :param dataset_type: Type of the dataset to be created ('ranking' or 'filtering').
    :param batch_size: The size of batch of molecules/reactions.
    :param training_data_ratio: Ratio of training data to total data.
    :param num_cpus: Number of CPUs to use for data processing.

    :return: A `LightningDataset` object containing training and validation datasets.

    """

    with DisableLogger(), HiddenPrints():
        if dataset_type == "filtering":
            full_dataset = FilteringPolicyDataset(
                reaction_rules_path=reaction_rules_path,
                molecules_path=molecules_or_reactions_path,
                output_path=output_path,
                num_cpus=num_cpus,
            )

        elif dataset_type == "ranking":
            full_dataset = RankingPolicyDataset(
                reaction_rules_path=reaction_rules_path,
                reactions_path=molecules_or_reactions_path,
                output_path=output_path,
            )

    train_size = int(training_data_ratio * len(full_dataset))
    val_size = len(full_dataset) - train_size

    train_dataset, val_dataset = random_split(
        full_dataset, [train_size, val_size], torch.Generator().manual_seed(42)
    )
    print(
        f"Training set size: {len(train_dataset)}, validation set size: {len(val_dataset)}"
    )

    datamodule = LightningDataset(
        train_dataset,
        val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        drop_last=True,
    )

    return datamodule


def run_policy_training(
    datamodule: LightningDataset,
    config: PolicyNetworkConfig,
    results_path: str,
    weights_file_name: str = "policy_network",
    accelerator: str = "gpu",
    devices: Union[List[int], str, int] = "auto",
    silent: bool = False,
) -> None:
    """
    Trains a policy network using a given datamodule and training configuration.

    :param datamodule: A PyTorch Lightning `DataModule` class instance. It is responsible for loading, processing, and preparing the training data for the model.
    :param config: The dictionary that contains various configuration settings for the policy training process.
    :param results_path: Path to store the training results and logs.
    :param accelerator: Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances. Default: "gpu".
    :param devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator. Default: "auto".
    :param silent: Run in the silent mode with no progress bars. Default: True.
    :param weights_file_name: The name of weights file to be saved. Default: "policy_network".

    :return: None.

    """
    results_path = Path(results_path)
    results_path.mkdir(exist_ok=True)

    network = PolicyNetwork(
        vector_dim=config.vector_dim,
        n_rules=datamodule.train_dataset.dataset.num_classes,
        batch_size=config.batch_size,
        dropout=config.dropout,
        num_conv_layers=config.num_conv_layers,
        learning_rate=config.learning_rate,
        policy_type=config.policy_type,
    )

    checkpoint = ModelCheckpoint(
        dirpath=results_path, filename=weights_file_name, monitor="val_loss", mode="min"
    )

    if silent:
        enable_progress_bar = False
    else:
        enable_progress_bar = True

    trainer = Trainer(
        accelerator=accelerator,
        devices=devices,
        max_epochs=config.num_epoch,
        callbacks=[checkpoint],
        logger=False,
        gradient_clip_val=1.0,
        enable_progress_bar=enable_progress_bar,
    )

    if silent:
        with DisableLogger(), HiddenPrints():
            trainer.fit(network, datamodule)
    else:
        trainer.fit(network, datamodule)

    ba = round(trainer.logged_metrics["train_balanced_accuracy_y_step"].item(), 3)
    print(f"Policy network balanced accuracy: {ba}")