File size: 9,597 Bytes
994fb49
6c3d8a1
994fb49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c0063e
 
 
 
 
6c3d8a1
994fb49
 
 
 
 
 
 
 
 
6c3d8a1
 
 
 
 
 
994fb49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c3d8a1
994fb49
 
 
f695c70
994fb49
 
f695c70
994fb49
 
f695c70
 
 
994fb49
 
 
f695c70
994fb49
 
 
 
 
 
 
f695c70
994fb49
 
 
 
f695c70
994fb49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c3d8a1
 
994fb49
6c3d8a1
 
994fb49
 
 
 
 
 
 
 
f695c70
 
994fb49
f695c70
994fb49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import argparse
import copy
import datetime
import os
import sys
import yaml
import optuna
import time
import logging
import pandas as pd

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.callbacks import Callback

from flare.data.data_module import ContrastiveDataModule
from flare.data.datasets import ContrastiveDataset
from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from flare.utils.models import get_model
from flare.definitions import TEST_RESULTS_DIR
from flare.utils.config import default_param_path, load_param_file
from functools import partial
from rdkit import RDLogger
from massspecgym.models.base import Stage

# Suppress RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--param_pth",
    type=str,
    default=None,
    help="Base YAML (default: FLARE_PARAMS or repo params.yaml)",
)
parser.add_argument("--n_trials", type=int, default=20)

class EpochLossTracker(Callback):
    def __init__(self, trial):
        super().__init__()
        self.trial = trial
        self.history = {"train_loss": [], "val_loss": []}

    def on_train_epoch_end(self, trainer, pl_module):
        if "train_loss" in trainer.callback_metrics:
            self.history["train_loss"].append(
                float(trainer.callback_metrics["train_loss"].cpu().item())
            )

    def on_validation_epoch_end(self, trainer, pl_module):
        val_key = f"{Stage.VAL.to_pref()}loss"
        if val_key in trainer.callback_metrics:
            self.history["val_loss"].append(
                float(trainer.callback_metrics[val_key].cpu().item())
            )

    def on_fit_end(self, trainer, pl_module):
        # Attach to trial so save_trial_result can access it
        self.trial.set_user_attr("loss_history", self.history)



class SafePruningCallback(PyTorchLightningPruningCallback, Callback):
    """Wraps Optuna pruning to make it a proper Lightning Callback."""
    pass


def setup_logging(log_path):
    """Setup logging without breaking tqdm progress bars."""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Remove existing handlers (avoid duplicate or wrong outputs)
    if logger.hasHandlers():
        logger.handlers.clear()

    # File handler
    file_handler = logging.FileHandler(log_path, mode="a")
    file_handler.setFormatter(
        logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    )

    # Console handler (stderr so tqdm stays clean)
    console_handler = logging.StreamHandler(sys.stderr)
    console_handler.setFormatter(
        logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    )

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)


def save_trial_result(base_dir, trial, params, duration):
    """Append trial results to a CSV file after each trial."""
    history_path = os.path.join(base_dir, "trial_history.csv")

    # Fetch losses from trial user_attrs
    loss_hist = trial.user_attrs.get("loss_history", {})
    record = {
        "number": trial.number,
        "duration_sec": duration,
        "train_loss": loss_hist.get("train_loss", []),
        "val_loss": loss_hist.get("val_loss", []),
        **trial.params,
    }

    # Append to CSV safely
    if os.path.exists(history_path):
        df = pd.read_csv(history_path)
        df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
    else:
        df = pd.DataFrame([record])
    df.to_csv(history_path, index=False)


def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
    start_time = time.time()
    params = copy.deepcopy(base_params)

    try:
        # Training-related params
        params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
        params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
        params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
        params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.01, 0.1)

        # Spectra encoder-related params
        params['formula_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5)
        params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4, 8])
        params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [1,2,3,4,5])

        choice = trial.suggest_categorical(
            "formula_dims",
            ["64,128", "512,256", "256,512", "128", "256", "128,128", "512,512", "64,64,64,64"]
        )
        params["formula_dims"] = [int(x) for x in choice.split(",")]

        # Molecule encoder-related params
        params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5)
        choice = trial.suggest_categorical(
            "gnn_channels",
            ["64,128", "128,256", "256,512", "64,128,128", "128,128", "64,64,64"]
        )
        params["gnn_channels"] = [int(x) for x in choice.split(",")]

        # Ensure last layer matches final embedding dim
        final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [64,256,512,1024])
        params['formula_dims'].append(final_embedding_dim)
        params['gnn_channels'].append(final_embedding_dim)

        logging.info(f"Formula dims: {params['formula_dims']}")
        logging.info(f"GNN channels: {params['gnn_channels']}")

        # Init seed
        pl.seed_everything(params["seed"])

        # Init dataset + datamodule
        spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
        mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
        dataset = get_ms_dataset(params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params)

        collate_fn = partial(
            ContrastiveDataset.collate_fn,
            spec_enc=params["spec_enc"],
            spectra_view=params["spectra_view"],
        )

        data_module = ContrastiveDataModule(
            dataset=dataset,
            collate_fn=collate_fn,
            split_pth=params["split_pth"],
            batch_size=params["batch_size"],
            num_workers=params["num_workers"],
        )

        # Init model
        model = get_model(params["model"], params)

        # Metric to optimize
        callbacks = []
        monitor_metric = f"{Stage.VAL.to_pref()}loss"
        pruning_cb = SafePruningCallback(trial, monitor=monitor_metric)
        callbacks.append(pruning_cb)
        loss_tracker = EpochLossTracker(trial)
        callbacks.append(loss_tracker)

        

        trainer = Trainer(
            accelerator=params["accelerator"],
            devices=params["devices"],
            max_epochs=params["max_epochs"],
            logger=False,
            enable_checkpointing=False,
            callbacks=callbacks,
        )

        data_module.prepare_data()
        data_module.setup()

        # Validate before training
        trainer.validate(model, datamodule=data_module)

        # Fit (may be pruned early)
        trainer.fit(model, datamodule=data_module)

        # Duration
        duration = time.time() - start_time
        trial_times.append(duration)
        avg_time = sum(trial_times) / len(trial_times)
        remaining = (total_trials - trial.number - 1) * avg_time
        logging.info(f"[Trial {trial.number}] Duration: {duration/60:.2f} min | Avg: {avg_time/60:.2f} min | ETA: {remaining/60:.2f} min")

        value = trainer.callback_metrics[monitor_metric].item()
        trial.set_user_attr("duration", duration)

        # Save progress
        save_trial_result(base_dir, trial, base_params, duration, )

        return value

    except Exception as e:
        duration = time.time() - start_time
        logging.exception(f"Trial {trial.number} failed: {e}")
        save_trial_result(base_dir, trial, base_params, duration)
        raise


def main(args):
    param_path = args.param_pth or str(default_param_path())
    params = load_param_file(param_path)

    now = datetime.datetime.now().strftime("%Y%m%d")
    base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
    os.makedirs(base_dir, exist_ok=True)
    params["experiment_dir"] = base_dir

    # Setup logging
    log_path = os.path.join(base_dir, "optuna.log")
    setup_logging(log_path)

    trial_times = []
    study_name = "filip_contrastive"
    storage = f"sqlite:///{base_dir}/optuna_study.db"

    study = optuna.create_study(study_name=study_name, storage=storage, direction="minimize", pruner=optuna.pruners.MedianPruner(), load_if_exists=True)
    study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials)

    # Print best trial
    logging.info("\nBest trial:")
    logging.info(study.best_trial.params)

    # Merge base params with best trial
    best_params = params.copy()
    best_params.update(study.best_trial.params)

    # Save best params to YAML
    best_param_path = os.path.join(base_dir, "best_params.yaml")
    with open(best_param_path, "w") as f:
        yaml.dump(best_params, f)
    logging.info(f"\nBest parameters saved to: {best_param_path}")
    logging.info(f"Run training with: python train.py --param_pth {best_param_path}")


if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)
    main(args)