FLARE / flare /tune.py
yzhouchen001's picture
update
19a4dfc
import argparse
import datetime
import os
import sys
import yaml
import optuna
import time
import logging
import pandas as pd
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from optuna.integration import PyTorchLightningPruningCallback
from pytorch_lightning.callbacks import Callback
from flare.data.data_module import ContrastiveDataModule
from flare.data.datasets import ContrastiveDataset
from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from flare.utils.models import get_model
from flare.definitions import TEST_RESULTS_DIR
from functools import partial
from rdkit import RDLogger
from massspecgym.models.base import Stage
# Suppress RDKit warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
parser.add_argument("--n_trials", type=int, default=20)
class EpochLossTracker(Callback):
def __init__(self, trial):
super().__init__()
self.trial = trial
self.history = {"train_loss": [], "val_loss": []}
def on_train_epoch_end(self, trainer, pl_module):
if "train_loss" in trainer.callback_metrics:
self.history["train_loss"].append(
float(trainer.callback_metrics["train_loss"].cpu().item())
)
def on_validation_epoch_end(self, trainer, pl_module):
val_key = f"{Stage.VAL.to_pref()}loss"
if val_key in trainer.callback_metrics:
self.history["val_loss"].append(
float(trainer.callback_metrics[val_key].cpu().item())
)
def on_fit_end(self, trainer, pl_module):
# Attach to trial so save_trial_result can access it
self.trial.set_user_attr("loss_history", self.history)
class SafePruningCallback(PyTorchLightningPruningCallback, Callback):
"""Wraps Optuna pruning to make it a proper Lightning Callback."""
pass
def setup_logging(log_path):
"""Setup logging without breaking tqdm progress bars."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Remove existing handlers (avoid duplicate or wrong outputs)
if logger.hasHandlers():
logger.handlers.clear()
# File handler
file_handler = logging.FileHandler(log_path, mode="a")
file_handler.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
)
# Console handler (stderr so tqdm stays clean)
console_handler = logging.StreamHandler(sys.stderr)
console_handler.setFormatter(
logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
def save_trial_result(base_dir, trial, params, duration):
"""Append trial results to a CSV file after each trial."""
history_path = os.path.join(base_dir, "trial_history.csv")
# Fetch losses from trial user_attrs
loss_hist = trial.user_attrs.get("loss_history", {})
record = {
"number": trial.number,
"duration_sec": duration,
"train_loss": loss_hist.get("train_loss", []),
"val_loss": loss_hist.get("val_loss", []),
**trial.params,
}
# Append to CSV safely
if os.path.exists(history_path):
df = pd.read_csv(history_path)
df = pd.concat([df, pd.DataFrame([record])], ignore_index=True)
else:
df = pd.DataFrame([record])
df.to_csv(history_path, index=False)
def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_trials):
start_time = time.time()
params = base_params.copy()
try:
# Training-related params
params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.01, 0.1)
# Spectra encoder-related params
params['formula_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5)
params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4, 8])
params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [1,2,3,4,5])
choice = trial.suggest_categorical(
"formula_dims",
["64,128", "512,256", "256,512", "128", "256", "128,128", "512,512", "64,64,64,64"]
)
params["formula_dims"] = [int(x) for x in choice.split(",")]
# Molecule encoder-related params
params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5)
choice = trial.suggest_categorical(
"gnn_channels",
["64,128", "128,256", "256,512", "64,128,128", "128,128", "64,64,64"]
)
params["gnn_channels"] = [int(x) for x in choice.split(",")]
# Ensure last layer matches final embedding dim
final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [64,256,512,1024])
params['formula_dims'].append(final_embedding_dim)
params['gnn_channels'].append(final_embedding_dim)
logging.info(f"Formula dims: {params['formula_dims']}")
logging.info(f"GNN channels: {params['gnn_channels']}")
# Init seed
pl.seed_everything(params["seed"])
# Init dataset + datamodule
spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
dataset = get_ms_dataset(params["spectra_view"], params["molecule_view"], spec_featurizer, mol_featurizer, params)
collate_fn = partial(
ContrastiveDataset.collate_fn,
spec_enc=params["spec_enc"],
spectra_view=params["spectra_view"],
mask_peak_ratio=params["mask_peak_ratio"],
aug_cands=params["aug_cands"],
)
data_module = ContrastiveDataModule(
dataset=dataset,
collate_fn=collate_fn,
split_pth=params["split_pth"],
batch_size=params["batch_size"],
num_workers=params["num_workers"],
)
# Init model
model = get_model(params["model"], params)
# Metric to optimize
callbacks = []
monitor_metric = f"{Stage.VAL.to_pref()}loss"
pruning_cb = SafePruningCallback(trial, monitor=monitor_metric)
callbacks.append(pruning_cb)
loss_tracker = EpochLossTracker(trial)
callbacks.append(loss_tracker)
trainer = Trainer(
accelerator=params["accelerator"],
devices=params["devices"],
max_epochs=params["max_epochs"],
logger=False,
enable_checkpointing=False,
callbacks=callbacks,
)
data_module.prepare_data()
data_module.setup()
# Validate before training
trainer.validate(model, datamodule=data_module)
# Fit (may be pruned early)
trainer.fit(model, datamodule=data_module)
# Duration
duration = time.time() - start_time
trial_times.append(duration)
avg_time = sum(trial_times) / len(trial_times)
remaining = (total_trials - trial.number - 1) * avg_time
logging.info(f"[Trial {trial.number}] Duration: {duration/60:.2f} min | Avg: {avg_time/60:.2f} min | ETA: {remaining/60:.2f} min")
value = trainer.callback_metrics[monitor_metric].item()
trial.set_user_attr("duration", duration)
# Save progress
save_trial_result(base_dir, trial, base_params, duration, )
return value
except Exception as e:
duration = time.time() - start_time
logging.exception(f"Trial {trial.number} failed: {e}")
save_trial_result(base_dir, trial, base_params, duration)
raise
def main(args):
with open(args.param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
# now = datetime.datetime.now().strftime("%Y%m%d")
# base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
base_dir = "../experiments/20250916_simple_model_optuna"
os.makedirs(base_dir, exist_ok=True)
params["experiment_dir"] = base_dir
# Setup logging
log_path = os.path.join(base_dir, "optuna.log")
setup_logging(log_path)
trial_times = []
study_name = "filip_contrastive"
storage = f"sqlite:///{base_dir}/optuna_study.db"
study = optuna.create_study(study_name=study_name, storage=storage, direction="minimize", pruner=optuna.pruners.MedianPruner(), load_if_exists=True)
study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials)
# Print best trial
logging.info("\nBest trial:")
logging.info(study.best_trial.params)
# Merge base params with best trial
best_params = params.copy()
best_params.update(study.best_trial.params)
# Save best params to YAML
best_param_path = os.path.join(base_dir, "best_params.yaml")
with open(best_param_path, "w") as f:
yaml.dump(best_params, f)
logging.info(f"\nBest parameters saved to: {best_param_path}")
logging.info(f"Run training with: python train.py --param_pth {best_param_path}")
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
main(args)