""" base_hyperopt.py Abstract away common hyperopt functionality """ import logging import yaml from pathlib import Path from datetime import datetime from typing import Callable import pytorch_lightning as pl import ray from ray import tune from ray.air.config import RunConfig from ray.tune.search import ConcurrencyLimiter from ray.tune.search.optuna import OptunaSearch from ray.tune.schedulers.async_hyperband import ASHAScheduler import mist_cf.common as common def add_hyperopt_args(parser): # Tune args ha = parser.add_argument_group("Hyperopt Args") ha.add_argument("--cpus-per-trial", default=1, type=int) ha.add_argument("--gpus-per-trial", default=1, type=float) ha.add_argument("--num-h-samples", default=50, type=int) ha.add_argument("--grace-period", default=60 * 15, type=int) ha.add_argument("--max-concurrent", default=10, type=int) ha.add_argument("--tune-checkpoint", default=None) # Overwrite default savedir time_name = datetime.now().strftime("%Y_%m_%d") save_default = f"results/{time_name}_hyperopt/" parser.set_defaults(save_dir=save_default) def run_hyperopt( kwargs: dict, score_function: Callable, param_space_function: Callable, initial_points: list, gen_shared_data: Callable = lambda params: {}, ): """run_hyperopt. Args: kwargs: All dictionary args for hyperopt and train score_function: Trainable function that sets up model train param_space_function: Function to suggest new params initial_points: List of initial params to try """ # init ray with new session ray.init(address="local") # Fix base_args based upon tune args kwargs["gpu"] = kwargs.get("gpus_per_trial", 0) > 0 # max_t = args.max_epochs if kwargs["debug"]: kwargs["num_h_samples"] = 10 kwargs["max_epochs"] = 5 save_dir = kwargs["save_dir"] common.setup_logger( save_dir, log_name="hyperopt.log", debug=kwargs.get("debug", False) ) pl.utilities.seed.seed_everything(kwargs.get("seed")) shared_args = gen_shared_data(kwargs) # Define score function trainable = tune.with_parameters( score_function, base_args=kwargs, orig_dir=Path().resolve(), **shared_args ) # Dump args yaml_args = yaml.dump(kwargs) logging.info(f"\n{yaml_args}") with open(Path(save_dir) / "args.yaml", "w") as fp: fp.write(yaml_args) metric = "val_loss" # Include cpus and gpus per trial trainable = tune.with_resources( trainable, resources=tune.PlacementGroupFactory( [ { "CPU": kwargs.get("cpus_per_trial"), "GPU": kwargs.get("gpus_per_trial"), }, { "CPU": kwargs.get("num_workers"), }, ], strategy="PACK", ), ) search_algo = OptunaSearch( metric=metric, mode="min", points_to_evaluate=initial_points, space=param_space_function, ) search_algo = ConcurrencyLimiter( search_algo, max_concurrent=kwargs["max_concurrent"] ) tuner = tune.Tuner( trainable, tune_config=tune.TuneConfig( mode="min", metric=metric, search_alg=search_algo, scheduler=ASHAScheduler( max_t=24 * 60 * 60, # max_t, time_attr="time_total_s", grace_period=kwargs.get("grace_period"), reduction_factor=2, ), num_samples=kwargs.get("num_h_samples"), ), run_config=RunConfig(name=None, local_dir=kwargs["save_dir"]), ) if kwargs.get("tune_checkpoint") is not None: ckpt = str(Path(kwargs["tune_checkpoint"]).resolve()) tuner = tuner.restore(path=ckpt, restart_errored=True) results = tuner.fit() best_trial = results.get_best_result() output = {"score": best_trial.metrics[metric], "config": best_trial.config} out_str = yaml.dump(output, indent=2) logging.info(out_str) with open(Path(save_dir) / "best_trial.yaml", "w") as f: f.write(out_str) # Output full res table results.get_dataframe().to_csv( Path(save_dir) / "full_res_tbl.tsv", sep="\t", index=None )