Spaces:
Sleeping
Sleeping
File size: 4,353 Bytes
c8bfe50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
""" base_hyperopt.py
Abstract away common hyperopt functionality
"""
import logging
import yaml
from pathlib import Path
from datetime import datetime
from typing import Callable
import pytorch_lightning as pl
import ray
from ray import tune
from ray.air.config import RunConfig
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from ray.tune.schedulers.async_hyperband import ASHAScheduler
import mist_cf.common as common
def add_hyperopt_args(parser):
# Tune args
ha = parser.add_argument_group("Hyperopt Args")
ha.add_argument("--cpus-per-trial", default=1, type=int)
ha.add_argument("--gpus-per-trial", default=1, type=float)
ha.add_argument("--num-h-samples", default=50, type=int)
ha.add_argument("--grace-period", default=60 * 15, type=int)
ha.add_argument("--max-concurrent", default=10, type=int)
ha.add_argument("--tune-checkpoint", default=None)
# Overwrite default savedir
time_name = datetime.now().strftime("%Y_%m_%d")
save_default = f"results/{time_name}_hyperopt/"
parser.set_defaults(save_dir=save_default)
def run_hyperopt(
kwargs: dict,
score_function: Callable,
param_space_function: Callable,
initial_points: list,
gen_shared_data: Callable = lambda params: {},
):
"""run_hyperopt.
Args:
kwargs: All dictionary args for hyperopt and train
score_function: Trainable function that sets up model train
param_space_function: Function to suggest new params
initial_points: List of initial params to try
"""
# init ray with new session
ray.init(address="local")
# Fix base_args based upon tune args
kwargs["gpu"] = kwargs.get("gpus_per_trial", 0) > 0
# max_t = args.max_epochs
if kwargs["debug"]:
kwargs["num_h_samples"] = 10
kwargs["max_epochs"] = 5
save_dir = kwargs["save_dir"]
common.setup_logger(
save_dir, log_name="hyperopt.log", debug=kwargs.get("debug", False)
)
pl.utilities.seed.seed_everything(kwargs.get("seed"))
shared_args = gen_shared_data(kwargs)
# Define score function
trainable = tune.with_parameters(
score_function, base_args=kwargs, orig_dir=Path().resolve(), **shared_args
)
# Dump args
yaml_args = yaml.dump(kwargs)
logging.info(f"\n{yaml_args}")
with open(Path(save_dir) / "args.yaml", "w") as fp:
fp.write(yaml_args)
metric = "val_loss"
# Include cpus and gpus per trial
trainable = tune.with_resources(
trainable,
resources=tune.PlacementGroupFactory(
[
{
"CPU": kwargs.get("cpus_per_trial"),
"GPU": kwargs.get("gpus_per_trial"),
},
{
"CPU": kwargs.get("num_workers"),
},
],
strategy="PACK",
),
)
search_algo = OptunaSearch(
metric=metric,
mode="min",
points_to_evaluate=initial_points,
space=param_space_function,
)
search_algo = ConcurrencyLimiter(
search_algo, max_concurrent=kwargs["max_concurrent"]
)
tuner = tune.Tuner(
trainable,
tune_config=tune.TuneConfig(
mode="min",
metric=metric,
search_alg=search_algo,
scheduler=ASHAScheduler(
max_t=24 * 60 * 60, # max_t,
time_attr="time_total_s",
grace_period=kwargs.get("grace_period"),
reduction_factor=2,
),
num_samples=kwargs.get("num_h_samples"),
),
run_config=RunConfig(name=None, local_dir=kwargs["save_dir"]),
)
if kwargs.get("tune_checkpoint") is not None:
ckpt = str(Path(kwargs["tune_checkpoint"]).resolve())
tuner = tuner.restore(path=ckpt, restart_errored=True)
results = tuner.fit()
best_trial = results.get_best_result()
output = {"score": best_trial.metrics[metric], "config": best_trial.config}
out_str = yaml.dump(output, indent=2)
logging.info(out_str)
with open(Path(save_dir) / "best_trial.yaml", "w") as f:
f.write(out_str)
# Output full res table
results.get_dataframe().to_csv(
Path(save_dir) / "full_res_tbl.tsv", sep="\t", index=None
)
|