Spaces:
Sleeping
Sleeping
| import argparse | |
| import datetime | |
| import os | |
| import sys | |
| import yaml | |
| import optuna | |
| import time | |
| import logging | |
| import pandas as pd | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Trainer | |
| from optuna.integration import PyTorchLightningPruningCallback | |
| from pytorch_lightning.callbacks import Callback | |
| from flare.data.data_module import ContrastiveDataModule | |
| from flare.data.datasets import ContrastiveDataset | |
| from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer | |
| from flare.utils.models import get_model | |
| from flare.definitions import TEST_RESULTS_DIR | |
| from functools import partial | |
| from rdkit import RDLogger | |
| from massspecgym.models.base import Stage | |
| # Suppress RDKit warnings | |
| lg = RDLogger.logger() | |
| lg.setLevel(RDLogger.CRITICAL) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml") | |
| parser.add_argument("--n_trials", type=int, default=20) | |
| class EpochLossTracker(Callback): | |
| def __init__(self, trial): | |
| super().__init__() | |
| self.trial = trial | |
| self.history = {"train_loss": [], "val_loss": []} | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| if "train_loss" in trainer.callback_metrics: | |
| self.history["train_loss"].append( | |
| float(trainer.callback_metrics["train_loss"].cpu().item()) | |
| ) | |
| def on_validation_epoch_end(self, trainer, pl_module): | |
| val_key = f"{Stage.VAL.to_pref()}loss" | |
| if val_key in trainer.callback_metrics: | |
| self.history["val_loss"].append( | |
| float(trainer.callback_metrics[val_key].cpu().item()) | |
| ) | |
| def on_fit_end(self, trainer, pl_module): | |
| # Attach to trial so save_trial_result can access it | |
| self.trial.set_user_attr("loss_history", self.history) | |
| class SafePruningCallback(PyTorchLightningPruningCallback, Callback): | |
| """Wraps Optuna pruning to make it a proper Lightning Callback.""" | |
| pass | |
| def setup_logging(log_path): | |
| """Setup logging without breaking tqdm progress bars.""" | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| # Remove existing handlers (avoid duplicate or wrong outputs) | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| # File handler | |
| file_handler = logging.FileHandler(log_path, mode="a") | |
| file_handler.setFormatter( | |
| logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") | |
| ) | |
| # Console handler (stderr so tqdm stays clean) | |
| console_handler = logging.StreamHandler(sys.stderr) | |
| console_handler.setFormatter( | |
| logging.Formatter("%(asctime)s [%(levelname)s] %(message)s") | |
| ) | |
| logger.addHandler(file_handler) | |
| logger.addHandler(console_handler) | |
| def save_trial_result(base_dir, trial, params, duration): | |
| """Append trial results to a CSV file after each trial.""" | |
| history_path = os.path.join(base_dir, "trial_history.csv") | |
| # Fetch losses from trial user_attrs | |
| loss_hist = trial.user_attrs.get("loss_history", {}) | |
| record = { | |
| "number": trial.number, | |
| "duration_sec": duration, | |
| "train_loss": loss_hist.get("train_loss", []), | |
| "val_loss": loss_hist.get("val_loss", []), | |
| **trial.params, | |
| } | |
| # Append to CSV safely | |
| if os.path.exists(history_path): | |
| df = pd.read_csv(history_path) | |
| df = pd.concat([df, pd.DataFrame([record])], ignore_index=True) | |
| else: | |
| df = pd.DataFrame([record]) | |
| df.to_csv(history_path, index=False) | |
| def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials): | |
| start_time = time.time() | |
| params = base_params.copy() | |
| try: | |
| # Training-related params | |
| params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128, 256]) | |
| params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True) | |
| params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True) | |
| params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.01, 0.1) | |
| # Spectra encoder-related params | |
| params['formula_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5) | |
| params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4, 8]) | |
| params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [1,2,3,4,5]) | |
| choice = trial.suggest_categorical( | |
| "formula_dims", | |
| ["64,128", "512,256", "256,512", "128", "256", "128,128", "512,512", "64,64,64,64"] | |
| ) | |
| params["formula_dims"] = [int(x) for x in choice.split(",")] | |
| # Molecule encoder-related params | |
| params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5) | |
| choice = trial.suggest_categorical( | |
| "gnn_channels", | |
| ["64,128", "128,256", "256,512", "64,128,128", "128,128", "64,64,64"] | |
| ) | |
| params["gnn_channels"] = [int(x) for x in choice.split(",")] | |
| # Ensure last layer matches final embedding dim | |
| final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [64,256,512,1024]) | |
| params['formula_dims'].append(final_embedding_dim) | |
| params['gnn_channels'].append(final_embedding_dim) | |
| logging.info(f"Formula dims: {params['formula_dims']}") | |
| logging.info(f"GNN channels: {params['gnn_channels']}") | |
| # Init seed | |
| pl.seed_everything(params["seed"]) | |
| # Init dataset + datamodule | |
| spec_featurizer = get_spec_featurizer(params["spectra_view"], params) | |
| mol_featurizer = get_mol_featurizer(params["molecule_view"], params) | |
| dataset = get_ms_dataset(params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params) | |
| collate_fn = partial( | |
| ContrastiveDataset.collate_fn, | |
| spec_enc=params["spec_enc"], | |
| spectra_view=params["spectra_view"], | |
| mask_peak_ratio=params["mask_peak_ratio"], | |
| aug_cands=params["aug_cands"], | |
| ) | |
| data_module = ContrastiveDataModule( | |
| dataset=dataset, | |
| collate_fn=collate_fn, | |
| split_pth=params["split_pth"], | |
| batch_size=params["batch_size"], | |
| num_workers=params["num_workers"], | |
| ) | |
| # Init model | |
| model = get_model(params["model"], params) | |
| # Metric to optimize | |
| callbacks = [] | |
| monitor_metric = f"{Stage.VAL.to_pref()}loss" | |
| pruning_cb = SafePruningCallback(trial, monitor=monitor_metric) | |
| callbacks.append(pruning_cb) | |
| loss_tracker = EpochLossTracker(trial) | |
| callbacks.append(loss_tracker) | |
| trainer = Trainer( | |
| accelerator=params["accelerator"], | |
| devices=params["devices"], | |
| max_epochs=params["max_epochs"], | |
| logger=False, | |
| enable_checkpointing=False, | |
| callbacks=callbacks, | |
| ) | |
| data_module.prepare_data() | |
| data_module.setup() | |
| # Validate before training | |
| trainer.validate(model, datamodule=data_module) | |
| # Fit (may be pruned early) | |
| trainer.fit(model, datamodule=data_module) | |
| # Duration | |
| duration = time.time() - start_time | |
| trial_times.append(duration) | |
| avg_time = sum(trial_times) / len(trial_times) | |
| remaining = (total_trials - trial.number - 1) * avg_time | |
| logging.info(f"[Trial {trial.number}] Duration: {duration/60:.2f} min | Avg: {avg_time/60:.2f} min | ETA: {remaining/60:.2f} min") | |
| value = trainer.callback_metrics[monitor_metric].item() | |
| trial.set_user_attr("duration", duration) | |
| # Save progress | |
| save_trial_result(base_dir, trial, base_params, duration, ) | |
| return value | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| logging.exception(f"Trial {trial.number} failed: {e}") | |
| save_trial_result(base_dir, trial, base_params, duration) | |
| raise | |
| def main(args): | |
| with open(args.param_pth) as f: | |
| params = yaml.load(f, Loader=yaml.FullLoader) | |
| # now = datetime.datetime.now().strftime("%Y%m%d") | |
| # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna") | |
| base_dir = "../experiments/20250916_simple_model_optuna" | |
| os.makedirs(base_dir, exist_ok=True) | |
| params["experiment_dir"] = base_dir | |
| # Setup logging | |
| log_path = os.path.join(base_dir, "optuna.log") | |
| setup_logging(log_path) | |
| trial_times = [] | |
| study_name = "filip_contrastive" | |
| storage = f"sqlite:///{base_dir}/optuna_study.db" | |
| study = optuna.create_study(study_name=study_name, storage=storage, direction="minimize", pruner=optuna.pruners.MedianPruner(), load_if_exists=True) | |
| study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials) | |
| # Print best trial | |
| logging.info("\nBest trial:") | |
| logging.info(study.best_trial.params) | |
| # Merge base params with best trial | |
| best_params = params.copy() | |
| best_params.update(study.best_trial.params) | |
| # Save best params to YAML | |
| best_param_path = os.path.join(base_dir, "best_params.yaml") | |
| with open(best_param_path, "w") as f: | |
| yaml.dump(best_params, f) | |
| logging.info(f"\nBest parameters saved to: {best_param_path}") | |
| logging.info(f"Run training with: python train.py --param_pth {best_param_path}") | |
| if __name__ == "__main__": | |
| args = parser.parse_args([] if "__file__" not in globals() else None) | |
| main(args) | |