diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de2f49c6015ae4616e5d4f34df7e14e7452eb45b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3946b685dd0aa2cc1e201f5e2f8aac1f6f1aa914 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..246cbcff8fe4c973530fe195fb5a92f1727a714f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..678e23a8305b89504e007b6ec56cbd35526bfdd8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07e6970766c5208d5437bc48e22086855a473b74 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fa7e13f93c530d5b478f36a147dce9729045c17 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a287be60f9ef3cadc25efb14e57899892bde6e6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e32852ddc507d0e4e801ac9ccb86773c348fb450 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a94fdc1ce408c1d6095fa80391d8b76d6606299 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c90445292a85d30de50ea5028141c4f276547a8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb0768d9c67094069d788e470e8160ea6d6da243 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54d33d99b2b54973a462dea00fa85bfe483a6aa5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a606cdd88c7ab59e149f08c61dd924e8f63799cc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9ed6a6b93f9016fa7c26ecccd57face6e7c0f90 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1960c3c211a1f310504b09dbe3b3ee804cded9 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ceb43ac22d58b5fe79e3965fcfaede78290e936 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5abb4f08c74f9fdafb16c52b19eef0fe6467e537 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/commands.py b/.venv/lib/python3.11/site-packages/ray/tune/cli/commands.py new file mode 100644 index 0000000000000000000000000000000000000000..09070124124eaeb9a9e959dfc66e3c79d5fa7734 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/cli/commands.py @@ -0,0 +1,306 @@ +import logging +import operator +import os +import shutil +import subprocess +from datetime import datetime +from pathlib import Path +from typing import List, Optional + +import click +import pandas as pd +from pandas.api.types import is_numeric_dtype, is_string_dtype + +from ray._private.thirdparty.tabulate.tabulate import tabulate +from ray.air.constants import EXPR_RESULT_FILE +from ray.tune import TuneError +from ray.tune.analysis import ExperimentAnalysis +from ray.tune.result import ( + CONFIG_PREFIX, + DEFAULT_EXPERIMENT_INFO_KEYS, + DEFAULT_RESULT_KEYS, +) + +logger = logging.getLogger(__name__) + +EDITOR = os.getenv("EDITOR", "vim") + +TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)" + +DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS + +DEFAULT_PROJECT_INFO_KEYS = ( + "name", + "total_trials", + "last_updated", +) + +TERM_WIDTH, TERM_HEIGHT = shutil.get_terminal_size(fallback=(100, 100)) + +OPERATORS = { + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + ">=": operator.ge, + ">": operator.gt, +} + + +def _check_tabulate(): + """Checks whether tabulate is installed.""" + if tabulate is None: + raise ImportError("Tabulate not installed. Please run `pip install tabulate`.") + + +def print_format_output(dataframe): + """Prints output of given dataframe to fit into terminal. + + Returns: + table: Final outputted dataframe. + dropped_cols: Columns dropped due to terminal size. + empty_cols: Empty columns (dropped on default). + """ + print_df = pd.DataFrame() + dropped_cols = [] + empty_cols = [] + # column display priority is based on the info_keys passed in + for i, col in enumerate(dataframe): + if dataframe[col].isnull().all(): + # Don't add col to print_df if is fully empty + empty_cols += [col] + continue + + print_df[col] = dataframe[col] + test_table = tabulate(print_df, headers="keys", tablefmt="psql") + if str(test_table).index("\n") > TERM_WIDTH: + # Drop all columns beyond terminal width + print_df.drop(col, axis=1, inplace=True) + dropped_cols += list(dataframe.columns)[i:] + break + + table = tabulate(print_df, headers="keys", tablefmt="psql", showindex="never") + + print(table) + if dropped_cols: + click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow") + click.secho("Please increase your terminal size to view remaining columns.") + if empty_cols: + click.secho("Empty columns: {}".format(empty_cols), fg="yellow") + + return table, dropped_cols, empty_cols + + +def list_trials( + experiment_path: str, + sort: Optional[List[str]] = None, + output: Optional[str] = None, + filter_op: Optional[str] = None, + info_keys: Optional[List[str]] = None, + limit: int = None, + desc: bool = False, +): + """Lists trials in the directory subtree starting at the given path. + + Args: + experiment_path: Directory where trials are located. + Like Experiment.local_dir/Experiment.name/experiment*.json. + sort: Keys to sort by. + output: Name of file where output is saved. + filter_op: Filter operation in the format + " ". + info_keys: Keys that are displayed. + limit: Number of rows to display. + desc: Sort ascending vs. descending. + """ + _check_tabulate() + + try: + checkpoints_df = ExperimentAnalysis(experiment_path).dataframe() # last result + except TuneError as e: + raise click.ClickException("No trial data found!") from e + + config_prefix = CONFIG_PREFIX + "/" + + def key_filter(k): + return k in DEFAULT_CLI_KEYS or k.startswith(config_prefix) + + col_keys = [k for k in checkpoints_df.columns if key_filter(k)] + + if info_keys: + for k in info_keys: + if k not in checkpoints_df.columns: + raise click.ClickException( + "Provided key invalid: {}. " + "Available keys: {}.".format(k, checkpoints_df.columns) + ) + col_keys = [k for k in checkpoints_df.columns if k in info_keys] + + if not col_keys: + raise click.ClickException("No columns to output.") + + checkpoints_df = checkpoints_df[col_keys] + if "last_update_time" in checkpoints_df: + with pd.option_context("mode.use_inf_as_null", True): + datetime_series = checkpoints_df["last_update_time"].dropna() + + datetime_series = datetime_series.apply( + lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT) + ) + checkpoints_df["last_update_time"] = datetime_series + + if "logdir" in checkpoints_df: + # logdir often too long to view in table, so drop experiment_path + checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace( + experiment_path, "" + ) + + if filter_op: + col, op, val = filter_op.split(" ") + col_type = checkpoints_df[col].dtype + if is_numeric_dtype(col_type): + val = float(val) + elif is_string_dtype(col_type): + val = str(val) + # TODO(Andrew): add support for datetime and boolean + else: + raise click.ClickException( + "Unsupported dtype for {}: {}".format(val, col_type) + ) + op = OPERATORS[op] + filtered_index = op(checkpoints_df[col], val) + checkpoints_df = checkpoints_df[filtered_index] + + if sort: + for key in sort: + if key not in checkpoints_df: + raise click.ClickException( + "{} not in: {}".format(key, list(checkpoints_df)) + ) + ascending = not desc + checkpoints_df = checkpoints_df.sort_values(by=sort, ascending=ascending) + + if limit: + checkpoints_df = checkpoints_df[:limit] + + print_format_output(checkpoints_df) + + if output: + file_extension = os.path.splitext(output)[1].lower() + if file_extension in (".p", ".pkl", ".pickle"): + checkpoints_df.to_pickle(output) + elif file_extension == ".csv": + checkpoints_df.to_csv(output, index=False) + else: + raise click.ClickException("Unsupported filetype: {}".format(output)) + click.secho("Output saved at {}".format(output), fg="green") + + +def list_experiments( + project_path: str, + sort: Optional[List[str]] = None, + output: str = None, + filter_op: str = None, + info_keys: Optional[List[str]] = None, + limit: int = None, + desc: bool = False, +): + """Lists experiments in the directory subtree. + + Args: + project_path: Directory where experiments are located. + Corresponds to Experiment.local_dir. + sort: Keys to sort by. + output: Name of file where output is saved. + filter_op: Filter operation in the format + " ". + info_keys: Keys that are displayed. + limit: Number of rows to display. + desc: Sort ascending vs. descending. + """ + _check_tabulate() + base, experiment_folders, _ = next(os.walk(project_path)) + + experiment_data_collection = [] + + for experiment_dir in experiment_folders: + num_trials = sum( + EXPR_RESULT_FILE in files + for _, _, files in os.walk(os.path.join(base, experiment_dir)) + ) + + experiment_data = {"name": experiment_dir, "total_trials": num_trials} + experiment_data_collection.append(experiment_data) + + if not experiment_data_collection: + raise click.ClickException("No experiments found!") + + info_df = pd.DataFrame(experiment_data_collection) + if not info_keys: + info_keys = DEFAULT_PROJECT_INFO_KEYS + col_keys = [k for k in list(info_keys) if k in info_df] + if not col_keys: + raise click.ClickException( + "None of keys {} in experiment data!".format(info_keys) + ) + info_df = info_df[col_keys] + + if filter_op: + col, op, val = filter_op.split(" ") + col_type = info_df[col].dtype + if is_numeric_dtype(col_type): + val = float(val) + elif is_string_dtype(col_type): + val = str(val) + # TODO(Andrew): add support for datetime and boolean + else: + raise click.ClickException( + "Unsupported dtype for {}: {}".format(val, col_type) + ) + op = OPERATORS[op] + filtered_index = op(info_df[col], val) + info_df = info_df[filtered_index] + + if sort: + for key in sort: + if key not in info_df: + raise click.ClickException("{} not in: {}".format(key, list(info_df))) + ascending = not desc + info_df = info_df.sort_values(by=sort, ascending=ascending) + + if limit: + info_df = info_df[:limit] + + print_format_output(info_df) + + if output: + file_extension = os.path.splitext(output)[1].lower() + if file_extension in (".p", ".pkl", ".pickle"): + info_df.to_pickle(output) + elif file_extension == ".csv": + info_df.to_csv(output, index=False) + else: + raise click.ClickException("Unsupported filetype: {}".format(output)) + click.secho("Output saved at {}".format(output), fg="green") + + +def add_note(path: str, filename: str = "note.txt"): + """Opens a txt file at the given path where user can add and save notes. + + Args: + path: Directory where note will be saved. + filename: Name of note. Defaults to "note.txt" + """ + path = Path(path).expanduser() + assert path.is_dir(), "{} is not a valid directory.".format(path) + + filepath = path / filename + + try: + subprocess.call([EDITOR, filepath.as_posix()]) + except Exception as exc: + click.secho("Editing note failed: {}".format(str(exc)), fg="red") + if filepath.exists(): + print("Note updated at:", filepath.as_posix()) + else: + print("Note created at:", filepath.as_posix()) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py b/.venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..5401d091c3ba622e78532985e60a18307a2484d9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py @@ -0,0 +1,101 @@ +import click + +import ray.tune.cli.commands as commands + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument("experiment_path", required=True, type=str) +@click.option("--sort", default=None, type=str, help="Select which column to sort on.") +@click.option( + "--output", + "-o", + default=None, + type=str, + help="Select file to output information to.", +) +@click.option( + "--filter", + "filter_op", + default=None, + type=str, + help="Select filter in the format ' '.", +) +@click.option( + "--columns", default=None, type=str, help="Select columns to be displayed." +) +@click.option( + "--limit", default=None, type=int, help="Select number of rows to display." +) +@click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.") +def list_trials(experiment_path, sort, output, filter_op, columns, limit, desc): + """Lists trials in the directory subtree starting at the given path.""" + if sort: + sort = sort.split(",") + if columns: + columns = columns.split(",") + commands.list_trials(experiment_path, sort, output, filter_op, columns, limit, desc) + + +@cli.command() +@click.argument("project_path", required=True, type=str) +@click.option("--sort", default=None, type=str, help="Select which column to sort on.") +@click.option( + "--output", + "-o", + default=None, + type=str, + help="Select file to output information to.", +) +@click.option( + "--filter", + "filter_op", + default=None, + type=str, + help="Select filter in the format ' '.", +) +@click.option( + "--columns", default=None, type=str, help="Select columns to be displayed." +) +@click.option( + "--limit", default=None, type=int, help="Select number of rows to display." +) +@click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.") +def list_experiments(project_path, sort, output, filter_op, columns, limit, desc): + """Lists experiments in the directory subtree.""" + if sort: + sort = sort.split(",") + if columns: + columns = columns.split(",") + commands.list_experiments( + project_path, sort, output, filter_op, columns, limit, desc + ) + + +@cli.command() +@click.argument("path", required=True, type=str) +@click.option( + "--filename", default="note.txt", type=str, help="Specify filename for note." +) +def add_note(path, filename): + """Adds user notes as a text file at the given path.""" + commands.add_note(path, filename) + + +cli.add_command(list_trials, name="ls") +cli.add_command(list_trials, name="list-trials") +cli.add_command(list_experiments, name="lsx") +cli.add_command(list_experiments, name="list-experiments") +cli.add_command(add_note, name="add-note") + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9406b56167ba406a6744c48dc33f0863be5301dc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4ef6ca9d555ff84075eb549aa7d07bca7435c3e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af790dfef40cf0756f992304c848206155ccfcee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6636a8c85fa15c2bb1ed2ad4d4a4d66a0d20d069 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10bf32b518004b0d43d3397b9cb3307ec5b315c8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f746c8ac3f67e3d378d5fb145698c5e15e47deb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3f76e9e17531f9cabdd6dd8a98cb08833a37d6a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py @@ -0,0 +1,285 @@ +import os + +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +import torchvision.datasets as dset +import torchvision.transforms as transforms +import torchvision.utils as vutils +from scipy.stats import entropy +from torch.autograd import Variable +from torch.nn import functional as F + +import ray + +# Training parameters +workers = 2 +batch_size = 64 +image_size = 32 + +# Number of channels in the training images. For color images this is 3 +nc = 1 + +# Size of z latent vector (i.e. size of generator input) +nz = 100 + +# Size of feature maps in generator +ngf = 32 + +# Size of feature maps in discriminator +ndf = 32 + +# Beta1 hyperparam for Adam optimizers +beta1 = 0.5 + +# iterations of actual training in each Trainable _train +train_iterations_per_step = 5 + +MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt") + + +def get_data_loader(data_dir="~/data"): + dataset = dset.MNIST( + root=data_dir, + download=True, + transform=transforms.Compose( + [ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] + ), + ) + + # Create the dataloader + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=workers + ) + + return dataloader + + +# __GANmodel_begin__ +# custom weights initialization called on netG and netD +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +# Generator Code +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.main = nn.Sequential( + # input is Z, going into a convolution + nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False), + nn.BatchNorm2d(ngf * 4), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf * 2), + nn.ReLU(True), + nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), + nn.BatchNorm2d(ngf), + nn.ReLU(True), + nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), + nn.Tanh(), + ) + + def forward(self, input): + return self.main(input) + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.main = nn.Sequential( + nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), + nn.BatchNorm2d(ndf * 4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), + nn.Sigmoid(), + ) + + def forward(self, input): + return self.main(input) + + +# __GANmodel_end__ + + +# __INCEPTION_SCORE_begin__ +class Net(nn.Module): + """ + LeNet for MNist classification, used for inception_score + """ + + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def inception_score(imgs, mnist_model_ref, batch_size=32, splits=1): + N = len(imgs) + dtype = torch.FloatTensor + dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) + cm = ray.get(mnist_model_ref) # Get the mnist model from Ray object store. + up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype) + + def get_pred(x): + x = up(x) + x = cm(x) + return F.softmax(x).data.cpu().numpy() + + preds = np.zeros((N, 10)) + for i, batch in enumerate(dataloader, 0): + batch = batch.type(dtype) + batchv = Variable(batch) + batch_size_i = batch.size()[0] + preds[i * batch_size : i * batch_size + batch_size_i] = get_pred(batchv) + + # Now compute the mean kl-div + split_scores = [] + for k in range(splits): + part = preds[k * (N // splits) : (k + 1) * (N // splits), :] + py = np.mean(part, axis=0) + scores = [] + for i in range(part.shape[0]): + pyx = part[i, :] + scores.append(entropy(pyx, py)) + split_scores.append(np.exp(np.mean(scores))) + + return np.mean(split_scores), np.std(split_scores) + + +# __INCEPTION_SCORE_end__ + + +def train_func( + netD, + netG, + optimG, + optimD, + criterion, + dataloader, + iteration, + device, + mnist_model_ref, +): + real_label = 1 + fake_label = 0 + + for i, data in enumerate(dataloader, 0): + if i >= train_iterations_per_step: + break + + netD.zero_grad() + real_cpu = data[0].to(device) + b_size = real_cpu.size(0) + label = torch.full((b_size,), real_label, dtype=torch.float, device=device) + output = netD(real_cpu).view(-1) + errD_real = criterion(output, label) + errD_real.backward() + D_x = output.mean().item() + + noise = torch.randn(b_size, nz, 1, 1, device=device) + fake = netG(noise) + label.fill_(fake_label) + output = netD(fake.detach()).view(-1) + errD_fake = criterion(output, label) + errD_fake.backward() + D_G_z1 = output.mean().item() + errD = errD_real + errD_fake + optimD.step() + + netG.zero_grad() + label.fill_(real_label) + output = netD(fake).view(-1) + errG = criterion(output, label) + errG.backward() + D_G_z2 = output.mean().item() + optimG.step() + + is_score, is_std = inception_score(fake, mnist_model_ref) + + # Output training stats + if iteration % 10 == 0: + print( + "[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))" + ": %.4f / %.4f \tInception score: %.4f" + % ( + iteration, + len(dataloader), + errD.item(), + errG.item(), + D_x, + D_G_z1, + D_G_z2, + is_score, + ) + ) + + return errG.item(), errD.item(), is_score + + +def plot_images(dataloader): + # Plot some training images + real_batch = next(iter(dataloader)) + plt.figure(figsize=(8, 8)) + plt.axis("off") + plt.title("Original Images") + plt.imshow( + np.transpose( + vutils.make_grid(real_batch[0][:64], padding=2, normalize=True).cpu(), + (1, 2, 0), + ) + ) + + plt.show() + + +def demo_gan(checkpoint_paths): + img_list = [] + fixed_noise = torch.randn(64, nz, 1, 1) + for path in checkpoint_paths: + checkpoint_dict = torch.load(os.path.join(path, "checkpoint.pt")) + + loadedG = Generator() + loadedG.load_state_dict(checkpoint_dict["netGmodel"]) + with torch.no_grad(): + fake = loadedG(fixed_noise).detach().cpu() + img_list.append(vutils.make_grid(fake, padding=2, normalize=True)) + + fig = plt.figure(figsize=(8, 8)) + plt.axis("off") + ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list] + ani = animation.ArtistAnimation( + fig, ims, interval=1000, repeat_delay=1000, blit=True + ) + ani.save("./generated.gif", writer="imagemagick", dpi=72) + plt.show() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py new file mode 100644 index 0000000000000000000000000000000000000000..acb1edae2a85a420fe6a757853a7cae1df818d86 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +""" +Example of training DCGAN on MNIST using PBT with Tune's function API. +""" +import argparse +import os +import tempfile + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +from filelock import FileLock + +import ray +from ray import train, tune +from ray.train import Checkpoint +from ray.tune.examples.pbt_dcgan_mnist.common import ( + MODEL_PATH, + Discriminator, + Generator, + Net, + beta1, + demo_gan, + get_data_loader, + plot_images, + train_func, + weights_init, +) +from ray.tune.schedulers import PopulationBasedTraining + + +# __Train_begin__ +def dcgan_train(config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + netD = Discriminator().to(device) + netD.apply(weights_init) + netG = Generator().to(device) + netG.apply(weights_init) + criterion = nn.BCELoss() + optimizerD = optim.Adam( + netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999) + ) + optimizerG = optim.Adam( + netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999) + ) + with FileLock(os.path.expanduser("~/ray_results/.data.lock")): + dataloader = get_data_loader() + + step = 1 + checkpoint = train.get_checkpoint() + if checkpoint: + with checkpoint.as_directory() as checkpoint_dir: + checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt")) + netD.load_state_dict(checkpoint_dict["netDmodel"]) + netG.load_state_dict(checkpoint_dict["netGmodel"]) + optimizerD.load_state_dict(checkpoint_dict["optimD"]) + optimizerG.load_state_dict(checkpoint_dict["optimG"]) + # Note: Make sure to increment the loaded step by 1 to get the + # current step. + last_step = checkpoint_dict["step"] + step = last_step + 1 + + # NOTE: It's important to set the optimizer learning rates + # again, since we want to explore the parameters passed in by PBT. + # Without this, we would continue using the exact same + # configuration as the trial whose checkpoint we are exploiting. + if "netD_lr" in config: + for param_group in optimizerD.param_groups: + param_group["lr"] = config["netD_lr"] + if "netG_lr" in config: + for param_group in optimizerG.param_groups: + param_group["lr"] = config["netG_lr"] + + while True: + lossG, lossD, is_score = train_func( + netD, + netG, + optimizerG, + optimizerD, + criterion, + dataloader, + step, + device, + config["mnist_model_ref"], + ) + metrics = {"lossg": lossG, "lossd": lossD, "is_score": is_score} + + if step % config["checkpoint_interval"] == 0: + with tempfile.TemporaryDirectory() as tmpdir: + torch.save( + { + "netDmodel": netD.state_dict(), + "netGmodel": netG.state_dict(), + "optimD": optimizerD.state_dict(), + "optimG": optimizerG.state_dict(), + "step": step, + }, + os.path.join(tmpdir, "checkpoint.pt"), + ) + train.report(metrics, checkpoint=Checkpoint.from_directory(tmpdir)) + else: + train.report(metrics) + + step += 1 + + +# __Train_end__ + + +def download_mnist_cnn(): + import urllib.request + + # Download a pre-trained MNIST model for inception score calculation. + # This is a tiny model (<100kb). + if not os.path.exists(MODEL_PATH): + print("downloading model") + os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) + urllib.request.urlretrieve( + "https://github.com/ray-project/ray/raw/master/python/ray/tune/" + "examples/pbt_dcgan_mnist/mnist_cnn.pt", + MODEL_PATH, + ) + return MODEL_PATH + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) + parser.add_argument( + "--data-dir", type=str, default="~/data/", help="Set the path of the dataset." + ) + args, _ = parser.parse_known_args() + ray.init() + + download_mnist_cnn() + + dataloader = get_data_loader(args.data_dir) + if not args.smoke_test: + plot_images(dataloader) + + # __tune_begin__ + + # load the pretrained mnist classification model for inception_score + mnist_cnn = Net() + mnist_cnn.load_state_dict(torch.load(MODEL_PATH)) + mnist_cnn.eval() + # Put the model in Ray object store. + mnist_model_ref = ray.put(mnist_cnn) + + scheduler = PopulationBasedTraining( + perturbation_interval=5, + hyperparam_mutations={ + # distribution for resampling + "netG_lr": lambda: np.random.uniform(1e-2, 1e-5), + "netD_lr": lambda: np.random.uniform(1e-2, 1e-5), + }, + ) + + tune_iter = 5 if args.smoke_test else 300 + tuner = tune.Tuner( + dcgan_train, + run_config=train.RunConfig( + name="pbt_dcgan_mnist", + stop={"training_iteration": tune_iter}, + verbose=1, + ), + tune_config=tune.TuneConfig( + metric="is_score", + mode="max", + num_samples=8, + scheduler=scheduler, + ), + param_space={ + "netG_lr": tune.choice([0.0001, 0.0002, 0.0005]), + "netD_lr": tune.choice([0.0001, 0.0002, 0.0005]), + "mnist_model_ref": mnist_model_ref, + }, + ) + results = tuner.fit() + # __tune_end__ + + # demo of the trained Generators + if not args.smoke_test: + checkpoint_paths = [result.checkpoint.to_directory() for result in results] + demo_gan(checkpoint_paths) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6678fbbff7b92fb530127c68f9c2dedfe79a20 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python +""" +Example of training DCGAN on MNIST using PBT with Tune's Trainable Class +API. +""" +import argparse +import os +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +from filelock import FileLock + +import ray +from ray import train, tune +from ray.tune.examples.pbt_dcgan_mnist.common import ( + MODEL_PATH, + Discriminator, + Generator, + Net, + beta1, + demo_gan, + get_data_loader, + plot_images, + train_func, + weights_init, +) +from ray.tune.schedulers import PopulationBasedTraining + + +# __Trainable_begin__ +class PytorchTrainable(tune.Trainable): + def setup(self, config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + self.device = torch.device("cuda" if use_cuda else "cpu") + self.netD = Discriminator().to(self.device) + self.netD.apply(weights_init) + self.netG = Generator().to(self.device) + self.netG.apply(weights_init) + self.criterion = nn.BCELoss() + self.optimizerD = optim.Adam( + self.netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999) + ) + self.optimizerG = optim.Adam( + self.netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999) + ) + with FileLock(os.path.expanduser("~/.data.lock")): + self.dataloader = get_data_loader(config.get("data_dir", "~/data")) + self.mnist_model_ref = config["mnist_model_ref"] + + def step(self): + lossG, lossD, is_score = train_func( + self.netD, + self.netG, + self.optimizerG, + self.optimizerD, + self.criterion, + self.dataloader, + self._iteration, + self.device, + self.mnist_model_ref, + ) + return {"lossg": lossG, "lossd": lossD, "is_score": is_score} + + def save_checkpoint(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint.pt") + torch.save( + { + "netDmodel": self.netD.state_dict(), + "netGmodel": self.netG.state_dict(), + "optimD": self.optimizerD.state_dict(), + "optimG": self.optimizerG.state_dict(), + }, + path, + ) + + return checkpoint_dir + + def load_checkpoint(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint.pt") + checkpoint = torch.load(path) + self.netD.load_state_dict(checkpoint["netDmodel"]) + self.netG.load_state_dict(checkpoint["netGmodel"]) + self.optimizerD.load_state_dict(checkpoint["optimD"]) + self.optimizerG.load_state_dict(checkpoint["optimG"]) + + def reset_config(self, new_config): + if "netD_lr" in new_config: + for param_group in self.optimizerD.param_groups: + param_group["lr"] = new_config["netD_lr"] + if "netG_lr" in new_config: + for param_group in self.optimizerG.param_groups: + param_group["lr"] = new_config["netG_lr"] + + self.config = new_config + return True + + +# __Trainable_end__ + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing" + ) + parser.add_argument( + "--data-dir", type=str, default="~/data/", help="Set the path of the dataset." + ) + args, _ = parser.parse_known_args() + ray.init() + + import urllib.request + + # Download a pre-trained MNIST model for inception score calculation. + # This is a tiny model (<100kb). + if not os.path.exists(MODEL_PATH): + print("downloading model") + os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) + urllib.request.urlretrieve( + "https://github.com/ray-project/ray/raw/master/python/ray/tune/" + "examples/pbt_dcgan_mnist/mnist_cnn.pt", + MODEL_PATH, + ) + + dataloader = get_data_loader() + if not args.smoke_test: + plot_images(dataloader) + + # load the pretrained mnist classification model for inception_score + mnist_cnn = Net() + mnist_cnn.load_state_dict(torch.load(MODEL_PATH)) + mnist_cnn.eval() + mnist_model_ref = ray.put(mnist_cnn) + + # __tune_begin__ + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + perturbation_interval=5, + hyperparam_mutations={ + # distribution for resampling + "netG_lr": lambda: np.random.uniform(1e-2, 1e-5), + "netD_lr": lambda: np.random.uniform(1e-2, 1e-5), + }, + ) + + tune_iter = 10 if args.smoke_test else 300 + tuner = tune.Tuner( + PytorchTrainable, + run_config=train.RunConfig( + name="pbt_dcgan_mnist", + stop={"training_iteration": tune_iter}, + verbose=1, + checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True), + ), + tune_config=tune.TuneConfig( + metric="is_score", + mode="max", + num_samples=8, + scheduler=scheduler, + reuse_actors=True, + ), + param_space={ + "netG_lr": tune.sample_from( + lambda spec: random.choice([0.0001, 0.0002, 0.0005]) + ), + "netD_lr": tune.sample_from( + lambda spec: random.choice([0.0001, 0.0002, 0.0005]) + ), + "mnist_model_ref": mnist_model_ref, + "data_dir": args.data_dir, + }, + ) + results = tuner.fit() + + # export_formats=[ExportFormat.MODEL] + # __tune_end__ + + # demo of the trained Generators + if not args.smoke_test: + checkpoint_paths = [result.checkpoint.to_directory() for result in results] + demo_gan(checkpoint_paths) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfe119bb4db3db2b5e1beef046ec309b32e856c5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b075195f981405d11715f87cb958326f42bc9a5f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/experimental/output.py b/.venv/lib/python3.11/site-packages/ray/tune/experimental/output.py new file mode 100644 index 0000000000000000000000000000000000000000..699217e7534abbc9be1ca409dffbce609d5b06e4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/experimental/output.py @@ -0,0 +1,1043 @@ +import argparse +import collections +import datetime +import logging +import math +import numbers +import os +import sys +import textwrap +import time +from dataclasses import dataclass +from enum import IntEnum +from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd + +import ray +from ray._private.dict import flatten_dict, unflattened_lookup +from ray._private.thirdparty.tabulate.tabulate import ( + DataRow, + Line, + TableFormat, + tabulate, +) +from ray.air._internal.usage import AirEntrypoint +from ray.air.constants import TRAINING_ITERATION +from ray.train import Checkpoint +from ray.tune.callback import Callback +from ray.tune.experiment.trial import Trial +from ray.tune.result import ( + AUTO_RESULT_KEYS, + EPISODE_REWARD_MEAN, + MEAN_ACCURACY, + MEAN_LOSS, + TIME_TOTAL_S, + TIMESTEPS_TOTAL, +) +from ray.tune.search.sample import Domain +from ray.tune.utils.log import Verbosity + +try: + import rich + import rich.layout + import rich.live +except ImportError: + rich = None + + +logger = logging.getLogger(__name__) + +# defines the mapping of the key in result and the key to be printed in table. +# Note this is ordered! +DEFAULT_COLUMNS = collections.OrderedDict( + { + MEAN_ACCURACY: "acc", + MEAN_LOSS: "loss", + TRAINING_ITERATION: "iter", + TIME_TOTAL_S: "total time (s)", + TIMESTEPS_TOTAL: "ts", + EPISODE_REWARD_MEAN: "reward", + } +) + +# These keys are blacklisted for printing out training/tuning intermediate/final result! +BLACKLISTED_KEYS = { + "config", + "date", + "done", + "hostname", + "iterations_since_restore", + "node_ip", + "pid", + "time_since_restore", + "timestamp", + "trial_id", + "experiment_tag", + "should_checkpoint", + "_report_on", # LIGHTNING_REPORT_STAGE_KEY +} + +VALID_SUMMARY_TYPES = { + int, + float, + np.float32, + np.float64, + np.int32, + np.int64, + type(None), +} + +# The order of summarizing trials. +ORDER = [ + Trial.RUNNING, + Trial.TERMINATED, + Trial.PAUSED, + Trial.PENDING, + Trial.ERROR, +] + + +class AirVerbosity(IntEnum): + SILENT = 0 + DEFAULT = 1 + VERBOSE = 2 + + def __repr__(self): + return str(self.value) + + +IS_NOTEBOOK = ray.widgets.util.in_notebook() + + +def get_air_verbosity( + verbose: Union[int, AirVerbosity, Verbosity] +) -> Optional[AirVerbosity]: + if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0": + return None + + if isinstance(verbose, AirVerbosity): + return verbose + + verbose_int = verbose if isinstance(verbose, int) else verbose.value + + # Verbosity 2 and 3 both map to AirVerbosity 2 + verbose_int = min(2, verbose_int) + + return AirVerbosity(verbose_int) + + +def _infer_params(config: Dict[str, Any]) -> List[str]: + params = [] + flat_config = flatten_dict(config) + for key, val in flat_config.items(): + if isinstance(val, Domain): + params.append(key) + # Grid search is a special named field. Because we flattened + # the whole config, we look it up per string + if key.endswith("/grid_search"): + # Truncate `/grid_search` + params.append(key[:-12]) + return params + + +def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]: + """Get strings representing the current and elapsed time. + + Args: + start_time: POSIX timestamp of the start of the tune run + current_time: POSIX timestamp giving the current time + + Returns: + Current time and elapsed time for the current run + """ + current_time_dt = datetime.datetime.fromtimestamp(current_time) + start_time_dt = datetime.datetime.fromtimestamp(start_time) + delta: datetime.timedelta = current_time_dt - start_time_dt + + rest = delta.total_seconds() + days = int(rest // (60 * 60 * 24)) + + rest -= days * (60 * 60 * 24) + hours = int(rest // (60 * 60)) + + rest -= hours * (60 * 60) + minutes = int(rest // 60) + + seconds = int(rest - minutes * 60) + + running_for_str = "" + if days > 0: + running_for_str += f"{days:d}d " + + if hours > 0 or running_for_str: + running_for_str += f"{hours:d}hr " + + if minutes > 0 or running_for_str: + running_for_str += f"{minutes:d}min " + + running_for_str += f"{seconds:d}s" + + return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str + + +def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]: + trials_by_state = collections.defaultdict(list) + for t in trials: + trials_by_state[t.status].append(t) + return trials_by_state + + +def _get_trials_with_error(trials: List[Trial]) -> List[Trial]: + return [t for t in trials if t.error_file] + + +def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]: + """Try to infer the metrics to print out. + + By default, only the first 4 meaningful metrics in `last_result` will be + inferred as user implied metrics. + """ + # Using OrderedDict for OrderedSet. + result = collections.OrderedDict() + for t in trials: + if not t.last_result: + continue + for metric, value in t.last_result.items(): + if metric not in DEFAULT_COLUMNS: + if metric not in AUTO_RESULT_KEYS: + if type(value) in VALID_SUMMARY_TYPES: + result[metric] = "" # not important + + if len(result) >= limit: + return list(result.keys()) + return list(result.keys()) + + +def _current_best_trial( + trials: List[Trial], metric: Optional[str], mode: Optional[str] +) -> Tuple[Optional[Trial], Optional[str]]: + """ + Returns the best trial and the metric key. If anything is empty or None, + returns a trivial result of None, None. + + Args: + trials: List of trials. + metric: Metric that trials are being ranked. + mode: One of "min" or "max". + + Returns: + Best trial and the metric key. + """ + if not trials or not metric or not mode: + return None, None + + metric_op = 1.0 if mode == "max" else -1.0 + best_metric = float("-inf") + best_trial = None + for t in trials: + if not t.last_result: + continue + metric_value = unflattened_lookup(metric, t.last_result, default=None) + if pd.isnull(metric_value): + continue + if not best_trial or metric_value * metric_op > best_metric: + best_metric = metric_value * metric_op + best_trial = t + return best_trial, metric + + +@dataclass +class _PerStatusTrialTableData: + trial_infos: List[List[str]] + more_info: str + + +@dataclass +class _TrialTableData: + header: List[str] + data: List[_PerStatusTrialTableData] + + +def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any: + """Abbreviate a string representation of an object to `max_len` characters. + + For numbers, booleans and None, the original value will be returned for + correct rendering in the table formatting tool. + + Args: + value: Object to be represented as a string. + max_len: Maximum return string length. + """ + if value is None or isinstance(value, (int, float, numbers.Number, bool)): + return value + + string = str(value) + if len(string) <= max_len: + return string + + if wrap: + # Maximum two rows. + # Todo: Make this configurable in the refactor + if len(value) > max_len * 2: + value = "..." + string[(3 - (max_len * 2)) :] + + wrapped = textwrap.wrap(value, width=max_len) + return "\n".join(wrapped) + + result = "..." + string[(3 - max_len) :] + return result + + +def _get_trial_info( + trial: Trial, param_keys: List[str], metric_keys: List[str] +) -> List[str]: + """Returns the following information about a trial: + + name | status | metrics... + + Args: + trial: Trial to get information for. + param_keys: Names of parameters to include. + metric_keys: Names of metrics to include. + """ + result = trial.last_result + trial_info = [str(trial), trial.status] + + # params + trial_info.extend( + [ + _max_len( + unflattened_lookup(param, trial.config, default=None), + ) + for param in param_keys + ] + ) + # metrics + trial_info.extend( + [ + _max_len( + unflattened_lookup(metric, result, default=None), + ) + for metric in metric_keys + ] + ) + return trial_info + + +def _get_trial_table_data_per_status( + status: str, + trials: List[Trial], + param_keys: List[str], + metric_keys: List[str], + force_max_rows: bool = False, +) -> Optional[_PerStatusTrialTableData]: + """Gather all information of trials pertained to one `status`. + + Args: + status: The trial status of interest. + trials: all the trials of that status. + param_keys: *Ordered* list of parameters to be displayed in the table. + metric_keys: *Ordered* list of metrics to be displayed in the table. + Including both default and user defined. + force_max_rows: Whether or not to enforce a max row number for this status. + If True, only a max of `5` rows will be shown. + + Returns: + All information of trials pertained to the `status`. + """ + # TODO: configure it. + max_row = 5 if force_max_rows else math.inf + if not trials: + return None + + trial_infos = list() + more_info = None + for t in trials: + if len(trial_infos) >= max_row: + remaining = len(trials) - max_row + more_info = f"{remaining} more {status}" + break + trial_infos.append(_get_trial_info(t, param_keys, metric_keys)) + return _PerStatusTrialTableData(trial_infos, more_info) + + +def _get_trial_table_data( + trials: List[Trial], + param_keys: List[str], + metric_keys: List[str], + all_rows: bool = False, + wrap_headers: bool = False, +) -> _TrialTableData: + """Generate a table showing the current progress of tuning trials. + + Args: + trials: List of trials for which progress is to be shown. + param_keys: Ordered list of parameters to be displayed in the table. + metric_keys: Ordered list of metrics to be displayed in the table. + Including both default and user defined. + Will only be shown if at least one trial is having the key. + all_rows: Force to show all rows. + wrap_headers: If True, header columns can be wrapped with ``\n``. + + Returns: + Trial table data, including header and trial table per each status. + """ + # TODO: configure + max_trial_num_to_show = 20 + max_column_length = 20 + trials_by_state = _get_trials_by_state(trials) + + # get the right metric to show. + metric_keys = [ + k + for k in metric_keys + if any( + unflattened_lookup(k, t.last_result, default=None) is not None + for t in trials + ) + ] + + # get header from metric keys + formatted_metric_columns = [ + _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys + ] + + formatted_param_columns = [ + _max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys + ] + + metric_header = [ + DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted + for metric, formatted in zip(metric_keys, formatted_metric_columns) + ] + + param_header = formatted_param_columns + + # Map to the abbreviated version if necessary. + header = ["Trial name", "status"] + param_header + metric_header + + trial_data = list() + for t_status in ORDER: + trial_data_per_status = _get_trial_table_data_per_status( + t_status, + trials_by_state[t_status], + param_keys=param_keys, + metric_keys=metric_keys, + force_max_rows=not all_rows and len(trials) > max_trial_num_to_show, + ) + if trial_data_per_status: + trial_data.append(trial_data_per_status) + return _TrialTableData(header, trial_data) + + +def _best_trial_str( + trial: Trial, + metric: str, +): + """Returns a readable message stating the current best trial.""" + # returns something like + # Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa + val = unflattened_lookup(metric, trial.last_result, default=None) + config = trial.last_result.get("config", {}) + parameter_columns = list(config.keys()) + params = {p: unflattened_lookup(p, config) for p in parameter_columns} + return ( + f"Current best trial: {trial.trial_id} with {metric}={val} and " + f"params={params}" + ) + + +def _render_table_item( + key: str, item: Any, prefix: str = "" +) -> Iterable[Tuple[str, str]]: + key = prefix + key + + if isinstance(item, argparse.Namespace): + item = item.__dict__ + + if isinstance(item, float): + # tabulate does not work well with mixed-type columns, so we format + # numbers ourselves. + yield key, f"{item:.5f}".rstrip("0") + elif isinstance(item, dict): + flattened = flatten_dict(item) + for k, v in sorted(flattened.items()): + yield key + "/" + str(k), _max_len(v) + else: + yield key, _max_len(item, 20) + + +def _get_dict_as_table_data( + data: Dict, + include: Optional[Collection] = None, + exclude: Optional[Collection] = None, + upper_keys: Optional[Collection] = None, +): + """Get ``data`` dict as table rows. + + If specified, excluded keys are removed. Excluded keys can either be + fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary + (e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is + needed, we can revisit the logic at a later point. + + The same is true for included keys. If a top-level key is included (e.g. ``foo``) + then all sub keys will be included, too, except if they are excluded. + + If keys are both excluded and included, exclusion takes precedence. Thus, if + ``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output. + """ + include = include or set() + exclude = exclude or set() + upper_keys = upper_keys or set() + + upper = [] + lower = [] + + for key, value in sorted(data.items()): + # Exclude top-level keys + if key in exclude: + continue + + for k, v in _render_table_item(str(key), value): + # k is now the full subkey, e.g. config/nested/key + + # We can exclude the full key + if k in exclude: + continue + + # If we specify includes, top-level includes should take precedence + # (e.g. if `config` is in include, include config always). + if include and key not in include and k not in include: + continue + + if key in upper_keys: + upper.append([k, v]) + else: + lower.append([k, v]) + + if not upper: + return lower + elif not lower: + return upper + else: + return upper + lower + + +if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"): + # Copied/adjusted from tabulate + AIR_TABULATE_TABLEFMT = TableFormat( + lineabove=Line("╭", "─", "─", "╮"), + linebelowheader=Line("├", "─", "─", "┤"), + linebetweenrows=None, + linebelow=Line("╰", "─", "─", "╯"), + headerrow=DataRow("│", " ", "│"), + datarow=DataRow("│", " ", "│"), + padding=1, + with_header_hide=None, + ) +else: + # For non-utf output, use ascii-compatible characters. + # This prevents errors e.g. when legacy windows encoding is used. + AIR_TABULATE_TABLEFMT = TableFormat( + lineabove=Line("+", "-", "-", "+"), + linebelowheader=Line("+", "-", "-", "+"), + linebetweenrows=None, + linebelow=Line("+", "-", "-", "+"), + headerrow=DataRow("|", " ", "|"), + datarow=DataRow("|", " ", "|"), + padding=1, + with_header_hide=None, + ) + + +def _print_dict_as_table( + data: Dict, + header: Optional[str] = None, + include: Optional[Collection[str]] = None, + exclude: Optional[Collection[str]] = None, + division: Optional[Collection[str]] = None, +): + table_data = _get_dict_as_table_data( + data=data, include=include, exclude=exclude, upper_keys=division + ) + + headers = [header, ""] if header else [] + + if not table_data: + return + + print( + tabulate( + table_data, + headers=headers, + colalign=("left", "right"), + tablefmt=AIR_TABULATE_TABLEFMT, + ) + ) + + +class ProgressReporter(Callback): + """Periodically prints out status update.""" + + # TODO: Make this configurable + _heartbeat_freq = 30 # every 30 sec + # to be updated by subclasses. + _heartbeat_threshold = None + _start_end_verbosity = None + _intermediate_result_verbosity = None + _addressing_tmpl = None + + def __init__( + self, + verbosity: AirVerbosity, + progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, + ): + """ + + Args: + verbosity: AirVerbosity level. + """ + self._verbosity = verbosity + self._start_time = time.time() + self._last_heartbeat_time = float("-inf") + self._start_time = time.time() + self._progress_metrics = progress_metrics + self._trial_last_printed_results = {} + + self._in_block = None + + @property + def verbosity(self) -> AirVerbosity: + return self._verbosity + + def setup( + self, + start_time: Optional[float] = None, + **kwargs, + ): + self._start_time = start_time + + def _start_block(self, indicator: Any): + if self._in_block != indicator: + self._end_block() + self._in_block = indicator + + def _end_block(self): + if self._in_block: + print("") + self._in_block = None + + def on_experiment_end(self, trials: List["Trial"], **info): + self._end_block() + + def experiment_started( + self, + experiment_name: str, + experiment_path: str, + searcher_str: str, + scheduler_str: str, + total_num_samples: int, + tensorboard_path: Optional[str] = None, + **kwargs, + ): + self._start_block("exp_start") + print(f"\nView detailed results here: {experiment_path}") + + if tensorboard_path: + print( + f"To visualize your results with TensorBoard, run: " + f"`tensorboard --logdir {tensorboard_path}`" + ) + + @property + def _time_heartbeat_str(self): + current_time_str, running_time_str = _get_time_str( + self._start_time, time.time() + ) + return ( + f"Current time: {current_time_str}. Total running time: " + running_time_str + ) + + def print_heartbeat(self, trials, *args, force: bool = False): + if self._verbosity < self._heartbeat_threshold: + return + if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq: + self._print_heartbeat(trials, *args, force=force) + self._last_heartbeat_time = time.time() + + def _print_heartbeat(self, trials, *args, force: bool = False): + raise NotImplementedError + + def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False): + """Only print result if a different result has been reported, or force=True""" + result = result or trial.last_result + + last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1) + this_iter = result.get(TRAINING_ITERATION, 0) + + if this_iter != last_result_iter or force: + _print_dict_as_table( + result, + header=f"{self._addressing_tmpl.format(trial)} result", + include=self._progress_metrics, + exclude=BLACKLISTED_KEYS, + division=AUTO_RESULT_KEYS, + ) + self._trial_last_printed_results[trial.trial_id] = this_iter + + def _print_config(self, trial): + _print_dict_as_table( + trial.config, header=f"{self._addressing_tmpl.format(trial)} config" + ) + + def on_trial_result( + self, + iteration: int, + trials: List[Trial], + trial: Trial, + result: Dict, + **info, + ): + if self.verbosity < self._intermediate_result_verbosity: + return + self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}") + curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) + print( + f"{self._addressing_tmpl.format(trial)} " + f"finished iteration {result[TRAINING_ITERATION]} " + f"at {curr_time_str}. Total running time: " + running_time_str + ) + self._print_result(trial, result) + + def on_trial_complete( + self, iteration: int, trials: List[Trial], trial: Trial, **info + ): + if self.verbosity < self._start_end_verbosity: + return + curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) + finished_iter = 0 + if trial.last_result and TRAINING_ITERATION in trial.last_result: + finished_iter = trial.last_result[TRAINING_ITERATION] + + self._start_block(f"trial_{trial}_complete") + print( + f"{self._addressing_tmpl.format(trial)} " + f"completed after {finished_iter} iterations " + f"at {curr_time_str}. Total running time: " + running_time_str + ) + self._print_result(trial) + + def on_trial_error( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + curr_time_str, running_time_str = _get_time_str(self._start_time, time.time()) + finished_iter = 0 + if trial.last_result and TRAINING_ITERATION in trial.last_result: + finished_iter = trial.last_result[TRAINING_ITERATION] + + self._start_block(f"trial_{trial}_error") + print( + f"{self._addressing_tmpl.format(trial)} " + f"errored after {finished_iter} iterations " + f"at {curr_time_str}. Total running time: {running_time_str}\n" + f"Error file: {trial.error_file}" + ) + self._print_result(trial) + + def on_trial_recover( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info) + + def on_checkpoint( + self, + iteration: int, + trials: List[Trial], + trial: Trial, + checkpoint: Checkpoint, + **info, + ): + if self._verbosity < self._intermediate_result_verbosity: + return + # don't think this is supposed to happen but just to be safe. + saved_iter = "?" + if trial.last_result and TRAINING_ITERATION in trial.last_result: + saved_iter = trial.last_result[TRAINING_ITERATION] + + self._start_block(f"trial_{trial}_result_{saved_iter}") + + loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}" + + print( + f"{self._addressing_tmpl.format(trial)} " + f"saved a checkpoint for iteration {saved_iter} " + f"at: {loc}" + ) + + def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info): + if self.verbosity < self._start_end_verbosity: + return + has_config = bool(trial.config) + + self._start_block(f"trial_{trial}_start") + if has_config: + print( + f"{self._addressing_tmpl.format(trial)} " f"started with configuration:" + ) + self._print_config(trial) + else: + print( + f"{self._addressing_tmpl.format(trial)} " + f"started without custom configuration." + ) + + +def _detect_reporter( + verbosity: AirVerbosity, + num_samples: int, + entrypoint: Optional[AirEntrypoint] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + config: Optional[Dict] = None, + progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, +): + if entrypoint in { + AirEntrypoint.TUNE_RUN, + AirEntrypoint.TUNE_RUN_EXPERIMENTS, + AirEntrypoint.TUNER, + }: + reporter = TuneTerminalReporter( + verbosity, + num_samples=num_samples, + metric=metric, + mode=mode, + config=config, + progress_metrics=progress_metrics, + ) + else: + reporter = TrainReporter(verbosity, progress_metrics=progress_metrics) + return reporter + + +class TuneReporterBase(ProgressReporter): + _heartbeat_threshold = AirVerbosity.DEFAULT + _wrap_headers = False + _intermediate_result_verbosity = AirVerbosity.VERBOSE + _start_end_verbosity = AirVerbosity.DEFAULT + _addressing_tmpl = "Trial {}" + + def __init__( + self, + verbosity: AirVerbosity, + num_samples: int = 0, + metric: Optional[str] = None, + mode: Optional[str] = None, + config: Optional[Dict] = None, + progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None, + ): + self._num_samples = num_samples + self._metric = metric + self._mode = mode + # will be populated when first result comes in. + self._inferred_metric = None + self._inferred_params = _infer_params(config or {}) + super(TuneReporterBase, self).__init__( + verbosity=verbosity, progress_metrics=progress_metrics + ) + + def setup( + self, + start_time: Optional[float] = None, + total_samples: Optional[int] = None, + **kwargs, + ): + super().setup(start_time=start_time) + self._num_samples = total_samples + + def _get_overall_trial_progress_str(self, trials): + result = " | ".join( + [ + f"{len(trials)} {status}" + for status, trials in _get_trials_by_state(trials).items() + ] + ) + return f"Trial status: {result}" + + # TODO: Return a more structured type to share code with Jupyter flow. + def _get_heartbeat( + self, trials, *sys_args, force_full_output: bool = False + ) -> Tuple[List[str], _TrialTableData]: + result = list() + # Trial status: 1 RUNNING | 7 PENDING + result.append(self._get_overall_trial_progress_str(trials)) + # Current time: 2023-02-24 12:35:39 (running for 00:00:37.40) + result.append(self._time_heartbeat_str) + # Logical resource usage: 8.0/64 CPUs, 0/0 GPUs + result.extend(sys_args) + # Current best trial: TRIAL NAME, metrics: {...}, parameters: {...} + current_best_trial, metric = _current_best_trial( + trials, self._metric, self._mode + ) + if current_best_trial: + result.append(_best_trial_str(current_best_trial, metric)) + # Now populating the trial table data. + if not self._inferred_metric: + # try inferring again. + self._inferred_metric = _infer_user_metrics(trials) + + all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric + + trial_table_data = _get_trial_table_data( + trials, + param_keys=self._inferred_params, + metric_keys=all_metrics, + all_rows=force_full_output, + wrap_headers=self._wrap_headers, + ) + return result, trial_table_data + + def _print_heartbeat(self, trials, *sys_args, force: bool = False): + raise NotImplementedError + + +class TuneTerminalReporter(TuneReporterBase): + def experiment_started( + self, + experiment_name: str, + experiment_path: str, + searcher_str: str, + scheduler_str: str, + total_num_samples: int, + tensorboard_path: Optional[str] = None, + **kwargs, + ): + if total_num_samples > sys.maxsize: + total_num_samples_str = "infinite" + else: + total_num_samples_str = str(total_num_samples) + + print( + tabulate( + [ + ["Search algorithm", searcher_str], + ["Scheduler", scheduler_str], + ["Number of trials", total_num_samples_str], + ], + headers=["Configuration for experiment", experiment_name], + tablefmt=AIR_TABULATE_TABLEFMT, + ) + ) + super().experiment_started( + experiment_name=experiment_name, + experiment_path=experiment_path, + searcher_str=searcher_str, + scheduler_str=scheduler_str, + total_num_samples=total_num_samples, + tensorboard_path=tensorboard_path, + **kwargs, + ) + + def _print_heartbeat(self, trials, *sys_args, force: bool = False): + if self._verbosity < self._heartbeat_threshold and not force: + return + heartbeat_strs, table_data = self._get_heartbeat( + trials, *sys_args, force_full_output=force + ) + + self._start_block("heartbeat") + for s in heartbeat_strs: + print(s) + # now print the table using Tabulate + more_infos = [] + all_data = [] + fail_header = table_data.header + for sub_table in table_data.data: + all_data.extend(sub_table.trial_infos) + if sub_table.more_info: + more_infos.append(sub_table.more_info) + + print( + tabulate( + all_data, + headers=fail_header, + tablefmt=AIR_TABULATE_TABLEFMT, + showindex=False, + ) + ) + if more_infos: + print(", ".join(more_infos)) + + if not force: + # Only print error table at end of training + return + + trials_with_error = _get_trials_with_error(trials) + if not trials_with_error: + return + + self._start_block("status_errored") + print(f"Number of errored trials: {len(trials_with_error)}") + fail_header = ["Trial name", "# failures", "error file"] + fail_table_data = [ + [ + str(trial), + str(trial.run_metadata.num_failures) + + ("" if trial.status == Trial.ERROR else "*"), + trial.error_file, + ] + for trial in trials_with_error + ] + print( + tabulate( + fail_table_data, + headers=fail_header, + tablefmt=AIR_TABULATE_TABLEFMT, + showindex=False, + colalign=("left", "right", "left"), + ) + ) + if any(trial.status == Trial.TERMINATED for trial in trials_with_error): + print("* The trial terminated successfully after retrying.") + + +class TrainReporter(ProgressReporter): + # the minimal verbosity threshold at which heartbeat starts getting printed. + _heartbeat_threshold = AirVerbosity.VERBOSE + _intermediate_result_verbosity = AirVerbosity.DEFAULT + _start_end_verbosity = AirVerbosity.DEFAULT + _addressing_tmpl = "Training" + + def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False): + # Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa + if len(trials) == 0: + return + trial = trials[0] + if trial.status != Trial.RUNNING: + return " ".join( + [f"Training is in {trial.status} status.", self._time_heartbeat_str] + ) + if not trial.last_result or TRAINING_ITERATION not in trial.last_result: + iter_num = 1 + else: + iter_num = trial.last_result[TRAINING_ITERATION] + 1 + return " ".join( + [f"Training on iteration {iter_num}.", self._time_heartbeat_str] + ) + + def _print_heartbeat(self, trials, *args, force: bool = False): + print(self._get_heartbeat(trials, force_full_output=force)) + + def on_trial_result( + self, + iteration: int, + trials: List[Trial], + trial: Trial, + result: Dict, + **info, + ): + self._last_heartbeat_time = time.time() + super().on_trial_result( + iteration=iteration, trials=trials, trial=trial, result=result, **info + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fd315308bab32e576eeaaf3f4af10efc4fe6bf2e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py @@ -0,0 +1,32 @@ +from ray.tune.logger.csv import CSVLogger, CSVLoggerCallback +from ray.tune.logger.json import JsonLogger, JsonLoggerCallback +from ray.tune.logger.logger import ( + LegacyLoggerCallback, + Logger, + LoggerCallback, + pretty_print, +) +from ray.tune.logger.noop import NoopLogger +from ray.tune.logger.tensorboardx import TBXLogger, TBXLoggerCallback + +DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger) + +# isort: off +from ray.tune.logger.unified import UnifiedLogger # noqa: E402 + +# isort: on + +__all__ = [ + "Logger", + "LoggerCallback", + "LegacyLoggerCallback", + "pretty_print", + "CSVLogger", + "CSVLoggerCallback", + "JsonLogger", + "JsonLoggerCallback", + "NoopLogger", + "TBXLogger", + "TBXLoggerCallback", + "UnifiedLogger", +] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f646f0675528fc461d689f50140f1130c34a3138 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24cac3c4dbd546489adc45d596391c017e3d352b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..192819013ad944d20bbda273edf27bf1a9c428b7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fd200c3e3db85f81440bd3b2455442393a0e82c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a4ef05c297768e1a1cdc7b5e2e83085e3f86340 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa388c6e92bfaea64970d80a23456edd54f62baa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd5d1fa79f8aa675dd5b34c2f36f6191c08c07b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..196fd87a3db26fafb22fb79f0ba0d99b1d16d2d3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85396d8e6e29a8a78e46c09691937e9bace2abf7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46707a956025c09aab1603c2410804aaafb5b8b8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d72b05894d9c2259657334cba39775801f0344b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/aim.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/aim.py new file mode 100644 index 0000000000000000000000000000000000000000..863df7f46fe1f17793ad4afe04f0fe87806f3913 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/aim.py @@ -0,0 +1,187 @@ +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import numpy as np + +from ray.air.constants import TRAINING_ITERATION +from ray.tune.logger.logger import LoggerCallback +from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL +from ray.tune.utils import flatten_dict +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial + +try: + from aim.sdk import Repo, Run +except ImportError: + Repo, Run = None, None + +logger = logging.getLogger(__name__) + +VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64] + + +@PublicAPI +class AimLoggerCallback(LoggerCallback): + """Aim Logger: logs metrics in Aim format. + + Aim is an open-source, self-hosted ML experiment tracking tool. + It's good at tracking lots (thousands) of training runs, and it allows you to + compare them with a performant and well-designed UI. + + Source: https://github.com/aimhubio/aim + + Args: + repo: Aim repository directory or a `Repo` object that the Run object will + log results to. If not provided, a default repo will be set up in the + experiment directory (one level above trial directories). + experiment: Sets the `experiment` property of each Run object, which is the + experiment name associated with it. Can be used later to query + runs/sequences. + If not provided, the default will be the Tune experiment name set + by `RunConfig(name=...)`. + metrics: List of metric names (out of the metrics reported by Tune) to + track in Aim. If no metric are specified, log everything that + is reported. + aim_run_kwargs: Additional arguments that will be passed when creating the + individual `Run` objects for each trial. For the full list of arguments, + please see the Aim documentation: + https://aimstack.readthedocs.io/en/latest/refs/sdk.html + """ + + VALID_HPARAMS = (str, bool, int, float, list, type(None)) + VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64) + + def __init__( + self, + repo: Optional[Union[str, "Repo"]] = None, + experiment_name: Optional[str] = None, + metrics: Optional[List[str]] = None, + **aim_run_kwargs, + ): + """ + See help(AimLoggerCallback) for more information about parameters. + """ + assert Run is not None, ( + "aim must be installed!. You can install aim with" + " the command: `pip install aim`." + ) + self._repo_path = repo + self._experiment_name = experiment_name + if not (bool(metrics) or metrics is None): + raise ValueError( + "`metrics` must either contain at least one metric name, or be None, " + "in which case all reported metrics will be logged to the aim repo." + ) + self._metrics = metrics + self._aim_run_kwargs = aim_run_kwargs + self._trial_to_run: Dict["Trial", Run] = {} + + def _create_run(self, trial: "Trial") -> Run: + """Initializes an Aim Run object for a given trial. + + Args: + trial: The Tune trial that aim will track as a Run. + + Returns: + Run: The created aim run for a specific trial. + """ + experiment_dir = trial.local_experiment_path + run = Run( + repo=self._repo_path or experiment_dir, + experiment=self._experiment_name or trial.experiment_dir_name, + **self._aim_run_kwargs, + ) + # Attach a few useful trial properties + run["trial_id"] = trial.trial_id + run["trial_log_dir"] = trial.path + trial_ip = trial.get_ray_actor_ip() + if trial_ip: + run["trial_ip"] = trial_ip + return run + + def log_trial_start(self, trial: "Trial"): + if trial in self._trial_to_run: + # Cleanup an existing run if the trial has been restarted + self._trial_to_run[trial].close() + + trial.init_local_path() + self._trial_to_run[trial] = self._create_run(trial) + + if trial.evaluated_params: + self._log_trial_hparams(trial) + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + tmp_result = result.copy() + + step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION] + + for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]: + tmp_result.pop(k, None) # not useful to log these + + # `context` and `epoch` are special keys that users can report, + # which are treated as special aim metrics/configurations. + context = tmp_result.pop("context", None) + epoch = tmp_result.pop("epoch", None) + + trial_run = self._trial_to_run[trial] + path = ["ray", "tune"] + + flat_result = flatten_dict(tmp_result, delimiter="/") + valid_result = {} + + for attr, value in flat_result.items(): + if self._metrics and attr not in self._metrics: + continue + + full_attr = "/".join(path + [attr]) + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not ( + np.isnan(value) or np.isinf(value) + ): + valid_result[attr] = value + trial_run.track( + value=value, + name=full_attr, + epoch=epoch, + step=step, + context=context, + ) + elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or ( + isinstance(value, np.ndarray) and value.size > 0 + ): + valid_result[attr] = value + + def log_trial_end(self, trial: "Trial", failed: bool = False): + trial_run = self._trial_to_run.pop(trial) + trial_run.close() + + def _log_trial_hparams(self, trial: "Trial"): + params = flatten_dict(trial.evaluated_params, delimiter="/") + flat_params = flatten_dict(params) + + scrubbed_params = { + k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS) + } + + np_params = { + k: v.tolist() + for k, v in flat_params.items() + if isinstance(v, self.VALID_NP_HPARAMS) + } + + scrubbed_params.update(np_params) + removed = { + k: v + for k, v in flat_params.items() + if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS) + } + if removed: + logger.info( + "Removed the following hyperparameter values when " + "logging to aim: %s", + str(removed), + ) + + run = self._trial_to_run[trial] + run["hparams"] = scrubbed_params diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/comet.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/comet.py new file mode 100644 index 0000000000000000000000000000000000000000..31dfeafe670d70aa811db7152448390b7109b4d6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/comet.py @@ -0,0 +1,3 @@ +from ray.air.integrations.comet import CometLoggerCallback + +CometLoggerCallback.__module__ = "ray.tune.logger.comet" diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/csv.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/csv.py new file mode 100644 index 0000000000000000000000000000000000000000..5802b43f893d4240ad878787dca8398e75db0c2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/csv.py @@ -0,0 +1,135 @@ +import csv +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Dict, TextIO + +from ray.air.constants import EXPR_PROGRESS_FILE +from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback +from ray.tune.utils import flatten_dict +from ray.util.annotations import Deprecated, PublicAPI + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial # noqa: F401 + +logger = logging.getLogger(__name__) + + +@Deprecated( + message=_LOGGER_DEPRECATION_WARNING.format( + old="CSVLogger", new="ray.tune.csv.CSVLoggerCallback" + ), + warning=True, +) +@PublicAPI +class CSVLogger(Logger): + """Logs results to progress.csv under the trial directory. + + Automatically flattens nested dicts in the result dict before writing + to csv: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + + """ + + def _init(self): + self._initialized = False + + def _maybe_init(self): + """CSV outputted with Headers as first set of results.""" + if not self._initialized: + progress_file = Path(self.logdir, EXPR_PROGRESS_FILE) + self._continuing = ( + progress_file.exists() and progress_file.stat().st_size > 0 + ) + self._file = progress_file.open("a") + self._csv_out = None + self._initialized = True + + def on_result(self, result: Dict): + self._maybe_init() + + tmp = result.copy() + if "config" in tmp: + del tmp["config"] + result = flatten_dict(tmp, delimiter="/") + if self._csv_out is None: + self._csv_out = csv.DictWriter(self._file, result.keys()) + if not self._continuing: + self._csv_out.writeheader() + self._csv_out.writerow( + {k: v for k, v in result.items() if k in self._csv_out.fieldnames} + ) + self._file.flush() + + def flush(self): + if self._initialized and not self._file.closed: + self._file.flush() + + def close(self): + if self._initialized: + self._file.close() + + +@PublicAPI +class CSVLoggerCallback(LoggerCallback): + """Logs results to progress.csv under the trial directory. + + Automatically flattens nested dicts in the result dict before writing + to csv: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + + """ + + _SAVED_FILE_TEMPLATES = [EXPR_PROGRESS_FILE] + + def __init__(self): + self._trial_continue: Dict["Trial", bool] = {} + self._trial_files: Dict["Trial", TextIO] = {} + self._trial_csv: Dict["Trial", csv.DictWriter] = {} + + def _setup_trial(self, trial: "Trial"): + if trial in self._trial_files: + self._trial_files[trial].close() + + # Make sure logdir exists + trial.init_local_path() + local_file_path = Path(trial.local_path, EXPR_PROGRESS_FILE) + + # Resume the file from remote storage. + self._restore_from_remote(EXPR_PROGRESS_FILE, trial) + + self._trial_continue[trial] = ( + local_file_path.exists() and local_file_path.stat().st_size > 0 + ) + + self._trial_files[trial] = local_file_path.open("at") + self._trial_csv[trial] = None + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_files: + self._setup_trial(trial) + + tmp = result.copy() + tmp.pop("config", None) + result = flatten_dict(tmp, delimiter="/") + + if not self._trial_csv[trial]: + self._trial_csv[trial] = csv.DictWriter( + self._trial_files[trial], result.keys() + ) + if not self._trial_continue[trial]: + self._trial_csv[trial].writeheader() + + self._trial_csv[trial].writerow( + {k: v for k, v in result.items() if k in self._trial_csv[trial].fieldnames} + ) + self._trial_files[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial not in self._trial_files: + return + + del self._trial_csv[trial] + self._trial_files[trial].close() + del self._trial_files[trial] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/json.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/json.py new file mode 100644 index 0000000000000000000000000000000000000000..d248a40802967b376f6e489dd268800642599f30 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/json.py @@ -0,0 +1,128 @@ +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Dict, TextIO + +import numpy as np + +import ray.cloudpickle as cloudpickle +from ray.air.constants import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, EXPR_RESULT_FILE +from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback +from ray.tune.utils.util import SafeFallbackEncoder +from ray.util.annotations import Deprecated, PublicAPI + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial # noqa: F401 + +logger = logging.getLogger(__name__) + +tf = None +VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64] + + +@Deprecated( + message=_LOGGER_DEPRECATION_WARNING.format( + old="JsonLogger", new="ray.tune.json.JsonLoggerCallback" + ), + warning=True, +) +@PublicAPI +class JsonLogger(Logger): + """Logs trial results in json format. + + Also writes to a results file and param.json file when results or + configurations are updated. Experiments must be executed with the + JsonLogger to be compatible with the ExperimentAnalysis tool. + """ + + def _init(self): + self.update_config(self.config) + local_file = Path(self.logdir, EXPR_RESULT_FILE) + self.local_out = local_file.open("a") + + def on_result(self, result: Dict): + json.dump(result, self, cls=SafeFallbackEncoder) + self.write("\n") + self.local_out.flush() + + def write(self, b): + self.local_out.write(b) + + def flush(self): + if not self.local_out.closed: + self.local_out.flush() + + def close(self): + self.local_out.close() + + def update_config(self, config: Dict): + self.config = config + config_out = Path(self.logdir, EXPR_PARAM_FILE) + with open(config_out, "w") as f: + json.dump(self.config, f, indent=2, sort_keys=True, cls=SafeFallbackEncoder) + config_pkl = Path(self.logdir, EXPR_PARAM_PICKLE_FILE) + with config_pkl.open("wb") as f: + cloudpickle.dump(self.config, f) + + +@PublicAPI +class JsonLoggerCallback(LoggerCallback): + """Logs trial results in json format. + + Also writes to a results file and param.json file when results or + configurations are updated. Experiments must be executed with the + JsonLoggerCallback to be compatible with the ExperimentAnalysis tool. + """ + + _SAVED_FILE_TEMPLATES = [EXPR_RESULT_FILE, EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE] + + def __init__(self): + self._trial_configs: Dict["Trial", Dict] = {} + self._trial_files: Dict["Trial", TextIO] = {} + + def log_trial_start(self, trial: "Trial"): + if trial in self._trial_files: + self._trial_files[trial].close() + + # Update config + self.update_config(trial, trial.config) + + # Make sure logdir exists + trial.init_local_path() + local_file = Path(trial.local_path, EXPR_RESULT_FILE) + + # Resume the file from remote storage. + self._restore_from_remote(EXPR_RESULT_FILE, trial) + + self._trial_files[trial] = local_file.open("at") + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_files: + self.log_trial_start(trial) + json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder) + self._trial_files[trial].write("\n") + self._trial_files[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial not in self._trial_files: + return + + self._trial_files[trial].close() + del self._trial_files[trial] + + def update_config(self, trial: "Trial", config: Dict): + self._trial_configs[trial] = config + + config_out = Path(trial.local_path, EXPR_PARAM_FILE) + with config_out.open("w") as f: + json.dump( + self._trial_configs[trial], + f, + indent=2, + sort_keys=True, + cls=SafeFallbackEncoder, + ) + + config_pkl = Path(trial.local_path, EXPR_PARAM_PICKLE_FILE) + with config_pkl.open("wb") as f: + cloudpickle.dump(self._trial_configs[trial], f) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/logger.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ad14069c3c20e8f17f5a5275ccb7e89758dcd7d4 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/logger.py @@ -0,0 +1,259 @@ +import abc +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Type + +import pyarrow +import yaml + +from ray.air._internal.json import SafeFallbackEncoder +from ray.tune.callback import Callback +from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial # noqa: F401 + +logger = logging.getLogger(__name__) + + +# Apply flow style for sequences of this length +_SEQUENCE_LEN_FLOW_STYLE = 3 + +_LOGGER_DEPRECATION_WARNING = ( + "The `{old} interface is deprecated in favor of the " + "`{new}` interface and will be removed in Ray 2.7." +) + + +@Deprecated( + message=_LOGGER_DEPRECATION_WARNING.format( + old="Logger", new="ray.tune.logger.LoggerCallback" + ), +) +@DeveloperAPI +class Logger(abc.ABC): + """Logging interface for ray.tune. + + By default, the UnifiedLogger implementation is used which logs results in + multiple formats (TensorBoard, rllab/viskit, plain json, custom loggers) + at once. + + Arguments: + config: Configuration passed to all logger creators. + logdir: Directory for all logger creators to log to. + trial: Trial object for the logger to access. + """ + + def __init__(self, config: Dict, logdir: str, trial: Optional["Trial"] = None): + self.config = config + self.logdir = logdir + self.trial = trial + self._init() + + def _init(self): + pass + + def on_result(self, result): + """Given a result, appends it to the existing log.""" + + raise NotImplementedError + + def update_config(self, config): + """Updates the config for logger.""" + + pass + + def close(self): + """Releases all resources used by this logger.""" + + pass + + def flush(self): + """Flushes all disk writes to storage.""" + + pass + + +@PublicAPI +class LoggerCallback(Callback): + """Base class for experiment-level logger callbacks + + This base class defines a general interface for logging events, + like trial starts, restores, ends, checkpoint saves, and receiving + trial results. + + Callbacks implementing this interface should make sure that logging + utilities are cleaned up properly on trial termination, i.e. when + ``log_trial_end`` is received. This includes e.g. closing files. + """ + + def log_trial_start(self, trial: "Trial"): + """Handle logging when a trial starts. + + Args: + trial: Trial object. + """ + pass + + def log_trial_restore(self, trial: "Trial"): + """Handle logging when a trial restores. + + Args: + trial: Trial object. + """ + pass + + def log_trial_save(self, trial: "Trial"): + """Handle logging when a trial saves a checkpoint. + + Args: + trial: Trial object. + """ + pass + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + """Handle logging when a trial reports a result. + + Args: + trial: Trial object. + result: Result dictionary. + """ + pass + + def log_trial_end(self, trial: "Trial", failed: bool = False): + """Handle logging when a trial ends. + + Args: + trial: Trial object. + failed: True if the Trial finished gracefully, False if + it failed (e.g. when it raised an exception). + """ + pass + + def on_trial_result( + self, + iteration: int, + trials: List["Trial"], + trial: "Trial", + result: Dict, + **info, + ): + self.log_trial_result(iteration, trial, result) + + def on_trial_start( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_trial_start(trial) + + def on_trial_restore( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_trial_restore(trial) + + def on_trial_save( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_trial_save(trial) + + def on_trial_complete( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_trial_end(trial, failed=False) + + def on_trial_error( + self, iteration: int, trials: List["Trial"], trial: "Trial", **info + ): + self.log_trial_end(trial, failed=True) + + def _restore_from_remote(self, file_name: str, trial: "Trial") -> None: + if not trial.checkpoint: + # If there's no checkpoint, there's no logging artifacts to restore + # since we're starting from scratch. + return + + local_file = Path(trial.local_path, file_name).as_posix() + remote_file = Path(trial.storage.trial_fs_path, file_name).as_posix() + + try: + pyarrow.fs.copy_files( + remote_file, + local_file, + source_filesystem=trial.storage.storage_filesystem, + ) + logger.debug(f"Copied {remote_file} to {local_file}") + except FileNotFoundError: + logger.warning(f"Remote file not found: {remote_file}") + except Exception: + logger.exception(f"Error downloading {remote_file}") + + +@DeveloperAPI +class LegacyLoggerCallback(LoggerCallback): + """Supports logging to trial-specific `Logger` classes. + + Previously, Ray Tune logging was handled via `Logger` classes that have + been instantiated per-trial. This callback is a fallback to these + `Logger`-classes, instantiating each `Logger` class for each trial + and logging to them. + + Args: + logger_classes: Logger classes that should + be instantiated for each trial. + + """ + + def __init__(self, logger_classes: Iterable[Type[Logger]]): + self.logger_classes = list(logger_classes) + self._class_trial_loggers: Dict[Type[Logger], Dict["Trial", Logger]] = {} + + def log_trial_start(self, trial: "Trial"): + trial.init_local_path() + + for logger_class in self.logger_classes: + trial_loggers = self._class_trial_loggers.get(logger_class, {}) + if trial not in trial_loggers: + logger = logger_class(trial.config, trial.local_path, trial) + trial_loggers[trial] = logger + self._class_trial_loggers[logger_class] = trial_loggers + + def log_trial_restore(self, trial: "Trial"): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].flush() + + def log_trial_save(self, trial: "Trial"): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].flush() + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].on_result(result) + + def log_trial_end(self, trial: "Trial", failed: bool = False): + for logger_class, trial_loggers in self._class_trial_loggers.items(): + if trial in trial_loggers: + trial_loggers[trial].close() + + +class _RayDumper(yaml.SafeDumper): + def represent_sequence(self, tag, sequence, flow_style=None): + if len(sequence) > _SEQUENCE_LEN_FLOW_STYLE: + return super().represent_sequence(tag, sequence, flow_style=True) + return super().represent_sequence(tag, sequence, flow_style=flow_style) + + +@DeveloperAPI +def pretty_print(result, exclude: Optional[Set[str]] = None): + result = result.copy() + result.update(config=None) # drop config from pretty print + result.update(hist_stats=None) # drop hist_stats from pretty print + out = {} + for k, v in result.items(): + if v is not None and (exclude is None or k not in exclude): + out[k] = v + + cleaned = json.dumps(out, cls=SafeFallbackEncoder) + return yaml.dump(json.loads(cleaned), Dumper=_RayDumper, default_flow_style=False) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/mlflow.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..ac00b3d5155441988e84c69a38ed5e2444e48b78 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/mlflow.py @@ -0,0 +1,3 @@ +from ray.air.integrations.mlflow import MLflowLoggerCallback + +MLflowLoggerCallback.__module__ = "ray.tune.logger.mlflow" diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/noop.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/noop.py new file mode 100644 index 0000000000000000000000000000000000000000..a9bae96b7cd7209cf61e8b2219c721e0eec88bcd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/noop.py @@ -0,0 +1,9 @@ +from ray.tune.logger.logger import Logger +from ray.util.annotations import Deprecated, PublicAPI + + +@Deprecated(message="`NoopLogger` will be removed in Ray 2.7.") +@PublicAPI +class NoopLogger(Logger): + def on_result(self, result): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/tensorboardx.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/tensorboardx.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b626ae5053c66693ea22bea56db0b57a9f689f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/tensorboardx.py @@ -0,0 +1,328 @@ +import logging +from typing import TYPE_CHECKING, Dict + +import numpy as np + +from ray.air.constants import TRAINING_ITERATION +from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback +from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL +from ray.tune.utils import flatten_dict +from ray.util.annotations import Deprecated, PublicAPI +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial # noqa: F401 + +logger = logging.getLogger(__name__) + +VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64] + + +@Deprecated( + message=_LOGGER_DEPRECATION_WARNING.format( + old="TBXLogger", new="ray.tune.tensorboardx.TBXLoggerCallback" + ), + warning=True, +) +@PublicAPI +class TBXLogger(Logger): + """TensorBoardX Logger. + + Note that hparams will be written only after a trial has terminated. + This logger automatically flattens nested dicts to show on TensorBoard: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + """ + + VALID_HPARAMS = (str, bool, int, float, list, type(None)) + VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64) + + def _init(self): + try: + from tensorboardX import SummaryWriter + except ImportError: + if log_once("tbx-install"): + logger.info('pip install "ray[tune]" to see TensorBoard files.') + raise + self._file_writer = SummaryWriter(self.logdir, flush_secs=30) + self.last_result = None + + def on_result(self, result: Dict): + step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] + + tmp = result.copy() + for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]: + if k in tmp: + del tmp[k] # not useful to log these + + flat_result = flatten_dict(tmp, delimiter="/") + path = ["ray", "tune"] + valid_result = {} + + for attr, value in flat_result.items(): + full_attr = "/".join(path + [attr]) + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value): + valid_result[full_attr] = value + self._file_writer.add_scalar(full_attr, value, global_step=step) + elif (isinstance(value, list) and len(value) > 0) or ( + isinstance(value, np.ndarray) and value.size > 0 + ): + valid_result[full_attr] = value + + # Must be a single image. + if isinstance(value, np.ndarray) and value.ndim == 3: + self._file_writer.add_image( + full_attr, + value, + global_step=step, + ) + continue + + # Must be a batch of images. + if isinstance(value, np.ndarray) and value.ndim == 4: + self._file_writer.add_images( + full_attr, + value, + global_step=step, + ) + continue + + # Must be video + if isinstance(value, np.ndarray) and value.ndim == 5: + self._file_writer.add_video( + full_attr, value, global_step=step, fps=20 + ) + continue + + try: + self._file_writer.add_histogram(full_attr, value, global_step=step) + # In case TensorboardX still doesn't think it's a valid value + # (e.g. `[[]]`), warn and move on. + except (ValueError, TypeError): + if log_once("invalid_tbx_value"): + logger.warning( + "You are trying to log an invalid value ({}={}) " + "via {}!".format(full_attr, value, type(self).__name__) + ) + + self.last_result = valid_result + self._file_writer.flush() + + def flush(self): + if self._file_writer is not None: + self._file_writer.flush() + + def close(self): + if self._file_writer is not None: + if self.trial and self.trial.evaluated_params and self.last_result: + flat_result = flatten_dict(self.last_result, delimiter="/") + scrubbed_result = { + k: value + for k, value in flat_result.items() + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) + } + self._try_log_hparams(scrubbed_result) + self._file_writer.close() + + def _try_log_hparams(self, result): + # TBX currently errors if the hparams value is None. + flat_params = flatten_dict(self.trial.evaluated_params) + scrubbed_params = { + k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS) + } + + np_params = { + k: v.tolist() + for k, v in flat_params.items() + if isinstance(v, self.VALID_NP_HPARAMS) + } + + scrubbed_params.update(np_params) + + removed = { + k: v + for k, v in flat_params.items() + if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS) + } + if removed: + logger.info( + "Removed the following hyperparameter values when " + "logging to tensorboard: %s", + str(removed), + ) + + from tensorboardX.summary import hparams + + try: + experiment_tag, session_start_tag, session_end_tag = hparams( + hparam_dict=scrubbed_params, metric_dict=result + ) + self._file_writer.file_writer.add_summary(experiment_tag) + self._file_writer.file_writer.add_summary(session_start_tag) + self._file_writer.file_writer.add_summary(session_end_tag) + except Exception: + logger.exception( + "TensorboardX failed to log hparams. " + "This may be due to an unsupported type " + "in the hyperparameter values." + ) + + +@PublicAPI +class TBXLoggerCallback(LoggerCallback): + """TensorBoardX Logger. + + Note that hparams will be written only after a trial has terminated. + This logger automatically flattens nested dicts to show on TensorBoard: + + {"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2} + """ + + _SAVED_FILE_TEMPLATES = ["events.out.tfevents.*"] + + VALID_HPARAMS = (str, bool, int, float, list, type(None)) + VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64) + + def __init__(self): + try: + from tensorboardX import SummaryWriter + + self._summary_writer_cls = SummaryWriter + except ImportError: + if log_once("tbx-install"): + logger.info('pip install "ray[tune]" to see TensorBoard files.') + raise + self._trial_writer: Dict["Trial", SummaryWriter] = {} + self._trial_result: Dict["Trial", Dict] = {} + + def log_trial_start(self, trial: "Trial"): + if trial in self._trial_writer: + self._trial_writer[trial].close() + trial.init_local_path() + self._trial_writer[trial] = self._summary_writer_cls( + trial.local_path, flush_secs=30 + ) + self._trial_result[trial] = {} + + def log_trial_result(self, iteration: int, trial: "Trial", result: Dict): + if trial not in self._trial_writer: + self.log_trial_start(trial) + + step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] + + tmp = result.copy() + for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]: + if k in tmp: + del tmp[k] # not useful to log these + + flat_result = flatten_dict(tmp, delimiter="/") + path = ["ray", "tune"] + valid_result = {} + + for attr, value in flat_result.items(): + full_attr = "/".join(path + [attr]) + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not np.isnan(value): + valid_result[full_attr] = value + self._trial_writer[trial].add_scalar(full_attr, value, global_step=step) + elif (isinstance(value, list) and len(value) > 0) or ( + isinstance(value, np.ndarray) and value.size > 0 + ): + valid_result[full_attr] = value + + # Must be a single image. + if isinstance(value, np.ndarray) and value.ndim == 3: + self._trial_writer[trial].add_image( + full_attr, + value, + global_step=step, + ) + continue + + # Must be a batch of images. + if isinstance(value, np.ndarray) and value.ndim == 4: + self._trial_writer[trial].add_images( + full_attr, + value, + global_step=step, + ) + continue + + # Must be video + if isinstance(value, np.ndarray) and value.ndim == 5: + self._trial_writer[trial].add_video( + full_attr, value, global_step=step, fps=20 + ) + continue + + try: + self._trial_writer[trial].add_histogram( + full_attr, value, global_step=step + ) + # In case TensorboardX still doesn't think it's a valid value + # (e.g. `[[]]`), warn and move on. + except (ValueError, TypeError): + if log_once("invalid_tbx_value"): + logger.warning( + "You are trying to log an invalid value ({}={}) " + "via {}!".format(full_attr, value, type(self).__name__) + ) + + self._trial_result[trial] = valid_result + self._trial_writer[trial].flush() + + def log_trial_end(self, trial: "Trial", failed: bool = False): + if trial in self._trial_writer: + if trial and trial.evaluated_params and self._trial_result[trial]: + flat_result = flatten_dict(self._trial_result[trial], delimiter="/") + scrubbed_result = { + k: value + for k, value in flat_result.items() + if isinstance(value, tuple(VALID_SUMMARY_TYPES)) + } + self._try_log_hparams(trial, scrubbed_result) + self._trial_writer[trial].close() + del self._trial_writer[trial] + del self._trial_result[trial] + + def _try_log_hparams(self, trial: "Trial", result: Dict): + # TBX currently errors if the hparams value is None. + flat_params = flatten_dict(trial.evaluated_params) + scrubbed_params = { + k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS) + } + + np_params = { + k: v.tolist() + for k, v in flat_params.items() + if isinstance(v, self.VALID_NP_HPARAMS) + } + + scrubbed_params.update(np_params) + + removed = { + k: v + for k, v in flat_params.items() + if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS) + } + if removed: + logger.info( + "Removed the following hyperparameter values when " + "logging to tensorboard: %s", + str(removed), + ) + + from tensorboardX.summary import hparams + + try: + experiment_tag, session_start_tag, session_end_tag = hparams( + hparam_dict=scrubbed_params, metric_dict=result + ) + self._trial_writer[trial].file_writer.add_summary(experiment_tag) + self._trial_writer[trial].file_writer.add_summary(session_start_tag) + self._trial_writer[trial].file_writer.add_summary(session_end_tag) + except Exception: + logger.exception( + "TensorboardX failed to log hparams. " + "This may be due to an unsupported type " + "in the hyperparameter values." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/unified.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/unified.py new file mode 100644 index 0000000000000000000000000000000000000000..91af70cbd86c12044e012dd75302733fb4056c33 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/unified.py @@ -0,0 +1,74 @@ +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Type + +from ray.tune.logger import DEFAULT_LOGGERS +from ray.tune.logger.json import JsonLogger +from ray.tune.logger.logger import Logger +from ray.util import log_once +from ray.util.annotations import Deprecated, PublicAPI + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from ray.tune.experiment.trial import Trial # noqa: F401 + + +@Deprecated(message="`UnifiedLogger` will be removed in Ray 2.7.", warning=True) +@PublicAPI +class UnifiedLogger(Logger): + """Unified result logger for TensorBoard, rllab/viskit, plain json. + + Arguments: + config: Configuration passed to all logger creators. + logdir: Directory for all logger creators to log to. + loggers: List of logger creators. Defaults to CSV, Tensorboard, + and JSON loggers. + """ + + def __init__( + self, + config: Dict, + logdir: str, + trial: Optional["Trial"] = None, + loggers: Optional[List[Type[Logger]]] = None, + ): + if loggers is None: + self._logger_cls_list = DEFAULT_LOGGERS + else: + self._logger_cls_list = loggers + if JsonLogger not in self._logger_cls_list: + if log_once("JsonLogger"): + logger.warning( + "JsonLogger not provided. The ExperimentAnalysis tool is " + "disabled." + ) + + super(UnifiedLogger, self).__init__(config, logdir, trial) + + def _init(self): + self._loggers = [] + for cls in self._logger_cls_list: + try: + self._loggers.append(cls(self.config, self.logdir, self.trial)) + except Exception as exc: + if log_once(f"instantiate:{cls.__name__}"): + logger.warning( + "Could not instantiate %s: %s.", cls.__name__, str(exc) + ) + + def on_result(self, result): + for _logger in self._loggers: + _logger.on_result(result) + + def update_config(self, config): + for _logger in self._loggers: + _logger.update_config(config) + + def close(self): + for _logger in self._loggers: + _logger.close() + + def flush(self): + for _logger in self._loggers: + _logger.flush() diff --git a/.venv/lib/python3.11/site-packages/ray/tune/logger/wandb.py b/.venv/lib/python3.11/site-packages/ray/tune/logger/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..85e80b60bfe8ce0bfa187d0b29dd4caec39eb725 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/logger/wandb.py @@ -0,0 +1,3 @@ +from ray.air.integrations.wandb import WandbLoggerCallback + +WandbLoggerCallback.__module__ = "ray.tune.logger.wandb" diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__init__.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f40125e5e50eaadf38479102ef016d0b09ea5458 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__init__.py @@ -0,0 +1,96 @@ +import inspect + +from ray._private.utils import get_function_args +from ray.tune.schedulers.async_hyperband import ASHAScheduler, AsyncHyperBandScheduler +from ray.tune.schedulers.hb_bohb import HyperBandForBOHB +from ray.tune.schedulers.hyperband import HyperBandScheduler +from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule +from ray.tune.schedulers.pbt import ( + PopulationBasedTraining, + PopulationBasedTrainingReplay, +) +from ray.tune.schedulers.resource_changing_scheduler import ResourceChangingScheduler +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util import PublicAPI + + +def _pb2_importer(): + # PB2 introduces a GPy dependency which can be expensive, so we import + # lazily. + from ray.tune.schedulers.pb2 import PB2 + + return PB2 + + +# Values in this dictionary will be one two kinds: +# class of the scheduler object to create +# wrapper function to support a lazy import of the scheduler class +SCHEDULER_IMPORT = { + "fifo": FIFOScheduler, + "async_hyperband": AsyncHyperBandScheduler, + "asynchyperband": AsyncHyperBandScheduler, + "median_stopping_rule": MedianStoppingRule, + "medianstopping": MedianStoppingRule, + "hyperband": HyperBandScheduler, + "hb_bohb": HyperBandForBOHB, + "pbt": PopulationBasedTraining, + "pbt_replay": PopulationBasedTrainingReplay, + "pb2": _pb2_importer, + "resource_changing": ResourceChangingScheduler, +} + + +@PublicAPI(stability="beta") +def create_scheduler( + scheduler, + **kwargs, +): + """Instantiate a scheduler based on the given string. + + This is useful for swapping between different schedulers. + + Args: + scheduler: The scheduler to use. + **kwargs: Scheduler parameters. + These keyword arguments will be passed to the initialization + function of the chosen scheduler. + Returns: + ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler. + Example: + >>> from ray import tune + >>> pbt_kwargs = {} + >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) # doctest: +SKIP + """ + + scheduler = scheduler.lower() + if scheduler not in SCHEDULER_IMPORT: + raise ValueError( + f"The `scheduler` argument must be one of " + f"{list(SCHEDULER_IMPORT)}. " + f"Got: {scheduler}" + ) + + SchedulerClass = SCHEDULER_IMPORT[scheduler] + + if inspect.isfunction(SchedulerClass): + # invoke the wrapper function to retrieve class + SchedulerClass = SchedulerClass() + + scheduler_args = get_function_args(SchedulerClass) + trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args} + + return SchedulerClass(**trimmed_kwargs) + + +__all__ = [ + "TrialScheduler", + "HyperBandScheduler", + "AsyncHyperBandScheduler", + "ASHAScheduler", + "MedianStoppingRule", + "FIFOScheduler", + "PopulationBasedTraining", + "PopulationBasedTrainingReplay", + "HyperBandForBOHB", + "ResourceChangingScheduler", +] diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/async_hyperband.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/async_hyperband.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c260df344220165e778a3948269866c7ffca8b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/async_hyperband.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hb_bohb.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hb_bohb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab0e0616e28b6582f818c0d52fab59221eeff529 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hb_bohb.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hyperband.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hyperband.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1570e0e5e1b34d90d1cb020e89d3179ede4458f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/hyperband.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/median_stopping_rule.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/median_stopping_rule.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fdb2e7fb33771d7845e9269160a6f1beee3a52e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/median_stopping_rule.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pb2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pb2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fff579f57193668feb87c1eeeac8300eb76def7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pb2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pbt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pbt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..900375d7f984f744a624f2aabc1a321ae290769d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/pbt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/trial_scheduler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/trial_scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f02c7c2708e8928fbbd7d52b9fecd4df8bad076e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/trial_scheduler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/util.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3b552f268b7838a5685061460212ffd6e22a1ee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/util.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/async_hyperband.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/async_hyperband.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf8204592ba647ca723de54d14fe81e8b95e40f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/async_hyperband.py @@ -0,0 +1,273 @@ +import logging +import pickle +from typing import TYPE_CHECKING, Dict, Optional, Union + +import numpy as np + +from ray.tune.experiment import Trial +from ray.tune.result import DEFAULT_METRIC +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util import PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +@PublicAPI +class AsyncHyperBandScheduler(FIFOScheduler): + """Implements the Async Successive Halving. + + This should provide similar theoretical performance as HyperBand but + avoid straggler issues that HyperBand faces. One implementation detail + is when using multiple brackets, trial allocation to bracket is done + randomly with over a softmax probability. + + See https://arxiv.org/abs/1810.05934 + + Args: + time_attr: A training result attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. If None but a mode was passed, + the `ray.tune.result.DEFAULT_METRIC` will be used per default. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + max_t: max time units per trial. Trials will be stopped after + max_t time units (determined by time_attr) have passed. + grace_period: Only stop trials at least this old in time. + The units are the same as the attribute named by `time_attr`. + reduction_factor: Used to set halving rate and amount. This + is simply a unit-less scalar. + brackets: Number of brackets. Each bracket has a different + halving rate, specified by the reduction factor. + stop_last_trials: Whether to terminate the trials after + reaching max_t. Defaults to True. + """ + + def __init__( + self, + time_attr: str = "training_iteration", + metric: Optional[str] = None, + mode: Optional[str] = None, + max_t: int = 100, + grace_period: int = 1, + reduction_factor: float = 4, + brackets: int = 1, + stop_last_trials: bool = True, + ): + assert max_t > 0, "Max (time_attr) not valid!" + assert max_t >= grace_period, "grace_period must be <= max_t!" + assert grace_period > 0, "grace_period must be positive!" + assert reduction_factor > 1, "Reduction Factor not valid!" + assert brackets > 0, "brackets must be positive!" + if mode: + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + super().__init__() + self._reduction_factor = reduction_factor + self._max_t = max_t + + self._trial_info = {} # Stores Trial -> Bracket + + # Tracks state for new trial add + self._brackets = [ + _Bracket( + grace_period, + max_t, + reduction_factor, + s, + stop_last_trials=stop_last_trials, + ) + for s in range(brackets) + ] + self._counter = 0 # for + self._num_stopped = 0 + self._metric = metric + self._mode = mode + self._metric_op = None + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + self._time_attr = time_attr + self._stop_last_trials = stop_last_trials + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + if self._metric and metric: + return False + if self._mode and mode: + return False + + if metric: + self._metric = metric + if mode: + self._mode = mode + + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + + if self._metric is None and self._mode: + # If only a mode was passed, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + if not self._metric or not self._metric_op: + raise ValueError( + "{} has been instantiated without a valid `metric` ({}) or " + "`mode` ({}) parameter. Either pass these parameters when " + "instantiating the scheduler, or pass them as parameters " + "to `tune.TuneConfig()`".format( + self.__class__.__name__, self._metric, self._mode + ) + ) + + sizes = np.array([len(b._rungs) for b in self._brackets]) + probs = np.e ** (sizes - sizes.max()) + normalized = probs / probs.sum() + idx = np.random.choice(len(self._brackets), p=normalized) + self._trial_info[trial.trial_id] = self._brackets[idx] + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + action = TrialScheduler.CONTINUE + if self._time_attr not in result or self._metric not in result: + return action + if result[self._time_attr] >= self._max_t and self._stop_last_trials: + action = TrialScheduler.STOP + else: + bracket = self._trial_info[trial.trial_id] + action = bracket.on_result( + trial, result[self._time_attr], self._metric_op * result[self._metric] + ) + if action == TrialScheduler.STOP: + self._num_stopped += 1 + return action + + def on_trial_complete( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + if self._time_attr not in result or self._metric not in result: + return + bracket = self._trial_info[trial.trial_id] + bracket.on_result( + trial, result[self._time_attr], self._metric_op * result[self._metric] + ) + del self._trial_info[trial.trial_id] + + def on_trial_remove(self, tune_controller: "TuneController", trial: Trial): + del self._trial_info[trial.trial_id] + + def debug_string(self) -> str: + out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped) + out += "\n" + "\n".join([b.debug_str() for b in self._brackets]) + return out + + def save(self, checkpoint_path: str): + save_object = self.__dict__ + with open(checkpoint_path, "wb") as outputFile: + pickle.dump(save_object, outputFile) + + def restore(self, checkpoint_path: str): + with open(checkpoint_path, "rb") as inputFile: + save_object = pickle.load(inputFile) + self.__dict__.update(save_object) + + +class _Bracket: + """Bookkeeping system to track the cutoffs. + + Rungs are created in reversed order so that we can more easily find + the correct rung corresponding to the current iteration of the result. + + Example: + >>> trial1, trial2, trial3 = ... # doctest: +SKIP + >>> b = _Bracket(1, 10, 2, 0) # doctest: +SKIP + >>> # CONTINUE + >>> b.on_result(trial1, 1, 2) # doctest: +SKIP + >>> # CONTINUE + >>> b.on_result(trial2, 1, 4) # doctest: +SKIP + >>> # rungs are reversed + >>> b.cutoff(b._rungs[-1][1]) == 3.0 # doctest: +SKIP + # STOP + >>> b.on_result(trial3, 1, 1) # doctest: +SKIP + >>> b.cutoff(b._rungs[3][1]) == 2.0 # doctest: +SKIP + """ + + def __init__( + self, + min_t: int, + max_t: int, + reduction_factor: float, + s: int, + stop_last_trials: bool = True, + ): + self.rf = reduction_factor + MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1) + self._rungs = [ + (min_t * self.rf ** (k + s), {}) for k in reversed(range(MAX_RUNGS)) + ] + self._stop_last_trials = stop_last_trials + + def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]: + if not recorded: + return None + return np.nanpercentile(list(recorded.values()), (1 - 1 / self.rf) * 100) + + def on_result(self, trial: Trial, cur_iter: int, cur_rew: Optional[float]) -> str: + action = TrialScheduler.CONTINUE + for milestone, recorded in self._rungs: + if ( + cur_iter >= milestone + and trial.trial_id in recorded + and not self._stop_last_trials + ): + # If our result has been recorded for this trial already, the + # decision to continue training has already been made. Thus we can + # skip new cutoff calculation and just continue training. + # We can also break as milestones are descending. + break + if cur_iter < milestone or trial.trial_id in recorded: + continue + else: + cutoff = self.cutoff(recorded) + if cutoff is not None and cur_rew < cutoff: + action = TrialScheduler.STOP + if cur_rew is None: + logger.warning( + "Reward attribute is None! Consider" + " reporting using a different field." + ) + else: + recorded[trial.trial_id] = cur_rew + break + return action + + def debug_str(self) -> str: + # TODO: fix up the output for this + iters = " | ".join( + [ + "Iter {:.3f}: {}".format(milestone, self.cutoff(recorded)) + for milestone, recorded in self._rungs + ] + ) + return "Bracket: " + iters + + +ASHAScheduler = AsyncHyperBandScheduler + +if __name__ == "__main__": + sched = AsyncHyperBandScheduler(grace_period=1, max_t=10, reduction_factor=2) + print(sched.debug_string()) + bracket = sched._brackets[0] + print(bracket.cutoff({str(i): i for i in range(20)})) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hb_bohb.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hb_bohb.py new file mode 100644 index 0000000000000000000000000000000000000000..6c454d9efd35a52398fc183ab87ef007b1a59f21 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hb_bohb.py @@ -0,0 +1,176 @@ +import logging +from typing import TYPE_CHECKING, Dict, Optional + +from ray.tune.experiment import Trial +from ray.tune.schedulers.hyperband import HyperBandScheduler +from ray.tune.schedulers.trial_scheduler import TrialScheduler +from ray.util import PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +@PublicAPI +class HyperBandForBOHB(HyperBandScheduler): + """Extends HyperBand early stopping algorithm for BOHB. + + This implementation removes the ``HyperBandScheduler`` pipelining. This + class introduces key changes: + + 1. Trials are now placed so that the bracket with the largest size is + filled first. + + 2. Trials will be paused even if the bracket is not filled. This allows + BOHB to insert new trials into the training. + + See ray.tune.schedulers.HyperBandScheduler for parameter docstring. + """ + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + """Adds new trial. + + On a new trial add, if current bracket is not filled, add to current + bracket. Else, if current band is not filled, create new bracket, add + to current bracket. Else, create new iteration, create new bracket, + add to bracket. + """ + if not self._metric or not self._metric_op: + raise ValueError( + "{} has been instantiated without a valid `metric` ({}) or " + "`mode` ({}) parameter. Either pass these parameters when " + "instantiating the scheduler, or pass them as parameters " + "to `tune.TuneConfig()`".format( + self.__class__.__name__, self._metric, self._mode + ) + ) + + cur_bracket = self._state["bracket"] + cur_band = self._hyperbands[self._state["band_idx"]] + if cur_bracket is None or cur_bracket.filled(): + retry = True + while retry: + # if current iteration is filled, create new iteration + if self._cur_band_filled(): + cur_band = [] + self._hyperbands.append(cur_band) + self._state["band_idx"] += 1 + + # MAIN CHANGE HERE - largest bracket first! + # cur_band will always be less than s_max_1 or else filled + s = self._s_max_1 - len(cur_band) - 1 + assert s >= 0, "Current band is filled!" + if self._get_r0(s) == 0: + logger.debug("BOHB: Bracket too small - Retrying...") + cur_bracket = None + else: + retry = False + cur_bracket = self._create_bracket(s) + cur_band.append(cur_bracket) + self._state["bracket"] = cur_bracket + + self._state["bracket"].add_trial(trial) + self._trial_info[trial] = cur_bracket, self._state["band_idx"] + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + """If bracket is finished, all trials will be stopped. + + If a given trial finishes and bracket iteration is not done, + the trial will be paused and resources will be given up. + + This scheduler will not start trials but will stop trials. + The current running trial will not be handled, + as the trialrunner will be given control to handle it.""" + + result["hyperband_info"] = {} + bracket, _ = self._trial_info[trial] + bracket.update_trial_stats(trial, result) + + if bracket.continue_trial(trial): + return TrialScheduler.CONTINUE + + result["hyperband_info"]["budget"] = bracket._cumul_r + + # MAIN CHANGE HERE! + statuses = [(t, t.status) for t in bracket._live_trials] + if not bracket.filled() or any( + status != Trial.PAUSED for t, status in statuses if t is not trial + ): + # BOHB Specific. This hack existed in old Ray versions + # and was removed, but it needs to be brought back + # as otherwise the BOHB doesn't behave as intended. + # The default concurrency limiter works by discarding + # new suggestions if there are more running trials + # than the limit. That doesn't take into account paused + # trials. With BOHB, this leads to N trials finishing + # completely and then another N trials starting, + # instead of trials being paused and resumed in brackets + # as intended. + # There should be a better API for this. + # TODO(team-ml): Refactor alongside HyperBandForBOHB + tune_controller.search_alg.searcher.on_pause(trial.trial_id) + return TrialScheduler.PAUSE + + logger.debug(f"Processing bracket after trial {trial} result") + action = self._process_bracket(tune_controller, bracket) + if action == TrialScheduler.PAUSE: + tune_controller.search_alg.searcher.on_pause(trial.trial_id) + return action + + def _unpause_trial(self, tune_controller: "TuneController", trial: Trial): + # Hack. See comment in on_trial_result + tune_controller.search_alg.searcher.on_unpause(trial.trial_id) + + def choose_trial_to_run( + self, tune_controller: "TuneController", allow_recurse: bool = True + ) -> Optional[Trial]: + """Fair scheduling within iteration by completion percentage. + + List of trials not used since all trials are tracked as state + of scheduler. If iteration is occupied (ie, no trials to run), + then look into next iteration. + """ + + for hyperband in self._hyperbands: + # band will have None entries if no resources + # are to be allocated to that bracket. + scrubbed = [b for b in hyperband if b is not None] + for bracket in scrubbed: + for trial in bracket.current_trials(): + if ( + trial.status == Trial.PAUSED + and trial in bracket.trials_to_unpause + ) or trial.status == Trial.PENDING: + return trial + # MAIN CHANGE HERE! + if not any(t.status == Trial.RUNNING for t in tune_controller.get_trials()): + for hyperband in self._hyperbands: + for bracket in hyperband: + if bracket and any( + trial.status == Trial.PAUSED + for trial in bracket.current_trials() + ): + # This will change the trial state + logger.debug("Processing bracket since no trial is running.") + self._process_bracket(tune_controller, bracket) + + # If there are pending trials now, suggest one. + # This is because there might be both PENDING and + # PAUSED trials now, and PAUSED trials will raise + # an error before the trial runner tries again. + if allow_recurse and any( + ( + trial.status == Trial.PAUSED + and trial in bracket.trials_to_unpause + ) + or trial.status == Trial.PENDING + for trial in bracket.current_trials() + ): + return self.choose_trial_to_run( + tune_controller, allow_recurse=False + ) + # MAIN CHANGE HERE! + return None diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hyperband.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hyperband.py new file mode 100644 index 0000000000000000000000000000000000000000..57503d97ee34a1ede4b37ba4534860f9d06b27d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/hyperband.py @@ -0,0 +1,606 @@ +import collections +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import numpy as np + +from ray.tune.error import TuneError +from ray.tune.experiment import Trial +from ray.tune.result import DEFAULT_METRIC +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +# Implementation notes: +# This implementation contains 3 logical levels. +# Each HyperBand iteration is a "band". There can be multiple +# bands running at once, and there can be 1 band that is incomplete. +# +# In each band, there are at most `s` + 1 brackets. +# `s` is a value determined by given parameters, and assigned on +# a cyclic basis. +# +# In each bracket, there are at most `n(s)` trials, indicating that +# `n` is a function of `s`. These trials go through a series of +# halving procedures, dropping lowest performers. Multiple +# brackets are running at once. +# +# Trials added will be inserted into the most recent bracket +# and band and will spill over to new brackets/bands accordingly. +# +# This maintains the bracket size and max trial count per band +# to 5 and 117 respectively, which correspond to that of +# `max_attr=81, eta=3` from the blog post. Trials will fill up +# from smallest bracket to largest, with largest +# having the most rounds of successive halving. +@PublicAPI +class HyperBandScheduler(FIFOScheduler): + """Implements the HyperBand early stopping algorithm. + + HyperBandScheduler early stops trials using the HyperBand optimization + algorithm. It divides trials into brackets of varying sizes, and + periodically early stops low-performing trials within each bracket. + + To use this implementation of HyperBand with Tune, all you need + to do is specify the max length of time a trial can run `max_t`, the time + units `time_attr`, the name of the reported objective value `metric`, + and if `metric` is to be maximized or minimized (`mode`). + We automatically determine reasonable values for the other + HyperBand parameters based on the given values. + + For example, to limit trials to 10 minutes and early stop based on the + `episode_mean_reward` attr, construct: + + ``HyperBand('time_total_s', 'episode_reward_mean', max_t=600)`` + + Note that Tune's stopping criteria will be applied in conjunction with + HyperBand's early stopping mechanisms. + + See also: https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/ + + Args: + time_attr: The training result attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. If None but a mode was passed, + the `ray.tune.result.DEFAULT_METRIC` will be used per default. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + max_t: max time units per trial. Trials will be stopped after + max_t time units (determined by time_attr) have passed. + The scheduler will terminate trials after this time has passed. + Note that this is different from the semantics of `max_t` as + mentioned in the original HyperBand paper. + reduction_factor: Same as `eta`. Determines how sharp + the difference is between bracket space-time allocation ratios. + stop_last_trials: Whether to terminate the trials after + reaching max_t. Defaults to True. + """ # noqa: E501 + + _supports_buffered_results = False + + def __init__( + self, + time_attr: str = "training_iteration", + metric: Optional[str] = None, + mode: Optional[str] = None, + max_t: int = 81, + reduction_factor: float = 3, + stop_last_trials: bool = True, + ): + assert max_t > 0, "Max (time_attr) not valid!" + if mode: + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + super().__init__() + self._eta = reduction_factor + self._s_max_1 = int(np.round(np.log(max_t) / np.log(reduction_factor))) + 1 + self._max_t_attr = max_t + # bracket max trials + self._get_n0 = lambda s: int(np.ceil(self._s_max_1 / (s + 1) * self._eta**s)) + # bracket initial iterations + self._get_r0 = lambda s: int((max_t * self._eta ** (-s))) + self._hyperbands = [[]] # list of hyperband iterations + self._trial_info = {} # Stores Trial -> Bracket, Band Iteration + + # Tracks state for new trial add + self._state = {"bracket": None, "band_idx": 0} + self._num_stopped = 0 + self._metric = metric + self._mode = mode + self._metric_op = None + + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + self._time_attr = time_attr + self._stop_last_trials = stop_last_trials + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + if self._metric and metric: + return False + if self._mode and mode: + return False + + if metric: + self._metric = metric + if mode: + self._mode = mode + + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + + if self._metric is None and self._mode: + # If only a mode was passed, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + """Adds new trial. + + On a new trial add, if current bracket is not filled, + add to current bracket. Else, if current band is not filled, + create new bracket, add to current bracket. + Else, create new iteration, create new bracket, add to bracket.""" + if not self._metric or not self._metric_op: + raise ValueError( + "{} has been instantiated without a valid `metric` ({}) or " + "`mode` ({}) parameter. Either pass these parameters when " + "instantiating the scheduler, or pass them as parameters " + "to `tune.TuneConfig()`".format( + self.__class__.__name__, self._metric, self._mode + ) + ) + + cur_bracket = self._state["bracket"] + cur_band = self._hyperbands[self._state["band_idx"]] + if cur_bracket is None or cur_bracket.filled(): + retry = True + while retry: + # if current iteration is filled, create new iteration + if self._cur_band_filled(): + cur_band = [] + self._hyperbands.append(cur_band) + self._state["band_idx"] += 1 + + # cur_band will always be less than s_max_1 or else filled + s = len(cur_band) + assert s < self._s_max_1, "Current band is filled!" + if self._get_r0(s) == 0: + logger.info("Bracket too small - Retrying...") + cur_bracket = None + else: + retry = False + cur_bracket = self._create_bracket(s) + cur_band.append(cur_bracket) + self._state["bracket"] = cur_bracket + + self._state["bracket"].add_trial(trial) + self._trial_info[trial] = cur_bracket, self._state["band_idx"] + + def _create_bracket(self, s): + return _Bracket( + time_attr=self._time_attr, + max_trials=self._get_n0(s), + init_t_attr=self._get_r0(s), + max_t_attr=self._max_t_attr, + eta=self._eta, + s=s, + stop_last_trials=self._stop_last_trials, + ) + + def _cur_band_filled(self) -> bool: + """Checks if the current band is filled. + + The size of the current band should be equal to s_max_1""" + + cur_band = self._hyperbands[self._state["band_idx"]] + return len(cur_band) == self._s_max_1 + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + """If bracket is finished, all trials will be stopped. + + If a given trial finishes and bracket iteration is not done, + the trial will be paused and resources will be given up. + + This scheduler will not start trials but will stop trials. + The current running trial will not be handled, + as the trialrunner will be given control to handle it.""" + + bracket, _ = self._trial_info[trial] + bracket.update_trial_stats(trial, result) + + if bracket.continue_trial(trial): + return TrialScheduler.CONTINUE + + logger.debug(f"Processing bracket after trial {trial} result") + action = self._process_bracket(tune_controller, bracket) + logger.debug( + f"{action} for {trial} on " + f"{self._time_attr}={result.get(self._time_attr)}" + ) + return action + + def _process_bracket( + self, tune_controller: "TuneController", bracket: "_Bracket" + ) -> str: + """This is called whenever a trial makes progress. + + When all live trials in the bracket have no more iterations left, + Trials will be successively halved. If bracket is done, all + non-running trials will be stopped and cleaned up, + and during each halving phase, bad trials will be stopped while good + trials will return to "PENDING". + + Note some implicit conditions here: In ``on_trial_result`` a trial is + either continued (e.g. if it didn't reach the time threshold for the bracket) + or this method (``_process_bracket``) is called. If there are other trials left + that still haven't reached the threshold, the trial is PAUSED. This means + that when the bracket is actually processed (``bracket.cur_iter_done``), there + is at most one RUNNING trial (which is the trial that is currently processed) + and the rest are either PAUSED (as explained above) or TERMINATED/ERRORED + (if they finish separately). + """ + + action = TrialScheduler.PAUSE + if bracket.cur_iter_done(): + if bracket.finished(): + bracket.cleanup_full(tune_controller) + return TrialScheduler.STOP + + bracket.is_being_processed = True + + good, bad = bracket.successive_halving(self._metric, self._metric_op) + + logger.debug( + f"Processing {len(good)} good and {len(bad)} bad trials in " + f"bracket {bracket}.\n" + f"Good: {good}\nBad: {bad}" + ) + + # kill bad trials + self._num_stopped += len(bad) + for t in bad: + if t.status == Trial.PAUSED or t.is_saving: + logger.debug(f"Stopping other trial {str(t)}") + tune_controller.stop_trial(t) + elif t.status == Trial.RUNNING: + # See the docstring: There can only be at most one RUNNING + # trial, which is the current trial. + logger.debug(f"Stopping current trial {str(t)}") + bracket.cleanup_trial(t) + action = TrialScheduler.STOP + else: + # Trials cannot be ERROR/TERMINATED, as then they would have + # been removed from the bracket (in `bracket.cleanup_trial`). + # Trials cannot be PENDING, as then they wouldn't have reported + # enough results to finish the bracket, and it wouldn't be + # processed. + raise TuneError( + f"Trial with unexpected bad status encountered: " + f"{str(t)} is {t.status}" + ) + + # ready the good trials - if trial is too far ahead, don't continue + for t in good: + if bracket.continue_trial(t): + # The scheduler should have cleaned up this trial already. + assert t.status not in (Trial.ERROR, Trial.TERMINATED), ( + f"Good trial {t.trial_id} is in an invalid state: {t.status}\n" + "Expected trial to be either PAUSED, PENDING, or RUNNING.\n" + "If you encounter this, please file an issue on the Ray Github." + ) + if t.status == Trial.PAUSED or t.is_saving: + logger.debug(f"Unpausing trial {str(t)}") + self._unpause_trial(tune_controller, t) + bracket.trials_to_unpause.add(t) + elif t.status == Trial.RUNNING: + # See the docstring: There can only be at most one RUNNING + # trial, which is the current trial. + logger.debug(f"Continuing current trial {str(t)}") + action = TrialScheduler.CONTINUE + # else: PENDING trial (from a previous unpause) should stay as is. + elif bracket.finished() and bracket.stop_last_trials: + # Scheduler decides to not continue trial because the bracket + # reached max_t. In this case, stop the trials + if t.status == Trial.PAUSED or t.is_saving: + logger.debug(f"Bracket finished. Stopping other trial {str(t)}") + tune_controller.stop_trial(t) + elif t.status == Trial.RUNNING: + # See the docstring: There can only be at most one RUNNING + # trial, which is the current trial. + logger.debug( + f"Bracket finished. Stopping current trial {str(t)}" + ) + bracket.cleanup_trial(t) + action = TrialScheduler.STOP + return action + + def _unpause_trial(self, tune_controller: "TuneController", trial: Trial): + """No-op by default.""" + return + + def on_trial_remove(self, tune_controller: "TuneController", trial: Trial): + """Notification when trial terminates. + + Trial info is removed from bracket. Triggers halving if bracket is + not finished.""" + bracket, _ = self._trial_info[trial] + bracket.cleanup_trial(trial) + if not bracket.finished() and not bracket.is_being_processed: + logger.debug(f"Processing bracket after trial {trial} removed") + self._process_bracket(tune_controller, bracket) + + def on_trial_complete( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + """Cleans up trial info from bracket if trial completed early.""" + self.on_trial_remove(tune_controller, trial) + + def on_trial_error(self, tune_controller: "TuneController", trial: Trial): + """Cleans up trial info from bracket if trial errored early.""" + self.on_trial_remove(tune_controller, trial) + + def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]: + """Fair scheduling within iteration by completion percentage. + + List of trials not used since all trials are tracked as state + of scheduler. If iteration is occupied (ie, no trials to run), + then look into next iteration. + """ + + for hyperband in self._hyperbands: + # band will have None entries if no resources + # are to be allocated to that bracket. + scrubbed = [b for b in hyperband if b is not None] + for bracket in sorted(scrubbed, key=lambda b: b.completion_percentage()): + for trial in bracket.current_trials(): + if ( + trial.status == Trial.PAUSED + and trial in bracket.trials_to_unpause + ) or trial.status == Trial.PENDING: + return trial + return None + + def debug_string(self) -> str: + """This provides a progress notification for the algorithm. + + For each bracket, the algorithm will output a string as follows: + + Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%): + {PENDING: 2, RUNNING: 3, TERMINATED: 2} + + "Max Size" indicates the max number of pending/running experiments + set according to the Hyperband algorithm. + + "Milestone" indicates the iterations a trial will run for before + the next halving will occur. + + "Completed" indicates an approximate progress metric. Some brackets, + like ones that are unfilled, will not reach 100%. + """ + out = "Using HyperBand: " + out += "num_stopped={} total_brackets={}".format( + self._num_stopped, sum(len(band) for band in self._hyperbands) + ) + for i, band in enumerate(self._hyperbands): + out += "\nRound #{}:".format(i) + for bracket in band: + if bracket: + out += "\n {}".format(bracket) + return out + + def state(self) -> Dict[str, int]: + return { + "num_brackets": sum(len(band) for band in self._hyperbands), + "num_stopped": self._num_stopped, + } + + +class _Bracket: + """Logical object for tracking Hyperband bracket progress. Keeps track + of proper parameters as designated by HyperBand. + + Also keeps track of progress to ensure good scheduling. + """ + + def __init__( + self, + time_attr: str, + max_trials: int, + init_t_attr: int, + max_t_attr: int, + eta: float, + s: int, + stop_last_trials: bool = True, + ): + self._live_trials = {} # maps trial -> current result + self._all_trials = [] + self._time_attr = time_attr # attribute to + + self._n = self._n0 = max_trials + self._r = self._r0 = init_t_attr + self._max_t_attr = max_t_attr + self._cumul_r = self._r0 + + self._eta = eta + self._halves = s + + self._total_work = self._calculate_total_work(self._n0, self._r0, s) + self._completed_progress = 0 + self.stop_last_trials = stop_last_trials + self.is_being_processed = False + + self.trials_to_unpause = set() + + def add_trial(self, trial: Trial): + """Add trial to bracket assuming bracket is not filled. + + At a later iteration, a newly added trial will be given equal + opportunity to catch up.""" + assert not self.filled(), "Cannot add trial to filled bracket!" + self._live_trials[trial] = None + self._all_trials.append(trial) + + def cur_iter_done(self) -> bool: + """Checks if all iterations have completed. + + TODO(rliaw): also check that `t.iterations == self._r`""" + return all( + self._get_result_time(result) >= self._cumul_r + for result in self._live_trials.values() + ) + + def finished(self) -> bool: + if not self.stop_last_trials: + return False + return self._halves == 0 and self.cur_iter_done() + + def current_trials(self) -> List[Trial]: + return list(self._live_trials) + + def continue_trial(self, trial: Trial) -> bool: + result = self._live_trials[trial] + if not self.stop_last_trials and self._halves == 0: + return True + elif self._get_result_time(result) < self._cumul_r: + logger.debug( + f"Continuing trial {trial} as it hasn't reached the time threshold " + f"{self._cumul_r}, yet." + ) + return True + return False + + def filled(self) -> bool: + """Checks if bracket is filled. + + Only let new trials be added at current level minimizing the need + to backtrack and bookkeep previous medians.""" + + return len(self._live_trials) == self._n + + def successive_halving( + self, metric: str, metric_op: float + ) -> Tuple[List[Trial], List[Trial]]: + if self._halves == 0 and not self.stop_last_trials: + return self._live_trials, [] + assert self._halves > 0 + + # "Halving" is a misnomer. We're actually reducing by factor `eta`. + self._halves -= 1 + + # If we had 8 trials in the bracket and eta=2, we will keep 4. + # If we had 9 trials in the bracket and eta=3, we will keep 3. + self._n = int(np.ceil(self._n / self._eta)) + + # Likewise, we increase the number of iterations until we process the bracket + # again. + # Remember r0 = max_t * self._eta ** (-s) + # Let max_t=16, eta=2, s=1. Then r0=8, and we calculate r1=16. + # Let max_t=16, eta=2, s=2. Then r0=4, and we calculate r1=8, r2=16. + + # Let max_t=81, eta=3, s=1. Then r0=27, and we calculate r1=81. + # Let max_t=81, eta=3, s=2. Then r0=9, and we calculate r1=27, r2=81. + self._r *= self._eta + self._r = int(min(self._r, self._max_t_attr)) + self._cumul_r = self._r + sorted_trials = sorted( + self._live_trials, key=lambda t: metric_op * self._live_trials[t][metric] + ) + + good, bad = sorted_trials[-self._n :], sorted_trials[: -self._n] + return good, bad + + def update_trial_stats(self, trial: Trial, result: Dict): + """Update result for trial. Called after trial has finished + an iteration - will decrement iteration count. + + TODO(rliaw): The other alternative is to keep the trials + in and make sure they're not set as pending later.""" + + assert trial in self._live_trials + assert self._get_result_time(result) >= 0 + observed_time = self._get_result_time(result) + last_observed = self._get_result_time(self._live_trials[trial]) + + delta = observed_time - last_observed + if delta <= 0: + logger.info( + "Restoring from a previous point in time. " + "Previous={}; Now={}".format(last_observed, observed_time) + ) + self._completed_progress += delta + self._live_trials[trial] = result + self.trials_to_unpause.discard(trial) + + def cleanup_trial(self, trial: Trial): + """Clean up statistics tracking for terminated trials (either by force + or otherwise). + + This may cause bad trials to continue for a long time, in the case + where all the good trials finish early and there are only bad trials + left in a bracket with a large max-iteration.""" + self._live_trials.pop(trial, None) + + def cleanup_full(self, tune_controller: "TuneController"): + """Cleans up bracket after bracket is completely finished. + + Lets the last trial continue to run until termination condition + kicks in.""" + for trial in self.current_trials(): + if trial.status == Trial.PAUSED: + tune_controller.stop_trial(trial) + + def completion_percentage(self) -> float: + """Returns a progress metric. + + This will not be always finish with 100 since dead trials + are dropped.""" + if self.finished(): + return 1.0 + return min(self._completed_progress / self._total_work, 1.0) + + def _get_result_time(self, result: Dict) -> float: + if result is None: + return 0 + return result[self._time_attr] + + def _calculate_total_work(self, n: int, r: float, s: int): + work = 0 + cumulative_r = r + for _ in range(s + 1): + work += int(n) * int(r) + n /= self._eta + n = int(np.ceil(n)) + r *= self._eta + r = int(min(r, self._max_t_attr - cumulative_r)) + return work + + def __repr__(self) -> str: + status = ", ".join( + [ + "Max Size (n)={}".format(self._n), + "Milestone (r)={}".format(self._cumul_r), + "completed={:.1%}".format(self.completion_percentage()), + ] + ) + counts = collections.Counter([t.status for t in self._all_trials]) + trial_statuses = ", ".join( + sorted("{}: {}".format(k, v) for k, v in counts.items()) + ) + return "Bracket({}): {{{}}} ".format(status, trial_statuses) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/median_stopping_rule.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/median_stopping_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5cb35f7a53eb456521893c76d4ab6964ddc2c9 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/median_stopping_rule.py @@ -0,0 +1,217 @@ +import collections +import logging +from typing import TYPE_CHECKING, Dict, List, Optional + +import numpy as np + +from ray.tune.experiment import Trial +from ray.tune.result import DEFAULT_METRIC +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +@PublicAPI +class MedianStoppingRule(FIFOScheduler): + """Implements the median stopping rule as described in the Vizier paper: + + https://research.google.com/pubs/pub46180.html + + Args: + time_attr: The training result attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. If None but a mode was passed, + the `ray.tune.result.DEFAULT_METRIC` will be used per default. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + grace_period: Only stop trials at least this old in time. + The mean will only be computed from this time onwards. The units + are the same as the attribute named by `time_attr`. + min_samples_required: Minimum number of trials to compute median + over. + min_time_slice: Each trial runs at least this long before + yielding (assuming it isn't stopped). Note: trials ONLY yield if + there are not enough samples to evaluate performance for the + current result AND there are other trials waiting to run. + The units are the same as the attribute named by `time_attr`. + hard_stop: If False, pauses trials instead of stopping + them. When all other trials are complete, paused trials will be + resumed and allowed to run FIFO. + """ + + def __init__( + self, + time_attr: str = "time_total_s", + metric: Optional[str] = None, + mode: Optional[str] = None, + grace_period: float = 60.0, + min_samples_required: int = 3, + min_time_slice: int = 0, + hard_stop: bool = True, + ): + super().__init__() + self._stopped_trials = set() + self._grace_period = grace_period + self._min_samples_required = min_samples_required + self._min_time_slice = min_time_slice + self._metric = metric + self._worst = None + self._compare_op = None + + self._mode = mode + if mode: + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." + self._worst = float("-inf") if self._mode == "max" else float("inf") + self._compare_op = max if self._mode == "max" else min + + self._time_attr = time_attr + self._hard_stop = hard_stop + self._trial_state = {} + self._last_pause = collections.defaultdict(lambda: float("-inf")) + self._results = collections.defaultdict(list) + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + if self._metric and metric: + return False + if self._mode and mode: + return False + + if metric: + self._metric = metric + if mode: + self._mode = mode + + self._worst = float("-inf") if self._mode == "max" else float("inf") + self._compare_op = max if self._mode == "max" else min + + if self._metric is None and self._mode: + # If only a mode was passed, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + if not self._metric or not self._worst or not self._compare_op: + raise ValueError( + "{} has been instantiated without a valid `metric` ({}) or " + "`mode` ({}) parameter. Either pass these parameters when " + "instantiating the scheduler, or pass them as parameters " + "to `tune.TuneConfig()`".format( + self.__class__.__name__, self._metric, self._mode + ) + ) + + super(MedianStoppingRule, self).on_trial_add(tune_controller, trial) + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + """Callback for early stopping. + + This stopping rule stops a running trial if the trial's best objective + value by step `t` is strictly worse than the median of the running + averages of all completed trials' objectives reported up to step `t`. + """ + if self._time_attr not in result or self._metric not in result: + return TrialScheduler.CONTINUE + + if trial in self._stopped_trials: + assert not self._hard_stop + # Fall back to FIFO + return TrialScheduler.CONTINUE + + time = result[self._time_attr] + self._results[trial].append(result) + + if time < self._grace_period: + return TrialScheduler.CONTINUE + + trials = self._trials_beyond_time(time) + trials.remove(trial) + + if len(trials) < self._min_samples_required: + action = self._on_insufficient_samples(tune_controller, trial, time) + if action == TrialScheduler.PAUSE: + self._last_pause[trial] = time + action_str = "Yielding time to other trials." + else: + action_str = "Continuing anyways." + logger.debug( + "MedianStoppingRule: insufficient samples={} to evaluate " + "trial {} at t={}. {}".format( + len(trials), trial.trial_id, time, action_str + ) + ) + return action + + median_result = self._median_result(trials, time) + best_result = self._best_result(trial) + logger.debug( + "Trial {} best res={} vs median res={} at t={}".format( + trial, best_result, median_result, time + ) + ) + + if self._compare_op(median_result, best_result) != best_result: + logger.debug("MedianStoppingRule: early stopping {}".format(trial)) + self._stopped_trials.add(trial) + if self._hard_stop: + return TrialScheduler.STOP + else: + return TrialScheduler.PAUSE + else: + return TrialScheduler.CONTINUE + + def on_trial_complete( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + self._results[trial].append(result) + + def debug_string(self) -> str: + return "Using MedianStoppingRule: num_stopped={}.".format( + len(self._stopped_trials) + ) + + def _on_insufficient_samples( + self, tune_controller: "TuneController", trial: Trial, time: float + ) -> str: + pause = time - self._last_pause[trial] > self._min_time_slice + pause = pause and [ + t + for t in tune_controller.get_live_trials() + if t.status in (Trial.PENDING, Trial.PAUSED) + ] + return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE + + def _trials_beyond_time(self, time: float) -> List[Trial]: + trials = [ + trial + for trial in self._results + if self._results[trial][-1][self._time_attr] >= time + ] + return trials + + def _median_result(self, trials: List[Trial], time: float): + return np.median([self._running_mean(trial, time) for trial in trials]) + + def _running_mean(self, trial: Trial, time: float) -> np.ndarray: + results = self._results[trial] + # TODO(ekl) we could do interpolation to be more precise, but for now + # assume len(results) is large and the time diffs are roughly equal + scoped_results = [ + r for r in results if self._grace_period <= r[self._time_attr] <= time + ] + return np.mean([r[self._metric] for r in scoped_results]) + + def _best_result(self, trial): + results = self._results[trial] + return self._compare_op([r[self._metric] for r in results]) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..b80295d00ee4c5e4d3c7d16cb26ed45fceb2f5e3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2.py @@ -0,0 +1,507 @@ +import logging +from copy import deepcopy +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union + +import numpy as np +import pandas as pd + +from ray.tune import TuneError +from ray.tune.experiment import Trial +from ray.tune.schedulers import PopulationBasedTraining +from ray.tune.schedulers.pbt import _PBTTrialState +from ray.tune.utils.util import flatten_dict, unflatten_dict +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + + +def import_pb2_dependencies(): + try: + import GPy + except ImportError: + GPy = None + try: + import sklearn + except ImportError: + sklearn = None + return GPy, sklearn + + +GPy, has_sklearn = import_pb2_dependencies() + +if GPy and has_sklearn: + from ray.tune.schedulers.pb2_utils import ( + UCB, + TV_SquaredExp, + normalize, + optimize_acq, + select_length, + standardize, + ) + +logger = logging.getLogger(__name__) + + +def _fill_config( + config: Dict, hyperparam_bounds: Dict[str, Union[dict, list, tuple]] +) -> Dict: + """Fills missing hyperparameters in config by sampling uniformly from the + specified `hyperparam_bounds`. + Recursively fills the config if `hyperparam_bounds` is a nested dict. + + This is a helper used to set initial hyperparameter values if the user doesn't + specify them in the Tuner `param_space`. + + Returns the dict of filled hyperparameters. + """ + filled_hyperparams = {} + for param_name, bounds in hyperparam_bounds.items(): + if isinstance(bounds, dict): + if param_name not in config: + config[param_name] = {} + filled_hyperparams[param_name] = _fill_config(config[param_name], bounds) + elif isinstance(bounds, (list, tuple)) and param_name not in config: + if log_once(param_name + "-missing"): + logger.debug( + f"Cannot find {param_name} in config. Initializing by " + "sampling uniformly from the provided `hyperparam_bounds`." + ) + assert len(bounds) == 2 + low, high = bounds + config[param_name] = filled_hyperparams[param_name] = np.random.uniform( + low, high + ) + return filled_hyperparams + + +def _select_config( + Xraw: np.array, + yraw: np.array, + current: list, + newpoint: np.array, + bounds: dict, + num_f: int, +) -> np.ndarray: + """Selects the next hyperparameter config to try. + + This function takes the formatted data, fits the GP model and optimizes the + UCB acquisition function to select the next point. + + Args: + Xraw: The un-normalized array of hyperparams, Time and + Reward + yraw: The un-normalized vector of reward changes. + current: The hyperparams of trials currently running. This is + important so we do not select the same config twice. If there is + data here then we fit a second GP including it + (with fake y labels). The GP variance doesn't depend on the y + labels so it is ok. + newpoint: The Reward and Time for the new point. + We cannot change these as they are based on the *new weights*. + bounds: Bounds for the hyperparameters. Used to normalize. + num_f: The number of fixed params. Almost always 2 (reward+time) + + Return: + xt: A vector of new hyperparameters. + """ + length = select_length(Xraw, yraw, bounds, num_f) + + Xraw = Xraw[-length:, :] + yraw = yraw[-length:] + + base_vals = np.array(list(bounds.values())).T + oldpoints = Xraw[:, :num_f] + old_lims = np.concatenate( + (np.max(oldpoints, axis=0), np.min(oldpoints, axis=0)) + ).reshape(2, oldpoints.shape[1]) + limits = np.concatenate((old_lims, base_vals), axis=1) + + X = normalize(Xraw, limits) + y = standardize(yraw).reshape(yraw.size, 1) + + fixed = normalize(newpoint, oldpoints) + + kernel = TV_SquaredExp( + input_dim=X.shape[1], variance=1.0, lengthscale=1.0, epsilon=0.1 + ) + + try: + m = GPy.models.GPRegression(X, y, kernel) + except np.linalg.LinAlgError: + # add diagonal ** we would ideally make this something more robust... + X += np.eye(X.shape[0]) * 1e-3 + m = GPy.models.GPRegression(X, y, kernel) + + try: + m.optimize() + except np.linalg.LinAlgError: + # add diagonal ** we would ideally make this something more robust... + X += np.eye(X.shape[0]) * 1e-3 + m = GPy.models.GPRegression(X, y, kernel) + m.optimize() + + m.kern.lengthscale.fix(m.kern.lengthscale.clip(1e-5, 1)) + + if current is None: + m1 = deepcopy(m) + else: + # add the current trials to the dataset + padding = np.array([fixed for _ in range(current.shape[0])]) + current = normalize(current, base_vals) + current = np.hstack((padding, current)) + + Xnew = np.vstack((X, current)) + ypad = np.zeros(current.shape[0]) + ypad = ypad.reshape(-1, 1) + ynew = np.vstack((y, ypad)) + + # kernel = GPy.kern.RBF(input_dim=X.shape[1], variance=1., + # lengthscale=1.) + kernel = TV_SquaredExp( + input_dim=X.shape[1], variance=1.0, lengthscale=1.0, epsilon=0.1 + ) + m1 = GPy.models.GPRegression(Xnew, ynew, kernel) + m1.optimize() + + xt = optimize_acq(UCB, m, m1, fixed, num_f) + + # convert back... + xt = xt * (np.max(base_vals, axis=0) - np.min(base_vals, axis=0)) + np.min( + base_vals, axis=0 + ) + + xt = xt.astype(np.float32) + return xt + + +def _explore( + data: pd.DataFrame, + bounds: Dict[str, Tuple[float, float]], + current: list, + base: Trial, + old: Trial, + config: Dict[str, Tuple[float, float]], +) -> Tuple[Dict, pd.DataFrame]: + """Returns next hyperparameter configuration to use. + + This function primarily processes the data from completed trials + and then requests the next config from the select_config function. + It then adds the new trial to the dataframe, so that the reward change + can be computed using the new weights. + It returns the new point and the dataframe with the new entry. + """ + + df = data.sort_values(by="Time").reset_index(drop=True) + + # Group by trial ID and hyperparams. + # Compute change in timesteps and reward. + df["y"] = df.groupby(["Trial"] + list(bounds.keys()))["Reward"].diff() + df["t_change"] = df.groupby(["Trial"] + list(bounds.keys()))["Time"].diff() + + # Delete entries without positive change in t. + df = df[df["t_change"] > 0].reset_index(drop=True) + df["R_before"] = df.Reward - df.y + + # Normalize the reward change by the update size. + # For example if trials took diff lengths of time. + df["y"] = df.y / df.t_change + df = df[~df.y.isna()].reset_index(drop=True) + df = df.sort_values(by="Time").reset_index(drop=True) + + # Only use the last 1k datapoints, so the GP is not too slow. + df = df.iloc[-1000:, :].reset_index(drop=True) + + # We need this to know the T and Reward for the weights. + dfnewpoint = df[df["Trial"] == str(base)] + + if not dfnewpoint.empty: + # N ow specify the dataset for the GP. + y = np.array(df.y.values) + # Meta data we keep -> episodes and reward. + # (TODO: convert to curve) + t_r = df[["Time", "R_before"]] + hparams = df[bounds.keys()] + X = pd.concat([t_r, hparams], axis=1).values + newpoint = df[df["Trial"] == str(base)].iloc[-1, :][["Time", "R_before"]].values + new = _select_config(X, y, current, newpoint, bounds, num_f=len(t_r.columns)) + + new_config = config.copy() + values = [] + # Cast types for new hyperparameters. + for i, col in enumerate(hparams.columns): + # Use the type from the old config. Like this types + # should be passed on from the first config downwards. + type_ = type(config[col]) + new_config[col] = type_(new[i]) + values.append(type_(new[i])) + + new_T = df[df["Trial"] == str(base)].iloc[-1, :]["Time"] + new_Reward = df[df["Trial"] == str(base)].iloc[-1, :].Reward + + lst = [[str(old)] + [new_T] + values + [new_Reward]] + cols = ["Trial", "Time"] + list(bounds) + ["Reward"] + new_entry = pd.DataFrame(lst, columns=cols) + + # Create an entry for the new config, with the reward from the + # copied agent. + data = pd.concat([data, new_entry]).reset_index(drop=True) + + else: + new_config = config.copy() + + return new_config, data + + +class PB2(PopulationBasedTraining): + """Implements the Population Based Bandit (PB2) algorithm. + + PB2 trains a group of models (or agents) in parallel. Periodically, poorly + performing models clone the state of the top performers, and the hyper- + parameters are re-selected using GP-bandit optimization. The GP model is + trained to predict the improvement in the next training period. + + Like PBT, PB2 adapts hyperparameters during training time. This enables + very fast hyperparameter discovery and also automatically discovers + schedules. + + This Tune PB2 implementation is built on top of Tune's PBT implementation. + It considers all trials added as part of the PB2 population. If the number + of trials exceeds the cluster capacity, they will be time-multiplexed as to + balance training progress across the population. To run multiple trials, + use `tune.TuneConfig(num_samples=)`. + + In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in + `pb2_global.txt` and individual policy perturbations are recorded + in pb2_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, + target trial iteration, clone trial iteration, old config, new config] + on each perturbation step. + + Args: + time_attr: The training result attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + perturbation_interval: Models will be considered for + perturbation at this interval of `time_attr`. Note that + perturbation incurs checkpoint overhead, so you shouldn't set this + to be too frequent. + hyperparam_bounds: Hyperparameters to mutate. The format is + as follows: for each key, enter a list of the form [min, max] + representing the minimum and maximum possible hyperparam values. + A key can also hold a dict for nested hyperparameters. + Tune will sample uniformly between the bounds provided by + `hyperparam_bounds` for the initial hyperparameter values if the + corresponding hyperparameters are not present in a trial's initial `config`. + quantile_fraction: Parameters are transferred from the top + `quantile_fraction` fraction of trials to the bottom + `quantile_fraction` fraction. Needs to be between 0 and 0.5. + Setting it to 0 essentially implies doing no exploitation at all. + custom_explore_fn: You can also specify a custom exploration + function. This function is invoked as `f(config)`, where the input + is the new config generated by Bayesian Optimization. This function + should return the `config` updated as needed. + log_config: Whether to log the ray config of each model to + local_dir at each exploit. Allows config schedule to be + reconstructed. + require_attrs: Whether to require time_attr and metric to appear + in result for every iteration. If True, error will be raised + if these values are not present in trial result. + synch: If False, will use asynchronous implementation of + PBT. Trial perturbations occur every perturbation_interval for each + trial independently. If True, will use synchronous implementation + of PBT. Perturbations will occur only after all trials are + synced at the same time_attr every perturbation_interval. + Defaults to False. See Appendix A.1 here + https://arxiv.org/pdf/1711.09846.pdf. + + Example: + + .. code-block:: python + + from ray import tune + from ray.tune.schedulers.pb2 import PB2 + from ray.tune.examples.pbt_function import pbt_function + # run "pip install gpy" to use PB2 + + pb2 = PB2( + metric="mean_accuracy", + mode="max", + perturbation_interval=20, + hyperparam_bounds={"lr": [0.0001, 0.1]}, + ) + tuner = tune.Tuner( + pbt_function, + tune_config=tune.TuneConfig( + scheduler=pb2, + num_samples=8, + ), + param_space={"lr": 0.0001}, + ) + tuner.fit() + + """ + + def __init__( + self, + time_attr: str = "time_total_s", + metric: Optional[str] = None, + mode: Optional[str] = None, + perturbation_interval: float = 60.0, + hyperparam_bounds: Dict[str, Union[dict, list, tuple]] = None, + quantile_fraction: float = 0.25, + log_config: bool = True, + require_attrs: bool = True, + synch: bool = False, + custom_explore_fn: Optional[Callable[[dict], dict]] = None, + ): + + gpy_available, sklearn_available = import_pb2_dependencies() + if not gpy_available: + raise RuntimeError("Please install GPy to use PB2.") + + if not sklearn_available: + raise RuntimeError("Please install scikit-learn to use PB2.") + + hyperparam_bounds = hyperparam_bounds or {} + + if not hyperparam_bounds: + raise TuneError( + "`hyperparam_bounds` must be specified to use PB2 scheduler." + ) + + super(PB2, self).__init__( + time_attr=time_attr, + metric=metric, + mode=mode, + perturbation_interval=perturbation_interval, + hyperparam_mutations=hyperparam_bounds, + quantile_fraction=quantile_fraction, + resample_probability=0, + custom_explore_fn=custom_explore_fn, + log_config=log_config, + require_attrs=require_attrs, + synch=synch, + ) + + self.last_exploration_time = 0 # when we last explored + self.data = pd.DataFrame() + + self._hyperparam_bounds = hyperparam_bounds + self._hyperparam_bounds_flat = flatten_dict( + hyperparam_bounds, prevent_delimiter=True + ) + self._validate_hyperparam_bounds(self._hyperparam_bounds_flat) + + # Current = trials running that have already re-started after reaching + # the checkpoint. When exploring we care if these trials + # are already in or scheduled to be in the next round. + self.current = None + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + filled_hyperparams = _fill_config(trial.config, self._hyperparam_bounds) + # Make sure that the params we sampled show up in the CLI output + trial.evaluated_params.update(flatten_dict(filled_hyperparams)) + super().on_trial_add(tune_controller, trial) + + def _validate_hyperparam_bounds(self, hyperparam_bounds: dict): + """Check that each hyperparam bound is of the form [low, high]. + + Raises: + ValueError: if any of the hyperparam bounds are of an invalid format. + """ + for key, value in hyperparam_bounds.items(): + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError( + "`hyperparam_bounds` values must either be " + f"a list or tuple of size 2, but got {value} " + f"instead for the param '{key}'" + ) + low, high = value + if low > high: + raise ValueError( + "`hyperparam_bounds` values must be of the form [low, high] " + f"where low <= high, but got {value} instead for param '{key}'." + ) + + def _save_trial_state( + self, state: _PBTTrialState, time: int, result: Dict, trial: Trial + ): + score = super(PB2, self)._save_trial_state(state, time, result, trial) + + # Data logging for PB2. + + # Collect hyperparams names and current values for this trial. + names = list(self._hyperparam_bounds_flat.keys()) + flattened_config = flatten_dict(trial.config) + values = [flattened_config[key] for key in names] + + # Store trial state and hyperparams in dataframe. + # this needs to be made more general. + lst = [[trial, result[self._time_attr]] + values + [score]] + cols = ["Trial", "Time"] + names + ["Reward"] + entry = pd.DataFrame(lst, columns=cols) + + self.data = pd.concat([self.data, entry]).reset_index(drop=True) + self.data.Trial = self.data.Trial.astype("str") + + def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]: + """Gets new config for trial by exploring trial_to_clone's config using + Bayesian Optimization (BO) to choose the hyperparameter values to explore. + + Overrides `PopulationBasedTraining._get_new_config`. + + Args: + trial: The current trial that decided to exploit trial_to_clone. + trial_to_clone: The top-performing trial with a hyperparameter config + that the current trial will explore. + + Returns: + new_config: New hyperparameter configuration (after BO). + operations: Empty dict since PB2 doesn't explore in easily labeled ways + like PBT does. + """ + # If we are at a new timestep, we dont want to penalise for trials + # still going. + if self.data["Time"].max() > self.last_exploration_time: + self.current = None + + new_config_flat, data = _explore( + self.data, + self._hyperparam_bounds_flat, + self.current, + trial_to_clone, + trial, + flatten_dict(trial_to_clone.config), + ) + + # Important to replace the old values, since we are copying across + self.data = data.copy() + + # If the current guy being selecting is at a point that is already + # done, then append the data to the "current" which contains the + # points in the current batch. + new = [new_config_flat[key] for key in self._hyperparam_bounds_flat] + + new = np.array(new) + new = new.reshape(1, new.size) + if self.data["Time"].max() > self.last_exploration_time: + self.last_exploration_time = self.data["Time"].max() + self.current = new.copy() + else: + self.current = np.concatenate((self.current, new), axis=0) + logger.debug(self.current) + + new_config = unflatten_dict(new_config_flat) + + if self._custom_explore_fn: + new_config = self._custom_explore_fn(new_config) + assert ( + new_config is not None + ), "Custom explore function failed to return a new config" + + return new_config, {} diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2_utils.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd06af9b4ef155a3cee6d99a200eaa8ad1e1438 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pb2_utils.py @@ -0,0 +1,183 @@ +import GPy +import numpy as np +from GPy.core import Param +from GPy.kern import Kern +from scipy.optimize import minimize +from sklearn.metrics import pairwise_distances +from sklearn.metrics.pairwise import euclidean_distances + + +class TV_SquaredExp(Kern): + """Time varying squared exponential kernel. + For more info see the TV-GP-UCB paper: + http://proceedings.mlr.press/v51/bogunovic16.pdf + """ + + def __init__( + self, input_dim, variance=1.0, lengthscale=1.0, epsilon=0.0, active_dims=None + ): + super().__init__(input_dim, active_dims, "time_se") + self.variance = Param("variance", variance) + self.lengthscale = Param("lengthscale", lengthscale) + self.epsilon = Param("epsilon", epsilon) + self.link_parameters(self.variance, self.lengthscale, self.epsilon) + + def K(self, X, X2): + # time must be in the far left column + if self.epsilon > 0.5: # 0.5 + self.epsilon = 0.5 + if X2 is None: + X2 = np.copy(X) + T1 = X[:, 0].reshape(-1, 1) + T2 = X2[:, 0].reshape(-1, 1) + dists = pairwise_distances(T1, T2, "cityblock") + timekernel = (1 - self.epsilon) ** (0.5 * dists) + + X = X[:, 1:] + X2 = X2[:, 1:] + + RBF = self.variance * np.exp( + -np.square(euclidean_distances(X, X2)) / self.lengthscale + ) + + return RBF * timekernel + + def Kdiag(self, X): + return self.variance * np.ones(X.shape[0]) + + def update_gradients_full(self, dL_dK, X, X2): + if X2 is None: + X2 = np.copy(X) + T1 = X[:, 0].reshape(-1, 1) + T2 = X2[:, 0].reshape(-1, 1) + + X = X[:, 1:] + X2 = X2[:, 1:] + dist2 = np.square(euclidean_distances(X, X2)) / self.lengthscale + + dvar = np.exp(-np.square((euclidean_distances(X, X2)) / self.lengthscale)) + dl = -( + 2 * euclidean_distances(X, X2) ** 2 * self.variance * np.exp(-dist2) + ) * self.lengthscale ** (-2) + n = pairwise_distances(T1, T2, "cityblock") / 2 + deps = -n * (1 - self.epsilon) ** (n - 1) + + self.variance.gradient = np.sum(dvar * dL_dK) + self.lengthscale.gradient = np.sum(dl * dL_dK) + self.epsilon.gradient = np.sum(deps * dL_dK) + + +def normalize(data, wrt): + """Normalize data to be in range (0,1), with respect to (wrt) boundaries, + which can be specified. + """ + return (data - np.min(wrt, axis=0)) / ( + np.max(wrt, axis=0) - np.min(wrt, axis=0) + 1e-8 + ) + + +def standardize(data): + """Standardize to be Gaussian N(0,1). Clip final values.""" + data = (data - np.mean(data, axis=0)) / (np.std(data, axis=0) + 1e-8) + return np.clip(data, -2, 2) + + +def UCB(m, m1, x, fixed, kappa=None): + """UCB acquisition function. Interesting points to note: + 1) We concat with the fixed points, because we are not optimizing wrt + these. This is the Reward and Time, which we can't change. We want + to find the best hyperparameters *given* the reward and time. + 2) We use m to get the mean and m1 to get the variance. If we already + have trials running, then m1 contains this information. This reduces + the variance at points currently running, even if we don't have + their label. + Ref: https://jmlr.org/papers/volume15/desautels14a/desautels14a.pdf + + """ + + c1 = 0.2 + c2 = 0.4 + beta_t = c1 + max(0, np.log(c2 * m.X.shape[0])) + kappa = np.sqrt(beta_t) if kappa is None else kappa + + xtest = np.concatenate((fixed.reshape(-1, 1), np.array(x).reshape(-1, 1))).T + + try: + preds = m.predict(xtest) + preds = m.predict(xtest) + mean = preds[0][0][0] + except ValueError: + mean = -9999 + + try: + preds = m1.predict(xtest) + var = preds[1][0][0] + except ValueError: + var = 0 + return mean + kappa * var + + +def optimize_acq(func, m, m1, fixed, num_f): + """Optimize acquisition function.""" + + opts = {"maxiter": 200, "maxfun": 200, "disp": False} + + T = 10 + best_value = -999 + best_theta = m1.X[0, :] + + bounds = [(0, 1) for _ in range(m.X.shape[1] - num_f)] + + for ii in range(T): + x0 = np.random.uniform(0, 1, m.X.shape[1] - num_f) + + res = minimize( + lambda x: -func(m, m1, x, fixed), + x0, + bounds=bounds, + method="L-BFGS-B", + options=opts, + ) + + val = func(m, m1, res.x, fixed) + if val > best_value: + best_value = val + best_theta = res.x + + return np.clip(best_theta, 0, 1) + + +def select_length(Xraw, yraw, bounds, num_f): + """Select the number of datapoints to keep, using cross validation""" + min_len = 200 + + if Xraw.shape[0] < min_len: + return Xraw.shape[0] + else: + length = min_len - 10 + scores = [] + while length + 10 <= Xraw.shape[0]: + length += 10 + + base_vals = np.array(list(bounds.values())).T + X_len = Xraw[-length:, :] + y_len = yraw[-length:] + oldpoints = X_len[:, :num_f] + old_lims = np.concatenate( + (np.max(oldpoints, axis=0), np.min(oldpoints, axis=0)) + ).reshape(2, oldpoints.shape[1]) + limits = np.concatenate((old_lims, base_vals), axis=1) + + X = normalize(X_len, limits) + y = standardize(y_len).reshape(y_len.size, 1) + + kernel = TV_SquaredExp( + input_dim=X.shape[1], variance=1.0, lengthscale=1.0, epsilon=0.1 + ) + m = GPy.models.GPRegression(X, y, kernel) + m.optimize(messages=True) + + scores.append(m.log_likelihood()) + idx = np.argmax(scores) + length = (idx + int((min_len / 10))) * 10 + return length diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pbt.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pbt.py new file mode 100644 index 0000000000000000000000000000000000000000..0c389f76dcd0769a153102a78b7717617af3ed85 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/pbt.py @@ -0,0 +1,1182 @@ +import copy +import json +import logging +import math +import os +import random +import shutil +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union + +from ray.air.constants import TRAINING_ITERATION +from ray.train import Checkpoint +from ray.train._internal.session import _FutureTrainingResult, _TrainingResult +from ray.tune.error import TuneError +from ray.tune.experiment import Trial +from ray.tune.result import DEFAULT_METRIC +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.tune.search import SearchGenerator +from ray.tune.search.sample import Domain, Function +from ray.tune.search.variant_generator import format_vars +from ray.tune.utils.util import SafeFallbackEncoder +from ray.util import PublicAPI +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +class _PBTTrialState: + """Internal PBT state tracked per-trial.""" + + def __init__(self, trial: Trial): + self.orig_tag = trial.experiment_tag + self.last_score = None + self.last_checkpoint = None + self.last_perturbation_time = 0 + self.last_train_time = 0 # Used for synchronous mode. + self.last_result = None # Used for synchronous mode. + + def __repr__(self) -> str: + return str( + ( + self.last_score, + self.last_checkpoint, + self.last_train_time, + self.last_perturbation_time, + ) + ) + + +def _explore( + config: Dict, + mutations: Dict, + resample_probability: float, + perturbation_factors: Tuple[float], + custom_explore_fn: Optional[Callable], +) -> Tuple[Dict, Dict]: + """Return a perturbed config and string descriptors of the operations performed + on the original config to produce the new config. + + Args: + config: Original hyperparameter configuration. + mutations: Specification of mutations to perform as documented + in the PopulationBasedTraining scheduler. + resample_probability: Probability of allowing resampling of a + particular variable. + perturbation_factors: Scaling factors to choose between when mutating + a continuous hyperparameter. + custom_explore_fn: Custom explore function applied after built-in + config perturbations. + + Returns: + new_config: New hyperparameter configuration (after random mutations). + operations: Map of hyperparams -> strings describing mutation operations + performed + """ + operations = {} + new_config = copy.deepcopy(config) + for key, distribution in mutations.items(): + if isinstance(distribution, dict): + # Handle nested hyperparameter configs by recursively perturbing them + nested_new_config, nested_ops = _explore( + config[key], + mutations[key], + resample_probability, + perturbation_factors, + custom_explore_fn=None, + ) + new_config.update({key: nested_new_config}) + operations.update({key: nested_ops}) + elif isinstance(distribution, (list, tuple)): + # Case 1: Hyperparameter resample distribution is a list/tuple + if ( + random.random() < resample_probability + or config[key] not in distribution + ): + # Resample a value from the list with `resample_probability` + new_config[key] = random.choice(distribution) + operations[key] = "resample" + else: + # Otherwise, perturb by shifting to the left or right of the list + shift = random.choice([-1, 1]) + old_idx = distribution.index(config[key]) + new_idx = old_idx + shift + new_idx = min(max(new_idx, 0), len(distribution) - 1) + new_config[key] = distribution[new_idx] + operations[key] = ( + f"shift {'left' if shift == -1 else 'right'}" + f"{' (noop)' if old_idx == new_idx else ''}" + ) + elif isinstance(distribution, (Domain, Callable)): + # Case 2: Hyperparameter resample distribution is: + # 1. a function (ex: lambda: np.random.uniform(0, 1)) + # 2. tune search Domain (ex: tune.uniform(0, 1)) + if random.random() < resample_probability: + # Resample a value from the function/domain with `resample_probability` + new_config[key] = ( + distribution.sample(None) + if isinstance(distribution, Domain) + else distribution() + ) + operations[key] = "resample" + else: + # Otherwise, perturb by multiplying the hyperparameter by one + # of the `perturbation_factors` + perturbation_factor = random.choice(perturbation_factors) + new_config[key] = config[key] * perturbation_factor + operations[key] = f"* {perturbation_factor}" + if isinstance(config[key], int): + # If this hyperparameter started out as an integer (ex: `batch_size`), + # convert the new value back + new_config[key] = int(new_config[key]) + else: + raise ValueError( + f"Unsupported hyperparameter distribution type: {type(distribution)}" + ) + if custom_explore_fn: + # The user can perform any additional hyperparameter exploration + # via `custom_explore_fn` + new_config = custom_explore_fn(new_config) + assert new_config is not None, "Custom explore fn failed to return new config" + return new_config, operations + + +def _make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str: + """Appends perturbed params to the trial name to show in the console.""" + + resolved_vars = {} + for k in mutations.keys(): + resolved_vars[("config", k)] = config[k] + return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars)) + + +def _fill_config( + config: Dict, attr: str, search_space: Union[dict, list, tuple, Callable, Domain] +): + """Add attr to config by sampling from search_space. + + This is a helper used to set initial hyperparameter values if the user doesn't + specify them in the Tuner `param_space`. + """ + if isinstance(search_space, Callable): + config[attr] = search_space() + elif isinstance(search_space, Domain): + config[attr] = search_space.sample(None) + elif isinstance(search_space, (list, tuple)): + config[attr] = random.choice(search_space) + elif isinstance(search_space, dict): + config[attr] = {} + for k, v in search_space.items(): + _fill_config(config[attr], k, v) + + +def _filter_mutated_params_from_config( + config: Dict, hyperparam_mutations: Dict +) -> Dict: + """Filter out hyperparameters from a config so that only parameters specified + within hyperparam_mutations remain. This recursively filters nested configs. + + Example: + >>> config = { + ... "a": {"b": 2, "c": 0, "d": {"e": 0.1}}, + ... "f": {"g": 0.5}, + ... } + >>> hyperparam_mutations = { + ... "a": {"b": [1, 2], "c": [-1, 0]}, + ... } + >>> _filter_mutated_params_from_config(config, hyperparam_mutations) == { + ... "a": {"b": 2, "c": 0} + ... } + True + + Args: + config: The config dict that we want to filter. + hyperparam_mutations: A dict containing a subset of hyperparameters from + config, used to filter the config. + + Returns: + mutated_params: A copy of config containing only params specified in + hyperparam_mutations + """ + mutated_params = {} + for param_name in config: + if param_name not in hyperparam_mutations: + continue + + if isinstance(config[param_name], dict): + nested_params = _filter_mutated_params_from_config( + config[param_name], hyperparam_mutations[param_name] + ) + mutated_params[param_name] = nested_params + else: + mutated_params[param_name] = config[param_name] + return mutated_params + + +@PublicAPI +class PopulationBasedTraining(FIFOScheduler): + """Implements the Population Based Training (PBT) algorithm. + + https://www.deepmind.com/blog/population-based-training-of-neural-networks + + PBT trains a group of models (or agents) in parallel. Periodically, poorly + performing models clone the state of the top performers, and a random + mutation is applied to their hyperparameters in the hopes of + outperforming the current top models. + + Unlike other hyperparameter search algorithms, PBT mutates hyperparameters + during training time. This enables very fast hyperparameter discovery and + also automatically discovers good annealing schedules. + + This Tune PBT implementation considers all trials added as part of the + PBT population. If the number of trials exceeds the cluster capacity, + they will be time-multiplexed as to balance training progress across the + population. To run multiple trials, use `tune.TuneConfig(num_samples=)`. + + In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in + `pbt_global.txt` and individual policy perturbations are recorded + in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, + target trial iteration, clone trial iteration, old config, new config] + on each perturbation step. + + Args: + time_attr: The training result attr to use for comparing time. + Note that you can pass in something non-temporal such as + `training_iteration` as a measure of progress, the only requirement + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. If None but a mode was passed, + the `ray.tune.result.DEFAULT_METRIC` will be used per default. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. + perturbation_interval: Models will be considered for + perturbation at this interval of `time_attr`. Note that + perturbation incurs checkpoint overhead, so you shouldn't set this + to be too frequent. + burn_in_period: Models will not be considered for + perturbation before this interval of `time_attr` has passed. This + guarantees that models are trained for at least a certain amount + of time or timesteps before being perturbed. + hyperparam_mutations: Hyperparams to mutate. The format is + as follows: for each key, either a list, function, + or a tune search space object (tune.loguniform, tune.uniform, + etc.) can be provided. A list specifies an allowed set of + categorical values. A function or tune search space object + specifies the distribution of a continuous parameter. You must + use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary + tune.sample_from objects are not supported. + A key can also hold a dict for nested hyperparameters. + You must specify at least one of `hyperparam_mutations` or + `custom_explore_fn`. + Tune will sample the search space provided by + `hyperparam_mutations` for the initial hyperparameter values if the + corresponding hyperparameters are not present in a trial's initial `config`. + quantile_fraction: Parameters are transferred from the top + `quantile_fraction` fraction of trials to the bottom + `quantile_fraction` fraction. Needs to be between 0 and 0.5. + Setting it to 0 essentially implies doing no exploitation at all. + resample_probability: The probability of resampling from the + original distribution when applying `hyperparam_mutations`. If not + resampled, the value will be perturbed by a factor chosen from + `perturbation_factors` if continuous, or changed to an adjacent value + if discrete. + perturbation_factors: Scaling factors to choose between when mutating + a continuous hyperparameter. + custom_explore_fn: You can also specify a custom exploration + function. This function is invoked as `f(config)` after built-in + perturbations from `hyperparam_mutations` are applied, and should + return `config` updated as needed. You must specify at least one of + `hyperparam_mutations` or `custom_explore_fn`. + log_config: Whether to log the ray config of each model to + local_dir at each exploit. Allows config schedule to be + reconstructed. + require_attrs: Whether to require time_attr and metric to appear + in result for every iteration. If True, error will be raised + if these values are not present in trial result. + synch: If False, will use asynchronous implementation of + PBT. Trial perturbations occur every perturbation_interval for each + trial independently. If True, will use synchronous implementation + of PBT. Perturbations will occur only after all trials are + synced at the same time_attr every perturbation_interval. + Defaults to False. See Appendix A.1 here + https://arxiv.org/pdf/1711.09846.pdf. + + .. code-block:: python + + import random + from ray import tune + from ray.tune.schedulers import PopulationBasedTraining + + pbt = PopulationBasedTraining( + time_attr="training_iteration", + metric="episode_reward_mean", + mode="max", + perturbation_interval=10, # every 10 `time_attr` units + # (training_iterations in this case) + hyperparam_mutations={ + # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling + # resets it to a value sampled from the lambda function. + "factor_1": lambda: random.uniform(0.0, 20.0), + # Alternatively, use tune search space primitives. + # The search space for factor_1 is equivalent to factor_2. + "factor_2": tune.uniform(0.0, 20.0), + # Perturb factor3 by changing it to an adjacent value, e.g. + # 10 -> 1 or 10 -> 100. Resampling will choose at random. + "factor_3": [1, 10, 100, 1000, 10000], + # Using tune.choice is NOT equivalent to the above. + # factor_4 is treated as a continuous hyperparameter. + "factor_4": tune.choice([1, 10, 100, 1000, 10000]), + }) + tuner = tune.Tuner( + trainable, + tune_config=tune.TuneConfig( + scheduler=pbt, + num_samples=8, + ), + ) + tuner.fit() + + """ + + def __init__( + self, + time_attr: str = "time_total_s", + metric: Optional[str] = None, + mode: Optional[str] = None, + perturbation_interval: float = 60.0, + burn_in_period: float = 0.0, + hyperparam_mutations: Dict[ + str, Union[dict, list, tuple, Callable, Domain] + ] = None, + quantile_fraction: float = 0.25, + resample_probability: float = 0.25, + perturbation_factors: Tuple[float, float] = (1.2, 0.8), + custom_explore_fn: Optional[Callable] = None, + log_config: bool = True, + require_attrs: bool = True, + synch: bool = False, + ): + hyperparam_mutations = hyperparam_mutations or {} + for value in hyperparam_mutations.values(): + if not isinstance(value, (dict, list, tuple, Domain, Callable)): + raise TypeError( + "`hyperparam_mutation` values must be either " + "a List, Tuple, Dict, a tune search space object, or " + "a callable." + ) + if isinstance(value, Function): + raise ValueError( + "arbitrary tune.sample_from objects are not " + "supported for `hyperparam_mutation` values." + "You must use other built in primitives like" + "tune.uniform, tune.loguniform, etc." + ) + + if not hyperparam_mutations and not custom_explore_fn: + raise TuneError( + "You must specify at least one of `hyperparam_mutations` " + "or `custom_explore_fn` to use PBT." + ) + + if quantile_fraction > 0.5 or quantile_fraction < 0: + raise ValueError( + "You must set `quantile_fraction` to a value between 0 and" + "0.5. Current value: '{}'".format(quantile_fraction) + ) + + if perturbation_interval <= 0: + raise ValueError( + "perturbation_interval must be a positive number greater " + "than 0. Current value: '{}'".format(perturbation_interval) + ) + + if mode: + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'." + + super().__init__() + self._metric = metric + self._mode = mode + self._metric_op = None + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + self._time_attr = time_attr + self._perturbation_interval = perturbation_interval + self._burn_in_period = burn_in_period + self._hyperparam_mutations = hyperparam_mutations + self._quantile_fraction = quantile_fraction + self._resample_probability = resample_probability + self._perturbation_factors = perturbation_factors + self._trial_state = {} + self._custom_explore_fn = custom_explore_fn + self._log_config = log_config + self._require_attrs = require_attrs + self._synch = synch + self._next_perturbation_sync = max( + self._perturbation_interval, + self._burn_in_period, + ) + + # Metrics + self._num_checkpoints = 0 + self._num_perturbations = 0 + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + if self._metric and metric: + return False + if self._mode and mode: + return False + + if metric: + self._metric = metric + if mode: + self._mode = mode + + if self._mode == "max": + self._metric_op = 1.0 + elif self._mode == "min": + self._metric_op = -1.0 + + if self._metric is None and self._mode: + # If only a mode was passed, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + if tune_controller.search_alg is not None and isinstance( + tune_controller.search_alg, SearchGenerator + ): + raise ValueError( + "Search algorithms cannot be used with {} " + "schedulers. Please remove {}.".format( + self.__class__.__name__, tune_controller.search_alg + ) + ) + + if not self._metric or not self._metric_op: + raise ValueError( + "{} has been instantiated without a valid `metric` ({}) or " + "`mode` ({}) parameter. Either pass these parameters when " + "instantiating the scheduler, or pass them as parameters " + "to `tune.TuneConfig()`".format( + self.__class__.__name__, self._metric, self._mode + ) + ) + + checkpoint_config = trial.run_metadata.checkpoint_manager.checkpoint_config + if ( + checkpoint_config.num_to_keep + and checkpoint_config.num_to_keep <= 2 + and log_once("pbt_num_to_keep") + ): + warnings.warn( + "Using `CheckpointConfig.num_to_keep <= 2` with PBT can lead to " + "restoration problems when checkpoint are deleted too early for " + "other trials to exploit them. If this happens, increase the value " + "of `num_to_keep`." + ) + + self._trial_state[trial] = _PBTTrialState(trial) + + for attr in self._hyperparam_mutations.keys(): + if attr not in trial.config: + if log_once(attr + "-missing"): + logger.debug( + "Cannot find {} in config. Using search " + "space provided by hyperparam_mutations." + ) + # Add attr to trial's config by sampling search space from + # hyperparam_mutations. + _fill_config(trial.config, attr, self._hyperparam_mutations[attr]) + # Make sure this attribute is added to CLI output. + trial.evaluated_params[attr] = trial.config[attr] + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + if self._time_attr not in result: + time_missing_msg = ( + "Cannot find time_attr {} " + "in trial result {}. Make sure that this " + "attribute is returned in the " + "results of your Trainable.".format(self._time_attr, result) + ) + if self._require_attrs: + raise RuntimeError( + time_missing_msg + + "If this error is expected, you can change this to " + "a warning message by " + "setting PBT(require_attrs=False)" + ) + else: + if log_once("pbt-time_attr-error"): + logger.warning(time_missing_msg) + if self._metric not in result: + metric_missing_msg = ( + "Cannot find metric {} in trial result {}. " + "Make sure that this attribute is returned " + "in the " + "results of your Trainable.".format(self._metric, result) + ) + if self._require_attrs: + raise RuntimeError( + metric_missing_msg + "If this error is expected, " + "you can change this to a warning message by " + "setting PBT(require_attrs=False)" + ) + else: + if log_once("pbt-metric-error"): + logger.warning(metric_missing_msg) + + if self._metric not in result or self._time_attr not in result: + return TrialScheduler.CONTINUE + + time = result[self._time_attr] + state = self._trial_state[trial] + + # Continue training if burn-in period has not been reached, yet. + if time < self._burn_in_period: + logger.debug(f"Still in burn-in period: {time} < {self._burn_in_period}") + return TrialScheduler.CONTINUE + + # Continue training if perturbation interval has not been reached, yet. + time_since_perturb = time - state.last_perturbation_time + if time_since_perturb < self._perturbation_interval: + logger.debug( + f"Perturbation interval not reached: " + f"{time_since_perturb} < {self._perturbation_interval}" + ) + return TrialScheduler.CONTINUE # avoid checkpoint overhead + + logger.debug(f"Updating trial state for trial {trial} at time {time}") + self._save_trial_state(state, time, result, trial) + + if not self._synch: + state.last_perturbation_time = time + lower_quantile, upper_quantile = self._quantiles() + decision = TrialScheduler.CONTINUE + for other_trial in tune_controller.get_trials(): + if other_trial.status in [Trial.PENDING, Trial.PAUSED]: + decision = TrialScheduler.PAUSE + break + self._checkpoint_or_exploit( + trial, tune_controller, upper_quantile, lower_quantile + ) + return TrialScheduler.NOOP if trial.status == Trial.PAUSED else decision + else: + # Synchronous mode. + if any( + self._trial_state[t].last_train_time < self._next_perturbation_sync + and t != trial + for t in tune_controller.get_live_trials() + ): + logger.debug( + f"Sync: Other trials are not at perturb time, yet. " + f"Pausing trial {trial} to wait." + ) + else: + # All trials are synced at the same timestep. + logger.debug("Sync: All trials are at perturb time.") + lower_quantile, upper_quantile = self._quantiles() + all_trials = tune_controller.get_trials() + not_in_quantile = [] + for t in all_trials: + if t not in lower_quantile and t not in upper_quantile: + not_in_quantile.append(t) + + logger.debug( + "Trial statistics\n" + f"Upper quantile: {upper_quantile}\n" + f"Lower quantile: {lower_quantile}\n" + f"Not in quantile: {not_in_quantile}" + ) + + # Move upper quantile trials to beginning and lower quantile + # to end. This ensures that checkpointing of strong trials + # occurs before exploiting of weaker ones. + all_trials = upper_quantile + not_in_quantile + lower_quantile + for t in all_trials: + logger.debug(f"Perturbing trial {t}") + self._trial_state[t].last_perturbation_time = time + self._checkpoint_or_exploit( + t, tune_controller, upper_quantile, lower_quantile + ) + + all_train_times = [ + self._trial_state[t].last_train_time + for t in tune_controller.get_trials() + ] + max_last_train_time = max(all_train_times) + self._next_perturbation_sync = max( + self._next_perturbation_sync + self._perturbation_interval, + max_last_train_time, + ) + logger.debug(f"Next perturb at time {self._next_perturbation_sync}") + # In sync mode we should pause all trials once result comes in. + # Once a perturbation step happens for all trials, they should + # still all be paused. + # choose_trial_to_run will then pick the next trial to run out of + # the paused trials. + return ( + TrialScheduler.NOOP + if trial.status == Trial.PAUSED + else TrialScheduler.PAUSE + ) + + def _save_trial_state( + self, state: _PBTTrialState, time: int, result: Dict, trial: Trial + ): + """Saves necessary trial information when result is received. + Args: + state: The state object for the trial. + time: The current timestep of the trial. + result: The trial's result dictionary. + trial: The trial object. + """ + + # This trial has reached its perturbation interval. + # Record new state in the state object. + score = self._metric_op * result[self._metric] + state.last_score = score + state.last_train_time = time + state.last_result = result + + return score + + def _checkpoint_or_exploit( + self, + trial: Trial, + tune_controller: "TuneController", + upper_quantile: List[Trial], + lower_quantile: List[Trial], + ): + """Checkpoint if in upper quantile, exploits if in lower.""" + state = self._trial_state[trial] + if trial in upper_quantile: + # The trial last result is only updated after the scheduler + # callback. So, we override with the current result. + logger.debug(f"Trial {trial} is in upper quantile. Saving checkpoint.") + if trial.status == Trial.PAUSED: + if trial.temporary_state.saving_to and isinstance( + trial.temporary_state.saving_to, _FutureTrainingResult + ): + logger.debug(f"Trial {trial} is still saving.") + state.last_checkpoint = trial.temporary_state.saving_to + else: + # Paused trial will always have an in-memory checkpoint. + logger.debug( + f"Trial {trial} is paused. Use last available " + f"checkpoint {trial.checkpoint}." + ) + state.last_checkpoint = trial.checkpoint + else: + logger.debug(f"Instructing {trial} to save.") + state.last_checkpoint = tune_controller._schedule_trial_save( + trial, result=state.last_result + ) + self._num_checkpoints += 1 + else: + state.last_checkpoint = None # not a top trial + + if trial in lower_quantile: + trial_to_clone = random.choice(upper_quantile) + assert trial is not trial_to_clone + clone_state = self._trial_state[trial_to_clone] + last_checkpoint = clone_state.last_checkpoint + + logger.debug( + f"Trial {trial} is in lower quantile. " + f"Exploiting trial {trial_to_clone}." + ) + + if isinstance(last_checkpoint, _FutureTrainingResult): + training_result = last_checkpoint.resolve() + + if training_result: + clone_state.last_result = training_result.metrics + clone_state.last_checkpoint = training_result.checkpoint + last_checkpoint = clone_state.last_checkpoint + else: + logger.debug( + "PBT-scheduled checkpoint save resolved to None. Trial " + f"{trial_to_clone} didn't save any checkpoint before " + f"and can't be exploited." + ) + last_checkpoint = None + + if not last_checkpoint: + logger.info( + f"[pbt]: no checkpoint for trial {trial_to_clone}." + f" Skip exploit for Trial {trial}" + ) + return + self._exploit(tune_controller, trial, trial_to_clone) + + def _log_config_on_step( + self, + trial_state: _PBTTrialState, + new_state: _PBTTrialState, + trial: Trial, + trial_to_clone: Trial, + new_config: Dict, + ): + """Logs transition during exploit/exploit step. + + For each step, logs: [target trial tag, clone trial tag, target trial + iteration, clone trial iteration, old config, new config]. + """ + trial_name, trial_to_clone_name = (trial_state.orig_tag, new_state.orig_tag) + trial_id = trial.trial_id + trial_to_clone_id = trial_to_clone.trial_id + trial_path = os.path.join( + trial.local_experiment_path, "pbt_policy_" + trial_id + ".txt" + ) + trial_to_clone_path = os.path.join( + trial_to_clone.local_dir, "pbt_policy_" + trial_to_clone_id + ".txt" + ) + policy = [ + trial_name, + trial_to_clone_name, + trial.last_result.get(TRAINING_ITERATION, 0), + trial_to_clone.last_result.get(TRAINING_ITERATION, 0), + trial_to_clone.config, + new_config, + ] + # Log to global file. + with open( + os.path.join(trial.local_experiment_path, "pbt_global.txt"), "a+" + ) as f: + print(json.dumps(policy, cls=SafeFallbackEncoder), file=f) + # Overwrite state in target trial from trial_to_clone. + if os.path.exists(trial_to_clone_path): + shutil.copyfile(trial_to_clone_path, trial_path) + # Log new exploit in target trial log. + with open(trial_path, "a+") as f: + f.write(json.dumps(policy, cls=SafeFallbackEncoder) + "\n") + + def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]: + """Gets new config for trial by exploring trial_to_clone's config. + + Args: + trial: The current trial that decided to exploit trial_to_clone. + trial_to_clone: The top-performing trial with a hyperparameter config + that the current trial will explore by perturbing. + + Returns: + new_config: New hyperparameter configuration (after random mutations). + operations: Map of hyperparams -> strings describing mutation operations + performed + """ + return _explore( + trial_to_clone.config, + self._hyperparam_mutations, + self._resample_probability, + self._perturbation_factors, + self._custom_explore_fn, + ) + + def _summarize_hyperparam_changes( + self, + old_params: Dict, + new_params: Dict, + operations: Optional[Dict] = None, + prefix: str = "", + ) -> str: + """Generates a summary of hyperparameter changes from a PBT "explore" step. + + Example: + Given the following hyperparam_mutations: + + hyperparam_mutations = { + "a": tune.uniform(0, 1), + "b": list(range(5)), + "c": { + "d": tune.uniform(2, 3), + "e": {"f": [-1, 0, 1]}, + }, + } + + This is an example summary output of the operations performed on old_params + to get new_params: + + a : 0.5 --- (* 0.8) --> 0.4 + b : 2 --- (resample) --> 4 + c : + d : 2.5 --- (* 1.2) --> 3.0 + e : + f : 0 --- (shift right) --> 1 + + The summary shows the old and new hyperparameter values, with the operation + used to perturb labeled in between. + If the operation for a certain hyperparameter is not provided, then the summary + will just contain arrows without a label. (ex: a : 0.5 -----> 0.4) + + Args: + old_params: Old values of hyperparameters that are perturbed to generate + the new config + new_params: The newly generated hyperparameter config from PBT exploration + operations: Map of hyperparams -> string descriptors the operations + performed to generate the values in `new_params` + prefix: Helper argument to format nested dict hyperparam configs + + Returns: + summary_str: The hyperparameter change summary to print/log. + """ + summary_str = "" + if not old_params: + return summary_str + for param_name in old_params: + old_val = old_params[param_name] + assert param_name in new_params, ( + "`old_params` and `new_params` " + f"must both contain the key: '{param_name}'\n" + f"old_params.keys() = {old_params.keys()}\n" + f"new_params.keys() = {new_params.keys()}" + ) + new_val = new_params[param_name] + summary_str += f"{prefix}{param_name} : " + if isinstance(old_val, Dict): + # Handle nested hyperparameters by recursively summarizing + summary_str += "\n" + nested_operations = operations.get(param_name, {}) + summary_str += self._summarize_hyperparam_changes( + old_val, + new_val, + operations=nested_operations, + prefix=prefix + " " * 4, + ) + else: + op = operations.get(param_name, None) + if not op: + arrow = "----->" + else: + arrow = f"--- ({op}) -->" + summary_str += f"{old_val} {arrow} {new_val}\n" + return summary_str + + def _exploit( + self, + tune_controller: "TuneController", + trial: Trial, + trial_to_clone: Trial, + ): + """Transfers perturbed state from trial_to_clone -> trial. + + If specified, also logs the updated hyperparam state. + """ + trial_state = self._trial_state[trial] + new_state = self._trial_state[trial_to_clone] + class_name = self.__class__.__name__ + logger.info( + f"\n\n[{class_name}] [Exploit] Cloning trial " + "{} (score = {:4f}) into trial {} (score = {:4f})\n".format( + trial_to_clone.trial_id, + new_state.last_score, + trial.trial_id, + trial_state.last_score, + ) + ) + + new_config, operations = self._get_new_config(trial, trial_to_clone) + + # Only log mutated hyperparameters and not entire config. + old_params = _filter_mutated_params_from_config( + trial_to_clone.config, self._hyperparam_mutations + ) + new_params = _filter_mutated_params_from_config( + new_config, self._hyperparam_mutations + ) + explore_info_str = ( + f"\n\n[{class_name}] [Explore] Perturbed the hyperparameter config of trial" + f"{trial.trial_id}:\n" + ) + explore_info_str += ( + self._summarize_hyperparam_changes(old_params, new_params, operations) + or "No hyperparameters mutated." + ) + logger.info(explore_info_str) + + if self._log_config: + self._log_config_on_step( + trial_state, new_state, trial, trial_to_clone, new_config + ) + + new_tag = _make_experiment_tag( + trial_state.orig_tag, new_config, self._hyperparam_mutations + ) + if trial.status == Trial.PAUSED: + # If trial is paused we update it with a new checkpoint. + # When the trial is started again, the new checkpoint is used. + if not self._synch: + raise TuneError( + "Trials should be paused here only if in " + "synchronous mode. If you encounter this error" + " please raise an issue on Ray Github." + ) + else: + tune_controller.pause_trial(trial, should_checkpoint=False) + trial.set_experiment_tag(new_tag) + # Clone hyperparameters from the `trial_to_clone` + trial.set_config(new_config) + + # Resume training from a shallow copy of `trial_to_clone`'s latest + # checkpoint + checkpoint_to_exploit: Checkpoint = copy.copy(new_state.last_checkpoint) + + trial.run_metadata.checkpoint_manager._latest_checkpoint_result = ( + _TrainingResult( + checkpoint=checkpoint_to_exploit, metrics=new_state.last_result + ) + ) + + self._num_perturbations += 1 + # Transfer over the last perturbation time as well + trial_state.last_perturbation_time = new_state.last_perturbation_time + trial_state.last_train_time = new_state.last_train_time + + def _quantiles(self) -> Tuple[List[Trial], List[Trial]]: + """Returns trials in the lower and upper `quantile` of the population. + + If there is not enough data to compute this, returns empty lists. + """ + trials = [] + for trial, state in self._trial_state.items(): + logger.debug("Trial {}, state {}".format(trial, state)) + if trial.is_finished(): + logger.debug("Trial {} is finished".format(trial)) + if state.last_score is not None and not trial.is_finished(): + trials.append(trial) + trials.sort(key=lambda t: self._trial_state[t].last_score) + + if len(trials) <= 1: + return [], [] + else: + num_trials_in_quantile = int( + math.ceil(len(trials) * self._quantile_fraction) + ) + if num_trials_in_quantile > len(trials) / 2: + num_trials_in_quantile = int(math.floor(len(trials) / 2)) + return (trials[:num_trials_in_quantile], trials[-num_trials_in_quantile:]) + + def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]: + """Ensures all trials get fair share of time (as defined by time_attr). + + This enables the PBT scheduler to support a greater number of + concurrent trials than can fit in the cluster at any given time. + """ + candidates = [] + for trial in tune_controller.get_trials(): + if trial.status in [ + Trial.PENDING, + Trial.PAUSED, + ]: + if not self._synch: + candidates.append(trial) + elif ( + self._trial_state[trial].last_train_time + < self._next_perturbation_sync + ): + candidates.append(trial) + candidates.sort(key=lambda trial: self._trial_state[trial].last_train_time) + return candidates[0] if candidates else None + + # Unit test only. TODO(xwjiang): Remove test-specific APIs. + def reset_stats(self): + self._num_perturbations = 0 + self._num_checkpoints = 0 + + # Unit test only. TODO(xwjiang): Remove test-specific APIs. + def last_scores(self, trials: List[Trial]) -> List[float]: + scores = [] + for trial in trials: + state = self._trial_state[trial] + if state.last_score is not None and not trial.is_finished(): + scores.append(state.last_score) + return scores + + def debug_string(self) -> str: + return "PopulationBasedTraining: {} checkpoints, {} perturbs".format( + self._num_checkpoints, self._num_perturbations + ) + + +@PublicAPI +class PopulationBasedTrainingReplay(FIFOScheduler): + """Replays a Population Based Training run. + + Population Based Training does not return a single hyperparameter + configuration, but rather a schedule of configurations. For instance, + PBT might discover that a larger learning rate leads to good results + in the first training iterations, but that a smaller learning rate + is preferable later. + + This scheduler enables replaying these parameter schedules from + a finished PBT run. This requires that population based training has + been run with ``log_config=True``, which is the default setting. + + The scheduler will only accept and train a single trial. It will + start with the initial config of the existing trial and update the + config according to the schedule. + + Args: + policy_file: The PBT policy file. Usually this is + stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt`` + where ``xxx`` is the trial ID. + + Example: + + .. code-block:: python + + # Replaying a result from ray.tune.examples.pbt_convnet_example + from ray import train, tune + + from ray.tune.examples.pbt_convnet_example import PytorchTrainable + from ray.tune.schedulers import PopulationBasedTrainingReplay + + replay = PopulationBasedTrainingReplay( + "~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt") + + tuner = tune.Tuner( + PytorchTrainable, + run_config=train.RunConfig( + stop={"training_iteration": 100} + ), + tune_config=tune.TuneConfig( + scheduler=replay, + ), + ) + tuner.fit() + + + """ + + def __init__(self, policy_file: str): + policy_file = Path(policy_file).expanduser() + if not policy_file.exists(): + raise ValueError("Policy file not found: {}".format(policy_file.as_posix())) + + self.policy_file = policy_file.as_posix() + + # Find and read pbt policy file, potentially raise error + initial_config, self._policy = self._load_policy(self.policy_file) + + self.experiment_tag = "replay_{}".format(os.path.basename(self.policy_file)) + self.config = initial_config + self.current_config = self.config + + self._trial = None + self._current_step = 0 + self._num_perturbations = 0 + + self._policy_iter = iter(self._policy) + self._next_policy = next(self._policy_iter, None) + + def _load_policy(self, policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]: + raw_policy = [] + with open(policy_file, "rt") as fp: + for row in fp.readlines(): + try: + parsed_row = json.loads(row) + except json.JSONDecodeError: + raise ValueError( + "Could not read PBT policy file: {}.".format(policy_file) + ) from None + raw_policy.append(tuple(parsed_row)) + + # Loop through policy from end to start to obtain changepoints + policy = [] + last_new_tag = None + last_old_conf = None + for old_tag, new_tag, old_step, new_step, old_conf, new_conf in reversed( + raw_policy + ): + if last_new_tag and old_tag != last_new_tag: + # Tag chain ended. This means that previous changes were + # overwritten by the last change and should be ignored. + break + last_new_tag = new_tag + last_old_conf = old_conf + + policy.append((new_step, new_conf)) + + return last_old_conf, list(reversed(policy)) + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + if self._trial: + raise ValueError( + "More than one trial added to PBT replay run. This " + "means the same schedule will be trained multiple " + "times. Do you want to set `n_samples=1`?" + ) + self._trial = trial + if self._trial.config and self._policy: + logger.warning( + "Trial was initialized with a config, which was overwritten. " + "Did you start the PBT replay with a `config` parameter?" + ) + elif self._trial.config and not self._policy: + # Only train with initial policy + self.config = self._trial.config + elif not self._trial.config and not self._policy: + raise ValueError( + "No replay policy found and trial initialized without a " + "valid config. Either pass a `config` argument to `tune.Tuner()`" + "or consider not using PBT replay for this run." + ) + self._trial.set_config(self.config) + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + if TRAINING_ITERATION not in result: + # No time reported + return TrialScheduler.CONTINUE + + if not self._next_policy: + # No more changes in the config + return TrialScheduler.CONTINUE + + step = result[TRAINING_ITERATION] + self._current_step = step + + change_at, new_config = self._next_policy + + if step < change_at: + # Don't change the policy just yet + return TrialScheduler.CONTINUE + + logger.info( + "Population Based Training replay is now at step {}. " + "Configuration will be changed to {}.".format(step, new_config) + ) + + result = tune_controller._schedule_trial_save(trial, result=result) + training_result = result.resolve() + trial.run_metadata.checkpoint_manager._latest_checkpoint_result = ( + training_result + ) + + new_tag = _make_experiment_tag(self.experiment_tag, new_config, new_config) + + tune_controller.pause_trial(trial, should_checkpoint=False) + trial.set_experiment_tag(new_tag) + trial.set_config(new_config) + + self.current_config = new_config + self._num_perturbations += 1 + self._next_policy = next(self._policy_iter, None) + + return TrialScheduler.NOOP + + def debug_string(self) -> str: + return "PopulationBasedTraining replay: Step {}, perturb {}".format( + self._current_step, self._num_perturbations + ) diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/resource_changing_scheduler.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/resource_changing_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..24d437cf892f9eded05ac2b6ec511900e4e3b911 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/resource_changing_scheduler.py @@ -0,0 +1,871 @@ +import logging +import pickle +import warnings +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import numpy as np + +from ray.air.execution.resources.request import _sum_bundles +from ray.tune.execution.placement_groups import PlacementGroupFactory +from ray.tune.experiment import Trial +from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="beta") +class DistributeResources: + """This class creates a basic uniform resource allocation function. + + The function naively balances free resources (CPUs and GPUs) between + trials, giving them all equal priority, ensuring that all resources + are always being used. The free resources will be placed in new bundles. + The function assumes that all bundles are equal (there is no "head" + bundle). + + If for some reason a trial ends up with + more resources than there are free ones, it will adjust downwards. + It will also ensure that trial as at least as many resources as + it started with (``base_trial_resource``). + + The function returns a new ``PlacementGroupFactory`` with updated + resource requirements, or None. If the returned + ``PlacementGroupFactory`` is equal by value to the one the + trial has currently, the scheduler will skip the update process + internally (same with None). + + If you wish to implement your own resource distribution logic, + you can do so by extending this class, as it provides several + generic methods. You can also implement a function instead. + + Args: + add_bundles: If True, create new bundles from free resources. + Otherwise, spread them among base_trial_resource bundles. + increase_by: A dict with key-value + pairs representing an atomic unit of resources (name-amount) + the trial will be increased by. If not set, the trial will + increase by 1 CPU/GPU. + increase_by_times: If set to >=1 and ``increase_by`` is set, + the trial will increase by maximum of + ``increase_by_times * increase_by`` resources. If set to <1, + no upper limit is set. Ignored if ``increase_by`` is not set. + reserve_resources: A dict of + resource_name-amount pairs representing the resources + that will not be allocated to resized trials. + """ + + def __init__( + self, + add_bundles: bool = False, + increase_by: Optional[Dict[str, float]] = None, + increase_by_times: int = -1, + reserve_resources: Optional[Dict[str, float]] = None, + ): + self.add_bundles = add_bundles + self.increase_by = increase_by or {} + self.increase_by_times = increase_by_times + self.reserve_resources = reserve_resources or {} + + def _validate( + self, base_trial_resource: PlacementGroupFactory, result: Dict[str, Any] + ) -> bool: + """Return False if we should keep the current resources outright.""" + if not isinstance(base_trial_resource, PlacementGroupFactory): + raise ValueError( + f"{self.__class__.__name__} only supports PlacementGroupFactories." + ) + + if not self.add_bundles and len(base_trial_resource.bundles) > 1: + raise ValueError( + "If `add_bundles` is False, the number of bundles in " + "`resources_per_trial` must be 1 " + f"(got {len(base_trial_resource.bundles)})." + ) + + # Don't bother if this is just the first iteration + if result["training_iteration"] < 1: + return False + return True + + def _get_total_available_resources( + self, tune_controller: "TuneController" + ) -> Tuple[float, float]: + """Get the number of CPUs and GPUs avaialble in total (not just free)""" + total_available_cpus = ( + tune_controller._resource_updater.get_num_cpus() + - self.reserve_resources.get("CPU", 0) + ) + total_available_gpus = ( + tune_controller._resource_updater.get_num_gpus() + - self.reserve_resources.get("GPU", 0) + ) + return total_available_cpus, total_available_gpus + + def _get_used_cpus_and_gpus(self, t: Trial) -> Tuple[float, float]: + """Check how many CPUs and GPUs a trial is using currently""" + return ( + t.placement_group_factory.required_resources.get("CPU", 0), + t.placement_group_factory.required_resources.get("GPU", 0), + ) + + def _get_resources_from_bundles( + self, bundles: List[Dict[str, float]] + ) -> Dict[str, float]: + """Get total sums of resources in bundles""" + if not bundles: + return {"CPU": 0, "GPU": 0} + return _sum_bundles(bundles) + + def _is_bundle_empty(self, bundle: Dict[str, float]) -> bool: + return not (bundle.get("CPU", 0) or bundle.get("GPU", 0)) + + def _add_two_bundles( + self, + bundles_a: List[Dict[str, float]], + bundles_b: List[Dict[str, float]], + increase_by: Dict[str, float], + limit_to_increase_by_times: bool, + max_increase_by_times: int = -1, + ): + """Add two bundles together. + + If ``limit_to_increase_by_times`` is True, ``self.increase_by_times`` > 0 + and ``max_increase_by_times`` > 0, ensure that the resulting number of + bundles is not above ``min(max_increase_by_times, self.increase_by_times)``. + + If ``limit_to_increase_by_times`` is True and ``self.increase_by_times`` > 0, + ensure that the resulting number of bundles is not above + `self.increase_by_times``. + """ + if limit_to_increase_by_times: + if max_increase_by_times > 0 and self.increase_by_times > 0: + max_increase_by_times = min( + max_increase_by_times, self.increase_by_times + ) + elif self.increase_by_times > 0: + max_increase_by_times = self.increase_by_times + + if self.add_bundles: + bundles = [b for b in bundles_a if not self._is_bundle_empty(b)] + [ + b for b in bundles_b if not self._is_bundle_empty(b) + ] + if max_increase_by_times > 0: + bundles = bundles[:max_increase_by_times] + else: + bundles_a = bundles_a or [{}] + bundles_b = bundles_b or [{}] + bundles = [ + { + "CPU": bundles_a[0].get("CPU", 0) + bundles_b[0].get("CPU", 0), + "GPU": bundles_a[0].get("GPU", 0) + bundles_b[0].get("GPU", 0), + } + ] + if max_increase_by_times > 0: + bundles[0]["CPU"] = min( + bundles[0]["CPU"], + increase_by.get("CPU", 0) * max_increase_by_times, + ) + bundles[0]["GPU"] = min( + bundles[0]["GPU"], + increase_by.get("GPU", 0) * max_increase_by_times, + ) + + return bundles + + def _get_multiplier( + self, + increase_by: Dict[str, float], + cpus: float = 0, + gpus: float = 0, + max_multiplier: int = -1, + ) -> int: + """Get how many times ``increase_by`` bundles + occur in ``cpus`` and ``gpus``.""" + if increase_by.get("CPU", 0) and increase_by.get("GPU", 0): + multiplier = min( + cpus // increase_by.get("CPU", 0), + gpus // increase_by.get("GPU", 0), + ) + elif increase_by.get("GPU", 0): + multiplier = gpus // increase_by.get("GPU", 0) + else: + multiplier = cpus // increase_by.get("CPU", 0) + + if max_multiplier > 0 and multiplier > 0: + multiplier = min(max_multiplier, multiplier) + return int(multiplier) + + def _remove_bundles( + self, + bundles: List[Dict[str, float]], + increase_by: Dict[str, float], + multiplier: int, + ) -> List[Dict[str, float]]: + """Remove ``multiplier`` ``increase_by`` bundles from ``bundles``.""" + multiplier = -abs(multiplier) + if self.add_bundles: + bundles = bundles[:multiplier] + else: + bundles = deepcopy(bundles) + bundles[0]["CPU"] += increase_by.get("CPU", 0) * multiplier + bundles[0]["GPU"] += increase_by.get("GPU", 0) * multiplier + bundles[0]["CPU"] = max(bundles[0]["CPU"], 0) + bundles[0]["GPU"] = max(bundles[0]["GPU"], 0) + return bundles + + def _create_new_bundles( + self, + increase_by: Dict[str, float], + multiplier: int, + ) -> List[Dict[str, float]]: + """Create a list of new bundles containing ``increase_by`` * ``multiplier``.""" + multiplier = abs(multiplier) + + if self.add_bundles: + bundles = [increase_by] * int(multiplier) + else: + bundles = [{}] + bundles[0]["CPU"] = increase_by.get("CPU", 0) * multiplier + bundles[0]["GPU"] = increase_by.get("GPU", 0) * multiplier + + return bundles + + def _modify_bundles_with_free_resources( + self, + bundles: List[Dict[str, float]], + increase_by: Dict[str, float], + free_cpus: float, + free_gpus: float, + *, + max_multiplier: int = -1, + max_increase_by_times: int = -1, + ): + """Given free resources, increase/decrease the number of bundles in + ``bundles``.""" + multiplier = self._get_multiplier( + increase_by, free_cpus, free_gpus, max_multiplier + ) + if multiplier < 0: + bundles = self._remove_bundles(bundles, increase_by, multiplier) + elif multiplier > 0: + bundles_to_add = self._create_new_bundles(increase_by, multiplier) + bundles = self._add_two_bundles( + bundles, bundles_to_add, increase_by, True, max_increase_by_times + ) + return bundles + + def _get_added_bundles( + self, bundles: List[Dict[str, float]], base_bundles: List[Dict[str, float]] + ) -> List[Dict[str, float]]: + """Return the difference between bundles and base_bundles""" + if self.add_bundles: + added_bundles = bundles[len(base_bundles) :] + else: + if not bundles: + bundles = [{"CPU": 0, "GPU": 0}] + if not base_bundles: + base_bundles = [{"CPU": 0, "GPU": 0}] + added_bundles = [ + { + "CPU": bundles[0].get("CPU", 0) - base_bundles[0].get("CPU", 0), + "GPU": bundles[0].get("GPU", 0) - base_bundles[0].get("GPU", 0), + } + ] + return added_bundles + + def _are_bundles_below_limit( + self, + bundles: List[Dict[str, float]], + base_bundles: Optional[List[Dict[str, float]]] = None, + max_added_cpus: Optional[float] = None, + max_added_gpus: Optional[float] = None, + ): + if not max_added_cpus: + if self.increase_by_times > 0: + max_added_cpus = self.increase_by.get("CPU", 0) * self.increase_by_times + else: + max_added_cpus = np.inf + if not max_added_gpus: + if self.increase_by_times > 0: + max_added_gpus = self.increase_by.get("GPU", 0) * self.increase_by_times + else: + max_added_gpus = np.inf + added_resources = self._get_resources_from_bundles( + self._get_added_bundles(bundles, base_bundles) if base_bundles else bundles + ) + ret = ( + added_resources.get("CPU", -np.inf) < max_added_cpus + or added_resources.get("GPU", -np.inf) < max_added_gpus + ) + return ret + + def _get_new_added_bundles( + self, + trial: Trial, + all_trials: List[Trial], + base_bundles: List[Dict[str, float]], + increase_by: Dict[str, float], + total_available_cpus: float, + total_available_gpus: float, + used_cpus: float, + used_gpus: float, + ) -> List[Dict[str, float]]: + """Returns updated added bundles.""" + upper_limit_all_trials_bundles = [list() for _ in range(len(all_trials))] + + free_cpus = total_available_cpus - used_cpus + free_gpus = total_available_gpus - used_gpus + + base_resources = self._get_resources_from_bundles(base_bundles) + upper_limit_cpus_to_distribute = total_available_cpus - ( + base_resources.get("CPU", 0) * len(all_trials) + ) + upper_limit_gpus_to_distribute = total_available_gpus - ( + base_resources.get("GPU", 0) * len(all_trials) + ) + max_increase_by_times = 0 + + # First, calculate upper limits for uniform allocation + # This is done by simulating a clean slate scenario + # The loop runs until all resources are allocated or + # all trials are at their resource limits + i = 0 + trials_at_limit = set() + while ( + len(trials_at_limit) < len(all_trials) + # we have previously asserted that at least one resource has to be + # bigger than 0 + and upper_limit_cpus_to_distribute >= increase_by.get("CPU", 0) + and upper_limit_gpus_to_distribute >= increase_by.get("GPU", 0) + ): + idx = i % len(upper_limit_all_trials_bundles) + old_bundles = deepcopy(upper_limit_all_trials_bundles[idx]) + upper_limit_all_trials_bundles[ + idx + ] = self._modify_bundles_with_free_resources( + upper_limit_all_trials_bundles[idx], + increase_by, + upper_limit_cpus_to_distribute, + upper_limit_gpus_to_distribute, + max_multiplier=1, + ) + added_resources = self._get_resources_from_bundles( + self._get_added_bundles( + upper_limit_all_trials_bundles[idx], old_bundles + ) + ) + if not added_resources.get("CPU", 0) and not added_resources.get("GPU", 0): + trials_at_limit.add(idx) + elif idx == 0: + max_increase_by_times += 1 + upper_limit_cpus_to_distribute -= added_resources.get("CPU", 0) + upper_limit_gpus_to_distribute -= added_resources.get("GPU", 0) + i += 1 + + # Add new resourcs, but only up to calculated upper limits + # (max_increase_by_times) + return self._modify_bundles_with_free_resources( + self._get_added_bundles( + trial.placement_group_factory.bundles, base_bundles + ), + increase_by, + free_cpus, + free_gpus, + max_increase_by_times=max_increase_by_times, + ) + + def __call__( + self, + tune_controller: "TuneController", + trial: Trial, + result: Dict[str, Any], + scheduler: "ResourceChangingScheduler", + ) -> Optional[PlacementGroupFactory]: + """Run resource allocation logic. + + Returns a new ``PlacementGroupFactory`` with updated + resource requirements, or None. If the returned + ``PlacementGroupFactory`` is equal by value to the one the + trial has currently, the scheduler will skip the update process + internally (same with None). + + Args: + tune_controller: Trial runner for this Tune run. + Can be used to obtain information about other trials. + trial: The trial to allocate new resources to. + result: The latest results of trial. + scheduler: The scheduler calling + the function. + """ + # Get base trial resources as defined in + # ``tune.run(resources_per_trial)`` + base_trial_resource = scheduler.base_trial_resources + + if not self._validate(base_trial_resource=base_trial_resource, result=result): + return None + + # default values if resources_per_trial is unspecified + if base_trial_resource is None: + base_trial_resource = PlacementGroupFactory([{"CPU": 1, "GPU": 0}]) + + if self.increase_by: + increase_by = self.increase_by + assert not self._is_bundle_empty(increase_by) + assert increase_by.get("CPU", 0) >= 0 and increase_by.get("GPU", 0) >= 0 + elif self.add_bundles: + increase_by = base_trial_resource.bundles[-1] + elif base_trial_resource.bundles[0].get("GPU", 0): + increase_by = {"GPU": 1} + else: + increase_by = {"CPU": 1} + + base_bundles = deepcopy(base_trial_resource.bundles) + + ( + total_available_cpus, + total_available_gpus, + ) = self._get_total_available_resources(tune_controller=tune_controller) + + all_trials = tune_controller.get_live_trials() + + used_cpus_and_gpus = [self._get_used_cpus_and_gpus(t) for t in all_trials] + used_cpus, used_gpus = zip(*used_cpus_and_gpus) + used_cpus = sum(used_cpus) + used_gpus = sum(used_gpus) + + added_bundles = self._get_new_added_bundles( + trial, + all_trials, + base_bundles, + increase_by, + total_available_cpus, + total_available_gpus, + used_cpus, + used_gpus, + ) + + new_bundles = self._add_two_bundles( + base_bundles, added_bundles, increase_by, False + ) + + pgf = PlacementGroupFactory( + new_bundles, + strategy=base_trial_resource.strategy, + *base_trial_resource._args, + **base_trial_resource._kwargs, + ) + pgf._head_bundle_is_empty = base_trial_resource._head_bundle_is_empty + return pgf + + +@PublicAPI(stability="beta") +class DistributeResourcesToTopJob(DistributeResources): + """This class creates a "TopJob" resource allocation function. + + The function will assign all of the free resources to the best + performing trial (as defined by ``metric`` and ``mode``). The + previous best trials will not have their resources deallocated, + unless in the case outlined below. + + If for some reason a trial ends up with + more resources than there are free ones, it will adjust downwards. + It will also ensure that trial as at least as many resources as + it started with (``base_trial_resource``). + + The function returns a new ``PlacementGroupFactory`` with updated + resource requirements, or None. If the returned + ``PlacementGroupFactory`` is equal by value to the one the + trial has currently, the scheduler will skip the update process + internally (same with None). + + Args: + add_bundles: If True, create new bundles from free resources. + Otherwise, spread them among base_trial_resource bundles. + increase_by: A dict with key-value + pairs representing an atomic unit of resources (name-amount) + the trial will be increased by. If not set, the trial will + increase by 1 CPU/GPU. + increase_by_times: If set to >=1 and ``increase_by`` is set, + the trial will increase by maximum of + ``increase_by_times * increase_by`` resources. If set to <1, + no upper limit is set. Ignored if ``increase_by`` is not set. + reserve_resources: A dict of + resource_name-amount pairs representing the resources + that will not be allocated to resized trials. + is that the attribute should increase monotonically. + metric: The training result objective value attribute. Stopping + procedures will use this attribute. If None, will use the metric + of the scheduler. + mode: One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. If None, will use the metric + of the scheduler. + + """ + + def __init__( + self, + add_bundles: bool = False, + increase_by: Optional[Dict[str, float]] = None, + increase_by_times: int = -1, + reserve_resources: Optional[Dict[str, float]] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + ): + super().__init__(add_bundles, increase_by, increase_by_times, reserve_resources) + self.metric = metric + self.mode = mode + + @property + def _metric_op(self) -> float: + if self.mode not in ("min", "max"): + raise ValueError("The mode parameter can only be either min or max.") + if self.mode == "max": + return 1.0 + return -1.0 + + def _get_new_added_bundles( + self, + trial: Trial, + all_trials: List[Trial], + base_bundles: List[Dict[str, float]], + increase_by: Dict[str, float], + total_available_cpus: float, + total_available_gpus: float, + used_cpus: float, + used_gpus: float, + ) -> List[Dict[str, float]]: + if self.metric is None: + raise ValueError( + "The metric parameter cannot be None. The parameter can be set in " + "either `DistributeResourcesToTopJob`, the base scheduler or in " + "`tune.TuneConfig()` (highest to lowest priority)." + ) + + free_cpus = total_available_cpus - used_cpus + free_gpus = total_available_gpus - used_gpus + + sorted_trials = sorted( + all_trials, + key=lambda t: -self._metric_op * t.last_result.get(self.metric, np.inf), + ) + + added_bundles = self._get_added_bundles( + trial.placement_group_factory.bundles, base_bundles + ) + + best_trial = next( + ( + t + for t in sorted_trials + if self._are_bundles_below_limit( + t.placement_group_factory.bundles, base_bundles + ) + ), + sorted_trials[0], + ) + + if ( + trial.trial_id != best_trial.trial_id + # Only reduce resources here + and self._get_multiplier(increase_by, free_cpus, free_gpus) >= 0 + ): + return added_bundles + + return self._modify_bundles_with_free_resources( + added_bundles, + increase_by, + free_cpus, + free_gpus, + ) + + +_DistributeResourcesDefault = DistributeResources(add_bundles=False) +_DistributeResourcesDistributedDefault = DistributeResources(add_bundles=True) + + +@PublicAPI(stability="beta") +class ResourceChangingScheduler(TrialScheduler): + """A utility scheduler to dynamically change resources of live trials. + + .. versionadded:: 1.5.0 + + .. note:: + Experimental. API may change in future releases. + + The ResourceChangingScheduler works by wrapping around any other + scheduler and adjusting the resource requirements of live trials + in response to the decisions of the wrapped scheduler + through a user-specified ``resources_allocation_function``. + + An example of such a function can be found in + :doc:`/tune/examples/includes/xgboost_dynamic_resources_example`. + + If the functional API is used, the current trial resources can be obtained + by calling `tune.get_trial_resources()` inside the training function. + The function should be able to + :ref:`load and save checkpoints ` + (the latter preferably every iteration). + + If the Trainable (class) API is used, you can obtain the current trial + resources through the ``Trainable.trial_resources`` property. + + Cannot be used if ``reuse_actors`` is True in ``tune.TuneConfig()``. A ValueError + will be raised in that case. + + Args: + base_scheduler: The scheduler to provide decisions + about trials. If None, a default FIFOScheduler will be used. + resources_allocation_function: The callable used to change + live trial resource requiements during tuning. This callable + will be called on each trial as it finishes one step of training. + The callable must take four arguments: ``TrialRunner``, current + ``Trial``, current result :class:`dict` and the + ``ResourceChangingScheduler`` calling it. The callable must + return a ``PlacementGroupFactory`` + or None (signifying no need for an update). If + ``resources_allocation_function`` is None, no resource + requirements will be changed at any time. + By default, :class:`DistributeResources` will be used, + distributing available CPUs and GPUs over all running trials + in a robust way, without any prioritization. + + Warning: + If the ``resources_allocation_function`` sets trial resource + requirements to values bigger than possible, the trial will + not run. Ensure that your callable accounts for that possibility + by setting upper limits. Consult :class:`DistributeResources` + to see how that may be done. + + Example: + .. code-block:: python + + base_scheduler = ASHAScheduler(max_t=16) + def my_resources_allocation_function( + tune_controller: "TuneController", + trial: Trial, + result: Dict[str, Any], + scheduler: "ResourceChangingScheduler" + ) -> Optional[Union[PlacementGroupFactory, Resource]]: + # logic here + # usage of PlacementGroupFactory is strongly preferred + return PlacementGroupFactory(...) + scheduler = ResourceChangingScheduler( + base_scheduler, + my_resources_allocation_function + ) + + See :doc:`/tune/examples/includes/xgboost_dynamic_resources_example` for a + more detailed example. + """ + + def __init__( + self, + base_scheduler: Optional[TrialScheduler] = None, + resources_allocation_function: Optional[ + Callable[ + [ + "TuneController", + Trial, + Dict[str, Any], + "ResourceChangingScheduler", + ], + Optional[PlacementGroupFactory], + ] + ] = _DistributeResourcesDefault, + ) -> None: + super().__init__() + if resources_allocation_function is None: + warnings.warn( + "`resources_allocation_function` is None. No resource " + "requirements will be changed at any time. Pass a " + "correctly defined function to enable functionality." + ) + self._resources_allocation_function = resources_allocation_function + self._base_scheduler = base_scheduler or FIFOScheduler() + self._base_trial_resources: Optional[PlacementGroupFactory] = None + self._trials_to_reallocate: Dict[ + Trial, Optional[Union[dict, PlacementGroupFactory]] + ] = {} + self._reallocated_trial_ids: Set[str] = set() + self._metric = None + self._mode = None + + @property + def metric(self): + return self._base_scheduler._metric + + @property + def base_trial_resources(self) -> Optional[PlacementGroupFactory]: + return self._base_trial_resources + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + self._metric = metric + self._mode = mode + return self._base_scheduler.set_search_properties(metric, mode, **spec) + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial, **kwargs): + # use the first trial resources as the base + if self._base_trial_resources is None: + self._base_trial_resources = trial.placement_group_factory + # Raise error if the resources of a newly added trial don't match + # base resources, but allow trials that have already had their + # resources changed by ResourceChangingScheduler + # (those can be added again during loading from a checkpoint) + elif trial.trial_id not in self._reallocated_trial_ids: + trial_resources = trial.placement_group_factory + if trial_resources != self._base_trial_resources: + raise RuntimeError( + "ResourceChangingScheduler doesn't support trials with " + "varying base resources. First trial had " + f"{self._base_trial_resources}, trial {trial} has " + f"{trial_resources}." + ) + + return self._base_scheduler.on_trial_add(tune_controller, trial, **kwargs) + + def on_trial_error(self, tune_controller: "TuneController", trial: Trial, **kwargs): + return self._base_scheduler.on_trial_error(tune_controller, trial, **kwargs) + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + base_scheduler_decision = self._base_scheduler.on_trial_result( + tune_controller, trial, result + ) + if base_scheduler_decision == TrialScheduler.CONTINUE: + new_resources = self.reallocate_trial_resources_if_needed( + tune_controller, trial, result + ) + if new_resources: + self._trials_to_reallocate[trial] = new_resources + return TrialScheduler.PAUSE + return base_scheduler_decision + + def on_trial_complete( + self, + tune_controller: "TuneController", + trial: Trial, + result: Dict, + **kwargs, + ): + return self._base_scheduler.on_trial_complete( + tune_controller, trial, result, **kwargs + ) + + def on_trial_remove( + self, tune_controller: "TuneController", trial: Trial, **kwargs + ): + return self._base_scheduler.on_trial_remove(tune_controller, trial, **kwargs) + + def choose_trial_to_run( + self, tune_controller: "TuneController", **kwargs + ) -> Optional[Trial]: + if getattr(tune_controller, "_reuse_actors", False): + raise ValueError( + "ResourceChangingScheduler cannot be used with " + "`reuse_actors=True`. FIX THIS by setting " + "`reuse_actors=False` in `tune.TuneConfig()`." + ) + + any_resources_changed = False + + new_trials_to_reallocate = {} + for trial, new_resources in self._trials_to_reallocate.items(): + if trial.status == Trial.RUNNING: + new_trials_to_reallocate[trial] = new_resources + logger.debug(f"{trial} is still running, skipping for now") + continue + any_resources_changed = any_resources_changed or self.set_trial_resources( + trial, new_resources + ) + self._trials_to_reallocate = new_trials_to_reallocate + + trial = self._base_scheduler.choose_trial_to_run(tune_controller, **kwargs) + return trial + + def debug_string(self) -> str: + return "(ResourceChangingScheduler) " f"{self._base_scheduler.debug_string()}" + + def save(self, checkpoint_path: str): + save_object = self.__dict__ + with open(checkpoint_path, "wb") as outputFile: + pickle.dump(save_object, outputFile) + + def restore(self, checkpoint_path: str): + with open(checkpoint_path, "rb") as inputFile: + save_object = pickle.load(inputFile) + self.__dict__.update(save_object) + + def set_trial_resources( + self, trial: Trial, new_resources: Union[Dict, PlacementGroupFactory] + ) -> bool: + """Returns True if new_resources were set.""" + if new_resources: + logger.info( + f"Setting trial {trial} resource to {new_resources} " + f"with {new_resources._bundles}" + ) + trial.placement_group_factory = None + trial.update_resources(new_resources) + # keep track of all trials which had their resources changed + self._reallocated_trial_ids.add(trial.trial_id) + return True + return False + + def _are_resources_the_same( + self, + trial: Trial, + new_resources, + ) -> bool: + """Returns True if trial's resources are value equal to new_resources. + + Only checks for PlacementGroupFactories at this moment. + """ + if ( + isinstance(new_resources, PlacementGroupFactory) + and trial.placement_group_factory == new_resources + ): + logger.debug( + f"{trial} PGF " + f"{trial.placement_group_factory.required_resources}" + f" and {new_resources.required_resources}" + f" are the same, skipping" + ) + return True + else: + return False + + def reallocate_trial_resources_if_needed( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> Optional[Union[dict, PlacementGroupFactory]]: + """Calls user defined resources_allocation_function. If the returned + resources are not none and not the same as currently present, returns + them. Otherwise, returns None.""" + if self._resources_allocation_function is None: + return None + + if not getattr(self._resources_allocation_function, "metric", None): + self._resources_allocation_function.metric = getattr( + self._base_scheduler, "_metric", self._metric + ) + if not getattr(self._resources_allocation_function, "mode", None): + self._resources_allocation_function.mode = getattr( + self._base_scheduler, "_mode", self._mode + ) + + new_resources = self._resources_allocation_function( + tune_controller, trial, result, self + ) + + # if we can check if the new resources are the same, + # we do that here and skip resource allocation + if new_resources and not self._are_resources_the_same(trial, new_resources): + return new_resources + return None diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/trial_scheduler.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/trial_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d99db5d24704c5cc891b0813837d39c32a2321ec --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/trial_scheduler.py @@ -0,0 +1,170 @@ +from typing import TYPE_CHECKING, Dict, Optional + +from ray.air._internal.usage import tag_scheduler +from ray.tune.experiment import Trial +from ray.tune.result import DEFAULT_METRIC +from ray.util.annotations import DeveloperAPI, PublicAPI + +if TYPE_CHECKING: + from ray.tune.execution.tune_controller import TuneController + + +@DeveloperAPI +class TrialScheduler: + """Interface for implementing a Trial Scheduler class. + + Note to Tune developers: If a new scheduler is added, please update + `air/_internal/usage.py`. + """ + + CONTINUE = "CONTINUE" #: Status for continuing trial execution + PAUSE = "PAUSE" #: Status for pausing trial execution + STOP = "STOP" #: Status for stopping trial execution + # Caution: Temporary and anti-pattern! This means Scheduler calls + # into Executor directly without going through TrialRunner. + # TODO(xwjiang): Deprecate this after we control the interaction + # between schedulers and executor. + NOOP = "NOOP" + + _metric = None + + _supports_buffered_results = True + + def __init__(self): + tag_scheduler(self) + + @property + def metric(self): + return self._metric + + @property + def supports_buffered_results(self): + return self._supports_buffered_results + + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], **spec + ) -> bool: + """Pass search properties to scheduler. + + This method acts as an alternative to instantiating schedulers + that react to metrics with their own `metric` and `mode` parameters. + + Args: + metric: Metric to optimize + mode: One of ["min", "max"]. Direction to optimize. + **spec: Any kwargs for forward compatiblity. + Info like Experiment.PUBLIC_KEYS is provided through here. + """ + if self._metric and metric: + return False + if metric: + self._metric = metric + + if self._metric is None: + # Per default, use anonymous metric + self._metric = DEFAULT_METRIC + + return True + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + """Called when a new trial is added to the trial runner.""" + + raise NotImplementedError + + def on_trial_error(self, tune_controller: "TuneController", trial: Trial): + """Notification for the error of trial. + + This will only be called when the trial is in the RUNNING state.""" + + raise NotImplementedError + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + """Called on each intermediate result returned by a trial. + + At this point, the trial scheduler can make a decision by returning + one of CONTINUE, PAUSE, and STOP. This will only be called when the + trial is in the RUNNING state.""" + + raise NotImplementedError + + def on_trial_complete( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + """Notification for the completion of trial. + + This will only be called when the trial is in the RUNNING state and + either completes naturally or by manual termination.""" + + raise NotImplementedError + + def on_trial_remove(self, tune_controller: "TuneController", trial: Trial): + """Called to remove trial. + + This is called when the trial is in PAUSED or PENDING state. Otherwise, + call `on_trial_complete`.""" + + raise NotImplementedError + + def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]: + """Called to choose a new trial to run. + + This should return one of the trials in tune_controller that is in + the PENDING or PAUSED state. This function must be idempotent. + + If no trial is ready, return None.""" + + raise NotImplementedError + + def debug_string(self) -> str: + """Returns a human readable message for printing to the console.""" + + raise NotImplementedError + + def save(self, checkpoint_path: str): + """Save trial scheduler to a checkpoint""" + raise NotImplementedError + + def restore(self, checkpoint_path: str): + """Restore trial scheduler from checkpoint.""" + raise NotImplementedError + + +@PublicAPI +class FIFOScheduler(TrialScheduler): + """Simple scheduler that just runs trials in submission order.""" + + def __init__(self): + super().__init__() + + def on_trial_add(self, tune_controller: "TuneController", trial: Trial): + pass + + def on_trial_error(self, tune_controller: "TuneController", trial: Trial): + pass + + def on_trial_result( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ) -> str: + return TrialScheduler.CONTINUE + + def on_trial_complete( + self, tune_controller: "TuneController", trial: Trial, result: Dict + ): + pass + + def on_trial_remove(self, tune_controller: "TuneController", trial: Trial): + pass + + def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]: + for trial in tune_controller.get_trials(): + if trial.status == Trial.PENDING: + return trial + for trial in tune_controller.get_trials(): + if trial.status == Trial.PAUSED: + return trial + return None + + def debug_string(self) -> str: + return "Using FIFO scheduling algorithm." diff --git a/.venv/lib/python3.11/site-packages/ray/tune/schedulers/util.py b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/util.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0e012a8367e89397a6258632e5f92e6d41d23d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/tune/schedulers/util.py @@ -0,0 +1,27 @@ +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +def _set_search_properties_backwards_compatible( + set_search_properties_func, metric: Optional[str], mode: Optional[str], **spec +) -> bool: + """Wraps around set_search_properties() so that it is backward compatible. + + Also outputs a warning to encourage custom schedulers to be updated. + """ + try: + return set_search_properties_func(metric, mode, **spec) + except TypeError as e: + if str(e).startswith( + "set_search_properties() got an unexpected keyword argument" + ): + logger.warning( + "Please update custom Scheduler to take in function signature " + "as ``def set_search_properties(metric, mode, " + "**spec) -> bool``." + ) + return set_search_properties_func(metric, mode) + else: + raise e