Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/tune/analysis/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/experiment_analysis.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/analysis/experiment_analysis.py +678 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/config.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/out_of_band_serialize_dataset.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/placeholder.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/test_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/tuner_internal.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/config.py +46 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/out_of_band_serialize_dataset.py +33 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/placeholder.py +244 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/test_utils.py +66 -0
- .venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py +669 -0
- .venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/resource_changing_scheduler.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__init__.py +153 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/concurrency_limiter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/sample.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_algorithm.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_generator.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/_mock.py +55 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/ax_search.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/basic_variant.py +421 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/concurrency_limiter.py +176 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/hebo_search.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/nevergrad_search.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/nevergrad_search.py +373 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__init__.py +3 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/optuna_search.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/optuna/optuna_search.py +733 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/repeater.py +199 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/sample.py +742 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/search_algorithm.py +127 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/search_generator.py +222 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/searcher.py +597 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/util.py +31 -0
- .venv/lib/python3.11/site-packages/ray/tune/search/variant_generator.py +523 -0
- .venv/lib/python3.11/site-packages/ray/tune/stopper/__init__.py +18 -0
- .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/experiment_plateau.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/function_stopper.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/tune/analysis/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.analysis.experiment_analysis import ExperimentAnalysis
|
| 2 |
+
|
| 3 |
+
__all__ = ["ExperimentAnalysis"]
|
.venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (310 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/analysis/__pycache__/experiment_analysis.cpython-311.pyc
ADDED
|
Binary file (35.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/analysis/experiment_analysis.py
ADDED
|
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from numbers import Number
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import pyarrow.fs
|
| 11 |
+
|
| 12 |
+
from ray.air.constants import EXPR_PROGRESS_FILE, EXPR_RESULT_FILE, TRAINING_ITERATION
|
| 13 |
+
from ray.train import Checkpoint
|
| 14 |
+
from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
|
| 15 |
+
from ray.tune.execution.experiment_state import _find_newest_experiment_checkpoint
|
| 16 |
+
from ray.tune.execution.tune_controller import TuneController
|
| 17 |
+
from ray.tune.experiment import Trial
|
| 18 |
+
from ray.tune.result import CONFIG_PREFIX, DEFAULT_METRIC
|
| 19 |
+
from ray.tune.utils import flatten_dict
|
| 20 |
+
from ray.tune.utils.serialization import TuneFunctionDecoder
|
| 21 |
+
from ray.tune.utils.util import is_nan, is_nan_or_inf, unflattened_lookup
|
| 22 |
+
from ray.util.annotations import PublicAPI
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import pandas as pd
|
| 26 |
+
from pandas import DataFrame
|
| 27 |
+
except ImportError:
|
| 28 |
+
pd = None
|
| 29 |
+
DataFrame = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@PublicAPI(stability="beta")
|
| 36 |
+
class ExperimentAnalysis:
|
| 37 |
+
"""Analyze results from a Ray Train/Tune experiment.
|
| 38 |
+
|
| 39 |
+
To use this class, the run must store the history of reported metrics
|
| 40 |
+
in log files (e.g., `result.json` and `progress.csv`).
|
| 41 |
+
This is the default behavior, unless default loggers are explicitly excluded
|
| 42 |
+
with the `TUNE_DISABLE_AUTO_CALLBACK_LOGGERS=1` environment variable.
|
| 43 |
+
|
| 44 |
+
Parameters:
|
| 45 |
+
experiment_checkpoint_path: Path to an `experiment_state.json` file,
|
| 46 |
+
or a directory that contains an `experiment_state.json` file.
|
| 47 |
+
default_metric: Default metric for comparing results. Can be
|
| 48 |
+
overwritten with the ``metric`` parameter in the respective
|
| 49 |
+
functions.
|
| 50 |
+
default_mode: Default mode for comparing results. Has to be one
|
| 51 |
+
of [min, max]. Can be overwritten with the ``mode`` parameter
|
| 52 |
+
in the respective functions.
|
| 53 |
+
trials: List of trials that can be accessed via `analysis.trials`.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
experiment_checkpoint_path: Union[str, os.PathLike],
|
| 59 |
+
*,
|
| 60 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 61 |
+
trials: Optional[List[Trial]] = None,
|
| 62 |
+
default_metric: Optional[str] = None,
|
| 63 |
+
default_mode: Optional[str] = None,
|
| 64 |
+
):
|
| 65 |
+
self.default_metric = default_metric
|
| 66 |
+
if default_mode and default_mode not in ["min", "max"]:
|
| 67 |
+
raise ValueError("`default_mode` has to be None or one of [min, max]")
|
| 68 |
+
self.default_mode = default_mode
|
| 69 |
+
if self.default_metric is None and self.default_mode is not None:
|
| 70 |
+
# If only a mode was passed, use anonymous metric
|
| 71 |
+
self.default_metric = DEFAULT_METRIC
|
| 72 |
+
|
| 73 |
+
# Resolve the filesystem if not specified.
|
| 74 |
+
if storage_filesystem:
|
| 75 |
+
self._fs = storage_filesystem
|
| 76 |
+
else:
|
| 77 |
+
self._fs, experiment_checkpoint_path = get_fs_and_path(
|
| 78 |
+
experiment_checkpoint_path
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Find the json state file.
|
| 82 |
+
experiment_checkpoint_path = str(experiment_checkpoint_path)
|
| 83 |
+
if experiment_checkpoint_path.endswith(".json"):
|
| 84 |
+
self._experiment_fs_path = os.path.dirname(experiment_checkpoint_path)
|
| 85 |
+
self._experiment_json_fs_path = experiment_checkpoint_path
|
| 86 |
+
else:
|
| 87 |
+
self._experiment_fs_path = experiment_checkpoint_path
|
| 88 |
+
|
| 89 |
+
experiment_json_fs_path = _find_newest_experiment_checkpoint(
|
| 90 |
+
experiment_path=self._experiment_fs_path, fs=self._fs
|
| 91 |
+
)
|
| 92 |
+
if experiment_json_fs_path is None:
|
| 93 |
+
pattern = TuneController.CKPT_FILE_TMPL.format("*")
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"No experiment snapshot file of form '{pattern}' was found at: "
|
| 96 |
+
f"({self._fs.type_name}, {self._experiment_fs_path})\n"
|
| 97 |
+
"Please check if you specified the correct experiment path, "
|
| 98 |
+
"which should be a combination of the `storage_path` and `name` "
|
| 99 |
+
"specified in your run."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self._experiment_json_fs_path = experiment_json_fs_path
|
| 103 |
+
|
| 104 |
+
self.trials = trials or self._load_trials()
|
| 105 |
+
self._trial_dataframes = self._fetch_trial_dataframes()
|
| 106 |
+
self._configs = self.get_all_configs()
|
| 107 |
+
|
| 108 |
+
def _load_trials(self) -> List[Trial]:
|
| 109 |
+
with self._fs.open_input_stream(self._experiment_json_fs_path) as f:
|
| 110 |
+
experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder)
|
| 111 |
+
|
| 112 |
+
experiment_fs_path = Path(self._experiment_fs_path)
|
| 113 |
+
|
| 114 |
+
trials = []
|
| 115 |
+
trial_states = experiment_state["trial_data"]
|
| 116 |
+
for trial_json_state, trial_runtime_metadata in trial_states:
|
| 117 |
+
trial = Trial.from_json_state(trial_json_state, stub=True)
|
| 118 |
+
trial.restore_run_metadata(trial_runtime_metadata)
|
| 119 |
+
|
| 120 |
+
new_storage = copy.copy(trial.storage)
|
| 121 |
+
new_storage.storage_fs_path = experiment_fs_path.parent.as_posix()
|
| 122 |
+
new_storage.storage_filesystem = self._fs
|
| 123 |
+
new_storage.experiment_dir_name = experiment_fs_path.name
|
| 124 |
+
trial.set_storage(new_storage)
|
| 125 |
+
|
| 126 |
+
trials.append(trial)
|
| 127 |
+
return trials
|
| 128 |
+
|
| 129 |
+
def _fetch_trial_dataframe(self, trial: Trial) -> DataFrame:
|
| 130 |
+
force_dtype = {"trial_id": str} # Never convert trial_id to float.
|
| 131 |
+
|
| 132 |
+
# If there were no reported results, there will be no files into a DataFrame
|
| 133 |
+
if trial.last_result is None:
|
| 134 |
+
return DataFrame()
|
| 135 |
+
|
| 136 |
+
json_fs_path = Path(trial.storage.trial_fs_path, EXPR_RESULT_FILE).as_posix()
|
| 137 |
+
csv_fs_path = Path(trial.storage.trial_fs_path, EXPR_PROGRESS_FILE).as_posix()
|
| 138 |
+
# Prefer reading the JSON if it exists.
|
| 139 |
+
if _exists_at_fs_path(trial.storage.storage_filesystem, json_fs_path):
|
| 140 |
+
with trial.storage.storage_filesystem.open_input_stream(json_fs_path) as f:
|
| 141 |
+
content = f.readall().decode("utf-8").rstrip("\n")
|
| 142 |
+
if not content:
|
| 143 |
+
return DataFrame()
|
| 144 |
+
json_list = [json.loads(row) for row in content.split("\n")]
|
| 145 |
+
df = pd.json_normalize(json_list, sep="/")
|
| 146 |
+
# Fallback to reading the CSV.
|
| 147 |
+
elif _exists_at_fs_path(trial.storage.storage_filesystem, csv_fs_path):
|
| 148 |
+
with trial.storage.storage_filesystem.open_input_stream(csv_fs_path) as f:
|
| 149 |
+
csv_str = f.readall().decode("utf-8")
|
| 150 |
+
df = pd.read_csv(io.StringIO(csv_str), dtype=force_dtype)
|
| 151 |
+
else:
|
| 152 |
+
raise FileNotFoundError(
|
| 153 |
+
f"Could not fetch metrics for {trial}: both {EXPR_RESULT_FILE} and "
|
| 154 |
+
f"{EXPR_PROGRESS_FILE} were not found at {trial.storage.trial_fs_path}"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return df
|
| 158 |
+
|
| 159 |
+
def _fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
|
| 160 |
+
"""Fetches trial dataframes from files.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
A dictionary mapping trial_id -> pd.DataFrame
|
| 164 |
+
"""
|
| 165 |
+
failures = []
|
| 166 |
+
|
| 167 |
+
trial_dfs = {}
|
| 168 |
+
for trial in self.trials:
|
| 169 |
+
try:
|
| 170 |
+
trial_dfs[trial.trial_id] = self._fetch_trial_dataframe(trial)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
failures.append((trial, e))
|
| 173 |
+
trial_dfs[trial.trial_id] = DataFrame()
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
if failures:
|
| 177 |
+
fail_str = "\n".join(
|
| 178 |
+
[f"- {trial}: {repr(error)}" for trial, error in failures]
|
| 179 |
+
)
|
| 180 |
+
logger.warning(
|
| 181 |
+
f"Failed to fetch metrics for {len(failures)} trial(s):\n{fail_str}"
|
| 182 |
+
)
|
| 183 |
+
return trial_dfs
|
| 184 |
+
|
| 185 |
+
def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
|
| 186 |
+
"""Returns all trial hyperparameter configurations.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
prefix: If True, flattens the config dict
|
| 190 |
+
and prepends `config/`.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
Dict[str, Dict]: Mapping trial_id -> config dict
|
| 194 |
+
"""
|
| 195 |
+
return {
|
| 196 |
+
trial.trial_id: (
|
| 197 |
+
flatten_dict({CONFIG_PREFIX: trial.config}) if prefix else trial.config
|
| 198 |
+
)
|
| 199 |
+
for trial in self.trials
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def experiment_path(self) -> str:
|
| 204 |
+
"""Path pointing to the experiment directory on persistent storage.
|
| 205 |
+
|
| 206 |
+
This can point to a remote storage location (e.g. S3) or to a local
|
| 207 |
+
location (path on the head node)."""
|
| 208 |
+
return self._experiment_fs_path
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def best_trial(self) -> Trial:
|
| 212 |
+
"""Get the best trial of the experiment
|
| 213 |
+
|
| 214 |
+
The best trial is determined by comparing the last trial results
|
| 215 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 216 |
+
|
| 217 |
+
If you didn't pass these parameters, use
|
| 218 |
+
`get_best_trial(metric, mode, scope)` instead.
|
| 219 |
+
"""
|
| 220 |
+
if not self.default_metric or not self.default_mode:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
"To fetch the `best_trial`, pass a `metric` and `mode` "
|
| 223 |
+
"parameter to `tune.run()`. Alternatively, use the "
|
| 224 |
+
"`get_best_trial(metric, mode)` method to set the metric "
|
| 225 |
+
"and mode explicitly."
|
| 226 |
+
)
|
| 227 |
+
return self.get_best_trial(self.default_metric, self.default_mode)
|
| 228 |
+
|
| 229 |
+
@property
|
| 230 |
+
def best_config(self) -> Dict:
|
| 231 |
+
"""Get the config of the best trial of the experiment
|
| 232 |
+
|
| 233 |
+
The best trial is determined by comparing the last trial results
|
| 234 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 235 |
+
|
| 236 |
+
If you didn't pass these parameters, use
|
| 237 |
+
`get_best_config(metric, mode, scope)` instead.
|
| 238 |
+
"""
|
| 239 |
+
if not self.default_metric or not self.default_mode:
|
| 240 |
+
raise ValueError(
|
| 241 |
+
"To fetch the `best_config`, pass a `metric` and `mode` "
|
| 242 |
+
"parameter to `tune.run()`. Alternatively, use the "
|
| 243 |
+
"`get_best_config(metric, mode)` method to set the metric "
|
| 244 |
+
"and mode explicitly."
|
| 245 |
+
)
|
| 246 |
+
return self.get_best_config(self.default_metric, self.default_mode)
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def best_checkpoint(self) -> Checkpoint:
|
| 250 |
+
"""Get the checkpoint path of the best trial of the experiment
|
| 251 |
+
|
| 252 |
+
The best trial is determined by comparing the last trial results
|
| 253 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 254 |
+
|
| 255 |
+
If you didn't pass these parameters, use
|
| 256 |
+
`get_best_checkpoint(trial, metric, mode)` instead.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
:class:`Checkpoint <ray.train.Checkpoint>` object.
|
| 260 |
+
"""
|
| 261 |
+
if not self.default_metric or not self.default_mode:
|
| 262 |
+
raise ValueError(
|
| 263 |
+
"To fetch the `best_checkpoint`, pass a `metric` and `mode` "
|
| 264 |
+
"parameter to `tune.run()`. Alternatively, use the "
|
| 265 |
+
"`get_best_checkpoint(trial, metric, mode)` method to set the "
|
| 266 |
+
"metric and mode explicitly."
|
| 267 |
+
)
|
| 268 |
+
best_trial = self.best_trial
|
| 269 |
+
if not best_trial:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"No best trial found. Please check if you specified the "
|
| 272 |
+
f"correct default metric ({self.default_metric}) and mode "
|
| 273 |
+
f"({self.default_mode})."
|
| 274 |
+
)
|
| 275 |
+
return self.get_best_checkpoint(
|
| 276 |
+
best_trial, self.default_metric, self.default_mode
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
@property
|
| 280 |
+
def best_dataframe(self) -> DataFrame:
|
| 281 |
+
"""Get the full result dataframe of the best trial of the experiment
|
| 282 |
+
|
| 283 |
+
The best trial is determined by comparing the last trial results
|
| 284 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 285 |
+
|
| 286 |
+
If you didn't pass these parameters, use
|
| 287 |
+
`get_best_trial(metric, mode)` and use it to look for the dataframe
|
| 288 |
+
in the `self.trial_dataframes` dict.
|
| 289 |
+
"""
|
| 290 |
+
if not self.default_metric or not self.default_mode:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"To fetch the `best_result`, pass a `metric` and `mode` "
|
| 293 |
+
"parameter to `tune.run()`."
|
| 294 |
+
)
|
| 295 |
+
return self.trial_dataframes[self.best_trial.trial_id]
|
| 296 |
+
|
| 297 |
+
@property
|
| 298 |
+
def best_result(self) -> Dict:
|
| 299 |
+
"""Get the last result of the best trial of the experiment
|
| 300 |
+
|
| 301 |
+
The best trial is determined by comparing the last trial results
|
| 302 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 303 |
+
|
| 304 |
+
If you didn't pass these parameters, use
|
| 305 |
+
`get_best_trial(metric, mode, scope).last_result` instead.
|
| 306 |
+
"""
|
| 307 |
+
if not self.default_metric or not self.default_mode:
|
| 308 |
+
raise ValueError(
|
| 309 |
+
"To fetch the `best_result`, pass a `metric` and `mode` "
|
| 310 |
+
"parameter to `tune.run()`. Alternatively, use "
|
| 311 |
+
"`get_best_trial(metric, mode).last_result` to set "
|
| 312 |
+
"the metric and mode explicitly and fetch the last result."
|
| 313 |
+
)
|
| 314 |
+
return self.best_trial.last_result
|
| 315 |
+
|
| 316 |
+
def _delimiter(self):
|
| 317 |
+
return os.environ.get("TUNE_RESULT_DELIM", "/")
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def best_result_df(self) -> DataFrame:
|
| 321 |
+
"""Get the best result of the experiment as a pandas dataframe.
|
| 322 |
+
|
| 323 |
+
The best trial is determined by comparing the last trial results
|
| 324 |
+
using the `metric` and `mode` parameters passed to `tune.run()`.
|
| 325 |
+
|
| 326 |
+
If you didn't pass these parameters, use
|
| 327 |
+
`get_best_trial(metric, mode, scope).last_result` instead.
|
| 328 |
+
"""
|
| 329 |
+
if not pd:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"`best_result_df` requires pandas. Install with "
|
| 332 |
+
"`pip install pandas`."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
best_result = flatten_dict(self.best_result, delimiter=self._delimiter())
|
| 336 |
+
return pd.DataFrame.from_records([best_result], index="trial_id")
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def results(self) -> Dict[str, Dict]:
|
| 340 |
+
"""Get the last result of the all trials of the experiment"""
|
| 341 |
+
return {trial.trial_id: trial.last_result for trial in self.trials}
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def results_df(self) -> DataFrame:
|
| 345 |
+
"""Get all the last results as a pandas dataframe."""
|
| 346 |
+
if not pd:
|
| 347 |
+
raise ValueError(
|
| 348 |
+
"`results_df` requires pandas. Install with `pip install pandas`."
|
| 349 |
+
)
|
| 350 |
+
return pd.DataFrame.from_records(
|
| 351 |
+
[
|
| 352 |
+
flatten_dict(trial.last_result, delimiter=self._delimiter())
|
| 353 |
+
for trial in self.trials
|
| 354 |
+
],
|
| 355 |
+
index="trial_id",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
@property
|
| 359 |
+
def trial_dataframes(self) -> Dict[str, DataFrame]:
|
| 360 |
+
"""List of all dataframes of the trials.
|
| 361 |
+
|
| 362 |
+
Each dataframe is indexed by iterations and contains reported
|
| 363 |
+
metrics.
|
| 364 |
+
"""
|
| 365 |
+
return self._trial_dataframes
|
| 366 |
+
|
| 367 |
+
def dataframe(
|
| 368 |
+
self, metric: Optional[str] = None, mode: Optional[str] = None
|
| 369 |
+
) -> DataFrame:
|
| 370 |
+
"""Returns a pandas.DataFrame object constructed from the trials.
|
| 371 |
+
|
| 372 |
+
This function will look through all observed results of each trial
|
| 373 |
+
and return the one corresponding to the passed ``metric`` and
|
| 374 |
+
``mode``: If ``mode=min``, it returns the result with the lowest
|
| 375 |
+
*ever* observed ``metric`` for this trial (this is not necessarily
|
| 376 |
+
the last)! For ``mode=max``, it's the highest, respectively. If
|
| 377 |
+
``metric=None`` or ``mode=None``, the last result will be returned.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
metric: Key for trial info to order on. If None, uses last result.
|
| 381 |
+
mode: One of [None, "min", "max"].
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
pd.DataFrame: Constructed from a result dict of each trial.
|
| 385 |
+
"""
|
| 386 |
+
# Do not validate metric/mode here or set from default metric/mode!
|
| 387 |
+
# Otherwise we will get confusing results as the lowest ever observed
|
| 388 |
+
# result may not be the last result.
|
| 389 |
+
if mode and mode not in ["min", "max"]:
|
| 390 |
+
raise ValueError("If set, `mode` has to be one of [min, max]")
|
| 391 |
+
|
| 392 |
+
if mode and not metric:
|
| 393 |
+
raise ValueError(
|
| 394 |
+
"If a `mode` is passed to `ExperimentAnalysis.dataframe(),"
|
| 395 |
+
" you'll also have to pass a `metric`!"
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
rows = self._retrieve_rows(metric=metric, mode=mode)
|
| 399 |
+
all_configs = self.get_all_configs(prefix=True)
|
| 400 |
+
for path, config in all_configs.items():
|
| 401 |
+
if path in rows:
|
| 402 |
+
rows[path].update(config)
|
| 403 |
+
rows[path].update(logdir=path)
|
| 404 |
+
return pd.DataFrame(list(rows.values()))
|
| 405 |
+
|
| 406 |
+
def _get_trial_checkpoints_with_metric(
|
| 407 |
+
self, trial: Trial, metric: Optional[str] = None
|
| 408 |
+
) -> List[Tuple[Checkpoint, Number]]:
|
| 409 |
+
"""Get all checkpoints and a specified metric of a trial.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
trial: The log directory of a trial, or a trial instance.
|
| 413 |
+
metric: key for trial info to return, e.g. "mean_accuracy".
|
| 414 |
+
"training_iteration" is used by default if no value was
|
| 415 |
+
passed to ``self.default_metric``.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
List of [Checkpoint, metric] for all checkpoints of the trial.
|
| 419 |
+
"""
|
| 420 |
+
metric = metric or self.default_metric or TRAINING_ITERATION
|
| 421 |
+
|
| 422 |
+
best_checkpoint_results = (
|
| 423 |
+
trial.run_metadata.checkpoint_manager.best_checkpoint_results
|
| 424 |
+
)
|
| 425 |
+
best_checkpoints = [
|
| 426 |
+
(checkpoint_result.checkpoint, checkpoint_result.metrics)
|
| 427 |
+
for checkpoint_result in best_checkpoint_results
|
| 428 |
+
]
|
| 429 |
+
# Support nested metrics given as flattened strings, e.g.
|
| 430 |
+
# "info/learner/default_policy/policy_loss".
|
| 431 |
+
return [
|
| 432 |
+
(checkpoint, unflattened_lookup(metric, metrics))
|
| 433 |
+
for checkpoint, metrics in best_checkpoints
|
| 434 |
+
]
|
| 435 |
+
|
| 436 |
+
def get_best_checkpoint(
|
| 437 |
+
self,
|
| 438 |
+
trial: Trial,
|
| 439 |
+
metric: Optional[str] = None,
|
| 440 |
+
mode: Optional[str] = None,
|
| 441 |
+
) -> Optional[Checkpoint]:
|
| 442 |
+
"""Gets best persistent checkpoint path of provided trial.
|
| 443 |
+
|
| 444 |
+
Any checkpoints with an associated metric value of ``nan`` will be filtered out.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
trial: The log directory of a trial, or a trial instance.
|
| 448 |
+
metric: key of trial info to return, e.g. "mean_accuracy".
|
| 449 |
+
"training_iteration" is used by default if no value was
|
| 450 |
+
passed to ``self.default_metric``.
|
| 451 |
+
mode: One of [min, max]. Defaults to ``self.default_mode``.
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
A :class:`Checkpoint <ray.train.Checkpoint>` object
|
| 455 |
+
"""
|
| 456 |
+
metric = metric or self.default_metric or TRAINING_ITERATION
|
| 457 |
+
mode = self._validate_mode(mode)
|
| 458 |
+
|
| 459 |
+
checkpoints_and_metrics = self._get_trial_checkpoints_with_metric(trial, metric)
|
| 460 |
+
|
| 461 |
+
# Filter out nan. Sorting nan values leads to undefined behavior.
|
| 462 |
+
checkpoints_and_metrics = list(
|
| 463 |
+
filter(lambda x: not is_nan(x[1]), checkpoints_and_metrics)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if not checkpoints_and_metrics:
|
| 467 |
+
logger.error(f"No checkpoints have been found for trial {trial}.")
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
score_order_factor = -1 if mode == "min" else 1
|
| 471 |
+
best_checkpoint, _ = max(
|
| 472 |
+
checkpoints_and_metrics, key=lambda x: score_order_factor * x[1]
|
| 473 |
+
)
|
| 474 |
+
return best_checkpoint
|
| 475 |
+
|
| 476 |
+
def get_best_trial(
|
| 477 |
+
self,
|
| 478 |
+
metric: Optional[str] = None,
|
| 479 |
+
mode: Optional[str] = None,
|
| 480 |
+
scope: str = "last",
|
| 481 |
+
filter_nan_and_inf: bool = True,
|
| 482 |
+
) -> Optional[Trial]:
|
| 483 |
+
"""Retrieve the best trial object.
|
| 484 |
+
|
| 485 |
+
Compares all trials' scores on ``metric``.
|
| 486 |
+
If ``metric`` is not specified, ``self.default_metric`` will be used.
|
| 487 |
+
If `mode` is not specified, ``self.default_mode`` will be used.
|
| 488 |
+
These values are usually initialized by passing the ``metric`` and
|
| 489 |
+
``mode`` parameters to ``tune.run()``.
|
| 490 |
+
|
| 491 |
+
Args:
|
| 492 |
+
metric: Key for trial info to order on. Defaults to
|
| 493 |
+
``self.default_metric``.
|
| 494 |
+
mode: One of [min, max]. Defaults to ``self.default_mode``.
|
| 495 |
+
scope: One of [all, last, avg, last-5-avg, last-10-avg].
|
| 496 |
+
If `scope=last`, only look at each trial's final step for
|
| 497 |
+
`metric`, and compare across trials based on `mode=[min,max]`.
|
| 498 |
+
If `scope=avg`, consider the simple average over all steps
|
| 499 |
+
for `metric` and compare across trials based on
|
| 500 |
+
`mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
|
| 501 |
+
consider the simple average over the last 5 or 10 steps for
|
| 502 |
+
`metric` and compare across trials based on `mode=[min,max]`.
|
| 503 |
+
If `scope=all`, find each trial's min/max score for `metric`
|
| 504 |
+
based on `mode`, and compare trials based on `mode=[min,max]`.
|
| 505 |
+
filter_nan_and_inf: If True (default), NaN or infinite
|
| 506 |
+
values are disregarded and these trials are never selected as
|
| 507 |
+
the best trial.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
The best trial for the provided metric. If no trials contain the provided
|
| 511 |
+
metric, or if the value for the metric is NaN for all trials,
|
| 512 |
+
then returns None.
|
| 513 |
+
"""
|
| 514 |
+
if len(self.trials) == 1:
|
| 515 |
+
return self.trials[0]
|
| 516 |
+
|
| 517 |
+
metric = self._validate_metric(metric)
|
| 518 |
+
mode = self._validate_mode(mode)
|
| 519 |
+
|
| 520 |
+
if scope not in ["all", "last", "avg", "last-5-avg", "last-10-avg"]:
|
| 521 |
+
raise ValueError(
|
| 522 |
+
"ExperimentAnalysis: attempting to get best trial for "
|
| 523 |
+
'metric {} for scope {} not in ["all", "last", "avg", '
|
| 524 |
+
'"last-5-avg", "last-10-avg"]. '
|
| 525 |
+
"If you didn't pass a `metric` parameter to `tune.run()`, "
|
| 526 |
+
"you have to pass one when fetching the best trial.".format(
|
| 527 |
+
metric, scope
|
| 528 |
+
)
|
| 529 |
+
)
|
| 530 |
+
best_trial = None
|
| 531 |
+
best_metric_score = None
|
| 532 |
+
|
| 533 |
+
for trial in self.trials:
|
| 534 |
+
if metric not in trial.metric_analysis:
|
| 535 |
+
continue
|
| 536 |
+
|
| 537 |
+
if scope in ["last", "avg", "last-5-avg", "last-10-avg"]:
|
| 538 |
+
metric_score = trial.metric_analysis[metric][scope]
|
| 539 |
+
else:
|
| 540 |
+
metric_score = trial.metric_analysis[metric][mode]
|
| 541 |
+
|
| 542 |
+
if filter_nan_and_inf and is_nan_or_inf(metric_score):
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
if best_metric_score is None:
|
| 546 |
+
best_metric_score = metric_score
|
| 547 |
+
best_trial = trial
|
| 548 |
+
continue
|
| 549 |
+
|
| 550 |
+
if (mode == "max") and (best_metric_score < metric_score):
|
| 551 |
+
best_metric_score = metric_score
|
| 552 |
+
best_trial = trial
|
| 553 |
+
elif (mode == "min") and (best_metric_score > metric_score):
|
| 554 |
+
best_metric_score = metric_score
|
| 555 |
+
best_trial = trial
|
| 556 |
+
|
| 557 |
+
if not best_trial:
|
| 558 |
+
logger.warning(
|
| 559 |
+
"Could not find best trial. Did you pass the correct `metric` "
|
| 560 |
+
"parameter?"
|
| 561 |
+
)
|
| 562 |
+
return best_trial
|
| 563 |
+
|
| 564 |
+
def get_best_config(
|
| 565 |
+
self,
|
| 566 |
+
metric: Optional[str] = None,
|
| 567 |
+
mode: Optional[str] = None,
|
| 568 |
+
scope: str = "last",
|
| 569 |
+
) -> Optional[Dict]:
|
| 570 |
+
"""Retrieve the best config corresponding to the trial.
|
| 571 |
+
|
| 572 |
+
Compares all trials' scores on `metric`.
|
| 573 |
+
If ``metric`` is not specified, ``self.default_metric`` will be used.
|
| 574 |
+
If `mode` is not specified, ``self.default_mode`` will be used.
|
| 575 |
+
These values are usually initialized by passing the ``metric`` and
|
| 576 |
+
``mode`` parameters to ``tune.run()``.
|
| 577 |
+
|
| 578 |
+
Args:
|
| 579 |
+
metric: Key for trial info to order on. Defaults to
|
| 580 |
+
``self.default_metric``.
|
| 581 |
+
mode: One of [min, max]. Defaults to ``self.default_mode``.
|
| 582 |
+
scope: One of [all, last, avg, last-5-avg, last-10-avg].
|
| 583 |
+
If `scope=last`, only look at each trial's final step for
|
| 584 |
+
`metric`, and compare across trials based on `mode=[min,max]`.
|
| 585 |
+
If `scope=avg`, consider the simple average over all steps
|
| 586 |
+
for `metric` and compare across trials based on
|
| 587 |
+
`mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
|
| 588 |
+
consider the simple average over the last 5 or 10 steps for
|
| 589 |
+
`metric` and compare across trials based on `mode=[min,max]`.
|
| 590 |
+
If `scope=all`, find each trial's min/max score for `metric`
|
| 591 |
+
based on `mode`, and compare trials based on `mode=[min,max]`.
|
| 592 |
+
"""
|
| 593 |
+
best_trial = self.get_best_trial(metric, mode, scope)
|
| 594 |
+
return best_trial.config if best_trial else None
|
| 595 |
+
|
| 596 |
+
def get_last_checkpoint(
|
| 597 |
+
self, trial=None, metric="training_iteration", mode="max"
|
| 598 |
+
) -> Optional[Checkpoint]:
|
| 599 |
+
"""Gets the last checkpoint of the provided trial,
|
| 600 |
+
i.e., with the highest "training_iteration".
|
| 601 |
+
|
| 602 |
+
If no trial is specified, it loads the best trial according to the
|
| 603 |
+
provided metric and mode (defaults to max. training iteration).
|
| 604 |
+
|
| 605 |
+
Args:
|
| 606 |
+
trial: If None, load the best trial automatically.
|
| 607 |
+
metric: If no trial is specified, use this metric to identify
|
| 608 |
+
the best trial and load the last checkpoint from this trial.
|
| 609 |
+
mode: If no trial is specified, use the metric and this mode
|
| 610 |
+
to identify the best trial and load the last checkpoint from it.
|
| 611 |
+
|
| 612 |
+
Returns:
|
| 613 |
+
Path for last checkpoint of trial
|
| 614 |
+
"""
|
| 615 |
+
trial = trial or self.get_best_trial(metric, mode)
|
| 616 |
+
return self.get_best_checkpoint(trial, TRAINING_ITERATION, "max")
|
| 617 |
+
|
| 618 |
+
def _validate_metric(self, metric: str) -> str:
|
| 619 |
+
if not metric and not self.default_metric:
|
| 620 |
+
raise ValueError(
|
| 621 |
+
"No `metric` has been passed and `default_metric` has "
|
| 622 |
+
"not been set. Please specify the `metric` parameter."
|
| 623 |
+
)
|
| 624 |
+
return metric or self.default_metric
|
| 625 |
+
|
| 626 |
+
def _validate_mode(self, mode: str) -> str:
|
| 627 |
+
if not mode and not self.default_mode:
|
| 628 |
+
raise ValueError(
|
| 629 |
+
"No `mode` has been passed and `default_mode` has "
|
| 630 |
+
"not been set. Please specify the `mode` parameter."
|
| 631 |
+
)
|
| 632 |
+
if mode and mode not in ["min", "max"]:
|
| 633 |
+
raise ValueError("If set, `mode` has to be one of [min, max]")
|
| 634 |
+
return mode or self.default_mode
|
| 635 |
+
|
| 636 |
+
def _retrieve_rows(
|
| 637 |
+
self, metric: Optional[str] = None, mode: Optional[str] = None
|
| 638 |
+
) -> Dict[str, Any]:
|
| 639 |
+
assert mode is None or mode in ["max", "min"]
|
| 640 |
+
assert not mode or metric
|
| 641 |
+
rows = {}
|
| 642 |
+
for path, df in self.trial_dataframes.items():
|
| 643 |
+
if df.empty:
|
| 644 |
+
continue
|
| 645 |
+
if metric not in df:
|
| 646 |
+
idx = -1
|
| 647 |
+
elif mode == "max":
|
| 648 |
+
idx = df[metric].idxmax()
|
| 649 |
+
elif mode == "min":
|
| 650 |
+
idx = df[metric].idxmin()
|
| 651 |
+
else:
|
| 652 |
+
idx = -1
|
| 653 |
+
try:
|
| 654 |
+
rows[path] = df.iloc[idx].to_dict()
|
| 655 |
+
except TypeError:
|
| 656 |
+
# idx is nan
|
| 657 |
+
logger.warning(
|
| 658 |
+
"Warning: Non-numerical value(s) encountered for {}".format(path)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
return rows
|
| 662 |
+
|
| 663 |
+
def __getstate__(self) -> Dict[str, Any]:
|
| 664 |
+
"""Ensure that trials are marked as stubs when pickling,
|
| 665 |
+
so that they can be loaded later without the trainable
|
| 666 |
+
being registered.
|
| 667 |
+
"""
|
| 668 |
+
state = self.__dict__.copy()
|
| 669 |
+
|
| 670 |
+
def make_stub_if_needed(trial: Trial) -> Trial:
|
| 671 |
+
if trial.stub:
|
| 672 |
+
return trial
|
| 673 |
+
trial_copy = Trial(trial.trainable_name, stub=True)
|
| 674 |
+
trial_copy.__setstate__(trial.__getstate__())
|
| 675 |
+
return trial_copy
|
| 676 |
+
|
| 677 |
+
state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
|
| 678 |
+
return state
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (186 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (2.63 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/out_of_band_serialize_dataset.cpython-311.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/placeholder.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/test_utils.cpython-311.pyc
ADDED
|
Binary file (4.84 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/__pycache__/tuner_internal.cpython-311.pyc
ADDED
|
Binary file (28.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/impl/config.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from ray.air.config import CheckpointConfig as _CheckpointConfig
|
| 4 |
+
from ray.air.config import FailureConfig as _FailureConfig
|
| 5 |
+
from ray.air.config import RunConfig as _RunConfig
|
| 6 |
+
from ray.train.constants import _v2_migration_warnings_enabled
|
| 7 |
+
from ray.train.utils import _copy_doc, _log_deprecation_warning
|
| 8 |
+
|
| 9 |
+
# NOTE: This is just a pass-through wrapper around `ray.train.RunConfig`
|
| 10 |
+
# in order to detect whether the import module was correct (e.g. `ray.tune.RunConfig`).
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
@_copy_doc(_CheckpointConfig)
|
| 15 |
+
class CheckpointConfig(_CheckpointConfig):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
@_copy_doc(_FailureConfig)
|
| 21 |
+
class FailureConfig(_FailureConfig):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
@_copy_doc(_RunConfig)
|
| 27 |
+
class RunConfig(_RunConfig):
|
| 28 |
+
def __post_init__(self):
|
| 29 |
+
self.checkpoint_config = self.checkpoint_config or CheckpointConfig()
|
| 30 |
+
self.failure_config = self.failure_config or FailureConfig()
|
| 31 |
+
|
| 32 |
+
super().__post_init__()
|
| 33 |
+
|
| 34 |
+
if not isinstance(self.checkpoint_config, CheckpointConfig):
|
| 35 |
+
if _v2_migration_warnings_enabled():
|
| 36 |
+
_log_deprecation_warning(
|
| 37 |
+
"The `CheckpointConfig` class should be imported from `ray.tune` "
|
| 38 |
+
"when passing it to the Tuner. Please update your imports."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
if not isinstance(self.failure_config, FailureConfig):
|
| 42 |
+
if _v2_migration_warnings_enabled():
|
| 43 |
+
_log_deprecation_warning(
|
| 44 |
+
"The `FailureConfig` class should be imported from `ray.tune` "
|
| 45 |
+
"when passing it to the Tuner. Please update your imports."
|
| 46 |
+
)
|
.venv/lib/python3.11/site-packages/ray/tune/impl/out_of_band_serialize_dataset.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import traceback
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _deserialize_and_fully_execute_if_needed(serialized_ds: bytes):
|
| 8 |
+
ds = ray.data.Dataset.deserialize_lineage(serialized_ds)
|
| 9 |
+
return ds
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _reduce(ds: ray.data.Dataset):
|
| 13 |
+
tb_list = traceback.format_list(traceback.extract_stack())
|
| 14 |
+
_already_in_out_of_band_serialization = False
|
| 15 |
+
for tb in tb_list:
|
| 16 |
+
# TODO(xwjiang): Let's make this less hacky.
|
| 17 |
+
if "serialize_lineage" in tb:
|
| 18 |
+
_already_in_out_of_band_serialization = True
|
| 19 |
+
break
|
| 20 |
+
if not _already_in_out_of_band_serialization and ds.has_serializable_lineage():
|
| 21 |
+
return _deserialize_and_fully_execute_if_needed, (ds.serialize_lineage(),)
|
| 22 |
+
else:
|
| 23 |
+
return ds.__reduce__()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@contextlib.contextmanager
|
| 27 |
+
def out_of_band_serialize_dataset():
|
| 28 |
+
context = ray._private.worker.global_worker.get_serialization_context()
|
| 29 |
+
try:
|
| 30 |
+
context._register_cloudpickle_reducer(ray.data.Dataset, _reduce)
|
| 31 |
+
yield
|
| 32 |
+
finally:
|
| 33 |
+
context._unregister_cloudpickle_reducer(ray.data.Dataset)
|
.venv/lib/python3.11/site-packages/ray/tune/impl/placeholder.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Any, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
from ray.tune.search.sample import Categorical, Domain, Function
|
| 6 |
+
from ray.tune.search.variant_generator import assign_value
|
| 7 |
+
from ray.util.annotations import DeveloperAPI
|
| 8 |
+
|
| 9 |
+
ID_HASH_LENGTH = 8
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_resolvers_map():
|
| 13 |
+
return defaultdict(list)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _id_hash(path_tuple):
|
| 17 |
+
"""Compute a hash for the specific placeholder based on its path."""
|
| 18 |
+
return hashlib.sha1(str(path_tuple).encode("utf-8")).hexdigest()[:ID_HASH_LENGTH]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _FunctionResolver:
|
| 22 |
+
"""Replaced value for function typed objects."""
|
| 23 |
+
|
| 24 |
+
TOKEN = "__fn_ph"
|
| 25 |
+
|
| 26 |
+
def __init__(self, hash, fn):
|
| 27 |
+
self.hash = hash
|
| 28 |
+
self._fn = fn
|
| 29 |
+
|
| 30 |
+
def resolve(self, config: Dict):
|
| 31 |
+
"""Some functions take a resolved spec dict as input.
|
| 32 |
+
|
| 33 |
+
Note: Function placeholders are independently sampled during
|
| 34 |
+
resolution. Therefore their random states are not restored.
|
| 35 |
+
"""
|
| 36 |
+
return self._fn.sample(config=config)
|
| 37 |
+
|
| 38 |
+
def get_placeholder(self) -> str:
|
| 39 |
+
return (self.TOKEN, self.hash)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class _RefResolver:
|
| 43 |
+
"""Replaced value for all other non-primitive objects."""
|
| 44 |
+
|
| 45 |
+
TOKEN = "__ref_ph"
|
| 46 |
+
|
| 47 |
+
def __init__(self, hash, value):
|
| 48 |
+
self.hash = hash
|
| 49 |
+
self._value = value
|
| 50 |
+
|
| 51 |
+
def resolve(self):
|
| 52 |
+
return self._value
|
| 53 |
+
|
| 54 |
+
def get_placeholder(self) -> str:
|
| 55 |
+
return (self.TOKEN, self.hash)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _is_primitive(x):
|
| 59 |
+
"""Returns True if x is a primitive type.
|
| 60 |
+
|
| 61 |
+
Primitive types are int, float, str, bool, and None.
|
| 62 |
+
"""
|
| 63 |
+
return isinstance(x, (int, float, str, bool)) or x is None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@DeveloperAPI
|
| 67 |
+
def inject_placeholders(
|
| 68 |
+
config: Any,
|
| 69 |
+
resolvers: defaultdict,
|
| 70 |
+
id_prefix: Tuple = (),
|
| 71 |
+
path_prefix: Tuple = (),
|
| 72 |
+
) -> Dict:
|
| 73 |
+
"""Replaces reference objects contained by a config dict with placeholders.
|
| 74 |
+
|
| 75 |
+
Given a config dict, this function replaces all reference objects contained
|
| 76 |
+
by this dict with placeholder strings. It recursively expands nested dicts
|
| 77 |
+
and lists, and properly handles Tune native search objects such as Categorical
|
| 78 |
+
and Function.
|
| 79 |
+
This makes sure the config dict only contains primitive typed values, which
|
| 80 |
+
can then be handled by different search algorithms.
|
| 81 |
+
|
| 82 |
+
A few details about id_prefix and path_prefix. Consider the following config,
|
| 83 |
+
where "param1" is a simple grid search of 3 tuples.
|
| 84 |
+
|
| 85 |
+
config = {
|
| 86 |
+
"param1": tune.grid_search([
|
| 87 |
+
(Cat, None, None),
|
| 88 |
+
(None, Dog, None),
|
| 89 |
+
(None, None, Fish),
|
| 90 |
+
]),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
We will replace the 3 objects contained with placeholders. And after trial
|
| 94 |
+
expansion, the config may look like this:
|
| 95 |
+
|
| 96 |
+
config = {
|
| 97 |
+
"param1": (None, (placeholder, hash), None)
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
Now you need 2 pieces of information to resolve the placeholder. One is the
|
| 101 |
+
path of ("param1", 1), which tells you that the first element of the tuple
|
| 102 |
+
under "param1" key is a placeholder that needs to be resolved.
|
| 103 |
+
The other is the mapping from the placeholder to the actual object. In this
|
| 104 |
+
case hash -> Dog.
|
| 105 |
+
|
| 106 |
+
id and path prefixes serve exactly this purpose here. The difference between
|
| 107 |
+
these two is that id_prefix is the location of the value in the pre-injected
|
| 108 |
+
config tree. So if a value is the second option in a grid_search, it gets an
|
| 109 |
+
id part of 1. Injected placeholders all get unique id prefixes. path prefix
|
| 110 |
+
identifies a placeholder in the expanded config tree. So for example, all
|
| 111 |
+
options of a single grid_search will get the same path prefix. This is how
|
| 112 |
+
we know which location has a placeholder to be resolved in the post-expansion
|
| 113 |
+
tree.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
config: The config dict to replace references in.
|
| 117 |
+
resolvers: A dict from path to replaced objects.
|
| 118 |
+
id_prefix: The prefix to prepend to id every single placeholders.
|
| 119 |
+
path_prefix: The prefix to prepend to every path identifying
|
| 120 |
+
potential locations of placeholders in an expanded tree.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
The config with all references replaced.
|
| 124 |
+
"""
|
| 125 |
+
if isinstance(config, dict) and "grid_search" in config and len(config) == 1:
|
| 126 |
+
config["grid_search"] = [
|
| 127 |
+
# Different options gets different id prefixes.
|
| 128 |
+
# But we should omit appending to path_prefix because after expansion,
|
| 129 |
+
# this level will not be there.
|
| 130 |
+
inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
|
| 131 |
+
for i, choice in enumerate(config["grid_search"])
|
| 132 |
+
]
|
| 133 |
+
return config
|
| 134 |
+
elif isinstance(config, dict):
|
| 135 |
+
return {
|
| 136 |
+
k: inject_placeholders(v, resolvers, id_prefix + (k,), path_prefix + (k,))
|
| 137 |
+
for k, v in config.items()
|
| 138 |
+
}
|
| 139 |
+
elif isinstance(config, list):
|
| 140 |
+
return [
|
| 141 |
+
inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
|
| 142 |
+
for i, elem in enumerate(config)
|
| 143 |
+
]
|
| 144 |
+
elif isinstance(config, tuple):
|
| 145 |
+
return tuple(
|
| 146 |
+
inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
|
| 147 |
+
for i, elem in enumerate(config)
|
| 148 |
+
)
|
| 149 |
+
elif _is_primitive(config):
|
| 150 |
+
# Primitive types.
|
| 151 |
+
return config
|
| 152 |
+
elif isinstance(config, Categorical):
|
| 153 |
+
config.categories = [
|
| 154 |
+
# Different options gets different id prefixes.
|
| 155 |
+
# But we should omit appending to path_prefix because after expansion,
|
| 156 |
+
# this level will not be there.
|
| 157 |
+
inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
|
| 158 |
+
for i, choice in enumerate(config.categories)
|
| 159 |
+
]
|
| 160 |
+
return config
|
| 161 |
+
elif isinstance(config, Function):
|
| 162 |
+
# Function type.
|
| 163 |
+
id_hash = _id_hash(id_prefix)
|
| 164 |
+
v = _FunctionResolver(id_hash, config)
|
| 165 |
+
resolvers[path_prefix].append(v)
|
| 166 |
+
return v.get_placeholder()
|
| 167 |
+
elif not isinstance(config, Domain):
|
| 168 |
+
# Other non-search space reference objects, dataset, actor handle, etc.
|
| 169 |
+
id_hash = _id_hash(id_prefix)
|
| 170 |
+
v = _RefResolver(id_hash, config)
|
| 171 |
+
resolvers[path_prefix].append(v)
|
| 172 |
+
return v.get_placeholder()
|
| 173 |
+
else:
|
| 174 |
+
# All the other cases, do nothing.
|
| 175 |
+
return config
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _get_placeholder(config: Any, prefix: Tuple, path: Tuple):
|
| 179 |
+
if not path:
|
| 180 |
+
return prefix, config
|
| 181 |
+
|
| 182 |
+
key = path[0]
|
| 183 |
+
if isinstance(config, tuple):
|
| 184 |
+
if config[0] in (_FunctionResolver.TOKEN, _RefResolver.TOKEN):
|
| 185 |
+
# Found a matching placeholder.
|
| 186 |
+
# Note that we do not require that the full path are consumed before
|
| 187 |
+
# declaring a match. Because this placeholder may be part of a nested
|
| 188 |
+
# search space. For example, the following config:
|
| 189 |
+
# config = {
|
| 190 |
+
# "param1": tune.grid_search([
|
| 191 |
+
# tune.grid_search([Object1, 2, 3]),
|
| 192 |
+
# tune.grid_search([Object2, 5, 6]),
|
| 193 |
+
# ]),
|
| 194 |
+
# }
|
| 195 |
+
# will result in placeholders under path ("param1", 0, 0).
|
| 196 |
+
# After expansion though, the choosen placeholder will live under path
|
| 197 |
+
# ("param1", 0) like this: config = {"param1": (Placeholder1, 2, 3)}
|
| 198 |
+
return prefix, config
|
| 199 |
+
elif key < len(config):
|
| 200 |
+
return _get_placeholder(
|
| 201 |
+
config[key], prefix=prefix + (path[0],), path=path[1:]
|
| 202 |
+
)
|
| 203 |
+
elif (isinstance(config, dict) and key in config) or (
|
| 204 |
+
isinstance(config, list) and key < len(config)
|
| 205 |
+
):
|
| 206 |
+
# Expand config tree recursively.
|
| 207 |
+
return _get_placeholder(config[key], prefix=prefix + (path[0],), path=path[1:])
|
| 208 |
+
|
| 209 |
+
# Can not find a matching placeholder.
|
| 210 |
+
return None, None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@DeveloperAPI
|
| 214 |
+
def resolve_placeholders(config: Any, replaced: defaultdict):
|
| 215 |
+
"""Replaces placeholders contained by a config dict with the original values.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
config: The config to replace placeholders in.
|
| 219 |
+
replaced: A dict from path to replaced objects.
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
def __resolve(resolver_type, args):
|
| 223 |
+
for path, resolvers in replaced.items():
|
| 224 |
+
assert resolvers
|
| 225 |
+
|
| 226 |
+
if not isinstance(resolvers[0], resolver_type):
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
prefix, ph = _get_placeholder(config, (), path)
|
| 230 |
+
if not ph:
|
| 231 |
+
# Represents an unchosen value. Just skip.
|
| 232 |
+
continue
|
| 233 |
+
|
| 234 |
+
for resolver in resolvers:
|
| 235 |
+
if resolver.hash != ph[1]:
|
| 236 |
+
continue
|
| 237 |
+
# Found the matching resolver.
|
| 238 |
+
assign_value(config, prefix, resolver.resolve(*args))
|
| 239 |
+
|
| 240 |
+
# RefResolvers first.
|
| 241 |
+
__resolve(_RefResolver, args=())
|
| 242 |
+
# Functions need to be resolved after RefResolvers, in case they are
|
| 243 |
+
# referencing values from the RefResolvers.
|
| 244 |
+
__resolve(_FunctionResolver, args=(config,))
|
.venv/lib/python3.11/site-packages/ray/tune/impl/test_utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.datasets import load_breast_cancer
|
| 2 |
+
|
| 3 |
+
from ray import tune
|
| 4 |
+
from ray.data import Dataset, Datasource, ReadTask, read_datasource
|
| 5 |
+
from ray.data.block import BlockMetadata
|
| 6 |
+
from ray.tune.impl.utils import execute_dataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# TODO(xwjiang): Enable this when Clark's out-of-band-serialization is landed.
|
| 10 |
+
class TestDatasource(Datasource):
|
| 11 |
+
def prepare_read(self, parallelism: int, **read_args):
|
| 12 |
+
import pyarrow as pa
|
| 13 |
+
|
| 14 |
+
def load_data():
|
| 15 |
+
data_raw = load_breast_cancer(as_frame=True)
|
| 16 |
+
dataset_df = data_raw["data"]
|
| 17 |
+
dataset_df["target"] = data_raw["target"]
|
| 18 |
+
return [pa.Table.from_pandas(dataset_df)]
|
| 19 |
+
|
| 20 |
+
meta = BlockMetadata(
|
| 21 |
+
num_rows=None,
|
| 22 |
+
size_bytes=None,
|
| 23 |
+
schema=None,
|
| 24 |
+
input_files=None,
|
| 25 |
+
exec_stats=None,
|
| 26 |
+
)
|
| 27 |
+
return [ReadTask(load_data, meta)]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def gen_dataset_func() -> Dataset:
|
| 31 |
+
test_datasource = TestDatasource()
|
| 32 |
+
return read_datasource(test_datasource)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_grid_search():
|
| 36 |
+
ds1 = gen_dataset_func().lazy().map(lambda x: x)
|
| 37 |
+
ds2 = gen_dataset_func().lazy().map(lambda x: x)
|
| 38 |
+
assert not ds1._plan._has_final_stage_snapshot()
|
| 39 |
+
assert not ds2._plan._has_final_stage_snapshot()
|
| 40 |
+
param_space = {"train_dataset": tune.grid_search([ds1, ds2])}
|
| 41 |
+
execute_dataset(param_space)
|
| 42 |
+
executed_ds = param_space["train_dataset"]["grid_search"]
|
| 43 |
+
assert len(executed_ds) == 2
|
| 44 |
+
assert executed_ds[0]._plan._has_final_stage_snapshot()
|
| 45 |
+
assert executed_ds[1]._plan._has_final_stage_snapshot()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_choice():
|
| 49 |
+
ds1 = gen_dataset_func().lazy().map(lambda x: x)
|
| 50 |
+
ds2 = gen_dataset_func().lazy().map(lambda x: x)
|
| 51 |
+
assert not ds1._plan._has_final_stage_snapshot()
|
| 52 |
+
assert not ds2._plan._has_final_stage_snapshot()
|
| 53 |
+
param_space = {"train_dataset": tune.choice([ds1, ds2])}
|
| 54 |
+
execute_dataset(param_space)
|
| 55 |
+
executed_ds = param_space["train_dataset"].categories
|
| 56 |
+
assert len(executed_ds) == 2
|
| 57 |
+
assert executed_ds[0]._plan._has_final_stage_snapshot()
|
| 58 |
+
assert executed_ds[1]._plan._has_final_stage_snapshot()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
import sys
|
| 63 |
+
|
| 64 |
+
import pytest
|
| 65 |
+
|
| 66 |
+
sys.exit(pytest.main(["-v", "-x", __file__]))
|
.venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import io
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import (
|
| 7 |
+
TYPE_CHECKING,
|
| 8 |
+
Any,
|
| 9 |
+
Callable,
|
| 10 |
+
Dict,
|
| 11 |
+
List,
|
| 12 |
+
Optional,
|
| 13 |
+
Tuple,
|
| 14 |
+
Type,
|
| 15 |
+
Union,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
import pyarrow.fs
|
| 19 |
+
|
| 20 |
+
import ray.cloudpickle as pickle
|
| 21 |
+
import ray.train
|
| 22 |
+
from ray.air._internal.uri_utils import URI
|
| 23 |
+
from ray.air._internal.usage import AirEntrypoint
|
| 24 |
+
from ray.train import ScalingConfig
|
| 25 |
+
from ray.train._internal.storage import StorageContext, get_fs_and_path
|
| 26 |
+
from ray.train.constants import _v2_migration_warnings_enabled
|
| 27 |
+
from ray.train.utils import _log_deprecation_warning
|
| 28 |
+
from ray.tune import (
|
| 29 |
+
Experiment,
|
| 30 |
+
ExperimentAnalysis,
|
| 31 |
+
ResumeConfig,
|
| 32 |
+
RunConfig,
|
| 33 |
+
TuneConfig,
|
| 34 |
+
TuneError,
|
| 35 |
+
)
|
| 36 |
+
from ray.tune.registry import is_function_trainable
|
| 37 |
+
from ray.tune.result_grid import ResultGrid
|
| 38 |
+
from ray.tune.trainable import Trainable
|
| 39 |
+
from ray.tune.tune import _Config, run
|
| 40 |
+
from ray.tune.utils import flatten_dict
|
| 41 |
+
from ray.util import inspect_serializability
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from ray.train.trainer import BaseTrainer
|
| 45 |
+
from ray.util.queue import Queue
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_TUNER_PKL = "tuner.pkl"
|
| 49 |
+
_TRAINABLE_KEY = "_trainable"
|
| 50 |
+
_CONVERTED_TRAINABLE_KEY = "_converted_trainable"
|
| 51 |
+
_PARAM_SPACE_KEY = "_param_space"
|
| 52 |
+
_EXPERIMENT_ANALYSIS_KEY = "_experiment_analysis"
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
|
| 56 |
+
TrainableType = Union[str, Callable, Type[Trainable]]
|
| 57 |
+
TrainableTypeOrTrainer = Union[TrainableType, "BaseTrainer"]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TunerInternal:
|
| 61 |
+
"""The real implementation behind external facing ``Tuner``.
|
| 62 |
+
|
| 63 |
+
The external facing ``Tuner`` multiplexes between local Tuner and remote Tuner
|
| 64 |
+
depending on whether in Ray client mode.
|
| 65 |
+
|
| 66 |
+
In Ray client mode, external ``Tuner`` wraps ``TunerInternal`` into a remote actor,
|
| 67 |
+
which is guaranteed to be placed on head node.
|
| 68 |
+
|
| 69 |
+
``TunerInternal`` can be constructed from fresh, in which case, ``trainable`` needs
|
| 70 |
+
to be provided, together with optional ``param_space``, ``tune_config`` and
|
| 71 |
+
``run_config``.
|
| 72 |
+
|
| 73 |
+
It can also be restored from a previous failed run (given ``restore_path``).
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
restore_path: The path from where the Tuner can be restored. If provided, None
|
| 77 |
+
of the rest args are needed.
|
| 78 |
+
resume_config: Resume config to configure which trials to continue.
|
| 79 |
+
trainable: The trainable to be tuned.
|
| 80 |
+
param_space: Search space of the tuning job.
|
| 81 |
+
One thing to note is that both preprocessor and dataset can be tuned here.
|
| 82 |
+
tune_config: Tuning algorithm specific configs.
|
| 83 |
+
Refer to ray.tune.tune_config.TuneConfig for more info.
|
| 84 |
+
run_config: Runtime configuration that is specific to individual trials.
|
| 85 |
+
If passed, this will overwrite the run config passed to the Trainer,
|
| 86 |
+
if applicable. Refer to ray.tune.RunConfig for more info.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
restore_path: str = None,
|
| 92 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
|
| 93 |
+
resume_config: Optional[ResumeConfig] = None,
|
| 94 |
+
trainable: Optional[TrainableTypeOrTrainer] = None,
|
| 95 |
+
param_space: Optional[Dict[str, Any]] = None,
|
| 96 |
+
tune_config: Optional[TuneConfig] = None,
|
| 97 |
+
run_config: Optional[RunConfig] = None,
|
| 98 |
+
_tuner_kwargs: Optional[Dict] = None,
|
| 99 |
+
_entrypoint: AirEntrypoint = AirEntrypoint.TUNER,
|
| 100 |
+
):
|
| 101 |
+
from ray.train.trainer import BaseTrainer
|
| 102 |
+
|
| 103 |
+
if isinstance(trainable, BaseTrainer):
|
| 104 |
+
if _v2_migration_warnings_enabled():
|
| 105 |
+
_log_deprecation_warning(
|
| 106 |
+
"Passing a Trainer to the Tuner is deprecated. "
|
| 107 |
+
"See the section on hyperparameter optimization in this "
|
| 108 |
+
"REP for more information: "
|
| 109 |
+
"https://github.com/ray-project/enhancements/pull/57"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
run_config = self._choose_run_config(
|
| 113 |
+
tuner_run_config=run_config,
|
| 114 |
+
trainer=trainable,
|
| 115 |
+
param_space=param_space,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self._tune_config = tune_config or TuneConfig()
|
| 119 |
+
self._run_config = copy.copy(run_config) or RunConfig()
|
| 120 |
+
|
| 121 |
+
if not isinstance(self._run_config, RunConfig):
|
| 122 |
+
if _v2_migration_warnings_enabled():
|
| 123 |
+
_log_deprecation_warning(
|
| 124 |
+
"The `RunConfig` class should be imported from `ray.tune` "
|
| 125 |
+
"when passing it to the Tuner. Please update your imports."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self._entrypoint = _entrypoint
|
| 129 |
+
|
| 130 |
+
# Restore from Tuner checkpoint.
|
| 131 |
+
if restore_path:
|
| 132 |
+
self._restore_from_path_or_uri(
|
| 133 |
+
path_or_uri=restore_path,
|
| 134 |
+
trainable=trainable,
|
| 135 |
+
overwrite_param_space=param_space,
|
| 136 |
+
resume_config=resume_config,
|
| 137 |
+
storage_filesystem=storage_filesystem,
|
| 138 |
+
)
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
# Start from fresh
|
| 142 |
+
if not trainable:
|
| 143 |
+
raise TuneError("You need to provide a trainable to tune.")
|
| 144 |
+
|
| 145 |
+
self.trainable = trainable
|
| 146 |
+
assert self.converted_trainable
|
| 147 |
+
self._validate_trainable(self.converted_trainable)
|
| 148 |
+
|
| 149 |
+
self.param_space = param_space
|
| 150 |
+
|
| 151 |
+
self._resume_config = None
|
| 152 |
+
self._is_restored = False
|
| 153 |
+
self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
|
| 154 |
+
self._experiment_analysis = None
|
| 155 |
+
|
| 156 |
+
self._run_config.name = (
|
| 157 |
+
self._run_config.name
|
| 158 |
+
or StorageContext.get_experiment_dir_name(self.converted_trainable)
|
| 159 |
+
)
|
| 160 |
+
# The storage context here is only used to access the resolved
|
| 161 |
+
# storage fs and experiment path, in order to avoid duplicating that logic.
|
| 162 |
+
# This is NOT the storage context object that gets passed to remote workers.
|
| 163 |
+
storage = StorageContext(
|
| 164 |
+
storage_path=self._run_config.storage_path,
|
| 165 |
+
experiment_dir_name=self._run_config.name,
|
| 166 |
+
storage_filesystem=self._run_config.storage_filesystem,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
fs = storage.storage_filesystem
|
| 170 |
+
fs.create_dir(storage.experiment_fs_path)
|
| 171 |
+
with fs.open_output_stream(
|
| 172 |
+
Path(storage.experiment_fs_path, _TUNER_PKL).as_posix()
|
| 173 |
+
) as f:
|
| 174 |
+
f.write(pickle.dumps(self.__getstate__()))
|
| 175 |
+
|
| 176 |
+
def get_run_config(self) -> RunConfig:
|
| 177 |
+
return self._run_config
|
| 178 |
+
|
| 179 |
+
# For Jupyter output with Ray Client
|
| 180 |
+
def set_run_config_and_remote_string_queue(
|
| 181 |
+
self, run_config: RunConfig, string_queue: "Queue"
|
| 182 |
+
):
|
| 183 |
+
self._run_config = run_config
|
| 184 |
+
self._tuner_kwargs["_remote_string_queue"] = string_queue
|
| 185 |
+
|
| 186 |
+
def clear_remote_string_queue(self):
|
| 187 |
+
self._tuner_kwargs.pop("_remote_string_queue", None)
|
| 188 |
+
|
| 189 |
+
def _expected_utilization(self, cpus_per_trial, cpus_total):
|
| 190 |
+
num_samples = self._tune_config.num_samples
|
| 191 |
+
if num_samples < 0: # TODO: simplify this in Tune
|
| 192 |
+
num_samples = math.inf
|
| 193 |
+
concurrent_trials = self._tune_config.max_concurrent_trials or 0
|
| 194 |
+
if concurrent_trials < 1: # TODO: simplify this in Tune
|
| 195 |
+
concurrent_trials = math.inf
|
| 196 |
+
|
| 197 |
+
actual_concurrency = min(
|
| 198 |
+
(
|
| 199 |
+
(cpus_total // cpus_per_trial) if cpus_per_trial else 0,
|
| 200 |
+
num_samples,
|
| 201 |
+
concurrent_trials,
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
return (actual_concurrency * cpus_per_trial) / (cpus_total + 0.001)
|
| 205 |
+
|
| 206 |
+
def _validate_trainable(
|
| 207 |
+
self, trainable: TrainableType, required_trainable_name: Optional[str] = None
|
| 208 |
+
):
|
| 209 |
+
"""Determines whether or not the trainable is valid.
|
| 210 |
+
|
| 211 |
+
This includes checks on the serializability of the trainable, as well
|
| 212 |
+
asserting that the trainable name is as expected on restoration.
|
| 213 |
+
|
| 214 |
+
This trainable name validation is needed due to an implementation detail
|
| 215 |
+
where the trainable name (which is differently generated depending on
|
| 216 |
+
the trainable type) is saved in the Trial metadata and needs to match
|
| 217 |
+
upon restoration. This does not affect the typical path, since `Tuner.restore`
|
| 218 |
+
expects the exact same trainable (which will have the same name).
|
| 219 |
+
|
| 220 |
+
Raises:
|
| 221 |
+
ValueError: if the trainable name does not match or if the trainable
|
| 222 |
+
is not serializable.
|
| 223 |
+
"""
|
| 224 |
+
try:
|
| 225 |
+
pickle.dumps(trainable)
|
| 226 |
+
except TypeError as e:
|
| 227 |
+
sio = io.StringIO()
|
| 228 |
+
inspect_serializability(trainable, print_file=sio)
|
| 229 |
+
msg = (
|
| 230 |
+
"The provided trainable is not serializable, which is a requirement "
|
| 231 |
+
"since the trainable is serialized and deserialized when transferred "
|
| 232 |
+
"to remote workers. See below for a trace of the non-serializable "
|
| 233 |
+
"objects that were found in your trainable:\n"
|
| 234 |
+
f"{sio.getvalue()}"
|
| 235 |
+
)
|
| 236 |
+
raise TypeError(msg) from e
|
| 237 |
+
|
| 238 |
+
if not required_trainable_name:
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
trainable_name = Experiment.get_trainable_name(trainable)
|
| 242 |
+
|
| 243 |
+
if trainable_name != required_trainable_name:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
"Invalid `trainable` input to `Tuner.restore()`. To fix this error, "
|
| 246 |
+
"pass in the same trainable that was used to initialize the Tuner. "
|
| 247 |
+
"Got a trainable with identifier "
|
| 248 |
+
f"'{trainable_name}' but expected '{required_trainable_name}'."
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def _set_trainable_on_restore(
|
| 252 |
+
self, trainable: TrainableType, old_trainable_name: Optional[str]
|
| 253 |
+
):
|
| 254 |
+
from ray.train.base_trainer import BaseTrainer
|
| 255 |
+
|
| 256 |
+
self.trainable = trainable
|
| 257 |
+
assert self.converted_trainable
|
| 258 |
+
self._validate_trainable(
|
| 259 |
+
trainable=self.converted_trainable,
|
| 260 |
+
required_trainable_name=old_trainable_name,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if isinstance(self.trainable, BaseTrainer):
|
| 264 |
+
# Log a warning in case the user tries to modify the
|
| 265 |
+
# `RunConfig` from the Trainer
|
| 266 |
+
trainer: BaseTrainer = self.trainable
|
| 267 |
+
|
| 268 |
+
# Only log if the Trainer has a non-default RunConfig
|
| 269 |
+
if trainer.run_config != RunConfig():
|
| 270 |
+
logger.warning(
|
| 271 |
+
"The Tune experiment will restore using the original run's "
|
| 272 |
+
"`RunConfig`. If you made any changes to the `RunConfig` "
|
| 273 |
+
"within the Trainer you passed into `Tuner.restore`, "
|
| 274 |
+
"they will be ignored in the resumed run."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
trainer.run_config = self._run_config
|
| 278 |
+
|
| 279 |
+
def _validate_param_space_on_restore(
|
| 280 |
+
self,
|
| 281 |
+
new_param_space: Dict[str, Any],
|
| 282 |
+
flattened_param_space_keys: Optional[List[str]],
|
| 283 |
+
):
|
| 284 |
+
"""Determines whether the (optionally) re-specified `param_space` is valid.
|
| 285 |
+
|
| 286 |
+
This method performs very loose validation on the new param_space to
|
| 287 |
+
prevent users from trying to specify new hyperparameters to tune over.
|
| 288 |
+
|
| 289 |
+
Raises:
|
| 290 |
+
ValueError: if not all keys match the original param_space.
|
| 291 |
+
"""
|
| 292 |
+
if flattened_param_space_keys is None:
|
| 293 |
+
# Backwards compatibility: skip validation
|
| 294 |
+
return
|
| 295 |
+
|
| 296 |
+
keys = sorted(flatten_dict(new_param_space).keys())
|
| 297 |
+
if keys != flattened_param_space_keys:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"Invalid `param_space` input to `Tuner.restore()`. To fix this error, "
|
| 300 |
+
"pass in the same `param_space` that was used to initialize the Tuner. "
|
| 301 |
+
"Only re-specify the `param_space` to refresh Ray object references "
|
| 302 |
+
"that no longer exist due to restoring from a new Ray cluster session. "
|
| 303 |
+
"It should not be used to introduce new hyperparameters to tune."
|
| 304 |
+
f"\n\nGot: {keys}\nExpected: {flattened_param_space_keys}"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def _set_param_space_on_restore(
|
| 308 |
+
self,
|
| 309 |
+
param_space: Optional[Dict[str, Any]],
|
| 310 |
+
flattened_param_space_keys: Optional[List[str]],
|
| 311 |
+
):
|
| 312 |
+
self.param_space = param_space
|
| 313 |
+
|
| 314 |
+
if self.param_space is not None:
|
| 315 |
+
# param_space = None -> use the original param_space
|
| 316 |
+
self._validate_param_space_on_restore(
|
| 317 |
+
new_param_space=self.param_space,
|
| 318 |
+
flattened_param_space_keys=flattened_param_space_keys,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def _load_tuner_state(
|
| 322 |
+
self, tuner_state: Dict[str, Any]
|
| 323 |
+
) -> Tuple[Optional[str], Optional[List[str]]]:
|
| 324 |
+
"""Loads Tuner state from the previously saved `tuner.pkl`.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
tuner_pkl_path: pathlib.Path of the `tuner.pkl` file saved during the
|
| 328 |
+
original Tuner initialization.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
tuple: of `(old_trainable_name, flattened_param_space_keys)` used for
|
| 332 |
+
validating the re-specified `trainable` and `param_space`.
|
| 333 |
+
"""
|
| 334 |
+
# NOTE: These are magic keys used for validating restore args.
|
| 335 |
+
old_trainable_name = tuner_state.pop("__trainable_name", None)
|
| 336 |
+
flattened_param_space_keys = tuner_state.pop(
|
| 337 |
+
"__flattened_param_space_keys", None
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.__setstate__(tuner_state)
|
| 341 |
+
|
| 342 |
+
return old_trainable_name, flattened_param_space_keys
|
| 343 |
+
|
| 344 |
+
def _restore_from_path_or_uri(
|
| 345 |
+
self,
|
| 346 |
+
path_or_uri: str,
|
| 347 |
+
trainable: TrainableTypeOrTrainer,
|
| 348 |
+
overwrite_param_space: Optional[Dict[str, Any]],
|
| 349 |
+
resume_config: ResumeConfig,
|
| 350 |
+
storage_filesystem: Optional[pyarrow.fs.FileSystem],
|
| 351 |
+
):
|
| 352 |
+
fs, fs_path = get_fs_and_path(path_or_uri, storage_filesystem)
|
| 353 |
+
with fs.open_input_file(Path(fs_path, _TUNER_PKL).as_posix()) as f:
|
| 354 |
+
tuner_state = pickle.loads(f.readall())
|
| 355 |
+
|
| 356 |
+
old_trainable_name, flattened_param_space_keys = self._load_tuner_state(
|
| 357 |
+
tuner_state
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Perform validation and set the re-specified `trainable` and `param_space`
|
| 361 |
+
self._set_trainable_on_restore(
|
| 362 |
+
trainable=trainable, old_trainable_name=old_trainable_name
|
| 363 |
+
)
|
| 364 |
+
self._set_param_space_on_restore(
|
| 365 |
+
param_space=overwrite_param_space,
|
| 366 |
+
flattened_param_space_keys=flattened_param_space_keys,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Update RunConfig to reflect changes in the experiment directory
|
| 370 |
+
path_or_uri_obj = URI(path_or_uri)
|
| 371 |
+
|
| 372 |
+
# Infer the `storage_path` and run `name` of the restored run using the
|
| 373 |
+
# experiment directory.
|
| 374 |
+
# Ex: ~/ray_results/exp_name -> ~/ray_results, exp_name
|
| 375 |
+
# Ex: s3://bucket/exp_name -> s3://bucket, exp_name
|
| 376 |
+
self._run_config.name = path_or_uri_obj.name
|
| 377 |
+
self._run_config.storage_path = str(path_or_uri_obj.parent)
|
| 378 |
+
# Update the storage_filesystem with the one passed in on restoration, if any.
|
| 379 |
+
self._run_config.storage_filesystem = storage_filesystem
|
| 380 |
+
|
| 381 |
+
# Load the experiment results at the point where it left off.
|
| 382 |
+
try:
|
| 383 |
+
self._experiment_analysis = ExperimentAnalysis(
|
| 384 |
+
experiment_checkpoint_path=path_or_uri,
|
| 385 |
+
default_metric=self._tune_config.metric,
|
| 386 |
+
default_mode=self._tune_config.mode,
|
| 387 |
+
storage_filesystem=storage_filesystem,
|
| 388 |
+
)
|
| 389 |
+
except Exception:
|
| 390 |
+
self._experiment_analysis = None
|
| 391 |
+
|
| 392 |
+
self._resume_config = resume_config
|
| 393 |
+
self._is_restored = True
|
| 394 |
+
|
| 395 |
+
def _choose_run_config(
|
| 396 |
+
self,
|
| 397 |
+
tuner_run_config: Optional[RunConfig],
|
| 398 |
+
trainer: "BaseTrainer",
|
| 399 |
+
param_space: Optional[Dict[str, Any]],
|
| 400 |
+
) -> RunConfig:
|
| 401 |
+
"""Chooses which `RunConfig` to use when multiple can be passed in
|
| 402 |
+
through a Trainer or the Tuner itself.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
tuner_run_config: The run config passed into the Tuner constructor.
|
| 406 |
+
trainer: The Trainer instance to use with Tune, which may have
|
| 407 |
+
a RunConfig specified by the user.
|
| 408 |
+
param_space: The param space passed to the Tuner.
|
| 409 |
+
|
| 410 |
+
Raises:
|
| 411 |
+
ValueError: if the `run_config` is specified as a hyperparameter.
|
| 412 |
+
"""
|
| 413 |
+
if param_space and "run_config" in param_space:
|
| 414 |
+
raise ValueError(
|
| 415 |
+
"`RunConfig` cannot be tuned as part of the `param_space`! "
|
| 416 |
+
"Move the run config to be a parameter of the `Tuner`: "
|
| 417 |
+
"Tuner(..., run_config=RunConfig(...))"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Both Tuner RunConfig + Trainer RunConfig --> prefer Tuner RunConfig
|
| 421 |
+
if tuner_run_config and trainer.run_config != ray.train.RunConfig():
|
| 422 |
+
logger.info(
|
| 423 |
+
"A `RunConfig` was passed to both the `Tuner` and the "
|
| 424 |
+
f"`{trainer.__class__.__name__}`. The run config passed to "
|
| 425 |
+
"the `Tuner` is the one that will be used."
|
| 426 |
+
)
|
| 427 |
+
return tuner_run_config
|
| 428 |
+
|
| 429 |
+
# No Tuner RunConfig -> pass the Trainer config through
|
| 430 |
+
# This returns either a user-specified config, or the default RunConfig
|
| 431 |
+
# if nothing was provided to both the Trainer or Tuner.
|
| 432 |
+
if not tuner_run_config:
|
| 433 |
+
return trainer.run_config
|
| 434 |
+
|
| 435 |
+
# Tuner RunConfig + No Trainer RunConfig --> Use the Tuner config
|
| 436 |
+
return tuner_run_config
|
| 437 |
+
|
| 438 |
+
def _process_scaling_config(self) -> None:
|
| 439 |
+
"""Converts ``self._param_space["scaling_config"]`` to a dict.
|
| 440 |
+
|
| 441 |
+
The dict is converted back to a dataclass by the Trainer, after the
|
| 442 |
+
Tune search specification is resolved.
|
| 443 |
+
"""
|
| 444 |
+
# TODO: introduce `ray.tune.sample.TuneableDataclass` and allow Tune to
|
| 445 |
+
# natively resolve specs with dataclasses.
|
| 446 |
+
scaling_config = self._param_space.get("scaling_config")
|
| 447 |
+
if not isinstance(scaling_config, ScalingConfig):
|
| 448 |
+
return
|
| 449 |
+
self._param_space["scaling_config"] = scaling_config.__dict__.copy()
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def trainable(self) -> TrainableTypeOrTrainer:
|
| 453 |
+
return self._trainable
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def converted_trainable(self) -> TrainableType:
|
| 457 |
+
return self._converted_trainable
|
| 458 |
+
|
| 459 |
+
@trainable.setter
|
| 460 |
+
def trainable(self, trainable: TrainableTypeOrTrainer):
|
| 461 |
+
self._trainable = trainable
|
| 462 |
+
self._converted_trainable = self._convert_trainable(trainable)
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def param_space(self) -> Optional[Dict[str, Any]]:
|
| 466 |
+
return self._param_space
|
| 467 |
+
|
| 468 |
+
@param_space.setter
|
| 469 |
+
def param_space(self, param_space: Optional[Dict[str, Any]]):
|
| 470 |
+
# Handle any configs that adhere to the `to_dict` interface.
|
| 471 |
+
# Ex: AlgorithmConfig from RLlib
|
| 472 |
+
if isinstance(param_space, _Config):
|
| 473 |
+
param_space = param_space.to_dict()
|
| 474 |
+
|
| 475 |
+
if not isinstance(param_space, dict) and param_space is not None:
|
| 476 |
+
raise ValueError(
|
| 477 |
+
"The `param_space` passed to the `Tuner` must be a dict. "
|
| 478 |
+
f"Got '{type(param_space)}' instead."
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
self._param_space = param_space
|
| 482 |
+
|
| 483 |
+
if param_space:
|
| 484 |
+
self._process_scaling_config()
|
| 485 |
+
|
| 486 |
+
def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType:
|
| 487 |
+
"""Converts a Trainer to a Tune trainable and saves the converted
|
| 488 |
+
trainable. If not using a Trainer, this leaves the trainable as is."""
|
| 489 |
+
from ray.train.trainer import BaseTrainer
|
| 490 |
+
|
| 491 |
+
return (
|
| 492 |
+
trainable.as_trainable()
|
| 493 |
+
if isinstance(trainable, BaseTrainer)
|
| 494 |
+
else trainable
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
def fit(self) -> ResultGrid:
|
| 498 |
+
trainable = self.converted_trainable
|
| 499 |
+
param_space = copy.deepcopy(self.param_space)
|
| 500 |
+
if not self._is_restored:
|
| 501 |
+
analysis = self._fit_internal(trainable, param_space)
|
| 502 |
+
else:
|
| 503 |
+
analysis = self._fit_resume(trainable, param_space)
|
| 504 |
+
|
| 505 |
+
self._experiment_analysis = analysis
|
| 506 |
+
|
| 507 |
+
return ResultGrid(self._experiment_analysis)
|
| 508 |
+
|
| 509 |
+
def get_results(self) -> ResultGrid:
|
| 510 |
+
if not self._experiment_analysis:
|
| 511 |
+
raise RuntimeError(
|
| 512 |
+
"Can't return results as experiment has not been run, yet. "
|
| 513 |
+
"Call `Tuner.fit()` to run the experiment first."
|
| 514 |
+
)
|
| 515 |
+
return ResultGrid(self._experiment_analysis)
|
| 516 |
+
|
| 517 |
+
def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
|
| 518 |
+
"""Get tune.run arguments common for both new and resumed runs."""
|
| 519 |
+
# Avoid overwriting the originally configured checkpoint config.
|
| 520 |
+
checkpoint_config = copy.deepcopy(self._run_config.checkpoint_config)
|
| 521 |
+
|
| 522 |
+
if checkpoint_config.checkpoint_frequency:
|
| 523 |
+
# Function trainables (and thus most of our trainers) usually don't handle
|
| 524 |
+
# this argument.
|
| 525 |
+
handle_checkpoint_freq = getattr(
|
| 526 |
+
trainable, "_handles_checkpoint_freq", None
|
| 527 |
+
)
|
| 528 |
+
if handle_checkpoint_freq is False:
|
| 529 |
+
# If we specifically know this trainable doesn't support the
|
| 530 |
+
# argument, raise an error
|
| 531 |
+
raise ValueError(
|
| 532 |
+
"You passed `checkpoint_frequency="
|
| 533 |
+
f"{checkpoint_config.checkpoint_frequency}` to your "
|
| 534 |
+
"CheckpointConfig, but this trainer does not support "
|
| 535 |
+
"this argument. If you passed in a Trainer that takes in a "
|
| 536 |
+
"custom training loop, you will need to "
|
| 537 |
+
"report a checkpoint every `checkpoint_frequency` iterations "
|
| 538 |
+
"within your training loop using "
|
| 539 |
+
"`ray.train.report(metrics=..., checkpoint=...)` "
|
| 540 |
+
"to get this behavior."
|
| 541 |
+
)
|
| 542 |
+
elif handle_checkpoint_freq is True:
|
| 543 |
+
# If we specifically support it, it's handled in the training loop,
|
| 544 |
+
# so we disable tune's bookkeeping.
|
| 545 |
+
checkpoint_config.checkpoint_frequency = 0
|
| 546 |
+
# Otherwise, the trainable is not a Trainer and we just keep the
|
| 547 |
+
# user-supplied value.
|
| 548 |
+
# Function trainables will raise a runtime error later if set > 0
|
| 549 |
+
if checkpoint_config.checkpoint_at_end is not None:
|
| 550 |
+
# Again, function trainables usually don't handle this argument.
|
| 551 |
+
handle_cp_at_end = getattr(trainable, "_handles_checkpoint_at_end", None)
|
| 552 |
+
if handle_cp_at_end is False:
|
| 553 |
+
# If we specifically know we don't support it, raise an error.
|
| 554 |
+
raise ValueError(
|
| 555 |
+
"You passed `checkpoint_at_end="
|
| 556 |
+
f"{checkpoint_config.checkpoint_at_end}` "
|
| 557 |
+
"to your CheckpointConfig, but this trainer does not support "
|
| 558 |
+
"this argument. If you passed in a Trainer that takes in a "
|
| 559 |
+
"custom training loop, you should include one last call to "
|
| 560 |
+
"`ray.train.report(metrics=..., checkpoint=...)` "
|
| 561 |
+
"at the end of your training loop to get this behavior."
|
| 562 |
+
)
|
| 563 |
+
elif handle_cp_at_end is True:
|
| 564 |
+
# If we specifically support it, it's handled in the training loop,
|
| 565 |
+
# so we disable tune's internal bookkeeping.
|
| 566 |
+
checkpoint_config.checkpoint_at_end = False
|
| 567 |
+
# If this is a user-defined trainable, just keep the value
|
| 568 |
+
# Function trainables will raise a runtime error later if set to True
|
| 569 |
+
else:
|
| 570 |
+
# Set default to False for function trainables and True for everything else
|
| 571 |
+
if is_function_trainable(trainable):
|
| 572 |
+
checkpoint_config.checkpoint_at_end = False
|
| 573 |
+
else:
|
| 574 |
+
checkpoint_config.checkpoint_at_end = True
|
| 575 |
+
|
| 576 |
+
return dict(
|
| 577 |
+
storage_path=self._run_config.storage_path,
|
| 578 |
+
storage_filesystem=self._run_config.storage_filesystem,
|
| 579 |
+
name=self._run_config.name,
|
| 580 |
+
mode=self._tune_config.mode,
|
| 581 |
+
metric=self._tune_config.metric,
|
| 582 |
+
callbacks=self._run_config.callbacks,
|
| 583 |
+
sync_config=self._run_config.sync_config,
|
| 584 |
+
stop=self._run_config.stop,
|
| 585 |
+
max_failures=self._run_config.failure_config.max_failures,
|
| 586 |
+
checkpoint_config=checkpoint_config,
|
| 587 |
+
raise_on_failed_trial=False,
|
| 588 |
+
fail_fast=(self._run_config.failure_config.fail_fast),
|
| 589 |
+
progress_reporter=self._run_config.progress_reporter,
|
| 590 |
+
verbose=self._run_config.verbose,
|
| 591 |
+
reuse_actors=self._tune_config.reuse_actors,
|
| 592 |
+
max_concurrent_trials=self._tune_config.max_concurrent_trials,
|
| 593 |
+
time_budget_s=self._tune_config.time_budget_s,
|
| 594 |
+
trial_name_creator=self._tune_config.trial_name_creator,
|
| 595 |
+
trial_dirname_creator=self._tune_config.trial_dirname_creator,
|
| 596 |
+
_entrypoint=self._entrypoint,
|
| 597 |
+
# Deprecated
|
| 598 |
+
chdir_to_trial_dir=self._tune_config.chdir_to_trial_dir,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def _fit_internal(
|
| 602 |
+
self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
|
| 603 |
+
) -> ExperimentAnalysis:
|
| 604 |
+
"""Fitting for a fresh Tuner."""
|
| 605 |
+
args = {
|
| 606 |
+
**self._get_tune_run_arguments(trainable),
|
| 607 |
+
**dict(
|
| 608 |
+
run_or_experiment=trainable,
|
| 609 |
+
config=param_space,
|
| 610 |
+
num_samples=self._tune_config.num_samples,
|
| 611 |
+
search_alg=self._tune_config.search_alg,
|
| 612 |
+
scheduler=self._tune_config.scheduler,
|
| 613 |
+
log_to_file=self._run_config.log_to_file,
|
| 614 |
+
),
|
| 615 |
+
**self._tuner_kwargs,
|
| 616 |
+
}
|
| 617 |
+
analysis = run(
|
| 618 |
+
**args,
|
| 619 |
+
)
|
| 620 |
+
self.clear_remote_string_queue()
|
| 621 |
+
return analysis
|
| 622 |
+
|
| 623 |
+
def _fit_resume(
|
| 624 |
+
self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
|
| 625 |
+
) -> ExperimentAnalysis:
|
| 626 |
+
"""Fitting for a restored Tuner."""
|
| 627 |
+
assert self._resume_config
|
| 628 |
+
|
| 629 |
+
args = {
|
| 630 |
+
**self._get_tune_run_arguments(trainable),
|
| 631 |
+
**dict(
|
| 632 |
+
run_or_experiment=trainable,
|
| 633 |
+
config=param_space,
|
| 634 |
+
resume_config=self._resume_config,
|
| 635 |
+
search_alg=self._tune_config.search_alg,
|
| 636 |
+
scheduler=self._tune_config.scheduler,
|
| 637 |
+
),
|
| 638 |
+
**self._tuner_kwargs,
|
| 639 |
+
}
|
| 640 |
+
analysis = run(**args)
|
| 641 |
+
self.clear_remote_string_queue()
|
| 642 |
+
return analysis
|
| 643 |
+
|
| 644 |
+
def __getstate__(self):
|
| 645 |
+
state = self.__dict__.copy()
|
| 646 |
+
state["_tuner_kwargs"] = state["_tuner_kwargs"].copy()
|
| 647 |
+
state["_tuner_kwargs"].pop("_remote_string_queue", None)
|
| 648 |
+
state.pop(_TRAINABLE_KEY, None)
|
| 649 |
+
trainable = state.pop(_CONVERTED_TRAINABLE_KEY, None)
|
| 650 |
+
param_space = state.pop(_PARAM_SPACE_KEY, None)
|
| 651 |
+
state.pop(_EXPERIMENT_ANALYSIS_KEY, None)
|
| 652 |
+
|
| 653 |
+
state["__trainable_name"] = (
|
| 654 |
+
Experiment.get_trainable_name(trainable) if trainable else None
|
| 655 |
+
)
|
| 656 |
+
state["__flattened_param_space_keys"] = (
|
| 657 |
+
sorted(flatten_dict(param_space).keys())
|
| 658 |
+
if param_space is not None
|
| 659 |
+
else None
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
return state
|
| 663 |
+
|
| 664 |
+
def __setstate__(self, state):
|
| 665 |
+
# Make sure the magic metadata gets removed first.
|
| 666 |
+
state.pop("__flattened_param_space_keys", None)
|
| 667 |
+
state.pop("__trainable_name", None)
|
| 668 |
+
|
| 669 |
+
self.__dict__.update(state)
|
.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (3.35 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/schedulers/__pycache__/resource_changing_scheduler.cpython-311.pyc
ADDED
|
Binary file (41 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/__init__.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray._private.utils import get_function_args
|
| 2 |
+
from ray.tune.search.basic_variant import BasicVariantGenerator
|
| 3 |
+
from ray.tune.search.concurrency_limiter import ConcurrencyLimiter
|
| 4 |
+
from ray.tune.search.repeater import Repeater
|
| 5 |
+
from ray.tune.search.search_algorithm import SearchAlgorithm
|
| 6 |
+
from ray.tune.search.search_generator import SearchGenerator
|
| 7 |
+
from ray.tune.search.searcher import Searcher
|
| 8 |
+
from ray.tune.search.variant_generator import grid_search
|
| 9 |
+
from ray.util import PublicAPI
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _import_variant_generator():
|
| 13 |
+
return BasicVariantGenerator
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _import_ax_search():
|
| 17 |
+
from ray.tune.search.ax.ax_search import AxSearch
|
| 18 |
+
|
| 19 |
+
return AxSearch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _import_hyperopt_search():
|
| 23 |
+
from ray.tune.search.hyperopt.hyperopt_search import HyperOptSearch
|
| 24 |
+
|
| 25 |
+
return HyperOptSearch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _import_bayesopt_search():
|
| 29 |
+
from ray.tune.search.bayesopt.bayesopt_search import BayesOptSearch
|
| 30 |
+
|
| 31 |
+
return BayesOptSearch
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _import_bohb_search():
|
| 35 |
+
from ray.tune.search.bohb.bohb_search import TuneBOHB
|
| 36 |
+
|
| 37 |
+
return TuneBOHB
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _import_nevergrad_search():
|
| 41 |
+
from ray.tune.search.nevergrad.nevergrad_search import NevergradSearch
|
| 42 |
+
|
| 43 |
+
return NevergradSearch
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _import_optuna_search():
|
| 47 |
+
from ray.tune.search.optuna.optuna_search import OptunaSearch
|
| 48 |
+
|
| 49 |
+
return OptunaSearch
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _import_zoopt_search():
|
| 53 |
+
from ray.tune.search.zoopt.zoopt_search import ZOOptSearch
|
| 54 |
+
|
| 55 |
+
return ZOOptSearch
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _import_hebo_search():
|
| 59 |
+
from ray.tune.search.hebo.hebo_search import HEBOSearch
|
| 60 |
+
|
| 61 |
+
return HEBOSearch
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
SEARCH_ALG_IMPORT = {
|
| 65 |
+
"variant_generator": _import_variant_generator,
|
| 66 |
+
"random": _import_variant_generator,
|
| 67 |
+
"ax": _import_ax_search,
|
| 68 |
+
"hyperopt": _import_hyperopt_search,
|
| 69 |
+
"bayesopt": _import_bayesopt_search,
|
| 70 |
+
"bohb": _import_bohb_search,
|
| 71 |
+
"nevergrad": _import_nevergrad_search,
|
| 72 |
+
"optuna": _import_optuna_search,
|
| 73 |
+
"zoopt": _import_zoopt_search,
|
| 74 |
+
"hebo": _import_hebo_search,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@PublicAPI(stability="beta")
|
| 79 |
+
def create_searcher(
|
| 80 |
+
search_alg,
|
| 81 |
+
**kwargs,
|
| 82 |
+
):
|
| 83 |
+
"""Instantiate a search algorithm based on the given string.
|
| 84 |
+
|
| 85 |
+
This is useful for swapping between different search algorithms.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
search_alg: The search algorithm to use.
|
| 89 |
+
metric: The training result objective value attribute. Stopping
|
| 90 |
+
procedures will use this attribute.
|
| 91 |
+
mode: One of {min, max}. Determines whether objective is
|
| 92 |
+
minimizing or maximizing the metric attribute.
|
| 93 |
+
**kwargs: Additional parameters.
|
| 94 |
+
These keyword arguments will be passed to the initialization
|
| 95 |
+
function of the chosen class.
|
| 96 |
+
Returns:
|
| 97 |
+
ray.tune.search.Searcher: The search algorithm.
|
| 98 |
+
Example:
|
| 99 |
+
>>> from ray import tune # doctest: +SKIP
|
| 100 |
+
>>> search_alg = tune.create_searcher('ax') # doctest: +SKIP
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
search_alg = search_alg.lower()
|
| 104 |
+
if search_alg not in SEARCH_ALG_IMPORT:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"The `search_alg` argument must be one of "
|
| 107 |
+
f"{list(SEARCH_ALG_IMPORT)}. "
|
| 108 |
+
f"Got: {search_alg}"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
|
| 112 |
+
|
| 113 |
+
search_alg_args = get_function_args(SearcherClass)
|
| 114 |
+
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args}
|
| 115 |
+
|
| 116 |
+
return SearcherClass(**trimmed_kwargs)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
UNRESOLVED_SEARCH_SPACE = str(
|
| 120 |
+
"You passed a `{par}` parameter to {cls} that contained unresolved search "
|
| 121 |
+
"space definitions. {cls} should however be instantiated with fully "
|
| 122 |
+
"configured search spaces only. To use Ray Tune's automatic search space "
|
| 123 |
+
"conversion, pass the space definition as part of the `param_space` argument "
|
| 124 |
+
"to `tune.Tuner()` instead."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
UNDEFINED_SEARCH_SPACE = str(
|
| 128 |
+
"Trying to sample a configuration from {cls}, but no search "
|
| 129 |
+
"space has been defined. Either pass the `{space}` argument when "
|
| 130 |
+
"instantiating the search algorithm, or pass a `param_space` to "
|
| 131 |
+
"`tune.Tuner()`."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
UNDEFINED_METRIC_MODE = str(
|
| 135 |
+
"Trying to sample a configuration from {cls}, but the `metric` "
|
| 136 |
+
"({metric}) or `mode` ({mode}) parameters have not been set. "
|
| 137 |
+
"Either pass these arguments when instantiating the search algorithm, "
|
| 138 |
+
"or pass them to `tune.TuneConfig()`."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
__all__ = [
|
| 143 |
+
"SearchAlgorithm",
|
| 144 |
+
"Searcher",
|
| 145 |
+
"ConcurrencyLimiter",
|
| 146 |
+
"Repeater",
|
| 147 |
+
"BasicVariantGenerator",
|
| 148 |
+
"grid_search",
|
| 149 |
+
"SearchGenerator",
|
| 150 |
+
"UNRESOLVED_SEARCH_SPACE",
|
| 151 |
+
"UNDEFINED_SEARCH_SPACE",
|
| 152 |
+
"UNDEFINED_METRIC_MODE",
|
| 153 |
+
]
|
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/concurrency_limiter.cpython-311.pyc
ADDED
|
Binary file (9.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/sample.cpython-311.pyc
ADDED
|
Binary file (40.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_algorithm.cpython-311.pyc
ADDED
|
Binary file (5.87 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/__pycache__/search_generator.cpython-311.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/_mock.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
from ray.tune.experiment import Trial
|
| 4 |
+
from ray.tune.search import ConcurrencyLimiter, Searcher
|
| 5 |
+
from ray.tune.search.search_generator import SearchGenerator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _MockSearcher(Searcher):
|
| 9 |
+
def __init__(self, **kwargs):
|
| 10 |
+
self.live_trials = {}
|
| 11 |
+
self.counter = {"result": 0, "complete": 0}
|
| 12 |
+
self.final_results = []
|
| 13 |
+
self.stall = False
|
| 14 |
+
self.results = []
|
| 15 |
+
super(_MockSearcher, self).__init__(**kwargs)
|
| 16 |
+
|
| 17 |
+
def suggest(self, trial_id: str):
|
| 18 |
+
if not self.stall:
|
| 19 |
+
self.live_trials[trial_id] = 1
|
| 20 |
+
return {"test_variable": 2}
|
| 21 |
+
return None
|
| 22 |
+
|
| 23 |
+
def on_trial_result(self, trial_id: str, result: Dict):
|
| 24 |
+
self.counter["result"] += 1
|
| 25 |
+
self.results += [result]
|
| 26 |
+
|
| 27 |
+
def on_trial_complete(
|
| 28 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 29 |
+
):
|
| 30 |
+
self.counter["complete"] += 1
|
| 31 |
+
if result:
|
| 32 |
+
self._process_result(result)
|
| 33 |
+
if trial_id in self.live_trials:
|
| 34 |
+
del self.live_trials[trial_id]
|
| 35 |
+
|
| 36 |
+
def _process_result(self, result: Dict):
|
| 37 |
+
self.final_results += [result]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class _MockSuggestionAlgorithm(SearchGenerator):
|
| 41 |
+
def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
|
| 42 |
+
self.searcher = _MockSearcher(**kwargs)
|
| 43 |
+
if max_concurrent:
|
| 44 |
+
self.searcher = ConcurrencyLimiter(
|
| 45 |
+
self.searcher, max_concurrent=max_concurrent
|
| 46 |
+
)
|
| 47 |
+
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def live_trials(self) -> List[Trial]:
|
| 51 |
+
return self.searcher.live_trials
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def results(self) -> List[Dict]:
|
| 55 |
+
return self.searcher.results
|
.venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (291 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/ax/__pycache__/ax_search.cpython-311.pyc
ADDED
|
Binary file (19.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/basic_variant.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import itertools
|
| 3 |
+
import os
|
| 4 |
+
import uuid
|
| 5 |
+
import warnings
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from ray.air._internal.usage import tag_searcher
|
| 12 |
+
from ray.tune.error import TuneError
|
| 13 |
+
from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
|
| 14 |
+
from ray.tune.search.sample import _BackwardsCompatibleNumpyRng, np_random_generator
|
| 15 |
+
from ray.tune.search.search_algorithm import SearchAlgorithm
|
| 16 |
+
from ray.tune.search.variant_generator import (
|
| 17 |
+
_count_spec_samples,
|
| 18 |
+
_count_variants,
|
| 19 |
+
_flatten_resolved_vars,
|
| 20 |
+
_get_preset_variants,
|
| 21 |
+
format_vars,
|
| 22 |
+
generate_variants,
|
| 23 |
+
)
|
| 24 |
+
from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
|
| 25 |
+
from ray.util import PublicAPI
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from ray.tune.experiment import Experiment
|
| 29 |
+
|
| 30 |
+
SERIALIZATION_THRESHOLD = 1e6
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class _VariantIterator:
|
| 34 |
+
"""Iterates over generated variants from the search space.
|
| 35 |
+
|
| 36 |
+
This object also toggles between lazy evaluation and
|
| 37 |
+
eager evaluation of samples. If lazy evaluation is enabled,
|
| 38 |
+
this object cannot be serialized.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, iterable, lazy_eval=False):
|
| 42 |
+
self.lazy_eval = lazy_eval
|
| 43 |
+
self.iterable = iterable
|
| 44 |
+
self._has_next = True
|
| 45 |
+
if lazy_eval:
|
| 46 |
+
self._load_value()
|
| 47 |
+
else:
|
| 48 |
+
self.iterable = list(iterable)
|
| 49 |
+
self._has_next = bool(self.iterable)
|
| 50 |
+
|
| 51 |
+
def _load_value(self):
|
| 52 |
+
try:
|
| 53 |
+
self.next_value = next(self.iterable)
|
| 54 |
+
except StopIteration:
|
| 55 |
+
self._has_next = False
|
| 56 |
+
|
| 57 |
+
def has_next(self):
|
| 58 |
+
return self._has_next
|
| 59 |
+
|
| 60 |
+
def __next__(self):
|
| 61 |
+
if self.lazy_eval:
|
| 62 |
+
current_value = self.next_value
|
| 63 |
+
self._load_value()
|
| 64 |
+
return current_value
|
| 65 |
+
current_value = self.iterable.pop(0)
|
| 66 |
+
self._has_next = bool(self.iterable)
|
| 67 |
+
return current_value
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class _TrialIterator:
|
| 71 |
+
"""Generates trials from the spec.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
uuid_prefix: Used in creating the trial name.
|
| 75 |
+
num_samples: Number of samples from distribution
|
| 76 |
+
(same as tune.TuneConfig).
|
| 77 |
+
unresolved_spec: Experiment specification
|
| 78 |
+
that might have unresolved distributions.
|
| 79 |
+
constant_grid_search: Should random variables be sampled
|
| 80 |
+
first before iterating over grid variants (True) or not (False).
|
| 81 |
+
points_to_evaluate: Configurations that will be tried out without sampling.
|
| 82 |
+
lazy_eval: Whether variants should be generated
|
| 83 |
+
lazily or eagerly. This is toggled depending
|
| 84 |
+
on the size of the grid search.
|
| 85 |
+
start: index at which to start counting trials.
|
| 86 |
+
random_state (int | np.random.Generator | np.random.RandomState):
|
| 87 |
+
Seed or numpy random generator to use for reproducible results.
|
| 88 |
+
If None (default), will use the global numpy random generator
|
| 89 |
+
(``np.random``). Please note that full reproducibility cannot
|
| 90 |
+
be guaranteed in a distributed enviroment.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
uuid_prefix: str,
|
| 96 |
+
num_samples: int,
|
| 97 |
+
unresolved_spec: dict,
|
| 98 |
+
constant_grid_search: bool = False,
|
| 99 |
+
points_to_evaluate: Optional[List] = None,
|
| 100 |
+
lazy_eval: bool = False,
|
| 101 |
+
start: int = 0,
|
| 102 |
+
random_state: Optional[
|
| 103 |
+
Union[int, "np_random_generator", np.random.RandomState]
|
| 104 |
+
] = None,
|
| 105 |
+
):
|
| 106 |
+
self.parser = _make_parser()
|
| 107 |
+
self.num_samples = num_samples
|
| 108 |
+
self.uuid_prefix = uuid_prefix
|
| 109 |
+
self.num_samples_left = num_samples
|
| 110 |
+
self.unresolved_spec = unresolved_spec
|
| 111 |
+
self.constant_grid_search = constant_grid_search
|
| 112 |
+
self.points_to_evaluate = points_to_evaluate or []
|
| 113 |
+
self.num_points_to_evaluate = len(self.points_to_evaluate)
|
| 114 |
+
self.counter = start
|
| 115 |
+
self.lazy_eval = lazy_eval
|
| 116 |
+
self.variants = None
|
| 117 |
+
self.random_state = random_state
|
| 118 |
+
|
| 119 |
+
def create_trial(self, resolved_vars, spec):
|
| 120 |
+
trial_id = self.uuid_prefix + ("%05d" % self.counter)
|
| 121 |
+
experiment_tag = str(self.counter)
|
| 122 |
+
# Always append resolved vars to experiment tag?
|
| 123 |
+
if resolved_vars:
|
| 124 |
+
experiment_tag += "_{}".format(format_vars(resolved_vars))
|
| 125 |
+
self.counter += 1
|
| 126 |
+
return _create_trial_from_spec(
|
| 127 |
+
spec,
|
| 128 |
+
self.parser,
|
| 129 |
+
evaluated_params=_flatten_resolved_vars(resolved_vars),
|
| 130 |
+
trial_id=trial_id,
|
| 131 |
+
experiment_tag=experiment_tag,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def __next__(self):
|
| 135 |
+
"""Generates Trial objects with the variant generation process.
|
| 136 |
+
|
| 137 |
+
Uses a fixed point iteration to resolve variants. All trials
|
| 138 |
+
should be able to be generated at once.
|
| 139 |
+
|
| 140 |
+
See also: `ray.tune.search.variant_generator`.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Trial object
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
if "run" not in self.unresolved_spec:
|
| 147 |
+
raise TuneError("Must specify `run` in {}".format(self.unresolved_spec))
|
| 148 |
+
|
| 149 |
+
if self.variants and self.variants.has_next():
|
| 150 |
+
# This block will be skipped upon instantiation.
|
| 151 |
+
# `variants` will be set later after the first loop.
|
| 152 |
+
resolved_vars, spec = next(self.variants)
|
| 153 |
+
return self.create_trial(resolved_vars, spec)
|
| 154 |
+
|
| 155 |
+
if self.points_to_evaluate:
|
| 156 |
+
config = self.points_to_evaluate.pop(0)
|
| 157 |
+
self.num_samples_left -= 1
|
| 158 |
+
self.variants = _VariantIterator(
|
| 159 |
+
_get_preset_variants(
|
| 160 |
+
self.unresolved_spec,
|
| 161 |
+
config,
|
| 162 |
+
constant_grid_search=self.constant_grid_search,
|
| 163 |
+
random_state=self.random_state,
|
| 164 |
+
),
|
| 165 |
+
lazy_eval=self.lazy_eval,
|
| 166 |
+
)
|
| 167 |
+
resolved_vars, spec = next(self.variants)
|
| 168 |
+
return self.create_trial(resolved_vars, spec)
|
| 169 |
+
elif self.num_samples_left > 0:
|
| 170 |
+
self.variants = _VariantIterator(
|
| 171 |
+
generate_variants(
|
| 172 |
+
self.unresolved_spec,
|
| 173 |
+
constant_grid_search=self.constant_grid_search,
|
| 174 |
+
random_state=self.random_state,
|
| 175 |
+
),
|
| 176 |
+
lazy_eval=self.lazy_eval,
|
| 177 |
+
)
|
| 178 |
+
self.num_samples_left -= 1
|
| 179 |
+
resolved_vars, spec = next(self.variants)
|
| 180 |
+
return self.create_trial(resolved_vars, spec)
|
| 181 |
+
else:
|
| 182 |
+
raise StopIteration
|
| 183 |
+
|
| 184 |
+
def __iter__(self):
|
| 185 |
+
return self
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@PublicAPI
|
| 189 |
+
class BasicVariantGenerator(SearchAlgorithm):
|
| 190 |
+
"""Uses Tune's variant generation for resolving variables.
|
| 191 |
+
|
| 192 |
+
This is the default search algorithm used if no other search algorithm
|
| 193 |
+
is specified.
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
points_to_evaluate: Initial parameter suggestions to be run
|
| 198 |
+
first. This is for when you already have some good parameters
|
| 199 |
+
you want to run first to help the algorithm make better suggestions
|
| 200 |
+
for future parameters. Needs to be a list of dicts containing the
|
| 201 |
+
configurations.
|
| 202 |
+
max_concurrent: Maximum number of concurrently running trials.
|
| 203 |
+
If 0 (default), no maximum is enforced.
|
| 204 |
+
constant_grid_search: If this is set to ``True``, Ray Tune will
|
| 205 |
+
*first* try to sample random values and keep them constant over
|
| 206 |
+
grid search parameters. If this is set to ``False`` (default),
|
| 207 |
+
Ray Tune will sample new random parameters in each grid search
|
| 208 |
+
condition.
|
| 209 |
+
random_state:
|
| 210 |
+
Seed or numpy random generator to use for reproducible results.
|
| 211 |
+
If None (default), will use the global numpy random generator
|
| 212 |
+
(``np.random``). Please note that full reproducibility cannot
|
| 213 |
+
be guaranteed in a distributed environment.
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
Example:
|
| 217 |
+
|
| 218 |
+
.. code-block:: python
|
| 219 |
+
|
| 220 |
+
from ray import tune
|
| 221 |
+
|
| 222 |
+
# This will automatically use the `BasicVariantGenerator`
|
| 223 |
+
tuner = tune.Tuner(
|
| 224 |
+
lambda config: config["a"] + config["b"],
|
| 225 |
+
tune_config=tune.TuneConfig(
|
| 226 |
+
num_samples=4
|
| 227 |
+
),
|
| 228 |
+
param_space={
|
| 229 |
+
"a": tune.grid_search([1, 2]),
|
| 230 |
+
"b": tune.randint(0, 3)
|
| 231 |
+
},
|
| 232 |
+
)
|
| 233 |
+
tuner.fit()
|
| 234 |
+
|
| 235 |
+
In the example above, 8 trials will be generated: For each sample
|
| 236 |
+
(``4``), each of the grid search variants for ``a`` will be sampled
|
| 237 |
+
once. The ``b`` parameter will be sampled randomly.
|
| 238 |
+
|
| 239 |
+
The generator accepts a pre-set list of points that should be evaluated.
|
| 240 |
+
The points will replace the first samples of each experiment passed to
|
| 241 |
+
the ``BasicVariantGenerator``.
|
| 242 |
+
|
| 243 |
+
Each point will replace one sample of the specified ``num_samples``. If
|
| 244 |
+
grid search variables are overwritten with the values specified in the
|
| 245 |
+
presets, the number of samples will thus be reduced.
|
| 246 |
+
|
| 247 |
+
Example:
|
| 248 |
+
|
| 249 |
+
.. code-block:: python
|
| 250 |
+
|
| 251 |
+
from ray import tune
|
| 252 |
+
from ray.tune.search.basic_variant import BasicVariantGenerator
|
| 253 |
+
|
| 254 |
+
tuner = tune.Tuner(
|
| 255 |
+
lambda config: config["a"] + config["b"],
|
| 256 |
+
tune_config=tune.TuneConfig(
|
| 257 |
+
search_alg=BasicVariantGenerator(points_to_evaluate=[
|
| 258 |
+
{"a": 2, "b": 2},
|
| 259 |
+
{"a": 1},
|
| 260 |
+
{"b": 2}
|
| 261 |
+
]),
|
| 262 |
+
num_samples=4
|
| 263 |
+
),
|
| 264 |
+
param_space={
|
| 265 |
+
"a": tune.grid_search([1, 2]),
|
| 266 |
+
"b": tune.randint(0, 3)
|
| 267 |
+
},
|
| 268 |
+
)
|
| 269 |
+
tuner.fit()
|
| 270 |
+
|
| 271 |
+
The example above will produce six trials via four samples:
|
| 272 |
+
|
| 273 |
+
- The first sample will produce one trial with ``a=2`` and ``b=2``.
|
| 274 |
+
- The second sample will produce one trial with ``a=1`` and ``b`` sampled
|
| 275 |
+
randomly
|
| 276 |
+
- The third sample will produce two trials, one for each grid search
|
| 277 |
+
value of ``a``. It will be ``b=2`` for both of these trials.
|
| 278 |
+
- The fourth sample will produce two trials, one for each grid search
|
| 279 |
+
value of ``a``. ``b`` will be sampled randomly and independently for
|
| 280 |
+
both of these trials.
|
| 281 |
+
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
CKPT_FILE_TMPL = "basic-variant-state-{}.json"
|
| 285 |
+
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
points_to_evaluate: Optional[List[Dict]] = None,
|
| 289 |
+
max_concurrent: int = 0,
|
| 290 |
+
constant_grid_search: bool = False,
|
| 291 |
+
random_state: Optional[
|
| 292 |
+
Union[int, "np_random_generator", np.random.RandomState]
|
| 293 |
+
] = None,
|
| 294 |
+
):
|
| 295 |
+
tag_searcher(self)
|
| 296 |
+
self._trial_generator = []
|
| 297 |
+
self._iterators = []
|
| 298 |
+
self._trial_iter = None
|
| 299 |
+
self._finished = False
|
| 300 |
+
self._random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 301 |
+
|
| 302 |
+
self._points_to_evaluate = points_to_evaluate or []
|
| 303 |
+
|
| 304 |
+
# Unique prefix for all trials generated, e.g., trial ids start as
|
| 305 |
+
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
| 306 |
+
force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
|
| 307 |
+
if force_test_uuid:
|
| 308 |
+
self._uuid_prefix = force_test_uuid + "_"
|
| 309 |
+
else:
|
| 310 |
+
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
|
| 311 |
+
|
| 312 |
+
self._total_samples = 0
|
| 313 |
+
self.max_concurrent = max_concurrent
|
| 314 |
+
self._constant_grid_search = constant_grid_search
|
| 315 |
+
self._live_trials = set()
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def total_samples(self):
|
| 319 |
+
return self._total_samples
|
| 320 |
+
|
| 321 |
+
def add_configurations(
|
| 322 |
+
self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
|
| 323 |
+
):
|
| 324 |
+
"""Chains generator given experiment specifications.
|
| 325 |
+
|
| 326 |
+
Arguments:
|
| 327 |
+
experiments: Experiments to run.
|
| 328 |
+
"""
|
| 329 |
+
from ray.tune.experiment import _convert_to_experiment_list
|
| 330 |
+
|
| 331 |
+
experiment_list = _convert_to_experiment_list(experiments)
|
| 332 |
+
|
| 333 |
+
for experiment in experiment_list:
|
| 334 |
+
grid_vals = _count_spec_samples(experiment.spec, num_samples=1)
|
| 335 |
+
lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
|
| 336 |
+
if lazy_eval:
|
| 337 |
+
warnings.warn(
|
| 338 |
+
f"The number of pre-generated samples ({grid_vals}) "
|
| 339 |
+
"exceeds the serialization threshold "
|
| 340 |
+
f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
|
| 341 |
+
"disabled. To fix this, reduce the number of "
|
| 342 |
+
"dimensions/size of the provided grid search."
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
previous_samples = self._total_samples
|
| 346 |
+
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
|
| 347 |
+
self._total_samples += _count_variants(experiment.spec, points_to_evaluate)
|
| 348 |
+
iterator = _TrialIterator(
|
| 349 |
+
uuid_prefix=self._uuid_prefix,
|
| 350 |
+
num_samples=experiment.spec.get("num_samples", 1),
|
| 351 |
+
unresolved_spec=experiment.spec,
|
| 352 |
+
constant_grid_search=self._constant_grid_search,
|
| 353 |
+
points_to_evaluate=points_to_evaluate,
|
| 354 |
+
lazy_eval=lazy_eval,
|
| 355 |
+
start=previous_samples,
|
| 356 |
+
random_state=self._random_state,
|
| 357 |
+
)
|
| 358 |
+
self._iterators.append(iterator)
|
| 359 |
+
self._trial_generator = itertools.chain(self._trial_generator, iterator)
|
| 360 |
+
|
| 361 |
+
def next_trial(self):
|
| 362 |
+
"""Provides one Trial object to be queued into the TrialRunner.
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
Trial: Returns a single trial.
|
| 366 |
+
"""
|
| 367 |
+
if self.is_finished():
|
| 368 |
+
return None
|
| 369 |
+
if self.max_concurrent > 0 and len(self._live_trials) >= self.max_concurrent:
|
| 370 |
+
return None
|
| 371 |
+
if not self._trial_iter:
|
| 372 |
+
self._trial_iter = iter(self._trial_generator)
|
| 373 |
+
try:
|
| 374 |
+
trial = next(self._trial_iter)
|
| 375 |
+
self._live_trials.add(trial.trial_id)
|
| 376 |
+
return trial
|
| 377 |
+
except StopIteration:
|
| 378 |
+
self._trial_generator = []
|
| 379 |
+
self._trial_iter = None
|
| 380 |
+
self.set_finished()
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
def on_trial_complete(
|
| 384 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 385 |
+
):
|
| 386 |
+
if trial_id in self._live_trials:
|
| 387 |
+
self._live_trials.remove(trial_id)
|
| 388 |
+
|
| 389 |
+
def get_state(self):
|
| 390 |
+
if any(iterator.lazy_eval for iterator in self._iterators):
|
| 391 |
+
return False
|
| 392 |
+
state = self.__dict__.copy()
|
| 393 |
+
del state["_trial_generator"]
|
| 394 |
+
return state
|
| 395 |
+
|
| 396 |
+
def set_state(self, state):
|
| 397 |
+
self.__dict__.update(state)
|
| 398 |
+
for iterator in self._iterators:
|
| 399 |
+
self._trial_generator = itertools.chain(self._trial_generator, iterator)
|
| 400 |
+
|
| 401 |
+
def save_to_dir(self, dirpath, session_str):
|
| 402 |
+
if any(iterator.lazy_eval for iterator in self._iterators):
|
| 403 |
+
return False
|
| 404 |
+
state_dict = self.get_state()
|
| 405 |
+
_atomic_save(
|
| 406 |
+
state=state_dict,
|
| 407 |
+
checkpoint_dir=dirpath,
|
| 408 |
+
file_name=self.CKPT_FILE_TMPL.format(session_str),
|
| 409 |
+
tmp_file_name=".tmp_generator",
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def has_checkpoint(self, dirpath: str):
|
| 413 |
+
"""Whether a checkpoint file exists within dirpath."""
|
| 414 |
+
return any(Path(dirpath).glob(self.CKPT_FILE_TMPL.format("*")))
|
| 415 |
+
|
| 416 |
+
def restore_from_dir(self, dirpath: str):
|
| 417 |
+
"""Restores self + searcher + search wrappers from dirpath."""
|
| 418 |
+
state_dict = _load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*"))
|
| 419 |
+
if not state_dict:
|
| 420 |
+
raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
|
| 421 |
+
self.set_state(state_dict)
|
.venv/lib/python3.11/site-packages/ray/tune/search/concurrency_limiter.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from ray.tune.search.searcher import Searcher
|
| 6 |
+
from ray.tune.search.util import _set_search_properties_backwards_compatible
|
| 7 |
+
from ray.util.annotations import PublicAPI
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@PublicAPI
|
| 13 |
+
class ConcurrencyLimiter(Searcher):
|
| 14 |
+
"""A wrapper algorithm for limiting the number of concurrent trials.
|
| 15 |
+
|
| 16 |
+
Certain Searchers have their own internal logic for limiting
|
| 17 |
+
the number of concurrent trials. If such a Searcher is passed to a
|
| 18 |
+
``ConcurrencyLimiter``, the ``max_concurrent`` of the
|
| 19 |
+
``ConcurrencyLimiter`` will override the ``max_concurrent`` value
|
| 20 |
+
of the Searcher. The ``ConcurrencyLimiter`` will then let the
|
| 21 |
+
Searcher's internal logic take over.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
searcher: Searcher object that the
|
| 25 |
+
ConcurrencyLimiter will manage.
|
| 26 |
+
max_concurrent: Maximum concurrent samples from the underlying
|
| 27 |
+
searcher.
|
| 28 |
+
batch: Whether to wait for all concurrent samples
|
| 29 |
+
to finish before updating the underlying searcher.
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
|
| 33 |
+
.. code-block:: python
|
| 34 |
+
|
| 35 |
+
from ray.tune.search import ConcurrencyLimiter
|
| 36 |
+
search_alg = HyperOptSearch(metric="accuracy")
|
| 37 |
+
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
|
| 38 |
+
tuner = tune.Tuner(
|
| 39 |
+
trainable,
|
| 40 |
+
tune_config=tune.TuneConfig(
|
| 41 |
+
search_alg=search_alg
|
| 42 |
+
),
|
| 43 |
+
)
|
| 44 |
+
tuner.fit()
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
|
| 48 |
+
assert type(max_concurrent) is int and max_concurrent > 0
|
| 49 |
+
self.searcher = searcher
|
| 50 |
+
self.max_concurrent = max_concurrent
|
| 51 |
+
self.batch = batch
|
| 52 |
+
self.live_trials = set()
|
| 53 |
+
self.num_unfinished_live_trials = 0
|
| 54 |
+
self.cached_results = {}
|
| 55 |
+
self._limit_concurrency = True
|
| 56 |
+
|
| 57 |
+
if not isinstance(searcher, Searcher):
|
| 58 |
+
raise RuntimeError(
|
| 59 |
+
f"The `ConcurrencyLimiter` only works with `Searcher` "
|
| 60 |
+
f"objects (got {type(searcher)}). Please try to pass "
|
| 61 |
+
f"`max_concurrent` to the search generator directly."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self._set_searcher_max_concurrency()
|
| 65 |
+
|
| 66 |
+
super(ConcurrencyLimiter, self).__init__(
|
| 67 |
+
metric=self.searcher.metric, mode=self.searcher.mode
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def _set_searcher_max_concurrency(self):
|
| 71 |
+
# If the searcher has special logic for handling max concurrency,
|
| 72 |
+
# we do not do anything inside the ConcurrencyLimiter
|
| 73 |
+
self._limit_concurrency = not self.searcher.set_max_concurrency(
|
| 74 |
+
self.max_concurrent
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def set_max_concurrency(self, max_concurrent: int) -> bool:
|
| 78 |
+
# Determine if this behavior is acceptable, or if it should
|
| 79 |
+
# raise an exception.
|
| 80 |
+
self.max_concurrent = max_concurrent
|
| 81 |
+
return True
|
| 82 |
+
|
| 83 |
+
def set_search_properties(
|
| 84 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 85 |
+
) -> bool:
|
| 86 |
+
self._set_searcher_max_concurrency()
|
| 87 |
+
return _set_search_properties_backwards_compatible(
|
| 88 |
+
self.searcher.set_search_properties, metric, mode, config, **spec
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 92 |
+
if not self._limit_concurrency:
|
| 93 |
+
return self.searcher.suggest(trial_id)
|
| 94 |
+
|
| 95 |
+
assert (
|
| 96 |
+
trial_id not in self.live_trials
|
| 97 |
+
), f"Trial ID {trial_id} must be unique: already found in set."
|
| 98 |
+
if len(self.live_trials) >= self.max_concurrent:
|
| 99 |
+
logger.debug(
|
| 100 |
+
f"Not providing a suggestion for {trial_id} due to "
|
| 101 |
+
"concurrency limit: %s/%s.",
|
| 102 |
+
len(self.live_trials),
|
| 103 |
+
self.max_concurrent,
|
| 104 |
+
)
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
suggestion = self.searcher.suggest(trial_id)
|
| 108 |
+
if suggestion not in (None, Searcher.FINISHED):
|
| 109 |
+
self.live_trials.add(trial_id)
|
| 110 |
+
self.num_unfinished_live_trials += 1
|
| 111 |
+
return suggestion
|
| 112 |
+
|
| 113 |
+
def on_trial_complete(
|
| 114 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 115 |
+
):
|
| 116 |
+
if not self._limit_concurrency:
|
| 117 |
+
return self.searcher.on_trial_complete(trial_id, result=result, error=error)
|
| 118 |
+
|
| 119 |
+
if trial_id not in self.live_trials:
|
| 120 |
+
return
|
| 121 |
+
elif self.batch:
|
| 122 |
+
self.cached_results[trial_id] = (result, error)
|
| 123 |
+
self.num_unfinished_live_trials -= 1
|
| 124 |
+
if self.num_unfinished_live_trials <= 0:
|
| 125 |
+
# Update the underlying searcher once the
|
| 126 |
+
# full batch is completed.
|
| 127 |
+
for trial_id, (result, error) in self.cached_results.items():
|
| 128 |
+
self.searcher.on_trial_complete(
|
| 129 |
+
trial_id, result=result, error=error
|
| 130 |
+
)
|
| 131 |
+
self.live_trials.remove(trial_id)
|
| 132 |
+
self.cached_results = {}
|
| 133 |
+
self.num_unfinished_live_trials = 0
|
| 134 |
+
else:
|
| 135 |
+
return
|
| 136 |
+
else:
|
| 137 |
+
self.searcher.on_trial_complete(trial_id, result=result, error=error)
|
| 138 |
+
self.live_trials.remove(trial_id)
|
| 139 |
+
self.num_unfinished_live_trials -= 1
|
| 140 |
+
|
| 141 |
+
def on_trial_result(self, trial_id: str, result: Dict) -> None:
|
| 142 |
+
self.searcher.on_trial_result(trial_id, result)
|
| 143 |
+
|
| 144 |
+
def add_evaluated_point(
|
| 145 |
+
self,
|
| 146 |
+
parameters: Dict,
|
| 147 |
+
value: float,
|
| 148 |
+
error: bool = False,
|
| 149 |
+
pruned: bool = False,
|
| 150 |
+
intermediate_values: Optional[List[float]] = None,
|
| 151 |
+
):
|
| 152 |
+
return self.searcher.add_evaluated_point(
|
| 153 |
+
parameters, value, error, pruned, intermediate_values
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def get_state(self) -> Dict:
|
| 157 |
+
state = self.__dict__.copy()
|
| 158 |
+
del state["searcher"]
|
| 159 |
+
return copy.deepcopy(state)
|
| 160 |
+
|
| 161 |
+
def set_state(self, state: Dict):
|
| 162 |
+
self.__dict__.update(state)
|
| 163 |
+
|
| 164 |
+
def save(self, checkpoint_path: str):
|
| 165 |
+
self.searcher.save(checkpoint_path)
|
| 166 |
+
|
| 167 |
+
def restore(self, checkpoint_path: str):
|
| 168 |
+
self.searcher.restore(checkpoint_path)
|
| 169 |
+
|
| 170 |
+
# BOHB Specific.
|
| 171 |
+
# TODO(team-ml): Refactor alongside HyperBandForBOHB
|
| 172 |
+
def on_pause(self, trial_id: str):
|
| 173 |
+
self.searcher.on_pause(trial_id)
|
| 174 |
+
|
| 175 |
+
def on_unpause(self, trial_id: str):
|
| 176 |
+
self.searcher.on_unpause(trial_id)
|
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.search.hebo.hebo_search import HEBOSearch
|
| 2 |
+
|
| 3 |
+
__all__ = ["HEBOSearch"]
|
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (299 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/hebo/__pycache__/hebo_search.cpython-311.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.search.nevergrad.nevergrad_search import NevergradSearch
|
| 2 |
+
|
| 3 |
+
__all__ = ["NevergradSearch"]
|
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (320 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/__pycache__/nevergrad_search.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/nevergrad/nevergrad_search.py
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import Dict, List, Optional, Sequence, Type, Union
|
| 5 |
+
|
| 6 |
+
from ray.tune.result import DEFAULT_METRIC
|
| 7 |
+
from ray.tune.search import (
|
| 8 |
+
UNDEFINED_METRIC_MODE,
|
| 9 |
+
UNDEFINED_SEARCH_SPACE,
|
| 10 |
+
UNRESOLVED_SEARCH_SPACE,
|
| 11 |
+
Searcher,
|
| 12 |
+
)
|
| 13 |
+
from ray.tune.search.sample import (
|
| 14 |
+
Categorical,
|
| 15 |
+
Domain,
|
| 16 |
+
Float,
|
| 17 |
+
Integer,
|
| 18 |
+
LogUniform,
|
| 19 |
+
Quantized,
|
| 20 |
+
)
|
| 21 |
+
from ray.tune.search.variant_generator import parse_spec_vars
|
| 22 |
+
from ray.tune.utils.util import flatten_dict, unflatten_dict
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import nevergrad as ng
|
| 26 |
+
from nevergrad.optimization import Optimizer
|
| 27 |
+
from nevergrad.optimization.base import ConfiguredOptimizer
|
| 28 |
+
|
| 29 |
+
Parameter = ng.p.Parameter
|
| 30 |
+
except ImportError:
|
| 31 |
+
ng = None
|
| 32 |
+
Optimizer = None
|
| 33 |
+
ConfiguredOptimizer = None
|
| 34 |
+
Parameter = None
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class NevergradSearch(Searcher):
|
| 40 |
+
"""Uses Nevergrad to optimize hyperparameters.
|
| 41 |
+
|
| 42 |
+
Nevergrad is an open source tool from Facebook for derivative free
|
| 43 |
+
optimization. More info can be found at:
|
| 44 |
+
https://github.com/facebookresearch/nevergrad.
|
| 45 |
+
|
| 46 |
+
You will need to install Nevergrad via the following command:
|
| 47 |
+
|
| 48 |
+
.. code-block:: bash
|
| 49 |
+
|
| 50 |
+
$ pip install nevergrad
|
| 51 |
+
|
| 52 |
+
Parameters:
|
| 53 |
+
optimizer: Optimizer class provided from Nevergrad.
|
| 54 |
+
See here for available optimizers:
|
| 55 |
+
https://facebookresearch.github.io/nevergrad/optimizers_ref.html#optimizers
|
| 56 |
+
This can also be an instance of a `ConfiguredOptimizer`. See the
|
| 57 |
+
section on configured optimizers in the above link.
|
| 58 |
+
optimizer_kwargs: Kwargs passed in when instantiating the `optimizer`
|
| 59 |
+
space: Nevergrad parametrization
|
| 60 |
+
to be passed to optimizer on instantiation, or list of parameter
|
| 61 |
+
names if you passed an optimizer object.
|
| 62 |
+
metric: The training result objective value attribute. If None
|
| 63 |
+
but a mode was passed, the anonymous metric `_metric` will be used
|
| 64 |
+
per default.
|
| 65 |
+
mode: One of {min, max}. Determines whether objective is
|
| 66 |
+
minimizing or maximizing the metric attribute.
|
| 67 |
+
points_to_evaluate: Initial parameter suggestions to be run
|
| 68 |
+
first. This is for when you already have some good parameters
|
| 69 |
+
you want to run first to help the algorithm make better suggestions
|
| 70 |
+
for future parameters. Needs to be a list of dicts containing the
|
| 71 |
+
configurations.
|
| 72 |
+
|
| 73 |
+
Tune automatically converts search spaces to Nevergrad's format:
|
| 74 |
+
|
| 75 |
+
.. code-block:: python
|
| 76 |
+
|
| 77 |
+
import nevergrad as ng
|
| 78 |
+
|
| 79 |
+
config = {
|
| 80 |
+
"width": tune.uniform(0, 20),
|
| 81 |
+
"height": tune.uniform(-100, 100),
|
| 82 |
+
"activation": tune.choice(["relu", "tanh"])
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
current_best_params = [{
|
| 86 |
+
"width": 10,
|
| 87 |
+
"height": 0,
|
| 88 |
+
"activation": relu",
|
| 89 |
+
}]
|
| 90 |
+
|
| 91 |
+
ng_search = NevergradSearch(
|
| 92 |
+
optimizer=ng.optimizers.OnePlusOne,
|
| 93 |
+
metric="mean_loss",
|
| 94 |
+
mode="min",
|
| 95 |
+
points_to_evaluate=current_best_params)
|
| 96 |
+
|
| 97 |
+
run(my_trainable, config=config, search_alg=ng_search)
|
| 98 |
+
|
| 99 |
+
If you would like to pass the search space manually, the code would
|
| 100 |
+
look like this:
|
| 101 |
+
|
| 102 |
+
.. code-block:: python
|
| 103 |
+
|
| 104 |
+
import nevergrad as ng
|
| 105 |
+
|
| 106 |
+
space = ng.p.Dict(
|
| 107 |
+
width=ng.p.Scalar(lower=0, upper=20),
|
| 108 |
+
height=ng.p.Scalar(lower=-100, upper=100),
|
| 109 |
+
activation=ng.p.Choice(choices=["relu", "tanh"])
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
ng_search = NevergradSearch(
|
| 113 |
+
optimizer=ng.optimizers.OnePlusOne,
|
| 114 |
+
space=space,
|
| 115 |
+
metric="mean_loss",
|
| 116 |
+
mode="min")
|
| 117 |
+
|
| 118 |
+
run(my_trainable, search_alg=ng_search)
|
| 119 |
+
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
optimizer: Optional[
|
| 125 |
+
Union[Optimizer, Type[Optimizer], ConfiguredOptimizer]
|
| 126 |
+
] = None,
|
| 127 |
+
optimizer_kwargs: Optional[Dict] = None,
|
| 128 |
+
space: Optional[Union[Dict, Parameter]] = None,
|
| 129 |
+
metric: Optional[str] = None,
|
| 130 |
+
mode: Optional[str] = None,
|
| 131 |
+
points_to_evaluate: Optional[List[Dict]] = None,
|
| 132 |
+
):
|
| 133 |
+
assert (
|
| 134 |
+
ng is not None
|
| 135 |
+
), """Nevergrad must be installed!
|
| 136 |
+
You can install Nevergrad with the command:
|
| 137 |
+
`pip install nevergrad`."""
|
| 138 |
+
if mode:
|
| 139 |
+
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
| 140 |
+
|
| 141 |
+
super(NevergradSearch, self).__init__(metric=metric, mode=mode)
|
| 142 |
+
|
| 143 |
+
self._space = None
|
| 144 |
+
self._opt_factory = None
|
| 145 |
+
self._nevergrad_opt = None
|
| 146 |
+
self._optimizer_kwargs = optimizer_kwargs or {}
|
| 147 |
+
|
| 148 |
+
if points_to_evaluate is None:
|
| 149 |
+
self._points_to_evaluate = None
|
| 150 |
+
elif not isinstance(points_to_evaluate, Sequence):
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"Invalid object type passed for `points_to_evaluate`: "
|
| 153 |
+
f"{type(points_to_evaluate)}. "
|
| 154 |
+
"Please pass a list of points (dictionaries) instead."
|
| 155 |
+
)
|
| 156 |
+
else:
|
| 157 |
+
self._points_to_evaluate = list(points_to_evaluate)
|
| 158 |
+
|
| 159 |
+
if isinstance(space, dict) and space:
|
| 160 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
| 161 |
+
if domain_vars or grid_vars:
|
| 162 |
+
logger.warning(
|
| 163 |
+
UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self))
|
| 164 |
+
)
|
| 165 |
+
space = self.convert_search_space(space)
|
| 166 |
+
|
| 167 |
+
if isinstance(optimizer, Optimizer):
|
| 168 |
+
if space is not None and not isinstance(space, list):
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"If you pass a configured optimizer to Nevergrad, either "
|
| 171 |
+
"pass a list of parameter names or None as the `space` "
|
| 172 |
+
"parameter."
|
| 173 |
+
)
|
| 174 |
+
if self._optimizer_kwargs:
|
| 175 |
+
raise ValueError(
|
| 176 |
+
"If you pass in optimizer kwargs, either pass "
|
| 177 |
+
"an `Optimizer` subclass or an instance of "
|
| 178 |
+
"`ConfiguredOptimizer`."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self._parameters = space
|
| 182 |
+
self._nevergrad_opt = optimizer
|
| 183 |
+
elif (
|
| 184 |
+
inspect.isclass(optimizer) and issubclass(optimizer, Optimizer)
|
| 185 |
+
) or isinstance(optimizer, ConfiguredOptimizer):
|
| 186 |
+
self._opt_factory = optimizer
|
| 187 |
+
self._parameters = None
|
| 188 |
+
self._space = space
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(
|
| 191 |
+
"The `optimizer` argument passed to NevergradSearch must be "
|
| 192 |
+
"either an `Optimizer` or a `ConfiguredOptimizer`."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self._live_trial_mapping = {}
|
| 196 |
+
|
| 197 |
+
if self._nevergrad_opt or self._space:
|
| 198 |
+
self._setup_nevergrad()
|
| 199 |
+
|
| 200 |
+
def _setup_nevergrad(self):
|
| 201 |
+
if self._opt_factory:
|
| 202 |
+
self._nevergrad_opt = self._opt_factory(
|
| 203 |
+
self._space, **self._optimizer_kwargs
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# nevergrad.tell internally minimizes, so "max" => -1
|
| 207 |
+
if self._mode == "max":
|
| 208 |
+
self._metric_op = -1.0
|
| 209 |
+
elif self._mode == "min":
|
| 210 |
+
self._metric_op = 1.0
|
| 211 |
+
|
| 212 |
+
if self._metric is None and self._mode:
|
| 213 |
+
# If only a mode was passed, use anonymous metric
|
| 214 |
+
self._metric = DEFAULT_METRIC
|
| 215 |
+
|
| 216 |
+
if hasattr(self._nevergrad_opt, "instrumentation"): # added in v0.2.0
|
| 217 |
+
if self._nevergrad_opt.instrumentation.kwargs:
|
| 218 |
+
if self._nevergrad_opt.instrumentation.args:
|
| 219 |
+
raise ValueError("Instrumented optimizers should use kwargs only")
|
| 220 |
+
if self._parameters is not None:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
"Instrumented optimizers should provide "
|
| 223 |
+
"None as parameter_names"
|
| 224 |
+
)
|
| 225 |
+
else:
|
| 226 |
+
if self._parameters is None:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"Non-instrumented optimizers should have "
|
| 229 |
+
"a list of parameter_names"
|
| 230 |
+
)
|
| 231 |
+
if len(self._nevergrad_opt.instrumentation.args) != 1:
|
| 232 |
+
raise ValueError("Instrumented optimizers should use kwargs only")
|
| 233 |
+
if self._parameters is not None and self._nevergrad_opt.dimension != len(
|
| 234 |
+
self._parameters
|
| 235 |
+
):
|
| 236 |
+
raise ValueError(
|
| 237 |
+
"len(parameters_names) must match optimizer "
|
| 238 |
+
"dimension for non-instrumented optimizers"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if self._points_to_evaluate:
|
| 242 |
+
# Nevergrad is LIFO, so we add the points to evaluate in reverse
|
| 243 |
+
# order.
|
| 244 |
+
for i in range(len(self._points_to_evaluate) - 1, -1, -1):
|
| 245 |
+
self._nevergrad_opt.suggest(self._points_to_evaluate[i])
|
| 246 |
+
|
| 247 |
+
def set_search_properties(
|
| 248 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 249 |
+
) -> bool:
|
| 250 |
+
if self._nevergrad_opt or self._space:
|
| 251 |
+
return False
|
| 252 |
+
space = self.convert_search_space(config)
|
| 253 |
+
self._space = space
|
| 254 |
+
|
| 255 |
+
if metric:
|
| 256 |
+
self._metric = metric
|
| 257 |
+
if mode:
|
| 258 |
+
self._mode = mode
|
| 259 |
+
|
| 260 |
+
self._setup_nevergrad()
|
| 261 |
+
return True
|
| 262 |
+
|
| 263 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 264 |
+
if not self._nevergrad_opt:
|
| 265 |
+
raise RuntimeError(
|
| 266 |
+
UNDEFINED_SEARCH_SPACE.format(
|
| 267 |
+
cls=self.__class__.__name__, space="space"
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
if not self._metric or not self._mode:
|
| 271 |
+
raise RuntimeError(
|
| 272 |
+
UNDEFINED_METRIC_MODE.format(
|
| 273 |
+
cls=self.__class__.__name__, metric=self._metric, mode=self._mode
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
suggested_config = self._nevergrad_opt.ask()
|
| 278 |
+
|
| 279 |
+
self._live_trial_mapping[trial_id] = suggested_config
|
| 280 |
+
# in v0.2.0+, output of ask() is a Candidate,
|
| 281 |
+
# with fields args and kwargs
|
| 282 |
+
if not suggested_config.kwargs:
|
| 283 |
+
if self._parameters:
|
| 284 |
+
return unflatten_dict(
|
| 285 |
+
dict(zip(self._parameters, suggested_config.args[0]))
|
| 286 |
+
)
|
| 287 |
+
return unflatten_dict(suggested_config.value)
|
| 288 |
+
else:
|
| 289 |
+
return unflatten_dict(suggested_config.kwargs)
|
| 290 |
+
|
| 291 |
+
def on_trial_complete(
|
| 292 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 293 |
+
):
|
| 294 |
+
"""Notification for the completion of trial.
|
| 295 |
+
|
| 296 |
+
The result is internally negated when interacting with Nevergrad
|
| 297 |
+
so that Nevergrad Optimizers can "maximize" this value,
|
| 298 |
+
as it minimizes on default.
|
| 299 |
+
"""
|
| 300 |
+
if result:
|
| 301 |
+
self._process_result(trial_id, result)
|
| 302 |
+
|
| 303 |
+
self._live_trial_mapping.pop(trial_id)
|
| 304 |
+
|
| 305 |
+
def _process_result(self, trial_id: str, result: Dict):
|
| 306 |
+
ng_trial_info = self._live_trial_mapping[trial_id]
|
| 307 |
+
self._nevergrad_opt.tell(ng_trial_info, self._metric_op * result[self._metric])
|
| 308 |
+
|
| 309 |
+
def save(self, checkpoint_path: str):
|
| 310 |
+
save_object = self.__dict__
|
| 311 |
+
with open(checkpoint_path, "wb") as outputFile:
|
| 312 |
+
pickle.dump(save_object, outputFile)
|
| 313 |
+
|
| 314 |
+
def restore(self, checkpoint_path: str):
|
| 315 |
+
with open(checkpoint_path, "rb") as inputFile:
|
| 316 |
+
save_object = pickle.load(inputFile)
|
| 317 |
+
self.__dict__.update(save_object)
|
| 318 |
+
|
| 319 |
+
@staticmethod
|
| 320 |
+
def convert_search_space(spec: Dict) -> Parameter:
|
| 321 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 322 |
+
|
| 323 |
+
if grid_vars:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
"Grid search parameters cannot be automatically converted "
|
| 326 |
+
"to a Nevergrad search space."
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Flatten and resolve again after checking for grid search.
|
| 330 |
+
spec = flatten_dict(spec, prevent_delimiter=True)
|
| 331 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 332 |
+
|
| 333 |
+
def resolve_value(domain: Domain) -> Parameter:
|
| 334 |
+
sampler = domain.get_sampler()
|
| 335 |
+
if isinstance(sampler, Quantized):
|
| 336 |
+
logger.warning(
|
| 337 |
+
"Nevergrad does not support quantization. Dropped quantization."
|
| 338 |
+
)
|
| 339 |
+
sampler = sampler.get_sampler()
|
| 340 |
+
|
| 341 |
+
if isinstance(domain, Float):
|
| 342 |
+
if isinstance(sampler, LogUniform):
|
| 343 |
+
return ng.p.Log(
|
| 344 |
+
lower=domain.lower, upper=domain.upper, exponent=sampler.base
|
| 345 |
+
)
|
| 346 |
+
return ng.p.Scalar(lower=domain.lower, upper=domain.upper)
|
| 347 |
+
|
| 348 |
+
elif isinstance(domain, Integer):
|
| 349 |
+
if isinstance(sampler, LogUniform):
|
| 350 |
+
return ng.p.Log(
|
| 351 |
+
lower=domain.lower,
|
| 352 |
+
upper=domain.upper - 1, # Upper bound exclusive
|
| 353 |
+
exponent=sampler.base,
|
| 354 |
+
).set_integer_casting()
|
| 355 |
+
return ng.p.Scalar(
|
| 356 |
+
lower=domain.lower,
|
| 357 |
+
upper=domain.upper - 1, # Upper bound exclusive
|
| 358 |
+
).set_integer_casting()
|
| 359 |
+
|
| 360 |
+
elif isinstance(domain, Categorical):
|
| 361 |
+
return ng.p.Choice(choices=domain.categories)
|
| 362 |
+
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"Nevergrad does not support parameters of type "
|
| 365 |
+
"`{}` with samplers of type `{}`".format(
|
| 366 |
+
type(domain).__name__, type(domain.sampler).__name__
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Parameter name is e.g. "a/b/c" for nested dicts
|
| 371 |
+
space = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
|
| 372 |
+
|
| 373 |
+
return ng.p.Dict(**space)
|
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.search.optuna.optuna_search import OptunaSearch
|
| 2 |
+
|
| 3 |
+
__all__ = ["OptunaSearch"]
|
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (308 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/__pycache__/optuna_search.cpython-311.pyc
ADDED
|
Binary file (31.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/search/optuna/optuna_search.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import pickle
|
| 4 |
+
import time
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
from packaging import version
|
| 9 |
+
|
| 10 |
+
from ray.air.constants import TRAINING_ITERATION
|
| 11 |
+
from ray.tune.result import DEFAULT_METRIC
|
| 12 |
+
from ray.tune.search import (
|
| 13 |
+
UNDEFINED_METRIC_MODE,
|
| 14 |
+
UNDEFINED_SEARCH_SPACE,
|
| 15 |
+
UNRESOLVED_SEARCH_SPACE,
|
| 16 |
+
Searcher,
|
| 17 |
+
)
|
| 18 |
+
from ray.tune.search.sample import (
|
| 19 |
+
Categorical,
|
| 20 |
+
Domain,
|
| 21 |
+
Float,
|
| 22 |
+
Integer,
|
| 23 |
+
LogUniform,
|
| 24 |
+
Quantized,
|
| 25 |
+
Uniform,
|
| 26 |
+
)
|
| 27 |
+
from ray.tune.search.variant_generator import parse_spec_vars
|
| 28 |
+
from ray.tune.utils.util import flatten_dict, unflatten_dict, validate_warmstart
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import optuna as ot
|
| 32 |
+
from optuna.distributions import BaseDistribution as OptunaDistribution
|
| 33 |
+
from optuna.samplers import BaseSampler
|
| 34 |
+
from optuna.storages import BaseStorage
|
| 35 |
+
from optuna.trial import Trial as OptunaTrial
|
| 36 |
+
from optuna.trial import TrialState as OptunaTrialState
|
| 37 |
+
except ImportError:
|
| 38 |
+
ot = None
|
| 39 |
+
OptunaDistribution = None
|
| 40 |
+
BaseSampler = None
|
| 41 |
+
BaseStorage = None
|
| 42 |
+
OptunaTrialState = None
|
| 43 |
+
OptunaTrial = None
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
# print a warning if define by run function takes longer than this to execute
|
| 48 |
+
DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _OptunaTrialSuggestCaptor:
|
| 52 |
+
"""Utility to capture returned values from Optuna's suggest_ methods.
|
| 53 |
+
|
| 54 |
+
This will wrap around the ``optuna.Trial` object and decorate all
|
| 55 |
+
`suggest_` callables with a function capturing the returned value,
|
| 56 |
+
which will be saved in the ``captured_values`` dict.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, ot_trial: OptunaTrial) -> None:
|
| 60 |
+
self.ot_trial = ot_trial
|
| 61 |
+
self.captured_values: Dict[str, Any] = {}
|
| 62 |
+
|
| 63 |
+
def _get_wrapper(self, func: Callable) -> Callable:
|
| 64 |
+
@functools.wraps(func)
|
| 65 |
+
def wrapper(*args, **kwargs):
|
| 66 |
+
# name is always the first arg for suggest_ methods
|
| 67 |
+
name = kwargs.get("name", args[0])
|
| 68 |
+
ret = func(*args, **kwargs)
|
| 69 |
+
self.captured_values[name] = ret
|
| 70 |
+
return ret
|
| 71 |
+
|
| 72 |
+
return wrapper
|
| 73 |
+
|
| 74 |
+
def __getattr__(self, item_name: str) -> Any:
|
| 75 |
+
item = getattr(self.ot_trial, item_name)
|
| 76 |
+
if item_name.startswith("suggest_") and callable(item):
|
| 77 |
+
return self._get_wrapper(item)
|
| 78 |
+
return item
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class OptunaSearch(Searcher):
|
| 82 |
+
"""A wrapper around Optuna to provide trial suggestions.
|
| 83 |
+
|
| 84 |
+
`Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
|
| 85 |
+
In contrast to other libraries, it employs define-by-run style
|
| 86 |
+
hyperparameter definitions.
|
| 87 |
+
|
| 88 |
+
This Searcher is a thin wrapper around Optuna's search algorithms.
|
| 89 |
+
You can pass any Optuna sampler, which will be used to generate
|
| 90 |
+
hyperparameter suggestions.
|
| 91 |
+
|
| 92 |
+
Multi-objective optimization is supported.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
space: Hyperparameter search space definition for
|
| 96 |
+
Optuna's sampler. This can be either a :class:`dict` with
|
| 97 |
+
parameter names as keys and ``optuna.distributions`` as values,
|
| 98 |
+
or a Callable - in which case, it should be a define-by-run
|
| 99 |
+
function using ``optuna.trial`` to obtain the hyperparameter
|
| 100 |
+
values. The function should return either a :class:`dict` of
|
| 101 |
+
constant values with names as keys, or None.
|
| 102 |
+
For more information, see https://optuna.readthedocs.io\
|
| 103 |
+
/en/stable/tutorial/10_key_features/002_configurations.html.
|
| 104 |
+
|
| 105 |
+
.. warning::
|
| 106 |
+
No actual computation should take place in the define-by-run
|
| 107 |
+
function. Instead, put the training logic inside the function
|
| 108 |
+
or class trainable passed to ``tune.Tuner()``.
|
| 109 |
+
|
| 110 |
+
metric: The training result objective value attribute. If
|
| 111 |
+
None but a mode was passed, the anonymous metric ``_metric``
|
| 112 |
+
will be used per default. Can be a list of metrics for
|
| 113 |
+
multi-objective optimization.
|
| 114 |
+
mode: One of {min, max}. Determines whether objective is
|
| 115 |
+
minimizing or maximizing the metric attribute. Can be a list of
|
| 116 |
+
modes for multi-objective optimization (corresponding to
|
| 117 |
+
``metric``).
|
| 118 |
+
points_to_evaluate: Initial parameter suggestions to be run
|
| 119 |
+
first. This is for when you already have some good parameters
|
| 120 |
+
you want to run first to help the algorithm make better suggestions
|
| 121 |
+
for future parameters. Needs to be a list of dicts containing the
|
| 122 |
+
configurations.
|
| 123 |
+
sampler: Optuna sampler used to
|
| 124 |
+
draw hyperparameter configurations. Defaults to ``MOTPESampler``
|
| 125 |
+
for multi-objective optimization with Optuna<2.9.0, and
|
| 126 |
+
``TPESampler`` in every other case.
|
| 127 |
+
See https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
|
| 128 |
+
for available Optuna samplers.
|
| 129 |
+
|
| 130 |
+
.. warning::
|
| 131 |
+
Please note that with Optuna 2.10.0 and earlier
|
| 132 |
+
default ``MOTPESampler``/``TPESampler`` suffer
|
| 133 |
+
from performance issues when dealing with a large number of
|
| 134 |
+
completed trials (approx. >100). This will manifest as
|
| 135 |
+
a delay when suggesting new configurations.
|
| 136 |
+
This is an Optuna issue and may be fixed in a future
|
| 137 |
+
Optuna release.
|
| 138 |
+
study_name: Optuna study name that uniquely identifies the trial
|
| 139 |
+
results. Defaults to ``"optuna"``.
|
| 140 |
+
storage: Optuna storage used for storing trial results to
|
| 141 |
+
storages other than in-memory storage,
|
| 142 |
+
for instance optuna.storages.RDBStorage.
|
| 143 |
+
seed: Seed to initialize sampler with. This parameter is only
|
| 144 |
+
used when ``sampler=None``. In all other cases, the sampler
|
| 145 |
+
you pass should be initialized with the seed already.
|
| 146 |
+
evaluated_rewards: If you have previously evaluated the
|
| 147 |
+
parameters passed in as points_to_evaluate you can avoid
|
| 148 |
+
re-running those trials by passing in the reward attributes
|
| 149 |
+
as a list so the optimiser can be told the results without
|
| 150 |
+
needing to re-compute the trial. Must be the same length as
|
| 151 |
+
points_to_evaluate.
|
| 152 |
+
|
| 153 |
+
.. warning::
|
| 154 |
+
When using ``evaluated_rewards``, the search space ``space``
|
| 155 |
+
must be provided as a :class:`dict` with parameter names as
|
| 156 |
+
keys and ``optuna.distributions`` instances as values. The
|
| 157 |
+
define-by-run search space definition is not yet supported with
|
| 158 |
+
this functionality.
|
| 159 |
+
|
| 160 |
+
Tune automatically converts search spaces to Optuna's format:
|
| 161 |
+
|
| 162 |
+
.. code-block:: python
|
| 163 |
+
|
| 164 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 165 |
+
|
| 166 |
+
config = {
|
| 167 |
+
"a": tune.uniform(6, 8)
|
| 168 |
+
"b": tune.loguniform(1e-4, 1e-2)
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
optuna_search = OptunaSearch(
|
| 172 |
+
metric="loss",
|
| 173 |
+
mode="min")
|
| 174 |
+
|
| 175 |
+
tuner = tune.Tuner(
|
| 176 |
+
trainable,
|
| 177 |
+
tune_config=tune.TuneConfig(
|
| 178 |
+
search_alg=optuna_search,
|
| 179 |
+
),
|
| 180 |
+
param_space=config,
|
| 181 |
+
)
|
| 182 |
+
tuner.fit()
|
| 183 |
+
|
| 184 |
+
If you would like to pass the search space manually, the code would
|
| 185 |
+
look like this:
|
| 186 |
+
|
| 187 |
+
.. code-block:: python
|
| 188 |
+
|
| 189 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 190 |
+
import optuna
|
| 191 |
+
|
| 192 |
+
space = {
|
| 193 |
+
"a": optuna.distributions.FloatDistribution(6, 8),
|
| 194 |
+
"b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
optuna_search = OptunaSearch(
|
| 198 |
+
space,
|
| 199 |
+
metric="loss",
|
| 200 |
+
mode="min")
|
| 201 |
+
|
| 202 |
+
tuner = tune.Tuner(
|
| 203 |
+
trainable,
|
| 204 |
+
tune_config=tune.TuneConfig(
|
| 205 |
+
search_alg=optuna_search,
|
| 206 |
+
),
|
| 207 |
+
)
|
| 208 |
+
tuner.fit()
|
| 209 |
+
|
| 210 |
+
# Equivalent Optuna define-by-run function approach:
|
| 211 |
+
|
| 212 |
+
def define_search_space(trial: optuna.Trial):
|
| 213 |
+
trial.suggest_float("a", 6, 8)
|
| 214 |
+
trial.suggest_float("b", 1e-4, 1e-2, log=True)
|
| 215 |
+
# training logic goes into trainable, this is just
|
| 216 |
+
# for search space definition
|
| 217 |
+
|
| 218 |
+
optuna_search = OptunaSearch(
|
| 219 |
+
define_search_space,
|
| 220 |
+
metric="loss",
|
| 221 |
+
mode="min")
|
| 222 |
+
|
| 223 |
+
tuner = tune.Tuner(
|
| 224 |
+
trainable,
|
| 225 |
+
tune_config=tune.TuneConfig(
|
| 226 |
+
search_alg=optuna_search,
|
| 227 |
+
),
|
| 228 |
+
)
|
| 229 |
+
tuner.fit()
|
| 230 |
+
|
| 231 |
+
Multi-objective optimization is supported:
|
| 232 |
+
|
| 233 |
+
.. code-block:: python
|
| 234 |
+
|
| 235 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 236 |
+
import optuna
|
| 237 |
+
|
| 238 |
+
space = {
|
| 239 |
+
"a": optuna.distributions.FloatDistribution(6, 8),
|
| 240 |
+
"b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
# Note you have to specify metric and mode here instead of
|
| 244 |
+
# in tune.TuneConfig
|
| 245 |
+
optuna_search = OptunaSearch(
|
| 246 |
+
space,
|
| 247 |
+
metric=["loss1", "loss2"],
|
| 248 |
+
mode=["min", "max"])
|
| 249 |
+
|
| 250 |
+
# Do not specify metric and mode here!
|
| 251 |
+
tuner = tune.Tuner(
|
| 252 |
+
trainable,
|
| 253 |
+
tune_config=tune.TuneConfig(
|
| 254 |
+
search_alg=optuna_search,
|
| 255 |
+
),
|
| 256 |
+
)
|
| 257 |
+
tuner.fit()
|
| 258 |
+
|
| 259 |
+
You can pass configs that will be evaluated first using
|
| 260 |
+
``points_to_evaluate``:
|
| 261 |
+
|
| 262 |
+
.. code-block:: python
|
| 263 |
+
|
| 264 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 265 |
+
import optuna
|
| 266 |
+
|
| 267 |
+
space = {
|
| 268 |
+
"a": optuna.distributions.FloatDistribution(6, 8),
|
| 269 |
+
"b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
optuna_search = OptunaSearch(
|
| 273 |
+
space,
|
| 274 |
+
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
| 275 |
+
metric="loss",
|
| 276 |
+
mode="min")
|
| 277 |
+
|
| 278 |
+
tuner = tune.Tuner(
|
| 279 |
+
trainable,
|
| 280 |
+
tune_config=tune.TuneConfig(
|
| 281 |
+
search_alg=optuna_search,
|
| 282 |
+
),
|
| 283 |
+
)
|
| 284 |
+
tuner.fit()
|
| 285 |
+
|
| 286 |
+
Avoid re-running evaluated trials by passing the rewards together with
|
| 287 |
+
`points_to_evaluate`:
|
| 288 |
+
|
| 289 |
+
.. code-block:: python
|
| 290 |
+
|
| 291 |
+
from ray.tune.search.optuna import OptunaSearch
|
| 292 |
+
import optuna
|
| 293 |
+
|
| 294 |
+
space = {
|
| 295 |
+
"a": optuna.distributions.FloatDistribution(6, 8),
|
| 296 |
+
"b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
optuna_search = OptunaSearch(
|
| 300 |
+
space,
|
| 301 |
+
points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
|
| 302 |
+
evaluated_rewards=[0.89, 0.42]
|
| 303 |
+
metric="loss",
|
| 304 |
+
mode="min")
|
| 305 |
+
|
| 306 |
+
tuner = tune.Tuner(
|
| 307 |
+
trainable,
|
| 308 |
+
tune_config=tune.TuneConfig(
|
| 309 |
+
search_alg=optuna_search,
|
| 310 |
+
),
|
| 311 |
+
)
|
| 312 |
+
tuner.fit()
|
| 313 |
+
|
| 314 |
+
.. versionadded:: 0.8.8
|
| 315 |
+
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(
|
| 319 |
+
self,
|
| 320 |
+
space: Optional[
|
| 321 |
+
Union[
|
| 322 |
+
Dict[str, "OptunaDistribution"],
|
| 323 |
+
List[Tuple],
|
| 324 |
+
Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
|
| 325 |
+
]
|
| 326 |
+
] = None,
|
| 327 |
+
metric: Optional[Union[str, List[str]]] = None,
|
| 328 |
+
mode: Optional[Union[str, List[str]]] = None,
|
| 329 |
+
points_to_evaluate: Optional[List[Dict]] = None,
|
| 330 |
+
sampler: Optional["BaseSampler"] = None,
|
| 331 |
+
study_name: Optional[str] = None,
|
| 332 |
+
storage: Optional["BaseStorage"] = None,
|
| 333 |
+
seed: Optional[int] = None,
|
| 334 |
+
evaluated_rewards: Optional[List] = None,
|
| 335 |
+
):
|
| 336 |
+
assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
|
| 337 |
+
super(OptunaSearch, self).__init__(metric=metric, mode=mode)
|
| 338 |
+
|
| 339 |
+
if isinstance(space, dict) and space:
|
| 340 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
|
| 341 |
+
if domain_vars or grid_vars:
|
| 342 |
+
logger.warning(
|
| 343 |
+
UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__)
|
| 344 |
+
)
|
| 345 |
+
space = self.convert_search_space(space)
|
| 346 |
+
else:
|
| 347 |
+
# Flatten to support nested dicts
|
| 348 |
+
space = flatten_dict(space, "/")
|
| 349 |
+
|
| 350 |
+
self._space = space
|
| 351 |
+
|
| 352 |
+
self._points_to_evaluate = points_to_evaluate or []
|
| 353 |
+
self._evaluated_rewards = evaluated_rewards
|
| 354 |
+
if study_name:
|
| 355 |
+
self._study_name = study_name
|
| 356 |
+
else:
|
| 357 |
+
self._study_name = "optuna" # Fixed study name for in-memory storage
|
| 358 |
+
|
| 359 |
+
if sampler and seed:
|
| 360 |
+
logger.warning(
|
| 361 |
+
"You passed an initialized sampler to `OptunaSearch`. The "
|
| 362 |
+
"`seed` parameter has to be passed to the sampler directly "
|
| 363 |
+
"and will be ignored."
|
| 364 |
+
)
|
| 365 |
+
elif sampler:
|
| 366 |
+
assert isinstance(sampler, BaseSampler), (
|
| 367 |
+
"You can only pass an instance of "
|
| 368 |
+
"`optuna.samplers.BaseSampler` "
|
| 369 |
+
"as a sampler to `OptunaSearcher`."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self._sampler = sampler
|
| 373 |
+
self._seed = seed
|
| 374 |
+
|
| 375 |
+
if storage:
|
| 376 |
+
assert isinstance(storage, BaseStorage), (
|
| 377 |
+
"The `storage` parameter in `OptunaSearcher` must be an instance "
|
| 378 |
+
"of `optuna.storages.BaseStorage`."
|
| 379 |
+
)
|
| 380 |
+
# If storage is not provided, just set self._storage to None
|
| 381 |
+
# so that the default in-memory storage is used.
|
| 382 |
+
self._storage = storage
|
| 383 |
+
|
| 384 |
+
self._completed_trials = set()
|
| 385 |
+
|
| 386 |
+
self._ot_trials = {}
|
| 387 |
+
self._ot_study = None
|
| 388 |
+
if self._space:
|
| 389 |
+
self._setup_study(mode)
|
| 390 |
+
|
| 391 |
+
def _setup_study(self, mode: Union[str, list]):
|
| 392 |
+
if self._metric is None and self._mode:
|
| 393 |
+
if isinstance(self._mode, list):
|
| 394 |
+
raise ValueError(
|
| 395 |
+
"If ``mode`` is a list (multi-objective optimization "
|
| 396 |
+
"case), ``metric`` must be defined."
|
| 397 |
+
)
|
| 398 |
+
# If only a mode was passed, use anonymous metric
|
| 399 |
+
self._metric = DEFAULT_METRIC
|
| 400 |
+
|
| 401 |
+
pruner = ot.pruners.NopPruner()
|
| 402 |
+
|
| 403 |
+
if self._sampler:
|
| 404 |
+
sampler = self._sampler
|
| 405 |
+
elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse(
|
| 406 |
+
"2.9.0"
|
| 407 |
+
):
|
| 408 |
+
# MOTPESampler deprecated in Optuna>=2.9.0
|
| 409 |
+
sampler = ot.samplers.MOTPESampler(seed=self._seed)
|
| 410 |
+
else:
|
| 411 |
+
sampler = ot.samplers.TPESampler(seed=self._seed)
|
| 412 |
+
|
| 413 |
+
if isinstance(mode, list):
|
| 414 |
+
study_direction_args = dict(
|
| 415 |
+
directions=["minimize" if m == "min" else "maximize" for m in mode],
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
study_direction_args = dict(
|
| 419 |
+
direction="minimize" if mode == "min" else "maximize",
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
self._ot_study = ot.study.create_study(
|
| 423 |
+
storage=self._storage,
|
| 424 |
+
sampler=sampler,
|
| 425 |
+
pruner=pruner,
|
| 426 |
+
study_name=self._study_name,
|
| 427 |
+
load_if_exists=True,
|
| 428 |
+
**study_direction_args,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if self._points_to_evaluate:
|
| 432 |
+
validate_warmstart(
|
| 433 |
+
self._space,
|
| 434 |
+
self._points_to_evaluate,
|
| 435 |
+
self._evaluated_rewards,
|
| 436 |
+
validate_point_name_lengths=not callable(self._space),
|
| 437 |
+
)
|
| 438 |
+
if self._evaluated_rewards:
|
| 439 |
+
for point, reward in zip(
|
| 440 |
+
self._points_to_evaluate, self._evaluated_rewards
|
| 441 |
+
):
|
| 442 |
+
self.add_evaluated_point(point, reward)
|
| 443 |
+
else:
|
| 444 |
+
for point in self._points_to_evaluate:
|
| 445 |
+
self._ot_study.enqueue_trial(point)
|
| 446 |
+
|
| 447 |
+
def set_search_properties(
|
| 448 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 449 |
+
) -> bool:
|
| 450 |
+
if self._space:
|
| 451 |
+
return False
|
| 452 |
+
space = self.convert_search_space(config)
|
| 453 |
+
self._space = space
|
| 454 |
+
if metric:
|
| 455 |
+
self._metric = metric
|
| 456 |
+
if mode:
|
| 457 |
+
self._mode = mode
|
| 458 |
+
|
| 459 |
+
self._setup_study(self._mode)
|
| 460 |
+
return True
|
| 461 |
+
|
| 462 |
+
def _suggest_from_define_by_run_func(
|
| 463 |
+
self,
|
| 464 |
+
func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
|
| 465 |
+
ot_trial: "OptunaTrial",
|
| 466 |
+
) -> Dict:
|
| 467 |
+
captor = _OptunaTrialSuggestCaptor(ot_trial)
|
| 468 |
+
time_start = time.time()
|
| 469 |
+
ret = func(captor)
|
| 470 |
+
time_taken = time.time() - time_start
|
| 471 |
+
if time_taken > DEFINE_BY_RUN_WARN_THRESHOLD_S:
|
| 472 |
+
warnings.warn(
|
| 473 |
+
"Define-by-run function passed in the `space` argument "
|
| 474 |
+
f"took {time_taken} seconds to "
|
| 475 |
+
"run. Ensure that actual computation, training takes "
|
| 476 |
+
"place inside Tune's train functions or Trainables "
|
| 477 |
+
"passed to `tune.Tuner()`."
|
| 478 |
+
)
|
| 479 |
+
if ret is not None:
|
| 480 |
+
if not isinstance(ret, dict):
|
| 481 |
+
raise TypeError(
|
| 482 |
+
"The return value of the define-by-run function "
|
| 483 |
+
"passed in the `space` argument should be "
|
| 484 |
+
"either None or a `dict` with `str` keys. "
|
| 485 |
+
f"Got {type(ret)}."
|
| 486 |
+
)
|
| 487 |
+
if not all(isinstance(k, str) for k in ret.keys()):
|
| 488 |
+
raise TypeError(
|
| 489 |
+
"At least one of the keys in the dict returned by the "
|
| 490 |
+
"define-by-run function passed in the `space` argument "
|
| 491 |
+
"was not a `str`."
|
| 492 |
+
)
|
| 493 |
+
return {**captor.captured_values, **ret} if ret else captor.captured_values
|
| 494 |
+
|
| 495 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 496 |
+
if not self._space:
|
| 497 |
+
raise RuntimeError(
|
| 498 |
+
UNDEFINED_SEARCH_SPACE.format(
|
| 499 |
+
cls=self.__class__.__name__, space="space"
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
if not self._metric or not self._mode:
|
| 503 |
+
raise RuntimeError(
|
| 504 |
+
UNDEFINED_METRIC_MODE.format(
|
| 505 |
+
cls=self.__class__.__name__, metric=self._metric, mode=self._mode
|
| 506 |
+
)
|
| 507 |
+
)
|
| 508 |
+
if callable(self._space):
|
| 509 |
+
# Define-by-run case
|
| 510 |
+
if trial_id not in self._ot_trials:
|
| 511 |
+
self._ot_trials[trial_id] = self._ot_study.ask()
|
| 512 |
+
|
| 513 |
+
ot_trial = self._ot_trials[trial_id]
|
| 514 |
+
|
| 515 |
+
params = self._suggest_from_define_by_run_func(self._space, ot_trial)
|
| 516 |
+
else:
|
| 517 |
+
# Use Optuna ask interface (since version 2.6.0)
|
| 518 |
+
if trial_id not in self._ot_trials:
|
| 519 |
+
self._ot_trials[trial_id] = self._ot_study.ask(
|
| 520 |
+
fixed_distributions=self._space
|
| 521 |
+
)
|
| 522 |
+
ot_trial = self._ot_trials[trial_id]
|
| 523 |
+
params = ot_trial.params
|
| 524 |
+
|
| 525 |
+
return unflatten_dict(params)
|
| 526 |
+
|
| 527 |
+
def on_trial_result(self, trial_id: str, result: Dict):
|
| 528 |
+
if isinstance(self.metric, list):
|
| 529 |
+
# Optuna doesn't support incremental results
|
| 530 |
+
# for multi-objective optimization
|
| 531 |
+
return
|
| 532 |
+
if trial_id in self._completed_trials:
|
| 533 |
+
logger.warning(
|
| 534 |
+
f"Received additional result for trial {trial_id}, but "
|
| 535 |
+
f"it already finished. Result: {result}"
|
| 536 |
+
)
|
| 537 |
+
return
|
| 538 |
+
metric = result[self.metric]
|
| 539 |
+
step = result[TRAINING_ITERATION]
|
| 540 |
+
ot_trial = self._ot_trials[trial_id]
|
| 541 |
+
ot_trial.report(metric, step)
|
| 542 |
+
|
| 543 |
+
def on_trial_complete(
|
| 544 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 545 |
+
):
|
| 546 |
+
if trial_id in self._completed_trials:
|
| 547 |
+
logger.warning(
|
| 548 |
+
f"Received additional completion for trial {trial_id}, but "
|
| 549 |
+
f"it already finished. Result: {result}"
|
| 550 |
+
)
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
ot_trial = self._ot_trials[trial_id]
|
| 554 |
+
|
| 555 |
+
if result:
|
| 556 |
+
if isinstance(self.metric, list):
|
| 557 |
+
val = [result.get(metric, None) for metric in self.metric]
|
| 558 |
+
else:
|
| 559 |
+
val = result.get(self.metric, None)
|
| 560 |
+
else:
|
| 561 |
+
val = None
|
| 562 |
+
ot_trial_state = OptunaTrialState.COMPLETE
|
| 563 |
+
if val is None:
|
| 564 |
+
if error:
|
| 565 |
+
ot_trial_state = OptunaTrialState.FAIL
|
| 566 |
+
else:
|
| 567 |
+
ot_trial_state = OptunaTrialState.PRUNED
|
| 568 |
+
try:
|
| 569 |
+
self._ot_study.tell(ot_trial, val, state=ot_trial_state)
|
| 570 |
+
except Exception as exc:
|
| 571 |
+
logger.warning(exc) # E.g. if NaN was reported
|
| 572 |
+
|
| 573 |
+
self._completed_trials.add(trial_id)
|
| 574 |
+
|
| 575 |
+
def add_evaluated_point(
|
| 576 |
+
self,
|
| 577 |
+
parameters: Dict,
|
| 578 |
+
value: float,
|
| 579 |
+
error: bool = False,
|
| 580 |
+
pruned: bool = False,
|
| 581 |
+
intermediate_values: Optional[List[float]] = None,
|
| 582 |
+
):
|
| 583 |
+
if not self._space:
|
| 584 |
+
raise RuntimeError(
|
| 585 |
+
UNDEFINED_SEARCH_SPACE.format(
|
| 586 |
+
cls=self.__class__.__name__, space="space"
|
| 587 |
+
)
|
| 588 |
+
)
|
| 589 |
+
if not self._metric or not self._mode:
|
| 590 |
+
raise RuntimeError(
|
| 591 |
+
UNDEFINED_METRIC_MODE.format(
|
| 592 |
+
cls=self.__class__.__name__, metric=self._metric, mode=self._mode
|
| 593 |
+
)
|
| 594 |
+
)
|
| 595 |
+
if callable(self._space):
|
| 596 |
+
raise TypeError(
|
| 597 |
+
"Define-by-run function passed in `space` argument is not "
|
| 598 |
+
"yet supported when using `evaluated_rewards`. Please provide "
|
| 599 |
+
"an `OptunaDistribution` dict or pass a Ray Tune "
|
| 600 |
+
"search space to `tune.Tuner()`."
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
ot_trial_state = OptunaTrialState.COMPLETE
|
| 604 |
+
if error:
|
| 605 |
+
ot_trial_state = OptunaTrialState.FAIL
|
| 606 |
+
elif pruned:
|
| 607 |
+
ot_trial_state = OptunaTrialState.PRUNED
|
| 608 |
+
|
| 609 |
+
if intermediate_values:
|
| 610 |
+
intermediate_values_dict = {
|
| 611 |
+
i: value for i, value in enumerate(intermediate_values)
|
| 612 |
+
}
|
| 613 |
+
else:
|
| 614 |
+
intermediate_values_dict = None
|
| 615 |
+
|
| 616 |
+
# If the trial state is FAILED, the value must be `None` in Optuna==4.1.0
|
| 617 |
+
# Reference: https://github.com/optuna/optuna/pull/5211
|
| 618 |
+
# This is a temporary fix for the issue that Optuna enforces the value
|
| 619 |
+
# to be `None` if the trial state is FAILED.
|
| 620 |
+
# TODO (hpguo): A better solution may requires us to update the base class
|
| 621 |
+
# to allow the `value` arg in `add_evaluated_point` being `Optional[float]`.
|
| 622 |
+
if ot_trial_state == OptunaTrialState.FAIL:
|
| 623 |
+
value = None
|
| 624 |
+
|
| 625 |
+
trial = ot.trial.create_trial(
|
| 626 |
+
state=ot_trial_state,
|
| 627 |
+
value=value,
|
| 628 |
+
params=parameters,
|
| 629 |
+
distributions=self._space,
|
| 630 |
+
intermediate_values=intermediate_values_dict,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
self._ot_study.add_trial(trial)
|
| 634 |
+
|
| 635 |
+
def save(self, checkpoint_path: str):
|
| 636 |
+
save_object = self.__dict__.copy()
|
| 637 |
+
with open(checkpoint_path, "wb") as outputFile:
|
| 638 |
+
pickle.dump(save_object, outputFile)
|
| 639 |
+
|
| 640 |
+
def restore(self, checkpoint_path: str):
|
| 641 |
+
with open(checkpoint_path, "rb") as inputFile:
|
| 642 |
+
save_object = pickle.load(inputFile)
|
| 643 |
+
if isinstance(save_object, dict):
|
| 644 |
+
self.__dict__.update(save_object)
|
| 645 |
+
else:
|
| 646 |
+
# Backwards compatibility
|
| 647 |
+
(
|
| 648 |
+
self._sampler,
|
| 649 |
+
self._ot_trials,
|
| 650 |
+
self._ot_study,
|
| 651 |
+
self._points_to_evaluate,
|
| 652 |
+
self._evaluated_rewards,
|
| 653 |
+
) = save_object
|
| 654 |
+
|
| 655 |
+
@staticmethod
|
| 656 |
+
def convert_search_space(spec: Dict) -> Dict[str, Any]:
|
| 657 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 658 |
+
|
| 659 |
+
if not domain_vars and not grid_vars:
|
| 660 |
+
return {}
|
| 661 |
+
|
| 662 |
+
if grid_vars:
|
| 663 |
+
raise ValueError(
|
| 664 |
+
"Grid search parameters cannot be automatically converted "
|
| 665 |
+
"to an Optuna search space."
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# Flatten and resolve again after checking for grid search.
|
| 669 |
+
spec = flatten_dict(spec, prevent_delimiter=True)
|
| 670 |
+
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 671 |
+
|
| 672 |
+
def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
|
| 673 |
+
quantize = None
|
| 674 |
+
|
| 675 |
+
sampler = domain.get_sampler()
|
| 676 |
+
if isinstance(sampler, Quantized):
|
| 677 |
+
quantize = sampler.q
|
| 678 |
+
sampler = sampler.sampler
|
| 679 |
+
if isinstance(sampler, LogUniform):
|
| 680 |
+
logger.warning(
|
| 681 |
+
"Optuna does not handle quantization in loguniform "
|
| 682 |
+
"sampling. The parameter will be passed but it will "
|
| 683 |
+
"probably be ignored."
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
if isinstance(domain, Float):
|
| 687 |
+
if isinstance(sampler, LogUniform):
|
| 688 |
+
if quantize:
|
| 689 |
+
logger.warning(
|
| 690 |
+
"Optuna does not support both quantization and "
|
| 691 |
+
"sampling from LogUniform. Dropped quantization."
|
| 692 |
+
)
|
| 693 |
+
return ot.distributions.FloatDistribution(
|
| 694 |
+
domain.lower, domain.upper, log=True
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
elif isinstance(sampler, Uniform):
|
| 698 |
+
if quantize:
|
| 699 |
+
return ot.distributions.FloatDistribution(
|
| 700 |
+
domain.lower, domain.upper, step=quantize
|
| 701 |
+
)
|
| 702 |
+
return ot.distributions.FloatDistribution(
|
| 703 |
+
domain.lower, domain.upper
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
elif isinstance(domain, Integer):
|
| 707 |
+
if isinstance(sampler, LogUniform):
|
| 708 |
+
return ot.distributions.IntDistribution(
|
| 709 |
+
domain.lower, domain.upper - 1, step=quantize or 1, log=True
|
| 710 |
+
)
|
| 711 |
+
elif isinstance(sampler, Uniform):
|
| 712 |
+
# Upper bound should be inclusive for quantization and
|
| 713 |
+
# exclusive otherwise
|
| 714 |
+
return ot.distributions.IntDistribution(
|
| 715 |
+
domain.lower,
|
| 716 |
+
domain.upper - int(bool(not quantize)),
|
| 717 |
+
step=quantize or 1,
|
| 718 |
+
)
|
| 719 |
+
elif isinstance(domain, Categorical):
|
| 720 |
+
if isinstance(sampler, Uniform):
|
| 721 |
+
return ot.distributions.CategoricalDistribution(domain.categories)
|
| 722 |
+
|
| 723 |
+
raise ValueError(
|
| 724 |
+
"Optuna search does not support parameters of type "
|
| 725 |
+
"`{}` with samplers of type `{}`".format(
|
| 726 |
+
type(domain).__name__, type(domain.sampler).__name__
|
| 727 |
+
)
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# Parameter name is e.g. "a/b/c" for nested dicts
|
| 731 |
+
values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
|
| 732 |
+
|
| 733 |
+
return values
|
.venv/lib/python3.11/site-packages/ray/tune/search/repeater.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ray.tune.search.searcher import Searcher
|
| 8 |
+
from ray.tune.search.util import _set_search_properties_backwards_compatible
|
| 9 |
+
from ray.util import PublicAPI
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
TRIAL_INDEX = "__trial_index__"
|
| 14 |
+
"""str: A constant value representing the repeat index of the trial."""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _warn_num_samples(searcher: Searcher, num_samples: int):
|
| 18 |
+
if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
|
| 19 |
+
logger.warning(
|
| 20 |
+
"`num_samples` is now expected to be the total number of trials, "
|
| 21 |
+
"including the repeat trials. For example, set num_samples=15 if "
|
| 22 |
+
"you intend to obtain 3 search algorithm suggestions and repeat "
|
| 23 |
+
"each suggestion 5 times. Any leftover trials "
|
| 24 |
+
"(num_samples mod repeat) will be ignored."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _TrialGroup:
|
| 29 |
+
"""Internal class for grouping trials of same parameters.
|
| 30 |
+
|
| 31 |
+
This is used when repeating trials for reducing training variance.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
primary_trial_id: Trial ID of the "primary trial".
|
| 35 |
+
This trial is the one that the Searcher is aware of.
|
| 36 |
+
config: Suggested configuration shared across all trials
|
| 37 |
+
in the trial group.
|
| 38 |
+
max_trials: Max number of trials to execute within this group.
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, primary_trial_id: str, config: Dict, max_trials: int = 1):
|
| 43 |
+
assert type(config) is dict, "config is not a dict, got {}".format(config)
|
| 44 |
+
self.primary_trial_id = primary_trial_id
|
| 45 |
+
self.config = config
|
| 46 |
+
self._trials = {primary_trial_id: None}
|
| 47 |
+
self.max_trials = max_trials
|
| 48 |
+
|
| 49 |
+
def add(self, trial_id: str):
|
| 50 |
+
assert len(self._trials) < self.max_trials
|
| 51 |
+
self._trials.setdefault(trial_id, None)
|
| 52 |
+
|
| 53 |
+
def full(self) -> bool:
|
| 54 |
+
return len(self._trials) == self.max_trials
|
| 55 |
+
|
| 56 |
+
def report(self, trial_id: str, score: float):
|
| 57 |
+
assert trial_id in self._trials
|
| 58 |
+
if score is None:
|
| 59 |
+
raise ValueError("Internal Error: Score cannot be None.")
|
| 60 |
+
self._trials[trial_id] = score
|
| 61 |
+
|
| 62 |
+
def finished_reporting(self) -> bool:
|
| 63 |
+
return (
|
| 64 |
+
None not in self._trials.values() and len(self._trials) == self.max_trials
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def scores(self) -> List[Optional[float]]:
|
| 68 |
+
return list(self._trials.values())
|
| 69 |
+
|
| 70 |
+
def count(self) -> int:
|
| 71 |
+
return len(self._trials)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@PublicAPI
|
| 75 |
+
class Repeater(Searcher):
|
| 76 |
+
"""A wrapper algorithm for repeating trials of same parameters.
|
| 77 |
+
|
| 78 |
+
Set tune.TuneConfig(num_samples=...) to be a multiple of `repeat`. For example,
|
| 79 |
+
set num_samples=15 if you intend to obtain 3 search algorithm suggestions
|
| 80 |
+
and repeat each suggestion 5 times. Any leftover trials
|
| 81 |
+
(num_samples mod repeat) will be ignored.
|
| 82 |
+
|
| 83 |
+
It is recommended that you do not run an early-stopping TrialScheduler
|
| 84 |
+
simultaneously.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
searcher: Searcher object that the
|
| 88 |
+
Repeater will optimize. Note that the Searcher
|
| 89 |
+
will only see 1 trial among multiple repeated trials.
|
| 90 |
+
The result/metric passed to the Searcher upon
|
| 91 |
+
trial completion will be averaged among all repeats.
|
| 92 |
+
repeat: Number of times to generate a trial with a repeated
|
| 93 |
+
configuration. Defaults to 1.
|
| 94 |
+
set_index: Sets a tune.search.repeater.TRIAL_INDEX in
|
| 95 |
+
Trainable/Function config which corresponds to the index of the
|
| 96 |
+
repeated trial. This can be used for seeds. Defaults to True.
|
| 97 |
+
|
| 98 |
+
Example:
|
| 99 |
+
|
| 100 |
+
.. code-block:: python
|
| 101 |
+
|
| 102 |
+
from ray.tune.search import Repeater
|
| 103 |
+
|
| 104 |
+
search_alg = BayesOptSearch(...)
|
| 105 |
+
re_search_alg = Repeater(search_alg, repeat=10)
|
| 106 |
+
|
| 107 |
+
# Repeat 2 samples 10 times each.
|
| 108 |
+
tuner = tune.Tuner(
|
| 109 |
+
trainable,
|
| 110 |
+
tune_config=tune.TuneConfig(
|
| 111 |
+
search_alg=re_search_alg,
|
| 112 |
+
num_samples=20,
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
tuner.fit()
|
| 116 |
+
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, searcher: Searcher, repeat: int = 1, set_index: bool = True):
|
| 120 |
+
self.searcher = searcher
|
| 121 |
+
self.repeat = repeat
|
| 122 |
+
self._set_index = set_index
|
| 123 |
+
self._groups = []
|
| 124 |
+
self._trial_id_to_group = {}
|
| 125 |
+
self._current_group = None
|
| 126 |
+
super(Repeater, self).__init__(
|
| 127 |
+
metric=self.searcher.metric, mode=self.searcher.mode
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 131 |
+
if self._current_group is None or self._current_group.full():
|
| 132 |
+
config = self.searcher.suggest(trial_id)
|
| 133 |
+
if config is None:
|
| 134 |
+
return config
|
| 135 |
+
self._current_group = _TrialGroup(
|
| 136 |
+
trial_id, copy.deepcopy(config), max_trials=self.repeat
|
| 137 |
+
)
|
| 138 |
+
self._groups.append(self._current_group)
|
| 139 |
+
index_in_group = 0
|
| 140 |
+
else:
|
| 141 |
+
index_in_group = self._current_group.count()
|
| 142 |
+
self._current_group.add(trial_id)
|
| 143 |
+
|
| 144 |
+
config = self._current_group.config.copy()
|
| 145 |
+
if self._set_index:
|
| 146 |
+
config[TRIAL_INDEX] = index_in_group
|
| 147 |
+
self._trial_id_to_group[trial_id] = self._current_group
|
| 148 |
+
return config
|
| 149 |
+
|
| 150 |
+
def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, **kwargs):
|
| 151 |
+
"""Stores the score for and keeps track of a completed trial.
|
| 152 |
+
|
| 153 |
+
Stores the metric of a trial as nan if any of the following conditions
|
| 154 |
+
are met:
|
| 155 |
+
|
| 156 |
+
1. ``result`` is empty or not provided.
|
| 157 |
+
2. ``result`` is provided but no metric was provided.
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
if trial_id not in self._trial_id_to_group:
|
| 161 |
+
logger.error(
|
| 162 |
+
"Trial {} not in group; cannot report score. "
|
| 163 |
+
"Seen trials: {}".format(trial_id, list(self._trial_id_to_group))
|
| 164 |
+
)
|
| 165 |
+
trial_group = self._trial_id_to_group[trial_id]
|
| 166 |
+
if not result or self.searcher.metric not in result:
|
| 167 |
+
score = np.nan
|
| 168 |
+
else:
|
| 169 |
+
score = result[self.searcher.metric]
|
| 170 |
+
trial_group.report(trial_id, score)
|
| 171 |
+
|
| 172 |
+
if trial_group.finished_reporting():
|
| 173 |
+
scores = trial_group.scores()
|
| 174 |
+
self.searcher.on_trial_complete(
|
| 175 |
+
trial_group.primary_trial_id,
|
| 176 |
+
result={self.searcher.metric: np.nanmean(scores)},
|
| 177 |
+
**kwargs
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def get_state(self) -> Dict:
|
| 181 |
+
self_state = self.__dict__.copy()
|
| 182 |
+
del self_state["searcher"]
|
| 183 |
+
return self_state
|
| 184 |
+
|
| 185 |
+
def set_state(self, state: Dict):
|
| 186 |
+
self.__dict__.update(state)
|
| 187 |
+
|
| 188 |
+
def save(self, checkpoint_path: str):
|
| 189 |
+
self.searcher.save(checkpoint_path)
|
| 190 |
+
|
| 191 |
+
def restore(self, checkpoint_path: str):
|
| 192 |
+
self.searcher.restore(checkpoint_path)
|
| 193 |
+
|
| 194 |
+
def set_search_properties(
|
| 195 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 196 |
+
) -> bool:
|
| 197 |
+
return _set_search_properties_backwards_compatible(
|
| 198 |
+
self.searcher.set_search_properties, metric, mode, config, **spec
|
| 199 |
+
)
|
.venv/lib/python3.11/site-packages/ray/tune/search/sample.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from copy import copy
|
| 3 |
+
from inspect import signature
|
| 4 |
+
from math import isclose
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
# Backwards compatibility
|
| 10 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
# Added in numpy>=1.17 but we require numpy>=1.16
|
| 14 |
+
np_random_generator = np.random.Generator
|
| 15 |
+
LEGACY_RNG = False
|
| 16 |
+
except AttributeError:
|
| 17 |
+
|
| 18 |
+
class np_random_generator:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
LEGACY_RNG = True
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class _BackwardsCompatibleNumpyRng:
|
| 27 |
+
"""Thin wrapper to ensure backwards compatibility between
|
| 28 |
+
new and old numpy randomness generators.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
_rng = None
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
generator_or_seed: Optional[
|
| 36 |
+
Union["np_random_generator", np.random.RandomState, int]
|
| 37 |
+
] = None,
|
| 38 |
+
):
|
| 39 |
+
if generator_or_seed is None or isinstance(
|
| 40 |
+
generator_or_seed, (np.random.RandomState, np_random_generator)
|
| 41 |
+
):
|
| 42 |
+
self._rng = generator_or_seed
|
| 43 |
+
elif LEGACY_RNG:
|
| 44 |
+
self._rng = np.random.RandomState(generator_or_seed)
|
| 45 |
+
else:
|
| 46 |
+
self._rng = np.random.default_rng(generator_or_seed)
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def legacy_rng(self) -> bool:
|
| 50 |
+
return not isinstance(self._rng, np_random_generator)
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def rng(self):
|
| 54 |
+
# don't set self._rng to np.random to avoid picking issues
|
| 55 |
+
return self._rng if self._rng is not None else np.random
|
| 56 |
+
|
| 57 |
+
def __getattr__(self, name: str) -> Any:
|
| 58 |
+
# https://numpy.org/doc/stable/reference/random/new-or-different.html
|
| 59 |
+
if self.legacy_rng:
|
| 60 |
+
if name == "integers":
|
| 61 |
+
name = "randint"
|
| 62 |
+
elif name == "random":
|
| 63 |
+
name = "rand"
|
| 64 |
+
return getattr(self.rng, name)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
RandomState = Union[
|
| 68 |
+
None, _BackwardsCompatibleNumpyRng, np_random_generator, np.random.RandomState, int
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@DeveloperAPI
|
| 73 |
+
class Domain:
|
| 74 |
+
"""Base class to specify a type and valid range to sample parameters from.
|
| 75 |
+
|
| 76 |
+
This base class is implemented by parameter spaces, like float ranges
|
| 77 |
+
(``Float``), integer ranges (``Integer``), or categorical variables
|
| 78 |
+
(``Categorical``). The ``Domain`` object contains information about
|
| 79 |
+
valid values (e.g. minimum and maximum values), and exposes methods that
|
| 80 |
+
allow specification of specific samplers (e.g. ``uniform()`` or
|
| 81 |
+
``loguniform()``).
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
sampler = None
|
| 86 |
+
default_sampler_cls = None
|
| 87 |
+
|
| 88 |
+
def cast(self, value):
|
| 89 |
+
"""Cast value to domain type"""
|
| 90 |
+
return value
|
| 91 |
+
|
| 92 |
+
def set_sampler(self, sampler, allow_override=False):
|
| 93 |
+
if self.sampler and not allow_override:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"You can only choose one sampler for parameter "
|
| 96 |
+
"domains. Existing sampler for parameter {}: "
|
| 97 |
+
"{}. Tried to add {}".format(
|
| 98 |
+
self.__class__.__name__, self.sampler, sampler
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
self.sampler = sampler
|
| 102 |
+
|
| 103 |
+
def get_sampler(self):
|
| 104 |
+
sampler = self.sampler
|
| 105 |
+
if not sampler:
|
| 106 |
+
sampler = self.default_sampler_cls()
|
| 107 |
+
return sampler
|
| 108 |
+
|
| 109 |
+
def sample(
|
| 110 |
+
self,
|
| 111 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 112 |
+
size: int = 1,
|
| 113 |
+
random_state: "RandomState" = None,
|
| 114 |
+
):
|
| 115 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 116 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 117 |
+
sampler = self.get_sampler()
|
| 118 |
+
return sampler.sample(self, config=config, size=size, random_state=random_state)
|
| 119 |
+
|
| 120 |
+
def is_grid(self):
|
| 121 |
+
return isinstance(self.sampler, Grid)
|
| 122 |
+
|
| 123 |
+
def is_function(self):
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
def is_valid(self, value: Any):
|
| 127 |
+
"""Returns True if `value` is a valid value in this domain."""
|
| 128 |
+
raise NotImplementedError
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def domain_str(self):
|
| 132 |
+
return "(unknown)"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@DeveloperAPI
|
| 136 |
+
class Sampler:
|
| 137 |
+
def sample(
|
| 138 |
+
self,
|
| 139 |
+
domain: Domain,
|
| 140 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 141 |
+
size: int = 1,
|
| 142 |
+
random_state: "RandomState" = None,
|
| 143 |
+
):
|
| 144 |
+
raise NotImplementedError
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@DeveloperAPI
|
| 148 |
+
class BaseSampler(Sampler):
|
| 149 |
+
def __str__(self):
|
| 150 |
+
return "Base"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@DeveloperAPI
|
| 154 |
+
class Uniform(Sampler):
|
| 155 |
+
def __str__(self):
|
| 156 |
+
return "Uniform"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@DeveloperAPI
|
| 160 |
+
class LogUniform(Sampler):
|
| 161 |
+
def __init__(self, base: float = 10):
|
| 162 |
+
self.base = base
|
| 163 |
+
assert self.base > 0, "Base has to be strictly greater than 0"
|
| 164 |
+
|
| 165 |
+
def __str__(self):
|
| 166 |
+
return "LogUniform"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@DeveloperAPI
|
| 170 |
+
class Normal(Sampler):
|
| 171 |
+
def __init__(self, mean: float = 0.0, sd: float = 0.0):
|
| 172 |
+
self.mean = mean
|
| 173 |
+
self.sd = sd
|
| 174 |
+
|
| 175 |
+
assert self.sd > 0, "SD has to be strictly greater than 0"
|
| 176 |
+
|
| 177 |
+
def __str__(self):
|
| 178 |
+
return "Normal"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@DeveloperAPI
|
| 182 |
+
class Grid(Sampler):
|
| 183 |
+
"""Dummy sampler used for grid search"""
|
| 184 |
+
|
| 185 |
+
def sample(
|
| 186 |
+
self,
|
| 187 |
+
domain: Domain,
|
| 188 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 189 |
+
size: int = 1,
|
| 190 |
+
random_state: "RandomState" = None,
|
| 191 |
+
):
|
| 192 |
+
return RuntimeError("Do not call `sample()` on grid.")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@DeveloperAPI
|
| 196 |
+
class Float(Domain):
|
| 197 |
+
class _Uniform(Uniform):
|
| 198 |
+
def sample(
|
| 199 |
+
self,
|
| 200 |
+
domain: "Float",
|
| 201 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 202 |
+
size: int = 1,
|
| 203 |
+
random_state: "RandomState" = None,
|
| 204 |
+
):
|
| 205 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 206 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 207 |
+
assert domain.lower > float("-inf"), "Uniform needs a lower bound"
|
| 208 |
+
assert domain.upper < float("inf"), "Uniform needs a upper bound"
|
| 209 |
+
items = random_state.uniform(domain.lower, domain.upper, size=size)
|
| 210 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 211 |
+
|
| 212 |
+
class _LogUniform(LogUniform):
|
| 213 |
+
def sample(
|
| 214 |
+
self,
|
| 215 |
+
domain: "Float",
|
| 216 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 217 |
+
size: int = 1,
|
| 218 |
+
random_state: "RandomState" = None,
|
| 219 |
+
):
|
| 220 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 221 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 222 |
+
assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
|
| 223 |
+
assert (
|
| 224 |
+
0 < domain.upper < float("inf")
|
| 225 |
+
), "LogUniform needs a upper bound greater than 0"
|
| 226 |
+
logmin = np.log(domain.lower) / np.log(self.base)
|
| 227 |
+
logmax = np.log(domain.upper) / np.log(self.base)
|
| 228 |
+
|
| 229 |
+
items = self.base ** (random_state.uniform(logmin, logmax, size=size))
|
| 230 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 231 |
+
|
| 232 |
+
class _Normal(Normal):
|
| 233 |
+
def sample(
|
| 234 |
+
self,
|
| 235 |
+
domain: "Float",
|
| 236 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 237 |
+
size: int = 1,
|
| 238 |
+
random_state: "RandomState" = None,
|
| 239 |
+
):
|
| 240 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 241 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 242 |
+
assert not domain.lower or domain.lower == float(
|
| 243 |
+
"-inf"
|
| 244 |
+
), "Normal sampling does not allow a lower value bound."
|
| 245 |
+
assert not domain.upper or domain.upper == float(
|
| 246 |
+
"inf"
|
| 247 |
+
), "Normal sampling does not allow a upper value bound."
|
| 248 |
+
items = random_state.normal(self.mean, self.sd, size=size)
|
| 249 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 250 |
+
|
| 251 |
+
default_sampler_cls = _Uniform
|
| 252 |
+
|
| 253 |
+
def __init__(self, lower: Optional[float], upper: Optional[float]):
|
| 254 |
+
# Need to explicitly check for None
|
| 255 |
+
self.lower = lower if lower is not None else float("-inf")
|
| 256 |
+
self.upper = upper if upper is not None else float("inf")
|
| 257 |
+
|
| 258 |
+
def cast(self, value):
|
| 259 |
+
return float(value)
|
| 260 |
+
|
| 261 |
+
def uniform(self):
|
| 262 |
+
if not self.lower > float("-inf"):
|
| 263 |
+
raise ValueError(
|
| 264 |
+
"Uniform requires a lower bound. Make sure to set the "
|
| 265 |
+
"`lower` parameter of `Float()`."
|
| 266 |
+
)
|
| 267 |
+
if not self.upper < float("inf"):
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"Uniform requires a upper bound. Make sure to set the "
|
| 270 |
+
"`upper` parameter of `Float()`."
|
| 271 |
+
)
|
| 272 |
+
new = copy(self)
|
| 273 |
+
new.set_sampler(self._Uniform())
|
| 274 |
+
return new
|
| 275 |
+
|
| 276 |
+
def loguniform(self, base: float = 10):
|
| 277 |
+
if not self.lower > 0:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"LogUniform requires a lower bound greater than 0."
|
| 280 |
+
f"Got: {self.lower}. Did you pass a variable that has "
|
| 281 |
+
"been log-transformed? If so, pass the non-transformed value "
|
| 282 |
+
"instead."
|
| 283 |
+
)
|
| 284 |
+
if not 0 < self.upper < float("inf"):
|
| 285 |
+
raise ValueError(
|
| 286 |
+
"LogUniform requires a upper bound greater than 0. "
|
| 287 |
+
f"Got: {self.lower}. Did you pass a variable that has "
|
| 288 |
+
"been log-transformed? If so, pass the non-transformed value "
|
| 289 |
+
"instead."
|
| 290 |
+
)
|
| 291 |
+
new = copy(self)
|
| 292 |
+
new.set_sampler(self._LogUniform(base))
|
| 293 |
+
return new
|
| 294 |
+
|
| 295 |
+
def normal(self, mean=0.0, sd=1.0):
|
| 296 |
+
new = copy(self)
|
| 297 |
+
new.set_sampler(self._Normal(mean, sd))
|
| 298 |
+
return new
|
| 299 |
+
|
| 300 |
+
def quantized(self, q: float):
|
| 301 |
+
if self.lower > float("-inf") and not isclose(
|
| 302 |
+
self.lower / q, round(self.lower / q)
|
| 303 |
+
):
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"Your lower variable bound {self.lower} is not divisible by "
|
| 306 |
+
f"quantization factor {q}."
|
| 307 |
+
)
|
| 308 |
+
if self.upper < float("inf") and not isclose(
|
| 309 |
+
self.upper / q, round(self.upper / q)
|
| 310 |
+
):
|
| 311 |
+
raise ValueError(
|
| 312 |
+
f"Your upper variable bound {self.upper} is not divisible by "
|
| 313 |
+
f"quantization factor {q}."
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
new = copy(self)
|
| 317 |
+
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
| 318 |
+
return new
|
| 319 |
+
|
| 320 |
+
def is_valid(self, value: float):
|
| 321 |
+
return self.lower <= value <= self.upper
|
| 322 |
+
|
| 323 |
+
@property
|
| 324 |
+
def domain_str(self):
|
| 325 |
+
return f"({self.lower}, {self.upper})"
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@DeveloperAPI
|
| 329 |
+
class Integer(Domain):
|
| 330 |
+
class _Uniform(Uniform):
|
| 331 |
+
def sample(
|
| 332 |
+
self,
|
| 333 |
+
domain: "Integer",
|
| 334 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 335 |
+
size: int = 1,
|
| 336 |
+
random_state: "RandomState" = None,
|
| 337 |
+
):
|
| 338 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 339 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 340 |
+
items = random_state.integers(domain.lower, domain.upper, size=size)
|
| 341 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 342 |
+
|
| 343 |
+
class _LogUniform(LogUniform):
|
| 344 |
+
def sample(
|
| 345 |
+
self,
|
| 346 |
+
domain: "Integer",
|
| 347 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 348 |
+
size: int = 1,
|
| 349 |
+
random_state: "RandomState" = None,
|
| 350 |
+
):
|
| 351 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 352 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 353 |
+
assert domain.lower > 0, "LogUniform needs a lower bound greater than 0"
|
| 354 |
+
assert (
|
| 355 |
+
0 < domain.upper < float("inf")
|
| 356 |
+
), "LogUniform needs a upper bound greater than 0"
|
| 357 |
+
logmin = np.log(domain.lower) / np.log(self.base)
|
| 358 |
+
logmax = np.log(domain.upper) / np.log(self.base)
|
| 359 |
+
|
| 360 |
+
items = self.base ** (random_state.uniform(logmin, logmax, size=size))
|
| 361 |
+
items = np.floor(items).astype(int)
|
| 362 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 363 |
+
|
| 364 |
+
default_sampler_cls = _Uniform
|
| 365 |
+
|
| 366 |
+
def __init__(self, lower, upper):
|
| 367 |
+
self.lower = lower
|
| 368 |
+
self.upper = upper
|
| 369 |
+
|
| 370 |
+
def cast(self, value):
|
| 371 |
+
return int(value)
|
| 372 |
+
|
| 373 |
+
def quantized(self, q: int):
|
| 374 |
+
new = copy(self)
|
| 375 |
+
new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True)
|
| 376 |
+
return new
|
| 377 |
+
|
| 378 |
+
def uniform(self):
|
| 379 |
+
new = copy(self)
|
| 380 |
+
new.set_sampler(self._Uniform())
|
| 381 |
+
return new
|
| 382 |
+
|
| 383 |
+
def loguniform(self, base: float = 10):
|
| 384 |
+
if not self.lower > 0:
|
| 385 |
+
raise ValueError(
|
| 386 |
+
"LogUniform requires a lower bound greater than 0."
|
| 387 |
+
f"Got: {self.lower}. Did you pass a variable that has "
|
| 388 |
+
"been log-transformed? If so, pass the non-transformed value "
|
| 389 |
+
"instead."
|
| 390 |
+
)
|
| 391 |
+
if not 0 < self.upper < float("inf"):
|
| 392 |
+
raise ValueError(
|
| 393 |
+
"LogUniform requires a upper bound greater than 0. "
|
| 394 |
+
f"Got: {self.lower}. Did you pass a variable that has "
|
| 395 |
+
"been log-transformed? If so, pass the non-transformed value "
|
| 396 |
+
"instead."
|
| 397 |
+
)
|
| 398 |
+
new = copy(self)
|
| 399 |
+
new.set_sampler(self._LogUniform(base))
|
| 400 |
+
return new
|
| 401 |
+
|
| 402 |
+
def is_valid(self, value: int):
|
| 403 |
+
return self.lower <= value <= self.upper
|
| 404 |
+
|
| 405 |
+
@property
|
| 406 |
+
def domain_str(self):
|
| 407 |
+
return f"({self.lower}, {self.upper})"
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
@DeveloperAPI
|
| 411 |
+
class Categorical(Domain):
|
| 412 |
+
class _Uniform(Uniform):
|
| 413 |
+
def sample(
|
| 414 |
+
self,
|
| 415 |
+
domain: "Categorical",
|
| 416 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 417 |
+
size: int = 1,
|
| 418 |
+
random_state: "RandomState" = None,
|
| 419 |
+
):
|
| 420 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 421 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 422 |
+
# do not use .choice() directly on domain.categories
|
| 423 |
+
# as that will coerce them to a single dtype
|
| 424 |
+
indices = random_state.choice(
|
| 425 |
+
np.arange(0, len(domain.categories)), size=size
|
| 426 |
+
)
|
| 427 |
+
items = [domain.categories[index] for index in indices]
|
| 428 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 429 |
+
|
| 430 |
+
default_sampler_cls = _Uniform
|
| 431 |
+
|
| 432 |
+
def __init__(self, categories: Sequence):
|
| 433 |
+
self.categories = list(categories)
|
| 434 |
+
|
| 435 |
+
def uniform(self):
|
| 436 |
+
new = copy(self)
|
| 437 |
+
new.set_sampler(self._Uniform())
|
| 438 |
+
return new
|
| 439 |
+
|
| 440 |
+
def grid(self):
|
| 441 |
+
new = copy(self)
|
| 442 |
+
new.set_sampler(Grid())
|
| 443 |
+
return new
|
| 444 |
+
|
| 445 |
+
def __len__(self):
|
| 446 |
+
return len(self.categories)
|
| 447 |
+
|
| 448 |
+
def __getitem__(self, item):
|
| 449 |
+
return self.categories[item]
|
| 450 |
+
|
| 451 |
+
def is_valid(self, value: Any):
|
| 452 |
+
return value in self.categories
|
| 453 |
+
|
| 454 |
+
@property
|
| 455 |
+
def domain_str(self):
|
| 456 |
+
return f"{self.categories}"
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@DeveloperAPI
|
| 460 |
+
class Function(Domain):
|
| 461 |
+
class _CallSampler(BaseSampler):
|
| 462 |
+
def __try_fn(self, domain: "Function", config: Dict[str, Any]):
|
| 463 |
+
try:
|
| 464 |
+
return domain.func(config)
|
| 465 |
+
except (AttributeError, KeyError):
|
| 466 |
+
from ray.tune.search.variant_generator import _UnresolvedAccessGuard
|
| 467 |
+
|
| 468 |
+
r = domain.func(_UnresolvedAccessGuard({"config": config}))
|
| 469 |
+
logger.warning(
|
| 470 |
+
"sample_from functions that take a spec dict are "
|
| 471 |
+
"deprecated. Please update your function to work with "
|
| 472 |
+
"the config dict directly."
|
| 473 |
+
)
|
| 474 |
+
return r
|
| 475 |
+
|
| 476 |
+
def sample(
|
| 477 |
+
self,
|
| 478 |
+
domain: "Function",
|
| 479 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 480 |
+
size: int = 1,
|
| 481 |
+
random_state: "RandomState" = None,
|
| 482 |
+
):
|
| 483 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 484 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 485 |
+
if domain.pass_config:
|
| 486 |
+
items = [
|
| 487 |
+
self.__try_fn(domain, config[i])
|
| 488 |
+
if isinstance(config, list)
|
| 489 |
+
else self.__try_fn(domain, config)
|
| 490 |
+
for i in range(size)
|
| 491 |
+
]
|
| 492 |
+
else:
|
| 493 |
+
items = [domain.func() for i in range(size)]
|
| 494 |
+
|
| 495 |
+
return items if len(items) > 1 else domain.cast(items[0])
|
| 496 |
+
|
| 497 |
+
default_sampler_cls = _CallSampler
|
| 498 |
+
|
| 499 |
+
def __init__(self, func: Callable):
|
| 500 |
+
sig = signature(func)
|
| 501 |
+
|
| 502 |
+
pass_config = True # whether we should pass `config` when calling `func`
|
| 503 |
+
try:
|
| 504 |
+
sig.bind({})
|
| 505 |
+
except TypeError:
|
| 506 |
+
pass_config = False
|
| 507 |
+
|
| 508 |
+
if not pass_config:
|
| 509 |
+
try:
|
| 510 |
+
sig.bind()
|
| 511 |
+
except TypeError as exc:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
"The function passed to a `Function` parameter must be "
|
| 514 |
+
"callable with either 0 or 1 parameters."
|
| 515 |
+
) from exc
|
| 516 |
+
|
| 517 |
+
self.pass_config = pass_config
|
| 518 |
+
self.func = func
|
| 519 |
+
|
| 520 |
+
def is_function(self):
|
| 521 |
+
return True
|
| 522 |
+
|
| 523 |
+
def is_valid(self, value: Any):
|
| 524 |
+
return True # This is user-defined, so lets not assume anything
|
| 525 |
+
|
| 526 |
+
@property
|
| 527 |
+
def domain_str(self):
|
| 528 |
+
return f"{self.func}()"
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@DeveloperAPI
|
| 532 |
+
class Quantized(Sampler):
|
| 533 |
+
def __init__(self, sampler: Sampler, q: Union[float, int]):
|
| 534 |
+
self.sampler = sampler
|
| 535 |
+
self.q = q
|
| 536 |
+
|
| 537 |
+
assert self.sampler, "Quantized() expects a sampler instance"
|
| 538 |
+
|
| 539 |
+
def get_sampler(self):
|
| 540 |
+
return self.sampler
|
| 541 |
+
|
| 542 |
+
def sample(
|
| 543 |
+
self,
|
| 544 |
+
domain: Domain,
|
| 545 |
+
config: Optional[Union[List[Dict], Dict]] = None,
|
| 546 |
+
size: int = 1,
|
| 547 |
+
random_state: "RandomState" = None,
|
| 548 |
+
):
|
| 549 |
+
if not isinstance(random_state, _BackwardsCompatibleNumpyRng):
|
| 550 |
+
random_state = _BackwardsCompatibleNumpyRng(random_state)
|
| 551 |
+
|
| 552 |
+
if self.q == 1:
|
| 553 |
+
return self.sampler.sample(domain, config, size, random_state=random_state)
|
| 554 |
+
|
| 555 |
+
quantized_domain = copy(domain)
|
| 556 |
+
quantized_domain.lower = np.ceil(domain.lower / self.q) * self.q
|
| 557 |
+
quantized_domain.upper = np.floor(domain.upper / self.q) * self.q
|
| 558 |
+
values = self.sampler.sample(
|
| 559 |
+
quantized_domain, config, size, random_state=random_state
|
| 560 |
+
)
|
| 561 |
+
quantized = np.round(np.divide(values, self.q)) * self.q
|
| 562 |
+
|
| 563 |
+
if not isinstance(quantized, np.ndarray):
|
| 564 |
+
return domain.cast(quantized)
|
| 565 |
+
return list(quantized)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
@PublicAPI
|
| 569 |
+
def sample_from(func: Callable[[Dict], Any]):
|
| 570 |
+
"""Specify that tune should sample configuration values from this function.
|
| 571 |
+
|
| 572 |
+
Arguments:
|
| 573 |
+
func: An callable function to draw a sample from.
|
| 574 |
+
"""
|
| 575 |
+
return Function(func)
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
@PublicAPI
|
| 579 |
+
def uniform(lower: float, upper: float):
|
| 580 |
+
"""Sample a float value uniformly between ``lower`` and ``upper``.
|
| 581 |
+
|
| 582 |
+
Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
|
| 583 |
+
``np.random.uniform(1, 10))``
|
| 584 |
+
|
| 585 |
+
"""
|
| 586 |
+
return Float(lower, upper).uniform()
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@PublicAPI
|
| 590 |
+
def quniform(lower: float, upper: float, q: float):
|
| 591 |
+
"""Sample a quantized float value uniformly between ``lower`` and ``upper``.
|
| 592 |
+
|
| 593 |
+
Sampling from ``tune.uniform(1, 10)`` is equivalent to sampling from
|
| 594 |
+
``np.random.uniform(1, 10))``
|
| 595 |
+
|
| 596 |
+
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
| 597 |
+
Quantization makes the upper bound inclusive.
|
| 598 |
+
|
| 599 |
+
"""
|
| 600 |
+
return Float(lower, upper).uniform().quantized(q)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@PublicAPI
|
| 604 |
+
def loguniform(lower: float, upper: float, base: float = 10):
|
| 605 |
+
"""Sugar for sampling in different orders of magnitude.
|
| 606 |
+
|
| 607 |
+
Args:
|
| 608 |
+
lower: Lower boundary of the output interval (e.g. 1e-4)
|
| 609 |
+
upper: Upper boundary of the output interval (e.g. 1e-2)
|
| 610 |
+
base: Base of the log. Defaults to 10.
|
| 611 |
+
|
| 612 |
+
"""
|
| 613 |
+
return Float(lower, upper).loguniform(base)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
@PublicAPI
|
| 617 |
+
def qloguniform(lower: float, upper: float, q: float, base: float = 10):
|
| 618 |
+
"""Sugar for sampling in different orders of magnitude.
|
| 619 |
+
|
| 620 |
+
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
| 621 |
+
|
| 622 |
+
Quantization makes the upper bound inclusive.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
lower: Lower boundary of the output interval (e.g. 1e-4)
|
| 626 |
+
upper: Upper boundary of the output interval (e.g. 1e-2)
|
| 627 |
+
q: Quantization number. The result will be rounded to an
|
| 628 |
+
integer increment of this value.
|
| 629 |
+
base: Base of the log. Defaults to 10.
|
| 630 |
+
|
| 631 |
+
"""
|
| 632 |
+
return Float(lower, upper).loguniform(base).quantized(q)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
@PublicAPI
|
| 636 |
+
def choice(categories: Sequence):
|
| 637 |
+
"""Sample a categorical value.
|
| 638 |
+
|
| 639 |
+
Sampling from ``tune.choice([1, 2])`` is equivalent to sampling from
|
| 640 |
+
``np.random.choice([1, 2])``
|
| 641 |
+
|
| 642 |
+
"""
|
| 643 |
+
return Categorical(categories).uniform()
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@PublicAPI
|
| 647 |
+
def randint(lower: int, upper: int):
|
| 648 |
+
"""Sample an integer value uniformly between ``lower`` and ``upper``.
|
| 649 |
+
|
| 650 |
+
``lower`` is inclusive, ``upper`` is exclusive.
|
| 651 |
+
|
| 652 |
+
Sampling from ``tune.randint(10)`` is equivalent to sampling from
|
| 653 |
+
``np.random.randint(10)``
|
| 654 |
+
|
| 655 |
+
.. versionchanged:: 1.5.0
|
| 656 |
+
When converting Ray Tune configs to searcher-specific search spaces,
|
| 657 |
+
the lower and upper limits are adjusted to keep compatibility with
|
| 658 |
+
the bounds stated in the docstring above.
|
| 659 |
+
|
| 660 |
+
"""
|
| 661 |
+
return Integer(lower, upper).uniform()
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@PublicAPI
|
| 665 |
+
def lograndint(lower: int, upper: int, base: float = 10):
|
| 666 |
+
"""Sample an integer value log-uniformly between ``lower`` and ``upper``,
|
| 667 |
+
with ``base`` being the base of logarithm.
|
| 668 |
+
|
| 669 |
+
``lower`` is inclusive, ``upper`` is exclusive.
|
| 670 |
+
|
| 671 |
+
.. versionchanged:: 1.5.0
|
| 672 |
+
When converting Ray Tune configs to searcher-specific search spaces,
|
| 673 |
+
the lower and upper limits are adjusted to keep compatibility with
|
| 674 |
+
the bounds stated in the docstring above.
|
| 675 |
+
|
| 676 |
+
"""
|
| 677 |
+
return Integer(lower, upper).loguniform(base)
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
@PublicAPI
|
| 681 |
+
def qrandint(lower: int, upper: int, q: int = 1):
|
| 682 |
+
"""Sample an integer value uniformly between ``lower`` and ``upper``.
|
| 683 |
+
|
| 684 |
+
``lower`` is inclusive, ``upper`` is also inclusive (!).
|
| 685 |
+
|
| 686 |
+
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
| 687 |
+
Quantization makes the upper bound inclusive.
|
| 688 |
+
|
| 689 |
+
.. versionchanged:: 1.5.0
|
| 690 |
+
When converting Ray Tune configs to searcher-specific search spaces,
|
| 691 |
+
the lower and upper limits are adjusted to keep compatibility with
|
| 692 |
+
the bounds stated in the docstring above.
|
| 693 |
+
|
| 694 |
+
"""
|
| 695 |
+
return Integer(lower, upper).uniform().quantized(q)
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
@PublicAPI
|
| 699 |
+
def qlograndint(lower: int, upper: int, q: int, base: float = 10):
|
| 700 |
+
"""Sample an integer value log-uniformly between ``lower`` and ``upper``,
|
| 701 |
+
with ``base`` being the base of logarithm.
|
| 702 |
+
|
| 703 |
+
``lower`` is inclusive, ``upper`` is also inclusive (!).
|
| 704 |
+
|
| 705 |
+
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
| 706 |
+
Quantization makes the upper bound inclusive.
|
| 707 |
+
|
| 708 |
+
.. versionchanged:: 1.5.0
|
| 709 |
+
When converting Ray Tune configs to searcher-specific search spaces,
|
| 710 |
+
the lower and upper limits are adjusted to keep compatibility with
|
| 711 |
+
the bounds stated in the docstring above.
|
| 712 |
+
|
| 713 |
+
"""
|
| 714 |
+
return Integer(lower, upper).loguniform(base).quantized(q)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
@PublicAPI
|
| 718 |
+
def randn(mean: float = 0.0, sd: float = 1.0):
|
| 719 |
+
"""Sample a float value normally with ``mean`` and ``sd``.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
mean: Mean of the normal distribution. Defaults to 0.
|
| 723 |
+
sd: SD of the normal distribution. Defaults to 1.
|
| 724 |
+
|
| 725 |
+
"""
|
| 726 |
+
return Float(None, None).normal(mean, sd)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
@PublicAPI
|
| 730 |
+
def qrandn(mean: float, sd: float, q: float):
|
| 731 |
+
"""Sample a float value normally with ``mean`` and ``sd``.
|
| 732 |
+
|
| 733 |
+
The value will be quantized, i.e. rounded to an integer increment of ``q``.
|
| 734 |
+
|
| 735 |
+
Args:
|
| 736 |
+
mean: Mean of the normal distribution.
|
| 737 |
+
sd: SD of the normal distribution.
|
| 738 |
+
q: Quantization number. The result will be rounded to an
|
| 739 |
+
integer increment of this value.
|
| 740 |
+
|
| 741 |
+
"""
|
| 742 |
+
return Float(None, None).normal(mean, sd).quantized(q)
|
.venv/lib/python3.11/site-packages/ray/tune/search/search_algorithm.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
from ray.util.annotations import DeveloperAPI
|
| 4 |
+
|
| 5 |
+
if TYPE_CHECKING:
|
| 6 |
+
from ray.tune.experiment import Experiment
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@DeveloperAPI
|
| 10 |
+
class SearchAlgorithm:
|
| 11 |
+
"""Interface of an event handler API for hyperparameter search.
|
| 12 |
+
|
| 13 |
+
Unlike TrialSchedulers, SearchAlgorithms will not have the ability
|
| 14 |
+
to modify the execution (i.e., stop and pause trials).
|
| 15 |
+
|
| 16 |
+
Trials added manually (i.e., via the Client API) will also notify
|
| 17 |
+
this class upon new events, so custom search algorithms should
|
| 18 |
+
maintain a list of trials ID generated from this class.
|
| 19 |
+
|
| 20 |
+
See also: `ray.tune.search.BasicVariantGenerator`.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
_finished = False
|
| 24 |
+
|
| 25 |
+
_metric = None
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def metric(self):
|
| 29 |
+
return self._metric
|
| 30 |
+
|
| 31 |
+
def set_search_properties(
|
| 32 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 33 |
+
) -> bool:
|
| 34 |
+
"""Pass search properties to search algorithm.
|
| 35 |
+
|
| 36 |
+
This method acts as an alternative to instantiating search algorithms
|
| 37 |
+
with their own specific search spaces. Instead they can accept a
|
| 38 |
+
Tune config through this method.
|
| 39 |
+
|
| 40 |
+
The search algorithm will usually pass this method to their
|
| 41 |
+
``Searcher`` instance.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
metric: Metric to optimize
|
| 45 |
+
mode: One of ["min", "max"]. Direction to optimize.
|
| 46 |
+
config: Tune config dict.
|
| 47 |
+
**spec: Any kwargs for forward compatiblity.
|
| 48 |
+
Info like Experiment.PUBLIC_KEYS is provided through here.
|
| 49 |
+
"""
|
| 50 |
+
if self._metric and metric:
|
| 51 |
+
return False
|
| 52 |
+
if metric:
|
| 53 |
+
self._metric = metric
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def total_samples(self):
|
| 58 |
+
"""Get number of total trials to be generated"""
|
| 59 |
+
return 0
|
| 60 |
+
|
| 61 |
+
def add_configurations(
|
| 62 |
+
self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
|
| 63 |
+
):
|
| 64 |
+
"""Tracks given experiment specifications.
|
| 65 |
+
|
| 66 |
+
Arguments:
|
| 67 |
+
experiments: Experiments to run.
|
| 68 |
+
"""
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
def next_trial(self):
|
| 72 |
+
"""Returns single Trial object to be queued into the TrialRunner.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
trial: Returns a Trial object.
|
| 76 |
+
"""
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def on_trial_result(self, trial_id: str, result: Dict):
|
| 80 |
+
"""Called on each intermediate result returned by a trial.
|
| 81 |
+
|
| 82 |
+
This will only be called when the trial is in the RUNNING state.
|
| 83 |
+
|
| 84 |
+
Arguments:
|
| 85 |
+
trial_id: Identifier for the trial.
|
| 86 |
+
result: Result dictionary.
|
| 87 |
+
"""
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def on_trial_complete(
|
| 91 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 92 |
+
):
|
| 93 |
+
"""Notification for the completion of trial.
|
| 94 |
+
|
| 95 |
+
Arguments:
|
| 96 |
+
trial_id: Identifier for the trial.
|
| 97 |
+
result: Defaults to None. A dict will
|
| 98 |
+
be provided with this notification when the trial is in
|
| 99 |
+
the RUNNING state AND either completes naturally or
|
| 100 |
+
by manual termination.
|
| 101 |
+
error: Defaults to False. True if the trial is in
|
| 102 |
+
the RUNNING state and errors.
|
| 103 |
+
"""
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
def is_finished(self) -> bool:
|
| 107 |
+
"""Returns True if no trials left to be queued into TrialRunner.
|
| 108 |
+
|
| 109 |
+
Can return True before all trials have finished executing.
|
| 110 |
+
"""
|
| 111 |
+
return self._finished
|
| 112 |
+
|
| 113 |
+
def set_finished(self):
|
| 114 |
+
"""Marks the search algorithm as finished."""
|
| 115 |
+
self._finished = True
|
| 116 |
+
|
| 117 |
+
def has_checkpoint(self, dirpath: str) -> bool:
|
| 118 |
+
"""Should return False if restoring is not implemented."""
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
+
def save_to_dir(self, dirpath: str, **kwargs):
|
| 122 |
+
"""Saves a search algorithm."""
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
def restore_from_dir(self, dirpath: str):
|
| 126 |
+
"""Restores a search algorithm along with its wrapped state."""
|
| 127 |
+
pass
|
.venv/lib/python3.11/site-packages/ray/tune/search/search_generator.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from ray.tune.error import TuneError
|
| 6 |
+
from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
|
| 7 |
+
from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
|
| 8 |
+
from ray.tune.search.search_algorithm import SearchAlgorithm
|
| 9 |
+
from ray.tune.search.searcher import Searcher
|
| 10 |
+
from ray.tune.search.util import _set_search_properties_backwards_compatible
|
| 11 |
+
from ray.tune.search.variant_generator import _resolve_nested_dict, format_vars
|
| 12 |
+
from ray.tune.utils.util import (
|
| 13 |
+
_atomic_save,
|
| 14 |
+
_load_newest_checkpoint,
|
| 15 |
+
flatten_dict,
|
| 16 |
+
merge_dicts,
|
| 17 |
+
)
|
| 18 |
+
from ray.util.annotations import DeveloperAPI
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _warn_on_repeater(searcher, total_samples):
|
| 24 |
+
from ray.tune.search.repeater import _warn_num_samples
|
| 25 |
+
|
| 26 |
+
_warn_num_samples(searcher, total_samples)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@DeveloperAPI
|
| 30 |
+
class SearchGenerator(SearchAlgorithm):
|
| 31 |
+
"""Generates trials to be passed to the TrialRunner.
|
| 32 |
+
|
| 33 |
+
Uses the provided ``searcher`` object to generate trials. This class
|
| 34 |
+
transparently handles repeating trials with score aggregation
|
| 35 |
+
without embedding logic into the Searcher.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
searcher: Search object that subclasses the Searcher base class. This
|
| 39 |
+
is then used for generating new hyperparameter samples.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
CKPT_FILE_TMPL = "search_gen_state-{}.json"
|
| 43 |
+
|
| 44 |
+
def __init__(self, searcher: Searcher):
|
| 45 |
+
assert issubclass(
|
| 46 |
+
type(searcher), Searcher
|
| 47 |
+
), "Searcher should be subclassing Searcher."
|
| 48 |
+
self.searcher = searcher
|
| 49 |
+
self._parser = _make_parser()
|
| 50 |
+
self._experiment = None
|
| 51 |
+
self._counter = 0 # Keeps track of number of trials created.
|
| 52 |
+
self._total_samples = 0 # int: total samples to evaluate.
|
| 53 |
+
self._finished = False
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def metric(self):
|
| 57 |
+
return self.searcher.metric
|
| 58 |
+
|
| 59 |
+
def set_search_properties(
|
| 60 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 61 |
+
) -> bool:
|
| 62 |
+
return _set_search_properties_backwards_compatible(
|
| 63 |
+
self.searcher.set_search_properties, metric, mode, config, **spec
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def total_samples(self):
|
| 68 |
+
return self._total_samples
|
| 69 |
+
|
| 70 |
+
def add_configurations(
|
| 71 |
+
self, experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]
|
| 72 |
+
):
|
| 73 |
+
"""Registers experiment specifications.
|
| 74 |
+
|
| 75 |
+
Arguments:
|
| 76 |
+
experiments: Experiments to run.
|
| 77 |
+
"""
|
| 78 |
+
assert not self._experiment
|
| 79 |
+
logger.debug("added configurations")
|
| 80 |
+
experiment_list = _convert_to_experiment_list(experiments)
|
| 81 |
+
assert (
|
| 82 |
+
len(experiment_list) == 1
|
| 83 |
+
), "SearchAlgorithms can only support 1 experiment at a time."
|
| 84 |
+
self._experiment = experiment_list[0]
|
| 85 |
+
experiment_spec = self._experiment.spec
|
| 86 |
+
self._total_samples = self._experiment.spec.get("num_samples", 1)
|
| 87 |
+
|
| 88 |
+
_warn_on_repeater(self.searcher, self._total_samples)
|
| 89 |
+
if "run" not in experiment_spec:
|
| 90 |
+
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
| 91 |
+
|
| 92 |
+
def next_trial(self):
|
| 93 |
+
"""Provides one Trial object to be queued into the TrialRunner.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Trial: Returns a single trial.
|
| 97 |
+
"""
|
| 98 |
+
if not self.is_finished():
|
| 99 |
+
return self.create_trial_if_possible(self._experiment.spec)
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def create_trial_if_possible(self, experiment_spec: Dict) -> Optional[Trial]:
|
| 103 |
+
logger.debug("creating trial")
|
| 104 |
+
trial_id = Trial.generate_id()
|
| 105 |
+
suggested_config = self.searcher.suggest(trial_id)
|
| 106 |
+
if suggested_config == Searcher.FINISHED:
|
| 107 |
+
self._finished = True
|
| 108 |
+
logger.debug("Searcher has finished.")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
if suggested_config is None:
|
| 112 |
+
return
|
| 113 |
+
spec = copy.deepcopy(experiment_spec)
|
| 114 |
+
spec["config"] = merge_dicts(spec["config"], copy.deepcopy(suggested_config))
|
| 115 |
+
|
| 116 |
+
# Create a new trial_id if duplicate trial is created
|
| 117 |
+
flattened_config = _resolve_nested_dict(spec["config"])
|
| 118 |
+
self._counter += 1
|
| 119 |
+
tag = "{0}_{1}".format(str(self._counter), format_vars(flattened_config))
|
| 120 |
+
trial = _create_trial_from_spec(
|
| 121 |
+
spec,
|
| 122 |
+
self._parser,
|
| 123 |
+
evaluated_params=flatten_dict(suggested_config),
|
| 124 |
+
experiment_tag=tag,
|
| 125 |
+
trial_id=trial_id,
|
| 126 |
+
)
|
| 127 |
+
return trial
|
| 128 |
+
|
| 129 |
+
def on_trial_result(self, trial_id: str, result: Dict):
|
| 130 |
+
"""Notifies the underlying searcher."""
|
| 131 |
+
self.searcher.on_trial_result(trial_id, result)
|
| 132 |
+
|
| 133 |
+
def on_trial_complete(
|
| 134 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 135 |
+
):
|
| 136 |
+
self.searcher.on_trial_complete(trial_id=trial_id, result=result, error=error)
|
| 137 |
+
|
| 138 |
+
def is_finished(self) -> bool:
|
| 139 |
+
return self._counter >= self._total_samples or self._finished
|
| 140 |
+
|
| 141 |
+
def get_state(self) -> Dict:
|
| 142 |
+
return {
|
| 143 |
+
"counter": self._counter,
|
| 144 |
+
"total_samples": self._total_samples,
|
| 145 |
+
"finished": self._finished,
|
| 146 |
+
"experiment": self._experiment,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
def set_state(self, state: Dict):
|
| 150 |
+
self._counter = state["counter"]
|
| 151 |
+
self._total_samples = state["total_samples"]
|
| 152 |
+
self._finished = state["finished"]
|
| 153 |
+
self._experiment = state["experiment"]
|
| 154 |
+
|
| 155 |
+
def has_checkpoint(self, dirpath: str):
|
| 156 |
+
return bool(_load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")))
|
| 157 |
+
|
| 158 |
+
def save_to_dir(self, dirpath: str, session_str: str):
|
| 159 |
+
"""Saves self + searcher to dir.
|
| 160 |
+
|
| 161 |
+
Separates the "searcher" from its wrappers (concurrency, repeating).
|
| 162 |
+
This allows the user to easily restore a given searcher.
|
| 163 |
+
|
| 164 |
+
The save operation is atomic (write/swap).
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
dirpath: Filepath to experiment dir.
|
| 168 |
+
session_str: Unique identifier of the current run
|
| 169 |
+
session.
|
| 170 |
+
"""
|
| 171 |
+
searcher = self.searcher
|
| 172 |
+
search_alg_state = self.get_state()
|
| 173 |
+
while hasattr(searcher, "searcher"):
|
| 174 |
+
searcher_name = type(searcher).__name__
|
| 175 |
+
if searcher_name in search_alg_state:
|
| 176 |
+
logger.warning(
|
| 177 |
+
"There was a duplicate when saving {}. "
|
| 178 |
+
"Restore may not work properly.".format(searcher_name)
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
search_alg_state["name:" + searcher_name] = searcher.get_state()
|
| 182 |
+
searcher = searcher.searcher
|
| 183 |
+
base_searcher = searcher
|
| 184 |
+
# We save the base searcher separately for users to easily
|
| 185 |
+
# separate the searcher.
|
| 186 |
+
base_searcher.save_to_dir(dirpath, session_str)
|
| 187 |
+
_atomic_save(
|
| 188 |
+
state=search_alg_state,
|
| 189 |
+
checkpoint_dir=dirpath,
|
| 190 |
+
file_name=self.CKPT_FILE_TMPL.format(session_str),
|
| 191 |
+
tmp_file_name=".tmp_search_generator_ckpt",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def restore_from_dir(self, dirpath: str):
|
| 195 |
+
"""Restores self + searcher + search wrappers from dirpath."""
|
| 196 |
+
|
| 197 |
+
searcher = self.searcher
|
| 198 |
+
search_alg_state = _load_newest_checkpoint(
|
| 199 |
+
dirpath, self.CKPT_FILE_TMPL.format("*")
|
| 200 |
+
)
|
| 201 |
+
if not search_alg_state:
|
| 202 |
+
raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
|
| 203 |
+
while hasattr(searcher, "searcher"):
|
| 204 |
+
searcher_name = "name:" + type(searcher).__name__
|
| 205 |
+
if searcher_name not in search_alg_state:
|
| 206 |
+
names = [
|
| 207 |
+
key.split("name:")[1]
|
| 208 |
+
for key in search_alg_state
|
| 209 |
+
if key.startswith("name:")
|
| 210 |
+
]
|
| 211 |
+
logger.warning(
|
| 212 |
+
"{} was not found in the experiment "
|
| 213 |
+
"state when restoring. Found {}.".format(searcher_name, names)
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
searcher.set_state(search_alg_state.pop(searcher_name))
|
| 217 |
+
searcher = searcher.searcher
|
| 218 |
+
base_searcher = searcher
|
| 219 |
+
|
| 220 |
+
logger.debug(f"searching base {base_searcher}")
|
| 221 |
+
base_searcher.restore_from_dir(dirpath)
|
| 222 |
+
self.set_state(search_alg_state)
|
.venv/lib/python3.11/site-packages/ray/tune/search/searcher.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
from ray.air._internal.usage import tag_searcher
|
| 9 |
+
from ray.tune.search.util import _set_search_properties_backwards_compatible
|
| 10 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 11 |
+
from ray.util.debug import log_once
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from ray.tune.analysis import ExperimentAnalysis
|
| 15 |
+
from ray.tune.experiment import Trial
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@DeveloperAPI
|
| 21 |
+
class Searcher:
|
| 22 |
+
"""Abstract class for wrapping suggesting algorithms.
|
| 23 |
+
|
| 24 |
+
Custom algorithms can extend this class easily by overriding the
|
| 25 |
+
`suggest` method provide generated parameters for the trials.
|
| 26 |
+
|
| 27 |
+
Any subclass that implements ``__init__`` must also call the
|
| 28 |
+
constructor of this class: ``super(Subclass, self).__init__(...)``.
|
| 29 |
+
|
| 30 |
+
To track suggestions and their corresponding evaluations, the method
|
| 31 |
+
`suggest` will be passed a trial_id, which will be used in
|
| 32 |
+
subsequent notifications.
|
| 33 |
+
|
| 34 |
+
Not all implementations support multi objectives.
|
| 35 |
+
|
| 36 |
+
Note to Tune developers: If a new searcher is added, please update
|
| 37 |
+
`air/_internal/usage.py`.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
metric: The training result objective value attribute. If
|
| 41 |
+
list then list of training result objective value attributes
|
| 42 |
+
mode: If string One of {min, max}. If list then
|
| 43 |
+
list of max and min, determines whether objective is minimizing
|
| 44 |
+
or maximizing the metric attribute. Must match type of metric.
|
| 45 |
+
|
| 46 |
+
.. code-block:: python
|
| 47 |
+
|
| 48 |
+
class ExampleSearch(Searcher):
|
| 49 |
+
def __init__(self, metric="mean_loss", mode="min", **kwargs):
|
| 50 |
+
super(ExampleSearch, self).__init__(
|
| 51 |
+
metric=metric, mode=mode, **kwargs)
|
| 52 |
+
self.optimizer = Optimizer()
|
| 53 |
+
self.configurations = {}
|
| 54 |
+
|
| 55 |
+
def suggest(self, trial_id):
|
| 56 |
+
configuration = self.optimizer.query()
|
| 57 |
+
self.configurations[trial_id] = configuration
|
| 58 |
+
|
| 59 |
+
def on_trial_complete(self, trial_id, result, **kwargs):
|
| 60 |
+
configuration = self.configurations[trial_id]
|
| 61 |
+
if result and self.metric in result:
|
| 62 |
+
self.optimizer.update(configuration, result[self.metric])
|
| 63 |
+
|
| 64 |
+
tuner = tune.Tuner(
|
| 65 |
+
trainable_function,
|
| 66 |
+
tune_config=tune.TuneConfig(
|
| 67 |
+
search_alg=ExampleSearch()
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
tuner.fit()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
FINISHED = "FINISHED"
|
| 76 |
+
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
metric: Optional[str] = None,
|
| 81 |
+
mode: Optional[str] = None,
|
| 82 |
+
):
|
| 83 |
+
tag_searcher(self)
|
| 84 |
+
self._metric = metric
|
| 85 |
+
self._mode = mode
|
| 86 |
+
|
| 87 |
+
if not mode or not metric:
|
| 88 |
+
# Early return to avoid assertions
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
assert isinstance(
|
| 92 |
+
metric, type(mode)
|
| 93 |
+
), "metric and mode must be of the same type"
|
| 94 |
+
if isinstance(mode, str):
|
| 95 |
+
assert mode in ["min", "max"], "if `mode` is a str must be 'min' or 'max'!"
|
| 96 |
+
elif isinstance(mode, list):
|
| 97 |
+
assert len(mode) == len(metric), "Metric and mode must be the same length"
|
| 98 |
+
assert all(
|
| 99 |
+
mod in ["min", "max", "obs"] for mod in mode
|
| 100 |
+
), "All of mode must be 'min' or 'max' or 'obs'!"
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError("Mode most either be a list or string")
|
| 103 |
+
|
| 104 |
+
def set_search_properties(
|
| 105 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 106 |
+
) -> bool:
|
| 107 |
+
"""Pass search properties to searcher.
|
| 108 |
+
|
| 109 |
+
This method acts as an alternative to instantiating search algorithms
|
| 110 |
+
with their own specific search spaces. Instead they can accept a
|
| 111 |
+
Tune config through this method. A searcher should return ``True``
|
| 112 |
+
if setting the config was successful, or ``False`` if it was
|
| 113 |
+
unsuccessful, e.g. when the search space has already been set.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
metric: Metric to optimize
|
| 117 |
+
mode: One of ["min", "max"]. Direction to optimize.
|
| 118 |
+
config: Tune config dict.
|
| 119 |
+
**spec: Any kwargs for forward compatiblity.
|
| 120 |
+
Info like Experiment.PUBLIC_KEYS is provided through here.
|
| 121 |
+
"""
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def on_trial_result(self, trial_id: str, result: Dict) -> None:
|
| 125 |
+
"""Optional notification for result during training.
|
| 126 |
+
|
| 127 |
+
Note that by default, the result dict may include NaNs or
|
| 128 |
+
may not include the optimization metric. It is up to the
|
| 129 |
+
subclass implementation to preprocess the result to
|
| 130 |
+
avoid breaking the optimization process.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
trial_id: A unique string ID for the trial.
|
| 134 |
+
result: Dictionary of metrics for current training progress.
|
| 135 |
+
Note that the result dict may include NaNs or
|
| 136 |
+
may not include the optimization metric. It is up to the
|
| 137 |
+
subclass implementation to preprocess the result to
|
| 138 |
+
avoid breaking the optimization process.
|
| 139 |
+
"""
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
def on_trial_complete(
|
| 143 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 144 |
+
) -> None:
|
| 145 |
+
"""Notification for the completion of trial.
|
| 146 |
+
|
| 147 |
+
Typically, this method is used for notifying the underlying
|
| 148 |
+
optimizer of the result.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
trial_id: A unique string ID for the trial.
|
| 152 |
+
result: Dictionary of metrics for current training progress.
|
| 153 |
+
Note that the result dict may include NaNs or
|
| 154 |
+
may not include the optimization metric. It is up to the
|
| 155 |
+
subclass implementation to preprocess the result to
|
| 156 |
+
avoid breaking the optimization process. Upon errors, this
|
| 157 |
+
may also be None.
|
| 158 |
+
error: True if the training process raised an error.
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
+
raise NotImplementedError
|
| 162 |
+
|
| 163 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 164 |
+
"""Queries the algorithm to retrieve the next set of parameters.
|
| 165 |
+
|
| 166 |
+
Arguments:
|
| 167 |
+
trial_id: Trial ID used for subsequent notifications.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
dict | FINISHED | None: Configuration for a trial, if possible.
|
| 171 |
+
If FINISHED is returned, Tune will be notified that
|
| 172 |
+
no more suggestions/configurations will be provided.
|
| 173 |
+
If None is returned, Tune will skip the querying of the
|
| 174 |
+
searcher for this step.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
raise NotImplementedError
|
| 178 |
+
|
| 179 |
+
def add_evaluated_point(
|
| 180 |
+
self,
|
| 181 |
+
parameters: Dict,
|
| 182 |
+
value: float,
|
| 183 |
+
error: bool = False,
|
| 184 |
+
pruned: bool = False,
|
| 185 |
+
intermediate_values: Optional[List[float]] = None,
|
| 186 |
+
):
|
| 187 |
+
"""Pass results from a point that has been evaluated separately.
|
| 188 |
+
|
| 189 |
+
This method allows for information from outside the
|
| 190 |
+
suggest - on_trial_complete loop to be passed to the search
|
| 191 |
+
algorithm.
|
| 192 |
+
This functionality depends on the underlying search algorithm
|
| 193 |
+
and may not be always available.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
parameters: Parameters used for the trial.
|
| 197 |
+
value: Metric value obtained in the trial.
|
| 198 |
+
error: True if the training process raised an error.
|
| 199 |
+
pruned: True if trial was pruned.
|
| 200 |
+
intermediate_values: List of metric values for
|
| 201 |
+
intermediate iterations of the result. None if not
|
| 202 |
+
applicable.
|
| 203 |
+
|
| 204 |
+
"""
|
| 205 |
+
raise NotImplementedError
|
| 206 |
+
|
| 207 |
+
def add_evaluated_trials(
|
| 208 |
+
self,
|
| 209 |
+
trials_or_analysis: Union["Trial", List["Trial"], "ExperimentAnalysis"],
|
| 210 |
+
metric: str,
|
| 211 |
+
):
|
| 212 |
+
"""Pass results from trials that have been evaluated separately.
|
| 213 |
+
|
| 214 |
+
This method allows for information from outside the
|
| 215 |
+
suggest - on_trial_complete loop to be passed to the search
|
| 216 |
+
algorithm.
|
| 217 |
+
This functionality depends on the underlying search algorithm
|
| 218 |
+
and may not be always available (same as ``add_evaluated_point``.)
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
trials_or_analysis: Trials to pass results form to the searcher.
|
| 222 |
+
metric: Metric name reported by trials used for
|
| 223 |
+
determining the objective value.
|
| 224 |
+
|
| 225 |
+
"""
|
| 226 |
+
if self.add_evaluated_point == Searcher.add_evaluated_point:
|
| 227 |
+
raise NotImplementedError
|
| 228 |
+
|
| 229 |
+
# lazy imports to avoid circular dependencies
|
| 230 |
+
from ray.tune.analysis import ExperimentAnalysis
|
| 231 |
+
from ray.tune.experiment import Trial
|
| 232 |
+
from ray.tune.result import DONE
|
| 233 |
+
|
| 234 |
+
if isinstance(trials_or_analysis, (list, tuple)):
|
| 235 |
+
trials = trials_or_analysis
|
| 236 |
+
elif isinstance(trials_or_analysis, Trial):
|
| 237 |
+
trials = [trials_or_analysis]
|
| 238 |
+
elif isinstance(trials_or_analysis, ExperimentAnalysis):
|
| 239 |
+
trials = trials_or_analysis.trials
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError(
|
| 242 |
+
"Expected input to be a `Trial`, a list of `Trial`s, or "
|
| 243 |
+
f"`ExperimentAnalysis`, got: {trials_or_analysis}"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
any_trial_had_metric = False
|
| 247 |
+
|
| 248 |
+
def trial_to_points(trial: Trial) -> Dict[str, Any]:
|
| 249 |
+
nonlocal any_trial_had_metric
|
| 250 |
+
has_trial_been_pruned = (
|
| 251 |
+
trial.status == Trial.TERMINATED
|
| 252 |
+
and not trial.last_result.get(DONE, False)
|
| 253 |
+
)
|
| 254 |
+
has_trial_finished = (
|
| 255 |
+
trial.status == Trial.TERMINATED and trial.last_result.get(DONE, False)
|
| 256 |
+
)
|
| 257 |
+
if not any_trial_had_metric:
|
| 258 |
+
any_trial_had_metric = (
|
| 259 |
+
metric in trial.last_result and has_trial_finished
|
| 260 |
+
)
|
| 261 |
+
if Trial.TERMINATED and metric not in trial.last_result:
|
| 262 |
+
return None
|
| 263 |
+
return dict(
|
| 264 |
+
parameters=trial.config,
|
| 265 |
+
value=trial.last_result.get(metric, None),
|
| 266 |
+
error=trial.status == Trial.ERROR,
|
| 267 |
+
pruned=has_trial_been_pruned,
|
| 268 |
+
intermediate_values=None, # we do not save those
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
for trial in trials:
|
| 272 |
+
kwargs = trial_to_points(trial)
|
| 273 |
+
if kwargs:
|
| 274 |
+
self.add_evaluated_point(**kwargs)
|
| 275 |
+
|
| 276 |
+
if not any_trial_had_metric:
|
| 277 |
+
warnings.warn(
|
| 278 |
+
"No completed trial returned the specified metric. "
|
| 279 |
+
"Make sure the name you have passed is correct. "
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def save(self, checkpoint_path: str):
|
| 283 |
+
"""Save state to path for this search algorithm.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
checkpoint_path: File where the search algorithm
|
| 287 |
+
state is saved. This path should be used later when
|
| 288 |
+
restoring from file.
|
| 289 |
+
|
| 290 |
+
Example:
|
| 291 |
+
|
| 292 |
+
.. code-block:: python
|
| 293 |
+
|
| 294 |
+
search_alg = Searcher(...)
|
| 295 |
+
|
| 296 |
+
tuner = tune.Tuner(
|
| 297 |
+
cost,
|
| 298 |
+
tune_config=tune.TuneConfig(
|
| 299 |
+
search_alg=search_alg,
|
| 300 |
+
num_samples=5
|
| 301 |
+
),
|
| 302 |
+
param_space=config
|
| 303 |
+
)
|
| 304 |
+
results = tuner.fit()
|
| 305 |
+
|
| 306 |
+
search_alg.save("./my_favorite_path.pkl")
|
| 307 |
+
|
| 308 |
+
.. versionchanged:: 0.8.7
|
| 309 |
+
Save is automatically called by `Tuner().fit()`. You can use
|
| 310 |
+
`Tuner().restore()` to restore from an experiment directory
|
| 311 |
+
such as `~/ray_results/trainable`.
|
| 312 |
+
|
| 313 |
+
"""
|
| 314 |
+
raise NotImplementedError
|
| 315 |
+
|
| 316 |
+
def restore(self, checkpoint_path: str):
|
| 317 |
+
"""Restore state for this search algorithm
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
checkpoint_path: File where the search algorithm
|
| 322 |
+
state is saved. This path should be the same
|
| 323 |
+
as the one provided to "save".
|
| 324 |
+
|
| 325 |
+
Example:
|
| 326 |
+
|
| 327 |
+
.. code-block:: python
|
| 328 |
+
|
| 329 |
+
search_alg.save("./my_favorite_path.pkl")
|
| 330 |
+
|
| 331 |
+
search_alg2 = Searcher(...)
|
| 332 |
+
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
|
| 333 |
+
search_alg2.restore(checkpoint_path)
|
| 334 |
+
tuner = tune.Tuner(
|
| 335 |
+
cost,
|
| 336 |
+
tune_config=tune.TuneConfig(
|
| 337 |
+
search_alg=search_alg2,
|
| 338 |
+
num_samples=5
|
| 339 |
+
),
|
| 340 |
+
)
|
| 341 |
+
tuner.fit()
|
| 342 |
+
|
| 343 |
+
"""
|
| 344 |
+
raise NotImplementedError
|
| 345 |
+
|
| 346 |
+
def set_max_concurrency(self, max_concurrent: int) -> bool:
|
| 347 |
+
"""Set max concurrent trials this searcher can run.
|
| 348 |
+
|
| 349 |
+
This method will be called on the wrapped searcher by the
|
| 350 |
+
``ConcurrencyLimiter``. It is intended to allow for searchers
|
| 351 |
+
which have custom, internal logic handling max concurrent trials
|
| 352 |
+
to inherit the value passed to ``ConcurrencyLimiter``.
|
| 353 |
+
|
| 354 |
+
If this method returns False, it signifies that no special
|
| 355 |
+
logic for handling this case is present in the searcher.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
max_concurrent: Number of maximum concurrent trials.
|
| 359 |
+
"""
|
| 360 |
+
return False
|
| 361 |
+
|
| 362 |
+
def get_state(self) -> Dict:
|
| 363 |
+
raise NotImplementedError
|
| 364 |
+
|
| 365 |
+
def set_state(self, state: Dict):
|
| 366 |
+
raise NotImplementedError
|
| 367 |
+
|
| 368 |
+
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
|
| 369 |
+
"""Automatically saves the given searcher to the checkpoint_dir.
|
| 370 |
+
|
| 371 |
+
This is automatically used by Tuner().fit() during a Tune job.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
checkpoint_dir: Filepath to experiment dir.
|
| 375 |
+
session_str: Unique identifier of the current run
|
| 376 |
+
session.
|
| 377 |
+
"""
|
| 378 |
+
tmp_search_ckpt_path = os.path.join(checkpoint_dir, ".tmp_searcher_ckpt")
|
| 379 |
+
success = True
|
| 380 |
+
try:
|
| 381 |
+
self.save(tmp_search_ckpt_path)
|
| 382 |
+
except NotImplementedError:
|
| 383 |
+
if log_once("suggest:save_to_dir"):
|
| 384 |
+
logger.warning("save not implemented for Searcher. Skipping save.")
|
| 385 |
+
success = False
|
| 386 |
+
|
| 387 |
+
if success and os.path.exists(tmp_search_ckpt_path):
|
| 388 |
+
os.replace(
|
| 389 |
+
tmp_search_ckpt_path,
|
| 390 |
+
os.path.join(checkpoint_dir, self.CKPT_FILE_TMPL.format(session_str)),
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def restore_from_dir(self, checkpoint_dir: str):
|
| 394 |
+
"""Restores the state of a searcher from a given checkpoint_dir.
|
| 395 |
+
|
| 396 |
+
Typically, you should use this function to restore from an
|
| 397 |
+
experiment directory such as `~/ray_results/trainable`.
|
| 398 |
+
|
| 399 |
+
.. code-block:: python
|
| 400 |
+
|
| 401 |
+
tuner = tune.Tuner(
|
| 402 |
+
cost,
|
| 403 |
+
run_config=train.RunConfig(
|
| 404 |
+
name=self.experiment_name,
|
| 405 |
+
storage_path="~/my_results",
|
| 406 |
+
),
|
| 407 |
+
tune_config=tune.TuneConfig(
|
| 408 |
+
search_alg=search_alg,
|
| 409 |
+
num_samples=5
|
| 410 |
+
),
|
| 411 |
+
param_space=config
|
| 412 |
+
)
|
| 413 |
+
tuner.fit()
|
| 414 |
+
|
| 415 |
+
search_alg2 = Searcher()
|
| 416 |
+
search_alg2.restore_from_dir(
|
| 417 |
+
os.path.join("~/my_results", self.experiment_name)
|
| 418 |
+
"""
|
| 419 |
+
|
| 420 |
+
pattern = self.CKPT_FILE_TMPL.format("*")
|
| 421 |
+
full_paths = glob.glob(os.path.join(checkpoint_dir, pattern))
|
| 422 |
+
if not full_paths:
|
| 423 |
+
raise RuntimeError(
|
| 424 |
+
"Searcher unable to find checkpoint in {}".format(checkpoint_dir)
|
| 425 |
+
) # TODO
|
| 426 |
+
most_recent_checkpoint = max(full_paths)
|
| 427 |
+
self.restore(most_recent_checkpoint)
|
| 428 |
+
|
| 429 |
+
@property
|
| 430 |
+
def metric(self) -> str:
|
| 431 |
+
"""The training result objective value attribute."""
|
| 432 |
+
return self._metric
|
| 433 |
+
|
| 434 |
+
@property
|
| 435 |
+
def mode(self) -> str:
|
| 436 |
+
"""Specifies if minimizing or maximizing the metric."""
|
| 437 |
+
return self._mode
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@PublicAPI
|
| 441 |
+
class ConcurrencyLimiter(Searcher):
|
| 442 |
+
"""A wrapper algorithm for limiting the number of concurrent trials.
|
| 443 |
+
|
| 444 |
+
Certain Searchers have their own internal logic for limiting
|
| 445 |
+
the number of concurrent trials. If such a Searcher is passed to a
|
| 446 |
+
``ConcurrencyLimiter``, the ``max_concurrent`` of the
|
| 447 |
+
``ConcurrencyLimiter`` will override the ``max_concurrent`` value
|
| 448 |
+
of the Searcher. The ``ConcurrencyLimiter`` will then let the
|
| 449 |
+
Searcher's internal logic take over.
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
searcher: Searcher object that the
|
| 453 |
+
ConcurrencyLimiter will manage.
|
| 454 |
+
max_concurrent: Maximum concurrent samples from the underlying
|
| 455 |
+
searcher.
|
| 456 |
+
batch: Whether to wait for all concurrent samples
|
| 457 |
+
to finish before updating the underlying searcher.
|
| 458 |
+
|
| 459 |
+
Example:
|
| 460 |
+
|
| 461 |
+
.. code-block:: python
|
| 462 |
+
|
| 463 |
+
from ray.tune.search import ConcurrencyLimiter
|
| 464 |
+
search_alg = HyperOptSearch(metric="accuracy")
|
| 465 |
+
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
|
| 466 |
+
tuner = tune.Tuner(
|
| 467 |
+
trainable_function,
|
| 468 |
+
tune_config=tune.TuneConfig(
|
| 469 |
+
search_alg=search_alg
|
| 470 |
+
),
|
| 471 |
+
)
|
| 472 |
+
tuner.fit()
|
| 473 |
+
|
| 474 |
+
"""
|
| 475 |
+
|
| 476 |
+
def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
|
| 477 |
+
assert type(max_concurrent) is int and max_concurrent > 0
|
| 478 |
+
self.searcher = searcher
|
| 479 |
+
self.max_concurrent = max_concurrent
|
| 480 |
+
self.batch = batch
|
| 481 |
+
self.live_trials = set()
|
| 482 |
+
self.num_unfinished_live_trials = 0
|
| 483 |
+
self.cached_results = {}
|
| 484 |
+
self._limit_concurrency = True
|
| 485 |
+
|
| 486 |
+
if not isinstance(searcher, Searcher):
|
| 487 |
+
raise RuntimeError(
|
| 488 |
+
f"The `ConcurrencyLimiter` only works with `Searcher` "
|
| 489 |
+
f"objects (got {type(searcher)}). Please try to pass "
|
| 490 |
+
f"`max_concurrent` to the search generator directly."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
self._set_searcher_max_concurrency()
|
| 494 |
+
|
| 495 |
+
super(ConcurrencyLimiter, self).__init__(
|
| 496 |
+
metric=self.searcher.metric, mode=self.searcher.mode
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
def _set_searcher_max_concurrency(self):
|
| 500 |
+
# If the searcher has special logic for handling max concurrency,
|
| 501 |
+
# we do not do anything inside the ConcurrencyLimiter
|
| 502 |
+
self._limit_concurrency = not self.searcher.set_max_concurrency(
|
| 503 |
+
self.max_concurrent
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def set_max_concurrency(self, max_concurrent: int) -> bool:
|
| 507 |
+
# Determine if this behavior is acceptable, or if it should
|
| 508 |
+
# raise an exception.
|
| 509 |
+
self.max_concurrent = max_concurrent
|
| 510 |
+
return True
|
| 511 |
+
|
| 512 |
+
def set_search_properties(
|
| 513 |
+
self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
|
| 514 |
+
) -> bool:
|
| 515 |
+
self._set_searcher_max_concurrency()
|
| 516 |
+
return _set_search_properties_backwards_compatible(
|
| 517 |
+
self.searcher.set_search_properties, metric, mode, config, **spec
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
def suggest(self, trial_id: str) -> Optional[Dict]:
|
| 521 |
+
if not self._limit_concurrency:
|
| 522 |
+
return self.searcher.suggest(trial_id)
|
| 523 |
+
|
| 524 |
+
assert (
|
| 525 |
+
trial_id not in self.live_trials
|
| 526 |
+
), f"Trial ID {trial_id} must be unique: already found in set."
|
| 527 |
+
if len(self.live_trials) >= self.max_concurrent:
|
| 528 |
+
logger.debug(
|
| 529 |
+
f"Not providing a suggestion for {trial_id} due to "
|
| 530 |
+
"concurrency limit: %s/%s.",
|
| 531 |
+
len(self.live_trials),
|
| 532 |
+
self.max_concurrent,
|
| 533 |
+
)
|
| 534 |
+
return
|
| 535 |
+
|
| 536 |
+
suggestion = self.searcher.suggest(trial_id)
|
| 537 |
+
if suggestion not in (None, Searcher.FINISHED):
|
| 538 |
+
self.live_trials.add(trial_id)
|
| 539 |
+
self.num_unfinished_live_trials += 1
|
| 540 |
+
return suggestion
|
| 541 |
+
|
| 542 |
+
def on_trial_complete(
|
| 543 |
+
self, trial_id: str, result: Optional[Dict] = None, error: bool = False
|
| 544 |
+
):
|
| 545 |
+
if not self._limit_concurrency:
|
| 546 |
+
return self.searcher.on_trial_complete(trial_id, result=result, error=error)
|
| 547 |
+
|
| 548 |
+
if trial_id not in self.live_trials:
|
| 549 |
+
return
|
| 550 |
+
elif self.batch:
|
| 551 |
+
self.cached_results[trial_id] = (result, error)
|
| 552 |
+
self.num_unfinished_live_trials -= 1
|
| 553 |
+
if self.num_unfinished_live_trials <= 0:
|
| 554 |
+
# Update the underlying searcher once the
|
| 555 |
+
# full batch is completed.
|
| 556 |
+
for trial_id, (result, error) in self.cached_results.items():
|
| 557 |
+
self.searcher.on_trial_complete(
|
| 558 |
+
trial_id, result=result, error=error
|
| 559 |
+
)
|
| 560 |
+
self.live_trials.remove(trial_id)
|
| 561 |
+
self.cached_results = {}
|
| 562 |
+
self.num_unfinished_live_trials = 0
|
| 563 |
+
else:
|
| 564 |
+
return
|
| 565 |
+
else:
|
| 566 |
+
self.searcher.on_trial_complete(trial_id, result=result, error=error)
|
| 567 |
+
self.live_trials.remove(trial_id)
|
| 568 |
+
self.num_unfinished_live_trials -= 1
|
| 569 |
+
|
| 570 |
+
def on_trial_result(self, trial_id: str, result: Dict) -> None:
|
| 571 |
+
self.searcher.on_trial_result(trial_id, result)
|
| 572 |
+
|
| 573 |
+
def add_evaluated_point(
|
| 574 |
+
self,
|
| 575 |
+
parameters: Dict,
|
| 576 |
+
value: float,
|
| 577 |
+
error: bool = False,
|
| 578 |
+
pruned: bool = False,
|
| 579 |
+
intermediate_values: Optional[List[float]] = None,
|
| 580 |
+
):
|
| 581 |
+
return self.searcher.add_evaluated_point(
|
| 582 |
+
parameters, value, error, pruned, intermediate_values
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
def get_state(self) -> Dict:
|
| 586 |
+
state = self.__dict__.copy()
|
| 587 |
+
del state["searcher"]
|
| 588 |
+
return copy.deepcopy(state)
|
| 589 |
+
|
| 590 |
+
def set_state(self, state: Dict):
|
| 591 |
+
self.__dict__.update(state)
|
| 592 |
+
|
| 593 |
+
def save(self, checkpoint_path: str):
|
| 594 |
+
self.searcher.save(checkpoint_path)
|
| 595 |
+
|
| 596 |
+
def restore(self, checkpoint_path: str):
|
| 597 |
+
self.searcher.restore(checkpoint_path)
|
.venv/lib/python3.11/site-packages/ray/tune/search/util.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _set_search_properties_backwards_compatible(
|
| 8 |
+
set_search_properties_func,
|
| 9 |
+
metric: Optional[str],
|
| 10 |
+
mode: Optional[str],
|
| 11 |
+
config: Dict,
|
| 12 |
+
**spec
|
| 13 |
+
) -> bool:
|
| 14 |
+
"""Wraps around set_search_properties() so that it is backward compatible.
|
| 15 |
+
|
| 16 |
+
Also outputs a warning to encourage custom searchers to be updated.
|
| 17 |
+
"""
|
| 18 |
+
try:
|
| 19 |
+
return set_search_properties_func(metric, mode, config, **spec)
|
| 20 |
+
except TypeError as e:
|
| 21 |
+
if str(e).startswith(
|
| 22 |
+
"set_search_properties() got an unexpected keyword argument"
|
| 23 |
+
):
|
| 24 |
+
logger.warning(
|
| 25 |
+
"Please update custom Searcher to take in function signature "
|
| 26 |
+
"as ``def set_search_properties(metric, mode, config, "
|
| 27 |
+
"**spec) -> bool``."
|
| 28 |
+
)
|
| 29 |
+
return set_search_properties_func(metric, mode, config)
|
| 30 |
+
else:
|
| 31 |
+
raise e
|
.venv/lib/python3.11/site-packages/ray/tune/search/variant_generator.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from collections.abc import Mapping
|
| 6 |
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy
|
| 9 |
+
|
| 10 |
+
from ray.tune.search.sample import Categorical, Domain, Function, RandomState
|
| 11 |
+
from ray.util.annotations import DeveloperAPI, PublicAPI
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@DeveloperAPI
|
| 17 |
+
def generate_variants(
|
| 18 |
+
unresolved_spec: Dict,
|
| 19 |
+
constant_grid_search: bool = False,
|
| 20 |
+
random_state: "RandomState" = None,
|
| 21 |
+
) -> Generator[Tuple[Dict, Dict], None, None]:
|
| 22 |
+
"""Generates variants from a spec (dict) with unresolved values.
|
| 23 |
+
|
| 24 |
+
There are two types of unresolved values:
|
| 25 |
+
|
| 26 |
+
Grid search: These define a grid search over values. For example, the
|
| 27 |
+
following grid search values in a spec will produce six distinct
|
| 28 |
+
variants in combination:
|
| 29 |
+
|
| 30 |
+
"activation": grid_search(["relu", "tanh"])
|
| 31 |
+
"learning_rate": grid_search([1e-3, 1e-4, 1e-5])
|
| 32 |
+
|
| 33 |
+
Lambda functions: These are evaluated to produce a concrete value, and
|
| 34 |
+
can express dependencies or conditional distributions between values.
|
| 35 |
+
They can also be used to express random search (e.g., by calling
|
| 36 |
+
into the `random` or `np` module).
|
| 37 |
+
|
| 38 |
+
"cpu": lambda spec: spec.config.num_workers
|
| 39 |
+
"batch_size": lambda spec: random.uniform(1, 1000)
|
| 40 |
+
|
| 41 |
+
Finally, to support defining specs in plain JSON / YAML, grid search
|
| 42 |
+
and lambda functions can also be defined alternatively as follows:
|
| 43 |
+
|
| 44 |
+
"activation": {"grid_search": ["relu", "tanh"]}
|
| 45 |
+
"cpu": {"eval": "spec.config.num_workers"}
|
| 46 |
+
|
| 47 |
+
Use `format_vars` to format the returned dict of hyperparameters.
|
| 48 |
+
|
| 49 |
+
Yields:
|
| 50 |
+
(Dict of resolved variables, Spec object)
|
| 51 |
+
"""
|
| 52 |
+
for resolved_vars, spec in _generate_variants_internal(
|
| 53 |
+
unresolved_spec,
|
| 54 |
+
constant_grid_search=constant_grid_search,
|
| 55 |
+
random_state=random_state,
|
| 56 |
+
):
|
| 57 |
+
assert not _unresolved_values(spec)
|
| 58 |
+
yield resolved_vars, spec
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@PublicAPI(stability="beta")
|
| 62 |
+
def grid_search(values: Iterable) -> Dict[str, Iterable]:
|
| 63 |
+
"""Specify a grid of values to search over.
|
| 64 |
+
|
| 65 |
+
Values specified in a grid search are guaranteed to be sampled.
|
| 66 |
+
|
| 67 |
+
If multiple grid search variables are defined, they are combined with the
|
| 68 |
+
combinatorial product. This means every possible combination of values will
|
| 69 |
+
be sampled.
|
| 70 |
+
|
| 71 |
+
Example:
|
| 72 |
+
|
| 73 |
+
>>> from ray import tune
|
| 74 |
+
>>> param_space={
|
| 75 |
+
... "x": tune.grid_search([10, 20]),
|
| 76 |
+
... "y": tune.grid_search(["a", "b", "c"])
|
| 77 |
+
... }
|
| 78 |
+
|
| 79 |
+
This will create a grid of 6 samples:
|
| 80 |
+
``{"x": 10, "y": "a"}``, ``{"x": 10, "y": "b"}``, etc.
|
| 81 |
+
|
| 82 |
+
When specifying ``num_samples`` in the
|
| 83 |
+
:class:`TuneConfig <ray.tune.tune_config.TuneConfig>`, this will specify
|
| 84 |
+
the number of random samples per grid search combination.
|
| 85 |
+
|
| 86 |
+
For instance, in the example above, if ``num_samples=4``,
|
| 87 |
+
a total of 24 trials will be started -
|
| 88 |
+
4 trials for each of the 6 grid search combinations.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
values: An iterable whose parameters will be used for creating a trial grid.
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
return {"grid_search": values}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
_STANDARD_IMPORTS = {
|
| 98 |
+
"random": random,
|
| 99 |
+
"np": numpy,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
_MAX_RESOLUTION_PASSES = 20
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
|
| 106 |
+
"""Flattens a nested dict by joining keys into tuple of paths.
|
| 107 |
+
|
| 108 |
+
Can then be passed into `format_vars`.
|
| 109 |
+
"""
|
| 110 |
+
res = {}
|
| 111 |
+
for k, v in nested_dict.items():
|
| 112 |
+
if isinstance(v, dict):
|
| 113 |
+
for k_, v_ in _resolve_nested_dict(v).items():
|
| 114 |
+
res[(k,) + k_] = v_
|
| 115 |
+
else:
|
| 116 |
+
res[(k,)] = v
|
| 117 |
+
return res
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@DeveloperAPI
|
| 121 |
+
def format_vars(resolved_vars: Dict) -> str:
|
| 122 |
+
"""Format variables to be used as experiment tags.
|
| 123 |
+
|
| 124 |
+
Experiment tags are used in directory names, so this method makes sure
|
| 125 |
+
the resulting tags can be legally used in directory names on all systems.
|
| 126 |
+
|
| 127 |
+
The input to this function is a dict of the form
|
| 128 |
+
``{("nested", "config", "path"): "value"}``. The output will be a comma
|
| 129 |
+
separated string of the form ``last_key=value``, so in this example
|
| 130 |
+
``path=value``.
|
| 131 |
+
|
| 132 |
+
Note that the sanitizing implies that empty strings are possible return
|
| 133 |
+
values. This is expected and acceptable, as it is not a common case and
|
| 134 |
+
the resulting directory names will still be valid.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
resolved_vars: Dictionary mapping from config path tuples to a value.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Comma-separated key=value string.
|
| 141 |
+
"""
|
| 142 |
+
vars = resolved_vars.copy()
|
| 143 |
+
# TrialRunner already has these in the experiment_tag
|
| 144 |
+
for v in ["run", "env", "resources_per_trial"]:
|
| 145 |
+
vars.pop(v, None)
|
| 146 |
+
|
| 147 |
+
return ",".join(
|
| 148 |
+
f"{_clean_value(k[-1])}={_clean_value(v)}" for k, v in sorted(vars.items())
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _flatten_resolved_vars(resolved_vars: Dict) -> Dict:
|
| 153 |
+
"""Formats the resolved variable dict into a mapping of (str -> value)."""
|
| 154 |
+
flattened_resolved_vars_dict = {}
|
| 155 |
+
for pieces, value in resolved_vars.items():
|
| 156 |
+
if pieces[0] == "config":
|
| 157 |
+
pieces = pieces[1:]
|
| 158 |
+
pieces = [str(piece) for piece in pieces]
|
| 159 |
+
flattened_resolved_vars_dict["/".join(pieces)] = value
|
| 160 |
+
return flattened_resolved_vars_dict
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _clean_value(value: Any) -> str:
|
| 164 |
+
"""Format floats and replace invalid string characters with ``_``."""
|
| 165 |
+
if isinstance(value, float):
|
| 166 |
+
return f"{value:.4f}"
|
| 167 |
+
else:
|
| 168 |
+
# Define an invalid alphabet, which is the inverse of the
|
| 169 |
+
# stated regex characters
|
| 170 |
+
invalid_alphabet = r"[^a-zA-Z0-9_-]+"
|
| 171 |
+
return re.sub(invalid_alphabet, "_", str(value)).strip("_")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@DeveloperAPI
|
| 175 |
+
def parse_spec_vars(
|
| 176 |
+
spec: Dict,
|
| 177 |
+
) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]]]:
|
| 178 |
+
resolved, unresolved = _split_resolved_unresolved_values(spec)
|
| 179 |
+
resolved_vars = list(resolved.items())
|
| 180 |
+
|
| 181 |
+
if not unresolved:
|
| 182 |
+
return resolved_vars, [], []
|
| 183 |
+
|
| 184 |
+
grid_vars = []
|
| 185 |
+
domain_vars = []
|
| 186 |
+
for path, value in unresolved.items():
|
| 187 |
+
if value.is_grid():
|
| 188 |
+
grid_vars.append((path, value))
|
| 189 |
+
else:
|
| 190 |
+
domain_vars.append((path, value))
|
| 191 |
+
grid_vars.sort()
|
| 192 |
+
|
| 193 |
+
return resolved_vars, domain_vars, grid_vars
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _count_spec_samples(spec: Dict, num_samples=1) -> int:
|
| 197 |
+
"""Count samples for a specific spec"""
|
| 198 |
+
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 199 |
+
grid_count = 1
|
| 200 |
+
for path, domain in grid_vars:
|
| 201 |
+
grid_count *= len(domain.categories)
|
| 202 |
+
return num_samples * grid_count
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
| 206 |
+
# Helper function: Deep update dictionary
|
| 207 |
+
def deep_update(d, u):
|
| 208 |
+
for k, v in u.items():
|
| 209 |
+
if isinstance(v, Mapping):
|
| 210 |
+
d[k] = deep_update(d.get(k, {}), v)
|
| 211 |
+
else:
|
| 212 |
+
d[k] = v
|
| 213 |
+
return d
|
| 214 |
+
|
| 215 |
+
total_samples = 0
|
| 216 |
+
total_num_samples = spec.get("num_samples", 1)
|
| 217 |
+
# For each preset, overwrite the spec and count the samples generated
|
| 218 |
+
# for this preset
|
| 219 |
+
for preset in presets:
|
| 220 |
+
preset_spec = copy.deepcopy(spec)
|
| 221 |
+
deep_update(preset_spec["config"], preset)
|
| 222 |
+
total_samples += _count_spec_samples(preset_spec, 1)
|
| 223 |
+
total_num_samples -= 1
|
| 224 |
+
|
| 225 |
+
# Add the remaining samples
|
| 226 |
+
if total_num_samples > 0:
|
| 227 |
+
total_samples += _count_spec_samples(spec, total_num_samples)
|
| 228 |
+
return total_samples
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _generate_variants_internal(
|
| 232 |
+
spec: Dict, constant_grid_search: bool = False, random_state: "RandomState" = None
|
| 233 |
+
) -> Tuple[Dict, Dict]:
|
| 234 |
+
spec = copy.deepcopy(spec)
|
| 235 |
+
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
| 236 |
+
|
| 237 |
+
if not domain_vars and not grid_vars:
|
| 238 |
+
yield {}, spec
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
# Variables to resolve
|
| 242 |
+
to_resolve = domain_vars
|
| 243 |
+
|
| 244 |
+
all_resolved = True
|
| 245 |
+
if constant_grid_search:
|
| 246 |
+
# In this path, we first sample random variables and keep them constant
|
| 247 |
+
# for grid search.
|
| 248 |
+
# `_resolve_domain_vars` will alter `spec` directly
|
| 249 |
+
all_resolved, resolved_vars = _resolve_domain_vars(
|
| 250 |
+
spec, domain_vars, allow_fail=True, random_state=random_state
|
| 251 |
+
)
|
| 252 |
+
if not all_resolved:
|
| 253 |
+
# Not all variables have been resolved, but remove those that have
|
| 254 |
+
# from the `to_resolve` list.
|
| 255 |
+
to_resolve = [(r, d) for r, d in to_resolve if r not in resolved_vars]
|
| 256 |
+
grid_search = _grid_search_generator(spec, grid_vars)
|
| 257 |
+
for resolved_spec in grid_search:
|
| 258 |
+
if not constant_grid_search or not all_resolved:
|
| 259 |
+
# In this path, we sample the remaining random variables
|
| 260 |
+
_, resolved_vars = _resolve_domain_vars(
|
| 261 |
+
resolved_spec, to_resolve, random_state=random_state
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
for resolved, spec in _generate_variants_internal(
|
| 265 |
+
resolved_spec,
|
| 266 |
+
constant_grid_search=constant_grid_search,
|
| 267 |
+
random_state=random_state,
|
| 268 |
+
):
|
| 269 |
+
for path, value in grid_vars:
|
| 270 |
+
resolved_vars[path] = _get_value(spec, path)
|
| 271 |
+
for k, v in resolved.items():
|
| 272 |
+
if (
|
| 273 |
+
k in resolved_vars
|
| 274 |
+
and v != resolved_vars[k]
|
| 275 |
+
and _is_resolved(resolved_vars[k])
|
| 276 |
+
):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
"The variable `{}` could not be unambiguously "
|
| 279 |
+
"resolved to a single value. Consider simplifying "
|
| 280 |
+
"your configuration.".format(k)
|
| 281 |
+
)
|
| 282 |
+
resolved_vars[k] = v
|
| 283 |
+
yield resolved_vars, spec
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _get_preset_variants(
|
| 287 |
+
spec: Dict,
|
| 288 |
+
config: Dict,
|
| 289 |
+
constant_grid_search: bool = False,
|
| 290 |
+
random_state: "RandomState" = None,
|
| 291 |
+
):
|
| 292 |
+
"""Get variants according to a spec, initialized with a config.
|
| 293 |
+
|
| 294 |
+
Variables from the spec are overwritten by the variables in the config.
|
| 295 |
+
Thus, we may end up with less sampled parameters.
|
| 296 |
+
|
| 297 |
+
This function also checks if values used to overwrite search space
|
| 298 |
+
parameters are valid, and logs a warning if not.
|
| 299 |
+
"""
|
| 300 |
+
spec = copy.deepcopy(spec)
|
| 301 |
+
|
| 302 |
+
resolved, _, _ = parse_spec_vars(config)
|
| 303 |
+
|
| 304 |
+
for path, val in resolved:
|
| 305 |
+
try:
|
| 306 |
+
domain = _get_value(spec["config"], path)
|
| 307 |
+
if isinstance(domain, dict):
|
| 308 |
+
if "grid_search" in domain:
|
| 309 |
+
domain = Categorical(domain["grid_search"])
|
| 310 |
+
else:
|
| 311 |
+
# If users want to overwrite an entire subdict,
|
| 312 |
+
# let them do it.
|
| 313 |
+
domain = None
|
| 314 |
+
except IndexError as exc:
|
| 315 |
+
raise ValueError(
|
| 316 |
+
f"Pre-set config key `{'/'.join(path)}` does not correspond "
|
| 317 |
+
f"to a valid key in the search space definition. Please add "
|
| 318 |
+
f"this path to the `param_space` variable passed to `tune.Tuner()`."
|
| 319 |
+
) from exc
|
| 320 |
+
|
| 321 |
+
if domain:
|
| 322 |
+
if isinstance(domain, Domain):
|
| 323 |
+
if not domain.is_valid(val):
|
| 324 |
+
logger.warning(
|
| 325 |
+
f"Pre-set value `{val}` is not within valid values of "
|
| 326 |
+
f"parameter `{'/'.join(path)}`: {domain.domain_str}"
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
# domain is actually a fixed value
|
| 330 |
+
if domain != val:
|
| 331 |
+
logger.warning(
|
| 332 |
+
f"Pre-set value `{val}` is not equal to the value of "
|
| 333 |
+
f"parameter `{'/'.join(path)}`: {domain}"
|
| 334 |
+
)
|
| 335 |
+
assign_value(spec["config"], path, val)
|
| 336 |
+
|
| 337 |
+
return _generate_variants_internal(
|
| 338 |
+
spec, constant_grid_search=constant_grid_search, random_state=random_state
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
@DeveloperAPI
|
| 343 |
+
def assign_value(spec: Dict, path: Tuple, value: Any):
|
| 344 |
+
"""Assigns a value to a nested dictionary.
|
| 345 |
+
|
| 346 |
+
Handles the special case of tuples, in which case the tuples
|
| 347 |
+
will be re-constructed to accomodate the updated value.
|
| 348 |
+
"""
|
| 349 |
+
parent_spec = None
|
| 350 |
+
parent_key = None
|
| 351 |
+
for k in path[:-1]:
|
| 352 |
+
parent_spec = spec
|
| 353 |
+
parent_key = k
|
| 354 |
+
spec = spec[k]
|
| 355 |
+
key = path[-1]
|
| 356 |
+
if not isinstance(spec, tuple):
|
| 357 |
+
# spec is mutable. Just assign the value.
|
| 358 |
+
spec[key] = value
|
| 359 |
+
else:
|
| 360 |
+
if parent_spec is None:
|
| 361 |
+
raise ValueError("Cannot assign value to a tuple.")
|
| 362 |
+
assert isinstance(key, int), "Tuple key must be an int."
|
| 363 |
+
# Special handling since tuples are immutable.
|
| 364 |
+
parent_spec[parent_key] = spec[:key] + (value,) + spec[key + 1 :]
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _get_value(spec: Dict, path: Tuple) -> Any:
|
| 368 |
+
for k in path:
|
| 369 |
+
spec = spec[k]
|
| 370 |
+
return spec
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _resolve_domain_vars(
|
| 374 |
+
spec: Dict,
|
| 375 |
+
domain_vars: List[Tuple[Tuple, Domain]],
|
| 376 |
+
allow_fail: bool = False,
|
| 377 |
+
random_state: "RandomState" = None,
|
| 378 |
+
) -> Tuple[bool, Dict]:
|
| 379 |
+
resolved = {}
|
| 380 |
+
error = True
|
| 381 |
+
num_passes = 0
|
| 382 |
+
while error and num_passes < _MAX_RESOLUTION_PASSES:
|
| 383 |
+
num_passes += 1
|
| 384 |
+
error = False
|
| 385 |
+
for path, domain in domain_vars:
|
| 386 |
+
if path in resolved:
|
| 387 |
+
continue
|
| 388 |
+
try:
|
| 389 |
+
value = domain.sample(
|
| 390 |
+
_UnresolvedAccessGuard(spec), random_state=random_state
|
| 391 |
+
)
|
| 392 |
+
except RecursiveDependencyError as e:
|
| 393 |
+
error = e
|
| 394 |
+
except Exception:
|
| 395 |
+
raise ValueError(
|
| 396 |
+
"Failed to evaluate expression: {}: {}".format(path, domain)
|
| 397 |
+
)
|
| 398 |
+
else:
|
| 399 |
+
assign_value(spec, path, value)
|
| 400 |
+
resolved[path] = value
|
| 401 |
+
if error:
|
| 402 |
+
if not allow_fail:
|
| 403 |
+
raise error
|
| 404 |
+
else:
|
| 405 |
+
return False, resolved
|
| 406 |
+
return True, resolved
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def _grid_search_generator(
|
| 410 |
+
unresolved_spec: Dict, grid_vars: List
|
| 411 |
+
) -> Generator[Dict, None, None]:
|
| 412 |
+
value_indices = [0] * len(grid_vars)
|
| 413 |
+
|
| 414 |
+
def increment(i):
|
| 415 |
+
value_indices[i] += 1
|
| 416 |
+
if value_indices[i] >= len(grid_vars[i][1]):
|
| 417 |
+
value_indices[i] = 0
|
| 418 |
+
if i + 1 < len(value_indices):
|
| 419 |
+
return increment(i + 1)
|
| 420 |
+
else:
|
| 421 |
+
return True
|
| 422 |
+
return False
|
| 423 |
+
|
| 424 |
+
if not grid_vars:
|
| 425 |
+
yield unresolved_spec
|
| 426 |
+
return
|
| 427 |
+
|
| 428 |
+
while value_indices[-1] < len(grid_vars[-1][1]):
|
| 429 |
+
spec = copy.deepcopy(unresolved_spec)
|
| 430 |
+
for i, (path, values) in enumerate(grid_vars):
|
| 431 |
+
assign_value(spec, path, values[value_indices[i]])
|
| 432 |
+
yield spec
|
| 433 |
+
if grid_vars:
|
| 434 |
+
done = increment(0)
|
| 435 |
+
if done:
|
| 436 |
+
break
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _is_resolved(v) -> bool:
|
| 440 |
+
resolved, _ = _try_resolve(v)
|
| 441 |
+
return resolved
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def _try_resolve(v) -> Tuple[bool, Any]:
|
| 445 |
+
if isinstance(v, Domain):
|
| 446 |
+
# Domain to sample from
|
| 447 |
+
return False, v
|
| 448 |
+
elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
|
| 449 |
+
# Lambda function in eval syntax
|
| 450 |
+
return False, Function(
|
| 451 |
+
lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec})
|
| 452 |
+
)
|
| 453 |
+
elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
|
| 454 |
+
# Grid search values
|
| 455 |
+
grid_values = v["grid_search"]
|
| 456 |
+
return False, Categorical(grid_values).grid()
|
| 457 |
+
return True, v
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _split_resolved_unresolved_values(
|
| 461 |
+
spec: Dict,
|
| 462 |
+
) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
|
| 463 |
+
resolved_vars = {}
|
| 464 |
+
unresolved_vars = {}
|
| 465 |
+
for k, v in spec.items():
|
| 466 |
+
resolved, v = _try_resolve(v)
|
| 467 |
+
if not resolved:
|
| 468 |
+
unresolved_vars[(k,)] = v
|
| 469 |
+
elif isinstance(v, dict):
|
| 470 |
+
# Recurse into a dict
|
| 471 |
+
(
|
| 472 |
+
_resolved_children,
|
| 473 |
+
_unresolved_children,
|
| 474 |
+
) = _split_resolved_unresolved_values(v)
|
| 475 |
+
for path, value in _resolved_children.items():
|
| 476 |
+
resolved_vars[(k,) + path] = value
|
| 477 |
+
for path, value in _unresolved_children.items():
|
| 478 |
+
unresolved_vars[(k,) + path] = value
|
| 479 |
+
elif isinstance(v, (list, tuple)):
|
| 480 |
+
# Recurse into a list
|
| 481 |
+
for i, elem in enumerate(v):
|
| 482 |
+
(
|
| 483 |
+
_resolved_children,
|
| 484 |
+
_unresolved_children,
|
| 485 |
+
) = _split_resolved_unresolved_values({i: elem})
|
| 486 |
+
for path, value in _resolved_children.items():
|
| 487 |
+
resolved_vars[(k,) + path] = value
|
| 488 |
+
for path, value in _unresolved_children.items():
|
| 489 |
+
unresolved_vars[(k,) + path] = value
|
| 490 |
+
else:
|
| 491 |
+
resolved_vars[(k,)] = v
|
| 492 |
+
return resolved_vars, unresolved_vars
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
|
| 496 |
+
return _split_resolved_unresolved_values(spec)[1]
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _has_unresolved_values(spec: Dict) -> bool:
|
| 500 |
+
return True if _unresolved_values(spec) else False
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class _UnresolvedAccessGuard(dict):
|
| 504 |
+
def __init__(self, *args, **kwds):
|
| 505 |
+
super(_UnresolvedAccessGuard, self).__init__(*args, **kwds)
|
| 506 |
+
self.__dict__ = self
|
| 507 |
+
|
| 508 |
+
def __getattribute__(self, item):
|
| 509 |
+
value = dict.__getattribute__(self, item)
|
| 510 |
+
if not _is_resolved(value):
|
| 511 |
+
raise RecursiveDependencyError(
|
| 512 |
+
"`{}` recursively depends on {}".format(item, value)
|
| 513 |
+
)
|
| 514 |
+
elif isinstance(value, dict):
|
| 515 |
+
return _UnresolvedAccessGuard(value)
|
| 516 |
+
else:
|
| 517 |
+
return value
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
@DeveloperAPI
|
| 521 |
+
class RecursiveDependencyError(Exception):
|
| 522 |
+
def __init__(self, msg: str):
|
| 523 |
+
Exception.__init__(self, msg)
|
.venv/lib/python3.11/site-packages/ray/tune/stopper/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.tune.stopper.experiment_plateau import ExperimentPlateauStopper
|
| 2 |
+
from ray.tune.stopper.function_stopper import FunctionStopper
|
| 3 |
+
from ray.tune.stopper.maximum_iteration import MaximumIterationStopper
|
| 4 |
+
from ray.tune.stopper.noop import NoopStopper
|
| 5 |
+
from ray.tune.stopper.stopper import CombinedStopper, Stopper
|
| 6 |
+
from ray.tune.stopper.timeout import TimeoutStopper
|
| 7 |
+
from ray.tune.stopper.trial_plateau import TrialPlateauStopper
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"Stopper",
|
| 11 |
+
"CombinedStopper",
|
| 12 |
+
"ExperimentPlateauStopper",
|
| 13 |
+
"FunctionStopper",
|
| 14 |
+
"MaximumIterationStopper",
|
| 15 |
+
"NoopStopper",
|
| 16 |
+
"TimeoutStopper",
|
| 17 |
+
"TrialPlateauStopper",
|
| 18 |
+
]
|
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (894 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/experiment_plateau.cpython-311.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/tune/stopper/__pycache__/function_stopper.cpython-311.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|