import argparse import datetime import os import sys import yaml import optuna import time import logging import pandas as pd sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) import pytorch_lightning as pl from pytorch_lightning import Trainer from optuna.integration import PyTorchLightningPruningCallback from pytorch_lightning.callbacks import Callback from flare.data.data_module import ContrastiveDataModule from flare.data.datasets import ContrastiveDataset from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer from flare.utils.models import get_model from flare.definitions import TEST_RESULTS_DIR from functools import partial from rdkit import RDLogger from massspecgym.models.base import Stage # Suppress RDKit warnings lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) parser = argparse.ArgumentParser() parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml") parser.add_argument("--n_trials", type=int, default=20) class EpochLossTracker(Callback): def __init__(self, trial): super().__init__() self.trial = trial self.history = {"train_loss": [], "val_loss": []} def on_train_epoch_end(self, trainer, pl_module): if "train_loss" in trainer.callback_metrics: self.history["train_loss"].append( float(trainer.callback_metrics["train_loss"].cpu().item()) ) def on_validation_epoch_end(self, trainer, pl_module): val_key = f"{Stage.VAL.to_pref()}loss" if val_key in trainer.callback_metrics: self.history["val_loss"].append( float(trainer.callback_metrics[val_key].cpu().item()) ) def on_fit_end(self, trainer, pl_module): # Attach to trial so save_trial_result can access it self.trial.set_user_attr("loss_history", self.history) class SafePruningCallback(PyTorchLightningPruningCallback, Callback): """Wraps Optuna pruning to make it a proper Lightning Callback.""" pass def setup_logging(log_path): """Setup logging without breaking tqdm progress bars.""" logger = logging.getLogger() logger.setLevel(logging.INFO) # Remove existing handlers (avoid duplicate or wrong outputs) if logger.hasHandlers(): logger.handlers.clear() # File handler file_handler = logging.FileHandler(log_path, mode="a") file_handler.setFormatter( logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") ) # Console handler (stderr so tqdm stays clean) console_handler = logging.StreamHandler(sys.stderr) console_handler.setFormatter( logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") ) logger.addHandler(file_handler) logger.addHandler(console_handler) def save_trial_result(base_dir, trial, params, duration): """Append trial results to a CSV file after each trial.""" history_path = os.path.join(base_dir, "trial_history.csv") # Fetch losses from trial user_attrs loss_hist = trial.user_attrs.get("loss_history", {}) record = { "number": trial.number, "duration_sec": duration, "train_loss": loss_hist.get("train_loss", []), "val_loss": loss_hist.get("val_loss", []), **trial.params, } # Append to CSV safely if os.path.exists(history_path): df = pd.read_csv(history_path) df = pd.concat([df, pd.DataFrame([record])], ignore_index=True) else: df = pd.DataFrame([record]) df.to_csv(history_path, index=False) def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials): start_time = time.time() params = base_params.copy() try: # Training-related params params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True) params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True) params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.01, 0.1) # Spectra encoder-related params params['formula_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5) params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4, 8]) params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [1,2,3,4,5]) choice = trial.suggest_categorical( "formula_dims", ["64,128", "512,256", "256,512", "128", "256", "128,128", "512,512", "64,64,64,64"] ) params["formula_dims"] = [int(x) for x in choice.split(",")] # Molecule encoder-related params params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5) choice = trial.suggest_categorical( "gnn_channels", ["64,128", "128,256", "256,512", "64,128,128", "128,128", "64,64,64"] ) params["gnn_channels"] = [int(x) for x in choice.split(",")] # Ensure last layer matches final embedding dim final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [64,256,512,1024]) params['formula_dims'].append(final_embedding_dim) params['gnn_channels'].append(final_embedding_dim) logging.info(f"Formula dims: {params['formula_dims']}") logging.info(f"GNN channels: {params['gnn_channels']}") # Init seed pl.seed_everything(params["seed"]) # Init dataset + datamodule spec_featurizer = get_spec_featurizer(params["spectra_view"], params) mol_featurizer = get_mol_featurizer(params["molecule_view"], params) dataset = get_ms_dataset(params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params) collate_fn = partial( ContrastiveDataset.collate_fn, spec_enc=params["spec_enc"], spectra_view=params["spectra_view"], mask_peak_ratio=params["mask_peak_ratio"], aug_cands=params["aug_cands"], ) data_module = ContrastiveDataModule( dataset=dataset, collate_fn=collate_fn, split_pth=params["split_pth"], batch_size=params["batch_size"], num_workers=params["num_workers"], ) # Init model model = get_model(params["model"], params) # Metric to optimize callbacks = [] monitor_metric = f"{Stage.VAL.to_pref()}loss" pruning_cb = SafePruningCallback(trial, monitor=monitor_metric) callbacks.append(pruning_cb) loss_tracker = EpochLossTracker(trial) callbacks.append(loss_tracker) trainer = Trainer( accelerator=params["accelerator"], devices=params["devices"], max_epochs=params["max_epochs"], logger=False, enable_checkpointing=False, callbacks=callbacks, ) data_module.prepare_data() data_module.setup() # Validate before training trainer.validate(model, datamodule=data_module) # Fit (may be pruned early) trainer.fit(model, datamodule=data_module) # Duration duration = time.time() - start_time trial_times.append(duration) avg_time = sum(trial_times) / len(trial_times) remaining = (total_trials - trial.number - 1) * avg_time logging.info(f"[Trial {trial.number}] Duration: {duration/60:.2f} min | Avg: {avg_time/60:.2f} min | ETA: {remaining/60:.2f} min") value = trainer.callback_metrics[monitor_metric].item() trial.set_user_attr("duration", duration) # Save progress save_trial_result(base_dir, trial, base_params, duration, ) return value except Exception as e: duration = time.time() - start_time logging.exception(f"Trial {trial.number} failed: {e}") save_trial_result(base_dir, trial, base_params, duration) raise def main(args): with open(args.param_pth) as f: params = yaml.load(f, Loader=yaml.FullLoader) # now = datetime.datetime.now().strftime("%Y%m%d") # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna") base_dir = "../experiments/20250916_simple_model_optuna" os.makedirs(base_dir, exist_ok=True) params["experiment_dir"] = base_dir # Setup logging log_path = os.path.join(base_dir, "optuna.log") setup_logging(log_path) trial_times = [] study_name = "filip_contrastive" storage = f"sqlite:///{base_dir}/optuna_study.db" study = optuna.create_study(study_name=study_name, storage=storage, direction="minimize", pruner=optuna.pruners.MedianPruner(), load_if_exists=True) study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials) # Print best trial logging.info("\nBest trial:") logging.info(study.best_trial.params) # Merge base params with best trial best_params = params.copy() best_params.update(study.best_trial.params) # Save best params to YAML best_param_path = os.path.join(base_dir, "best_params.yaml") with open(best_param_path, "w") as f: yaml.dump(best_params, f) logging.info(f"\nBest parameters saved to: {best_param_path}") logging.info(f"Run training with: python train.py --param_pth {best_param_path}") if __name__ == "__main__": args = parser.parse_args([] if "__file__" not in globals() else None) main(args)