Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/commands.py +306 -0
- .venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py +101 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py +285 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py +191 -0
- .venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py +185 -0
- .venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/experimental/output.py +1043 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py +32 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/aim.py +187 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/comet.py +3 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/csv.py +135 -0
- .venv/lib/python3.11/site-packages/ray/tune/logger/json.py +128 -0
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/callback.cpython-311.pyc
ADDED
|
Binary file (23.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/constants.cpython-311.pyc
ADDED
|
Binary file (988 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/context.cpython-311.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/error.cpython-311.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/progress_reporter.cpython-311.pyc
ADDED
|
Binary file (73.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/resources.cpython-311.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result.cpython-311.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/result_grid.cpython-311.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/syncer.cpython-311.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune.cpython-311.pyc
ADDED
|
Binary file (48 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tune_config.cpython-311.pyc
ADDED
|
Binary file (6.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/__pycache__/tuner.cpython-311.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/cli/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (185 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/commands.cpython-311.pyc
ADDED
|
Binary file (15.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/cli/__pycache__/scripts.cpython-311.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/cli/commands.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import operator
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
import click
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from pandas.api.types import is_numeric_dtype, is_string_dtype
|
| 13 |
+
|
| 14 |
+
from ray._private.thirdparty.tabulate.tabulate import tabulate
|
| 15 |
+
from ray.air.constants import EXPR_RESULT_FILE
|
| 16 |
+
from ray.tune import TuneError
|
| 17 |
+
from ray.tune.analysis import ExperimentAnalysis
|
| 18 |
+
from ray.tune.result import (
|
| 19 |
+
CONFIG_PREFIX,
|
| 20 |
+
DEFAULT_EXPERIMENT_INFO_KEYS,
|
| 21 |
+
DEFAULT_RESULT_KEYS,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
EDITOR = os.getenv("EDITOR", "vim")
|
| 27 |
+
|
| 28 |
+
TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S (%A)"
|
| 29 |
+
|
| 30 |
+
DEFAULT_CLI_KEYS = DEFAULT_EXPERIMENT_INFO_KEYS + DEFAULT_RESULT_KEYS
|
| 31 |
+
|
| 32 |
+
DEFAULT_PROJECT_INFO_KEYS = (
|
| 33 |
+
"name",
|
| 34 |
+
"total_trials",
|
| 35 |
+
"last_updated",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
TERM_WIDTH, TERM_HEIGHT = shutil.get_terminal_size(fallback=(100, 100))
|
| 39 |
+
|
| 40 |
+
OPERATORS = {
|
| 41 |
+
"<": operator.lt,
|
| 42 |
+
"<=": operator.le,
|
| 43 |
+
"==": operator.eq,
|
| 44 |
+
"!=": operator.ne,
|
| 45 |
+
">=": operator.ge,
|
| 46 |
+
">": operator.gt,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _check_tabulate():
|
| 51 |
+
"""Checks whether tabulate is installed."""
|
| 52 |
+
if tabulate is None:
|
| 53 |
+
raise ImportError("Tabulate not installed. Please run `pip install tabulate`.")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def print_format_output(dataframe):
|
| 57 |
+
"""Prints output of given dataframe to fit into terminal.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
table: Final outputted dataframe.
|
| 61 |
+
dropped_cols: Columns dropped due to terminal size.
|
| 62 |
+
empty_cols: Empty columns (dropped on default).
|
| 63 |
+
"""
|
| 64 |
+
print_df = pd.DataFrame()
|
| 65 |
+
dropped_cols = []
|
| 66 |
+
empty_cols = []
|
| 67 |
+
# column display priority is based on the info_keys passed in
|
| 68 |
+
for i, col in enumerate(dataframe):
|
| 69 |
+
if dataframe[col].isnull().all():
|
| 70 |
+
# Don't add col to print_df if is fully empty
|
| 71 |
+
empty_cols += [col]
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
print_df[col] = dataframe[col]
|
| 75 |
+
test_table = tabulate(print_df, headers="keys", tablefmt="psql")
|
| 76 |
+
if str(test_table).index("\n") > TERM_WIDTH:
|
| 77 |
+
# Drop all columns beyond terminal width
|
| 78 |
+
print_df.drop(col, axis=1, inplace=True)
|
| 79 |
+
dropped_cols += list(dataframe.columns)[i:]
|
| 80 |
+
break
|
| 81 |
+
|
| 82 |
+
table = tabulate(print_df, headers="keys", tablefmt="psql", showindex="never")
|
| 83 |
+
|
| 84 |
+
print(table)
|
| 85 |
+
if dropped_cols:
|
| 86 |
+
click.secho("Dropped columns: {}".format(dropped_cols), fg="yellow")
|
| 87 |
+
click.secho("Please increase your terminal size to view remaining columns.")
|
| 88 |
+
if empty_cols:
|
| 89 |
+
click.secho("Empty columns: {}".format(empty_cols), fg="yellow")
|
| 90 |
+
|
| 91 |
+
return table, dropped_cols, empty_cols
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def list_trials(
|
| 95 |
+
experiment_path: str,
|
| 96 |
+
sort: Optional[List[str]] = None,
|
| 97 |
+
output: Optional[str] = None,
|
| 98 |
+
filter_op: Optional[str] = None,
|
| 99 |
+
info_keys: Optional[List[str]] = None,
|
| 100 |
+
limit: int = None,
|
| 101 |
+
desc: bool = False,
|
| 102 |
+
):
|
| 103 |
+
"""Lists trials in the directory subtree starting at the given path.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
experiment_path: Directory where trials are located.
|
| 107 |
+
Like Experiment.local_dir/Experiment.name/experiment*.json.
|
| 108 |
+
sort: Keys to sort by.
|
| 109 |
+
output: Name of file where output is saved.
|
| 110 |
+
filter_op: Filter operation in the format
|
| 111 |
+
"<column> <operator> <value>".
|
| 112 |
+
info_keys: Keys that are displayed.
|
| 113 |
+
limit: Number of rows to display.
|
| 114 |
+
desc: Sort ascending vs. descending.
|
| 115 |
+
"""
|
| 116 |
+
_check_tabulate()
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
checkpoints_df = ExperimentAnalysis(experiment_path).dataframe() # last result
|
| 120 |
+
except TuneError as e:
|
| 121 |
+
raise click.ClickException("No trial data found!") from e
|
| 122 |
+
|
| 123 |
+
config_prefix = CONFIG_PREFIX + "/"
|
| 124 |
+
|
| 125 |
+
def key_filter(k):
|
| 126 |
+
return k in DEFAULT_CLI_KEYS or k.startswith(config_prefix)
|
| 127 |
+
|
| 128 |
+
col_keys = [k for k in checkpoints_df.columns if key_filter(k)]
|
| 129 |
+
|
| 130 |
+
if info_keys:
|
| 131 |
+
for k in info_keys:
|
| 132 |
+
if k not in checkpoints_df.columns:
|
| 133 |
+
raise click.ClickException(
|
| 134 |
+
"Provided key invalid: {}. "
|
| 135 |
+
"Available keys: {}.".format(k, checkpoints_df.columns)
|
| 136 |
+
)
|
| 137 |
+
col_keys = [k for k in checkpoints_df.columns if k in info_keys]
|
| 138 |
+
|
| 139 |
+
if not col_keys:
|
| 140 |
+
raise click.ClickException("No columns to output.")
|
| 141 |
+
|
| 142 |
+
checkpoints_df = checkpoints_df[col_keys]
|
| 143 |
+
if "last_update_time" in checkpoints_df:
|
| 144 |
+
with pd.option_context("mode.use_inf_as_null", True):
|
| 145 |
+
datetime_series = checkpoints_df["last_update_time"].dropna()
|
| 146 |
+
|
| 147 |
+
datetime_series = datetime_series.apply(
|
| 148 |
+
lambda t: datetime.fromtimestamp(t).strftime(TIMESTAMP_FORMAT)
|
| 149 |
+
)
|
| 150 |
+
checkpoints_df["last_update_time"] = datetime_series
|
| 151 |
+
|
| 152 |
+
if "logdir" in checkpoints_df:
|
| 153 |
+
# logdir often too long to view in table, so drop experiment_path
|
| 154 |
+
checkpoints_df["logdir"] = checkpoints_df["logdir"].str.replace(
|
| 155 |
+
experiment_path, ""
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if filter_op:
|
| 159 |
+
col, op, val = filter_op.split(" ")
|
| 160 |
+
col_type = checkpoints_df[col].dtype
|
| 161 |
+
if is_numeric_dtype(col_type):
|
| 162 |
+
val = float(val)
|
| 163 |
+
elif is_string_dtype(col_type):
|
| 164 |
+
val = str(val)
|
| 165 |
+
# TODO(Andrew): add support for datetime and boolean
|
| 166 |
+
else:
|
| 167 |
+
raise click.ClickException(
|
| 168 |
+
"Unsupported dtype for {}: {}".format(val, col_type)
|
| 169 |
+
)
|
| 170 |
+
op = OPERATORS[op]
|
| 171 |
+
filtered_index = op(checkpoints_df[col], val)
|
| 172 |
+
checkpoints_df = checkpoints_df[filtered_index]
|
| 173 |
+
|
| 174 |
+
if sort:
|
| 175 |
+
for key in sort:
|
| 176 |
+
if key not in checkpoints_df:
|
| 177 |
+
raise click.ClickException(
|
| 178 |
+
"{} not in: {}".format(key, list(checkpoints_df))
|
| 179 |
+
)
|
| 180 |
+
ascending = not desc
|
| 181 |
+
checkpoints_df = checkpoints_df.sort_values(by=sort, ascending=ascending)
|
| 182 |
+
|
| 183 |
+
if limit:
|
| 184 |
+
checkpoints_df = checkpoints_df[:limit]
|
| 185 |
+
|
| 186 |
+
print_format_output(checkpoints_df)
|
| 187 |
+
|
| 188 |
+
if output:
|
| 189 |
+
file_extension = os.path.splitext(output)[1].lower()
|
| 190 |
+
if file_extension in (".p", ".pkl", ".pickle"):
|
| 191 |
+
checkpoints_df.to_pickle(output)
|
| 192 |
+
elif file_extension == ".csv":
|
| 193 |
+
checkpoints_df.to_csv(output, index=False)
|
| 194 |
+
else:
|
| 195 |
+
raise click.ClickException("Unsupported filetype: {}".format(output))
|
| 196 |
+
click.secho("Output saved at {}".format(output), fg="green")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def list_experiments(
|
| 200 |
+
project_path: str,
|
| 201 |
+
sort: Optional[List[str]] = None,
|
| 202 |
+
output: str = None,
|
| 203 |
+
filter_op: str = None,
|
| 204 |
+
info_keys: Optional[List[str]] = None,
|
| 205 |
+
limit: int = None,
|
| 206 |
+
desc: bool = False,
|
| 207 |
+
):
|
| 208 |
+
"""Lists experiments in the directory subtree.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
project_path: Directory where experiments are located.
|
| 212 |
+
Corresponds to Experiment.local_dir.
|
| 213 |
+
sort: Keys to sort by.
|
| 214 |
+
output: Name of file where output is saved.
|
| 215 |
+
filter_op: Filter operation in the format
|
| 216 |
+
"<column> <operator> <value>".
|
| 217 |
+
info_keys: Keys that are displayed.
|
| 218 |
+
limit: Number of rows to display.
|
| 219 |
+
desc: Sort ascending vs. descending.
|
| 220 |
+
"""
|
| 221 |
+
_check_tabulate()
|
| 222 |
+
base, experiment_folders, _ = next(os.walk(project_path))
|
| 223 |
+
|
| 224 |
+
experiment_data_collection = []
|
| 225 |
+
|
| 226 |
+
for experiment_dir in experiment_folders:
|
| 227 |
+
num_trials = sum(
|
| 228 |
+
EXPR_RESULT_FILE in files
|
| 229 |
+
for _, _, files in os.walk(os.path.join(base, experiment_dir))
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
experiment_data = {"name": experiment_dir, "total_trials": num_trials}
|
| 233 |
+
experiment_data_collection.append(experiment_data)
|
| 234 |
+
|
| 235 |
+
if not experiment_data_collection:
|
| 236 |
+
raise click.ClickException("No experiments found!")
|
| 237 |
+
|
| 238 |
+
info_df = pd.DataFrame(experiment_data_collection)
|
| 239 |
+
if not info_keys:
|
| 240 |
+
info_keys = DEFAULT_PROJECT_INFO_KEYS
|
| 241 |
+
col_keys = [k for k in list(info_keys) if k in info_df]
|
| 242 |
+
if not col_keys:
|
| 243 |
+
raise click.ClickException(
|
| 244 |
+
"None of keys {} in experiment data!".format(info_keys)
|
| 245 |
+
)
|
| 246 |
+
info_df = info_df[col_keys]
|
| 247 |
+
|
| 248 |
+
if filter_op:
|
| 249 |
+
col, op, val = filter_op.split(" ")
|
| 250 |
+
col_type = info_df[col].dtype
|
| 251 |
+
if is_numeric_dtype(col_type):
|
| 252 |
+
val = float(val)
|
| 253 |
+
elif is_string_dtype(col_type):
|
| 254 |
+
val = str(val)
|
| 255 |
+
# TODO(Andrew): add support for datetime and boolean
|
| 256 |
+
else:
|
| 257 |
+
raise click.ClickException(
|
| 258 |
+
"Unsupported dtype for {}: {}".format(val, col_type)
|
| 259 |
+
)
|
| 260 |
+
op = OPERATORS[op]
|
| 261 |
+
filtered_index = op(info_df[col], val)
|
| 262 |
+
info_df = info_df[filtered_index]
|
| 263 |
+
|
| 264 |
+
if sort:
|
| 265 |
+
for key in sort:
|
| 266 |
+
if key not in info_df:
|
| 267 |
+
raise click.ClickException("{} not in: {}".format(key, list(info_df)))
|
| 268 |
+
ascending = not desc
|
| 269 |
+
info_df = info_df.sort_values(by=sort, ascending=ascending)
|
| 270 |
+
|
| 271 |
+
if limit:
|
| 272 |
+
info_df = info_df[:limit]
|
| 273 |
+
|
| 274 |
+
print_format_output(info_df)
|
| 275 |
+
|
| 276 |
+
if output:
|
| 277 |
+
file_extension = os.path.splitext(output)[1].lower()
|
| 278 |
+
if file_extension in (".p", ".pkl", ".pickle"):
|
| 279 |
+
info_df.to_pickle(output)
|
| 280 |
+
elif file_extension == ".csv":
|
| 281 |
+
info_df.to_csv(output, index=False)
|
| 282 |
+
else:
|
| 283 |
+
raise click.ClickException("Unsupported filetype: {}".format(output))
|
| 284 |
+
click.secho("Output saved at {}".format(output), fg="green")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def add_note(path: str, filename: str = "note.txt"):
|
| 288 |
+
"""Opens a txt file at the given path where user can add and save notes.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
path: Directory where note will be saved.
|
| 292 |
+
filename: Name of note. Defaults to "note.txt"
|
| 293 |
+
"""
|
| 294 |
+
path = Path(path).expanduser()
|
| 295 |
+
assert path.is_dir(), "{} is not a valid directory.".format(path)
|
| 296 |
+
|
| 297 |
+
filepath = path / filename
|
| 298 |
+
|
| 299 |
+
try:
|
| 300 |
+
subprocess.call([EDITOR, filepath.as_posix()])
|
| 301 |
+
except Exception as exc:
|
| 302 |
+
click.secho("Editing note failed: {}".format(str(exc)), fg="red")
|
| 303 |
+
if filepath.exists():
|
| 304 |
+
print("Note updated at:", filepath.as_posix())
|
| 305 |
+
else:
|
| 306 |
+
print("Note created at:", filepath.as_posix())
|
.venv/lib/python3.11/site-packages/ray/tune/cli/scripts.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import click
|
| 2 |
+
|
| 3 |
+
import ray.tune.cli.commands as commands
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@click.group()
|
| 7 |
+
def cli():
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@cli.command()
|
| 12 |
+
@click.argument("experiment_path", required=True, type=str)
|
| 13 |
+
@click.option("--sort", default=None, type=str, help="Select which column to sort on.")
|
| 14 |
+
@click.option(
|
| 15 |
+
"--output",
|
| 16 |
+
"-o",
|
| 17 |
+
default=None,
|
| 18 |
+
type=str,
|
| 19 |
+
help="Select file to output information to.",
|
| 20 |
+
)
|
| 21 |
+
@click.option(
|
| 22 |
+
"--filter",
|
| 23 |
+
"filter_op",
|
| 24 |
+
default=None,
|
| 25 |
+
type=str,
|
| 26 |
+
help="Select filter in the format '<column> <operator> <value>'.",
|
| 27 |
+
)
|
| 28 |
+
@click.option(
|
| 29 |
+
"--columns", default=None, type=str, help="Select columns to be displayed."
|
| 30 |
+
)
|
| 31 |
+
@click.option(
|
| 32 |
+
"--limit", default=None, type=int, help="Select number of rows to display."
|
| 33 |
+
)
|
| 34 |
+
@click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.")
|
| 35 |
+
def list_trials(experiment_path, sort, output, filter_op, columns, limit, desc):
|
| 36 |
+
"""Lists trials in the directory subtree starting at the given path."""
|
| 37 |
+
if sort:
|
| 38 |
+
sort = sort.split(",")
|
| 39 |
+
if columns:
|
| 40 |
+
columns = columns.split(",")
|
| 41 |
+
commands.list_trials(experiment_path, sort, output, filter_op, columns, limit, desc)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@cli.command()
|
| 45 |
+
@click.argument("project_path", required=True, type=str)
|
| 46 |
+
@click.option("--sort", default=None, type=str, help="Select which column to sort on.")
|
| 47 |
+
@click.option(
|
| 48 |
+
"--output",
|
| 49 |
+
"-o",
|
| 50 |
+
default=None,
|
| 51 |
+
type=str,
|
| 52 |
+
help="Select file to output information to.",
|
| 53 |
+
)
|
| 54 |
+
@click.option(
|
| 55 |
+
"--filter",
|
| 56 |
+
"filter_op",
|
| 57 |
+
default=None,
|
| 58 |
+
type=str,
|
| 59 |
+
help="Select filter in the format '<column> <operator> <value>'.",
|
| 60 |
+
)
|
| 61 |
+
@click.option(
|
| 62 |
+
"--columns", default=None, type=str, help="Select columns to be displayed."
|
| 63 |
+
)
|
| 64 |
+
@click.option(
|
| 65 |
+
"--limit", default=None, type=int, help="Select number of rows to display."
|
| 66 |
+
)
|
| 67 |
+
@click.option("--desc", default=False, type=bool, help="Sort ascending vs. descending.")
|
| 68 |
+
def list_experiments(project_path, sort, output, filter_op, columns, limit, desc):
|
| 69 |
+
"""Lists experiments in the directory subtree."""
|
| 70 |
+
if sort:
|
| 71 |
+
sort = sort.split(",")
|
| 72 |
+
if columns:
|
| 73 |
+
columns = columns.split(",")
|
| 74 |
+
commands.list_experiments(
|
| 75 |
+
project_path, sort, output, filter_op, columns, limit, desc
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@cli.command()
|
| 80 |
+
@click.argument("path", required=True, type=str)
|
| 81 |
+
@click.option(
|
| 82 |
+
"--filename", default="note.txt", type=str, help="Specify filename for note."
|
| 83 |
+
)
|
| 84 |
+
def add_note(path, filename):
|
| 85 |
+
"""Adds user notes as a text file at the given path."""
|
| 86 |
+
commands.add_note(path, filename)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
cli.add_command(list_trials, name="ls")
|
| 90 |
+
cli.add_command(list_trials, name="list-trials")
|
| 91 |
+
cli.add_command(list_experiments, name="lsx")
|
| 92 |
+
cli.add_command(list_experiments, name="list-experiments")
|
| 93 |
+
cli.add_command(add_note, name="add-note")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main():
|
| 97 |
+
return cli()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
main()
|
.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/common.cpython-311.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_func.cpython-311.pyc
ADDED
|
Binary file (8.83 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/__pycache__/pbt_dcgan_mnist_trainable.cpython-311.pyc
ADDED
|
Binary file (9.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/common.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib.animation as animation
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.parallel
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import torchvision.datasets as dset
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
import torchvision.utils as vutils
|
| 13 |
+
from scipy.stats import entropy
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
|
| 19 |
+
# Training parameters
|
| 20 |
+
workers = 2
|
| 21 |
+
batch_size = 64
|
| 22 |
+
image_size = 32
|
| 23 |
+
|
| 24 |
+
# Number of channels in the training images. For color images this is 3
|
| 25 |
+
nc = 1
|
| 26 |
+
|
| 27 |
+
# Size of z latent vector (i.e. size of generator input)
|
| 28 |
+
nz = 100
|
| 29 |
+
|
| 30 |
+
# Size of feature maps in generator
|
| 31 |
+
ngf = 32
|
| 32 |
+
|
| 33 |
+
# Size of feature maps in discriminator
|
| 34 |
+
ndf = 32
|
| 35 |
+
|
| 36 |
+
# Beta1 hyperparam for Adam optimizers
|
| 37 |
+
beta1 = 0.5
|
| 38 |
+
|
| 39 |
+
# iterations of actual training in each Trainable _train
|
| 40 |
+
train_iterations_per_step = 5
|
| 41 |
+
|
| 42 |
+
MODEL_PATH = os.path.expanduser("~/.ray/models/mnist_cnn.pt")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_data_loader(data_dir="~/data"):
|
| 46 |
+
dataset = dset.MNIST(
|
| 47 |
+
root=data_dir,
|
| 48 |
+
download=True,
|
| 49 |
+
transform=transforms.Compose(
|
| 50 |
+
[
|
| 51 |
+
transforms.Resize(image_size),
|
| 52 |
+
transforms.ToTensor(),
|
| 53 |
+
transforms.Normalize((0.5,), (0.5,)),
|
| 54 |
+
]
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Create the dataloader
|
| 59 |
+
dataloader = torch.utils.data.DataLoader(
|
| 60 |
+
dataset, batch_size=batch_size, shuffle=True, num_workers=workers
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return dataloader
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# __GANmodel_begin__
|
| 67 |
+
# custom weights initialization called on netG and netD
|
| 68 |
+
def weights_init(m):
|
| 69 |
+
classname = m.__class__.__name__
|
| 70 |
+
if classname.find("Conv") != -1:
|
| 71 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
| 72 |
+
elif classname.find("BatchNorm") != -1:
|
| 73 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 74 |
+
nn.init.constant_(m.bias.data, 0)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Generator Code
|
| 78 |
+
class Generator(nn.Module):
|
| 79 |
+
def __init__(self):
|
| 80 |
+
super(Generator, self).__init__()
|
| 81 |
+
self.main = nn.Sequential(
|
| 82 |
+
# input is Z, going into a convolution
|
| 83 |
+
nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
|
| 84 |
+
nn.BatchNorm2d(ngf * 4),
|
| 85 |
+
nn.ReLU(True),
|
| 86 |
+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
| 87 |
+
nn.BatchNorm2d(ngf * 2),
|
| 88 |
+
nn.ReLU(True),
|
| 89 |
+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
| 90 |
+
nn.BatchNorm2d(ngf),
|
| 91 |
+
nn.ReLU(True),
|
| 92 |
+
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
| 93 |
+
nn.Tanh(),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, input):
|
| 97 |
+
return self.main(input)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Discriminator(nn.Module):
|
| 101 |
+
def __init__(self):
|
| 102 |
+
super(Discriminator, self).__init__()
|
| 103 |
+
self.main = nn.Sequential(
|
| 104 |
+
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
| 105 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 106 |
+
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
| 107 |
+
nn.BatchNorm2d(ndf * 2),
|
| 108 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 109 |
+
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
| 110 |
+
nn.BatchNorm2d(ndf * 4),
|
| 111 |
+
nn.LeakyReLU(0.2, inplace=True),
|
| 112 |
+
nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
|
| 113 |
+
nn.Sigmoid(),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def forward(self, input):
|
| 117 |
+
return self.main(input)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# __GANmodel_end__
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# __INCEPTION_SCORE_begin__
|
| 124 |
+
class Net(nn.Module):
|
| 125 |
+
"""
|
| 126 |
+
LeNet for MNist classification, used for inception_score
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self):
|
| 130 |
+
super(Net, self).__init__()
|
| 131 |
+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
| 132 |
+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
| 133 |
+
self.conv2_drop = nn.Dropout2d()
|
| 134 |
+
self.fc1 = nn.Linear(320, 50)
|
| 135 |
+
self.fc2 = nn.Linear(50, 10)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
| 139 |
+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
| 140 |
+
x = x.view(-1, 320)
|
| 141 |
+
x = F.relu(self.fc1(x))
|
| 142 |
+
x = F.dropout(x, training=self.training)
|
| 143 |
+
x = self.fc2(x)
|
| 144 |
+
return F.log_softmax(x, dim=1)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def inception_score(imgs, mnist_model_ref, batch_size=32, splits=1):
|
| 148 |
+
N = len(imgs)
|
| 149 |
+
dtype = torch.FloatTensor
|
| 150 |
+
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
| 151 |
+
cm = ray.get(mnist_model_ref) # Get the mnist model from Ray object store.
|
| 152 |
+
up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)
|
| 153 |
+
|
| 154 |
+
def get_pred(x):
|
| 155 |
+
x = up(x)
|
| 156 |
+
x = cm(x)
|
| 157 |
+
return F.softmax(x).data.cpu().numpy()
|
| 158 |
+
|
| 159 |
+
preds = np.zeros((N, 10))
|
| 160 |
+
for i, batch in enumerate(dataloader, 0):
|
| 161 |
+
batch = batch.type(dtype)
|
| 162 |
+
batchv = Variable(batch)
|
| 163 |
+
batch_size_i = batch.size()[0]
|
| 164 |
+
preds[i * batch_size : i * batch_size + batch_size_i] = get_pred(batchv)
|
| 165 |
+
|
| 166 |
+
# Now compute the mean kl-div
|
| 167 |
+
split_scores = []
|
| 168 |
+
for k in range(splits):
|
| 169 |
+
part = preds[k * (N // splits) : (k + 1) * (N // splits), :]
|
| 170 |
+
py = np.mean(part, axis=0)
|
| 171 |
+
scores = []
|
| 172 |
+
for i in range(part.shape[0]):
|
| 173 |
+
pyx = part[i, :]
|
| 174 |
+
scores.append(entropy(pyx, py))
|
| 175 |
+
split_scores.append(np.exp(np.mean(scores)))
|
| 176 |
+
|
| 177 |
+
return np.mean(split_scores), np.std(split_scores)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# __INCEPTION_SCORE_end__
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def train_func(
|
| 184 |
+
netD,
|
| 185 |
+
netG,
|
| 186 |
+
optimG,
|
| 187 |
+
optimD,
|
| 188 |
+
criterion,
|
| 189 |
+
dataloader,
|
| 190 |
+
iteration,
|
| 191 |
+
device,
|
| 192 |
+
mnist_model_ref,
|
| 193 |
+
):
|
| 194 |
+
real_label = 1
|
| 195 |
+
fake_label = 0
|
| 196 |
+
|
| 197 |
+
for i, data in enumerate(dataloader, 0):
|
| 198 |
+
if i >= train_iterations_per_step:
|
| 199 |
+
break
|
| 200 |
+
|
| 201 |
+
netD.zero_grad()
|
| 202 |
+
real_cpu = data[0].to(device)
|
| 203 |
+
b_size = real_cpu.size(0)
|
| 204 |
+
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
|
| 205 |
+
output = netD(real_cpu).view(-1)
|
| 206 |
+
errD_real = criterion(output, label)
|
| 207 |
+
errD_real.backward()
|
| 208 |
+
D_x = output.mean().item()
|
| 209 |
+
|
| 210 |
+
noise = torch.randn(b_size, nz, 1, 1, device=device)
|
| 211 |
+
fake = netG(noise)
|
| 212 |
+
label.fill_(fake_label)
|
| 213 |
+
output = netD(fake.detach()).view(-1)
|
| 214 |
+
errD_fake = criterion(output, label)
|
| 215 |
+
errD_fake.backward()
|
| 216 |
+
D_G_z1 = output.mean().item()
|
| 217 |
+
errD = errD_real + errD_fake
|
| 218 |
+
optimD.step()
|
| 219 |
+
|
| 220 |
+
netG.zero_grad()
|
| 221 |
+
label.fill_(real_label)
|
| 222 |
+
output = netD(fake).view(-1)
|
| 223 |
+
errG = criterion(output, label)
|
| 224 |
+
errG.backward()
|
| 225 |
+
D_G_z2 = output.mean().item()
|
| 226 |
+
optimG.step()
|
| 227 |
+
|
| 228 |
+
is_score, is_std = inception_score(fake, mnist_model_ref)
|
| 229 |
+
|
| 230 |
+
# Output training stats
|
| 231 |
+
if iteration % 10 == 0:
|
| 232 |
+
print(
|
| 233 |
+
"[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z))"
|
| 234 |
+
": %.4f / %.4f \tInception score: %.4f"
|
| 235 |
+
% (
|
| 236 |
+
iteration,
|
| 237 |
+
len(dataloader),
|
| 238 |
+
errD.item(),
|
| 239 |
+
errG.item(),
|
| 240 |
+
D_x,
|
| 241 |
+
D_G_z1,
|
| 242 |
+
D_G_z2,
|
| 243 |
+
is_score,
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return errG.item(), errD.item(), is_score
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def plot_images(dataloader):
|
| 251 |
+
# Plot some training images
|
| 252 |
+
real_batch = next(iter(dataloader))
|
| 253 |
+
plt.figure(figsize=(8, 8))
|
| 254 |
+
plt.axis("off")
|
| 255 |
+
plt.title("Original Images")
|
| 256 |
+
plt.imshow(
|
| 257 |
+
np.transpose(
|
| 258 |
+
vutils.make_grid(real_batch[0][:64], padding=2, normalize=True).cpu(),
|
| 259 |
+
(1, 2, 0),
|
| 260 |
+
)
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
plt.show()
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def demo_gan(checkpoint_paths):
|
| 267 |
+
img_list = []
|
| 268 |
+
fixed_noise = torch.randn(64, nz, 1, 1)
|
| 269 |
+
for path in checkpoint_paths:
|
| 270 |
+
checkpoint_dict = torch.load(os.path.join(path, "checkpoint.pt"))
|
| 271 |
+
|
| 272 |
+
loadedG = Generator()
|
| 273 |
+
loadedG.load_state_dict(checkpoint_dict["netGmodel"])
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
fake = loadedG(fixed_noise).detach().cpu()
|
| 276 |
+
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
|
| 277 |
+
|
| 278 |
+
fig = plt.figure(figsize=(8, 8))
|
| 279 |
+
plt.axis("off")
|
| 280 |
+
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
|
| 281 |
+
ani = animation.ArtistAnimation(
|
| 282 |
+
fig, ims, interval=1000, repeat_delay=1000, blit=True
|
| 283 |
+
)
|
| 284 |
+
ani.save("./generated.gif", writer="imagemagick", dpi=72)
|
| 285 |
+
plt.show()
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_func.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Example of training DCGAN on MNIST using PBT with Tune's function API.
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os
|
| 7 |
+
import tempfile
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.parallel
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
import torch.utils.data
|
| 15 |
+
from filelock import FileLock
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
from ray import train, tune
|
| 19 |
+
from ray.train import Checkpoint
|
| 20 |
+
from ray.tune.examples.pbt_dcgan_mnist.common import (
|
| 21 |
+
MODEL_PATH,
|
| 22 |
+
Discriminator,
|
| 23 |
+
Generator,
|
| 24 |
+
Net,
|
| 25 |
+
beta1,
|
| 26 |
+
demo_gan,
|
| 27 |
+
get_data_loader,
|
| 28 |
+
plot_images,
|
| 29 |
+
train_func,
|
| 30 |
+
weights_init,
|
| 31 |
+
)
|
| 32 |
+
from ray.tune.schedulers import PopulationBasedTraining
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# __Train_begin__
|
| 36 |
+
def dcgan_train(config):
|
| 37 |
+
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
| 38 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 39 |
+
netD = Discriminator().to(device)
|
| 40 |
+
netD.apply(weights_init)
|
| 41 |
+
netG = Generator().to(device)
|
| 42 |
+
netG.apply(weights_init)
|
| 43 |
+
criterion = nn.BCELoss()
|
| 44 |
+
optimizerD = optim.Adam(
|
| 45 |
+
netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
|
| 46 |
+
)
|
| 47 |
+
optimizerG = optim.Adam(
|
| 48 |
+
netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
|
| 49 |
+
)
|
| 50 |
+
with FileLock(os.path.expanduser("~/ray_results/.data.lock")):
|
| 51 |
+
dataloader = get_data_loader()
|
| 52 |
+
|
| 53 |
+
step = 1
|
| 54 |
+
checkpoint = train.get_checkpoint()
|
| 55 |
+
if checkpoint:
|
| 56 |
+
with checkpoint.as_directory() as checkpoint_dir:
|
| 57 |
+
checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
|
| 58 |
+
netD.load_state_dict(checkpoint_dict["netDmodel"])
|
| 59 |
+
netG.load_state_dict(checkpoint_dict["netGmodel"])
|
| 60 |
+
optimizerD.load_state_dict(checkpoint_dict["optimD"])
|
| 61 |
+
optimizerG.load_state_dict(checkpoint_dict["optimG"])
|
| 62 |
+
# Note: Make sure to increment the loaded step by 1 to get the
|
| 63 |
+
# current step.
|
| 64 |
+
last_step = checkpoint_dict["step"]
|
| 65 |
+
step = last_step + 1
|
| 66 |
+
|
| 67 |
+
# NOTE: It's important to set the optimizer learning rates
|
| 68 |
+
# again, since we want to explore the parameters passed in by PBT.
|
| 69 |
+
# Without this, we would continue using the exact same
|
| 70 |
+
# configuration as the trial whose checkpoint we are exploiting.
|
| 71 |
+
if "netD_lr" in config:
|
| 72 |
+
for param_group in optimizerD.param_groups:
|
| 73 |
+
param_group["lr"] = config["netD_lr"]
|
| 74 |
+
if "netG_lr" in config:
|
| 75 |
+
for param_group in optimizerG.param_groups:
|
| 76 |
+
param_group["lr"] = config["netG_lr"]
|
| 77 |
+
|
| 78 |
+
while True:
|
| 79 |
+
lossG, lossD, is_score = train_func(
|
| 80 |
+
netD,
|
| 81 |
+
netG,
|
| 82 |
+
optimizerG,
|
| 83 |
+
optimizerD,
|
| 84 |
+
criterion,
|
| 85 |
+
dataloader,
|
| 86 |
+
step,
|
| 87 |
+
device,
|
| 88 |
+
config["mnist_model_ref"],
|
| 89 |
+
)
|
| 90 |
+
metrics = {"lossg": lossG, "lossd": lossD, "is_score": is_score}
|
| 91 |
+
|
| 92 |
+
if step % config["checkpoint_interval"] == 0:
|
| 93 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 94 |
+
torch.save(
|
| 95 |
+
{
|
| 96 |
+
"netDmodel": netD.state_dict(),
|
| 97 |
+
"netGmodel": netG.state_dict(),
|
| 98 |
+
"optimD": optimizerD.state_dict(),
|
| 99 |
+
"optimG": optimizerG.state_dict(),
|
| 100 |
+
"step": step,
|
| 101 |
+
},
|
| 102 |
+
os.path.join(tmpdir, "checkpoint.pt"),
|
| 103 |
+
)
|
| 104 |
+
train.report(metrics, checkpoint=Checkpoint.from_directory(tmpdir))
|
| 105 |
+
else:
|
| 106 |
+
train.report(metrics)
|
| 107 |
+
|
| 108 |
+
step += 1
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# __Train_end__
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def download_mnist_cnn():
|
| 115 |
+
import urllib.request
|
| 116 |
+
|
| 117 |
+
# Download a pre-trained MNIST model for inception score calculation.
|
| 118 |
+
# This is a tiny model (<100kb).
|
| 119 |
+
if not os.path.exists(MODEL_PATH):
|
| 120 |
+
print("downloading model")
|
| 121 |
+
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
|
| 122 |
+
urllib.request.urlretrieve(
|
| 123 |
+
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
|
| 124 |
+
"examples/pbt_dcgan_mnist/mnist_cnn.pt",
|
| 125 |
+
MODEL_PATH,
|
| 126 |
+
)
|
| 127 |
+
return MODEL_PATH
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser()
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--smoke-test", action="store_true", help="Finish quickly for testing"
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--data-dir", type=str, default="~/data/", help="Set the path of the dataset."
|
| 137 |
+
)
|
| 138 |
+
args, _ = parser.parse_known_args()
|
| 139 |
+
ray.init()
|
| 140 |
+
|
| 141 |
+
download_mnist_cnn()
|
| 142 |
+
|
| 143 |
+
dataloader = get_data_loader(args.data_dir)
|
| 144 |
+
if not args.smoke_test:
|
| 145 |
+
plot_images(dataloader)
|
| 146 |
+
|
| 147 |
+
# __tune_begin__
|
| 148 |
+
|
| 149 |
+
# load the pretrained mnist classification model for inception_score
|
| 150 |
+
mnist_cnn = Net()
|
| 151 |
+
mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
|
| 152 |
+
mnist_cnn.eval()
|
| 153 |
+
# Put the model in Ray object store.
|
| 154 |
+
mnist_model_ref = ray.put(mnist_cnn)
|
| 155 |
+
|
| 156 |
+
scheduler = PopulationBasedTraining(
|
| 157 |
+
perturbation_interval=5,
|
| 158 |
+
hyperparam_mutations={
|
| 159 |
+
# distribution for resampling
|
| 160 |
+
"netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
| 161 |
+
"netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
tune_iter = 5 if args.smoke_test else 300
|
| 166 |
+
tuner = tune.Tuner(
|
| 167 |
+
dcgan_train,
|
| 168 |
+
run_config=train.RunConfig(
|
| 169 |
+
name="pbt_dcgan_mnist",
|
| 170 |
+
stop={"training_iteration": tune_iter},
|
| 171 |
+
verbose=1,
|
| 172 |
+
),
|
| 173 |
+
tune_config=tune.TuneConfig(
|
| 174 |
+
metric="is_score",
|
| 175 |
+
mode="max",
|
| 176 |
+
num_samples=8,
|
| 177 |
+
scheduler=scheduler,
|
| 178 |
+
),
|
| 179 |
+
param_space={
|
| 180 |
+
"netG_lr": tune.choice([0.0001, 0.0002, 0.0005]),
|
| 181 |
+
"netD_lr": tune.choice([0.0001, 0.0002, 0.0005]),
|
| 182 |
+
"mnist_model_ref": mnist_model_ref,
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
results = tuner.fit()
|
| 186 |
+
# __tune_end__
|
| 187 |
+
|
| 188 |
+
# demo of the trained Generators
|
| 189 |
+
if not args.smoke_test:
|
| 190 |
+
checkpoint_paths = [result.checkpoint.to_directory() for result in results]
|
| 191 |
+
demo_gan(checkpoint_paths)
|
.venv/lib/python3.11/site-packages/ray/tune/examples/pbt_dcgan_mnist/pbt_dcgan_mnist_trainable.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Example of training DCGAN on MNIST using PBT with Tune's Trainable Class
|
| 4 |
+
API.
|
| 5 |
+
"""
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.parallel
|
| 14 |
+
import torch.optim as optim
|
| 15 |
+
import torch.utils.data
|
| 16 |
+
from filelock import FileLock
|
| 17 |
+
|
| 18 |
+
import ray
|
| 19 |
+
from ray import train, tune
|
| 20 |
+
from ray.tune.examples.pbt_dcgan_mnist.common import (
|
| 21 |
+
MODEL_PATH,
|
| 22 |
+
Discriminator,
|
| 23 |
+
Generator,
|
| 24 |
+
Net,
|
| 25 |
+
beta1,
|
| 26 |
+
demo_gan,
|
| 27 |
+
get_data_loader,
|
| 28 |
+
plot_images,
|
| 29 |
+
train_func,
|
| 30 |
+
weights_init,
|
| 31 |
+
)
|
| 32 |
+
from ray.tune.schedulers import PopulationBasedTraining
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# __Trainable_begin__
|
| 36 |
+
class PytorchTrainable(tune.Trainable):
|
| 37 |
+
def setup(self, config):
|
| 38 |
+
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
|
| 39 |
+
self.device = torch.device("cuda" if use_cuda else "cpu")
|
| 40 |
+
self.netD = Discriminator().to(self.device)
|
| 41 |
+
self.netD.apply(weights_init)
|
| 42 |
+
self.netG = Generator().to(self.device)
|
| 43 |
+
self.netG.apply(weights_init)
|
| 44 |
+
self.criterion = nn.BCELoss()
|
| 45 |
+
self.optimizerD = optim.Adam(
|
| 46 |
+
self.netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
|
| 47 |
+
)
|
| 48 |
+
self.optimizerG = optim.Adam(
|
| 49 |
+
self.netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999)
|
| 50 |
+
)
|
| 51 |
+
with FileLock(os.path.expanduser("~/.data.lock")):
|
| 52 |
+
self.dataloader = get_data_loader(config.get("data_dir", "~/data"))
|
| 53 |
+
self.mnist_model_ref = config["mnist_model_ref"]
|
| 54 |
+
|
| 55 |
+
def step(self):
|
| 56 |
+
lossG, lossD, is_score = train_func(
|
| 57 |
+
self.netD,
|
| 58 |
+
self.netG,
|
| 59 |
+
self.optimizerG,
|
| 60 |
+
self.optimizerD,
|
| 61 |
+
self.criterion,
|
| 62 |
+
self.dataloader,
|
| 63 |
+
self._iteration,
|
| 64 |
+
self.device,
|
| 65 |
+
self.mnist_model_ref,
|
| 66 |
+
)
|
| 67 |
+
return {"lossg": lossG, "lossd": lossD, "is_score": is_score}
|
| 68 |
+
|
| 69 |
+
def save_checkpoint(self, checkpoint_dir):
|
| 70 |
+
path = os.path.join(checkpoint_dir, "checkpoint.pt")
|
| 71 |
+
torch.save(
|
| 72 |
+
{
|
| 73 |
+
"netDmodel": self.netD.state_dict(),
|
| 74 |
+
"netGmodel": self.netG.state_dict(),
|
| 75 |
+
"optimD": self.optimizerD.state_dict(),
|
| 76 |
+
"optimG": self.optimizerG.state_dict(),
|
| 77 |
+
},
|
| 78 |
+
path,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return checkpoint_dir
|
| 82 |
+
|
| 83 |
+
def load_checkpoint(self, checkpoint_dir):
|
| 84 |
+
path = os.path.join(checkpoint_dir, "checkpoint.pt")
|
| 85 |
+
checkpoint = torch.load(path)
|
| 86 |
+
self.netD.load_state_dict(checkpoint["netDmodel"])
|
| 87 |
+
self.netG.load_state_dict(checkpoint["netGmodel"])
|
| 88 |
+
self.optimizerD.load_state_dict(checkpoint["optimD"])
|
| 89 |
+
self.optimizerG.load_state_dict(checkpoint["optimG"])
|
| 90 |
+
|
| 91 |
+
def reset_config(self, new_config):
|
| 92 |
+
if "netD_lr" in new_config:
|
| 93 |
+
for param_group in self.optimizerD.param_groups:
|
| 94 |
+
param_group["lr"] = new_config["netD_lr"]
|
| 95 |
+
if "netG_lr" in new_config:
|
| 96 |
+
for param_group in self.optimizerG.param_groups:
|
| 97 |
+
param_group["lr"] = new_config["netG_lr"]
|
| 98 |
+
|
| 99 |
+
self.config = new_config
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# __Trainable_end__
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
parser = argparse.ArgumentParser()
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--smoke-test", action="store_true", help="Finish quickly for testing"
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--data-dir", type=str, default="~/data/", help="Set the path of the dataset."
|
| 112 |
+
)
|
| 113 |
+
args, _ = parser.parse_known_args()
|
| 114 |
+
ray.init()
|
| 115 |
+
|
| 116 |
+
import urllib.request
|
| 117 |
+
|
| 118 |
+
# Download a pre-trained MNIST model for inception score calculation.
|
| 119 |
+
# This is a tiny model (<100kb).
|
| 120 |
+
if not os.path.exists(MODEL_PATH):
|
| 121 |
+
print("downloading model")
|
| 122 |
+
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
|
| 123 |
+
urllib.request.urlretrieve(
|
| 124 |
+
"https://github.com/ray-project/ray/raw/master/python/ray/tune/"
|
| 125 |
+
"examples/pbt_dcgan_mnist/mnist_cnn.pt",
|
| 126 |
+
MODEL_PATH,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
dataloader = get_data_loader()
|
| 130 |
+
if not args.smoke_test:
|
| 131 |
+
plot_images(dataloader)
|
| 132 |
+
|
| 133 |
+
# load the pretrained mnist classification model for inception_score
|
| 134 |
+
mnist_cnn = Net()
|
| 135 |
+
mnist_cnn.load_state_dict(torch.load(MODEL_PATH))
|
| 136 |
+
mnist_cnn.eval()
|
| 137 |
+
mnist_model_ref = ray.put(mnist_cnn)
|
| 138 |
+
|
| 139 |
+
# __tune_begin__
|
| 140 |
+
scheduler = PopulationBasedTraining(
|
| 141 |
+
time_attr="training_iteration",
|
| 142 |
+
perturbation_interval=5,
|
| 143 |
+
hyperparam_mutations={
|
| 144 |
+
# distribution for resampling
|
| 145 |
+
"netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
| 146 |
+
"netD_lr": lambda: np.random.uniform(1e-2, 1e-5),
|
| 147 |
+
},
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
tune_iter = 10 if args.smoke_test else 300
|
| 151 |
+
tuner = tune.Tuner(
|
| 152 |
+
PytorchTrainable,
|
| 153 |
+
run_config=train.RunConfig(
|
| 154 |
+
name="pbt_dcgan_mnist",
|
| 155 |
+
stop={"training_iteration": tune_iter},
|
| 156 |
+
verbose=1,
|
| 157 |
+
checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
|
| 158 |
+
),
|
| 159 |
+
tune_config=tune.TuneConfig(
|
| 160 |
+
metric="is_score",
|
| 161 |
+
mode="max",
|
| 162 |
+
num_samples=8,
|
| 163 |
+
scheduler=scheduler,
|
| 164 |
+
reuse_actors=True,
|
| 165 |
+
),
|
| 166 |
+
param_space={
|
| 167 |
+
"netG_lr": tune.sample_from(
|
| 168 |
+
lambda spec: random.choice([0.0001, 0.0002, 0.0005])
|
| 169 |
+
),
|
| 170 |
+
"netD_lr": tune.sample_from(
|
| 171 |
+
lambda spec: random.choice([0.0001, 0.0002, 0.0005])
|
| 172 |
+
),
|
| 173 |
+
"mnist_model_ref": mnist_model_ref,
|
| 174 |
+
"data_dir": args.data_dir,
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
+
results = tuner.fit()
|
| 178 |
+
|
| 179 |
+
# export_formats=[ExportFormat.MODEL]
|
| 180 |
+
# __tune_end__
|
| 181 |
+
|
| 182 |
+
# demo of the trained Generators
|
| 183 |
+
if not args.smoke_test:
|
| 184 |
+
checkpoint_paths = [result.checkpoint.to_directory() for result in results]
|
| 185 |
+
demo_gan(checkpoint_paths)
|
.venv/lib/python3.11/site-packages/ray/tune/experimental/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experimental/__pycache__/output.cpython-311.pyc
ADDED
|
Binary file (45.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/experimental/output.py
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import collections
|
| 3 |
+
import datetime
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import numbers
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import textwrap
|
| 10 |
+
import time
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from enum import IntEnum
|
| 13 |
+
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
import ray
|
| 19 |
+
from ray._private.dict import flatten_dict, unflattened_lookup
|
| 20 |
+
from ray._private.thirdparty.tabulate.tabulate import (
|
| 21 |
+
DataRow,
|
| 22 |
+
Line,
|
| 23 |
+
TableFormat,
|
| 24 |
+
tabulate,
|
| 25 |
+
)
|
| 26 |
+
from ray.air._internal.usage import AirEntrypoint
|
| 27 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 28 |
+
from ray.train import Checkpoint
|
| 29 |
+
from ray.tune.callback import Callback
|
| 30 |
+
from ray.tune.experiment.trial import Trial
|
| 31 |
+
from ray.tune.result import (
|
| 32 |
+
AUTO_RESULT_KEYS,
|
| 33 |
+
EPISODE_REWARD_MEAN,
|
| 34 |
+
MEAN_ACCURACY,
|
| 35 |
+
MEAN_LOSS,
|
| 36 |
+
TIME_TOTAL_S,
|
| 37 |
+
TIMESTEPS_TOTAL,
|
| 38 |
+
)
|
| 39 |
+
from ray.tune.search.sample import Domain
|
| 40 |
+
from ray.tune.utils.log import Verbosity
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import rich
|
| 44 |
+
import rich.layout
|
| 45 |
+
import rich.live
|
| 46 |
+
except ImportError:
|
| 47 |
+
rich = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
# defines the mapping of the key in result and the key to be printed in table.
|
| 53 |
+
# Note this is ordered!
|
| 54 |
+
DEFAULT_COLUMNS = collections.OrderedDict(
|
| 55 |
+
{
|
| 56 |
+
MEAN_ACCURACY: "acc",
|
| 57 |
+
MEAN_LOSS: "loss",
|
| 58 |
+
TRAINING_ITERATION: "iter",
|
| 59 |
+
TIME_TOTAL_S: "total time (s)",
|
| 60 |
+
TIMESTEPS_TOTAL: "ts",
|
| 61 |
+
EPISODE_REWARD_MEAN: "reward",
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# These keys are blacklisted for printing out training/tuning intermediate/final result!
|
| 66 |
+
BLACKLISTED_KEYS = {
|
| 67 |
+
"config",
|
| 68 |
+
"date",
|
| 69 |
+
"done",
|
| 70 |
+
"hostname",
|
| 71 |
+
"iterations_since_restore",
|
| 72 |
+
"node_ip",
|
| 73 |
+
"pid",
|
| 74 |
+
"time_since_restore",
|
| 75 |
+
"timestamp",
|
| 76 |
+
"trial_id",
|
| 77 |
+
"experiment_tag",
|
| 78 |
+
"should_checkpoint",
|
| 79 |
+
"_report_on", # LIGHTNING_REPORT_STAGE_KEY
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
VALID_SUMMARY_TYPES = {
|
| 83 |
+
int,
|
| 84 |
+
float,
|
| 85 |
+
np.float32,
|
| 86 |
+
np.float64,
|
| 87 |
+
np.int32,
|
| 88 |
+
np.int64,
|
| 89 |
+
type(None),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# The order of summarizing trials.
|
| 93 |
+
ORDER = [
|
| 94 |
+
Trial.RUNNING,
|
| 95 |
+
Trial.TERMINATED,
|
| 96 |
+
Trial.PAUSED,
|
| 97 |
+
Trial.PENDING,
|
| 98 |
+
Trial.ERROR,
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class AirVerbosity(IntEnum):
|
| 103 |
+
SILENT = 0
|
| 104 |
+
DEFAULT = 1
|
| 105 |
+
VERBOSE = 2
|
| 106 |
+
|
| 107 |
+
def __repr__(self):
|
| 108 |
+
return str(self.value)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
IS_NOTEBOOK = ray.widgets.util.in_notebook()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_air_verbosity(
|
| 115 |
+
verbose: Union[int, AirVerbosity, Verbosity]
|
| 116 |
+
) -> Optional[AirVerbosity]:
|
| 117 |
+
if os.environ.get("RAY_AIR_NEW_OUTPUT", "1") == "0":
|
| 118 |
+
return None
|
| 119 |
+
|
| 120 |
+
if isinstance(verbose, AirVerbosity):
|
| 121 |
+
return verbose
|
| 122 |
+
|
| 123 |
+
verbose_int = verbose if isinstance(verbose, int) else verbose.value
|
| 124 |
+
|
| 125 |
+
# Verbosity 2 and 3 both map to AirVerbosity 2
|
| 126 |
+
verbose_int = min(2, verbose_int)
|
| 127 |
+
|
| 128 |
+
return AirVerbosity(verbose_int)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _infer_params(config: Dict[str, Any]) -> List[str]:
|
| 132 |
+
params = []
|
| 133 |
+
flat_config = flatten_dict(config)
|
| 134 |
+
for key, val in flat_config.items():
|
| 135 |
+
if isinstance(val, Domain):
|
| 136 |
+
params.append(key)
|
| 137 |
+
# Grid search is a special named field. Because we flattened
|
| 138 |
+
# the whole config, we look it up per string
|
| 139 |
+
if key.endswith("/grid_search"):
|
| 140 |
+
# Truncate `/grid_search`
|
| 141 |
+
params.append(key[:-12])
|
| 142 |
+
return params
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _get_time_str(start_time: float, current_time: float) -> Tuple[str, str]:
|
| 146 |
+
"""Get strings representing the current and elapsed time.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
start_time: POSIX timestamp of the start of the tune run
|
| 150 |
+
current_time: POSIX timestamp giving the current time
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Current time and elapsed time for the current run
|
| 154 |
+
"""
|
| 155 |
+
current_time_dt = datetime.datetime.fromtimestamp(current_time)
|
| 156 |
+
start_time_dt = datetime.datetime.fromtimestamp(start_time)
|
| 157 |
+
delta: datetime.timedelta = current_time_dt - start_time_dt
|
| 158 |
+
|
| 159 |
+
rest = delta.total_seconds()
|
| 160 |
+
days = int(rest // (60 * 60 * 24))
|
| 161 |
+
|
| 162 |
+
rest -= days * (60 * 60 * 24)
|
| 163 |
+
hours = int(rest // (60 * 60))
|
| 164 |
+
|
| 165 |
+
rest -= hours * (60 * 60)
|
| 166 |
+
minutes = int(rest // 60)
|
| 167 |
+
|
| 168 |
+
seconds = int(rest - minutes * 60)
|
| 169 |
+
|
| 170 |
+
running_for_str = ""
|
| 171 |
+
if days > 0:
|
| 172 |
+
running_for_str += f"{days:d}d "
|
| 173 |
+
|
| 174 |
+
if hours > 0 or running_for_str:
|
| 175 |
+
running_for_str += f"{hours:d}hr "
|
| 176 |
+
|
| 177 |
+
if minutes > 0 or running_for_str:
|
| 178 |
+
running_for_str += f"{minutes:d}min "
|
| 179 |
+
|
| 180 |
+
running_for_str += f"{seconds:d}s"
|
| 181 |
+
|
| 182 |
+
return f"{current_time_dt:%Y-%m-%d %H:%M:%S}", running_for_str
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _get_trials_by_state(trials: List[Trial]) -> Dict[str, List[Trial]]:
|
| 186 |
+
trials_by_state = collections.defaultdict(list)
|
| 187 |
+
for t in trials:
|
| 188 |
+
trials_by_state[t.status].append(t)
|
| 189 |
+
return trials_by_state
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _get_trials_with_error(trials: List[Trial]) -> List[Trial]:
|
| 193 |
+
return [t for t in trials if t.error_file]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _infer_user_metrics(trials: List[Trial], limit: int = 4) -> List[str]:
|
| 197 |
+
"""Try to infer the metrics to print out.
|
| 198 |
+
|
| 199 |
+
By default, only the first 4 meaningful metrics in `last_result` will be
|
| 200 |
+
inferred as user implied metrics.
|
| 201 |
+
"""
|
| 202 |
+
# Using OrderedDict for OrderedSet.
|
| 203 |
+
result = collections.OrderedDict()
|
| 204 |
+
for t in trials:
|
| 205 |
+
if not t.last_result:
|
| 206 |
+
continue
|
| 207 |
+
for metric, value in t.last_result.items():
|
| 208 |
+
if metric not in DEFAULT_COLUMNS:
|
| 209 |
+
if metric not in AUTO_RESULT_KEYS:
|
| 210 |
+
if type(value) in VALID_SUMMARY_TYPES:
|
| 211 |
+
result[metric] = "" # not important
|
| 212 |
+
|
| 213 |
+
if len(result) >= limit:
|
| 214 |
+
return list(result.keys())
|
| 215 |
+
return list(result.keys())
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _current_best_trial(
|
| 219 |
+
trials: List[Trial], metric: Optional[str], mode: Optional[str]
|
| 220 |
+
) -> Tuple[Optional[Trial], Optional[str]]:
|
| 221 |
+
"""
|
| 222 |
+
Returns the best trial and the metric key. If anything is empty or None,
|
| 223 |
+
returns a trivial result of None, None.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
trials: List of trials.
|
| 227 |
+
metric: Metric that trials are being ranked.
|
| 228 |
+
mode: One of "min" or "max".
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Best trial and the metric key.
|
| 232 |
+
"""
|
| 233 |
+
if not trials or not metric or not mode:
|
| 234 |
+
return None, None
|
| 235 |
+
|
| 236 |
+
metric_op = 1.0 if mode == "max" else -1.0
|
| 237 |
+
best_metric = float("-inf")
|
| 238 |
+
best_trial = None
|
| 239 |
+
for t in trials:
|
| 240 |
+
if not t.last_result:
|
| 241 |
+
continue
|
| 242 |
+
metric_value = unflattened_lookup(metric, t.last_result, default=None)
|
| 243 |
+
if pd.isnull(metric_value):
|
| 244 |
+
continue
|
| 245 |
+
if not best_trial or metric_value * metric_op > best_metric:
|
| 246 |
+
best_metric = metric_value * metric_op
|
| 247 |
+
best_trial = t
|
| 248 |
+
return best_trial, metric
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@dataclass
|
| 252 |
+
class _PerStatusTrialTableData:
|
| 253 |
+
trial_infos: List[List[str]]
|
| 254 |
+
more_info: str
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
@dataclass
|
| 258 |
+
class _TrialTableData:
|
| 259 |
+
header: List[str]
|
| 260 |
+
data: List[_PerStatusTrialTableData]
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _max_len(value: Any, max_len: int = 20, wrap: bool = False) -> Any:
|
| 264 |
+
"""Abbreviate a string representation of an object to `max_len` characters.
|
| 265 |
+
|
| 266 |
+
For numbers, booleans and None, the original value will be returned for
|
| 267 |
+
correct rendering in the table formatting tool.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
value: Object to be represented as a string.
|
| 271 |
+
max_len: Maximum return string length.
|
| 272 |
+
"""
|
| 273 |
+
if value is None or isinstance(value, (int, float, numbers.Number, bool)):
|
| 274 |
+
return value
|
| 275 |
+
|
| 276 |
+
string = str(value)
|
| 277 |
+
if len(string) <= max_len:
|
| 278 |
+
return string
|
| 279 |
+
|
| 280 |
+
if wrap:
|
| 281 |
+
# Maximum two rows.
|
| 282 |
+
# Todo: Make this configurable in the refactor
|
| 283 |
+
if len(value) > max_len * 2:
|
| 284 |
+
value = "..." + string[(3 - (max_len * 2)) :]
|
| 285 |
+
|
| 286 |
+
wrapped = textwrap.wrap(value, width=max_len)
|
| 287 |
+
return "\n".join(wrapped)
|
| 288 |
+
|
| 289 |
+
result = "..." + string[(3 - max_len) :]
|
| 290 |
+
return result
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _get_trial_info(
|
| 294 |
+
trial: Trial, param_keys: List[str], metric_keys: List[str]
|
| 295 |
+
) -> List[str]:
|
| 296 |
+
"""Returns the following information about a trial:
|
| 297 |
+
|
| 298 |
+
name | status | metrics...
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
trial: Trial to get information for.
|
| 302 |
+
param_keys: Names of parameters to include.
|
| 303 |
+
metric_keys: Names of metrics to include.
|
| 304 |
+
"""
|
| 305 |
+
result = trial.last_result
|
| 306 |
+
trial_info = [str(trial), trial.status]
|
| 307 |
+
|
| 308 |
+
# params
|
| 309 |
+
trial_info.extend(
|
| 310 |
+
[
|
| 311 |
+
_max_len(
|
| 312 |
+
unflattened_lookup(param, trial.config, default=None),
|
| 313 |
+
)
|
| 314 |
+
for param in param_keys
|
| 315 |
+
]
|
| 316 |
+
)
|
| 317 |
+
# metrics
|
| 318 |
+
trial_info.extend(
|
| 319 |
+
[
|
| 320 |
+
_max_len(
|
| 321 |
+
unflattened_lookup(metric, result, default=None),
|
| 322 |
+
)
|
| 323 |
+
for metric in metric_keys
|
| 324 |
+
]
|
| 325 |
+
)
|
| 326 |
+
return trial_info
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _get_trial_table_data_per_status(
|
| 330 |
+
status: str,
|
| 331 |
+
trials: List[Trial],
|
| 332 |
+
param_keys: List[str],
|
| 333 |
+
metric_keys: List[str],
|
| 334 |
+
force_max_rows: bool = False,
|
| 335 |
+
) -> Optional[_PerStatusTrialTableData]:
|
| 336 |
+
"""Gather all information of trials pertained to one `status`.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
status: The trial status of interest.
|
| 340 |
+
trials: all the trials of that status.
|
| 341 |
+
param_keys: *Ordered* list of parameters to be displayed in the table.
|
| 342 |
+
metric_keys: *Ordered* list of metrics to be displayed in the table.
|
| 343 |
+
Including both default and user defined.
|
| 344 |
+
force_max_rows: Whether or not to enforce a max row number for this status.
|
| 345 |
+
If True, only a max of `5` rows will be shown.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
All information of trials pertained to the `status`.
|
| 349 |
+
"""
|
| 350 |
+
# TODO: configure it.
|
| 351 |
+
max_row = 5 if force_max_rows else math.inf
|
| 352 |
+
if not trials:
|
| 353 |
+
return None
|
| 354 |
+
|
| 355 |
+
trial_infos = list()
|
| 356 |
+
more_info = None
|
| 357 |
+
for t in trials:
|
| 358 |
+
if len(trial_infos) >= max_row:
|
| 359 |
+
remaining = len(trials) - max_row
|
| 360 |
+
more_info = f"{remaining} more {status}"
|
| 361 |
+
break
|
| 362 |
+
trial_infos.append(_get_trial_info(t, param_keys, metric_keys))
|
| 363 |
+
return _PerStatusTrialTableData(trial_infos, more_info)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _get_trial_table_data(
|
| 367 |
+
trials: List[Trial],
|
| 368 |
+
param_keys: List[str],
|
| 369 |
+
metric_keys: List[str],
|
| 370 |
+
all_rows: bool = False,
|
| 371 |
+
wrap_headers: bool = False,
|
| 372 |
+
) -> _TrialTableData:
|
| 373 |
+
"""Generate a table showing the current progress of tuning trials.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
trials: List of trials for which progress is to be shown.
|
| 377 |
+
param_keys: Ordered list of parameters to be displayed in the table.
|
| 378 |
+
metric_keys: Ordered list of metrics to be displayed in the table.
|
| 379 |
+
Including both default and user defined.
|
| 380 |
+
Will only be shown if at least one trial is having the key.
|
| 381 |
+
all_rows: Force to show all rows.
|
| 382 |
+
wrap_headers: If True, header columns can be wrapped with ``\n``.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Trial table data, including header and trial table per each status.
|
| 386 |
+
"""
|
| 387 |
+
# TODO: configure
|
| 388 |
+
max_trial_num_to_show = 20
|
| 389 |
+
max_column_length = 20
|
| 390 |
+
trials_by_state = _get_trials_by_state(trials)
|
| 391 |
+
|
| 392 |
+
# get the right metric to show.
|
| 393 |
+
metric_keys = [
|
| 394 |
+
k
|
| 395 |
+
for k in metric_keys
|
| 396 |
+
if any(
|
| 397 |
+
unflattened_lookup(k, t.last_result, default=None) is not None
|
| 398 |
+
for t in trials
|
| 399 |
+
)
|
| 400 |
+
]
|
| 401 |
+
|
| 402 |
+
# get header from metric keys
|
| 403 |
+
formatted_metric_columns = [
|
| 404 |
+
_max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in metric_keys
|
| 405 |
+
]
|
| 406 |
+
|
| 407 |
+
formatted_param_columns = [
|
| 408 |
+
_max_len(k, max_len=max_column_length, wrap=wrap_headers) for k in param_keys
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
metric_header = [
|
| 412 |
+
DEFAULT_COLUMNS[metric] if metric in DEFAULT_COLUMNS else formatted
|
| 413 |
+
for metric, formatted in zip(metric_keys, formatted_metric_columns)
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
param_header = formatted_param_columns
|
| 417 |
+
|
| 418 |
+
# Map to the abbreviated version if necessary.
|
| 419 |
+
header = ["Trial name", "status"] + param_header + metric_header
|
| 420 |
+
|
| 421 |
+
trial_data = list()
|
| 422 |
+
for t_status in ORDER:
|
| 423 |
+
trial_data_per_status = _get_trial_table_data_per_status(
|
| 424 |
+
t_status,
|
| 425 |
+
trials_by_state[t_status],
|
| 426 |
+
param_keys=param_keys,
|
| 427 |
+
metric_keys=metric_keys,
|
| 428 |
+
force_max_rows=not all_rows and len(trials) > max_trial_num_to_show,
|
| 429 |
+
)
|
| 430 |
+
if trial_data_per_status:
|
| 431 |
+
trial_data.append(trial_data_per_status)
|
| 432 |
+
return _TrialTableData(header, trial_data)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _best_trial_str(
|
| 436 |
+
trial: Trial,
|
| 437 |
+
metric: str,
|
| 438 |
+
):
|
| 439 |
+
"""Returns a readable message stating the current best trial."""
|
| 440 |
+
# returns something like
|
| 441 |
+
# Current best trial: 18ae7_00005 with loss=0.5918508041056858 and params={'train_loop_config': {'lr': 0.059253447253394785}}. # noqa
|
| 442 |
+
val = unflattened_lookup(metric, trial.last_result, default=None)
|
| 443 |
+
config = trial.last_result.get("config", {})
|
| 444 |
+
parameter_columns = list(config.keys())
|
| 445 |
+
params = {p: unflattened_lookup(p, config) for p in parameter_columns}
|
| 446 |
+
return (
|
| 447 |
+
f"Current best trial: {trial.trial_id} with {metric}={val} and "
|
| 448 |
+
f"params={params}"
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _render_table_item(
|
| 453 |
+
key: str, item: Any, prefix: str = ""
|
| 454 |
+
) -> Iterable[Tuple[str, str]]:
|
| 455 |
+
key = prefix + key
|
| 456 |
+
|
| 457 |
+
if isinstance(item, argparse.Namespace):
|
| 458 |
+
item = item.__dict__
|
| 459 |
+
|
| 460 |
+
if isinstance(item, float):
|
| 461 |
+
# tabulate does not work well with mixed-type columns, so we format
|
| 462 |
+
# numbers ourselves.
|
| 463 |
+
yield key, f"{item:.5f}".rstrip("0")
|
| 464 |
+
elif isinstance(item, dict):
|
| 465 |
+
flattened = flatten_dict(item)
|
| 466 |
+
for k, v in sorted(flattened.items()):
|
| 467 |
+
yield key + "/" + str(k), _max_len(v)
|
| 468 |
+
else:
|
| 469 |
+
yield key, _max_len(item, 20)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def _get_dict_as_table_data(
|
| 473 |
+
data: Dict,
|
| 474 |
+
include: Optional[Collection] = None,
|
| 475 |
+
exclude: Optional[Collection] = None,
|
| 476 |
+
upper_keys: Optional[Collection] = None,
|
| 477 |
+
):
|
| 478 |
+
"""Get ``data`` dict as table rows.
|
| 479 |
+
|
| 480 |
+
If specified, excluded keys are removed. Excluded keys can either be
|
| 481 |
+
fully specified (e.g. ``foo/bar/baz``) or specify a top-level dictionary
|
| 482 |
+
(e.g. ``foo``), but no intermediate levels (e.g. ``foo/bar``). If this is
|
| 483 |
+
needed, we can revisit the logic at a later point.
|
| 484 |
+
|
| 485 |
+
The same is true for included keys. If a top-level key is included (e.g. ``foo``)
|
| 486 |
+
then all sub keys will be included, too, except if they are excluded.
|
| 487 |
+
|
| 488 |
+
If keys are both excluded and included, exclusion takes precedence. Thus, if
|
| 489 |
+
``foo`` is excluded but ``foo/bar`` is included, it won't show up in the output.
|
| 490 |
+
"""
|
| 491 |
+
include = include or set()
|
| 492 |
+
exclude = exclude or set()
|
| 493 |
+
upper_keys = upper_keys or set()
|
| 494 |
+
|
| 495 |
+
upper = []
|
| 496 |
+
lower = []
|
| 497 |
+
|
| 498 |
+
for key, value in sorted(data.items()):
|
| 499 |
+
# Exclude top-level keys
|
| 500 |
+
if key in exclude:
|
| 501 |
+
continue
|
| 502 |
+
|
| 503 |
+
for k, v in _render_table_item(str(key), value):
|
| 504 |
+
# k is now the full subkey, e.g. config/nested/key
|
| 505 |
+
|
| 506 |
+
# We can exclude the full key
|
| 507 |
+
if k in exclude:
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
# If we specify includes, top-level includes should take precedence
|
| 511 |
+
# (e.g. if `config` is in include, include config always).
|
| 512 |
+
if include and key not in include and k not in include:
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
if key in upper_keys:
|
| 516 |
+
upper.append([k, v])
|
| 517 |
+
else:
|
| 518 |
+
lower.append([k, v])
|
| 519 |
+
|
| 520 |
+
if not upper:
|
| 521 |
+
return lower
|
| 522 |
+
elif not lower:
|
| 523 |
+
return upper
|
| 524 |
+
else:
|
| 525 |
+
return upper + lower
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
if sys.stdout and sys.stdout.encoding and sys.stdout.encoding.startswith("utf"):
|
| 529 |
+
# Copied/adjusted from tabulate
|
| 530 |
+
AIR_TABULATE_TABLEFMT = TableFormat(
|
| 531 |
+
lineabove=Line("╭", "─", "─", "╮"),
|
| 532 |
+
linebelowheader=Line("├", "─", "─", "┤"),
|
| 533 |
+
linebetweenrows=None,
|
| 534 |
+
linebelow=Line("╰", "─", "─", "╯"),
|
| 535 |
+
headerrow=DataRow("│", " ", "│"),
|
| 536 |
+
datarow=DataRow("│", " ", "│"),
|
| 537 |
+
padding=1,
|
| 538 |
+
with_header_hide=None,
|
| 539 |
+
)
|
| 540 |
+
else:
|
| 541 |
+
# For non-utf output, use ascii-compatible characters.
|
| 542 |
+
# This prevents errors e.g. when legacy windows encoding is used.
|
| 543 |
+
AIR_TABULATE_TABLEFMT = TableFormat(
|
| 544 |
+
lineabove=Line("+", "-", "-", "+"),
|
| 545 |
+
linebelowheader=Line("+", "-", "-", "+"),
|
| 546 |
+
linebetweenrows=None,
|
| 547 |
+
linebelow=Line("+", "-", "-", "+"),
|
| 548 |
+
headerrow=DataRow("|", " ", "|"),
|
| 549 |
+
datarow=DataRow("|", " ", "|"),
|
| 550 |
+
padding=1,
|
| 551 |
+
with_header_hide=None,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _print_dict_as_table(
|
| 556 |
+
data: Dict,
|
| 557 |
+
header: Optional[str] = None,
|
| 558 |
+
include: Optional[Collection[str]] = None,
|
| 559 |
+
exclude: Optional[Collection[str]] = None,
|
| 560 |
+
division: Optional[Collection[str]] = None,
|
| 561 |
+
):
|
| 562 |
+
table_data = _get_dict_as_table_data(
|
| 563 |
+
data=data, include=include, exclude=exclude, upper_keys=division
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
headers = [header, ""] if header else []
|
| 567 |
+
|
| 568 |
+
if not table_data:
|
| 569 |
+
return
|
| 570 |
+
|
| 571 |
+
print(
|
| 572 |
+
tabulate(
|
| 573 |
+
table_data,
|
| 574 |
+
headers=headers,
|
| 575 |
+
colalign=("left", "right"),
|
| 576 |
+
tablefmt=AIR_TABULATE_TABLEFMT,
|
| 577 |
+
)
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class ProgressReporter(Callback):
|
| 582 |
+
"""Periodically prints out status update."""
|
| 583 |
+
|
| 584 |
+
# TODO: Make this configurable
|
| 585 |
+
_heartbeat_freq = 30 # every 30 sec
|
| 586 |
+
# to be updated by subclasses.
|
| 587 |
+
_heartbeat_threshold = None
|
| 588 |
+
_start_end_verbosity = None
|
| 589 |
+
_intermediate_result_verbosity = None
|
| 590 |
+
_addressing_tmpl = None
|
| 591 |
+
|
| 592 |
+
def __init__(
|
| 593 |
+
self,
|
| 594 |
+
verbosity: AirVerbosity,
|
| 595 |
+
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
|
| 596 |
+
):
|
| 597 |
+
"""
|
| 598 |
+
|
| 599 |
+
Args:
|
| 600 |
+
verbosity: AirVerbosity level.
|
| 601 |
+
"""
|
| 602 |
+
self._verbosity = verbosity
|
| 603 |
+
self._start_time = time.time()
|
| 604 |
+
self._last_heartbeat_time = float("-inf")
|
| 605 |
+
self._start_time = time.time()
|
| 606 |
+
self._progress_metrics = progress_metrics
|
| 607 |
+
self._trial_last_printed_results = {}
|
| 608 |
+
|
| 609 |
+
self._in_block = None
|
| 610 |
+
|
| 611 |
+
@property
|
| 612 |
+
def verbosity(self) -> AirVerbosity:
|
| 613 |
+
return self._verbosity
|
| 614 |
+
|
| 615 |
+
def setup(
|
| 616 |
+
self,
|
| 617 |
+
start_time: Optional[float] = None,
|
| 618 |
+
**kwargs,
|
| 619 |
+
):
|
| 620 |
+
self._start_time = start_time
|
| 621 |
+
|
| 622 |
+
def _start_block(self, indicator: Any):
|
| 623 |
+
if self._in_block != indicator:
|
| 624 |
+
self._end_block()
|
| 625 |
+
self._in_block = indicator
|
| 626 |
+
|
| 627 |
+
def _end_block(self):
|
| 628 |
+
if self._in_block:
|
| 629 |
+
print("")
|
| 630 |
+
self._in_block = None
|
| 631 |
+
|
| 632 |
+
def on_experiment_end(self, trials: List["Trial"], **info):
|
| 633 |
+
self._end_block()
|
| 634 |
+
|
| 635 |
+
def experiment_started(
|
| 636 |
+
self,
|
| 637 |
+
experiment_name: str,
|
| 638 |
+
experiment_path: str,
|
| 639 |
+
searcher_str: str,
|
| 640 |
+
scheduler_str: str,
|
| 641 |
+
total_num_samples: int,
|
| 642 |
+
tensorboard_path: Optional[str] = None,
|
| 643 |
+
**kwargs,
|
| 644 |
+
):
|
| 645 |
+
self._start_block("exp_start")
|
| 646 |
+
print(f"\nView detailed results here: {experiment_path}")
|
| 647 |
+
|
| 648 |
+
if tensorboard_path:
|
| 649 |
+
print(
|
| 650 |
+
f"To visualize your results with TensorBoard, run: "
|
| 651 |
+
f"`tensorboard --logdir {tensorboard_path}`"
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
@property
|
| 655 |
+
def _time_heartbeat_str(self):
|
| 656 |
+
current_time_str, running_time_str = _get_time_str(
|
| 657 |
+
self._start_time, time.time()
|
| 658 |
+
)
|
| 659 |
+
return (
|
| 660 |
+
f"Current time: {current_time_str}. Total running time: " + running_time_str
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
def print_heartbeat(self, trials, *args, force: bool = False):
|
| 664 |
+
if self._verbosity < self._heartbeat_threshold:
|
| 665 |
+
return
|
| 666 |
+
if force or time.time() - self._last_heartbeat_time >= self._heartbeat_freq:
|
| 667 |
+
self._print_heartbeat(trials, *args, force=force)
|
| 668 |
+
self._last_heartbeat_time = time.time()
|
| 669 |
+
|
| 670 |
+
def _print_heartbeat(self, trials, *args, force: bool = False):
|
| 671 |
+
raise NotImplementedError
|
| 672 |
+
|
| 673 |
+
def _print_result(self, trial, result: Optional[Dict] = None, force: bool = False):
|
| 674 |
+
"""Only print result if a different result has been reported, or force=True"""
|
| 675 |
+
result = result or trial.last_result
|
| 676 |
+
|
| 677 |
+
last_result_iter = self._trial_last_printed_results.get(trial.trial_id, -1)
|
| 678 |
+
this_iter = result.get(TRAINING_ITERATION, 0)
|
| 679 |
+
|
| 680 |
+
if this_iter != last_result_iter or force:
|
| 681 |
+
_print_dict_as_table(
|
| 682 |
+
result,
|
| 683 |
+
header=f"{self._addressing_tmpl.format(trial)} result",
|
| 684 |
+
include=self._progress_metrics,
|
| 685 |
+
exclude=BLACKLISTED_KEYS,
|
| 686 |
+
division=AUTO_RESULT_KEYS,
|
| 687 |
+
)
|
| 688 |
+
self._trial_last_printed_results[trial.trial_id] = this_iter
|
| 689 |
+
|
| 690 |
+
def _print_config(self, trial):
|
| 691 |
+
_print_dict_as_table(
|
| 692 |
+
trial.config, header=f"{self._addressing_tmpl.format(trial)} config"
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
def on_trial_result(
|
| 696 |
+
self,
|
| 697 |
+
iteration: int,
|
| 698 |
+
trials: List[Trial],
|
| 699 |
+
trial: Trial,
|
| 700 |
+
result: Dict,
|
| 701 |
+
**info,
|
| 702 |
+
):
|
| 703 |
+
if self.verbosity < self._intermediate_result_verbosity:
|
| 704 |
+
return
|
| 705 |
+
self._start_block(f"trial_{trial}_result_{result[TRAINING_ITERATION]}")
|
| 706 |
+
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
|
| 707 |
+
print(
|
| 708 |
+
f"{self._addressing_tmpl.format(trial)} "
|
| 709 |
+
f"finished iteration {result[TRAINING_ITERATION]} "
|
| 710 |
+
f"at {curr_time_str}. Total running time: " + running_time_str
|
| 711 |
+
)
|
| 712 |
+
self._print_result(trial, result)
|
| 713 |
+
|
| 714 |
+
def on_trial_complete(
|
| 715 |
+
self, iteration: int, trials: List[Trial], trial: Trial, **info
|
| 716 |
+
):
|
| 717 |
+
if self.verbosity < self._start_end_verbosity:
|
| 718 |
+
return
|
| 719 |
+
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
|
| 720 |
+
finished_iter = 0
|
| 721 |
+
if trial.last_result and TRAINING_ITERATION in trial.last_result:
|
| 722 |
+
finished_iter = trial.last_result[TRAINING_ITERATION]
|
| 723 |
+
|
| 724 |
+
self._start_block(f"trial_{trial}_complete")
|
| 725 |
+
print(
|
| 726 |
+
f"{self._addressing_tmpl.format(trial)} "
|
| 727 |
+
f"completed after {finished_iter} iterations "
|
| 728 |
+
f"at {curr_time_str}. Total running time: " + running_time_str
|
| 729 |
+
)
|
| 730 |
+
self._print_result(trial)
|
| 731 |
+
|
| 732 |
+
def on_trial_error(
|
| 733 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 734 |
+
):
|
| 735 |
+
curr_time_str, running_time_str = _get_time_str(self._start_time, time.time())
|
| 736 |
+
finished_iter = 0
|
| 737 |
+
if trial.last_result and TRAINING_ITERATION in trial.last_result:
|
| 738 |
+
finished_iter = trial.last_result[TRAINING_ITERATION]
|
| 739 |
+
|
| 740 |
+
self._start_block(f"trial_{trial}_error")
|
| 741 |
+
print(
|
| 742 |
+
f"{self._addressing_tmpl.format(trial)} "
|
| 743 |
+
f"errored after {finished_iter} iterations "
|
| 744 |
+
f"at {curr_time_str}. Total running time: {running_time_str}\n"
|
| 745 |
+
f"Error file: {trial.error_file}"
|
| 746 |
+
)
|
| 747 |
+
self._print_result(trial)
|
| 748 |
+
|
| 749 |
+
def on_trial_recover(
|
| 750 |
+
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
|
| 751 |
+
):
|
| 752 |
+
self.on_trial_error(iteration=iteration, trials=trials, trial=trial, **info)
|
| 753 |
+
|
| 754 |
+
def on_checkpoint(
|
| 755 |
+
self,
|
| 756 |
+
iteration: int,
|
| 757 |
+
trials: List[Trial],
|
| 758 |
+
trial: Trial,
|
| 759 |
+
checkpoint: Checkpoint,
|
| 760 |
+
**info,
|
| 761 |
+
):
|
| 762 |
+
if self._verbosity < self._intermediate_result_verbosity:
|
| 763 |
+
return
|
| 764 |
+
# don't think this is supposed to happen but just to be safe.
|
| 765 |
+
saved_iter = "?"
|
| 766 |
+
if trial.last_result and TRAINING_ITERATION in trial.last_result:
|
| 767 |
+
saved_iter = trial.last_result[TRAINING_ITERATION]
|
| 768 |
+
|
| 769 |
+
self._start_block(f"trial_{trial}_result_{saved_iter}")
|
| 770 |
+
|
| 771 |
+
loc = f"({checkpoint.filesystem.type_name}){checkpoint.path}"
|
| 772 |
+
|
| 773 |
+
print(
|
| 774 |
+
f"{self._addressing_tmpl.format(trial)} "
|
| 775 |
+
f"saved a checkpoint for iteration {saved_iter} "
|
| 776 |
+
f"at: {loc}"
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def on_trial_start(self, iteration: int, trials: List[Trial], trial: Trial, **info):
|
| 780 |
+
if self.verbosity < self._start_end_verbosity:
|
| 781 |
+
return
|
| 782 |
+
has_config = bool(trial.config)
|
| 783 |
+
|
| 784 |
+
self._start_block(f"trial_{trial}_start")
|
| 785 |
+
if has_config:
|
| 786 |
+
print(
|
| 787 |
+
f"{self._addressing_tmpl.format(trial)} " f"started with configuration:"
|
| 788 |
+
)
|
| 789 |
+
self._print_config(trial)
|
| 790 |
+
else:
|
| 791 |
+
print(
|
| 792 |
+
f"{self._addressing_tmpl.format(trial)} "
|
| 793 |
+
f"started without custom configuration."
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
def _detect_reporter(
|
| 798 |
+
verbosity: AirVerbosity,
|
| 799 |
+
num_samples: int,
|
| 800 |
+
entrypoint: Optional[AirEntrypoint] = None,
|
| 801 |
+
metric: Optional[str] = None,
|
| 802 |
+
mode: Optional[str] = None,
|
| 803 |
+
config: Optional[Dict] = None,
|
| 804 |
+
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
|
| 805 |
+
):
|
| 806 |
+
if entrypoint in {
|
| 807 |
+
AirEntrypoint.TUNE_RUN,
|
| 808 |
+
AirEntrypoint.TUNE_RUN_EXPERIMENTS,
|
| 809 |
+
AirEntrypoint.TUNER,
|
| 810 |
+
}:
|
| 811 |
+
reporter = TuneTerminalReporter(
|
| 812 |
+
verbosity,
|
| 813 |
+
num_samples=num_samples,
|
| 814 |
+
metric=metric,
|
| 815 |
+
mode=mode,
|
| 816 |
+
config=config,
|
| 817 |
+
progress_metrics=progress_metrics,
|
| 818 |
+
)
|
| 819 |
+
else:
|
| 820 |
+
reporter = TrainReporter(verbosity, progress_metrics=progress_metrics)
|
| 821 |
+
return reporter
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
class TuneReporterBase(ProgressReporter):
|
| 825 |
+
_heartbeat_threshold = AirVerbosity.DEFAULT
|
| 826 |
+
_wrap_headers = False
|
| 827 |
+
_intermediate_result_verbosity = AirVerbosity.VERBOSE
|
| 828 |
+
_start_end_verbosity = AirVerbosity.DEFAULT
|
| 829 |
+
_addressing_tmpl = "Trial {}"
|
| 830 |
+
|
| 831 |
+
def __init__(
|
| 832 |
+
self,
|
| 833 |
+
verbosity: AirVerbosity,
|
| 834 |
+
num_samples: int = 0,
|
| 835 |
+
metric: Optional[str] = None,
|
| 836 |
+
mode: Optional[str] = None,
|
| 837 |
+
config: Optional[Dict] = None,
|
| 838 |
+
progress_metrics: Optional[Union[List[str], List[Dict[str, str]]]] = None,
|
| 839 |
+
):
|
| 840 |
+
self._num_samples = num_samples
|
| 841 |
+
self._metric = metric
|
| 842 |
+
self._mode = mode
|
| 843 |
+
# will be populated when first result comes in.
|
| 844 |
+
self._inferred_metric = None
|
| 845 |
+
self._inferred_params = _infer_params(config or {})
|
| 846 |
+
super(TuneReporterBase, self).__init__(
|
| 847 |
+
verbosity=verbosity, progress_metrics=progress_metrics
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
def setup(
|
| 851 |
+
self,
|
| 852 |
+
start_time: Optional[float] = None,
|
| 853 |
+
total_samples: Optional[int] = None,
|
| 854 |
+
**kwargs,
|
| 855 |
+
):
|
| 856 |
+
super().setup(start_time=start_time)
|
| 857 |
+
self._num_samples = total_samples
|
| 858 |
+
|
| 859 |
+
def _get_overall_trial_progress_str(self, trials):
|
| 860 |
+
result = " | ".join(
|
| 861 |
+
[
|
| 862 |
+
f"{len(trials)} {status}"
|
| 863 |
+
for status, trials in _get_trials_by_state(trials).items()
|
| 864 |
+
]
|
| 865 |
+
)
|
| 866 |
+
return f"Trial status: {result}"
|
| 867 |
+
|
| 868 |
+
# TODO: Return a more structured type to share code with Jupyter flow.
|
| 869 |
+
def _get_heartbeat(
|
| 870 |
+
self, trials, *sys_args, force_full_output: bool = False
|
| 871 |
+
) -> Tuple[List[str], _TrialTableData]:
|
| 872 |
+
result = list()
|
| 873 |
+
# Trial status: 1 RUNNING | 7 PENDING
|
| 874 |
+
result.append(self._get_overall_trial_progress_str(trials))
|
| 875 |
+
# Current time: 2023-02-24 12:35:39 (running for 00:00:37.40)
|
| 876 |
+
result.append(self._time_heartbeat_str)
|
| 877 |
+
# Logical resource usage: 8.0/64 CPUs, 0/0 GPUs
|
| 878 |
+
result.extend(sys_args)
|
| 879 |
+
# Current best trial: TRIAL NAME, metrics: {...}, parameters: {...}
|
| 880 |
+
current_best_trial, metric = _current_best_trial(
|
| 881 |
+
trials, self._metric, self._mode
|
| 882 |
+
)
|
| 883 |
+
if current_best_trial:
|
| 884 |
+
result.append(_best_trial_str(current_best_trial, metric))
|
| 885 |
+
# Now populating the trial table data.
|
| 886 |
+
if not self._inferred_metric:
|
| 887 |
+
# try inferring again.
|
| 888 |
+
self._inferred_metric = _infer_user_metrics(trials)
|
| 889 |
+
|
| 890 |
+
all_metrics = list(DEFAULT_COLUMNS.keys()) + self._inferred_metric
|
| 891 |
+
|
| 892 |
+
trial_table_data = _get_trial_table_data(
|
| 893 |
+
trials,
|
| 894 |
+
param_keys=self._inferred_params,
|
| 895 |
+
metric_keys=all_metrics,
|
| 896 |
+
all_rows=force_full_output,
|
| 897 |
+
wrap_headers=self._wrap_headers,
|
| 898 |
+
)
|
| 899 |
+
return result, trial_table_data
|
| 900 |
+
|
| 901 |
+
def _print_heartbeat(self, trials, *sys_args, force: bool = False):
|
| 902 |
+
raise NotImplementedError
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class TuneTerminalReporter(TuneReporterBase):
|
| 906 |
+
def experiment_started(
|
| 907 |
+
self,
|
| 908 |
+
experiment_name: str,
|
| 909 |
+
experiment_path: str,
|
| 910 |
+
searcher_str: str,
|
| 911 |
+
scheduler_str: str,
|
| 912 |
+
total_num_samples: int,
|
| 913 |
+
tensorboard_path: Optional[str] = None,
|
| 914 |
+
**kwargs,
|
| 915 |
+
):
|
| 916 |
+
if total_num_samples > sys.maxsize:
|
| 917 |
+
total_num_samples_str = "infinite"
|
| 918 |
+
else:
|
| 919 |
+
total_num_samples_str = str(total_num_samples)
|
| 920 |
+
|
| 921 |
+
print(
|
| 922 |
+
tabulate(
|
| 923 |
+
[
|
| 924 |
+
["Search algorithm", searcher_str],
|
| 925 |
+
["Scheduler", scheduler_str],
|
| 926 |
+
["Number of trials", total_num_samples_str],
|
| 927 |
+
],
|
| 928 |
+
headers=["Configuration for experiment", experiment_name],
|
| 929 |
+
tablefmt=AIR_TABULATE_TABLEFMT,
|
| 930 |
+
)
|
| 931 |
+
)
|
| 932 |
+
super().experiment_started(
|
| 933 |
+
experiment_name=experiment_name,
|
| 934 |
+
experiment_path=experiment_path,
|
| 935 |
+
searcher_str=searcher_str,
|
| 936 |
+
scheduler_str=scheduler_str,
|
| 937 |
+
total_num_samples=total_num_samples,
|
| 938 |
+
tensorboard_path=tensorboard_path,
|
| 939 |
+
**kwargs,
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
def _print_heartbeat(self, trials, *sys_args, force: bool = False):
|
| 943 |
+
if self._verbosity < self._heartbeat_threshold and not force:
|
| 944 |
+
return
|
| 945 |
+
heartbeat_strs, table_data = self._get_heartbeat(
|
| 946 |
+
trials, *sys_args, force_full_output=force
|
| 947 |
+
)
|
| 948 |
+
|
| 949 |
+
self._start_block("heartbeat")
|
| 950 |
+
for s in heartbeat_strs:
|
| 951 |
+
print(s)
|
| 952 |
+
# now print the table using Tabulate
|
| 953 |
+
more_infos = []
|
| 954 |
+
all_data = []
|
| 955 |
+
fail_header = table_data.header
|
| 956 |
+
for sub_table in table_data.data:
|
| 957 |
+
all_data.extend(sub_table.trial_infos)
|
| 958 |
+
if sub_table.more_info:
|
| 959 |
+
more_infos.append(sub_table.more_info)
|
| 960 |
+
|
| 961 |
+
print(
|
| 962 |
+
tabulate(
|
| 963 |
+
all_data,
|
| 964 |
+
headers=fail_header,
|
| 965 |
+
tablefmt=AIR_TABULATE_TABLEFMT,
|
| 966 |
+
showindex=False,
|
| 967 |
+
)
|
| 968 |
+
)
|
| 969 |
+
if more_infos:
|
| 970 |
+
print(", ".join(more_infos))
|
| 971 |
+
|
| 972 |
+
if not force:
|
| 973 |
+
# Only print error table at end of training
|
| 974 |
+
return
|
| 975 |
+
|
| 976 |
+
trials_with_error = _get_trials_with_error(trials)
|
| 977 |
+
if not trials_with_error:
|
| 978 |
+
return
|
| 979 |
+
|
| 980 |
+
self._start_block("status_errored")
|
| 981 |
+
print(f"Number of errored trials: {len(trials_with_error)}")
|
| 982 |
+
fail_header = ["Trial name", "# failures", "error file"]
|
| 983 |
+
fail_table_data = [
|
| 984 |
+
[
|
| 985 |
+
str(trial),
|
| 986 |
+
str(trial.run_metadata.num_failures)
|
| 987 |
+
+ ("" if trial.status == Trial.ERROR else "*"),
|
| 988 |
+
trial.error_file,
|
| 989 |
+
]
|
| 990 |
+
for trial in trials_with_error
|
| 991 |
+
]
|
| 992 |
+
print(
|
| 993 |
+
tabulate(
|
| 994 |
+
fail_table_data,
|
| 995 |
+
headers=fail_header,
|
| 996 |
+
tablefmt=AIR_TABULATE_TABLEFMT,
|
| 997 |
+
showindex=False,
|
| 998 |
+
colalign=("left", "right", "left"),
|
| 999 |
+
)
|
| 1000 |
+
)
|
| 1001 |
+
if any(trial.status == Trial.TERMINATED for trial in trials_with_error):
|
| 1002 |
+
print("* The trial terminated successfully after retrying.")
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
class TrainReporter(ProgressReporter):
|
| 1006 |
+
# the minimal verbosity threshold at which heartbeat starts getting printed.
|
| 1007 |
+
_heartbeat_threshold = AirVerbosity.VERBOSE
|
| 1008 |
+
_intermediate_result_verbosity = AirVerbosity.DEFAULT
|
| 1009 |
+
_start_end_verbosity = AirVerbosity.DEFAULT
|
| 1010 |
+
_addressing_tmpl = "Training"
|
| 1011 |
+
|
| 1012 |
+
def _get_heartbeat(self, trials: List[Trial], force_full_output: bool = False):
|
| 1013 |
+
# Training on iteration 1. Current time: 2023-03-22 15:29:25 (running for 00:00:03.24) # noqa
|
| 1014 |
+
if len(trials) == 0:
|
| 1015 |
+
return
|
| 1016 |
+
trial = trials[0]
|
| 1017 |
+
if trial.status != Trial.RUNNING:
|
| 1018 |
+
return " ".join(
|
| 1019 |
+
[f"Training is in {trial.status} status.", self._time_heartbeat_str]
|
| 1020 |
+
)
|
| 1021 |
+
if not trial.last_result or TRAINING_ITERATION not in trial.last_result:
|
| 1022 |
+
iter_num = 1
|
| 1023 |
+
else:
|
| 1024 |
+
iter_num = trial.last_result[TRAINING_ITERATION] + 1
|
| 1025 |
+
return " ".join(
|
| 1026 |
+
[f"Training on iteration {iter_num}.", self._time_heartbeat_str]
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
def _print_heartbeat(self, trials, *args, force: bool = False):
|
| 1030 |
+
print(self._get_heartbeat(trials, force_full_output=force))
|
| 1031 |
+
|
| 1032 |
+
def on_trial_result(
|
| 1033 |
+
self,
|
| 1034 |
+
iteration: int,
|
| 1035 |
+
trials: List[Trial],
|
| 1036 |
+
trial: Trial,
|
| 1037 |
+
result: Dict,
|
| 1038 |
+
**info,
|
| 1039 |
+
):
|
| 1040 |
+
self._last_heartbeat_time = time.time()
|
| 1041 |
+
super().on_trial_result(
|
| 1042 |
+
iteration=iteration, trials=trials, trial=trial, result=result, **info
|
| 1043 |
+
)
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.logger.csv import CSVLogger, CSVLoggerCallback
|
| 2 |
+
from ray.tune.logger.json import JsonLogger, JsonLoggerCallback
|
| 3 |
+
from ray.tune.logger.logger import (
|
| 4 |
+
LegacyLoggerCallback,
|
| 5 |
+
Logger,
|
| 6 |
+
LoggerCallback,
|
| 7 |
+
pretty_print,
|
| 8 |
+
)
|
| 9 |
+
from ray.tune.logger.noop import NoopLogger
|
| 10 |
+
from ray.tune.logger.tensorboardx import TBXLogger, TBXLoggerCallback
|
| 11 |
+
|
| 12 |
+
DEFAULT_LOGGERS = (JsonLogger, CSVLogger, TBXLogger)
|
| 13 |
+
|
| 14 |
+
# isort: off
|
| 15 |
+
from ray.tune.logger.unified import UnifiedLogger # noqa: E402
|
| 16 |
+
|
| 17 |
+
# isort: on
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"Logger",
|
| 21 |
+
"LoggerCallback",
|
| 22 |
+
"LegacyLoggerCallback",
|
| 23 |
+
"pretty_print",
|
| 24 |
+
"CSVLogger",
|
| 25 |
+
"CSVLoggerCallback",
|
| 26 |
+
"JsonLogger",
|
| 27 |
+
"JsonLoggerCallback",
|
| 28 |
+
"NoopLogger",
|
| 29 |
+
"TBXLogger",
|
| 30 |
+
"TBXLoggerCallback",
|
| 31 |
+
"UnifiedLogger",
|
| 32 |
+
]
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (982 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/aim.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/comet.cpython-311.pyc
ADDED
|
Binary file (327 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/csv.cpython-311.pyc
ADDED
|
Binary file (7.66 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/json.cpython-311.pyc
ADDED
|
Binary file (8.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/logger.cpython-311.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/mlflow.cpython-311.pyc
ADDED
|
Binary file (331 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/noop.cpython-311.pyc
ADDED
|
Binary file (875 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/tensorboardx.cpython-311.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/unified.cpython-311.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/__pycache__/wandb.cpython-311.pyc
ADDED
|
Binary file (327 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/logger/aim.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 7 |
+
from ray.tune.logger.logger import LoggerCallback
|
| 8 |
+
from ray.tune.result import TIME_TOTAL_S, TIMESTEPS_TOTAL
|
| 9 |
+
from ray.tune.utils import flatten_dict
|
| 10 |
+
from ray.util.annotations import PublicAPI
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from ray.tune.experiment.trial import Trial
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from aim.sdk import Repo, Run
|
| 17 |
+
except ImportError:
|
| 18 |
+
Repo, Run = None, None
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@PublicAPI
|
| 26 |
+
class AimLoggerCallback(LoggerCallback):
|
| 27 |
+
"""Aim Logger: logs metrics in Aim format.
|
| 28 |
+
|
| 29 |
+
Aim is an open-source, self-hosted ML experiment tracking tool.
|
| 30 |
+
It's good at tracking lots (thousands) of training runs, and it allows you to
|
| 31 |
+
compare them with a performant and well-designed UI.
|
| 32 |
+
|
| 33 |
+
Source: https://github.com/aimhubio/aim
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
repo: Aim repository directory or a `Repo` object that the Run object will
|
| 37 |
+
log results to. If not provided, a default repo will be set up in the
|
| 38 |
+
experiment directory (one level above trial directories).
|
| 39 |
+
experiment: Sets the `experiment` property of each Run object, which is the
|
| 40 |
+
experiment name associated with it. Can be used later to query
|
| 41 |
+
runs/sequences.
|
| 42 |
+
If not provided, the default will be the Tune experiment name set
|
| 43 |
+
by `RunConfig(name=...)`.
|
| 44 |
+
metrics: List of metric names (out of the metrics reported by Tune) to
|
| 45 |
+
track in Aim. If no metric are specified, log everything that
|
| 46 |
+
is reported.
|
| 47 |
+
aim_run_kwargs: Additional arguments that will be passed when creating the
|
| 48 |
+
individual `Run` objects for each trial. For the full list of arguments,
|
| 49 |
+
please see the Aim documentation:
|
| 50 |
+
https://aimstack.readthedocs.io/en/latest/refs/sdk.html
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
VALID_HPARAMS = (str, bool, int, float, list, type(None))
|
| 54 |
+
VALID_NP_HPARAMS = (np.bool_, np.float32, np.float64, np.int32, np.int64)
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
repo: Optional[Union[str, "Repo"]] = None,
|
| 59 |
+
experiment_name: Optional[str] = None,
|
| 60 |
+
metrics: Optional[List[str]] = None,
|
| 61 |
+
**aim_run_kwargs,
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
See help(AimLoggerCallback) for more information about parameters.
|
| 65 |
+
"""
|
| 66 |
+
assert Run is not None, (
|
| 67 |
+
"aim must be installed!. You can install aim with"
|
| 68 |
+
" the command: `pip install aim`."
|
| 69 |
+
)
|
| 70 |
+
self._repo_path = repo
|
| 71 |
+
self._experiment_name = experiment_name
|
| 72 |
+
if not (bool(metrics) or metrics is None):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"`metrics` must either contain at least one metric name, or be None, "
|
| 75 |
+
"in which case all reported metrics will be logged to the aim repo."
|
| 76 |
+
)
|
| 77 |
+
self._metrics = metrics
|
| 78 |
+
self._aim_run_kwargs = aim_run_kwargs
|
| 79 |
+
self._trial_to_run: Dict["Trial", Run] = {}
|
| 80 |
+
|
| 81 |
+
def _create_run(self, trial: "Trial") -> Run:
|
| 82 |
+
"""Initializes an Aim Run object for a given trial.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
trial: The Tune trial that aim will track as a Run.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Run: The created aim run for a specific trial.
|
| 89 |
+
"""
|
| 90 |
+
experiment_dir = trial.local_experiment_path
|
| 91 |
+
run = Run(
|
| 92 |
+
repo=self._repo_path or experiment_dir,
|
| 93 |
+
experiment=self._experiment_name or trial.experiment_dir_name,
|
| 94 |
+
**self._aim_run_kwargs,
|
| 95 |
+
)
|
| 96 |
+
# Attach a few useful trial properties
|
| 97 |
+
run["trial_id"] = trial.trial_id
|
| 98 |
+
run["trial_log_dir"] = trial.path
|
| 99 |
+
trial_ip = trial.get_ray_actor_ip()
|
| 100 |
+
if trial_ip:
|
| 101 |
+
run["trial_ip"] = trial_ip
|
| 102 |
+
return run
|
| 103 |
+
|
| 104 |
+
def log_trial_start(self, trial: "Trial"):
|
| 105 |
+
if trial in self._trial_to_run:
|
| 106 |
+
# Cleanup an existing run if the trial has been restarted
|
| 107 |
+
self._trial_to_run[trial].close()
|
| 108 |
+
|
| 109 |
+
trial.init_local_path()
|
| 110 |
+
self._trial_to_run[trial] = self._create_run(trial)
|
| 111 |
+
|
| 112 |
+
if trial.evaluated_params:
|
| 113 |
+
self._log_trial_hparams(trial)
|
| 114 |
+
|
| 115 |
+
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
| 116 |
+
tmp_result = result.copy()
|
| 117 |
+
|
| 118 |
+
step = result.get(TIMESTEPS_TOTAL, None) or result[TRAINING_ITERATION]
|
| 119 |
+
|
| 120 |
+
for k in ["config", "pid", "timestamp", TIME_TOTAL_S, TRAINING_ITERATION]:
|
| 121 |
+
tmp_result.pop(k, None) # not useful to log these
|
| 122 |
+
|
| 123 |
+
# `context` and `epoch` are special keys that users can report,
|
| 124 |
+
# which are treated as special aim metrics/configurations.
|
| 125 |
+
context = tmp_result.pop("context", None)
|
| 126 |
+
epoch = tmp_result.pop("epoch", None)
|
| 127 |
+
|
| 128 |
+
trial_run = self._trial_to_run[trial]
|
| 129 |
+
path = ["ray", "tune"]
|
| 130 |
+
|
| 131 |
+
flat_result = flatten_dict(tmp_result, delimiter="/")
|
| 132 |
+
valid_result = {}
|
| 133 |
+
|
| 134 |
+
for attr, value in flat_result.items():
|
| 135 |
+
if self._metrics and attr not in self._metrics:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
full_attr = "/".join(path + [attr])
|
| 139 |
+
if isinstance(value, tuple(VALID_SUMMARY_TYPES)) and not (
|
| 140 |
+
np.isnan(value) or np.isinf(value)
|
| 141 |
+
):
|
| 142 |
+
valid_result[attr] = value
|
| 143 |
+
trial_run.track(
|
| 144 |
+
value=value,
|
| 145 |
+
name=full_attr,
|
| 146 |
+
epoch=epoch,
|
| 147 |
+
step=step,
|
| 148 |
+
context=context,
|
| 149 |
+
)
|
| 150 |
+
elif (isinstance(value, (list, tuple, set)) and len(value) > 0) or (
|
| 151 |
+
isinstance(value, np.ndarray) and value.size > 0
|
| 152 |
+
):
|
| 153 |
+
valid_result[attr] = value
|
| 154 |
+
|
| 155 |
+
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
| 156 |
+
trial_run = self._trial_to_run.pop(trial)
|
| 157 |
+
trial_run.close()
|
| 158 |
+
|
| 159 |
+
def _log_trial_hparams(self, trial: "Trial"):
|
| 160 |
+
params = flatten_dict(trial.evaluated_params, delimiter="/")
|
| 161 |
+
flat_params = flatten_dict(params)
|
| 162 |
+
|
| 163 |
+
scrubbed_params = {
|
| 164 |
+
k: v for k, v in flat_params.items() if isinstance(v, self.VALID_HPARAMS)
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
np_params = {
|
| 168 |
+
k: v.tolist()
|
| 169 |
+
for k, v in flat_params.items()
|
| 170 |
+
if isinstance(v, self.VALID_NP_HPARAMS)
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
scrubbed_params.update(np_params)
|
| 174 |
+
removed = {
|
| 175 |
+
k: v
|
| 176 |
+
for k, v in flat_params.items()
|
| 177 |
+
if not isinstance(v, self.VALID_HPARAMS + self.VALID_NP_HPARAMS)
|
| 178 |
+
}
|
| 179 |
+
if removed:
|
| 180 |
+
logger.info(
|
| 181 |
+
"Removed the following hyperparameter values when "
|
| 182 |
+
"logging to aim: %s",
|
| 183 |
+
str(removed),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
run = self._trial_to_run[trial]
|
| 187 |
+
run["hparams"] = scrubbed_params
|
.venv/lib/python3.11/site-packages/ray/tune/logger/comet.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.air.integrations.comet import CometLoggerCallback
|
| 2 |
+
|
| 3 |
+
CometLoggerCallback.__module__ = "ray.tune.logger.comet"
|
.venv/lib/python3.11/site-packages/ray/tune/logger/csv.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import TYPE_CHECKING, Dict, TextIO
|
| 5 |
+
|
| 6 |
+
from ray.air.constants import EXPR_PROGRESS_FILE
|
| 7 |
+
from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
|
| 8 |
+
from ray.tune.utils import flatten_dict
|
| 9 |
+
from ray.util.annotations import Deprecated, PublicAPI
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from ray.tune.experiment.trial import Trial # noqa: F401
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@Deprecated(
|
| 18 |
+
message=_LOGGER_DEPRECATION_WARNING.format(
|
| 19 |
+
old="CSVLogger", new="ray.tune.csv.CSVLoggerCallback"
|
| 20 |
+
),
|
| 21 |
+
warning=True,
|
| 22 |
+
)
|
| 23 |
+
@PublicAPI
|
| 24 |
+
class CSVLogger(Logger):
|
| 25 |
+
"""Logs results to progress.csv under the trial directory.
|
| 26 |
+
|
| 27 |
+
Automatically flattens nested dicts in the result dict before writing
|
| 28 |
+
to csv:
|
| 29 |
+
|
| 30 |
+
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def _init(self):
|
| 35 |
+
self._initialized = False
|
| 36 |
+
|
| 37 |
+
def _maybe_init(self):
|
| 38 |
+
"""CSV outputted with Headers as first set of results."""
|
| 39 |
+
if not self._initialized:
|
| 40 |
+
progress_file = Path(self.logdir, EXPR_PROGRESS_FILE)
|
| 41 |
+
self._continuing = (
|
| 42 |
+
progress_file.exists() and progress_file.stat().st_size > 0
|
| 43 |
+
)
|
| 44 |
+
self._file = progress_file.open("a")
|
| 45 |
+
self._csv_out = None
|
| 46 |
+
self._initialized = True
|
| 47 |
+
|
| 48 |
+
def on_result(self, result: Dict):
|
| 49 |
+
self._maybe_init()
|
| 50 |
+
|
| 51 |
+
tmp = result.copy()
|
| 52 |
+
if "config" in tmp:
|
| 53 |
+
del tmp["config"]
|
| 54 |
+
result = flatten_dict(tmp, delimiter="/")
|
| 55 |
+
if self._csv_out is None:
|
| 56 |
+
self._csv_out = csv.DictWriter(self._file, result.keys())
|
| 57 |
+
if not self._continuing:
|
| 58 |
+
self._csv_out.writeheader()
|
| 59 |
+
self._csv_out.writerow(
|
| 60 |
+
{k: v for k, v in result.items() if k in self._csv_out.fieldnames}
|
| 61 |
+
)
|
| 62 |
+
self._file.flush()
|
| 63 |
+
|
| 64 |
+
def flush(self):
|
| 65 |
+
if self._initialized and not self._file.closed:
|
| 66 |
+
self._file.flush()
|
| 67 |
+
|
| 68 |
+
def close(self):
|
| 69 |
+
if self._initialized:
|
| 70 |
+
self._file.close()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@PublicAPI
|
| 74 |
+
class CSVLoggerCallback(LoggerCallback):
|
| 75 |
+
"""Logs results to progress.csv under the trial directory.
|
| 76 |
+
|
| 77 |
+
Automatically flattens nested dicts in the result dict before writing
|
| 78 |
+
to csv:
|
| 79 |
+
|
| 80 |
+
{"a": {"b": 1, "c": 2}} -> {"a/b": 1, "a/c": 2}
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
_SAVED_FILE_TEMPLATES = [EXPR_PROGRESS_FILE]
|
| 85 |
+
|
| 86 |
+
def __init__(self):
|
| 87 |
+
self._trial_continue: Dict["Trial", bool] = {}
|
| 88 |
+
self._trial_files: Dict["Trial", TextIO] = {}
|
| 89 |
+
self._trial_csv: Dict["Trial", csv.DictWriter] = {}
|
| 90 |
+
|
| 91 |
+
def _setup_trial(self, trial: "Trial"):
|
| 92 |
+
if trial in self._trial_files:
|
| 93 |
+
self._trial_files[trial].close()
|
| 94 |
+
|
| 95 |
+
# Make sure logdir exists
|
| 96 |
+
trial.init_local_path()
|
| 97 |
+
local_file_path = Path(trial.local_path, EXPR_PROGRESS_FILE)
|
| 98 |
+
|
| 99 |
+
# Resume the file from remote storage.
|
| 100 |
+
self._restore_from_remote(EXPR_PROGRESS_FILE, trial)
|
| 101 |
+
|
| 102 |
+
self._trial_continue[trial] = (
|
| 103 |
+
local_file_path.exists() and local_file_path.stat().st_size > 0
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self._trial_files[trial] = local_file_path.open("at")
|
| 107 |
+
self._trial_csv[trial] = None
|
| 108 |
+
|
| 109 |
+
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
| 110 |
+
if trial not in self._trial_files:
|
| 111 |
+
self._setup_trial(trial)
|
| 112 |
+
|
| 113 |
+
tmp = result.copy()
|
| 114 |
+
tmp.pop("config", None)
|
| 115 |
+
result = flatten_dict(tmp, delimiter="/")
|
| 116 |
+
|
| 117 |
+
if not self._trial_csv[trial]:
|
| 118 |
+
self._trial_csv[trial] = csv.DictWriter(
|
| 119 |
+
self._trial_files[trial], result.keys()
|
| 120 |
+
)
|
| 121 |
+
if not self._trial_continue[trial]:
|
| 122 |
+
self._trial_csv[trial].writeheader()
|
| 123 |
+
|
| 124 |
+
self._trial_csv[trial].writerow(
|
| 125 |
+
{k: v for k, v in result.items() if k in self._trial_csv[trial].fieldnames}
|
| 126 |
+
)
|
| 127 |
+
self._trial_files[trial].flush()
|
| 128 |
+
|
| 129 |
+
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
| 130 |
+
if trial not in self._trial_files:
|
| 131 |
+
return
|
| 132 |
+
|
| 133 |
+
del self._trial_csv[trial]
|
| 134 |
+
self._trial_files[trial].close()
|
| 135 |
+
del self._trial_files[trial]
|
.venv/lib/python3.11/site-packages/ray/tune/logger/json.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import TYPE_CHECKING, Dict, TextIO
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
import ray.cloudpickle as cloudpickle
|
| 9 |
+
from ray.air.constants import EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE, EXPR_RESULT_FILE
|
| 10 |
+
from ray.tune.logger.logger import _LOGGER_DEPRECATION_WARNING, Logger, LoggerCallback
|
| 11 |
+
from ray.tune.utils.util import SafeFallbackEncoder
|
| 12 |
+
from ray.util.annotations import Deprecated, PublicAPI
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from ray.tune.experiment.trial import Trial # noqa: F401
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
tf = None
|
| 20 |
+
VALID_SUMMARY_TYPES = [int, float, np.float32, np.float64, np.int32, np.int64]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@Deprecated(
|
| 24 |
+
message=_LOGGER_DEPRECATION_WARNING.format(
|
| 25 |
+
old="JsonLogger", new="ray.tune.json.JsonLoggerCallback"
|
| 26 |
+
),
|
| 27 |
+
warning=True,
|
| 28 |
+
)
|
| 29 |
+
@PublicAPI
|
| 30 |
+
class JsonLogger(Logger):
|
| 31 |
+
"""Logs trial results in json format.
|
| 32 |
+
|
| 33 |
+
Also writes to a results file and param.json file when results or
|
| 34 |
+
configurations are updated. Experiments must be executed with the
|
| 35 |
+
JsonLogger to be compatible with the ExperimentAnalysis tool.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def _init(self):
|
| 39 |
+
self.update_config(self.config)
|
| 40 |
+
local_file = Path(self.logdir, EXPR_RESULT_FILE)
|
| 41 |
+
self.local_out = local_file.open("a")
|
| 42 |
+
|
| 43 |
+
def on_result(self, result: Dict):
|
| 44 |
+
json.dump(result, self, cls=SafeFallbackEncoder)
|
| 45 |
+
self.write("\n")
|
| 46 |
+
self.local_out.flush()
|
| 47 |
+
|
| 48 |
+
def write(self, b):
|
| 49 |
+
self.local_out.write(b)
|
| 50 |
+
|
| 51 |
+
def flush(self):
|
| 52 |
+
if not self.local_out.closed:
|
| 53 |
+
self.local_out.flush()
|
| 54 |
+
|
| 55 |
+
def close(self):
|
| 56 |
+
self.local_out.close()
|
| 57 |
+
|
| 58 |
+
def update_config(self, config: Dict):
|
| 59 |
+
self.config = config
|
| 60 |
+
config_out = Path(self.logdir, EXPR_PARAM_FILE)
|
| 61 |
+
with open(config_out, "w") as f:
|
| 62 |
+
json.dump(self.config, f, indent=2, sort_keys=True, cls=SafeFallbackEncoder)
|
| 63 |
+
config_pkl = Path(self.logdir, EXPR_PARAM_PICKLE_FILE)
|
| 64 |
+
with config_pkl.open("wb") as f:
|
| 65 |
+
cloudpickle.dump(self.config, f)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@PublicAPI
|
| 69 |
+
class JsonLoggerCallback(LoggerCallback):
|
| 70 |
+
"""Logs trial results in json format.
|
| 71 |
+
|
| 72 |
+
Also writes to a results file and param.json file when results or
|
| 73 |
+
configurations are updated. Experiments must be executed with the
|
| 74 |
+
JsonLoggerCallback to be compatible with the ExperimentAnalysis tool.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
_SAVED_FILE_TEMPLATES = [EXPR_RESULT_FILE, EXPR_PARAM_FILE, EXPR_PARAM_PICKLE_FILE]
|
| 78 |
+
|
| 79 |
+
def __init__(self):
|
| 80 |
+
self._trial_configs: Dict["Trial", Dict] = {}
|
| 81 |
+
self._trial_files: Dict["Trial", TextIO] = {}
|
| 82 |
+
|
| 83 |
+
def log_trial_start(self, trial: "Trial"):
|
| 84 |
+
if trial in self._trial_files:
|
| 85 |
+
self._trial_files[trial].close()
|
| 86 |
+
|
| 87 |
+
# Update config
|
| 88 |
+
self.update_config(trial, trial.config)
|
| 89 |
+
|
| 90 |
+
# Make sure logdir exists
|
| 91 |
+
trial.init_local_path()
|
| 92 |
+
local_file = Path(trial.local_path, EXPR_RESULT_FILE)
|
| 93 |
+
|
| 94 |
+
# Resume the file from remote storage.
|
| 95 |
+
self._restore_from_remote(EXPR_RESULT_FILE, trial)
|
| 96 |
+
|
| 97 |
+
self._trial_files[trial] = local_file.open("at")
|
| 98 |
+
|
| 99 |
+
def log_trial_result(self, iteration: int, trial: "Trial", result: Dict):
|
| 100 |
+
if trial not in self._trial_files:
|
| 101 |
+
self.log_trial_start(trial)
|
| 102 |
+
json.dump(result, self._trial_files[trial], cls=SafeFallbackEncoder)
|
| 103 |
+
self._trial_files[trial].write("\n")
|
| 104 |
+
self._trial_files[trial].flush()
|
| 105 |
+
|
| 106 |
+
def log_trial_end(self, trial: "Trial", failed: bool = False):
|
| 107 |
+
if trial not in self._trial_files:
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
self._trial_files[trial].close()
|
| 111 |
+
del self._trial_files[trial]
|
| 112 |
+
|
| 113 |
+
def update_config(self, trial: "Trial", config: Dict):
|
| 114 |
+
self._trial_configs[trial] = config
|
| 115 |
+
|
| 116 |
+
config_out = Path(trial.local_path, EXPR_PARAM_FILE)
|
| 117 |
+
with config_out.open("w") as f:
|
| 118 |
+
json.dump(
|
| 119 |
+
self._trial_configs[trial],
|
| 120 |
+
f,
|
| 121 |
+
indent=2,
|
| 122 |
+
sort_keys=True,
|
| 123 |
+
cls=SafeFallbackEncoder,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
config_pkl = Path(trial.local_path, EXPR_PARAM_PICKLE_FILE)
|
| 127 |
+
with config_pkl.open("wb") as f:
|
| 128 |
+
cloudpickle.dump(self._trial_configs[trial], f)
|